Skip to content

Commit 5ef34dc

Browse files
author
evian
committed
[KV Cache] fix ci
Signed-off-by: evian <eviantai@u.nus.edu>
1 parent 954009a commit 5ef34dc

File tree

2 files changed

+26
-16
lines changed

2 files changed

+26
-16
lines changed

src/llmcompressor/modifiers/quantization/cache.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -94,13 +94,14 @@ def update(
9494
_pad_and_append_at_idx_(self.k_observers, layer_idx, k_observer)
9595
_pad_and_append_at_idx_(self.v_observers, layer_idx, v_observer)
9696

97-
# reshape for per channel scenario
98-
num_heads = key_states.shape[1]
99-
head_dim = key_states.shape[-1]
100-
# from [batch_size, num_heads, seq_len - residual_length, head_dim]
101-
# to [batch_size, seq_len - residual_length, num_heads * head_dim]
102-
key_states = key_states.transpose(1, 2).flatten(2)
103-
value_states = value_states.transpose(1, 2).flatten(2)
97+
if key_states.dim() == 4:
98+
# reshape for per channel scenario
99+
num_heads = key_states.shape[1]
100+
head_dim = key_states.shape[-1]
101+
# from [batch_size, num_heads, seq_len - residual_length, head_dim]
102+
# to [batch_size, seq_len - residual_length, num_heads * head_dim]
103+
key_states = key_states.transpose(1, 2).flatten(2)
104+
value_states = value_states.transpose(1, 2).flatten(2)
104105

105106
q_key_states = self._quantize(
106107
key_states.contiguous(), KVCacheScaleType.KEY, layer_idx
@@ -114,13 +115,18 @@ def update(
114115
q_value_states, KVCacheScaleType.VALUE, layer_idx
115116
)
116117

117-
# reshape for per channel scenario
118-
# from [batch_size, seq_len - residual_length, num_heads * head_dim]
119-
# to [batch_size, num_heads, seq_len - residual_length, head_dim]
120-
qdq_key_states = qdq_key_states.view(
121-
qdq_key_states.shape[0], qdq_key_states.shape[1], num_heads, head_dim).transpose(1, 2)
122-
qdq_value_states = qdq_value_states.view(
123-
qdq_value_states.shape[0], qdq_value_states.shape[1], num_heads, head_dim).transpose(1, 2)
118+
if key_states.dim() == 4:
119+
# reshape for per channel scenario
120+
# from [batch_size, seq_len - residual_length, num_heads * head_dim]
121+
# to [batch_size, num_heads, seq_len - residual_length, head_dim]
122+
qdq_key_states = qdq_key_states.view(
123+
qdq_key_states.shape[0], qdq_key_states.shape[1],
124+
num_heads, head_dim
125+
).transpose(1, 2).contiguous()
126+
qdq_value_states = qdq_value_states.view(
127+
qdq_value_states.shape[0], qdq_value_states.shape[1],
128+
num_heads, head_dim
129+
).transpose(1, 2).contiguous()
124130

125131
keys_to_return, values_to_return = qdq_key_states, qdq_value_states
126132

src/llmcompressor/observers/base.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,9 +183,13 @@ def get_qparams(
183183
elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL:
184184
# 1. dim=2 scenario: in kv cache quant scenario which is
185185
# [batch_size, seq_len - residual_length, num_heads * head_dim]
186-
# 2. dim=0 scenario: assume observed is transposed, because its the output, hence use dim 0
186+
# 2. dim=0 scenario: assume observed is transposed,
187+
# because its the output, hence use dim 0
187188
dim = 2 if observed.dim() == 3 else 0
188-
self._scale, self._zero_point = self.get_qparams_along_dim(observed, dim)
189+
self._scale, self._zero_point = self.get_qparams_along_dim(
190+
observed,
191+
dim
192+
)
189193

190194
elif self.quantization_args.strategy == QuantizationStrategy.TOKEN:
191195
# use dim 1, assume the obsersed.shape = [batch, token, hidden]

0 commit comments

Comments
 (0)