14
14
15
15
import paddle
16
16
17
+ from paddlenlp .utils import infohub
18
+
17
19
18
20
def matmul_hadU (X ):
19
21
@@ -31,22 +33,37 @@ def matmul_hadU(X):
31
33
return input .reshape (X .shape )
32
34
33
35
34
- def random_hadamard_matrix (size , dtype , is_block = False ):
35
- if not is_block :
36
- A = paddle .randint (low = 0 , high = 2 , shape = [size , size ]).astype ("float32" ) * 2 - 1
37
- Q , _ = paddle .linalg .qr (A )
38
- return Q .astype (dtype ), 1
36
+ def create_hadamard_matrix (block_size , dtype ):
37
+ Q = paddle .diag (paddle .ones ((block_size ), dtype = dtype ))
38
+ block = matmul_hadU (Q )
39
+ return block
40
+
41
+
42
+ def hadamard_matmul (input , side , hadamard_matrix , block_size ):
43
+ # left -> H.T@input right -> input@H
44
+ origin_shape = input .shape
45
+ input = input .reshape ([- 1 , origin_shape [- 1 ]])
46
+ if side == "left" :
47
+ # H.T@input -> (input.T@H).T
48
+ input = input .transpose ([1 , 0 ])
49
+ block_num = input .shape [- 1 ] // block_size
50
+ output = input .reshape ([- 1 , block_num , block_size ]) @ hadamard_matrix
51
+ output = output .reshape ([- 1 , block_num * block_size ])
52
+ if side == "left" :
53
+ output = output .transpose ([1 , 0 ])
54
+ output = output .reshape (origin_shape )
55
+
56
+ return output
57
+
58
+
59
+ def apply_hadamard_matmul (x , side , block_size ):
60
+ if getattr (infohub , "hadamard" ) is None :
61
+ setattr (infohub , "hadamard" , {})
62
+
63
+ if block_size in infohub .hadamard :
64
+ hadamard_matrix = infohub .hadamard [block_size ]
39
65
else :
40
- num_blocks = size
41
- while not (num_blocks % 2 ):
42
- num_blocks = num_blocks // 2
43
- block_size = size // num_blocks
44
- Q = paddle .diag (paddle .ones ((block_size ,), dtype = "float32" ))
45
- block = matmul_hadU (Q )
46
- large_matrix = paddle .zeros ([size , size ])
47
-
48
- for i in range (num_blocks ):
49
- start_row = i * block_size
50
- start_col = i * block_size
51
- large_matrix [start_row : start_row + block_size , start_col : start_col + block_size ] = block
52
- return large_matrix .cast (dtype ), block_size
66
+ hadamard_matrix = create_hadamard_matrix (block_size , x .dtype )
67
+ infohub .hadamard [block_size ] = hadamard_matrix
68
+ target_x = hadamard_matmul (x , side , hadamard_matrix , block_size )
69
+ return target_x
0 commit comments