Skip to content

Conversation

anshul-si
Copy link

@anshul-si anshul-si commented Sep 15, 2025

Summary: During this experiment to integrate the new replicate function into torchtitan, I used pytorch/pytorch#162021, which has not been landed. However, since this is more about making replicate more efficient rather than changing replicate's core code, pytorch/pytorch#160135, which has landed, should be fine. pytorch/pytorch#160133 is the last time replicate_with_fsdp.py and its replicate api were touched.

In order to enable the new replicate, which uses a 2D device mesh (since it is a specialized version of HSDP), I changed the parallelism code to include dp_shard dim = 1 only if dp_replicate > 1, and created device mesh that I pass down in apply_ddp.

Below is a link comparing the loss curves for Llama3.1-8B models: one configured with dimension sharding (2) and tensor parallelism (4), and the other with dimension replication (2) and sharding (4).

image

https://fburl.com/mlhub/btkos8ok

Test Case

  1. CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh

Expected output of this experiment should be something like:
[rank0]:[titan] 2025-09-15 17:38:26,676 - root - INFO - Starting job: Llama 3 debug training
[rank0]:[titan] 2025-09-15 17:38:29,094 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config
[rank0]:[titan] 2025-09-15 17:38:29,097 - root - INFO - Building 2-D device mesh with ['dp_replicate', 'dp_shard'], [8, 1]
[rank0]:[titan] 2025-09-15 17:38:29,104 - root - INFO - [GC] Initial GC collection 0.00 seconds
[rank0]:NCCL version 2.27.5+cuda12.6
[rank0]:[titan] 2025-09-15 17:38:35,439 - root - INFO - Loading tokenizer from tokenizer.json
[rank0]:[titan] 2025-09-15 17:38:35,441 - root - INFO - Preparing c4_test dataset from tests/assets/c4_test
[rank0]:[titan] 2025-09-15 17:38:35,894 - root - INFO - Building llama3 debugmodel with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=256, n_layers=6, n_heads=16, n_kv_heads=None, vocab_size=2000, multiple_of=256, ffn_dim_multiplier=None, norm_eps=1e-05, rope_theta=500000, max_seq_len=2048, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=0)
[rank0]:[titan] 2025-09-15 17:38:35,931 - root - INFO - CUDA capacity: NVIDIA H100 with 94.99GiB memory
[rank0]:[titan] 2025-09-15 17:38:35,950 - root - INFO - Model llama3 debugmodel size: 6,139,136 total parameters
[rank0]:[titan] 2025-09-15 17:38:35,951 - root - INFO - Applied selective activation checkpointing to the model
[rank0]:[titan] 2025-09-15 17:38:35,972 - root - INFO - Applied DDP to the model
[rank0]:[titan] 2025-09-15 17:38:36,153 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14
[rank0]:[titan] 2025-09-15 17:38:36,153 - root - INFO - CUDA memory usage for model: 0.04GiB(0.04%)
[rank0]:[titan] 2025-09-15 17:38:36,154 - root - WARNING - model.safetensors.index.json not found at hf_assets_path: ./tests/assets/tokenizer/model.safetensors.index.json. Defaulting to saving a single safetensors file if checkpoint is saved in HF format
[rank0]:[titan] 2025-09-15 17:38:36,154 - root - INFO - Mixed precision training is handled by AMP
[rank0]:[titan] 2025-09-15 17:38:36,154 - root - INFO - Trainer is initialized with local batch size 8, global batch size 64, gradient accumulation steps 1, sequence length 2048, total steps 10 (warmup 2)

Stack from ghstack (oldest at bottom):

anshul-si added a commit that referenced this pull request Sep 15, 2025
…torchtitan

ghstack-source-id: ea5b964
Pull Request resolved: #1714
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 15, 2025
@anshul-si anshul-si marked this pull request as draft September 15, 2025 23:12
@anshul-si anshul-si requested review from mori360 and removed request for fegin September 15, 2025 23:12
…ation with torchtitan"

