@@ -42,9 +42,15 @@ struct SM90ArchSpec {
4242
4343 // Too many scaling factors in a single block: `block_n > block_k and std::gcd(block_n, block_k) != block_n - block_k`
4444 // Or too many register spills
45- if (block_n > 128 and kernel_type == KernelType::Kernel1D2D and (block_n != 144 and block_n != 160 and block_n != 192 ))
46- return false ;
4745
46+ if (get_env<int >(" ENABLE_SWAPAB" )){
47+ if (block_n != 64 and block_n != 128 and block_n != 256 )
48+ return false ;
49+ }else {
50+ if (block_n > 128 and kernel_type == KernelType::Kernel1D2D and (block_n != 144 and block_n != 160 and block_n != 192 ))
51+ return false ;
52+ }
53+
4854 // Avoid bank conflicts for FP32 output
4955 if (cd_dtype == torch::kFloat and block_n % 16 == 0 )
5056 return false ;
@@ -77,7 +83,13 @@ struct SM90ArchSpec {
7783
7884 static ThreadConfig get_thread_config (const KernelType& kernel_type,
7985 const int & block_m, const int & block_n) {
80- return ThreadConfig::sm90 (128 , (block_m == 64 ? 1 : 2 ) * 128 );
86+ int tile = 64 ;
87+ if (get_env<int >(" ENABLE_SWAPAB" )){
88+ tile = block_n;
89+ }else {
90+ tile = block_m;
91+ }
92+ return ThreadConfig::sm90 (128 , (tile > 64 ? 2 : 1 ) * 128 );
8193 }
8294
8395 static int get_smem_cd_size (const KernelType& kernel_type,
@@ -102,7 +114,8 @@ struct SM90ArchSpec {
102114
103115 static int get_extra_sfb_smem_size (const int & m, const int & n, const int & k,
104116 const int & block_m, const int & block_n, const int & block_k) {
105- const auto & use_uniform_sfb = block_k % block_n == 0 ? 1 : 2 ;
117+ const auto & use_uniform_sfb = get_env<int >(" ENABLE_SWAPAB" ) ? (block_n / 64 ):(block_k % block_n == 0 ? 1 : 2 );
118+
106119 return align<int >(ceil_div (k, block_k) * static_cast <int >(sizeof (float )) * use_uniform_sfb, 8 );
107120 }
108121
0 commit comments