13
13
# limitations under the License.
14
14
15
15
16
- from typing import OrderedDict
16
+ from typing import OrderedDict , Tuple , Union
17
17
18
18
import paddle
19
19
import paddle .distributed .fleet as fleet
24
24
SharedLayerDesc ,
25
25
)
26
26
from paddle .distributed .fleet .recompute .recompute import recompute
27
+ from paddle .distributed .fleet .utils .sequence_parallel_utils import ScatterOp
27
28
28
29
from ...utils .tools import get_env_device
29
30
from ..model_utils import PipelinePretrainedModel
32
33
DeepseekV2DecoderLayer ,
33
34
DeepseekV2LMHead ,
34
35
DeepseekV2Model ,
36
+ DeepseekV2MTPLayer ,
35
37
DeepseekV2PretrainedModel ,
36
38
DeepseekV2PretrainingCriterion ,
37
39
DeepseekV2RMSNorm ,
@@ -46,6 +48,7 @@ def parse_args(args):
46
48
if isinstance (args , tuple ):
47
49
if len (args ) == 4 :
48
50
hidden_states , attention_mask , attn_mask_startend_row_indices , position_ids = args
51
+
49
52
elif len (args ) == 3 :
50
53
hidden_states , attention_mask , attn_mask_startend_row_indices = args
51
54
position_ids = None
@@ -119,25 +122,22 @@ def forward(self, args):
119
122
_type_: _description_
120
123
"""
121
124
input_ids , attention_mask , attn_mask_startend_row_indices , position_ids = parse_args (args )
122
- input_embeds = self .embed_tokens (input_ids )
123
- if self .config .sequence_parallel :
124
- from paddlenlp .transformers import ScatterOp
125
-
126
- # [bs, seq_len, num_head * head_dim] -> [bs * seq_len, num_head * head_dim]
127
- bs , seq_len , hidden_size = input_embeds .shape
128
- input_embeds = paddle .reshape_ (input_embeds , [bs * seq_len , hidden_size ])
129
- # [seq_len * bs / n, num_head * head_dim] (n is mp parallelism)
130
- input_embeds = ScatterOp .apply (input_embeds )
125
+ inputs_embeds = self .embed_tokens (input_ids )
131
126
132
127
batch_size , seq_length = input_ids .shape
128
+ if self .config .num_nextn_predict_layers > 0 :
129
+ seq_length -= self .config .num_nextn_predict_layers
130
+
131
+ if attention_mask is not None :
132
+ attention_mask = attention_mask [:, : - self .config .num_nextn_predict_layers ]
133
133
134
134
if attention_mask is not None :
135
135
assert (
136
136
attn_mask_startend_row_indices is None
137
137
), "attention_mask and attn_mask_startend_row_indices can not be set at same time"
138
138
139
139
attention_mask = DeepseekV2Model ._prepare_decoder_attention_mask (
140
- attention_mask , (batch_size , seq_length ), 0 , input_embeds .dtype
140
+ attention_mask , (batch_size , seq_length ), 0 , inputs_embeds .dtype
141
141
)
142
142
attention_mask .stop_gradient = True
143
143
if get_env_device () == "npu" :
@@ -146,13 +146,53 @@ def forward(self, args):
146
146
attention_mask = paddle .tril (paddle .ones ((seq_length , seq_length ), dtype = "bool" ))
147
147
attention_mask .stop_gradient = True
148
148
149
- return return_args (input_embeds , attention_mask , attn_mask_startend_row_indices , position_ids )
149
+ if self .config .num_nextn_predict_layers > 0 :
150
+ inputs_embeds_extra = inputs_embeds [:, - self .config .num_nextn_predict_layers :, :] # [B, S, D]
151
+ inputs_embeds = inputs_embeds [:, : - self .config .num_nextn_predict_layers , :]
152
+ inputs_embeds_ori = inputs_embeds
153
+ batch_size , seq_length , _ = inputs_embeds .shape
154
+
155
+ if self .sequence_parallel :
156
+ # [bs, seq_len, num_head * head_dim] -> [bs * seq_len, num_head * head_dim]
157
+ inputs_embeds = paddle .reshape (inputs_embeds , [- 1 , inputs_embeds .shape [- 1 ]])
158
+ # [seq_len * bs / n, num_head * head_dim] (n is mp parallelism)
159
+ inputs_embeds = ScatterOp .apply (inputs_embeds )
160
+ embeds_res = [inputs_embeds ]
161
+ for depth in range (self .config .num_nextn_predict_layers ):
162
+ inputs_embeds_mtp = paddle .concat (
163
+ [
164
+ inputs_embeds_ori [:, (depth + 1 ) :, :],
165
+ inputs_embeds_extra [:, : (depth + 1 ), :],
166
+ ],
167
+ axis = 1 ,
168
+ )
169
+ if self .sequence_parallel :
170
+ inputs_embeds_mtp = inputs_embeds_mtp .reshape ([- 1 , inputs_embeds_mtp .shape [- 1 ]])
171
+ inputs_embeds_mtp = ScatterOp .apply (inputs_embeds_mtp )
172
+ embeds_res .append (inputs_embeds_mtp )
173
+ # if not self.sequence_parallel
174
+ # mtp_embeds: [B*num_nextn_predict_layers, seq_len, hidden_size]
175
+ # else:
176
+ # mtp_embeds: [B*seq_len*num_nextn_predict_layers, hidden_size]
177
+ inputs_embeds = paddle .concat (embeds_res , axis = 0 )
178
+ return return_args (inputs_embeds , attention_mask , attn_mask_startend_row_indices , position_ids )
179
+ else :
180
+ if self .sequence_parallel :
181
+ inputs_embeds = inputs_embeds .reshape ([- 1 , inputs_embeds .shape [- 1 ]])
182
+ inputs_embeds = ScatterOp .apply (inputs_embeds )
183
+ return return_args (inputs_embeds , attention_mask , attn_mask_startend_row_indices , position_ids )
150
184
151
185
152
186
class DeepseekV2DecoderLayerPipe (DeepseekV2DecoderLayer ):
153
187
def forward (self , args ):
154
188
hidden_states , attention_mask , attn_mask_startend_row_indices , position_ids = parse_args (args )
155
189
190
+ if self .config .num_nextn_predict_layers > 0 :
191
+ _ , _ , hidden_size = hidden_states .shape
192
+ hidden_size_mtp = hidden_size // (self .config .num_nextn_predict_layers + 1 )
193
+ inputs_embeds_mtp = hidden_states [:, :, - hidden_size_mtp :]
194
+ hidden_states = hidden_states [:, :, :- hidden_size_mtp ]
195
+
156
196
has_gradient = not hidden_states .stop_gradient
157
197
158
198
if attention_mask is not None and attention_mask .dtype == paddle .int32 :
@@ -193,17 +233,91 @@ def forward(self, args):
193
233
attn_mask_startend_row_indices = attn_mask_startend_row_indices ,
194
234
)
195
235
236
+ if self .config .num_nextn_predict_layers > 0 :
237
+ hidden_states = paddle .concat ([hidden_states , * inputs_embeds_mtp ])
238
+
239
+ return return_args (hidden_states , attention_mask , attn_mask_startend_row_indices , position_ids )
240
+
241
+
242
+ class DeepseekV2MTPLayerPipe (DeepseekV2MTPLayer ):
243
+ def forward (self , args ):
244
+ hidden_states , attention_mask , attn_mask_startend_row_indices , position_ids = parse_args (args )
245
+
246
+ hidden_states_list = paddle .split (hidden_states , self .config .num_nextn_predict_layers + 1 )
247
+ hidden_states_main_model = hidden_states_list [0 ]
248
+ inputs_embeds_cur_depth_list = hidden_states_list [1 :]
249
+ has_gradient = not hidden_states_main_model .stop_gradient
250
+
251
+ if attention_mask is not None and attention_mask .dtype == paddle .int32 :
252
+ attention_mask , attn_mask_startend_row_indices , position_ids = (
253
+ None ,
254
+ attention_mask ,
255
+ attn_mask_startend_row_indices ,
256
+ )
257
+ elif attention_mask is not None and attention_mask .dtype == paddle .int64 :
258
+ attention_mask , attn_mask_startend_row_indices , position_ids = None , None , attention_mask
259
+ elif attn_mask_startend_row_indices is not None and attn_mask_startend_row_indices .dtype == paddle .int64 :
260
+ attn_mask_startend_row_indices , position_ids = None , attn_mask_startend_row_indices
261
+
262
+ output_list = [hidden_states_main_model ]
263
+ hidden_states = hidden_states_main_model
264
+ for depth in range (self .config .num_nextn_predict_layers ):
265
+ inputs_embeds_cur_depth = inputs_embeds_cur_depth_list [depth ]
266
+ if self .enable_recompute and self .config .recompute_granularity == "full" and has_gradient :
267
+ if attention_mask is not None or attn_mask_startend_row_indices is not None :
268
+ hidden_states = recompute (
269
+ super ().forward ,
270
+ hidden_states ,
271
+ inputs_embeds_cur_depth ,
272
+ position_ids = position_ids ,
273
+ attention_mask = attention_mask ,
274
+ attn_mask_startend_row_indices = attn_mask_startend_row_indices ,
275
+ use_reentrant = False ,
276
+ )
277
+ else :
278
+ # for pretrain
279
+ hidden_states = recompute (
280
+ super ().forward ,
281
+ hidden_states ,
282
+ inputs_embeds_cur_depth ,
283
+ position_ids = position_ids ,
284
+ attn_mask_startend_row_indices = attn_mask_startend_row_indices ,
285
+ use_reentrant = self .config .recompute_use_reentrant ,
286
+ )
287
+ else :
288
+ hidden_states = super ().forward (
289
+ hidden_states ,
290
+ inputs_embeds_cur_depth ,
291
+ position_ids = position_ids ,
292
+ attention_mask = attention_mask ,
293
+ attn_mask_startend_row_indices = attn_mask_startend_row_indices ,
294
+ )
295
+ output_list .append (hidden_states )
296
+
297
+ hidden_states = paddle .concat (output_list )
196
298
return return_args (hidden_states , attention_mask , attn_mask_startend_row_indices , position_ids )
197
299
198
300
199
301
class DeepseekV2RMSNormPipe (nn .Layer ):
200
302
def __init__ (self , config ):
201
303
super ().__init__ ()
304
+ self .config = config
202
305
self .norm = DeepseekV2RMSNorm (config )
203
306
204
307
def forward (self , args ):
205
308
hidden_states , attention_mask , attn_mask_startend_row_indices , position_ids = parse_args (args )
206
- return self .norm (hidden_states )
309
+
310
+ if self .config .num_nextn_predict_layers > 0 :
311
+ hidden_states_list = paddle .split (hidden_states , self .config .num_nextn_predict_layers + 1 )
312
+ hidden_states = hidden_states_list [0 ]
313
+ hidden_states_mtp = hidden_states_list [- self .config .num_nextn_predict_layers :]
314
+
315
+ output_list = [self .norm (hidden_states )]
316
+ for hidden_states in hidden_states_mtp :
317
+ output_list .append (self .norm (hidden_states ))
318
+ return output_list
319
+ else :
320
+ return self .norm (hidden_states )
207
321
208
322
209
323
class DeepseekV2LMHeadPipe (DeepseekV2LMHead ):
@@ -214,6 +328,27 @@ def __init__(self, config):
214
328
def embedding_weight (self ):
215
329
return get_attr (self , "weight" )
216
330
331
+ def forward (self , args : Union [Tuple , paddle .Tensor ]):
332
+ if self .config .num_nextn_predict_layers > 0 :
333
+ logits = []
334
+ for _hidden_states in args :
335
+ logits .append (super ().forward (_hidden_states ))
336
+ return logits
337
+ hidden_states = args
338
+ logits = super ().forward (hidden_states )
339
+ return logits
340
+
341
+
342
+ class DeepseekV2PretrainingCriterionPipe (DeepseekV2PretrainingCriterion ):
343
+ def forward (self , logits , labels ):
344
+ if self .config .num_nextn_predict_layers > 0 :
345
+ mtp_logits = logits [1 :]
346
+ logits = logits [0 ]
347
+ loss = super ().forward (logits , labels , mtp_logits = mtp_logits )
348
+ else :
349
+ loss = super ().forward (logits , labels )
350
+ return loss
351
+
217
352
218
353
class DeepseekV2ForCausalLMPipe (PipelinePretrainedModel , PipelineLayer ):
219
354
"""DeepseekV2ForPretraining adapted for pipeline parallelism.
@@ -308,6 +443,12 @@ def get_hcg():
308
443
),
309
444
f"{ self ._base_model .base_model_prefix } .layers.{ i } " ,
310
445
)
446
+ for i in range (config .num_nextn_predict_layers ):
447
+ self .add_sequential_layer (
448
+ LayerDesc (DeepseekV2MTPLayerPipe , config = config , layer_idx = i ),
449
+ f"{ self ._base_model .base_model_prefix } .layers.{ i } " ,
450
+ )
451
+
311
452
self .add_sequential_layer (LayerDesc (DeepseekV2RMSNormPipe , config = config ), self ._base_model .base_model_prefix )
312
453
313
454
if config .tie_word_embeddings :
@@ -331,7 +472,7 @@ def get_hcg():
331
472
), "pp recompute interval should smaller than num layers of each pp chunk"
332
473
recompute_interval = self .config .pp_recompute_interval
333
474
334
- seg_method = "layer:DeepseekV2DecoderLayer"
475
+ seg_method = "layer:DeepseekV2DecoderLayer|MTPLayer "
335
476
if config .num_hidden_layers % get_hcg ().topology ().get_dim_size ("pipe" ) != 0 :
336
477
seg_method = "uniform"
337
478
@@ -355,4 +496,4 @@ def get_hcg():
355
496
# PipelinePretrainedModel.__init__(self.super(), config=config)
356
497
357
498
def get_loss_fn (self , config ):
358
- return DeepseekV2PretrainingCriterion (config )
499
+ return DeepseekV2PretrainingCriterionPipe (config )
0 commit comments