**Summary:** During this experiment to integrate the new replicate function into torchtitan, I used pytorch/pytorch#162021, which has not been landed. However, since this is more about making replicate more efficient rather than changing replicate's core code, pytorch/pytorch#160135, which has landed, should be fine. pytorch/pytorch#160133 is the last time replicate_with_fsdp.py and its replicate api were touched. 

In order to enable the new replicate, which uses a 2D device mesh (since it is a specialized version of HSDP), I changed the parallelism code to include dp_shard dim = 1 only if dp_replicate > 1, and created device mesh that I pass down in apply_ddp. 

**Test Case**
1. CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh

Expected output of this experiment should be something like:
[rank0]:[titan] 2025-09-15 17:38:26,676 - root - INFO - Starting job: Llama 3 debug training
[rank0]:[titan] 2025-09-15 17:38:29,094 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config
**[rank0]:[titan] 2025-09-15 17:38:29,097 - root - INFO - Building 2-D device mesh with ['dp_replicate', 'dp_shard'], [8, 1]**
[rank0]:[titan] 2025-09-15 17:38:29,104 - root - INFO - [GC] Initial GC collection 0.00 seconds
[rank0]:NCCL version 2.27.5+cuda12.6
[rank0]:[titan] 2025-09-15 17:38:35,439 - root - INFO - Loading tokenizer from tokenizer.json
[rank0]:[titan] 2025-09-15 17:38:35,441 - root - INFO - Preparing c4_test dataset from tests/assets/c4_test
[rank0]:[titan] 2025-09-15 17:38:35,894 - root - INFO - Building llama3 debugmodel with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=256, n_layers=6, n_heads=16, n_kv_heads=None, vocab_size=2000, multiple_of=256, ffn_dim_multiplier=None, norm_eps=1e-05, rope_theta=500000, max_seq_len=2048, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=0)
[rank0]:[titan] 2025-09-15 17:38:35,931 - root - INFO - CUDA capacity: NVIDIA H100 with 94.99GiB memory
[rank0]:[titan] 2025-09-15 17:38:35,950 - root - INFO - Model llama3 debugmodel size: 6,139,136 total parameters
[rank0]:[titan] 2025-09-15 17:38:35,951 - root - INFO - Applied selective activation checkpointing to the model
**[rank0]:[titan] 2025-09-15 17:38:35,972 - root - INFO - Applied DDP to the model**
[rank0]:[titan] 2025-09-15 17:38:36,153 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14
[rank0]:[titan] 2025-09-15 17:38:36,153 - root - INFO - CUDA memory usage for model: 0.04GiB(0.04%)
[rank0]:[titan] 2025-09-15 17:38:36,154 - root - WARNING - model.safetensors.index.json not found at hf_assets_path: ./tests/assets/tokenizer/model.safetensors.index.json.                     Defaulting to saving a single safetensors file if checkpoint is saved in HF format
[rank0]:[titan] 2025-09-15 17:38:36,154 - root - INFO - Mixed precision training is handled by AMP
[rank0]:[titan] 2025-09-15 17:38:36,154 - root - INFO - Trainer is initialized with local batch size 8, global batch size 64, gradient accumulation steps 1, sequence length 2048, total steps 10 (warmup 2)




[ghstack-poisoned]
anshul-si added a commit that referenced this pull request Sep 16, 2025
…torchtitan

ghstack-source-id: 19a48b7
Pull Request resolved: #1714
…ation with torchtitan"

**Summary:** During this experiment to integrate the new replicate function into torchtitan, I used pytorch/pytorch#162021, which has not been landed. However, since this is more about making replicate more efficient rather than changing replicate's core code, pytorch/pytorch#160135, which has landed, should be fine. pytorch/pytorch#160133 is the last time replicate_with_fsdp.py and its replicate api were touched. 

In order to enable the new replicate, which uses a 2D device mesh (since it is a specialized version of HSDP), I changed the parallelism code to include dp_shard dim = 1 only if dp_replicate > 1, and created device mesh that I pass down in apply_ddp. 

**Test Case**
1. CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh

Expected output of this experiment should be something like:
[rank0]:[titan] 2025-09-15 17:38:26,676 - root - INFO - Starting job: Llama 3 debug training
[rank0]:[titan] 2025-09-15 17:38:29,094 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config
**[rank0]:[titan] 2025-09-15 17:38:29,097 - root - INFO - Building 2-D device mesh with ['dp_replicate', 'dp_shard'], [8, 1]**
[rank0]:[titan] 2025-09-15 17:38:29,104 - root - INFO - [GC] Initial GC collection 0.00 seconds
[rank0]:NCCL version 2.27.5+cuda12.6
[rank0]:[titan] 2025-09-15 17:38:35,439 - root - INFO - Loading tokenizer from tokenizer.json
[rank0]:[titan] 2025-09-15 17:38:35,441 - root - INFO - Preparing c4_test dataset from tests/assets/c4_test
[rank0]:[titan] 2025-09-15 17:38:35,894 - root - INFO - Building llama3 debugmodel with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=256, n_layers=6, n_heads=16, n_kv_heads=None, vocab_size=2000, multiple_of=256, ffn_dim_multiplier=None, norm_eps=1e-05, rope_theta=500000, max_seq_len=2048, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=0)
[rank0]:[titan] 2025-09-15 17:38:35,931 - root - INFO - CUDA capacity: NVIDIA H100 with 94.99GiB memory
[rank0]:[titan] 2025-09-15 17:38:35,950 - root - INFO - Model llama3 debugmodel size: 6,139,136 total parameters
[rank0]:[titan] 2025-09-15 17:38:35,951 - root - INFO - Applied selective activation checkpointing to the model
**[rank0]:[titan] 2025-09-15 17:38:35,972 - root - INFO - Applied DDP to the model**
[rank0]:[titan] 2025-09-15 17:38:36,153 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14
[rank0]:[titan] 2025-09-15 17:38:36,153 - root - INFO - CUDA memory usage for model: 0.04GiB(0.04%)
[rank0]:[titan] 2025-09-15 17:38:36,154 - root - WARNING - model.safetensors.index.json not found at hf_assets_path: ./tests/assets/tokenizer/model.safetensors.index.json.                     Defaulting to saving a single safetensors file if checkpoint is saved in HF format
[rank0]:[titan] 2025-09-15 17:38:36,154 - root - INFO - Mixed precision training is handled by AMP
[rank0]:[titan] 2025-09-15 17:38:36,154 - root - INFO - Trainer is initialized with local batch size 8, global batch size 64, gradient accumulation steps 1, sequence length 2048, total steps 10 (warmup 2)




[ghstack-poisoned]
anshul-si added a commit that referenced this pull request Sep 16, 2025
…torchtitan

ghstack-source-id: 7bba1f6
Pull Request resolved: #1714
torch._dynamo.config.optimize_ddp = "ddp_optimizer"

replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
replicate(model, device_mesh=dp_mesh)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replicate is meant to be equivalent to FSDP1 NO_SHARD, instead of DDP

the key difference is DDP has bucket_cap_mb to overlap all_reduce with compute. replicate does not have such overlapping

the big assumption is user uses fsdp2, but want to avoid all-gathers for small modules, that's the time we use replicate

cc @tianyu-l in case we are counting replicate for DDP + EP

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replicate is meant to be equivalent to FSDP1 NO_SHARD, instead of DDP

What does this mean? I thought we are doing replicate only because we can replace DDP with replicate so that we can do DDP + TP or other parallelisms.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for overlappings in replicate, we should use nested wrapping

for layer in model.layers:
    replicate(layer)
    
replicate(model)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replicate is meant to be equivalent to FSDP1 NO_SHARD, instead of DDP

What does this mean? I thought we are doing replicate only because we can replace DDP with replicate so that we can do DDP + TP or other parallelisms.

synced in chat that we can replace DDP with replicate in titan, just need the per-layer wrapping for replicate

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the big assumption is user uses fsdp2, but want to avoid all-gathers for small modules, that's the time we use replicate

Does this means we could wrap some of small modules in the model with Replicate(), while other modules with fully_shard()? That's very flexible

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's right

…ation with torchtitan"

**Summary:** During this experiment to integrate the new replicate function into torchtitan, I used pytorch/pytorch#162021, which has not been landed. However, since this is more about making replicate more efficient rather than changing replicate's core code, pytorch/pytorch#160135, which has landed, should be fine. pytorch/pytorch#160133 is the last time replicate_with_fsdp.py and its replicate api were touched. 

In order to enable the new replicate, which uses a 2D device mesh (since it is a specialized version of HSDP), I changed the parallelism code to include dp_shard dim = 1 only if dp_replicate > 1, and created device mesh that I pass down in apply_ddp. 

**Test Case**
1. CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh

Expected output of this experiment should be something like:
[rank0]:[titan] 2025-09-15 17:38:26,676 - root - INFO - Starting job: Llama 3 debug training
[rank0]:[titan] 2025-09-15 17:38:29,094 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config
**[rank0]:[titan] 2025-09-15 17:38:29,097 - root - INFO - Building 2-D device mesh with ['dp_replicate', 'dp_shard'], [8, 1]**
[rank0]:[titan] 2025-09-15 17:38:29,104 - root - INFO - [GC] Initial GC collection 0.00 seconds
[rank0]:NCCL version 2.27.5+cuda12.6
[rank0]:[titan] 2025-09-15 17:38:35,439 - root - INFO - Loading tokenizer from tokenizer.json
[rank0]:[titan] 2025-09-15 17:38:35,441 - root - INFO - Preparing c4_test dataset from tests/assets/c4_test
[rank0]:[titan] 2025-09-15 17:38:35,894 - root - INFO - Building llama3 debugmodel with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=256, n_layers=6, n_heads=16, n_kv_heads=None, vocab_size=2000, multiple_of=256, ffn_dim_multiplier=None, norm_eps=1e-05, rope_theta=500000, max_seq_len=2048, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=0)
[rank0]:[titan] 2025-09-15 17:38:35,931 - root - INFO - CUDA capacity: NVIDIA H100 with 94.99GiB memory
[rank0]:[titan] 2025-09-15 17:38:35,950 - root - INFO - Model llama3 debugmodel size: 6,139,136 total parameters
[rank0]:[titan] 2025-09-15 17:38:35,951 - root - INFO - Applied selective activation checkpointing to the model
**[rank0]:[titan] 2025-09-15 17:38:35,972 - root - INFO - Applied DDP to the model**
[rank0]:[titan] 2025-09-15 17:38:36,153 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14
[rank0]:[titan] 2025-09-15 17:38:36,153 - root - INFO - CUDA memory usage for model: 0.04GiB(0.04%)
[rank0]:[titan] 2025-09-15 17:38:36,154 - root - WARNING - model.safetensors.index.json not found at hf_assets_path: ./tests/assets/tokenizer/model.safetensors.index.json.                     Defaulting to saving a single safetensors file if checkpoint is saved in HF format
[rank0]:[titan] 2025-09-15 17:38:36,154 - root - INFO - Mixed precision training is handled by AMP
[rank0]:[titan] 2025-09-15 17:38:36,154 - root - INFO - Trainer is initialized with local batch size 8, global batch size 64, gradient accumulation steps 1, sequence length 2048, total steps 10 (warmup 2)




[ghstack-poisoned]
anshul-si added a commit that referenced this pull request Sep 23, 2025
…torchtitan

ghstack-source-id: bfe9ee3
Pull Request resolved: #1714
@anshul-si anshul-si marked this pull request as ready for review September 23, 2025 20:16
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Had some comments

reduce_dtype: torch.dtype,
enable_compile: bool,
enable_compiled_autograd: bool,
cpu_offload: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any scenario we'd use cpu offloading with replicate?
CPU offloading is a memory saving technique with significant slowdown. If users are memory bound, the first thing they should try is probably FSDP (which would incur more comm such as AG in forward)?
The scenario I can think of is a world where communication is super slow so that it's better to do GPU/CPU transmit instead of cross GPU.

It's OK to leave this option here if it's supported, but would like to get some understanding.

cc @weifengpy @fegin

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think a good use case for this. cpu_offload is usually the last choice when people scale their models.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Besides Tianyu's example, I can't really think of any other examples that make sense. It seems more logical to remove cpu offloading from replicate/ddp.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when people enable replicate, do they have to turn off fsdp? if yes, then we wont need cou_offloading for replicate

when fsdp+replicate are both applied to differeny part of the model, we need cpu_offloading to be consistant across fsdp and replicate. otherwise optimizer.step see some a mixture of cpu and gpu tensors and throw errors

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@weifengpy
I think this PR is only about replacing DDP with replicate.

I understand the the mix of replicate and fully_shard gives more power, but it's not set up today in torchtitan. If we support such a policy, we probably need to call it apply_data_parallel instead of apply_ddp / apply_fsdp.

BTW, I think with this PR, apply_ddp is a bad name -- should we call it apply_replicate or something?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

apply_replicate is a good name

elif parallel_dims.dp_replicate_enabled:
if world_mesh.ndim > 1:
raise RuntimeError("DDP has not supported > 1D parallelism")
# if world_mesh.ndim > 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We always have >= 2D mesh as you would enable dp_shard==1 anyway?
We should:

  1. verify DDP+TP and DDP+PP work with correct numerics (see instructions at https://github.com/pytorch/torchtitan/blob/main/docs/debugging.md#seed-checkpoint-based-reproducibility)
  2. remove this comment
  3. change all other occurrences as they depend on this function https://github.com/search?q=repo%3Apytorch%2Ftorchtitan%20apply_ddp&type=code

…ation with torchtitan"

**Summary:** During this experiment to integrate the new replicate function into torchtitan, I used pytorch/pytorch#162021, which has not been landed. However, since this is more about making replicate more efficient rather than changing replicate's core code, pytorch/pytorch#160135, which has landed, should be fine. pytorch/pytorch#160133 is the last time replicate_with_fsdp.py and its replicate api were touched. 

In order to enable the new replicate, which uses a 2D device mesh (since it is a specialized version of HSDP), I changed the parallelism code to include dp_shard dim = 1 only if dp_replicate > 1, and created device mesh that I pass down in apply_ddp. 

Below is a link comparing the loss curves for Llama3.1-8B models: one configured with dimension sharding (2) and tensor parallelism (4), and the other with dimension replication (2) and sharding (4).

<img width="1266" height="483" alt="image" src="https://github.com/user-attachments/assets/40198bc5-5e3f-486b-be56-12111e010e0c" />

https://fburl.com/mlhub/btkos8ok

**Test Case**
1. CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh

Expected output of this experiment should be something like:
[rank0]:[titan] 2025-09-15 17:38:26,676 - root - INFO - Starting job: Llama 3 debug training
[rank0]:[titan] 2025-09-15 17:38:29,094 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config
**[rank0]:[titan] 2025-09-15 17:38:29,097 - root - INFO - Building 2-D device mesh with ['dp_replicate', 'dp_shard'], [8, 1]**
[rank0]:[titan] 2025-09-15 17:38:29,104 - root - INFO - [GC] Initial GC collection 0.00 seconds
[rank0]:NCCL version 2.27.5+cuda12.6
[rank0]:[titan] 2025-09-15 17:38:35,439 - root - INFO - Loading tokenizer from tokenizer.json
[rank0]:[titan] 2025-09-15 17:38:35,441 - root - INFO - Preparing c4_test dataset from tests/assets/c4_test
[rank0]:[titan] 2025-09-15 17:38:35,894 - root - INFO - Building llama3 debugmodel with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=256, n_layers=6, n_heads=16, n_kv_heads=None, vocab_size=2000, multiple_of=256, ffn_dim_multiplier=None, norm_eps=1e-05, rope_theta=500000, max_seq_len=2048, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=0)
[rank0]:[titan] 2025-09-15 17:38:35,931 - root - INFO - CUDA capacity: NVIDIA H100 with 94.99GiB memory
[rank0]:[titan] 2025-09-15 17:38:35,950 - root - INFO - Model llama3 debugmodel size: 6,139,136 total parameters
[rank0]:[titan] 2025-09-15 17:38:35,951 - root - INFO - Applied selective activation checkpointing to the model
**[rank0]:[titan] 2025-09-15 17:38:35,972 - root - INFO - Applied DDP to the model**
[rank0]:[titan] 2025-09-15 17:38:36,153 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14
[rank0]:[titan] 2025-09-15 17:38:36,153 - root - INFO - CUDA memory usage for model: 0.04GiB(0.04%)
[rank0]:[titan] 2025-09-15 17:38:36,154 - root - WARNING - model.safetensors.index.json not found at hf_assets_path: ./tests/assets/tokenizer/model.safetensors.index.json.                     Defaulting to saving a single safetensors file if checkpoint is saved in HF format
[rank0]:[titan] 2025-09-15 17:38:36,154 - root - INFO - Mixed precision training is handled by AMP
[rank0]:[titan] 2025-09-15 17:38:36,154 - root - INFO - Trainer is initialized with local batch size 8, global batch size 64, gradient accumulation steps 1, sequence length 2048, total steps 10 (warmup 2)




