@@ -696,16 +696,18 @@ def set_state_dict(self, state_dict):
696
696
).cast (dtype )
697
697
698
698
if self .use_weight_only :
699
- (
700
- self .transformer_block .q_a_proj_weights [idx ],
701
- self .transformer_block .q_a_proj_weights_scale [idx ],
702
- ) = weight_quantize (q_a_proj_weight , algo = self .quant_algo , group_size = self .weightonly_group_size )
703
-
704
- (
705
- self .transformer_block .q_b_proj_weights [idx ],
706
- self .transformer_block .q_b_proj_weights_scale [idx ],
707
- ) = weight_quantize (q_b_proj_weight , algo = self .quant_algo , group_size = self .weightonly_group_size )
699
+ q_a_proj_quanted_weight , q_a_proj_weight_scale = weight_quantize (
700
+ q_a_proj_weight , algo = self .quant_algo , group_size = self .weightonly_group_size
701
+ )
702
+ self .transformer_block .q_a_proj_weights [idx ].set_value (q_a_proj_quanted_weight )
703
+ self .transformer_block .q_a_proj_weights_scale [idx ].set_value (q_a_proj_weight_scale )
704
+
705
+ q_b_proj_quanted_weight , q_b_proj_weight_scale = weight_quantize (
706
+ q_b_proj_weight , algo = self .quant_algo , group_size = self .weightonly_group_size
707
+ )
708
+ self .transformer_block .q_b_proj_weights [idx ].set_value (q_b_proj_quanted_weight )
708
709
self .transformer_block .q_a_layernorm_weights [idx ].set_value (q_a_layernorm_weight )
710
+ self .transformer_block .q_b_proj_weights_scale [idx ].set_value (q_b_proj_weight_scale )
709
711
elif "fp8" in self .quant_type :
710
712
q_a_proj_quanted_weight = (
711
713
paddle .to_tensor (
@@ -752,10 +754,11 @@ def set_state_dict(self, state_dict):
752
754
).cast (dtype )
753
755
754
756
if self .use_weight_only :
755
- (
756
- self .transformer_block .q_proj_weights [idx ],
757
- self .transformer_block .q_proj_weights_scale [idx ],
758
- ) = weight_quantize (q_proj_weight , algo = self .quant_algo , group_size = self .weightonly_group_size )
757
+ q_proj_quanted_weight , q_proj_weight_scale = weight_quantize (
758
+ q_proj_weight , algo = self .quant_algo , group_size = self .weightonly_group_size
759
+ )
760
+ self .transformer_block .q_proj_weights [idx ].set_value (q_proj_quanted_weight )
761
+ self .transformer_block .q_proj_weights_scale [idx ].set_value (q_proj_weight_scale )
759
762
elif "fp8" in self .quant_type :
760
763
q_proj_quanted_weight = (
761
764
paddle .to_tensor (state_dict [f"{ self .base_model_prefix } .layers.{ idx } .self_attn.q_proj.weight" ])
@@ -822,18 +825,18 @@ def set_state_dict(self, state_dict):
822
825
self .transformer_block .v_b_proj_weights [idx ].set_value (wv_b )
823
826
824
827
if self .use_weight_only :
825
- (
826
- self .transformer_block .kv_a_proj_with_mqa_weights [idx ],
827
- self .transformer_block .kv_a_proj_with_mqa_weights_scale [idx ],
828
- ) = weight_quantize (
828
+ kv_a_proj_with_mqa_quanted_weight , kv_a_proj_with_mqa_weight_scale = weight_quantize (
829
829
kv_a_proj_with_mqa_weight , algo = self .quant_algo , group_size = self .weightonly_group_size
830
830
)
831
+ self .transformer_block .kv_a_proj_with_mqa_weights [idx ].set_value (kv_a_proj_with_mqa_quanted_weight )
832
+ self .transformer_block .kv_a_proj_with_mqa_weights_scale [idx ].set_value (kv_a_proj_with_mqa_weight_scale )
831
833
832
- (
833
- self .transformer_block . kv_b_proj_weights [ idx ],
834
- self . transformer_block . kv_b_proj_weights_scale [ idx ],
835
- ) = weight_quantize ( kv_b_proj_weight , algo = self .quant_algo , group_size = self . weightonly_group_size )
834
+ kv_b_proj_quanted_weight , kv_b_proj_weight_scale = weight_quantize (
835
+ kv_b_proj_weight , algo = self .quant_algo , group_size = self . weightonly_group_size
836
+ )
837
+ self .transformer_block . kv_b_proj_weights [ idx ]. set_value ( kv_b_proj_quanted_weight )
836
838
self .transformer_block .kv_a_layernorm_weights [idx ].set_value (kv_a_layernorm_weight )
839
+ self .transformer_block .kv_b_proj_weights_scale [idx ].set_value (kv_b_proj_weight_scale )
837
840
elif "fp8" in self .quant_type :
838
841
kv_a_proj_with_mqa_quanted_weight = (
839
842
paddle .to_tensor (
@@ -876,10 +879,11 @@ def set_state_dict(self, state_dict):
876
879
self .transformer_block .kv_b_proj_weights [idx ].set_value (kv_b_proj_weight )
877
880
878
881
if self .use_weight_only :
879
- (
880
- self .transformer_block .linear_weights [idx ],
881
- self .transformer_block .linear_weights_scale [idx ],
882
- ) = weight_quantize (linear_weight , algo = self .quant_algo , group_size = self .weightonly_group_size )
882
+ linear_quanted_weight , linear_weight_scale = weight_quantize (
883
+ linear_weight , algo = self .quant_algo , group_size = self .weightonly_group_size
884
+ )
885
+ self .transformer_block .linear_weights [idx ].set_value (linear_quanted_weight )
886
+ self .transformer_block .linear_weights_scale [idx ].set_value (linear_weight_scale )
883
887
elif "fp8" in self .quant_type :
884
888
linear_quanted_weight = (
885
889
paddle .to_tensor (state_dict [f"{ self .base_model_prefix } .layers.{ idx } .self_attn.o_proj.weight" ])
@@ -915,12 +919,11 @@ def set_state_dict(self, state_dict):
915
919
ffn1_weight_tensor = paddle .to_tensor (concated_ffn1_weight ).cast (paddle .get_default_dtype ())
916
920
917
921
if self .use_weight_only :
918
- (
919
- self .transformer_block .ffn1_weights [idx ],
920
- self .transformer_block .ffn1_weights_scale [idx ],
921
- ) = weight_quantize (
922
+ ffn1_quanted_weight_tensor , ffn1_weight_scale_tensor = weight_quantize (
922
923
ffn1_weight_tensor , algo = self .quant_algo , group_size = self .weightonly_group_size
923
924
)
925
+ self .transformer_block .ffn1_weights [idx ].set_value (ffn1_quanted_weight_tensor )
926
+ self .transformer_block .ffn1_weights_scale [idx ].set_value (ffn1_weight_scale_tensor )
924
927
elif "fp8" in self .quant_type :
925
928
ffn1_quanted_weight_tensor = (
926
929
paddle .to_tensor (concated_ffn1_weight ).transpose ((1 , 0 )).cast (paddle .float8_e4m3fn )
@@ -949,12 +952,11 @@ def set_state_dict(self, state_dict):
949
952
state_dict [f"{ self .base_model_prefix } .layers.{ idx } .mlp.down_proj.weight" ]
950
953
).cast (paddle .get_default_dtype ())
951
954
if self .use_weight_only :
952
- (
953
- self .transformer_block .ffn2_weights [idx ],
954
- self .transformer_block .ffn2_weights_scale [idx ],
955
- ) = weight_quantize (
955
+ ffn2_quanted_weight_tensor , ffn2_weight_scale_tensor = weight_quantize (
956
956
ffn2_weight_tensor , algo = self .quant_algo , group_size = self .weightonly_group_size
957
957
)
958
+ self .transformer_block .ffn2_weights [idx ].set_value (ffn2_quanted_weight_tensor )
959
+ self .transformer_block .ffn2_weights_scale [idx ].set_value (ffn2_weight_scale_tensor )
958
960
elif "fp8" in self .quant_type :
959
961
ffn2_quanted_weight_tensor = (
960
962
paddle .to_tensor (state_dict [f"{ self .base_model_prefix } .layers.{ idx } .mlp.down_proj.weight" ])
@@ -1199,19 +1201,21 @@ def set_state_dict(self, state_dict):
1199
1201
).cast (dtype )
1200
1202
1201
1203
if self .use_weight_only :
1202
- (
1203
- self .transformer_block .shared_expert_ffn1_weights [idx ],
1204
- self .transformer_block .shared_expert_ffn1_weights_scale [idx ],
1205
- ) = weight_quantize (
1204
+ shared_expert_ffn1_quanted_weight , shared_expert_ffn1_weight_scale = weight_quantize (
1206
1205
shared_expert_ffn1_weight , algo = self .quant_algo , group_size = self .weightonly_group_size
1207
1206
)
1207
+ self .transformer_block .shared_expert_ffn1_weights [idx ].set_value (shared_expert_ffn1_quanted_weight )
1208
+ self .transformer_block .shared_expert_ffn1_weights_scale [idx ].set_value (
1209
+ shared_expert_ffn1_weight_scale
1210
+ )
1208
1211
1209
- (
1210
- self .transformer_block .shared_expert_ffn2_weights [idx ],
1211
- self .transformer_block .shared_expert_ffn2_weights_scale [idx ],
1212
- ) = weight_quantize (
1212
+ shared_expert_ffn2_quanted_weight , shared_expert_ffn2_weight_scale = weight_quantize (
1213
1213
shared_expert_ffn2_weight , algo = self .quant_algo , group_size = self .weightonly_group_size
1214
1214
)
1215
+ self .transformer_block .shared_expert_ffn2_weights [idx ].set_value (shared_expert_ffn2_quanted_weight )
1216
+ self .transformer_block .shared_expert_ffn2_weights_scale [idx ].set_value (
1217
+ shared_expert_ffn2_weight_scale
1218
+ )
1215
1219
1216
1220
elif "fp8" in self .quant_type :
1217
1221
shared_expert_ffn1_quanted_weight = (
0 commit comments