Skip to content

FabricModule: wrap forward methods instead of monkeypatch-based redirect #20711

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 17 additions & 23 deletions src/lightning/fabric/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,28 +176,6 @@ def mark_forward_method(self, method: Union[MethodType, str]) -> None:
)
self._forward_methods.add(name)

def _redirection_through_forward(self, method_name: str) -> Callable:
assert method_name != "forward"
original_forward = self._original_module.forward

def wrapped_forward(*args: Any, **kwargs: Any) -> Any:
# Unpatch ourselves immediately before calling the method `method_name`
# because itself may want to call the real `forward`
self._original_module.forward = original_forward
# Call the actual method e.g. `.training_step(...)`
method = getattr(self._original_module, method_name)
return method(*args, **kwargs)

# We make the caller "unknowingly" send their arguments through the forward_module's `__call__`.
# We expect that the `forward_module` will eventually call `original_module.forward`, which we
# have patched to redirect back to `original_module.method_name()`.
def call_forward_module(*args: Any, **kwargs: Any) -> Any:
# Patch the original_module's forward, so we can redirect the arguments back to the real method
self._original_module.forward = wrapped_forward
return self.forward(*args, **kwargs)

return call_forward_module

def _wrap_method_with_module_call_tracker(self, method: Callable, name: str) -> Callable:
"""Tracks whether any submodule in ``self._original_module`` was called during the execution of ``method`` by
registering forward hooks on all submodules."""
Expand Down Expand Up @@ -240,6 +218,22 @@ def _register_backward_hook(self, tensor: Tensor) -> Tensor:
tensor.register_hook(hook)
return tensor

def wrap_forward_method(self, method: Callable) -> Callable:
@wraps(method)
def wrapper(*args: Any, **kwargs: Any) -> Any:
precision = self._strategy.precision
args, kwargs = precision.convert_input((args, kwargs))

with precision.forward_context():
output = method(*args, **kwargs)

output = precision.convert_output(output)

apply_to_collection(output, dtype=Tensor, function=self._register_backward_hook)
return output

return wrapper

@override
def __getattr__(self, item: Any) -> Any:
if (
Expand All @@ -248,7 +242,7 @@ def __getattr__(self, item: Any) -> Any:
and self._forward_module != self._original_module
):
# Special support for methods marked by `mark_forward_method` to prevent bypassing DDP's forward
return self._redirection_through_forward(item)
return self.wrap_forward_method(getattr(self._original_module, item))

try:
# __getattr__ gets called as a last resort if the attribute does not exist
Expand Down
Loading