Skip to content

CUDA error when using Mamba2 with long context #503

Open
@titzehong

Description

@titzehong

Hi, I am benchmarking inference speed on long sequences and encountering CUDA-related errors specifically with the Mamba2 models at longer sequence lengths (>200k). This issue does not occur with Mamba1 models.

For example running:

python benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba2-1.3b" --promptlen 300000 --genlen 1

produces the error:

File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/utils/generation.py", line 260, in generate
output = decode(
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/utils/generation.py", line 221, in decode
scores.append(get_logits(sequences[-1], inference_params))
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/utils/generation.py", line 184, in get_logits
logits = model(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/models/mixer_seq_simple.py", line 279, in forward
hidden_states = self.backbone(input_ids, inference_params=inference_params, **mixer_kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/models/mixer_seq_simple.py", line 194, in forward
hidden_states, residual = layer(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/modules/block.py", line 67, in forward
hidden_states = self.mixer(hidden_states, inference_params=inference_params, **mixer_kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/modules/mamba2.py", line 242, in forward
y = mamba_chunk_scan_combined(
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/ssd_combined.py", line 581, in mamba_chunk_scan_combined
return MambaChunkScanCombinedFn.apply(x, dt, A, B, C, chunk_size, D, z, dt_bias, initial_states, seq_idx, cu_seqlens, dt_softplus, dt_limit, return_final_states, return_varlen_states)
File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 553, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/ssd_combined.py", line 540, in forward
out, out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, cu_seqlens=cu_seqlens, dt_softplus=dt_softplus, dt_limit=dt_limit)
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/ssd_combined.py", line 312, in _mamba_chunk_scan_combined_fwd
dA_cumsum, dt = _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit)
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/ssd_chunk_state.py", line 675, in _chunk_cumsum_fwd
_chunk_cumsum_fwd_kernel[grid_chunk_cs](
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 326, in
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 156, in run
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 156, in
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 133, in _bench
return do_bench(kernel_call, warmup=self.num_warmups, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8))
File "/usr/local/lib/python3.10/dist-packages/triton/testing.py", line 108, in do_bench
torch.cuda.synchronize()
File "/usr/local/lib/python3.10/dist-packages/torch/cuda/init.py", line 801, in synchronize
return torch._C._cuda_synchronize()
RuntimeError: CUDA error: an illegal memory access was encountered

This issue seems to only occur with Mamba2 models and is present across models of all different sizes. Mamba1 however works well and i am able to do inference on prompt lengths of up to 1m on the 1.4b model.

I am using a single h100 (80gb) card.

Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions