Module(module, name=None, strategy=None, strategies=None)
Bases: CallableObjectProxy
AITune module wrapper.
This class wraps a torch module and provides tuning functionality.
The module can be in 3 different states:
- passthrough: the module is not tuned, and will behave identically to the original module.
- recording: the module records samples and detects multi-graphs, this is a necessary step before tuning.
- tuned: the module is tuned and uses underlying tuned module.
You can go from passthrough and recording with enablePassthrough and enableRecording methods.
You can go from tuned to passthrough/recording only if the force flag is True - then the module will be reset.
This wrapper can be used in place of a torch module, and will behave identically to the original module.
Example
model = torch.nn.Linear(10, 10)
model = Module(model, "name")
Initializes module.
Parameters:
-
module
(Module)
–
-
name
(str | None, default:
None
)
–
module name which differences from other tuned modules.
-
strategy
(TuneStrategy | None, default:
None
)
–
optional strategy, which will be used for all encountered graphs
-
strategies
(StrategyList | StrategyMap | None, default:
None
)
–
list or dict of strategies, each for a graph
Source code in aitune/torch/module/wrapper_module.py
| def __init__(
self,
module: torch.nn.Module,
name: str | None = None,
strategy: TuneStrategy | None = None,
strategies: StrategyList | StrategyMap | None = None,
):
"""Initializes module.
Args:
module: torch module to wrap.
name: module name which differences from other tuned modules.
strategy: optional strategy, which will be used for all encountered graphs
strategies: list or dict of strategies, each for a graph
"""
super().__init__(module)
self._self_name = sanitize_model_name(name) or get_object_name(module)
self._self_orig_forward = module.forward
self._original_forward_pre_hooks = module._forward_pre_hooks
self._original_forward_hooks = module._forward_hooks
self._self_proxy_forward = wrapt.decorator(self._forward)(module.forward)
module.forward = self._self_proxy_forward
self._setup_strategies(strategy, strategies)
self._self_device = get_module_device(module) or torch.device("cpu")
self._self_state = ModuleState.INIT
self._self_wrapper = None
self._self_prev_recording = None
self._system_monitor = SystemMonitor()
MODULE_REGISTRY.register(self._self_name, self)
|
device
property
Get the device of the module.
Returns:
-
device
–
The device the module is using.
graph_specs
property
Multi-graphs of the module.
module
property
Get the backends of the module.
name
property
Get the name of the module.
state
property
Get the state of the module.
__getitem__
Delegate getitem calls to the wrapped module.
This allows the proxy to handle indexing operations on the wrapped module,
which is particularly useful for Sequential modules and other indexable modules.
Parameters:
-
key
(Any)
–
The index or key to use for accessing the wrapped module.
Returns:
-
Any
–
The result of accessing the wrapped module with the given key.
Source code in aitune/torch/module/wrapper_module.py
| def __getitem__(self, key: Any) -> Any:
"""Delegate __getitem__ calls to the wrapped module.
This allows the proxy to handle indexing operations on the wrapped module,
which is particularly useful for Sequential modules and other indexable modules.
Args:
key: The index or key to use for accessing the wrapped module.
Returns:
The result of accessing the wrapped module with the given key.
"""
return self.__wrapped__[key]
|
activate
Activates the module backends.
Source code in aitune/torch/module/wrapper_module.py
| def activate(self):
"""Activates the module backends."""
if self._self_state == ModuleState.TUNED:
self._activate_wrapper()
elif self._self_state == ModuleState.PASSTHROUGH:
return
else:
raise RuntimeError("Module is not tuned. Cannot activate backends.")
|
deactivate
Deactivates the module backends.
Source code in aitune/torch/module/wrapper_module.py
| def deactivate(self):
"""Deactivates the module backends."""
if self._self_state == ModuleState.TUNED:
wrapper = cast(TunedModule, self._self_wrapper)
wrapper.deactivate()
|
enable_passthrough
enable_passthrough(force=False)
Enables passthrough mode.
Parameters:
-
force
(bool, default:
False
)
–
if True, force the module to be in the passthrough mode.
Source code in aitune/torch/module/wrapper_module.py
| def enable_passthrough(self, force: bool = False):
"""Enables passthrough mode.
Args:
force: if True, force the module to be in the passthrough mode.
"""
if self._self_state == ModuleState.TUNED:
if force:
self._reset()
self.enable_passthrough()
else:
raise RuntimeError("Module is already tuned. Use force=True to reset tuned module.")
elif self._self_state == ModuleState.RECORDING:
self._self_prev_recording = self._self_wrapper
self._deactivate_wrapper()
self._self_state = ModuleState.PASSTHROUGH
self._self_wrapper = PassthroughModule(
self.__wrapped__,
self._self_device,
)
|
enable_recording
enable_recording(force=False)
Enables recording mode.
Parameters:
-
force
(bool, default:
False
)
–
if True, force the module to be in the recording mode.
Source code in aitune/torch/module/wrapper_module.py
| def enable_recording(self, force: bool = False):
"""Enables recording mode.
Args:
force: if True, force the module to be in the recording mode.
"""
if self._self_state == ModuleState.TUNED:
if force:
self._reset()
self.enable_recording()
else:
raise RuntimeError("Module is already tuned. Use force=True to reset tuned module.")
elif self._self_state in [ModuleState.PASSTHROUGH, ModuleState.INIT]:
self._deactivate_wrapper()
self._self_state = ModuleState.RECORDING
if self._self_prev_recording is None:
self._self_wrapper = RecordingModule(
self.__wrapped__,
self._self_name,
)
else:
# continue with previous recording
self._self_wrapper = self._self_prev_recording
self._self_prev_recording = None
|
from_dict
staticmethod
from_dict(module, state_dict, device=None)
Create a wrapper module from a state dictionary.
Source code in aitune/torch/module/wrapper_module.py
| @staticmethod
def from_dict(module: torch.nn.Module, state_dict: dict, device: torch.device | None = None):
"""Create a wrapper module from a state dictionary."""
if not Module.is_state_dict_valid(state_dict):
raise ValueError(f"Invalid dictionary format for {Module.__class__.__name__}")
self = Module(module, state_dict[Module.NAME_KEY])
self._self_state = ModuleState.TUNED
self._self_wrapper = TunedModule.from_dict(module, state_dict[Module.TUNED_MODULE_KEY])
self._deploy_wrapper(device)
return self
|
is_state_dict_valid
staticmethod
is_state_dict_valid(state_dict)
Check if the state dictionary has a wrapper module.
Source code in aitune/torch/module/wrapper_module.py
| @staticmethod
def is_state_dict_valid(state_dict: dict):
"""Check if the state dictionary has a wrapper module."""
return state_dict and Module.__name__ in state_dict[Module.TYPE_KEY]
|
to_dict
to_dict(*args, destination=None, prefix='', keep_vars=False)
Convert the wrapper module to a state dictionary.
The method follows torch.state_dict format so that it could be called recursively from torch.nn.Module.state_dict.
The name of the function follows aitune convention of from/to dict. Special alias is created to match torch
convention i.e. state_dict.
Source code in aitune/torch/module/wrapper_module.py
| def to_dict(self, *args, destination=None, prefix="", keep_vars=False):
"""Convert the wrapper module to a state dictionary.
The method follows torch.state_dict format so that it could be called recursively from torch.nn.Module.state_dict.
The name of the function follows aitune convention of from/to dict. Special alias is created to match torch
convention i.e. state_dict.
"""
if self._self_state != ModuleState.TUNED:
raise RuntimeError("Module is not tuned. Cannot save state_dict.")
destination = OrderedDict() if destination is None else destination
destination[prefix] = {
self.TYPE_KEY: Module.__name__, # self.__class__ returns wrapped class, use direct class object instead
self.NAME_KEY: self._self_name,
self.TUNED_MODULE_KEY: self._self_wrapper.to_dict(), # pytype: disable=attribute-error
}
return destination
|
tune
tune(device=None, strategy=None, dry_run=False)
Tunes the module.
Parameters:
-
device
(device | None, default:
None
)
–
the device to use for tuning
-
strategy
(TuneStrategy | None, default:
None
)
–
the tuning strategy to use
-
dry_run
(bool, default:
False
)
–
if True, only dry run the tuning
Note
if module has already defined strategy/strategies those will take precedence over the provided one.
Source code in aitune/torch/module/wrapper_module.py
| def tune(
self,
device: torch.device | None = None,
strategy: TuneStrategy | None = None,
dry_run: bool = False,
):
"""Tunes the module.
Args:
device: the device to use for tuning
strategy: the tuning strategy to use
dry_run: if True, only dry run the tuning
Note:
if module has already defined strategy/strategies those will take precedence over the provided one.
"""
if self._self_state == ModuleState.TUNED:
raise ValueError(f"Module: '{self._self_name}' has already been tuned. Reset it to do it again.")
elif self._self_state == ModuleState.INIT or self._self_state == ModuleState.PASSTHROUGH:
raise ValueError(f"Module: '{self._self_name}' has not recorded any samples. Cannot tune it.")
device = device or self._self_device
recording = cast(RecordingModule, self._self_wrapper)
backends: OrderedDict[SampleMetadata, Backend] = OrderedDict()
strategies = self._get_strategies_for_graph_specs(strategy, recording.graph_specs, dry_run)
for strategy, graph_spec in zip(strategies, recording.graph_specs, strict=True):
cache_dir = self._create_graph_cache_dir(graph_spec)
data = recording.samples_for_graph_spec(graph_spec)
if dry_run:
strategy.tune_dry_run(self.__wrapped__, self._self_name, graph_spec, data, device, cache_dir)
else:
try:
self._restore_original_forward()
backends[graph_spec.input_spec] = strategy.tune(
self.__wrapped__, self._self_name, graph_spec, data, device, cache_dir
)
finally:
self._proxy_forward()
if not dry_run:
self._self_prev_recording = recording
self._self_state = ModuleState.TUNED
self._self_wrapper = TunedModule(backends)
self._offload(backends)
|