Skip to content

Conversation

jiangpeng36
Copy link
Contributor

@jiangpeng36 jiangpeng36 commented Sep 5, 2025

This PR is based on top of #23569 and #24219.

What this PR does / why we need it?

This PR allows the model runner to function asynchronously when using async scheduling. This allows full overlap of the cpu operations (including prepare_inputs) and the model forward pass. This diff is functional and does not support speculative decoding, PP, or guided decoding.

Expected speedup is 5-10% over the current async scheduling.

image

Does this PR introduce any user-facing change?

How was this patch tested?

server

python -m vllm.entrypoints.openai.api_server --model=Qwen3-32B\
	--trust-remote-code --enforce-eager \
	--distributed-executor-backend=mp \
	-tp=4 \
	--port 8006 \
	--max-model-len 32000 \
	--block-size 128 \
	--gpu-memory-utilization 0.99

client

python $TEST_PY --backend vllm --trust-remote-code --model Qwen3-32B \
  --dataset-name random --random-input-len 2048 --random-output-len 2048 \
  --ignore-eos\
  --num-prompts 48 --max-concurrency 48  --request-rate inf --temperature 0 \
  --metric-percentiles 90  --base-url http://localhost:8006 --save-result \
  --result-dir $PROFILER_DIR

benchmark test based on Qwen3-32B TPOT result:

forward async scheduler async sync
avg 41.73 41.86 44.20
improve0 0.3% 0 0
improve1 5.58% 0 0

benchmark test based on Qwen2___5-VL-7B-Instruct TPOT result:

forward async sync
avg 23.22 29.16
improve 20.3% 0

Copy link

github-actions bot commented Sep 5, 2025

This pull request has conflicts, please resolve those before we can evaluate the pull request.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces asynchronous model execution to overlap CPU and NPU operations for a performance boost. The core changes involve a new AsyncNPUModelRunnerOutput class to handle non-blocking data transfers and modifications to the model execution pipeline to support this. While the changes are promising for performance, I've identified a critical issue with state management in the asynchronous path that could lead to incorrect model outputs, an unused attribute that should be removed, and a hardcoded path in an example file that hinders usability. Addressing these points will be crucial for the stability and correctness of this new feature.

Comment on lines 1679 to +1558
self,
scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[ModelRunnerOutput, torch.Tensor]:
) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

To correctly handle asynchronous scheduling, the worker's CPU-side state must be updated with the actual token IDs from the previous step. This should happen at the beginning of the current step, before preparing inputs. Without this, features like repetition penalty will use stale or incorrect token history.

Please add state update logic at the start of execute_model to synchronize prev_sampled_token_ids and update the CPU-side token history. Here is a code snippet to illustrate the required logic:

if self.use_async_scheduling and self.input_batch.prev_sampled_token_ids is not None:
    # Sync and update state from previous async step
    prev_sampled_token_ids_cpu = self.input_batch.prev_sampled_token_ids.tolist()
    prev_req_id_to_index = self.input_batch.prev_req_id_to_index
    assert prev_req_id_to_index is not None

    for req_id, prev_req_idx in prev_req_id_to_index.items():
        if req_id not in self.requests:
            continue
        req_state = self.requests[req_id]
        req_idx = self.input_batch.req_id_to_index.get(req_id)
        if req_idx is None:
            continue
        
        sampled_ids = prev_sampled_token_ids_cpu[prev_req_idx]
        if not sampled_ids:
            continue

        req_state.output_token_ids.extend(sampled_ids)

        start_idx = self.input_batch.num_tokens_no_spec[req_idx]
        end_idx = start_idx + len(sampled_ids)
        assert end_idx <= self.model_config.max_model_len
        self.input_batch.token_ids_cpu[req_idx, start_idx:end_idx] = sampled_ids
        self.input_batch.num_tokens_no_spec[req_idx] = end_idx
        self.input_batch.num_tokens[req_idx] = end_idx
    
    # Clear the prev step's data
    self.input_batch.prev_sampled_token_ids = None
    self.input_batch.prev_req_id_to_index = None

