Skip to content

Commit e73a142

Browse files
authored
Build mxfp4 kernel for sm120a (#2285)
1 parent eb86177 commit e73a142

File tree

8 files changed

+320
-50
lines changed

8 files changed

+320
-50
lines changed

benchmarks/float8/bench_matmul.py

Lines changed: 43 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
get_name_to_shapes_iter,
1717
)
1818

19+
from torchao.ops import mx_fp4_bf16
20+
from torchao.prototype.mx_formats.mx_tensor import to_mx
1921
from torchao.testing.float8.roofline_utils import get_specs
2022

2123

@@ -62,13 +64,19 @@ def run(
6264
):
6365
device = "cuda"
6466
# TODO(future PR): this is ugly
65-
assert recipe in ("tensorwise", "rowwise", "mxfp8_cublas"), "unsupported"
67+
assert recipe in ("tensorwise", "rowwise", "mxfp8_cublas", "mxfp4_cutlass"), (
68+
"unsupported"
69+
)
70+
use_fp4 = recipe == "mxfp4_cutlass"
6671

6772
specs = get_specs()
6873
bf16_peak_tops = specs["bf16_peak_tops"]
6974
fp8_peak_tops = specs["fp8_peak_tops"]
75+
fp4_peak_tops = specs["fp4_peak_tops"]
7076
print(f"gpu_name: {torch.cuda.get_device_name(0)}")
71-
print(f"peak tops: bf16 {bf16_peak_tops:.2e}, fp8 {fp8_peak_tops:.2e}")
77+
print(
78+
f"peak tops: bf16 {bf16_peak_tops:.2e}, fp8 {fp8_peak_tops:.2e}, fp4 {fp4_peak_tops:.2e}"
79+
)
7280

7381
headers = (
7482
"fast_accum",
@@ -77,14 +85,14 @@ def run(
7785
"K",
7886
"N",
7987
"ref_time_s",
80-
"fp8_time_s",
81-
"fp8_speedup",
88+
"time_s",
89+
"speedup",
8290
)
8391
results = []
8492

8593
dtype = torch.bfloat16
8694
name_to_shapes = get_name_to_shapes_iter(shape_gen_name, M, K, N)
87-
fast_accum_vals = [True, False]
95+
fast_accum_vals = [False] if use_fp4 else [True, False]
8896

8997
for idx, (fast_accum, (name, (M, K, N))) in enumerate(
9098
itertools.product(fast_accum_vals, name_to_shapes)
@@ -107,35 +115,53 @@ def run(
107115

108116
del A
109117

110-
# raw float8 matmul (upper bound for what we can achive in eager mode)
111-
# TODO(future): add e5m2
112-
d1, d2, d3 = torch.float8_e4m3fn, torch.float8_e4m3fn, dtype
113-
A = torch.zeros(M, K, device=device, dtype=d1)
114-
B = torch.zeros(K, N, device=device, dtype=d2).t().contiguous().t()
118+
A_hp = torch.randn(M, K, device=device)
119+
B_hp_t = torch.randn(N, K, device=device)
120+
121+
if use_fp4:
122+
_, A = to_mx(A_hp, torch.float4_e2m1fn_x2, 32)
123+
_, Bt = to_mx(B_hp_t, torch.float4_e2m1fn_x2, 32)
124+
B = Bt.contiguous().T
125+
peak_tops = fp4_peak_tops
126+
else:
127+
# raw float8 matmul (upper bound for what we can achive in eager mode)
128+
# TODO(future): add e5m2
129+
d1, d2, d3 = torch.float8_e4m3fn, torch.float8_e4m3fn, dtype
130+
A = A_hp.to(d1)
131+
B = B_hp_t.to(d2).contiguous().T
132+
peak_tops = fp8_peak_tops
133+
115134
if recipe == "tensorwise":
116135
scale_a = torch.tensor([1.0], device=device)
117136
scale_b = torch.tensor([1.0], device=device)
118137
elif recipe == "rowwise":
119138
scale_a = torch.ones(M, 1, device=device)
120139
scale_b = torch.ones(1, N, device=device)
121-
elif recipe == "mxfp8_cublas":
140+
elif recipe in ("mxfp8_cublas", "mxfp4_cutlass"):
122141
scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu)
123142
scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu)
124143
else:
125144
assert False, f"unknown recipe {recipe}"
126145

127-
def do_matmul(A, B):
146+
def do_matmul_fp8(A, B):
128147
nonlocal scale_a
129148
nonlocal scale_b
130149
return torch._scaled_mm(
131150
A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum
132151
)
133152

134-
fp8_time_sec, fp8_tops_sec, fp8_pct_top_peak = do_benchmarks(
135-
tops, fp8_peak_tops, use_gpu_kernel_time, do_matmul, A, B
153+
def do_matmul_mxfp4(A, B):
154+
nonlocal scale_a
155+
nonlocal scale_b
156+
return mx_fp4_bf16(A, B, scale_a, scale_b)
157+
158+
do_matmul = do_matmul_mxfp4 if use_fp4 else do_matmul_fp8
159+
160+
time_sec, tops_sec, pct_top_peak = do_benchmarks(
161+
tops, peak_tops, use_gpu_kernel_time, do_matmul, A, B
136162
)
137163
print(
138-
f"fp8 time_sec {fp8_time_sec:.2E}, tops/sec {fp8_tops_sec:.2E}, pct_peak {fp8_pct_top_peak:.3f}"
164+
f"time_sec {time_sec:.2E}, tops/sec {tops_sec:.2E}, pct_peak {pct_top_peak:.3f}"
139165
)
140166

141167
del A, B, scale_a, scale_b
@@ -148,8 +174,8 @@ def do_matmul(A, B):
148174
K,
149175
N,
150176
ref_time_sec,
151-
fp8_time_sec,
152-
ref_time_sec / fp8_time_sec,
177+
time_sec,
178+
ref_time_sec / time_sec,
153179
]
154180
)
155181

