Skip to content

Commit 2f6bab6

Browse files
committed
[torchtitan][replicate] experimenting new replicate integration with torchtitan
ghstack-source-id: 7bba1f6 Pull Request resolved: #1714
1 parent d240be0 commit 2f6bab6

File tree

3 files changed

+11
-7
lines changed

3 files changed

+11
-7
lines changed

torchtitan/distributed/parallel_dims.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,9 @@ def _build_mesh_without_ep(self) -> DeviceMesh:
151151
[self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp],
152152
["pp", "dp_replicate", "dp_shard", "cp", "tp"],
153153
):
154-
if d > 1:
154+
# Include dp_shard dimension even if it equals 1 when replicate > 1
155+
# to make device_mesh compatible with replicate function
156+
if d > 1 or (name == "dp_shard" and self.dp_replicate > 1):
155157
dims.append(d)
156158
names.append(name)
157159

torchtitan/models/llama3/infra/parallelize.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import torch
1111
import torch.nn as nn
12-
from torch.distributed._composable.replicate import replicate
12+
from torch.distributed._composable.replicate_with_fsdp import replicate
1313

1414
from torch.distributed.device_mesh import DeviceMesh
1515
from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy
@@ -135,11 +135,13 @@ def parallelize_llama(
135135
if job_config.training.enable_cpu_offload:
136136
logger.info("Applied CPU Offloading to the model")
137137
elif parallel_dims.dp_replicate_enabled:
138-
if world_mesh.ndim > 1:
139-
raise RuntimeError("DDP has not supported > 1D parallelism")
138+
# if world_mesh.ndim > 1:
139+
# raise RuntimeError("DDP has not supported > 1D parallelism")
140+
141+
dp_mesh_dim_names = ("dp_replicate", "dp_shard")
140142
apply_ddp(
141143
model,
142-
world_mesh,
144+
world_mesh[tuple(dp_mesh_dim_names)],
143145
enable_compile=model_compile_enabled,
144146
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
145147
)
@@ -328,6 +330,6 @@ def apply_ddp(
328330
else:
329331
torch._dynamo.config.optimize_ddp = "ddp_optimizer"
330332

331-
replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
333+
replicate(model, device_mesh=dp_mesh)
332334

333335
logger.info("Applied DDP to the model")

torchtitan/models/llama3/train_configs/debug_model.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ steps = 10
4343
dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)
4444

4545
[parallelism]
46-
data_parallel_replicate_degree = 1
46+
data_parallel_replicate_degree = 8
4747
data_parallel_shard_degree = -1
4848
fsdp_reshard_after_forward = "default" # default / never / always
4949
tensor_parallel_degree = 1

0 commit comments

Comments
 (0)