sampling_params = SamplingParams(max_tokens=100, temperature=0.0)
# Create an LLM.
llm = LLM(model="deepseek-ai/DeepSeek-V2-Lite",
llm = LLM(model="/home/jp/model/Qwen2.5-0.5B-Instruct",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The model path is hardcoded to a local directory. This will cause the example to fail for other users. Please use a model identifier from a public hub, like Hugging Face Hub, so that the example is runnable out of the box.

Suggested change
llm = LLM(model="/home/jp/model/Qwen2.5-0.5B-Instruct",
llm = LLM(model="Qwen/Qwen2.5-0.5B-Instruct",

Comment on lines +1874 to +1761
self.input_batch.prev_sampled_token_ids_invalid_indices = \
invalid_req_indices_set
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The attribute self.input_batch.prev_sampled_token_ids_invalid_indices is assigned here but is never read or used anywhere. This appears to be dead code and should be removed along with its definition in InputBatch to improve clarity and reduce maintenance overhead.


# Cached reference to the GPU tensor of previously sampled tokens
self.prev_sampled_token_ids: Optional[torch.Tensor] = None
self.prev_sampled_token_ids_invalid_indices: Optional[set[int]] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The attribute prev_sampled_token_ids_invalid_indices is defined here but is never read or used anywhere in the codebase. This appears to be dead code and should be removed to avoid confusion and reduce maintenance overhead.

Copy link

github-actions bot commented Sep 5, 2025

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

@jiangpeng36 jiangpeng36 force-pushed the pr23569 branch 5 times, most recently from 19e7592 to 703b715 Compare September 8, 2025 06:44
@jiangpeng36 jiangpeng36 force-pushed the pr23569 branch 3 times, most recently from e083282 to 74abecc Compare September 8, 2025 11:30
Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@@ -0,0 +1,189 @@
#
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this file could be removed? plz remove this file in a single commit and will merge soon

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok,already removed

Signed-off-by: jiangpeng36 <jiangpeng36@huawei.com>
Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
Co-authored-by: Ronald1995 <ronaldautomobile@163.com>
Copy link

codecov bot commented Sep 11, 2025

Codecov Report

❌ Patch coverage is 13.09524% with 73 lines in your changes missing coverage. Please review.
✅ Project coverage is 75.16%. Comparing base (1bbb20e) to head (0718412).
⚠️ Report is 18 commits behind head on main.

Files with missing lines Patch % Lines
vllm_ascend/worker/model_runner_v1.py 6.41% 73 Missing ⚠️

❌ Your patch status has failed because the patch coverage (13.09%) is below the target coverage (80.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2783      +/-   ##
==========================================
+ Coverage   74.76%   75.16%   +0.39%     
==========================================
  Files         150      155       +5     
  Lines       20891    21195     +304     
==========================================
+ Hits        15620    15932     +312     
+ Misses       5271     5263       -8     
Flag Coverage Δ
unittests 75.16% <13.09%> (+0.39%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@MengqingCao
Copy link
Collaborator

merging as all the test passed in https://github.com/vllm-project/vllm-ascend/actions/runs/17606628607 and https://github.com/vllm-project/vllm-ascend/actions/runs/17631881467, except for some unrelated ut issues. For code coverage, plz write ut in the next pr for async scheduler, thx!

@MengqingCao MengqingCao merged commit 2b9269b into vllm-project:main Sep 11, 2025
21 of 22 checks passed
yiz-liu pushed a commit to linfeng-yuan/vllm-ascend that referenced this pull request Sep 12, 2025
This PR is based on top of
[#23569](vllm-project/vllm#23569) and
[#24219](vllm-project/vllm#24219).

### What this PR does / why we need it?
This PR allows the model runner to function asynchronously when using
async scheduling. This allows full overlap of the cpu operations
(including prepare_inputs) and the model forward pass. This diff is
functional and does not support speculative decoding, PP, or guided
decoding.

Expected speedup is 5-10% over the current async scheduling.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?
server
```
python -m vllm.entrypoints.openai.api_server --model=Qwen3-32B\
	--trust-remote-code --enforce-eager \
	--distributed-executor-backend=mp \
	-tp=4 \
	--port 8006 \
	--max-model-len 32000 \
	--block-size 128 \
	--gpu-memory-utilization 0.99
```
client
```
python $TEST_PY --backend vllm --trust-remote-code --model Qwen3-32B \
  --dataset-name random --random-input-len 2048 --random-output-len 2048 \
  --ignore-eos\
  --num-prompts 48 --max-concurrency 48  --request-rate inf --temperature 0 \
  --metric-percentiles 90  --base-url http://localhost:8006 --save-result \
  --result-dir $PROFILER_DIR
```

benchmark test based on Qwen3-32B TPOT result:
||forward async| scheduler async |sync|
|-|-|-|-|
|avg|41.73|41.86|44.20|
|improve0|0.3%|0|0|
|improve1|5.58%|0|0|

benchmark test based on Qwen2___5-VL-7B-Instruct TPOT result:
||forward async|sync|
|-|-|-|
|avg|23.22|29.16|
|improve|20.3%|0|

- vLLM version: main
- vLLM main:
vllm-project/vllm@e93f4cc

Signed-off-by: jiangpeng36 <jiangpeng36@huawei.com>
Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
Co-authored-by: jiangpeng36 <jiangpeng36@huawei.com>
Co-authored-by: Ronald1995 <ronaldautomobile@163.com>
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
offline893 pushed a commit to offline893/vllm-ascend that referenced this pull request Sep 16, 2025
This PR is based on top of
[#23569](vllm-project/vllm#23569) and
[#24219](vllm-project/vllm#24219).

### What this PR does / why we need it?
This PR allows the model runner to function asynchronously when using
async scheduling. This allows full overlap of the cpu operations
(including prepare_inputs) and the model forward pass. This diff is
functional and does not support speculative decoding, PP, or guided
decoding.

Expected speedup is 5-10% over the current async scheduling.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?
server
```
python -m vllm.entrypoints.openai.api_server --model=Qwen3-32B\
	--trust-remote-code --enforce-eager \
	--distributed-executor-backend=mp \
	-tp=4 \
	--port 8006 \
	--max-model-len 32000 \
	--block-size 128 \
	--gpu-memory-utilization 0.99
```
client
```
python $TEST_PY --backend vllm --trust-remote-code --model Qwen3-32B \
  --dataset-name random --random-input-len 2048 --random-output-len 2048 \
  --ignore-eos\
  --num-prompts 48 --max-concurrency 48  --request-rate inf --temperature 0 \
  --metric-percentiles 90  --base-url http://localhost:8006 --save-result \
  --result-dir $PROFILER_DIR
```

benchmark test based on Qwen3-32B TPOT result:
||forward async| scheduler async |sync|
|-|-|-|-|
|avg|41.73|41.86|44.20|
|improve0|0.3%|0|0|
|improve1|5.58%|0|0|

benchmark test based on Qwen2___5-VL-7B-Instruct TPOT result:
||forward async|sync|
|-|-|-|
|avg|23.22|29.16|
|improve|20.3%|0|

- vLLM version: main
- vLLM main:
vllm-project/vllm@e93f4cc

Signed-off-by: jiangpeng36 <jiangpeng36@huawei.com>
Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
Co-authored-by: jiangpeng36 <jiangpeng36@huawei.com>
Co-authored-by: Ronald1995 <ronaldautomobile@163.com>
wangxiaoteng888 pushed a commit to LCAIZJ/vllm-ascend that referenced this pull request Sep 25, 2025
This PR is based on top of
[#23569](vllm-project/vllm#23569) and
[#24219](vllm-project/vllm#24219).

### What this PR does / why we need it?
This PR allows the model runner to function asynchronously when using
async scheduling. This allows full overlap of the cpu operations
(including prepare_inputs) and the model forward pass. This diff is
functional and does not support speculative decoding, PP, or guided
decoding.

Expected speedup is 5-10% over the current async scheduling.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?
server
```
python -m vllm.entrypoints.openai.api_server --model=Qwen3-32B\
	--trust-remote-code --enforce-eager \
	--distributed-executor-backend=mp \
	-tp=4 \
	--port 8006 \
	--max-model-len 32000 \
	--block-size 128 \
	--gpu-memory-utilization 0.99
```
client
```
python $TEST_PY --backend vllm --trust-remote-code --model Qwen3-32B \
  --dataset-name random --random-input-len 2048 --random-output-len 2048 \
  --ignore-eos\
  --num-prompts 48 --max-concurrency 48  --request-rate inf --temperature 0 \
  --metric-percentiles 90  --base-url http://localhost:8006 --save-result \
  --result-dir $PROFILER_DIR
```

benchmark test based on Qwen3-32B TPOT result:
||forward async| scheduler async |sync|
|-|-|-|-|
|avg|41.73|41.86|44.20|
|improve0|0.3%|0|0|
|improve1|5.58%|0|0|

benchmark test based on Qwen2___5-VL-7B-Instruct TPOT result:
||forward async|sync|
|-|-|-|
|avg|23.22|29.16|
|improve|20.3%|0|


- vLLM version: main
- vLLM main:
vllm-project/vllm@e93f4cc

Signed-off-by: jiangpeng36 <jiangpeng36@huawei.com>
Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
Co-authored-by: jiangpeng36 <jiangpeng36@huawei.com>
Co-authored-by: Ronald1995 <ronaldautomobile@163.com>
chopper0126 pushed a commit to chopper0126/vllm-ascend that referenced this pull request Sep 26, 2025
This PR is based on top of
[#23569](vllm-project/vllm#23569) and
[#24219](vllm-project/vllm#24219).

### What this PR does / why we need it?
This PR allows the model runner to function asynchronously when using
async scheduling. This allows full overlap of the cpu operations
(including prepare_inputs) and the model forward pass. This diff is
functional and does not support speculative decoding, PP, or guided
decoding.

Expected speedup is 5-10% over the current async scheduling.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?
server
```
python -m vllm.entrypoints.openai.api_server --model=Qwen3-32B\
	--trust-remote-code --enforce-eager \
	--distributed-executor-backend=mp \
	-tp=4 \
	--port 8006 \
	--max-model-len 32000 \
	--block-size 128 \
	--gpu-memory-utilization 0.99
```
client
```
python $TEST_PY --backend vllm --trust-remote-code --model Qwen3-32B \
  --dataset-name random --random-input-len 2048 --random-output-len 2048 \
  --ignore-eos\
  --num-prompts 48 --max-concurrency 48  --request-rate inf --temperature 0 \
  --metric-percentiles 90  --base-url http://localhost:8006 --save-result \
  --result-dir $PROFILER_DIR
```

benchmark test based on Qwen3-32B TPOT result:
||forward async| scheduler async |sync|
|-|-|-|-|
|avg|41.73|41.86|44.20|
|improve0|0.3%|0|0|
|improve1|5.58%|0|0|

benchmark test based on Qwen2___5-VL-7B-Instruct TPOT result:
||forward async|sync|
|-|-|-|
|avg|23.22|29.16|
|improve|20.3%|0|


- vLLM version: main
- vLLM main:
vllm-project/vllm@e93f4cc

Signed-off-by: jiangpeng36 <jiangpeng36@huawei.com>
Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
Co-authored-by: jiangpeng36 <jiangpeng36@huawei.com>
Co-authored-by: Ronald1995 <ronaldautomobile@163.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module:tests ready read for review ready-for-test start test by label for PR
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants