Skip to content

Commit 68181f9

Browse files
mzusmans.kochetkov
authored andcommitted
[BugFix][Kernel] Fix Illegal memory access in causal_conv1d in H100 (vllm-project#9838)
Signed-off-by: mzusman <mor.zusmann@gmail.com> Signed-off-by: s.kochetkov <s.m.kochetkov@tcsbank.ru>
1 parent b124e9c commit 68181f9

File tree

3 files changed

+40
-7
lines changed

3 files changed

+40
-7
lines changed

csrc/mamba/causal_conv1d/causal_conv1d.cu

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,31 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
418418
typename Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, seqlen - chunk * kChunkSize);
419419
}
420420
out += kChunkSize;
421+
422+
int final_state_position = ((seqlen - (kWidth - 1)) - (n_chunks - 1) * kChunkSize);
423+
// in case the final state is separated between the last "smem_exchange" and
424+
// and the one before it (chunk = n_chunks - 1 and chunk = n_chunks - 2),
425+
// (which occurs when `final_state_position` is a non-positivie index)
426+
// we load the correct data from smem_exchange from both chunks, the last chunk iteration and the one before it
427+
if (final_state_position < 0 && seqlen > kWidth){
428+
input_t vals_load[kNElts] = {0};
429+
if ((chunk == n_chunks - 2) && (tidx == kNThreads - 1)){
430+
// chunk = n_chunks - 2, a segment of the final state sits in the last index
431+
reinterpret_cast<vec_t *>(vals_load)[0] = smem_exchange[kNThreads - 1];
432+
#pragma unroll
433+
for (int w = 0; w < -final_state_position; ++w){
434+
conv_states[w] = vals_load[kNElts + final_state_position + w];
435+
}
436+
}
437+
if ((chunk == n_chunks - 1) && tidx == 0){
438+
// chunk = n_chunks - 1, the second segment of the final state first positions
439+
reinterpret_cast<vec_t *>(vals_load)[0] = smem_exchange[0];
440+
for (int w = -final_state_position; w < kWidth - 1; ++w){
441+
conv_states[w] = vals_load[w + final_state_position];
442+
}
443+
return;
444+
}
445+
}
421446
}
422447
// Final state is stored in the smem_exchange last token slot,
423448
// in case seqlen < kWidth, we would need to take the final state from the
@@ -446,9 +471,14 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
446471
}
447472
else {
448473
// in case the final state is in between the threads data
449-
reinterpret_cast<vec_t *>(x_vals_load)[1] = smem_exchange[last_thread + 1];
450-
reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[last_thread];
451474
const int offset = ((seqlen - (kWidth - 1)) % (kNElts));
475+
if ((offset + kWidth - 2) >= kNElts && (last_thread + 1 < kNThreads)){
476+
// In case last_thread == kNThreads - 1, accessing last_thread + 1 will result in a
477+
// illegal access error on H100.
478+
// Therefore, we access last_thread + 1, only if the final state data sits there
479+
reinterpret_cast<vec_t *>(x_vals_load)[1] = smem_exchange[last_thread + 1];
480+
}
481+
reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[last_thread];
452482
#pragma unroll
453483
for (int w = 0; w < kWidth - 1; ++w){
454484
conv_states[w] = x_vals_load[offset + w ];

tests/kernels/test_causal_conv1d.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def causal_conv1d_opcheck_fn(x: torch.Tensor,
151151
@pytest.mark.parametrize("has_bias", [True])
152152
@pytest.mark.parametrize("width", [4])
153153
@pytest.mark.parametrize(
154-
'seqlen', [1, 8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096])
154+
'seqlen', [1, 8, 16, 32, 64, 128, 256, 512, 784, 1024, 1025, 2048, 4096])
155155
@pytest.mark.parametrize('dim', [64])
156156
@pytest.mark.parametrize('batch', [1])
157157
def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,
@@ -420,7 +420,10 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
420420

421421
unpadded_out = out[:, :out_ref_tensor.shape[-1]]
422422
assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol)
423-
assert torch.allclose(final_states, final_states_ref, rtol=rtol, atol=atol)
423+
assert torch.allclose(final_states[state_indices],
424+
final_states_ref[state_indices],
425+
rtol=rtol,
426+
atol=atol)
424427

425428
causal_conv1d_opcheck_fn(x.squeeze(0), weight, bias, cumsum.cuda(),
426429
padded_state_indices, has_initial_states,

tests/kernels/test_mamba_ssm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,7 @@ def test_selective_state_update_with_batch_indices(with_padding, dim, dstate,
555555
device = "cuda"
556556
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
557557
if itype == torch.bfloat16:
558-
rtol, atol = 7e-2, 7e-2
558+
rtol, atol = 1e-1, 1e-1
559559
if torch.version.hip:
560560
atol *= 2
561561
# set seed
@@ -610,8 +610,8 @@ def test_selective_state_update_with_batch_indices(with_padding, dim, dstate,
610610
dt_bias=dt_bias,
611611
dt_softplus=True)
612612

613-
print("Output diff max", (out - out_ref[0]).max())
614-
print("Output diff mean", (out - out_ref[0]).mean())
613+
print("Output diff max", (out[:batch_size] - out_ref).max())
614+
print("Output diff mean", (out[:batch_size] - out_ref).mean())
615615
print("Output state diff max", (state[state_indices, :] - state_ref).max())
616616
print("Output state diff mean",
617617
(state[state_indices, :] - state_ref).mean())

0 commit comments

Comments
 (0)