You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
import torch
from mamba_ssm import Mamba2
dim = 256
model = Mamba2(
# This module uses roughly 3 * expand * d_model^2 parameters
d_model=dim, # Model dimension d_model
d_state=64, # SSM state expansion factor
d_conv=4, # Local convolution width
expand=2, # Block expansion factor
layer_idx=0,
)
model.compile()
model.to("cuda")
batch = 10
for length in [20, 30, 40]:
x = torch.randn(batch, length, dim).to("cuda")
y1 = model(x)
If the sequence length is fixed it works as intended, but for varying sequence length it crashes.
I get this stacktrace:
TORCHDYNAMO_VERBOSE=1 python timeseries_models/mwe.py
/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py:253: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
warnings.warn(
Traceback (most recent call last):
File "/home/peter/code/timeseries/timeseries_models/mwe.py", line 17, in <module>
y1 = model(x)
^^^^^^^^
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1765, in _wrapped_call_impl
return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 712, in compile_wrapper
raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 699, in compile_wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1778, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/mamba_ssm/modules/mamba2.py", line 185, in forward
out = mamba_split_conv1d_scan_combined(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 947, in mamba_split_conv1d_scan_combined
return MambaSplitConv1dScanCombinedFn.apply(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_weight, outproj_bias, headdim, ngroups, norm_before_gate)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/autograd/function.py", line 579, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/amp/autocast_mode.py", line 517, in decorate_fwd
return fwd(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 792, in forward
causal_conv1d_fwd_function(rearrange_and_update_stride(xBC, "b s d -> b d s"),
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1469, in __call__
return self._torchdynamo_orig_callable(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1248, in __call__
result = self._inner_convert(
^^^^^^^^^^^^^^^^^^^^
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 625, in __call__
return _compile(
^^^^^^^^^
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1092, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_utils_internal.py", line 97, in wrapper_function
return function(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 779, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 818, in _compile_inner
out_code = transform_code_object(code, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_dynamo/bytecode_transformation.py", line 1424, in transform_code_object
transformations(instructions, code_options)
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 265, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 743, in transform
tracer.run()
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3531, in run
super().run()
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1359, in run
while self.step():
^^^^^^^^^^^
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1263, in step
self.dispatch_table[inst.opcode](self, inst)
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3699, in RETURN_VALUE
self._return(inst)
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3680, in _return
all_stack_locals_metadata = self.output.compile_subgraph(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1405, in compile_subgraph
self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1678, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1733, in call_user_compiler
return self._call_user_compiler(gm)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1765, in _call_user_compiler
compiled_fn = compiler_fn(gm, self.example_inputs())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_dynamo/repro/after_dynamo.py", line 150, in __call__
compiled_gm = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/__init__.py", line 2365, in __call__
return compile_fx(model_, inputs_, config_patches=self.config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 2389, in compile_fx
raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 2375, in compile_fx
return aot_autograd(
^^^^^^^^^^^^^
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_dynamo/backends/common.py", line 106, in __call__
cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1189, in aot_module_simplified
compiled_fn = AOTAutogradCache.load(
^^^^^^^^^^^^^^^^^^^^^^
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/autograd_cache.py", line 1057, in load
compiled_fn = dispatch_and_compile()
^^^^^^^^^^^^^^^^^^^^^^
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1174, in dispatch_and_compile
compiled_fn, _ = create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 576, in create_aot_dispatcher_function
return _create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 836, in _create_aot_dispatcher_function
compiled_fn, fw_metadata = compiler_fn(
^^^^^^^^^^^^
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 240, in aot_dispatch_base
compiled_fw = compiler(fw_module, updated_flat_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 483, in __call__
return self.compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 2207, in fw_compiler_base
return inner_compile(
^^^^^^^^^^^^^^
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 717, in compile_fx_inner
return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_dynamo/repro/after_aot.py", line 124, in debug_wrapper
inner_compiled_fn = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 887, in _compile_fx_inner
raise InductorError(e, currentframe()).with_traceback(
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 871, in _compile_fx_inner
mb_compiled_graph = fx_codegen_and_compile(
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1524, in fx_codegen_and_compile
return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1402, in codegen_and_compile
compiled_module = graph.compile_to_module()
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_inductor/graph.py", line 2285, in compile_to_module
return self._compile_to_module()
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_inductor/graph.py", line 2291, in _compile_to_module
self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
^^^^^^^^^^^^^^
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_inductor/graph.py", line 2226, in codegen
self._update_scheduler()
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_inductor/graph.py", line 2220, in _update_scheduler
self.scheduler = Scheduler(self.operations)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_inductor/scheduler.py", line 2025, in __init__
self._init(nodes)
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_inductor/scheduler.py", line 2091, in _init
self.compute_ancestors()
File "/home/peter/code/timeseries/.venv/lib/python3.12/site-packages/torch/_inductor/scheduler.py", line 2572, in compute_ancestors
ancestors |= name_to_ancestors[dep_node_name]
~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^
torch._inductor.exc.InductorError: KeyError: 'op3'
The text was updated successfully, but these errors were encountered:
Using the git version:
If the sequence length is fixed it works as intended, but for varying sequence length it crashes.
I get this stacktrace:
The text was updated successfully, but these errors were encountered: