Skip to content

Commit 7acaf18

Browse files
authored
[LLM] Add MTP for Deepseekv3 (#9876)
* mtp * MTP * update deafult config * update MTP * update seq_aux_loss * update output * lint * fix for qwen2moe
1 parent 9790880 commit 7acaf18

File tree

8 files changed

+444
-70
lines changed

8 files changed

+444
-70
lines changed

paddlenlp/transformers/deepseek_v2/configuration.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,8 @@ def __init__(
139139
intermediate_size=11008,
140140
moe_intermediate_size=1407,
141141
num_hidden_layers=30,
142-
num_nextn_predict_layers=1,
142+
num_nextn_predict_layers=0,
143+
num_nextn_predict_lambda=1.0,
143144
num_attention_heads=32,
144145
num_key_value_heads=32,
145146
n_shared_experts=None,
@@ -187,6 +188,7 @@ def __init__(
187188
self.moe_intermediate_size = moe_intermediate_size
188189
self.num_hidden_layers = num_hidden_layers
189190
self.num_nextn_predict_layers = num_nextn_predict_layers
191+
self.num_nextn_predict_lambda = num_nextn_predict_lambda
190192
self.num_attention_heads = num_attention_heads
191193
self.n_shared_experts = n_shared_experts
192194
self.n_routed_experts = n_routed_experts

paddlenlp/transformers/deepseek_v2/modeling.py

Lines changed: 203 additions & 49 deletions
Large diffs are not rendered by default.

paddlenlp/transformers/deepseek_v2/modeling_pp.py

Lines changed: 156 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515

16-
from typing import OrderedDict
16+
from typing import OrderedDict, Tuple, Union
1717

1818
import paddle
1919
import paddle.distributed.fleet as fleet
@@ -24,6 +24,7 @@
2424
SharedLayerDesc,
2525
)
2626
from paddle.distributed.fleet.recompute.recompute import recompute
27+
from paddle.distributed.fleet.utils.sequence_parallel_utils import ScatterOp
2728

2829
from ...utils.tools import get_env_device
2930
from ..model_utils import PipelinePretrainedModel
@@ -32,6 +33,7 @@
3233
DeepseekV2DecoderLayer,
3334
DeepseekV2LMHead,
3435
DeepseekV2Model,
36+
DeepseekV2MTPLayer,
3537
DeepseekV2PretrainedModel,
3638
DeepseekV2PretrainingCriterion,
3739
DeepseekV2RMSNorm,
@@ -46,6 +48,7 @@ def parse_args(args):
4648
if isinstance(args, tuple):
4749
if len(args) == 4:
4850
hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = args
51+
4952
elif len(args) == 3:
5053
hidden_states, attention_mask, attn_mask_startend_row_indices = args
5154
position_ids = None
@@ -119,25 +122,22 @@ def forward(self, args):
119122
_type_: _description_
120123
"""
121124
input_ids, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args)
122-
input_embeds = self.embed_tokens(input_ids)
123-
if self.config.sequence_parallel:
124-
from paddlenlp.transformers import ScatterOp
125-
126-
# [bs, seq_len, num_head * head_dim] -> [bs * seq_len, num_head * head_dim]
127-
bs, seq_len, hidden_size = input_embeds.shape
128-
input_embeds = paddle.reshape_(input_embeds, [bs * seq_len, hidden_size])
129-
# [seq_len * bs / n, num_head * head_dim] (n is mp parallelism)
130-
input_embeds = ScatterOp.apply(input_embeds)
125+
inputs_embeds = self.embed_tokens(input_ids)
131126

132127
batch_size, seq_length = input_ids.shape
128+
if self.config.num_nextn_predict_layers > 0:
129+
seq_length -= self.config.num_nextn_predict_layers
130+
131+
if attention_mask is not None:
132+
attention_mask = attention_mask[:, : -self.config.num_nextn_predict_layers]
133133

134134
if attention_mask is not None:
135135
assert (
136136
attn_mask_startend_row_indices is None
137137
), "attention_mask and attn_mask_startend_row_indices can not be set at same time"
138138

139139
attention_mask = DeepseekV2Model._prepare_decoder_attention_mask(
140-
attention_mask, (batch_size, seq_length), 0, input_embeds.dtype
140+
attention_mask, (batch_size, seq_length), 0, inputs_embeds.dtype
141141
)
142142
attention_mask.stop_gradient = True
143143
if get_env_device() == "npu":
@@ -146,13 +146,53 @@ def forward(self, args):
146146
attention_mask = paddle.tril(paddle.ones((seq_length, seq_length), dtype="bool"))
147147
attention_mask.stop_gradient = True
148148

149-
return return_args(input_embeds, attention_mask, attn_mask_startend_row_indices, position_ids)
149+
if self.config.num_nextn_predict_layers > 0:
150+
inputs_embeds_extra = inputs_embeds[:, -self.config.num_nextn_predict_layers :, :] # [B, S, D]
151+
inputs_embeds = inputs_embeds[:, : -self.config.num_nextn_predict_layers, :]
152+
inputs_embeds_ori = inputs_embeds
153+
batch_size, seq_length, _ = inputs_embeds.shape
154+
155+
if self.sequence_parallel:
156+
# [bs, seq_len, num_head * head_dim] -> [bs * seq_len, num_head * head_dim]
157+
inputs_embeds = paddle.reshape(inputs_embeds, [-1, inputs_embeds.shape[-1]])
158+
# [seq_len * bs / n, num_head * head_dim] (n is mp parallelism)
159+
inputs_embeds = ScatterOp.apply(inputs_embeds)
160+
embeds_res = [inputs_embeds]
161+
for depth in range(self.config.num_nextn_predict_layers):
162+
inputs_embeds_mtp = paddle.concat(
163+
[
164+
inputs_embeds_ori[:, (depth + 1) :, :],
165+
inputs_embeds_extra[:, : (depth + 1), :],
166+
],
167+
axis=1,
168+
)
169+
if self.sequence_parallel:
170+
inputs_embeds_mtp = inputs_embeds_mtp.reshape([-1, inputs_embeds_mtp.shape[-1]])
171+
inputs_embeds_mtp = ScatterOp.apply(inputs_embeds_mtp)
172+
embeds_res.append(inputs_embeds_mtp)
173+
# if not self.sequence_parallel
174+
# mtp_embeds: [B*num_nextn_predict_layers, seq_len, hidden_size]
175+
# else:
176+
# mtp_embeds: [B*seq_len*num_nextn_predict_layers, hidden_size]
177+
inputs_embeds = paddle.concat(embeds_res, axis=0)
178+
return return_args(inputs_embeds, attention_mask, attn_mask_startend_row_indices, position_ids)
179+
else:
180+
if self.sequence_parallel:
181+
inputs_embeds = inputs_embeds.reshape([-1, inputs_embeds.shape[-1]])
182+
inputs_embeds = ScatterOp.apply(inputs_embeds)
183+
return return_args(inputs_embeds, attention_mask, attn_mask_startend_row_indices, position_ids)
150184

151185

152186
class DeepseekV2DecoderLayerPipe(DeepseekV2DecoderLayer):
153187
def forward(self, args):
154188
hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args)
155189

190+
if self.config.num_nextn_predict_layers > 0:
191+
_, _, hidden_size = hidden_states.shape
192+
hidden_size_mtp = hidden_size // (self.config.num_nextn_predict_layers + 1)
193+
inputs_embeds_mtp = hidden_states[:, :, -hidden_size_mtp:]
194+
hidden_states = hidden_states[:, :, :-hidden_size_mtp]
195+
156196
has_gradient = not hidden_states.stop_gradient
157197

158198
if attention_mask is not None and attention_mask.dtype == paddle.int32:
@@ -193,17 +233,91 @@ def forward(self, args):
193233
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
194234
)
195235

236+
if self.config.num_nextn_predict_layers > 0:
237+
hidden_states = paddle.concat([hidden_states, *inputs_embeds_mtp])
238+
239+
return return_args(hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids)
240+
241+
242+
class DeepseekV2MTPLayerPipe(DeepseekV2MTPLayer):
243+
def forward(self, args):
244+
hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args)
245+
246+
hidden_states_list = paddle.split(hidden_states, self.config.num_nextn_predict_layers + 1)
247+
hidden_states_main_model = hidden_states_list[0]
248+
inputs_embeds_cur_depth_list = hidden_states_list[1:]
249+
has_gradient = not hidden_states_main_model.stop_gradient
250+
251+
if attention_mask is not None and attention_mask.dtype == paddle.int32:
252+
attention_mask, attn_mask_startend_row_indices, position_ids = (
253+
None,
254+
attention_mask,
255+
attn_mask_startend_row_indices,
256+
)
257+
elif attention_mask is not None and attention_mask.dtype == paddle.int64:
258+
attention_mask, attn_mask_startend_row_indices, position_ids = None, None, attention_mask
259+
elif attn_mask_startend_row_indices is not None and attn_mask_startend_row_indices.dtype == paddle.int64:
260+
attn_mask_startend_row_indices, position_ids = None, attn_mask_startend_row_indices
261+
262+
output_list = [hidden_states_main_model]
263+
hidden_states = hidden_states_main_model
264+
for depth in range(self.config.num_nextn_predict_layers):
265+
inputs_embeds_cur_depth = inputs_embeds_cur_depth_list[depth]
266+
if self.enable_recompute and self.config.recompute_granularity == "full" and has_gradient:
267+
if attention_mask is not None or attn_mask_startend_row_indices is not None:
268+
hidden_states = recompute(
269+
super().forward,
270+
hidden_states,
271+
inputs_embeds_cur_depth,
272+
position_ids=position_ids,
273+
attention_mask=attention_mask,
274+
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
275+
use_reentrant=False,
276+
)
277+
else:
278+
# for pretrain
279+
hidden_states = recompute(
280+
super().forward,
281+
hidden_states,
282+
inputs_embeds_cur_depth,
283+
position_ids=position_ids,
284+
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
285+
use_reentrant=self.config.recompute_use_reentrant,
286+
)
287+
else:
288+
hidden_states = super().forward(
289+
hidden_states,
290+
inputs_embeds_cur_depth,
291+
position_ids=position_ids,
292+
attention_mask=attention_mask,
293+
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
294+
)
295+
output_list.append(hidden_states)
296+
297+
hidden_states = paddle.concat(output_list)
196298
return return_args(hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids)
197299

198300

199301
class DeepseekV2RMSNormPipe(nn.Layer):
200302
def __init__(self, config):
201303
super().__init__()
304+
self.config = config
202305
self.norm = DeepseekV2RMSNorm(config)
203306

204307
def forward(self, args):
205308
hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args)
206-
return self.norm(hidden_states)
309+
310+
if self.config.num_nextn_predict_layers > 0:
311+
hidden_states_list = paddle.split(hidden_states, self.config.num_nextn_predict_layers + 1)
312+
hidden_states = hidden_states_list[0]
313+
hidden_states_mtp = hidden_states_list[-self.config.num_nextn_predict_layers :]
314+
315+
output_list = [self.norm(hidden_states)]
316+
for hidden_states in hidden_states_mtp:
317+
output_list.append(self.norm(hidden_states))
318+
return output_list
319+
else:
320+
return self.norm(hidden_states)
207321

208322

209323
class DeepseekV2LMHeadPipe(DeepseekV2LMHead):
@@ -214,6 +328,27 @@ def __init__(self, config):
214328
def embedding_weight(self):
215329
return get_attr(self, "weight")
216330

331+
def forward(self, args: Union[Tuple, paddle.Tensor]):
332+
if self.config.num_nextn_predict_layers > 0:
333+
logits = []
334+
for _hidden_states in args:
335+
logits.append(super().forward(_hidden_states))
336+
return logits
337+
hidden_states = args
338+
logits = super().forward(hidden_states)
339+
return logits
340+
341+
342+
class DeepseekV2PretrainingCriterionPipe(DeepseekV2PretrainingCriterion):
343+
def forward(self, logits, labels):
344+
if self.config.num_nextn_predict_layers > 0:
345+
mtp_logits = logits[1:]
346+
logits = logits[0]
347+
loss = super().forward(logits, labels, mtp_logits=mtp_logits)
348+
else:
349+
loss = super().forward(logits, labels)
350+
return loss
351+
217352

218353
class DeepseekV2ForCausalLMPipe(PipelinePretrainedModel, PipelineLayer):
219354
"""DeepseekV2ForPretraining adapted for pipeline parallelism.
@@ -308,6 +443,12 @@ def get_hcg():
308443
),
309444
f"{self._base_model.base_model_prefix}.layers.{i}",
310445
)
446+
for i in range(config.num_nextn_predict_layers):
447+
self.add_sequential_layer(
448+
LayerDesc(DeepseekV2MTPLayerPipe, config=config, layer_idx=i),
449+
f"{self._base_model.base_model_prefix}.layers.{i}",
450+
)
451+
311452
self.add_sequential_layer(LayerDesc(DeepseekV2RMSNormPipe, config=config), self._base_model.base_model_prefix)
312453

