Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion paddlenlp/transformers/qwen2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1490,7 +1490,12 @@ def forward(self, hidden_states, tensor_parallel_output=None, batch_size=None):

if self.config.sequence_parallel:
hidden_states = GatherOp.apply(hidden_states)
hidden_states = paddle.reshape_(hidden_states, [batch_size, -1, self.config.hidden_size])
if batch_size is not None:
hidden_states = paddle.reshape_(hidden_states, [batch_size, -1, self.config.hidden_size])
else:
hidden_states = paddle.reshape_(
hidden_states, [-1, self.config.max_sequence_length, self.config.hidden_size]
)

if tensor_parallel_output is None:
tensor_parallel_output = self.config.tensor_parallel_output
Expand Down
5 changes: 5 additions & 0 deletions paddlenlp/transformers/qwen2/modeling_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def forward(self, args):
elif attn_mask_startend_row_indices is not None and attn_mask_startend_row_indices.dtype == paddle.int64:
attn_mask_startend_row_indices, position_ids = None, attn_mask_startend_row_indices

batch_size = position_ids.shape[0]
if self.enable_recompute and self.config.recompute_granularity == "full" and has_gradient:
recompute_fn = rr_recompute if any(self.skip_recompute_ops.values()) else recompute
if attention_mask is not None or attn_mask_startend_row_indices is not None:
Expand All @@ -182,6 +183,7 @@ def forward(self, args):
attention_mask=attention_mask,
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
use_reentrant=False,
batch_size=batch_size,
)
else:
# for pretrain
Expand All @@ -191,13 +193,15 @@ def forward(self, args):
position_ids=position_ids,
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
use_reentrant=self.config.recompute_use_reentrant,
batch_size=batch_size,
)
else:
hidden_states = super().forward(
hidden_states,
position_ids=position_ids,
attention_mask=attention_mask,
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
batch_size=batch_size,
)

return return_args(hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids)
Expand Down Expand Up @@ -232,6 +236,7 @@ class Qwen2ForCausalLMPipe(PipelinePretrainedModel, PipelineLayer):
config_class = Qwen2Config

_get_tensor_parallel_mappings = Qwen2PretrainedModel._get_tensor_parallel_mappings
_get_fuse_or_split_param_mappings = Qwen2PretrainedModel._get_fuse_or_split_param_mappings
_init_weights = Qwen2PretrainedModel._init_weights
_keys_to_ignore_on_load_unexpected = Qwen2PretrainedModel._keys_to_ignore_on_load_unexpected
_get_model_flops = Qwen2PretrainedModel._get_model_flops
Expand Down
1 change: 1 addition & 0 deletions paddlenlp/transformers/qwen2_moe/modeling_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ class Qwen2MoeForCausalLMPipe(PipelinePretrainedModel, PipelineLayer):
config_class = Qwen2MoeConfig

_get_tensor_parallel_mappings = Qwen2MoePretrainedModel._get_tensor_parallel_mappings
_get_fuse_or_split_param_mappings = Qwen2MoePretrainedModel._get_fuse_or_split_param_mappings
_init_weights = Qwen2MoePretrainedModel._init_weights
_keys_to_ignore_on_load_unexpected = Qwen2MoePretrainedModel._keys_to_ignore_on_load_unexpected
_tied_weights_keys = ["lm_head.weight"]
Expand Down
5 changes: 5 additions & 0 deletions paddlenlp/transformers/qwen3/modeling_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def forward(self, args):
elif attn_mask_startend_row_indices is not None and attn_mask_startend_row_indices.dtype == paddle.int64:
attn_mask_startend_row_indices, position_ids = None, attn_mask_startend_row_indices

