diff --git a/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py b/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py index 975480f390..ec47126c71 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/shuffle.py @@ -25,14 +25,56 @@ def reshape( input: TRTTensor, shape: Sequence[int], ) -> TRTTensor: + # Count dynamic dimensions and check for inferred dimension (-1) + num_dynamic_dims = 0 + has_inferred_dim = False + inferred_dim_index = -1 + + # Create a mutable copy of the shape for modification + new_shape = list(shape) + + # Special case: Handle dynamic shape with inferred dimension (-1) + # This is required for ops like dynamic_block_quantize_op that requires + # dimension to be known at compile time rather than runtime + for i, s in enumerate(new_shape): + if isinstance(s, TRTTensor): + num_dynamic_dims += 1 + elif s == -1: + has_inferred_dim = True + inferred_dim_index = i + + # Only process if we have exactly one dynamic dimension and one inferred dimension + # This is a common pattern in quantization where one dimension is dynamic + # and another needs to be inferred to maintain total element count + if has_inferred_dim and num_dynamic_dims == 1: + # Calculate the inferred dimension size + # Total elements = product of all input dimensions except dynamic shape dim + total_elements = 1 + for s in input.shape: + if s != -1: + total_elements *= s + + # Divide by known dimensions in new_shape to find the inferred dimension + # This ensures the total number of elements remains the same + for s in new_shape: + if isinstance(s, int) and s != -1: + if total_elements % s != 0: + raise ValueError( + f"Cannot infer dimension: {total_elements} elements not divisible by {s}" + ) + total_elements //= s + + # Replace -1 with the calculated inferred dimension + new_shape[inferred_dim_index] = total_elements + layer = ctx.net.add_shuffle(input) - if all(isinstance(s, int) for s in shape): - layer.reshape_dims = tuple(shape) + if all(isinstance(s, int) for s in new_shape): + layer.reshape_dims = tuple(new_shape) else: # Convert all the dimensions to trt Tensors. trt_shape = [] - for i, s in enumerate(shape): + for i, s in enumerate(new_shape): if isinstance(s, TRTTensor): dim_int32 = cast_trt_tensor( ctx,