Skip to content

Commit 0ffc133

Browse files
authored
[0.9.1][PromptLogprobs][V1] Support prompt logprobs to fix ceval accuracy in V1 (#2654)
### What this PR does / why we need it? Support prompt logprobs to fix ceval accuracy in V1 ### Does this PR introduce _any_ user-facing change? 1. Users could set `prompt_logprobs` in `SamplingParams` after this pr 2. Users could use lm_eval to evaluate the accuracy ### How was this patch tested? CI passed with new added test. --------- Signed-off-by: MengqingCao <cmq0113@163.com>
1 parent b8164d5 commit 0ffc133

File tree

2 files changed

+122
-2
lines changed

2 files changed

+122
-2
lines changed

tests/singlecard/test_offline_inference.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,3 +131,21 @@ def test_models_topk() -> None:
131131
enforce_eager=True,
132132
gpu_memory_utilization=0.7) as vllm_model:
133133
vllm_model.generate(example_prompts, sampling_params)
134+
135+
136+
def test_models_prompt_logprobs() -> None:
137+
example_prompts = [
138+
"Hello, my name is",
139+
"The president of the United States is",
140+
"The capital of France is",
141+
"The future of AI is",
142+
]
143+
144+
with VllmRunner("Qwen/Qwen3-0.6B-Base",
145+
max_model_len=8192,
146+
dtype="float16",
147+
enforce_eager=True,
148+
gpu_memory_utilization=0.7) as vllm_model:
149+
vllm_model.generate_greedy_logprobs(example_prompts,
150+
max_tokens=50,
151+
num_logprobs=1)

vllm_ascend/worker/model_runner_v1.py

