Skip to content

🐛 [Bug] Dynamic Shape Type Mismatch Error When Using Static Shape #3876

@henrymmorton

Description

@henrymmorton

Bug Description

To Reproduce

Steps to reproduce the behavior:

  1. Install the torch_tensorrt wheels found at https://pypi.jetson-ai-lab.io/jp6/cu126 (2.8 for cu126) on a Jetson Orin Nano running Jetpack 6.2
  2. Try to compile a model using a static shape

Code Sample

dummy_input = torch.randn(1, 3, 544, 960, dtype=torch.float16).cuda()

trt_model = torch_tensorrt.compile(
model,
inputs=[dummy_input],
enabled_precision={torch.float16},
workspace_size=1 << 27
)

Stack Trace

raceback (most recent call last):
File "/home/henry/.local/lib/python3.10/site-packages/torch/export/dynamic_shapes.py", line 614, in _tree_map_with_path
return tree_map_with_path(f, tree, *dynamic_shapes, is_leaf=is_leaf)
File "/home/henry/.local/lib/python3.10/site-packages/torch/utils/_pytree.py", line 2076, in tree_map_with_path
all_keypath_leaves = keypath_leaves + [treespec.flatten_up_to(r) for r in rests]
File "/home/henry/.local/lib/python3.10/site-packages/torch/utils/_pytree.py", line 2076, in
all_keypath_leaves = keypath_leaves + [treespec.flatten_up_to(r) for r in rests]
File "/home/henry/.local/lib/python3.10/site-packages/torch/utils/_pytree.py", line 1192, in flatten_up_to
helper(self, tree, subtrees)
File "/home/henry/.local/lib/python3.10/site-packages/torch/utils/_pytree.py", line 1189, in helper
helper(subspec, subtree, subtrees)
File "/home/henry/.local/lib/python3.10/site-packages/torch/utils/_pytree.py", line 1145, in helper
raise ValueError(
ValueError: Node type mismatch; expected <class 'tuple'>, but got <class 'dict'>.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/home/henry/git/aeolus_dev/core/autonomy.py", line 442, in
autonomy(speed_matrix=HARD_SPEED_MATRIX, spoofing = run_with_spoofing, recording_path = hard_recording_path)
File "/home/henry/git/aeolus_dev/core/autonomy.py", line 377, in autonomy
aeolus, video_stream, teensy_state_store = spin_up_spoofed_autonomy(speed_matrix, recording_path)
File "/home/henry/git/aeolus_dev/core/autonomy.py", line 306, in spin_up_spoofed_autonomy
prime_autonomy_pipeline(my_aeolus_run, image_stream)
File "/home/henry/git/aeolus_dev/core/autonomy.py", line 224, in prime_autonomy_pipeline
predictor = ERFNetPredictor(aeolus.model_weights)
File "/home/henry/git/aeolus_dev/core/cv/nueral_net/vision_pipeline/ERFNetPredictor.py", line 38, in init
trt_model = torch_tensorrt.compile(
File "/home/henry/.virtualenvs/aeolus-system/lib/python3.10/site-packages/torch_tensorrt/_compile.py", line 286, in compile
exp_program = dynamo_trace(
File "/home/henry/.virtualenvs/aeolus-system/lib/python3.10/site-packages/torch_tensorrt/dynamo/_tracer.py", line 79, in trace
exp_program = export(
File "/home/henry/.local/lib/python3.10/site-packages/torch/export/init.py", line 304, in export
raise e
File "/home/henry/.local/lib/python3.10/site-packages/torch/export/init.py", line 271, in export
return _export(
File "/home/henry/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1150, in wrapper
raise e
File "/home/henry/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1116, in wrapper
ep = fn(*args, **kwargs)
File "/home/henry/.local/lib/python3.10/site-packages/torch/export/exported_program.py", line 123, in wrapper
return fn(*args, **kwargs)
File "/home/henry/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 2163, in _export
ep = _export_for_training(
File "/home/henry/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1150, in wrapper
raise e
File "/home/henry/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1116, in wrapper
ep = fn(*args, **kwargs)
File "/home/henry/.local/lib/python3.10/site-packages/torch/export/exported_program.py", line 123, in wrapper
return fn(*args, **kwargs)
File "/home/henry/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 2026, in _export_for_training
export_artifact = export_func(
File "/home/henry/.local/lib/python3.10/site-packages/torch/export/_trace.py", line 1923, in _non_strict_export
) = make_fake_inputs(
File "/home/henry/.local/lib/python3.10/site-packages/torch/_export/non_strict_utils.py", line 356, in make_fake_inputs
_check_dynamic_shapes(combined_args, dynamic_shapes)
File "/home/henry/.local/lib/python3.10/site-packages/torch/export/dynamic_shapes.py", line 1031, in _check_dynamic_shapes
_tree_map_with_path(check_shape, combined_args, dynamic_shapes, tree_name="inputs")
File "/home/henry/.local/lib/python3.10/site-packages/torch/export/dynamic_shapes.py", line 686, in _tree_map_with_path
_compare(tree_spec, other_tree_spec, [])
File "/home/henry/.local/lib/python3.10/site-packages/torch/export/dynamic_shapes.py", line 677, in _compare
_compare(
File "/home/henry/.local/lib/python3.10/site-packages/torch/export/dynamic_shapes.py", line 652, in _compare
raise_mismatch_error(
File "/home/henry/.local/lib/python3.10/site-packages/torch/export/dynamic_shapes.py", line 634, in raise_mismatch_error
raise UserError(
torch._dynamo.exc.UserError: Detected mismatch between the structure of inputs and dynamic_shapes: inputs['inputs'] is a <class 'tuple'>, but dynamic_shapes['inputs'] is a <class 'dict'>
For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#dynamic-shapes-validation

Expected behavior

I would expect the model to compile without issues without any errors related to dynamic shape since it is given a static shape

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version: 2.8
  • PyTorch Version: 2.8
  • CPU Architecture: ARM
  • OS (e.g., Linux):
  • How you installed PyTorch: Jetson-AI-Lab wheels
  • Python version: 3.10
  • CUDA version: 12.6
  • GPU models and configuration: Jetson Orin Nano
  • Any other relevant information:

Additional context

I believe this error is related to the function call
dynamic shapes = get_dynamic_shapes_args(mod, arg_inputs) (Line 77 of _tracer.py)

The subsequent call to export expects None for dynamic shapes if the input is a static shape but get_dynamic_shape_args returns an empty dictionary instead

Replacing the function call with dynamic_shapes = None solves the problem

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions