Skip to content

Commit 92cdf63

Browse files
add DeepseekV3 AWQ mapping (#1619)
SUMMARY: Add AWQ activation-smooth mapping for `DeepseekV3ForCausalLM`. TEST PLAN: [examples/quantizing_moe/deepseek_r1_example.py](./examples/quantizing_moe/deepseek_r1_example.py) but recipe adapted to use `AWQModifier` instead: ```python from datasets import load_dataset from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from llmcompressor.modeling import prepare_for_calibration from llmcompressor.modifiers.awq import AWQModifier from llmcompressor.transformers import oneshot # Select model and load it. # This script takes about 48 hours on 1xA100 to complete. # Future improvements will reduce this runtime (#1561, #1558). # For DeepSeek-R1, we require a full precision model in order to properly calibrate # `DeepSeek-R1-0528-BF16` is a DeepSeek-V3 FP8 model which has been converted to BF16 model_id = "unsloth/DeepSeek-R1-0528-BF16" config = AutoConfig.from_pretrained(model_id) del config.quantization_config # fp8 qconfig no longer appplies to bf16 model model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype="auto", config=config ) tokenizer = AutoTokenizer.from_pretrained(model_id) model = prepare_for_calibration(model) # Select calibration dataset. DATASET_ID = "HuggingFaceH4/ultrachat_200k" DATASET_SPLIT = "train_sft" # Select number of samples. 512 samples is a good place to start. # Increasing the number of samples can improve accuracy. NUM_CALIBRATION_SAMPLES = 512 MAX_SEQUENCE_LENGTH = 2048 # Load dataset and preprocess. ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") ds = ds.shuffle(seed=42) def preprocess(example): return { "text": tokenizer.apply_chat_template( example["messages"], tokenize=False, ) } ds = ds.map(preprocess) # Tokenize inputs. def tokenize(sample): return tokenizer( sample["text"], padding=False, max_length=MAX_SEQUENCE_LENGTH, truncation=True, add_special_tokens=False, ) ds = ds.map(tokenize, remove_columns=ds.column_names) # Configure the quantization algorithm to run. # since the MoE gate layers are sensitive to quantization, we add them to the ignore # list so they remain at full precision recipe = AWQModifier( targets="Linear", scheme="W4A16", ignore=["lm_head", "re:.*mlp.gate$"] ) # Apply algorithms. # due to the large size of DeepSeekV3, we specify sequential targets such that # only one MLP is loaded into GPU memory at a time oneshot( model=model, dataset=ds, recipe=recipe, max_seq_length=MAX_SEQUENCE_LENGTH, num_calibration_samples=NUM_CALIBRATION_SAMPLES, sequential_targets=["DeepseekV3Attention", "DeepseekV3MLP"], ) # Save to disk compressed. SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W4A16-G128" model.save_pretrained(SAVE_DIR, save_compressed=True) tokenizer.save_pretrained(SAVE_DIR) ``` --------- Signed-off-by: cjackal <44624812+cjackal@users.noreply.github.com> Signed-off-by: Brian Dellabetta <bdellabe@redhat.com> Co-authored-by: Brian Dellabetta <brian-dellabetta@users.noreply.github.com> Co-authored-by: Brian Dellabetta <bdellabe@redhat.com>
1 parent 9c44074 commit 92cdf63

File tree

3 files changed

+26
-5
lines changed

3 files changed

+26
-5
lines changed

src/llmcompressor/modeling/llama4.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from typing import Tuple
22

33
import torch
4-
from transformers.models import Llama4Config
5-
from transformers.models.llama4.configuration_llama4 import Llama4TextConfig
4+
from transformers.models.llama4.configuration_llama4 import (
5+
Llama4Config,
6+
Llama4TextConfig,
7+
)
68
from transformers.models.llama4.modeling_llama4 import (
79
Llama4TextExperts,
810
Llama4TextMLP,

src/llmcompressor/modifiers/awq/base.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -465,11 +465,13 @@ def _apply_smoothing(self, model: Module) -> None:
465465
# Calculates the relative magnitude of the weights within
466466
# each of the quantization groups, and rescales each group
467467
# individually so that each group has weights on a 0-1 scale.
468-
w_scale = weight.abs() / (weight.abs().amax(dim=1, keepdim=True) + 1e-6)
468+
weight.abs_()
469+
weight.div_(weight.amax(dim=1, keepdim=True) + 1e-6)
469470
# Resizes the rescaled weight matrix back up to its original dimensions
470-
w_scale = w_scale.view(org_shape)
471+
weight = weight.view(org_shape)
471472
# Gets the average rescaled magnitude for each output channel
472-
w_mean = w_scale.mean(0)
473+
w_mean = weight.mean(0)
474+
del weight
473475

474476
with calibration_forward_context(model), HooksMixin.disable_hooks():
475477
# [STEP 3]: Compute output of module

src/llmcompressor/modifiers/awq/mappings.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,26 @@ class AWQMapping:
116116
),
117117
]
118118

119+
# DeepseekV3
120+
_deepseek_mappings = [
121+
AWQMapping(
122+
"re:.*input_layernorm$",
123+
# Some models use q_proj instead of q_a_proj
124+
["re:.*(q|q_a)_proj$", "re:.*kv_a_proj_with_mqa$"],
125+
),
126+
AWQMapping("re:.*q_a_layernorm$", ["re:.*q_b_proj$"]),
127+
AWQMapping("re:.*kv_a_layernorm$", ["re:.*kv_b_proj$"]),
128+
AWQMapping(
129+
"re:.*post_attention_layernorm$",
130+
["re:.*gate_proj$", "re:.*up_proj$"],
131+
),
132+
AWQMapping("re:.*up_proj$", ["re:.*down_proj$"]),
133+
]
134+
119135
AWQ_MAPPING_REGISTRY: Dict[str, list[AWQMapping]] = {
120136
"CohereForCausalLM": _cohere_mappings,
121137
"Cohere2ForCausalLM": _cohere_mappings,
138+
"DeepseekV3ForCausalLM": _deepseek_mappings,
122139
"Gemma2ForCausalLM": _gemma_mappings,
123140
"Gemma3ForCausalLM": _gemma_mappings,
124141
"Gemma3ForConditionalGeneration": _gemma_mappings,

0 commit comments

Comments
 (0)