Skip to content

[Feature] support min_p_sampling #2872

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Jul 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions custom_ops/gpu_ops/sample_kernels/min_p_sampling_from_probs.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "helper.h"
#include "paddle/phi/backends/context_pool.h"
#include "sample_kernels/sampling.cuh"

std::vector<paddle::Tensor> MinPSamplingFromProbs(const paddle::Tensor &probs,
const paddle::Tensor &min_p) {
std::vector<int64_t> probs_shape = probs.shape();
unsigned int batch_size = probs_shape[0];
unsigned int vocab_size = probs_shape[1];
auto cu_stream = probs.stream();

auto renorm_probs =
GetEmptyTensor(probs.dims(), paddle::DataType::FLOAT32, probs.place());

cudaError_t status;

status = sampling::MinPSamplingFromProb<float, int>(
const_cast<float *>(probs.data<float>()),
const_cast<float *>(min_p.data<float>()),
renorm_probs.data<float>(),
batch_size,
vocab_size,
true, // deterministic
cu_stream);


PD_CHECK(status == cudaSuccess, "SamplingFromProbs failed with error code " +
std::string(cudaGetErrorString(status)));

return {renorm_probs};
}

std::vector<std::vector<int64_t>>
MinPSamplingFromProbsInferShape(const std::vector<int64_t> &probs_shape,
const paddle::optional<std::vector<int64_t>> &min_p_shape) {
return {probs_shape};
}

std::vector<paddle::DataType>
MinPSamplingFromProbsInferDtype(const paddle::DataType &probs_dtype,
const paddle::optional<paddle::DataType> &min_p_dtype) {
return {probs_dtype};
}


PD_BUILD_STATIC_OP(min_p_sampling)
.Inputs({"probs", "min_p"})
.Outputs({"renorm_probs"})
.SetKernelFn(PD_KERNEL(MinPSamplingFromProbs))
.SetInferShapeFn(PD_INFER_SHAPE(MinPSamplingFromProbsInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(MinPSamplingFromProbsInferDtype));
73 changes: 73 additions & 0 deletions custom_ops/gpu_ops/sample_kernels/sampling.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,9 @@ __device__ __forceinline__ void DeviceSamplingFromProb(
aggregate += aggregate_local;
}




template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE, bool DETERMINISTIC,
typename DType, typename IdType>
Expand Down Expand Up @@ -391,6 +394,8 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output,
}
}



template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
bool DETERMINISTIC, typename DType, typename IdType>
Expand Down Expand Up @@ -553,6 +558,47 @@ struct RenormTempStorage {
};
};

template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE, bool DETERMINISTIC,
typename DType,typename IdType>
__global__ void MinPSamplingFromProbKernel(DType* probs, const float* min_p_arr,
DType* renormed_prob,uint32_t d) {
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
float p = (min_p_arr == nullptr) ? 0 : min_p_arr[bx];
const uint32_t row_idx = bx;

extern __shared__ __align__(
alignof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
uint8_t smem_sampling[];
auto& temp_storage =
reinterpret_cast<SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(
smem_sampling);

float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>>(
probs, row_idx, d, temp_storage);
float pivot = max_val * p;

vec_t<float, VEC_SIZE> probs_vec;
#pragma unroll 2
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
probs_vec.fill(0);
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
}

#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
probs_vec[j] = (probs_vec[j] >= pivot) ? probs_vec[j] : 0;
}
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
probs_vec.store(renormed_prob + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
}

}
}


