@@ -1952,6 +1952,7 @@ def _attn_fwd_tma_ws_persistent( # Q, V, desc_k, desc_v, sm_scale, M, Out, #
1952
1952
]
1953
1953
1954
1954
1955
+ # on-device TMA
1955
1956
@triton .autotune (list (filter (keep , configsCutlassBlackwell )), key = ["N_CTX" ])
1956
1957
@triton .jit
1957
1958
def _attn_fwd_tma_ws_persistent_with_dp ( # Q, V, desc_k, desc_v, sm_scale, M, Out, #
@@ -2110,6 +2111,146 @@ def _attn_fwd_tma_ws_persistent_with_dp( # Q, V, desc_k, desc_v, sm_scale, M, O
2110
2111
tile_idx += num_progs
2111
2112
2112
2113
2114
+ @triton .jit
2115
+ def _attn_fwd_subtile (q , k , offs_m , start_n , offs_n , qk_scale , l_i , m_i , acc , v , dtype : tl .constexpr , STAGE : tl .constexpr ):
2116
+ qk = tl .dot (q , k )
2117
+ if STAGE == 2 :
2118
+ mask = offs_m [:, None ] >= (start_n + offs_n [None , :])
2119
+ qk = qk * qk_scale + tl .where (mask , 0 , - 1.0e6 )
2120
+ m_ij = tl .maximum (m_i , tl .max (qk , 1 ))
2121
+ qk -= m_ij [:, None ]
2122
+ else :
2123
+ m_ij = tl .maximum (m_i , tl .max (qk , 1 ) * qk_scale )
2124
+ qk = qk * qk_scale - m_ij [:, None ]
2125
+ p = tl .math .exp2 (qk )
2126
+ # -- compute correction factor
2127
+ alpha = tl .math .exp2 (m_i - m_ij )
2128
+ l_ij = tl .sum (p , 1 )
2129
+
2130
+ # -- update output accumulator --
2131
+ BM : tl .constexpr = acc .shape [0 ]
2132
+ BN : tl .constexpr = acc .shape [1 ]
2133
+
2134
+ acc0 , acc1 = acc .reshape ([BM , 2 , BN // 2 ]).permute (0 , 2 , 1 ).split ()
2135
+ acc0 = acc0 * alpha [:, None ]
2136
+ acc1 = acc1 * alpha [:, None ]
2137
+ acc = tl .join (acc0 , acc1 ).permute (0 , 2 , 1 ).reshape ([BM , BN ])
2138
+
2139
+ # prepare p and v for the dot
2140
+ p = p .to (dtype )
2141
+ # note that this non transposed v for FP8 is only supported on Blackwell
2142
+ acc = tl .dot (p , v , acc )
2143
+ # update m_i and l_i
2144
+ # place this at the end of the loop to reduce register pressure
2145
+ l_i = l_i * alpha + l_ij
2146
+ m_i = m_ij
2147
+
2148
+ return l_i , m_i , acc
2149
+
2150
+
2151
+ @triton .jit
2152
+ def _attn_fwd_inner_oss_dp (acc0 , acc1 , l_i0 , l_i1 , m_i0 , m_i1 , q0 , q1 , #
2153
+ desc_k , desc_v , #
2154
+ offset_y , dtype : tl .constexpr , start_m , qk_scale , #
2155
+ BLOCK_M : tl .constexpr , HEAD_DIM : tl .constexpr , BLOCK_N : tl .constexpr , #
2156
+ STAGE : tl .constexpr , offs_m0 : tl .constexpr , offs_m1 : tl .constexpr , #
2157
+ offs_n : tl .constexpr , #
2158
+ N_CTX : tl .constexpr , warp_specialize : tl .constexpr ):
2159
+ # range of values handled by this stage
2160
+ if STAGE == 1 :
2161
+ lo , hi = 0 , start_m * BLOCK_M
2162
+ elif STAGE == 2 :
2163
+ lo , hi = start_m * BLOCK_M , (start_m + 1 ) * BLOCK_M
2164
+ lo = tl .multiple_of (lo , BLOCK_M )
2165
+ # causal = False
2166
+ else :
2167
+ lo , hi = 0 , N_CTX
2168
+ offsetkv_y = offset_y + lo
2169
+
2170
+ # loop over k, v and update accumulator
2171
+ for start_n in tl .range (lo , hi , BLOCK_N , warp_specialize = warp_specialize , disallow_acc_multi_buffer = True ):
2172
+ start_n = tl .multiple_of (start_n , BLOCK_N )
2173
+
2174
+ k = desc_k .load ([offsetkv_y , 0 ]).T
2175
+ v = desc_v .load ([offsetkv_y , 0 ])
2176
+
2177
+ l_i0 , m_i0 , acc0 = _attn_fwd_subtile (q0 , k , offs_m0 , start_n , offs_n , qk_scale , l_i0 , m_i0 , acc0 , v , dtype , STAGE )
2178
+ l_i1 , m_i1 , acc1 = _attn_fwd_subtile (q1 , k , offs_m1 , start_n , offs_n , qk_scale , l_i1 , m_i1 , acc1 , v , dtype , STAGE )
2179
+
2180
+ offsetkv_y += BLOCK_N
2181
+
2182
+ return acc0 , acc1 , l_i0 , l_i1 , m_i0 , m_i1
2183
+
2184
+
2185
+ #@triton.autotune(configs=list(filter(keep_tma, configs_tma_dp)),
2186
+ # key=["N_CTX", "HEAD_DIM", "FP8_OUTPUT", "warp_specialize"])
2187
+ @triton .jit
2188
+ def _attn_fwd_tma_oss_dp (sm_scale , M , #
2189
+ Z , H , desc_q , desc_k , desc_v , desc_o , N_CTX , #
2190
+ HEAD_DIM : tl .constexpr , #
2191
+ BLOCK_M : tl .constexpr , #
2192
+ BLOCK_N : tl .constexpr , #
2193
+ FP8_OUTPUT : tl .constexpr , #
2194
+ STAGE : tl .constexpr , #
2195
+ warp_specialize : tl .constexpr , #
2196
+ ENABLE_TMA : tl .constexpr ,
2197
+ ):
2198
+ dtype = tl .float8e5 if FP8_OUTPUT else tl .bfloat16
2199
+ tl .static_assert (BLOCK_N <= HEAD_DIM )
2200
+ start_m = tl .program_id (0 )
2201
+ off_hz = tl .program_id (1 )
2202
+ off_z = off_hz // H
2203
+ off_h = off_hz % H
2204
+
2205
+ offset_y = off_z + off_h * N_CTX
2206
+ qo_offset_y = offset_y + start_m * BLOCK_M
2207
+ # initialize offsets
2208
+ offs_m0 = start_m * BLOCK_M + tl .arange (0 , BLOCK_M // 2 )
2209
+ offs_m1 = start_m * BLOCK_M + tl .arange (BLOCK_M // 2 , BLOCK_M )
2210
+ offs_n = tl .arange (0 , BLOCK_N )
2211
+
2212
+ m_i0 = tl .zeros ([BLOCK_M // 2 ], dtype = tl .float32 ) - float ("inf" )
2213
+ l_i0 = tl .zeros ([BLOCK_M // 2 ], dtype = tl .float32 ) + 1.0
2214
+ acc0 = tl .zeros ([BLOCK_M // 2 , HEAD_DIM ], dtype = tl .float32 )
2215
+
2216
+ m_i1 = tl .zeros ([BLOCK_M // 2 ], dtype = tl .float32 ) - float ("inf" )
2217
+ l_i1 = tl .zeros ([BLOCK_M // 2 ], dtype = tl .float32 ) + 1.0
2218
+ acc1 = tl .zeros ([BLOCK_M // 2 , HEAD_DIM ], dtype = tl .float32 )
2219
+
2220
+ qk_scale = sm_scale
2221
+ qk_scale *= 1.44269504 # 1/log(2)
2222
+
2223
+ q0 = desc_q .load ([qo_offset_y , 0 ])
2224
+ q1 = desc_q .load ([qo_offset_y + BLOCK_M // 2 , 0 ])
2225
+
2226
+ if STAGE & 1 :
2227
+ acc0 , acc1 , l_i0 , l_i1 , m_i0 , m_i1 = _attn_fwd_inner_oss_dp (acc0 , acc1 , l_i0 , l_i1 , m_i0 , m_i1 , q0 , q1 , #
2228
+ desc_k , desc_v , #
2229
+ offset_y , dtype , start_m , qk_scale , #
2230
+ BLOCK_M , HEAD_DIM , BLOCK_N , #
2231
+ 4 - STAGE , offs_m0 , offs_m1 , offs_n , N_CTX , #
2232
+ warp_specialize )
2233
+ if STAGE & 2 :
2234
+ acc0 , acc1 , l_i0 , l_i1 , m_i0 , m_i1 = _attn_fwd_inner_oss_dp (acc0 , acc1 , l_i0 , l_i1 , m_i0 , m_i1 , q0 , q1 , #
2235
+ desc_k , desc_v , #
2236
+ offset_y , dtype , start_m , qk_scale , #
2237
+ BLOCK_M , HEAD_DIM , BLOCK_N , #
2238
+ 2 , offs_m0 , offs_m1 , offs_n , N_CTX , #
2239
+ warp_specialize )
2240
+
2241
+ m_i0 += tl .math .log2 (l_i0 )
2242
+ acc0 = acc0 / l_i0 [:, None ]
2243
+ m_ptrs0 = M + off_hz * N_CTX + offs_m0
2244
+ tl .store (m_ptrs0 , m_i0 )
2245
+ desc_o .store ([qo_offset_y , 0 ], acc0 .to (dtype ))
2246
+
2247
+ m_i1 += tl .math .log2 (l_i1 )
2248
+ acc1 = acc1 / l_i1 [:, None ]
2249
+ m_ptrs1 = M + off_hz * N_CTX + offs_m1
2250
+ tl .store (m_ptrs1 , m_i1 )
2251
+ desc_o .store ([qo_offset_y + BLOCK_M // 2 , 0 ], acc1 .to (dtype ))
2252
+
2253
+
2113
2254
@triton .jit
2114
2255
def _attn_bwd_preprocess (
2115
2256
O ,
@@ -2472,6 +2613,7 @@ def forward(ctx, q, k, v, causal, sm_scale, baseVariant):
2472
2613
2473
2614
# no autotune with fixed BLOCK_N
2474
2615
if HAS_TMA_DESC is True and torch .version .hip is None :
2616
+ # Legacy on-host grid constant TMA
2475
2617
desc_helper = TmaAutoTuneHelper ()
2476
2618
desc_helper .init_tma_descriptor ("k" )
2477
2619
desc_helper .init_tma_descriptor ("v" )
@@ -2628,6 +2770,17 @@ def grid_tma_persistent(META):
2628
2770
desc_v = desc_helper .get_tma_descriptor_kernel_param ("v" )
2629
2771
desc_o = desc_helper .get_tma_descriptor_kernel_param ("o" )
2630
2772
2773
+ # For variants using new on-host TMA
2774
+ if baseVariant == "on_host_tma_ws_oss" :
2775
+ from triton .tools .tensor_descriptor import TensorDescriptor
2776
+ y_dim = q .shape [0 ] * q .shape [1 ] * q .shape [2 ]
2777
+ BLOCK_M = 256
2778
+ BLOCK_N = 128
2779
+ desc_q = TensorDescriptor (q , shape = [y_dim , HEAD_DIM_K ], strides = [HEAD_DIM_K , 1 ], block_shape = [BLOCK_M // 2 , HEAD_DIM_K ])
2780
+ desc_v = TensorDescriptor (v , shape = [y_dim , HEAD_DIM_K ], strides = [HEAD_DIM_K , 1 ], block_shape = [BLOCK_N , HEAD_DIM_K ])
2781
+ desc_k = TensorDescriptor (k , shape = [y_dim , HEAD_DIM_K ], strides = [HEAD_DIM_K , 1 ], block_shape = [BLOCK_N , HEAD_DIM_K ])
2782
+ desc_o = TensorDescriptor (o , shape = [y_dim , HEAD_DIM_K ], strides = [HEAD_DIM_K , 1 ], block_shape = [BLOCK_M // 2 , HEAD_DIM_K ])
2783
+
2631
2784
M = torch .empty (
2632
2785
(q .shape [0 ], q .shape [1 ], q .shape [2 ]), device = q .device , dtype = torch .float32
2633
2786
)
@@ -2810,6 +2963,24 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]):
2810
2963
ENABLE_WS = True ,
2811
2964
** extra_kern_args ,
2812
2965
)
2966
+ elif baseVariant == "on_host_tma_ws_oss" :
2967
+ BLOCK_M = 256
2968
+ BLOCK_N = 128
2969
+ _attn_fwd_tma_oss_dp [grid_tma ](
2970
+ sm_scale , M , #
2971
+ q .shape [0 ], q .shape [1 ], #
2972
+ desc_q , desc_k , desc_v , desc_o , #
2973
+ N_CTX = q .shape [2 ], #
2974
+ HEAD_DIM = HEAD_DIM_K , #
2975
+ FP8_OUTPUT = q .dtype == torch .float8_e5m2 , #
2976
+ STAGE = stage , #
2977
+ warp_specialize = True , #
2978
+ ENABLE_TMA = True ,
2979
+ BLOCK_N = BLOCK_N , BLOCK_M = BLOCK_M , #
2980
+ num_warps = 4 ,
2981
+ num_stages = 2 ,
2982
+ #maxnreg=64,
2983
+ ** extra_kern_args )
2813
2984
2814
2985
ctx .save_for_backward (q , k , v , o , M )
2815
2986
ctx .grid = grid_tma
0 commit comments