Skip to content

Commit 336eb31

Browse files
committed
add new
1 parent 925a532 commit 336eb31

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

paddlenlp/quantization/quantization_linear.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import paddle
1616
import paddle.nn as nn
1717
from paddle.autograd import PyLayer
18+
from paddle.distributed import fleet
1819
from paddle.distributed.fleet.base import topology as tp
1920
from paddle.distributed.fleet.layers.mpu import mp_ops
2021
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
@@ -23,6 +24,8 @@
2324
)
2425
from paddle.nn.quant import llm_int8_linear, weight_dequantize, weight_only_linear
2526

27+
from paddlenlp.utils import infohub
28+
2629
from .qat_utils import QATFunc
2730

2831
try:
@@ -222,6 +225,7 @@ def quant_weight_linear(
222225
training,
223226
act_scale,
224227
weight_quantize_algo,
228+
group,
225229
)
226230
else:
227231
return QuantizationLinearFunc.apply(
@@ -238,10 +242,15 @@ def quant_weight_linear(
238242

239243

240244
def get_act_scale_group(is_row=False):
241-
if paddle.distributed.is_initialized():
242-
group = None
245+
if not paddle.distributed.is_initialized() or not is_row:
246+
return None
247+
248+
if getattr(infohub, "scale_group") is None:
249+
hcg = fleet.get_hybrid_communicate_group()
250+
group = hcg.get_model_parallel_group()
251+
setattr(infohub, "scale_group", group)
243252
else:
244-
group = None
253+
group = infohub.scale_group
245254
return group
246255

247256

@@ -606,7 +615,7 @@ def __init__(
606615
)
607616
self.act_scale.is_distributed = True if self.is_mp else False
608617
self.act_scale.stop_gradient = True
609-
self.group = get_act_scale_group()
618+
self.group = get_act_scale_group(is_row=True)
610619
else:
611620
raise NotImplementedError(f"Not yet support weight_quantize_algo: {self.weight_quantize_algo}")
612621

0 commit comments

Comments
 (0)