Skip to content

Torch Inductor Backend API

TorchInductorBackend

aitune.torch.backend.TorchInductorBackend

TorchInductorBackend(config=None)

Bases: Backend

Backend that does torch compilation with Inductor.

Initializes backend.

Parameters:

Source code in aitune/torch/backend/torch_inductor_backend.py
def __init__(
    self,
    config: TorchInductorBackendConfig | None = None,
):
    """Initializes backend.

    Args:
        config: Configuration for torch compile with inductor backend
    """
    super().__init__()

    # initialize variables
    self._config = config or TorchInductorBackendConfig()

    # build variables
    self._compiled_module = None
    self._orig_module = None
    self._output_dtype = None
    self._data = None

device property

device

Get the device of the backend.

Returns:

  • device

    The device the module is using.

is_active property

is_active

Returns True if the backend is active.

name property

name

Name of a backend.

activate

activate()

Activates backend.

After activating, the backend should be ready to do inference.

Source code in aitune/torch/backend/backend.py
@nvtx.annotate(domain="AITune", color="black")
def activate(self):
    """Activates backend.

    After activating, the backend should be ready to do inference.
    """
    if self.state == BackendState.INIT:
        raise RuntimeError(f"Cannot activate backend {self.name}, backend should be built first")
    if self.state == BackendState.DEPLOYED:
        raise RuntimeError(f"Cannot activate backend {self.name}, backend is already deployed")

    if self.state == BackendState.INACTIVE or self.state == BackendState.CHECKPOINT_LOADED:
        self._activate()
        self.state = BackendState.ACTIVE

build

build(module, graph_spec, data, device, cache_dir)

Build the model with the given arguments.

Building a backend should be idempotent i.e. do not cause side effects. A model is not necessarily pure functional and can have an internal state (like kv cache for LLMs). That is why build can call a sample of inputs at most once so that subsequent calls have exact same state as the first call for the given sample.

After building, the backend should be activated.

Source code in aitune/torch/backend/backend.py
def build(
    self,
    module: nn.Module,
    graph_spec: GraphSpec,
    data: list[Sample],
    device: torch.device,
    cache_dir: Path,
) -> "Backend":
    """Build the model with the given arguments.

    Building a backend should be idempotent i.e. do not cause side effects. A model is not necessarily pure
    functional and can have an internal state (like kv cache for LLMs). That is why build can call a sample of
    inputs at most once so that subsequent calls have exact same state as the first call for the given sample.

    After building, the backend should be activated.
    """
    if self.state == BackendState.INIT:
        try:
            self._assert_device(device)
            self._set_device(device)
            ready_backend = self._build(module, graph_spec, data, cache_dir)
            self.state = BackendState.ACTIVE
            return ready_backend
        except Exception as e:
            self._logger.error("Failed to build backend(%s): %s", self.__class__.__name__, e, exc_info=True)
            raise e
    else:
        raise RuntimeError(f"Backend {self.name} build should be called only once")

deactivate

deactivate()

Deactivates backend.

After deactivating, the backend cannot be used to do inference.

Source code in aitune/torch/backend/backend.py
def deactivate(self):
    """Deactivates backend.

    After deactivating, the backend cannot be used to do inference.
    """
    if self.state == BackendState.INIT:
        raise RuntimeError(f"Cannot deactivate backend {self.name}, backend should be built first")
    if self.state == BackendState.DEPLOYED:
        raise RuntimeError(f"Cannot deactivate backend {self.name}, backend is already deployed")
    if self.state == BackendState.CHECKPOINT_LOADED:
        raise RuntimeError(f"Cannot deactivate backend {self.name}, backend has already been deployed")

    if self.state == BackendState.ACTIVE:
        self._deactivate()
        self._clean_memory()
        self.state = BackendState.INACTIVE

deploy

deploy(device)

Deploys the backend.

After deploying, the backend is ready to do inference. Backend cannot be deactivated anymore.

Parameters:

  • device (device | None) –

    The device to deploy the backend on.

