@@ -209,7 +209,7 @@ def quant_weight_linear(
209
209
):
210
210
if weight_quantize_algo in ["a8w8linear" , "a8w4linear" , "fp8linear" ]:
211
211
212
- state , training , act_scale = act_state
212
+ state , training , act_scale , group = act_state
213
213
214
214
return QATFunc .apply (
215
215
x ,
@@ -237,6 +237,14 @@ def quant_weight_linear(
237
237
)
238
238
239
239
240
+ def get_act_scale_group (is_row = False ):
241
+ if paddle .distributed .is_initialized ():
242
+ group = None
243
+ else :
244
+ group = None
245
+ return group
246
+
247
+
240
248
class QuantizationLinear (nn .Layer ):
241
249
"""Quantization Linear layer."""
242
250
@@ -290,6 +298,7 @@ def __init__(
290
298
shape = [], dtype = self ._dtype , is_bias = False , default_initializer = nn .initializer .Constant (value = 0.0 )
291
299
)
292
300
self .act_scale .stop_gradient = True
301
+ self .group = get_act_scale_group ()
293
302
294
303
elif self .weight_quantize_algo in ["fp4" , "nf4" ]:
295
304
if qlora_weight_linear is None :
@@ -349,6 +358,7 @@ def __init__(
349
358
for p in self .parameters ():
350
359
p .is_distributed = is_distributed
351
360
p .mp_moe = mp_moe
361
+ self .quant_weight .weight_quantize_algo = self .weight_quantize_algo
352
362
353
363
def forward (self , x ):
354
364
output = quant_weight_linear (
@@ -363,7 +373,7 @@ def forward(self, x):
363
373
if (self .weight_quantize_algo in ["fp4" , "nf4" ] and self .quantization_config .qlora_weight_double_quant )
364
374
else None ,
365
375
bias = self .bias ,
366
- act_state = (self .state , self .training , self .act_scale )
376
+ act_state = (self .state , self .training , self .act_scale , self . group )
367
377
if self .weight_quantize_algo in ["a8w8linear" , "a8w4linear" , "fp8linear" ]
368
378
else None ,
369
379
)
@@ -455,6 +465,7 @@ def __init__(
455
465
)
456
466
self .act_scale .is_distributed = True if self .is_mp else False
457
467
self .act_scale .stop_gradient = True
468
+ self .group = get_act_scale_group ()
458
469
else :
459
470
raise NotImplementedError (f"Not yet support weight_quantize_algo: { self .weight_quantize_algo } " )
460
471
if bias_attr is False :
@@ -469,6 +480,7 @@ def __init__(
469
480
self .bias .is_distributed = True if self .is_mp else False
470
481
if self .bias .is_distributed :
471
482
self .bias .split_axis = 0
483
+ self .quant_weight .weight_quantize_algo = self .weight_quantize_algo
472
484
473
485
def forward (self , x ):
474
486
if self .is_mp :
@@ -495,7 +507,7 @@ def forward(self, x):
495
507
if (self .weight_quantize_algo in ["fp4" , "nf4" ] and self .quantization_config .qlora_weight_double_quant )
496
508
else None ,
497
509
bias = self .bias ,
498
- act_state = (self .state , self .training , self .act_scale )
510
+ act_state = (self .state , self .training , self .act_scale , self . group )
499
511
if self .weight_quantize_algo in ["a8w8linear" , "a8w4linear" , "fp8linear" ]
500
512
else None ,
501
513
)
@@ -594,6 +606,7 @@ def __init__(
594
606
)
595
607
self .act_scale .is_distributed = True if self .is_mp else False
596
608
self .act_scale .stop_gradient = True
609
+ self .group = get_act_scale_group ()
597
610
else :
598
611
raise NotImplementedError (f"Not yet support weight_quantize_algo: { self .weight_quantize_algo } " )
599
612
@@ -607,6 +620,8 @@ def __init__(
607
620
is_bias = True ,
608
621
)
609
622
623
+ self .quant_weight .weight_quantize_algo = self .weight_quantize_algo
624
+
610
625
def forward (self , x ):
611
626
if self .input_is_parallel or (not self .is_mp ):
612
627
input_parallel = x
@@ -628,7 +643,7 @@ def forward(self, x):
628
643
if (self .weight_quantize_algo in ["fp4" , "nf4" ] and self .quantization_config .qlora_weight_double_quant )
629
644
else None ,
630
645
bias = None ,
631
- act_state = (self .state , self .training , self .act_scale )
646
+ act_state = (self .state , self .training , self .act_scale , self . group )
632
647
if self .weight_quantize_algo in ["a8w8linear" , "a8w4linear" , "fp8linear" ]
633
648
else None ,
634
649
)
@@ -656,7 +671,7 @@ def forward(self, x):
656
671
if (self .weight_quantize_algo in ["fp4" , "nf4" ] and self .quantization_config .qlora_weight_double_quant )
657
672
else None ,
658
673
bias = self .bias ,
659
- act_state = (self .state , self .training , self .act_scale )
674
+ act_state = (self .state , self .training , self .act_scale , self . group )
660
675
if self .weight_quantize_algo in ["a8w8linear" , "a8w4linear" , "fp8linear" ]
661
676
else None ,
662
677
)
0 commit comments