Skip to content

Decorators

The PyTriton provide decorators for operations on input requests to simplify passing the requests to the model inputs. We have prepared several useful decorators for converting generic request input into common user needs. You can create custom decorators tailored to your requirements and chain them with other decorators.

Batch

In many cases, it is more convenient to receive input already batched in the form of a NumPy array instead of a list of separate requests. For such cases, we have prepared the @batch decorator that adapts generic input into a batched form. It passes kwargs to the inference function where each named input contains a NumPy array with a batch of requests received by the Triton server.

Below, we show the difference between decorated and undecorated functions bound with Triton:

import numpy as np
from pytriton.decorators import batch
from pytriton.proxy.types import Request

# Sample input data with 2 requests - each with 2 inputs
input_data = [
     Request({'in1': np.array([[1, 1]]), 'in2': np.array([[2, 2]])}),
     Request({'in1': np.array([[1, 2]]), 'in2': np.array([[2, 3]])})
]


def undecorated_identity_fn(requests):
    print(requests)
    # As expected, requests = [
    #     Request({'in1': np.array([[1, 1]]), 'in2': np.array([[2, 2]])}),
    #     Request({'in1': np.array([[1, 2]]), 'in2': np.array([[2, 3]])}),
    # ]
    results = requests
    return results


@batch
def decorated_identity_fn(in1, in2):
    print(in1, in2)
    # in1 = np.array([[1, 1], [1, 2]])
    # in2 = np.array([[2, 2], [2, 3]])
    # Inputs are batched by `@batch` decorator and passed to the function as kwargs, so they can be automatically mapped
    # with in1, in2 function parameters
    # Of course, we could get these params explicitly with **kwargs like this:
    # def decorated_infer_fn(**kwargs):
    return {"out1": in1, "out2": in2}


undecorated_identity_fn(input_data)
decorated_identity_fn(input_data)

More examples using the @batch decorator with different frameworks are shown below.

Example implementation for TensorFlow model:

import numpy as np
import tensorflow as tf

from pytriton.decorators import batch


@batch
def infer_tf_fn(**inputs: np.ndarray):
    (images_batch,) = inputs.values()
    images_batch_tensor = tf.convert_to_tensor(images_batch)
    output1_batch = model.predict(images_batch_tensor)
    return [output1_batch]

Example implementation for PyTorch model:

import numpy as np
import torch

from pytriton.decorators import batch


@batch
def infer_pt_fn(**inputs: np.ndarray):
    (input1_batch,) = inputs.values()
    input1_batch_tensor = torch.from_numpy(input1_batch).to("cuda")
    output1_batch_tensor = model(input1_batch_tensor)
    output1_batch = output1_batch_tensor.cpu().detach().numpy()
    return [output1_batch]

Example implementation with named inputs and outputs:

import numpy as np
from pytriton.decorators import batch


@batch
def add_subtract_fn(a: np.ndarray, b: np.ndarray):
    return {"add": a + b, "sub": a - b}


@batch
def multiply_fn(**inputs: np.ndarray):
    a = inputs["a"]
    b = inputs["b"]
    return [a * b]

Example implementation with strings:

import numpy as np
from transformers import pipeline
from pytriton.decorators import batch

CLASSIFIER = pipeline("zero-shot-classification", model="facebook/bart-base", device=0)


@batch
def classify_text_fn(text_array: np.ndarray):
    text = text_array[0]  # text_array contains one string at index 0
    text = text.decode("utf-8")  # string is stored in byte array encoded in utf-8
    result = CLASSIFIER(text)
    return [np.array(result)]  # return statistics generated by classifier

Sample

@sample - takes the first request and converts it into named inputs. This decorator is useful with non-batching models. Instead of a one-element list of requests, we get named inputs - kwargs.

from pytriton.decorators import sample


@sample
def infer_fn(sequence):
    pass

Group by keys

@group_by_keys - groups requests with the same set of keys and calls the wrapped function for each group separately. This decorator is convenient to use before batching because the batching decorator requires a consistent set of inputs as it stacks them into batches.

from pytriton.decorators import batch, group_by_keys


@group_by_keys
@batch
def infer_fn(mandatory_input, optional_input=None):
    # perform inference
    pass

Group by values

@group_by_values(*keys) - groups requests with the same input value (for selected keys) and calls the wrapped function for each group separately. This decorator is particularly useful with models requiring dynamic parameters sent by users, such as temperature. In this case, we want to run the model only for requests with the same temperature value.

from pytriton.decorators import batch, group_by_values


@batch
@group_by_values('temperature')
def infer_fn(mandatory_input, temperature):
    # perform inference
    pass

Fill optionals

@fill_optionals(**defaults) - fills missing inputs in requests with default values provided by the user. If model owners have default values for some optional parameters, it's a good idea to provide them at the beginning, so other decorators can create larger consistent groups and send them to the inference callable.

import numpy as np
from pytriton.decorators import batch, fill_optionals, group_by_values


@fill_optionals(temperature=np.array([0.7]))
@batch
@group_by_values('temperature')
def infer_fn(mandatory_input, temperature):
    # perform inference
    pass

``

Pad batch

@pad_batch - appends the last row to the input multiple times to achieve the desired batch size (preferred batch size or max batch size from the model config, whichever is closer to the current input size).

from pytriton.decorators import batch, pad_batch

@batch
@pad_batch
def infer_fn(mandatory_input):
    # this model requires mandatory_input batch to be the size provided in the model config
    pass

First value

@first_value - this decorator takes the first elements from batches for selected inputs specified by the keys parameter. If the value is a one-element array, it is converted to a scalar value. This decorator is convenient to use with dynamic model parameters that users send in requests. You can use @group_by_values before to have batches with the same values in each batch.

import numpy as np
from pytriton.decorators import batch, fill_optionals, first_value, group_by_values

@fill_optionals(temperature=np.array([0.7]))
@batch
@group_by_values('temperature')
@first_value('temperature')
def infer_fn(mandatory_input, temperature):
    # perform inference with scalar temperature=10
    pass

Triton context

The @triton_context decorator provides an additional argument called triton_context, from which you can read the model config.

```python from pytriton.decorators import triton_context

@triton_context def infer_fn(input_list, **kwargs): model_config = kwargs['triton_context'].model_config # perform inference using some information from model_config pass ```

Stacking multiple decorators

Here is an example of stacking multiple decorators together. We recommend starting with type 1 decorators, followed by types 2 and 3. Place the @triton_context decorator last in the chain.

import numpy as np
from pytriton.decorators import batch, fill_optionals, first_value, group_by_keys, group_by_values, triton_context


@fill_optionals(temperature=np.array([0.7]))
@group_by_keys
@batch
@group_by_values('temperature')
@first_value('temperature')
@triton_context
def infer(triton_context, mandatory_input, temperature, opt1=None, opt2=None):
    model_config = triton_context.model_config
    # perform inference using:
    #   - some information from model_config
    #   - scalar temperature value
    #   - optional parameters opt1 and/or opt2