16
16
get_name_to_shapes_iter,
17
17
)
18
18
19
+ from torchao.ops import mx_fp4_bf16
20
+ from torchao.prototype.mx_formats.mx_tensor import to_mx
19
21
from torchao.testing.training.roofline_utils import get_specs
20
22
21
23
@@ -62,29 +64,38 @@ def run(
62
64
):
63
65
device = "cuda"
64
66
# TODO(future PR): this is ugly
65
- assert recipe in ("tensorwise", "rowwise", "mxfp8_cublas"), "unsupported"
67
+ assert recipe in (
68
+ "tensorwise",
69
+ "rowwise",
70
+ "mxfp8_cublas",
71
+ "mxfp4_cutlass",
72
+ "nvfp4",
73
+ ), "unsupported"
74
+ use_fp4 = recipe in ("mxfp4_cutlass", "nvfp4")
66
75
67
76
specs = get_specs()
68
77
bf16_peak_tops = specs["bf16_peak_tops"]
69
78
fp8_peak_tops = specs["fp8_peak_tops"]
79
+ fp4_peak_tops = specs["fp4_peak_tops"]
70
80
print(f"gpu_name: {torch.cuda.get_device_name(0)}")
71
- print(f"peak tops: bf16 {bf16_peak_tops:.2e}, fp8 {fp8_peak_tops:.2e}")
72
-
81
+ print(
82
+ f"peak tops: bf16 {bf16_peak_tops:.2e}, fp8 {fp8_peak_tops:.2e}, fp4 {fp4_peak_tops:.2e}"
83
+ )
73
84
headers = (
74
85
"fast_accum",
75
86
"name",
76
87
"M",
77
88
"K",
78
89
"N",
79
- "ref_time_s ",
80
- "fp8_time_s ",
90
+ "time_s ",
91
+ "speedup ",
81
92
"fp8_speedup",
82
93
)
83
94
results = []
84
95
85
96
dtype = torch.bfloat16
86
97
name_to_shapes = get_name_to_shapes_iter(shape_gen_name, M, K, N)
87
- fast_accum_vals = [True, False]
98
+ fast_accum_vals = [False] if use_fp4 else [ True, False]
88
99
89
100
for idx, (fast_accum, (name, (M, K, N))) in enumerate(
90
101
itertools.product(fast_accum_vals, name_to_shapes)
@@ -107,38 +118,82 @@ def run(
107
118
108
119
del A
109
120
110
- # raw float8 matmul (upper bound for what we can achive in eager mode)
111
- # TODO(future): add e5m2
112
- d1, d2, d3 = torch.float8_e4m3fn, torch.float8_e4m3fn, dtype
113
- A = torch.zeros(M, K, device=device, dtype=d1)
114
- B = torch.zeros(K, N, device=device, dtype=d2).t().contiguous().t()
121
+ A_hp = torch.randn(M, K, device=device)
122
+ B_hp_t = torch.randn(N, K, device=device)
123
+
124
+ if recipe == "mxfp4_cutlass":
125
+ _, A = to_mx(A_hp, torch.float4_e2m1fn_x2, 32)
126
+ _, Bt = to_mx(B_hp_t, torch.float4_e2m1fn_x2, 32)
127
+ B = Bt.contiguous().T
128
+ peak_tops = fp4_peak_tops
129
+ elif recipe == "nvfp4":
130
+ from torchao.prototype.mx_formats.nvfp4_tensor import nvfp4_quantize
131
+
132
+ A_scales, A_data = nvfp4_quantize(A_hp, block_size=16)
133
+ B_scales, B_data = nvfp4_quantize(B_hp_t, block_size=16)
134
+ A = A_data.view(torch.float4_e2m1fn_x2)
135
+ B = B_data.view(torch.float4_e2m1fn_x2).T
136
+ peak_tops = fp4_peak_tops
137
+ else:
138
+ # raw float8 matmul (upper bound for what we can achive in eager mode)
139
+ # TODO(future): add e5m2
140
+ d1, d2, d3 = torch.float8_e4m3fn, torch.float8_e4m3fn, dtype
141
+ A = A_hp.to(d1)
142
+ B = B_hp_t.to(d2).contiguous().T
143
+ peak_tops = fp8_peak_tops
144
+
115
145
if recipe == "tensorwise":
116
146
scale_a = torch.tensor([1.0], device=device)
117
147
scale_b = torch.tensor([1.0], device=device)
118
148
elif recipe == "rowwise":
119
149
scale_a = torch.ones(M, 1, device=device)
120
150
scale_b = torch.ones(1, N, device=device)
121
- elif recipe == "mxfp8_cublas":
151
+ elif recipe in ( "mxfp8_cublas", "mxfp4_cutlass") :
122
152
scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu)
123
153
scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu)
154
+ elif recipe == "nvfp4":
155
+ # Use the blockwise scales from nvfp4_quantize
156
+ scale_a = A_scales.view(torch.float8_e4m3fn)
157
+ scale_b = B_scales.view(torch.float8_e4m3fn)
124
158
else:
125
159
assert False, f"unknown recipe {recipe}"
126
160
127
- def do_matmul (A, B):
161
+ def do_matmul_fp8 (A, B):
128
162
nonlocal scale_a
129
163
nonlocal scale_b
130
164
return torch._scaled_mm(
131
165
A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum
132
166
)
133
167
134
- fp8_time_sec, fp8_tops_sec, fp8_pct_top_peak = do_benchmarks(
135
- tops, fp8_peak_tops, use_gpu_kernel_time, do_matmul, A, B
168
+ def do_matmul_mxfp4(A, B):
169
+ nonlocal scale_a
170
+ nonlocal scale_b
171
+ return mx_fp4_bf16(A, B, scale_a, scale_b)
172
+
173
+ def do_matmul_nvfp4(A, B):
174
+ nonlocal scale_a
175
+ nonlocal scale_b
176
+ return torch._scaled_mm(A, B, scale_a, scale_b, out_dtype=dtype)
177
+
178
+ if recipe == "mxfp4_cutlass":
179
+ do_matmul = do_matmul_mxfp4
180
+ elif recipe == "nvfp4":
181
+ do_matmul = do_matmul_nvfp4
182
+ else:
183
+ do_matmul = do_matmul_fp8
184
+
185
+ time_sec, tops_sec, pct_top_peak = do_benchmarks(
186
+ tops, peak_tops, use_gpu_kernel_time, do_matmul, A, B
136
187
)
137
188
print(
138
- f"fp8 time_sec {fp8_time_sec :.2E}, tops/sec {fp8_tops_sec :.2E}, pct_peak {fp8_pct_top_peak :.3f}"
189
+ f"time_sec {time_sec :.2E}, tops/sec {tops_sec :.2E}, pct_peak {pct_top_peak :.3f}"
139
190
)
140
191
141
- del A, B, scale_a, scale_b
192
+ del A, B
193
+ if scale_a is not None:
194
+ del scale_a
195
+ if scale_b is not None:
196
+ del scale_b
142
197
143
198
results.append(
144
199
[
@@ -148,8 +203,8 @@ def do_matmul(A, B):
148
203
K,
149
204
N,
150
205
ref_time_sec,
151
- fp8_time_sec ,
152
- ref_time_sec / fp8_time_sec ,
206
+ time_sec ,
207
+ ref_time_sec / time_sec ,
153
208
]
154
209
)
155
210
0 commit comments