Skip to content

Utilities

model_navigator.api.utilities

Public utilities for the Model Navigator API.

UnpackedDataloader

UnpackedDataloader(dataloader, unpack_fn)

A wrapper around a SizedDataLoader that applies a function to each sample.

Parameters:

  • dataloader (SizedDataLoader) –

    A SizedDataLoader.

  • unpack_fn (Callable) –

    A function that takes a sample and returns a new sample.

Returns:

  • An iterator over the samples in the dataloader with the unpack_fn applied.

Example

dataloader = [1, 2, 3] unpacked_dataloader = UnpackedDataloader(dataloader, lambda x: x + 1)

unpacked_dataloader is now [2, 3, 4]

Initialize the UnpackedDataloader.

Source code in model_navigator/api/utilities.py
def __init__(self, dataloader: SizedDataLoader, unpack_fn: Callable):
    """Initialize the UnpackedDataloader."""
    self._dataloader = dataloader
    self._unpack_fn = unpack_fn

__iter__

__iter__()

Return an iterator over the samples in the dataloader with the unpack_fn applied.

Source code in model_navigator/api/utilities.py
def __iter__(self):
    """Return an iterator over the samples in the dataloader with the unpack_fn applied."""
    for sample in self._dataloader:
        yield self._unpack_fn(sample)

__len__

__len__()

Return the number of samples in the dataloader.

Source code in model_navigator/api/utilities.py
def __len__(self):
    """Return the number of samples in the dataloader."""
    return len(self._dataloader)

find_max_batch_size_till_oom

find_max_batch_size_till_oom(framework, model, dataloader, batch_dim=0, max_batch_size_search_limit=None)

Find the maximum batch size for a model.

Search is performed by running the model on the dataloader until an OOM error is encountered.

Parameters:

  • framework (Framework) –

    The framework of the model.

  • model (Any) –

    The model.

  • dataloader (SizedDataLoader) –

    A SizedDataLoader.

  • batch_dim (int, default: 0 ) –

    The batch dimension of the model.

  • max_batch_size_search_limit (Optional[int], default: None ) –

    Limit the search for the maximum batch size to this value.

Source code in model_navigator/api/utilities.py
def find_max_batch_size_till_oom(
    framework: Framework,
    model: Any,
    dataloader: SizedDataLoader,
    batch_dim: int = 0,
    max_batch_size_search_limit: Optional[int] = None,
):
    """Find the maximum batch size for a model.

    Search is performed by running the model on the dataloader until an OOM error is encountered.

    Args:
        framework: The framework of the model.
        model: The model.
        dataloader: A SizedDataLoader.
        batch_dim: The batch dimension of the model.
        max_batch_size_search_limit: Limit the search for the maximum batch size to this value.
    """
    if framework == Framework.TORCH:
        runner_name = "TorchCUDA"
    elif framework == Framework.TENSORFLOW:
        runner_name = "TensorFlowCUDA"
    elif framework == Framework.ONNX:
        runner_name = "OnnxCUDA"
    elif framework == Framework.JAX:
        runner_name = "Jax"

    sample = next(iter(dataloader))

    pytree_metadata = PyTreeMetadata.from_sample(
        sample=sample, tensor_type=FRAMEWORK_TO_TENSOR_TYPE[framework], prefix="input"
    )
    input_metadata = TensorMetadata(pytree_metadata=pytree_metadata)

    profiling_sample = input_metadata.flatten_sample(sample=sample)
    profiling_sample = {name: to_numpy(tensor, from_framework=framework) for name, tensor in profiling_sample.items()}

    for name, tensor in profiling_sample.items():
        shape = list(tensor.shape)
        shape[batch_dim] = -1
        input_metadata.add(name=name, shape=shape, dtype=tensor.dtype)

    with tempfile.NamedTemporaryFile() as temp_file:
        results_path = pathlib.Path(temp_file.name)

        runner = get_runner(runner_name)(
            model=model,
            input_metadata=input_metadata,
            output_metadata=TensorMetadata(),
        )  # pytype: disable=not-instantiable
        runner.infer = functools.partial(runner.infer, return_raw_outputs=True, check_inputs=False)
        try:
            LOGGER.info("Starting max batch size search.")
            MaxBatchSizeFinder(
                profile=OptimizationProfile(max_batch_size=max_batch_size_search_limit),
                batch_dim=batch_dim,
                results_path=results_path,
            ).run(
                runner=runner,
                profiling_sample=profiling_sample,
                sample_id=0,
            )
        finally:
            if results_path.is_file():
                with open(results_path) as file:
                    max_bs_line = file.readlines()[-1]
                    results_dict = json.loads(max_bs_line)
                    if "batch_size" in results_dict:
                        LOGGER.info(f"Max batch size: {results_dict['batch_size']}")
                    else:
                        raise ModelNavigatorProfilingError("Max batch size not found.")
            else:
                raise ModelNavigatorProfilingError("Max batch size not found.")
        return results_dict["batch_size"]