Skip to content

Commit b391f2b

Browse files
committed
[torchtitan][replicate] experimenting new replicate integration with torchtitan
ghstack-source-id: ea5b964 Pull Request resolved: #1714
1 parent d240be0 commit b391f2b

File tree

3 files changed

+18
-8
lines changed

3 files changed

+18
-8
lines changed

torchtitan/distributed/parallel_dims.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,13 @@ def _build_mesh_with_ep(self) -> DeviceMesh:
100100
):
101101
# dp_shard_mod_ep is needed even if it's 1, whose FSDP wrapping
102102
# helps the MoE layers do mixed precision training
103-
if d > 1 or name == "dp_shard_mod_ep":
103+
# dp_shard_in_ep is included even if it equals 1 when replicate > 1
104+
# to make device_mesh compatible with replicate function
105+
if (
106+
d > 1
107+
or name == "dp_shard_mod_ep"
108+
or (name == "dp_shard_in_ep" and self.dp_replicate > 1)
109+
):
104110
dims.append(d)
105111
names.append(name)
106112

@@ -151,7 +157,9 @@ def _build_mesh_without_ep(self) -> DeviceMesh:
151157
[self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp],
152158
["pp", "dp_replicate", "dp_shard", "cp", "tp"],
153159
):
154-
if d > 1:
160+
# Include dp_shard dimension even if it equals 1 when replicate > 1
161+
# to make device_mesh compatible with replicate function
162+
if d > 1 or (name == "dp_shard" and self.dp_replicate > 1):
155163
dims.append(d)
156164
names.append(name)
157165

torchtitan/models/llama3/infra/parallelize.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import torch
1111
import torch.nn as nn
12-
from torch.distributed._composable.replicate import replicate
12+
from torch.distributed._composable.replicate_with_fsdp import replicate
1313

1414
from torch.distributed.device_mesh import DeviceMesh
1515
from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy
@@ -135,11 +135,13 @@ def parallelize_llama(
135135
if job_config.training.enable_cpu_offload:
136136
logger.info("Applied CPU Offloading to the model")
137137
elif parallel_dims.dp_replicate_enabled:
138-
if world_mesh.ndim > 1:
139-
raise RuntimeError("DDP has not supported > 1D parallelism")
138+
# if world_mesh.ndim > 1:
139+
# raise RuntimeError("DDP has not supported > 1D parallelism")
140+
141+
dp_mesh_dim_names = ("dp_replicate", "dp_shard")
140142
apply_ddp(
141143
model,
142-
world_mesh,
144+
world_mesh[tuple(dp_mesh_dim_names)],
143145
enable_compile=model_compile_enabled,
144146
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
145147
)
@@ -328,6 +330,6 @@ def apply_ddp(
328330
else:
329331
torch._dynamo.config.optimize_ddp = "ddp_optimizer"
330332

331-
replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
333+
replicate(model, device_mesh=dp_mesh)
332334

333335
logger.info("Applied DDP to the model")

torchtitan/models/llama3/train_configs/debug_model.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ steps = 10
4343
dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)
4444

4545
[parallelism]
46-
data_parallel_replicate_degree = 1
46+
data_parallel_replicate_degree = 8
4747
data_parallel_shard_degree = -1
4848
fsdp_reshard_after_forward = "default" # default / never / always
4949
tensor_parallel_degree = 1

0 commit comments

Comments
 (0)