Skip to content

Commit 798c791

Browse files
authored
Merge pull request #231 from BobQC/main
fix bug for pull request #227: no skip when mask_dtype is float
2 parents 3c8f3b1 + 34ea987 commit 798c791

File tree

1 file changed

+0
-2
lines changed

1 file changed

+0
-2
lines changed

sageattention/triton/attn_qk_int8_per_block.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,6 @@ def _attn_fwd_inner(acc, l_i, m_i, q, q_scale, qo_len, kv_len,
3737
skip = True
3838
else:
3939
mask_block = tl.load(mask_ptrs + start_n * stride_maskn, mask=(offs_m[:, None] < qo_len) & (offs_n[None, :] < kv_len - start_n), other=-1.0e6)
40-
if tl.max(mask_block) == 0 and tl.min(mask_block) == 0:
41-
skip = True
4240
if not skip:
4341
k_mask = offs_n[None, :] < (kv_len - start_n)
4442
k = tl.load(K_ptrs, mask=k_mask)

0 commit comments

Comments
 (0)