Skip to content

Commit 844dde7

Browse files
committed
Add triton_tutorial_flash_v2_on_host_tma_ws_oss_blackwell
Summary: Works with hacky compiler implementation of autoWS Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 7740c6d commit 844dde7

File tree

2 files changed

+195
-0
lines changed

2 files changed

+195
-0
lines changed

tritonbench/kernels/triton_fused_attention.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1952,6 +1952,7 @@ def _attn_fwd_tma_ws_persistent( # Q, V, desc_k, desc_v, sm_scale, M, Out, #
19521952
]
19531953

19541954

1955+
# on-device TMA
19551956
@triton.autotune(list(filter(keep, configsCutlassBlackwell)), key=["N_CTX"])
19561957
@triton.jit
19571958
def _attn_fwd_tma_ws_persistent_with_dp( # Q, V, desc_k, desc_v, sm_scale, M, Out, #
@@ -2110,6 +2111,146 @@ def _attn_fwd_tma_ws_persistent_with_dp( # Q, V, desc_k, desc_v, sm_scale, M, O
21102111
tile_idx += num_progs
21112112

21122113

2114+
@triton.jit
2115+
def _attn_fwd_subtile(q, k, offs_m, start_n, offs_n, qk_scale, l_i, m_i, acc, v, dtype: tl.constexpr, STAGE: tl.constexpr):
2116+
qk = tl.dot(q, k)
2117+
if STAGE == 2:
2118+
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
2119+
qk = qk * qk_scale + tl.where(mask, 0, -1.0e6)
2120+
m_ij = tl.maximum(m_i, tl.max(qk, 1))
2121+
qk -= m_ij[:, None]
2122+
else:
2123+
m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)
2124+
qk = qk * qk_scale - m_ij[:, None]
2125+
p = tl.math.exp2(qk)
2126+
# -- compute correction factor
2127+
alpha = tl.math.exp2(m_i - m_ij)
2128+
l_ij = tl.sum(p, 1)
2129+
2130+
# -- update output accumulator --
2131+
BM: tl.constexpr = acc.shape[0]
2132+
BN: tl.constexpr = acc.shape[1]
2133+
2134+
acc0, acc1 = acc.reshape([BM, 2, BN//2]).permute(0, 2, 1).split()
2135+
acc0 = acc0 * alpha[:, None]
2136+
acc1 = acc1 * alpha[:, None]
2137+
acc = tl.join(acc0, acc1).permute(0, 2, 1).reshape([BM, BN])
2138+
2139+
# prepare p and v for the dot
2140+
p = p.to(dtype)
2141+
# note that this non transposed v for FP8 is only supported on Blackwell
2142+
acc = tl.dot(p, v, acc)
2143+
# update m_i and l_i
2144+
# place this at the end of the loop to reduce register pressure
2145+
l_i = l_i * alpha + l_ij
2146+
m_i = m_ij
2147+
2148+
return l_i, m_i, acc
2149+
2150+
2151+
@triton.jit
2152+
def _attn_fwd_inner_oss_dp(acc0, acc1, l_i0, l_i1, m_i0, m_i1, q0, q1, #
2153+
desc_k, desc_v, #
2154+
offset_y, dtype: tl.constexpr, start_m, qk_scale, #
2155+
BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, #
2156+
STAGE: tl.constexpr, offs_m0: tl.constexpr, offs_m1: tl.constexpr, #
2157+
offs_n: tl.constexpr, #
2158+
N_CTX: tl.constexpr, warp_specialize: tl.constexpr):
2159+
# range of values handled by this stage
2160+
if STAGE == 1:
2161+
lo, hi = 0, start_m * BLOCK_M
2162+
elif STAGE == 2:
2163+
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
2164+
lo = tl.multiple_of(lo, BLOCK_M)
2165+
# causal = False
2166+
else:
2167+
lo, hi = 0, N_CTX
2168+
offsetkv_y = offset_y + lo
2169+
2170+
# loop over k, v and update accumulator
2171+
for start_n in tl.range(lo, hi, BLOCK_N, warp_specialize=warp_specialize, disallow_acc_multi_buffer=True):
2172+
start_n = tl.multiple_of(start_n, BLOCK_N)
2173+
2174+
k = desc_k.load([offsetkv_y, 0]).T
2175+
v = desc_v.load([offsetkv_y, 0])
2176+
2177+
l_i0, m_i0, acc0 = _attn_fwd_subtile(q0, k, offs_m0, start_n, offs_n, qk_scale, l_i0, m_i0, acc0, v, dtype, STAGE)
2178+
l_i1, m_i1, acc1 = _attn_fwd_subtile(q1, k, offs_m1, start_n, offs_n, qk_scale, l_i1, m_i1, acc1, v, dtype, STAGE)
2179+
2180+
offsetkv_y += BLOCK_N
2181+
2182+
return acc0, acc1, l_i0, l_i1, m_i0, m_i1
2183+
2184+
2185+
#@triton.autotune(configs=list(filter(keep_tma, configs_tma_dp)),
2186+
# key=["N_CTX", "HEAD_DIM", "FP8_OUTPUT", "warp_specialize"])
2187+
@triton.jit
2188+
def _attn_fwd_tma_oss_dp(sm_scale, M, #
2189+
Z, H, desc_q, desc_k, desc_v, desc_o, N_CTX, #
2190+
HEAD_DIM: tl.constexpr, #
2191+
BLOCK_M: tl.constexpr, #
2192+
BLOCK_N: tl.constexpr, #
2193+
FP8_OUTPUT: tl.constexpr, #
2194+
STAGE: tl.constexpr, #
2195+
warp_specialize: tl.constexpr, #
2196+
ENABLE_TMA: tl.constexpr,
2197+
):
2198+
dtype = tl.float8e5 if FP8_OUTPUT else tl.bfloat16
2199+
tl.static_assert(BLOCK_N <= HEAD_DIM)
2200+
start_m = tl.program_id(0)
2201+
off_hz = tl.program_id(1)
2202+
off_z = off_hz // H
2203+
off_h = off_hz % H
2204+
2205+
offset_y = off_z + off_h * N_CTX
2206+
qo_offset_y = offset_y + start_m * BLOCK_M
2207+
# initialize offsets
2208+
offs_m0 = start_m * BLOCK_M + tl.arange(0, BLOCK_M//2)
2209+
offs_m1 = start_m * BLOCK_M + tl.arange(BLOCK_M//2, BLOCK_M)
2210+
offs_n = tl.arange(0, BLOCK_N)
2211+
2212+
m_i0 = tl.zeros([BLOCK_M//2], dtype=tl.float32) - float("inf")
2213+
l_i0 = tl.zeros([BLOCK_M//2], dtype=tl.float32) + 1.0
2214+
acc0 = tl.zeros([BLOCK_M//2, HEAD_DIM], dtype=tl.float32)
2215+
2216+
m_i1 = tl.zeros([BLOCK_M//2], dtype=tl.float32) - float("inf")
2217+
l_i1 = tl.zeros([BLOCK_M//2], dtype=tl.float32) + 1.0
2218+
acc1 = tl.zeros([BLOCK_M//2, HEAD_DIM], dtype=tl.float32)
2219+
2220+
qk_scale = sm_scale
2221+
qk_scale *= 1.44269504 # 1/log(2)
2222+
2223+
q0 = desc_q.load([qo_offset_y, 0])
2224+
q1 = desc_q.load([qo_offset_y + BLOCK_M//2, 0])
2225+
2226+
if STAGE & 1:
2227+
acc0, acc1, l_i0, l_i1, m_i0, m_i1 = _attn_fwd_inner_oss_dp(acc0, acc1, l_i0, l_i1, m_i0, m_i1, q0, q1, #
2228+
desc_k, desc_v, #
2229+
offset_y, dtype, start_m, qk_scale, #
2230+
BLOCK_M, HEAD_DIM, BLOCK_N, #
2231+
4 - STAGE, offs_m0, offs_m1, offs_n, N_CTX, #
2232+
warp_specialize)
2233+
if STAGE & 2:
2234+
acc0, acc1, l_i0, l_i1, m_i0, m_i1 = _attn_fwd_inner_oss_dp(acc0, acc1, l_i0, l_i1, m_i0, m_i1, q0, q1, #
2235+
desc_k, desc_v, #
2236+
offset_y, dtype, start_m, qk_scale, #
2237+
BLOCK_M, HEAD_DIM, BLOCK_N, #
2238+
2, offs_m0, offs_m1, offs_n, N_CTX, #
2239+
warp_specialize)
2240+
2241+
m_i0 += tl.math.log2(l_i0)
2242+
acc0 = acc0 / l_i0[:, None]
2243+
m_ptrs0 = M + off_hz * N_CTX + offs_m0
2244+
tl.store(m_ptrs0, m_i0)
2245+
desc_o.store([qo_offset_y, 0], acc0.to(dtype))
2246+
2247+
m_i1 += tl.math.log2(l_i1)
2248+
acc1 = acc1 / l_i1[:, None]
2249+
m_ptrs1 = M + off_hz * N_CTX + offs_m1
2250+
tl.store(m_ptrs1, m_i1)
2251+
desc_o.store([qo_offset_y+BLOCK_M//2, 0], acc1.to(dtype))
2252+
2253+
21132254
@triton.jit
21142255
def _attn_bwd_preprocess(
21152256
O,
@@ -2472,6 +2613,7 @@ def forward(ctx, q, k, v, causal, sm_scale, baseVariant):
24722613

24732614
# no autotune with fixed BLOCK_N
24742615
if HAS_TMA_DESC is True and torch.version.hip is None:
2616+
# Legacy on-host grid constant TMA
24752617
desc_helper = TmaAutoTuneHelper()
24762618
desc_helper.init_tma_descriptor("k")
24772619
desc_helper.init_tma_descriptor("v")
@@ -2628,6 +2770,17 @@ def grid_tma_persistent(META):
26282770
desc_v = desc_helper.get_tma_descriptor_kernel_param("v")
26292771
desc_o = desc_helper.get_tma_descriptor_kernel_param("o")
26302772

2773+
# For variants using new on-host TMA
2774+
if baseVariant == "on_host_tma_ws_oss":
2775+
from triton.tools.tensor_descriptor import TensorDescriptor
2776+
y_dim = q.shape[0] * q.shape[1] * q.shape[2]
2777+
BLOCK_M = 256
2778+
BLOCK_N = 128
2779+
desc_q = TensorDescriptor(q, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=[BLOCK_M//2, HEAD_DIM_K])
2780+
desc_v = TensorDescriptor(v, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=[BLOCK_N, HEAD_DIM_K])
2781+
desc_k = TensorDescriptor(k, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=[BLOCK_N, HEAD_DIM_K])
2782+
desc_o = TensorDescriptor(o, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=[BLOCK_M//2, HEAD_DIM_K])
2783+
26312784
M = torch.empty(
26322785
(q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32
26332786
)
@@ -2810,6 +2963,24 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]):
28102963
ENABLE_WS=True,
28112964
**extra_kern_args,
28122965
)
2966+
elif baseVariant == "on_host_tma_ws_oss":
2967+
BLOCK_M = 256
2968+
BLOCK_N = 128
2969+
_attn_fwd_tma_oss_dp[grid_tma](
2970+
sm_scale, M, #
2971+
q.shape[0], q.shape[1], #
2972+
desc_q, desc_k, desc_v, desc_o, #
2973+
N_CTX=q.shape[2], #
2974+
HEAD_DIM=HEAD_DIM_K, #
2975+
FP8_OUTPUT=q.dtype == torch.float8_e5m2, #
2976+
STAGE=stage, #
2977+
warp_specialize=True, #
2978+
ENABLE_TMA=True,
2979+
BLOCK_N=BLOCK_N, BLOCK_M=BLOCK_M, #
2980+
num_warps=4,
2981+
num_stages=2,
2982+
#maxnreg=64,
2983+
**extra_kern_args)
28132984

28142985
ctx.save_for_backward(q, k, v, o, M)
28152986
ctx.grid = grid_tma

tritonbench/operators/blackwell_attentions/operator.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ def causal_mask(b, h, q_idx, kv_idx):
282282

283283
return lambda: flex_attention(q, k, v, block_mask=block_mask)
284284

285+
# use Meta's warpspec + on device TMA + persistent
285286
@register_benchmark(enabled=False)
286287
def triton_tutorial_flash_v2_tma_ws_persistent_blackwell(
287288
self,
@@ -293,6 +294,29 @@ def triton_tutorial_flash_v2_tma_ws_persistent_blackwell(
293294
q, k, v, self.causal, self.sm_scale, "tma_ws_persistent_blackwell"
294295
)
295296

297+
@register_benchmark(enabled=False)
298+
def triton_tutorial_flash_v2_blackwell(
299+
self,
300+
q: torch.Tensor,
301+
k: torch.Tensor,
302+
v: torch.Tensor,
303+
) -> Callable:
304+
return lambda: triton_tutorial_FA2_opt(
305+
q, k, v, self.causal, self.sm_scale, "base_opt"
306+
)
307+
308+
# use OSS warpspec + on host TMA
309+
@register_benchmark(enabled=False)
310+
def triton_tutorial_flash_v2_on_host_tma_ws_oss_blackwell(
311+
self,
312+
q: torch.Tensor,
313+
k: torch.Tensor,
314+
v: torch.Tensor,
315+
) -> Callable:
316+
return lambda: triton_tutorial_FA2_opt(
317+
q, k, v, self.causal, self.sm_scale, "on_host_tma_ws_oss"
318+
)
319+
296320
# Only works with triton main, forward only.
297321
@register_benchmark(enabled=False)
298322
def gluon_blackwell_tutorial_fwd(

0 commit comments

Comments
 (0)