Skip to content

AOT Inspect API

inspect

aitune.torch.inspect

inspect(obj, dataset, inference_function=None, number_of_iterations=DEFAULT_INSPECT_ITERATIONS, warmup_iterations=DEFAULT_WARMUP_ITERATIONS, min_depth=0, max_depth=5)

Inspect provided callable object searching for nn.Module members executed as part of forward pass.

Parameters:

  • obj (Callable | Module) –

    Callable object to inspect.

  • dataset (DatasetLike | DataLoaderFactory | Tensor) –

    List of tuples with batch size and input.

  • inference_function (Callable | None, default: None ) –

    Custom inference function to use for inspection, obj is used by default.

  • number_of_iterations (int, default: DEFAULT_INSPECT_ITERATIONS ) –

    Number of iterations to run for inference.

  • warmup_iterations (int, default: DEFAULT_WARMUP_ITERATIONS ) –

    Number of iterations to run for warmup.

  • min_depth (int, default: 0 ) –

    Minimum depth of the modules to inspect, if root level modules is not working, try to increase this value

  • max_depth (int, default: 5 ) –

    Maximum depth of the modules to inspect

Returns: InspectedModulesInfo object.

Source code in aitune/torch/inspecting/inspecting.py
def inspect(
    obj: Callable | torch.nn.Module,
    dataset: DatasetLike | DataLoaderFactory | torch.Tensor,
    inference_function: Callable | None = None,
    number_of_iterations: int = DEFAULT_INSPECT_ITERATIONS,
    warmup_iterations: int = DEFAULT_WARMUP_ITERATIONS,
    min_depth: int = 0,
    max_depth: int = 5,
) -> InspectedModulesInfo:
    """Inspect provided callable object searching for nn.Module members executed as part of forward pass.

    Args:
        obj: Callable object to inspect.
        dataset: List of tuples with batch size and input.
        inference_function: Custom inference function to use for inspection, obj is used by default.
        number_of_iterations: Number of iterations to run for inference.
        warmup_iterations: Number of iterations to run for warmup.
        min_depth: Minimum depth of the modules to inspect, if root level modules is not working, try to increase this value
        max_depth: Maximum depth of the modules to inspect
    Returns:
        InspectedModulesInfo object.
    """
    setup_logging(format_string=LOG_FORMAT)

    logger.info("Inspecting object searching for executed nn.Module members.")
    model_inspector = ModuleInspector(min_depth=min_depth, max_depth=max_depth)
    model_inspector.inspect(obj)

    # If no inference function is provided, use the obj
    if inference_function is None:
        inference_function = obj

    dataset = ensure_enough_samples(dataset, max(number_of_iterations, warmup_iterations))

    # Warmup, run the model on a few samples, to make sure the model is compiled and the cache is warm
    _warmup(inference_function, dataset, warmup_iterations)
    _reset_total_execution_time(model_inspector)

    # Run the model on the dataset for the number of iterations
    total_execution_time = 0.0
    for _, args, kwargs in samples_generator(dataset, [1], max_num_batches_per_batch_size=number_of_iterations):
        synchronize()
        start_time = time.perf_counter()
        with torch.no_grad():
            inference_function(*args, **kwargs)
        synchronize()
        end_time = time.perf_counter()
        total_execution_time += end_time - start_time

    modules = model_inspector.get_modules()

    logger.info("Inspection done. Found %d candidate modules for tuning.", len(modules))

    inspected_modules_info = InspectedModulesInfo(total_execution_time, number_of_iterations)
    for module in modules:
        inspected_modules_info.add_module(module)

    model_inspector.reset()

    return inspected_modules_info

wrap

aitune.torch.wrap

wrap(obj, modules, strategy=None, strategies=None)

Wrap provided modules with inspection logic.

Parameters:

  • obj (object) –

    Callable object to wrap.

  • modules (list[ModuleInfo]) –

    Dictionary of module names and their corresponding ModuleInfo objects.

  • strategy (TuneStrategy | None, default: None ) –

    Strategy to use for patching.

  • strategies (StrategyList | StrategyMap | None, default: None ) –

    Strategies to use for patching.

Returns:

  • object

    Wrapped callable object.

