From 2fc2fb7939f190ae0f1a2274103704a43f98db57 Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Sat, 3 May 2025 15:09:17 -0500 Subject: [PATCH 01/10] Update fsdp.py --- src/lightning/pytorch/plugins/precision/fsdp.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/lightning/pytorch/plugins/precision/fsdp.py b/src/lightning/pytorch/plugins/precision/fsdp.py index c41199adb480e..57c5bf9c9aba1 100644 --- a/src/lightning/pytorch/plugins/precision/fsdp.py +++ b/src/lightning/pytorch/plugins/precision/fsdp.py @@ -76,9 +76,6 @@ def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradSca @override def clip_grad_by_norm(self, *_: Any, **__: Any) -> None: # see https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_ - # section `Gradient Clipping`, using `torch.nn.utils.clip_grad_norm_` is incorrect with FSDP. - # To overcome this we need to call root_sharded_module.clip_grad_norm(clip_val), but we don't have a reference - # to the root module raise MisconfigurationException( f"`gradient_clip_algorithm='norm'` is currently not supported for `{self.__class__.__name__}`" ) From c36f40cb0b066b3ebd22c76afbdd196f3edc53b5 Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Sat, 3 May 2025 15:52:46 -0500 Subject: [PATCH 02/10] Support gradient norm clipping for FSDP --- src/lightning/pytorch/core/module.py | 4 +++- src/lightning/pytorch/plugins/precision/amp.py | 6 +++++- src/lightning/pytorch/plugins/precision/fsdp.py | 9 ++++----- src/lightning/pytorch/plugins/precision/precision.py | 5 +++-- tests/tests_pytorch/plugins/precision/test_amp.py | 8 +++++--- 5 files changed, 20 insertions(+), 12 deletions(-) diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index b8624daac3fa3..d1b0cca4feeae 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -1207,7 +1207,9 @@ def clip_gradients( ) gradient_clip_algorithm = GradClipAlgorithmType(gradient_clip_algorithm) - self.trainer.precision_plugin.clip_gradients(optimizer, gradient_clip_val, gradient_clip_algorithm) + self.trainer.precision_plugin.clip_gradients( + self.trainer.model, optimizer, gradient_clip_val, gradient_clip_algorithm + ) def configure_gradient_clipping( self, diff --git a/src/lightning/pytorch/plugins/precision/amp.py b/src/lightning/pytorch/plugins/precision/amp.py index 75e792af46b90..6746b5dcd2585 100644 --- a/src/lightning/pytorch/plugins/precision/amp.py +++ b/src/lightning/pytorch/plugins/precision/amp.py @@ -15,6 +15,7 @@ import torch from torch import Tensor +from torch.nn import Module from torch.optim import LBFGS, Optimizer from typing_extensions import override @@ -100,6 +101,7 @@ def optimizer_step( # type: ignore[override] @override def clip_gradients( self, + module: Module, optimizer: Optimizer, clip_val: Union[int, float] = 0.0, gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, @@ -109,7 +111,9 @@ def clip_gradients( f"The current optimizer, {type(optimizer).__qualname__}, does not allow for gradient clipping" " because it performs unscaling of gradients internally. HINT: Are you using a 'fused' optimizer?" ) - super().clip_gradients(optimizer=optimizer, clip_val=clip_val, gradient_clip_algorithm=gradient_clip_algorithm) + super().clip_gradients( + module=module, optimizer=optimizer, clip_val=clip_val, gradient_clip_algorithm=gradient_clip_algorithm + ) def autocast_context_manager(self) -> torch.autocast: return torch.autocast(self.device, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.half)) diff --git a/src/lightning/pytorch/plugins/precision/fsdp.py b/src/lightning/pytorch/plugins/precision/fsdp.py index cd05cda985df5..bc4b3c0185a85 100644 --- a/src/lightning/pytorch/plugins/precision/fsdp.py +++ b/src/lightning/pytorch/plugins/precision/fsdp.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import AbstractContextManager -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable, Optional, Union import torch from lightning_utilities import apply_to_collection from torch import Tensor from torch.nn import Module +from torch.optim import Optimizer from typing_extensions import get_args, override import lightning.pytorch as pl @@ -81,11 +82,9 @@ def convert_module(self, module: Module) -> Module: return module @override - def clip_grad_by_norm(self, *_: Any, **__: Any) -> None: + def clip_grad_by_norm(self, module: Module, optimizer: Optimizer, clip_val: Union[int, float]) -> None: # see https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_ - raise MisconfigurationException( - f"`gradient_clip_algorithm='norm'` is currently not supported for `{self.__class__.__name__}`" - ) + module.clip_grad_norm_(clip_val) @property def mixed_precision_config(self) -> "TorchMixedPrecision": diff --git a/src/lightning/pytorch/plugins/precision/precision.py b/src/lightning/pytorch/plugins/precision/precision.py index 327fb2d4f5a27..08655fafca758 100644 --- a/src/lightning/pytorch/plugins/precision/precision.py +++ b/src/lightning/pytorch/plugins/precision/precision.py @@ -143,6 +143,7 @@ def _clip_gradients( def clip_gradients( self, + module: Module, optimizer: Optimizer, clip_val: Union[int, float] = 0.0, gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, @@ -153,14 +154,14 @@ def clip_gradients( if gradient_clip_algorithm == GradClipAlgorithmType.VALUE: self.clip_grad_by_value(optimizer, clip_val) elif gradient_clip_algorithm == GradClipAlgorithmType.NORM: - self.clip_grad_by_norm(optimizer, clip_val) + self.clip_grad_by_norm(module, optimizer, clip_val) def clip_grad_by_value(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None: """Clip gradients by value.""" parameters = self.main_params(optimizer) torch.nn.utils.clip_grad_value_(parameters, clip_value=clip_val) - def clip_grad_by_norm(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None: + def clip_grad_by_norm(self, module: Module, optimizer: Optimizer, clip_val: Union[int, float]) -> None: """Clip gradients by norm.""" parameters = self.main_params(optimizer) torch.nn.utils.clip_grad_norm_(parameters, clip_val) diff --git a/tests/tests_pytorch/plugins/precision/test_amp.py b/tests/tests_pytorch/plugins/precision/test_amp.py index cb061c540b2be..809d9f19d0706 100644 --- a/tests/tests_pytorch/plugins/precision/test_amp.py +++ b/tests/tests_pytorch/plugins/precision/test_amp.py @@ -14,6 +14,7 @@ from unittest.mock import Mock import pytest +from torch.nn import Module from torch.optim import Optimizer from lightning.pytorch.plugins import MixedPrecision @@ -22,22 +23,23 @@ def test_clip_gradients(): """Test that `.clip_gradients()` is a no-op when clipping is disabled.""" + module = Mock(spec=Module) optimizer = Mock(spec=Optimizer) precision = MixedPrecision(precision="16-mixed", device="cuda:0", scaler=Mock()) precision.clip_grad_by_value = Mock() precision.clip_grad_by_norm = Mock() - precision.clip_gradients(optimizer) + precision.clip_gradients(module, optimizer) precision.clip_grad_by_value.assert_not_called() precision.clip_grad_by_norm.assert_not_called() - precision.clip_gradients(optimizer, clip_val=1.0, gradient_clip_algorithm=GradClipAlgorithmType.VALUE) + precision.clip_gradients(module, optimizer, clip_val=1.0, gradient_clip_algorithm=GradClipAlgorithmType.VALUE) precision.clip_grad_by_value.assert_called_once() precision.clip_grad_by_norm.assert_not_called() precision.clip_grad_by_value.reset_mock() precision.clip_grad_by_norm.reset_mock() - precision.clip_gradients(optimizer, clip_val=1.0, gradient_clip_algorithm=GradClipAlgorithmType.NORM) + precision.clip_gradients(module, optimizer, clip_val=1.0, gradient_clip_algorithm=GradClipAlgorithmType.NORM) precision.clip_grad_by_value.assert_not_called() precision.clip_grad_by_norm.assert_called_once() From 8fad4235fdbcac1819d6582b40c385584adbe02d Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Sat, 3 May 2025 16:08:08 -0500 Subject: [PATCH 03/10] Update CHANGELOG.md --- src/lightning/pytorch/CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 5616defeffc8a..c794603990737 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -10,6 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added - Add enable_autolog_hparams argument to Trainer ([#20593](https://github.com/Lightning-AI/pytorch-lightning/pull/20593)) +- Support `grad_clip_norm_()` for FSDP ([#20784](https://github.com/Lightning-AI/pytorch-lightning/pull/20784)) ### Changed From 04fbaf1f996f49cfffe60da6c16a97534ffcd1d6 Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Sat, 3 May 2025 16:17:43 -0500 Subject: [PATCH 04/10] Fix args for certain precisions --- src/lightning/pytorch/plugins/precision/deepspeed.py | 1 + src/lightning/pytorch/plugins/precision/precision.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/lightning/pytorch/plugins/precision/deepspeed.py b/src/lightning/pytorch/plugins/precision/deepspeed.py index 9225e3bb9e7be..e09eb67f4fecf 100644 --- a/src/lightning/pytorch/plugins/precision/deepspeed.py +++ b/src/lightning/pytorch/plugins/precision/deepspeed.py @@ -141,6 +141,7 @@ def optimizer_step( # type: ignore[override] @override def clip_gradients( self, + module: Optional[Module], optimizer: Optimizer, clip_val: Union[int, float] = 0.0, gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, diff --git a/src/lightning/pytorch/plugins/precision/precision.py b/src/lightning/pytorch/plugins/precision/precision.py index 08655fafca758..a11182db68f97 100644 --- a/src/lightning/pytorch/plugins/precision/precision.py +++ b/src/lightning/pytorch/plugins/precision/precision.py @@ -143,7 +143,7 @@ def _clip_gradients( def clip_gradients( self, - module: Module, + module: Optional[Module], optimizer: Optimizer, clip_val: Union[int, float] = 0.0, gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, @@ -161,7 +161,7 @@ def clip_grad_by_value(self, optimizer: Optimizer, clip_val: Union[int, float]) parameters = self.main_params(optimizer) torch.nn.utils.clip_grad_value_(parameters, clip_value=clip_val) - def clip_grad_by_norm(self, module: Module, optimizer: Optimizer, clip_val: Union[int, float]) -> None: + def clip_grad_by_norm(self, module: Optional[Module], optimizer: Optimizer, clip_val: Union[int, float]) -> None: """Clip gradients by norm.""" parameters = self.main_params(optimizer) torch.nn.utils.clip_grad_norm_(parameters, clip_val) From bce69ca26a1290653b4fd89edc06f5f506631000 Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Sat, 3 May 2025 16:24:10 -0500 Subject: [PATCH 05/10] Standardize precision args --- src/lightning/pytorch/plugins/precision/amp.py | 2 +- src/lightning/pytorch/plugins/precision/fsdp.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lightning/pytorch/plugins/precision/amp.py b/src/lightning/pytorch/plugins/precision/amp.py index 6746b5dcd2585..f6ec37e7d4edb 100644 --- a/src/lightning/pytorch/plugins/precision/amp.py +++ b/src/lightning/pytorch/plugins/precision/amp.py @@ -101,7 +101,7 @@ def optimizer_step( # type: ignore[override] @override def clip_gradients( self, - module: Module, + module: Optional[Module], optimizer: Optimizer, clip_val: Union[int, float] = 0.0, gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, diff --git a/src/lightning/pytorch/plugins/precision/fsdp.py b/src/lightning/pytorch/plugins/precision/fsdp.py index bc4b3c0185a85..aec60f4529740 100644 --- a/src/lightning/pytorch/plugins/precision/fsdp.py +++ b/src/lightning/pytorch/plugins/precision/fsdp.py @@ -82,7 +82,7 @@ def convert_module(self, module: Module) -> Module: return module @override - def clip_grad_by_norm(self, module: Module, optimizer: Optimizer, clip_val: Union[int, float]) -> None: + def clip_grad_by_norm(self, module: Optional[Module], optimizer: Optimizer, clip_val: Union[int, float]) -> None: # see https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_ module.clip_grad_norm_(clip_val) From 0df38f54022088e524c3cbfaffd0bc50a8999afb Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Sat, 3 May 2025 16:33:03 -0500 Subject: [PATCH 06/10] Guard for typing --- src/lightning/pytorch/plugins/precision/fsdp.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/lightning/pytorch/plugins/precision/fsdp.py b/src/lightning/pytorch/plugins/precision/fsdp.py index aec60f4529740..899dc1d623564 100644 --- a/src/lightning/pytorch/plugins/precision/fsdp.py +++ b/src/lightning/pytorch/plugins/precision/fsdp.py @@ -84,6 +84,8 @@ def convert_module(self, module: Module) -> Module: @override def clip_grad_by_norm(self, module: Optional[Module], optimizer: Optimizer, clip_val: Union[int, float]) -> None: # see https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_ + if module is None or not hasattr(module, "clip_grad_norm_") or not isinstance(module.clip_grad_norm_, Callable): + return module.clip_grad_norm_(clip_val) @property From a42b974389d20eb9553cb8982bf7aa66d6459556 Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Sat, 3 May 2025 16:38:54 -0500 Subject: [PATCH 07/10] Fix argument typing --- src/lightning/pytorch/plugins/precision/fsdp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/plugins/precision/fsdp.py b/src/lightning/pytorch/plugins/precision/fsdp.py index 899dc1d623564..1facad738ae85 100644 --- a/src/lightning/pytorch/plugins/precision/fsdp.py +++ b/src/lightning/pytorch/plugins/precision/fsdp.py @@ -84,7 +84,7 @@ def convert_module(self, module: Module) -> Module: @override def clip_grad_by_norm(self, module: Optional[Module], optimizer: Optimizer, clip_val: Union[int, float]) -> None: # see https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_ - if module is None or not hasattr(module, "clip_grad_norm_") or not isinstance(module.clip_grad_norm_, Callable): + if module is None or not hasattr(module, "clip_grad_norm_") or not isinstance(module.clip_grad_norm_, callable): return module.clip_grad_norm_(clip_val) From ed2fe05ad04c43873d8998b8c7782f3a249430fd Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Sat, 3 May 2025 16:41:36 -0500 Subject: [PATCH 08/10] Wrap AMP test module in FSDP --- tests/tests_pytorch/plugins/precision/test_amp.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/tests_pytorch/plugins/precision/test_amp.py b/tests/tests_pytorch/plugins/precision/test_amp.py index 809d9f19d0706..b009a900446dd 100644 --- a/tests/tests_pytorch/plugins/precision/test_amp.py +++ b/tests/tests_pytorch/plugins/precision/test_amp.py @@ -14,6 +14,7 @@ from unittest.mock import Mock import pytest +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.nn import Module from torch.optim import Optimizer @@ -23,7 +24,7 @@ def test_clip_gradients(): """Test that `.clip_gradients()` is a no-op when clipping is disabled.""" - module = Mock(spec=Module) + module = FSDP(Mock(spec=Module)) optimizer = Mock(spec=Optimizer) precision = MixedPrecision(precision="16-mixed", device="cuda:0", scaler=Mock()) precision.clip_grad_by_value = Mock() From 2f62a0a1b7f0c462b693d19b3b7ec3b680c410aa Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Sat, 3 May 2025 16:51:14 -0500 Subject: [PATCH 09/10] Simplify guard --- src/lightning/pytorch/plugins/precision/fsdp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/plugins/precision/fsdp.py b/src/lightning/pytorch/plugins/precision/fsdp.py index 1facad738ae85..280bc4351f237 100644 --- a/src/lightning/pytorch/plugins/precision/fsdp.py +++ b/src/lightning/pytorch/plugins/precision/fsdp.py @@ -84,7 +84,7 @@ def convert_module(self, module: Module) -> Module: @override def clip_grad_by_norm(self, module: Optional[Module], optimizer: Optimizer, clip_val: Union[int, float]) -> None: # see https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_ - if module is None or not hasattr(module, "clip_grad_norm_") or not isinstance(module.clip_grad_norm_, callable): + if module is None: return module.clip_grad_norm_(clip_val) From 7f7987e5225b807196ee2dd878c9f7b095e4505e Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Sat, 3 May 2025 17:08:37 -0500 Subject: [PATCH 10/10] Remove FSDP traces in AMP precision unit test --- tests/tests_pytorch/plugins/precision/test_amp.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/tests_pytorch/plugins/precision/test_amp.py b/tests/tests_pytorch/plugins/precision/test_amp.py index b009a900446dd..900892fad5fdd 100644 --- a/tests/tests_pytorch/plugins/precision/test_amp.py +++ b/tests/tests_pytorch/plugins/precision/test_amp.py @@ -14,7 +14,6 @@ from unittest.mock import Mock import pytest -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.nn import Module from torch.optim import Optimizer @@ -24,7 +23,7 @@ def test_clip_gradients(): """Test that `.clip_gradients()` is a no-op when clipping is disabled.""" - module = FSDP(Mock(spec=Module)) + module = Mock(spec=Module) optimizer = Mock(spec=Optimizer) precision = MixedPrecision(precision="16-mixed", device="cuda:0", scaler=Mock()) precision.clip_grad_by_value = Mock() @@ -49,8 +48,9 @@ def test_optimizer_amp_scaling_support_in_step_method(): """Test that the plugin checks if the optimizer takes over unscaling in its step, making it incompatible with gradient clipping (example: fused Adam).""" + module = Mock(spec=Module) optimizer = Mock(_step_supports_amp_scaling=True) precision = MixedPrecision(precision="16-mixed", device="cuda:0", scaler=Mock()) with pytest.raises(RuntimeError, match="The current optimizer.*does not allow for gradient clipping"): - precision.clip_gradients(optimizer, clip_val=1.0) + precision.clip_gradients(module, optimizer, clip_val=1.0)