Skip to content

HighestThroughputStrategy API

aitune.torch.tune_strategy.HighestThroughputStrategy

HighestThroughputStrategy(backends=None, measurement_stop_strategy=None, profiling_stop_strategy=None, **kwargs)

Bases: TuneStrategyFindMaxBatchSizeExtension

Searches and selects the backend with the highest throughput.

Initializes strategy.

Parameters:

  • backends (list[Backend] | None, default: None ) –

    List of backends to tune.

  • measurement_stop_strategy (MeasuringStopStrategy | None, default: None ) –

    Measurement stop strategy.

  • profiling_stop_strategy (ProfilingStopStrategy | None, default: None ) –

    Profiling stop strategy.

  • kwargs (Any, default: {} ) –

    Additional arguments for the parent class

Source code in aitune/torch/tune_strategy/highest_throughput_strategy.py
def __init__(
    self,
    backends: list[Backend] | None = None,
    measurement_stop_strategy: MeasuringStopStrategy | None = None,
    profiling_stop_strategy: ProfilingStopStrategy | None = None,
    **kwargs: Any,
):
    """Initializes strategy.

    Args:
        backends: List of backends to tune.
        measurement_stop_strategy: Measurement stop strategy.
        profiling_stop_strategy: Profiling stop strategy.
        kwargs: Additional arguments for the parent class
    """
    super().__init__(**kwargs)
    self._backends = backends or self._default_backends()
    self._measurement_stop_strategy = measurement_stop_strategy or StableWindowMeasuringStopStrategy(
        window_size=DEFAULT_WINDOW_SIZE,
        stability_percentage=DEFAULT_STABILITY_PERCENTAGE,
    )
    self._profiling_stop_strategy = profiling_stop_strategy or ThroughputSaturatedProfilingStopStrategy(
        throughput_cutoff_threshold=DEFAULT_THROUGHPUT_CUTOFF_THRESHOLD,
        throughput_backoff_limit=DEFAULT_THROUGHPUT_BACKOFF_LIMIT,
    )
    self.results: list[HighestThroughputStrategyResult] = []

check_correctness

check_correctness(backend, name, graph_spec, data)

Check outputs for NaN/inf.

Parameters:

  • backend (Backend) –

    The backend to check.

  • name (str) –

    The name of the module.

  • graph_spec (GraphSpec) –

    The graph spec of the module.

  • data (list[Sample]) –

    The data to check.

Note

This method is should be called by the _tune method to check the correctness of the backend.

You can disable correctness check by calling enable_correctness_check(False).

Raises:

  • CorrectnessCheckError

    if the backend fails any check.

Source code in aitune/torch/tune_strategy/tune_strategy.py
def check_correctness(self, backend: Backend, name: str, graph_spec: GraphSpec, data: list[Sample]):
    """Check outputs for NaN/inf.

    Args:
        backend: The backend to check.
        name: The name of the module.
        graph_spec: The graph spec of the module.
        data: The data to check.

    Note:
        This method is should be called by the _tune method to check the correctness of the backend.

        You can disable correctness check by calling `enable_correctness_check(False)`.

    Raises:
        CorrectnessCheckError: if the backend fails any check.
    """
    if not self._enable_correctness_check:
        self._logger.debug(
            "Correctness check is disabled for %s and graph spec %s",
            backend.describe(),
            graph_spec,
        )
        return

    self._logger.debug("Checking correctness for %s and graph spec %s", backend.describe(), graph_spec)
    with torch.no_grad():
        for args, kwargs in data:
            outputs = backend.infer(*deepcopy(args), **deepcopy(kwargs))
            check_output_correctness(outputs, name=f"{name}.{graph_spec.name}.{backend.describe()}.output")
            outputs_metadata = SampleMetadata.from_outputs(outputs)
            check_output_tensor_shapes(graph_spec.output_spec.tensor_specs, outputs_metadata.tensor_specs)

clone

clone()

Clones the tune strategy.

Source code in aitune/torch/tune_strategy/tune_strategy.py
def clone(self) -> "TuneStrategy":
    """Clones the tune strategy."""
    return deepcopy(self)

default_profiling_config staticmethod

