Skip to content

Commit 1375af8

Browse files
author
Yashwant Bezawada
committed
Fix empty tensor shape issue in DynamicCache for torch.compile
Fixes #42027 This commit fixes a regression where torch.cat receives incorrectly shaped empty tensors during model tracing with torch.compile. The issue was introduced in commit dc11a3c where empty cache tensors were initialized as 1D tensors with shape [0] using torch.tensor([]). When these are concatenated with 4D key/value tensors [batch, heads, seq, dim] along dim=-2, torch.compile's tracing fails. Changes: - Modified DynamicLayer.lazy_initialization() to create properly shaped 4D empty tensors [batch, heads, 0, dim] instead of 1D [0] - Modified QuantizedLayer.update() to reset cache with proper 4D shape - Used torch.zeros() with explicit shape matching key_states dimensions This ensures torch.cat operations work correctly in both eager and compiled modes.
1 parent bb65d2d commit 1375af8

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

src/transformers/cache_utils.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,11 @@ class DynamicLayer(CacheLayerMixin):
9191

9292
def lazy_initialization(self, key_states: torch.Tensor):
9393
self.dtype, self.device = key_states.dtype, key_states.device
94-
self.keys = torch.tensor([], dtype=self.dtype, device=self.device)
95-
self.values = torch.tensor([], dtype=self.dtype, device=self.device)
94+
# Initialize with proper 4D shape: [batch_size, num_heads, 0, head_dim]
95+
# This ensures torch.cat works correctly in torch.compile mode
96+
batch_size, num_heads, _, head_dim = key_states.shape
97+
self.keys = torch.zeros((batch_size, num_heads, 0, head_dim), dtype=self.dtype, device=self.device)
98+
self.values = torch.zeros((batch_size, num_heads, 0, head_dim), dtype=self.dtype, device=self.device)
9699
self.is_initialized = True
97100

98101
def update(
@@ -545,8 +548,14 @@ def update(
545548
if self.keys.dim() == 4 and self.keys.shape[-2] + 1 >= self.residual_length:
546549
self._quantized_keys = self._quantize(keys_to_return.contiguous(), axis=self.axis_key)
547550
self._quantized_values = self._quantize(values_to_return.contiguous(), axis=self.axis_value)
548-
self.keys = torch.tensor([], dtype=key_states.dtype, device=key_states.device)
549-
self.values = torch.tensor([], dtype=key_states.dtype, device=key_states.device)
551+
# Reset to proper 4D empty tensors to ensure torch.cat works correctly in torch.compile mode
552+
batch_size, num_heads, _, head_dim = key_states.shape
553+
self.keys = torch.zeros(
554+
(batch_size, num_heads, 0, head_dim), dtype=key_states.dtype, device=key_states.device
555+
)
556+
self.values = torch.zeros(
557+
(batch_size, num_heads, 0, head_dim), dtype=key_states.dtype, device=key_states.device
558+
)
550559
else:
551560
self.keys = torch.cat([self.keys, key_states], dim=-2)
552561
self.values = torch.cat([self.values, value_states], dim=-2)

0 commit comments

Comments
 (0)