@@ -272,15 +272,18 @@ def get_cutlass_build_flags():
272
272
raise ValueError ("No CUDA version found" )
273
273
274
274
major , minor = map (int , cuda_version .split ("." )[:2 ])
275
- build_sm90a = major > 12 or (major == 12 and minor >= 6 )
276
- build_sm100a = major > 12 or (major == 12 and minor >= 8 )
275
+ build_sm90a = (major , minor ) >= (12 , 6 )
276
+ build_sm100a = (major , minor ) >= (12 , 8 )
277
+ build_sm120a = (major , minor ) >= (12 , 8 )
277
278
278
279
if build_sm90a :
279
280
print (f"CUDA { cuda_version } : Enabling SM90a CUTLASS kernels" )
280
281
if build_sm100a :
281
282
print (f"CUDA { cuda_version } : Enabling SM100a CUTLASS kernels" )
283
+ if build_sm120a :
284
+ print (f"CUDA { cuda_version } : Enabling SM120a CUTLASS kernels" )
282
285
283
- return build_sm90a , build_sm100a
286
+ return build_sm90a , build_sm100a , build_sm120a
284
287
except :
285
288
# Fallback to architecture flags
286
289
cuda_arch_flags = _get_cuda_arch_flags ()
@@ -340,6 +343,11 @@ def __init__(
340
343
self .cmake_args = cmake_args
341
344
342
345
346
+ def remove_items (a : list , b : list ) -> list :
347
+ """Remove items in list b from list a"""
348
+ return [x for x in a if x not in b ]
349
+
350
+
343
351
def get_extensions ():
344
352
# Skip building C++ extensions if USE_CPP is set to "0"
345
353
if use_cpp == "0" :
@@ -454,7 +462,7 @@ def get_extensions():
454
462
excluded_sources = list (
455
463
glob .glob (os .path .join (extensions_dir , "cpu/*.cpp" ), recursive = True )
456
464
)
457
- sources = [ s for s in sources if s not in excluded_sources ]
465
+ sources = remove_items ( sources , excluded_sources )
458
466
459
467
# Collect CUDA source files
460
468
extensions_cuda_dir = os .path .join (extensions_dir , "cuda" )
@@ -498,22 +506,24 @@ def get_extensions():
498
506
rocm_sources = list (
499
507
glob .glob (os .path .join (extensions_rocm_dir , "**/*.cpp" ), recursive = True )
500
508
)
501
- sources = [ s for s in sources if s not in rocm_sources ]
509
+ sources = remove_items ( sources , rocm_sources )
502
510
503
- use_cutlass = False
511
+ use_cutlass = use_cuda and not IS_WINDOWS
504
512
cutlass_90a_sources = None
505
513
cutlass_100a_sources = None
514
+ cutlass_120a_sources = None
506
515
build_for_sm90a = False
507
516
build_for_sm100a = False
508
- if use_cuda and not IS_WINDOWS :
509
- use_cutlass = True
517
+ build_for_sm120a = False
518
+
519
+ if use_cutlass :
510
520
cutlass_dir = os .path .join (third_party_path , "cutlass" )
511
521
cutlass_include_dir = os .path .join (cutlass_dir , "include" )
512
522
cutlass_tools_include_dir = os .path .join (
513
523
cutlass_dir , "tools" , "util" , "include"
514
524
)
515
525
cutlass_extensions_include_dir = os .path .join (cwd , extensions_cuda_dir )
516
- if use_cutlass :
526
+
517
527
extra_compile_args ["nvcc" ].extend (
518
528
[
519
529
"-DTORCHAO_USE_CUTLASS" ,
@@ -533,7 +543,7 @@ def get_extensions():
533
543
]
534
544
)
535
545
536
- build_for_sm90a , build_for_sm100a = get_cutlass_build_flags ()
546
+ build_for_sm90a , build_for_sm100a , build_for_sm120a = get_cutlass_build_flags ()
537
547
# Define sm90a sources
538
548
cutlass_90a_sources = [
539
549
os .path .join (
@@ -557,40 +567,40 @@ def get_extensions():
557
567
"rowwise_scaled_linear_sparse_cutlass_" + dtypes + ".cu" ,
558
568
)
559
569
)
560
- # Always remove sm90a sources from main sources
561
- sources = [s for s in sources if s not in cutlass_90a_sources ]
570
+ sources = remove_items (sources , cutlass_90a_sources )
562
571
563
572
# Always compile mx_fp_cutlass_kernels.cu ONLY with sm100a architecture
564
573
cutlass_100a_sources = [
565
574
os .path .join (
566
575
extensions_cuda_dir ,
567
576
"mx_kernels" ,
568
- "mx_fp_cutlass_kernels .cu" ,
577
+ "mx_fp_cutlass_kernels_sm100a .cu" ,
569
578
),
570
579
]
571
- # Remove from main sources to prevent compilation with other architectures
572
- sources = [
573
- s for s in sources if os .path .basename (s ) != "mx_fp_cutlass_kernels.cu"
580
+ sources = remove_items (sources , cutlass_100a_sources )
581
+
582
+ # Always compile mx_fp_cutlass_kernels.cu ONLY with sm120a architecture
583
+ cutlass_120a_sources = [
584
+ os .path .join (
585
+ extensions_cuda_dir ,
586
+ "mx_kernels" ,
587
+ "mx_fp_cutlass_kernels_sm120a.cu" ,
588
+ ),
574
589
]
590
+ sources = remove_items (sources , cutlass_120a_sources )
575
591
576
592
else :
577
- # Remove CUTLASS-based kernels from the sources list. An
578
- # assumption is that these files will have "cutlass" in its
579
- # name.
593
+ # Remove CUTLASS-based kernels from the sources list. An assumption is that
594
+ # these files will have "cutlass" in its name.
580
595
cutlass_sources = list (
581
596
glob .glob (
582
597
os .path .join (extensions_cuda_dir , "**/*cutlass*.cu" ), recursive = True
583
598
)
584
599
)
585
- sources = [ s for s in sources if s not in cutlass_sources ]
600
+ sources = remove_items ( sources , cutlass_sources )
586
601
587
602
ext_modules = []
588
603
if len (sources ) > 0 :
589
- # Double-check to ensure mx_fp_cutlass_kernels.cu is not in sources
590
- sources = [
591
- s for s in sources if os .path .basename (s ) != "mx_fp_cutlass_kernels.cu"
592
- ]
593
-
594
604
ext_modules .append (
595
605
extension (
596
606
"torchao._C" ,
@@ -643,6 +653,27 @@ def get_extensions():
643
653
)
644
654
)
645
655
656
+ # Only build the cutlass_120a extension if sm120a is in the architecture flags
657
+ if (
658
+ cutlass_120a_sources is not None
659
+ and len (cutlass_120a_sources ) > 0
660
+ and build_for_sm120a
661
+ ):
662
+ cutlass_120a_extra_compile_args = copy .deepcopy (extra_compile_args )
663
+ # Only use sm120a architecture for these sources, ignoring cuda_arch_flags
664
+ cutlass_120a_extra_compile_args ["nvcc" ].append (
665
+ "-gencode=arch=compute_120a,code=sm_120a"
666
+ )
667
+ ext_modules .append (
668
+ extension (
669
+ "torchao._C_cutlass_120a" ,
670
+ cutlass_120a_sources ,
671
+ py_limited_api = True ,
672
+ extra_compile_args = cutlass_120a_extra_compile_args ,
673
+ extra_link_args = extra_link_args ,
674
+ )
675
+ )
676
+
646
677
# Build CMakeLists from /torchao/experimental - additional options become available : TORCHAO_BUILD_CPU_AARCH64, TORCHAO_BUILD_KLEIDIAI, TORCHAO_BUILD_MPS_OPS, TORCHAO_PARALLEL_BACKEND
647
678
if build_macos_arm_auto or os .getenv ("BUILD_TORCHAO_EXPERIMENTAL" ) == "1" :
648
679
build_options = BuildOptions ()
0 commit comments