Skip to content

Commit b63e06c

Browse files
authored
fix: Fix a perf regression due to weights being ITensors (#3568)
1 parent 60863a3 commit b63e06c

File tree

6 files changed

+85
-38
lines changed

6 files changed

+85
-38
lines changed

py/torch_tensorrt/dynamo/_refit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def construct_refit_mapping(
7878
)
7979
interpreter._construct_trt_network_def()
8080

81-
return interpreter.ctx.mapping
81+
return interpreter.ctx.weight_refit_map
8282

8383

8484
@needs_refit

py/torch_tensorrt/dynamo/conversion/_ConversionContext.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,17 @@ class ConversionContext:
1515
net: TensorRT Network being built
1616
compilation_settings: Settings selected by the user for compilation
1717
requires_output_allocator: Boolean flag indicating if the converter creates operators which require an Output Allocator to run (e.g. data dependent operators)
18+
weight_refit_map: Dictionary mapping weight names to their corresponding np.array
19+
cpu_weights_reference_holder: Dictionary mapping weight names to their corresponding torch.Tensor
1820
"""
1921

2022
net: TRTNetwork
2123
compilation_settings: CompilationSettings = field(
2224
default_factory=CompilationSettings
2325
)
2426
requires_output_allocator: bool = False
25-
mapping: dict[str, np.array] = field(default_factory=dict)
26-
cpu_weights_reference_holder: dict[str, Union[torch.Tensor, np.array]] = field(
27+
weight_refit_map: dict[str, np.array] = field(default_factory=dict)
28+
cpu_weights_reference_holder: dict[str, Union[torch.Tensor]] = field(
2729
default_factory=dict
2830
)
2931

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,7 @@ def _save_weight_mapping(self) -> None:
499499
for k, v in self.module.state_dict().items()
500500
}
501501
weight_name_map: dict[str, Any] = {}
502-
np_map = self.ctx.mapping
502+
np_map = self.ctx.weight_refit_map
503503
constant_mapping = {k: v for k, v in np_map.items() if v.size == 1}
504504
net = self.ctx.net
505505
for i in range(net.num_layers):

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,37 @@ def cast_int_or_float_to_bool(
320320
return tensor
321321

322322

323+
def to_trt_weights(
324+
value: Any, target_quantized_type: Optional[trt.DataType] = None
325+
) -> trt.Weights:
326+
"""
327+
Convert a PyTorch tensor or NumPy array to TensorRT weights.
328+
329+
Args:
330+
value (Union[torch.Tensor, np.ndarray]): The tensor or array to convert to TRT weights
331+
332+
Returns:
333+
trt.Weights: TensorRT weights object with appropriate data type
334+
335+
Note:
336+
- Input tensors are made contiguous before conversion
337+
- Data type is preserved from the original tensor/array
338+
"""
339+
if isinstance(value, torch.Tensor):
340+
# Tensor must be contiguous before conversion
341+
value = value.contiguous()
342+
value_trt_dtype = _enums.dtype._from(value.dtype).to(trt.DataType)
343+
return trt.Weights(value_trt_dtype, value.data_ptr(), value.nelement())
344+
elif isinstance(value, np.ndarray):
345+
value = np.ascontiguousarray(value)
346+
value_np_dtype = _enums.dtype._from(value.dtype).to(np.dtype, use_default=True)
347+
return trt.Weights(value_np_dtype, value.data, value.size)
348+
else:
349+
raise AssertionError(
350+
f"to_trt_weights can only be called on torch.Tensor or np.ndarray, got an object of type: {type(value)}"
351+
)
352+
353+
323354
def create_constant(
324355
ctx: ConversionContext,
325356
value: Union[int, float, bool, np.ndarray, torch.Tensor],
@@ -363,19 +394,6 @@ def create_constant(
363394
shape = list(torch_value.shape)
364395

365396
if torch_value is not None:
366-
if torch_value.dtype == torch.float8_e4m3fn:
367-
weights = trt.Weights(
368-
type=trt.DataType.FP8,
369-
ptr=torch_value.data_ptr(),
370-
count=torch_value.numel(),
371-
)
372-
constant = ctx.net.add_constant(
373-
shape,
374-
weights,
375-
)
376-
constant.name = name
377-
ctx.cpu_weights_reference_holder[name + " FP8_CONSTANT"] = torch_value
378-
return constant.get_output(0)
379397

380398
if torch_value.dtype == torch.uint8:
381399
if (
@@ -400,25 +418,27 @@ def create_constant(
400418
ctx.cpu_weights_reference_holder[name + " FP4_CONSTANT"] = torch_value
401419
return constant.get_output(0)
402420

421+
# TODO: Refit map uses numpy arrays. Remove this once refit is updated to use torch.Tensor
403422
if torch_value.dtype == torch.bfloat16:
404423
torch_value_fp32 = torch_value.to(torch.float32)
405424
numpy_value = torch_value_fp32.numpy()
406425
else:
407426
numpy_value = torch_value.numpy()
408-
ctx.mapping[name + " CONSTANT"] = numpy_value.reshape(-1)
427+
428+
# Used for refit
429+
ctx.weight_refit_map[name + " CONSTANT"] = numpy_value.reshape(-1)
430+
431+
# This is a buffer to hold the torch.Tensor so that they are alive during the course of TRT compilation.
432+
ctx.cpu_weights_reference_holder[name] = torch_value
433+
434+
# Convert the torch.Tensor to a trt.Weights object
435+
trt_weights = to_trt_weights(torch_value)
409436
constant = ctx.net.add_constant(
410437
shape,
411-
numpy_value,
438+
trt_weights,
412439
)
413440
constant.name = name
414441

415-
if torch_value.dtype == torch.bfloat16:
416-
return cast_trt_tensor(
417-
ctx,
418-
constant.get_output(0),
419-
trt.DataType.BF16,
420-
name + "_bf16_cast",
421-
)
422442
return constant.get_output(0)
423443
else:
424444
raise ValueError(

py/torch_tensorrt/dynamo/conversion/impl/conv.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,14 @@
1313
cast_trt_tensor,
1414
extend_attr_to_tuple,
1515
get_trt_tensor,
16+
has_dynamic_shape,
17+
set_layer_name,
1618
to_torch,
19+
to_trt_weights,
1720
)
1821
from torch_tensorrt.fx.converters.converter_utils import (
1922
get_dyn_range,
20-
has_dynamic_shape,
2123
mark_as_int8_layer,
22-
set_layer_name,
2324
)
2425
from torch_tensorrt.fx.types import TRTTensor
2526

@@ -64,6 +65,8 @@ def convNd(
6465
f"Convolution {name} has bias of type {type(bias)}, Expected Torch Tensor or TRT Tensor"
6566
)
6667

68+
num_output_maps = 0
69+
kernel_shape = ()
6770
# Process weight terms
6871
if isinstance(weight, TRTTensor):
6972
weight = get_trt_tensor(ctx, weight, f"{name}_weight")
@@ -72,23 +75,33 @@ def convNd(
7275
weight = impl.unsqueeze.unsqueeze(
7376
ctx, target, source_ir, weight.name + "_unsqueeze_conv1d", weight, -1
7477
)
78+
num_output_maps = weight.shape[0]
79+
kernel_shape = weight.shape[2:]
7580
elif isinstance(weight, (torch.Tensor, np.ndarray)):
7681
weight = to_torch(weight, dtype=input.dtype)
7782
# Append new dimension (unsqueeze) if the convolution is 1d
7883
if is_conv1d:
7984
weight = torch.unsqueeze(weight, -1)
80-
weight = get_trt_tensor(ctx, weight, f"{name}_weight")
85+
86+
num_output_maps = weight.shape[0]
87+
kernel_shape = weight.shape[2:]
88+
weight = to_trt_weights(weight)
8189

8290
else:
8391
raise RuntimeError(
8492
f"Convolution {name} has weight of type {type(weight)}, Expect Optional[Tensor]"
8593
)
8694

95+
assert (
96+
num_output_maps > 0
97+
), "Number of output channels in convolution must be greater than 0"
98+
assert len(kernel_shape) > 0, "Convolution kernel shape must be non-empty"
99+
87100
# add conv layer
88101
conv_layer = ctx.net.add_convolution_nd(
89102
input=input,
90-
num_output_maps=weight.shape[0],
91-
kernel_shape=weight.shape[2:],
103+
num_output_maps=num_output_maps,
104+
kernel_shape=kernel_shape,
92105
kernel=trt.Weights() if isinstance(weight, TRTTensor) else weight,
93106
bias=trt.Weights() if isinstance(bias, TRTTensor) else bias,
94107
)

py/torch_tensorrt/dynamo/conversion/impl/deconv.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,15 @@
99
from torch_tensorrt.dynamo.conversion import impl
1010
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
1111
from torch_tensorrt.dynamo.conversion.converter_utils import (
12+
SourceIR,
1213
extend_attr_to_tuple,
1314
get_trt_tensor,
15+
has_dynamic_shape,
1416
to_torch,
17+
to_trt_weights,
1518
)
1619
from torch_tensorrt.fx.converters.converter_utils import (
17-
SourceIR,
1820
get_dyn_range,
19-
has_dynamic_shape,
2021
mark_as_int8_layer,
2122
set_layer_name,
2223
)
@@ -40,6 +41,7 @@ def deconvNd(
4041
scale: Optional[Union[torch.Tensor, float]] = None,
4142
zero_point: Optional[Union[torch.Tensor, float]] = None,
4243
) -> TRTTensor:
44+
4345
if has_dynamic_shape(input.shape):
4446
assert input.shape[1] != -1, "Channel dim can't be dynamic for deconvolution."
4547

@@ -64,32 +66,42 @@ def deconvNd(
6466
)
6567

6668
# Process weight terms
69+
num_output_maps = 0
70+
kernel_shape = ()
6771
if isinstance(weight, TRTTensor):
6872
weight = get_trt_tensor(ctx, weight, f"{name}_weight")
6973
# Append new dimension (unsqueeze) if the deconvolution is 1d
7074
if is_deconv1d:
7175
input = impl.unsqueeze.unsqueeze(
7276
ctx, target, source_ir, name + "_unsqueeze_weight", weight, -1
7377
)
78+
num_output_maps = weight.shape[1]
79+
kernel_shape = weight.shape[2:]
7480

7581
elif isinstance(weight, (torch.Tensor, np.ndarray)):
7682
weight = to_torch(weight, dtype=input.dtype)
7783
# Append new dimension (unsqueeze) if the deconvolution is 1d
7884
if is_deconv1d:
7985
weight = torch.unsqueeze(weight, -1)
80-
81-
weight = get_trt_tensor(ctx, weight, f"{name}_weight")
86+
num_output_maps = weight.shape[1]
87+
kernel_shape = weight.shape[2:]
88+
weight = to_trt_weights(weight)
8289

8390
else:
8491
raise RuntimeError(
85-
f"Convolution {name} has weight of type {type(weight)}, Expect Optional[Tensor]"
92+
f"Deconvolution {name} has weight of type {type(weight)}, Expect Optional[Tensor]"
8693
)
8794

95+
assert (
96+
num_output_maps > 0
97+
), "Number of output channels in deconvolution must be greater than 0"
98+
assert len(kernel_shape) > 0, "Deconvolution kernel shape must be non-empty"
99+
88100
# add deconv layer
89101
deconv_layer = ctx.net.add_deconvolution_nd(
90102
input=input,
91-
num_output_maps=weight.shape[1] * groups,
92-
kernel_shape=weight.shape[2:],
103+
num_output_maps=num_output_maps * groups,
104+
kernel_shape=kernel_shape,
93105
kernel=trt.Weights() if isinstance(weight, TRTTensor) else weight,
94106
bias=trt.Weights() if isinstance(bias, TRTTensor) else bias,
95107
)

0 commit comments

Comments
 (0)