Skip to content

Inplace Optimize Module

model_navigator.Module

Module(module, name=None, input_mapping=None, output_mapping=None, timer=None, offload_parameters_to_cpu=False, forward_func=None)

Bases: ObjectProxy

Inplace Optimize module wrapper.

This class wraps a torch module and provides inplace optimization functionality. Depending on the configuration set in config, the module will be optimized, recorded, or passed through.

This wrapper can be used in place of a torch module, and will behave identically to the original module.

Parameters:

  • module (Module) –

    torch module to wrap.

  • name (Optional[str], default: None ) –

    module name.

  • input_mapping (Optional[Callable], default: None ) –

    function to map module inputs to the expected input.

  • output_mapping (Optional[Callable], default: None ) –

    function to map module outputs to the expected output.

  • offload_parameters_to_cpu (bool, default: False ) –

    offload parameters to cpu.

  • forward_func (Optional[str], default: None ) –

    forwarding function name used by the module, if None, the module call is used.

Example

import torch import model_navigator as nav model = torch.nn.Linear(10, 10) model = nav.Module(model)

Initialize Module.

Source code in model_navigator/inplace/wrapper.py
def __init__(
    self,
    module: "torch.nn.Module",
    name: Optional[str] = None,
    input_mapping: Optional[Callable] = None,
    output_mapping: Optional[Callable] = None,
    timer: Optional[Timer] = None,
    offload_parameters_to_cpu: bool = False,
    forward_func: Optional[str] = None,
) -> None:
    """Initialize Module."""
    super().__init__(module)
    self._name = name or get_object_name(module)
    self._input_mapping = input_mapping or (lambda x: x)
    self._output_mapping = output_mapping or (lambda x: x)
    self._optimize_config = None
    if timer:
        self.add_timer(timer=timer)
    else:
        self._module_timer = None

    current_forward = None
    if forward_func:
        try:
            current_forward = getattr(module, forward_func)
        except AttributeError as e:
            raise ModelNavigatorUserInputError(f"Forward method must exist, got {forward_func}.") from e
        setattr(module, forward_func, self.__call__)

    self._wrapper = RecordAndOptimizeModule(
        module,
        # OptimizeConfig(),
        self._name,
        self._input_mapping,
        self._output_mapping,
        forward=current_forward,
    )
    module_registry.register(self._name, self)

is_optimized property

is_optimized

Check if the module is optimized.

is_ready_for_optimization property

is_ready_for_optimization

Check if the module is ready for optimization.

name property

name

Module name.

optimize_config property writable

optimize_config

Module optimize config.

wrapper property

wrapper

Return the wrapper module.

__call__

__call__(*args, **kwargs)

Call the wrapped module.

This method overrides the call method of the wrapped module. If the module is already optimized it is replaced with the optimized one.

Source code in model_navigator/inplace/wrapper.py
def __call__(self, *args, **kwargs) -> Any:
    """Call the wrapped module.

    This method overrides the __call__ method of the wrapped module.
    If the module is already optimized it is replaced with the optimized one.
    """
    if self._module_timer and self._module_timer.enabled:
        with self._module_timer:
            output = self._wrapper(*args, **kwargs)
            if isinstance(self, torch.nn.Module):
                torch.cuda.synchronize()
    else:
        output = self._wrapper(*args, **kwargs)

    return output

add_timer

add_timer(timer)

Add timer to module.

Source code in model_navigator/inplace/wrapper.py
def add_timer(self, timer: Timer) -> None:
    """Add timer to module."""
    self._module_timer = timer.register_module(self._name)

load_optimized

load_optimized(strategy=None, activate_runners=True)

Load optimized module.

Source code in model_navigator/inplace/wrapper.py
def load_optimized(self, strategy: Optional[RuntimeSearchStrategy] = None, activate_runners: bool = True) -> None:
    """Load optimized module."""
    # TODO: Consider another validation for optimization status here. is_optimized property is modified by loading passthrough.
    # if not self.is_optimized:
    #     raise ModelNavigatorModuleNotOptimizedError(f"Module {self.name} is not optimized.")
    self._wrapper = OptimizedModule(
        module=self._wrapper._module,
        # self._optimize_config,
        name=self._name,
        input_mapping=self._input_mapping,
        output_mapping=self._output_mapping,
        strategy=strategy,
        activate_runners=activate_runners,
    )

load_passthrough

load_passthrough()

Load passthrough module.

Source code in model_navigator/inplace/wrapper.py
def load_passthrough(self) -> None:
    """Load passthrough module."""
    self._wrapper = PassthroughModule(
        module=self._wrapper._module,
        name=self._name,
        input_mapping=self._input_mapping,
        output_mapping=self._output_mapping,
        optimize_config=self._optimize_config,
    )

load_recorded

load_recorded()

Load recorded module.

Source code in model_navigator/inplace/wrapper.py
def load_recorded(self) -> None:
    """Load recorded module."""
    self._wrapper = RecordModule(
        module=self._wrapper._module,
        name=self._name,
        input_mapping=self._input_mapping,
        output_mapping=self._output_mapping,
        optimize_config=self._optimize_config,
    )

optimize

optimize()

Optimize the module.

Source code in model_navigator/inplace/wrapper.py
def optimize(self) -> None:
    """Optimize the module."""
    assert isinstance(self.wrapper, RecordModule), f"Module {self.name} must be in recording mode to optimize."
    assert not self.is_optimized, f"Module {self.name} is already optimized."
    assert hasattr(self.wrapper, "optimize"), f"Module {self.name} does not have an optimize method."
    self.wrapper.optimize()