Skip to content

Commit fe356c3

Browse files
committed
add hadamard
1 parent 43ec0ac commit fe356c3

File tree

5 files changed

+193
-42
lines changed

5 files changed

+193
-42
lines changed

paddlenlp/quantization/hadamard_utils.py

Lines changed: 80 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,49 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
# import paddle
16+
17+
18+
# def matmul_hadU(X):
19+
20+
# input = X.clone().reshape((-1, X.shape[-1], 1))
21+
# output = input.clone()
22+
# while input.shape[1] > 1:
23+
# input = input.reshape((input.shape[0], input.shape[1] // 2, 2, input.shape[2]))
24+
# output = output.reshape(input.shape)
25+
# output[:, :, 0, :] = input[:, :, 0, :] + input[:, :, 1, :]
26+
# output[:, :, 1, :] = input[:, :, 0, :] - input[:, :, 1, :]
27+
# output = output.reshape((input.shape[0], input.shape[1], -1))
28+
# (input, output) = (output, input)
29+
# del output
30+
31+
# return input.reshape(X.shape)
32+
33+
34+
# def random_hadamard_matrix(size, dtype, is_block=False):
35+
# if not is_block:
36+
# A = paddle.randint(low=0, high=2, shape=[size, size]).astype("float32") * 2 - 1
37+
# Q, _ = paddle.linalg.qr(A)
38+
# return Q.astype(dtype), 1
39+
# else:
40+
# num_blocks = size
41+
# while not (num_blocks % 2):
42+
# num_blocks = num_blocks // 2
43+
# block_size = size // num_blocks
44+
# Q = paddle.diag(paddle.ones((block_size,), dtype="float32"))
45+
# block = matmul_hadU(Q)
46+
# large_matrix = paddle.zeros([size, size])
47+
48+
# for i in range(num_blocks):
49+
# start_row = i * block_size
50+
# start_col = i * block_size
51+
# large_matrix[start_row : start_row + block_size, start_col : start_col + block_size] = block
52+
# return large_matrix.cast(dtype), block_size
53+
1554
import paddle
1655

56+
from paddlenlp.utils import infohub
57+
1758

1859
def matmul_hadU(X):
1960

@@ -31,22 +72,43 @@ def matmul_hadU(X):
3172
return input.reshape(X.shape)
3273

3374

34-
def random_hadamard_matrix(size, dtype, is_block=False):
35-
if not is_block:
36-
A = paddle.randint(low=0, high=2, shape=[size, size]).astype("float32") * 2 - 1
37-
Q, _ = paddle.linalg.qr(A)
38-
return Q.astype(dtype), 1
75+
def random_hadamard_matrix(block_size, dtype):
76+
Q = paddle.diag(paddle.ones((block_size), dtype=dtype))
77+
block = matmul_hadU(Q)
78+
return block
79+
80+
81+
def create_hadamard_matrix(block_size, dtype):
82+
Q = paddle.diag(paddle.ones((block_size), dtype=dtype))
83+
block = matmul_hadU(Q)
84+
return block
85+
86+
87+
def hadamard_matmul(input, side, hadamard_matrix, block_size):
88+
# left -> H.T@input right -> input@H
89+
origin_shape = input.shape
90+
input = input.reshape([-1, origin_shape[-1]])
91+
if side == "left":
92+
# H.T@input -> (input.T@H).T
93+
input = input.transpose([1, 0])
94+
block_num = input.shape[-1] // block_size
95+
output = input.reshape([-1, block_num, block_size]) @ hadamard_matrix
96+
output = output.reshape([-1, block_num * block_size])
97+
if side == "left":
98+
output = output.transpose([1, 0])
99+
output = output.reshape(origin_shape)
100+
101+
return output
102+
103+
104+
def apply_hadamard_matmul(x, side, block_size):
105+
if getattr(infohub, "hadamard") is None:
106+
setattr(infohub, "hadamard", {})
107+
108+
if block_size in infohub.hadamard:
109+
hadamard_matrix = infohub.hadamard[block_size]
39110
else:
40-
num_blocks = size
41-
while not (num_blocks % 2):
42-
num_blocks = num_blocks // 2
43-
block_size = size // num_blocks
44-
Q = paddle.diag(paddle.ones((block_size,), dtype="float32"))
45-
block = matmul_hadU(Q)
46-
large_matrix = paddle.zeros([size, size])
47-
48-
for i in range(num_blocks):
49-
start_row = i * block_size
50-
start_col = i * block_size
51-
large_matrix[start_row : start_row + block_size, start_col : start_col + block_size] = block
52-
return large_matrix.cast(dtype), block_size
111+
hadamard_matrix = create_hadamard_matrix(block_size, x.dtype)
112+
infohub.hadamard[block_size] = hadamard_matrix
113+
target_x = hadamard_matmul(x, side, hadamard_matrix, block_size)
114+
return target_x, block_size

