Skip to content

Commit e63fe82

Browse files
committed
Prompt Embeddings Support for v1 Engine
Signed-off-by: jesse <szxfml@gmail.com>
1 parent 0c04bf1 commit e63fe82

File tree

2 files changed

+119
-15
lines changed

2 files changed

+119
-15
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 71 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@
6767
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
6868
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
6969
LazyLoader, cdiv, get_dtype_size,
70-
is_pin_memory_available)
70+
is_pin_memory_available,
71+
length_from_prompt_token_ids_or_embeds)
7172
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
7273
from vllm.v1.attention.backends.utils import \
7374
reorder_batch_to_split_decodes_and_prefills
@@ -285,11 +286,14 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
285286

286287
self.is_multimodal_model = self.model_config.is_multimodal_model
287288
self.is_pooling_model = self.model_config.pooler_config is not None
289+
self.enable_prompt_embeds = self.model_config.enable_prompt_embeds
288290
if self.is_multimodal_model:
289-
self.inputs_embeds = torch.zeros(
290-
(self.max_num_tokens, self.model_config.get_hidden_size()),
291-
dtype=self.dtype,
292-
device=self.device)
291+
self.inputs_embeds = self._make_buffer(self.max_num_tokens,
292+
self.model_config.get_hidden_size(),
293+
dtype=self.dtype,
294+
numpy=False)
295+
self.is_token_ids = self._make_buffer(self.max_num_tokens,
296+
dtype=torch.bool)
293297

294298
# Set up Attention
295299
self.attn_backend = get_attn_backend(
@@ -572,6 +576,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
572576
self.requests[req_id] = CachedRequestState(
573577
req_id=req_id,
574578
prompt_token_ids=new_req_data.prompt_token_ids,
579+
prompt_embeds=new_req_data.prompt_embeds,
575580
mm_kwargs=new_req_data.mm_kwargs,
576581
mm_positions=new_req_data.mm_positions,
577582
sampling_params=sampling_params,
@@ -843,7 +848,8 @@ def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
843848
self.input_batch.num_computed_tokens_cpu[index]
844849
num_scheduled_tokens = \
845850
scheduler_output.num_scheduled_tokens[req_id]
846-
num_prompt_tokens = len(req.prompt_token_ids)
851+
num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
852+
req.prompt_token_ids, req.prompt_embeds)
847853

848854
if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens:
849855
prompt_part_len = max(0,
@@ -1016,6 +1022,8 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int,
10161022
self.input_ids[:total_num_scheduled_tokens].copy_(
10171023
self.input_ids_cpu[:total_num_scheduled_tokens],
10181024
non_blocking=True)
1025+
self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens)
1026+
self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens)
10191027
return
10201028

10211029
# Async scheduling case, where some decode requests from the previous
@@ -1043,6 +1051,8 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int,
10431051
self.input_ids[:total_num_scheduled_tokens].copy_(
10441052
self.input_ids_cpu[:total_num_scheduled_tokens],
10451053
non_blocking=True)
1054+
self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens)
1055+
self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens)
10461056
if num_commmon_tokens == 0:
10471057
# No requests in common with the previous iteration
10481058
# So input_ids_cpu will have all the input ids.
@@ -1056,6 +1066,7 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int,
10561066
self.input_batch.prev_sampled_token_ids[:num_commmon_tokens,
10571067
0],
10581068
non_blocking=True)
1069+
self.is_token_ids.gpu[:num_commmon_tokens] = True
10591070
return
10601071
# Upload the index tensors asynchronously
10611072
# so the scatter can be non-blocking.
@@ -1195,15 +1206,60 @@ def _prepare_inputs(
11951206
# where M is the max_model_len.
11961207
token_indices = (positions_np +
11971208
req_indices * self.input_batch.token_ids_cpu.shape[1])
1198-
1209+
token_indices_tensor = torch.from_numpy(token_indices)
11991210
# Prepare input_ids.
12001211
# NOTE(woosuk): We use torch.index_select instead of np.take here
12011212
# because torch.index_select is much faster than np.take for large
12021213
# tensors.
12031214
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
12041215
0,
1205-
torch.from_numpy(token_indices),
1216+
token_indices_tensor,
12061217
out=self.input_ids_cpu[:total_num_scheduled_tokens])
1218+
is_token_ids = self.input_batch.is_token_ids.flatten()
1219+
torch.index_select(
1220+
is_token_ids,
1221+
0,
1222+
token_indices_tensor,
1223+
out=self.is_token_ids.cpu[:total_num_scheduled_tokens])
1224+
1225+
# Because we did not pre-allocate a massive prompt_embeds CPU tensor on
1226+
# the InputBatch, we need to fill in the prompt embeds into the expected
1227+
# spots in the GpuModelRunner's pre-allocated prompt_embeds tensor.
1228+
if self.input_batch.req_prompt_embeds:
1229+
output_idx = 0
1230+
for req_idx in range(num_reqs):
1231+
num_sched = num_scheduled_tokens[req_idx]
1232+
1233+
# Skip if this request doesn't have embeddings
1234+
if req_idx not in self.input_batch.req_prompt_embeds:
1235+
output_idx += num_sched
1236+
continue
1237+
1238+
# Skip if no tokens scheduled
1239+
if num_sched <= 0:
1240+
output_idx += num_sched
1241+
continue
1242+
1243+
req_embeds = self.input_batch.req_prompt_embeds[req_idx]
1244+
start_pos = self.input_batch.num_computed_tokens_cpu[req_idx]
1245+
1246+
# Skip if trying to read beyond available embeddings
1247+
if start_pos >= req_embeds.shape[0]:
1248+
output_idx += num_sched
1249+
continue
1250+
1251+
# Copy available embeddings
1252+
end_pos = start_pos + num_sched
1253+
actual_end = min(end_pos, req_embeds.shape[0])
1254+
actual_num_sched = actual_end - start_pos
1255+
1256+
if actual_num_sched > 0:
1257+
self.inputs_embeds.cpu[output_idx:output_idx +
1258+
actual_num_sched].copy_(
1259+
req_embeds[start_pos:actual_end]
1260+
)
1261+
1262+
output_idx += num_sched
12071263

12081264
# Prepare some information for building Attention-Metadata
12091265
# Compute and commit slot mapping
@@ -1985,6 +2041,7 @@ def execute_model(
19852041

19862042
self.input_batch.token_ids_cpu[req_idx,
19872043
start_idx:end_idx] = sampled_ids
2044+
self.input_batch.is_token_ids[req_idx, start_idx:end_idx] = True
19882045
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
19892046
self.input_batch.num_tokens[req_idx] = end_idx
19902047
req_id = self.input_batch.req_ids[req_idx]
@@ -2200,6 +2257,9 @@ def _dummy_run(
22002257
if self.is_multimodal_model:
22012258
input_ids = None
22022259
inputs_embeds = self.inputs_embeds[:num_tokens]
2260+
elif self.enable_prompt_embeds:
2261+
input_ids = None
2262+
inputs_embeds = self.inputs_embeds.gpu[:num_tokens]
22032263
else:
22042264
input_ids = self.input_ids[:num_tokens]
22052265
inputs_embeds = None
@@ -3070,6 +3130,9 @@ def _get_prompt_logprobs_dict(
30703130

30713131
# Get metadata for this request.
30723132
request = self.requests[req_id]
3133+
if request.prompt_token_ids is None:
3134+
# Prompt logprobs is incompatible with prompt embeddings
3135+
continue
30733136
num_prompt_tokens = len(request.prompt_token_ids)
30743137
prompt_token_ids = torch.tensor(request.prompt_token_ids).to(
30753138
self.device, non_blocking=True)

vllm_ascend/worker/npu_input_batch.py

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
PlaceholderRange)
2929
from vllm.pooling_params import PoolingParams
3030
from vllm.sampling_params import SamplingParams, SamplingType
31-
from vllm.utils import swap_dict_values
31+
from vllm.utils import length_from_prompt_token_ids_or_embeds,swap_dict_values
3232
from vllm.v1.outputs import LogprobsTensors
3333
from vllm.v1.pool.metadata import PoolingMetadata
3434
from vllm.v1.sample.logits_processor import (BatchUpdateBuilder,
@@ -45,7 +45,7 @@
4545
class CachedRequestState:
4646

4747
req_id: str
48-
prompt_token_ids: list[int]
48+
prompt_token_ids: Optional[list[int]]
4949
mm_kwargs: list[MultiModalKwargsItem]
5050
mm_positions: list[PlaceholderRange]
5151
mm_hashes: list[str]
@@ -61,9 +61,11 @@ class CachedRequestState:
6161
mrope_position_delta: Optional[int] = None
6262

6363
lora_request: Optional[LoRARequest] = None
64+
prompt_embeds: Optional[torch.Tensor] = None
6465

6566
def __post_init__(self):
66-
self.num_prompt_tokens = len(self.prompt_token_ids)
67+
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
68+
self.prompt_token_ids, self.prompt_embeds)
6769

6870
@property
6971
def num_tokens(self) -> int:
@@ -78,6 +80,10 @@ def mm_inputs(self) -> list[MultiModalKwargs]:
7880

7981
def get_token_id(self, idx: int) -> int:
8082
if idx < self.num_prompt_tokens:
83+
if self.prompt_token_ids is None:
84+
raise ValueError(
85+
f"Tried to access token index {idx}, but that token was "
86+
"provided via prompt_embeds, and its ID is unknown.")
8187
return self.prompt_token_ids[idx]
8288
else:
8389
return self.output_token_ids[idx - self.num_prompt_tokens]
@@ -122,6 +128,14 @@ def __init__(
122128
pin_memory=False,
123129
)
124130
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
131+
self.is_token_ids = torch.zeros((max_num_reqs, max_model_len),
132+
device="cpu",
133+
dtype=bool,
134+
pin_memory=False)
135+
# Store prompt embeddings per request to avoid OOM from large upfront
136+
# allocation if max_model_len is big.
137+
# Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size)
138+
self.req_prompt_embeds: dict[int, torch.Tensor] = {}
125139
self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
126140
self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
127141
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
@@ -326,15 +340,23 @@ def add_request(
326340
self.req_id_to_index[req_id] = req_index
327341

328342
# Copy the prompt token ids and output token ids.
329-
num_prompt_tokens = len(request.prompt_token_ids)
343+
num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
344+
request.prompt_token_ids, request.prompt_embeds)
330345
self.num_prompt_tokens[req_index] = num_prompt_tokens
331-
self.token_ids_cpu[
332-
req_index, :num_prompt_tokens] = request.prompt_token_ids
333346
start_idx = num_prompt_tokens
334347
end_idx = start_idx + len(request.output_token_ids)
348+
if request.prompt_token_ids is not None:
349+
self.token_ids_cpu[
350+
req_index, :num_prompt_tokens] = request.prompt_token_ids
351+
self.is_token_ids[req_index, :num_prompt_tokens] = True
352+
else:
353+
self.is_token_ids[req_index, :num_prompt_tokens] = False
354+
if request.prompt_embeds is not None:
355+
self.req_prompt_embeds[req_index] = request.prompt_embeds
335356
self.token_ids_cpu[req_index,
336357
start_idx:end_idx] = request.output_token_ids
337-
# Number of token ids in token_ids_cpu.
358+
self.is_token_ids[req_index, start_idx:end_idx] = True
359+
# Number of token ids in prompt (token_ids_cpu or prompt_embeds).
338360
# NOTE(woosuk): This may include spec decode tokens.
339361
self.num_tokens[req_index] = request.num_tokens
340362
# Number of tokens without spec decode tokens.
@@ -534,6 +556,20 @@ def swap_states(self, i1: int, i2: int) -> None:
534556
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
535557
self.token_ids_cpu[i2, ...] = tmp
536558

559+
self.is_token_ids[[i1, i2], ...] = self.is_token_ids[[i2, i1], ...]
560+
561+
# Swap prompt embeddings if they exist
562+
embeds_i1 = self.req_prompt_embeds.get(i1)
563+
embeds_i2 = self.req_prompt_embeds.get(i2)
564+
if embeds_i1 is not None:
565+
self.req_prompt_embeds[i2] = embeds_i1
566+
else:
567+
self.req_prompt_embeds.pop(i2, None)
568+
if embeds_i2 is not None:
569+
self.req_prompt_embeds[i1] = embeds_i2
570+
else:
571+
self.req_prompt_embeds.pop(i1, None)
572+
537573
swap_dict_values(self.generators, i1, i2)
538574
swap_dict_values(self.bad_words_token_ids, i1, i2)
539575

@@ -612,6 +648,11 @@ def condense(self) -> None:
612648
num_tokens = self.num_tokens[last_req_index]
613649
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
614650
last_req_index, :num_tokens]
651+
self.is_token_ids[empty_index, :num_tokens] = self.is_token_ids[
652+
last_req_index, :num_tokens]
653+
if last_req_index in self.req_prompt_embeds:
654+
self.req_prompt_embeds[
655+
empty_index] = self.req_prompt_embeds.pop(last_req_index)
615656
self.num_tokens[empty_index] = num_tokens
616657
self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
617658
last_req_index]

0 commit comments

Comments
 (0)