File tree Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -1404,9 +1404,6 @@ def execute_model(
1404
1404
scheduler_output , intermediate_tensors ))
1405
1405
1406
1406
with ProfileExecuteDuration ().capture_async ("post process" ):
1407
- if self .input_batch .pooling_params :
1408
- return self ._pool (hidden_states , num_scheduled_tokens ,
1409
- num_scheduled_tokens_np )
1410
1407
# Broadcast PP output for external_launcher (torchrun)
1411
1408
# to make sure we are synced across pp ranks
1412
1409
# TODO: Support overlapping mirco-batches
@@ -1423,6 +1420,9 @@ def execute_model(
1423
1420
hidden_states .tensors , all_gather_group = get_tp_group ())
1424
1421
logits = None
1425
1422
else :
1423
+ if self .input_batch .pooling_params :
1424
+ return self ._pool (hidden_states , num_scheduled_tokens ,
1425
+ num_scheduled_tokens_np )
1426
1426
sample_hidden_states = hidden_states [logits_indices ]
1427
1427
logits = self .model .compute_logits (sample_hidden_states , None )
1428
1428
if broadcast_pp_output :
You can’t perform that action at this time.
0 commit comments