Skip to content

Commit d641fdd

Browse files
author
wangxiaoxin (A)
committed
add optimze of dsv3.
Signed-off-by: wangxiaoxin (A) <w00664509@china.huawei.com>
1 parent 543380c commit d641fdd

File tree

9 files changed

+327
-3
lines changed

9 files changed

+327
-3
lines changed

tests/multicard/test_offline_inference_distributed.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@
2121
Run `pytest tests/test_offline_inference.py`.
2222
"""
2323
import os
24+
from unittest.mock import patch
2425

2526
import vllm # noqa: F401
27+
from vllm import SamplingParams
2628

2729
from tests.conftest import VllmRunner
2830

@@ -61,3 +63,23 @@ def test_models_distributed_DeepSeek():
6163
distributed_executor_backend="mp",
6264
) as vllm_model:
6365
vllm_model.generate_greedy(example_prompts, max_tokens)
66+
67+
68+
@patch.dict(os.environ, {"VLLM_ENABLE_TOPK_OPTIMZE": "1"})
69+
def test_models_distributed_topk(model) -> None:
70+
example_prompts = [
71+
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.",
72+
"Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.",
73+
"Compare and contrast artificial intelligence with human intelligence in terms of processing information.",
74+
]
75+
sampling_params = SamplingParams(max_tokens=5,
76+
temperature=0.0,
77+
top_k=50,
78+
top_p=0.9)
79+
80+
with VllmRunner("deepseek-ai/DeepSeek-V2-Lite",
81+
max_model_len=8192,
82+
dtype="float16",
83+
enforce_eager=True,
84+
gpu_memory_utilization=0.7) as vllm_model:
85+
vllm_model.generate(example_prompts, sampling_params)

tests/singlecard/test_offline_inference.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@
2121
Run `pytest tests/test_offline_inference.py`.
2222
"""
2323
import os
24+
from unittest.mock import patch
2425

2526
import pytest
2627
import vllm # noqa: F401
28+
from vllm import SamplingParams
2729
from vllm.assets.image import ImageAsset
2830

2931
import vllm_ascend # noqa: F401
@@ -81,3 +83,24 @@ def test_multimodal(model, prompt_template, vllm_runner):
8183
vllm_model.generate_greedy(prompts=prompts,
8284
images=images,
8385
max_tokens=64)
86+
87+
88+
@patch.dict(os.environ, {"VLLM_ENABLE_TOPK_OPTIMZE": "1"})
89+
def test_models_topk() -> None:
90+
example_prompts = [
91+
"Hello, my name is",
92+
"The president of the United States is",
93+
"The capital of France is",
94+
"The future of AI is",
95+
]
96+
sampling_params = SamplingParams(max_tokens=5,
97+
temperature=0.0,
98+
top_k=50,
99+
top_p=0.9)
100+
101+
with VllmRunner("Qwen/Qwen2.5-0.5B-Instructs",
102+
max_model_len=8192,
103+
dtype="float16",
104+
enforce_eager=True,
105+
gpu_memory_utilization=0.7) as vllm_model:
106+
vllm_model.generate(example_prompts, sampling_params)

tests/singlecard/test_sampler.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
# Adapted from vllm/tests/entrypoints/llm/test_guided_generate.py
5+
# Copyright 2023 The vLLM team.
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
#
19+
from typing import Optional
20+
21+
import torch
22+
from vllm.v1.sample.sampler import Sampler # noqa: F401
23+
24+
# Set tolerance to 1 for quant ops
25+
DEFAULT_ATOL = 1e-3
26+
DEFAULT_RTOL = 1e-3
27+
28+
29+
def apply_min_p_new(
30+
logits: torch.Tensor,
31+
min_p: torch.Tensor,
32+
) -> torch.Tensor:
33+
"""
34+
Filters logits using adaptive probability thresholding.
35+
"""
36+
if min_p == 0:
37+
return logits
38+
# Convert logits to probability distribution
39+
probability_values = torch.nn.functional.softmax(logits, dim=-1)
40+
# Calculate maximum probabilities per sequence
41+
max_probabilities = torch.amax(probability_values, dim=-1, keepdim=True)
42+
# Reshape min_p for broadcasting
43+
adjusted_min_p = min_p.unsqueeze(1) * max_probabilities
44+
# Identify valid tokens using threshold comparison
45+
# Apply mask using boolean indexing
46+
logits = logits.masked_fill(probability_values < adjusted_min_p,
47+
-float('inf'))
48+
return logits
49+
50+
51+
def apply_top_k_top_p(
52+
logits: torch.Tensor,
53+
k: Optional[torch.Tensor],
54+
p: Optional[torch.Tensor],
55+
) -> torch.Tensor:
56+
"""Apply top-k and top-p masks to the logits.
57+
58+
If a top-p is used, this function will sort the logits tensor,
59+
which can be slow for large batches.
60+
61+
The logits tensor may be updated in-place.
62+
"""
63+
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
64+
65+
if k is not None:
66+
# Apply top-k.
67+
top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B
68+
# Get all the top_k values.
69+
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
70+
top_k_mask = logits_sort < top_k_mask
71+
logits_sort.masked_fill_(top_k_mask, -float("inf"))
72+
73+
if p is not None:
74+
# Apply top-p.
75+
probs_sort = logits_sort.softmax(dim=-1)
76+
probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort)
77+
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
78+
# at least one
79+
top_p_mask[:, -1] = False
80+
logits_sort.masked_fill_(top_p_mask, -float("inf"))
81+
82+
# Re-sort the probabilities.
83+
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
84+
return logits
85+
86+
87+
def apply_top_k_top_p_new(
88+
logits: torch.Tensor,
89+
k: Optional[torch.Tensor],
90+
p: Optional[torch.Tensor],
91+
) -> torch.Tensor:
92+
batch_size, vocab_size = logits.shape
93+
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
94+
95+
# Apply top-k.
96+
boundary = logits_sort.gather(1, (vocab_size - k).unsqueeze(dim=1))
97+
top_k_mask = logits_sort < boundary
98+
logits_sort.masked_fill_(top_k_mask, -float("inf"))
99+
100+
if p is not None:
101+
# Apply top-p.
102+
cutoff = top_k_mask.sum(dim=-1).min()
103+
probs_sort = logits_sort.softmax(dim=-1)[:, cutoff:]
104+
probs_sum = probs_sort.cumsum(dim=-1)
105+
top_p_mask = probs_sum > 1 - p.unsqueeze(dim=1)
106+
top_p_mask[:, -1] = True
107+
strides = torch.arange(0,
108+
batch_size * vocab_size,
109+
vocab_size,
110+
device=logits.device)
111+
flatten_idx = logits_idx[:, cutoff:] + strides.unsqueeze(dim=1)
112+
valid_idx = torch.masked_select(flatten_idx, top_p_mask)
113+
logits_flatten = logits.flatten()
114+
valid_logits = torch.index_select(logits_flatten, 0, valid_idx)
115+
logits = torch.empty_like(logits_flatten).fill_(-float("inf"))
116+
logits[valid_idx] = valid_logits
117+
return logits.reshape(batch_size, vocab_size)
118+
119+
120+
# test with leading dimension and merge seqlen and batch_size as num_tokens
121+
@torch.inference_mode()
122+
def test_apply_min_p() -> None:
123+
logits = torch.randn((128, 7168)).npu()
124+
min_p = torch.Tensor([0.01]).npu()
125+
logits_new = apply_min_p_new(logits, min_p)
126+
sampler = Sampler()
127+
logits_old = sampler.apply_min_p(logits, min_p)
128+
# Compare the results.
129+
torch.testing.assert_close(logits_new,
130+
logits_old,
131+
atol=DEFAULT_ATOL,
132+
rtol=DEFAULT_RTOL)
133+
134+
135+
# test with leading dimension and merge seqlen and batch_size as num_tokens
136+
@torch.inference_mode()
137+
def test_apply_top_k_top_p() -> None:
138+
logits = torch.randn((128, 7168)).npu()
139+
k = torch.Tensor([-1]).int().npu()
140+
p = torch.Tensor([1]).int().npu()
141+
logits_new = apply_top_k_top_p_new(logits, k, p)
142+
logits_old = apply_top_k_top_p(logits, k, p)
143+
# Compare the results.
144+
torch.testing.assert_close(logits_new,
145+
logits_old,
146+
atol=DEFAULT_ATOL,
147+
rtol=DEFAULT_RTOL)

