Skip to content

Torch-TensorRT AOT Backend API

TorchTensorRTAotBackend

aitune.torch.backend.TorchTensorRTAotBackend

TorchTensorRTAotBackend(config=None)

Bases: Backend

Backend that compiles model using TensorRT.

Initialize TensorRT backend.

Parameters:

Source code in aitune/torch/backend/torch_tensorrt_aot_backend.py
def __init__(
    self,
    config: TorchTensorRTAotBackendConfig | None = None,
):
    """Initialize TensorRT backend.

    Args:
        config: Configuration for TensorRT compilation
    """
    super().__init__()
    assert_cuda_is_available()
    assert_torch_tensorrt()

    # initialize variables
    self._config = config or TorchTensorRTAotBackendConfig()

    # build variables
    self._opt_module = None
    self._exported_model_path = None

device property

device

Get the device of the backend.

Returns:

  • device

    The device the module is using.

is_active property

is_active

Returns True if the backend is active.

name property

name

Name of a backend.

activate

activate()

Activates backend.

After activating, the backend should be ready to do inference.

Source code in aitune/torch/backend/backend.py
@nvtx.annotate(domain="AITune", color="black")
def activate(self):
    """Activates backend.

    After activating, the backend should be ready to do inference.
    """
    if self.state == BackendState.INIT:
        raise RuntimeError(f"Cannot activate backend {self.name}, backend should be built first")
    if self.state == BackendState.DEPLOYED:
        raise RuntimeError(f"Cannot activate backend {self.name}, backend is already deployed")

    if self.state == BackendState.INACTIVE or self.state == BackendState.CHECKPOINT_LOADED:
        self._activate()
        self.state = BackendState.ACTIVE

build

build(module, graph_spec, data, device, cache_dir)

Build the model with the given arguments.

Building a backend should be idempotent i.e. do not cause side effects. A model is not necessarily pure functional and can have an internal state (like kv cache for LLMs). That is why build can call a sample of inputs at most once so that subsequent calls have exact same state as the first call for the given sample.

After building, the backend should be activated.

Source code in aitune/torch/backend/backend.py
def build(
    self,
    module: nn.Module,
    graph_spec: GraphSpec,
    data: list[Sample],
    device: torch.device,
    cache_dir: Path,
) -> "Backend":
    """Build the model with the given arguments.

    Building a backend should be idempotent i.e. do not cause side effects. A model is not necessarily pure
    functional and can have an internal state (like kv cache for LLMs). That is why build can call a sample of
    inputs at most once so that subsequent calls have exact same state as the first call for the given sample.

    After building, the backend should be activated.
    """
    if self.state == BackendState.INIT:
        try:
            self._assert_device(device)
            self._set_device(device)
            ready_backend = self._build(module, graph_spec, data, cache_dir)
            self.state = BackendState.ACTIVE
            return ready_backend
        except Exception as e:
            self._logger.error("Failed to build backend(%s): %s", self.__class__.__name__, e, exc_info=True)
            raise e
    else:
        raise RuntimeError(f"Backend {self.name} build should be called only once")

deactivate

deactivate()

Deactivates backend.

After deactivating, the backend cannot be used to do inference.

Source code in aitune/torch/backend/backend.py
def deactivate(self):
    """Deactivates backend.

    After deactivating, the backend cannot be used to do inference.
    """
    if self.state == BackendState.INIT:
        raise RuntimeError(f"Cannot deactivate backend {self.name}, backend should be built first")
    if self.state == BackendState.DEPLOYED:
        raise RuntimeError(f"Cannot deactivate backend {self.name}, backend is already deployed")
    if self.state == BackendState.CHECKPOINT_LOADED:
        raise RuntimeError(f"Cannot deactivate backend {self.name}, backend has already been deployed")

    if self.state == BackendState.ACTIVE:
        self._deactivate()
        self._clean_memory()
        self.state = BackendState.INACTIVE

deploy

deploy(device)

Deploys the backend.

After deploying, the backend is ready to do inference. Backend cannot be deactivated anymore.

Parameters:

  • device (device | None) –

    The device to deploy the backend on.

Source code in aitune/torch/backend/backend.py
def deploy(self, device: torch.device | None):
    """Deploys the backend.

    After deploying, the backend is ready to do inference. Backend cannot be deactivated anymore.

    Args:
        device: The device to deploy the backend on.
    """
    if self.state != BackendState.CHECKPOINT_LOADED:
        raise RuntimeError(f"Cannot deploy backend {self.name}, backend should be loaded from a checkpoint")

    self._set_device(device)
    self._deploy()
    self.state = BackendState.DEPLOYED

describe

describe()

Returns the description of the backend.

