Skip to content

Commit c1c7b32

Browse files
author
wangxiaoxin (A)
committed
xx
1 parent ba413b4 commit c1c7b32

File tree

4 files changed

+130
-3
lines changed

4 files changed

+130
-3
lines changed

tests/sample/test_sampler.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Copyright 2023 The vLLM team.
2+
3+
# Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved.
4+
# Adapted from
5+
# https://github.com/vllm-project/vllm/blob/main/vllm/tests/kernels/test_rotary_embedding.py
6+
7+
from typing import Optional, Tuple, Union
8+
9+
import pytest
10+
import torch
11+
import torch.nn as nn
12+
13+
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p # noqa: F401
14+
import vllm.v1.sample.Sampler import apply_min_p # noqa: F401
15+
16+
# Only Neox style true scenario is supported for now
17+
IS_NEOX_STYLE = [True]
18+
DTYPES = [torch.half]
19+
HEAD_SIZES = [64, 96, 128, 256]
20+
ROTARY_DIMS = [None, 32] # None means rotary dim == head size
21+
NUM_HEADS = [17] # Arbitrary values for testing
22+
BATCH_SIZES = [5] # Arbitrary values for testing
23+
SEQ_LENS = [11, 4096] # Arbitrary values for testing
24+
SEEDS = [0]
25+
DEVICES = [f"npu:{0}"]
26+
# Set tolerance to 1 for quant ops
27+
DEFAULT_ATOL = 1e-3
28+
DEFAULT_RTOL = 1e-3
29+
30+
31+
def apply_min_p_new(
32+
logits: torch.Tensor,
33+
min_p: torch.Tensor,
34+
) -> torch.Tensor:
35+
"""
36+
Filters logits using adaptive probability thresholding.
37+
"""
38+
if min_p == 0:
39+
return logits
40+
# Convert logits to probability distribution
41+
probability_values = torch.nn.functional.softmax(logits, dim=-1)
42+
# Calculate maximum probabilities per sequence
43+
max_probabilities = torch.amax(probability_values,
44+
dim=-1,
45+
keepdim=True)
46+
# Reshape min_p for broadcasting
47+
adjusted_min_p = min_p.unsqueeze(1) * max_probabilities
48+
# Identify valid tokens using threshold comparison
49+
# Apply mask using boolean indexing
50+
logits = logits.masked_fill(probability_values < adjusted_min_p, -float('inf'))
51+
return logits
52+
53+
def apply_top_k_top_p_new(
54+
logits: torch.Tensor,
55+
k: Optional[torch.Tensor],
56+
p: Optional[torch.Tensor],
57+
) -> torch.Tensor:
58+
batch_size, vocab_size = logits.shape
59+
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
60+
61+
# Apply top-k.
62+
boundary = logits_sort.gather(1, (vocab_size - k).unsqueeze(dim=1))
63+
top_k_mask = logits_sort < boundary
64+
logits_sort.masked_fill_(top_k_mask, -float("inf"))
65+
66+
# Apply top-p.
67+
cutoff = top_k_mask.sum(dim=-1).min()
68+
probs_sort = logits_sort.softmax(dim=-1)[:, cutoff:]
69+
probs_sum = probs_sort.cumsum(dim=-1)
70+
top_p_mask = probs_sum > 1 - p.unsqueeze(dim=1)
71+
72+
top_p_mask[:, -1] = True
73+
strides = torch.arange(0, batch_size*vocab_size, vocab_size, device=logits.device)
74+
flatten_idx = logits_idx[:, cutoff:] + strides.unsqueeze(dim=1)
75+
valid_idx = torch.masked_select(flatten_idx, top_p_mask)
76+
logits_flatten = logits.flatten()
77+
valid_logits = torch.index_select(logits_flatten, 0, valid_idx)
78+
logits = torch.empty_like(logits_flatten).fill_(-float("inf"))
79+
logits[valid_idx] = valid_logits
80+
return logits.reshape(batch_size, vocab_size)
81+
82+
# test with leading dimension and merge seqlen and batch_size as num_tokens
83+
@pytest.mark.parametrize("device", DEVICES)
84+
@torch.inference_mode()
85+
def test_apply_min_p(
86+
) -> None:
87+
logits =
88+
min_p =
89+
logits_new = apply_min_p_new(logits, min_p)
90+
logits_old = apply_top_k_top_p(logits, min_p)
91+
# Compare the results.
92+
torch.testing.assert_close(logits_new,
93+
logits_old,
94+
atol=DEFAULT_ATOL,
95+
rtol=DEFAULT_RTOL)
96+
97+
# test with leading dimension and merge seqlen and batch_size as num_tokens
98+
@pytest.mark.parametrize("device", DEVICES)
99+
@torch.inference_mode()
100+
def test_apply_top_k_top_p(
101+
) -> None:
102+
logits =
103+
k =
104+
p =
105+
logits_new = apply_top_k_top_p_new(logits, k, p)
106+
logits_old = apply_top_k_top_p(logits, k, p)
107+
# Compare the results.
108+
torch.testing.assert_close(logits_new,
109+
logits_old,
110+
atol=DEFAULT_ATOL,
111+
rtol=DEFAULT_RTOL)

vllm_ascend/ops/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import vllm_ascend.ops.layernorm # noqa
2424
import vllm_ascend.ops.rotary_embedding # noqa
2525
import vllm_ascend.ops.vocab_parallel_embedding # noqa
26-
import vllm_ascend.ops.utils # noqa
2726

2827

2928
class dummyFusionOp:

vllm_ascend/patch/worker/patch_common/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,4 @@
2424
import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa
2525
import vllm_ascend.patch.worker.patch_common.patch_multi_step_worker # noqa
2626
import vllm_ascend.patch.worker.patch_common.patch_spec_decode_worker # noqa
27+
import vllm_ascend.patch.worker.patch_common.patch_sampler # noqa

vllm_ascend/patch/worker/patch_common/patch_sampler.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,22 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
117

218
import torch
3-
import vllm.v1.sample.sampler as s1
19+
import vllm.v1.sample import Sampler
420
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample
521
from vllm import envs
622
from typing import Callable, Optional
@@ -73,6 +89,6 @@ def topk_topp_forward_native(
7389
probs = logits.softmax(dim=-1, dtype=torch.float32)
7490
return random_sample(probs, generators)
7591

76-
s1.apply_min_p = apply_min_p
92+
Sampler.apply_min_p = apply_min_p
7793
if envs.VLLM_ENABLE_TOPK_OPTIMZE:
7894
TopKTopPSampler.forward_native = topk_topp_forward_native

0 commit comments

Comments
 (0)