Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion tests/utils/test_prepare_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,32 @@ def test_check_and_convert_mp_policy_dtypes(self) -> None:
):
_check_and_convert_mp_policy_dtypes(invalid_mp_policy)

@patch("torchtnt.utils.prepare_module.fully_shard")
def test_fsdp2_strategy_shard_predicates(self, mock_fully_shard: Mock) -> None:
"""
Ensure modules_to_shard and shard_predicates are applied sequentially
"""

class SimpleModule(torch.nn.Module):
def __init__(self):
super(SimpleModule, self).__init__()
self.linear1 = torch.nn.Linear(10, 10, device="meta")
self.conv = torch.nn.Conv2d(10, 10, kernel_size=3, device="meta")

module = SimpleModule()
strategy = FSDP2Strategy(
modules_to_shard=[torch.nn.Conv2d],
shard_predicates=[lambda n, _: "linear" in n],
)
mock_mesh = MagicMock(spec=DeviceMesh)
mock_global_mesh = MagicMock(spec=GlobalMeshCoordinator)
mock_global_mesh.dp_mesh = mock_mesh
module = prepare_fsdp2(
module, torch.device("cpu"), strategy, global_mesh=mock_global_mesh
)
# shards self.linear, self.conv, and self
self.assertEqual(mock_fully_shard.call_count, 3)

@patch("torchtnt.utils.prepare_module.fully_shard")
def test_fsdp2_mesh(self, mock_fully_shard: Mock) -> None:
"""
Expand All @@ -285,7 +311,7 @@ def test_fsdp2_mesh(self, mock_fully_shard: Mock) -> None:
mock_global_mesh = MagicMock(spec=GlobalMeshCoordinator)
mock_global_mesh.dp_mesh = mock_mesh

strategy = FSDP2Strategy()
strategy = FSDP2Strategy(modules_to_shard=[torch.nn.Linear])
module = prepare_fsdp2(
module,
torch.device("cpu"),
Expand Down
35 changes: 26 additions & 9 deletions torchtnt/utils/prepare_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
ContextManager,
Dict,
Iterable,
List,
Literal,
Optional,
Set,
Expand All @@ -27,6 +28,7 @@

import torch
import torch.distributed as dist
from pyre_extensions import none_throws
from torch.distributed import ProcessGroup

from torch.distributed._composable_state import _get_module_state
Expand Down Expand Up @@ -192,15 +194,18 @@ class FSDP2Strategy(Strategy):
For more details on the args, see the link.

Args:
modules_to_shard: A list of modules that should be sharded across devices. Options are 'all' to shard all submodules, or a list of module names/module types.
reshard_after_forward: If True, reshards parameters after the forward pass to optimize memory usage.
modules_to_shard: A list of modules that should be sharded across devices. Options are 'all' to shard all submodules, or a list of module names/module types. Specify None to not shard any modules with this flag.
shard_predicates: A list of predicates to decide which modules to shard with FSDP. Each predicate takes a module name (fqn) and the module itself. If any predicate returns True, the submodule is sharded.
reshard_after_forward: If True, reshards parameters post-forward pass to save memory.
mp_policy: Controls mixed precision policy. If only dtype is provided, it will be used to cast all relevant parts of model. If None, no mixed precision is used
cpu_offload: If True, enables CPU offloading of model parameters to reduce GPU memory usage.

Note:
It is recommended to specify specific modules to shard to avoid unnecessary sharding of all submodules, which has
communication overhead.

Note: modules_to_shard and shard_predicates are applied sequentially. If a module is specified in modules_to_shard, it will be sharded regardless of shard_predicates, and vice-versa

Example:
>>> model
TransformerDecoder(
Expand All @@ -222,10 +227,15 @@ class FSDP2Strategy(Strategy):
>>> strategy = FSDP2Strategy(modules_to_shard=["TransformerSelfAttentionLayer", "Linear"])
"""

modules_to_shard: Union[
Literal["all"],
Iterable[Union[str, Type[torch.nn.Module]]],
] = "all"
modules_to_shard: Optional[
Union[
Literal["all"],
Iterable[Union[str, Type[torch.nn.Module]]],
]
] = None
shard_predicates: List[Callable[[str, torch.nn.Module], bool]] = field(
default_factory=list
)
reshard_after_forward: Union[bool, int] = True
mp_policy: Optional[Union[str, torch.dtype, MixedPrecisionPolicy]] = None
cpu_offload: bool = False
Expand Down Expand Up @@ -435,20 +445,20 @@ def prepare_fsdp2(
shard_all = modules_to_shard == "all"
shard_module_names: Set[str] = set()
shard_module_types: Tuple[Type[torch.nn.Module], ...] = ()
if not shard_all:
if not shard_all and modules_to_shard is not None:
assert (
type(modules_to_shard) is not str
), f"modules_to_shard must be an iterable of modules or 'all', got {shard_all}"

for item in modules_to_shard:
for item in none_throws(modules_to_shard):
if isinstance(item, str):
shard_module_names.add(item)
else:
shard_module_types = shard_module_types + (item,)

# apply the fsdp2 sharding bottoms up
num_layers_sharded = 0
for _, m in reversed(list(module.named_modules())):
for n, m in reversed(list(module.named_modules())):
if shard_all:
# fully_shard does not support containers that do not implement forward
if not isinstance(m, (torch.nn.ModuleList, torch.nn.ModuleDict)):
Expand All @@ -460,6 +470,13 @@ def prepare_fsdp2(
# if m exists in shard_module_types, then shard it
fully_shard(m, **fsdp_kwargs)
num_layers_sharded += 1
elif len(strategy.shard_predicates) > 0:
# if shard_predicates is not empty, then check if any of the conditions are true
for predicate in strategy.shard_predicates:
if predicate(n, m):
fully_shard(m, **fsdp_kwargs)
num_layers_sharded += 1
break

if num_layers_sharded == 0:
raise ValueError(
Expand Down
Loading