From 1375af8ace8cd82653344b87b7315f6b537cad6e Mon Sep 17 00:00:00 2001 From: Yashwant Bezawada Date: Wed, 5 Nov 2025 20:03:02 -0600 Subject: [PATCH] Fix empty tensor shape issue in DynamicCache for torch.compile Fixes #42027 This commit fixes a regression where torch.cat receives incorrectly shaped empty tensors during model tracing with torch.compile. The issue was introduced in commit dc11a3cbb2c where empty cache tensors were initialized as 1D tensors with shape [0] using torch.tensor([]). When these are concatenated with 4D key/value tensors [batch, heads, seq, dim] along dim=-2, torch.compile's tracing fails. Changes: - Modified DynamicLayer.lazy_initialization() to create properly shaped 4D empty tensors [batch, heads, 0, dim] instead of 1D [0] - Modified QuantizedLayer.update() to reset cache with proper 4D shape - Used torch.zeros() with explicit shape matching key_states dimensions This ensures torch.cat operations work correctly in both eager and compiled modes. --- src/transformers/cache_utils.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 28f40952f2cd..6043f5ddeb4d 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -91,8 +91,11 @@ class DynamicLayer(CacheLayerMixin): def lazy_initialization(self, key_states: torch.Tensor): self.dtype, self.device = key_states.dtype, key_states.device - self.keys = torch.tensor([], dtype=self.dtype, device=self.device) - self.values = torch.tensor([], dtype=self.dtype, device=self.device) + # Initialize with proper 4D shape: [batch_size, num_heads, 0, head_dim] + # This ensures torch.cat works correctly in torch.compile mode + batch_size, num_heads, _, head_dim = key_states.shape + self.keys = torch.zeros((batch_size, num_heads, 0, head_dim), dtype=self.dtype, device=self.device) + self.values = torch.zeros((batch_size, num_heads, 0, head_dim), dtype=self.dtype, device=self.device) self.is_initialized = True def update( @@ -545,8 +548,14 @@ def update( if self.keys.dim() == 4 and self.keys.shape[-2] + 1 >= self.residual_length: self._quantized_keys = self._quantize(keys_to_return.contiguous(), axis=self.axis_key) self._quantized_values = self._quantize(values_to_return.contiguous(), axis=self.axis_value) - self.keys = torch.tensor([], dtype=key_states.dtype, device=key_states.device) - self.values = torch.tensor([], dtype=key_states.dtype, device=key_states.device) + # Reset to proper 4D empty tensors to ensure torch.cat works correctly in torch.compile mode + batch_size, num_heads, _, head_dim = key_states.shape + self.keys = torch.zeros( + (batch_size, num_heads, 0, head_dim), dtype=key_states.dtype, device=key_states.device + ) + self.values = torch.zeros( + (batch_size, num_heads, 0, head_dim), dtype=key_states.dtype, device=key_states.device + ) else: self.keys = torch.cat([self.keys, key_states], dim=-2) self.values = torch.cat([self.values, value_states], dim=-2)