Source code in aitune/torch/backend/backend.py
def deploy(self, device: torch.device | None):
    """Deploys the backend.

    After deploying, the backend is ready to do inference. Backend cannot be deactivated anymore.

    Args:
        device: The device to deploy the backend on.
    """
    if self.state != BackendState.CHECKPOINT_LOADED:
        raise RuntimeError(f"Cannot deploy backend {self.name}, backend should be loaded from a checkpoint")

    self._set_device(device)
    self._deploy()
    self.state = BackendState.DEPLOYED

describe

describe()

Returns the description of the backend.

Source code in aitune/torch/backend/torch_inductor_backend.py
def describe(self) -> str:
    """Returns the description of the backend."""
    return f"{self.__class__.__name__}({self._config.describe()})"

from_dict classmethod

from_dict(module, state_dict)

Creates a backend from a state_dict.

Source code in aitune/torch/backend/torch_inductor_backend.py
@classmethod
def from_dict(cls, module: torch.nn.Module | None, state_dict: dict):
    """Creates a backend from a state_dict."""
    if state_dict.get(cls.STATE_TYPE) != cls.__name__:
        raise ValueError(f"Invalid state_dict type: {state_dict.get(cls.STATE_TYPE)}")

    if module is None:
        raise ValueError("Module is required to create a backend from a state_dict.")

    config = TorchInductorBackendConfig.from_dict(state_dict[cls.STATE_CONFIG])

    backend = cls(config=config)
    backend._output_dtype = state_dict[cls.STATE_OUTPUT_DTYPE]
    backend._data = state_dict[cls.STATE_DATA]
    backend._device = state_dict[cls.STATE_DEVICE]
    backend._orig_module = module
    module.load_state_dict(state_dict[cls.STATE_ORIG_MODULE], strict=False)
    backend.state = BackendState.CHECKPOINT_LOADED
    return backend

infer

infer(*args, **kwargs)

Run inference with the given arguments.

Parameters:

  • args (Any, default: () ) –

    Variable length argument list.

  • kwargs (Any, default: {} ) –

    Arbitrary keyword arguments.

Returns:

  • Any ( Any ) –

    The result of the inference.

Source code in aitune/torch/backend/backend.py
def infer(self, *args: Any, **kwargs: Any) -> Any:
    """Run inference with the given arguments.

    Args:
        args: Variable length argument list.
        kwargs: Arbitrary keyword arguments.

    Returns:
        Any: The result of the inference.
    """
    if self.state != BackendState.ACTIVE and self.state != BackendState.DEPLOYED:
        raise RuntimeError(f"Cannot run inference, backend {self.name} should be activated first")

    return self._infer(*args, **kwargs)

key

key()

Returns the key of the backend.

Source code in aitune/torch/backend/torch_inductor_backend.py
def key(self) -> str:
    """Returns the key of the backend."""
    return f"{self.__class__.__name__}_{self._config.key()}"

to_dict

to_dict()

Returns the state_dict of the backend.

Source code in aitune/torch/backend/torch_inductor_backend.py
def to_dict(self):
    """Returns the state_dict of the backend."""
    if not self._orig_module:
        raise RuntimeError("Backend has not been properly initialized. Please call build() first.")

    return {
        self.STATE_TYPE: self.__class__.__name__,
        self.STATE_CONFIG: self._config.to_dict(),
        self.STATE_OUTPUT_DTYPE: self._output_dtype,
        self.STATE_DATA: self._data,
        self.STATE_ORIG_MODULE: self._orig_module.state_dict(),
        self.STATE_DEVICE: self._device,
    }

TorchInductorBackendConfig

aitune.torch.backend.TorchInductorBackendConfig dataclass

TorchInductorBackendConfig(fullgraph=False, dynamic=None, mode=None, options=None, autocast_enabled=False, autocast_dtype=None)

Bases: BackendConfig

Configuration for torch.compile with inductor backend.

