Skip to content

TorchAO Backend API

TorchAOBackend

aitune.torch.backend.TorchAOBackend

TorchAOBackend(config=None)

Bases: Backend

Backend that does torch quantization.

Supported quantizations
  • int8wo
  • int8dq
  • fp8wo
  • fp8dq

If you would like to use customize quantization, you can pass in a quantization config.

Initializes backend.

Parameters:

Source code in aitune/torch/backend/torchao_backend.py
def __init__(
    self,
    config: TorchAOBackendConfig | None = None,
):
    """Initializes backend.

    Args:
        config: The configuration to use.
    """
    super().__init__()
    # initialize variables
    self._config = config or TorchAOBackendConfig(quantization=DEFAULT_QUANTIZATION)

    # build variables
    self._quant_module = None
    self._orig_module = None
    self._data = None

device property

device

Get the device of the backend.

Returns:

  • device

    The device the module is using.

is_active property

is_active

Returns True if the backend is active.

name property

name

Name of a backend.

activate

activate()

Activates backend.

After activating, the backend should be ready to do inference.

Source code in aitune/torch/backend/backend.py
@nvtx.annotate(domain="AITune", color="black")
def activate(self):
    """Activates backend.

    After activating, the backend should be ready to do inference.
    """
    if self.state == BackendState.INIT:
        raise RuntimeError(f"Cannot activate backend {self.name}, backend should be built first")
    if self.state == BackendState.DEPLOYED:
        raise RuntimeError(f"Cannot activate backend {self.name}, backend is already deployed")

    if self.state == BackendState.INACTIVE or self.state == BackendState.CHECKPOINT_LOADED:
        self._activate()
        self.state = BackendState.ACTIVE

build

build(module, graph_spec, data, device, cache_dir)

Build the model with the given arguments.

Building a backend should be idempotent i.e. do not cause side effects. A model is not necessarily pure functional and can have an internal state (like kv cache for LLMs). That is why build can call a sample of inputs at most once so that subsequent calls have exact same state as the first call for the given sample.

After building, the backend should be activated.

Source code in aitune/torch/backend/backend.py
def build(
    self,
    module: nn.Module,
    graph_spec: GraphSpec,
    data: list[Sample],
    device: torch.device,
    cache_dir: Path,
) -> "Backend":
    """Build the model with the given arguments.

    Building a backend should be idempotent i.e. do not cause side effects. A model is not necessarily pure
    functional and can have an internal state (like kv cache for LLMs). That is why build can call a sample of
    inputs at most once so that subsequent calls have exact same state as the first call for the given sample.

    After building, the backend should be activated.
    """
    if self.state == BackendState.INIT:
        try:
            self._assert_device(device)
            self._set_device(device)
            ready_backend = self._build(module, graph_spec, data, cache_dir)
            self.state = BackendState.ACTIVE
            return ready_backend
        except Exception as e:
            self._logger.error("Failed to build backend(%s): %s", self.__class__.__name__, e, exc_info=True)
            raise e
    else:
        raise RuntimeError(f"Backend {self.name} build should be called only once")

deactivate

deactivate()

Deactivates backend.

After deactivating, the backend cannot be used to do inference.

Source code in aitune/torch/backend/backend.py
def deactivate(self):
    """Deactivates backend.

    After deactivating, the backend cannot be used to do inference.
    """
    if self.state == BackendState.INIT:
        raise RuntimeError(f"Cannot deactivate backend {self.name}, backend should be built first")
    if self.state == BackendState.DEPLOYED:
        raise RuntimeError(f"Cannot deactivate backend {self.name}, backend is already deployed")
    if self.state == BackendState.CHECKPOINT_LOADED:
        raise RuntimeError(f"Cannot deactivate backend {self.name}, backend has already been deployed")

    if self.state == BackendState.ACTIVE:
        self._deactivate()
        self._clean_memory()
        self.state = BackendState.INACTIVE

deploy

deploy(device)

Deploys the backend.

After deploying, the backend is ready to do inference. Backend cannot be deactivated anymore.

Parameters:

  • device (device | None) –

    The device to deploy the backend on.

Source code in aitune/torch/backend/backend.py
def deploy(self, device: torch.device | None):
    """Deploys the backend.

    After deploying, the backend is ready to do inference. Backend cannot be deactivated anymore.

    Args:
        device: The device to deploy the backend on.
    """
    if self.state != BackendState.CHECKPOINT_LOADED:
        raise RuntimeError(f"Cannot deploy backend {self.name}, backend should be loaded from a checkpoint")

    self._set_device(device)
    self._deploy()
    self.state = BackendState.DEPLOYED

describe

describe()

Returns the description of the backend.

Source code in aitune/torch/backend/torchao_backend.py
def describe(self) -> str:
    """Returns the description of the backend."""
    return f"{self.__class__.__name__}({self._config.describe()})"

from_dict classmethod

from_dict(module, state_dict)

Creates a backend from a state_dict.

Source code in aitune/torch/backend/torchao_backend.py
@classmethod
def from_dict(cls, module: torch.nn.Module | None, state_dict: dict):
    """Creates a backend from a state_dict."""
    if state_dict.get(cls.STATE_TYPE) != cls.__name__:
        raise ValueError(f"Invalid state_dict type: {state_dict.get(cls.STATE_TYPE)}")

    if module is None:
        raise ValueError("Module is required to create a backend from a state_dict.")

    config = TorchAOBackendConfig.from_dict(state_dict[cls.STATE_CONFIG])

    backend = cls(config=config)
    backend._data = state_dict[cls.STATE_DATA]
    backend._device = state_dict[cls.STATE_DEVICE]
    backend._orig_module = module
    module.load_state_dict(state_dict[cls.STATE_ORIG_MODULE], strict=False)
    backend.state = BackendState.CHECKPOINT_LOADED

    return backend

infer

infer(*args, **kwargs)

Run inference with the given arguments.

Parameters:

  • args (Any, default: () ) –

    Variable length argument list.

  • kwargs (Any, default: {} ) –

    Arbitrary keyword arguments.

Returns:

  • Any ( Any ) –

    The result of the inference.

Source code in aitune/torch/backend/backend.py
def infer(self, *args: Any, **kwargs: Any) -> Any:
    """Run inference with the given arguments.

    Args:
        args: Variable length argument list.
        kwargs: Arbitrary keyword arguments.

    Returns:
        Any: The result of the inference.
    """
    if self.state != BackendState.ACTIVE and self.state != BackendState.DEPLOYED:
        raise RuntimeError(f"Cannot run inference, backend {self.name} should be activated first")

    return self._infer(*args, **kwargs)

key

key()

Returns the key of the backend.

Source code in aitune/torch/backend/torchao_backend.py
def key(self) -> str:
    """Returns the key of the backend."""
    return f"{self.__class__.__name__}_{self._config.key()}"

to_dict

to_dict()

Returns the state_dict of the backend.

Source code in aitune/torch/backend/torchao_backend.py
def to_dict(self):
    """Returns the state_dict of the backend."""
    if self._orig_module is None:
        raise RuntimeError("Backend has not been properly initialized. Please call build() first.")
    return {
        self.STATE_TYPE: self.__class__.__name__,
        self.STATE_CONFIG: self._config.to_dict(),
        self.STATE_ORIG_MODULE: self._orig_module.state_dict(),
        self.STATE_DATA: self._data,
        self.STATE_DEVICE: self._device,
    }

TorchAOBackendConfig

aitune.torch.backend.TorchAOBackendConfig dataclass

TorchAOBackendConfig(quantization=None, quantization_config=None)

Bases: BackendConfig

Configuration for TorchAOBackend.

__post_init__

__post_init__()

Post init for TorchAOBackendConfig.

Source code in aitune/torch/backend/torchao_backend.py
def __post_init__(self):
    """Post init for TorchAOBackendConfig."""
    if self.quantization is not None and self.quantization_config is not None:
        raise ValueError("Only one of quantization or quantization_config should be provided.")
    if self.quantization is None and self.quantization_config is None:
        raise ValueError("Either quantization or quantization_config should be provided.")

    if not self.quantization_config:
        self.quantization_config = self._get_quantization_config(self.quantization)

describe

describe()

Returns the description of the backend.

Source code in aitune/torch/backend/torchao_backend.py
def describe(self) -> str:
    """Returns the description of the backend."""
    kwargs = {}
    for f in fields(self.quantization_config.__class__):
        if f.default is MISSING and f.default_factory is MISSING:
            kwargs[f.name] = getattr(self.quantization_config, f.name)

    changed_fields = self._get_changed_fields(
        self.quantization_config,
        self.quantization_config.__class__(*kwargs),
        include=list(kwargs.keys()),
    )
    return f"quantization_config={self.quantization_config.__class__.__name__}({','.join(changed_fields)})"

from_dict classmethod

from_dict(state_dict)

Convert dict to TorchAOBackendConfig.

Source code in aitune/torch/backend/torchao_backend.py
@classmethod
def from_dict(cls, state_dict: dict):
    """Convert dict to TorchAOBackendConfig."""
    return cls(quantization_config=loads(state_dict["quantization_config"]))

key

key()

Returns the key of the backend configuration.

Source code in aitune/torch/backend/torchao_backend.py
def key(self) -> str:
    """Returns the key of the backend configuration."""
    config_dict = {
        "quantization_config": self.quantization_config,
    }
    config_dict = self._to_json(config_dict)
    config_dict_str = json.dumps(config_dict)
    key = hash_string(config_dict_str)
    return key

to_dict

to_dict()

Convert TorchAOBackendConfig to dict.

Source code in aitune/torch/backend/torchao_backend.py
def to_dict(self):
    """Convert TorchAOBackendConfig to dict."""
    return {
        "quantization_config": dumps(self.quantization_config),
    }

to_json

to_json(path)

Saves the backend configuration to a file.

Source code in aitune/torch/backend/backend.py
def to_json(self, path: Path):
    """Saves the backend configuration to a file."""
    config_dict = self.to_dict()
    config_dict = self._to_json(config_dict)
    with open(path, "w") as f:
        json.dump(config_dict, f, indent=2)