@@ -1960,6 +1960,7 @@ def _attn_fwd_tma_ws_persistent( # Q, V, desc_k, desc_v, sm_scale, M, Out, #
1960
1960
]
1961
1961
1962
1962
1963
+ # on-device TMA
1963
1964
@triton .autotune (list (filter (keep , configsCutlassBlackwell )), key = ["N_CTX" ])
1964
1965
@triton .jit
1965
1966
def _attn_fwd_tma_ws_persistent_with_dp ( # Q, V, desc_k, desc_v, sm_scale, M, Out, #
@@ -2118,6 +2119,146 @@ def _attn_fwd_tma_ws_persistent_with_dp( # Q, V, desc_k, desc_v, sm_scale, M, O
2118
2119
tile_idx += num_progs
2119
2120
2120
2121
2122
+ @triton .jit
2123
+ def _attn_fwd_subtile (q , k , offs_m , start_n , offs_n , qk_scale , l_i , m_i , acc , v , dtype : tl .constexpr , STAGE : tl .constexpr ):
2124
+ qk = tl .dot (q , k )
2125
+ if STAGE == 2 :
2126
+ mask = offs_m [:, None ] >= (start_n + offs_n [None , :])
2127
+ qk = qk * qk_scale + tl .where (mask , 0 , - 1.0e6 )
2128
+ m_ij = tl .maximum (m_i , tl .max (qk , 1 ))
2129
+ qk -= m_ij [:, None ]
2130
+ else :
2131
+ m_ij = tl .maximum (m_i , tl .max (qk , 1 ) * qk_scale )
2132
+ qk = qk * qk_scale - m_ij [:, None ]
2133
+ p = tl .math .exp2 (qk )
2134
+ # -- compute correction factor
2135
+ alpha = tl .math .exp2 (m_i - m_ij )
2136
+ l_ij = tl .sum (p , 1 )
2137
+
2138
+ # -- update output accumulator --
2139
+ BM : tl .constexpr = acc .shape [0 ]
2140
+ BN : tl .constexpr = acc .shape [1 ]
2141
+
2142
+ acc0 , acc1 = acc .reshape ([BM , 2 , BN // 2 ]).permute (0 , 2 , 1 ).split ()
2143
+ acc0 = acc0 * alpha [:, None ]
2144
+ acc1 = acc1 * alpha [:, None ]
2145
+ acc = tl .join (acc0 , acc1 ).permute (0 , 2 , 1 ).reshape ([BM , BN ])
2146
+
2147
+ # prepare p and v for the dot
2148
+ p = p .to (dtype )
2149
+ # note that this non transposed v for FP8 is only supported on Blackwell
2150
+ acc = tl .dot (p , v , acc )
2151
+ # update m_i and l_i
2152
+ # place this at the end of the loop to reduce register pressure
2153
+ l_i = l_i * alpha + l_ij
2154
+ m_i = m_ij
2155
+
2156
+ return l_i , m_i , acc
2157
+
2158
+
2159
+ @triton .jit
2160
+ def _attn_fwd_inner_oss_dp (acc0 , acc1 , l_i0 , l_i1 , m_i0 , m_i1 , q0 , q1 , #
2161
+ desc_k , desc_v , #
2162
+ offset_y , dtype : tl .constexpr , start_m , qk_scale , #
2163
+ BLOCK_M : tl .constexpr , HEAD_DIM : tl .constexpr , BLOCK_N : tl .constexpr , #
2164
+ STAGE : tl .constexpr , offs_m0 : tl .constexpr , offs_m1 : tl .constexpr , #
2165
+ offs_n : tl .constexpr , #
2166
+ N_CTX : tl .constexpr , warp_specialize : tl .constexpr ):
2167
+ # range of values handled by this stage
2168
+ if STAGE == 1 :
2169
+ lo , hi = 0 , start_m * BLOCK_M
2170
+ elif STAGE == 2 :
2171
+ lo , hi = start_m * BLOCK_M , (start_m + 1 ) * BLOCK_M
2172
+ lo = tl .multiple_of (lo , BLOCK_M )
2173
+ # causal = False
2174
+ else :
2175
+ lo , hi = 0 , N_CTX
2176
+ offsetkv_y = offset_y + lo
2177
+
2178
+ # loop over k, v and update accumulator
2179
+ for start_n in tl .range (lo , hi , BLOCK_N , warp_specialize = warp_specialize , disallow_acc_multi_buffer = True ):
2180
+ start_n = tl .multiple_of (start_n , BLOCK_N )
2181
+
2182
+ k = desc_k .load ([offsetkv_y , 0 ]).T
2183
+ v = desc_v .load ([offsetkv_y , 0 ])
2184
+
2185
+ l_i0 , m_i0 , acc0 = _attn_fwd_subtile (q0 , k , offs_m0 , start_n , offs_n , qk_scale , l_i0 , m_i0 , acc0 , v , dtype , STAGE )
2186
+ l_i1 , m_i1 , acc1 = _attn_fwd_subtile (q1 , k , offs_m1 , start_n , offs_n , qk_scale , l_i1 , m_i1 , acc1 , v , dtype , STAGE )
2187
+
2188
+ offsetkv_y += BLOCK_N
2189
+
2190
+ return acc0 , acc1 , l_i0 , l_i1 , m_i0 , m_i1
2191
+
2192
+
2193
+ #@triton.autotune(configs=list(filter(keep_tma, configs_tma_dp)),
2194
+ # key=["N_CTX", "HEAD_DIM", "FP8_OUTPUT", "warp_specialize"])
2195
+ @triton .jit
2196
+ def _attn_fwd_tma_oss_dp (sm_scale , M , #
2197
+ Z , H , desc_q , desc_k , desc_v , desc_o , N_CTX , #
2198
+ HEAD_DIM : tl .constexpr , #
2199
+ BLOCK_M : tl .constexpr , #
2200
+ BLOCK_N : tl .constexpr , #
2201
+ FP8_OUTPUT : tl .constexpr , #
2202
+ STAGE : tl .constexpr , #
2203
+ warp_specialize : tl .constexpr , #
2204
+ ENABLE_TMA : tl .constexpr ,
2205
+ ):
2206
+ dtype = tl .float8e5 if FP8_OUTPUT else tl .bfloat16
2207
+ tl .static_assert (BLOCK_N <= HEAD_DIM )
2208
+ start_m = tl .program_id (0 )
2209
+ off_hz = tl .program_id (1 )
2210
+ off_z = off_hz // H
2211
+ off_h = off_hz % H
2212
+
2213
+ offset_y = off_z + off_h * N_CTX
2214
+ qo_offset_y = offset_y + start_m * BLOCK_M
2215
+ # initialize offsets
2216
+ offs_m0 = start_m * BLOCK_M + tl .arange (0 , BLOCK_M // 2 )
2217
+ offs_m1 = start_m * BLOCK_M + tl .arange (BLOCK_M // 2 , BLOCK_M )
2218
+ offs_n = tl .arange (0 , BLOCK_N )
2219
+
2220
+ m_i0 = tl .zeros ([BLOCK_M // 2 ], dtype = tl .float32 ) - float ("inf" )
2221
+ l_i0 = tl .zeros ([BLOCK_M // 2 ], dtype = tl .float32 ) + 1.0
2222
+ acc0 = tl .zeros ([BLOCK_M // 2 , HEAD_DIM ], dtype = tl .float32 )
2223
+
2224
+ m_i1 = tl .zeros ([BLOCK_M // 2 ], dtype = tl .float32 ) - float ("inf" )
2225
+ l_i1 = tl .zeros ([BLOCK_M // 2 ], dtype = tl .float32 ) + 1.0
2226
+ acc1 = tl .zeros ([BLOCK_M // 2 , HEAD_DIM ], dtype = tl .float32 )
2227
+
2228
+ qk_scale = sm_scale
2229
+ qk_scale *= 1.44269504 # 1/log(2)
2230
+
2231
+ q0 = desc_q .load ([qo_offset_y , 0 ])
2232
+ q1 = desc_q .load ([qo_offset_y + BLOCK_M // 2 , 0 ])
2233
+
2234
+ if STAGE & 1 :
2235
+ acc0 , acc1 , l_i0 , l_i1 , m_i0 , m_i1 = _attn_fwd_inner_oss_dp (acc0 , acc1 , l_i0 , l_i1 , m_i0 , m_i1 , q0 , q1 , #
2236
+ desc_k , desc_v , #
2237
+ offset_y , dtype , start_m , qk_scale , #
2238
+ BLOCK_M , HEAD_DIM , BLOCK_N , #
2239
+ 4 - STAGE , offs_m0 , offs_m1 , offs_n , N_CTX , #
2240
+ warp_specialize )
2241
+ if STAGE & 2 :
2242
+ acc0 , acc1 , l_i0 , l_i1 , m_i0 , m_i1 = _attn_fwd_inner_oss_dp (acc0 , acc1 , l_i0 , l_i1 , m_i0 , m_i1 , q0 , q1 , #
2243
+ desc_k , desc_v , #
2244
+ offset_y , dtype , start_m , qk_scale , #
2245
+ BLOCK_M , HEAD_DIM , BLOCK_N , #
2246
+ 2 , offs_m0 , offs_m1 , offs_n , N_CTX , #
2247
+ warp_specialize )
2248
+
2249
+ m_i0 += tl .math .log2 (l_i0 )
2250
+ acc0 = acc0 / l_i0 [:, None ]
2251
+ m_ptrs0 = M + off_hz * N_CTX + offs_m0
2252
+ tl .store (m_ptrs0 , m_i0 )
2253
+ desc_o .store ([qo_offset_y , 0 ], acc0 .to (dtype ))
2254
+
2255
+ m_i1 += tl .math .log2 (l_i1 )
2256
+ acc1 = acc1 / l_i1 [:, None ]
2257
+ m_ptrs1 = M + off_hz * N_CTX + offs_m1
2258
+ tl .store (m_ptrs1 , m_i1 )
2259
+ desc_o .store ([qo_offset_y + BLOCK_M // 2 , 0 ], acc1 .to (dtype ))
2260
+
2261
+
2121
2262
@triton .jit
2122
2263
def _attn_bwd_preprocess (
2123
2264
O ,
@@ -2480,6 +2621,7 @@ def forward(ctx, q, k, v, causal, sm_scale, baseVariant):
2480
2621
2481
2622
# no autotune with fixed BLOCK_N
2482
2623
if HAS_TMA_DESC is True and torch .version .hip is None :
2624
+ # Legacy on-host grid constant TMA
2483
2625
desc_helper = TmaAutoTuneHelper ()
2484
2626
desc_helper .init_tma_descriptor ("k" )
2485
2627
desc_helper .init_tma_descriptor ("v" )
@@ -2636,6 +2778,17 @@ def grid_tma_persistent(META):
2636
2778
desc_v = desc_helper .get_tma_descriptor_kernel_param ("v" )
2637
2779
desc_o = desc_helper .get_tma_descriptor_kernel_param ("o" )
2638
2780
2781
+ # For variants using new on-host TMA
2782
+ if baseVariant == "on_host_tma_ws_oss" :
2783
+ from triton .tools .tensor_descriptor import TensorDescriptor
2784
+ y_dim = q .shape [0 ] * q .shape [1 ] * q .shape [2 ]
2785
+ BLOCK_M = 256
2786
+ BLOCK_N = 128
2787
+ desc_q = TensorDescriptor (q , shape = [y_dim , HEAD_DIM_K ], strides = [HEAD_DIM_K , 1 ], block_shape = [BLOCK_M // 2 , HEAD_DIM_K ])
2788
+ desc_v = TensorDescriptor (v , shape = [y_dim , HEAD_DIM_K ], strides = [HEAD_DIM_K , 1 ], block_shape = [BLOCK_N , HEAD_DIM_K ])
2789
+ desc_k = TensorDescriptor (k , shape = [y_dim , HEAD_DIM_K ], strides = [HEAD_DIM_K , 1 ], block_shape = [BLOCK_N , HEAD_DIM_K ])
2790
+ desc_o = TensorDescriptor (o , shape = [y_dim , HEAD_DIM_K ], strides = [HEAD_DIM_K , 1 ], block_shape = [BLOCK_M // 2 , HEAD_DIM_K ])
2791
+
2639
2792
M = torch .empty (
2640
2793
(q .shape [0 ], q .shape [1 ], q .shape [2 ]), device = q .device , dtype = torch .float32
2641
2794
)
@@ -2818,6 +2971,24 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]):
2818
2971
ENABLE_WS = True ,
2819
2972
** extra_kern_args ,
2820
2973
)
2974
+ elif baseVariant == "on_host_tma_ws_oss" :
2975
+ BLOCK_M = 256
2976
+ BLOCK_N = 128
2977
+ _attn_fwd_tma_oss_dp [grid_tma ](
2978
+ sm_scale , M , #
2979
+ q .shape [0 ], q .shape [1 ], #
2980
+ desc_q , desc_k , desc_v , desc_o , #
2981
+ N_CTX = q .shape [2 ], #
2982
+ HEAD_DIM = HEAD_DIM_K , #
2983
+ FP8_OUTPUT = q .dtype == torch .float8_e5m2 , #
2984
+ STAGE = stage , #
2985
+ warp_specialize = True , #
2986
+ ENABLE_TMA = True ,
2987
+ BLOCK_N = BLOCK_N , BLOCK_M = BLOCK_M , #
2988
+ num_warps = 4 ,
2989
+ num_stages = 2 ,
2990
+ #maxnreg=64,
2991
+ ** extra_kern_args )
2821
2992
2822
2993
ctx .save_for_backward (q , k , v , o , M )
2823
2994
ctx .grid = grid_tma
0 commit comments