Skip to content

Commit 35076f8

Browse files
committed
speedup load
1 parent 5fc659b commit 35076f8

File tree

10 files changed

+528
-201
lines changed

10 files changed

+528
-201
lines changed

fastdeploy/model_executor/layers/embeddings.py

Lines changed: 61 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
from fastdeploy.config import FDConfig
2525

26-
from .utils import get_tensor
26+
from .utils import _set_var_distributed, get_tensor
2727

2828

2929
class VocabParallelEmbedding(nn.Layer):
@@ -66,33 +66,34 @@ def __init__(
6666
self.max_position_embeddings: int = fd_config.model_config.max_position_embeddings
6767
self.tie_word_embeddings: bool = fd_config.model_config.tie_word_embeddings
6868
self.params_dtype: str = params_dtype
69-
70-
if self.use_ep:
71-
self.word_embeddings = nn.Embedding(
72-
num_embeddings,
73-
embedding_dim,
74-
)
75-
else:
76-
if not self.column_cut:
77-
self.word_embeddings = fleet.meta_parallel.VocabParallelEmbedding(
78-
num_embeddings,
79-
embedding_dim,
80-
mp_group=fleet.get_hybrid_communicate_group().
81-
get_model_parallel_group(),
82-
weight_attr=paddle.ParamAttr(
83-
initializer=nn.initializer.Normal(
84-
mean=0.0, std=self.initializer_range), ),
85-
)
86-
else:
87-
# column cut embedding
88-
self.word_embeddings = nn.Embedding(
89-
num_embeddings,
90-
embedding_dim // self.world_size,
91-
)
92-
93-
self.word_embeddings.weight.is_distributed = True
94-
self.word_embeddings.weight.split_axis = 1
95-
69+
self.num_embeddings=num_embeddings
70+
self.embedding_dim=embedding_dim
71+
# if self.use_ep:
72+
# self.word_embeddings = nn.Embedding(
73+
# num_embeddings,
74+
# embedding_dim,
75+
# )
76+
# else:
77+
# if not self.column_cut:
78+
# self.word_embeddings = fleet.meta_parallel.VocabParallelEmbedding(
79+
# num_embeddings,
80+
# embedding_dim,
81+
# mp_group=fleet.get_hybrid_communicate_group().
82+
# get_model_parallel_group(),
83+
# weight_attr=paddle.ParamAttr(
84+
# initializer=nn.initializer.Normal(
85+
# mean=0.0, std=self.initializer_range), ),
86+
# )
87+
# else:
88+
# # column cut embedding
89+
# self.word_embeddings = nn.Embedding(
90+
# num_embeddings,
91+
# embedding_dim // self.world_size,
92+
# )
93+
94+
# self.word_embeddings.weight.is_distributed = True
95+
# self.word_embeddings.weight.split_axis = 1
96+
self.init_weight()
9697
if not self.use_rope:
9798
self.position_embeddings = nn.Embedding(
9899
self.max_position_embeddings,
@@ -103,6 +104,23 @@ def __init__(
103104

104105
self.prefix = prefix
105106
self.dropout = nn.Dropout(self.hidden_dropout_prob)
107+
def weight_loader(self, param, loaded_weight):
108+
param.copy_(loaded_weight,False)
109+
def init_weight(self):
110+
from fastdeploy.model_executor.models.utils import set_param_attr
111+
self.weight = self.create_parameter(
112+
shape=[self.num_embeddings//self.world_size,self.embedding_dim],
113+
dtype= paddle.get_default_dtype(),
114+
is_bias=False,
115+
default_initializer=paddle.nn.initializer.Constant(0),
116+
)
117+
if self.world_size > 0:
118+
if self.column_cut:
119+
_set_var_distributed(self.weight, split_axis=1)
120+
set_param_attr(self.weight,{"is_column":False,"weight_loader":self.weight_loader})
121+
else:
122+
_set_var_distributed(self.weight, split_axis=0)
123+
set_param_attr(self.weight,{"is_column":True,"weight_loader":self.weight_loader})
106124

107125
def load_state_dict(self, state_dict: Dict[str,
108126
paddle.Tensor | np.ndarray]):
@@ -112,15 +130,20 @@ def load_state_dict(self, state_dict: Dict[str,
112130
Args:
113131
state_dict (dict): A dictionary containing the checkpoint weights and biases.
114132
"""
115-
a = state_dict[self.prefix + ".weight"]
133+
# a = state_dict[self.prefix + ".weight"]
116134
if self.tie_word_embeddings:
117-
self.word_embeddings.weight.set_value(
118-
get_tensor(state_dict[self.prefix + ".weight"]).astype(
119-
paddle.get_default_dtype()))
135+
# bh_ops.static_op_bh_copy(self.word_embeddings.weight,get_tensor(state_dict[self.prefix + ".weight"]))
136+
self.weight.weight.copy_(
137+
get_tensor(state_dict[self.prefix + ".weight"]),False)
138+
# .astype(
139+
# paddle.get_default_dtype()))
120140
else:
121-
self.word_embeddings.weight.set_value(
122-
get_tensor(state_dict.pop(self.prefix + ".weight")).astype(
123-
paddle.get_default_dtype()))
141+
# bh_ops.static_op_bh_copy(self.word_embeddings.weight,get_tensor(state_dict.pop(self.prefix + ".weight")))
142+
143+
self.weight.weight.copy_(
144+
get_tensor(state_dict.pop(self.prefix + ".weight")),False)
145+
# .astype(
146+
# paddle.get_default_dtype()))
124147

125148
def forward(self, ids_remove_padding=None) -> paddle.Tensor:
126149
"""
@@ -134,10 +157,10 @@ def forward(self, ids_remove_padding=None) -> paddle.Tensor:
134157
Tensor: Embedded tensor representation of the input IDs.
135158
"""
136159
if self.use_ep:
137-
input_embedings = self.word_embeddings(ids_remove_padding)
160+
input_embedings = self.weight(ids_remove_padding)
138161
else:
139162
if self.column_cut:
140-
input_embedings = self.word_embeddings(ids_remove_padding)
163+
input_embedings = nn.functional.embedding(x=ids_remove_padding,weight=self.weight)
141164
inputs_embeds_temp = []
142165
paddle.distributed.all_gather(
143166
inputs_embeds_temp,
@@ -148,6 +171,6 @@ def forward(self, ids_remove_padding=None) -> paddle.Tensor:
148171
)
149172
input_embedings = paddle.concat(inputs_embeds_temp, -1)
150173
else:
151-
input_embedings = self.word_embeddings(ids_remove_padding)
174+
input_embedings = nn.functional.embedding(x=ids_remove_padding,weight=self.weight)
152175

153176
return input_embedings

0 commit comments

Comments
 (0)