Skip to content

Ahead-of-Time Tuning Guide

Ahead-of-time tuning is a mode where you explicitly control which modules to tune. This method provides precise control over the tuning process and is recommended for production environments.

Overview

Ahead-of-time tuning follows a four-step workflow:

  1. Inspect: Analyze your model or pipeline to identify tuneable modules
  2. Wrap: Wrap selected modules for tuning
  3. Tune: Execute the tuning process across different backends
  4. Persist: Save and load tuned models for later deployment

This approach offers several advantages:

  • Control: Explicitly choose which modules to tune, pick strategies and backends, and mix different technologies
  • Performance: Benchmark and select optimal configurations
  • Speed: Save the tuned model to a deployable artifact to be loaded on the production environment
  • Reproducibility: Deterministic tuning results

Quick Start

Here's a complete example using Stable Diffusion:

import aitune.torch as ait
from diffusers import DiffusionPipeline

# Initialize pipeline
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers")
pipe.to("cuda")

# Prepare input data
input_data = [{"prompt": "A beautiful landscape with mountains and a lake"}]

# Step 1: Inspect pipeline to discover modules
modules_info = ait.inspect(pipe, input_data)

# Display discovered modules
modules_info.describe()

# Step 2: Wrap modules for tuning
modules = modules_info.get_modules()
pipe = ait.wrap(pipe, modules)

# Step 3: Tune the pipeline
ait.tune(pipe, input_data)

# Step 4: Save the tuned pipeline
ait.save(pipe, "tuned_pipe.ait")

# Use the tuned pipeline
images = pipe(["A beautiful landscape with mountains and a lake"])

Detailed Workflow

1. Inspection Phase

The inspect function analyzes your model or pipeline to identify PyTorch modules that can be tuned. For a detailed guide on inspection, see the AOT Inspect Guide.

2. Wrapping Phase

Given the list of modules from the previous step, you can wrap them for tuning. Under the hood, each torch.nn.Module is wrapped (imagine a proxy object) with AITune Module which intercepts all forward calls to get data, tune the module and serve the tuned version.

The following line shows how to wrap modules.

model = ait.wrap(model, modules)

You can also specify tuning strategies during wrapping:

import aitune.torch as ait

strategy = ait.OneBackendStrategy(backend=ait.backend.TensorRTBackend())
model = ait.wrap(model, modules, strategy=strategy)

If you would like to have more control over picking the modules, you can manually wrap torch.nn.Module. When wrapping, you can specify a strategy for each module separately; i.e., you can combine different strategies backends into one model.

pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers")
pipe.to("cuda")

pipe.unet = ait.Module(pipe.unet, strategy=strategy_for_unet)
pipe.transformer = ait.Module(pipe.transformer, strategy=strategy_for_transformer)

3. Tuning Phase

The tune function executes the actual tuning:

ait.tune(
    func=model,                   # The wrapped callable module or pipeline to tune
    dataset=input_data,           # Dataset to use for tuning (list, Dataset, DataLoaderFactory, or Tensor)
    batch_sizes=[1, 4, 8],         # Optional: Multiple batch sizes. Defaults to [1, 2]
    max_num_batches_per_batch_size=10,  # Max batches per size. Defaults to None (all)
    device="cuda",                  # Device for tuning. Defaults to "cuda:0"
    dry_run=False,                  # Set True to test without tuning
    disable_external_logging=False, # Disable third-party logs
    clear_cache=False               # Clear AITune cache before tuning
)

Tuning Parameters

  • func: The wrapped callable (model or pipeline)
  • dataset: Dataset for tuning. It can be a list of samples, torch.utils.data.Dataset, DataLoaderFactory, Tensor or sequence of tensors, dictionaries, strings
  • batch_sizes: List of batch sizes to tune against. If not specified, values [1, 2] will be used
  • max_num_batches_per_batch_size: Maximum number of batches per batch size. If None, all batches will be used
  • device: Device to use for tuning. Defaults to "cuda:0"
  • dry_run: If True, performs a dry run without actual tuning
  • disable_external_logging: Disable logging from external libraries
  • clear_cache: Clear AITune cache before tuning

Tuning time depends on the tuned modules' size, used strategy, and number of backends. Modules are tuned one by one. If a strategy has many backends to pick from, it takes the one that fulfills specific strategy criteria. Each backend is validated against returning proper numeric results (check against NANs and infinity) and output shapes.

Note: If you specify a batch size that is not a power of 2, it will be used to gather samples but the actual search for the highest throughput will round it up to the nearest power of 2.

4. Persistence Phase

Once tuned, you can save your model for later use. This is crucial for production deployments to avoid re-tuning every time:

# Save the tuned model/pipeline
ait.save(pipe, "tuned_model.ait")

The tuned artifact will be saved in the checkpoints folder. The save function creates several files:

  • tuned_model.ait: The compressed checkpoint containing tuned and original weights
  • tuned_model_sha256_sums.txt: SHA256 hashes for verification

To do inference, you can load the tuned model/pipeline:

# Note: Initializing the original object is required before loading
pipe = DiffusionPipeline.from_pretrained(...)
pipe = ait.load(pipe, "tuned_model.ait")
# pipe is ready for use

Custom Inference Functions

For complex pipelines, you can provide a custom inference function:

def custom_inference(prompt, num_steps=50):
    """Function forces width, height and number of steps."""
    return pipe(
        prompt=prompt,
        num_inference_steps=num_steps,
        height=1024,
        width=1024,
    )

modules_info = ait.inspect(
    pipe,
    input_data,
    inference_function=custom_inference
)

Configuration Options

AITune has configuration for the tuning process, and each backend has its configuration.

Global Configuration

You can configure AITune globally:

from aitune.torch import config

# Set cache directory
config.cache_dir = "/path/to/cache"

# Set minimum samples for tuning
config.min_num_samples = 5

# Set maximum stored samples per graph
config.max_num_samples_stored = 100

# Device to move model after tuning
config.device_after_tuning = "cuda"

# Enable/disable strict mode for input validation
config.strict_mode = True

# Enable HuggingFace integrations
config.enable_hf_integrations = True

Backend-Specific Configuration

Each backend has its own corresponding configuration:

from aitune.torch.backend import TensorRTBackendConfig, TensorRTBackend

config = TensorRTBackendConfig(
    use_cuda_graphs=True,
    workspace_size=1 << 30,  # 1GB
)
backend = TensorRTBackend(config)

See backend-specific documentation:

Dry Run Mode

You can run tuning in dry-run mode. It records samples of data, detects batch and dynamic axes, and detects graphs of execution but does not call the actual backend to tune. This allows debugging if everything is working as expected.

The dry-run mode can be turned on with the proper argument:

import logging

# make sure logging if configured
logging.basicConfig(level=logging.INFO, force=True)
# invoke dry-run tuning
ait.tune(pipe, input_data, dry_run=True)

Example output from dry-run

2026-01-26 16:23:44,360 - INFO - ════════════════════════════════════════════════════════════════
2026-01-26 16:23:44,360 - INFO - 🎯 Tuning module: `transformer` (all graphs)
2026-01-26 16:23:44,367 - INFO - ------------------------------------------------------------
2026-01-26 16:23:44,367 - INFO - 🚀 Tuning graph `0` for module `transformer` (DRY RUN):
2026-01-26 16:23:44,368 - INFO -   number of parameters: 2028328000
2026-01-26 16:23:44,368 - INFO -   number of layers: 6
2026-01-26 16:23:44,369 - INFO -   precisions: torch.float16
2026-01-26 16:23:44,369 - INFO -   graph_spec:
2026-01-26 16:23:44,369 - INFO -     input_spec:
 Tensors:
╒═══════════════════════════╤══════════════════════════════╤══════════════════════════╤═══════════════════╤═══════════════════╤═══════════════╕
│ Locator                   │ Name                         │ Shape                    │ Min Shape         │ Max Shape         │ Dtype         │
╞═══════════════════════════╪══════════════════════════════╪══════════════════════════╪═══════════════════╪═══════════════════╪═══════════════╡
│ ['encoder_hidden_states'] │ kwargs_encoder_hidden_states │ ['batch0', 333, 4096]    │ [2, 333, 4096]    │ [4, 333, 4096]    │ torch.float16 │
├───────────────────────────┼──────────────────────────────┼──────────────────────────┼───────────────────┼───────────────────┼───────────────┤
│ ['hidden_states']         │ kwargs_hidden_states         │ ['batch0', 16, 128, 128] │ [2, 16, 128, 128] │ [4, 16, 128, 128] │ torch.float16 │
├───────────────────────────┼──────────────────────────────┼──────────────────────────┼───────────────────┼───────────────────┼───────────────┤
│ ['pooled_projections']    │ kwargs_pooled_projections    │ ['batch0', 2048]         │ [2, 2048]         │ [4, 2048]         │ torch.float16 │
├───────────────────────────┼──────────────────────────────┼──────────────────────────┼───────────────────┼───────────────────┼───────────────┤
│ ['timestep']              │ kwargs_timestep              │ ['batch0']               │ [2]               │ [4]               │ torch.float32 │
╘═══════════════════════════╧══════════════════════════════╧══════════════════════════╧═══════════════════╧═══════════════════╧═══════════════╛
Other:
╒════════════════════════════╤═══════════════════════════════╤═════════╕
│ Locator                    │ Name                          │ Value   │
╞════════════════════════════╪═══════════════════════════════╪═════════╡
│ ['joint_attention_kwargs'] │ kwargs_joint_attention_kwargs │ None    │
├────────────────────────────┼───────────────────────────────┼─────────┤
│ ['return_dict']            │ kwargs_return_dict            │ False   │
╘════════════════════════════╧═══════════════════════════════╧═════════╛

2026-01-26 16:23:44,370 - INFO -     output_spec:
 Tensors:
╒═══════════╤═══════════╤══════════════════════════╤═══════════════════╤═══════════════════╤═══════════════╕
│ Locator   │ Name      │ Shape                    │ Min Shape         │ Max Shape         │ Dtype         │
╞═══════════╪═══════════╪══════════════════════════╪═══════════════════╪═══════════════════╪═══════════════╡
│ [0]       │ outputs_0 │ ['batch0', 16, 128, 128] │ [2, 16, 128, 128] │ [4, 16, 128, 128] │ torch.float16 │
╘═══════════╧═══════════╧══════════════════════════╧═══════════════════╧═══════════════════╧═══════════════╛

2026-01-26 16:23:44,370 - INFO -   num samples: 1
2026-01-26 16:23:44,370 - INFO -   device: cuda:0
2026-01-26 16:23:44,370 - INFO -   cache_dir: /home/pbazan/.cache/aitune/transformer/0
2026-01-26 16:23:44,371 - INFO -   strategy:
2026-01-26 16:23:44,371 - INFO -     name: First Wins Strategy
2026-01-26 16:23:44,371 - INFO -     description: evaluate backends in order, return first working backend
2026-01-26 16:23:44,371 - INFO -     backends:
2026-01-26 16:23:44,371 - INFO -       TensorRTBackend(quantization_config=None)
2026-01-26 16:23:44,371 - INFO -       TorchInductorBackend()
2026-01-26 16:23:44,372 - INFO -       TorchEagerBackend()
2026-01-26 16:23:44,372 - INFO - ✅ Tuning module: `transformer` (all graphs) completed.

Next Steps