Skip to content

Model Store API

model_navigator.triton.model_repository.add_model

add_model(model_repository_path, model_name, model_path, config, model_version=1)

Generate model deployment inside provided model store path.

The config requires specialized configuration to be passed for backend on which model is executed. Example:

  • ONNX model requires ONNXModelConfig
  • TensorRT model requires TensorRTModelConfig
  • TorchScript or Torch-TensorRT models requires PyTorchModelConfig
  • TensorFlow SavedModel or TensorFlow-TensorRT models requires TensorFlowModelConfig
  • Python model requires PythonModelConfig
  • TensorRT-LLM model requires TensorRTLLMModelConfig

Parameters:

Returns:

  • Path

    Path to created model store

Source code in model_navigator/triton/model_repository.py
def add_model(
    model_repository_path: Union[str, pathlib.Path],
    model_name: str,
    model_path: Union[str, pathlib.Path],
    config: Union[
        ONNXModelConfig,
        TensorRTModelConfig,
        PyTorchModelConfig,
        PythonModelConfig,
        TensorFlowModelConfig,
        TensorRTLLMModelConfig,
    ],
    model_version: int = 1,
) -> pathlib.Path:
    """Generate model deployment inside provided model store path.

    The config requires specialized configuration to be passed for backend on which model is executed. Example:

    - ONNX model requires ONNXModelConfig
    - TensorRT model requires TensorRTModelConfig
    - TorchScript or Torch-TensorRT models requires PyTorchModelConfig
    - TensorFlow SavedModel or TensorFlow-TensorRT models requires TensorFlowModelConfig
    - Python model requires PythonModelConfig
    - TensorRT-LLM model requires TensorRTLLMModelConfig

    Args:
        model_repository_path: Path where deployment should be created
        model_name: Name under which model is deployed in Triton Inference Server
        model_path: Path to model
        config: Specialized configuration of model for backend on which model is executed
        model_version: Version of model that is deployed

    Returns:
         Path to created model store
    """
    if isinstance(config, ONNXModelConfig):
        model_config = ModelConfigBuilder.from_onnx_config(
            model_name=model_name,
            model_version=model_version,
            onnx_config=config,
        )
    elif isinstance(config, TensorFlowModelConfig):
        model_config = ModelConfigBuilder.from_tensorflow_config(
            model_name=model_name,
            model_version=model_version,
            tensorflow_config=config,
        )
    elif isinstance(config, PythonModelConfig):
        model_config = ModelConfigBuilder.from_python_config(
            model_name=model_name,
            model_version=model_version,
            python_config=config,
        )
    elif isinstance(config, PyTorchModelConfig):
        model_config = ModelConfigBuilder.from_pytorch_config(
            model_name=model_name,
            model_version=model_version,
            pytorch_config=config,
        )
    elif isinstance(config, TensorRTModelConfig):
        model_config = ModelConfigBuilder.from_tensorrt_config(
            model_name=model_name,
            model_version=model_version,
            tensorrt_config=config,
        )
    elif isinstance(config, TensorRTLLMModelConfig):
        model_config = ModelConfigBuilder.from_tensorrt_llm_config(
            model_name=model_name,
            model_version=model_version,
            tensorrt_llm_config=config,
        )
    else:
        raise ModelNavigatorWrongParameterError(f"Unsupported model config provided: {config.__class__}")

    model_repository_path = pathlib.Path(model_repository_path)

    # Collect model filename if default not provided
    backend = model_config.backend or model_config.platform
    model_filename = model_config.default_model_filename or _get_default_filename(backend=backend)

    # Path to model version catalog
    model_version_path = _get_version_path(
        model_repository_path=model_repository_path,
        model_name=model_config.model_name,
        version=model_config.model_version,
    )

    if isinstance(config, TensorRTLLMModelConfig):
        config.engine_dir = model_version_path / model_filename

    initial_state_files = _collect_initial_state_files(model_config=model_config)
    warmup_files = _collect_warmup_files(model_config=model_config)

    triton_model_repository = _TritonModelRepository(
        model_repository_path=model_repository_path,
        model_name=model_name,
        model_version=model_version,
        model_filename=model_filename,
    )

    return triton_model_repository.deploy_model(
        model_path=pathlib.Path(model_path),
        model_config=model_config,
        warmup_files=warmup_files,
        initial_state_files=initial_state_files,
    )

model_navigator.triton.model_repository.add_model_from_package

add_model_from_package(model_repository_path, model_name, package, model_version=1, strategies=None, response_cache=False, warmup=False)

Create the Triton Model Store with optimized model and save it to model_repository_path.

Parameters:

  • model_repository_path (Union[str, Path]) –

    Path where the model store is located

  • model_name (str) –

    Name under which model is deployed in Triton Inference Server

  • package (Package) –

    Package for which model store is created

  • model_version (int, default: 1 ) –

    Version of model that is deployed

  • strategies (Optional[List[RuntimeSearchStrategy]], default: None ) –

    List of strategies for finding the best model. Strategies are selected in provided order. When first fails, next strategy from the list is used. When none provided the strategies defaults to [MaxThroughputAndMinLatencyStrategy, MinLatencyStrategy]

  • response_cache (bool, default: False ) –

    Enable response cache for model

  • warmup (bool, default: False ) –

    Enable warmup for min and max batch size

Returns:

  • Path to created model store

Source code in model_navigator/triton/model_repository.py
def add_model_from_package(
    model_repository_path: Union[str, pathlib.Path],
    model_name: str,
    package: Package,
    model_version: int = 1,
    strategies: Optional[List[RuntimeSearchStrategy]] = None,
    response_cache: bool = False,
    warmup: bool = False,
):
    """Create the Triton Model Store with optimized model and save it to `model_repository_path`.

    Args:
        model_repository_path: Path where the model store is located
        model_name: Name under which model is deployed in Triton Inference Server
        package: Package for which model store is created
        model_version: Version of model that is deployed
        strategies: List of strategies for finding the best model. Strategies are selected in provided order. When
                    first fails, next strategy from the list is used. When none provided the strategies
                    defaults to [`MaxThroughputAndMinLatencyStrategy`, `MinLatencyStrategy`]
        response_cache: Enable response cache for model
        warmup: Enable warmup for min and max batch size

    Returns:
        Path to created model store
    """
    if package.is_empty():
        raise ModelNavigatorEmptyPackageError("No models available in the package. Triton deployment is not possible.")

    if package.config.batch_dim not in [0, None]:
        raise ModelNavigatorWrongParameterError(
            "Only models without batching or batch dimension on first place in shape are supported for Triton."
        )

    if strategies is None:
        strategies = [MaxThroughputAndMinLatencyStrategy(), MinLatencyStrategy()]

    batching = package.config.batch_dim == 0

    runtime_result = None
    for strategy in strategies:
        try:
            runtime_result = RuntimeAnalyzer.get_runtime(
                models_status=package.status.models_status,
                strategy=strategy,
                formats=[fmt.value for fmt in TRITON_FORMATS],
                runners=[runner.name() for runner in TRITON_RUNNERS],
            )
            break
        except ModelNavigatorRuntimeAnalyzerError:
            LOGGER.debug(f"No model found with strategy: {strategy}")

    if not runtime_result:
        raise ModelNavigatorError("No optimized model found in package.")

    max_batch_size = max(
        profiling_results.batch_size if profiling_results.batch_size is not None else 0
        for profiling_results in runtime_result.runner_status.result[Performance.name]["profiling_results"]
    )

    model_warmup = {}
    if warmup:
        model_warmup = _prepare_model_warmup(
            batching=batching,
            max_batch_size=max_batch_size,
            package=package,
        )

    if runtime_result.model_status.model_config.format == Format.TENSORRT:
        input_metadata, output_metadata = (
            _prepare_tensorrt_metadata(package.status.input_metadata),
            _prepare_tensorrt_metadata(package.status.output_metadata),
        )
    else:
        input_metadata, output_metadata = package.status.input_metadata, package.status.output_metadata

    inputs = _input_tensor_from_metadata(
        input_metadata,
        batching=batching,
    )
    outputs = _output_tensor_from_metadata(
        output_metadata,
        batching=batching,
    )

    if runtime_result.model_status.model_config.format == Format.ONNX:
        config = _onnx_config_from_runtime_result(
            batching=batching,
            max_batch_size=max_batch_size,
            inputs=inputs,
            outputs=outputs,
            response_cache=response_cache,
            runtime_result=runtime_result,
            warmup=model_warmup,
        )

    elif runtime_result.model_status.model_config.format in [Format.TF_SAVEDMODEL, Format.TF_TRT]:
        config = _tensorflow_config_from_runtime_result(
            batching=batching,
            max_batch_size=max_batch_size,
            inputs=inputs,
            outputs=outputs,
            response_cache=response_cache,
            runtime_result=runtime_result,
            warmup=model_warmup,
        )
    elif runtime_result.model_status.model_config.format in [Format.TORCHSCRIPT, Format.TORCH_TRT]:
        config = _pytorch_config_from_runtime_result(
            batching=batching,
            max_batch_size=max_batch_size,
            inputs=inputs,
            outputs=outputs,
            response_cache=response_cache,
            runtime_result=runtime_result,
            warmup=model_warmup,
        )
    elif runtime_result.model_status.model_config.format == Format.TENSORRT:
        config = _tensorrt_config_from_runtime_result(
            batching=batching,
            max_batch_size=max_batch_size,
            inputs=inputs,
            outputs=outputs,
            response_cache=response_cache,
            warmup=model_warmup,
        )
    else:
        raise ModelNavigatorError(
            f"Unsupported model format selected: {runtime_result.model_status.model_config.format}"
        )

    return add_model(
        model_repository_path=model_repository_path,
        model_name=model_name,
        model_version=model_version,
        model_path=package.workspace.path / runtime_result.model_status.model_config.path,
        config=config,
    )