Conversation
ea3e98a to
98555fe
Compare
98555fe to
ab88b54
Compare
narendasan
left a comment
There was a problem hiding this comment.
@apbose let me know if I added redundant infrastructure anywhere, I kind of worked from scratch
beda370 to
d6f189b
Compare
| weight_refit_map[engine_weight_name].dtype, | ||
| ] | ||
|
|
||
| # Stage 3: Slice matching for unmatched non-scalar CONSTANT weights. |
There was a problem hiding this comment.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/utils.py 2026-03-11 18:43:16.001645+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/utils.py 2026-03-11 18:43:36.562015+00:00
@@ -880,11 +880,11 @@
pass
elif isinstance(output_meta, (FakeTensor, torch.Tensor)):
out_dtype = output_meta.dtype
if out_dtype in COMPLEX_TO_REAL_DTYPE:
out_dtype = COMPLEX_TO_REAL_DTYPE[out_dtype]
-
+
if truncate_double and out_dtype == torch.float64:
output_dtypes.append(dtype.float32)
else:
output_dtypes.append(dtype._from(out_dtype))
elif isinstance(output_meta, torch.SymInt):
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_complex_rewrite.py 2026-03-11 18:43:16.028018+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_complex_rewrite.py 2026-03-11 18:43:39.054297+00:00
@@ -60,12 +60,15 @@
def _real_inputs(inputs: Tuple[Any, ...]) -> Tuple[Any, ...]:
"""Convert complex tensors to [..., 2] real layout."""
return tuple(
- torch.view_as_real(x).contiguous() if isinstance(x, torch.Tensor) and x.is_complex()
- else x
+ (
+ torch.view_as_real(x).contiguous()
+ if isinstance(x, torch.Tensor) and x.is_complex()
+ else x
+ )
for x in inputs
)
def _to_complex_if_needed(out: torch.Tensor, ref: torch.Tensor) -> torch.Tensor:| "Consider adding a rewrite in complex_graph_rewrite.py.", | ||
| node.target, | ||
| ) | ||
|
|
There was a problem hiding this comment.
if handler is not None and handler(node):
modified = True
elif node.target in _ELEMENTWISE_SAFE:
we should change to this, so that a false in handler(node) should correctly log the passthrough. eg: add.Tensor with add.Tensor(tensor, tensor) should go to elementwise_safe requiring no rewrite.
| # can identify it as a complex-layout [..., 2] tensor later when | ||
| # processing ops that use this buffer. | ||
| if fake_mode is not None: | ||
| try: |
There was a problem hiding this comment.
dont we need a similar fake_mode=FakeTensorMode() in buffer case as we did in the placeholder case?
| return ComplexSubGraphInfo([anchor_node], ordered_subgraph, list(input_nodes)) | ||
|
|
||
| def find_complex_op_subgraphs( | ||
| self, gm: GraphModule, anchor_target: str |
There was a problem hiding this comment.
I assume this is no longer used?
| trt_model = torch_tensorrt.dynamo.compile( | ||
| ep, | ||
| inputs=trt_inputs, | ||
| enabled_precisions={torch.float32}, |
There was a problem hiding this comment.
need to set use_explicit_typing=False or remove enabled_precisions
| ) | ||
| else: | ||
| pass # call_function inputs are rewritten in-place by the op handlers | ||
| input_node.replace_all_uses_with(new_node) |
There was a problem hiding this comment.
new_node will not exist if input_node.op is not placeholder or get_attr
| @_complex_unpacker(torch.ops.aten.pow.Tensor_Tensor) | ||
| def _rewrite_pow_tensor_tensor(self, node: Node) -> bool: | ||
| # z1**z2 = exp(z2 * log(z1)) | ||
| z1_inp, z2_inp = node.args[0], node.args[1] | ||
| with SubgraphBuilder(self.gm.graph, node) as b: | ||
| re1, im1 = self._inline_select_re_im(b, z1_inp) | ||
| re2 = b(torch.ops.aten.select.int, z2_inp, -1, 0) | ||
| im2 = b(torch.ops.aten.select.int, z2_inp, -1, 1) | ||
| log_re, log_im = self._inline_complex_log(b, re1, im1) | ||
| mul_re, mul_im = self._inline_complex_mul(b, re2, im2, log_re, log_im) | ||
| exp_re, exp_im = self._inline_complex_exp(b, mul_re, mul_im) | ||
| out = self._inline_cat_re_im(b, exp_re, exp_im) | ||
| node.replace_all_uses_with(out) | ||
| self.gm.graph.erase_node(node) | ||
| return True |
There was a problem hiding this comment.
are these rewriting ops going to be added only for complex values and keep real values untouched?
There was a problem hiding this comment.
yes, part of what the pass does is it identifies the boundaries of complex computation and only applies on that subgraph
| exp_program1, | ||
| tuple(inputs), | ||
| use_python_runtime=True, | ||
| enabled_precisions={torch.float}, |
| exp_program1, | ||
| tuple(inputs), | ||
| use_python_runtime=True, | ||
| enabled_precisions={torch.float}, |
| exp_program1, | ||
| tuple(inputs), | ||
| use_python_runtime=True, | ||
| enabled_precisions={torch.float}, |
| ) | ||
| from torch_tensorrt.dynamo.utils import ( | ||
| COMPLEX_TO_REAL_DTYPE, | ||
| deallocate_module, |
There was a problem hiding this comment.
I believe this is not used in this file?
| if not isinstance(src, torch.fx.Node): | ||
| continue | ||
| if src.op == "call_module" and "_run_on_acc" in str(src.target): | ||
| with partitioned_module.graph.inserting_before(output_node): |
There was a problem hiding this comment.
shouldn't we have this check for run_on_gpu too? since we detect the complex graph for the whole graph, so the reshaping and modification to add 2 in shape (new shape = [...,2] would occur on all, so we would need to insert view_as_complex in the end to them too>?
There was a problem hiding this comment.
for run as gpu, it should just be in complex right? only TRT boundaries have the issue where we have this alternative representation for complex
There was a problem hiding this comment.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/utils.py 2026-03-12 20:52:03.519505+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/utils.py 2026-03-12 20:52:25.755519+00:00
@@ -880,11 +880,11 @@
pass
elif isinstance(output_meta, (FakeTensor, torch.Tensor)):
out_dtype = output_meta.dtype
if out_dtype in COMPLEX_TO_REAL_DTYPE:
out_dtype = COMPLEX_TO_REAL_DTYPE[out_dtype]
-
+
if truncate_double and out_dtype == torch.float64:
output_dtypes.append(dtype.float32)
else:
output_dtypes.append(dtype._from(out_dtype))
elif isinstance(output_meta, torch.SymInt):
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/hlo/test_complex_graph_break.py 2026-03-12 20:52:03.536251+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/hlo/test_complex_graph_break.py 2026-03-12 20:52:26.973354+00:00
@@ -121,48 +121,49 @@
nodes_by_target: dict = {}
for n in gm.graph.nodes:
nodes_by_target.setdefault(n.target, []).append(n)
# view_as_complex must be present (inserted by the fallback wrapper)
- assert torch.ops.aten.view_as_complex.default in nodes_by_target, (
- "Expected view_as_complex to be inserted before cumsum, but it was not found"
- )
+ assert (
+ torch.ops.aten.view_as_complex.default in nodes_by_target
+ ), "Expected view_as_complex to be inserted before cumsum, but it was not found"
# cumsum must still be present (it was NOT removed)
- assert torch.ops.aten.cumsum.default in nodes_by_target, (
- "cumsum should remain in the graph (runs as PyTorch fallback)"
- )
+ assert (
+ torch.ops.aten.cumsum.default in nodes_by_target
+ ), "cumsum should remain in the graph (runs as PyTorch fallback)"
# The view_as_complex output feeds directly into cumsum
vc_node = nodes_by_target[torch.ops.aten.view_as_complex.default][0]
cumsum_node = nodes_by_target[torch.ops.aten.cumsum.default][0]
- assert cumsum_node.args[0] is vc_node, (
- f"cumsum's first arg should be the view_as_complex node, got {cumsum_node.args[0]}"
- )
+ assert (
+ cumsum_node.args[0] is vc_node
+ ), f"cumsum's first arg should be the view_as_complex node, got {cumsum_node.args[0]}"
# The view_as_complex input is a real-layout (is_complex_layout) node
vc_input = vc_node.args[0]
assert isinstance(vc_input, torch.fx.Node), "view_as_complex input must be a Node"
- assert vc_input.meta.get("is_complex_layout", False), (
- "view_as_complex input should be a real-layout complex node (is_complex_layout=True)"
- )
+ assert vc_input.meta.get(
+ "is_complex_layout", False
+ ), "view_as_complex input should be a real-layout complex node (is_complex_layout=True)"
# view_as_real must follow cumsum
- assert torch.ops.aten.view_as_real.default in nodes_by_target, (
- "Expected view_as_real to be inserted after cumsum, but it was not found"
- )
+ assert (
+ torch.ops.aten.view_as_real.default in nodes_by_target
+ ), "Expected view_as_real to be inserted after cumsum, but it was not found"
vr_node = nodes_by_target[torch.ops.aten.view_as_real.default][0]
- assert vr_node.args[0] is cumsum_node, (
- f"view_as_real's arg should be the cumsum node, got {vr_node.args[0]}"
- )
+ assert (
+ vr_node.args[0] is cumsum_node
+ ), f"view_as_real's arg should be the cumsum node, got {vr_node.args[0]}"
# After metadata propagation, cumsum receives a complex-dtype tensor
vc_val = vc_node.meta.get("val")
if vc_val is not None:
- assert vc_val.dtype in (torch.complex64, torch.complex128), (
- f"view_as_complex output should be complex, got {vc_val.dtype}"
- )
+ assert vc_val.dtype in (
+ torch.complex64,
+ torch.complex128,
+ ), f"view_as_complex output should be complex, got {vc_val.dtype}"
# ===========================================================================
# Test 2 — lowerable ops TRT, unsupported op PyTorch (with complex input),
# lowerable ops TRT again; end-to-end numerical correctness
@@ -219,13 +220,14 @@
gm = _export_and_lower(model, inputs)
for n in gm.graph.nodes:
if n.target == torch.ops.aten.cumsum.default:
vc_val = n.args[0].meta.get("val")
if vc_val is not None:
- assert vc_val.dtype in (torch.complex64, torch.complex128), (
- f"cumsum should receive a complex tensor, got {vc_val.dtype}"
- )
+ assert vc_val.dtype in (
+ torch.complex64,
+ torch.complex128,
+ ), f"cumsum should receive a complex tensor, got {vc_val.dtype}"
break
# End-to-end: compile and verify numerical correctness
ep = torch.export.export(model, inputs)
trt_model = torchtrt.dynamo.compile(
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_complex_rewrite.py 2026-03-12 20:52:03.536666+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_complex_rewrite.py 2026-03-12 20:52:27.977057+00:00
@@ -60,12 +60,15 @@
def _real_inputs(inputs: Tuple[Any, ...]) -> Tuple[Any, ...]:
"""Convert complex tensors to [..., 2] real layout."""
return tuple(
- torch.view_as_real(x).contiguous() if isinstance(x, torch.Tensor) and x.is_complex()
- else x
+ (
+ torch.view_as_real(x).contiguous()
+ if isinstance(x, torch.Tensor) and x.is_complex()
+ else x
+ )
for x in inputs
)
def _to_complex_if_needed(out: torch.Tensor, ref: torch.Tensor) -> torch.Tensor:257fae2 to
68523e2
Compare
There was a problem hiding this comment.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/utils.py 2026-03-16 19:08:52.549532+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/utils.py 2026-03-16 19:09:13.244259+00:00
@@ -836,13 +836,13 @@
Copy the metadata from anchor node to the replacement node. This should be used
if the anchor node is replaced with only a single replacement node i.e one-one replacement.
"""
for match_and_replacement in match_and_replacements:
anchor_node = match_and_replacement.nodes_map[match_and_replacement.anchor]
- assert len(match_and_replacement.replacements) == 1, (
- "Found more than 1 replacements for the anchor node."
- )
+ assert (
+ len(match_and_replacement.replacements) == 1
+ ), "Found more than 1 replacements for the anchor node."
replacement_node = match_and_replacement.replacements[0]
replacement_node.meta = anchor_node.meta
def flatten_nodes(nodes: Any) -> List[torch.fx.node.Node]:
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/llm/test_llm_models.py 2026-03-16 19:08:52.575235+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/llm/test_llm_models.py 2026-03-16 19:09:14.613988+00:00
@@ -5,10 +5,11 @@
import pytest
import torch
import torch_tensorrt
import importlib
+
if importlib.util.find_spec("transformers"):
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
sys.path.append(os.path.join(os.path.dirname(__file__), "../../../../tools/llm"))
import argparse
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/hlo/test_complex_graph_break.py 2026-03-16 19:08:52.575061+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/hlo/test_complex_graph_break.py 2026-03-16 19:09:14.625039+00:00
@@ -121,48 +121,49 @@
nodes_by_target: dict = {}
for n in gm.graph.nodes:
nodes_by_target.setdefault(n.target, []).append(n)
# view_as_complex must be present (inserted by the fallback wrapper)
- assert torch.ops.aten.view_as_complex.default in nodes_by_target, (
- "Expected view_as_complex to be inserted before cumsum, but it was not found"
- )
+ assert (
+ torch.ops.aten.view_as_complex.default in nodes_by_target
+ ), "Expected view_as_complex to be inserted before cumsum, but it was not found"
# cumsum must still be present (it was NOT removed)
- assert torch.ops.aten.cumsum.default in nodes_by_target, (
- "cumsum should remain in the graph (runs as PyTorch fallback)"
- )
+ assert (
+ torch.ops.aten.cumsum.default in nodes_by_target
+ ), "cumsum should remain in the graph (runs as PyTorch fallback)"
# The view_as_complex output feeds directly into cumsum
vc_node = nodes_by_target[torch.ops.aten.view_as_complex.default][0]
cumsum_node = nodes_by_target[torch.ops.aten.cumsum.default][0]
- assert cumsum_node.args[0] is vc_node, (
- f"cumsum's first arg should be the view_as_complex node, got {cumsum_node.args[0]}"
- )
+ assert (
+ cumsum_node.args[0] is vc_node
+ ), f"cumsum's first arg should be the view_as_complex node, got {cumsum_node.args[0]}"
# The view_as_complex input is a real-layout (is_complex_layout) node
vc_input = vc_node.args[0]
assert isinstance(vc_input, torch.fx.Node), "view_as_complex input must be a Node"
- assert vc_input.meta.get("is_complex_layout", False), (
- "view_as_complex input should be a real-layout complex node (is_complex_layout=True)"
- )
+ assert vc_input.meta.get(
+ "is_complex_layout", False
+ ), "view_as_complex input should be a real-layout complex node (is_complex_layout=True)"
# view_as_real must follow cumsum
- assert torch.ops.aten.view_as_real.default in nodes_by_target, (
- "Expected view_as_real to be inserted after cumsum, but it was not found"
- )
+ assert (
+ torch.ops.aten.view_as_real.default in nodes_by_target
+ ), "Expected view_as_real to be inserted after cumsum, but it was not found"
vr_node = nodes_by_target[torch.ops.aten.view_as_real.default][0]
- assert vr_node.args[0] is cumsum_node, (
- f"view_as_real's arg should be the cumsum node, got {vr_node.args[0]}"
- )
+ assert (
+ vr_node.args[0] is cumsum_node
+ ), f"view_as_real's arg should be the cumsum node, got {vr_node.args[0]}"
# After metadata propagation, cumsum receives a complex-dtype tensor
vc_val = vc_node.meta.get("val")
if vc_val is not None:
- assert vc_val.dtype in (torch.complex64, torch.complex128), (
- f"view_as_complex output should be complex, got {vc_val.dtype}"
- )
+ assert vc_val.dtype in (
+ torch.complex64,
+ torch.complex128,
+ ), f"view_as_complex output should be complex, got {vc_val.dtype}"
# ===========================================================================
# Test 2 — lowerable ops TRT, unsupported op PyTorch (with complex input),
# lowerable ops TRT again; end-to-end numerical correctness
@@ -219,13 +220,14 @@
gm = _export_and_lower(model, inputs)
for n in gm.graph.nodes:
if n.target == torch.ops.aten.cumsum.default:
vc_val = n.args[0].meta.get("val")
if vc_val is not None:
- assert vc_val.dtype in (torch.complex64, torch.complex128), (
- f"cumsum should receive a complex tensor, got {vc_val.dtype}"
- )
+ assert vc_val.dtype in (
+ torch.complex64,
+ torch.complex128,
+ ), f"cumsum should receive a complex tensor, got {vc_val.dtype}"
break
# End-to-end: compile and verify numerical correctness
ep = torch.export.export(model, inputs)
trt_model = torchtrt.dynamo.compile(
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_complex_rewrite.py 2026-03-16 19:08:52.575618+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/lowering/test_complex_rewrite.py 2026-03-16 19:09:15.729911+00:00
@@ -60,12 +60,15 @@
def _real_inputs(inputs: Tuple[Any, ...]) -> Tuple[Any, ...]:
"""Convert complex tensors to [..., 2] real layout."""
return tuple(
- torch.view_as_real(x).contiguous() if isinstance(x, torch.Tensor) and x.is_complex()
- else x
+ (
+ torch.view_as_real(x).contiguous()
+ if isinstance(x, torch.Tensor) and x.is_complex()
+ else x
+ )
for x in inputs
)
def _to_complex_if_needed(out: torch.Tensor, ref: torch.Tensor) -> torch.Tensor:
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_models_export.py 2026-03-16 19:08:52.576235+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_models_export.py 2026-03-16 19:09:16.197167+00:00
@@ -51,10 +51,11 @@
msg=f"Resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)
# Clean up model env
torch._dynamo.reset()
+
@unittest.skipIf(
not importlib.util.find_spec("torchvision"), "torchvision not installed"
)
@pytest.mark.unit
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_models.py 2026-03-16 19:08:52.576235+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/models/test_models.py 2026-03-16 19:09:16.416436+00:00
@@ -332,10 +332,11 @@
msg=f"HF BERT base-uncased TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)
# Clean up model env
torch._dynamo.reset()
+
@unittest.skipIf(
not importlib.util.find_spec("transformers"), "torchvision not installed"
)
@pytest.mark.unit68523e2 to
4871e1a
Compare
…mplex numerics, including complex tensor I/O Introduce a new infrastructure in the replace complex pass to handle a number of cases where simply just unpacking complex tensors is not sufficent for supporting the numerics correctly. This pass also now captures meta data about the original call signature so that during graph construction, the original calling convention is preserved and the runtimes do not need any specialization on supporting complex types.
…ment that marks nodes that are complex
… pytorch rather than fail to build
4871e1a to
00d0389
Compare
There was a problem hiding this comment.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_bool_split_aten.py 2026-03-18 17:11:47.969832+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_bool_split_aten.py 2026-03-18 17:12:09.444254+00:00
@@ -4,10 +4,11 @@
1. `index_has_bool_indices` validator correctly distinguishes bool vs int indices.
2. Integer-indexed `aten.index.Tensor` routes to the converter WITHOUT output allocator.
3. Boolean-indexed `aten.index.Tensor` routes to the converter WITH output allocator.
4. Both paths produce correct results.
"""
+
import unittest
from unittest.mock import MagicMock
import torch
import torch.nn as nn
@@ -58,13 +59,11 @@
node = _make_index_node([None, torch.tensor([True, False])])
self.assertTrue(index_has_bool_indices(node))
def test_mixed_int_and_bool_returns_true(self):
"""If any index is bool, the function should return True."""
- node = _make_index_node(
- [torch.tensor([0, 1]), torch.tensor([True, False])]
- )
+ node = _make_index_node([torch.tensor([0, 1]), torch.tensor([True, False])])
self.assertTrue(index_has_bool_indices(node))
def test_all_none_returns_false(self):
node = _make_index_node([None, None])
self.assertFalse(index_has_bool_indices(node))
Description
Adding comprehensive support for decomposing complex numerics and running computation through Torch-TensorRT
Fixes # (issue)
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: