15
15
import paddle
16
16
import paddle .nn as nn
17
17
from paddle .autograd import PyLayer
18
+ from paddle .distributed import fleet
18
19
from paddle .distributed .fleet .base import topology as tp
19
20
from paddle .distributed .fleet .layers .mpu import mp_ops
20
21
from paddle .distributed .fleet .utils .sequence_parallel_utils import (
23
24
)
24
25
from paddle .nn .quant import llm_int8_linear , weight_dequantize , weight_only_linear
25
26
27
+ from paddlenlp .utils import infohub
28
+
26
29
from .qat_utils import QATFunc
27
30
28
31
try :
@@ -222,6 +225,7 @@ def quant_weight_linear(
222
225
training ,
223
226
act_scale ,
224
227
weight_quantize_algo ,
228
+ group ,
225
229
)
226
230
else :
227
231
return QuantizationLinearFunc .apply (
@@ -238,10 +242,15 @@ def quant_weight_linear(
238
242
239
243
240
244
def get_act_scale_group (is_row = False ):
241
- if paddle .distributed .is_initialized ():
242
- group = None
245
+ if not paddle .distributed .is_initialized () or not is_row :
246
+ return None
247
+
248
+ if getattr (infohub , "scale_group" ) is None :
249
+ hcg = fleet .get_hybrid_communicate_group ()
250
+ group = hcg .get_model_parallel_group ()
251
+ setattr (infohub , "scale_group" , group )
243
252
else :
244
- group = None
253
+ group = infohub . scale_group
245
254
return group
246
255
247
256
@@ -606,7 +615,7 @@ def __init__(
606
615
)
607
616
self .act_scale .is_distributed = True if self .is_mp else False
608
617
self .act_scale .stop_gradient = True
609
- self .group = get_act_scale_group ()
618
+ self .group = get_act_scale_group (is_row = True )
610
619
else :
611
620
raise NotImplementedError (f"Not yet support weight_quantize_algo: { self .weight_quantize_algo } " )
612
621
0 commit comments