@@ -94,13 +94,14 @@ def update(
94
94
_pad_and_append_at_idx_ (self .k_observers , layer_idx , k_observer )
95
95
_pad_and_append_at_idx_ (self .v_observers , layer_idx , v_observer )
96
96
97
- # reshape for per channel scenario
98
- num_heads = key_states .shape [1 ]
99
- head_dim = key_states .shape [- 1 ]
100
- # from [batch_size, num_heads, seq_len - residual_length, head_dim]
101
- # to [batch_size, seq_len - residual_length, num_heads * head_dim]
102
- key_states = key_states .transpose (1 , 2 ).flatten (2 )
103
- value_states = value_states .transpose (1 , 2 ).flatten (2 )
97
+ if key_states .dim () == 4 :
98
+ # reshape for per channel scenario
99
+ num_heads = key_states .shape [1 ]
100
+ head_dim = key_states .shape [- 1 ]
101
+ # from [batch_size, num_heads, seq_len - residual_length, head_dim]
102
+ # to [batch_size, seq_len - residual_length, num_heads * head_dim]
103
+ key_states = key_states .transpose (1 , 2 ).flatten (2 )
104
+ value_states = value_states .transpose (1 , 2 ).flatten (2 )
104
105
105
106
q_key_states = self ._quantize (
106
107
key_states .contiguous (), KVCacheScaleType .KEY , layer_idx
@@ -114,13 +115,18 @@ def update(
114
115
q_value_states , KVCacheScaleType .VALUE , layer_idx
115
116
)
116
117
117
- # reshape for per channel scenario
118
- # from [batch_size, seq_len - residual_length, num_heads * head_dim]
119
- # to [batch_size, num_heads, seq_len - residual_length, head_dim]
120
- qdq_key_states = qdq_key_states .view (
121
- qdq_key_states .shape [0 ], qdq_key_states .shape [1 ], num_heads , head_dim ).transpose (1 , 2 )
122
- qdq_value_states = qdq_value_states .view (
123
- qdq_value_states .shape [0 ], qdq_value_states .shape [1 ], num_heads , head_dim ).transpose (1 , 2 )
118
+ if key_states .dim () == 4 :
119
+ # reshape for per channel scenario
120
+ # from [batch_size, seq_len - residual_length, num_heads * head_dim]
121
+ # to [batch_size, num_heads, seq_len - residual_length, head_dim]
122
+ qdq_key_states = qdq_key_states .view (
123
+ qdq_key_states .shape [0 ], qdq_key_states .shape [1 ],
124
+ num_heads , head_dim
125
+ ).transpose (1 , 2 ).contiguous ()
126
+ qdq_value_states = qdq_value_states .view (
127
+ qdq_value_states .shape [0 ], qdq_value_states .shape [1 ],
128
+ num_heads , head_dim
129
+ ).transpose (1 , 2 ).contiguous ()
124
130
125
131
keys_to_return , values_to_return = qdq_key_states , qdq_value_states
126
132
0 commit comments