From 06be93e67abed8c1bcc874e32d459fc11c7aaa39 Mon Sep 17 00:00:00 2001 From: MengqingCao Date: Thu, 10 Jul 2025 07:12:26 +0000 Subject: [PATCH 1/2] [AscendScheduler][Bugfix] Remove num_draft_tokens while allocating slots Signed-off-by: MengqingCao --- docs/source/tutorials/multi_node.md | 2 +- vllm_ascend/core/scheduler.py | 26 +++++++++++++++++--------- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/docs/source/tutorials/multi_node.md b/docs/source/tutorials/multi_node.md index 64475c32bc..865cc4089c 100644 --- a/docs/source/tutorials/multi_node.md +++ b/docs/source/tutorials/multi_node.md @@ -54,7 +54,7 @@ hccn_tool -i 0 -ping -g address 10.20.0.20 ``` ## Run with docker -Assume you have two Altas 800 A2(64G*8) nodes, and want to deploy the `deepseek-v3-w8a8` quantitative model across multi-node. +Assume you have two Atlas 800 A2(64G*8) nodes, and want to deploy the `deepseek-v3-w8a8` quantitative model across multi-node. ```shell # Define the image and container name diff --git a/vllm_ascend/core/scheduler.py b/vllm_ascend/core/scheduler.py index 00a17ddfc3..47340680bf 100644 --- a/vllm_ascend/core/scheduler.py +++ b/vllm_ascend/core/scheduler.py @@ -32,6 +32,8 @@ from vllm.v1.request import Request, RequestStatus from vllm.v1.structured_output import StructuredOutputManager +from vllm_ascend.utils import vllm_version_is + class AscendScheduler(Scheduler): """This Scheduler extends vllm's original v1 scheduler @@ -281,17 +283,23 @@ def skip_cur_request(): # allow the lower-priority requests to be scheduled. req_index += 1 continue - - num_draft_tokens = max( - num_new_tokens + request.num_computed_tokens - - request.num_tokens, 0) + if vllm_version_is("0.9.2"): + num_draft_tokens = max( + num_new_tokens + request.num_computed_tokens - + request.num_tokens, 0) while True: - new_blocks = self.kv_cache_manager.allocate_slots( - request, - num_new_tokens, - num_draft_tokens=num_draft_tokens, - num_lookahead_tokens=self.num_lookahead_tokens) + if vllm_version_is("0.9.2"): + new_blocks = self.kv_cache_manager.allocate_slots( + request, + num_new_tokens, + num_draft_tokens=num_draft_tokens, + num_lookahead_tokens=self.num_lookahead_tokens) + else: + new_blocks = self.kv_cache_manager.allocate_slots( + request, + num_new_tokens, + num_lookahead_tokens=self.num_lookahead_tokens) if new_blocks is None: # The request cannot be scheduled. # Preempt the lowest-priority request. From 5eacc1b305e193a42e2ae5c22c3d9f0040b2c9be Mon Sep 17 00:00:00 2001 From: MengqingCao Date: Thu, 10 Jul 2025 08:49:33 +0000 Subject: [PATCH 2/2] skip test on vllm-ascend/Qwen3-30B-A3B-Puring with aclgraph Signed-off-by: MengqingCao --- tests/e2e/singlecard/test_aclgraph.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/e2e/singlecard/test_aclgraph.py b/tests/e2e/singlecard/test_aclgraph.py index 4fc23aa7b3..89dfa08e41 100644 --- a/tests/e2e/singlecard/test_aclgraph.py +++ b/tests/e2e/singlecard/test_aclgraph.py @@ -29,7 +29,11 @@ from tests.conftest import VllmRunner from tests.model_utils import check_outputs_equal -MODELS = ["Qwen/Qwen2.5-0.5B-Instruct", "vllm-ascend/Qwen3-30B-A3B-Puring"] +MODELS = [ + "Qwen/Qwen2.5-0.5B-Instruct", + # TODO: REVERT ME when oom is fixed + # "vllm-ascend/Qwen3-30B-A3B-Puring" +] @pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",