Module(module, name=None, input_mapping=None, output_mapping=None, timer=None, forward_func=None, batching=None, precision='fp32', model_path=None)
 
    
            
              Bases: ObjectProxy
        Inplace Optimize module wrapper.
This class wraps a torch module and provides inplace optimization functionality.
Depending on the configuration set in config, the module will be
optimized, recorded, or passed through.
This wrapper can be used in place of a torch module, and will behave
identically to the original module.
Parameters:
    
        - 
          
module
              (Module)
          –
          
         
        - 
          
name
              (Optional[str], default:
                  None
)
          –
          
         
        - 
          
input_mapping
              (Optional[Callable], default:
                  None
)
          –
          
            function to map module inputs to the expected input.
           
         
        - 
          
output_mapping
              (Optional[Callable], default:
                  None
)
          –
          
            function to map module outputs to the expected output.
           
         
        - 
          
forward_func
              (Optional[str], default:
                  None
)
          –
          
            forwarding function name used by the module, if None, the module call is used.
           
         
        - 
          
batching
              (Optional[bool], default:
                  None
)
          –
          
            enable or disable batching on first (index 0) dimension of the model
           
         
        - 
          
precision
              (PrecisionType, default:
                  'fp32'
)
          –
          
         
        - 
          
model_path
              (Optional[Union[str, Path]], default:
                  None
)
          –
          
            optional path to ONNX or TensorRT model file, if provided the model will be loaded from the file instead of converting
           
         
    
  Note
  batching if specified takes precedence over corresponding values in the
configuration specified in nav.profile.
 
  Example
  
import torch
import model_navigator as nav
model = torch.nn.Linear(10, 10)
model = nav.Module(model)
 
        Initialize Module.
                  
                    Source code in model_navigator/inplace/wrapper.py
                     | def __init__(
    self,
    module: "torch.nn.Module",
    name: Optional[str] = None,
    input_mapping: Optional[Callable] = None,
    output_mapping: Optional[Callable] = None,
    timer: Optional[Timer] = None,
    forward_func: Optional[str] = None,
    batching: Optional[bool] = None,
    precision: PrecisionType = "fp32",
    model_path: Optional[Union[str, pathlib.Path]] = None,
) -> None:
    """Initialize Module."""
    super().__init__(module)
    if not isinstance(module, torch.nn.Module):
        raise ModelNavigatorUserInputError("Only torch modules are supported.")
    self._name = name or get_object_name(module)
    self._input_mapping = input_mapping or (lambda x: x)
    self._output_mapping = output_mapping or (lambda x: x)
    self._optimize_config = None
    if timer:
        self.add_timer(timer=timer)
    else:
        self._module_timer = None
    current_forward = None
    if forward_func:
        try:
            current_forward = getattr(module, forward_func)
        except AttributeError as e:
            raise ModelNavigatorUserInputError(f"Forward method must exist, got {forward_func}.") from e
        setattr(module, forward_func, lambda *args, **kwargs: Module.__call__(self, *args, **kwargs))
    self.batching = batching
    self.precision = precision
    if isinstance(model_path, str):
        self.model_path = pathlib.Path(model_path)
    else:
        self.model_path = model_path
    if self.model_path is not None and self.model_path.suffix not in [
        ".onnx",
        ".plan",
    ]:  # pytype: disable=attribute-error
        raise ModelNavigatorUserInputError(
            f"model_path must be either ONNX or TensorRT model file with .onnx or .plan extension, got {self.model_path}."
        )
    self._device = get_module_device(module) or torch.device("cpu")
    self._wrapper = RecordingModule(
        module,
        # OptimizeConfig(),
        self._name,
        self._input_mapping,
        self._output_mapping,
        forward=current_forward,
    )
    module_registry.register(self._name, self)
  | 
 
                   
  
            is_optimized
  
      property
  
    
        Check if the module is optimized.
     
 
            is_ready_for_optimization
  
      property
  
