Skip to content

Commit dd087ef

Browse files
Refector prepare_inputs in model_runner_v1.py (vllm-project#2750)
### What this PR does / why we need it? Refector prepare_inputs in model_runner_v1.py for more easy read. ### Does this PR introduce _any_ user-facing change? NO ### How was this patch tested? PASS CI - vLLM version: v0.10.1.1 - vLLM main: vllm-project/vllm@e599e2c --------- Signed-off-by: ChenTaoyu-SJTU <ctynb@qq.com>
1 parent c735bb0 commit dd087ef

File tree

1 file changed

+89
-56
lines changed

1 file changed

+89
-56
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 89 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -880,6 +880,26 @@ def _gather_mm_embeddings(
880880
mm_embeds.append(mm_embeds_item)
881881
return mm_embeds
882882

883+
def _get_cumsum_and_arange(
884+
self,
885+
num_tokens: np.ndarray,
886+
cumsum_dtype: Optional[np.dtype] = None,
887+
) -> tuple[np.ndarray, np.ndarray]:
888+
"""Get the cumulative sum and batched arange of the given array.
889+
# E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2])
890+
# Equivalent to but faster than:
891+
# np.concatenate([np.arange(n) for n in num_tokens])
892+
"""
893+
# Step 1. [2, 5, 3] -> [2, 7, 10]
894+
cu_num_tokens = np.cumsum(num_tokens, dtype=cumsum_dtype)
895+
total_num_tokens = cu_num_tokens[-1]
896+
# Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7]
897+
cumsums_offsets = np.repeat(cu_num_tokens - num_tokens, num_tokens)
898+
# Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
899+
arange = self.arange_np[:total_num_tokens] - cumsums_offsets
900+
901+
return cu_num_tokens, arange
902+
883903
def _prepare_inputs(
884904
self,
885905
scheduler_output: "SchedulerOutput",
@@ -901,17 +921,16 @@ def _prepare_inputs(
901921
self.input_batch.block_table.commit_block_table(num_reqs)
902922

903923
# Get the number of scheduled tokens for each request.
904-
# TODO: The Python loop can be slow. Optimize.
905-
num_scheduled_tokens = np.empty(num_reqs, dtype=np.int32)
906-
num_valid_tokens = np.empty(num_reqs, dtype=np.int32)
907-
max_num_scheduled_tokens = 0
908-
for i, req_id in enumerate(self.input_batch.req_ids):
909-
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
910-
num_scheduled_tokens[i] = num_tokens
911-
num_valid_tokens[i] = num_tokens - \
912-
len(scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
913-
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
914-
num_tokens)
924+
req_ids = self.input_batch.req_ids
925+
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
926+
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
927+
max_num_scheduled_tokens = max(tokens)
928+
num_valid_tokens = np.array([
929+
num_tokens -
930+
len(scheduler_output.scheduled_spec_decode_tokens.get(i, []))
931+
for num_tokens, i in zip(tokens, req_ids)
932+
],
933+
dtype=np.int32)
915934

916935
if (self.use_aclgraph and total_num_scheduled_tokens
917936
<= self.aclgraph_batch_sizes[-1]):
@@ -952,13 +971,15 @@ def _prepare_inputs(
952971
if self.lora_config:
953972
self.set_active_loras(self.input_batch, num_scheduled_tokens)
954973

955-
# Prepare positions
974+
# Get request indices.
975+
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
956976
req_indices = np.repeat(self.arange_np[:num_reqs],
957977
num_scheduled_tokens)
958-
cu_num_tokens = np.cumsum(num_scheduled_tokens)
959-
cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens,
960-
num_scheduled_tokens)
961-
arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets
978+
979+
# cu_num_tokens: [2, 5, 3] -> [2, 7, 10]
980+
# arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
981+
cu_num_tokens, arange = self._get_cumsum_and_arange(
982+
num_scheduled_tokens)
962983

963984
positions_np = self.positions_np[:total_num_scheduled_tokens]
964985
np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
@@ -975,50 +996,73 @@ def _prepare_inputs(
975996
self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
976997
non_blocking=True)
977998

978-
self.positions_cpu[total_num_scheduled_tokens:num_input_tokens].zero_()
979-
self.positions[:num_input_tokens].copy_(
980-
self.positions_cpu[:num_input_tokens], non_blocking=True)
981-
positions_cpu = self.positions_cpu[:num_input_tokens]
982-
positions = self.positions[:num_input_tokens]
983-
self.query_lens = torch.from_numpy(num_scheduled_tokens)
999+
# Get token indices.
1000+
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
1001+
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
1002+
# where M is the max_model_len.
1003+
token_indices = (positions_np +
1004+
req_indices * self.input_batch.token_ids_cpu.shape[1])
1005+
1006+
# Prepare input_ids.
1007+
# NOTE(woosuk): We use torch.index_select instead of np.take here
1008+
# because torch.index_select is much faster than np.take for large
1009+
# tensors.
1010+
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
1011+
0,
1012+
torch.from_numpy(token_indices),
1013+
out=self.input_ids_cpu[:total_num_scheduled_tokens])
1014+
1015+
# Prepare some information for building Attention-Metadata
1016+
# Compute and commit slot mapping
1017+
self.input_batch.block_table.compute_slot_mapping(
1018+
req_indices, positions_np)
1019+
self.input_batch.block_table.commit_slot_mapping(
1020+
total_num_scheduled_tokens)
1021+
self.slot_mapping_cpu[:total_num_scheduled_tokens].copy_(
1022+
self.input_batch.block_table[0].
1023+
slot_mapping_cpu[:total_num_scheduled_tokens])
1024+
1025+
self.query_start_loc_np[0] = 0
1026+
self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
1027+
self.query_start_loc[:num_reqs + 1].copy_(
1028+
self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True)
9841029

9851030
self.seq_lens_np[:num_reqs] = (
9861031
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
9871032
num_scheduled_tokens)
988-
seq_lens_cpu = self.seq_lens_cpu[:num_reqs]
1033+
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
1034+
non_blocking=True)
9891035

990-
block_table_indices = (req_indices * self.max_num_blocks_per_req +
991-
positions_np // self.block_size)
1036+
# Fill unused with -1. Needed for reshape_and_cache
1037+
self.query_start_loc[num_reqs + 1:].fill_(-1)
1038+
self.seq_lens[num_reqs:].fill_(0)
9921039

993-
block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor()
994-
block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
995-
block_offsets = positions_np % self.block_size
996-
np.add(block_numbers * self.block_size,
997-
block_offsets,
998-
out=self.slot_mapping_np[:total_num_scheduled_tokens])
1040+
self.query_lens = torch.from_numpy(num_scheduled_tokens)
9991041

1042+
# Copy the tensors to the NPU.
1043+
self.input_ids[:total_num_scheduled_tokens].copy_(
1044+
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
1045+
1046+
self.positions_cpu[total_num_scheduled_tokens:num_input_tokens].zero_()
1047+
self.positions[:num_input_tokens].copy_(
1048+
self.positions_cpu[:num_input_tokens], non_blocking=True)
1049+
1050+
# Make Attention metadata
1051+
positions_cpu = self.positions_cpu[:num_input_tokens]
1052+
positions = self.positions[:num_input_tokens]
1053+
seq_lens_cpu = self.seq_lens_cpu[:num_reqs]
10001054
attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens,
10011055
num_valid_tokens)
1002-
10031056
self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu,
10041057
position=positions_cpu,
10051058
attn_state=attn_state)
10061059
self.attn_state = attn_state # type: ignore
10071060

1008-
self.query_start_loc_np[0] = 0
1009-
self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
1010-
self.query_start_loc[:num_reqs + 1].copy_(
1011-
self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True)
1012-
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
1013-
non_blocking=True)
1014-
1015-
# Fill unused with -1. Needed for reshape_and_cache
1016-
self.seq_lens[num_reqs:].fill_(0)
1017-
self.query_start_loc[num_reqs + 1:].fill_(-1)
1018-
10191061
self.with_prefill = with_prefill
10201062
self.num_tokens_across_dp = num_tokens_across_dp
10211063
self._update_graph_pad_size(with_prefill, maybe_padded_num_tokens)
1064+
1065+
# Make AscendCommonAttentionMetadata
10221066
common_attn_metadata = AscendCommonAttentionMetadata(
10231067
query_start_loc=self.query_start_loc[:num_reqs + 1],
10241068
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
@@ -1044,19 +1088,8 @@ def _prepare_inputs(
10441088
if self.vllm_config.model_config.use_mla:
10451089
attn_metadata.num_input_tokens = num_input_tokens
10461090

1047-
# Prepare input_ids
1048-
token_indices = (positions_np +
1049-
req_indices * self.input_batch.token_ids_cpu.shape[1])
1050-
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
1051-
0,
1052-
torch.from_numpy(token_indices),
1053-
out=self.input_ids_cpu[:total_num_scheduled_tokens])
1054-
# Copy the tensors to the NPU.
1055-
self.input_ids[:total_num_scheduled_tokens].copy_(
1056-
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
1057-
1058-
# _prepare_inputs may reorder the batch, so we must gather multi
1059-
# modal outputs after that to ensure the correct order
1091+
# _prepare_inputs may reorder the batch, so we must gather
1092+
# multi-modal outputs after that to ensure the correct order
10601093
if self.is_multimodal_model:
10611094
# Run the multimodal encoder if any.
10621095
self._execute_mm_encoder(scheduler_output)

0 commit comments

Comments
 (0)