Skip to content

Commit 5c335b1

Browse files
committed
fix test_min_sampling
1 parent 5b9ffc5 commit 5c335b1

File tree

6 files changed

+119
-112
lines changed

6 files changed

+119
-112
lines changed

docs/zh/offline_inference.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ for output in outputs:
180180
* temperature(float): 控制生成随机性的参数,值越高结果越随机,值越低结果越确定
181181
* top_p(float): 概率累积分布截断阈值,仅考虑累计概率达到此阈值的最可能token集合
182182
* top_k(int): 采样概率最高的token数量,考虑概率最高的k个token进行采样
183-
* min_p(float): token入选的最小概率阈值(相对于最高概率token的比值,设为>0可通过过滤低概率token来提升文本生成质量)
183+
* min_p(float): token入选的最小概率阈值(相对于最高概率token的比值设为>0可通过过滤低概率token来提升文本生成质量)
184184
* max_tokens(int): 限制模型生成的最大token数量(包括输入和输出)
185185
* min_tokens(int): 强制模型生成的最少token数量,避免过早结束
186186

fastdeploy/engine/sampling_params.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class SamplingParams:
5353
top_p: Float that controls the cumulative probability of the top tokens
5454
to consider. Must be in [0, 1]. Set to 1 to consider all tokens.
5555
top_k: Int that controls the number of top tokens to consider. Must be a positive integer.
56-
min_p:Float that represents the minimum probability for a token to be
56+
min_p: Float that represents the minimum probability for a token to be
5757
considered, relative to the probability of the most likely token.
5858
Must be in [0, 1]. Set to 0 to disable this.
5959
seed: Random seed to use for the generation.
@@ -87,7 +87,7 @@ class SamplingParams:
8787
temperature: float = None
8888
top_p: float = None
8989
top_k: int = 0
90-
min_p: float=0.0
90+
min_p: float = 0.0
9191
seed: Optional[int] = None
9292
stop: Optional[Union[str, List[str]]] = None
9393
stop_token_ids: Optional[Union[List[List[int]], List[int]]] = None
@@ -186,7 +186,7 @@ def _verify_args(self) -> None:
186186
if not isinstance(self.top_k, int):
187187
raise TypeError(
188188
f"top_k must be an integer, got {type(self.top_k).__name__}")
189-
if not 0.0 <=self.min_p <= 1.0:
189+
if not 0.0 <= self.min_p <= 1.0:
190190
raise ValueError("min_p must be in [0,1],got f{self.min_p}")
191191

192192
if self.max_tokens is not None and self.max_tokens < 1:

fastdeploy/model_executor/layers/sample/sampler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def forward_cuda(
269269

270270
probs = F.softmax(logits)
271271

272-
probs= min_p_sampling(probs,sampling_metadata.min_p)
272+
probs = min_p_sampling(probs,sampling_metadata.min_p)
273273

274274
_, next_tokens = top_k_top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k)
275275

fastdeploy/worker/gpu_model_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ def get_attr_from_request(request, attr, default_value=None):
337337
request.eos_token_ids, dtype="int64").reshape(-1, 1)
338338
self.share_inputs["top_p"][idx:idx + 1] = get_attr_from_request(request, "top_p", 0.7)
339339
self.share_inputs["top_k"][idx:idx + 1] = request.get("top_k", 0)
340-
self.share_inputs["min_p"][idx:idx + 1] = request.get("min_p",0.0)
340+
self.share_inputs["min_p"][idx:idx + 1] = request.get("min_p", 0.0)
341341

