Skip to content

Commit 8d0a39a

Browse files
committed
support group-wise weight quant for qwen2 and change cpu kernel to gpu kernel
1 parent e16f4e2 commit 8d0a39a

File tree

2 files changed

+61
-45
lines changed

2 files changed

+61
-45
lines changed

paddlenlp/experimental/transformers/deepseek_v2/modeling.py

Lines changed: 45 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -696,16 +696,18 @@ def set_state_dict(self, state_dict):
696696
).cast(dtype)
697697

698698
if self.use_weight_only:
699-
(
700-
self.transformer_block.q_a_proj_weights[idx],
701-
self.transformer_block.q_a_proj_weights_scale[idx],
702-
) = weight_quantize(q_a_proj_weight, algo=self.quant_algo, group_size=self.weightonly_group_size)
703-
704-
(
705-
self.transformer_block.q_b_proj_weights[idx],
706-
self.transformer_block.q_b_proj_weights_scale[idx],
707-
) = weight_quantize(q_b_proj_weight, algo=self.quant_algo, group_size=self.weightonly_group_size)
699+
q_a_proj_quanted_weight, q_a_proj_weight_scale = weight_quantize(
700+
q_a_proj_weight, algo=self.quant_algo, group_size=self.weightonly_group_size
701+
)
702+
self.transformer_block.q_a_proj_weights[idx].set_value(q_a_proj_quanted_weight)
703+
self.transformer_block.q_a_proj_weights_scale[idx].set_value(q_a_proj_weight_scale)
704+
705+
q_b_proj_quanted_weight, q_b_proj_weight_scale = weight_quantize(
706+
q_b_proj_weight, algo=self.quant_algo, group_size=self.weightonly_group_size
707+
)
708+
self.transformer_block.q_b_proj_weights[idx].set_value(q_b_proj_quanted_weight)
708709
self.transformer_block.q_a_layernorm_weights[idx].set_value(q_a_layernorm_weight)
710+
self.transformer_block.q_b_proj_weights_scale[idx].set_value(q_b_proj_weight_scale)
709711
elif "fp8" in self.quant_type:
710712
q_a_proj_quanted_weight = (
711713
paddle.to_tensor(
@@ -752,10 +754,11 @@ def set_state_dict(self, state_dict):
752754
).cast(dtype)
753755

754756
if self.use_weight_only:
755-
(
756-
self.transformer_block.q_proj_weights[idx],
757-
self.transformer_block.q_proj_weights_scale[idx],
758-
) = weight_quantize(q_proj_weight, algo=self.quant_algo, group_size=self.weightonly_group_size)
757+
q_proj_quanted_weight, q_proj_weight_scale = weight_quantize(
758+
q_proj_weight, algo=self.quant_algo, group_size=self.weightonly_group_size
759+
)
760+
self.transformer_block.q_proj_weights[idx].set_value(q_proj_quanted_weight)
761+
self.transformer_block.q_proj_weights_scale[idx].set_value(q_proj_weight_scale)
759762
elif "fp8" in self.quant_type:
760763
q_proj_quanted_weight = (
761764
paddle.to_tensor(state_dict[f"{self.base_model_prefix}.layers.{idx}.self_attn.q_proj.weight"])
@@ -822,18 +825,18 @@ def set_state_dict(self, state_dict):
822825
self.transformer_block.v_b_proj_weights[idx].set_value(wv_b)
823826

824827
if self.use_weight_only:
825-
(
826-
self.transformer_block.kv_a_proj_with_mqa_weights[idx],
827-
self.transformer_block.kv_a_proj_with_mqa_weights_scale[idx],
828-
) = weight_quantize(
828+
kv_a_proj_with_mqa_quanted_weight, kv_a_proj_with_mqa_weight_scale = weight_quantize(
829829
kv_a_proj_with_mqa_weight, algo=self.quant_algo, group_size=self.weightonly_group_size
830830
)
831+
self.transformer_block.kv_a_proj_with_mqa_weights[idx].set_value(kv_a_proj_with_mqa_quanted_weight)
832+
self.transformer_block.kv_a_proj_with_mqa_weights_scale[idx].set_value(kv_a_proj_with_mqa_weight_scale)
831833

832-
(
833-
self.transformer_block.kv_b_proj_weights[idx],
834-
self.transformer_block.kv_b_proj_weights_scale[idx],
835-
) = weight_quantize(kv_b_proj_weight, algo=self.quant_algo, group_size=self.weightonly_group_size)
834+
kv_b_proj_quanted_weight, kv_b_proj_weight_scale = weight_quantize(
835+
kv_b_proj_weight, algo=self.quant_algo, group_size=self.weightonly_group_size
836+
)
837+
self.transformer_block.kv_b_proj_weights[idx].set_value(kv_b_proj_quanted_weight)
836838
self.transformer_block.kv_a_layernorm_weights[idx].set_value(kv_a_layernorm_weight)
839+
self.transformer_block.kv_b_proj_weights_scale[idx].set_value(kv_b_proj_weight_scale)
837840
elif "fp8" in self.quant_type:
838841
kv_a_proj_with_mqa_quanted_weight = (
839842
paddle.to_tensor(
@@ -876,10 +879,11 @@ def set_state_dict(self, state_dict):
876879
self.transformer_block.kv_b_proj_weights[idx].set_value(kv_b_proj_weight)
877880

878881
if self.use_weight_only:
879-
(
880-
self.transformer_block.linear_weights[idx],
881-
self.transformer_block.linear_weights_scale[idx],
882-
) = weight_quantize(linear_weight, algo=self.quant_algo, group_size=self.weightonly_group_size)
882+
linear_quanted_weight, linear_weight_scale = weight_quantize(
883+
linear_weight, algo=self.quant_algo, group_size=self.weightonly_group_size
884+
)
885+
self.transformer_block.linear_weights[idx].set_value(linear_quanted_weight)
886+
self.transformer_block.linear_weights_scale[idx].set_value(linear_weight_scale)
883887
elif "fp8" in self.quant_type:
884888
linear_quanted_weight = (
885889
paddle.to_tensor(state_dict[f"{self.base_model_prefix}.layers.{idx}.self_attn.o_proj.weight"])
@@ -915,12 +919,11 @@ def set_state_dict(self, state_dict):
915919
ffn1_weight_tensor = paddle.to_tensor(concated_ffn1_weight).cast(paddle.get_default_dtype())
916920

917921
if self.use_weight_only:
918-
(
919-
self.transformer_block.ffn1_weights[idx],
920-
self.transformer_block.ffn1_weights_scale[idx],
921-
) = weight_quantize(
922+
ffn1_quanted_weight_tensor, ffn1_weight_scale_tensor = weight_quantize(
922923
ffn1_weight_tensor, algo=self.quant_algo, group_size=self.weightonly_group_size
923924
)
925+
self.transformer_block.ffn1_weights[idx].set_value(ffn1_quanted_weight_tensor)
926+
self.transformer_block.ffn1_weights_scale[idx].set_value(ffn1_weight_scale_tensor)
924927
elif "fp8" in self.quant_type:
925928
ffn1_quanted_weight_tensor = (
926929
paddle.to_tensor(concated_ffn1_weight).transpose((1, 0)).cast(paddle.float8_e4m3fn)
@@ -949,12 +952,11 @@ def set_state_dict(self, state_dict):
949952
state_dict[f"{self.base_model_prefix}.layers.{idx}.mlp.down_proj.weight"]
950953
).cast(paddle.get_default_dtype())
951954
if self.use_weight_only:
952-
(
953-
self.transformer_block.ffn2_weights[idx],
954-
self.transformer_block.ffn2_weights_scale[idx],
955-
) = weight_quantize(
955+
ffn2_quanted_weight_tensor, ffn2_weight_scale_tensor = weight_quantize(
956956
ffn2_weight_tensor, algo=self.quant_algo, group_size=self.weightonly_group_size
957957
)
958+
self.transformer_block.ffn2_weights[idx].set_value(ffn2_quanted_weight_tensor)
959+
self.transformer_block.ffn2_weights_scale[idx].set_value(ffn2_weight_scale_tensor)
958960
elif "fp8" in self.quant_type:
959961
ffn2_quanted_weight_tensor = (
960962
paddle.to_tensor(state_dict[f"{self.base_model_prefix}.layers.{idx}.mlp.down_proj.weight"])
@@ -1199,19 +1201,21 @@ def set_state_dict(self, state_dict):
11991201
).cast(dtype)
12001202

12011203
if self.use_weight_only:
1202-
(
1203-
self.transformer_block.shared_expert_ffn1_weights[idx],
1204-
self.transformer_block.shared_expert_ffn1_weights_scale[idx],
1205-
) = weight_quantize(
1204+
shared_expert_ffn1_quanted_weight, shared_expert_ffn1_weight_scale = weight_quantize(
12061205
shared_expert_ffn1_weight, algo=self.quant_algo, group_size=self.weightonly_group_size
12071206
)
1207+
self.transformer_block.shared_expert_ffn1_weights[idx].set_value(shared_expert_ffn1_quanted_weight)
1208+
self.transformer_block.shared_expert_ffn1_weights_scale[idx].set_value(
1209+
shared_expert_ffn1_weight_scale
1210+
)
12081211

1209-
(
1210-
self.transformer_block.shared_expert_ffn2_weights[idx],
1211-
self.transformer_block.shared_expert_ffn2_weights_scale[idx],
1212-
) = weight_quantize(
1212+
shared_expert_ffn2_quanted_weight, shared_expert_ffn2_weight_scale = weight_quantize(
12131213
shared_expert_ffn2_weight, algo=self.quant_algo, group_size=self.weightonly_group_size
12141214
)
1215+
self.transformer_block.shared_expert_ffn2_weights[idx].set_value(shared_expert_ffn2_quanted_weight)
1216+
self.transformer_block.shared_expert_ffn2_weights_scale[idx].set_value(
1217+
shared_expert_ffn2_weight_scale
1218+
)
12151219

12161220
elif "fp8" in self.quant_type:
12171221
shared_expert_ffn1_quanted_weight = (

paddlenlp/experimental/transformers/qwen2/modeling.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,12 +109,15 @@ def __init__(self, config: Qwen2Config, base_model_prefix: str):
109109
self.use_fake_parameter = config.get("use_fake_parameter", False)
110110

111111
self.use_weight_only = False
112+
self.weightonly_group_size = -1
112113
if config.quant_type == "weight_only_int8":
113114
self.use_weight_only = True
114115
self.quant_algo = "weight_only_int8"
116+
self.weightonly_group_size = config.weightonly_group_size
115117
elif config.quant_type == "weight_only_int4":
116118
self.use_weight_only = True
117119
self.quant_algo = "weight_only_int4"
120+
self.weightonly_group_size = config.weightonly_group_size
118121
elif "a8w8" in config.quant_type:
119122
self.quant_model_path = config.model_name_or_path
120123
self.shift = config.quantization_config.shift
@@ -312,6 +315,7 @@ def __init__(self, config: Qwen2Config, base_model_prefix: str):
312315
kv_num_heads=self.num_key_value_heads,
313316
intermediate_size=self.intermediate_size,
314317
quant_type=self.quant_type,
318+
weightonly_group_size=self.weightonly_group_size,
315319
activation="swiglu",
316320
num_layers=config.num_hidden_layers,
317321
tp_degree=config.tensor_parallel_degree,
@@ -663,7 +667,9 @@ def set_state_dict(self, state_dict):
663667

664668
if self.use_weight_only:
665669
qkv_weight = paddle.transpose(qkv_weight, perm=[1, 0])
666-
qkv_quanted_weight, qkv_weight_scale = weight_quantize(qkv_weight, algo=self.quant_algo)
670+
qkv_quanted_weight, qkv_weight_scale = weight_quantize(
671+
qkv_weight, algo=self.quant_algo, group_size=self.weightonly_group_size
672+
)
667673
self.transformer_block.qkv_weights[idx].set_value(qkv_quanted_weight)
668674
self.transformer_block.qkv_weights_scale[idx].set_value(qkv_weight_scale)
669675
elif "fp8" in self.quant_type:
@@ -701,7 +707,9 @@ def set_state_dict(self, state_dict):
701707
paddle.get_default_dtype()
702708
)
703709
if self.use_weight_only:
704-
linear_quanted_weight, linear_weight_scale = weight_quantize(linear_weight, algo=self.quant_algo)
710+
linear_quanted_weight, linear_weight_scale = weight_quantize(
711+
linear_weight, algo=self.quant_algo, group_size=self.weightonly_group_size
712+
)
705713
self.transformer_block.linear_weights[idx].set_value(linear_quanted_weight)
706714
self.transformer_block.linear_weights_scale[idx].set_value(linear_weight_scale)
707715
elif "fp8" in self.quant_type:
@@ -758,7 +766,9 @@ def set_state_dict(self, state_dict):
758766
ffn1_weight = paddle.to_tensor(concated_ffn1_weight).cast(paddle.get_default_dtype())
759767

760768
if self.use_weight_only:
761-
ffn1_quanted_weight, ffn1_weight_scale = weight_quantize(ffn1_weight, algo=self.quant_algo)
769+
ffn1_quanted_weight, ffn1_weight_scale = weight_quantize(
770+
ffn1_weight, algo=self.quant_algo, group_size=self.weightonly_group_size
771+
)
762772
self.transformer_block.ffn1_weights[idx].set_value(ffn1_quanted_weight)
763773
self.transformer_block.ffn1_weights_scale[idx].set_value(ffn1_weight_scale)
764774
elif "fp8" in self.quant_type:
@@ -795,7 +805,9 @@ def set_state_dict(self, state_dict):
795805
paddle.get_default_dtype()
796806
)
797807
if self.use_weight_only:
798-
ffn2_quanted_weight, ffn2_weight_scale = weight_quantize(ffn2_weight, algo=self.quant_algo)
808+
ffn2_quanted_weight, ffn2_weight_scale = weight_quantize(
809+
ffn2_weight, algo=self.quant_algo, group_size=self.weightonly_group_size
810+
)
799811
self.transformer_block.ffn2_weights[idx].set_value(ffn2_quanted_weight)
800812
self.transformer_block.ffn2_weights_scale[idx].set_value(ffn2_weight_scale)
801813
elif "fp8" in self.quant_type:

0 commit comments

Comments
 (0)