PyTriton
PyTriton is a Flask/FastAPI-like interface that simplifies Triton's deployment in Python environments. The library allows serving Machine Learning models directly from Python through NVIDIA's Triton Inference Server.
How it works?
In PyTriton, as in Flask or FastAPI, you can define any Python function that executes a machine learning model prediction and exposes it through an HTTP/gRPC API. PyTriton installs Triton Inference Server in your environment and uses it for handling HTTP/gRPC requests and responses. Our library provides a Python API that allows attaching a Python function to Triton and a communication layer to send/receive data between Triton and the function. This solution helps utilize the performance features of Triton Inference Server, such as dynamic batching or response cache, without changing your model environment. Thus, it improves the performance of running inference on GPU for models implemented in Python. The solution is framework-agnostic and can be used along with frameworks like PyTorch, TensorFlow, or JAX.
Serving the models
PyTriton provides an option to serve your Python model using Triton Inference Server to handle HTTP/gRPC requests and pass the input/output tensors to and from the model. We use a blocking mode where the application is a long-lived process deployed in your cluster to serve the requests from clients.
Before you run the model for serving the inference callback function, it has to be defined. The inference callback receives the inputs and should return the model outputs:
import numpy as np
from pytriton.decorators import batch
@batch
def infer_fn(**inputs: np.ndarray):
input1, input2 = inputs.values()
outputs = model(input1, input2)
return [outputs]
The infer_fn
receives the batched input data for the model and should return the batched outputs.
In the next step, you need to create a connection between Triton and the model. For that purpose, the Triton
class has to
be used, and the bind
method is required to be called to create a dedicated connection between Triton Inference
Server and the defined infer_fn
.
In the blocking mode, we suggest using the Triton
object as a context manager where multiple models can be loaded in
the way presented below:
from pytriton.triton import Triton
from pytriton.model_config import ModelConfig, Tensor
with Triton() as triton:
triton.bind(
model_name="MyModel",
infer_func=infer_fn,
inputs=[
Tensor(dtype=bytes, shape=(1,)), # sample containing single bytes value
Tensor(dtype=bytes, shape=(-1,)), # sample containing vector of bytes
],
outputs=[
Tensor(dtype=np.float32, shape=(-1,)),
],
config=ModelConfig(max_batch_size=16),
)
At this point, you have defined how the model has to be handled by Triton and where the HTTP/gRPC requests for the model have
to be directed. The last part for serving the model is to call the serve
method on the Triton object:
When the .serve()
method is called on the Triton
object, the inference queries can be sent to
localhost:8000/v2/models/MyModel
, and the infer_fn
is called to handle the inference query.
Working in the Jupyter Notebook
The package provides an option to work with your model inside the Jupyter Notebook. We call it a background mode where the model is deployed on Triton Inference Server for handling HTTP/gRPC requests, but there are other actions that you want to perform after loading and starting serving the model.
Having the infer_fn
defined in the same way as described in the serving the models section, you
can use the Triton
object without a context:
In the next step, the model has to be loaded for serving in Triton Inference Server (which is also the same as in the serving example):
import numpy as np
from pytriton.decorators import batch
from pytriton.model_config import ModelConfig, Tensor
@batch
def infer_fn(**inputs: np.ndarray):
input1, input2 = inputs.values()
outputs = input1 + input2
return [outputs]
triton.bind(
model_name="MyModel",
infer_func=infer_fn,
inputs=[
Tensor(shape=(1,), dtype=np.float32),
Tensor(shape=(-1,), dtype=np.float32),
],
outputs=[Tensor(shape=(-1,), dtype=np.float32)],
config=ModelConfig(max_batch_size=16),
)
Finally, to run the model in background mode, use the run
method:
When the .run()
method is called on the Triton
object, the inference queries can be sent to
localhost:8000/v2/models/MyModel
, and the infer_fn
is called to handle the inference query.
The Triton server can be stopped at any time using the stop
method:
In-depth Topics and Examples
Model Deployment
Fine-tune your model deployment strategy with our targeted documentation:
- Initialize Triton for seamless startup.
- Bind your models to Triton for enhanced communication.
- Adjust your binding configurations for improved control.
- Expand your reach by deploying on clusters.
- Master the use of Triton in remote mode.
Inference Management
Hone your understanding of inference process management through PyTriton:
- Tailor the Inference Callable to your model's requirements.
- Use decorators to simplify your inference callbacks.
- Incorporate custom parameters/headers for flexibility. data.
Dive into Examples
Visit the examples directory for a curated selection of use cases ranging from basic to advanced, including:
- Standard model serving scenarios with different frameworks: PyTorch, TensorFlow2, JAX.
- Advanced setups like online learning, multi-node execution, or Kubernetes deployments.
Troubleshooting
If you encounter any obstacles, our Known Issues page is a helpful resource for troubleshooting common challenges.
Streaming (alpha)
We introduced new alpha feature to PyTriton that allows to stream partial responses from a model. It is based on NVIDIA Triton Inference deocoupled models feature. Look at example in examples/huggingface_dialogpt_streaming_pytorch.
Profiling model
The Perf Analyzer can be used to profile models served through PyTriton. We have prepared an example of using the Perf Analyzer to profile the BART PyTorch model. The example code can be found in examples/perf_analyzer.
Open Telemetry is a set of APIs, libraries, agents, and instrumentation to provide observability for cloud-native software. We have prepared an guide on how to use Open Telemetry with PyTriton.
What next?
Read more about using PyTriton in the Quick Start, Examples and find more options on how to configure Triton, models, and deployment on a cluster in the Deploying Models section.
The details about classes and methods can be found in the API Reference page.
If there are any issues diffcult to invastigate, it is possible to use pytriton-check tool. Usage is described in the Basic Troubleshooting section.