-
Notifications
You must be signed in to change notification settings - Fork 3.1k
[llm] support tensorwise fp8/int8 training #10612
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
43ec0ac
fe356c3
0fbc564
2e70aad
fb0224a
9a4f89e
925a532
336eb31
d718d40
1bfb4d9
587cecc
fc1d2c4
c158ab7
7e60cbf
becedb7
ed78605
7935b87
ce845e8
2ab3101
97d37dd
6fd3312
779db3e
1b95f02
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,8 @@ | |
|
||
import paddle | ||
|
||
from paddlenlp.utils import infohub | ||
|
||
|
||
def matmul_hadU(X): | ||
|
||
|
@@ -31,22 +33,37 @@ def matmul_hadU(X): | |
return input.reshape(X.shape) | ||
|
||
|
||
def random_hadamard_matrix(size, dtype, is_block=False): | ||
if not is_block: | ||
A = paddle.randint(low=0, high=2, shape=[size, size]).astype("float32") * 2 - 1 | ||
Q, _ = paddle.linalg.qr(A) | ||
return Q.astype(dtype), 1 | ||
def create_hadamard_matrix(block_size, dtype): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 和前面 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已删除 |
||
Q = paddle.diag(paddle.ones((block_size), dtype=dtype)) | ||
block = matmul_hadU(Q) | ||
return block | ||
|
||
|
||
def hadamard_matmul(input, side, hadamard_matrix, block_size): | ||
# left -> H.T@input right -> input@H | ||
origin_shape = input.shape | ||
input = input.reshape([-1, origin_shape[-1]]) | ||
if side == "left": | ||
# H.T@input -> (input.T@H).T | ||
input = input.transpose([1, 0]) | ||
block_num = input.shape[-1] // block_size | ||
output = input.reshape([-1, block_num, block_size]) @ hadamard_matrix | ||
output = output.reshape([-1, block_num * block_size]) | ||
if side == "left": | ||
output = output.transpose([1, 0]) | ||
output = output.reshape(origin_shape) | ||
|
||
return output | ||
|
||
|
||
def apply_hadamard_matmul(x, side, block_size): | ||
if getattr(infohub, "hadamard") is None: | ||
setattr(infohub, "hadamard", {}) | ||
|
||
if block_size in infohub.hadamard: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hadamard_matrix 没有默认值的话,没有命中该分支会出问题 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. infohub.hadamard 默认值是{} |
||
hadamard_matrix = infohub.hadamard[block_size] | ||
else: | ||
num_blocks = size | ||
while not (num_blocks % 2): | ||
num_blocks = num_blocks // 2 | ||
block_size = size // num_blocks | ||
Q = paddle.diag(paddle.ones((block_size,), dtype="float32")) | ||
block = matmul_hadU(Q) | ||
large_matrix = paddle.zeros([size, size]) | ||
|
||
for i in range(num_blocks): | ||
start_row = i * block_size | ||
start_col = i * block_size | ||
large_matrix[start_row : start_row + block_size, start_col : start_col + block_size] = block | ||
return large_matrix.cast(dtype), block_size | ||
hadamard_matrix = create_hadamard_matrix(block_size, x.dtype) | ||
infohub.hadamard[block_size] = hadamard_matrix | ||
target_x = hadamard_matmul(x, side, hadamard_matrix, block_size) | ||
return target_x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的hardcode可以避免吗?或者如何保证一定生效?至少需要有log提示
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
暂时没有更好的写法,因为scale是stop_gradient,但需要传入optimizer的参数