Skip to content

Commit faf788a

Browse files
authored
add-to-benchmarks (#2427)
1 parent 7a846d5 commit faf788a

File tree

3 files changed

+84
-25
lines changed

3 files changed

+84
-25
lines changed

benchmarks/float8/bench_matmul.py

Lines changed: 74 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
get_name_to_shapes_iter,
1717
)
1818

19+
from torchao.ops import mx_fp4_bf16
20+
from torchao.prototype.mx_formats.mx_tensor import to_mx
1921
from torchao.testing.training.roofline_utils import get_specs
2022

2123

@@ -62,29 +64,38 @@ def run(
6264
):
6365
device = "cuda"
6466
# TODO(future PR): this is ugly
65-
assert recipe in ("tensorwise", "rowwise", "mxfp8_cublas"), "unsupported"
67+
assert recipe in (
68+
"tensorwise",
69+
"rowwise",
70+
"mxfp8_cublas",
71+
"mxfp4_cutlass",
72+
"nvfp4",
73+
), "unsupported"
74+
use_fp4 = recipe in ("mxfp4_cutlass", "nvfp4")
6675

6776
specs = get_specs()
6877
bf16_peak_tops = specs["bf16_peak_tops"]
6978
fp8_peak_tops = specs["fp8_peak_tops"]
79+
fp4_peak_tops = specs["fp4_peak_tops"]
7080
print(f"gpu_name: {torch.cuda.get_device_name(0)}")
71-
print(f"peak tops: bf16 {bf16_peak_tops:.2e}, fp8 {fp8_peak_tops:.2e}")
72-
81+
print(
82+
f"peak tops: bf16 {bf16_peak_tops:.2e}, fp8 {fp8_peak_tops:.2e}, fp4 {fp4_peak_tops:.2e}"
83+
)
7384
headers = (
7485
"fast_accum",
7586
"name",
7687
"M",
7788
"K",
7889
"N",
79-
"ref_time_s",
80-
"fp8_time_s",
90+
"time_s",
91+
"speedup",
8192
"fp8_speedup",
8293
)
8394
results = []
8495

8596
dtype = torch.bfloat16
8697
name_to_shapes = get_name_to_shapes_iter(shape_gen_name, M, K, N)
87-
fast_accum_vals = [True, False]
98+
fast_accum_vals = [False] if use_fp4 else [True, False]
8899

89100
for idx, (fast_accum, (name, (M, K, N))) in enumerate(
90101
itertools.product(fast_accum_vals, name_to_shapes)
@@ -107,38 +118,82 @@ def run(
107118

108119
del A
109120

110-
# raw float8 matmul (upper bound for what we can achive in eager mode)
111-
# TODO(future): add e5m2
112-
d1, d2, d3 = torch.float8_e4m3fn, torch.float8_e4m3fn, dtype
113-
A = torch.zeros(M, K, device=device, dtype=d1)
114-
B = torch.zeros(K, N, device=device, dtype=d2).t().contiguous().t()
121+
A_hp = torch.randn(M, K, device=device)
122+
B_hp_t = torch.randn(N, K, device=device)
123+
124+
if recipe == "mxfp4_cutlass":
125+
_, A = to_mx(A_hp, torch.float4_e2m1fn_x2, 32)
126+
_, Bt = to_mx(B_hp_t, torch.float4_e2m1fn_x2, 32)
127+
B = Bt.contiguous().T
128+
peak_tops = fp4_peak_tops
129+
elif recipe == "nvfp4":
130+
from torchao.prototype.mx_formats.nvfp4_tensor import nvfp4_quantize
131+
132+
A_scales, A_data = nvfp4_quantize(A_hp, block_size=16)
133+
B_scales, B_data = nvfp4_quantize(B_hp_t, block_size=16)
134+
A = A_data.view(torch.float4_e2m1fn_x2)
135+
B = B_data.view(torch.float4_e2m1fn_x2).T
136+
peak_tops = fp4_peak_tops
137+
else:
138+
# raw float8 matmul (upper bound for what we can achive in eager mode)
139+
# TODO(future): add e5m2
140+
d1, d2, d3 = torch.float8_e4m3fn, torch.float8_e4m3fn, dtype
141+
A = A_hp.to(d1)
142+
B = B_hp_t.to(d2).contiguous().T
143+
peak_tops = fp8_peak_tops
144+
115145
if recipe == "tensorwise":
116146
scale_a = torch.tensor([1.0], device=device)
117147
scale_b = torch.tensor([1.0], device=device)
118148
elif recipe == "rowwise":
119149
scale_a = torch.ones(M, 1, device=device)
120150
scale_b = torch.ones(1, N, device=device)
121-
elif recipe == "mxfp8_cublas":
151+
elif recipe in ("mxfp8_cublas", "mxfp4_cutlass"):
122152
scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu)
123153
scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu)
154+
elif recipe == "nvfp4":
155+
# Use the blockwise scales from nvfp4_quantize
156+
scale_a = A_scales.view(torch.float8_e4m3fn)
157+
scale_b = B_scales.view(torch.float8_e4m3fn)
124158
else:
125159
assert False, f"unknown recipe {recipe}"
126160

127-
def do_matmul(A, B):
161+
def do_matmul_fp8(A, B):
128162
nonlocal scale_a
129163
nonlocal scale_b
130164
return torch._scaled_mm(
131165
A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum
132166
)
133167

134-
fp8_time_sec, fp8_tops_sec, fp8_pct_top_peak = do_benchmarks(
135-
tops, fp8_peak_tops, use_gpu_kernel_time, do_matmul, A, B
168+
def do_matmul_mxfp4(A, B):
169+
nonlocal scale_a
170+
nonlocal scale_b
171+
return mx_fp4_bf16(A, B, scale_a, scale_b)
172+
173+
def do_matmul_nvfp4(A, B):
174+
nonlocal scale_a
175+
nonlocal scale_b
176+
return torch._scaled_mm(A, B, scale_a, scale_b, out_dtype=dtype)
177+
178+
if recipe == "mxfp4_cutlass":
179+
do_matmul = do_matmul_mxfp4
180+
elif recipe == "nvfp4":
181+
do_matmul = do_matmul_nvfp4
182+
else:
183+
do_matmul = do_matmul_fp8
184+
185+
time_sec, tops_sec, pct_top_peak = do_benchmarks(
186+
tops, peak_tops, use_gpu_kernel_time, do_matmul, A, B
136187
)
137188
print(
138-
f"fp8 time_sec {fp8_time_sec:.2E}, tops/sec {fp8_tops_sec:.2E}, pct_peak {fp8_pct_top_peak:.3f}"
189+
f"time_sec {time_sec:.2E}, tops/sec {tops_sec:.2E}, pct_peak {pct_top_peak:.3f}"
139190
)
140191

141-
del A, B, scale_a, scale_b
192+
del A, B
193+
if scale_a is not None:
194+
del scale_a
195+
if scale_b is not None:
196+
del scale_b
142197

143198
results.append(
144199
[
@@ -148,8 +203,8 @@ def do_matmul(A, B):
148203
K,
149204
N,
150205
ref_time_sec,
151-
fp8_time_sec,
152-
ref_time_sec / fp8_time_sec,
206+
time_sec,
207+
ref_time_sec / time_sec,
153208
]
154209
)
155210

benchmarks/float8/utils.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -352,9 +352,6 @@ def get_gpu_kernel_gemm_time_s(f, *args, **kwargs):
352352
)
353353
# there is only 1 key, aten::mm or aten::_scaled_mm, with unit nanoseconds
354354
assert len(data) == 1
355-
if "aten::mm" in data:
356-
return data["aten::mm"] / 1e6 / n_iter
357-
elif "aten::_scaled_mm" in data:
358-
return data["aten::_scaled_mm"] / 1e6 / n_iter
359-
else:
360-
raise AssertionError("unexpected format of data")
355+
key, value = next(iter(data.items()))
356+
assert key in ("aten::mm", "aten::_scaled_mm", "torchao::mx_fp4_bf16")
357+
return value / 1e6 / n_iter

torchao/testing/training/roofline_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,13 @@
5454
# TODO(future): run measurement on hardware
5555
"pct_achievable_mem_bw": 0.92,
5656
},
57+
"NVIDIA GeForce RTX 5090": {
58+
# https://images.nvidia.com/aem-dam/Solutions/geforce/blackwell/nvidia-rtx-blackwell-gpu-architecture.pdf
59+
"bf16_peak_tops": 209.5e12,
60+
"fp8_peak_tops": 419e12,
61+
"fp4_peak_tops": 1676e12,
62+
"peak_mem_bw_bytes_sec": 1.792e15,
63+
},
5764
# TODO(future): more GPU names
5865
}
5966

0 commit comments

Comments
 (0)