From 844b0ad0ffced2c4ba7ec45d291f0c1cd9fb8a97 Mon Sep 17 00:00:00 2001 From: "changjun.lee" Date: Mon, 9 Jun 2025 12:55:20 +0900 Subject: [PATCH 1/2] fix: replace add_identity by add_cast for type cast --- py/torch_tensorrt/fx/converters/converter_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/py/torch_tensorrt/fx/converters/converter_utils.py b/py/torch_tensorrt/fx/converters/converter_utils.py index 510d4ef69b..78ea125424 100644 --- a/py/torch_tensorrt/fx/converters/converter_utils.py +++ b/py/torch_tensorrt/fx/converters/converter_utils.py @@ -909,7 +909,6 @@ def type_cast( """ This function helps to cast the input type to cast_type """ - layer_i = network.add_identity(input) - layer_i.set_output_type(0, cast_type) + layer_i = network.add_cast(input, cast_type) set_layer_name(layer_i, target, f"{name}_dtype_change") return layer_i.get_output(0) From cc98fec4e3594425c5af7626bd3c9d5778aa7091 Mon Sep 17 00:00:00 2001 From: "changjun.lee" Date: Tue, 10 Jun 2025 09:32:45 +0900 Subject: [PATCH 2/2] fix: use dynamo path for conversion utils instead of fx --- .../dynamo/conversion/converter_utils.py | 9 +++++---- py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py | 11 +++++------ 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 2df2f0f31b..26c5bb5126 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -8,10 +8,11 @@ import numpy as np import tensorrt as trt import torch -import torch_tensorrt.dynamo.conversion.impl as impl from torch.fx.experimental.proxy_tensor import unset_fake_temporarily from torch.fx.node import Argument, Target from torch.fx.passes.shape_prop import TensorMetadata + +import torch_tensorrt.dynamo.conversion.impl as impl from torch_tensorrt import _enums from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo._SourceIR import SourceIR @@ -141,9 +142,9 @@ def cast_trt_tensor( ) -> TRTTensor: """Given a TRT Tensor, convert that Tensor to the specified dtype - Adds an Identity layer to the network which performs the conversion - if the input's dtype is different from the cast type. Otherwise returns - input unchanged + Adds a Cast layer to the network to convert the input tensor to the specified dtype. + If the input tensor already has the desired dtype, it is returned unchanged. + Otherwise, a Cast layer is added to perform the conversion Args: ctx (ConversionContext): A ConversionContext containing the TensorRT network diff --git a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py index 990b01eb70..203bb03553 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py @@ -5,6 +5,7 @@ import numpy as np import tensorrt as trt from torch.fx.node import Target + from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion import impl from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext @@ -13,6 +14,9 @@ flatten_dims, get_positive_dim, get_trt_tensor, + has_dynamic_shape, + prepend_ones, + set_layer_name, ) from torch_tensorrt.dynamo.conversion.impl.cat import cat from torch_tensorrt.dynamo.conversion.impl.elementwise import floor_divide @@ -23,11 +27,6 @@ from torch_tensorrt.dynamo.conversion.impl.shape import shape as get_shape from torch_tensorrt.dynamo.conversion.impl.slice.base import slice from torch_tensorrt.dynamo.utils import DYNAMIC_DIM -from torch_tensorrt.fx.converters.converter_utils import ( - has_dynamic_shape, - prepend_ones, - set_layer_name, -) from torch_tensorrt.fx.types import Shape, TRTTensor @@ -230,7 +229,7 @@ def expand( # If the rank of the input tensor is less than the shape's rank, pad with ones if initial_tensor_rank < shape_rank: input_t = prepend_ones( - ctx.net, + ctx, input_t, name + "_expand_broadcast", shape_rank - initial_tensor_rank,