Skip to content

CUDA error when using Mamba2 with long context #503

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
titzehong opened this issue Jul 30, 2024 · 7 comments · May be fixed by #708
Open

CUDA error when using Mamba2 with long context #503

titzehong opened this issue Jul 30, 2024 · 7 comments · May be fixed by #708

Comments

@titzehong
Copy link

Hi, I am benchmarking inference speed on long sequences and encountering CUDA-related errors specifically with the Mamba2 models at longer sequence lengths (>200k). This issue does not occur with Mamba1 models.

For example running:

python benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba2-1.3b" --promptlen 300000 --genlen 1

produces the error:

File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/utils/generation.py", line 260, in generate
output = decode(
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/utils/generation.py", line 221, in decode
scores.append(get_logits(sequences[-1], inference_params))
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/utils/generation.py", line 184, in get_logits
logits = model(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/models/mixer_seq_simple.py", line 279, in forward
hidden_states = self.backbone(input_ids, inference_params=inference_params, **mixer_kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/models/mixer_seq_simple.py", line 194, in forward
hidden_states, residual = layer(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/modules/block.py", line 67, in forward
hidden_states = self.mixer(hidden_states, inference_params=inference_params, **mixer_kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/modules/mamba2.py", line 242, in forward
y = mamba_chunk_scan_combined(
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/ssd_combined.py", line 581, in mamba_chunk_scan_combined
return MambaChunkScanCombinedFn.apply(x, dt, A, B, C, chunk_size, D, z, dt_bias, initial_states, seq_idx, cu_seqlens, dt_softplus, dt_limit, return_final_states, return_varlen_states)
File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 553, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/ssd_combined.py", line 540, in forward
out, out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, cu_seqlens=cu_seqlens, dt_softplus=dt_softplus, dt_limit=dt_limit)
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/ssd_combined.py", line 312, in _mamba_chunk_scan_combined_fwd
dA_cumsum, dt = _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit)
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/ssd_chunk_state.py", line 675, in _chunk_cumsum_fwd
_chunk_cumsum_fwd_kernel[grid_chunk_cs](
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 326, in
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 156, in run
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 156, in
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 133, in _bench
return do_bench(kernel_call, warmup=self.num_warmups, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8))
File "/usr/local/lib/python3.10/dist-packages/triton/testing.py", line 108, in do_bench
torch.cuda.synchronize()
File "/usr/local/lib/python3.10/dist-packages/torch/cuda/init.py", line 801, in synchronize
return torch._C._cuda_synchronize()
RuntimeError: CUDA error: an illegal memory access was encountered

This issue seems to only occur with Mamba2 models and is present across models of all different sizes. Mamba1 however works well and i am able to do inference on prompt lengths of up to 1m on the 1.4b model.

I am using a single h100 (80gb) card.

Thanks!

@tridao
Copy link
Collaborator

tridao commented Jul 30, 2024

Are any tensors of size >= 2GB?
We use int32 for indexing, it's possible that it wraps around the max of int32 and produce negative index, causing IMA.

@titzehong
Copy link
Author

Yes i have several intermediate output tensors exceeding 2GB in size. However this is also true for lower context lengths that do not produce the error. For example, context length 250k which runs fine has almost all layer outputs exceeding 2GB but it runs fine.

Noted on the indexing causing the issue, is one solution to lower the models dimension?

@Hprairie
Copy link
Contributor

I think I found the problem, I have submitted a PR.

@iofu728
Copy link

iofu728 commented Sep 26, 2024

+1.

@Hprairie
Copy link
Contributor

Try the fix I did in the PR, it's 4 lines of code change, lmk if that works?

@serendipityCoding
Copy link

+1. I tried the fix in the PR, not working either.

@LuJunru
Copy link

LuJunru commented Oct 25, 2024

I found this is related to triton issue. Modify tl.program_id(*) to tl.program_id(*).to(tl.int64) can skip this error: triton-lang/triton#1058.

@younesbelkada younesbelkada linked a pull request Mar 19, 2025 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants