Skip to content

Commit c621f0c

Browse files
authored
BugFix: qwen model sequence parallel can not get batch size (#11147)
* feat(model): add support for Qwen model in RL PipelineParallel * feat(model): add support for QwenMoe model in RL PipelineParallel * feat(model): add support for QwenMoe model in RL PipelineParallel * BugFix: qwen model sequence parallel can not get batch size
1 parent 3f7eafc commit c621f0c

File tree

4 files changed

+18
-1
lines changed

4 files changed

+18
-1
lines changed

paddlenlp/transformers/qwen2/modeling.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1490,7 +1490,12 @@ def forward(self, hidden_states, tensor_parallel_output=None, batch_size=None):
14901490

14911491
if self.config.sequence_parallel:
14921492
hidden_states = GatherOp.apply(hidden_states)
1493-
hidden_states = paddle.reshape_(hidden_states, [batch_size, -1, self.config.hidden_size])
1493+
if batch_size is not None:
1494+
hidden_states = paddle.reshape_(hidden_states, [batch_size, -1, self.config.hidden_size])
1495+
else:
1496+
hidden_states = paddle.reshape_(
1497+
hidden_states, [-1, self.config.max_sequence_length, self.config.hidden_size]
1498+
)
14941499

14951500
if tensor_parallel_output is None:
14961501
tensor_parallel_output = self.config.tensor_parallel_output

paddlenlp/transformers/qwen2/modeling_pp.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ def forward(self, args):
172172
elif attn_mask_startend_row_indices is not None and attn_mask_startend_row_indices.dtype == paddle.int64:
173173
attn_mask_startend_row_indices, position_ids = None, attn_mask_startend_row_indices
174174

175+
batch_size = position_ids.shape[0]
175176
if self.enable_recompute and self.config.recompute_granularity == "full" and has_gradient:
176177
recompute_fn = rr_recompute if any(self.skip_recompute_ops.values()) else recompute
177178
if attention_mask is not None or attn_mask_startend_row_indices is not None:
@@ -182,6 +183,7 @@ def forward(self, args):
182183
attention_mask=attention_mask,
183184
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
184185
use_reentrant=False,
186+
batch_size=batch_size,
185187
)
186188
else:
187189
# for pretrain
@@ -191,13 +193,15 @@ def forward(self, args):
191193
position_ids=position_ids,
192194
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
193195
use_reentrant=self.config.recompute_use_reentrant,
196+
batch_size=batch_size,
194197
)
195198
else:
196199
hidden_states = super().forward(
197200
hidden_states,
198201
position_ids=position_ids,
199202
attention_mask=attention_mask,
200203
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
204+
batch_size=batch_size,
201205
)
202206

203207
return return_args(hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids)

paddlenlp/transformers/qwen3/modeling_pp.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ def forward(self, args):
172172
elif attn_mask_startend_row_indices is not None and attn_mask_startend_row_indices.dtype == paddle.int64:
173173
attn_mask_startend_row_indices, position_ids = None, attn_mask_startend_row_indices
174174

175+
batch_size = position_ids.shape[0]
175176
if self.enable_recompute and self.config.recompute_granularity == "full" and has_gradient:
176177
recompute_fn = rr_recompute if any(self.skip_recompute_ops.values()) else recompute
177178
if attention_mask is not None or attn_mask_startend_row_indices is not None:
@@ -182,6 +183,7 @@ def forward(self, args):
182183
attention_mask=attention_mask,
183184
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
184185
use_reentrant=False,
186+
batch_size=batch_size,
185187
)
186188
else:
187189
# for pretrain
@@ -191,13 +193,15 @@ def forward(self, args):
191193
position_ids=position_ids,
192194
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
193195
use_reentrant=self.config.recompute_use_reentrant,
196+
batch_size=batch_size,
194197
)
195198
else:
196199
hidden_states = super().forward(
197200
hidden_states,
198201
position_ids=position_ids,
199202
attention_mask=attention_mask,
200203
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
204+
batch_size=batch_size,
201205
)
202206

203207
return return_args(hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids)

paddlenlp/transformers/qwen3_moe/modeling_pp.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def forward(self, args):
6868
elif attn_mask_startend_row_indices is not None and attn_mask_startend_row_indices.dtype == paddle.int64:
6969
attn_mask_startend_row_indices, position_ids = None, attn_mask_startend_row_indices
7070

71+
batch_size = position_ids.shape[0]
7172
if self.enable_recompute and self.config.recompute_granularity == "full" and has_gradient:
7273
if attention_mask is not None or attn_mask_startend_row_indices is not None:
7374
hidden_states = recompute(
@@ -77,6 +78,7 @@ def forward(self, args):
7778
attention_mask=attention_mask,
7879
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
7980
use_reentrant=False,
81+
batch_size=batch_size,
8082
)
8183
else:
8284
# for pretrain
@@ -86,13 +88,15 @@ def forward(self, args):
8688
position_ids=position_ids,
8789
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
8890
use_reentrant=self.config.recompute_use_reentrant,
91+
batch_size=batch_size,
8992
)
9093
else:
9194
hidden_states = super().forward(
9295
hidden_states,
9396
position_ids=position_ids,
9497
attention_mask=attention_mask,
9598
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
99+
batch_size=batch_size,
96100
)
97101

98102
return return_args(hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids)

0 commit comments

Comments
 (0)