paddlenlp/quantization/qat_utils.py

Lines changed: 90 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from paddlenlp.utils import infohub
2121

22-
from .hadamard_utils import random_hadamard_matrix
22+
from .hadamard_utils import apply_hadamard_matmul, random_hadamard_matrix
2323

2424
try:
2525
from transformer_engine import transformer_engine_paddle as tex
@@ -35,6 +35,13 @@
3535
except ImportError:
3636
USE_FP8_GEMM = False
3737

38+
QMIN_QMAX_MAPPING = {
39+
"a8w8linear_activation": (-128, 127),
40+
"a8w4linear_activation": (-128, 127),
41+
"a8w8linear_weight": (-128, 127),
42+
"a8w4linear_weight": (-8, 7),
43+
}
44+
3845

3946
def quantize_tensorwise(x, quantization_config=None, bit_length=8, state=0, training=False, act_scale=None):
4047
qmax = (1 << (bit_length - 1)) - 1
@@ -154,16 +161,87 @@ def dequantize_channelwise(w_int8, scale, apply_hadamard=False):
154161
return w
155162

156163

157-
def a8w8_forward(
158-
x, w_int8, w_scale=None, bias=None, dtype=None, quantization_config=None, state=0, training=False, act_scale=None
164+
def quantize(
165+
x,
166+
weight_quantize_algo,
167+
tensor_type,
168+
quantization_config,
169+
apply_hadamard=False,
170+
side="right",
171+
act_scale=None,
172+
state=0,
173+
training=False,
174+
group=None,
175+
):
176+
if apply_hadamard:
177+
target_x, hadamard_scale = apply_hadamard_matmul(x, side, quantization_config.hadamard_block_size)
178+
else:
179+
target_x = x
180+
hadamard_scale = 1
181+
qmin, qmax = QMIN_QMAX_MAPPING[weight_quantize_algo + "_" + tensor_type]
182+
if tensor_type == "activation":
183+
if act_scale is not None:
184+
if training:
185+
scale = paddle.max(paddle.abs(target_x)) / qmax
186+
if state < quantization_config.apply_online_actscale_step:
187+
act_scale.set_value((state * act_scale + scale) / (state + 1))
188+
else:
189+
act_scale.set_value(
190+
(1 - quantization_config.moving_rate) * act_scale + quantization_config.moving_rate * scale
191+
)
192+
scale = act_scale
193+
else:
194+
# scale = act_scale
195+
scale = paddle.max(paddle.abs(target_x)) / qmax
196+
else:
197+
scale = paddle.max(paddle.abs(target_x)) / qmax
198+
if weight_quantize_algo in ["a8w8linear", "a8w4linear"]:
199+
quant_x = paddle.clip((target_x / scale).round(), qmin, qmax).astype("int8")
200+
else:
201+
raise NotImplementedError(f"Unknown {weight_quantize_algo}.")
202+
elif tensor_type == "weight":
203+
if weight_quantize_algo in ["a8w8linear", "a8w4linear"]:
204+
# channelwise
205+
scale = paddle.max(paddle.abs(target_x), axis=0, keepdim=True) / qmax
206+
if group is not None:
207+
paddle.distributed.all_reduce(scale, op=paddle.distributed.ReduceOp.MAX, group=group, sync_op=True)
208+
quant_x = paddle.clip((target_x / scale).round(), qmin, qmax).astype("int8").T
209+
scale.stop_gradient = True
210+
scale = scale.squeeze(0) / hadamard_scale
211+
else:
212+
raise NotImplementedError(f"Unknown {weight_quantize_algo}.")
213+
else:
214+
raise NotImplementedError(f"Unknown {tensor_type}.")
215+
return quant_x, scale
216+
217+
218+
def int8_forward(
219+
x,
220+
quant_w,
221+
scale_w,
222+
weight_quantize_algo,
223+
bias=None,
224+
quantization_config=None,
225+
state=0,
226+
training=False,
227+
act_scale=None,
159228
):
160-
x_int8, x_scale = quantize_tensorwise(
161-
x, quantization_config, bit_length=8, state=state, training=training, act_scale=act_scale
229+
quant_x, scale_x = quantize(
230+
x=x,
231+
weight_quantize_algo=weight_quantize_algo,
232+
tensor_type="activation",
233+
quantization_config=quantization_config,
234+
apply_hadamard=quantization_config.apply_hadamard,
235+
side="right",
236+
act_scale=act_scale,
237+
state=state,
238+
training=training,
162239
)
163-
out = paddle.matmul(x_int8, w_int8.T).astype(dtype) * (x_scale * w_scale.unsqueeze(0))
240+
241+
out = paddle.matmul(quant_x, quant_w.T).astype(scale_w.dtype) * (scale_x * scale_w)
164242
if bias is not None:
165243
out += bias
166-
return out, x_int8, x_scale
244+
return out, quant_x, scale_x
167245

168246

169247
def a8w8_backward(ctx, x, grad_output, quant_weight, quant_scale, quant_x, x_scale):
@@ -352,6 +430,7 @@ def forward(
352430
state,
353431
training,
354432
act_scale,
433+
weight_quantize_algo,
355434
):
356435
quant_x, x_scale = None, None
357436
if quantization_config.weight_quantize_algo in ["fp8linear"]:
@@ -367,12 +446,12 @@ def forward(
367446
act_scale=act_scale,
368447
)
369448
else:
370-
output, quant_x, x_scale = a8w8_forward(
449+
output, quant_x, x_scale = int8_forward(
371450
x,
372-
quant_weight,
373-
w_scale=quant_scale,
451+
quant_w=quant_weight,
452+
scale_w=quant_scale,
453+
weight_quantize_algo=weight_quantize_algo,
374454
bias=bias,
375-
dtype=dtype,
376455
quantization_config=quantization_config,
377456
state=state,
378457
training=training,

paddlenlp/quantization/quantization_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def __init__(
6565
ignore_modules=None,
6666
group_size=-1,
6767
apply_hadamard=False,
68+
hadamard_block_size=32,
6869
quant_input_grad=False,
6970
quant_weight_grad=False,
7071
skip_first_act_scale_step=20,
@@ -139,6 +140,7 @@ def __init__(
139140
self.ignore_modules = ignore_modules
140141
self.group_size = group_size
141142
self.apply_hadamard = apply_hadamard
143+
self.hadamard_block_size = hadamard_block_size
142144
self.quant_input_grad = quant_input_grad
143145
self.quant_weight_grad = quant_weight_grad
144146
self.skip_first_act_scale_step = skip_first_act_scale_step

paddlenlp/quantization/quantization_linear.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,16 @@ def quant_weight_linear(
212212
state, training, act_scale = act_state
213213

214214
return QATFunc.apply(
215-
x, quant_weight, bias, quant_scale, quantization_config, dtype, state, training, act_scale
215+
x,
216+
quant_weight,
217+
bias,
218+
quant_scale,
219+
quantization_config,
220+
dtype,
221+
state,
222+
training,
223+
act_scale,
224+
weight_quantize_algo,
216225
)
217226
else:
218227
return QuantizationLinearFunc.apply(

paddlenlp/quantization/quantization_utils.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
qlora_weight_quantize = None
3535

3636
from ..utils.log import logger
37-
from .qat_utils import fp8_quantize_tensorwise, quantize_channelwise
37+
from .qat_utils import fp8_quantize_tensorwise, quantize
3838
from .quantization_linear import (
3939
ColumnParallelQuantizationLinear,
4040
QuantizationLinear,
@@ -155,18 +155,17 @@ def convert_to_weight_quantize_state_dict(state_dict, name, quantization_config,
155155
if weight_name in state_dict:
156156
# gpu weight_quantize will fix in future
157157
target_weight = state_dict.pop(weight_name).cast(dtype).cuda()
158-
if weight_quantize_algo in ["a8w8linear"]:
159-
quant_weight, quant_scale = quantize_channelwise(
160-
target_weight, quantization_config.apply_hadamard, bit_length=8
161-
)
162-
act_scale = paddle.zeros([], dtype="bfloat16").cuda()
163-
act_scale.stop_gradient = True
164-
state_dict[act_scale_name] = act_scale
165-
elif weight_quantize_algo in ["a8w4linear"]:
166-
quant_weight, quant_scale = quantize_channelwise(
167-
target_weight, quantization_config.apply_hadamard, bit_length=4
158+
159+
if weight_quantize_algo in ["a8w8linear", "a8w4linear"]:
160+
quant_weight, quant_scale = quantize(
161+
target_weight,
162+
weight_quantize_algo,
163+
"weight",
164+
quantization_config,
165+
apply_hadamard=quantization_config.apply_hadamard,
166+
side="left",
168167
)
169-
act_scale = paddle.zeros([], dtype="bfloat16").cuda()
168+
act_scale = paddle.ones([], dtype=dtype).cuda()
170169
act_scale.stop_gradient = True
171170
state_dict[act_scale_name] = act_scale
172171
elif weight_quantize_algo in ["fp8linear"]:

0 commit comments

Comments
 (0)