Skip to content

Commit 14a244a

Browse files
shanjiazmgoinkylesayrs
authored
Support DeepSeekV3-style block FP8 quantization (clean) (#1675)
SUMMARY: Fixes [1475](#1475) This was originally pr [#1607](#1607), the commit history got messy. I cherry picked Michael's original commit 451219a and updated from there. TEST PLAN: Tested locally and generated the model. --------- Signed-off-by: mgoin <michael@neuralmagic.com> Signed-off-by: shanjiaz <zsjwpianpian@gmail.com> Co-authored-by: mgoin <michael@neuralmagic.com> Co-authored-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 0851638 commit 14a244a

File tree

5 files changed

+119
-11
lines changed

5 files changed

+119
-11
lines changed

docs/guides/compression_schemes.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ PTQ is performed to reduce the precision of quantizable weights (e.g., linear la
1919
- Useful for speed ups in high QPS regimes or offline serving on vLLM.
2020
- Recommended for NVIDIA GPUs with compute capability >=9.0 (Hopper and Blackwell).
2121

22+
### [W8A8-FP8_BLOCK](../examples/quantization_w8a8_fp8/fp8_block_example.py)
23+
- Uses block-wise quantization to compress weights to FP8 in (commonly 128×128 tiles), and dynamic per-token-group (128) quantization for activations. Does not require calibration dataset. Activation quantization is carried out during inference on vLLM.
24+
2225
## Sparsification
2326
Sparsification reduces model complexity by pruning selected weight values to zero while retaining essential weights in a subset of parameters. Supported formats include:
2427

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from transformers import AutoModelForCausalLM, AutoTokenizer
2+
3+
from llmcompressor import oneshot
4+
from llmcompressor.modifiers.quantization import QuantizationModifier
5+
6+
MODEL_ID = "Qwen/Qwen3-0.6B"
7+
8+
# Load model.
9+
model = AutoModelForCausalLM.from_pretrained(
10+
MODEL_ID, device_map="auto", torch_dtype="auto"
11+
)
12+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
13+
14+
# Configure the quantization algorithm and scheme.
15+
# In this case, we:
16+
# * quantize the weights to fp8 with per channel via ptq
17+
# * quantize the activations to fp8 with dynamic per token
18+
recipe = QuantizationModifier(
19+
targets="Linear", scheme="FP8_BLOCK", ignore=["lm_head"]
20+
)
21+
22+
# Apply quantization.
23+
oneshot(model=model, recipe=recipe)
24+
25+
# Confirm generations of the quantized model look sane.
26+
print("========== SAMPLE GENERATION ==============")
27+
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
28+
output = model.generate(input_ids, max_new_tokens=20)
29+
print(tokenizer.decode(output[0]))
30+
print("==========================================")
31+
32+
# Save to disk in compressed-tensors format.
33+
SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8-BLOCK"
34+
model.save_pretrained(SAVE_DIR)
35+
tokenizer.save_pretrained(SAVE_DIR)

src/llmcompressor/modifiers/quantization/calibration.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
)
1111
from compressed_tensors.quantization.lifecycle.forward import forward_quantize
1212
from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
13-
from compressed_tensors.utils import align_module_device, update_parameter_data
13+
from compressed_tensors.utils import align_module_device, update_offload_parameter
1414
from loguru import logger
1515
from torch.nn import Module
1616

@@ -116,16 +116,19 @@ def call_observer(
116116
value,
117117
should_calculate_gparam=True,
118118
)
119-
update_parameter_data(module, global_scale, f"{base_name}_global_scale")
119+
update_offload_parameter(module, f"{base_name}_global_scale", global_scale)
120120
else:
121121
global_scale = getattr(module, f"{base_name}_global_scale", None)
122122

123123
if should_calculate_qparams:
124124
updated_scale, updated_zero_point = observer(
125125
value, g_idx=g_idx, global_scale=global_scale
126126
)
127-
update_parameter_data(module, updated_scale, f"{base_name}_scale")
128-
update_parameter_data(module, updated_zero_point, f"{base_name}_zero_point")
127+
# register or update scale & zero_point parameters (supports block shapes)
128+
scale_name = f"{base_name}_scale"
129+
zp_name = f"{base_name}_zero_point"
130+
update_offload_parameter(module, scale_name, updated_scale)
131+
update_offload_parameter(module, zp_name, updated_zero_point)
129132

130133

131134
def update_weight_global_scale(module: Module):
@@ -256,8 +259,8 @@ def calibrate_kv_cache_output_hook(module: Module, _args: Any, _output: torch.Te
256259
kv_cache = getattr(module, "kv_cache")
257260
k_scale = kv_cache.k_scales[module.layer_idx]
258261
v_scale = kv_cache.v_scales[module.layer_idx]
259-
update_parameter_data(module, k_scale, KVCacheScaleType.KEY.value)
260-
update_parameter_data(module, v_scale, KVCacheScaleType.VALUE.value)
262+
update_offload_parameter(module, KVCacheScaleType.KEY.value, k_scale)
263+
update_offload_parameter(module, KVCacheScaleType.VALUE.value, v_scale)
261264

262265

263266
def initialize_quantized_kv_cache(module: Module):

src/llmcompressor/observers/base.py

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -193,13 +193,52 @@ def get_qparams(
193193
)
194194

195195
elif self.quantization_args.strategy == QuantizationStrategy.BLOCK:
196-
# TODO (#1475) add support for block-wise quantization
197-
raise NotImplementedError(
198-
"Block-wise quantization is not yet supported, "
199-
"consider group-wise quantization instead. More info at "
200-
"https://github.com/vllm-project/llm-compressor/issues/1475"
196+
# Block-wise quantization: one scale/zero_point per block of shape
197+
# [block_rows, block_cols]
198+
rows, cols = observed.shape[:2]
199+
bs = self.quantization_args.block_structure
200+
if not (
201+
isinstance(bs, (list, tuple))
202+
and len(bs) == 2
203+
and all(isinstance(x, int) for x in bs)
204+
):
205+
raise ValueError(
206+
f"Invalid block_structure '{bs}'. "
207+
f"Must be a list of two ints [rows, cols]."
208+
)
209+
block_rows, block_cols = bs
210+
num_br = int(ceil(rows / block_rows))
211+
num_bc = int(ceil(cols / block_cols))
212+
213+
# allocate per-block scale and zero_point
214+
self._scale = torch.empty(
215+
(num_br, num_bc), dtype=observed.dtype, device=observed.device
216+
)
217+
218+
# Use same dtype logic as GROUP strategy for zero_point
219+
if is_fp4(quantization_args=self.quantization_args):
220+
zp_dtype = FP8_E4M3_DATA.dtype
221+
else:
222+
zp_dtype = self.quantization_args.pytorch_dtype()
223+
224+
self._zero_point = torch.empty(
225+
(num_br, num_bc), dtype=zp_dtype, device=observed.device
201226
)
202227

228+
# compute qparams for each block
229+
for i in range(num_br):
230+
r0 = i * block_rows
231+
r1 = min((i + 1) * block_rows, rows)
232+
for j in range(num_bc):
233+
c0 = j * block_cols
234+
c1 = min((j + 1) * block_cols, cols)
235+
# reduce across both dims to get one scale and zp per block
236+
scale_bp, zp_bp = self.calculate_qparams(
237+
observed[r0:r1, c0:c1], reduce_dims=(0, 1)
238+
)
239+
self._scale[i, j] = scale_bp
240+
self._zero_point[i, j] = zp_bp
241+
203242
return self._scale, self._zero_point
204243

205244
def get_qparams_along_dim(

tests/llmcompressor/modifiers/quantization/test_base.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,34 @@ def q_config_kwargs(config_0, config_1):
3535
)
3636

3737

38+
@pytest.fixture
39+
def block_q_config_kwargs():
40+
return dict(
41+
config_groups=dict(
42+
group_block=dict(
43+
targets=["Linear"],
44+
input_activations=dict(
45+
num_bits=8, symmetric=True, strategy="group", group_size=128
46+
),
47+
weights=dict(
48+
num_bits=8,
49+
symmetric=True,
50+
strategy="block",
51+
block_structure=[128, 128],
52+
),
53+
),
54+
)
55+
)
56+
57+
58+
def test_block_strategy_parsing(block_q_config_kwargs):
59+
modifier = GPTQModifier(**block_q_config_kwargs)
60+
resolved = modifier.resolve_quantization_config()
61+
w_scheme = resolved.config_groups["group_block"].weights
62+
assert w_scheme.strategy == "block"
63+
assert w_scheme.block_structure == [128, 128]
64+
65+
3866
@pytest.mark.parametrize(
3967
"has_actorder,actorder,config_0,config_1,expected_0,expected_1",
4068
[

0 commit comments

Comments
 (0)