Execution Graphs
Overview
During the tuning process, AITune analyzes module inputs to detect unique execution graphs—distinct computational paths through your model based on input characteristics. Understanding execution graphs is crucial because:
- Separate Optimization: Each graph is tuned independently with its own optimized backend
- Dynamic Shape Support: Graphs capture relationships between batch sizes, dynamic dimensions, and static shapes
- Input Routing: At inference time, inputs are automatically routed to the correct optimized graph
What Defines an Execution Graph?
AITune creates a new execution graph when it encounters inputs that differ in these ways:
- Tensor Rank: Tensors with different numbers of dimensions
module(torch.randn(1, 10)) # Graph 0: rank-2 tensor
module(torch.randn(1, 10, 5)) # Graph 1: rank-3 tensor (different graph!)
- Argument Structure: Different combinations of positional and keyword arguments
module(torch.randn(1, 10)) # Graph 0
module(torch.randn(1, 10), mask=True) # Graph 1: additional kwarg
- Non-Tensor Arguments (in strict mode): Different primitive values or configurations
Important: Tensors with the same rank but different shapes belong to the same graph. AITune handles shape variations through dynamic shape tracking (batch axes and dynamic dimensions).
Graph Detection in the Tuning Workflow
Execution graphs are detected during the Sample Gathering Phase of tuning:
- Your model is executed with samples from the dataset
- Each wrapped module records input/output metadata using
SampleMetadata - AITune compares metadata to identify unique graph patterns
- Each unique pattern becomes a separate
GraphSpec - During Module Tuning, each graph is optimized independently
For a complete overview of the tuning process, see Tuning Workflow.
What This Guide Covers
This guide explores the technical details of execution graphs through the lens of SampleMetadata, the core class responsible for:
- Capturing tensor metadata (shapes, dtypes, structure)
- Detecting batch axes vs. dynamic dimensions
- Tracking shape ranges (min/max) across samples
- Enabling dynamic batching and shape inference
By the end of this guide, you'll understand how AITune:
- Identifies which inputs belong to the same graph
- Learns dynamic shape patterns from multiple samples
- Uses this information to optimize each execution path separately
Introduction to SampleMetadata
What is SampleMetadata?
SampleMetadata is a class designed to capture and track metadata about function inputs and outputs, particularly focusing on PyTorch tensors. It serves several important purposes:
- Tensor Tracking: Automatically discovers and tracks all tensors in complex nested data structures
- Shape Inference: Learns about dynamic dimensions and batch axes by observing multiple samples
- Model Optimization: Enables optimization backends to understand input/output characteristics
- Dynamic Batching: Supports scaling tensors to different batch sizes based on learned patterns
Key Concepts
- Locators: Navigate through nested structures (tuples, lists, dicts, dataclasses and registered user types) to find tensors
- TensorSpec: Underlying representation that tracks shape, dtype, and batch axis information
- Dynamic Dimensions: Dimensions that vary across samples (e.g., sequence length in NLP)
- Batch Axes: Dimensions that scale proportionally with batch size
Let's start with a simple example:
from dataclasses import dataclass
import torch
from aitune.torch.module.sample_metadata import SampleMetadata, InfoLevel
# Create a simple tensor and capture its metadata
simple_tensor = torch.randn(2, 3, 4)
args = (simple_tensor,)
kwargs = {}
metadata = SampleMetadata.from_inputs(args, kwargs)
print(repr(metadata))
Tensors:
╒═══════════╤════════╤═══════════╤═════════════╤═════════════╤═══════════════╕
│ Locator │ Name │ Shape │ Min Shape │ Max Shape │ Dtype │
╞═══════════╪════════╪═══════════╪═════════════╪═════════════╪═══════════════╡
│ [0] │ args_0 │ [2, 3, 4] │ [2, 3, 4] │ [2, 3, 4] │ torch.float32 │
╘═══════════╧════════╧═══════════╧═════════════╧═════════════╧═══════════════╛
The output shows that SampleMetadata automatically detected the tensor, assigned it a name (args_0), and captured its shape [2, 3, 4] along with data type information.
Creating Metadata from Inputs
The primary way to create SampleMetadata is through the from_inputs() static method. This method accepts:
args: Positional arguments (typically a tuple)kwargs: Keyword arguments (a dictionary)strict: Boolean flag controlling whether to track non-tensor data (default:False)
Let's explore different input patterns:
# Example 1: Multiple tensors in args
args = (
torch.randn(2, 3),
torch.randn(4, 5),
)
kwargs = {}
meta1 = SampleMetadata.from_inputs(args, kwargs)
print("Example 1 - Multiple args:")
print(repr(meta1))
Example 1 - Multiple args:
Tensors:
╒═══════════╤════════╤═════════╤═════════════╤═════════════╤═══════════════╕
│ Locator │ Name │ Shape │ Min Shape │ Max Shape │ Dtype │
╞═══════════╪════════╪═════════╪═════════════╪═════════════╪═══════════════╡
│ [0] │ args_0 │ [2, 3] │ [2, 3] │ [2, 3] │ torch.float32 │
├───────────┼────────┼─────────┼─────────────┼─────────────┼───────────────┤
│ [1] │ args_1 │ [4, 5] │ [4, 5] │ [4, 5] │ torch.float32 │
╘═══════════╧════════╧═════════╧═════════════╧═════════════╧═══════════════╛
# Example 2: Tensors in kwargs
args = ()
kwargs = {
"input_tensor": torch.randn(3, 4),
"mask": torch.randn(3, 1),
}
meta2 = SampleMetadata.from_inputs(args, kwargs)
print("Example 2 - Kwargs only:")
print(repr(meta2))
Example 2 - Kwargs only:
Tensors:
╒══════════════════╤═════════════════════╤═════════╤═════════════╤═════════════╤═══════════════╕
│ Locator │ Name │ Shape │ Min Shape │ Max Shape │ Dtype │
╞══════════════════╪═════════════════════╪═════════╪═════════════╪═════════════╪═══════════════╡
│ ['input_tensor'] │ kwargs_input_tensor │ [3, 4] │ [3, 4] │ [3, 4] │ torch.float32 │
├──────────────────┼─────────────────────┼─────────┼─────────────┼─────────────┼───────────────┤
│ ['mask'] │ kwargs_mask │ [3, 1] │ [3, 1] │ [3, 1] │ torch.float32 │
╘══════════════════╧═════════════════════╧═════════╧═════════════╧═════════════╧═══════════════╛
# Example 3: Mixed primitives and tensors
args = (
"some_string", # Primitive - ignored by default
torch.randn(2, 2), # Tensor - tracked
42, # Primitive - ignored by default
)
kwargs = {
"data": torch.randn(3, 3),
"learning_rate": 0.001, # Primitive - ignored by default
}
meta3 = SampleMetadata.from_inputs(args, kwargs)
print("Example 3 - Mixed types (strict=False):")
print(repr(meta3))
print("\nNotice that only tensors are tracked!")
Example 3 - Mixed types (strict=False):
Tensors:
╒═══════════╤═════════════╤═════════╤═════════════╤═════════════╤═══════════════╕
│ Locator │ Name │ Shape │ Min Shape │ Max Shape │ Dtype │
╞═══════════╪═════════════╪═════════╪═════════════╪═════════════╪═══════════════╡
│ [1] │ args_1 │ [2, 2] │ [2, 2] │ [2, 2] │ torch.float32 │
├───────────┼─────────────┼─────────┼─────────────┼─────────────┼───────────────┤
│ ['data'] │ kwargs_data │ [3, 3] │ [3, 3] │ [3, 3] │ torch.float32 │
╘═══════════╧═════════════╧═════════╧═════════════╧═════════════╧═══════════════╛
Notice that only tensors are tracked!
Strict vs. Non-Strict Mode
By default during tuning, AITune operates in strict mode (config.strict_mode=True), which means SampleMetadata captures both tensors and non-tensor data (primitives, strings, etc.). This ensures different argument values create different execution graphs.
However, when calling SampleMetadata.from_inputs() directly with strict=False, it only tracks tensors and ignores all other data types. This is useful when you only care about tensor shapes for optimization purposes.
Strict mode (strict=True) is useful for:
- Validating that function signatures match expected patterns
- Debugging data flow through complex pipelines
- Ensuring reproducibility of function calls
Let's compare the two modes:
# Same inputs, different modes
args = (1, 2, 3, torch.randn(2, 2))
kwargs = {"t": torch.randn(2, 3), "other": "abc"}
# Non-strict mode (default)
meta_non_strict = SampleMetadata.from_inputs(args, kwargs, strict=False)
print("Non-Strict Mode (strict=False):")
print(repr(meta_non_strict))
print("\n" + "="*80 + "\n")
# Strict mode
meta_strict = SampleMetadata.from_inputs(args, kwargs, strict=True)
print("Strict Mode (strict=True):")
print(repr(meta_strict))
Non-Strict Mode (strict=False):
Tensors:
╒═══════════╤══════════╤═════════╤═════════════╤═════════════╤═══════════════╕
│ Locator │ Name │ Shape │ Min Shape │ Max Shape │ Dtype │
╞═══════════╪══════════╪═════════╪═════════════╪═════════════╪═══════════════╡
│ [3] │ args_3 │ [2, 2] │ [2, 2] │ [2, 2] │ torch.float32 │
├───────────┼──────────┼─────────┼─────────────┼─────────────┼───────────────┤
│ ['t'] │ kwargs_t │ [2, 3] │ [2, 3] │ [2, 3] │ torch.float32 │
╘═══════════╧══════════╧═════════╧═════════════╧═════════════╧═══════════════╛
================================================================================
Strict Mode (strict=True):
Tensors:
╒═══════════╤══════════╤═════════╤═════════════╤═════════════╤═══════════════╕
│ Locator │ Name │ Shape │ Min Shape │ Max Shape │ Dtype │
╞═══════════╪══════════╪═════════╪═════════════╪═════════════╪═══════════════╡
│ [3] │ args_3 │ [2, 2] │ [2, 2] │ [2, 2] │ torch.float32 │
├───────────┼──────────┼─────────┼─────────────┼─────────────┼───────────────┤
│ ['t'] │ kwargs_t │ [2, 3] │ [2, 3] │ [2, 3] │ torch.float32 │
╘═══════════╧══════════╧═════════╧═════════════╧═════════════╧═══════════════╛
Other:
╒═══════════╤══════════════╤═════════╕
│ Locator │ Name │ Value │
╞═══════════╪══════════════╪═════════╡
│ [0] │ args_0 │ 1 │
├───────────┼──────────────┼─────────┤
│ [1] │ args_1 │ 2 │
├───────────┼──────────────┼─────────┤
│ [2] │ args_2 │ 3 │
├───────────┼──────────────┼─────────┤
│ ['other'] │ kwargs_other │ abc │
╘═══════════╧══════════════╧═════════╛
Notice that in strict mode, we see an additional "Other" section that includes the primitive values (1, 2, 3, and "abc").
Working with Nested Structures
One of the most powerful features of SampleMetadata is its ability to handle deeply nested data structures. Real-world model inputs often involve complex combinations of:
- Tuples and Lists: For variable-length sequences
- Dictionaries: For named parameters
- Dataclasses: For structured configuration objects
SampleMetadata uses Locators to navigate these structures and find all tensors, no matter how deeply nested they are.
Let's create a complex nested example:
# Define a custom dataclass
@dataclass
class ModelInput:
data: torch.Tensor
metadata: str
# Create complex nested structure
args = [
"first_arg",
torch.randn(1), # Simple tensor
(torch.randn(2), torch.randn(3)), # Tuple of tensors
{"t": torch.randn(4)}, # Dict with tensor
ModelInput(data=torch.randn(5), metadata="info"), # Dataclass with tensor
]
kwargs = {
"t1": torch.randn(1, 1),
"t2": [torch.randn(2, 2), torch.randn(3, 3)], # List of tensors
"t3": ModelInput(data=torch.randn(4, 4), metadata="xyz"),
"last": "other",
}
nested_meta = SampleMetadata.from_inputs(args, kwargs, strict=True)
print(repr(nested_meta))
Tensors:
╒═════════════╤════════════════╤═════════╤═════════════╤═════════════╤═══════════════╕
│ Locator │ Name │ Shape │ Min Shape │ Max Shape │ Dtype │
╞═════════════╪════════════════╪═════════╪═════════════╪═════════════╪═══════════════╡
│ [1] │ args_1 │ [1] │ [1] │ [1] │ torch.float32 │
├─────────────┼────────────────┼─────────┼─────────────┼─────────────┼───────────────┤
│ [2][0] │ args_2_0 │ [2] │ [2] │ [2] │ torch.float32 │
├─────────────┼────────────────┼─────────┼─────────────┼─────────────┼───────────────┤
│ [2][1] │ args_2_1 │ [3] │ [3] │ [3] │ torch.float32 │
├─────────────┼────────────────┼─────────┼─────────────┼─────────────┼───────────────┤
│ [3]['t'] │ args_3_t │ [4] │ [4] │ [4] │ torch.float32 │
├─────────────┼────────────────┼─────────┼─────────────┼─────────────┼───────────────┤
│ [4].data │ args_4.data │ [5] │ [5] │ [5] │ torch.float32 │
├─────────────┼────────────────┼─────────┼─────────────┼─────────────┼───────────────┤
│ ['t1'] │ kwargs_t1 │ [1, 1] │ [1, 1] │ [1, 1] │ torch.float32 │
├─────────────┼────────────────┼─────────┼─────────────┼─────────────┼───────────────┤
│ ['t2'][0] │ kwargs_t2_0 │ [2, 2] │ [2, 2] │ [2, 2] │ torch.float32 │
├─────────────┼────────────────┼─────────┼─────────────┼─────────────┼───────────────┤
│ ['t2'][1] │ kwargs_t2_1 │ [3, 3] │ [3, 3] │ [3, 3] │ torch.float32 │
├─────────────┼────────────────┼─────────┼─────────────┼─────────────┼───────────────┤
│ ['t3'].data │ kwargs_t3.data │ [4, 4] │ [4, 4] │ [4, 4] │ torch.float32 │
╘═════════════╧════════════════╧═════════╧═════════════╧═════════════╧═══════════════╛
Other:
╒═════════════════╤════════════════════╤═══════════╕
│ Locator │ Name │ Value │
╞═════════════════╪════════════════════╪═══════════╡
│ [0] │ args_0 │ first_arg │
├─────────────────┼────────────────────┼───────────┤
│ [4].metadata │ args_4.metadata │ info │
├─────────────────┼────────────────────┼───────────┤
│ ['last'] │ kwargs_last │ other │
├─────────────────┼────────────────────┼───────────┤
│ ['t3'].metadata │ kwargs_t3.metadata │ xyz │
╘═════════════════╧════════════════════╧═══════════╛
Understanding Locators
The Locator column shows the path to each tensor in the nested structure:
[1]: Second element of args (0-indexed)[2][0]: First element of the tuple at args[2][3]['t']: Value at key 't' in the dict at args[3][4].data: The 'data' attribute of the dataclass at args[4]['t2'][0]: First element of the list at kwargs['t2']['t3'].data: The 'data' attribute of the dataclass at kwargs['t3']
This allows SampleMetadata to precisely locate and manipulate tensors in complex structures.
Describing Metadata - InfoLevel
SampleMetadata provides three levels of detail when displaying information, controlled by the InfoLevel enum:
InfoLevel.SHORT: Compact representation with just tensor namesInfoLevel.MEDIUM: Includes locators and current shapes (simple table format)InfoLevel.FULL: Complete details including min/max shapes and dtypes (fancy table format)
Let's see the same metadata displayed at all three levels:
# Create sample metadata
args = (torch.randn(2, 3), torch.randn(4, 5, 6))
kwargs = {"mask": torch.randn(2, 1)}
meta = SampleMetadata.from_inputs(args, kwargs)
print("InfoLevel.SHORT:")
print(meta.describe(InfoLevel.SHORT))
print("\n" + "="*80 + "\n")
print("InfoLevel.MEDIUM:")
print(meta.describe(InfoLevel.MEDIUM))
print("\n" + "="*80 + "\n")
print("InfoLevel.FULL:")
print(meta.describe(InfoLevel.FULL))
InfoLevel.SHORT:
Tensors: args_0, args_1, kwargs_mask
================================================================================
InfoLevel.MEDIUM:
Tensors:
Locator Name Shape
--------- ----------- ---------
[0] args_0 [2, 3]
[1] args_1 [4, 5, 6]
['mask'] kwargs_mask [2, 1]
================================================================================
InfoLevel.FULL:
Tensors:
╒═══════════╤═════════════╤═══════════╤═════════════╤═════════════╤═══════════════╕
│ Locator │ Name │ Shape │ Min Shape │ Max Shape │ Dtype │
╞═══════════╪═════════════╪═══════════╪═════════════╪═════════════╪═══════════════╡
│ [0] │ args_0 │ [2, 3] │ [2, 3] │ [2, 3] │ torch.float32 │
├───────────┼─────────────┼───────────┼─────────────┼─────────────┼───────────────┤
│ [1] │ args_1 │ [4, 5, 6] │ [4, 5, 6] │ [4, 5, 6] │ torch.float32 │
├───────────┼─────────────┼───────────┼─────────────┼─────────────┼───────────────┤
│ ['mask'] │ kwargs_mask │ [2, 1] │ [2, 1] │ [2, 1] │ torch.float32 │
╘═══════════╧═════════════╧═══════════╧═════════════╧═════════════╧═══════════════╛
The FULL level is particularly useful because it shows:
- Min Shape: The smallest dimensions seen for each axis
- Max Shape: The largest dimensions seen for each axis
- Dtype: The PyTorch data type of the tensor
These become interesting when we start tracking multiple samples with different shapes.
Dynamic Shape Tracking
One of the most sophisticated features of SampleMetadata is its ability to learn about dynamic dimensions and batch axes by observing multiple samples with different shapes.
How It Works
When you call update_shapes_seen() with metadata from a different sample:
- Batch Axis Detection: If a dimension scales proportionally with batch size and the multiplier is an integer, it's marked as a batch axis (e.g.,
batch0,batch1) - Dynamic Dimension Detection: If a dimension changes but not proportionally to batch size, it's marked as a dynamic dimension (e.g.,
dim0,dim1) - Min/Max Tracking: The minimum and maximum values seen for each dimension are tracked
Let's see this in action:
# Create initial metadata with batch size 1
args_initial = [
torch.randn(1),
torch.randn(2),
torch.randn(5),
]
kwargs_initial = {
"data": torch.randn(1, 10), # First dim batch, second dynamic
}
meta_initial = SampleMetadata.from_inputs(args_initial, kwargs_initial, strict=False, batch_size=1)
print("Initial Metadata (batch_size=1):")
print(meta_initial.describe(InfoLevel.FULL))
Initial Metadata (batch_size=1):
Tensors:
╒═══════════╤═════════════╤═════════╤═════════════╤═════════════╤═══════════════╕
│ Locator │ Name │ Shape │ Min Shape │ Max Shape │ Dtype │
╞═══════════╪═════════════╪═════════╪═════════════╪═════════════╪═══════════════╡
│ [0] │ args_0 │ [1] │ [1] │ [1] │ torch.float32 │
├───────────┼─────────────┼─────────┼─────────────┼─────────────┼───────────────┤
│ [1] │ args_1 │ [2] │ [2] │ [2] │ torch.float32 │
├───────────┼─────────────┼─────────┼─────────────┼─────────────┼───────────────┤
│ [2] │ args_2 │ [5] │ [5] │ [5] │ torch.float32 │
├───────────┼─────────────┼─────────┼─────────────┼─────────────┼───────────────┤
│ ['data'] │ kwargs_data │ [1, 10] │ [1, 10] │ [1, 10] │ torch.float32 │
╘═══════════╧═════════════╧═════════╧═════════════╧═════════════╧═══════════════╛
# Create second metadata with different shapes and batch size 2
args_second = [
torch.randn(2), # Doubled (batch axis)
torch.randn(5), # Changed but not proportionally (dynamic)
torch.randn(15), # Changed but not proportionally (dynamic)
]
kwargs_second = {
"data": torch.randn(2, 25), # First dim doubled, second changed
}
meta_second = SampleMetadata.from_inputs(args_second, kwargs_second, strict=False, batch_size=2)
print("Second Metadata (batch_size=2):")
print(meta_second.describe(InfoLevel.FULL))
Second Metadata (batch_size=2):
Tensors:
╒═══════════╤═════════════╤═════════╤═════════════╤═════════════╤═══════════════╕
│ Locator │ Name │ Shape │ Min Shape │ Max Shape │ Dtype │
╞═══════════╪═════════════╪═════════╪═════════════╪═════════════╪═══════════════╡
│ [0] │ args_0 │ [2] │ [2] │ [2] │ torch.float32 │
├───────────┼─────────────┼─────────┼─────────────┼─────────────┼───────────────┤
│ [1] │ args_1 │ [5] │ [5] │ [5] │ torch.float32 │
├───────────┼─────────────┼─────────┼─────────────┼─────────────┼───────────────┤
│ [2] │ args_2 │ [15] │ [15] │ [15] │ torch.float32 │
├───────────┼─────────────┼─────────┼─────────────┼─────────────┼───────────────┤
│ ['data'] │ kwargs_data │ [2, 25] │ [2, 25] │ [2, 25] │ torch.float32 │
╘═══════════╧═════════════╧═════════╧═════════════╧═════════════╧═══════════════╛
# Update the initial metadata with information from the second sample
meta_initial.update_shapes_seen(meta_second)
print("Updated Metadata (after seeing both samples):")
print(meta_initial.describe(InfoLevel.FULL))
Updated Metadata (after seeing both samples):
Tensors:
╒═══════════╤═════════════╤════════════════════╤═════════════╤═════════════╤═══════════════╕
│ Locator │ Name │ Shape │ Min Shape │ Max Shape │ Dtype │
╞═══════════╪═════════════╪════════════════════╪═════════════╪═════════════╪═══════════════╡
│ [0] │ args_0 │ ['batch0'] │ [1] │ [2] │ torch.float32 │
├───────────┼─────────────┼────────────────────┼─────────────┼─────────────┼───────────────┤
│ [1] │ args_1 │ ['dim0'] │ [2] │ [5] │ torch.float32 │
├───────────┼─────────────┼────────────────────┼─────────────┼─────────────┼───────────────┤
│ [2] │ args_2 │ ['dim0'] │ [5] │ [15] │ torch.float32 │
├───────────┼─────────────┼────────────────────┼─────────────┼─────────────┼───────────────┤
│ ['data'] │ kwargs_data │ ['batch0', 'dim1'] │ [1, 10] │ [2, 25] │ torch.float32 │
╘═══════════╧═════════════╧════════════════════╧═════════════╧═════════════╧═══════════════╛
Understanding the Results
After updating, notice how the shapes have been transformed:
batch0: Dimensions that doubled when batch size doubled (1→2)dim0,dim1: Dimensions that changed but not proportionally to batch size- Min/Max Shape: Now show the range of values observed
This information is crucial for:
- Model compilation: Backends can create optimized graphs for dynamic shapes
- Memory planning: Knowing the range helps allocate appropriate buffers
- Validation: Ensuring new inputs fall within expected ranges
Batch Manipulation
Once SampleMetadata has learned about batch axes through update_shapes_seen(), it can use the make_batch() method to scale tensors to a target batch size.
How make_batch() Works
The method uses batch axis multipliers to determine how to scale each dimension:
- Multiplier = 1: Standard batch axis, scales linearly with batch size
- Multiplier > 1: Stacked batch axis (e.g., when inputs are vertically stacked)
- Slicing: If current size > target, slice the tensor
- Repeating: If current size < target, repeat the tensor
Let's see this in action:
# First, create metadata and teach it about batch axes
args1 = [torch.randn(1, 5), torch.randn(2, 3)]
kwargs1 = {"mask": torch.randn(1, 10)}
meta = SampleMetadata.from_inputs(args1, kwargs1, batch_size=1)
# Second sample with batch size 2
args2 = [torch.randn(2, 5), torch.randn(4, 3)]
kwargs2 = {"mask": torch.randn(2, 10)}
meta2 = SampleMetadata.from_inputs(args2, kwargs2, batch_size=2)
meta.update_shapes_seen(meta2)
print("Learned Metadata:")
print(meta.describe(InfoLevel.FULL))
Learned Metadata:
Tensors:
╒═══════════╤═════════════╤════════════════╤═════════════╤═════════════╤═══════════════╕
│ Locator │ Name │ Shape │ Min Shape │ Max Shape │ Dtype │
╞═══════════╪═════════════╪════════════════╪═════════════╪═════════════╪═══════════════╡
│ [0] │ args_0 │ ['batch0', 5] │ [1, 5] │ [2, 5] │ torch.float32 │
├───────────┼─────────────┼────────────────┼─────────────┼─────────────┼───────────────┤
│ [1] │ args_1 │ ['batch0', 3] │ [2, 3] │ [4, 3] │ torch.float32 │
├───────────┼─────────────┼────────────────┼─────────────┼─────────────┼───────────────┤
│ ['mask'] │ kwargs_mask │ ['batch0', 10] │ [1, 10] │ [2, 10] │ torch.float32 │
╘═══════════╧═════════════╧════════════════╧═════════════╧═════════════╧═══════════════╛
# Now use make_batch to scale to a larger batch size
original_args = [torch.randn(1, 5), torch.randn(2, 3)]
original_kwargs = {"mask": torch.randn(1, 10)}
print("Original shapes:")
print(f" args[0]: {original_args[0].shape}")
print(f" args[1]: {original_args[1].shape}")
print(f" kwargs['mask']: {original_kwargs['mask'].shape}")
print()
# Scale to batch size 10
batched_args, batched_kwargs = meta.make_batch(original_args, original_kwargs, batch_size=10)
print("After make_batch(batch_size=10):")
print(f" args[0]: {batched_args[0].shape}")
print(f" args[1]: {batched_args[1].shape}")
print(f" kwargs['mask']: {batched_kwargs['mask'].shape}")
Original shapes:
args[0]: torch.Size([1, 5])
args[1]: torch.Size([2, 3])
kwargs['mask']: torch.Size([1, 10])
After make_batch(batch_size=10):
args[0]: torch.Size([10, 5])
args[1]: torch.Size([20, 3])
kwargs['mask']: torch.Size([10, 10])
Notice how:
- The first dimensions (batch axes) scaled to match the target batch size of 10
args[1]has a multiplier of 2 (it's a stacked batch), so it scaled to 20 (10 × 2)- Non-batch dimensions (like the 5, 3, 10) remained unchanged
TensorSpec Deep Dive
SampleMetadata is actually a container for multiple TensorSpec objects, where each TensorSpec represents one tensor in the input/output structure.
TensorSpec Attributes
name: Symbolic name (e.g.,args_0,kwargs_mask)shape: Current shape representation (may include symbolic dimensions)min_shape: Minimum dimensions observedmax_shape: Maximum dimensions observeddtype: PyTorch data type_bs_multipliers: Internal batch size multipliers for each axis
Let's inspect TensorSpec objects directly:
# Create metadata with dynamic shapes
args1 = [torch.randn(1, 5)]
kwargs1 = {"data": torch.randn(1, 10)}
meta = SampleMetadata.from_inputs(args1, kwargs1, batch_size=1)
args2 = [torch.randn(2, 5)]
kwargs2 = {"data": torch.randn(2, 20)}
meta2 = SampleMetadata.from_inputs(args2, kwargs2, batch_size=2)
meta.update_shapes_seen(meta2)
# Access individual TensorSpec objects
print("Individual TensorSpec objects:\n")
for locator, tensor_spec in meta.tensor_data:
print(f"Locator: {locator}")
print(f" Name: {tensor_spec.name}")
print(f" Shape: {tensor_spec.shape}")
print(f" Min Shape: {tensor_spec.min_shape}")
print(f" Max Shape: {tensor_spec.max_shape}")
print(f" Dtype: {tensor_spec.dtype}")
print(f" Has batch axis: {tensor_spec.has_batch_axis()}")
print(f" Has dynamic axis: {tensor_spec.has_dynamic_axis()}")
print(f" Batch multipliers: {tensor_spec.get_batch_axis_multipliers()}")
Individual TensorSpec objects:
Locator: [0]
Name: args_0
Shape: ['batch0', 5]
Min Shape: [1, 5]
Max Shape: [2, 5]
Dtype: torch.float32
Has batch axis: True
Has dynamic axis: False
Batch multipliers: {0: 1}
Locator: ['data']
Name: kwargs_data
Shape: ['batch0', 'batch1']
Min Shape: [1, 10]
Max Shape: [2, 20]
Dtype: torch.float32
Has batch axis: True
Has dynamic axis: False
Batch multipliers: {0: 1, 1: 10}
Useful TensorSpec Methods
has_batch_axis(): Returns True if the tensor has at least one batch dimensionhas_dynamic_axis(): Returns True if the tensor has at least one dynamic dimensionget_batch_axis_multipliers(): Returns a dict mapping axis index to its batch multipliermatches(other): Checks if two TensorSpecs are compatible
These methods are used internally by SampleMetadata to perform operations like make_batch().
Practical Example - Complete Workflow
Let's put everything together with a realistic scenario: profiling a model with variable-length sequences (like in NLP tasks).
Scenario
We have a language model that takes:
- Input IDs with shape
(batch_size, sequence_length) - Attention mask with shape
(batch_size, sequence_length) - Position IDs with shape
(batch_size, sequence_length)
We'll profile it with different batch sizes and sequence lengths to learn the dynamic shapes.
# Simulate model profiling
@dataclass
class ModelInputs:
input_ids: torch.Tensor
attention_mask: torch.Tensor
position_ids: torch.Tensor
# Sample 1: batch_size=1, seq_len=10
sample1_args = ()
sample1_kwargs = {
"inputs": ModelInputs(
input_ids=torch.randint(0, 1000, (1, 10)),
attention_mask=torch.ones(1, 10),
position_ids=torch.arange(10).unsqueeze(0),
)
}
metadata = SampleMetadata.from_inputs(sample1_args, sample1_kwargs, batch_size=1, strict=False)
print("After Sample 1 (batch=1, seq_len=10):")
print(metadata.describe(InfoLevel.FULL))
After Sample 1 (batch=1, seq_len=10):
Tensors:
╒═══════════════════════════╤══════════════════════════════╤═════════╤═════════════╤═════════════╤═══════════════╕
│ Locator │ Name │ Shape │ Min Shape │ Max Shape │ Dtype │
╞═══════════════════════════╪══════════════════════════════╪═════════╪═════════════╪═════════════╪═══════════════╡
│ ['inputs'].input_ids │ kwargs_inputs.input_ids │ [1, 10] │ [1, 10] │ [1, 10] │ torch.int64 │
├───────────────────────────┼──────────────────────────────┼─────────┼─────────────┼─────────────┼───────────────┤
│ ['inputs'].attention_mask │ kwargs_inputs.attention_mask │ [1, 10] │ [1, 10] │ [1, 10] │ torch.float32 │
├───────────────────────────┼──────────────────────────────┼─────────┼─────────────┼─────────────┼───────────────┤
│ ['inputs'].position_ids │ kwargs_inputs.position_ids │ [1, 10] │ [1, 10] │ [1, 10] │ torch.int64 │
╘═══════════════════════════╧══════════════════════════════╧═════════╧═════════════╧═════════════╧═══════════════╛
# Sample 2: batch_size=2, seq_len=15
sample2_args = ()
sample2_kwargs = {
"inputs": ModelInputs(
input_ids=torch.randint(0, 1000, (2, 15)),
attention_mask=torch.ones(2, 15),
position_ids=torch.arange(15).unsqueeze(0).repeat(2, 1),
)
}
metadata2 = SampleMetadata.from_inputs(sample2_args, sample2_kwargs, batch_size=2, strict=False)
metadata.update_shapes_seen(metadata2)
print("After Sample 2 (batch=2, seq_len=15):")
print(metadata.describe(InfoLevel.FULL))
After Sample 2 (batch=2, seq_len=15):
Tensors:
╒═══════════════════════════╤══════════════════════════════╤════════════════════╤═════════════╤═════════════╤═══════════════╕
│ Locator │ Name │ Shape │ Min Shape │ Max Shape │ Dtype │
╞═══════════════════════════╪══════════════════════════════╪════════════════════╪═════════════╪═════════════╪═══════════════╡
│ ['inputs'].input_ids │ kwargs_inputs.input_ids │ ['batch0', 'dim1'] │ [1, 10] │ [2, 15] │ torch.int64 │
├───────────────────────────┼──────────────────────────────┼────────────────────┼─────────────┼─────────────┼───────────────┤
│ ['inputs'].attention_mask │ kwargs_inputs.attention_mask │ ['batch0', 'dim1'] │ [1, 10] │ [2, 15] │ torch.float32 │
├───────────────────────────┼──────────────────────────────┼────────────────────┼─────────────┼─────────────┼───────────────┤
│ ['inputs'].position_ids │ kwargs_inputs.position_ids │ ['batch0', 'dim1'] │ [1, 10] │ [2, 15] │ torch.int64 │
╘═══════════════════════════╧══════════════════════════════╧════════════════════╧═════════════╧═════════════╧═══════════════╛
# Sample 3: batch_size=4, seq_len=20
sample3_args = ()
sample3_kwargs = {
"inputs": ModelInputs(
input_ids=torch.randint(0, 1000, (4, 20)),
attention_mask=torch.ones(4, 20),
position_ids=torch.arange(20).unsqueeze(0).repeat(4, 1),
)
}
metadata3 = SampleMetadata.from_inputs(sample3_args, sample3_kwargs, batch_size=4, strict=False)
metadata.update_shapes_seen(metadata3)
print("After Sample 3 (batch=4, seq_len=20):")
print(metadata.describe(InfoLevel.FULL))
After Sample 3 (batch=4, seq_len=20):
Tensors:
╒═══════════════════════════╤══════════════════════════════╤════════════════════╤═════════════╤═════════════╤═══════════════╕
│ Locator │ Name │ Shape │ Min Shape │ Max Shape │ Dtype │
╞═══════════════════════════╪══════════════════════════════╪════════════════════╪═════════════╪═════════════╪═══════════════╡
│ ['inputs'].input_ids │ kwargs_inputs.input_ids │ ['batch0', 'dim1'] │ [1, 10] │ [4, 20] │ torch.int64 │
├───────────────────────────┼──────────────────────────────┼────────────────────┼─────────────┼─────────────┼───────────────┤
│ ['inputs'].attention_mask │ kwargs_inputs.attention_mask │ ['batch0', 'dim1'] │ [1, 10] │ [4, 20] │ torch.float32 │
├───────────────────────────┼──────────────────────────────┼────────────────────┼─────────────┼─────────────┼───────────────┤
│ ['inputs'].position_ids │ kwargs_inputs.position_ids │ ['batch0', 'dim1'] │ [1, 10] │ [4, 20] │ torch.int64 │
╘═══════════════════════════╧══════════════════════════════╧════════════════════╧═════════════╧═════════════╧═══════════════╛
Analysis
After observing three samples with different batch sizes and sequence lengths:
- First dimension: Identified as
batch0because it scaled proportionally (1→2→4) - Second dimension: Identified as
dim1because it varied dynamically (10→15→20) - Min/Max ranges: Captured the observed ranges for both dimensions
This information can now be used by optimization backends to compile efficient code for these dynamic shapes.
# Now we can create inputs for any batch size!
test_input_args = ()
test_input_kwargs = {
"inputs": ModelInputs(
input_ids=torch.randint(0, 1000, (2, 12)),
attention_mask=torch.ones(2, 12),
position_ids=torch.arange(12).unsqueeze(0).repeat(2, 1),
)
}
print("Original test input shapes:")
print(f" input_ids: {test_input_kwargs['inputs'].input_ids.shape}")
print(f" attention_mask: {test_input_kwargs['inputs'].attention_mask.shape}")
print(f" position_ids: {test_input_kwargs['inputs'].position_ids.shape}")
print()
# Scale to batch size 8
scaled_args, scaled_kwargs = metadata.make_batch(test_input_args, test_input_kwargs, batch_size=8)
print("After scaling to batch_size=8:")
print(f" input_ids: {scaled_kwargs['inputs'].input_ids.shape}")
print(f" attention_mask: {scaled_kwargs['inputs'].attention_mask.shape}")
print(f" position_ids: {scaled_kwargs['inputs'].position_ids.shape}")
print("\nNote: Batch dimension scaled to 8, but sequence length (dim1) remained at 12")
Original test input shapes:
input_ids: torch.Size([2, 12])
attention_mask: torch.Size([2, 12])
position_ids: torch.Size([2, 12])
After scaling to batch_size=8:
input_ids: torch.Size([8, 12])
attention_mask: torch.Size([8, 12])
position_ids: torch.Size([8, 12])
Note: Batch dimension scaled to 8, but sequence length (dim1) remained at 12
Summary
Key Takeaways
-
Purpose:
SampleMetadatacaptures and tracks metadata about tensors in complex data structures, enabling model optimization and dynamic batching. -
Creation: Use
SampleMetadata.from_inputs(args, kwargs, strict=bool)to create metadata from function inputs. -
Strict Mode: Controls whether only tensors (
strict=False) or all data types (strict=True) are tracked. -
Nested Structures: Automatically handles tuples, lists, dicts, and dataclasses using Locators.
-
InfoLevel: Three display modes (SHORT, MEDIUM, FULL) provide different levels of detail.
-
Dynamic Shape Learning:
update_shapes_seen()learns about batch axes and dynamic dimensions by observing multiple samples. -
Batch Manipulation:
make_batch()can scale tensors to any batch size based on learned batch axis multipliers. -
TensorSpec: The underlying representation of each tensor, containing shape, dtype, and batch information.
Use in AI-Tune Pipeline
SampleMetadata is a fundamental building block in the AITune library, used by:
- RecordingModule: Captures input/output metadata during profiling
- Backends: Use metadata to configure optimized execution (TensorRT, TorchScript, etc.)
- Graph Compilation: Enables the creation of optimized graphs for dynamic shapes
Source Code
For more details, see:
aitune/torch/module/sample_metadata.pyaitune/torch/module/tensor_spec.pyaitune/torch/module/locator.py
If you would like to tinker with SampleMetadata, you can find this example in notebooks/sample_metadata_walkthrough.ipynb.