Skip to content

AOT Backend API

Backend

This is the base class for all backends.

aitune.torch.backend.Backend

Backend()

Bases: ABC

Backend interface for tuning a module.

Initialize the backend.

Source code in aitune/torch/backend/backend.py
def __init__(self):
    """Initialize the backend."""
    self.state = BackendState.INIT
    self._device = None

device property

device

Get the device of the backend.

Returns:

  • device

    The device the module is using.

is_active property

is_active

Returns True if the backend is active.

is_jit abstractmethod property

is_jit

Returns True if the backend is a JIT backend.

This method ensures that the backend has this property defined.

name property

name

Name of a backend.

activate

activate()

Activates backend.

After activating, the backend should be ready to do inference.

Source code in aitune/torch/backend/backend.py
@nvtx.annotate(domain="AITune", color="black")
def activate(self):
    """Activates backend.

    After activating, the backend should be ready to do inference.
    """
    if self.state == BackendState.INIT:
        raise RuntimeError(f"Cannot activate backend {self.name}, backend should be built first")
    if self.state == BackendState.DEPLOYED:
        raise RuntimeError(f"Cannot activate backend {self.name}, backend is already deployed")

    if self.state == BackendState.INACTIVE or self.state == BackendState.CHECKPOINT_LOADED:
        self._activate()
        self.state = BackendState.ACTIVE

build

build(module, graph_spec, data, device, cache_dir)

Build the model with the given arguments.

Building a backend should be idempotent i.e. do not cause side effects. A model is not necessarily pure functional and can have an internal state (like kv cache for LLMs). That is why build can call a sample of inputs at most once so that subsequent calls have exact same state as the first call for the given sample.

After building, the backend should be activated.

Source code in aitune/torch/backend/backend.py
def build(
    self,
    module: nn.Module,
    graph_spec: GraphSpec,
    data: list[Sample],
    device: torch.device,
    cache_dir: Path,
) -> "Backend":
    """Build the model with the given arguments.

    Building a backend should be idempotent i.e. do not cause side effects. A model is not necessarily pure
    functional and can have an internal state (like kv cache for LLMs). That is why build can call a sample of
    inputs at most once so that subsequent calls have exact same state as the first call for the given sample.

    After building, the backend should be activated.
    """
    if self.state == BackendState.INIT:
        try:
            self._assert_device(device)
            self._set_device(device)
            ready_backend = self._build(module, graph_spec, data, cache_dir)
            self.state = BackendState.ACTIVE
            return ready_backend
        except Exception as e:
            self._logger.error("Failed to build backend(%s): %s", self.__class__.__name__, e, exc_info=True)
            raise e
    else:
        raise RuntimeError(f"Backend {self.name} build should be called only once")

deactivate

deactivate()

Deactivates backend.

After deactivating, the backend cannot be used to do inference.

Source code in aitune/torch/backend/backend.py
def deactivate(self):
    """Deactivates backend.

    After deactivating, the backend cannot be used to do inference.
    """
    if self.state == BackendState.INIT:
        raise RuntimeError(f"Cannot deactivate backend {self.name}, backend should be built first")
    if self.state == BackendState.DEPLOYED:
        raise RuntimeError(f"Cannot deactivate backend {self.name}, backend is already deployed")
    if self.state == BackendState.CHECKPOINT_LOADED:
        raise RuntimeError(f"Cannot deactivate backend {self.name}, backend has already been deployed")

    if self.state == BackendState.ACTIVE:
        self._deactivate()
        self._clean_memory()
        self.state = BackendState.INACTIVE

deploy

deploy(device)

Deploys the backend.

After deploying, the backend is ready to do inference. Backend cannot be deactivated anymore.

Parameters:

  • device (device | None) –

    The device to deploy the backend on.

Source code in aitune/torch/backend/backend.py
def deploy(self, device: torch.device | None):
    """Deploys the backend.

    After deploying, the backend is ready to do inference. Backend cannot be deactivated anymore.

    Args:
        device: The device to deploy the backend on.
    """
    if self.state != BackendState.CHECKPOINT_LOADED:
        raise RuntimeError(f"Cannot deploy backend {self.name}, backend should be loaded from a checkpoint")

    self._set_device(device)
    self._deploy()
    self.state = BackendState.DEPLOYED

describe abstractmethod

describe()

Returns the description of the backend.

Source code in aitune/torch/backend/backend.py
@abstractmethod
def describe(self) -> str:
    """Returns the description of the backend."""
    pass

from_dict abstractmethod classmethod

from_dict(state_dict)

Creates a backend from a state_dict.

Source code in aitune/torch/backend/backend.py
@classmethod
@abstractmethod
def from_dict(cls, state_dict: dict):
    """Creates a backend from a state_dict."""
    pass

infer

infer(*args, **kwargs)

Run inference with the given arguments.

Parameters:

  • args (Any, default: () ) –

    Variable length argument list.

  • kwargs (Any, default: {} ) –

    Arbitrary keyword arguments.

Returns:

  • Any ( Any ) –

    The result of the inference.

Source code in aitune/torch/backend/backend.py
def infer(self, *args: Any, **kwargs: Any) -> Any:
    """Run inference with the given arguments.

    Args:
        args: Variable length argument list.
        kwargs: Arbitrary keyword arguments.

    Returns:
        Any: The result of the inference.
    """
    if self.state != BackendState.ACTIVE and self.state != BackendState.DEPLOYED:
        raise RuntimeError(f"Cannot run inference, backend {self.name} should be activated first")

    return self._infer(*args, **kwargs)

key abstractmethod

key()

Returns the key of the backend.

Source code in aitune/torch/backend/backend.py
@abstractmethod
def key(self) -> str:
    """Returns the key of the backend."""
    pass

to_dict abstractmethod

to_dict()

Returns the state_dict of the backend.

Note: if there any binary artifacts (files) which should be stored by a backend, they must be passed as Python Path object. Such objects will be bundled with a checkpoint.

Source code in aitune/torch/backend/backend.py
@abstractmethod
def to_dict(self):
    """Returns the state_dict of the backend.

    Note: if there any binary artifacts (files) which should be stored by a backend,
    they must be passed as Python Path object. Such objects will be bundled with a checkpoint.
    """
    pass

BackendConfig

aitune.torch.backend.backend.BackendConfig dataclass

BackendConfig()

Configuration for a backend.

describe

describe()

Describe the backend configuration. Display only changed fields.

Source code in aitune/torch/backend/backend.py
def describe(self) -> str:
    """Describe the backend configuration. Display only changed fields."""
    default = self.__class__()
    changed_fields = self._get_changed_fields(self, default)
    return ",".join(changed_fields)

key

key()

Returns the keys of the backend configuration.

Source code in aitune/torch/backend/backend.py
def key(self) -> str:
    """Returns the keys of the backend configuration."""
    config_dict = self.to_dict()
    config_dict = self._to_json(config_dict)
    config_dict_str = json.dumps(config_dict)
    key = hash_string(config_dict_str)
    return key

to_dict

to_dict()

Returns the state_dict of the backend.

Source code in aitune/torch/backend/backend.py
def to_dict(self):
    """Returns the state_dict of the backend."""
    return asdict(self)

to_json

to_json(path)

Saves the backend configuration to a file.

Source code in aitune/torch/backend/backend.py
def to_json(self, path: Path):
    """Saves the backend configuration to a file."""
    config_dict = self.to_dict()
    config_dict = self._to_json(config_dict)
    with open(path, "w") as f:
        json.dump(config_dict, f, indent=2)