@@ -880,6 +880,26 @@ def _gather_mm_embeddings(
880
880
mm_embeds .append (mm_embeds_item )
881
881
return mm_embeds
882
882
883
+ def _get_cumsum_and_arange (
884
+ self ,
885
+ num_tokens : np .ndarray ,
886
+ cumsum_dtype : Optional [np .dtype ] = None ,
887
+ ) -> tuple [np .ndarray , np .ndarray ]:
888
+ """Get the cumulative sum and batched arange of the given array.
889
+ # E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2])
890
+ # Equivalent to but faster than:
891
+ # np.concatenate([np.arange(n) for n in num_tokens])
892
+ """
893
+ # Step 1. [2, 5, 3] -> [2, 7, 10]
894
+ cu_num_tokens = np .cumsum (num_tokens , dtype = cumsum_dtype )
895
+ total_num_tokens = cu_num_tokens [- 1 ]
896
+ # Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7]
897
+ cumsums_offsets = np .repeat (cu_num_tokens - num_tokens , num_tokens )
898
+ # Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
899
+ arange = self .arange_np [:total_num_tokens ] - cumsums_offsets
900
+
901
+ return cu_num_tokens , arange
902
+
883
903
def _prepare_inputs (
884
904
self ,
885
905
scheduler_output : "SchedulerOutput" ,
@@ -901,17 +921,16 @@ def _prepare_inputs(
901
921
self .input_batch .block_table .commit_block_table (num_reqs )
902
922
903
923
# Get the number of scheduled tokens for each request.
904
- # TODO: The Python loop can be slow. Optimize.
905
- num_scheduled_tokens = np .empty (num_reqs , dtype = np .int32 )
906
- num_valid_tokens = np .empty (num_reqs , dtype = np .int32 )
907
- max_num_scheduled_tokens = 0
908
- for i , req_id in enumerate (self .input_batch .req_ids ):
909
- num_tokens = scheduler_output .num_scheduled_tokens [req_id ]
910
- num_scheduled_tokens [i ] = num_tokens
911
- num_valid_tokens [i ] = num_tokens - \
912
- len (scheduler_output .scheduled_spec_decode_tokens .get (req_id , []))
913
- max_num_scheduled_tokens = max (max_num_scheduled_tokens ,
914
- num_tokens )
924
+ req_ids = self .input_batch .req_ids
925
+ tokens = [scheduler_output .num_scheduled_tokens [i ] for i in req_ids ]
926
+ num_scheduled_tokens = np .array (tokens , dtype = np .int32 )
927
+ max_num_scheduled_tokens = max (tokens )
928
+ num_valid_tokens = np .array ([
929
+ num_tokens -
930
+ len (scheduler_output .scheduled_spec_decode_tokens .get (i , []))
931
+ for num_tokens , i in zip (tokens , req_ids )
932
+ ],
933
+ dtype = np .int32 )
915
934
916
935
if (self .use_aclgraph and total_num_scheduled_tokens
917
936
<= self .aclgraph_batch_sizes [- 1 ]):
@@ -952,13 +971,15 @@ def _prepare_inputs(
952
971
if self .lora_config :
953
972
self .set_active_loras (self .input_batch , num_scheduled_tokens )
954
973
955
- # Prepare positions
974
+ # Get request indices.
975
+ # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
956
976
req_indices = np .repeat (self .arange_np [:num_reqs ],
957
977
num_scheduled_tokens )
958
- cu_num_tokens = np .cumsum (num_scheduled_tokens )
959
- cumsums_offsets = np .repeat (cu_num_tokens - num_scheduled_tokens ,
960
- num_scheduled_tokens )
961
- arange = self .arange_np [:total_num_scheduled_tokens ] - cumsums_offsets
978
+
979
+ # cu_num_tokens: [2, 5, 3] -> [2, 7, 10]
980
+ # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
981
+ cu_num_tokens , arange = self ._get_cumsum_and_arange (
982
+ num_scheduled_tokens )
962
983
963
984
positions_np = self .positions_np [:total_num_scheduled_tokens ]
964
985
np .add (self .input_batch .num_computed_tokens_cpu [req_indices ],
@@ -975,50 +996,73 @@ def _prepare_inputs(
975
996
self .mrope_positions_cpu [:, :total_num_scheduled_tokens ],
976
997
non_blocking = True )
977
998
978
- self .positions_cpu [total_num_scheduled_tokens :num_input_tokens ].zero_ ()
979
- self .positions [:num_input_tokens ].copy_ (
980
- self .positions_cpu [:num_input_tokens ], non_blocking = True )
981
- positions_cpu = self .positions_cpu [:num_input_tokens ]
982
- positions = self .positions [:num_input_tokens ]
983
- self .query_lens = torch .from_numpy (num_scheduled_tokens )
999
+ # Get token indices.
1000
+ # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
1001
+ # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
1002
+ # where M is the max_model_len.
1003
+ token_indices = (positions_np +
1004
+ req_indices * self .input_batch .token_ids_cpu .shape [1 ])
1005
+
1006
+ # Prepare input_ids.
1007
+ # NOTE(woosuk): We use torch.index_select instead of np.take here
1008
+ # because torch.index_select is much faster than np.take for large
1009
+ # tensors.
1010
+ torch .index_select (self .input_batch .token_ids_cpu_tensor .flatten (),
1011
+ 0 ,
1012
+ torch .from_numpy (token_indices ),
1013
+ out = self .input_ids_cpu [:total_num_scheduled_tokens ])
1014
+
1015
+ # Prepare some information for building Attention-Metadata
1016
+ # Compute and commit slot mapping
1017
+ self .input_batch .block_table .compute_slot_mapping (
1018
+ req_indices , positions_np )
1019
+ self .input_batch .block_table .commit_slot_mapping (
1020
+ total_num_scheduled_tokens )
1021
+ self .slot_mapping_cpu [:total_num_scheduled_tokens ].copy_ (
1022
+ self .input_batch .block_table [0 ].
1023
+ slot_mapping_cpu [:total_num_scheduled_tokens ])
1024
+
1025
+ self .query_start_loc_np [0 ] = 0
1026
+ self .query_start_loc_np [1 :num_reqs + 1 ] = cu_num_tokens
1027
+ self .query_start_loc [:num_reqs + 1 ].copy_ (
1028
+ self .query_start_loc_cpu [:num_reqs + 1 ], non_blocking = True )
984
1029
985
1030
self .seq_lens_np [:num_reqs ] = (
986
1031
self .input_batch .num_computed_tokens_cpu [:num_reqs ] +
987
1032
num_scheduled_tokens )
988
- seq_lens_cpu = self .seq_lens_cpu [:num_reqs ]
1033
+ self .seq_lens [:num_reqs ].copy_ (self .seq_lens_cpu [:num_reqs ],
1034
+ non_blocking = True )
989
1035
990
- block_table_indices = (req_indices * self .max_num_blocks_per_req +
991
- positions_np // self .block_size )
1036
+ # Fill unused with -1. Needed for reshape_and_cache
1037
+ self .query_start_loc [num_reqs + 1 :].fill_ (- 1 )
1038
+ self .seq_lens [num_reqs :].fill_ (0 )
992
1039
993
- block_table_cpu = self .input_batch .block_table [0 ].get_cpu_tensor ()
994
- block_numbers = block_table_cpu .flatten ()[block_table_indices ].numpy ()
995
- block_offsets = positions_np % self .block_size
996
- np .add (block_numbers * self .block_size ,
997
- block_offsets ,
998
- out = self .slot_mapping_np [:total_num_scheduled_tokens ])
1040
+ self .query_lens = torch .from_numpy (num_scheduled_tokens )
999
1041
1042
+ # Copy the tensors to the NPU.
1043
+ self .input_ids [:total_num_scheduled_tokens ].copy_ (
1044
+ self .input_ids_cpu [:total_num_scheduled_tokens ], non_blocking = True )
1045
+
1046
+ self .positions_cpu [total_num_scheduled_tokens :num_input_tokens ].zero_ ()
1047
+ self .positions [:num_input_tokens ].copy_ (
1048
+ self .positions_cpu [:num_input_tokens ], non_blocking = True )
1049
+
1050
+ # Make Attention metadata
1051
+ positions_cpu = self .positions_cpu [:num_input_tokens ]
1052
+ positions = self .positions [:num_input_tokens ]
1053
+ seq_lens_cpu = self .seq_lens_cpu [:num_reqs ]
1000
1054
attn_state = self ._build_attn_state (num_reqs , num_scheduled_tokens ,
1001
1055
num_valid_tokens )
1002
-
1003
1056
self .attn_mask = self ._make_attention_mask (seq_lens = seq_lens_cpu ,
1004
1057
position = positions_cpu ,
1005
1058
attn_state = attn_state )
1006
1059
self .attn_state = attn_state # type: ignore
1007
1060
1008
- self .query_start_loc_np [0 ] = 0
1009
- self .query_start_loc_np [1 :num_reqs + 1 ] = cu_num_tokens
1010
- self .query_start_loc [:num_reqs + 1 ].copy_ (
1011
- self .query_start_loc_cpu [:num_reqs + 1 ], non_blocking = True )
1012
- self .seq_lens [:num_reqs ].copy_ (self .seq_lens_cpu [:num_reqs ],
1013
- non_blocking = True )
1014
-
1015
- # Fill unused with -1. Needed for reshape_and_cache
1016
- self .seq_lens [num_reqs :].fill_ (0 )
1017
- self .query_start_loc [num_reqs + 1 :].fill_ (- 1 )
1018
-
1019
1061
self .with_prefill = with_prefill
1020
1062
self .num_tokens_across_dp = num_tokens_across_dp
1021
1063
self ._update_graph_pad_size (with_prefill , maybe_padded_num_tokens )
1064
+
1065
+ # Make AscendCommonAttentionMetadata
1022
1066
common_attn_metadata = AscendCommonAttentionMetadata (
1023
1067
query_start_loc = self .query_start_loc [:num_reqs + 1 ],
1024
1068
query_start_loc_cpu = self .query_start_loc_cpu [:num_reqs + 1 ],
@@ -1044,19 +1088,8 @@ def _prepare_inputs(
1044
1088
if self .vllm_config .model_config .use_mla :
1045
1089
attn_metadata .num_input_tokens = num_input_tokens
1046
1090
1047
- # Prepare input_ids
1048
- token_indices = (positions_np +
1049
- req_indices * self .input_batch .token_ids_cpu .shape [1 ])
1050
- torch .index_select (self .input_batch .token_ids_cpu_tensor .flatten (),
1051
- 0 ,
1052
- torch .from_numpy (token_indices ),
1053
- out = self .input_ids_cpu [:total_num_scheduled_tokens ])
1054
- # Copy the tensors to the NPU.
1055
- self .input_ids [:total_num_scheduled_tokens ].copy_ (
1056
- self .input_ids_cpu [:total_num_scheduled_tokens ], non_blocking = True )
1057
-
1058
- # _prepare_inputs may reorder the batch, so we must gather multi
1059
- # modal outputs after that to ensure the correct order
1091
+ # _prepare_inputs may reorder the batch, so we must gather
1092
+ # multi-modal outputs after that to ensure the correct order
1060
1093
if self .is_multimodal_model :
1061
1094
# Run the multimodal encoder if any.
1062
1095
self ._execute_mm_encoder (scheduler_output )
0 commit comments