Source code in aitune/torch/inspecting/wrapping.py
def wrap(
    obj: object,
    modules: list[ModuleInfo],
    strategy: TuneStrategy | None = None,
    strategies: StrategyList | StrategyMap | None = None,
) -> object:
    """Wrap provided modules with inspection logic.

    Args:
        obj: Callable object to wrap.
        modules: Dictionary of module names and their corresponding ModuleInfo objects.
        strategy: Strategy to use for patching.
        strategies: Strategies to use for patching.

    Returns:
        Wrapped callable object.
    """
    setup_logging(format_string=LOG_FORMAT)
    MODULE_REGISTRY.clear()
    for module_info in modules:
        logger.info("Wrapping module: %s", module_info.object_path or module_info.name or obj.__class__.__name__)
        if module_info.parent is None:
            return Module(obj, name=obj.__class__.__name__, strategy=strategy, strategies=strategies)

        ait_module = Module(module_info.module, name=module_info.name, strategy=strategy, strategies=strategies)
        module_info.parent.set_wrapped(module_info.name, ait_module)

    return obj

InspectedModulesInfo

aitune.torch.inspecting.InspectedModulesInfo

InspectedModulesInfo(total_execution_time, number_of_batches)

Information about inspected modules.

Initialize the inspected modules specification.

Source code in aitune/torch/inspecting/module_info.py
def __init__(self, total_execution_time: float, number_of_batches: int):
    """Initialize the inspected modules specification."""
    self._modules: dict[str, ModuleInfo] = {}
    self._total_execution_time = total_execution_time
    self._number_of_batches = number_of_batches

add_module

add_module(module)

Add a module to the specification.

Parameters:

  • module (ModuleInfo) –

    ModuleInfo object.

Source code in aitune/torch/inspecting/module_info.py
def add_module(self, module: ModuleInfo):
    """Add a module to the specification.

    Args:
        module: ModuleInfo object.
    """
    if module.object_path in self._modules:
        raise ValueError(f"Module `{module.name}` in `{module.object_path}` already exists")
    self._modules[module.object_path] = module

describe

describe()

Describe the inspected modules specification.

Source code in aitune/torch/inspecting/module_info.py
def describe(self) -> None:
    """Describe the inspected modules specification."""
    print("Module Execution Summary:")  # noqa: T201
    print("=" * 138)  # noqa: T201
    print(  # noqa: T201
        f"{'Module Name':^20}  {'Calls':^8}  {'Total Time (s)':^15}  {'Avg Time (s)':^15}  {'% of Total':^10}  {'# of params':^15}  {'# of layers':^15}  {'precisions':^25}"
    )
    print("-" * 138)  # noqa: T201

    for info in self._modules.values():
        percentage = (
            (info.total_execution_time / self._total_execution_time) * 100 if self._total_execution_time > 0 else 0
        )

        precisions = ", ".join(str(p) for p in info.precisions)

        print(  # noqa: T201
            f"{info.name:<20}  "
            f"{info.execution_count:>8}  "
            f"{info.total_execution_time:>15.4f}  "
            f"{info.average_execution_time:>15.4f}  "
            f"{percentage:>10.2f}%  "
            f"{info.num_parameters:>15}  "
            f"{info.num_layers:>15}  "
            f"{precisions:>25}"
        )

    print("-" * 138)  # noqa: T201
    print(f"Total execution time: {self._total_execution_time:.6f} seconds")  # noqa: T201
    print(f"Number of batch iterations: {self._number_of_batches}")  # noqa: T201
    print("=" * 138)  # noqa: T201

get_modules

get_modules(min_execution_percentage=None, limit=None)

Get the list of modules.

Parameters:

  • min_execution_percentage (float | None, default: None ) –

    Minimum execution percentage to include a module.

  • limit (int | None, default: None ) –

    Maximum number of modules to return.

Returns:

  • list[ModuleInfo]

    List of ModuleInfo objects.

Source code in aitune/torch/inspecting/module_info.py
def get_modules(
    self, min_execution_percentage: float | None = None, limit: int | None = None
) -> list["ModuleInfo"]:
    """Get the list of modules.

    Args:
        min_execution_percentage: Minimum execution percentage to include a module.
        limit: Maximum number of modules to return.

    Returns:
        List of ModuleInfo objects.
    """
    modules = []
    sorted_modules = sorted(
        self._modules.values(),
        key=lambda x: (-x.total_execution_time, x.name or ""),  # Negative for reverse=True, empty string for None
    )
    for module in sorted_modules:
        if min_execution_percentage is None or (
            module.total_execution_time / self._total_execution_time >= min_execution_percentage
        ):
            modules.append(module)

    if limit is not None:
        modules = modules[:limit]

    return modules