Skip to content

Commit 3854ad8

Browse files
committed
Fix dt softplus in update function (h/t Junxiong Wang)
1 parent 009bec5 commit 3854ad8

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

mamba_ssm/ops/triton/selective_state_update.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def _selective_scan_update_kernel(
7575
if HAS_DT_BIAS:
7676
dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
7777
if DT_SOFTPLUS:
78-
dt = tl.log(1.0 + tl.exp(dt))
78+
dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
7979
A = tl.load(A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
8080
dA = tl.exp(A * dt[:, None])
8181
B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)

0 commit comments

Comments
 (0)