Skip to content

Commit b6bb7dc

Browse files
float8 moe training conversion API prototype (#2275)
stack-info: PR: #2275, branch: danielvegamyhre/stack/1 migrate to quantize and add test work on moe training test
1 parent ab66083 commit b6bb7dc

File tree

14 files changed

+305
-15
lines changed

14 files changed

+305
-15
lines changed

test/prototype/scaled_grouped_mm/test_kernels.py renamed to test/prototype/moe_training/test_kernels.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@
1919
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
2020

2121

22-
from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import (
22+
from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
2323
triton_fp8_col_major_jagged_colwise_scales,
2424
triton_fp8_row_major_jagged_rowwise_scales,
2525
)
26-
from torchao.prototype.scaled_grouped_mm.utils import (
26+
from torchao.prototype.moe_training.utils import (
2727
_is_column_major,
2828
_to_2d_jagged_float8_tensor_colwise,
2929
_to_2d_jagged_float8_tensor_rowwise,

test/prototype/scaled_grouped_mm/test_scaled_grouped_mm.py renamed to test/prototype/moe_training/test_scaled_grouped_mm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from torchao.float8.float8_linear import matmul_with_hp_or_float8_args
2727
from torchao.float8.float8_tensor import LinearMMConfig
2828
from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated
29-
from torchao.prototype.scaled_grouped_mm.scaled_grouped_mm import (
29+
from torchao.prototype.moe_training.scaled_grouped_mm import (
3030
_scaled_grouped_mm,
3131
)
3232
from torchao.testing.utils import skip_if_rocm
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
import copy
2+
3+
import pytest
4+
import torch
5+
from torch import nn
6+
from torch.nn import functional as F
7+
8+
# this feature requires CUDA and SM89+
9+
if not torch.cuda.is_available() or torch.cuda.get_device_capability() < (8, 9):
10+
pytest.skip(
11+
"CUDA not available or compute capability < 8.9", allow_module_level=True
12+
)
13+
14+
from torchao.float8.float8_utils import compute_error
15+
from torchao.prototype.moe_training.conversion_utils import MoETrainingConfig
16+
from torchao.prototype.moe_training.tensor import ScaledGroupedMMTensor
17+
from torchao.quantization.quant_api import quantize_
18+
19+
# this test requires torchtitan
20+
try:
21+
from torchtitan.experiments.llama4.model.args import TransformerModelArgs
22+
from torchtitan.experiments.llama4.model.moe import MoE
23+
except ImportError:
24+
import warnings
25+
26+
warnings.warn("torchtitan not installed, skipping MoE tests.")
27+
pytest.skip(allow_module_level=True)
28+
29+
30+
@pytest.mark.parametrize(
31+
"target_fqns",
32+
[
33+
["experts"],
34+
["does.not.exist"],
35+
],
36+
)
37+
def test_moe_float8_training(target_fqns: list[str]):
38+
model_args = TransformerModelArgs(
39+
moe_enabled=True,
40+
num_experts=8,
41+
dim=256,
42+
)
43+
init_std = 0.02
44+
device = torch.device("cuda")
45+
46+
# reference bf16 MoE
47+
ref_model = MoE(model_args).to(torch.bfloat16).cuda()
48+
torch.manual_seed(42)
49+
ref_model.init_weights(init_std, device)
50+
51+
# target MoE for testing conversion
52+
model = copy.deepcopy(ref_model)
53+
54+
# assert starting params are identical for both models
55+
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
56+
assert torch.equal(param1, param2)
57+
58+
# convert MoE to float8 training
59+
def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
60+
for target_fqn in target_fqns:
61+
if target_fqn in cur_fqn:
62+
return True
63+
return False
64+
65+
# quantize test model
66+
config = MoETrainingConfig()
67+
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
68+
69+
# validate that only the experts were converted
70+
_validate_model_conversion(
71+
model,
72+
target_fqns=target_fqns,
73+
)
74+
75+
# inputs
76+
batch, seq, dim = 8, 2048, 256
77+
ref_x = torch.randn(
78+
batch, seq, dim, dtype=torch.bfloat16, requires_grad=True, device=device
79+
)
80+
x = ref_x.detach().clone().requires_grad_(True)
81+
82+
# forward pass
83+
ref_out = ref_model(ref_x)
84+
out = model(x)
85+
86+
# validate output
87+
out_sqnr = compute_error(out, ref_out)
88+
assert out_sqnr.item() >= 30.0, f"SQNR must be >= 30.0, got {out_sqnr.item()}."
89+
90+
# compute loss
91+
labels = torch.ones_like(ref_out)
92+
ref_loss = F.mse_loss(ref_out, labels)
93+
out_loss = F.mse_loss(out, labels)
94+
95+
# backward pass
96+
ref_loss.backward()
97+
out_loss.backward()
98+
99+
# validate input gradient
100+
input_grad_sqnr = compute_error(x.grad, ref_x.grad)
101+
assert input_grad_sqnr.item() >= 30.0, (
102+
f"SQNR must be >= 30.0, got {input_grad_sqnr.item()}."
103+
)
104+
105+
# validate param gradients
106+
for param1, param2 in zip(model.parameters(), ref_model.parameters()):
107+
param_grad_sqnr = compute_error(param1.grad, param2.grad)
108+
assert param_grad_sqnr.item() >= 25.0, (
109+
f"SQNR must be >= 25.0, got {param_grad_sqnr.item()}."
110+
)
111+
112+
113+
def _validate_model_conversion(
114+
root_module: nn.Module,
115+
target_fqns: list[str],
116+
):
117+
def _recursive_validate(
118+
module: nn.Module,
119+
cur_fqn: str,
120+
):
121+
is_allowed_module = cur_fqn in target_fqns
122+
123+
# check current module params
124+
for param_name, param in module.named_parameters(recurse=False):
125+
is_converted_type = isinstance(param, ScaledGroupedMMTensor)
126+
if is_converted_type:
127+
assert is_allowed_module, (
128+
f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}."
129+
)
130+
if not is_allowed_module:
131+
assert not is_converted_type, (
132+
f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}."
133+
)
134+
135+
# recursively check child modules
136+
for child_name, child_module in module.named_children():
137+
child_fqn = f"{cur_fqn}.{child_name}" if cur_fqn else child_name
138+
_recursive_validate(child_module, child_fqn)
139+
140+
_recursive_validate(root_module, "")
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from torchao.prototype.moe_training.scaled_grouped_mm import _scaled_grouped_mm
2+
3+
__all__ = ["_scaled_grouped_mm"]

torchao/prototype/scaled_grouped_mm/benchmarks/benchmark_kernels.py renamed to torchao/prototype/moe_training/benchmarks/benchmark_kernels.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@
1414
from tabulate import tabulate
1515
from tqdm import tqdm
1616

17-
from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import (
17+
from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
1818
triton_fp8_col_major_jagged_colwise_scales,
1919
triton_fp8_row_major_jagged_rowwise_scales,
2020
)
21-
from torchao.prototype.scaled_grouped_mm.utils import (
21+
from torchao.prototype.moe_training.utils import (
2222
_to_2d_jagged_float8_tensor_colwise,
2323
_to_2d_jagged_float8_tensor_rowwise,
2424
)

torchao/prototype/scaled_grouped_mm/benchmarks/benchmark_scaled_grouped_mm.py renamed to torchao/prototype/moe_training/benchmarks/benchmark_scaled_grouped_mm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from tabulate import tabulate
1515
from tqdm import tqdm
1616

17-
from torchao.prototype.scaled_grouped_mm import _scaled_grouped_mm
17+
from torchao.prototype.moe_training import _scaled_grouped_mm
1818

1919
device = torch.device("cuda")
2020

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
from typing import Callable, Optional
2+
3+
from torch import nn
4+
5+
from torchao.core.config import AOBaseConfig
6+
from torchao.prototype.moe_training.tensor import ScaledGroupedMMTensor
7+
from torchao.quantization.transform_module import (
8+
register_quantize_module_handler,
9+
)
10+
11+
12+
class MoETrainingConfig(AOBaseConfig):
13+
"""
14+
The MoETrainingConfig is specifically designed to be used on MoE models using
15+
`torch._grouped_mm` to implement expert computation in token-choice routing,
16+
where expert weights are implemented as 3D nn.Parameters wit `num_experts` as
17+
the leading dim.
18+
19+
MoETrainingConfig has a module handler registered to it which will
20+
find all nn.Parameters whose parent module matches the module filter function,
21+
and swap their data tensor with a ScaledGroupedMMTensor.
22+
23+
The ScaledGroupedMMTensor is a tensor subclass which overrides the
24+
`torch._grouped_mm` op by dispatching to a differentiable scaled grouped mm,
25+
which performs dynamic float8 rowwise quantization on scaled grouped GEMM
26+
operands in both the forward and backward pass.
27+
28+
For all other ops, ScaledGroupedMMTensor behaves like a regular torch.Tensor.
29+
"""
30+
31+
pass
32+
33+
34+
@register_quantize_module_handler(MoETrainingConfig)
35+
def _moe_training_transform(
36+
module: nn.Module,
37+
config: MoETrainingConfig,
38+
) -> nn.Module:
39+
"""
40+
Swaps `torch.nn.Parameter` data tensor with a ScaledGroupedMMTensor.
41+
42+
Args:
43+
module: Module to modify.
44+
config: MoETrainingConfig which defines how to perform the MoE training transform.
45+
46+
Returns:
47+
nn.Module: The modified module with swapped parameters.
48+
"""
49+
out = _swap_params(module)
50+
return out
51+
52+
53+
def _swap_params(
54+
module: nn.Module,
55+
*,
56+
module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None,
57+
) -> nn.Module:
58+
"""
59+
Recurses through the nn.Module, recursively swapping the data tensor of
60+
each nn.Parameter with a ScaledGroupedMMTensor. Only applies if the module
61+
passed the module_filter_fn, if specified.
62+
63+
Args:
64+
module: Module to modify.
65+
module_filter_fn: If specified, only the `torch.nn.Parameter` subclasses that
66+
that pass the filter function will be swapped. The inputs to the
67+
filter function are the module instance, and the FQN.
68+
69+
Returns:
70+
nn.Module: The modified module with swapped linear layers.
71+
"""
72+
if isinstance(module, nn.Parameter) and (
73+
module_filter_fn is None or module_filter_fn(module, "")
74+
):
75+
if len(list(module.children())) > 0:
76+
raise AssertionError(
77+
f"Does not support a root nn.Parameter with children: {module}"
78+
)
79+
if not isinstance(module.data, ScaledGroupedMMTensor):
80+
new_data = ScaledGroupedMMTensor(module.data)
81+
return nn.Parameter(new_data, requires_grad=module.requires_grad)
82+
return module
83+
84+
root_module = module
85+
86+
def post_order_traversal(
87+
module: nn.Module,
88+
cur_fqn: Optional[str] = None,
89+
parent_module: Optional[nn.Module] = None,
90+
):
91+
if cur_fqn is None:
92+
cur_fqn = ""
93+
94+
for child_module_name, child_module in module.named_children():
95+
if cur_fqn == "":
96+
new_fqn = child_module_name
97+
else:
98+
new_fqn = f"{cur_fqn}.{child_module_name}"
99+
100+
post_order_traversal(child_module, new_fqn, module)
101+
102+
if module_filter_fn is None or module_filter_fn(module, cur_fqn):
103+
for param_name, param in module.named_parameters(recurse=False):
104+
if not isinstance(param.data, ScaledGroupedMMTensor):
105+
new_param = nn.Parameter(
106+
ScaledGroupedMMTensor(param), requires_grad=param.requires_grad
107+
)
108+
setattr(module, param_name, new_param)
109+
print(f"Swapped {cur_fqn}.{param_name} to ScaledGroupedMMTensor")
110+
111+
post_order_traversal(root_module)
112+
return root_module
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import (
1+
from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
22
triton_fp8_col_major_jagged_colwise_scales as triton_fp8_col_major_jagged_colwise_scales,
33
)
4-
from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import (
4+
from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
55
triton_fp8_row_major_jagged_rowwise_scales as triton_fp8_row_major_jagged_rowwise_scales,
66
)

torchao/prototype/scaled_grouped_mm/kernels/jagged_float8_scales.py renamed to torchao/prototype/moe_training/kernels/jagged_float8_scales.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import triton
1717
import triton.language as tl
1818

19-
from torchao.prototype.scaled_grouped_mm.utils import _is_column_major
19+
from torchao.prototype.moe_training.utils import _is_column_major
2020

2121
EPS = 1e-12
2222

torchao/prototype/scaled_grouped_mm/scaled_grouped_mm.py renamed to torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010

1111
from torchao.float8.config import ScalingGranularity
1212
from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated
13-
from torchao.prototype.scaled_grouped_mm.kernels import (
13+
from torchao.prototype.moe_training.kernels import (
1414
triton_fp8_col_major_jagged_colwise_scales,
1515
triton_fp8_row_major_jagged_rowwise_scales,
1616
)
17-
from torchao.prototype.scaled_grouped_mm.utils import _is_column_major
17+
from torchao.prototype.moe_training.utils import _is_column_major
1818

1919

2020
def _scaled_grouped_mm(
@@ -83,7 +83,10 @@ def forward(
8383
assert not _is_column_major(A), "A must be row-major"
8484

8585
# Due to hardware requirements, the right operand in a scaled grouped GEMM must be column-major.
86-
assert _is_column_major(B_t), "B must be column-major"
86+
if not _is_column_major(B_t):
87+
# FSDP will complain if B_t (weights) is not contiguous, we can't require B_t to be column-major.
88+
# TODO: figure out better solution than transposing for each forward pass.
89+
B_t = B_t.transpose(-2, -1).contiguous().transpose(-2, -1)
8790

8891
# Convert high precision input tensor to float8, row-major for left operand of grouped GEMM.
8992
# A shape: (M, K)
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import torch
2+
3+
from torchao.prototype.moe_training import _scaled_grouped_mm
4+
5+
6+
class ScaledGroupedMMTensor(torch.Tensor):
7+
"""
8+
ScaledGroupedMMTensor is a simple tensor subclass that wraps a regular tensor
9+
and overrides the torch._grouped_mm op by dispatching to the
10+
differentiable _scaled_grouped_mm autograd function.
11+
"""
12+
13+
grouped_mm_func_name = "_grouped_mm"
14+
offs_arg_name = "offs"
15+
16+
def __init__(self, data: torch.Tensor):
17+
self._data = data
18+
19+
@classmethod
20+
def __torch_function__(cls, func, types, args, kwargs={}):
21+
if func.__name__ == cls.grouped_mm_func_name:
22+
# Use torchao scaled grouped mm with dynamic quant for
23+
# "2d x 3d with offsets" case (used for routed experts).
24+
# Otherwise, fall back to regular grouped mm.
25+
#
26+
# TODO: support "3d x 3d without offsets" case, which is
27+
# used for shared experts. This is basically the grouped_mm
28+
# kernel handling a bmm.
29+
A, B = args[0], args[1]
30+
A_is_2d = A.dim() == 2
31+
B_is_3d = B.dim() == 3
32+
has_offs = kwargs.get(cls.offs_arg_name) is not None
33+
if A_is_2d and B_is_3d and has_offs:
34+
return _scaled_grouped_mm(*args, **kwargs)
35+
return super().__torch_function__(func, types, args, kwargs)

torchao/prototype/scaled_grouped_mm/__init__.py

Lines changed: 0 additions & 3 deletions
This file was deleted.

0 commit comments

Comments
 (0)