template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
typename DType, typename IdType>
__global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* top_k_arr, uint32_t d) {
Expand Down Expand Up @@ -705,6 +751,33 @@ cudaError_t TopPSamplingFromProb(T *probs, IdType *output,
return cudaSuccess;
}

template <typename T,typename IdType>
cudaError_t MinPSamplingFromProb(T *probs, const T* min_p_arr,T *renormed_prob,
uint32_t batch_size,
uint32_t d, bool deterministic,
cudaStream_t stream = 0){
constexpr uint32_t BLOCK_THREADS = 1024;
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);

const uint32_t smem_size = sizeof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
dim3 nblks(batch_size);
dim3 nthrs(BLOCK_THREADS);
void* args[] = {&probs, &min_p_arr,&renormed_prob,&d};
DISPATCH_ALIGNED_VEC_SIZE(
vec_size, VEC_SIZE,
{DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
auto kernel =
MinPSamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO,
VEC_SIZE, DETERMINISTIC, T,IdType>;
CUDA_CALL(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
CUDA_CALL(cudaLaunchKernel((void *)kernel, nblks, nthrs, args,
smem_size, stream));
})});
return cudaSuccess;
}


template <typename T, typename IdType>
cudaError_t TopKTopPSamplingFromProb(T *probs, IdType *output,
uint32_t batch_size, const T *top_p_val, const IdType *top_k_val,
Expand Down
1 change: 1 addition & 0 deletions custom_ops/setup_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ def find_end_files(directory, end_str):
"gpu_ops/text_image_gather_scatter.cu",
"gpu_ops/sample_kernels/rejection_top_p_sampling.cu",
"gpu_ops/sample_kernels/top_k_renorm_probs.cu",
"gpu_ops/sample_kernels/min_p_sampling_from_probs.cu",
"gpu_ops/get_position_ids_and_mask_encoder_batch.cu",
"gpu_ops/fused_rotary_position_encoding.cu",
"gpu_ops/noaux_tc.cu",
Expand Down
1 change: 1 addition & 0 deletions docs/offline_inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ For ```LLM``` configuration, refer to [Parameter Documentation](parameters.md).
* temperature(float): Controls randomness (higher = more random)
* top_p(float): Probability threshold for token selection
* top_k(int): Number of tokens considered for sampling
* min_p(float): Minimum probability relative to the maximum probability for a token to be considered (>0 filters low-probability tokens to improve quality)
* max_tokens(int): Maximum generated tokens (input + output)
* min_tokens(int): Minimum forced generation length

Expand Down
1 change: 1 addition & 0 deletions docs/zh/offline_inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ for output in outputs:
* temperature(float): 控制生成随机性的参数,值越高结果越随机,值越低结果越确定
* top_p(float): 概率累积分布截断阈值,仅考虑累计概率达到此阈值的最可能token集合
* top_k(int): 采样概率最高的token数量,考虑概率最高的k个token进行采样
* min_p(float): token入选的最小概率阈值(相对于最高概率token的比值,设为>0可通过过滤低概率token来提升文本生成质量)
* max_tokens(int): 限制模型生成的最大token数量(包括输入和输出)
* min_tokens(int): 强制模型生成的最少token数量,避免过早结束

Expand Down
8 changes: 8 additions & 0 deletions fastdeploy/engine/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ class SamplingParams:
top_p: Float that controls the cumulative probability of the top tokens
to consider. Must be in [0, 1]. Set to 1 to consider all tokens.
top_k: Int that controls the number of top tokens to consider. Must be a positive integer.
min_p: Float that represents the minimum probability for a token to be
considered, relative to the probability of the most likely token.
Must be in [0, 1]. Set to 0 to disable this.
seed: Random seed to use for the generation.
stop: list of strings that stop the generation when they are generated.
The returned output will not contain the stop strings.
Expand Down Expand Up @@ -84,6 +87,7 @@ class SamplingParams:
temperature: float = None
top_p: float = None
top_k: int = 0
min_p: float = 0.0
seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None
stop_token_ids: Optional[Union[List[List[int]], List[int]]] = None
Expand Down Expand Up @@ -114,6 +118,7 @@ def from_optional(
temperature,
top_p,
top_k,
min_p,
seed=None,
stop=None,
stop_token_ids=None,
Expand All @@ -133,6 +138,7 @@ def from_optional(
temperature=temperature if temperature is not None else 1.0,
top_p=top_p,
top_k=top_k if top_k is not None else 0,
min_p=min_p if min_p is not None else 0.0,
seed=seed,
stop=stop,
stop_token_ids=stop_token_ids,
Expand Down Expand Up @@ -170,6 +176,8 @@ def _verify_args(self) -> None:
raise ValueError(f"top_k must be 0 (disable), or at least 1, " f"got {self.top_k}.")
if not isinstance(self.top_k, int):
raise TypeError(f"top_k must be an integer, got {type(self.top_k).__name__}")
if not 0.0 <= self.min_p <= 1.0:
raise ValueError("min_p must be in [0,1],got f{self.min_p}")

if self.max_tokens is not None and self.max_tokens < 1:
raise ValueError(f"max_tokens must be at least 1, got {self.max_tokens}.")
Expand Down
2 changes: 2 additions & 0 deletions fastdeploy/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ class CompletionRequest(BaseModel):
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
min_p: Optional[float] = None
user: Optional[str] = None

response_format: Optional[AnyResponseFormat] = None
Expand Down Expand Up @@ -460,6 +461,7 @@ class ChatCompletionRequest(BaseModel):
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
min_p: Optional[float] = None
user: Optional[str] = None
metadata: Optional[dict] = None

Expand Down
1 change: 1 addition & 0 deletions fastdeploy/model_executor/layers/sample/meta_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class SamplingMetadata:

top_p: paddle.Tensor
top_k: Optional[paddle.Tensor] = None
min_p: Optional[paddle.Tensor] = None
max_num_logprobs: Optional[int] = None
prompt_ids: Optional[paddle.Tensor] = None
prompt_lens: Optional[paddle.Tensor] = None
3 changes: 2 additions & 1 deletion fastdeploy/model_executor/layers/sample/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@
apply_penalty_multi_scores,
apply_speculative_penalty_multi_scores,
)
from .top_k_top_p_sampling import top_k_top_p_sampling
from .top_k_top_p_sampling import min_p_sampling, top_k_top_p_sampling

__all__ = [
"apply_penalty_multi_scores",
"apply_speculative_penalty_multi_scores",
"top_k_top_p_sampling",
"min_p_sampling",
]
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def top_k_top_p_sampling(

"""
top_p_class = envs.FD_SAMPLING_CLASS.lower()

if top_p_class == "air":
_, ids = air_top_p_sampling(x, top_p, threshold, topp_seed, seed=seed, k=k, mode=mode)
elif top_p_class == "rejection":
Expand Down Expand Up @@ -154,3 +155,25 @@ def rejection_top_p_sampling(
except ImportError:
raise RuntimeError("Cannot import rejection_top_p_sampling op.")
return ids


def min_p_sampling(
probs: paddle.tensor,
min_p_arr: Optional[paddle.Tensor],
) -> tuple[paddle.Tensor, paddle.Tensor]:
"""
min_p_sampling
"""
if paddle.count_nonzero(min_p_arr) == 0:
return probs
else:
if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import min_p_sampling

probs = min_p_sampling(probs, min_p_arr)
else:
max_probabilities = paddle.amax(probs, axis=-1, keepdim=True)
adjusted_min_p = max_probabilities * min_p_arr
invalid_token_mask = probs < adjusted_min_p.reshape([-1, 1])
probs = paddle.where(invalid_token_mask, paddle.full_like(probs, 0.0), probs)
return probs
4 changes: 4 additions & 0 deletions fastdeploy/model_executor/layers/sample/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from fastdeploy.model_executor.layers.sample.ops import (
apply_penalty_multi_scores,
apply_speculative_penalty_multi_scores,
min_p_sampling,
top_k_top_p_sampling,
)
from fastdeploy.platforms import current_platform
Expand Down Expand Up @@ -266,6 +267,8 @@ def forward_cuda(

probs = F.softmax(logits)

probs = min_p_sampling(probs, sampling_metadata.min_p)

_, next_tokens = top_k_top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k)

logprobs_tensors = (
Expand All @@ -281,6 +284,7 @@ def forward_cuda(
sampled_token_ids=next_tokens,
logprobs_tensors=logprobs_tensors,
)

return sampler_output


Expand Down
4 changes: 4 additions & 0 deletions fastdeploy/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,8 @@ def get_attr_from_request(request, attr, default_value=None):
self.share_inputs["eos_token_id"][:] = np.array(request.eos_token_ids, dtype="int64").reshape(-1, 1)
self.share_inputs["top_p"][idx : idx + 1] = get_attr_from_request(request, "top_p", 0.7)
self.share_inputs["top_k"][idx : idx + 1] = request.get("top_k", 0)
self.share_inputs["min_p"][idx : idx + 1] = request.get("min_p", 0.0)

self.share_inputs["temperature"][idx : idx + 1] = get_attr_from_request(request, "temperature", 0.95)
self.share_inputs["penalty_score"][idx : idx + 1] = get_attr_from_request(
request, "repetition_penalty", 1.0
Expand Down Expand Up @@ -430,6 +432,7 @@ def _init_share_inputs(self, max_num_seqs: int):
self.share_inputs["eos_token_id"] = paddle.full([self.parallel_config.eos_tokens_lens, 1], 0, dtype="int64")
self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1], self.model_config.top_p, dtype="float32")
self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int64")
self.share_inputs["min_p"] = paddle.full([max_num_seqs, 1], 0.0, dtype="float32")
self.share_inputs["temperature"] = paddle.full(
[max_num_seqs, 1], self.model_config.temperature, dtype="float32"
)
Expand Down Expand Up @@ -626,6 +629,7 @@ def _prepare_inputs(self) -> None:
temperature=self.share_inputs["temperature"],
top_p=self.share_inputs["top_p"],
top_k=self.share_inputs["top_k"],
min_p=self.share_inputs["min_p"],
step_idx=self.share_inputs["step_idx"],
pre_token_ids=self.share_inputs["pre_ids"],
prompt_ids=self.share_inputs["prompt_ids"],
Expand Down
3 changes: 3 additions & 0 deletions fastdeploy/worker/xpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ def process_prefill_inputs(self, req_dicts: List[Request]):
self.share_inputs["pre_ids"][idx : idx + 1] = -1
self.share_inputs["top_p"][idx : idx + 1] = request.get("top_p", 0.7)
self.share_inputs["top_k"][idx : idx + 1] = request.get("top_k", 0)
self.share_inputs["min_p"][idx : idx + 1] = request.get("min_p", 0.0)
self.share_inputs["temperature"][idx : idx + 1] = request.get("temperature", 0.95)
self.share_inputs["penalty_score"][idx : idx + 1] = request.get("repetition_penalty", 1.0)
self.share_inputs["frequency_score"][idx : idx + 1] = request.get("frequency_penalty", 0.0)
Expand Down Expand Up @@ -363,6 +364,7 @@ def _init_share_inputs(self, max_num_seqs: int):
self.share_inputs["eos_token_id"] = paddle.full([self.parallel_config.eos_tokens_lens, 1], 0, dtype="int64")
self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1], self.model_config.top_p, dtype="float32")
self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int64")
self.share_inputs["min_p"] = paddle.full([max_num_seqs, 1], 0.0, dtype="float32")
self.share_inputs["temperature"] = paddle.full(
[max_num_seqs, 1], self.model_config.temperature, dtype="float32"
)
Expand Down Expand Up @@ -473,6 +475,7 @@ def _prepare_inputs(self) -> None:
temperature=self.share_inputs["temperature"],
top_p=self.share_inputs["top_p"],
top_k=self.share_inputs["top_k"],
min_p=self.share_inputs["min_p"],
step_idx=self.share_inputs["step_idx"],
pre_token_ids=self.share_inputs["pre_ids"],
frequency_penalties=self.share_inputs["frequency_score"],
Expand Down
Loading
Loading