23
23
24
24
from fastdeploy .config import FDConfig
25
25
26
- from .utils import get_tensor
26
+ from .utils import _set_var_distributed , get_tensor
27
27
28
28
29
29
class VocabParallelEmbedding (nn .Layer ):
@@ -66,33 +66,34 @@ def __init__(
66
66
self .max_position_embeddings : int = fd_config .model_config .max_position_embeddings
67
67
self .tie_word_embeddings : bool = fd_config .model_config .tie_word_embeddings
68
68
self .params_dtype : str = params_dtype
69
-
70
- if self .use_ep :
71
- self .word_embeddings = nn .Embedding (
72
- num_embeddings ,
73
- embedding_dim ,
74
- )
75
- else :
76
- if not self .column_cut :
77
- self .word_embeddings = fleet .meta_parallel .VocabParallelEmbedding (
78
- num_embeddings ,
79
- embedding_dim ,
80
- mp_group = fleet .get_hybrid_communicate_group ().
81
- get_model_parallel_group (),
82
- weight_attr = paddle .ParamAttr (
83
- initializer = nn .initializer .Normal (
84
- mean = 0.0 , std = self .initializer_range ), ),
85
- )
86
- else :
87
- # column cut embedding
88
- self .word_embeddings = nn .Embedding (
89
- num_embeddings ,
90
- embedding_dim // self .world_size ,
91
- )
92
-
93
- self .word_embeddings .weight .is_distributed = True
94
- self .word_embeddings .weight .split_axis = 1
95
-
69
+ self .num_embeddings = num_embeddings
70
+ self .embedding_dim = embedding_dim
71
+ # if self.use_ep:
72
+ # self.word_embeddings = nn.Embedding(
73
+ # num_embeddings,
74
+ # embedding_dim,
75
+ # )
76
+ # else:
77
+ # if not self.column_cut:
78
+ # self.word_embeddings = fleet.meta_parallel.VocabParallelEmbedding(
79
+ # num_embeddings,
80
+ # embedding_dim,
81
+ # mp_group=fleet.get_hybrid_communicate_group().
82
+ # get_model_parallel_group(),
83
+ # weight_attr=paddle.ParamAttr(
84
+ # initializer=nn.initializer.Normal(
85
+ # mean=0.0, std=self.initializer_range), ),
86
+ # )
87
+ # else:
88
+ # # column cut embedding
89
+ # self.word_embeddings = nn.Embedding(
90
+ # num_embeddings,
91
+ # embedding_dim // self.world_size,
92
+ # )
93
+
94
+ # self.word_embeddings.weight.is_distributed = True
95
+ # self.word_embeddings.weight.split_axis = 1
96
+ self .init_weight ()
96
97
if not self .use_rope :
97
98
self .position_embeddings = nn .Embedding (
98
99
self .max_position_embeddings ,
@@ -103,6 +104,23 @@ def __init__(
103
104
104
105
self .prefix = prefix
105
106
self .dropout = nn .Dropout (self .hidden_dropout_prob )
107
+ def weight_loader (self , param , loaded_weight ):
108
+ param .copy_ (loaded_weight ,False )
109
+ def init_weight (self ):
110
+ from fastdeploy .model_executor .models .utils import set_param_attr
111
+ self .weight = self .create_parameter (
112
+ shape = [self .num_embeddings // self .world_size ,self .embedding_dim ],
113
+ dtype = paddle .get_default_dtype (),
114
+ is_bias = False ,
115
+ default_initializer = paddle .nn .initializer .Constant (0 ),
116
+ )
117
+ if self .world_size > 0 :
118
+ if self .column_cut :
119
+ _set_var_distributed (self .weight , split_axis = 1 )
120
+ set_param_attr (self .weight ,{"is_column" :False ,"weight_loader" :self .weight_loader })
121
+ else :
122
+ _set_var_distributed (self .weight , split_axis = 0 )
123
+ set_param_attr (self .weight ,{"is_column" :True ,"weight_loader" :self .weight_loader })
106
124
107
125
def load_state_dict (self , state_dict : Dict [str ,
108
126
paddle .Tensor | np .ndarray ]):
@@ -112,15 +130,20 @@ def load_state_dict(self, state_dict: Dict[str,
112
130
Args:
113
131
state_dict (dict): A dictionary containing the checkpoint weights and biases.
114
132
"""
115
- a = state_dict [self .prefix + ".weight" ]
133
+ # a = state_dict[self.prefix + ".weight"]
116
134
if self .tie_word_embeddings :
117
- self .word_embeddings .weight .set_value (
118
- get_tensor (state_dict [self .prefix + ".weight" ]).astype (
119
- paddle .get_default_dtype ()))
135
+ # bh_ops.static_op_bh_copy(self.word_embeddings.weight,get_tensor(state_dict[self.prefix + ".weight"]))
136
+ self .weight .weight .copy_ (
137
+ get_tensor (state_dict [self .prefix + ".weight" ]),False )
138
+ # .astype(
139
+ # paddle.get_default_dtype()))
120
140
else :
121
- self .word_embeddings .weight .set_value (
122
- get_tensor (state_dict .pop (self .prefix + ".weight" )).astype (
123
- paddle .get_default_dtype ()))
141
+ # bh_ops.static_op_bh_copy(self.word_embeddings.weight,get_tensor(state_dict.pop(self.prefix + ".weight")))
142
+
143
+ self .weight .weight .copy_ (
144
+ get_tensor (state_dict .pop (self .prefix + ".weight" )),False )
145
+ # .astype(
146
+ # paddle.get_default_dtype()))
124
147
125
148
def forward (self , ids_remove_padding = None ) -> paddle .Tensor :
126
149
"""
@@ -134,10 +157,10 @@ def forward(self, ids_remove_padding=None) -> paddle.Tensor:
134
157
Tensor: Embedded tensor representation of the input IDs.
135
158
"""
136
159
if self .use_ep :
137
- input_embedings = self .word_embeddings (ids_remove_padding )
160
+ input_embedings = self .weight (ids_remove_padding )
138
161
else :
139
162
if self .column_cut :
140
- input_embedings = self . word_embeddings ( ids_remove_padding )
163
+ input_embedings = nn . functional . embedding ( x = ids_remove_padding , weight = self . weight )
141
164
inputs_embeds_temp = []
142
165
paddle .distributed .all_gather (
143
166
inputs_embeds_temp ,
@@ -148,6 +171,6 @@ def forward(self, ids_remove_padding=None) -> paddle.Tensor:
148
171
)
149
172
input_embedings = paddle .concat (inputs_embeds_temp , - 1 )
150
173
else :
151
- input_embedings = self . word_embeddings ( ids_remove_padding )
174
+ input_embedings = nn . functional . embedding ( x = ids_remove_padding , weight = self . weight )
152
175
153
176
return input_embedings
0 commit comments