Skip to content

Commit 925a532

Browse files
committed
add distributed
1 parent 9a4f89e commit 925a532

File tree

3 files changed

+39
-10
lines changed

3 files changed

+39
-10
lines changed

paddlenlp/quantization/qat_utils.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ def quantize(
6767
if act_scale is not None:
6868
if training:
6969
scale = paddle.max(paddle.abs(target_x)) / qmax
70-
if paddle.distributed.is_initialized():
71-
paddle.distributed.all_reduce(scale, op=paddle.distributed.ReduceOp.MAX)
70+
if group is not None:
71+
paddle.distributed.all_reduce(scale, op=paddle.distributed.ReduceOp.MAX, group=group, sync_op=True)
7272
if state < quantization_config.apply_online_actscale_step:
7373
act_scale.set_value((state * act_scale + scale) / (state + 1))
7474
else:
@@ -97,7 +97,8 @@ def quantize(
9797
scale = scale.squeeze(0) / hadamard_scale
9898
elif weight_quantize_algo in ["fp8linear"]:
9999
scale = paddle.max(paddle.abs(target_x)) / qmax
100-
paddle.distributed.all_reduce(scale, op=paddle.distributed.ReduceOp.MAX)
100+
if group is not None:
101+
paddle.distributed.all_reduce(scale, op=paddle.distributed.ReduceOp.MAX, group=group, sync_op=True)
101102
quant_x = (target_x / scale).astype(quantization_config.fp8_format[tensor_type]).view("int8").T
102103
scale = scale / hadamard_scale
103104
else:
@@ -143,6 +144,7 @@ def int8_forward(
143144
state=0,
144145
training=False,
145146
act_scale=None,
147+
group=None,
146148
):
147149
quant_x, scale_x = quantize(
148150
x=x,
@@ -154,6 +156,7 @@ def int8_forward(
154156
act_scale=act_scale,
155157
state=state,
156158
training=training,
159+
group=group,
157160
)
158161

159162
out = paddle.matmul(quant_x, quant_w.T).astype(scale_w.dtype) * (scale_x * scale_w)
@@ -201,6 +204,7 @@ def fp8_forward(
201204
state=0,
202205
training=False,
203206
act_scale=None,
207+
group=None,
204208
):
205209
x_fp8, x_scale = quantize(
206210
x,
@@ -212,6 +216,7 @@ def fp8_forward(
212216
act_scale=act_scale,
213217
state=state,
214218
training=training,
219+
group=group,
215220
)
216221
x_fp8 = x_fp8.view(quantization_config.fp8_format["activation"])
217222
w_fp8 = w_fp8.view(quantization_config.fp8_format["weight"])
@@ -368,6 +373,7 @@ def forward(
368373
training,
369374
act_scale,
370375
weight_quantize_algo,
376+
group,
371377
):
372378
quant_x, x_scale = None, None
373379
if weight_quantize_algo in ["fp8linear"]:
@@ -382,6 +388,7 @@ def forward(
382388
state=state,
383389
training=training,
384390
act_scale=act_scale,
391+
group=group,
385392
)
386393
else:
387394
output, quant_x, x_scale = int8_forward(
@@ -394,6 +401,7 @@ def forward(
394401
state=state,
395402
training=training,
396403
act_scale=act_scale,
404+
group=group,
397405
)
398406
ctx.quantization_config = quantization_config
399407
ctx.weight_quantize_algo = weight_quantize_algo

paddlenlp/quantization/quantization_linear.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def quant_weight_linear(
209209
):
210210
if weight_quantize_algo in ["a8w8linear", "a8w4linear", "fp8linear"]:
211211

212-
state, training, act_scale = act_state
212+
state, training, act_scale, group = act_state
213213

214214
return QATFunc.apply(
215215
x,
@@ -237,6 +237,14 @@ def quant_weight_linear(
237237
)
238238

239239

240+
def get_act_scale_group(is_row=False):
241+
if paddle.distributed.is_initialized():
242+
group = None
243+
else:
244+
group = None
245+
return group
246+
247+
240248
class QuantizationLinear(nn.Layer):
241249
"""Quantization Linear layer."""
242250

@@ -290,6 +298,7 @@ def __init__(
290298
shape=[], dtype=self._dtype, is_bias=False, default_initializer=nn.initializer.Constant(value=0.0)
291299
)
292300
self.act_scale.stop_gradient = True
301+
self.group = get_act_scale_group()
293302

294303
elif self.weight_quantize_algo in ["fp4", "nf4"]:
295304
if qlora_weight_linear is None:
@@ -349,6 +358,7 @@ def __init__(
349358
for p in self.parameters():
350359
p.is_distributed = is_distributed
351360
p.mp_moe = mp_moe
361+
self.quant_weight.weight_quantize_algo = self.weight_quantize_algo
352362

353363
def forward(self, x):
354364
output = quant_weight_linear(
@@ -363,7 +373,7 @@ def forward(self, x):
363373
if (self.weight_quantize_algo in ["fp4", "nf4"] and self.quantization_config.qlora_weight_double_quant)
364374
else None,
365375
bias=self.bias,
366-
act_state=(self.state, self.training, self.act_scale)
376+
act_state=(self.state, self.training, self.act_scale, self.group)
367377
if self.weight_quantize_algo in ["a8w8linear", "a8w4linear", "fp8linear"]
368378
else None,
369379
)
@@ -455,6 +465,7 @@ def __init__(
455465
)
456466
self.act_scale.is_distributed = True if self.is_mp else False
457467
self.act_scale.stop_gradient = True
468+
self.group = get_act_scale_group()
458469
else:
459470
raise NotImplementedError(f"Not yet support weight_quantize_algo: {self.weight_quantize_algo}")
460471
if bias_attr is False:
@@ -469,6 +480,7 @@ def __init__(
469480
self.bias.is_distributed = True if self.is_mp else False
470481
if self.bias.is_distributed:
471482
self.bias.split_axis = 0
483+
self.quant_weight.weight_quantize_algo = self.weight_quantize_algo
472484

473485
def forward(self, x):
474486
if self.is_mp:
@@ -495,7 +507,7 @@ def forward(self, x):
495507
if (self.weight_quantize_algo in ["fp4", "nf4"] and self.quantization_config.qlora_weight_double_quant)
496508
else None,
497509
bias=self.bias,
498-
act_state=(self.state, self.training, self.act_scale)
510+
act_state=(self.state, self.training, self.act_scale, self.group)
499511
if self.weight_quantize_algo in ["a8w8linear", "a8w4linear", "fp8linear"]
500512
else None,
501513
)
@@ -594,6 +606,7 @@ def __init__(
594606
)
595607
self.act_scale.is_distributed = True if self.is_mp else False
596608
self.act_scale.stop_gradient = True
609+
self.group = get_act_scale_group()
597610
else:
598611
raise NotImplementedError(f"Not yet support weight_quantize_algo: {self.weight_quantize_algo}")
599612

@@ -607,6 +620,8 @@ def __init__(
607620
is_bias=True,
608621
)
609622

623+
self.quant_weight.weight_quantize_algo = self.weight_quantize_algo
624+
610625
def forward(self, x):
611626
if self.input_is_parallel or (not self.is_mp):
612627
input_parallel = x
@@ -628,7 +643,7 @@ def forward(self, x):
628643
if (self.weight_quantize_algo in ["fp4", "nf4"] and self.quantization_config.qlora_weight_double_quant)
629644
else None,
630645
bias=None,
631-
act_state=(self.state, self.training, self.act_scale)
646+
act_state=(self.state, self.training, self.act_scale, self.group)
632647
if self.weight_quantize_algo in ["a8w8linear", "a8w4linear", "fp8linear"]
633648
else None,
634649
)
@@ -656,7 +671,7 @@ def forward(self, x):
656671
if (self.weight_quantize_algo in ["fp4", "nf4"] and self.quantization_config.qlora_weight_double_quant)
657672
else None,
658673
bias=self.bias,
659-
act_state=(self.state, self.training, self.act_scale)
674+
act_state=(self.state, self.training, self.act_scale, self.group)
660675
if self.weight_quantize_algo in ["a8w8linear", "a8w4linear", "fp8linear"]
661676
else None,
662677
)

paddlenlp/utils/optimizer.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -324,15 +324,21 @@ def _append_optimize_op(self, block, param_and_grad):
324324
skip_update_param,
325325
)
326326
if skip_update_param:
327-
if self.quantization_config.weight_quantize_algo in ["a8w8linear", "a8w4linear", "fp8linear"]:
327+
if param.weight_quantize_algo in ["a8w8linear", "a8w4linear", "fp8linear"]:
328+
if "parallel_quantization_linear" not in param.name:
329+
group = None
330+
elif param.weight_quantize_algo in ["a8w8linear", "a8w4linear"] and "row" in param.name:
331+
group = None
332+
else:
333+
group = self.mp_group
328334
param[:], new_quant_scale = quantize(
329335
x=master_weight.astype(quant_scale.dtype),
330336
weight_quantize_algo=self.quantization_config.weight_quantize_algo,
331337
tensor_type="weight",
332338
quantization_config=self.quantization_config,
333339
side="left",
334340
apply_hadamard=self.quantization_config.apply_hadamard,
335-
group=None,
341+
group=group,
336342
)
337343
quant_scale.set_value(new_quant_scale)
338344
else:

0 commit comments

Comments
 (0)