Skip to content

Torch-TensorRT JIT Backend API

TorchTensorRTJitBackend

aitune.torch.backend.TorchTensorRTJitBackend

TorchTensorRTJitBackend(config=None)

Bases: Backend

Torch TensorRT thought torch.compile(backend="torch_tensorrt").

Backend does not use intermediate formats, and compiled model is not stored.

Initialize TorchTensorRTJitBackend.

Parameters:

Source code in aitune/torch/backend/torch_tensorrt_jit_backend.py
def __init__(
    self,
    config: TorchTensorRTJitBackendConfig | None = None,
):
    """Initialize TorchTensorRTJitBackend.

    Args:
        config: Configuration for torch compile with torch_tensorrt
    """
    super().__init__()
    assert_cuda_is_available()
    assert_torch_tensorrt()

    # initialize variables
    self._config = config or TorchTensorRTJitBackendConfig()

    # build variables
    self._compiled_module = None
    self._orig_module = None
    self._data = None

device property

device

Get the device of the backend.

Returns:

  • device

    The device the module is using.

is_active property

is_active

Returns True if the backend is active.

name property

name

Name of a backend.

activate

activate()

Activates backend.

After activating, the backend should be ready to do inference.

Source code in aitune/torch/backend/backend.py
@nvtx.annotate(domain="AITune", color="black")
def activate(self):
    """Activates backend.

    After activating, the backend should be ready to do inference.
    """
    if self.state == BackendState.INIT:
        raise RuntimeError(f"Cannot activate backend {self.name}, backend should be built first")
    if self.state == BackendState.DEPLOYED:
        raise RuntimeError(f"Cannot activate backend {self.name}, backend is already deployed")

    if self.state == BackendState.INACTIVE or self.state == BackendState.CHECKPOINT_LOADED:
        self._activate()
        self.state = BackendState.ACTIVE

build

build(module, graph_spec, data, device, cache_dir)

Build the model with the given arguments.

Building a backend should be idempotent i.e. do not cause side effects. A model is not necessarily pure functional and can have an internal state (like kv cache for LLMs). That is why build can call a sample of inputs at most once so that subsequent calls have exact same state as the first call for the given sample.

After building, the backend should be activated.

Source code in aitune/torch/backend/backend.py
def build(
    self,
    module: nn.Module,
    graph_spec: GraphSpec,
    data: list[Sample],
    device: torch.device,
    cache_dir: Path,
) -> "Backend":
    """Build the model with the given arguments.

    Building a backend should be idempotent i.e. do not cause side effects. A model is not necessarily pure
    functional and can have an internal state (like kv cache for LLMs). That is why build can call a sample of
    inputs at most once so that subsequent calls have exact same state as the first call for the given sample.

    After building, the backend should be activated.
    """
    if self.state == BackendState.INIT:
        try:
            self._assert_device(device)
            self._set_device(device)
            ready_backend = self._build(module, graph_spec, data, cache_dir)
            self.state = BackendState.ACTIVE
            return ready_backend
        except Exception as e:
            self._logger.error("Failed to build backend(%s): %s", self.__class__.__name__, e, exc_info=True)
            raise e
    else:
        raise RuntimeError(f"Backend {self.name} build should be called only once")

deactivate

deactivate()

Deactivates backend.

After deactivating, the backend cannot be used to do inference.

Source code in aitune/torch/backend/backend.py
def deactivate(self):
    """Deactivates backend.

    After deactivating, the backend cannot be used to do inference.
    """
    if self.state == BackendState.INIT:
        raise RuntimeError(f"Cannot deactivate backend {self.name}, backend should be built first")
    if self.state == BackendState.DEPLOYED:
        raise RuntimeError(f"Cannot deactivate backend {self.name}, backend is already deployed")
    if self.state == BackendState.CHECKPOINT_LOADED:
        raise RuntimeError(f"Cannot deactivate backend {self.name}, backend has already been deployed")

    if self.state == BackendState.ACTIVE:
        self._deactivate()
        self._clean_memory()
        self.state = BackendState.INACTIVE

deploy

deploy(device)

Deploys the backend.

After deploying, the backend is ready to do inference. Backend cannot be deactivated anymore.

Parameters:

  • device (device | None) –

    The device to deploy the backend on.

Source code in aitune/torch/backend/backend.py
def deploy(self, device: torch.device | None):
    """Deploys the backend.

    After deploying, the backend is ready to do inference. Backend cannot be deactivated anymore.

    Args:
        device: The device to deploy the backend on.
    """
    if self.state != BackendState.CHECKPOINT_LOADED:
        raise RuntimeError(f"Cannot deploy backend {self.name}, backend should be loaded from a checkpoint")

    self._set_device(device)
    self._deploy()
    self.state = BackendState.DEPLOYED

describe

describe()

Returns the description of the backend.

Source code in aitune/torch/backend/torch_tensorrt_jit_backend.py
def describe(self) -> str:
    """Returns the description of the backend."""
    return f"{self.__class__.__name__}({self._config.describe()})"

from_dict classmethod

from_dict(module, state_dict)

Creates a backend from a state_dict.

