Description
Hi, I am benchmarking inference speed on long sequences and encountering CUDA-related errors specifically with the Mamba2 models at longer sequence lengths (>200k). This issue does not occur with Mamba1 models.
For example running:
python benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba2-1.3b" --promptlen 300000 --genlen 1
produces the error:
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/utils/generation.py", line 260, in generate
output = decode(
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/utils/generation.py", line 221, in decode
scores.append(get_logits(sequences[-1], inference_params))
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/utils/generation.py", line 184, in get_logits
logits = model(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/models/mixer_seq_simple.py", line 279, in forward
hidden_states = self.backbone(input_ids, inference_params=inference_params, **mixer_kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/models/mixer_seq_simple.py", line 194, in forward
hidden_states, residual = layer(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/modules/block.py", line 67, in forward
hidden_states = self.mixer(hidden_states, inference_params=inference_params, **mixer_kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/modules/mamba2.py", line 242, in forward
y = mamba_chunk_scan_combined(
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/ssd_combined.py", line 581, in mamba_chunk_scan_combined
return MambaChunkScanCombinedFn.apply(x, dt, A, B, C, chunk_size, D, z, dt_bias, initial_states, seq_idx, cu_seqlens, dt_softplus, dt_limit, return_final_states, return_varlen_states)
File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 553, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/ssd_combined.py", line 540, in forward
out, out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, cu_seqlens=cu_seqlens, dt_softplus=dt_softplus, dt_limit=dt_limit)
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/ssd_combined.py", line 312, in _mamba_chunk_scan_combined_fwd
dA_cumsum, dt = _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit)
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/ssd_chunk_state.py", line 675, in _chunk_cumsum_fwd
_chunk_cumsum_fwd_kernel[grid_chunk_cs](
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 326, in
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 156, in run
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 156, in
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 133, in _bench
return do_bench(kernel_call, warmup=self.num_warmups, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8))
File "/usr/local/lib/python3.10/dist-packages/triton/testing.py", line 108, in do_bench
torch.cuda.synchronize()
File "/usr/local/lib/python3.10/dist-packages/torch/cuda/init.py", line 801, in synchronize
return torch._C._cuda_synchronize()
RuntimeError: CUDA error: an illegal memory access was encountered
This issue seems to only occur with Mamba2 models and is present across models of all different sizes. Mamba1 however works well and i am able to do inference on prompt lengths of up to 1m on the 1.4b model.
I am using a single h100 (80gb) card.
Thanks!