Skip to content

fix: replace add_identity by add_cast for type cast #3563

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
import numpy as np
import tensorrt as trt
import torch
import torch_tensorrt.dynamo.conversion.impl as impl
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
from torch.fx.node import Argument, Target
from torch.fx.passes.shape_prop import TensorMetadata

import torch_tensorrt.dynamo.conversion.impl as impl
from torch_tensorrt import _enums
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo._SourceIR import SourceIR
Expand Down Expand Up @@ -141,9 +142,9 @@ def cast_trt_tensor(
) -> TRTTensor:
"""Given a TRT Tensor, convert that Tensor to the specified dtype

Adds an Identity layer to the network which performs the conversion
if the input's dtype is different from the cast type. Otherwise returns
input unchanged
Adds a Cast layer to the network to convert the input tensor to the specified dtype.
If the input tensor already has the desired dtype, it is returned unchanged.
Otherwise, a Cast layer is added to perform the conversion

Args:
ctx (ConversionContext): A ConversionContext containing the TensorRT network
Expand Down
11 changes: 5 additions & 6 deletions py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import tensorrt as trt
from torch.fx.node import Target

from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
Expand All @@ -13,6 +14,9 @@
flatten_dims,
get_positive_dim,
get_trt_tensor,
has_dynamic_shape,
prepend_ones,
set_layer_name,
)
from torch_tensorrt.dynamo.conversion.impl.cat import cat
from torch_tensorrt.dynamo.conversion.impl.elementwise import floor_divide
Expand All @@ -23,11 +27,6 @@
from torch_tensorrt.dynamo.conversion.impl.shape import shape as get_shape
from torch_tensorrt.dynamo.conversion.impl.slice.base import slice
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
from torch_tensorrt.fx.converters.converter_utils import (
has_dynamic_shape,
prepend_ones,
set_layer_name,
)
from torch_tensorrt.fx.types import Shape, TRTTensor


Expand Down Expand Up @@ -230,7 +229,7 @@ def expand(
# If the rank of the input tensor is less than the shape's rank, pad with ones
if initial_tensor_rank < shape_rank:
input_t = prepend_ones(
ctx.net,
ctx,
input_t,
name + "_expand_broadcast",
shape_rank - initial_tensor_rank,
Expand Down
3 changes: 1 addition & 2 deletions py/torch_tensorrt/fx/converters/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,7 +909,6 @@ def type_cast(
"""
This function helps to cast the input type to cast_type
"""
layer_i = network.add_identity(input)
layer_i.set_output_type(0, cast_type)
layer_i = network.add_cast(input, cast_type)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use the cast_trt_tensor function to this instead ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a patch for FX, but looks like cast_trt_tensor is only in dynamo?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@peri044 As @zewenli98 mentioned, cast_trt_tensor is in Dynamo path. So it needs to import dynamo.conversion.converter_utils in FX path. It this what you intended? If not, would you prefer me to implement cast_trt_tensor just like in Dynamo path and use it instead of type_cast?

set_layer_name(layer_i, target, f"{name}_dtype_change")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the quick change @junstar92. LGTM as such. Just a minor change, since now we use the cast_trt_tensor in py/torch_tensorrt/dynamo/conversion/converter_utils.py and the above change is related to that, you could change the comment there -

Adds an Identity layer to the network which performs the conversion
if the input's dtype is different from the cast type. Otherwise returns
input unchanged

to something like

Adds a Cast layer to the network to convert the input tensor to the specified dtype.

If the input tensor already has the desired dtype, it is returned unchanged.
Otherwise, a Cast layer is added to perform the conversion

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback. I updated the comment for cast_trt_tensor as you mentioned.

return layer_i.get_output(0)