Skip to content

Utilities

model_navigator.api.utilities

Public utilities for the Model Navigator API.

UnpackedDataloader

UnpackedDataloader(dataloader, unpack_fn)

A wrapper around a SizedDataLoader that applies a function to each sample.

Parameters:

  • dataloader (SizedDataLoader) –

    A SizedDataLoader.

  • unpack_fn (Callable) –

    A function that takes a sample and returns a new sample.

Returns:

  • An iterator over the samples in the dataloader with the unpack_fn applied.

Example

dataloader = [1, 2, 3] unpacked_dataloader = UnpackedDataloader(dataloader, lambda x: x + 1)

unpacked_dataloader is now [2, 3, 4]

Initialize the UnpackedDataloader.

Source code in model_navigator/api/utilities.py
def __init__(self, dataloader: SizedDataLoader, unpack_fn: Callable):
    """Initialize the UnpackedDataloader."""
    self._dataloader = dataloader
    self._unpack_fn = unpack_fn

__iter__

__iter__()

Return an iterator over the samples in the dataloader with the unpack_fn applied.

Source code in model_navigator/api/utilities.py
def __iter__(self):
    """Return an iterator over the samples in the dataloader with the unpack_fn applied."""
    for sample in self._dataloader:
        yield self._unpack_fn(sample)

__len__

__len__()

Return the number of samples in the dataloader.

Source code in model_navigator/api/utilities.py
def __len__(self):
    """Return the number of samples in the dataloader."""
    return len(self._dataloader)

find_max_batch_size_till_oom

find_max_batch_size_till_oom(
    framework, model, dataloader, batch_dim=0, max_batch_size_search_limit=None, runner_config=None
)

Find the maximum batch size for a model.

Search is performed by running the model on the dataloader until an OOM error is encountered.

Parameters:

  • framework (Framework) –

    The framework of the model.

  • model (Any) –

    The model.

  • dataloader (SizedDataLoader) –

    A SizedDataLoader.

  • batch_dim (int, default: 0 ) –

    The batch dimension of the model.

  • max_batch_size_search_limit (Optional[int], default: None ) –

    Limit the search for the maximum batch size to this value.

  • runner_config (Optional[Dict], default: None ) –

    Additional runner configuration.

Source code in model_navigator/api/utilities.py
def find_max_batch_size_till_oom(
    framework: Framework,
    model: Any,
    dataloader: SizedDataLoader,
    batch_dim: int = 0,
    max_batch_size_search_limit: Optional[int] = None,
    runner_config: Optional[Dict] = None,
):
    """Find the maximum batch size for a model.

    Search is performed by running the model on the dataloader until an OOM error is encountered.

    Args:
        framework: The framework of the model.
        model: The model.
        dataloader: A SizedDataLoader.
        batch_dim: The batch dimension of the model.
        max_batch_size_search_limit: Limit the search for the maximum batch size to this value.
        runner_config: Additional runner configuration.
    """
    if framework == Framework.TORCH:
        runner_name = "TorchCUDA"
    elif framework == Framework.TENSORFLOW:
        runner_name = "TensorFlowCUDA"
    elif framework == Framework.ONNX:
        runner_name = "OnnxCUDA"
    elif framework == Framework.JAX:
        runner_name = "Jax"
    else:
        raise ModelNavigatorError(f"Unsupported {framework} for operation.")

    sample = next(iter(dataloader))

    pytree_metadata = PyTreeMetadata.from_sample(
        sample=sample, tensor_type=FRAMEWORK_TO_TENSOR_TYPE[framework], prefix="input"
    )
    input_metadata = TensorMetadata(pytree_metadata=pytree_metadata)

    profiling_sample = input_metadata.flatten_sample(sample=sample)
    profiling_sample = {name: to_numpy(tensor, from_framework=framework) for name, tensor in profiling_sample.items()}

    for name, tensor in profiling_sample.items():
        shape = list(tensor.shape)
        shape[batch_dim] = -1
        input_metadata.add(name=name, shape=shape, dtype=tensor.dtype)

    if runner_config is None:
        runner_config = {}

    optimization_profile = OptimizationProfile(
        max_batch_size=max_batch_size_search_limit,
        window_size=1,
        stabilization_windows=1,
        min_trials=1,
        max_trials=1,
        throughput_cutoff_threshold=-2,
    )

    with tempfile.NamedTemporaryFile() as temp_file:
        results_path = pathlib.Path(temp_file.name)

        runner = get_runner(runner_name)(
            model=model,
            input_metadata=input_metadata,
            output_metadata=None,
            enable_timer=True,
            **runner_config,
        )  # pytype: disable=not-instantiable
        try:
            LOGGER.info("Starting max batch size search.")
            MaxBatchSizeFinder(profile=optimization_profile, batch_dim=batch_dim, results_path=results_path,).run(
                runner=runner,
                profiling_sample=profiling_sample,
                sample_id=0,
            )
        finally:
            if results_path.is_file():
                with open(results_path) as file:
                    max_bs_line = file.readlines()[-1]
                    results_dict = json.loads(max_bs_line)
                    if "batch_size" in results_dict:
                        LOGGER.info(f"Max batch size: {results_dict['batch_size']}")
                    else:
                        raise ModelNavigatorProfilingError("Max batch size not found.")
            else:
                raise ModelNavigatorProfilingError("Max batch size not found.")
        return results_dict["batch_size"]