Skip to content

TorchAO Backend Guide

The TorchAO backend leverages PyTorch's torchao library for quantization-based model tuning. It provides various quantization schemes for weight-only and dynamic quantization.

Overview

  • Weight-Only Quantization: INT8, FP8
  • Dynamic Quantization: INT8 and FP8 with dynamic activations
  • Easy Configuration: Predefined quantization types
  • Pure PyTorch: No external dependencies beyond torchao

Quick Start

from aitune.torch.backend import TorchAOBackend, TorchAOBackendConfig
import aitune.torch as ait

# Configure with FP8 weight-only quantization
config = TorchAOBackendConfig(quantization="fp8wo")
backend = TorchAOBackend(config)

# Use in tuning
strategy = ait.OneBackendStrategy(backend=backend)

model = ait.Module(model, "my-model", strategy=strategy)
ait.tune(model, input_data)

Quantization Types

Weight-Only Quantization

# INT8 weight-only
config = TorchAOBackendConfig(quantization="int8wo")

# FP8 weight-only (default)
config = TorchAOBackendConfig(quantization="fp8wo")

Dynamic Quantization

# INT8 dynamic (activations + weights)
config = TorchAOBackendConfig(quantization="int8dq")

# FP8 dynamic (activations + weights)
config = TorchAOBackendConfig(quantization="fp8dq")

Configuration Options

Using Predefined Types

config = TorchAOBackendConfig(
    quantization="int8wo",  # Choose quantization type
)

Custom Configuration

from torchao.quantization import Int8WeightOnlyConfig

custom_config = Int8WeightOnlyConfig()

config = TorchAOBackendConfig(
    quantization_config=custom_config,
)

Quantization Comparison

Type Weights Activations Memory Reduction Speed Accuracy
int8wo INT8 FP16/FP32 ~2x High Better
int8dq INT8 INT8 ~2x Very High Good
fp8wo FP8 FP16/FP32 ~2x Very High Excellent
fp8dq FP8 FP8 ~2x Very High Excellent

Best Practices

  1. Start with FP8: Best accuracy/performance trade-off
  2. Use INT8 for Memory: When memory is critical
  3. Dynamic Quantization: Better accuracy, slightly higher overhead
  4. Validate Accuracy: Always test quantized model accuracy
  5. Calibration Data: Use representative samples

Troubleshooting

Issue: Accuracy loss too high

Solution: Try less aggressive quantization:

# Instead of int8wo, try fp8wo
config = TorchAOBackendConfig(quantization="fp8wo")

Issue: Not enough speed improvement

Solution: Try dynamic quantization:

config = TorchAOBackendConfig(quantization="fp8dq")

Next Steps