Skip to content

Commit bb3a82a

Browse files
committed
Fix mask in _chunk_scan_chunk_state_bwd_dx that could cause NaN
Only affect cases where sequence length is not a multiple of 256
1 parent 8f42a5e commit bb3a82a

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

mamba_ssm/ops/triton/ssd_combined.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,11 @@ def _chunk_scan_chunk_state_bwd_dx_kernel(
170170
dout = tl.load(dout_ptrs, mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim), other=0.0)
171171
dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32)
172172
cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None])
173-
mask = k + offs_k[None, :] >= offs_m[:, None]
173+
# If we don't have the (k + offs_k[None, :] < K_MAX) mask, for indices outside this range,
174+
# we might have dA_cs_m = 0.0 and dA_cs_k very negative, and tl.exp will return inf.
175+
# Multiplying with cb, which is 0.0 outside the range, will make the result NaN.
176+
# This will cause NaN in acc, and hence NaN in dx and ddt.
177+
mask = (k + offs_k[None, :] >= offs_m[:, None]) & (k + offs_k[None, :] < K_MAX)
174178
cb = tl.where(mask, cb, 0.0)
175179
cb = cb.to(dout_ptr.dtype.element_ty)
176180
acc += tl.dot(cb, dout)

0 commit comments

Comments
 (0)