|
59 | 59 | from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
60 | 60 | from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
61 | 61 | KVCacheSpec)
|
62 |
| -from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput |
| 62 | +from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, |
| 63 | + ModelRunnerOutput) |
63 | 64 | from vllm.v1.sample.metadata import SamplingMetadata
|
64 | 65 | from vllm.v1.sample.sampler import Sampler
|
65 | 66 | from vllm.v1.spec_decode.eagle import EagleProposer
|
@@ -1474,6 +1475,12 @@ def execute_model(
|
1474 | 1475 | logprobs_lists = logprobs_tensors.tolists() \
|
1475 | 1476 | if logprobs_tensors is not None else None
|
1476 | 1477 |
|
| 1478 | + # Compute prompt logprobs if needed. |
| 1479 | + prompt_logprobs_dict = self._get_prompt_logprobs_dict( |
| 1480 | + hidden_states[:num_scheduled_tokens], |
| 1481 | + scheduler_output, |
| 1482 | + ) |
| 1483 | + |
1477 | 1484 | # Get the valid generated tokens.
|
1478 | 1485 | sampled_token_ids = sampler_output.sampled_token_ids
|
1479 | 1486 | max_gen_len = sampled_token_ids.shape[-1]
|
@@ -1509,7 +1516,7 @@ def execute_model(
|
1509 | 1516 | sampled_token_ids=valid_sampled_token_ids,
|
1510 | 1517 | spec_token_ids=spec_token_ids,
|
1511 | 1518 | logprobs=logprobs_lists,
|
1512 |
| - prompt_logprobs_dict={}, |
| 1519 | + prompt_logprobs_dict=prompt_logprobs_dict, |
1513 | 1520 | finished_sending=finished_sending,
|
1514 | 1521 | finished_recving=finished_recving,
|
1515 | 1522 | )
|
@@ -2315,6 +2322,101 @@ def _generate_mtp_token_ids(
|
2315 | 2322 | spec_token_ids = draft_token_ids.tolist()
|
2316 | 2323 | return spec_token_ids
|
2317 | 2324 |
|
| 2325 | + def _get_prompt_logprobs_dict( |
| 2326 | + self, |
| 2327 | + hidden_states: torch.Tensor, |
| 2328 | + scheduler_output: "SchedulerOutput", |
| 2329 | + ) -> dict[str, Optional[LogprobsTensors]]: |
| 2330 | + num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs |
| 2331 | + if not num_prompt_logprobs_dict: |
| 2332 | + return {} |
| 2333 | + |
| 2334 | + in_progress_dict = self.input_batch.in_progress_prompt_logprobs_cpu |
| 2335 | + prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {} |
| 2336 | + |
| 2337 | + # Since prompt logprobs are a rare feature, prioritize simple, |
| 2338 | + # maintainable loop over optimal performance. |
| 2339 | + completed_prefill_reqs = [] |
| 2340 | + for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items(): |
| 2341 | + |
| 2342 | + num_tokens = scheduler_output.num_scheduled_tokens[req_id] |
| 2343 | + |
| 2344 | + # Get metadata for this request. |
| 2345 | + request = self.requests[req_id] |
| 2346 | + num_prompt_tokens = len(request.prompt_token_ids) |
| 2347 | + prompt_token_ids = torch.tensor(request.prompt_token_ids).to( |
| 2348 | + self.device, non_blocking=True) |
| 2349 | + |
| 2350 | + # Set up target LogprobsTensors object. |
| 2351 | + logprobs_tensors = in_progress_dict.get(req_id) |
| 2352 | + if not logprobs_tensors: |
| 2353 | + # Create empty logprobs CPU tensors for the entire prompt. |
| 2354 | + # If chunked, we'll copy in slice by slice. |
| 2355 | + logprobs_tensors = LogprobsTensors.empty_cpu( |
| 2356 | + num_prompt_tokens - 1, num_prompt_logprobs + 1) |
| 2357 | + in_progress_dict[req_id] = logprobs_tensors |
| 2358 | + |
| 2359 | + # Determine number of logits to retrieve. |
| 2360 | + start_idx = request.num_computed_tokens |
| 2361 | + start_tok = start_idx + 1 |
| 2362 | + num_remaining_tokens = num_prompt_tokens - start_tok |
| 2363 | + if num_tokens <= num_remaining_tokens: |
| 2364 | + # This is a chunk, more tokens remain. |
| 2365 | + # In the == case, there are no more prompt logprobs to produce |
| 2366 | + # but we want to defer returning them to the next step where we |
| 2367 | + # have new generated tokens to return. |
| 2368 | + num_logits = num_tokens |
| 2369 | + else: |
| 2370 | + # This is the last chunk of prompt tokens to return. |
| 2371 | + num_logits = num_remaining_tokens |
| 2372 | + completed_prefill_reqs.append(req_id) |
| 2373 | + prompt_logprobs_dict[req_id] = logprobs_tensors |
| 2374 | + |
| 2375 | + if num_logits <= 0: |
| 2376 | + # This can happen for the final chunk if we prefilled exactly |
| 2377 | + # (num_prompt_tokens - 1) tokens for this request in the prior |
| 2378 | + # step. There are no more prompt logprobs to produce. |
| 2379 | + continue |
| 2380 | + |
| 2381 | + # Get the logits corresponding to this req's prompt tokens. |
| 2382 | + # If this is a partial request (i.e. chunked prefill), |
| 2383 | + # then there is prompt logprob generated for each index. |
| 2384 | + req_idx = self.input_batch.req_id_to_index[req_id] |
| 2385 | + offset = self.query_start_loc_np[req_idx].item() |
| 2386 | + prompt_hidden_states = hidden_states[offset:offset + num_logits] |
| 2387 | + logits = self.model.compute_logits(prompt_hidden_states, None) |
| 2388 | + |
| 2389 | + # Get the "target" tokens for each index. For prompt at index i, |
| 2390 | + # the token at prompt index i+1 is the "sampled" token we want |
| 2391 | + # to gather the logprob for. |
| 2392 | + tgt_token_ids = prompt_token_ids[start_tok:start_tok + num_logits] |
| 2393 | + |
| 2394 | + # Compute prompt logprobs. |
| 2395 | + logprobs = self.sampler.compute_logprobs(logits) |
| 2396 | + token_ids, logprobs, ranks = self.sampler.gather_logprobs( |
| 2397 | + logprobs, num_prompt_logprobs, tgt_token_ids) |
| 2398 | + |
| 2399 | + # Transfer NPU->CPU async. |
| 2400 | + chunk_slice = slice(start_idx, start_idx + num_logits) |
| 2401 | + logprobs_tensors.logprob_token_ids[chunk_slice].copy_( |
| 2402 | + token_ids, non_blocking=True) |
| 2403 | + logprobs_tensors.logprobs[chunk_slice].copy_(logprobs, |
| 2404 | + non_blocking=True) |
| 2405 | + logprobs_tensors.selected_token_ranks[chunk_slice].copy_( |
| 2406 | + ranks, non_blocking=True) |
| 2407 | + |
| 2408 | + # Remove requests that have completed prefill from the batch |
| 2409 | + # num_prompt_logprobs_dict. |
| 2410 | + for req_id in completed_prefill_reqs: |
| 2411 | + del num_prompt_logprobs_dict[req_id] |
| 2412 | + del in_progress_dict[req_id] |
| 2413 | + |
| 2414 | + # Must synchronize the non-blocking NPU->CPU transfers. |
| 2415 | + if prompt_logprobs_dict: |
| 2416 | + torch.npu.synchronize() |
| 2417 | + |
| 2418 | + return prompt_logprobs_dict |
| 2419 | + |
2318 | 2420 | def init_torchair_graph_batch_sizes(self):
|
2319 | 2421 | start_graph_batch_size = 4
|
2320 | 2422 | tp_size = get_tensor_model_parallel_world_size()
|
|
0 commit comments