File tree Expand file tree Collapse file tree 2 files changed +10
-10
lines changed
py/torch_tensorrt/dynamo/conversion Expand file tree Collapse file tree 2 files changed +10
-10
lines changed Original file line number Diff line number Diff line change 8
8
import numpy as np
9
9
import tensorrt as trt
10
10
import torch
11
- import torch_tensorrt .dynamo .conversion .impl as impl
12
11
from torch .fx .experimental .proxy_tensor import unset_fake_temporarily
13
12
from torch .fx .node import Argument , Target
14
13
from torch .fx .passes .shape_prop import TensorMetadata
14
+
15
+ import torch_tensorrt .dynamo .conversion .impl as impl
15
16
from torch_tensorrt import _enums
16
17
from torch_tensorrt .dynamo ._settings import CompilationSettings
17
18
from torch_tensorrt .dynamo ._SourceIR import SourceIR
@@ -141,9 +142,9 @@ def cast_trt_tensor(
141
142
) -> TRTTensor :
142
143
"""Given a TRT Tensor, convert that Tensor to the specified dtype
143
144
144
- Adds an Identity layer to the network which performs the conversion
145
- if the input's dtype is different from the cast type. Otherwise returns
146
- input unchanged
145
+ Adds a Cast layer to the network to convert the input tensor to the specified dtype.
146
+ If the input tensor already has the desired dtype, it is returned unchanged.
147
+ Otherwise, a Cast layer is added to perform the conversion
147
148
148
149
Args:
149
150
ctx (ConversionContext): A ConversionContext containing the TensorRT network
Original file line number Diff line number Diff line change 5
5
import numpy as np
6
6
import tensorrt as trt
7
7
from torch .fx .node import Target
8
+
8
9
from torch_tensorrt .dynamo ._SourceIR import SourceIR
9
10
from torch_tensorrt .dynamo .conversion import impl
10
11
from torch_tensorrt .dynamo .conversion ._ConversionContext import ConversionContext
13
14
flatten_dims ,
14
15
get_positive_dim ,
15
16
get_trt_tensor ,
17
+ has_dynamic_shape ,
18
+ prepend_ones ,
19
+ set_layer_name ,
16
20
)
17
21
from torch_tensorrt .dynamo .conversion .impl .cat import cat
18
22
from torch_tensorrt .dynamo .conversion .impl .elementwise import floor_divide
23
27
from torch_tensorrt .dynamo .conversion .impl .shape import shape as get_shape
24
28
from torch_tensorrt .dynamo .conversion .impl .slice .base import slice
25
29
from torch_tensorrt .dynamo .utils import DYNAMIC_DIM
26
- from torch_tensorrt .fx .converters .converter_utils import (
27
- has_dynamic_shape ,
28
- prepend_ones ,
29
- set_layer_name ,
30
- )
31
30
from torch_tensorrt .fx .types import Shape , TRTTensor
32
31
33
32
@@ -230,7 +229,7 @@ def expand(
230
229
# If the rank of the input tensor is less than the shape's rank, pad with ones
231
230
if initial_tensor_rank < shape_rank :
232
231
input_t = prepend_ones (
233
- ctx . net ,
232
+ ctx ,
234
233
input_t ,
235
234
name + "_expand_broadcast" ,
236
235
shape_rank - initial_tensor_rank ,
You can’t perform that action at this time.
0 commit comments