vllm_ascend/envs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
lambda: bool(int(os.getenv("COMPILE_CUSTOM_KERNELS", "1"))),
3737
"VLLM_ENABLE_MC2":
3838
lambda: bool(int(os.getenv("VLLM_ENABLE_MC2", '0'))),
39+
"VLLM_ENABLE_TOPK_OPTIMZE":
40+
lambda: bool(int(os.getenv("VLLM_ENABLE_TOPK_OPTIMZE", '0'))),
3941
"USING_LCCL_COM":
4042
lambda: bool(int(os.getenv("USING_LCCL_COM", '0'))),
4143
"SOC_VERSION":

vllm_ascend/models/deepseek_v2.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,8 +230,7 @@ def forward(
230230
enable_force_load_balance = False
231231
num_tokens, hidden_dim = hidden_states.shape
232232

233-
if self.n_shared_experts is not None:
234-
shared_output = self.shared_experts(hidden_states)
233+
old_hidden_states = hidden_states.detach()
235234

236235
if self.tp_size > 1:
237236
# pass
@@ -270,6 +269,9 @@ def forward(
270269
else:
271270
final_hidden_states = router_hidden_states
272271

272+
if self.n_shared_experts is not None:
273+
shared_output = self.shared_experts(old_hidden_states)
274+
273275
if shared_output is not None:
274276
final_hidden_states = final_hidden_states + shared_output
275277

vllm_ascend/ops/fused_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ def fused_experts(
362362
num_experts)).to(topk_ids.dtype)
363363

364364
# Sort by local expert IDs
365-
sort_indices = torch.argsort(filtered_experts)
365+
sort_indices = torch.argsort(filtered_experts.view(torch.float32))
366366
sorted_token_indices = token_indices[sort_indices]
367367
sorted_weights = filtered_weights[sort_indices]
368368

vllm_ascend/patch/__init__.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,3 +166,30 @@
166166
# Future Plan:
167167
# Revert it when the ascend support triton kernel.
168168
#
169+
# ** File: v1/sample/sampler.py **
170+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
171+
# 1. `vllm.v1.sample.sampler.Sampler.apply_top_k_top_p`
172+
# Why:
173+
# We need to use the patched `apply_top_k_top_p` in `sample`.
174+
# The mainly reason to overwrite `apply_top_k_top_p` is
175+
# to improve performance.
176+
# How:
177+
# Re-implementation the `apply_top_k_top_p` function by pytorch
178+
# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit....
179+
# - https://github.com/vllm-project/vllm-ascend/pull/970
180+
# Future Plan:
181+
# Revert it when the ascend scatter performance improves.
182+
#
183+
# ** File: v1/sample/sampler.py **
184+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~s
185+
# 1. `vllm.v1.sample.sampler.Sampler.apply_min_p`
186+
# Why:
187+
# We need to use the patched `apply_min_p` in `sample`.
188+
# The mainly reason to overwrite `apply_min_p` is
189+
# to improve performance.
190+
# How:
191+
# Re-implementation the `apply_min_p` function by pytorch
192+
# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit....
193+
# - https://github.com/vllm-project/vllm-ascend/pull/970
194+
# Future Plan:
195+
# Revert it when the ascend indexput performance improves.

vllm_ascend/patch/worker/patch_common/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,5 @@
2323
import vllm_ascend.patch.worker.patch_common.patch_metrics # noqa
2424
import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa
2525
import vllm_ascend.patch.worker.patch_common.patch_multi_step_worker # noqa
26+
import vllm_ascend.patch.worker.patch_common.patch_sampler # noqa
2627
import vllm_ascend.patch.worker.patch_common.patch_spec_decode_worker # noqa
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from typing import Optional
19+
20+
import torch
21+
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample
22+
from vllm.v1.sample.sampler import Sampler
23+
24+
from vllm_ascend import envs
25+
26+
27+
def apply_min_p(
28+
self,
29+
logits: torch.Tensor,
30+
min_p: torch.Tensor,
31+
) -> torch.Tensor:
32+
"""
33+
Filters logits using adaptive probability thresholding.
34+
"""
35+
# Convert logits to probability distribution
36+
probability_values = torch.nn.functional.softmax(logits, dim=-1)
37+
# Calculate maximum probabilities per sequence
38+
max_probabilities = torch.amax(probability_values, dim=-1, keepdim=True)
39+
# Reshape min_p for broadcasting
40+
adjusted_min_p = min_p.unsqueeze(1) * max_probabilities
41+
# Identify valid tokens using threshold comparison
42+
# Apply mask using boolean indexing
43+
logits = logits.masked_fill(probability_values < adjusted_min_p,
44+
-float('inf'))
45+
return logits
46+
47+
48+
def _apply_top_k_top_p(
49+
logits: torch.Tensor,
50+
p: torch.Tensor,
51+
k: torch.Tensor,
52+
) -> torch.Tensor:
53+
batch_size, vocab_size = logits.shape
54+
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
55+
56+
# Apply top-k.
57+
boundary = logits_sort.gather(1, (vocab_size - k).unsqueeze(dim=1))
58+
top_k_mask = logits_sort < boundary
59+
logits_sort.masked_fill_(top_k_mask, -float("inf"))
60+
61+
# Apply top-p.
62+
cutoff = top_k_mask.sum(dim=-1).min()
63+
probs_sort = logits_sort.softmax(dim=-1)[:, cutoff:]
64+
probs_sum = probs_sort.cumsum(dim=-1)
65+
top_p_mask = probs_sum > 1 - p.unsqueeze(dim=1)
66+
67+
top_p_mask[:, -1] = True
68+
strides = torch.arange(0,
69+
batch_size * vocab_size,
70+
vocab_size,
71+
device=logits.device)
72+
flatten_idx = logits_idx[:, cutoff:] + strides.unsqueeze(dim=1)
73+
valid_idx = torch.masked_select(flatten_idx, top_p_mask)
74+
logits_flatten = logits.flatten()
75+
valid_logits = torch.index_select(logits_flatten, 0, valid_idx)
76+
logits = torch.empty_like(logits_flatten).fill_(-float("inf"))
77+
logits[valid_idx] = valid_logits
78+
return logits.reshape(batch_size, vocab_size)
79+
80+
81+
def topk_topp_forward_native(
82+
self,
83+
logits: torch.Tensor,
84+
generators: dict[int, torch.Generator],
85+
k: Optional[torch.Tensor],
86+
p: Optional[torch.Tensor],
87+
) -> torch.Tensor:
88+
"""
89+
PyTorch-native implementation of top-k and top-p sampling.
90+
91+
The logits tensor may be updated in-place.
92+
"""
93+
logits = _apply_top_k_top_p(logits, k, p)
94+
probs = logits.softmax(dim=-1, dtype=torch.float32)
95+
return random_sample(probs, generators)
96+
97+
98+
Sampler.apply_min_p = apply_min_p
99+
if envs.VLLM_ENABLE_TOPK_OPTIMZE:
100+
TopKTopPSampler.forward_native = topk_topp_forward_native

0 commit comments

Comments
 (0)