Skip to content

Triton Model Store API

model_navigator.triton.model_repository.add_model(model_repository_path, model_name, model_path, config, model_version=1)

Generate model deployment inside provided model store path.

The config requires specialized configuration to be passed for backend on which model is executed. Example: - ONNX model requires ONNXModelConfig - TensorRT model requires TensorRTModelConfig - TorchScript or Torch-TensorRT models requires PyTorchModelConfig - TensorFlow SavedModel or TensorFlow-TensorRT models requires TensorFlowModelConfig - Python model requires PythonModelConfig

Parameters:

Name Type Description Default
model_repository_path Union[str, pathlib.Path]

Path where deployment should be created

required
model_name str

Name under which model is deployed in Triton Inference Server

required
model_path Union[str, pathlib.Path]

Path to model

required
config Union[ONNXModelConfig, TensorRTModelConfig, PyTorchModelConfig, PythonModelConfig, TensorFlowModelConfig]

Specialized configuration of model for backend on which model is executed

required
model_version int

Version of model that is deployed

1

Returns:

Type Description
pathlib.Path

Path to created model store

Source code in model_navigator/triton/model_repository.py
def add_model(
    model_repository_path: Union[str, pathlib.Path],
    model_name: str,
    model_path: Union[str, pathlib.Path],
    config: Union[
        ONNXModelConfig,
        TensorRTModelConfig,
        PyTorchModelConfig,
        PythonModelConfig,
        TensorFlowModelConfig,
    ],
    model_version: int = 1,
) -> pathlib.Path:
    """Generate model deployment inside provided model store path.

    The config requires specialized configuration to be passed for backend on which model is executed. Example:
    - ONNX model requires ONNXModelConfig
    - TensorRT model requires TensorRTModelConfig
    - TorchScript or Torch-TensorRT models requires PyTorchModelConfig
    - TensorFlow SavedModel or TensorFlow-TensorRT models requires TensorFlowModelConfig
    - Python model requires PythonModelConfig

    Args:
        model_repository_path: Path where deployment should be created
        model_name: Name under which model is deployed in Triton Inference Server
        model_path: Path to model
        config: Specialized configuration of model for backend on which model is executed
        model_version: Version of model that is deployed

    Returns:
         Path to created model store
    """
    if isinstance(config, ONNXModelConfig):
        model_config = ModelConfigBuilder.from_onnx_config(
            model_name=model_name,
            model_version=model_version,
            onnx_config=config,
        )
    elif isinstance(config, TensorFlowModelConfig):
        model_config = ModelConfigBuilder.from_tensorflow_config(
            model_name=model_name,
            model_version=model_version,
            tensorflow_config=config,
        )
    elif isinstance(config, PythonModelConfig):
        model_config = ModelConfigBuilder.from_python_config(
            model_name=model_name,
            model_version=model_version,
            python_config=config,
        )
    elif isinstance(config, PyTorchModelConfig):
        model_config = ModelConfigBuilder.from_pytorch_config(
            model_name=model_name,
            model_version=model_version,
            pytorch_config=config,
        )
    elif isinstance(config, TensorRTModelConfig):
        model_config = ModelConfigBuilder.from_tensorrt_config(
            model_name=model_name,
            model_version=model_version,
            tensorrt_config=config,
        )
    else:
        raise ModelNavigatorWrongParameterError(f"Unsupported model config provided: {config.__class__}")

    triton_model_repository = _TritonModelRepository(model_repository_path=pathlib.Path(model_repository_path))
    return triton_model_repository.deploy_model(
        model_path=pathlib.Path(model_path),
        model_config=model_config,
    )

model_navigator.triton.model_repository.add_model_from_package(model_repository_path, model_name, package, model_version=1, strategy=None, response_cache=False)

Create the Triton Model Store with optimized model and save it to model_repository_path.

Parameters:

Name Type Description Default
model_repository_path Union[str, pathlib.Path]

Path where the model store is located

required
model_name str

Name under which model is deployed in Triton Inference Server

required
model_version int

Version of model that is deployed

1
package Package

Package for which model store is created

required
strategy Optional[RuntimeSearchStrategy]

Strategy for finding the best runtime. When not set the MaxThroughputAndMinLatencyStrategy is used.

None
response_cache bool

Enable response cache for model

False

Returns:

Type Description

Path to created model store

Source code in model_navigator/triton/model_repository.py
def add_model_from_package(
    model_repository_path: Union[str, pathlib.Path],
    model_name: str,
    package: Package,
    model_version: int = 1,
    strategy: Optional[RuntimeSearchStrategy] = None,
    response_cache: bool = False,
):
    """Create the Triton Model Store with optimized model and save it to `model_repository_path`.

    Args:
        model_repository_path: Path where the model store is located
        model_name: Name under which model is deployed in Triton Inference Server
        model_version: Version of model that is deployed
        package: Package for which model store is created
        strategy: Strategy for finding the best runtime.
                  When not set the `MaxThroughputAndMinLatencyStrategy` is used.
        response_cache: Enable response cache for model

    Returns:
        Path to created model store
    """
    if package.is_empty():
        raise ModelNavigatorEmptyPackageError("No models available in the package. Triton deployment is not possible.")

    if package.config.batch_dim not in [0, None]:
        raise ModelNavigatorWrongParameterError(
            "Only models without batching or batch dimension on first place in shape are supported for Triton."
        )

    if strategy is None:
        strategy = MaxThroughputAndMinLatencyStrategy()

    batching = package.config.batch_dim == 0

    runtime_result = RuntimeAnalyzer.get_runtime(
        models_status=package.status.models_status,
        strategy=strategy,
        formats=[fmt.value for fmt in TRITON_FORMATS],
        runners=[runner.name() for runner in TRITON_RUNNERS],
    )
    max_batch_size = max(
        profiling_results.batch_size
        for profiling_results in runtime_result.runner_status.result[Performance.name()]["profiling_results"]
    )

    if runtime_result.model_status.model_config.format == Format.ONNX:
        config = _onnx_config_from_runtime_result(
            batching=batching,
            max_batch_size=max_batch_size,
            response_cache=response_cache,
            runtime_result=runtime_result,
        )

    elif runtime_result.model_status.model_config.format in [Format.TF_SAVEDMODEL, Format.TF_TRT]:
        config = _tensorflow_config_from_runtime_result(
            batching=batching,
            max_batch_size=max_batch_size,
            response_cache=response_cache,
            runtime_result=runtime_result,
        )
    elif runtime_result.model_status.model_config.format in [Format.TORCHSCRIPT, Format.TORCH_TRT]:
        inputs = input_tensor_from_metadata(
            package.status.input_metadata,
            batching=batching,
        )
        outputs = output_tensor_from_metadata(
            package.status.output_metadata,
            batching=batching,
        )

        config = _pytorch_config_from_runtime_result(
            batching=batching,
            max_batch_size=max_batch_size,
            inputs=inputs,
            outputs=outputs,
            response_cache=response_cache,
            runtime_result=runtime_result,
        )
    elif runtime_result.model_status.model_config.format == Format.TENSORRT:
        config = _tensorrt_config_from_runtime_result(
            batching=batching,
            max_batch_size=max_batch_size,
            response_cache=response_cache,
        )
    else:
        raise ModelNavigatorError(
            f"Unsupported model format selected: {runtime_result.model_status.model_config.format}"
        )

    return add_model(
        model_repository_path=model_repository_path,
        model_name=model_name,
        model_version=model_version,
        model_path=package.workspace / runtime_result.model_status.model_config.path,
        config=config,
    )