Skip to content

Commit 67990e0

Browse files
authored
[Feature] support min_p_sampling (#2872)
* Fastdeploy support min_p * add test_min_p * fix * min_p_sampling * update * delete vl_gpu_model_runner.py * fix * Align usage of min_p with vLLM * fix * modified unit test * fix test_min_sampling * pre-commit all files * fix * fix * fix * fix xpu_model_runner.py
1 parent 95a214a commit 67990e0

File tree

15 files changed

+302
-1
lines changed

15 files changed

+302
-1
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "helper.h"
16+
#include "paddle/phi/backends/context_pool.h"
17+
#include "sample_kernels/sampling.cuh"
18+
19+
std::vector<paddle::Tensor> MinPSamplingFromProbs(const paddle::Tensor &probs,
20+
const paddle::Tensor &min_p) {
21+
std::vector<int64_t> probs_shape = probs.shape();
22+
unsigned int batch_size = probs_shape[0];
23+
unsigned int vocab_size = probs_shape[1];
24+
auto cu_stream = probs.stream();
25+
26+
auto renorm_probs =
27+
GetEmptyTensor(probs.dims(), paddle::DataType::FLOAT32, probs.place());
28+
29+
cudaError_t status;
30+
31+
status = sampling::MinPSamplingFromProb<float, int>(
32+
const_cast<float *>(probs.data<float>()),
33+
const_cast<float *>(min_p.data<float>()),
34+
renorm_probs.data<float>(),
35+
batch_size,
36+
vocab_size,
37+
true, // deterministic
38+
cu_stream);
39+
40+
41+
PD_CHECK(status == cudaSuccess, "SamplingFromProbs failed with error code " +
42+
std::string(cudaGetErrorString(status)));
43+
44+
return {renorm_probs};
45+
}
46+
47+
std::vector<std::vector<int64_t>>
48+
MinPSamplingFromProbsInferShape(const std::vector<int64_t> &probs_shape,
49+
const paddle::optional<std::vector<int64_t>> &min_p_shape) {
50+
return {probs_shape};
51+
}
52+
53+
std::vector<paddle::DataType>
54+
MinPSamplingFromProbsInferDtype(const paddle::DataType &probs_dtype,
55+
const paddle::optional<paddle::DataType> &min_p_dtype) {
56+
return {probs_dtype};
57+
}
58+
59+
60+
PD_BUILD_STATIC_OP(min_p_sampling)
61+
.Inputs({"probs", "min_p"})
62+
.Outputs({"renorm_probs"})
63+
.SetKernelFn(PD_KERNEL(MinPSamplingFromProbs))
64+
.SetInferShapeFn(PD_INFER_SHAPE(MinPSamplingFromProbsInferShape))
65+
.SetInferDtypeFn(PD_INFER_DTYPE(MinPSamplingFromProbsInferDtype));

custom_ops/gpu_ops/sample_kernels/sampling.cuh

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,9 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
276276
aggregate += aggregate_local;
277277
}
278278

279+
280+
281+
279282
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
280283
BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE, bool DETERMINISTIC,
281284
typename DType, typename IdType>
@@ -391,6 +394,8 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output,
391394
}
392395
}
393396

397+
398+
394399
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
395400
BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
396401
bool DETERMINISTIC, typename DType, typename IdType>
@@ -553,6 +558,47 @@ struct RenormTempStorage {
553558
};
554559
};
555560

