diff --git a/tests/utils/test_prepare_module.py b/tests/utils/test_prepare_module.py index 83c569736d..0c79c1e79d 100644 --- a/tests/utils/test_prepare_module.py +++ b/tests/utils/test_prepare_module.py @@ -274,6 +274,32 @@ def test_check_and_convert_mp_policy_dtypes(self) -> None: ): _check_and_convert_mp_policy_dtypes(invalid_mp_policy) + @patch("torchtnt.utils.prepare_module.fully_shard") + def test_fsdp2_strategy_shard_predicates(self, mock_fully_shard: Mock) -> None: + """ + Ensure modules_to_shard and shard_predicates are applied sequentially + """ + + class SimpleModule(torch.nn.Module): + def __init__(self): + super(SimpleModule, self).__init__() + self.linear1 = torch.nn.Linear(10, 10, device="meta") + self.conv = torch.nn.Conv2d(10, 10, kernel_size=3, device="meta") + + module = SimpleModule() + strategy = FSDP2Strategy( + modules_to_shard=[torch.nn.Conv2d], + shard_predicates=[lambda n, _: "linear" in n], + ) + mock_mesh = MagicMock(spec=DeviceMesh) + mock_global_mesh = MagicMock(spec=GlobalMeshCoordinator) + mock_global_mesh.dp_mesh = mock_mesh + module = prepare_fsdp2( + module, torch.device("cpu"), strategy, global_mesh=mock_global_mesh + ) + # shards self.linear, self.conv, and self + self.assertEqual(mock_fully_shard.call_count, 3) + @patch("torchtnt.utils.prepare_module.fully_shard") def test_fsdp2_mesh(self, mock_fully_shard: Mock) -> None: """ @@ -285,7 +311,7 @@ def test_fsdp2_mesh(self, mock_fully_shard: Mock) -> None: mock_global_mesh = MagicMock(spec=GlobalMeshCoordinator) mock_global_mesh.dp_mesh = mock_mesh - strategy = FSDP2Strategy() + strategy = FSDP2Strategy(modules_to_shard=[torch.nn.Linear]) module = prepare_fsdp2( module, torch.device("cpu"), diff --git a/torchtnt/utils/prepare_module.py b/torchtnt/utils/prepare_module.py index c632365afa..fad31d0d4a 100644 --- a/torchtnt/utils/prepare_module.py +++ b/torchtnt/utils/prepare_module.py @@ -17,6 +17,7 @@ ContextManager, Dict, Iterable, + List, Literal, Optional, Set, @@ -27,6 +28,7 @@ import torch import torch.distributed as dist +from pyre_extensions import none_throws from torch.distributed import ProcessGroup from torch.distributed._composable_state import _get_module_state @@ -192,8 +194,9 @@ class FSDP2Strategy(Strategy): For more details on the args, see the link. Args: - modules_to_shard: A list of modules that should be sharded across devices. Options are 'all' to shard all submodules, or a list of module names/module types. - reshard_after_forward: If True, reshards parameters after the forward pass to optimize memory usage. + modules_to_shard: A list of modules that should be sharded across devices. Options are 'all' to shard all submodules, or a list of module names/module types. Specify None to not shard any modules with this flag. + shard_predicates: A list of predicates to decide which modules to shard with FSDP. Each predicate takes a module name (fqn) and the module itself. If any predicate returns True, the submodule is sharded. + reshard_after_forward: If True, reshards parameters post-forward pass to save memory. mp_policy: Controls mixed precision policy. If only dtype is provided, it will be used to cast all relevant parts of model. If None, no mixed precision is used cpu_offload: If True, enables CPU offloading of model parameters to reduce GPU memory usage. @@ -201,6 +204,8 @@ class FSDP2Strategy(Strategy): It is recommended to specify specific modules to shard to avoid unnecessary sharding of all submodules, which has communication overhead. + Note: modules_to_shard and shard_predicates are applied sequentially. If a module is specified in modules_to_shard, it will be sharded regardless of shard_predicates, and vice-versa + Example: >>> model TransformerDecoder( @@ -222,10 +227,15 @@ class FSDP2Strategy(Strategy): >>> strategy = FSDP2Strategy(modules_to_shard=["TransformerSelfAttentionLayer", "Linear"]) """ - modules_to_shard: Union[ - Literal["all"], - Iterable[Union[str, Type[torch.nn.Module]]], - ] = "all" + modules_to_shard: Optional[ + Union[ + Literal["all"], + Iterable[Union[str, Type[torch.nn.Module]]], + ] + ] = None + shard_predicates: List[Callable[[str, torch.nn.Module], bool]] = field( + default_factory=list + ) reshard_after_forward: Union[bool, int] = True mp_policy: Optional[Union[str, torch.dtype, MixedPrecisionPolicy]] = None cpu_offload: bool = False @@ -435,12 +445,12 @@ def prepare_fsdp2( shard_all = modules_to_shard == "all" shard_module_names: Set[str] = set() shard_module_types: Tuple[Type[torch.nn.Module], ...] = () - if not shard_all: + if not shard_all and modules_to_shard is not None: assert ( type(modules_to_shard) is not str ), f"modules_to_shard must be an iterable of modules or 'all', got {shard_all}" - for item in modules_to_shard: + for item in none_throws(modules_to_shard): if isinstance(item, str): shard_module_names.add(item) else: @@ -448,7 +458,7 @@ def prepare_fsdp2( # apply the fsdp2 sharding bottoms up num_layers_sharded = 0 - for _, m in reversed(list(module.named_modules())): + for n, m in reversed(list(module.named_modules())): if shard_all: # fully_shard does not support containers that do not implement forward if not isinstance(m, (torch.nn.ModuleList, torch.nn.ModuleDict)): @@ -460,6 +470,13 @@ def prepare_fsdp2( # if m exists in shard_module_types, then shard it fully_shard(m, **fsdp_kwargs) num_layers_sharded += 1 + elif len(strategy.shard_predicates) > 0: + # if shard_predicates is not empty, then check if any of the conditions are true + for predicate in strategy.shard_predicates: + if predicate(n, m): + fully_shard(m, **fsdp_kwargs) + num_layers_sharded += 1 + break if num_layers_sharded == 0: raise ValueError(