67
67
from vllm .tasks import GenerationTask , PoolingTask , SupportedTask
68
68
from vllm .utils import (STR_DTYPE_TO_TORCH_DTYPE , DeviceMemoryProfiler ,
69
69
LazyLoader , cdiv , get_dtype_size ,
70
- is_pin_memory_available )
70
+ is_pin_memory_available ,
71
+ length_from_prompt_token_ids_or_embeds )
71
72
from vllm .v1 .attention .backends .gdn_attn import GDNAttentionMetadataBuilder
72
73
from vllm .v1 .attention .backends .utils import \
73
74
reorder_batch_to_split_decodes_and_prefills
@@ -285,11 +286,14 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
285
286
286
287
self .is_multimodal_model = self .model_config .is_multimodal_model
287
288
self .is_pooling_model = self .model_config .pooler_config is not None
289
+ self .enable_prompt_embeds = self .model_config .enable_prompt_embeds
288
290
if self .is_multimodal_model :
289
- self .inputs_embeds = torch .zeros (
290
- (self .max_num_tokens , self .model_config .get_hidden_size ()),
291
- dtype = self .dtype ,
292
- device = self .device )
291
+ self .inputs_embeds = self ._make_buffer (self .max_num_tokens ,
292
+ self .model_config .get_hidden_size (),
293
+ dtype = self .dtype ,
294
+ numpy = False )
295
+ self .is_token_ids = self ._make_buffer (self .max_num_tokens ,
296
+ dtype = torch .bool )
293
297
294
298
# Set up Attention
295
299
self .attn_backend = get_attn_backend (
@@ -572,6 +576,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
572
576
self .requests [req_id ] = CachedRequestState (
573
577
req_id = req_id ,
574
578
prompt_token_ids = new_req_data .prompt_token_ids ,
579
+ prompt_embeds = new_req_data .prompt_embeds ,
575
580
mm_kwargs = new_req_data .mm_kwargs ,
576
581
mm_positions = new_req_data .mm_positions ,
577
582
sampling_params = sampling_params ,
@@ -843,7 +848,8 @@ def _calc_mrope_positions(self, scheduler_output: "SchedulerOutput"):
843
848
self .input_batch .num_computed_tokens_cpu [index ]
844
849
num_scheduled_tokens = \
845
850
scheduler_output .num_scheduled_tokens [req_id ]
846
- num_prompt_tokens = len (req .prompt_token_ids )
851
+ num_prompt_tokens = length_from_prompt_token_ids_or_embeds (
852
+ req .prompt_token_ids , req .prompt_embeds )
847
853
848
854
if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens :
849
855
prompt_part_len = max (0 ,
@@ -1016,6 +1022,8 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int,
1016
1022
self .input_ids [:total_num_scheduled_tokens ].copy_ (
1017
1023
self .input_ids_cpu [:total_num_scheduled_tokens ],
1018
1024
non_blocking = True )
1025
+ self .inputs_embeds .copy_to_gpu (total_num_scheduled_tokens )
1026
+ self .is_token_ids .copy_to_gpu (total_num_scheduled_tokens )
1019
1027
return
1020
1028
1021
1029
# Async scheduling case, where some decode requests from the previous
@@ -1043,6 +1051,8 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int,
1043
1051
self .input_ids [:total_num_scheduled_tokens ].copy_ (
1044
1052
self .input_ids_cpu [:total_num_scheduled_tokens ],
1045
1053
non_blocking = True )
1054
+ self .inputs_embeds .copy_to_gpu (total_num_scheduled_tokens )
1055
+ self .is_token_ids .copy_to_gpu (total_num_scheduled_tokens )
1046
1056
if num_commmon_tokens == 0 :
1047
1057
# No requests in common with the previous iteration
1048
1058
# So input_ids_cpu will have all the input ids.
@@ -1056,6 +1066,7 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int,
1056
1066
self .input_batch .prev_sampled_token_ids [:num_commmon_tokens ,
1057
1067
0 ],
1058
1068
non_blocking = True )
1069
+ self .is_token_ids .gpu [:num_commmon_tokens ] = True
1059
1070
return
1060
1071
# Upload the index tensors asynchronously
1061
1072
# so the scatter can be non-blocking.
@@ -1195,15 +1206,60 @@ def _prepare_inputs(
1195
1206
# where M is the max_model_len.
1196
1207
token_indices = (positions_np +
1197
1208
req_indices * self .input_batch .token_ids_cpu .shape [1 ])
1198
-
1209
+ token_indices_tensor = torch . from_numpy ( token_indices )
1199
1210
# Prepare input_ids.
1200
1211
# NOTE(woosuk): We use torch.index_select instead of np.take here
1201
1212
# because torch.index_select is much faster than np.take for large
1202
1213
# tensors.
1203
1214
torch .index_select (self .input_batch .token_ids_cpu_tensor .flatten (),
1204
1215
0 ,
1205
- torch . from_numpy ( token_indices ) ,
1216
+ token_indices_tensor ,
1206
1217
out = self .input_ids_cpu [:total_num_scheduled_tokens ])
1218
+ is_token_ids = self .input_batch .is_token_ids .flatten ()
1219
+ torch .index_select (
1220
+ is_token_ids ,
1221
+ 0 ,
1222
+ token_indices_tensor ,
1223
+ out = self .is_token_ids .cpu [:total_num_scheduled_tokens ])
1224
+
1225
+ # Because we did not pre-allocate a massive prompt_embeds CPU tensor on
1226
+ # the InputBatch, we need to fill in the prompt embeds into the expected
1227
+ # spots in the GpuModelRunner's pre-allocated prompt_embeds tensor.
1228
+ if self .input_batch .req_prompt_embeds :
1229
+ output_idx = 0
1230
+ for req_idx in range (num_reqs ):
1231
+ num_sched = num_scheduled_tokens [req_idx ]
1232
+
1233
+ # Skip if this request doesn't have embeddings
1234
+ if req_idx not in self .input_batch .req_prompt_embeds :
1235
+ output_idx += num_sched
1236
+ continue
1237
+
1238
+ # Skip if no tokens scheduled
1239
+ if num_sched <= 0 :
1240
+ output_idx += num_sched
1241
+ continue
1242
+
1243
+ req_embeds = self .input_batch .req_prompt_embeds [req_idx ]
1244
+ start_pos = self .input_batch .num_computed_tokens_cpu [req_idx ]
1245
+
1246
+ # Skip if trying to read beyond available embeddings
1247
+ if start_pos >= req_embeds .shape [0 ]:
1248
+ output_idx += num_sched
1249
+ continue
1250
+
1251
+ # Copy available embeddings
1252
+ end_pos = start_pos + num_sched
1253
+ actual_end = min (end_pos , req_embeds .shape [0 ])
1254
+ actual_num_sched = actual_end - start_pos
1255
+
1256
+ if actual_num_sched > 0 :
1257
+ self .inputs_embeds .cpu [output_idx :output_idx +
1258
+ actual_num_sched ].copy_ (
1259
+ req_embeds [start_pos :actual_end ]
1260
+ )
1261
+
1262
+ output_idx += num_sched
1207
1263
1208
1264
# Prepare some information for building Attention-Metadata
1209
1265
# Compute and commit slot mapping
@@ -1985,6 +2041,7 @@ def execute_model(
1985
2041
1986
2042
self .input_batch .token_ids_cpu [req_idx ,
1987
2043
start_idx :end_idx ] = sampled_ids
2044
+ self .input_batch .is_token_ids [req_idx , start_idx :end_idx ] = True
1988
2045
self .input_batch .num_tokens_no_spec [req_idx ] = end_idx
1989
2046
self .input_batch .num_tokens [req_idx ] = end_idx
1990
2047
req_id = self .input_batch .req_ids [req_idx ]
@@ -2200,6 +2257,9 @@ def _dummy_run(
2200
2257
if self .is_multimodal_model :
2201
2258
input_ids = None
2202
2259
inputs_embeds = self .inputs_embeds [:num_tokens ]
2260
+ elif self .enable_prompt_embeds :
2261
+ input_ids = None
2262
+ inputs_embeds = self .inputs_embeds .gpu [:num_tokens ]
2203
2263
else :
2204
2264
input_ids = self .input_ids [:num_tokens ]
2205
2265
inputs_embeds = None
@@ -3070,6 +3130,9 @@ def _get_prompt_logprobs_dict(
3070
3130
3071
3131
# Get metadata for this request.
3072
3132
request = self .requests [req_id ]
3133
+ if request .prompt_token_ids is None :
3134
+ # Prompt logprobs is incompatible with prompt embeddings
3135
+ continue
3073
3136
num_prompt_tokens = len (request .prompt_token_ids )
3074
3137
prompt_token_ids = torch .tensor (request .prompt_token_ids ).to (
3075
3138
self .device , non_blocking = True )
0 commit comments