-
Notifications
You must be signed in to change notification settings - Fork 530
[torchtitan][replicate] experimenting new replicate integration with torchtitan #1714
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/anshul-si/1/base
Are you sure you want to change the base?
Conversation
…torchtitan [ghstack-poisoned]
…ation with torchtitan" **Summary:** During this experiment to integrate the new replicate function into torchtitan, I used pytorch/pytorch#162021, which has not been landed. However, since this is more about making replicate more efficient rather than changing replicate's core code, pytorch/pytorch#160135, which has landed, should be fine. pytorch/pytorch#160133 is the last time replicate_with_fsdp.py and its replicate api were touched. In order to enable the new replicate, which uses a 2D device mesh (since it is a specialized version of HSDP), I changed the parallelism code to include dp_shard dim = 1 only if dp_replicate > 1, and created device mesh that I pass down in apply_ddp. **Test Case** 1. CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh Expected output of this experiment should be something like: [rank0]:[titan] 2025-09-15 17:38:26,676 - root - INFO - Starting job: Llama 3 debug training [rank0]:[titan] 2025-09-15 17:38:29,094 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config **[rank0]:[titan] 2025-09-15 17:38:29,097 - root - INFO - Building 2-D device mesh with ['dp_replicate', 'dp_shard'], [8, 1]** [rank0]:[titan] 2025-09-15 17:38:29,104 - root - INFO - [GC] Initial GC collection 0.00 seconds [rank0]:NCCL version 2.27.5+cuda12.6 [rank0]:[titan] 2025-09-15 17:38:35,439 - root - INFO - Loading tokenizer from tokenizer.json [rank0]:[titan] 2025-09-15 17:38:35,441 - root - INFO - Preparing c4_test dataset from tests/assets/c4_test [rank0]:[titan] 2025-09-15 17:38:35,894 - root - INFO - Building llama3 debugmodel with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=256, n_layers=6, n_heads=16, n_kv_heads=None, vocab_size=2000, multiple_of=256, ffn_dim_multiplier=None, norm_eps=1e-05, rope_theta=500000, max_seq_len=2048, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=0) [rank0]:[titan] 2025-09-15 17:38:35,931 - root - INFO - CUDA capacity: NVIDIA H100 with 94.99GiB memory [rank0]:[titan] 2025-09-15 17:38:35,950 - root - INFO - Model llama3 debugmodel size: 6,139,136 total parameters [rank0]:[titan] 2025-09-15 17:38:35,951 - root - INFO - Applied selective activation checkpointing to the model **[rank0]:[titan] 2025-09-15 17:38:35,972 - root - INFO - Applied DDP to the model** [rank0]:[titan] 2025-09-15 17:38:36,153 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14 [rank0]:[titan] 2025-09-15 17:38:36,153 - root - INFO - CUDA memory usage for model: 0.04GiB(0.04%) [rank0]:[titan] 2025-09-15 17:38:36,154 - root - WARNING - model.safetensors.index.json not found at hf_assets_path: ./tests/assets/tokenizer/model.safetensors.index.json. Defaulting to saving a single safetensors file if checkpoint is saved in HF format [rank0]:[titan] 2025-09-15 17:38:36,154 - root - INFO - Mixed precision training is handled by AMP [rank0]:[titan] 2025-09-15 17:38:36,154 - root - INFO - Trainer is initialized with local batch size 8, global batch size 64, gradient accumulation steps 1, sequence length 2048, total steps 10 (warmup 2) [ghstack-poisoned]
…ation with torchtitan" **Summary:** During this experiment to integrate the new replicate function into torchtitan, I used pytorch/pytorch#162021, which has not been landed. However, since this is more about making replicate more efficient rather than changing replicate's core code, pytorch/pytorch#160135, which has landed, should be fine. pytorch/pytorch#160133 is the last time replicate_with_fsdp.py and its replicate api were touched. In order to enable the new replicate, which uses a 2D device mesh (since it is a specialized version of HSDP), I changed the parallelism code to include dp_shard dim = 1 only if dp_replicate > 1, and created device mesh that I pass down in apply_ddp. **Test Case** 1. CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh Expected output of this experiment should be something like: [rank0]:[titan] 2025-09-15 17:38:26,676 - root - INFO - Starting job: Llama 3 debug training [rank0]:[titan] 2025-09-15 17:38:29,094 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config **[rank0]:[titan] 2025-09-15 17:38:29,097 - root - INFO - Building 2-D device mesh with ['dp_replicate', 'dp_shard'], [8, 1]** [rank0]:[titan] 2025-09-15 17:38:29,104 - root - INFO - [GC] Initial GC collection 0.00 seconds [rank0]:NCCL version 2.27.5+cuda12.6 [rank0]:[titan] 2025-09-15 17:38:35,439 - root - INFO - Loading tokenizer from tokenizer.json [rank0]:[titan] 2025-09-15 17:38:35,441 - root - INFO - Preparing c4_test dataset from tests/assets/c4_test [rank0]:[titan] 2025-09-15 17:38:35,894 - root - INFO - Building llama3 debugmodel with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=256, n_layers=6, n_heads=16, n_kv_heads=None, vocab_size=2000, multiple_of=256, ffn_dim_multiplier=None, norm_eps=1e-05, rope_theta=500000, max_seq_len=2048, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=0) [rank0]:[titan] 2025-09-15 17:38:35,931 - root - INFO - CUDA capacity: NVIDIA H100 with 94.99GiB memory [rank0]:[titan] 2025-09-15 17:38:35,950 - root - INFO - Model llama3 debugmodel size: 6,139,136 total parameters [rank0]:[titan] 2025-09-15 17:38:35,951 - root - INFO - Applied selective activation checkpointing to the model **[rank0]:[titan] 2025-09-15 17:38:35,972 - root - INFO - Applied DDP to the model** [rank0]:[titan] 2025-09-15 17:38:36,153 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14 [rank0]:[titan] 2025-09-15 17:38:36,153 - root - INFO - CUDA memory usage for model: 0.04GiB(0.04%) [rank0]:[titan] 2025-09-15 17:38:36,154 - root - WARNING - model.safetensors.index.json not found at hf_assets_path: ./tests/assets/tokenizer/model.safetensors.index.json. Defaulting to saving a single safetensors file if checkpoint is saved in HF format [rank0]:[titan] 2025-09-15 17:38:36,154 - root - INFO - Mixed precision training is handled by AMP [rank0]:[titan] 2025-09-15 17:38:36,154 - root - INFO - Trainer is initialized with local batch size 8, global batch size 64, gradient accumulation steps 1, sequence length 2048, total steps 10 (warmup 2) [ghstack-poisoned]
torch._dynamo.config.optimize_ddp = "ddp_optimizer" | ||
|
||
replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100) | ||
replicate(model, device_mesh=dp_mesh) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
replicate is meant to be equivalent to FSDP1 NO_SHARD, instead of DDP
the key difference is DDP has bucket_cap_mb to overlap all_reduce with compute. replicate does not have such overlapping
the big assumption is user uses fsdp2, but want to avoid all-gathers for small modules, that's the time we use replicate
cc @tianyu-l in case we are counting replicate for DDP + EP
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
replicate is meant to be equivalent to FSDP1 NO_SHARD, instead of DDP
What does this mean? I thought we are doing replicate only because we can replace DDP with replicate so that we can do DDP + TP or other parallelisms.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for overlappings in replicate, we should use nested wrapping
for layer in model.layers:
replicate(layer)
replicate(model)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
replicate is meant to be equivalent to FSDP1 NO_SHARD, instead of DDP
What does this mean? I thought we are doing replicate only because we can replace DDP with replicate so that we can do DDP + TP or other parallelisms.
synced in chat that we can replace DDP with replicate in titan, just need the per-layer wrapping for replicate
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the big assumption is user uses fsdp2, but want to avoid all-gathers for small modules, that's the time we use replicate
Does this means we could wrap some of small modules in the model with Replicate(), while other modules with fully_shard()? That's very flexible
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's right
…ation with torchtitan" **Summary:** During this experiment to integrate the new replicate function into torchtitan, I used pytorch/pytorch#162021, which has not been landed. However, since this is more about making replicate more efficient rather than changing replicate's core code, pytorch/pytorch#160135, which has landed, should be fine. pytorch/pytorch#160133 is the last time replicate_with_fsdp.py and its replicate api were touched. In order to enable the new replicate, which uses a 2D device mesh (since it is a specialized version of HSDP), I changed the parallelism code to include dp_shard dim = 1 only if dp_replicate > 1, and created device mesh that I pass down in apply_ddp. **Test Case** 1. CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh Expected output of this experiment should be something like: [rank0]:[titan] 2025-09-15 17:38:26,676 - root - INFO - Starting job: Llama 3 debug training [rank0]:[titan] 2025-09-15 17:38:29,094 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config **[rank0]:[titan] 2025-09-15 17:38:29,097 - root - INFO - Building 2-D device mesh with ['dp_replicate', 'dp_shard'], [8, 1]** [rank0]:[titan] 2025-09-15 17:38:29,104 - root - INFO - [GC] Initial GC collection 0.00 seconds [rank0]:NCCL version 2.27.5+cuda12.6 [rank0]:[titan] 2025-09-15 17:38:35,439 - root - INFO - Loading tokenizer from tokenizer.json [rank0]:[titan] 2025-09-15 17:38:35,441 - root - INFO - Preparing c4_test dataset from tests/assets/c4_test [rank0]:[titan] 2025-09-15 17:38:35,894 - root - INFO - Building llama3 debugmodel with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=256, n_layers=6, n_heads=16, n_kv_heads=None, vocab_size=2000, multiple_of=256, ffn_dim_multiplier=None, norm_eps=1e-05, rope_theta=500000, max_seq_len=2048, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=0) [rank0]:[titan] 2025-09-15 17:38:35,931 - root - INFO - CUDA capacity: NVIDIA H100 with 94.99GiB memory [rank0]:[titan] 2025-09-15 17:38:35,950 - root - INFO - Model llama3 debugmodel size: 6,139,136 total parameters [rank0]:[titan] 2025-09-15 17:38:35,951 - root - INFO - Applied selective activation checkpointing to the model **[rank0]:[titan] 2025-09-15 17:38:35,972 - root - INFO - Applied DDP to the model** [rank0]:[titan] 2025-09-15 17:38:36,153 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14 [rank0]:[titan] 2025-09-15 17:38:36,153 - root - INFO - CUDA memory usage for model: 0.04GiB(0.04%) [rank0]:[titan] 2025-09-15 17:38:36,154 - root - WARNING - model.safetensors.index.json not found at hf_assets_path: ./tests/assets/tokenizer/model.safetensors.index.json. Defaulting to saving a single safetensors file if checkpoint is saved in HF format [rank0]:[titan] 2025-09-15 17:38:36,154 - root - INFO - Mixed precision training is handled by AMP [rank0]:[titan] 2025-09-15 17:38:36,154 - root - INFO - Trainer is initialized with local batch size 8, global batch size 64, gradient accumulation steps 1, sequence length 2048, total steps 10 (warmup 2) [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Had some comments
reduce_dtype: torch.dtype, | ||
enable_compile: bool, | ||
enable_compiled_autograd: bool, | ||
cpu_offload: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any scenario we'd use cpu offloading with replicate?
CPU offloading is a memory saving technique with significant slowdown. If users are memory bound, the first thing they should try is probably FSDP (which would incur more comm such as AG in forward)?
The scenario I can think of is a world where communication is super slow so that it's better to do GPU/CPU transmit instead of cross GPU.
It's OK to leave this option here if it's supported, but would like to get some understanding.
cc @weifengpy @fegin
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think a good use case for this. cpu_offload
is usually the last choice when people scale their models.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Besides Tianyu's example, I can't really think of any other examples that make sense. It seems more logical to remove cpu offloading from replicate/ddp.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when people enable replicate, do they have to turn off fsdp? if yes, then we wont need cou_offloading for replicate
when fsdp+replicate are both applied to differeny part of the model, we need cpu_offloading to be consistant across fsdp and replicate. otherwise optimizer.step see some a mixture of cpu and gpu tensors and throw errors
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@weifengpy
I think this PR is only about replacing DDP with replicate.
I understand the the mix of replicate
and fully_shard
gives more power, but it's not set up today in torchtitan. If we support such a policy, we probably need to call it apply_data_parallel
instead of apply_ddp
/ apply_fsdp
.
BTW, I think with this PR, apply_ddp
is a bad name -- should we call it apply_replicate
or something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
apply_replicate is a good name
elif parallel_dims.dp_replicate_enabled: | ||
if world_mesh.ndim > 1: | ||
raise RuntimeError("DDP has not supported > 1D parallelism") | ||
# if world_mesh.ndim > 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We always have >= 2D mesh as you would enable dp_shard==1 anyway?
We should:
- verify DDP+TP and DDP+PP work with correct numerics (see instructions at https://github.com/pytorch/torchtitan/blob/main/docs/debugging.md#seed-checkpoint-based-reproducibility)
- remove this comment
- change all other occurrences as they depend on this function https://github.com/search?q=repo%3Apytorch%2Ftorchtitan%20apply_ddp&type=code
…ation with torchtitan" **Summary:** During this experiment to integrate the new replicate function into torchtitan, I used pytorch/pytorch#162021, which has not been landed. However, since this is more about making replicate more efficient rather than changing replicate's core code, pytorch/pytorch#160135, which has landed, should be fine. pytorch/pytorch#160133 is the last time replicate_with_fsdp.py and its replicate api were touched. In order to enable the new replicate, which uses a 2D device mesh (since it is a specialized version of HSDP), I changed the parallelism code to include dp_shard dim = 1 only if dp_replicate > 1, and created device mesh that I pass down in apply_ddp. Below is a link comparing the loss curves for Llama3.1-8B models: one configured with dimension sharding (2) and tensor parallelism (4), and the other with dimension replication (2) and sharding (4). <img width="1266" height="483" alt="image" src="https://github.com/user-attachments/assets/40198bc5-5e3f-486b-be56-12111e010e0c" /> https://fburl.com/mlhub/btkos8ok **Test Case** 1. CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh Expected output of this experiment should be something like: [rank0]:[titan] 2025-09-15 17:38:26,676 - root - INFO - Starting job: Llama 3 debug training [rank0]:[titan] 2025-09-15 17:38:29,094 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config **[rank0]:[titan] 2025-09-15 17:38:29,097 - root - INFO - Building 2-D device mesh with ['dp_replicate', 'dp_shard'], [8, 1]** [rank0]:[titan] 2025-09-15 17:38:29,104 - root - INFO - [GC] Initial GC collection 0.00 seconds [rank0]:NCCL version 2.27.5+cuda12.6 [rank0]:[titan] 2025-09-15 17:38:35,439 - root - INFO - Loading tokenizer from tokenizer.json [rank0]:[titan] 2025-09-15 17:38:35,441 - root - INFO - Preparing c4_test dataset from tests/assets/c4_test [rank0]:[titan] 2025-09-15 17:38:35,894 - root - INFO - Building llama3 debugmodel with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=256, n_layers=6, n_heads=16, n_kv_heads=None, vocab_size=2000, multiple_of=256, ffn_dim_multiplier=None, norm_eps=1e-05, rope_theta=500000, max_seq_len=2048, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=0) [rank0]:[titan] 2025-09-15 17:38:35,931 - root - INFO - CUDA capacity: NVIDIA H100 with 94.99GiB memory [rank0]:[titan] 2025-09-15 17:38:35,950 - root - INFO - Model llama3 debugmodel size: 6,139,136 total parameters [rank0]:[titan] 2025-09-15 17:38:35,951 - root - INFO - Applied selective activation checkpointing to the model **[rank0]:[titan] 2025-09-15 17:38:35,972 - root - INFO - Applied DDP to the model** [rank0]:[titan] 2025-09-15 17:38:36,153 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14 [rank0]:[titan] 2025-09-15 17:38:36,153 - root - INFO - CUDA memory usage for model: 0.04GiB(0.04%) [rank0]:[titan] 2025-09-15 17:38:36,154 - root - WARNING - model.safetensors.index.json not found at hf_assets_path: ./tests/assets/tokenizer/model.safetensors.index.json. Defaulting to saving a single safetensors file if checkpoint is saved in HF format [rank0]:[titan] 2025-09-15 17:38:36,154 - root - INFO - Mixed precision training is handled by AMP [rank0]:[titan] 2025-09-15 17:38:36,154 - root - INFO - Trainer is initialized with local batch size 8, global batch size 64, gradient accumulation steps 1, sequence length 2048, total steps 10 (warmup 2) [ghstack-poisoned]
In this case, Mixed precision of |
Great point! |
apply_ddp( | ||
model, | ||
world_mesh, | ||
world_mesh[tuple(dp_mesh_dim_names)], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this dp_mesh 2d? for user facing apis, it should be a 1d mesh or default 1d world mesh
inside replicate api, we can do 2d
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
currently replicate api doesn't have a method to convert the 1d mesh to 2d. I can leave a TODO to change this when the unflatten method for device mesh is complete, but until then I think it makes more sense to leave it like this. Also the user technically believes they are creating a 1d mesh because they are only changing replicate dim. I think from their perspective, they still are only creating a 1d mesh
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I remember @fduwjj mentioned a mesh api to convert 1d to 2d. this needs to be done before landing. it's a user contract. if we change 2d back to 1d later, that becomes bc-breaking
@EquationWalker good catch! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
my request changes is mainly on 2d mesh. we should target 1d mesh for landing. it's a user contract in public facing api
I think the use of 2D mesh has something to do with the |
Summary: During this experiment to integrate the new replicate function into torchtitan, I used pytorch/pytorch#162021, which has not been landed. However, since this is more about making replicate more efficient rather than changing replicate's core code, pytorch/pytorch#160135, which has landed, should be fine. pytorch/pytorch#160133 is the last time replicate_with_fsdp.py and its replicate api were touched.
In order to enable the new replicate, which uses a 2D device mesh (since it is a specialized version of HSDP), I changed the parallelism code to include dp_shard dim = 1 only if dp_replicate > 1, and created device mesh that I pass down in apply_ddp.
Below is a link comparing the loss curves for Llama3.1-8B models: one configured with dimension sharding (2) and tensor parallelism (4), and the other with dimension replication (2) and sharding (4).
https://fburl.com/mlhub/btkos8ok
Test Case
Expected output of this experiment should be something like:
[rank0]:[titan] 2025-09-15 17:38:26,676 - root - INFO - Starting job: Llama 3 debug training
[rank0]:[titan] 2025-09-15 17:38:29,094 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config
[rank0]:[titan] 2025-09-15 17:38:29,097 - root - INFO - Building 2-D device mesh with ['dp_replicate', 'dp_shard'], [8, 1]
[rank0]:[titan] 2025-09-15 17:38:29,104 - root - INFO - [GC] Initial GC collection 0.00 seconds
[rank0]:NCCL version 2.27.5+cuda12.6
[rank0]:[titan] 2025-09-15 17:38:35,439 - root - INFO - Loading tokenizer from tokenizer.json
[rank0]:[titan] 2025-09-15 17:38:35,441 - root - INFO - Preparing c4_test dataset from tests/assets/c4_test
[rank0]:[titan] 2025-09-15 17:38:35,894 - root - INFO - Building llama3 debugmodel with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=256, n_layers=6, n_heads=16, n_kv_heads=None, vocab_size=2000, multiple_of=256, ffn_dim_multiplier=None, norm_eps=1e-05, rope_theta=500000, max_seq_len=2048, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=0)
[rank0]:[titan] 2025-09-15 17:38:35,931 - root - INFO - CUDA capacity: NVIDIA H100 with 94.99GiB memory
[rank0]:[titan] 2025-09-15 17:38:35,950 - root - INFO - Model llama3 debugmodel size: 6,139,136 total parameters
[rank0]:[titan] 2025-09-15 17:38:35,951 - root - INFO - Applied selective activation checkpointing to the model
[rank0]:[titan] 2025-09-15 17:38:35,972 - root - INFO - Applied DDP to the model
[rank0]:[titan] 2025-09-15 17:38:36,153 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14
[rank0]:[titan] 2025-09-15 17:38:36,153 - root - INFO - CUDA memory usage for model: 0.04GiB(0.04%)
[rank0]:[titan] 2025-09-15 17:38:36,154 - root - WARNING - model.safetensors.index.json not found at hf_assets_path: ./tests/assets/tokenizer/model.safetensors.index.json. Defaulting to saving a single safetensors file if checkpoint is saved in HF format
[rank0]:[titan] 2025-09-15 17:38:36,154 - root - INFO - Mixed precision training is handled by AMP
[rank0]:[titan] 2025-09-15 17:38:36,154 - root - INFO - Trainer is initialized with local batch size 8, global batch size 64, gradient accumulation steps 1, sequence length 2048, total steps 10 (warmup 2)
Stack from ghstack (oldest at bottom):