default_profiling_config(batching=True, max_batch_size=DEFAULT_MAX_BATCH_SIZE, window_size=DEFAULT_WINDOW_SIZE, stability_percentage=DEFAULT_STABILITY_PERCENTAGE, throughput_cutoff_threshold=DEFAULT_THROUGHPUT_CUTOFF_THRESHOLD, throughput_backoff_limit=DEFAULT_THROUGHPUT_BACKOFF_LIMIT)

Get profiling config for finding max batch size.

Parameters:

  • batching (bool, default: True ) –

    Whether to profile with batching.

  • max_batch_size (int, default: DEFAULT_MAX_BATCH_SIZE ) –

    Max batch size to find used to construct batch sizes, the batch sizes will be 2^n for n in range(max_batch_size.bit_length()).

  • window_size (int, default: DEFAULT_WINDOW_SIZE ) –

    Window size for measuring stop strategy.

  • stability_percentage (float, default: DEFAULT_STABILITY_PERCENTAGE ) –

    Stability percentage for measuring stop strategy.

  • throughput_cutoff_threshold (float, default: DEFAULT_THROUGHPUT_CUTOFF_THRESHOLD ) –

    Throughput cutoff threshold for profiling stop strategy.

  • throughput_backoff_limit (int, default: DEFAULT_THROUGHPUT_BACKOFF_LIMIT ) –

    Throughput backoff limit for profiling stop strategy.

Returns:

  • ProfilingConfig

    Profiling config for finding max batch size.

Note

The profiling config will use defaults from highest throughput strategy.

Source code in aitune/torch/tune_strategy/extension/find_max_batch_size_extension.py
@staticmethod
def default_profiling_config(
    batching: bool = True,
    max_batch_size: int = DEFAULT_MAX_BATCH_SIZE,
    window_size: int = DEFAULT_WINDOW_SIZE,
    stability_percentage: float = DEFAULT_STABILITY_PERCENTAGE,
    throughput_cutoff_threshold: float = DEFAULT_THROUGHPUT_CUTOFF_THRESHOLD,
    throughput_backoff_limit: int = DEFAULT_THROUGHPUT_BACKOFF_LIMIT,
) -> ProfilingConfig:
    """Get profiling config for finding max batch size.

    Args:
        batching: Whether to profile with batching.
        max_batch_size: Max batch size to find used to construct batch sizes, the batch sizes will be 2^n for n in range(max_batch_size.bit_length()).
        window_size: Window size for measuring stop strategy.
        stability_percentage: Stability percentage for measuring stop strategy.
        throughput_cutoff_threshold: Throughput cutoff threshold for profiling stop strategy.
        throughput_backoff_limit: Throughput backoff limit for profiling stop strategy.

    Returns:
        Profiling config for finding max batch size.

    Note:
        The profiling config will use defaults from highest throughput strategy.
    """
    return ProfilingConfig(
        batching=batching,
        batch_sizes=[2**n for n in range(max_batch_size.bit_length())],
        measuring_strategy=ModelExecutionTimeMeasuringStrategy(),
        measurement_stop_strategy=StableWindowMeasuringStopStrategy(
            window_size=window_size,
            stability_percentage=stability_percentage,
        ),
        profiling_stop_strategy=ThroughputSaturatedProfilingStopStrategy(
            throughput_cutoff_threshold=throughput_cutoff_threshold,
            throughput_backoff_limit=throughput_backoff_limit,
        ),
    )

describe

describe()

Describes what strategy is doing.

Source code in aitune/torch/tune_strategy/tune_strategy.py
def describe(self) -> str:
    """Describes what strategy is doing."""
    return "\n".join(self._describe_parts())

enable_correctness_check

enable_correctness_check(enable=True)

Enable/disable correctness checking.

Source code in aitune/torch/tune_strategy/tune_strategy.py
def enable_correctness_check(self, enable: bool = True) -> "TuneStrategy":
    """Enable/disable correctness checking."""
    self._enable_correctness_check = enable
    return self

enable_find_max_batch_size

enable_find_max_batch_size(enable=True)

Enables or disables find max batch size.

