Skip to content

How to use PyTriton client to split a large input into smaller batches and send them to the server in parallel

In this article, you will learn how to use PyTriton clients to create a chunking client that can handle inputs that are larger than the maximum batch size of your model.

First, you need to create a model that can process a batch of inputs and produce a batch of outputs. For simplicity, let's assume that your model can only handle two inputs at a time. We will call this model "Batch2" and run it on a local Triton server.

Next, you need to create a client that can send requests to your model. In this article, we will use the FuturesModelClient, which returns a Future object for each request. A Future object is a placeholder that can be used to get the result or check the status of the request later.

However, there is a problem with using the FuturesModelClient directly. If you try to send an input that is larger than the maximum batch size of your model, you will get an error. For example, the following code tries to send an input of size 4 to the "Batch2" model, which has a maximum batch size of 2:

import numpy as np
from pytriton.client import FuturesModelClient

with FuturesModelClient(f"localhost", "Batch2") as client:
    input_tensor = np.zeros((4, 1), dtype=np.int32)
    print(client.infer_batch(input_tensor).result())

This code will raise an exception like this:

PyTritonClientInferenceServerError: Error occurred during inference request. Message: [request id: 0] inference request batch-size must be <= 2 for 'Batch2'

To solve this problem, we can use a ChunkingClient class that inherits from FuturesModelClient and overrides the infer_batch method. The ChunkingClient class takes a chunking strategy as an argument, which is a function that takes the input dictionary and the maximum batch size as parameters and yields smaller dictionaries of inputs. The default chunking strategy simply splits the input along the first dimension according to the maximum batch size. For example, if the input is {"INPUT_1": np.zeros((5, 1), dtype=np.int32)} and the maximum batch size is 2, then the default chunking strategy will yield:

{"INPUT_1": np.zeros((2, 1), dtype=np.int32)}
{"INPUT_1": np.zeros((2, 1), dtype=np.int32)}
{"INPUT_1": np.zeros((1, 1), dtype=np.int32)}

You can also define your own chunking strategy if you have more complex logic for splitting your input.

# Define a ChunkingClient class that inherits from FuturesModelClient and splits the input into smaller batches
import concurrent.futures
from pytriton.client import FuturesModelClient

class ChunkingClient(FuturesModelClient):
    def __init__(self, host, model_name, chunking_strategy=None, max_workers=None):
        super().__init__(host, model_name, max_workers=max_workers)
        self.chunking_strategy = chunking_strategy or self.default_chunking_strategy

    def default_chunking_strategy(self, kwargs, max_batch_size):
        # Split the input by the first dimension according to the max batch size
        size_of_dimention_0 = self.find_size_0(kwargs)
        for i in range(0, size_of_dimention_0, max_batch_size):
            yield {key: value[i:i+max_batch_size] for key, value in kwargs.items()}

    def find_size_0(self, kwargs):
        # Check the size of the first dimension of each tensor and raise errors if they are not consistent or valid
        size_of_dimention_0 = None
        for key, value in kwargs.items():
            if isinstance(value, np.ndarray):
                if value.ndim > 0:
                    size = value.shape[0]
                    if size_of_dimention_0 is None or size_of_dimention_0 == size:
                        size_of_dimention_0 = size
                    else:
                        raise ValueError("The tensors have different sizes at the first dimension")
                else:
                    raise ValueError("The tensor has no first dimension")
            else:
                raise TypeError("The value is not a numpy tensor")
        return size_of_dimention_0

    def infer_batch(self, *args, **kwargs):
        max_batch_size = self.model_config().result().max_batch_size
        # Send the smaller batches to the server in parallel and yield the futures with results
        futures = [super(ChunkingClient, self).infer_batch(*args, **chunk) for chunk in self.chunking_strategy(kwargs, max_batch_size)]
        for future in futures:
            yield future
To use the ChunkingClient class, you can create an instance of it and use it in a context manager. For example:

# Use the ChunkingClient class with the default strategy to send an input of size 5 to the "Batch2" model
import numpy as np
from pytriton.client import FuturesModelClient
chunker_client = ChunkingClient("localhost", "Batch2")
results = []
with chunker_client as client:
    input_tensor = np.zeros((5, 1), dtype=np.int32)
    # Print the results of each future without concatenating them
    for future in client.infer_batch(INPUT_1=input_tensor):
        results.append(future.result())
print(results)

This code will print:

{'OUTPUT_1': array([[0],
       [0]], dtype=int32)}
{'OUTPUT_1': array([[0],
       [0]], dtype=int32)}
{'OUTPUT_1': array([[0]], dtype=int32)}

You can see that the input is split into three batches of sizes 2, 2, and 1, and each batch is sent to the server in parallel. The results are returned as futures that can be accessed individually without concatenating them.