Skip to content

Mamba2 not compilable with dynamic sequence length #740

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
peterbjorgensen opened this issue May 30, 2025 · 0 comments
Open

Mamba2 not compilable with dynamic sequence length #740

peterbjorgensen opened this issue May 30, 2025 · 0 comments

Comments

@peterbjorgensen
Copy link

Using the git version:

import torch
from mamba_ssm import Mamba2
dim = 256
model = Mamba2(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim,  # Model dimension d_model
    d_state=64,  # SSM state expansion factor
    d_conv=4,  # Local convolution width
    expand=2,  # Block expansion factor
    layer_idx=0,
)
model.compile()
model.to("cuda")
batch = 10
for length in [20, 30, 40]:
    x = torch.randn(batch, length, dim).to("cuda")
    y1 = model(x)

If the sequence length is fixed it works as intended, but for varying sequence length it crashes.
I get this stacktrace:

TORCHDYNAMO_VERBOSE=1 python timeseries_models/mwe.py
/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py:253: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
  warnings.warn(
Traceback (most recent call last):
  File "/home/peter/code/timeseries/timeseries_models/mwe.py", line 17, in <module>
    y1 = model(x)
         ^^^^^^^^
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1765, in _wrapped_call_impl
    return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 712, in compile_wrapper
    raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 699, in compile_wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1778, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/mamba_ssm/modules/mamba2.py", line 185, in forward
    out = mamba_split_conv1d_scan_combined(
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 947, in mamba_split_conv1d_scan_combined
    return MambaSplitConv1dScanCombinedFn.apply(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_weight, outproj_bias, headdim, ngroups, norm_before_gate)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/autograd/function.py", line 579, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/amp/autocast_mode.py", line 517, in decorate_fwd
    return fwd(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 792, in forward
    causal_conv1d_fwd_function(rearrange_and_update_stride(xBC, "b s d -> b d s"),
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1469, in __call__
    return self._torchdynamo_orig_callable(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1248, in __call__
    result = self._inner_convert(
             ^^^^^^^^^^^^^^^^^^^^
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 625, in __call__
    return _compile(
           ^^^^^^^^^
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1092, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_utils_internal.py", line 97, in wrapper_function
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 779, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 818, in _compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_dynamo/bytecode_transformation.py", line 1424, in transform_code_object
    transformations(instructions, code_options)
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 265, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 743, in transform
    tracer.run()
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3531, in run
    super().run()
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1359, in run
    while self.step():
          ^^^^^^^^^^^
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1263, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3699, in RETURN_VALUE
    self._return(inst)
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3680, in _return
    all_stack_locals_metadata = self.output.compile_subgraph(
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1405, in compile_subgraph
    self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1678, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1733, in call_user_compiler
    return self._call_user_compiler(gm)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1765, in _call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_dynamo/repro/after_dynamo.py", line 150, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/__init__.py", line 2365, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 2389, in compile_fx
    raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 2375, in compile_fx
    return aot_autograd(
           ^^^^^^^^^^^^^
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/common.py", line 106, in __call__
    cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1189, in aot_module_simplified
    compiled_fn = AOTAutogradCache.load(
                  ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/autograd_cache.py", line 1057, in load
    compiled_fn = dispatch_and_compile()
                  ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1174, in dispatch_and_compile
    compiled_fn, _ = create_aot_dispatcher_function(
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 576, in create_aot_dispatcher_function
    return _create_aot_dispatcher_function(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 836, in _create_aot_dispatcher_function
    compiled_fn, fw_metadata = compiler_fn(
                               ^^^^^^^^^^^^
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 240, in aot_dispatch_base
    compiled_fw = compiler(fw_module, updated_flat_args)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 483, in __call__
    return self.compiler_fn(gm, example_inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 2207, in fw_compiler_base
    return inner_compile(
           ^^^^^^^^^^^^^^
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 717, in compile_fx_inner
    return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_dynamo/repro/after_aot.py", line 124, in debug_wrapper
    inner_compiled_fn = compiler_fn(gm, example_inputs)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 887, in _compile_fx_inner
    raise InductorError(e, currentframe()).with_traceback(
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 871, in _compile_fx_inner
    mb_compiled_graph = fx_codegen_and_compile(
                        ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1524, in fx_codegen_and_compile
    return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1402, in codegen_and_compile
    compiled_module = graph.compile_to_module()
                      ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_inductor/graph.py", line 2285, in compile_to_module
    return self._compile_to_module()
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_inductor/graph.py", line 2291, in _compile_to_module
    self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
                                                             ^^^^^^^^^^^^^^
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_inductor/graph.py", line 2226, in codegen
    self._update_scheduler()
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_inductor/graph.py", line 2220, in _update_scheduler
    self.scheduler = Scheduler(self.operations)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_inductor/scheduler.py", line 2025, in __init__
    self._init(nodes)
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_inductor/scheduler.py", line 2091, in _init
    self.compute_ancestors()
  File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_inductor/scheduler.py", line 2572, in compute_ancestors
    ancestors |= name_to_ancestors[dep_node_name]
                 ~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^
torch._inductor.exc.InductorError: KeyError: 'op3'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant