-
Notifications
You must be signed in to change notification settings - Fork 314
Open
Labels
Description
Summary
FSDP works as expected on routed experts. However, on shared experts, an error occurs.
Repro
NGPU=4 CONFIG_FILE="./torchtitan/experiments/llama4/train_configs/debug_model.toml" ./run_train.sh --training.steps=100 --model.converters="float8" --float8.recipe_name="rowwise" --float8.moe_fqns_prototype="shared_expert"
Error:
File "/home/danvm/.conda/envs/torchtitan/lib/python3.13/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 357, in wrapper
return f(*args, **kwargs)
File "/home/danvm/torchtitan/torchtitan/train.py", line 437, in train
self.train_step(inputs, labels)
~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^
File "/home/danvm/torchtitan/torchtitan/train.py", line 370, in train_step
pred = model_parts[0](inputs)
File "/home/danvm/.conda/envs/torchtitan/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/home/danvm/.conda/envs/torchtitan/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
return inner()
File "/home/danvm/.conda/envs/torchtitan/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1827, in inner
result = forward_call(*args, **kwargs)
File "/home/danvm/torchtitan/torchtitan/experiments/llama4/model/model.py", line 471, in forward
h = layer(h, self.freqs_cis)
File "/home/danvm/.conda/envs/torchtitan/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/home/danvm/.conda/envs/torchtitan/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
return inner()
File "/home/danvm/.conda/envs/torchtitan/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1806, in inner
args_kwargs_result = hook(self, args, kwargs) # type: ignore[misc]
File "/home/danvm/.conda/envs/torchtitan/lib/python3.13/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_state.py", line 62, in fsdp_hook_wrapper
return torch._dynamo.disable(
~~~~~~~~~~~~~~~~~~~~~~
...<2 lines>...
reason="skipping FSDP hooks since torch._dynamo.config.skip_fsdp_hooks is set",
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
)(*args, **kwargs)
~^^^^^^^^^^^^^^^^^
File "/home/danvm/.conda/envs/torchtitan/lib/python3.13/site-packages/torch/_dynamo/eval_frame.py", line 896, in _fn
return fn(*args, **kwargs)
File "/home/danvm/.conda/envs/torchtitan/lib/python3.13/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_state.py", line 248, in _pre_forward
args, kwargs = self._fsdp_param_group.pre_forward(module, args, kwargs)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^
File "/home/danvm/.conda/envs/torchtitan/lib/python3.13/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py", line 352, in pre_forward
self.wait_for_unshard()
~~~~~~~~~~~~~~~~~~~~~^^
File "/home/danvm/.conda/envs/torchtitan/lib/python3.13/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py", line 311, in wait_for_unshard
fsdp_param.init_unsharded_param()
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^
File "/home/danvm/.conda/envs/torchtitan/lib/python3.13/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param.py", line 504, in init_unsharded_param
unsharded_param = torch.as_strided(
unsharded_tensor,
...<2 lines>...
storage_offset=0,
)
File "/home/danvm/ao/torchao/prototype/moe_training/tensor.py", line 87, in __torch_function__
return func(*args, **kwargs)
File "/home/danvm/ao/torchao/prototype/moe_training/tensor.py", line 112, in __torch_dispatch__
out = func(*args, **kwargs)
File "/home/danvm/.conda/envs/torchtitan/lib/python3.13/site-packages/torch/_ops.py", line 829, in __call__
return self._op(*args, **kwargs)
~~~~~~~~^^^^^^^^^^^^^^^^^
RuntimeError: setStorage: sizes [1, 256, 512], strides [131072, 512, 1], storage offset 0, and itemsize 4 requiring a storage size of 524288 are out of bounds for storage of size 0
Other details
Last op before error is a transpose:
[rank0]:[titan] 2025-06-26 20:54:23,410 - torchao.prototype.moe_training.tensor - INFO - transpose, args=(ScaledGroupedMMTensor(data=tensor([[[ 0.0122, -0.0187, -0.0114, ..., 0.0130, 0.0087, -0.0008],
[rank0]: [-0.0062, -0.0094, 0.0069, ..., -0.0036, -0.0066, 0.0196],
[rank0]: [ 0.0123, -0.0081, -0.0111, ..., 0.0017, 0.0075, 0.0020],
[rank0]: ...,
[rank0]: [ 0.0153, 0.0105, 0.0001, ..., -0.0139, -0.0056, 0.0094],
[rank0]: [-0.0102, 0.0026, 0.0302, ..., 0.0012, -0.0049, 0.0052],
[rank0]: [-0.0141, -0.0030, 0.0159, ..., 0.0162, -0.0038, -0.0104]]],
[rank0]: device='cuda:0'), dtype=torch.bfloat16), -2, -1), kwargs={}