Source code in aitune/torch/tune_strategy/extension/find_max_batch_size_extension.py
def enable_find_max_batch_size(self, enable: bool = True) -> "TuneStrategyFindMaxBatchSizeExtension":
    """Enables or disables find max batch size."""
    self.find_config.enable_find_max_batch_size = enable
    return self

find_max_batch_size

find_max_batch_size(module, name, graph_spec, data, device, cache_dir)

Finds max batch size for the module.

Source code in aitune/torch/tune_strategy/extension/find_max_batch_size_extension.py
def find_max_batch_size(
    self,
    module: nn.Module,
    name: str,
    graph_spec: GraphSpec,
    data: list[Sample],
    device: torch.device,
    cache_dir: Path,
):
    """Finds max batch size for the module."""
    if self.find_config.enable_find_max_batch_size:
        self._logger.info("🚀 Finding max batch size for %s", name)
        find_max_batch_size_cache_dir = cache_dir / "find_max_batch_size"
        build_log_file = self._log_file(find_max_batch_size_cache_dir, "build.log")
        try:
            backend = self.find_config.default_backend_class()
            with control_output(log_file=build_log_file):
                backend.build(module, graph_spec, deepcopy(data), device, find_max_batch_size_cache_dir)

            max_batch_size, best_throughput, _ = calculate_highest_throughput_for_backend(
                backend,
                name,
                graph_spec,
                data,
                self.find_config.profiling_config,
            )
            self._logger.info(
                "✅ Max batch size for %s is %d with throughput %.2f samples/s",
                name,
                max_batch_size,
                best_throughput,
            )
            graph_spec.input_spec.update_max_batch_size(data[0], max_batch_size)
        except Exception:
            error_log_file = self._log_file(find_max_batch_size_cache_dir, "error.log")
            error_log_file.write_text(f"Build log file: {build_log_file}\n\nError:\n{traceback.format_exc()}")
            self._logger.info("⚠️ Finding max batch size for `%s` failed (log file: %s)", name, error_log_file)
            raise

set_find_max_batch_size_default_backend_class

set_find_max_batch_size_default_backend_class(default_backend_class)

Sets default backend class for find max batch size.

Source code in aitune/torch/tune_strategy/extension/find_max_batch_size_extension.py
def set_find_max_batch_size_default_backend_class(
    self, default_backend_class: type[Backend]
) -> "TuneStrategyFindMaxBatchSizeExtension":
    """Sets default backend class for find max batch size."""
    self.find_config.default_backend_class = default_backend_class
    return self

set_find_max_batch_size_profiling_config

set_find_max_batch_size_profiling_config(profiling_config)

Sets profiling config for find max batch size.

Source code in aitune/torch/tune_strategy/extension/find_max_batch_size_extension.py
def set_find_max_batch_size_profiling_config(
    self, profiling_config: ProfilingConfig
) -> "TuneStrategyFindMaxBatchSizeExtension":
    """Sets profiling config for find max batch size."""
    self.find_config.profiling_config = profiling_config
    return self

tune

tune(module, name, graph_spec, data, device, cache_dir)

Tunes given torch module with provided graph_spec and data.

Source code in aitune/torch/tune_strategy/tune_strategy.py
def tune(
    self,
    module: nn.Module,
    name: str,
    graph_spec: GraphSpec,
    data: list[Sample],
    device: torch.device,
    cache_dir: Path,
) -> Backend:
    """Tunes given torch module with provided graph_spec and data."""
    self._describe(module, name, graph_spec, data, device, cache_dir)
    with Timer(name=f"Tune `{self.__class__.__name__}`", sink=self._sink):
        self._pre_tune(module, name, graph_spec, data, device, cache_dir)
        backend = self._tune(module, name, graph_spec, data, device, cache_dir)
        self._post_tune(backend, name, graph_spec, data)
        return backend

tune_dry_run

tune_dry_run(module, name, graph_spec, data, device, cache_dir)

Performs tune dry run.

Source code in aitune/torch/tune_strategy/tune_strategy.py
def tune_dry_run(
    self,
    module: nn.Module,
    name: str,
    graph_spec: GraphSpec,
    data: list[Sample],
    device: torch.device,
    cache_dir: Path,
):
    """Performs tune dry run."""
    self._describe(module, name, graph_spec, data, device, cache_dir, dry_run=True)