Skip to content

AOT Tuning Module API

Module

aitune.torch.Module

Module(module, name=None, strategy=None, strategies=None)

Bases: CallableObjectProxy

AITune module wrapper.

This class wraps a torch module and provides tuning functionality. The module can be in 3 different states: - passthrough: the module is not tuned, and will behave identically to the original module. - recording: the module records samples and detects multi-graphs, this is a necessary step before tuning. - tuned: the module is tuned and uses underlying tuned module.

You can go from passthrough and recording with enablePassthrough and enableRecording methods.

You can go from tuned to passthrough/recording only if the force flag is True - then the module will be reset.

This wrapper can be used in place of a torch module, and will behave identically to the original module.

Example

model = torch.nn.Linear(10, 10) model = Module(model, "name")

Initializes module.

Parameters:

  • module (Module) –

    torch module to wrap.

  • name (str | None, default: None ) –

    module name which differences from other tuned modules.

  • strategy (TuneStrategy | None, default: None ) –

    optional strategy, which will be used for all encountered graphs

  • strategies (StrategyList | StrategyMap | None, default: None ) –

    list or dict of strategies, each for a graph

Source code in aitune/torch/module/wrapper_module.py
def __init__(
    self,
    module: torch.nn.Module,
    name: str | None = None,
    strategy: TuneStrategy | None = None,
    strategies: StrategyList | StrategyMap | None = None,
):
    """Initializes module.

    Args:
        module: torch module to wrap.
        name: module name which differences from other tuned modules.
        strategy: optional strategy, which will be used for all encountered graphs
        strategies: list or dict of strategies, each for a graph
    """
    super().__init__(module)
    self._self_name = sanitize_model_name(name) or get_object_name(module)
    self._self_orig_forward = module.forward
    self._original_forward_pre_hooks = module._forward_pre_hooks
    self._original_forward_hooks = module._forward_hooks
    self._self_proxy_forward = wrapt.decorator(self._forward)(module.forward)
    module.forward = self._self_proxy_forward
    self._setup_strategies(strategy, strategies)

    self._self_device = get_module_device(module) or torch.device("cpu")
    self._self_state = ModuleState.INIT
    self._self_wrapper = None
    self._self_prev_recording = None

    self._system_monitor = SystemMonitor()

    MODULE_REGISTRY.register(self._self_name, self)

device property

device

Get the device of the module.

Returns:

  • device

    The device the module is using.

graph_specs property

graph_specs

Multi-graphs of the module.

module property

module

Get the backends of the module.

name property

name

Get the name of the module.

state property

state

Get the state of the module.

__getitem__

__getitem__(key)

Delegate getitem calls to the wrapped module.

This allows the proxy to handle indexing operations on the wrapped module, which is particularly useful for Sequential modules and other indexable modules.

Parameters:

  • key (Any) –

    The index or key to use for accessing the wrapped module.

Returns:

  • Any

    The result of accessing the wrapped module with the given key.

Source code in aitune/torch/module/wrapper_module.py
def __getitem__(self, key: Any) -> Any:
    """Delegate __getitem__ calls to the wrapped module.

    This allows the proxy to handle indexing operations on the wrapped module,
    which is particularly useful for Sequential modules and other indexable modules.

    Args:
        key: The index or key to use for accessing the wrapped module.

    Returns:
        The result of accessing the wrapped module with the given key.
    """
    return self.__wrapped__[key]

activate

activate()

Activates the module backends.

Source code in aitune/torch/module/wrapper_module.py
def activate(self):
    """Activates the module backends."""
    if self._self_state == ModuleState.TUNED:
        self._activate_wrapper()
    elif self._self_state == ModuleState.PASSTHROUGH:
        return
    else:
        raise RuntimeError("Module is not tuned. Cannot activate backends.")

deactivate

deactivate()

Deactivates the module backends.

Source code in aitune/torch/module/wrapper_module.py
def deactivate(self):
    """Deactivates the module backends."""
    if self._self_state == ModuleState.TUNED:
        wrapper = cast(TunedModule, self._self_wrapper)
        wrapper.deactivate()

enable_passthrough

enable_passthrough(force=False)

Enables passthrough mode.

Parameters:

  • force (bool, default: False ) –

    if True, force the module to be in the passthrough mode.

Source code in aitune/torch/module/wrapper_module.py
def enable_passthrough(self, force: bool = False):
    """Enables passthrough mode.

    Args:
        force: if True, force the module to be in the passthrough mode.
    """
    if self._self_state == ModuleState.TUNED:
        if force:
            self._reset()
            self.enable_passthrough()
        else:
            raise RuntimeError("Module is already tuned. Use force=True to reset tuned module.")
    elif self._self_state == ModuleState.RECORDING:
        self._self_prev_recording = self._self_wrapper

    self._deactivate_wrapper()
    self._self_state = ModuleState.PASSTHROUGH
    self._self_wrapper = PassthroughModule(
        self.__wrapped__,
        self._self_device,
    )

enable_recording

enable_recording(force=False)

Enables recording mode.

Parameters:

  • force (bool, default: False ) –

    if True, force the module to be in the recording mode.

Source code in aitune/torch/module/wrapper_module.py
def enable_recording(self, force: bool = False):
    """Enables recording mode.

    Args:
        force: if True, force the module to be in the recording mode.
    """
    if self._self_state == ModuleState.TUNED:
        if force:
            self._reset()
            self.enable_recording()
        else:
            raise RuntimeError("Module is already tuned. Use force=True to reset tuned module.")
    elif self._self_state in [ModuleState.PASSTHROUGH, ModuleState.INIT]:
        self._deactivate_wrapper()
        self._self_state = ModuleState.RECORDING
        if self._self_prev_recording is None:
            self._self_wrapper = RecordingModule(
                self.__wrapped__,
                self._self_name,
            )
        else:
            # continue with previous recording
            self._self_wrapper = self._self_prev_recording
            self._self_prev_recording = None

from_dict staticmethod

from_dict(module, state_dict, device=None)

Create a wrapper module from a state dictionary.

Source code in aitune/torch/module/wrapper_module.py
@staticmethod
def from_dict(module: torch.nn.Module, state_dict: dict, device: torch.device | None = None):
    """Create a wrapper module from a state dictionary."""
    if not Module.is_state_dict_valid(state_dict):
        raise ValueError(f"Invalid dictionary format for {Module.__class__.__name__}")

    self = Module(module, state_dict[Module.NAME_KEY])
    self._self_state = ModuleState.TUNED
    self._self_wrapper = TunedModule.from_dict(module, state_dict[Module.TUNED_MODULE_KEY])
    self._deploy_wrapper(device)
    return self

is_state_dict_valid staticmethod

is_state_dict_valid(state_dict)

Check if the state dictionary has a wrapper module.

Source code in aitune/torch/module/wrapper_module.py
@staticmethod
def is_state_dict_valid(state_dict: dict):
    """Check if the state dictionary has a wrapper module."""
    return state_dict and Module.__name__ in state_dict[Module.TYPE_KEY]

to_dict

to_dict(*args, destination=None, prefix='', keep_vars=False)

Convert the wrapper module to a state dictionary.

The method follows torch.state_dict format so that it could be called recursively from torch.nn.Module.state_dict. The name of the function follows aitune convention of from/to dict. Special alias is created to match torch convention i.e. state_dict.

Source code in aitune/torch/module/wrapper_module.py
def to_dict(self, *args, destination=None, prefix="", keep_vars=False):
    """Convert the wrapper module to a state dictionary.

    The method follows torch.state_dict format so that it could be called recursively from torch.nn.Module.state_dict.
    The name of the function follows aitune convention of from/to dict. Special alias is created to match torch
    convention i.e. state_dict.
    """
    if self._self_state != ModuleState.TUNED:
        raise RuntimeError("Module is not tuned. Cannot save state_dict.")
    destination = OrderedDict() if destination is None else destination

    destination[prefix] = {
        self.TYPE_KEY: Module.__name__,  # self.__class__ returns wrapped class, use direct class object instead
        self.NAME_KEY: self._self_name,
        self.TUNED_MODULE_KEY: self._self_wrapper.to_dict(),  # pytype: disable=attribute-error
    }
    return destination

tune

tune(device=None, strategy=None, dry_run=False)

Tunes the module.

Parameters:

  • device (device | None, default: None ) –

    the device to use for tuning

  • strategy (TuneStrategy | None, default: None ) –

    the tuning strategy to use

  • dry_run (bool, default: False ) –

    if True, only dry run the tuning

Note

if module has already defined strategy/strategies those will take precedence over the provided one.

Source code in aitune/torch/module/wrapper_module.py
def tune(
    self,
    device: torch.device | None = None,
    strategy: TuneStrategy | None = None,
    dry_run: bool = False,
):
    """Tunes the module.

    Args:
        device: the device to use for tuning
        strategy: the tuning strategy to use
        dry_run: if True, only dry run the tuning

    Note:
        if module has already defined strategy/strategies those will take precedence over the provided one.
    """
    if self._self_state == ModuleState.TUNED:
        raise ValueError(f"Module: '{self._self_name}' has already been tuned. Reset it to do it again.")
    elif self._self_state == ModuleState.INIT or self._self_state == ModuleState.PASSTHROUGH:
        raise ValueError(f"Module: '{self._self_name}' has not recorded any samples. Cannot tune it.")

    device = device or self._self_device

    recording = cast(RecordingModule, self._self_wrapper)
    backends: OrderedDict[SampleMetadata, Backend] = OrderedDict()
    strategies = self._get_strategies_for_graph_specs(strategy, recording.graph_specs, dry_run)

    for strategy, graph_spec in zip(strategies, recording.graph_specs, strict=True):
        cache_dir = self._create_graph_cache_dir(graph_spec)

        data = recording.samples_for_graph_spec(graph_spec)
        if dry_run:
            strategy.tune_dry_run(self.__wrapped__, self._self_name, graph_spec, data, device, cache_dir)
        else:
            try:
                self._restore_original_forward()
                backends[graph_spec.input_spec] = strategy.tune(
                    self.__wrapped__, self._self_name, graph_spec, data, device, cache_dir
                )
            finally:
                self._proxy_forward()

    if not dry_run:
        self._self_prev_recording = recording
        self._self_state = ModuleState.TUNED
        self._self_wrapper = TunedModule(backends)
        self._offload(backends)