342342
self.share_inputs["temperature"][idx:idx + 1] = get_attr_from_request(request,"temperature", 0.95)
343343
self.share_inputs["penalty_score"][idx:idx + 1] = get_attr_from_request(

test/layers/test_min_p.py

Lines changed: 0 additions & 106 deletions
This file was deleted.

test/layers/test_min_sampling.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import unittest
17+
18+
import numpy as np
19+
import paddle
20+
import paddle.nn.functional as F
21+
22+
from fastdeploy.model_executor.ops.gpu import min_p_sampling
23+
24+
25+
class TestMinPSampling(unittest.TestCase):
26+
def setUp(self):
27+
self.sample_time = 1000000
28+
self.vocab_size = 1000
29+
self.min_p_value = 0.5
30+
self.batch_size = 3
31+
self.batch_min_p_values = [0.1, 0.0, 0.9]
32+
self.additional_batch_min_p_values = [0.1, 0.0, 0.3]
33+
34+
35+
# min_p:0.5:FastDeploy
36+
def min_p_sampling_cpu(self,min_p):
37+
logits = paddle.ones(shape=[1, self.vocab_size], dtype="float32")
38+
logits[0][0] = 10
39+
logits[0][1] = 8
40+
low_prob_tensor = paddle.linspace(2.0, 0.0, self.vocab_size - 2)
41+
logits[0][2:] = low_prob_tensor
42+
43+
probs = F.softmax(logits)
44+
max_probabilities = paddle.amax(probs, axis=-1, keepdim=True)
45+
adjusted_min_p = max_probabilities * min_p.reshape([-1, 1])
46+
invalid_token_mask = probs < adjusted_min_p
47+
probs = paddle.where(invalid_token_mask,paddle.full_like(probs,0.0), probs)
48+
return probs
49+
50+
# min_p:0.5:FastDeploy
51+
def fastdeploy_min_p_sampling(self,min_p):
52+
logits = paddle.ones(shape=[1, self.vocab_size], dtype="float32")
53+
logits[0][0] = 10
54+
logits[0][1] = 8
55+
low_prob_tensor = paddle.linspace(2.0, 0.0, self.vocab_size - 2)
56+
logits[0][2:] = low_prob_tensor
57+
58+
probs = F.softmax(logits)
59+
probs = min_p_sampling(probs, min_p)
60+
return probs
61+
62+
63+
# batch:[0.1.0.0,0.9]:FastDeploy
64+
def fastdeploy_batch_min_p_sampling(self,batch_size, min_p_values):
65+
logits = paddle.ones(shape=[batch_size, self.vocab_size], dtype="float32")
66+
for b in range(batch_size):
67+
logits[b][0] = 10
68+
logits[b][1] = 8
69+
logits[b][2:] = paddle.linspace(2.0, 0.0, self.vocab_size - 2)
70+
71+
probs = F.softmax(logits, axis=-1)
72+
min_p_arr = paddle.to_tensor(min_p_values, dtype="float32")
73+
74+
probs = min_p_sampling(probs, min_p_arr)
75+
76+
return probs
77+
78+
def compare_results(self, probs, probs_cpu, atol=1e-6, rtol=1e-6):
79+
probs_np = probs.numpy()
80+
probs_cpu_np = probs_cpu.numpy()
81+
try:
82+
np.testing.assert_allclose(
83+
probs_np,
84+
probs_cpu_np,
85+
rtol=rtol,
86+
atol=atol,
87+
)
88+
print("The results are same between fastdeploy_min_p_sampling and min_p_sampling_cpu")
89+
except AssertionError as e:
90+
raise AssertionError(
91+
f"The results are different between fastdeploy_min_p_sampling and min_p_sampling_cpu:\n{str(e)}")
92+
93+
def test_single_min_p_sampling(self):
94+
min_p = paddle.to_tensor([self.min_p_value], dtype="float32")
95+
probs = self.fastdeploy_min_p_sampling(min_p)
96+
probs_cpu = self.min_p_sampling_cpu(min_p)
97+
self.compare_results(probs, probs_cpu)
98+
99+
def test_batch_min_p_sampling(self):
100+
batch_min_p = paddle.to_tensor(self.batch_min_p_values, dtype="float32")
101+
batch_probs = self.fastdeploy_batch_min_p_sampling(self.batch_size, batch_min_p)
102+
batch_probs_cpu = self.min_p_sampling_cpu(batch_min_p)
103+
self.compare_results(batch_probs, batch_probs_cpu)
104+
105+
def test_additional_batch_min_p_sampling(self):
106+
additional_batch_min_p = paddle.to_tensor(self.additional_batch_min_p_values, dtype="float32")
107+
additional_batch_probs = self.fastdeploy_batch_min_p_sampling(self.batch_size, additional_batch_min_p)
108+
additional_batch_probs_cpu = self.min_p_sampling_cpu(additional_batch_min_p)
109+
self.compare_results(additional_batch_probs, additional_batch_probs_cpu)
110+
111+
if __name__ == "__main__":
112+
if paddle.is_compiled_with_cuda():
113+
unittest.main()

0 commit comments

Comments
 (0)