From 6808186ba91be8c96ddd8e7837698c5de7ddd2c4 Mon Sep 17 00:00:00 2001 From: Tony Francis Date: Sat, 12 Apr 2025 20:23:09 -0400 Subject: [PATCH 1/2] wrap forward methods instead of monkeypatch based redirect --- src/lightning/fabric/wrappers.py | 39 +++++++++++++------------------- 1 file changed, 16 insertions(+), 23 deletions(-) diff --git a/src/lightning/fabric/wrappers.py b/src/lightning/fabric/wrappers.py index b593c9f22ed23..1d886043674d9 100644 --- a/src/lightning/fabric/wrappers.py +++ b/src/lightning/fabric/wrappers.py @@ -176,28 +176,6 @@ def mark_forward_method(self, method: Union[MethodType, str]) -> None: ) self._forward_methods.add(name) - def _redirection_through_forward(self, method_name: str) -> Callable: - assert method_name != "forward" - original_forward = self._original_module.forward - - def wrapped_forward(*args: Any, **kwargs: Any) -> Any: - # Unpatch ourselves immediately before calling the method `method_name` - # because itself may want to call the real `forward` - self._original_module.forward = original_forward - # Call the actual method e.g. `.training_step(...)` - method = getattr(self._original_module, method_name) - return method(*args, **kwargs) - - # We make the caller "unknowingly" send their arguments through the forward_module's `__call__`. - # We expect that the `forward_module` will eventually call `original_module.forward`, which we - # have patched to redirect back to `original_module.method_name()`. - def call_forward_module(*args: Any, **kwargs: Any) -> Any: - # Patch the original_module's forward, so we can redirect the arguments back to the real method - self._original_module.forward = wrapped_forward - return self.forward(*args, **kwargs) - - return call_forward_module - def _wrap_method_with_module_call_tracker(self, method: Callable, name: str) -> Callable: """Tracks whether any submodule in ``self._original_module`` was called during the execution of ``method`` by registering forward hooks on all submodules.""" @@ -239,6 +217,21 @@ def _register_backward_hook(self, tensor: Tensor) -> Tensor: hook = partial(_backward_hook, (strategy_requires or precision_requires)) tensor.register_hook(hook) return tensor + + def wrap_forward_method(self, method: Callable) -> Callable: + @wraps(method) + def wrapper(*args: Any, **kwargs: Any) -> Any: + precision = self._strategy.precision + args, kwargs = precision.convert_input((args, kwargs)) + + with precision.forward_context(): + output = method(*args, **kwargs) + + output = precision.convert_output(output) + + apply_to_collection(output, dtype=Tensor, function=self._register_backward_hook) + return output + return wrapper @override def __getattr__(self, item: Any) -> Any: @@ -248,7 +241,7 @@ def __getattr__(self, item: Any) -> Any: and self._forward_module != self._original_module ): # Special support for methods marked by `mark_forward_method` to prevent bypassing DDP's forward - return self._redirection_through_forward(item) + return self.wrap_forward_method(getattr(self._original_module, item)) try: # __getattr__ gets called as a last resort if the attribute does not exist From 0798ef770b3918e7e0ee0237bb07d28f3c26a986 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 13 Apr 2025 00:42:03 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/fabric/wrappers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/lightning/fabric/wrappers.py b/src/lightning/fabric/wrappers.py index 1d886043674d9..2b0669690caa9 100644 --- a/src/lightning/fabric/wrappers.py +++ b/src/lightning/fabric/wrappers.py @@ -217,7 +217,7 @@ def _register_backward_hook(self, tensor: Tensor) -> Tensor: hook = partial(_backward_hook, (strategy_requires or precision_requires)) tensor.register_hook(hook) return tensor - + def wrap_forward_method(self, method: Callable) -> Callable: @wraps(method) def wrapper(*args: Any, **kwargs: Any) -> Any: @@ -231,6 +231,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: apply_to_collection(output, dtype=Tensor, function=self._register_backward_hook) return output + return wrapper @override