Skip to content

Commit 60c583e

Browse files
Add float8 MoE training readme and runnable example (#2353)
* add moe training readme and runnable example * mention software requirements
1 parent b6bb7dc commit 60c583e

File tree

2 files changed

+166
-0
lines changed

2 files changed

+166
-0
lines changed
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Float8 MoE Training
2+
3+
This prototype feature provides a way to use float8 rowwise training on MoE layers.
4+
5+
Below is a simple runnable example of how to use this feature, using the MoE layer
6+
from the [torchtitan](https://github.com/pytorch/torchtitan) Llama4 implementation for demonstration.
7+
8+
9+
```python
10+
import torch
11+
from torch import nn
12+
from torch.nn import functional as F
13+
14+
# this feature requires CUDA and SM89+
15+
assert torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
16+
17+
from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig
18+
from torchao.quantization.quant_api import quantize_
19+
20+
# this example uses torchtitan llama4 MoE, see
21+
try:
22+
from torchtitan.experiments.llama4.model.args import TransformerModelArgs
23+
from torchtitan.experiments.llama4.model.moe import MoE
24+
except ImportError as e:
25+
raise ImportError(
26+
"torchtitan not installed, see installation instructions at https://github.com/pytorch/torchtitan"
27+
) from e
28+
29+
30+
# initialize model
31+
device = torch.device("cuda")
32+
model_args = TransformerModelArgs(
33+
moe_enabled=True,
34+
num_experts=8,
35+
dim=256,
36+
)
37+
model = MoE(model_args).to(torch.bfloat16).to(device)
38+
init_std = 0.02
39+
model.init_weights(init_std, device)
40+
41+
# module filter function to define which modules to quantize
42+
target_fqns = ["experts"]
43+
44+
45+
def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
46+
for target_fqn in target_fqns:
47+
if target_fqn in cur_fqn:
48+
return True
49+
return False
50+
51+
52+
# quantize the model
53+
config = MoETrainingConfig()
54+
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
55+
56+
# training loop
57+
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
58+
for step in range(10):
59+
batch, seq, dim = 8, 2048, 256
60+
x = torch.randn(
61+
batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device
62+
)
63+
64+
# forward pass
65+
out = model(x)
66+
67+
# compute loss
68+
labels = torch.ones_like(out)
69+
out_loss = F.mse_loss(out, labels)
70+
print(f"step {step} loss: {out_loss.item()}")
71+
72+
# backward pass
73+
out_loss.backward()
74+
optimizer.step()
75+
76+
```
77+
78+
## Requirements
79+
- torchao nightly build
80+
- CUDA compute capability 8.9+ (SM89+)
81+
82+
## Modeling requirements
83+
This prototype is specifically designed to be used on MoE models using
84+
`torch._grouped_mm` to implement expert computation in token-choice routing,
85+
where expert weights are implemented as 3D nn.Parameters with `num_experts` as
86+
the leading dim.
87+
88+
The `MoETrainingConfig` has a module handler registered to it which will
89+
find all nn.Parameters whose parent module matches the module filter function,
90+
and swap their data tensor with a ScaledGroupedMMTensor.
91+
92+
The ScaledGroupedMMTensor is a tensor subclass which overrides the
93+
`torch._grouped_mm` op by dispatching to a differentiable scaled grouped mm,
94+
which performs dynamic float8 rowwise quantization on scaled grouped GEMM
95+
operands in both the forward and backward pass.
96+
97+
For all other ops, ScaledGroupedMMTensor behaves like a regular torch.Tensor.
98+
99+
## Limitations
100+
- Only tested with eager mode, single GPU training so far.
101+
- Composability with parallelisms and `torch.compile` are next steps.
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import torch
2+
from torch import nn
3+
from torch.nn import functional as F
4+
5+
# this feature requires CUDA and SM89+
6+
assert torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
7+
8+
from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig
9+
from torchao.quantization.quant_api import quantize_
10+
11+
# this example uses torchtitan llama4 MoE, see
12+
try:
13+
from torchtitan.experiments.llama4.model.args import TransformerModelArgs
14+
from torchtitan.experiments.llama4.model.moe import MoE
15+
except ImportError as e:
16+
raise ImportError(
17+
"torchtitan not installed, see installation instructions at https://github.com/pytorch/torchtitan"
18+
) from e
19+
20+
21+
# initialize model
22+
device = torch.device("cuda")
23+
model_args = TransformerModelArgs(
24+
moe_enabled=True,
25+
num_experts=8,
26+
dim=256,
27+
)
28+
model = MoE(model_args).to(torch.bfloat16).to(device)
29+
init_std = 0.02
30+
model.init_weights(init_std, device)
31+
32+
# module filter function to define which modules to quantize
33+
target_fqns = ["experts"]
34+
35+
36+
def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
37+
for target_fqn in target_fqns:
38+
if target_fqn in cur_fqn:
39+
return True
40+
return False
41+
42+
43+
# quantize the model
44+
config = MoETrainingConfig()
45+
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
46+
47+
# training loop
48+
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
49+
for step in range(10):
50+
batch, seq, dim = 8, 2048, 256
51+
x = torch.randn(
52+
batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device
53+
)
54+
55+
# forward pass
56+
out = model(x)
57+
58+
# compute loss
59+
labels = torch.ones_like(out)
60+
out_loss = F.mse_loss(out, labels)
61+
print(f"step {step} loss: {out_loss.item()}")
62+
63+
# backward pass
64+
out_loss.backward()
65+
optimizer.step()

0 commit comments

Comments
 (0)