Lines changed: 104 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@
5959
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
6060
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
6161
KVCacheSpec)
62-
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
62+
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
63+
ModelRunnerOutput)
6364
from vllm.v1.sample.metadata import SamplingMetadata
6465
from vllm.v1.sample.sampler import Sampler
6566
from vllm.v1.spec_decode.eagle import EagleProposer
@@ -1474,6 +1475,12 @@ def execute_model(
14741475
logprobs_lists = logprobs_tensors.tolists() \
14751476
if logprobs_tensors is not None else None
14761477

1478+
# Compute prompt logprobs if needed.
1479+
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
1480+
hidden_states[:num_scheduled_tokens],
1481+
scheduler_output,
1482+
)
1483+
14771484
# Get the valid generated tokens.
14781485
sampled_token_ids = sampler_output.sampled_token_ids
14791486
max_gen_len = sampled_token_ids.shape[-1]
@@ -1509,7 +1516,7 @@ def execute_model(
15091516
sampled_token_ids=valid_sampled_token_ids,
15101517
spec_token_ids=spec_token_ids,
15111518
logprobs=logprobs_lists,
1512-
prompt_logprobs_dict={},
1519+
prompt_logprobs_dict=prompt_logprobs_dict,
15131520
finished_sending=finished_sending,
15141521
finished_recving=finished_recving,
15151522
)
@@ -2315,6 +2322,101 @@ def _generate_mtp_token_ids(
23152322
spec_token_ids = draft_token_ids.tolist()
23162323
return spec_token_ids
23172324

2325+
def _get_prompt_logprobs_dict(
2326+
self,
2327+
hidden_states: torch.Tensor,
2328+
scheduler_output: "SchedulerOutput",
2329+
) -> dict[str, Optional[LogprobsTensors]]:
2330+
num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs
2331+
if not num_prompt_logprobs_dict:
2332+
return {}
2333+
2334+
in_progress_dict = self.input_batch.in_progress_prompt_logprobs_cpu
2335+
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {}
2336+
2337+
# Since prompt logprobs are a rare feature, prioritize simple,
2338+
# maintainable loop over optimal performance.
2339+
completed_prefill_reqs = []
2340+
for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items():
2341+
2342+
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
2343+
2344+
# Get metadata for this request.
2345+
request = self.requests[req_id]
2346+
num_prompt_tokens = len(request.prompt_token_ids)
2347+
prompt_token_ids = torch.tensor(request.prompt_token_ids).to(
2348+
self.device, non_blocking=True)
2349+
2350+
# Set up target LogprobsTensors object.
2351+
logprobs_tensors = in_progress_dict.get(req_id)
2352+
if not logprobs_tensors:
2353+
# Create empty logprobs CPU tensors for the entire prompt.
2354+
# If chunked, we'll copy in slice by slice.
2355+
logprobs_tensors = LogprobsTensors.empty_cpu(
2356+
num_prompt_tokens - 1, num_prompt_logprobs + 1)
2357+
in_progress_dict[req_id] = logprobs_tensors
2358+
2359+
# Determine number of logits to retrieve.
2360+
start_idx = request.num_computed_tokens
2361+
start_tok = start_idx + 1
2362+
num_remaining_tokens = num_prompt_tokens - start_tok
2363+
if num_tokens <= num_remaining_tokens:
2364+
# This is a chunk, more tokens remain.
2365+
# In the == case, there are no more prompt logprobs to produce
2366+
# but we want to defer returning them to the next step where we
2367+
# have new generated tokens to return.
2368+
num_logits = num_tokens
2369+
else:
2370+
# This is the last chunk of prompt tokens to return.
2371+
num_logits = num_remaining_tokens
2372+
completed_prefill_reqs.append(req_id)
2373+
prompt_logprobs_dict[req_id] = logprobs_tensors
2374+
2375+
if num_logits <= 0:
2376+
# This can happen for the final chunk if we prefilled exactly
2377+
# (num_prompt_tokens - 1) tokens for this request in the prior
2378+
# step. There are no more prompt logprobs to produce.
2379+
continue
2380+
2381+
# Get the logits corresponding to this req's prompt tokens.
2382+
# If this is a partial request (i.e. chunked prefill),
2383+
# then there is prompt logprob generated for each index.
2384+
req_idx = self.input_batch.req_id_to_index[req_id]
2385+
offset = self.query_start_loc_np[req_idx].item()
2386+
prompt_hidden_states = hidden_states[offset:offset + num_logits]
2387+
logits = self.model.compute_logits(prompt_hidden_states, None)
2388+
2389+
# Get the "target" tokens for each index. For prompt at index i,
2390+
# the token at prompt index i+1 is the "sampled" token we want
2391+
# to gather the logprob for.
2392+
tgt_token_ids = prompt_token_ids[start_tok:start_tok + num_logits]
2393+
2394+
# Compute prompt logprobs.
2395+
logprobs = self.sampler.compute_logprobs(logits)
2396+
token_ids, logprobs, ranks = self.sampler.gather_logprobs(
2397+
logprobs, num_prompt_logprobs, tgt_token_ids)
2398+
2399+
# Transfer NPU->CPU async.
2400+
chunk_slice = slice(start_idx, start_idx + num_logits)
2401+
logprobs_tensors.logprob_token_ids[chunk_slice].copy_(
2402+
token_ids, non_blocking=True)
2403+
logprobs_tensors.logprobs[chunk_slice].copy_(logprobs,
2404+
non_blocking=True)
2405+
logprobs_tensors.selected_token_ranks[chunk_slice].copy_(
2406+
ranks, non_blocking=True)
2407+
2408+
# Remove requests that have completed prefill from the batch
2409+
# num_prompt_logprobs_dict.
2410+
for req_id in completed_prefill_reqs:
2411+
del num_prompt_logprobs_dict[req_id]
2412+
del in_progress_dict[req_id]
2413+
2414+
# Must synchronize the non-blocking NPU->CPU transfers.
2415+
if prompt_logprobs_dict:
2416+
torch.npu.synchronize()
2417+
2418+
return prompt_logprobs_dict
2419+
23182420
def init_torchair_graph_batch_sizes(self):
23192421
start_graph_batch_size = 4
23202422
tp_size = get_tensor_model_parallel_world_size()

0 commit comments

Comments
 (0)