Skip to content

Commit 162361f

Browse files
committed
Mamba test relive bfloat16 tolerence constraint to match update with
update, and small fix in causal_conv1d kernel Signed-off-by: mzusman <mor.zusmann@gmail.com>
1 parent 211fe91 commit 162361f

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

csrc/mamba/causal_conv1d/causal_conv1d.cu

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -446,9 +446,12 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
446446
}
447447
else {
448448
// in case the final state is in between the threads data
449-
reinterpret_cast<vec_t *>(x_vals_load)[1] = smem_exchange[last_thread + 1];
450-
reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[last_thread];
451449
const int offset = ((seqlen - (kWidth - 1)) % (kNElts));
450+
if ((offset + kWidth - 2) >= kNElts){
451+
// do not load to index 1 if we're not gonna read from there
452+
reinterpret_cast<vec_t *>(x_vals_load)[1] = smem_exchange[last_thread + 1];
453+
}
454+
reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[last_thread];
452455
#pragma unroll
453456
for (int w = 0; w < kWidth - 1; ++w){
454457
conv_states[w] = x_vals_load[offset + w ];

tests/kernels/test_mamba_ssm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,7 @@ def test_selective_state_update_with_batch_indices(with_padding, dim, dstate,
555555
device = "cuda"
556556
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
557557
if itype == torch.bfloat16:
558-
rtol, atol = 7e-2, 7e-2
558+
rtol, atol = 1e-1, 1e-1
559559
if torch.version.hip:
560560
atol *= 2
561561
# set seed
@@ -610,8 +610,8 @@ def test_selective_state_update_with_batch_indices(with_padding, dim, dstate,
610610
dt_bias=dt_bias,
611611
dt_softplus=True)
612612

613-
print("Output diff max", (out - out_ref[0]).max())
614-
print("Output diff mean", (out - out_ref[0]).mean())
613+
print("Output diff max", (out[:batch_size] - out_ref).max())
614+
print("Output diff mean", (out[:batch_size] - out_ref).mean())
615615
print("Output state diff max", (state[state_indices, :] - state_ref).max())
616616
print("Output state diff mean",
617617
(state[state_indices, :] - state_ref).mean())

0 commit comments

Comments
 (0)