Skip to content

Commit ac095f5

Browse files
authored
mergekit gpu 1226 (#9702)
* mergekit gpu 1226 * merge model gpu * merge gpu * add lora model * change valueerror * add lora * gpu test
1 parent 7c1c9ba commit ac095f5

File tree

8 files changed

+531
-76
lines changed

8 files changed

+531
-76
lines changed

paddlenlp/mergekit/merge_config.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class MergeConfig:
3535
default="np", metadata={"help": "Tensor type to use for the merge. Choose np(CPU Only) or pd (CPU/GPU)"}
3636
)
3737
n_process: int = field(default=1, metadata={"help": "Number of processes to use for the merge."})
38-
merge_preifx: str = field(default="model", metadata={"help": "Prefix name: model or master_weights"})
38+
merge_prefix: str = field(default="model", metadata={"help": "Prefix name: model or master_weights"})
3939
merge_method: str = field(default="linear", metadata={"help": "The merge strategy."})
4040
merge_type: str = field(default="linear", metadata={"help": "The type of merge process."})
4141
sparsify_type: str = field(default=None, metadata={"help": "The type of sparsify process."})
@@ -73,12 +73,11 @@ def __post_init__(self):
7373
def config_check(self):
7474
if self.output_path is not None:
7575
os.makedirs(self.output_path, exist_ok=True)
76-
if self.tensor_type not in ["np"]:
77-
raise ValueError(f"Unsupported tensor type: {self.tensor_type}. Support 'np' only.")
78-
if self.device != "cpu":
79-
logger.warning(f"Currently only support cpu device, but got {self.device}. Setting `device` to `cpu`.")
76+
if self.tensor_type not in ["np", "pd"]:
77+
raise ValueError(f"Unsupported tensor type: {self.tensor_type}. Support 'np' and 'pd' only.")
78+
if self.device == "gpu" and self.tensor_type == "np":
79+
logger.warning("np only support cpu device, but got gpu. Setting `device` to `cpu`.")
8080
self.device = "cpu"
81-
self.tensor_type = "np"
8281

8382
elif self.merge_method not in ["linear", "ties", "slerp", "della_linear", "della", "dare_linear", "dare_ties"]:
8483
raise ValueError(

paddlenlp/mergekit/merge_method.py

Lines changed: 80 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import numpy as np
16+
import paddle
1617

1718

1819
class MergeMethod:
@@ -46,8 +47,14 @@ def linear(self, tensor_list):
4647
if self.merge_config.tensor_type == "np":
4748
tensor_output = sum(weight * tensor for weight, tensor in zip(weight_list, tensor_list))
4849
return tensor_output
50+
elif self.merge_config.tensor_type == "pd":
51+
stacked_tensors = paddle.stack(tensor_list, axis=0)
52+
weights = paddle.to_tensor(weight_list, dtype=stacked_tensors.dtype)
53+
weights = weights.reshape([-1] + [1] * (len(stacked_tensors.shape) - 1))
54+
weighted_sum = paddle.sum(stacked_tensors * weights, axis=0)
55+
return weighted_sum
4956
else:
50-
raise NotImplementedError("Paddle Tensor is not supported yet.")
57+
raise ValueError(f"Unkonwn tensor type {self.merge_config.tensor_type}")
5158

5259
def slerp(self, tensor_list):
5360
"""
@@ -85,17 +92,45 @@ def slerp(self, tensor_list):
8592
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
8693
s1 = sin_theta_t / sin_theta_0
8794

95+
return s0 * t0_copy + s1 * t1_copy
96+
elif self.merge_config.tensor_type == "pd":
97+
t0, t1 = tensor_list
98+
# Copy the tensors to reuse them later
99+
t0_copy = t0.clone()
100+
t1_copy = t1.clone()
101+
102+
# Normalize the tensors to get the directions and angles
103+
t0 = self.normalize(t0)
104+
t1 = self.normalize(t1)
105+
106+
# Dot product with the normalized tensors
107+
dot = paddle.sum(t0 * t1)
108+
# If absolute value of dot product is almost 1, vectors are ~colinear, so use lerp
109+
if paddle.abs(dot) > self.merge_config.slerp_dot_threshold:
110+
return (1 - self.merge_config.slerp_alpha) * t0_copy + self.merge_config.slerp_alpha * t1_copy
111+
112+
# Calculate initial angle between t0 and t1
113+
theta_0 = paddle.acos(dot)
114+
sin_theta_0 = paddle.sin(theta_0)
115+
116+
# Angle at timestep t
117+
theta_t = theta_0 * self.merge_config.slerp_alpha
118+
sin_theta_t = paddle.sin(theta_t)
119+
120+
# Finish the slerp algorithm
121+
s0 = paddle.sin(theta_0 - theta_t) / sin_theta_0
122+
s1 = sin_theta_t / sin_theta_0
123+
88124
return s0 * t0_copy + s1 * t1_copy
89125
else:
90-
raise NotImplementedError("Paddle Tensor is not supported yet.")
126+
raise ValueError(f"Unkonwn tensor type {self.merge_config.tensor_type}")
91127

92128
def ties(self, tensor_list):
93129
if self.merge_config.tensor_type == "np":
94130
# Get weight tensor
95131
mask_dtype = tensor_list[0].dtype
96132
weight_list = self.merge_config.weight_list
97133
tensor_list = [weight * tensor for (weight, tensor) in zip(weight_list, tensor_list)]
98-
99134
# Elect majority sign
100135
sign_tensor_list = [np.sign(tensor).astype(mask_dtype) for tensor in tensor_list]
101136
if self.merge_config.ties_elect_type == "sum":
@@ -117,14 +152,51 @@ def ties(self, tensor_list):
117152
divisor[np.abs(divisor) < 1e-8] = 1
118153
merge_tensor /= divisor
119154
return merge_tensor
155+
156+
elif self.merge_config.tensor_type == "pd":
157+
mask_dtype = tensor_list[0].dtype
158+
weight_list = self.merge_config.weight_list
159+
stacked_tensors = paddle.stack(tensor_list, axis=0)
160+
weights = paddle.to_tensor(weight_list, dtype=stacked_tensors.dtype)
161+
weights = weights.reshape([-1] + [1] * (len(stacked_tensors.shape) - 1))
162+
weighted_tensors = stacked_tensors * weights
163+
# Elect majority sign
164+
if self.merge_config.ties_elect_type == "sum":
165+
majority_sign = (paddle.sum(weighted_tensors, axis=0) >= 0).astype(mask_dtype) * 2 - 1
166+
elif self.merge_config.ties_elect_type == "count":
167+
stacked_signs = paddle.sign(stacked_tensors).astype(mask_dtype)
168+
majority_sign = (paddle.sum(stacked_signs, axis=0) >= 0).astype(mask_dtype) * 2 - 1
169+
else:
170+
raise NotImplementedError(f"ties_elect_type: {self.merge_config.ties_elect_type} is unknown.")
171+
172+
# Merge
173+
stacked_masks = (paddle.sign(weighted_tensors) == majority_sign).astype(mask_dtype)
174+
masked_tensors = stacked_masks * weighted_tensors
175+
merge_tensor = paddle.sum(masked_tensors, axis=0)
176+
# Normalize
177+
if self.merge_config.normalize:
178+
weight_masks = stacked_masks * weights
179+
divisor = paddle.sum(weight_masks, axis=0)
180+
divisor = paddle.where(paddle.abs(divisor) < 1e-8, paddle.ones_like(divisor), divisor)
181+
merge_tensor /= divisor
182+
183+
return merge_tensor
120184
else:
121-
raise NotImplementedError("Paddle Tensor is not supported yet.")
185+
raise ValueError(f"Unkonwn tensor type {self.merge_config.tensor_type}")
122186

123187
def normalize(self, t):
124188
"""
125189
Normalize a vector by its L2 norm.
126190
"""
127-
norm_t = np.linalg.norm(t)
128-
if norm_t > self.merge_config.slerp_normalize_eps:
129-
t = t / norm_t
130-
return t
191+
if self.merge_config.tensor_type == "np":
192+
norm_t = np.linalg.norm(t)
193+
if norm_t > self.merge_config.slerp_normalize_eps:
194+
t = t / norm_t
195+
return t
196+
elif self.merge_config.tensor_type == "pd":
197+
norm_t = paddle.norm(t, p=2)
198+
if norm_t > self.merge_config.slerp_normalize_eps:
199+
t = t / norm_t
200+
return t
201+
else:
202+
raise ValueError(f"Unkonwn tensor type {self.merge_config.tensor_type}")

0 commit comments

Comments
 (0)