Open
Description
Environment
mamba-ssm==2.2.4
causal-conv1d==1.5.0.post8
torch==2.5.1
- GPU: NVIDIA A6000 (single GPU)
- CUDA version: 12.2
Problem Description
I'm encountering a RuntimeError: CUDA error: an illegal memory access was encountered
with specific batch/sequence length combinations when they are executed directly. However, the same tensor sizes work fine when executed after running a smaller tensor size first.
Code to Reproduce
Case 1: Works fine - all tensor sizes work when small tensor is executed first
if __name__ == "__main__":
from mamba_ssm import Mamba2
import torch
mamba = Mamba2(
d_model=384,
d_state=256,
d_conv=4,
expand=4,
).cuda()
# Execute small tensor first
x = torch.randn(26, 2048, 384).cuda()
x = mamba(x)
print(x.shape) # (26, 2048, 384) - Works
# Then execute larger tensors
x = torch.randn(27, 32768, 384).cuda()
x = mamba(x)
print(x.shape) # (27, 32768, 384) - Works
x = torch.randn(26, 32768, 384).cuda()
x = mamba(x)
print(x.shape) # (26, 32768, 384) - Works
Case 2: Works fine - direct execution of different batch size
if __name__ == "__main__":
from mamba_ssm import Mamba2
import torch
mamba = Mamba2(
d_model=384,
d_state=256,
d_conv=4,
expand=4,
).cuda()
x = torch.randn(32, 32768, 384).cuda()
x = mamba(x)
print(x.shape) # (32, 32768, 384) - Works
Case 3: FAILS - specific batch sizes fail when executed without small tensor first
if __name__ == "__main__":
from mamba_ssm import Mamba2
import torch
mamba = Mamba2(
d_model=384,
d_state=256,
d_conv=4,
expand=4,
).cuda()
# Skip small tensor execution - comment out these lines
# x = torch.randn(26, 2048, 384).cuda()
# x = mamba(x)
# print(x.shape)
# These specific sizes fail without the small tensor execution above
x = torch.randn(27, 32768, 384).cuda()
x = mamba(x) # RuntimeError: CUDA error: an illegal memory access was encountered
print(x.shape)
x = torch.randn(26, 32768, 384).cuda()
x = mamba(x) # RuntimeError: CUDA error: an illegal memory access was encountered
print(x.shape)
Key Observations
- Tensor sizes (27, 32768, 384) and (26, 32768, 384) work perfectly when executed after a smaller tensor
(26, 2048, 384)
- The same tensor sizes fail with CUDA illegal memory access when executed directly as the first operation
- Different batch size (32, 32768, 384) works fine even without prior small tensor execution
Full Stack Trace
Traceback (most recent call last):
File "/data2/matsumoto/Helixer_original/src/core/model/modules/mamba.py", line 132, in <module>
x = mamba(x)
^^^^^^^^
File "/data2/matsumoto/Helixer_original/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data2/matsumoto/Helixer_original/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data2/matsumoto/Helixer_original/.venv/lib/python3.11/site-packages/mamba_ssm/modules/mamba2.py", line 185, in forward
out = mamba_split_conv1d_scan_combined(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data2/matsumoto/Helixer_original/.venv/lib/python3.11/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 930, in mamba_split_conv1d_scan_combined
return MambaSplitConv1dScanCombinedFn.apply(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_weight, outproj_bias, headdim, ngroups, norm_before_gate)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data2/matsumoto/Helixer_original/.venv/lib/python3.11/site-packages/torch/autograd/function.py", line 575, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data2/matsumoto/Helixer_original/.venv/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 465, in decorate_fwd
return fwd(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
File "/data2/matsumoto/Helixer_original/.venv/lib/python3.11/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 795, in forward
out_x, _, dt_out, dA_cumsum, states, final_states = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size=chunk_size, D=D, z=None, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=dt_limit)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data2/matsumoto/Helixer_original/.venv/lib/python3.11/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 312, in _mamba_chunk_scan_combined_fwd
dA_cumsum, dt = _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data2/matsumoto/Helixer_original/.venv/lib/python3.11/site-packages/mamba_ssm/ops/triton/ssd_chunk_state.py", line 675, in _chunk_cumsum_fwd
_chunk_cumsum_fwd_kernel[grid_chunk_cs](
File "/data2/matsumoto/Helixer_original/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 345, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data2/matsumoto/Helixer_original/.venv/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 156, in run
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data2/matsumoto/Helixer_original/.venv/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 156, in <dictcomp>
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data2/matsumoto/Helixer_original/.venv/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 133, in _bench
return do_bench(kernel_call, warmup=self.num_warmups, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data2/matsumoto/Helixer_original/.venv/lib/python3.11/site-packages/triton/testing.py", line 107, in do_bench
di.synchronize()
File "/data2/matsumoto/Helixer_original/.venv/lib/python3.11/site-packages/torch/cuda/__init__.py", line 954, in synchronize
return torch._C._cuda_synchronize()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
Metadata
Metadata
Assignees
Labels
No labels