is_ready_for_optimization
 
    
        Check if the module is ready for optimization.
     
 
            optimize_config
  
      property
      writable
  
    
 
            wrapper
  
      property
  
    
        Return the wrapper module.
     
 
            __call__
__call__(*args, **kwargs)
 
    
        Call the wrapped module.
This method overrides the call method of the wrapped module.
If the module is already optimized it is replaced with the optimized one.
            
              Source code in model_navigator/inplace/wrapper.py
               | def __call__(self, *args, **kwargs) -> Any:
    """Call the wrapped module.
    This method overrides the __call__ method of the wrapped module.
    If the module is already optimized it is replaced with the optimized one.
    """
    if self._module_timer and self._module_timer.enabled:
        with self._module_timer:
            output = self._wrapper(*args, **kwargs)
            if isinstance(self, torch.nn.Module) and torch.cuda.is_available():
                torch.cuda.synchronize()
    else:
        output = self._wrapper(*args, **kwargs)
    return output
  | 
 
             
     
 
            add_timer
    
        Add timer to module.
            
              Source code in model_navigator/inplace/wrapper.py
               | def add_timer(self, timer: Timer) -> None:
    """Add timer to module."""
    self._module_timer = timer.register_module(self._name)
  | 
 
             
     
 
            load_eager
    
        Load eager module.
            
              Source code in model_navigator/inplace/wrapper.py
               | @deactivate_wrapper
def load_eager(self, device: Optional[str] = None) -> None:
    """Load eager module."""
    self._wrapper = EagerModule(
        module=self._wrapper.module,
        name=self._name,
        input_mapping=self._input_mapping,
        output_mapping=self._output_mapping,
        optimize_config=self._optimize_config,
        device=device or self._device,
        forward=self._wrapper.forward_call,
    )
  | 
 
             
     
 
            load_optimized
load_optimized(strategies=None, device='cuda', activate_runners=True)
 
    
        Load optimized module.
Parameters:
    
        - 
          
strategies
              (Optional[List[RuntimeSearchStrategy]], default:
                  None
)
          –
          
            List of strategies for finding the best model. Strategies are selected in provided order. When
        first fails, next strategy from the list is used. When no strategies have been provided it
        defaults to [MaxThroughputAndMinLatencyStrategy, MinLatencyStrategy]
           
         
        - 
          
device
              (Union[str, device], default:
                  'cuda'
)
          –
          
            Device on which optimized modules would be loaded. Defaults to "cuda".
           
         
        - 
          
activate_runners
              (bool, default:
                  True
)
          –
          
            Activate models - load on device. Defaults to True.
           
         
    
            
              Source code in model_navigator/inplace/wrapper.py
               | @deactivate_wrapper
def load_optimized(
    self,
    strategies: Optional[List[RuntimeSearchStrategy]] = None,
    device: Union[str, "torch.device"] = "cuda",
    activate_runners: bool = True,
) -> None:
    """Load optimized module.
    Args:
        strategies: List of strategies for finding the best model. Strategies are selected in provided order. When
                    first fails, next strategy from the list is used. When no strategies have been provided it
                    defaults to [`MaxThroughputAndMinLatencyStrategy`, `MinLatencyStrategy`]
        device: Device on which optimized modules would be loaded. Defaults to "cuda".
        activate_runners: Activate models - load on device. Defaults to True.
    """
    self._wrapper = OptimizedModule(
        module=self._wrapper.module,
        # self._optimize_config,
        name=self._name,
        input_mapping=self._input_mapping,
        output_mapping=self._output_mapping,
        strategies=strategies,
        activate_runners=activate_runners,
        device=str(device),
        forward=self._wrapper.forward_call,
    )
  | 
 
             
     
 
            load_recorded
    
        Load recorded module.
            
              Source code in model_navigator/inplace/wrapper.py
               | @deactivate_wrapper
def load_recorded(self) -> None:
    """Load recorded module."""
    self._wrapper = RecordingModule(
        module=self._wrapper.module,
        name=self._name,
        input_mapping=self._input_mapping,
        output_mapping=self._output_mapping,
        optimize_config=self._optimize_config,
        forward=self._wrapper.forward_call,
    )
  | 
 
             
     
 
            optimize
    
        Optimize the module.
            
              Source code in model_navigator/inplace/wrapper.py
               | def optimize(self) -> None:
    """Optimize the module."""
    assert isinstance(self.wrapper, RecordingModule), f"Module {self.name} must be in recording mode to optimize."
    assert not self.is_optimized, f"Module {self.name} is already optimized."
    assert hasattr(self.wrapper, "optimize"), f"Module {self.name} does not have an optimize method."
    self._wrapper.optimize()
    self.load_optimized(activate_runners=False)
  | 
 
             
     
 
            triton_model_store
triton_model_store(model_repository_path, strategies=None, model_name=None, model_version=1, response_cache=False, warmup=False, package_idx=-1)
 
    
        Store the optimized module in the Triton model store.
Parameters:
    
        - 
          
model_repository_path
              (Path)
          –
          
            Path to store the optimized module.
           
         
        - 
          
strategies
              (Optional[List[RuntimeSearchStrategy]], default:
                  None
)
          –
          
            List of strategies for finding the best model.
    Strategies are selected in provided order. When first fails, next strategy from the list is used.
    When no strategies have been provided it defaults to [MaxThroughputAndMinLatencyStrategy, MinLatencyStrategy]
           
         
        - 
          
model_name
              (Optional[str], default:
                  None
)
          –
          
            Name of the module to use in the Triton model store, by default the module name is used.
           
         
        - 
          
model_version
              (int, default:
                  1
)
          –
          
            Version of model that is deployed
           
         
        - 
          
response_cache(bool)
          –
          
            Enable response cache for model
           
         
        - 
          
warmup
              (bool, default:
                  False
)
          –
          
            Enable warmup for min and max batch size
           
         
        - 
          
package_idx
              (int, default:
                  -1
)
          –
          
            Index of package - pipeline execution status - to use for storing in Triton model store. Default is -1, which means the last package.
           
         
    
            
              Source code in model_navigator/inplace/wrapper.py
               | def triton_model_store(
    self,
    model_repository_path: pathlib.Path,
    strategies: Optional[List[RuntimeSearchStrategy]] = None,
    model_name: Optional[str] = None,
    model_version: int = 1,
    response_cache: bool = False,
    warmup: bool = False,
    package_idx: int = -1,
):
    """Store the optimized module in the Triton model store.
    Args:
        model_repository_path (pathlib.Path): Path to store the optimized module.
        strategies (Optional[List[RuntimeSearchStrategy]]): List of strategies for finding the best model.
                Strategies are selected in provided order. When first fails, next strategy from the list is used.
                When no strategies have been provided it defaults to [`MaxThroughputAndMinLatencyStrategy`, `MinLatencyStrategy`]
        model_name (Optional[str]): Name of the module to use in the Triton model store, by default the module name is used.
        model_version (int): Version of model that is deployed
        response_cache(bool): Enable response cache for model
        warmup (bool): Enable warmup for min and max batch size
        package_idx (int): Index of package - pipeline execution status - to use for storing in Triton model store. Default is -1, which means the last package.
    """
    if not isinstance(self._wrapper, OptimizedModule):
        raise ModelNavigatorUserInputError(
            f"Module {self.name} must be optimized to store in Triton model store. Did you load_optimized()?"
        )
    if len(self._wrapper.packages) == 0:
        raise ModelNavigatorUserInputError(
            f"Module {self.name} must have packages to store in Triton model store. Did you optimize the module?"
        )
    try:
        package = self._wrapper.packages[package_idx]
    except IndexError as e:
        raise ModelNavigatorUserInputError(
            f"Incorrect package index {package_idx=} for module {self.name}. There are only {len(self._wrapper.packages)} packages."
        ) from e
    model_name = model_name or self.name
    model_repository.add_model_from_package(
        model_repository_path,
        model_name,
        package,
        strategies=strategies,
        model_version=model_version,
        response_cache=response_cache,
        warmup=warmup,
    )
  |