Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions tests/singlecard/test_offline_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,21 @@ def test_models_topk() -> None:
enforce_eager=True,
gpu_memory_utilization=0.7) as vllm_model:
vllm_model.generate(example_prompts, sampling_params)


def test_models_prompt_logprobs() -> None:
example_prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

with VllmRunner("Qwen/Qwen3-0.6B-Base",
max_model_len=8192,
dtype="float16",
enforce_eager=True,
gpu_memory_utilization=0.7) as vllm_model:
vllm_model.generate_greedy_logprobs(example_prompts,
max_tokens=50,
num_logprobs=1)
106 changes: 104 additions & 2 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
ModelRunnerOutput)
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.sampler import Sampler
from vllm.v1.spec_decode.eagle import EagleProposer
Expand Down Expand Up @@ -1474,6 +1475,12 @@ def execute_model(
logprobs_lists = logprobs_tensors.tolists() \
if logprobs_tensors is not None else None

# Compute prompt logprobs if needed.
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
hidden_states[:num_scheduled_tokens],
scheduler_output,
)

# Get the valid generated tokens.
sampled_token_ids = sampler_output.sampled_token_ids
max_gen_len = sampled_token_ids.shape[-1]
Expand Down Expand Up @@ -1509,7 +1516,7 @@ def execute_model(
sampled_token_ids=valid_sampled_token_ids,
spec_token_ids=spec_token_ids,
logprobs=logprobs_lists,
prompt_logprobs_dict={},
prompt_logprobs_dict=prompt_logprobs_dict,
finished_sending=finished_sending,
finished_recving=finished_recving,
)
Expand Down Expand Up @@ -2315,6 +2322,101 @@ def _generate_mtp_token_ids(
spec_token_ids = draft_token_ids.tolist()
return spec_token_ids

def _get_prompt_logprobs_dict(
self,
hidden_states: torch.Tensor,
scheduler_output: "SchedulerOutput",
) -> dict[str, Optional[LogprobsTensors]]:
num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs
if not num_prompt_logprobs_dict:
return {}

in_progress_dict = self.input_batch.in_progress_prompt_logprobs_cpu
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {}

# Since prompt logprobs are a rare feature, prioritize simple,
# maintainable loop over optimal performance.
completed_prefill_reqs = []
for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items():

num_tokens = scheduler_output.num_scheduled_tokens[req_id]

# Get metadata for this request.
request = self.requests[req_id]
num_prompt_tokens = len(request.prompt_token_ids)
prompt_token_ids = torch.tensor(request.prompt_token_ids).to(
self.device, non_blocking=True)

# Set up target LogprobsTensors object.
logprobs_tensors = in_progress_dict.get(req_id)
if not logprobs_tensors:
# Create empty logprobs CPU tensors for the entire prompt.
# If chunked, we'll copy in slice by slice.
logprobs_tensors = LogprobsTensors.empty_cpu(
num_prompt_tokens - 1, num_prompt_logprobs + 1)
in_progress_dict[req_id] = logprobs_tensors

# Determine number of logits to retrieve.
start_idx = request.num_computed_tokens
start_tok = start_idx + 1
num_remaining_tokens = num_prompt_tokens - start_tok
if num_tokens <= num_remaining_tokens:
# This is a chunk, more tokens remain.
# In the == case, there are no more prompt logprobs to produce
# but we want to defer returning them to the next step where we
# have new generated tokens to return.
num_logits = num_tokens
else:
# This is the last chunk of prompt tokens to return.
num_logits = num_remaining_tokens
completed_prefill_reqs.append(req_id)
prompt_logprobs_dict[req_id] = logprobs_tensors

if num_logits <= 0:
# This can happen for the final chunk if we prefilled exactly
# (num_prompt_tokens - 1) tokens for this request in the prior
# step. There are no more prompt logprobs to produce.
continue

# Get the logits corresponding to this req's prompt tokens.
# If this is a partial request (i.e. chunked prefill),
# then there is prompt logprob generated for each index.
req_idx = self.input_batch.req_id_to_index[req_id]
offset = self.query_start_loc_np[req_idx].item()
prompt_hidden_states = hidden_states[offset:offset + num_logits]
logits = self.model.compute_logits(prompt_hidden_states, None)
Comment on lines +2340 to +2387
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation computes logits for each request serially within a for loop by calling self.model.compute_logits on sliced hidden states. This approach can become a significant performance bottleneck, especially in evaluation scenarios where many requests might require prompt logprobs simultaneously.

While the comment on line 2337 mentions prioritizing simplicity, the performance degradation from unbatched calls on the accelerator can be substantial.

I recommend refactoring this to batch the logit computation across all requests that need it. A possible approach is:

  1. In a first pass, iterate over the requests to collect all prompt_hidden_states slices.
  2. Concatenate these slices into a single tensor using torch.cat.
  3. Perform a single, batched model.compute_logits call on the concatenated tensor.
  4. In a second pass, iterate over the requests again, this time processing the corresponding slice of the resulting batched logits.

This change would leverage the NPU's parallel processing capabilities more effectively and significantly improve performance.


# Get the "target" tokens for each index. For prompt at index i,
# the token at prompt index i+1 is the "sampled" token we want
# to gather the logprob for.
tgt_token_ids = prompt_token_ids[start_tok:start_tok + num_logits]

# Compute prompt logprobs.
logprobs = self.sampler.compute_logprobs(logits)
token_ids, logprobs, ranks = self.sampler.gather_logprobs(
logprobs, num_prompt_logprobs, tgt_token_ids)

# Transfer NPU->CPU async.
chunk_slice = slice(start_idx, start_idx + num_logits)
logprobs_tensors.logprob_token_ids[chunk_slice].copy_(
token_ids, non_blocking=True)
logprobs_tensors.logprobs[chunk_slice].copy_(logprobs,
non_blocking=True)
logprobs_tensors.selected_token_ranks[chunk_slice].copy_(
ranks, non_blocking=True)

# Remove requests that have completed prefill from the batch
# num_prompt_logprobs_dict.
for req_id in completed_prefill_reqs:
del num_prompt_logprobs_dict[req_id]
del in_progress_dict[req_id]

# Must synchronize the non-blocking NPU->CPU transfers.
if prompt_logprobs_dict:
torch.npu.synchronize()

return prompt_logprobs_dict

def init_torchair_graph_batch_sizes(self):
start_graph_batch_size = 4
tp_size = get_tensor_model_parallel_world_size()
Expand Down