benchmarks/float8/utils.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -352,9 +352,6 @@ def get_gpu_kernel_gemm_time_s(f, *args, **kwargs):
352352
)
353353
# there is only 1 key, aten::mm or aten::_scaled_mm, with unit nanoseconds
354354
assert len(data) == 1
355-
if "aten::mm" in data:
356-
return data["aten::mm"] / 1e6 / n_iter
357-
elif "aten::_scaled_mm" in data:
358-
return data["aten::_scaled_mm"] / 1e6 / n_iter
359-
else:
360-
raise AssertionError("unexpected format of data")
355+
key, value = next(iter(data.items()))
356+
assert key in ("aten::mm", "aten::_scaled_mm", "torchao::mx_fp4_bf16")
357+
return value / 1e6 / n_iter

setup.py

Lines changed: 56 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -272,15 +272,18 @@ def get_cutlass_build_flags():
272272
raise ValueError("No CUDA version found")
273273

274274
major, minor = map(int, cuda_version.split(".")[:2])
275-
build_sm90a = major > 12 or (major == 12 and minor >= 6)
276-
build_sm100a = major > 12 or (major == 12 and minor >= 8)
275+
build_sm90a = (major, minor) >= (12, 6)
276+
build_sm100a = (major, minor) >= (12, 8)
277+
build_sm120a = (major, minor) >= (12, 8)
277278

278279
if build_sm90a:
279280
print(f"CUDA {cuda_version}: Enabling SM90a CUTLASS kernels")
280281
if build_sm100a:
281282
print(f"CUDA {cuda_version}: Enabling SM100a CUTLASS kernels")
283+
if build_sm120a:
284+
print(f"CUDA {cuda_version}: Enabling SM120a CUTLASS kernels")
282285