313454
if config.tie_word_embeddings:
@@ -331,7 +472,7 @@ def get_hcg():
331472
), "pp recompute interval should smaller than num layers of each pp chunk"
332473
recompute_interval = self.config.pp_recompute_interval
333474

334-
seg_method = "layer:DeepseekV2DecoderLayer"
475+
seg_method = "layer:DeepseekV2DecoderLayer|MTPLayer"
335476
if config.num_hidden_layers % get_hcg().topology().get_dim_size("pipe") != 0:
336477
seg_method = "uniform"
337478

@@ -355,4 +496,4 @@ def get_hcg():
355496
# PipelinePretrainedModel.__init__(self.super(), config=config)
356497

357498
def get_loss_fn(self, config):
358-
return DeepseekV2PretrainingCriterion(config)
499+
return DeepseekV2PretrainingCriterionPipe(config)

paddlenlp/transformers/deepseek_v3/modeling.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,12 +142,15 @@ def forward(
142142
)
143143

144144
hidden_states = outputs[0]
145+
mtp_outputs = outputs[-1]
146+
145147
logits = self.lm_head(hidden_states)
148+
mtp_logits = [self.lm_head(_hidden_states) for _hidden_states in mtp_outputs] if len(mtp_outputs) > 0 else []
146149

147150
loss = None
148151
# TODO@DrownFish19: shift labels
149152
if labels is not None:
150-
loss = self.criterion(logits, labels)
153+
loss = self.criterion(logits, labels, mtp_logits=mtp_logits)
151154

