Skip to content

Commit ad548d3

Browse files
authored
[llm] support tensorwise fp8/int8 training (#10612)
1 parent fb80d7d commit ad548d3

18 files changed

+812
-527
lines changed

llm/run_finetune.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,13 @@ def main():
166166
qlora_weight_blocksize=model_args.qlora_weight_blocksize,
167167
qlora_weight_double_quant=model_args.qlora_weight_double_quant,
168168
qlora_weight_double_quant_block_size=model_args.qlora_weight_double_quant_block_size,
169+
apply_hadamard=model_args.apply_hadamard,
170+
hadamard_block_size=model_args.hadamard_block_size,
171+
quant_input_grad=model_args.quant_input_grad,
172+
quant_weight_grad=model_args.quant_weight_grad,
173+
apply_online_actscale_step=model_args.apply_online_actscale_step,
174+
actscale_moving_rate=model_args.actscale_moving_rate,
175+
fp8_format_type=model_args.fp8_format_type,
169176
)
170177

171178
model_config = AutoConfig.from_pretrained(
@@ -447,7 +454,9 @@ def compute_metrics_do_generation(eval_preds):
447454
gen_args=gen_args,
448455
data_args=data_args,
449456
)
450-
trainable_parameters = [p for p in model.parameters() if not p.stop_gradient]
457+
trainable_parameters = [
458+
p for p in model.parameters() if not p.stop_gradient or ("quantization_linear" in p.name and "w_1" in p.name)
459+
]
451460
trainer.set_optimizer_grouped_parameters(trainable_parameters)
452461
if model_args.lorapro:
453462
optimizer = AdamWLoRAPro(

paddlenlp/quantization/hadamard_utils.py

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
import paddle
1616

17+
from paddlenlp.utils import infohub
18+
1719

1820
def matmul_hadU(X):
1921

@@ -31,22 +33,37 @@ def matmul_hadU(X):
3133
return input.reshape(X.shape)
3234

3335

34-
def random_hadamard_matrix(size, dtype, is_block=False):
35-
if not is_block:
36-
A = paddle.randint(low=0, high=2, shape=[size, size]).astype("float32") * 2 - 1
37-
Q, _ = paddle.linalg.qr(A)
38-
return Q.astype(dtype), 1
36+
def create_hadamard_matrix(block_size, dtype):
37+
Q = paddle.diag(paddle.ones((block_size), dtype=dtype))
38+
block = matmul_hadU(Q)
39+
return block
40+
41+
42+
def hadamard_matmul(input, side, hadamard_matrix, block_size):
43+
# left -> H.T@input right -> input@H
44+
origin_shape = input.shape
45+
input = input.reshape([-1, origin_shape[-1]])
46+
if side == "left":
47+
# H.T@input -> (input.T@H).T
48+
input = input.transpose([1, 0])
49+
block_num = input.shape[-1] // block_size
50+
output = input.reshape([-1, block_num, block_size]) @ hadamard_matrix
51+
output = output.reshape([-1, block_num * block_size])
52+
if side == "left":
53+
output = output.transpose([1, 0])
54+
output = output.reshape(origin_shape)
55+
56+
return output
57+
58+
59+
def apply_hadamard_matmul(x, side, block_size):
60+
if getattr(infohub, "hadamard") is None:
61+
setattr(infohub, "hadamard", {})
62+
63+
if block_size in infohub.hadamard:
64+
hadamard_matrix = infohub.hadamard[block_size]
3965
else:
40-
num_blocks = size
41-
while not (num_blocks % 2):
42-
num_blocks = num_blocks // 2
43-
block_size = size // num_blocks
44-
Q = paddle.diag(paddle.ones((block_size,), dtype="float32"))
45-
block = matmul_hadU(Q)
46-
large_matrix = paddle.zeros([size, size])
47-
48-
for i in range(num_blocks):
49-
start_row = i * block_size
50-
start_col = i * block_size
51-
large_matrix[start_row : start_row + block_size, start_col : start_col + block_size] = block
52-
return large_matrix.cast(dtype), block_size
66+
hadamard_matrix = create_hadamard_matrix(block_size, x.dtype)
67+
infohub.hadamard[block_size] = hadamard_matrix
68+
target_x = hadamard_matmul(x, side, hadamard_matrix, block_size)
69+
return target_x

0 commit comments

Comments
 (0)