Skip to content

Commit 5d13e2b

Browse files
committed
kv-cache int8 quant
Signed-off-by: George Ohashi <george@neuralmagic.com>
1 parent 3d19401 commit 5d13e2b

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

src/llmcompressor/modifiers/quantization/cache.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,15 @@ def _quantize(self, tensor, kv_type, layer_idx):
151151
scales = self.v_scales
152152
zps = self.v_zps
153153

154-
# tensor
155-
scale, zp = observer(tensor)
154+
# note: key, value states are in the shape:
155+
# [batch, num_key_value_heads, seq_len, head_dim]
156+
157+
base_name = None # tensor-wise quantization, shape of [1]
158+
if self.quantization_args.strategy == "channel":
159+
# target last dim to quantize, shape of [head_dim]
160+
base_name = "kv_cache"
161+
162+
scale, zp = observer(tensor, base_name=base_name)
156163
if len(scales) <= layer_idx:
157164
scales.append(scale)
158165
zps.append(zp)

src/llmcompressor/observers/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def get_qparams(
128128
self._zero_point[:, group_index] = zero_point.squeeze(1)
129129

130130
elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL:
131-
if base_name == "output":
131+
if base_name in ("output", "kv_cache"):
132132
# the last dimension is the hidden dimension
133133
# shape of [1,1, num_key_value_heads * head_dim]
134134
scale, zero_point = self.get_qparams_along_dim(

0 commit comments

Comments
 (0)