Skip to content

[MoE training] crash with FSDP if shared_expert uses float8 in torchtitan llama4 #2453

@danielvegamyhre

Description

@danielvegamyhre

Summary

FSDP works as expected on routed experts. However, on shared experts, an error occurs.

Repro

NGPU=4 CONFIG_FILE="./torchtitan/experiments/llama4/train_configs/debug_model.toml" ./run_train.sh --training.steps=100 --model.converters="float8" --float8.recipe_name="rowwise" --float8.moe_fqns_prototype="shared_expert"

Error:

    File "/home/danvm/.conda/envs/torchtitan/lib/python3.13/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 357, in wrapper
      return f(*args, **kwargs)
    File "/home/danvm/torchtitan/torchtitan/train.py", line 437, in train
      self.train_step(inputs, labels)
      ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^
    File "/home/danvm/torchtitan/torchtitan/train.py", line 370, in train_step
      pred = model_parts[0](inputs)
    File "/home/danvm/.conda/envs/torchtitan/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
             ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
    File "/home/danvm/.conda/envs/torchtitan/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
      return inner()
    File "/home/danvm/.conda/envs/torchtitan/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1827, in inner
      result = forward_call(*args, **kwargs)
    File "/home/danvm/torchtitan/torchtitan/experiments/llama4/model/model.py", line 471, in forward
      h = layer(h, self.freqs_cis)
    File "/home/danvm/.conda/envs/torchtitan/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
             ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
    File "/home/danvm/.conda/envs/torchtitan/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
      return inner()
    File "/home/danvm/.conda/envs/torchtitan/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1806, in inner
      args_kwargs_result = hook(self, args, kwargs)  # type: ignore[misc]
    File "/home/danvm/.conda/envs/torchtitan/lib/python3.13/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_state.py", line 62, in fsdp_hook_wrapper
      return torch._dynamo.disable(
             ~~~~~~~~~~~~~~~~~~~~~~
      ...<2 lines>...
          reason="skipping FSDP hooks since torch._dynamo.config.skip_fsdp_hooks is set",
          ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
      )(*args, **kwargs)
      ~^^^^^^^^^^^^^^^^^
    File "/home/danvm/.conda/envs/torchtitan/lib/python3.13/site-packages/torch/_dynamo/eval_frame.py", line 896, in _fn
      return fn(*args, **kwargs)
    File "/home/danvm/.conda/envs/torchtitan/lib/python3.13/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_state.py", line 248, in _pre_forward
      args, kwargs = self._fsdp_param_group.pre_forward(module, args, kwargs)
                     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^
    File "/home/danvm/.conda/envs/torchtitan/lib/python3.13/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py", line 352, in pre_forward
      self.wait_for_unshard()
      ~~~~~~~~~~~~~~~~~~~~~^^
    File "/home/danvm/.conda/envs/torchtitan/lib/python3.13/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py", line 311, in wait_for_unshard
      fsdp_param.init_unsharded_param()
      ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^
    File "/home/danvm/.conda/envs/torchtitan/lib/python3.13/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param.py", line 504, in init_unsharded_param
      unsharded_param = torch.as_strided(
          unsharded_tensor,
      ...<2 lines>...
          storage_offset=0,
      )
    File "/home/danvm/ao/torchao/prototype/moe_training/tensor.py", line 87, in __torch_function__
      return func(*args, **kwargs)
    File "/home/danvm/ao/torchao/prototype/moe_training/tensor.py", line 112, in __torch_dispatch__
      out = func(*args, **kwargs)
    File "/home/danvm/.conda/envs/torchtitan/lib/python3.13/site-packages/torch/_ops.py", line 829, in __call__
      return self._op(*args, **kwargs)
             ~~~~~~~~^^^^^^^^^^^^^^^^^
  RuntimeError: setStorage: sizes [1, 256, 512], strides [131072, 512, 1], storage offset 0, and itemsize 4 requiring a storage size of 524288 are out of bounds for storage of size 0

Other details

Last op before error is a transpose:

[rank0]:[titan] 2025-06-26 20:54:23,410 - torchao.prototype.moe_training.tensor - INFO - transpose, args=(ScaledGroupedMMTensor(data=tensor([[[ 0.0122, -0.0187, -0.0114,  ...,  0.0130,  0.0087, -0.0008],
[rank0]:         [-0.0062, -0.0094,  0.0069,  ..., -0.0036, -0.0066,  0.0196],
[rank0]:         [ 0.0123, -0.0081, -0.0111,  ...,  0.0017,  0.0075,  0.0020],
[rank0]:         ...,
[rank0]:         [ 0.0153,  0.0105,  0.0001,  ..., -0.0139, -0.0056,  0.0094],
[rank0]:         [-0.0102,  0.0026,  0.0302,  ...,  0.0012, -0.0049,  0.0052],
[rank0]:         [-0.0141, -0.0030,  0.0159,  ...,  0.0162, -0.0038, -0.0104]]],
[rank0]:       device='cuda:0'), dtype=torch.bfloat16), -2, -1), kwargs={}

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions