Skip to content

Commit 710ee87

Browse files
committed
Add triton_tutorial_flash_v2_on_host_tma_ws_oss_blackwell
Summary: Works with hacky compiler implementation of autoWS Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 608f961 commit 710ee87

File tree

2 files changed

+195
-0
lines changed

2 files changed

+195
-0
lines changed

tritonbench/kernels/triton_fused_attention.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1960,6 +1960,7 @@ def _attn_fwd_tma_ws_persistent( # Q, V, desc_k, desc_v, sm_scale, M, Out, #
19601960
]
19611961

19621962

1963+
# on-device TMA
19631964
@triton.autotune(list(filter(keep, configsCutlassBlackwell)), key=["N_CTX"])
19641965
@triton.jit
19651966
def _attn_fwd_tma_ws_persistent_with_dp( # Q, V, desc_k, desc_v, sm_scale, M, Out, #
@@ -2118,6 +2119,146 @@ def _attn_fwd_tma_ws_persistent_with_dp( # Q, V, desc_k, desc_v, sm_scale, M, O
21182119
tile_idx += num_progs
21192120

21202121

2122+
@triton.jit
2123+
def _attn_fwd_subtile(q, k, offs_m, start_n, offs_n, qk_scale, l_i, m_i, acc, v, dtype: tl.constexpr, STAGE: tl.constexpr):
2124+
qk = tl.dot(q, k)
2125+
if STAGE == 2:
2126+
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
2127+
qk = qk * qk_scale + tl.where(mask, 0, -1.0e6)
2128+
m_ij = tl.maximum(m_i, tl.max(qk, 1))
2129+
qk -= m_ij[:, None]
2130+
else:
2131+
m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)
2132+
qk = qk * qk_scale - m_ij[:, None]
2133+
p = tl.math.exp2(qk)
2134+
# -- compute correction factor
2135+
alpha = tl.math.exp2(m_i - m_ij)
2136+
l_ij = tl.sum(p, 1)
2137+
2138+
# -- update output accumulator --
2139+
BM: tl.constexpr = acc.shape[0]
2140+
BN: tl.constexpr = acc.shape[1]
2141+
2142+
acc0, acc1 = acc.reshape([BM, 2, BN//2]).permute(0, 2, 1).split()
2143+
acc0 = acc0 * alpha[:, None]
2144+
acc1 = acc1 * alpha[:, None]
2145+
acc = tl.join(acc0, acc1).permute(0, 2, 1).reshape([BM, BN])
2146+
2147+
# prepare p and v for the dot
2148+
p = p.to(dtype)
2149+
# note that this non transposed v for FP8 is only supported on Blackwell
2150+
acc = tl.dot(p, v, acc)
2151+
# update m_i and l_i
2152+
# place this at the end of the loop to reduce register pressure
2153+
l_i = l_i * alpha + l_ij
2154+
m_i = m_ij
2155+
2156+
return l_i, m_i, acc
2157+
2158+
2159+
@triton.jit
2160+
def _attn_fwd_inner_oss_dp(acc0, acc1, l_i0, l_i1, m_i0, m_i1, q0, q1, #
2161+
desc_k, desc_v, #
2162+
offset_y, dtype: tl.constexpr, start_m, qk_scale, #
2163+
BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, #
2164+
STAGE: tl.constexpr, offs_m0: tl.constexpr, offs_m1: tl.constexpr, #
2165+
offs_n: tl.constexpr, #
2166+
N_CTX: tl.constexpr, warp_specialize: tl.constexpr):
2167+
# range of values handled by this stage
2168+
if STAGE == 1:
2169+
lo, hi = 0, start_m * BLOCK_M
2170+
elif STAGE == 2:
2171+
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
2172+
lo = tl.multiple_of(lo, BLOCK_M)
2173+
# causal = False
2174+
else:
2175+
lo, hi = 0, N_CTX
2176+
offsetkv_y = offset_y + lo
2177+
2178+
# loop over k, v and update accumulator
2179+
for start_n in tl.range(lo, hi, BLOCK_N, warp_specialize=warp_specialize, disallow_acc_multi_buffer=True):
2180+
start_n = tl.multiple_of(start_n, BLOCK_N)
2181+
2182+
k = desc_k.load([offsetkv_y, 0]).T
2183+
v = desc_v.load([offsetkv_y, 0])
2184+
2185+
l_i0, m_i0, acc0 = _attn_fwd_subtile(q0, k, offs_m0, start_n, offs_n, qk_scale, l_i0, m_i0, acc0, v, dtype, STAGE)
2186+
l_i1, m_i1, acc1 = _attn_fwd_subtile(q1, k, offs_m1, start_n, offs_n, qk_scale, l_i1, m_i1, acc1, v, dtype, STAGE)
2187+
2188+
offsetkv_y += BLOCK_N
2189+
2190+
return acc0, acc1, l_i0, l_i1, m_i0, m_i1
2191+
2192+
2193+
#@triton.autotune(configs=list(filter(keep_tma, configs_tma_dp)),
2194+
# key=["N_CTX", "HEAD_DIM", "FP8_OUTPUT", "warp_specialize"])
2195+
@triton.jit
2196+
def _attn_fwd_tma_oss_dp(sm_scale, M, #
2197+
Z, H, desc_q, desc_k, desc_v, desc_o, N_CTX, #
2198+
HEAD_DIM: tl.constexpr, #
2199+
BLOCK_M: tl.constexpr, #
2200+
BLOCK_N: tl.constexpr, #
2201+
FP8_OUTPUT: tl.constexpr, #
2202+
STAGE: tl.constexpr, #
2203+
warp_specialize: tl.constexpr, #
2204+
ENABLE_TMA: tl.constexpr,
2205+
):
2206+
dtype = tl.float8e5 if FP8_OUTPUT else tl.bfloat16
2207+
tl.static_assert(BLOCK_N <= HEAD_DIM)
2208+
start_m = tl.program_id(0)
2209+
off_hz = tl.program_id(1)
2210+
off_z = off_hz // H
2211+
off_h = off_hz % H
2212+
2213+
offset_y = off_z + off_h * N_CTX
2214+
qo_offset_y = offset_y + start_m * BLOCK_M
2215+
# initialize offsets
2216+
offs_m0 = start_m * BLOCK_M + tl.arange(0, BLOCK_M//2)
2217+
offs_m1 = start_m * BLOCK_M + tl.arange(BLOCK_M//2, BLOCK_M)
2218+
offs_n = tl.arange(0, BLOCK_N)
2219+
2220+
m_i0 = tl.zeros([BLOCK_M//2], dtype=tl.float32) - float("inf")
2221+
l_i0 = tl.zeros([BLOCK_M//2], dtype=tl.float32) + 1.0
2222+
acc0 = tl.zeros([BLOCK_M//2, HEAD_DIM], dtype=tl.float32)
2223+
2224+
m_i1 = tl.zeros([BLOCK_M//2], dtype=tl.float32) - float("inf")
2225+
l_i1 = tl.zeros([BLOCK_M//2], dtype=tl.float32) + 1.0
2226+
acc1 = tl.zeros([BLOCK_M//2, HEAD_DIM], dtype=tl.float32)
2227+
2228+
qk_scale = sm_scale
2229+
qk_scale *= 1.44269504 # 1/log(2)
2230+
2231+
q0 = desc_q.load([qo_offset_y, 0])
2232+
q1 = desc_q.load([qo_offset_y + BLOCK_M//2, 0])
2233+
2234+
if STAGE & 1:
2235+
acc0, acc1, l_i0, l_i1, m_i0, m_i1 = _attn_fwd_inner_oss_dp(acc0, acc1, l_i0, l_i1, m_i0, m_i1, q0, q1, #
2236+
desc_k, desc_v, #
2237+
offset_y, dtype, start_m, qk_scale, #
2238+
BLOCK_M, HEAD_DIM, BLOCK_N, #
2239+
4 - STAGE, offs_m0, offs_m1, offs_n, N_CTX, #
2240+
warp_specialize)
2241+
if STAGE & 2:
2242+
acc0, acc1, l_i0, l_i1, m_i0, m_i1 = _attn_fwd_inner_oss_dp(acc0, acc1, l_i0, l_i1, m_i0, m_i1, q0, q1, #
2243+
desc_k, desc_v, #
2244+
offset_y, dtype, start_m, qk_scale, #
2245+
BLOCK_M, HEAD_DIM, BLOCK_N, #
2246+
2, offs_m0, offs_m1, offs_n, N_CTX, #
2247+
warp_specialize)
2248+
2249+
m_i0 += tl.math.log2(l_i0)
2250+
acc0 = acc0 / l_i0[:, None]
2251+
m_ptrs0 = M + off_hz * N_CTX + offs_m0
2252+
tl.store(m_ptrs0, m_i0)
2253+
desc_o.store([qo_offset_y, 0], acc0.to(dtype))
2254+
2255+
m_i1 += tl.math.log2(l_i1)
2256+
acc1 = acc1 / l_i1[:, None]
2257+
m_ptrs1 = M + off_hz * N_CTX + offs_m1
2258+
tl.store(m_ptrs1, m_i1)
2259+
desc_o.store([qo_offset_y+BLOCK_M//2, 0], acc1.to(dtype))
2260+
2261+
21212262
@triton.jit
21222263
def _attn_bwd_preprocess(
21232264
O,
@@ -2480,6 +2621,7 @@ def forward(ctx, q, k, v, causal, sm_scale, baseVariant):
24802621

24812622
# no autotune with fixed BLOCK_N
24822623
if HAS_TMA_DESC is True and torch.version.hip is None:
2624+
# Legacy on-host grid constant TMA
24832625
desc_helper = TmaAutoTuneHelper()
24842626
desc_helper.init_tma_descriptor("k")
24852627
desc_helper.init_tma_descriptor("v")
@@ -2636,6 +2778,17 @@ def grid_tma_persistent(META):
26362778
desc_v = desc_helper.get_tma_descriptor_kernel_param("v")
26372779
desc_o = desc_helper.get_tma_descriptor_kernel_param("o")
26382780

2781+
# For variants using new on-host TMA
2782+
if baseVariant == "on_host_tma_ws_oss":
2783+
from triton.tools.tensor_descriptor import TensorDescriptor
2784+
y_dim = q.shape[0] * q.shape[1] * q.shape[2]
2785+
BLOCK_M = 256
2786+
BLOCK_N = 128
2787+
desc_q = TensorDescriptor(q, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=[BLOCK_M//2, HEAD_DIM_K])
2788+
desc_v = TensorDescriptor(v, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=[BLOCK_N, HEAD_DIM_K])
2789+
desc_k = TensorDescriptor(k, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=[BLOCK_N, HEAD_DIM_K])
2790+
desc_o = TensorDescriptor(o, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=[BLOCK_M//2, HEAD_DIM_K])
2791+
26392792
M = torch.empty(
26402793
(q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32
26412794
)
@@ -2818,6 +2971,24 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]):
28182971
ENABLE_WS=True,
28192972
**extra_kern_args,
28202973
)
2974+
elif baseVariant == "on_host_tma_ws_oss":
2975+
BLOCK_M = 256
2976+
BLOCK_N = 128
2977+
_attn_fwd_tma_oss_dp[grid_tma](
2978+
sm_scale, M, #
2979+
q.shape[0], q.shape[1], #
2980+
desc_q, desc_k, desc_v, desc_o, #
2981+
N_CTX=q.shape[2], #
2982+
HEAD_DIM=HEAD_DIM_K, #
2983+
FP8_OUTPUT=q.dtype == torch.float8_e5m2, #
2984+
STAGE=stage, #
2985+
warp_specialize=True, #
2986+
ENABLE_TMA=True,
2987+
BLOCK_N=BLOCK_N, BLOCK_M=BLOCK_M, #
2988+
num_warps=4,
2989+
num_stages=2,
2990+
#maxnreg=64,
2991+
**extra_kern_args)
28212992

28222993
ctx.save_for_backward(q, k, v, o, M)
28232994
ctx.grid = grid_tma

tritonbench/operators/blackwell_attentions/operator.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,7 @@ def causal_mask(b, h, q_idx, kv_idx):
285285

286286
return lambda: flex_attention(q, k, v, block_mask=block_mask)
287287

288+
# use Meta's warpspec + on device TMA + persistent
288289
@register_benchmark(enabled=False)
289290
def triton_tutorial_flash_v2_tma_ws_persistent_blackwell(
290291
self,
@@ -296,6 +297,29 @@ def triton_tutorial_flash_v2_tma_ws_persistent_blackwell(
296297
q, k, v, self.causal, self.sm_scale, "tma_ws_persistent_blackwell"
297298
)
298299

300+
@register_benchmark(enabled=False)
301+
def triton_tutorial_flash_v2_blackwell(
302+
self,
303+
q: torch.Tensor,
304+
k: torch.Tensor,
305+
v: torch.Tensor,
306+
) -> Callable:
307+
return lambda: triton_tutorial_FA2_opt(
308+
q, k, v, self.causal, self.sm_scale, "base_opt"
309+
)
310+
311+
# use OSS warpspec + on host TMA
312+
@register_benchmark(enabled=False)
313+
def triton_tutorial_flash_v2_on_host_tma_ws_oss_blackwell(
314+
self,
315+
q: torch.Tensor,
316+
k: torch.Tensor,
317+
v: torch.Tensor,
318+
) -> Callable:
319+
return lambda: triton_tutorial_FA2_opt(
320+
q, k, v, self.causal, self.sm_scale, "on_host_tma_ws_oss"
321+
)
322+
299323
# Only works with triton main, forward only.
300324
@register_benchmark(enabled=False)
301325
def gluon_blackwell_tutorial_fwd(

0 commit comments

Comments
 (0)