561+
template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
562+
BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE, bool DETERMINISTIC,
563+
typename DType,typename IdType>
564+
__global__ void MinPSamplingFromProbKernel(DType* probs, const float* min_p_arr,
565+
DType* renormed_prob,uint32_t d) {
566+
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
567+
float p = (min_p_arr == nullptr) ? 0 : min_p_arr[bx];
568+
const uint32_t row_idx = bx;
569+
570+
extern __shared__ __align__(
571+
alignof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
572+
uint8_t smem_sampling[];
573+
auto& temp_storage =
574+
reinterpret_cast<SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
575+
smem_sampling);
576+
577+
float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
578+
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>>(
579+
probs, row_idx, d, temp_storage);
580+
float pivot = max_val * p;
581+
582+
vec_t<float, VEC_SIZE> probs_vec;
583+
#pragma unroll 2
584+
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
585+
probs_vec.fill(0);
586+
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
587+
probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
588+
}
589+
590+
#pragma unroll
591+
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
592+
probs_vec[j] = (probs_vec[j] >= pivot) ? probs_vec[j] : 0;
593+
}
594+
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
595+
probs_vec.store(renormed_prob + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
596+
}
597+
598+
}
599+
}
600+
601+
556602
template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
557603
typename DType, typename IdType>
558604
__global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* top_k_arr, uint32_t d) {
@@ -705,6 +751,33 @@ cudaError_t TopPSamplingFromProb(T *probs, IdType *output,
705751
return cudaSuccess;
706752
}
707753

754+
template <typename T,typename IdType>
755+
cudaError_t MinPSamplingFromProb(T *probs, const T* min_p_arr,T *renormed_prob,
756+
uint32_t batch_size,
757+
uint32_t d, bool deterministic,
758+
cudaStream_t stream = 0){
759+
constexpr uint32_t BLOCK_THREADS = 1024;
760+
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
761+
762+
const uint32_t smem_size = sizeof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
763+
dim3 nblks(batch_size);
764+
dim3 nthrs(BLOCK_THREADS);
765+
void* args[] = {&probs, &min_p_arr,&renormed_prob,&d};
766+
DISPATCH_ALIGNED_VEC_SIZE(
767+
vec_size, VEC_SIZE,
768+
{DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
769+
auto kernel =
770+
MinPSamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO,
771+
VEC_SIZE, DETERMINISTIC, T,IdType>;
772+
CUDA_CALL(cudaFuncSetAttribute(
773+
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
774+
CUDA_CALL(cudaLaunchKernel((void *)kernel, nblks, nthrs, args,
775+
smem_size, stream));
776+
})});
777+
return cudaSuccess;
778+
}
779+
780+
708781
template <typename T, typename IdType>
709782
cudaError_t TopKTopPSamplingFromProb(T *probs, IdType *output,
710783
uint32_t batch_size, const T *top_p_val, const IdType *top_k_val,

custom_ops/setup_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,7 @@ def find_end_files(directory, end_str):
287287
"gpu_ops/text_image_gather_scatter.cu",
288288
"gpu_ops/sample_kernels/rejection_top_p_sampling.cu",
289289
"gpu_ops/sample_kernels/top_k_renorm_probs.cu",
290+
"gpu_ops/sample_kernels/min_p_sampling_from_probs.cu",
290291
"gpu_ops/get_position_ids_and_mask_encoder_batch.cu",
291292
"gpu_ops/fused_rotary_position_encoding.cu",
292293
"gpu_ops/noaux_tc.cu",

docs/offline_inference.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ For ```LLM``` configuration, refer to [Parameter Documentation](parameters.md).
180180
* temperature(float): Controls randomness (higher = more random)
181181
* top_p(float): Probability threshold for token selection
182182
* top_k(int): Number of tokens considered for sampling
183+
* min_p(float): Minimum probability relative to the maximum probability for a token to be considered (>0 filters low-probability tokens to improve quality)
183184
* max_tokens(int): Maximum generated tokens (input + output)
184185
* min_tokens(int): Minimum forced generation length
185186

docs/zh/offline_inference.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ for output in outputs:
180180
* temperature(float): 控制生成随机性的参数,值越高结果越随机,值越低结果越确定
181181
* top_p(float): 概率累积分布截断阈值,仅考虑累计概率达到此阈值的最可能token集合
182182
* top_k(int): 采样概率最高的token数量,考虑概率最高的k个token进行采样
183+
* min_p(float): token入选的最小概率阈值(相对于最高概率token的比值,设为>0可通过过滤低概率token来提升文本生成质量)
183184
* max_tokens(int): 限制模型生成的最大token数量(包括输入和输出)
184185
* min_tokens(int): 强制模型生成的最少token数量,避免过早结束
185186

fastdeploy/engine/sampling_params.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ class SamplingParams:
5353
top_p: Float that controls the cumulative probability of the top tokens
5454
to consider. Must be in [0, 1]. Set to 1 to consider all tokens.
5555
top_k: Int that controls the number of top tokens to consider. Must be a positive integer.
56+
min_p: Float that represents the minimum probability for a token to be
57+
considered, relative to the probability of the most likely token.
58+
Must be in [0, 1]. Set to 0 to disable this.
5659
seed: Random seed to use for the generation.
5760
stop: list of strings that stop the generation when they are generated.
5861
The returned output will not contain the stop strings.
@@ -84,6 +87,7 @@ class SamplingParams:
8487
temperature: float = None
8588
top_p: float = None
8689
top_k: int = 0
90+
min_p: float = 0.0
8791
seed: Optional[int] = None
8892
stop: Optional[Union[str, List[str]]] = None
8993
stop_token_ids: Optional[Union[List[List[int]], List[int]]] = None
@@ -114,6 +118,7 @@ def from_optional(
114118
temperature,
115119
top_p,
116120
top_k,
121+
min_p,
117122
seed=None,
118123
stop=None,
119124
stop_token_ids=None,
@@ -133,6 +138,7 @@ def from_optional(
133138
temperature=temperature if temperature is not None else 1.0,
134139
top_p=top_p,
135140
top_k=top_k if top_k is not None else 0,
141+
min_p=min_p if min_p is not None else 0.0,
136142
seed=seed,
137143
stop=stop,
138144
stop_token_ids=stop_token_ids,
@@ -170,6 +176,8 @@ def _verify_args(self) -> None:
170176
raise ValueError(f"top_k must be 0 (disable), or at least 1, " f"got {self.top_k}.")
171177
if not isinstance(self.top_k, int):
172178
raise TypeError(f"top_k must be an integer, got {type(self.top_k).__name__}")
179+
if not 0.0 <= self.min_p <= 1.0:
180+
raise ValueError("min_p must be in [0,1],got f{self.min_p}")
173181

174182
if self.max_tokens is not None and self.max_tokens < 1:
175183
raise ValueError(f"max_tokens must be at least 1, got {self.max_tokens}.")

fastdeploy/entrypoints/openai/protocol.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,7 @@ class CompletionRequest(BaseModel):
339339
temperature: Optional[float] = None
340340
top_p: Optional[float] = None
341341
top_k: Optional[int] = None
342+
min_p: Optional[float] = None
342343
user: Optional[str] = None
343344

344345
response_format: Optional[AnyResponseFormat] = None
@@ -460,6 +461,7 @@ class ChatCompletionRequest(BaseModel):
460461
temperature: Optional[float] = None
461462
top_p: Optional[float] = None
462463
top_k: Optional[int] = None
464+
min_p: Optional[float] = None
463465
user: Optional[str] = None
464466
metadata: Optional[dict] = None
465467

fastdeploy/model_executor/layers/sample/meta_data.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class SamplingMetadata:
4242

4343
top_p: paddle.Tensor
4444
top_k: Optional[paddle.Tensor] = None
45+
min_p: Optional[paddle.Tensor] = None
4546
max_num_logprobs: Optional[int] = None
4647
prompt_ids: Optional[paddle.Tensor] = None
4748
prompt_lens: Optional[paddle.Tensor] = None

fastdeploy/model_executor/layers/sample/ops/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,11 @@
1818
apply_penalty_multi_scores,
1919
apply_speculative_penalty_multi_scores,
2020
)
21-
from .top_k_top_p_sampling import top_k_top_p_sampling
21+
from .top_k_top_p_sampling import min_p_sampling, top_k_top_p_sampling
2222

2323
__all__ = [
2424
"apply_penalty_multi_scores",
2525
"apply_speculative_penalty_multi_scores",
2626
"top_k_top_p_sampling",
27+
"min_p_sampling",
2728
]

fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def top_k_top_p_sampling(
6060
6161
"""
6262
top_p_class = envs.FD_SAMPLING_CLASS.lower()
63+
6364
if top_p_class == "air":
6465
_, ids = air_top_p_sampling(x, top_p, threshold, topp_seed, seed=seed, k=k, mode=mode)
6566
elif top_p_class == "rejection":
@@ -154,3 +155,25 @@ def rejection_top_p_sampling(
154155
except ImportError:
155156
raise RuntimeError("Cannot import rejection_top_p_sampling op.")
156157
return ids
158+
159+
160+
def min_p_sampling(
161+
probs: paddle.tensor,
162+
min_p_arr: Optional[paddle.Tensor],
163+
) -> tuple[paddle.Tensor, paddle.Tensor]:
164+
"""
165+
min_p_sampling
166+
"""
167+
if paddle.count_nonzero(min_p_arr) == 0:
168+
return probs
169+
else:
170+
if current_platform.is_cuda():
171+
from fastdeploy.model_executor.ops.gpu import min_p_sampling
172+
173+
probs = min_p_sampling(probs, min_p_arr)
174+
else:
175+
max_probabilities = paddle.amax(probs, axis=-1, keepdim=True)
176+
adjusted_min_p = max_probabilities * min_p_arr
177+
invalid_token_mask = probs < adjusted_min_p.reshape([-1, 1])
178+
probs = paddle.where(invalid_token_mask, paddle.full_like(probs, 0.0), probs)
179+
return probs

0 commit comments

Comments
 (0)