Parameters:

  • fullgraph (bool, default: False ) –

    If False (default), torch.compile attempts to discover compileable regions in the function it will tune. If True, then we require the entire function to be captured into a single graph. If this is not possible (that is, if there are graph breaks), then this will raise an error.

  • dynamic (bool or None, default: None ) –

    Use dynamic shape tracing. When this is True, we will up-front attempt to generate a kernel that is as dynamic as possible to avoid recompilations when sizes change. This may not always work as some operations/optimizations will force specialization; use TORCH_LOGS=dynamic to debug overspecialization. When this is False, we will NEVER generate dynamic kernels, we will always specialize. By default (None), we automatically detect if dynamism has occurred and compile a more dynamic kernel upon recompile.

  • mode (str, default: None ) –

    Can be either "default", "reduce-overhead", "max-autotune" or "max-autotune-no-cudagraphs"

    • "default" is the default mode, which is a good balance between performance and overhead

    • "reduce-overhead" is a mode that reduces the overhead of python with CUDA graphs, useful for small batches. Reduction of overhead can come at the cost of more memory usage, as we will cache the workspace memory required for the invocation so that we do not have to reallocate it on subsequent runs. Reduction of overhead is not guaranteed to work; today, we only reduce overhead for CUDA only graphs which do not mutate inputs. There are other circumstances where CUDA graphs are not applicable; use TORCH_LOG=perf_hints to debug.

    • "max-autotune" is a mode that leverages Triton or template based matrix multiplications on supported devices and Triton based convolutions on GPU. It enables CUDA graphs by default on GPU.

    • "max-autotune-no-cudagraphs" is a mode similar to "max-autotune" but without CUDA graphs

    • To see the exact configs that each mode sets you can call torch._inductor.list_mode_options()

  • options (dict, default: None ) –

    A dictionary of options to pass to the backend. - To see the full list of configs that it supports by calling torch._inductor.list_options()

  • autocast_enabled (bool, default: False ) –

    If True, enable autocast.

  • autocast_dtype (dtype, default: None ) –

    The dtype to use for autocast.

__post_init__

__post_init__()

Post init.

Source code in aitune/torch/backend/torch_inductor_backend.py
def __post_init__(self):
    """Post init."""
    # Check that mode and options are not both supplied in config
    if self.mode is not None and self.options is not None:
        raise ValueError(
            "Cannot specify both 'mode' and 'options' parameters in config. "
            "Use either 'mode' for predefined configurations or 'options' "
            "for custom configurations, but not both."
        )

describe

describe()

Describe the backend configuration. Display only changed fields.

Source code in aitune/torch/backend/backend.py
def describe(self) -> str:
    """Describe the backend configuration. Display only changed fields."""
    default = self.__class__()
    changed_fields = self._get_changed_fields(self, default)
    return ",".join(changed_fields)

from_dict classmethod

from_dict(state_dict)

Convert dict to TorchInductorBackendConfig.

Source code in aitune/torch/backend/torch_inductor_backend.py
@classmethod
def from_dict(cls, state_dict: dict):
    """Convert dict to TorchInductorBackendConfig."""
    return cls(**state_dict)

key

key()

Returns the keys of the backend configuration.

Source code in aitune/torch/backend/backend.py
def key(self) -> str:
    """Returns the keys of the backend configuration."""
    config_dict = self.to_dict()
    config_dict = self._to_json(config_dict)
    config_dict_str = json.dumps(config_dict)
    key = hash_string(config_dict_str)
    return key

to_dict

to_dict()

Returns the state_dict of the backend.

Source code in aitune/torch/backend/backend.py
def to_dict(self):
    """Returns the state_dict of the backend."""
    return asdict(self)

to_json

to_json(path)

Saves the backend configuration to a file.

Source code in aitune/torch/backend/backend.py
def to_json(self, path: Path):
    """Saves the backend configuration to a file."""
    config_dict = self.to_dict()
    config_dict = self._to_json(config_dict)
    with open(path, "w") as f:
        json.dump(config_dict, f, indent=2)