Skip to content

Commit cc98fec

Browse files
committed
fix: use dynamo path for conversion utils instead of fx
1 parent 844b0ad commit cc98fec

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88
import numpy as np
99
import tensorrt as trt
1010
import torch
11-
import torch_tensorrt.dynamo.conversion.impl as impl
1211
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
1312
from torch.fx.node import Argument, Target
1413
from torch.fx.passes.shape_prop import TensorMetadata
14+
15+
import torch_tensorrt.dynamo.conversion.impl as impl
1516
from torch_tensorrt import _enums
1617
from torch_tensorrt.dynamo._settings import CompilationSettings
1718
from torch_tensorrt.dynamo._SourceIR import SourceIR
@@ -141,9 +142,9 @@ def cast_trt_tensor(
141142
) -> TRTTensor:
142143
"""Given a TRT Tensor, convert that Tensor to the specified dtype
143144
144-
Adds an Identity layer to the network which performs the conversion
145-
if the input's dtype is different from the cast type. Otherwise returns
146-
input unchanged
145+
Adds a Cast layer to the network to convert the input tensor to the specified dtype.
146+
If the input tensor already has the desired dtype, it is returned unchanged.
147+
Otherwise, a Cast layer is added to perform the conversion
147148
148149
Args:
149150
ctx (ConversionContext): A ConversionContext containing the TensorRT network

py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy as np
66
import tensorrt as trt
77
from torch.fx.node import Target
8+
89
from torch_tensorrt.dynamo._SourceIR import SourceIR
910
from torch_tensorrt.dynamo.conversion import impl
1011
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
@@ -13,6 +14,9 @@
1314
flatten_dims,
1415
get_positive_dim,
1516
get_trt_tensor,
17+
has_dynamic_shape,
18+
prepend_ones,
19+
set_layer_name,
1620
)
1721
from torch_tensorrt.dynamo.conversion.impl.cat import cat
1822
from torch_tensorrt.dynamo.conversion.impl.elementwise import floor_divide
@@ -23,11 +27,6 @@
2327
from torch_tensorrt.dynamo.conversion.impl.shape import shape as get_shape
2428
from torch_tensorrt.dynamo.conversion.impl.slice.base import slice
2529
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
26-
from torch_tensorrt.fx.converters.converter_utils import (
27-
has_dynamic_shape,
28-
prepend_ones,
29-
set_layer_name,
30-
)
3130
from torch_tensorrt.fx.types import Shape, TRTTensor
3231

3332

@@ -230,7 +229,7 @@ def expand(
230229
# If the rank of the input tensor is less than the shape's rank, pad with ones
231230
if initial_tensor_rank < shape_rank:
232231
input_t = prepend_ones(
233-
ctx.net,
232+
ctx,
234233
input_t,
235234
name + "_expand_broadcast",
236235
shape_rank - initial_tensor_rank,

0 commit comments

Comments
 (0)