-
Notifications
You must be signed in to change notification settings - Fork 483
[0.9.1][PromptLogprobs][V1] Support prompt logprobs to fix ceval accuracy in V1 #2654
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[0.9.1][PromptLogprobs][V1] Support prompt logprobs to fix ceval accuracy in V1 #2654
Conversation
… V1 (vllm-project#1483) Support prompt logprobs in V1. This also enable lm_eval to test accuracy on V1 support prompt logprobs output CI passed with accuracy test. Using lm_eval, which use prompt logprobs as output to test accuracy, to test: ```python VLLM_USE_V1=1 lm_eval \ --model vllm \ --model_args pretrained=Qwen/Qwen2.5-7B-Instruct,max_model_len=4096,block_size=4 \ --tasks ceval-valid_computer_network \ --batch_size 8 ``` After this pr, the accuracy test results of `Qwen/Qwen2.5-7B-Instruct` on V1 is: ```bash | Tasks |Version|Filter|n-shot| Metric | |Value | |Stderr| |----------------------------|------:|------|-----:|--------|---|-----:|---|-----:| |ceval-valid_computer_network| 2|none | 0|acc |↑ |0.7368|± |0.1038| | | |none | 0|acc_norm|↑ |0.7368|± |0.1038| ``` Closes: vllm-project#1043 Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: MengqingCao <cmq0113@163.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for prompt logprobs in the V1 API, which is a valuable addition for improving evaluation accuracy on benchmarks like ceval
. The implementation introduces a new method, _get_prompt_logprobs_dict
, to handle the computation. While the logic is generally sound and a corresponding test case has been added, I've identified a significant performance issue in the new method. The current implementation processes requests serially, which could lead to bottlenecks during evaluation. My review includes a recommendation to refactor this part to use batched computation for better efficiency.
for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items(): | ||
|
||
num_tokens = scheduler_output.num_scheduled_tokens[req_id] | ||
|
||
# Get metadata for this request. | ||
request = self.requests[req_id] | ||
num_prompt_tokens = len(request.prompt_token_ids) | ||
prompt_token_ids = torch.tensor(request.prompt_token_ids).to( | ||
self.device, non_blocking=True) | ||
|
||
# Set up target LogprobsTensors object. | ||
logprobs_tensors = in_progress_dict.get(req_id) | ||
if not logprobs_tensors: | ||
# Create empty logprobs CPU tensors for the entire prompt. | ||
# If chunked, we'll copy in slice by slice. | ||
logprobs_tensors = LogprobsTensors.empty_cpu( | ||
num_prompt_tokens - 1, num_prompt_logprobs + 1) | ||
in_progress_dict[req_id] = logprobs_tensors | ||
|
||
# Determine number of logits to retrieve. | ||
start_idx = request.num_computed_tokens | ||
start_tok = start_idx + 1 | ||
num_remaining_tokens = num_prompt_tokens - start_tok | ||
if num_tokens <= num_remaining_tokens: | ||
# This is a chunk, more tokens remain. | ||
# In the == case, there are no more prompt logprobs to produce | ||
# but we want to defer returning them to the next step where we | ||
# have new generated tokens to return. | ||
num_logits = num_tokens | ||
else: | ||
# This is the last chunk of prompt tokens to return. | ||
num_logits = num_remaining_tokens | ||
completed_prefill_reqs.append(req_id) | ||
prompt_logprobs_dict[req_id] = logprobs_tensors | ||
|
||
if num_logits <= 0: | ||
# This can happen for the final chunk if we prefilled exactly | ||
# (num_prompt_tokens - 1) tokens for this request in the prior | ||
# step. There are no more prompt logprobs to produce. | ||
continue | ||
|
||
# Get the logits corresponding to this req's prompt tokens. | ||
# If this is a partial request (i.e. chunked prefill), | ||
# then there is prompt logprob generated for each index. | ||
req_idx = self.input_batch.req_id_to_index[req_id] | ||
offset = self.query_start_loc_np[req_idx].item() | ||
prompt_hidden_states = hidden_states[offset:offset + num_logits] | ||
logits = self.model.compute_logits(prompt_hidden_states, None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation computes logits for each request serially within a for
loop by calling self.model.compute_logits
on sliced hidden states. This approach can become a significant performance bottleneck, especially in evaluation scenarios where many requests might require prompt logprobs simultaneously.
While the comment on line 2337 mentions prioritizing simplicity, the performance degradation from unbatched calls on the accelerator can be substantial.
I recommend refactoring this to batch the logit computation across all requests that need it. A possible approach is:
- In a first pass, iterate over the requests to collect all
prompt_hidden_states
slices. - Concatenate these slices into a single tensor using
torch.cat
. - Perform a single, batched
model.compute_logits
call on the concatenated tensor. - In a second pass, iterate over the requests again, this time processing the corresponding slice of the resulting batched logits.
This change would leverage the NPU's parallel processing capabilities more effectively and significantly improve performance.
What this PR does / why we need it?
Support prompt logprobs to fix ceval accuracy in V1
Does this PR introduce any user-facing change?
prompt_logprobs
inSamplingParams
after this prHow was this patch tested?
CI passed with new added test.