Skip to content

RuntimeError: CUDA error: an illegal memory access was encountered #732

Open
@calliope-pro

Description

@calliope-pro

Environment

  • mamba-ssm==2.2.4
  • causal-conv1d==1.5.0.post8
  • torch==2.5.1
  • GPU: NVIDIA A6000 (single GPU)
  • CUDA version: 12.2

Problem Description

I'm encountering a RuntimeError: CUDA error: an illegal memory access was encountered with specific batch/sequence length combinations when they are executed directly. However, the same tensor sizes work fine when executed after running a smaller tensor size first.

Code to Reproduce

Case 1: Works fine - all tensor sizes work when small tensor is executed first

if __name__ == "__main__":
    from mamba_ssm import Mamba2
    import torch
    mamba = Mamba2(
        d_model=384,
        d_state=256,
        d_conv=4,
        expand=4,
    ).cuda()
    
    # Execute small tensor first
    x = torch.randn(26, 2048, 384).cuda()
    x = mamba(x)
    print(x.shape)  # (26, 2048, 384) - Works
    
    # Then execute larger tensors
    x = torch.randn(27, 32768, 384).cuda()
    x = mamba(x)
    print(x.shape)  # (27, 32768, 384) - Works
    
    x = torch.randn(26, 32768, 384).cuda()
    x = mamba(x)
    print(x.shape)  # (26, 32768, 384) - Works

Case 2: Works fine - direct execution of different batch size

if __name__ == "__main__":
    from mamba_ssm import Mamba2
    import torch
    mamba = Mamba2(
        d_model=384,
        d_state=256,
        d_conv=4,
        expand=4,
    ).cuda()
    
    x = torch.randn(32, 32768, 384).cuda()
    x = mamba(x)
    print(x.shape)  # (32, 32768, 384) - Works

Case 3: FAILS - specific batch sizes fail when executed without small tensor first

if __name__ == "__main__":
    from mamba_ssm import Mamba2
    import torch
    mamba = Mamba2(
        d_model=384,
        d_state=256,
        d_conv=4,
        expand=4,
    ).cuda()
    
    # Skip small tensor execution - comment out these lines
    # x = torch.randn(26, 2048, 384).cuda()
    # x = mamba(x)
    # print(x.shape)
    
    # These specific sizes fail without the small tensor execution above
    x = torch.randn(27, 32768, 384).cuda()
    x = mamba(x)  # RuntimeError: CUDA error: an illegal memory access was encountered
    print(x.shape)
    
    x = torch.randn(26, 32768, 384).cuda()
    x = mamba(x)  # RuntimeError: CUDA error: an illegal memory access was encountered
    print(x.shape)

Key Observations

  1. Tensor sizes (27, 32768, 384) and (26, 32768, 384) work perfectly when executed after a smaller tensor (26, 2048, 384)
  2. The same tensor sizes fail with CUDA illegal memory access when executed directly as the first operation
  3. Different batch size (32, 32768, 384) works fine even without prior small tensor execution

Full Stack Trace

Traceback (most recent call last):
  File "/data2/matsumoto/Helixer_original/src/core/model/modules/mamba.py", line 132, in <module>
    x = mamba(x)
        ^^^^^^^^
  File "/data2/matsumoto/Helixer_original/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data2/matsumoto/Helixer_original/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data2/matsumoto/Helixer_original/.venv/lib/python3.11/site-packages/mamba_ssm/modules/mamba2.py", line 185, in forward
    out = mamba_split_conv1d_scan_combined(
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data2/matsumoto/Helixer_original/.venv/lib/python3.11/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 930, in mamba_split_conv1d_scan_combined
    return MambaSplitConv1dScanCombinedFn.apply(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_weight, outproj_bias, headdim, ngroups, norm_before_gate)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data2/matsumoto/Helixer_original/.venv/lib/python3.11/site-packages/torch/autograd/function.py", line 575, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data2/matsumoto/Helixer_original/.venv/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 465, in decorate_fwd
    return fwd(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/data2/matsumoto/Helixer_original/.venv/lib/python3.11/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 795, in forward
    out_x, _, dt_out, dA_cumsum, states, final_states = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size=chunk_size, D=D, z=None, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=dt_limit)
                                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data2/matsumoto/Helixer_original/.venv/lib/python3.11/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 312, in _mamba_chunk_scan_combined_fwd
    dA_cumsum, dt = _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data2/matsumoto/Helixer_original/.venv/lib/python3.11/site-packages/mamba_ssm/ops/triton/ssd_chunk_state.py", line 675, in _chunk_cumsum_fwd
    _chunk_cumsum_fwd_kernel[grid_chunk_cs](
  File "/data2/matsumoto/Helixer_original/.venv/lib/python3.11/site-packages/triton/runtime/jit.py", line 345, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data2/matsumoto/Helixer_original/.venv/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 156, in run
    timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data2/matsumoto/Helixer_original/.venv/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 156, in <dictcomp>
    timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data2/matsumoto/Helixer_original/.venv/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 133, in _bench
    return do_bench(kernel_call, warmup=self.num_warmups, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data2/matsumoto/Helixer_original/.venv/lib/python3.11/site-packages/triton/testing.py", line 107, in do_bench
    di.synchronize()
  File "/data2/matsumoto/Helixer_original/.venv/lib/python3.11/site-packages/torch/cuda/__init__.py", line 954, in synchronize
    return torch._C._cuda_synchronize()
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions