@@ -1246,8 +1246,19 @@ def matmul_fp8_row(
1246
1246
# View inputs into proper torch fp8 dtype.
1247
1247
if torch .version .cuda :
1248
1248
assert a .dtype in (torch .float8_e4m3fn , torch .float8_e5m2 )
1249
+ elif torch .version .hip :
1250
+ if torch .cuda .get_device_capability () < (9 , 5 ):
1251
+ assert a .dtype in (
1252
+ torch .float8_e4m3fnuz ,
1253
+ torch .float8_e5m2fnuz ,
1254
+ )
1255
+ else :
1256
+ assert a .dtype in (torch .float8_e4m3fn , torch .float8_e5m2 )
1249
1257
else :
1250
- assert a .dtype in (torch .float8_e4m3fnuz , torch .float8_e5m2fnuz )
1258
+ assert a .dtype in (
1259
+ torch .float8_e4m3fnuz ,
1260
+ torch .float8_e5m2fnuz ,
1261
+ )
1251
1262
assert b .dtype == pt_fp8_dtype
1252
1263
M , N , K , m_key , n_key , k_key , c , c_dtype_triton , dot_out_dtype_triton , device = (
1253
1264
prep_matmul (a , b , dot_out_dtype )
@@ -3808,259 +3819,61 @@ def get_full_non_persistent_tuning_space():
3808
3819
3809
3820
3810
3821
MATMUL_CONFIGS_NON_PERSISTENT : list [Config ] = get_full_non_persistent_tuning_space ()
3822
+ # (BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M, SPLIT_K, waves_per_eu, matrix_instr_nonkdim, kpack, num_warps, num_stages)
3823
+ _MATMUL_CONFIG_TUPLES_PINGPONG_4K_8K_16K = [
3824
+ (16 , 16 , 256 , 1 , 1 , 8 , 16 , 2 , 2 , 2 ),
3825
+ (16 , 16 , 256 , 1 , 1 , 0 , 16 , 2 , 2 , 2 ),
3826
+ (32 , 64 , 512 , 1 , 1 , 2 , 16 , 2 , 8 , 2 ),
3827
+ (64 , 64 , 256 , 1 , 1 , 2 , 16 , 2 , 4 , 2 ),
3828
+ (256 , 256 , 128 , 32 , 1 , 2 , 16 , 1 , 8 , 2 ),
3829
+ (256 , 256 , 128 , 2 , 1 , 0 , 32 , 2 , 8 , 2 ),
3830
+ (256 , 256 , 128 , 1 , 1 , 0 , 32 , 2 , 8 , 2 ),
3831
+ (256 , 256 , 128 , 2 , 1 , 0 , 16 , 1 , 8 , 2 ),
3832
+ (256 , 256 , 64 , 2 , 1 , 2 , 16 , 1 , 8 , 2 ),
3833
+ (128 , 256 , 64 , 2 , 1 , 2 , 16 , 1 , 4 , 2 ),
3834
+ (256 , 128 , 128 , 4 , 1 , 0 , 16 , 1 , 8 , 2 ),
3835
+ (128 , 128 , 128 , 1 , 1 , 2 , 16 , 2 , 4 , 2 ),
3836
+ (128 , 128 , 256 , 1 , 1 , 2 , 16 , 2 , 8 , 2 ),
3837
+ (128 , 128 , 64 , 4 , 1 , 2 , 16 , 2 , 4 , 2 ),
3838
+ (128 , 128 , 64 , 1 , 1 , 2 , 16 , 2 , 4 , 2 ),
3839
+ (128 , 64 , 64 , 4 , 1 , 0 , 16 , 2 , 4 , 2 ),
3840
+ (128 , 64 , 64 , 1 , 1 , 0 , 16 , 2 , 4 , 2 ),
3841
+ (256 , 128 , 128 , 1 , 1 , 2 , 16 , 1 , 8 , 2 ),
3842
+ ]
3843
+
3844
+
3845
+ def _should_skip_config (block_k , matrix_instr_nonkdim ):
3846
+ """Skip config if BLOCK_K=64 and matrix_instr_nonkdim=16 on GFX95+"""
3847
+ try :
3848
+ return (
3849
+ block_k == 64
3850
+ and matrix_instr_nonkdim == 16
3851
+ and torch .version .hip is not None
3852
+ and torch .cuda .get_device_capability () >= (9 , 5 )
3853
+ )
3854
+ except RuntimeError :
3855
+ # If no HIP GPUs are available, we can't check device capability
3856
+ # so we don't skip any configs
3857
+ return False
3858
+
3859
+
3811
3860
MATMUL_CONFIGS_NON_PERSISTENT_PINGPONG_4K_8K_16K = [
3812
3861
triton .Config (
3813
3862
{
3814
- "BLOCK_M" : 16 ,
3815
- "BLOCK_N" : 16 ,
3816
- "BLOCK_K" : 256 ,
3817
- "GROUP_M" : 1 ,
3818
- "SPLIT_K" : 1 ,
3819
- "waves_per_eu" : 8 ,
3820
- "matrix_instr_nonkdim" : 16 ,
3821
- "kpack" : 2 ,
3822
- },
3823
- num_warps = 2 ,
3824
- num_stages = 2 ,
3825
- ),
3826
- triton .Config (
3827
- {
3828
- "BLOCK_M" : 16 ,
3829
- "BLOCK_N" : 16 ,
3830
- "BLOCK_K" : 256 ,
3831
- "GROUP_M" : 1 ,
3832
- "SPLIT_K" : 1 ,
3833
- "waves_per_eu" : 0 ,
3834
- "matrix_instr_nonkdim" : 16 ,
3835
- "kpack" : 2 ,
3836
- },
3837
- num_warps = 2 ,
3838
- num_stages = 2 ,
3839
- ),
3840
- triton .Config (
3841
- {
3842
- "BLOCK_M" : 32 ,
3843
- "BLOCK_N" : 64 ,
3844
- "BLOCK_K" : 512 ,
3845
- "GROUP_M" : 1 ,
3846
- "SPLIT_K" : 1 ,
3847
- "waves_per_eu" : 2 ,
3848
- "matrix_instr_nonkdim" : 16 ,
3849
- "kpack" : 2 ,
3850
- },
3851
- num_warps = 8 ,
3852
- num_stages = 2 ,
3853
- ),
3854
- triton .Config (
3855
- {
3856
- "BLOCK_M" : 64 ,
3857
- "BLOCK_N" : 64 ,
3858
- "BLOCK_K" : 256 ,
3859
- "GROUP_M" : 1 ,
3860
- "SPLIT_K" : 1 ,
3861
- "waves_per_eu" : 2 ,
3862
- "matrix_instr_nonkdim" : 16 ,
3863
- "kpack" : 2 ,
3863
+ "BLOCK_M" : block_m ,
3864
+ "BLOCK_N" : block_n ,
3865
+ "BLOCK_K" : block_k ,
3866
+ "GROUP_M" : group_m ,
3867
+ "SPLIT_K" : split_k ,
3868
+ "waves_per_eu" : waves_per_eu ,
3869
+ "matrix_instr_nonkdim" : matrix_instr_nonkdim ,
3870
+ "kpack" : kpack ,
3864
3871
},
3865
- num_warps = 4 ,
3866
- num_stages = 2 ,
3867
- ),
3868
- triton .Config (
3869
- {
3870
- "BLOCK_M" : 256 ,
3871
- "BLOCK_N" : 256 ,
3872
- "BLOCK_K" : 128 ,
3873
- "GROUP_M" : 32 ,
3874
- "SPLIT_K" : 1 ,
3875
- "waves_per_eu" : 2 ,
3876
- "matrix_instr_nonkdim" : 16 ,
3877
- "kpack" : 1 ,
3878
- },
3879
- num_warps = 8 ,
3880
- num_stages = 2 ,
3881
- ),
3882
- triton .Config (
3883
- {
3884
- "BLOCK_M" : 256 ,
3885
- "BLOCK_N" : 256 ,
3886
- "BLOCK_K" : 128 ,
3887
- "GROUP_M" : 2 ,
3888
- "SPLIT_K" : 1 ,
3889
- "waves_per_eu" : 0 ,
3890
- "matrix_instr_nonkdim" : 32 ,
3891
- "kpack" : 2 ,
3892
- },
3893
- num_warps = 8 ,
3894
- num_stages = 2 ,
3895
- ),
3896
- triton .Config (
3897
- {
3898
- "BLOCK_M" : 256 ,
3899
- "BLOCK_N" : 256 ,
3900
- "BLOCK_K" : 128 ,
3901
- "GROUP_M" : 1 ,
3902
- "SPLIT_K" : 1 ,
3903
- "waves_per_eu" : 0 ,
3904
- "matrix_instr_nonkdim" : 32 ,
3905
- "kpack" : 2 ,
3906
- },
3907
- num_warps = 8 ,
3908
- num_stages = 2 ,
3909
- ),
3910
- triton .Config (
3911
- {
3912
- "BLOCK_M" : 256 ,
3913
- "BLOCK_N" : 256 ,
3914
- "BLOCK_K" : 128 ,
3915
- "GROUP_M" : 2 ,
3916
- "SPLIT_K" : 1 ,
3917
- "waves_per_eu" : 0 ,
3918
- "matrix_instr_nonkdim" : 16 ,
3919
- "kpack" : 1 ,
3920
- },
3921
- num_warps = 8 ,
3922
- num_stages = 2 ,
3923
- ),
3924
- triton .Config (
3925
- {
3926
- "BLOCK_M" : 256 ,
3927
- "BLOCK_N" : 256 ,
3928
- "BLOCK_K" : 64 ,
3929
- "GROUP_M" : 2 ,
3930
- "SPLIT_K" : 1 ,
3931
- "waves_per_eu" : 2 ,
3932
- "matrix_instr_nonkdim" : 16 ,
3933
- "kpack" : 1 ,
3934
- },
3935
- num_warps = 8 ,
3936
- num_stages = 2 ,
3937
- ),
3938
- triton .Config (
3939
- {
3940
- "BLOCK_M" : 128 ,
3941
- "BLOCK_N" : 256 ,
3942
- "BLOCK_K" : 64 ,
3943
- "GROUP_M" : 2 ,
3944
- "SPLIT_K" : 1 ,
3945
- "waves_per_eu" : 2 ,
3946
- "matrix_instr_nonkdim" : 16 ,
3947
- "kpack" : 1 ,
3948
- },
3949
- num_warps = 4 ,
3950
- num_stages = 2 ,
3951
- ),
3952
- triton .Config (
3953
- {
3954
- "BLOCK_M" : 256 ,
3955
- "BLOCK_N" : 128 ,
3956
- "BLOCK_K" : 128 ,
3957
- "GROUP_M" : 4 ,
3958
- "SPLIT_K" : 1 ,
3959
- "waves_per_eu" : 0 ,
3960
- "matrix_instr_nonkdim" : 16 ,
3961
- "kpack" : 1 ,
3962
- },
3963
- num_warps = 8 ,
3964
- num_stages = 2 ,
3965
- ),
3966
- triton .Config (
3967
- {
3968
- "BLOCK_M" : 128 ,
3969
- "BLOCK_N" : 128 ,
3970
- "BLOCK_K" : 128 ,
3971
- "GROUP_M" : 1 ,
3972
- "SPLIT_K" : 1 ,
3973
- "waves_per_eu" : 2 ,
3974
- "matrix_instr_nonkdim" : 16 ,
3975
- "kpack" : 2 ,
3976
- },
3977
- num_warps = 4 ,
3978
- num_stages = 2 ,
3979
- ),
3980
- triton .Config (
3981
- {
3982
- "BLOCK_M" : 128 ,
3983
- "BLOCK_N" : 128 ,
3984
- "BLOCK_K" : 256 ,
3985
- "GROUP_M" : 1 ,
3986
- "SPLIT_K" : 1 ,
3987
- "waves_per_eu" : 2 ,
3988
- "matrix_instr_nonkdim" : 16 ,
3989
- "kpack" : 2 ,
3990
- },
3991
- num_warps = 8 ,
3992
- num_stages = 2 ,
3993
- ),
3994
- triton .Config (
3995
- {
3996
- "BLOCK_M" : 128 ,
3997
- "BLOCK_N" : 128 ,
3998
- "BLOCK_K" : 64 ,
3999
- "GROUP_M" : 4 ,
4000
- "SPLIT_K" : 1 ,
4001
- "waves_per_eu" : 2 ,
4002
- "matrix_instr_nonkdim" : 16 ,
4003
- "kpack" : 2 ,
4004
- },
4005
- num_warps = 4 ,
4006
- num_stages = 2 ,
4007
- ),
4008
- triton .Config (
4009
- {
4010
- "BLOCK_M" : 128 ,
4011
- "BLOCK_N" : 128 ,
4012
- "BLOCK_K" : 64 ,
4013
- "GROUP_M" : 1 ,
4014
- "SPLIT_K" : 1 ,
4015
- "waves_per_eu" : 2 ,
4016
- "matrix_instr_nonkdim" : 16 ,
4017
- "kpack" : 2 ,
4018
- },
4019
- num_warps = 4 ,
4020
- num_stages = 2 ,
4021
- ),
4022
- triton .Config (
4023
- {
4024
- "BLOCK_M" : 128 ,
4025
- "BLOCK_N" : 64 ,
4026
- "BLOCK_K" : 64 ,
4027
- "GROUP_M" : 4 ,
4028
- "SPLIT_K" : 1 ,
4029
- "waves_per_eu" : 0 ,
4030
- "matrix_instr_nonkdim" : 16 ,
4031
- "kpack" : 2 ,
4032
- },
4033
- num_warps = 4 ,
4034
- num_stages = 2 ,
4035
- ),
4036
- triton .Config (
4037
- {
4038
- "BLOCK_M" : 128 ,
4039
- "BLOCK_N" : 64 ,
4040
- "BLOCK_K" : 64 ,
4041
- "GROUP_M" : 1 ,
4042
- "SPLIT_K" : 1 ,
4043
- "waves_per_eu" : 0 ,
4044
- "matrix_instr_nonkdim" : 16 ,
4045
- "kpack" : 2 ,
4046
- },
4047
- num_warps = 4 ,
4048
- num_stages = 2 ,
4049
- ),
4050
- triton .Config (
4051
- {
4052
- "BLOCK_M" : 256 ,
4053
- "BLOCK_N" : 128 ,
4054
- "BLOCK_K" : 128 ,
4055
- "GROUP_M" : 1 ,
4056
- "SPLIT_K" : 1 ,
4057
- "waves_per_eu" : 2 ,
4058
- "matrix_instr_nonkdim" : 16 ,
4059
- "kpack" : 1 ,
4060
- },
4061
- num_warps = 8 ,
4062
- num_stages = 2 ,
4063
- ),
3872
+ num_warps = num_warps ,
3873
+ num_stages = num_stages ,
3874
+ )
3875
+ for block_m , block_n , block_k , group_m , split_k , waves_per_eu , matrix_instr_nonkdim , kpack , num_warps , num_stages in _MATMUL_CONFIG_TUPLES_PINGPONG_4K_8K_16K
3876
+ if not _should_skip_config (block_k , matrix_instr_nonkdim )
4064
3877
]
4065
3878
4066
3879
# Set this to enable full autotuning for proper benchmarking.
0 commit comments