Source code in aitune/torch/backend/torch_tensorrt_aot_backend.py
def describe(self) -> str:
    """Returns the description of the backend."""
    return f"{self.__class__.__name__}({self._config.describe()})"

from_dict classmethod

from_dict(module, state_dict)

Creates a backend from a state_dict.

Source code in aitune/torch/backend/torch_tensorrt_aot_backend.py
@classmethod
def from_dict(cls, module: torch.nn.Module | None, state_dict: dict) -> "TorchTensorRTAotBackend":
    """Creates a backend from a state_dict."""
    if state_dict.get(cls.STATE_TYPE) != cls.__name__:
        raise ValueError(f"Invalid state_dict type: {state_dict.get(cls.STATE_TYPE)}")

    backend = cls()  # create with default args, config is not needed
    backend._exported_model_path = state_dict[cls.EXPORTED_MODEL_PATH_KEY]
    backend._device = state_dict[cls.STATE_DEVICE]
    backend.state = BackendState.CHECKPOINT_LOADED
    return backend

infer

infer(*args, **kwargs)

Run inference with the given arguments.

Parameters:

  • args (Any, default: () ) –

    Variable length argument list.

  • kwargs (Any, default: {} ) –

    Arbitrary keyword arguments.

Returns:

  • Any ( Any ) –

    The result of the inference.

Source code in aitune/torch/backend/backend.py
def infer(self, *args: Any, **kwargs: Any) -> Any:
    """Run inference with the given arguments.

    Args:
        args: Variable length argument list.
        kwargs: Arbitrary keyword arguments.

    Returns:
        Any: The result of the inference.
    """
    if self.state != BackendState.ACTIVE and self.state != BackendState.DEPLOYED:
        raise RuntimeError(f"Cannot run inference, backend {self.name} should be activated first")

    return self._infer(*args, **kwargs)

key

key()

Returns the key of the backend.

Source code in aitune/torch/backend/torch_tensorrt_aot_backend.py
def key(self) -> str:
    """Returns the key of the backend."""
    return f"{self.__class__.__name__}_{self._config.key()}"

to_dict

to_dict()

Returns the state_dict of the backend.

Source code in aitune/torch/backend/torch_tensorrt_aot_backend.py
def to_dict(self) -> dict:
    """Returns the state_dict of the backend."""
    if self._exported_model_path is None:
        raise RuntimeError("No exported model path available. Model must be built first.")
    return {
        self.STATE_TYPE: self.__class__.__name__,
        self.EXPORTED_MODEL_PATH_KEY: self._exported_model_path,
        self.STATE_DEVICE: self._device,
    }

TorchTensorRTAotBackendConfig

aitune.torch.backend.TorchTensorRTAotBackendConfig dataclass

TorchTensorRTAotBackendConfig(ir='dynamo', compile_config=(lambda: TorchTensorRTConfig(enabled_precisions={float16}))(), pickle_protocol=DEFAULT_PICKLE_PROTOCOL)

Bases: BackendConfig

Configuration for TorchTensorRTAotBackend.

See torch_tensorrt/dynamo/_settings.py CompilationSettings for compile_config(TorchTensorRTConfig)

describe

describe()

Describe the backend configuration. Display only changed fields.

Source code in aitune/torch/backend/torch_tensorrt_aot_backend.py
def describe(self) -> str:
    """Describe the backend configuration. Display only changed fields."""
    other = self.__class__()
    compile_config_parts = self._get_changed_fields(
        self.compile_config, other.compile_config, include=["enabled_precisions"]
    )
    parts = [f"compile_config=TorchTensorRTConfig({','.join(compile_config_parts)})"]

    changed_fields = self._get_changed_fields(self, other, exclude=["compile_config"])
    parts.extend(changed_fields)

    return ",".join(parts)

key

key()

Returns the keys of the backend configuration.

Source code in aitune/torch/backend/backend.py
def key(self) -> str:
    """Returns the keys of the backend configuration."""
    config_dict = self.to_dict()
    config_dict = self._to_json(config_dict)
    config_dict_str = json.dumps(config_dict)
    key = hash_string(config_dict_str)
    return key

to_dict

to_dict()

Saves the backend configuration to a file.

Source code in aitune/torch/backend/torch_tensorrt_aot_backend.py
def to_dict(self):
    """Saves the backend configuration to a file."""
    state_dict = asdict(self)
    state_dict["compile_config"] = asdict(self.compile_config)  # explicitly convert to dict
    return state_dict

to_json

to_json(path)

Saves the backend configuration to a file.

Source code in aitune/torch/backend/backend.py
def to_json(self, path: Path):
    """Saves the backend configuration to a file."""
    config_dict = self.to_dict()
    config_dict = self._to_json(config_dict)
    with open(path, "w") as f:
        json.dump(config_dict, f, indent=2)