|
| 1 | +# |
| 2 | +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. |
| 3 | +# This file is a part of the vllm-ascend project. |
| 4 | +# Adapted from vllm/tests/entrypoints/llm/test_guided_generate.py |
| 5 | +# Copyright 2023 The vLLM team. |
| 6 | +# |
| 7 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 8 | +# you may not use this file except in compliance with the License. |
| 9 | +# You may obtain a copy of the License at |
| 10 | +# |
| 11 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 12 | +# |
| 13 | +# Unless required by applicable law or agreed to in writing, software |
| 14 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 15 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 16 | +# See the License for the specific language governing permissions and |
| 17 | +# limitations under the License. |
| 18 | +# |
| 19 | +from typing import Optional |
| 20 | + |
| 21 | +import torch |
| 22 | +from vllm.v1.sample.sampler import Sampler # noqa: F401 |
| 23 | + |
| 24 | +# Set tolerance to 1 for quant ops |
| 25 | +DEFAULT_ATOL = 1e-3 |
| 26 | +DEFAULT_RTOL = 1e-3 |
| 27 | + |
| 28 | + |
| 29 | +def apply_min_p_new( |
| 30 | + logits: torch.Tensor, |
| 31 | + min_p: torch.Tensor, |
| 32 | +) -> torch.Tensor: |
| 33 | + """ |
| 34 | + Filters logits using adaptive probability thresholding. |
| 35 | + """ |
| 36 | + if min_p == 0: |
| 37 | + return logits |
| 38 | + # Convert logits to probability distribution |
| 39 | + probability_values = torch.nn.functional.softmax(logits, dim=-1) |
| 40 | + # Calculate maximum probabilities per sequence |
| 41 | + max_probabilities = torch.amax(probability_values, dim=-1, keepdim=True) |
| 42 | + # Reshape min_p for broadcasting |
| 43 | + adjusted_min_p = min_p.unsqueeze(1) * max_probabilities |
| 44 | + # Identify valid tokens using threshold comparison |
| 45 | + # Apply mask using boolean indexing |
| 46 | + logits = logits.masked_fill(probability_values < adjusted_min_p, |
| 47 | + -float('inf')) |
| 48 | + return logits |
| 49 | + |
| 50 | + |
| 51 | +def apply_top_k_top_p( |
| 52 | + logits: torch.Tensor, |
| 53 | + k: Optional[torch.Tensor], |
| 54 | + p: Optional[torch.Tensor], |
| 55 | +) -> torch.Tensor: |
| 56 | + """Apply top-k and top-p masks to the logits. |
| 57 | +
|
| 58 | + If a top-p is used, this function will sort the logits tensor, |
| 59 | + which can be slow for large batches. |
| 60 | +
|
| 61 | + The logits tensor may be updated in-place. |
| 62 | + """ |
| 63 | + logits_sort, logits_idx = logits.sort(dim=-1, descending=False) |
| 64 | + |
| 65 | + if k is not None: |
| 66 | + # Apply top-k. |
| 67 | + top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B |
| 68 | + # Get all the top_k values. |
| 69 | + top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1)) |
| 70 | + top_k_mask = logits_sort < top_k_mask |
| 71 | + logits_sort.masked_fill_(top_k_mask, -float("inf")) |
| 72 | + |
| 73 | + if p is not None: |
| 74 | + # Apply top-p. |
| 75 | + probs_sort = logits_sort.softmax(dim=-1) |
| 76 | + probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort) |
| 77 | + top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) |
| 78 | + # at least one |
| 79 | + top_p_mask[:, -1] = False |
| 80 | + logits_sort.masked_fill_(top_p_mask, -float("inf")) |
| 81 | + |
| 82 | + # Re-sort the probabilities. |
| 83 | + logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort) |
| 84 | + return logits |
| 85 | + |
| 86 | + |
| 87 | +def apply_top_k_top_p_new( |
| 88 | + logits: torch.Tensor, |
| 89 | + k: Optional[torch.Tensor], |
| 90 | + p: Optional[torch.Tensor], |
| 91 | +) -> torch.Tensor: |
| 92 | + batch_size, vocab_size = logits.shape |
| 93 | + logits_sort, logits_idx = logits.sort(dim=-1, descending=False) |
| 94 | + |
| 95 | + # Apply top-k. |
| 96 | + boundary = logits_sort.gather(1, (vocab_size - k).unsqueeze(dim=1)) |
| 97 | + top_k_mask = logits_sort < boundary |
| 98 | + logits_sort.masked_fill_(top_k_mask, -float("inf")) |
| 99 | + |
| 100 | + if p is not None: |
| 101 | + # Apply top-p. |
| 102 | + cutoff = top_k_mask.sum(dim=-1).min() |
| 103 | + probs_sort = logits_sort.softmax(dim=-1)[:, cutoff:] |
| 104 | + probs_sum = probs_sort.cumsum(dim=-1) |
| 105 | + top_p_mask = probs_sum > 1 - p.unsqueeze(dim=1) |
| 106 | + top_p_mask[:, -1] = True |
| 107 | + strides = torch.arange(0, |
| 108 | + batch_size * vocab_size, |
| 109 | + vocab_size, |
| 110 | + device=logits.device) |
| 111 | + flatten_idx = logits_idx[:, cutoff:] + strides.unsqueeze(dim=1) |
| 112 | + valid_idx = torch.masked_select(flatten_idx, top_p_mask) |
| 113 | + logits_flatten = logits.flatten() |
| 114 | + valid_logits = torch.index_select(logits_flatten, 0, valid_idx) |
| 115 | + logits = torch.empty_like(logits_flatten).fill_(-float("inf")) |
| 116 | + logits[valid_idx] = valid_logits |
| 117 | + return logits.reshape(batch_size, vocab_size) |
| 118 | + |
| 119 | + |
| 120 | +# test with leading dimension and merge seqlen and batch_size as num_tokens |
| 121 | +@torch.inference_mode() |
| 122 | +def test_apply_min_p() -> None: |
| 123 | + logits = torch.randn((128, 7168)).npu() |
| 124 | + min_p = torch.Tensor([0.01]).npu() |
| 125 | + logits_new = apply_min_p_new(logits, min_p) |
| 126 | + sampler = Sampler() |
| 127 | + logits_old = sampler.apply_min_p(logits, min_p) |
| 128 | + # Compare the results. |
| 129 | + torch.testing.assert_close(logits_new, |
| 130 | + logits_old, |
| 131 | + atol=DEFAULT_ATOL, |
| 132 | + rtol=DEFAULT_RTOL) |
| 133 | + |
| 134 | + |
| 135 | +# test with leading dimension and merge seqlen and batch_size as num_tokens |
| 136 | +@torch.inference_mode() |
| 137 | +def test_apply_top_k_top_p() -> None: |
| 138 | + logits = torch.randn((128, 7168)).npu() |
| 139 | + k = torch.Tensor([-1]).int().npu() |
| 140 | + p = torch.Tensor([1]).int().npu() |
| 141 | + logits_new = apply_top_k_top_p_new(logits, k, p) |
| 142 | + logits_old = apply_top_k_top_p(logits, k, p) |
| 143 | + # Compare the results. |
| 144 | + torch.testing.assert_close(logits_new, |
| 145 | + logits_old, |
| 146 | + atol=DEFAULT_ATOL, |
| 147 | + rtol=DEFAULT_RTOL) |
0 commit comments