283-
return build_sm90a, build_sm100a
286+
return build_sm90a, build_sm100a, build_sm120a
284287
except:
285288
# Fallback to architecture flags
286289
cuda_arch_flags = _get_cuda_arch_flags()
@@ -340,6 +343,11 @@ def __init__(
340343
self.cmake_args = cmake_args
341344

342345

346+
def remove_items(a: list, b: list) -> list:
347+
"""Remove items in list b from list a"""
348+
return [x for x in a if x not in b]
349+
350+
343351
def get_extensions():
344352
# Skip building C++ extensions if USE_CPP is set to "0"
345353
if use_cpp == "0":
@@ -454,7 +462,7 @@ def get_extensions():
454462
excluded_sources = list(
455463
glob.glob(os.path.join(extensions_dir, "cpu/*.cpp"), recursive=True)
456464
)
457-
sources = [s for s in sources if s not in excluded_sources]
465+
sources = remove_items(sources, excluded_sources)
458466

459467
# Collect CUDA source files
460468
extensions_cuda_dir = os.path.join(extensions_dir, "cuda")
@@ -498,22 +506,24 @@ def get_extensions():
498506
rocm_sources = list(
499507
glob.glob(os.path.join(extensions_rocm_dir, "**/*.cpp"), recursive=True)
500508
)
501-
sources = [s for s in sources if s not in rocm_sources]
509+
sources = remove_items(sources, rocm_sources)
502510

503-
use_cutlass = False
511+
use_cutlass = use_cuda and not IS_WINDOWS
504512
cutlass_90a_sources = None
505513
cutlass_100a_sources = None
514+
cutlass_120a_sources = None
506515
build_for_sm90a = False
507516
build_for_sm100a = False
508-
if use_cuda and not IS_WINDOWS:
509-
use_cutlass = True
517+
build_for_sm120a = False
518+
519+
if use_cutlass:
510520
cutlass_dir = os.path.join(third_party_path, "cutlass")
511521
cutlass_include_dir = os.path.join(cutlass_dir, "include")
512522
cutlass_tools_include_dir = os.path.join(
513523
cutlass_dir, "tools", "util", "include"
514524
)
515525
cutlass_extensions_include_dir = os.path.join(cwd, extensions_cuda_dir)
516-
if use_cutlass:
526+
517527
extra_compile_args["nvcc"].extend(
518528
[
519529
"-DTORCHAO_USE_CUTLASS",
@@ -533,7 +543,7 @@ def get_extensions():
533543
]
534544
)
535545

536-
build_for_sm90a, build_for_sm100a = get_cutlass_build_flags()
546+
build_for_sm90a, build_for_sm100a, build_for_sm120a = get_cutlass_build_flags()
537547
# Define sm90a sources
538548
cutlass_90a_sources = [
539549
os.path.join(
@@ -557,40 +567,40 @@ def get_extensions():
557567
"rowwise_scaled_linear_sparse_cutlass_" + dtypes + ".cu",
558568
)
559569
)
560-
# Always remove sm90a sources from main sources
561-
sources = [s for s in sources if s not in cutlass_90a_sources]
570+
sources = remove_items(sources, cutlass_90a_sources)
562571

563572
# Always compile mx_fp_cutlass_kernels.cu ONLY with sm100a architecture
564573
cutlass_100a_sources = [
565574
os.path.join(
566575
extensions_cuda_dir,
567576
"mx_kernels",
568-
"mx_fp_cutlass_kernels.cu",
577+
"mx_fp_cutlass_kernels_sm100a.cu",
569578
),
570579
]
571-
# Remove from main sources to prevent compilation with other architectures
572-
sources = [
573-
s for s in sources if os.path.basename(s) != "mx_fp_cutlass_kernels.cu"
580+
sources = remove_items(sources, cutlass_100a_sources)
581+
582+
# Always compile mx_fp_cutlass_kernels.cu ONLY with sm120a architecture
583+
cutlass_120a_sources = [
584+
os.path.join(
585+
extensions_cuda_dir,
586+
"mx_kernels",
587+
"mx_fp_cutlass_kernels_sm120a.cu",
588+
),
574589
]
590+
sources = remove_items(sources, cutlass_120a_sources)
575591

576592
else:
577-
# Remove CUTLASS-based kernels from the sources list. An
578-
# assumption is that these files will have "cutlass" in its
579-
# name.
593+
# Remove CUTLASS-based kernels from the sources list. An assumption is that
594+
# these files will have "cutlass" in its name.
580595
cutlass_sources = list(
581596
glob.glob(
582597
os.path.join(extensions_cuda_dir, "**/*cutlass*.cu"), recursive=True
583598
)
584599
)
585-
sources = [s for s in sources if s not in cutlass_sources]
600+
sources = remove_items(sources, cutlass_sources)
586601

587602
ext_modules = []
588603
if len(sources) > 0:
589-
# Double-check to ensure mx_fp_cutlass_kernels.cu is not in sources
590-
sources = [
591-
s for s in sources if os.path.basename(s) != "mx_fp_cutlass_kernels.cu"
592-
]
593-
594604
ext_modules.append(
595605
extension(
596606
"torchao._C",
@@ -643,6 +653,27 @@ def get_extensions():
643653
)
644654
)
645655

656+
# Only build the cutlass_120a extension if sm120a is in the architecture flags
657+
if (
658+
cutlass_120a_sources is not None
659+
and len(cutlass_120a_sources) > 0
660+
and build_for_sm120a
661+
):
662+
cutlass_120a_extra_compile_args = copy.deepcopy(extra_compile_args)
663+
# Only use sm120a architecture for these sources, ignoring cuda_arch_flags
664+
cutlass_120a_extra_compile_args["nvcc"].append(
665+
"-gencode=arch=compute_120a,code=sm_120a"
666+
)
667+
ext_modules.append(
668+
extension(
669+
"torchao._C_cutlass_120a",
670+
cutlass_120a_sources,
671+
py_limited_api=True,
672+
extra_compile_args=cutlass_120a_extra_compile_args,
673+
extra_link_args=extra_link_args,
674+
)
675+
)
676+
646677
# Build CMakeLists from /torchao/experimental - additional options become available : TORCHAO_BUILD_CPU_AARCH64, TORCHAO_BUILD_KLEIDIAI, TORCHAO_BUILD_MPS_OPS, TORCHAO_PARALLEL_BACKEND
647678
if build_macos_arm_auto or os.getenv("BUILD_TORCHAO_EXPERIMENTAL") == "1":
648679
build_options = BuildOptions()

test/prototype/mx_formats/test_mx_mm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from torchao.prototype.mx_formats.utils import to_blocked
1515
from torchao.utils import (
1616
TORCH_VERSION_AT_LEAST_2_8,
17-
is_sm_at_least_100,
17+
is_sm_version,
1818
)
1919

2020
if not TORCH_VERSION_AT_LEAST_2_8:
@@ -59,7 +59,8 @@ def run_matrix_test(M: int, K: int, N: int, format) -> float:
5959

6060
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
6161
@pytest.mark.skipif(
62-
not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required for mxfloat8"
62+
not (is_sm_version(10, 0) or is_sm_version(12, 0)),
63+
reason="CUDA capability 10.0 or 12.0 is required for mxfloat8",
6364
)
6465
@pytest.mark.parametrize(
6566
"size",

torchao/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,21 @@
2525

2626
so_files = list(Path(__file__).parent.glob("_C*.so"))
2727
if len(so_files) > 0:
28+
compute_capability = (
29+
torch.cuda.get_device_capability() if torch.cuda.is_available() else None
30+
)
31+
2832
for file in so_files:
33+
# only load architecture-specific target if the current GPU matches that target
34+
if (
35+
("cutlass_90a" in file.name and compute_capability != (9, 0))
36+
or ("cutlass_100a" in file.name and compute_capability != (10, 0))
37+
or ("cutlass_120a" in file.name and compute_capability != (12, 0))
38+
):
39+
continue
40+
2941
torch.ops.load_library(str(file))
42+
3043
from . import ops
3144

3245
# The following library contains CPU kernels from torchao/experimental

0 commit comments

Comments
 (0)