Module(module, name=None, input_mapping=None, output_mapping=None, timer=None, offload_parameters_to_cpu=False, forward_func=None)
Bases: ObjectProxy
Inplace Optimize module wrapper.
This class wraps a torch module and provides inplace optimization functionality.
Depending on the configuration set in config, the module will be
optimized, recorded, or passed through.
This wrapper can be used in place of a torch module, and will behave
identically to the original module.
Parameters:
-
module
(Module
)
–
-
name
(Optional[str]
, default:
None
)
–
-
input_mapping
(Optional[Callable]
, default:
None
)
–
function to map module inputs to the expected input.
-
output_mapping
(Optional[Callable]
, default:
None
)
–
function to map module outputs to the expected output.
-
offload_parameters_to_cpu
(bool
, default:
False
)
–
offload parameters to cpu.
-
forward_func
(Optional[str]
, default:
None
)
–
forwarding function name used by the module, if None, the module call is used.
Example
import torch
import model_navigator as nav
model = torch.nn.Linear(10, 10)
model = nav.Module(model)
Initialize Module.
Source code in model_navigator/inplace/wrapper.py
| def __init__(
self,
module: "torch.nn.Module",
name: Optional[str] = None,
input_mapping: Optional[Callable] = None,
output_mapping: Optional[Callable] = None,
timer: Optional[Timer] = None,
offload_parameters_to_cpu: bool = False,
forward_func: Optional[str] = None,
) -> None:
"""Initialize Module."""
super().__init__(module)
self._name = name or get_object_name(module)
self._input_mapping = input_mapping or (lambda x: x)
self._output_mapping = output_mapping or (lambda x: x)
self._optimize_config = None
if timer:
self.add_timer(timer=timer)
else:
self._module_timer = None
current_forward = None
if forward_func:
try:
current_forward = getattr(module, forward_func)
except AttributeError as e:
raise ModelNavigatorUserInputError(f"Forward method must exist, got {forward_func}.") from e
setattr(module, forward_func, self.__call__)
self._wrapper = RecordAndOptimizeModule(
module,
# OptimizeConfig(),
self._name,
self._input_mapping,
self._output_mapping,
forward=current_forward,
)
module_registry.register(self._name, self)
|
is_optimized
property
Check if the module is optimized.
is_ready_for_optimization
property
is_ready_for_optimization
Check if the module is ready for optimization.
optimize_config
property
writable
wrapper
property
Return the wrapper module.
__call__
__call__(*args, **kwargs)
Call the wrapped module.
This method overrides the call method of the wrapped module.
If the module is already optimized it is replaced with the optimized one.
Source code in model_navigator/inplace/wrapper.py
| def __call__(self, *args, **kwargs) -> Any:
"""Call the wrapped module.
This method overrides the __call__ method of the wrapped module.
If the module is already optimized it is replaced with the optimized one.
"""
if self._module_timer and self._module_timer.enabled:
with self._module_timer:
output = self._wrapper(*args, **kwargs)
if isinstance(self, torch.nn.Module):
torch.cuda.synchronize()
else:
output = self._wrapper(*args, **kwargs)
return output
|
add_timer
Add timer to module.
Source code in model_navigator/inplace/wrapper.py
| def add_timer(self, timer: Timer) -> None:
"""Add timer to module."""
self._module_timer = timer.register_module(self._name)
|
load_optimized
load_optimized(strategy=None, activate_runners=True)
Load optimized module.
Source code in model_navigator/inplace/wrapper.py
| def load_optimized(self, strategy: Optional[RuntimeSearchStrategy] = None, activate_runners: bool = True) -> None:
"""Load optimized module."""
# TODO: Consider another validation for optimization status here. is_optimized property is modified by loading passthrough.
# if not self.is_optimized:
# raise ModelNavigatorModuleNotOptimizedError(f"Module {self.name} is not optimized.")
self._wrapper = OptimizedModule(
module=self._wrapper._module,
# self._optimize_config,
name=self._name,
input_mapping=self._input_mapping,
output_mapping=self._output_mapping,
strategy=strategy,
activate_runners=activate_runners,
)
|
load_passthrough
Load passthrough module.
Source code in model_navigator/inplace/wrapper.py
| def load_passthrough(self) -> None:
"""Load passthrough module."""
self._wrapper = PassthroughModule(
module=self._wrapper._module,
name=self._name,
input_mapping=self._input_mapping,
output_mapping=self._output_mapping,
optimize_config=self._optimize_config,
)
|
load_recorded
Load recorded module.
Source code in model_navigator/inplace/wrapper.py
| def load_recorded(self) -> None:
"""Load recorded module."""
self._wrapper = RecordModule(
module=self._wrapper._module,
name=self._name,
input_mapping=self._input_mapping,
output_mapping=self._output_mapping,
optimize_config=self._optimize_config,
)
|
optimize
Optimize the module.
Source code in model_navigator/inplace/wrapper.py
| def optimize(self) -> None:
"""Optimize the module."""
assert isinstance(self.wrapper, RecordModule), f"Module {self.name} must be in recording mode to optimize."
assert not self.is_optimized, f"Module {self.name} is already optimized."
assert hasattr(self.wrapper, "optimize"), f"Module {self.name} does not have an optimize method."
self.wrapper.optimize()
|