@@ -129,6 +129,75 @@ def early_config_prune(configs, named_args, dtsize=None, dtype=None, **kwargs):
129
129
if dtype is None :
130
130
dtype = named_args ["c_ptr" ].dtype
131
131
132
+ pruned_configs = []
133
+ for config in configs :
134
+ kw = config .kwargs
135
+ (
136
+ BLOCK_M ,
137
+ BLOCK_N ,
138
+ BLOCK_K ,
139
+ num_stages ,
140
+ use_tma_load_on_scales ,
141
+ ) = (
142
+ kw ["BLOCK_SIZE_M" ],
143
+ kw ["BLOCK_SIZE_N" ],
144
+ kw ["BLOCK_SIZE_K" ],
145
+ config .num_stages ,
146
+ kw .get ("USE_TMA_LOAD_ON_SCALES" , False ),
147
+ )
148
+ G , M , N = (
149
+ named_args ["G" ],
150
+ named_args ["M_BUCKET" ],
151
+ named_args ["N" ],
152
+ )
153
+
154
+ # 1. make sure we have enough smem
155
+ max_shared_memory = driver .active .utils .get_device_properties (device )[
156
+ "max_shared_mem"
157
+ ]
158
+ if torch .version .hip :
159
+ required_shared_memory = BLOCK_N * BLOCK_K * num_stages * dtsize
160
+ else :
161
+ required_shared_memory = (BLOCK_M + BLOCK_N ) * BLOCK_K * num_stages * dtsize
162
+ if required_shared_memory > max_shared_memory :
163
+ continue
164
+
165
+ M_PER_GROUP = M // G
166
+ MIN_M_TILES = 32 if torch .version .hip else 64
167
+ # 2. make sure we don't load M tiles that are too big
168
+ if BLOCK_M > MIN_M_TILES and BLOCK_M > (M_PER_GROUP * 2 ):
169
+ continue
170
+ # 3. make sure we don't load N tiles that are too small
171
+ if BLOCK_M < 128 and BLOCK_M < (M_PER_GROUP // 2 ):
172
+ continue
173
+
174
+ num_sm = driver .active .utils .get_device_properties (device )[
175
+ "multiprocessor_count"
176
+ ]
177
+ N_TILES = (N + BLOCK_N - 1 ) // BLOCK_N
178
+ MIN_N_TILES = 32 if torch .version .hip else 64
179
+ # 4. make sure we don't load N tiles that are too big
180
+ if BLOCK_N > MIN_N_TILES and M * N_TILES < num_sm :
181
+ continue
182
+ # 5. make sure we don't load N tiles that are too small
183
+ if BLOCK_N < 128 and M * N_TILES > 2 * num_sm :
184
+ continue
185
+ if dtsize >= 2 :
186
+ if use_tma_load_on_scales :
187
+ continue
188
+ pruned_configs .append (config )
189
+
190
+ return pruned_configs
191
+
192
+
193
+ def early_config_prune_ws (configs , named_args , dtsize = None , dtype = None , ** kwargs ):
194
+ device = torch .cuda .current_device ()
195
+ # BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
196
+ if dtsize is None :
197
+ dtsize = named_args ["c_ptr" ].element_size ()
198
+ if dtype is None :
199
+ dtype = named_args ["c_ptr" ].dtype
200
+
132
201
pruned_configs = []
133
202
for config in configs :
134
203
kw = config .kwargs
@@ -384,7 +453,7 @@ def _fbgemm_grouped_gemm(
384
453
@triton .autotune (
385
454
configs = _NV_WS_CONFIGS ,
386
455
key = ["G" , "M_BUCKET" , "N" , "K" ],
387
- prune_configs_by = {"early_config_prune" : early_config_prune },
456
+ prune_configs_by = {"early_config_prune" : early_config_prune_ws },
388
457
restore_value = ["c_ptr" ], # restore for scatter_add fusion
389
458
)
390
459
@triton .jit
@@ -712,7 +781,7 @@ def _fbgemm_grouped_gemm_fp8_rowwise(
712
781
key = ["G" , "M_BUCKET" , "N" , "K" ],
713
782
prune_configs_by = {
714
783
"early_config_prune" : functools .partial (
715
- early_config_prune , dtype = TT_FP8_DTYPE , dtsize = 1
784
+ early_config_prune_ws , dtype = TT_FP8_DTYPE , dtsize = 1
716
785
)
717
786
},
718
787
restore_value = ["c_ptr" ], # restore for scatter_add fusion
0 commit comments