13
13
# limitations under the License.
14
14
15
15
import numpy as np
16
+ import paddle
16
17
17
18
18
19
class MergeMethod :
@@ -46,8 +47,14 @@ def linear(self, tensor_list):
46
47
if self .merge_config .tensor_type == "np" :
47
48
tensor_output = sum (weight * tensor for weight , tensor in zip (weight_list , tensor_list ))
48
49
return tensor_output
50
+ elif self .merge_config .tensor_type == "pd" :
51
+ stacked_tensors = paddle .stack (tensor_list , axis = 0 )
52
+ weights = paddle .to_tensor (weight_list , dtype = stacked_tensors .dtype )
53
+ weights = weights .reshape ([- 1 ] + [1 ] * (len (stacked_tensors .shape ) - 1 ))
54
+ weighted_sum = paddle .sum (stacked_tensors * weights , axis = 0 )
55
+ return weighted_sum
49
56
else :
50
- raise NotImplementedError ( "Paddle Tensor is not supported yet. " )
57
+ raise ValueError ( f"Unkonwn tensor type { self . merge_config . tensor_type } " )
51
58
52
59
def slerp (self , tensor_list ):
53
60
"""
@@ -85,17 +92,45 @@ def slerp(self, tensor_list):
85
92
s0 = np .sin (theta_0 - theta_t ) / sin_theta_0
86
93
s1 = sin_theta_t / sin_theta_0
87
94
95
+ return s0 * t0_copy + s1 * t1_copy
96
+ elif self .merge_config .tensor_type == "pd" :
97
+ t0 , t1 = tensor_list
98
+ # Copy the tensors to reuse them later
99
+ t0_copy = t0 .clone ()
100
+ t1_copy = t1 .clone ()
101
+
102
+ # Normalize the tensors to get the directions and angles
103
+ t0 = self .normalize (t0 )
104
+ t1 = self .normalize (t1 )
105
+
106
+ # Dot product with the normalized tensors
107
+ dot = paddle .sum (t0 * t1 )
108
+ # If absolute value of dot product is almost 1, vectors are ~colinear, so use lerp
109
+ if paddle .abs (dot ) > self .merge_config .slerp_dot_threshold :
110
+ return (1 - self .merge_config .slerp_alpha ) * t0_copy + self .merge_config .slerp_alpha * t1_copy
111
+
112
+ # Calculate initial angle between t0 and t1
113
+ theta_0 = paddle .acos (dot )
114
+ sin_theta_0 = paddle .sin (theta_0 )
115
+
116
+ # Angle at timestep t
117
+ theta_t = theta_0 * self .merge_config .slerp_alpha
118
+ sin_theta_t = paddle .sin (theta_t )
119
+
120
+ # Finish the slerp algorithm
121
+ s0 = paddle .sin (theta_0 - theta_t ) / sin_theta_0
122
+ s1 = sin_theta_t / sin_theta_0
123
+
88
124
return s0 * t0_copy + s1 * t1_copy
89
125
else :
90
- raise NotImplementedError ( "Paddle Tensor is not supported yet. " )
126
+ raise ValueError ( f"Unkonwn tensor type { self . merge_config . tensor_type } " )
91
127
92
128
def ties (self , tensor_list ):
93
129
if self .merge_config .tensor_type == "np" :
94
130
# Get weight tensor
95
131
mask_dtype = tensor_list [0 ].dtype
96
132
weight_list = self .merge_config .weight_list
97
133
tensor_list = [weight * tensor for (weight , tensor ) in zip (weight_list , tensor_list )]
98
-
99
134
# Elect majority sign
100
135
sign_tensor_list = [np .sign (tensor ).astype (mask_dtype ) for tensor in tensor_list ]
101
136
if self .merge_config .ties_elect_type == "sum" :
@@ -117,14 +152,51 @@ def ties(self, tensor_list):
117
152
divisor [np .abs (divisor ) < 1e-8 ] = 1
118
153
merge_tensor /= divisor
119
154
return merge_tensor
155
+
156
+ elif self .merge_config .tensor_type == "pd" :
157
+ mask_dtype = tensor_list [0 ].dtype
158
+ weight_list = self .merge_config .weight_list
159
+ stacked_tensors = paddle .stack (tensor_list , axis = 0 )
160
+ weights = paddle .to_tensor (weight_list , dtype = stacked_tensors .dtype )
161
+ weights = weights .reshape ([- 1 ] + [1 ] * (len (stacked_tensors .shape ) - 1 ))
162
+ weighted_tensors = stacked_tensors * weights
163
+ # Elect majority sign
164
+ if self .merge_config .ties_elect_type == "sum" :
165
+ majority_sign = (paddle .sum (weighted_tensors , axis = 0 ) >= 0 ).astype (mask_dtype ) * 2 - 1
166
+ elif self .merge_config .ties_elect_type == "count" :
167
+ stacked_signs = paddle .sign (stacked_tensors ).astype (mask_dtype )
168
+ majority_sign = (paddle .sum (stacked_signs , axis = 0 ) >= 0 ).astype (mask_dtype ) * 2 - 1
169
+ else :
170
+ raise NotImplementedError (f"ties_elect_type: { self .merge_config .ties_elect_type } is unknown." )
171
+
172
+ # Merge
173
+ stacked_masks = (paddle .sign (weighted_tensors ) == majority_sign ).astype (mask_dtype )
174
+ masked_tensors = stacked_masks * weighted_tensors
175
+ merge_tensor = paddle .sum (masked_tensors , axis = 0 )
176
+ # Normalize
177
+ if self .merge_config .normalize :
178
+ weight_masks = stacked_masks * weights
179
+ divisor = paddle .sum (weight_masks , axis = 0 )
180
+ divisor = paddle .where (paddle .abs (divisor ) < 1e-8 , paddle .ones_like (divisor ), divisor )
181
+ merge_tensor /= divisor
182
+
183
+ return merge_tensor
120
184
else :
121
- raise NotImplementedError ( "Paddle Tensor is not supported yet. " )
185
+ raise ValueError ( f"Unkonwn tensor type { self . merge_config . tensor_type } " )
122
186
123
187
def normalize (self , t ):
124
188
"""
125
189
Normalize a vector by its L2 norm.
126
190
"""
127
- norm_t = np .linalg .norm (t )
128
- if norm_t > self .merge_config .slerp_normalize_eps :
129
- t = t / norm_t
130
- return t
191
+ if self .merge_config .tensor_type == "np" :
192
+ norm_t = np .linalg .norm (t )
193
+ if norm_t > self .merge_config .slerp_normalize_eps :
194
+ t = t / norm_t
195
+ return t
196
+ elif self .merge_config .tensor_type == "pd" :
197
+ norm_t = paddle .norm (t , p = 2 )
198
+ if norm_t > self .merge_config .slerp_normalize_eps :
199
+ t = t / norm_t
200
+ return t
201
+ else :
202
+ raise ValueError (f"Unkonwn tensor type { self .merge_config .tensor_type } " )
0 commit comments