Skip to content

Commit 2898903

Browse files
[float8 moe training] FSDP support (#2413)
* fsdp support in moe training * unwrap args and kwargs * disable func * fsdp working * roll back use_triton flag * add fsdp test for moe training
1 parent e73a142 commit 2898903

File tree

5 files changed

+253
-65
lines changed

5 files changed

+253
-65
lines changed
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
import copy
2+
import os
3+
4+
import pytest
5+
import torch
6+
from torch import distributed as dist
7+
from torch import nn
8+
from torch.distributed._composable.fsdp import fully_shard
9+
from torch.nn import functional as F
10+
11+
# this feature requires CUDA and SM89+
12+
if not torch.cuda.is_available() or torch.cuda.get_device_capability() < (8, 9):
13+
pytest.skip(
14+
"CUDA not available or compute capability < 8.9", allow_module_level=True
15+
)
16+
17+
from torchao.float8.float8_utils import compute_error
18+
from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig
19+
from torchao.prototype.moe_training.tensor import ScaledGroupedMMTensor
20+
from torchao.quantization.quant_api import quantize_
21+
22+
# this test requires torchtitan
23+
try:
24+
from torchtitan.experiments.llama4.model.args import TransformerModelArgs
25+
from torchtitan.experiments.llama4.model.moe import MoE
26+
except ImportError:
27+
import warnings
28+
29+
warnings.warn("torchtitan not installed, skipping MoE tests.")
30+
pytest.skip(allow_module_level=True)
31+
32+
33+
def test_moe_float8_training_fsdp():
34+
assert torch.cuda.is_available()
35+
36+
# setup distributed for fsdp
37+
setup_distributed()
38+
39+
# define model args
40+
target_fqns = ["experts"]
41+
model_args = TransformerModelArgs(
42+
moe_enabled=True,
43+
num_experts=8,
44+
dim=256,
45+
)
46+
init_std = 0.02
47+
device = torch.device("cuda")
48+
49+
# reference bf16 MoE
50+
ref_model = MoE(model_args).to(torch.bfloat16).cuda()
51+
torch.manual_seed(42)
52+
ref_model.init_weights(init_std, device)
53+
54+
# target MoE for testing conversion
55+
model = copy.deepcopy(ref_model)
56+
57+
# assert starting params are identical for both models
58+
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
59+
assert torch.equal(param1, param2)
60+
61+
# convert MoE to float8 training
62+
def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
63+
for target_fqn in target_fqns:
64+
if target_fqn in cur_fqn:
65+
return True
66+
return False
67+
68+
# quantize test model
69+
config = MoETrainingConfig()
70+
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
71+
72+
# validate that only the experts were converted
73+
_validate_model_conversion(
74+
model,
75+
target_fqns=target_fqns,
76+
)
77+
78+
# FSDP2
79+
fully_shard(model)
80+
fully_shard(ref_model)
81+
82+
# inputs
83+
batch, seq, dim = 8, 2048, 256
84+
ref_x = torch.randn(
85+
batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device
86+
)
87+
x = ref_x.detach().clone().requires_grad_(True)
88+
89+
# forward pass
90+
ref_out = ref_model(ref_x)
91+
out = model(x)
92+
93+
# validate output
94+
out_sqnr = compute_error(out, ref_out)
95+
assert out_sqnr.item() >= 30.0, f"SQNR must be >= 30.0, got {out_sqnr.item()}."
96+
97+
# compute loss
98+
labels = torch.ones_like(ref_out)
99+
ref_loss = F.mse_loss(ref_out, labels)
100+
out_loss = F.mse_loss(out, labels)
101+
102+
# backward pass
103+
ref_loss.backward()
104+
out_loss.backward()
105+
106+
# validate input gradient
107+
input_grad_sqnr = compute_error(x.grad, ref_x.grad)
108+
assert input_grad_sqnr.item() >= 30.0, (
109+
f"SQNR must be >= 30.0, got {input_grad_sqnr.item()}."
110+
)
111+
112+
# validate param gradients
113+
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
114+
param_grad_sqnr = compute_error(param1.grad, param2.grad)
115+
assert param_grad_sqnr.item() >= 25.0, (
116+
f"SQNR must be >= 25.0, got {param_grad_sqnr.item()}."
117+
)
118+
119+
dist.destroy_process_group()
120+
121+
122+
def _validate_model_conversion(
123+
root_module: nn.Module,
124+
target_fqns: list[str],
125+
):
126+
def _recursive_validate(
127+
module: nn.Module,
128+
cur_fqn: str,
129+
):
130+
is_allowed_module = cur_fqn in target_fqns
131+
132+
# check current module params
133+
for param_name, param in module.named_parameters(recurse=False):
134+
is_converted_type = isinstance(param, ScaledGroupedMMTensor)
135+
if is_converted_type:
136+
assert is_allowed_module, (
137+
f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}."
138+
)
139+
if not is_allowed_module:
140+
assert not is_converted_type, (
141+
f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}."
142+
)
143+
144+
# recursively check child modules
145+
for child_name, child_module in module.named_children():
146+
child_fqn = f"{cur_fqn}.{child_name}" if cur_fqn else child_name
147+
_recursive_validate(child_module, child_fqn)
148+
149+
_recursive_validate(root_module, "")
150+
151+
152+
def setup_distributed():
153+
rank = int(os.environ["RANK"])
154+
world_size = int(os.environ["WORLD_SIZE"])
155+
dist.init_process_group("nccl", rank=rank, world_size=world_size)
156+
torch.cuda.set_device(rank)

torchao/prototype/moe_training/benchmarks/benchmark_scaled_grouped_mm.py

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,7 @@ class ExperimentConfig:
3131

3232
@dataclass(frozen=True)
3333
class ExperimentResult:
34-
torch_time_us: float
35-
triton_time_us: bool
36-
triton_speedup: float
34+
time_us: float
3735

3836

3937
@dataclass(frozen=True)
@@ -98,46 +96,34 @@ def warmup(func, *args, **kwargs):
9896
for _ in range(10):
9997
func(*args, **kwargs)
10098

101-
def forward_backward(A, B_t, offs, use_triton=True):
99+
def forward_backward(A, B_t, offs):
102100
out = _scaled_grouped_mm(
103101
A,
104102
B_t,
105103
offs=offs,
106104
out_dtype=torch.bfloat16,
107-
use_triton_for_per_group_scales=use_triton,
108105
)
109106
out.sum().backward()
110107
torch.cuda.synchronize()
111108

112109
# benchmark torch
113110
torch_func = torch.compile(forward_backward) if args.compile else forward_backward
114-
warmup(torch_func, A, B_t, offs, use_triton=False)
111+
warmup(torch_func, A, B_t, offs)
115112
start_time_ns = time.perf_counter_ns()
116-
torch_func(A, B_t, offs, use_triton=False)
113+
torch_func(A, B_t, offs)
117114
torch_time_ns = time.perf_counter_ns() - start_time_ns
118-
torch_time_us = torch_time_ns / 1e3
119-
120-
# benchmark triton
121-
warmup(forward_backward, A, B_t, offs, use_triton=True)
122-
start_time_ns = time.perf_counter_ns()
123-
forward_backward(A, B_t, offs, use_triton=True)
124-
triton_time_ns = time.perf_counter_ns() - start_time_ns
125-
triton_time_us = triton_time_ns / 1e3
115+
time_us = torch_time_ns / 1e3
126116

127117
return ExperimentResult(
128-
torch_time_us=round(torch_time_us, 3),
129-
triton_time_us=round(triton_time_us, 3),
130-
triton_speedup=round(torch_time_us / triton_time_us, 3),
118+
time_us=round(time_us, 3),
131119
)
132120

133121

134122
def print_results(experiments: List[Experiment]):
135123
headers = [
136124
"A_shape",
137125
"B_shape",
138-
"torch_time_us",
139-
"triton_time_us",
140-
"triton_speedup",
126+
"time_us",
141127
]
142128
rows = []
143129
for experiment in experiments:
@@ -147,9 +133,7 @@ def print_results(experiments: List[Experiment]):
147133
[
148134
A_shape,
149135
B_shape,
150-
experiment.result.torch_time_us,
151-
experiment.result.triton_time_us,
152-
experiment.result.triton_speedup,
136+
experiment.result.time_us,
153137
]
154138
)
155139
print(tabulate(rows, headers=headers))

torchao/prototype/moe_training/conversion_utils.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,6 @@ class MoETrainingConfig(AOBaseConfig):
2828
For all other ops, ScaledGroupedMMTensor behaves like a regular torch.Tensor.
2929
"""
3030

31-
# temporary config flag for testing/benchmarking, will remove before graduating out of prototype
32-
use_triton_for_per_group_scales: bool = True
33-
3431

3532
@register_quantize_module_handler(MoETrainingConfig)
3633
def _moe_training_transform(
@@ -71,7 +68,6 @@ def _swap_params(
7168
Returns:
7269
nn.Module: The modified module with swapped linear layers.
7370
"""
74-
use_triton = config.use_triton_for_per_group_scales if config is not None else False
7571
if isinstance(module, nn.Parameter) and (
7672
module_filter_fn is None or module_filter_fn(module, "")
7773
):
@@ -80,9 +76,7 @@ def _swap_params(
8076
f"Does not support a root nn.Parameter with children: {module}"
8177
)
8278
if not isinstance(module.data, ScaledGroupedMMTensor):
83-
new_data = ScaledGroupedMMTensor(
84-
module.data, use_triton_for_per_group_scales=use_triton
85-
)
79+
new_data = ScaledGroupedMMTensor(module.data)
8680
return nn.Parameter(new_data, requires_grad=module.requires_grad)
8781
return module
8882

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
)
1717
from torchao.prototype.moe_training.utils import (
1818
_is_column_major,
19-
_to_2d_jagged_float8_tensor_colwise,
20-
_to_2d_jagged_float8_tensor_rowwise,
2119
)
2220

2321

@@ -26,7 +24,6 @@ def _scaled_grouped_mm(
2624
B_t: torch.Tensor,
2725
offs: torch.Tensor,
2826
out_dtype: Optional[torch.dtype] = torch.bfloat16,
29-
use_triton_for_per_group_scales: bool = True,
3027
) -> torch.Tensor:
3128
"""
3229
This function performs dynamic float8 quantization with row-wise scaling
@@ -143,7 +140,6 @@ def forward(
143140
# Store what we need for backward.
144141
ctx.save_for_backward(A, B_fp8_col_major, B_scales, offs)
145142
ctx.out_dtype = out_dtype
146-
ctx.use_triton_for_per_group_scales = use_triton_for_per_group_scales
147143

148144
# Perform scaled grouped GEMM and return result.
149145
# output shape: scaled grouped mm of (M,K) @ (B,K,N) = (M,N)
@@ -167,7 +163,6 @@ def forward(
167163
def backward(ctx, grad_output: torch.Tensor):
168164
A, B_fp8_col_major, B_scales, offs = ctx.saved_tensors
169165
out_dtype = ctx.out_dtype
170-
use_triton_for_per_group_scales = ctx.use_triton_for_per_group_scales
171166

172167
# Convert grad_output to float8, row-major for left operand of grouped GEMM
173168
# needed for grad_A: grad_output @ B
@@ -216,27 +211,16 @@ def backward(ctx, grad_output: torch.Tensor):
216211

217212
# grad_B is a special case. both operands of the grouped gemm will be 2D with offsets determing the "groups."
218213
# Compute scales for grad_output_t and A, which are both 2D tensors with offsets which define the "jagged" groups.
219-
per_group_rowwise_scale_func = (
220-
triton_fp8_row_major_jagged_rowwise_scales
221-
if use_triton_for_per_group_scales
222-
else _to_2d_jagged_float8_tensor_rowwise
223-
)
224-
per_group_colwise_scale_func = (
225-
triton_fp8_col_major_jagged_colwise_scales
226-
if use_triton_for_per_group_scales
227-
else _to_2d_jagged_float8_tensor_colwise
228-
)
229-
230214
grad_output_t_fp8_row_major, grad_output_t_scales = (
231-
per_group_rowwise_scale_func(
215+
triton_fp8_row_major_jagged_rowwise_scales(
232216
grad_output_t_row_major,
233217
offs,
234218
torch.float8_e4m3fn,
235219
round_scales_to_power_of_2=True,
236220
)
237221
)
238222

239-
A_fp8_col_major, A_scales = per_group_colwise_scale_func(
223+
A_fp8_col_major, A_scales = triton_fp8_col_major_jagged_colwise_scales(
240224
A_col_major,
241225
offs,
242226
torch.float8_e4m3fn,

0 commit comments

Comments
 (0)