Source code in aitune/torch/backend/torch_tensorrt_jit_backend.py
@classmethod
def from_dict(cls, module: torch.nn.Module | None, state_dict: dict):
    """Creates a backend from a state_dict."""
    if state_dict.get(cls.STATE_TYPE) != cls.__name__:
        raise ValueError(f"Invalid state_dict type: {state_dict.get(cls.STATE_TYPE)}")

    if module is None:
        raise ValueError("Module is required to create a backend from a state_dict.")

    config = TorchTensorRTJitBackendConfig.from_dict(state_dict[cls.STATE_CONFIG])

    backend = cls(config=config)
    backend._data = state_dict[cls.STATE_DATA]
    backend._device = state_dict[cls.STATE_DEVICE]
    backend._orig_module = module
    module.load_state_dict(state_dict[cls.STATE_ORIG_MODULE], strict=False)
    backend.state = BackendState.CHECKPOINT_LOADED
    return backend

infer

infer(*args, **kwargs)

Run inference with the given arguments.

Parameters:

  • args (Any, default: () ) –

    Variable length argument list.

  • kwargs (Any, default: {} ) –

    Arbitrary keyword arguments.

Returns:

  • Any ( Any ) –

    The result of the inference.

Source code in aitune/torch/backend/backend.py
def infer(self, *args: Any, **kwargs: Any) -> Any:
    """Run inference with the given arguments.

    Args:
        args: Variable length argument list.
        kwargs: Arbitrary keyword arguments.

    Returns:
        Any: The result of the inference.
    """
    if self.state != BackendState.ACTIVE and self.state != BackendState.DEPLOYED:
        raise RuntimeError(f"Cannot run inference, backend {self.name} should be activated first")

    return self._infer(*args, **kwargs)

key

key()

Returns the key of the backend.

Source code in aitune/torch/backend/torch_tensorrt_jit_backend.py
def key(self) -> str:
    """Returns the key of the backend."""
    return f"{self.__class__.__name__}_{self._config.key()}"

to_dict

to_dict()

Returns the state_dict of the backend.

Source code in aitune/torch/backend/torch_tensorrt_jit_backend.py
def to_dict(self):
    """Returns the state_dict of the backend."""
    if self._orig_module is None:
        raise RuntimeError("Backend has not been properly initialized. Please call build() first.")
    return {
        self.STATE_TYPE: self.__class__.__name__,
        self.STATE_CONFIG: self._config.to_dict(),
        self.STATE_DATA: self._data,
        self.STATE_ORIG_MODULE: self._orig_module.state_dict(),
        self.STATE_DEVICE: self._device,
    }

TorchTensorRTJitBackendConfig

aitune.torch.backend.TorchTensorRTJitBackendConfig dataclass

TorchTensorRTJitBackendConfig(compile_config=(lambda: TorchTensorRTConfig(enabled_precisions={float16}))(), fullgraph=False, dynamic_shapes=None, autocast_enabled=False, autocast_dtype=None)

Bases: BackendConfig

Configuration for torch.compile(backend="torch_tensorrt").

describe

describe()

Describe the backend configuration. Display only changed fields.

Source code in aitune/torch/backend/torch_tensorrt_jit_backend.py
def describe(self) -> str:
    """Describe the backend configuration. Display only changed fields."""
    other = self.__class__()
    compile_config_parts = self._get_changed_fields(
        self.compile_config, other.compile_config, include=["enabled_precisions"]
    )
    parts = [f"compile_config=TorchTensorRTConfig({','.join(compile_config_parts)})"]

    changed_fields = self._get_changed_fields(self, other, exclude=["compile_config"])
    parts.extend(changed_fields)

    return ",".join(parts)

from_dict classmethod

from_dict(state_dict)

Convert dict to TorchTensorRTJitBackendConfig.

Source code in aitune/torch/backend/torch_tensorrt_jit_backend.py
@classmethod
def from_dict(cls, state_dict: dict):
    """Convert dict to TorchTensorRTJitBackendConfig."""
    state_dict["compile_config"] = TorchTensorRTConfig(**state_dict["compile_config"])
    return cls(**state_dict)

key

key()

Returns the keys of the backend configuration.

Source code in aitune/torch/backend/backend.py
def key(self) -> str:
    """Returns the keys of the backend configuration."""
    config_dict = self.to_dict()
    config_dict = self._to_json(config_dict)
    config_dict_str = json.dumps(config_dict)
    key = hash_string(config_dict_str)
    return key

to_dict

to_dict()

Convert TorchTensorRTJitBackendConfig to dict.

Source code in aitune/torch/backend/torch_tensorrt_jit_backend.py
def to_dict(self):
    """Convert TorchTensorRTJitBackendConfig to dict."""
    state_dict = asdict(self)
    state_dict["compile_config"] = asdict(self.compile_config)  # explicitly convert to dict
    return state_dict

to_json

to_json(path)

Saves the backend configuration to a file.

Source code in aitune/torch/backend/backend.py
def to_json(self, path: Path):
    """Saves the backend configuration to a file."""
    config_dict = self.to_dict()
    config_dict = self._to_json(config_dict)
    with open(path, "w") as f:
        json.dump(config_dict, f, indent=2)