[ghstack-poisoned]
anshul-si added a commit that referenced this pull request Sep 23, 2025
…torchtitan

ghstack-source-id: 1ab4103
Pull Request resolved: #1714
@EquationWalker
Copy link

EquationWalker commented Sep 24, 2025

In this case, Mixed precision of replicate_with_fsdp should be handled by fully_shard instead of AMP. This means that we need to modify torchtitan/distributed/utils.py/maybe_enable_amp() to accommodate replicate_with_fsdp .
By the way, DistributedDataParallel has experimentally supported native mixed precision, similar to MixedPrecisionPolicy of FSDP2. This means that perhaps torchtitan can remove torchtitan/distributed/utils.py/maybe_enable_amp() completely. See at DDP native mixed precision #92882.
cc @weifengpy @tianyu-l

@tianyu-l
Copy link
Contributor

@EquationWalker

In this case, Mixed precision of replicate_with_fsdp should be handled by fully_shard instead of AMP. This means that we need to modify torchtitan/distributed/utils.py/maybe_enable_amp() to accommodate replicate_with_fsdp .

Great point!
@anshul-si Let's accommodate.

apply_ddp(
model,
world_mesh,
world_mesh[tuple(dp_mesh_dim_names)],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this dp_mesh 2d? for user facing apis, it should be a 1d mesh or default 1d world mesh

inside replicate api, we can do 2d

Copy link
Author

@anshul-si anshul-si Sep 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently replicate api doesn't have a method to convert the 1d mesh to 2d. I can leave a TODO to change this when the unflatten method for device mesh is complete, but until then I think it makes more sense to leave it like this. Also the user technically believes they are creating a 1d mesh because they are only changing replicate dim. I think from their perspective, they still are only creating a 1d mesh

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember @fduwjj mentioned a mesh api to convert 1d to 2d. this needs to be done before landing. it's a user contract. if we change 2d back to 1d later, that becomes bc-breaking

@weifengpy
Copy link
Contributor

In this case, Mixed precision of replicate_with_fsdp should be handled by fully_shard instead of AMP. This means that we need to modify torchtitan/distributed/utils.py/maybe_enable_amp() to accommodate replicate_with_fsdp .
By the way, DistributedDataParallel has experimentally supported native mixed precision, similar to MixedPrecisionPolicy of FSDP2.

@EquationWalker good catch!

Copy link
Contributor

@weifengpy weifengpy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my request changes is mainly on 2d mesh. we should target 1d mesh for landing. it's a user contract in public facing api

@EquationWalker
Copy link

EquationWalker commented Sep 25, 2025

my request changes is mainly on 2d mesh. we should target 1d mesh for landing. it's a user contract in public facing api

I think the use of 2D mesh has something to do with the FSDPParamGroup user contract. When passing a 2D mesh, FSDPParamGroup treats it as an HSDP and then shard parameters in the second dimension and replicate parameters in the first dimension. If you pass a 1D Mesh, FSDPParamGroup will shard parameters on this mesh instead of replicating them.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants