Skip to content

fix: Inferred dimensions at build time in reshape #3746

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 45 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/impl/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,56 @@ def reshape(
input: TRTTensor,
shape: Sequence[int],
) -> TRTTensor:
# Count dynamic dimensions and check for inferred dimension (-1)
num_dynamic_dims = 0
has_inferred_dim = False
inferred_dim_index = -1

# Create a mutable copy of the shape for modification
new_shape = list(shape)

# Special case: Handle dynamic shape with inferred dimension (-1)
# This is required for ops like dynamic_block_quantize_op that requires
# dimension to be known at compile time rather than runtime
for i, s in enumerate(new_shape):
if isinstance(s, TRTTensor):
num_dynamic_dims += 1
elif s == -1:
has_inferred_dim = True
inferred_dim_index = i

# Only process if we have exactly one dynamic dimension and one inferred dimension
# This is a common pattern in quantization where one dimension is dynamic
# and another needs to be inferred to maintain total element count
if has_inferred_dim and num_dynamic_dims == 1:
# Calculate the inferred dimension size
# Total elements = product of all input dimensions except dynamic shape dim
total_elements = 1
for s in input.shape:
if s != -1:
total_elements *= s

# Divide by known dimensions in new_shape to find the inferred dimension
# This ensures the total number of elements remains the same
for s in new_shape:
if isinstance(s, int) and s != -1:
if total_elements % s != 0:
raise ValueError(
f"Cannot infer dimension: {total_elements} elements not divisible by {s}"
)
total_elements //= s

# Replace -1 with the calculated inferred dimension
new_shape[inferred_dim_index] = total_elements

layer = ctx.net.add_shuffle(input)
if all(isinstance(s, int) for s in shape):
layer.reshape_dims = tuple(shape)
if all(isinstance(s, int) for s in new_shape):
layer.reshape_dims = tuple(new_shape)
else:
# Convert all the dimensions to trt Tensors.
trt_shape = []

for i, s in enumerate(shape):
for i, s in enumerate(new_shape):
if isinstance(s, TRTTensor):
dim_int32 = cast_trt_tensor(
ctx,
Expand Down
Loading