batch_size = position_ids.shape[0]
if self.enable_recompute and self.config.recompute_granularity == "full" and has_gradient:
recompute_fn = rr_recompute if any(self.skip_recompute_ops.values()) else recompute
if attention_mask is not None or attn_mask_startend_row_indices is not None:
Expand All @@ -182,6 +183,7 @@ def forward(self, args):
attention_mask=attention_mask,
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
use_reentrant=False,
batch_size=batch_size,
)
else:
# for pretrain
Expand All @@ -191,13 +193,15 @@ def forward(self, args):
position_ids=position_ids,
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
use_reentrant=self.config.recompute_use_reentrant,
batch_size=batch_size,
)
else:
hidden_states = super().forward(
hidden_states,
position_ids=position_ids,
attention_mask=attention_mask,
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
batch_size=batch_size,
)

return return_args(hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids)
Expand Down Expand Up @@ -232,6 +236,7 @@ class Qwen3ForCausalLMPipe(PipelinePretrainedModel, PipelineLayer):
config_class = Qwen3Config

_get_tensor_parallel_mappings = Qwen3PretrainedModel._get_tensor_parallel_mappings
_get_fuse_or_split_param_mappings = Qwen3PretrainedModel._get_fuse_or_split_param_mappings
_init_weights = Qwen3PretrainedModel._init_weights
_keys_to_ignore_on_load_unexpected = Qwen3PretrainedModel._keys_to_ignore_on_load_unexpected
_get_model_flops = Qwen3PretrainedModel._get_model_flops
Expand Down
31 changes: 31 additions & 0 deletions paddlenlp/transformers/qwen3_moe/modeling_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def forward(self, args):
elif attn_mask_startend_row_indices is not None and attn_mask_startend_row_indices.dtype == paddle.int64:
attn_mask_startend_row_indices, position_ids = None, attn_mask_startend_row_indices

batch_size = position_ids.shape[0]
if self.enable_recompute and self.config.recompute_granularity == "full" and has_gradient:
if attention_mask is not None or attn_mask_startend_row_indices is not None:
hidden_states = recompute(
Expand All @@ -77,6 +78,7 @@ def forward(self, args):
attention_mask=attention_mask,
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
use_reentrant=False,
batch_size=batch_size,
)
else:
# for pretrain
Expand All @@ -86,13 +88,15 @@ def forward(self, args):
position_ids=position_ids,
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
use_reentrant=self.config.recompute_use_reentrant,
batch_size=batch_size,
)
else:
hidden_states = super().forward(
hidden_states,
position_ids=position_ids,
attention_mask=attention_mask,
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
batch_size=batch_size,
)

return return_args(hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids)
Expand Down Expand Up @@ -127,11 +131,38 @@ class Qwen3MoeForCausalLMPipe(PipelinePretrainedModel, PipelineLayer):
config_class = Qwen3MoeConfig

_get_tensor_parallel_mappings = Qwen3MoePretrainedModel._get_tensor_parallel_mappings
_get_fuse_or_split_param_mappings = Qwen3MoePretrainedModel._get_fuse_or_split_param_mappings
_init_weights = Qwen3MoePretrainedModel._init_weights
_keys_to_ignore_on_load_unexpected = Qwen3MoePretrainedModel._keys_to_ignore_on_load_unexpected
_tied_weights_keys = ["lm_head.weight"]

# DONOT Add base_model_prefix !!!!
@classmethod
def get_tensor_parallel_convert_actions(
cls, config, loaded_state_dict_keys, is_split=True, ignore_error=False, base_model_prefix=None
):
"""
Get the tensor parallel convert actions for the model.
This function is overridden to handle the case where MoE experts are grouped and should not be split across TP ranks.
"""
# Get the default tensor parallel actions from the base class by calling super() with the exact same arguments.
tp_actions = super().get_tensor_parallel_convert_actions(
config,
loaded_state_dict_keys,
is_split=is_split,
ignore_error=ignore_error,
base_model_prefix=base_model_prefix,
)

# If moe_group is set, expert parameters should not be split.
# We remove them from the tp_actions dictionary.
if "Qwen3MoeForCausalLM" in config.architectures and config.moe_group == "tp":
# Iterate over a copy of the keys to safely modify the dictionary
for key in list(tp_actions.keys()):
if "mlp.experts" in key:
del tp_actions[key]

return tp_actions

@classmethod
def _prepare_pipeline_inputs_func(cls, inputs):
Expand Down
Loading