152155
if not return_dict:
153156
output = (logits,) + outputs[1:]

paddlenlp/transformers/model_outputs.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,49 @@ class BaseModelOutputWithPastAndCrossAttentions(ModelOutput):
676676
cum_offsets: Optional[Tuple[paddle.Tensor]] = None
677677

678678

679+
@dataclass
680+
class BaseModelOutputWithPastAndMTP(ModelOutput):
681+
"""
682+
Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
683+
684+
Args:
685+
last_hidden_state (`paddle.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
686+
Sequence of hidden-states at the output of the last layer of the model.
687+
688+
If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
689+
hidden_size)` is output.
690+
past_key_values (`tuple(tuple(paddle.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
691+
Tuple of `tuple(paddle.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
692+
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
693+
`config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
694+
encoder_sequence_length, embed_size_per_head)`.
695+
696+
Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
697+
`config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
698+
input) to speed up sequential decoding.
699+
hidden_states (`tuple(paddle.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
700+
Tuple of `paddle.Tensor` (one for the output of the embeddings, if the model has an embedding layer, +
701+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
702+
703+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
704+
attentions (`tuple(paddle.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
705+
Tuple of `paddle.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
706+
sequence_length)`.
707+
708+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
709+
heads.
710+
mtp_outputs (`tuple(paddle.Tensor)`, *optional*):
711+
MTP Layers outputs, used to compute the mtp loss.
712+
heads.
713+
"""
714+
715+
last_hidden_state: paddle.Tensor = None
716+
past_key_values: Optional[Tuple[Tuple[paddle.Tensor]]] = None
717+
hidden_states: Optional[Tuple[paddle.Tensor]] = None
718+
attentions: Optional[Tuple[paddle.Tensor]] = None
719+
mtp_outputs: Optional[Tuple[paddle.Tensor]] = None
720+
721+
679722
@dataclass
680723
class BaseModelOutputWithPoolingAndCrossAttentions(ModelOutput):
681724
"""

0 commit comments

Comments
 (0)