diff --git a/custom_ops/gpu_ops/sample_kernels/min_p_sampling_from_probs.cu b/custom_ops/gpu_ops/sample_kernels/min_p_sampling_from_probs.cu new file mode 100644 index 0000000000..c44c16b430 --- /dev/null +++ b/custom_ops/gpu_ops/sample_kernels/min_p_sampling_from_probs.cu @@ -0,0 +1,65 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "helper.h" +#include "paddle/phi/backends/context_pool.h" +#include "sample_kernels/sampling.cuh" + +std::vector MinPSamplingFromProbs(const paddle::Tensor &probs, + const paddle::Tensor &min_p) { + std::vector probs_shape = probs.shape(); + unsigned int batch_size = probs_shape[0]; + unsigned int vocab_size = probs_shape[1]; + auto cu_stream = probs.stream(); + + auto renorm_probs = + GetEmptyTensor(probs.dims(), paddle::DataType::FLOAT32, probs.place()); + + cudaError_t status; + + status = sampling::MinPSamplingFromProb( + const_cast(probs.data()), + const_cast(min_p.data()), + renorm_probs.data(), + batch_size, + vocab_size, + true, // deterministic + cu_stream); + + + PD_CHECK(status == cudaSuccess, "SamplingFromProbs failed with error code " + + std::string(cudaGetErrorString(status))); + + return {renorm_probs}; +} + +std::vector> +MinPSamplingFromProbsInferShape(const std::vector &probs_shape, + const paddle::optional> &min_p_shape) { + return {probs_shape}; +} + +std::vector +MinPSamplingFromProbsInferDtype(const paddle::DataType &probs_dtype, + const paddle::optional &min_p_dtype) { + return {probs_dtype}; +} + + +PD_BUILD_STATIC_OP(min_p_sampling) + .Inputs({"probs", "min_p"}) + .Outputs({"renorm_probs"}) + .SetKernelFn(PD_KERNEL(MinPSamplingFromProbs)) + .SetInferShapeFn(PD_INFER_SHAPE(MinPSamplingFromProbsInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(MinPSamplingFromProbsInferDtype)); diff --git a/custom_ops/gpu_ops/sample_kernels/sampling.cuh b/custom_ops/gpu_ops/sample_kernels/sampling.cuh index 7102c73d60..f14694fa1b 100644 --- a/custom_ops/gpu_ops/sample_kernels/sampling.cuh +++ b/custom_ops/gpu_ops/sample_kernels/sampling.cuh @@ -276,6 +276,9 @@ __device__ __forceinline__ void DeviceSamplingFromProb( aggregate += aggregate_local; } + + + template @@ -391,6 +394,8 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output, } } + + template @@ -553,6 +558,47 @@ struct RenormTempStorage { }; }; +template +__global__ void MinPSamplingFromProbKernel(DType* probs, const float* min_p_arr, + DType* renormed_prob,uint32_t d) { + const uint32_t bx = blockIdx.x, tx = threadIdx.x; + float p = (min_p_arr == nullptr) ? 0 : min_p_arr[bx]; + const uint32_t row_idx = bx; + + extern __shared__ __align__( + alignof(SamplingTempStorage)) + uint8_t smem_sampling[]; + auto& temp_storage = + reinterpret_cast&>( + smem_sampling); + + float max_val = GetMaxValue>( + probs, row_idx, d, temp_storage); + float pivot = max_val * p; + + vec_t probs_vec; +#pragma unroll 2 + for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { + probs_vec.fill(0); + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { + probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); + } + +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; ++j) { + probs_vec[j] = (probs_vec[j] >= pivot) ? probs_vec[j] : 0; + } + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { + probs_vec.store(renormed_prob + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); + } + + } +} + + template __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* top_k_arr, uint32_t d) { @@ -705,6 +751,33 @@ cudaError_t TopPSamplingFromProb(T *probs, IdType *output, return cudaSuccess; } +template +cudaError_t MinPSamplingFromProb(T *probs, const T* min_p_arr,T *renormed_prob, + uint32_t batch_size, + uint32_t d, bool deterministic, + cudaStream_t stream = 0){ + constexpr uint32_t BLOCK_THREADS = 1024; + const uint32_t vec_size = std::gcd(16 / sizeof(T), d); + + const uint32_t smem_size = sizeof(SamplingTempStorage); + dim3 nblks(batch_size); + dim3 nthrs(BLOCK_THREADS); + void* args[] = {&probs, &min_p_arr,&renormed_prob,&d}; + DISPATCH_ALIGNED_VEC_SIZE( + vec_size, VEC_SIZE, + {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { + auto kernel = + MinPSamplingFromProbKernel; + CUDA_CALL(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + CUDA_CALL(cudaLaunchKernel((void *)kernel, nblks, nthrs, args, + smem_size, stream)); + })}); + return cudaSuccess; +} + + template cudaError_t TopKTopPSamplingFromProb(T *probs, IdType *output, uint32_t batch_size, const T *top_p_val, const IdType *top_k_val, diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index de49ab4ead..6e25b6b13a 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -287,6 +287,7 @@ def find_end_files(directory, end_str): "gpu_ops/text_image_gather_scatter.cu", "gpu_ops/sample_kernels/rejection_top_p_sampling.cu", "gpu_ops/sample_kernels/top_k_renorm_probs.cu", + "gpu_ops/sample_kernels/min_p_sampling_from_probs.cu", "gpu_ops/get_position_ids_and_mask_encoder_batch.cu", "gpu_ops/fused_rotary_position_encoding.cu", "gpu_ops/noaux_tc.cu", diff --git a/docs/offline_inference.md b/docs/offline_inference.md index e1cdfb088f..45a77615a7 100644 --- a/docs/offline_inference.md +++ b/docs/offline_inference.md @@ -180,6 +180,7 @@ For ```LLM``` configuration, refer to [Parameter Documentation](parameters.md). * temperature(float): Controls randomness (higher = more random) * top_p(float): Probability threshold for token selection * top_k(int): Number of tokens considered for sampling +* min_p(float): Minimum probability relative to the maximum probability for a token to be considered (>0 filters low-probability tokens to improve quality) * max_tokens(int): Maximum generated tokens (input + output) * min_tokens(int): Minimum forced generation length diff --git a/docs/zh/offline_inference.md b/docs/zh/offline_inference.md index 828eae886c..015fc7b720 100644 --- a/docs/zh/offline_inference.md +++ b/docs/zh/offline_inference.md @@ -180,6 +180,7 @@ for output in outputs: * temperature(float): 控制生成随机性的参数,值越高结果越随机,值越低结果越确定 * top_p(float): 概率累积分布截断阈值,仅考虑累计概率达到此阈值的最可能token集合 * top_k(int): 采样概率最高的token数量,考虑概率最高的k个token进行采样 +* min_p(float): token入选的最小概率阈值(相对于最高概率token的比值,设为>0可通过过滤低概率token来提升文本生成质量) * max_tokens(int): 限制模型生成的最大token数量(包括输入和输出) * min_tokens(int): 强制模型生成的最少token数量,避免过早结束 diff --git a/fastdeploy/engine/sampling_params.py b/fastdeploy/engine/sampling_params.py index 336afcc8de..688351883b 100644 --- a/fastdeploy/engine/sampling_params.py +++ b/fastdeploy/engine/sampling_params.py @@ -53,6 +53,9 @@ class SamplingParams: top_p: Float that controls the cumulative probability of the top tokens to consider. Must be in [0, 1]. Set to 1 to consider all tokens. top_k: Int that controls the number of top tokens to consider. Must be a positive integer. + min_p: Float that represents the minimum probability for a token to be + considered, relative to the probability of the most likely token. + Must be in [0, 1]. Set to 0 to disable this. seed: Random seed to use for the generation. stop: list of strings that stop the generation when they are generated. The returned output will not contain the stop strings. @@ -84,6 +87,7 @@ class SamplingParams: temperature: float = None top_p: float = None top_k: int = 0 + min_p: float = 0.0 seed: Optional[int] = None stop: Optional[Union[str, List[str]]] = None stop_token_ids: Optional[Union[List[List[int]], List[int]]] = None @@ -114,6 +118,7 @@ def from_optional( temperature, top_p, top_k, + min_p, seed=None, stop=None, stop_token_ids=None, @@ -133,6 +138,7 @@ def from_optional( temperature=temperature if temperature is not None else 1.0, top_p=top_p, top_k=top_k if top_k is not None else 0, + min_p=min_p if min_p is not None else 0.0, seed=seed, stop=stop, stop_token_ids=stop_token_ids, @@ -170,6 +176,8 @@ def _verify_args(self) -> None: raise ValueError(f"top_k must be 0 (disable), or at least 1, " f"got {self.top_k}.") if not isinstance(self.top_k, int): raise TypeError(f"top_k must be an integer, got {type(self.top_k).__name__}") + if not 0.0 <= self.min_p <= 1.0: + raise ValueError("min_p must be in [0,1],got f{self.min_p}") if self.max_tokens is not None and self.max_tokens < 1: raise ValueError(f"max_tokens must be at least 1, got {self.max_tokens}.") diff --git a/fastdeploy/entrypoints/openai/protocol.py b/fastdeploy/entrypoints/openai/protocol.py index febb3dd0ca..fd8c970bbe 100644 --- a/fastdeploy/entrypoints/openai/protocol.py +++ b/fastdeploy/entrypoints/openai/protocol.py @@ -339,6 +339,7 @@ class CompletionRequest(BaseModel): temperature: Optional[float] = None top_p: Optional[float] = None top_k: Optional[int] = None + min_p: Optional[float] = None user: Optional[str] = None response_format: Optional[AnyResponseFormat] = None @@ -460,6 +461,7 @@ class ChatCompletionRequest(BaseModel): temperature: Optional[float] = None top_p: Optional[float] = None top_k: Optional[int] = None + min_p: Optional[float] = None user: Optional[str] = None metadata: Optional[dict] = None diff --git a/fastdeploy/model_executor/layers/sample/meta_data.py b/fastdeploy/model_executor/layers/sample/meta_data.py index 7f841f2cc7..69b2e3e198 100644 --- a/fastdeploy/model_executor/layers/sample/meta_data.py +++ b/fastdeploy/model_executor/layers/sample/meta_data.py @@ -42,6 +42,7 @@ class SamplingMetadata: top_p: paddle.Tensor top_k: Optional[paddle.Tensor] = None + min_p: Optional[paddle.Tensor] = None max_num_logprobs: Optional[int] = None prompt_ids: Optional[paddle.Tensor] = None prompt_lens: Optional[paddle.Tensor] = None diff --git a/fastdeploy/model_executor/layers/sample/ops/__init__.py b/fastdeploy/model_executor/layers/sample/ops/__init__.py index 16eb320b47..09834b305a 100644 --- a/fastdeploy/model_executor/layers/sample/ops/__init__.py +++ b/fastdeploy/model_executor/layers/sample/ops/__init__.py @@ -18,10 +18,11 @@ apply_penalty_multi_scores, apply_speculative_penalty_multi_scores, ) -from .top_k_top_p_sampling import top_k_top_p_sampling +from .top_k_top_p_sampling import min_p_sampling, top_k_top_p_sampling __all__ = [ "apply_penalty_multi_scores", "apply_speculative_penalty_multi_scores", "top_k_top_p_sampling", + "min_p_sampling", ] diff --git a/fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py b/fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py index 63da37802a..0edd4f42f9 100644 --- a/fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py +++ b/fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py @@ -60,6 +60,7 @@ def top_k_top_p_sampling( """ top_p_class = envs.FD_SAMPLING_CLASS.lower() + if top_p_class == "air": _, ids = air_top_p_sampling(x, top_p, threshold, topp_seed, seed=seed, k=k, mode=mode) elif top_p_class == "rejection": @@ -154,3 +155,25 @@ def rejection_top_p_sampling( except ImportError: raise RuntimeError("Cannot import rejection_top_p_sampling op.") return ids + + +def min_p_sampling( + probs: paddle.tensor, + min_p_arr: Optional[paddle.Tensor], +) -> tuple[paddle.Tensor, paddle.Tensor]: + """ + min_p_sampling + """ + if paddle.count_nonzero(min_p_arr) == 0: + return probs + else: + if current_platform.is_cuda(): + from fastdeploy.model_executor.ops.gpu import min_p_sampling + + probs = min_p_sampling(probs, min_p_arr) + else: + max_probabilities = paddle.amax(probs, axis=-1, keepdim=True) + adjusted_min_p = max_probabilities * min_p_arr + invalid_token_mask = probs < adjusted_min_p.reshape([-1, 1]) + probs = paddle.where(invalid_token_mask, paddle.full_like(probs, 0.0), probs) + return probs diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index e814f21dcf..ff32b154ae 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -30,6 +30,7 @@ from fastdeploy.model_executor.layers.sample.ops import ( apply_penalty_multi_scores, apply_speculative_penalty_multi_scores, + min_p_sampling, top_k_top_p_sampling, ) from fastdeploy.platforms import current_platform @@ -266,6 +267,8 @@ def forward_cuda( probs = F.softmax(logits) + probs = min_p_sampling(probs, sampling_metadata.min_p) + _, next_tokens = top_k_top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k) logprobs_tensors = ( @@ -281,6 +284,7 @@ def forward_cuda( sampled_token_ids=next_tokens, logprobs_tensors=logprobs_tensors, ) + return sampler_output diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index b3c046e1db..f2cd2af786 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -320,6 +320,8 @@ def get_attr_from_request(request, attr, default_value=None): self.share_inputs["eos_token_id"][:] = np.array(request.eos_token_ids, dtype="int64").reshape(-1, 1) self.share_inputs["top_p"][idx : idx + 1] = get_attr_from_request(request, "top_p", 0.7) self.share_inputs["top_k"][idx : idx + 1] = request.get("top_k", 0) + self.share_inputs["min_p"][idx : idx + 1] = request.get("min_p", 0.0) + self.share_inputs["temperature"][idx : idx + 1] = get_attr_from_request(request, "temperature", 0.95) self.share_inputs["penalty_score"][idx : idx + 1] = get_attr_from_request( request, "repetition_penalty", 1.0 @@ -430,6 +432,7 @@ def _init_share_inputs(self, max_num_seqs: int): self.share_inputs["eos_token_id"] = paddle.full([self.parallel_config.eos_tokens_lens, 1], 0, dtype="int64") self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1], self.model_config.top_p, dtype="float32") self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int64") + self.share_inputs["min_p"] = paddle.full([max_num_seqs, 1], 0.0, dtype="float32") self.share_inputs["temperature"] = paddle.full( [max_num_seqs, 1], self.model_config.temperature, dtype="float32" ) @@ -626,6 +629,7 @@ def _prepare_inputs(self) -> None: temperature=self.share_inputs["temperature"], top_p=self.share_inputs["top_p"], top_k=self.share_inputs["top_k"], + min_p=self.share_inputs["min_p"], step_idx=self.share_inputs["step_idx"], pre_token_ids=self.share_inputs["pre_ids"], prompt_ids=self.share_inputs["prompt_ids"], diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index a7e30d6fe7..601d7f264e 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -304,6 +304,7 @@ def process_prefill_inputs(self, req_dicts: List[Request]): self.share_inputs["pre_ids"][idx : idx + 1] = -1 self.share_inputs["top_p"][idx : idx + 1] = request.get("top_p", 0.7) self.share_inputs["top_k"][idx : idx + 1] = request.get("top_k", 0) + self.share_inputs["min_p"][idx : idx + 1] = request.get("min_p", 0.0) self.share_inputs["temperature"][idx : idx + 1] = request.get("temperature", 0.95) self.share_inputs["penalty_score"][idx : idx + 1] = request.get("repetition_penalty", 1.0) self.share_inputs["frequency_score"][idx : idx + 1] = request.get("frequency_penalty", 0.0) @@ -363,6 +364,7 @@ def _init_share_inputs(self, max_num_seqs: int): self.share_inputs["eos_token_id"] = paddle.full([self.parallel_config.eos_tokens_lens, 1], 0, dtype="int64") self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1], self.model_config.top_p, dtype="float32") self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int64") + self.share_inputs["min_p"] = paddle.full([max_num_seqs, 1], 0.0, dtype="float32") self.share_inputs["temperature"] = paddle.full( [max_num_seqs, 1], self.model_config.temperature, dtype="float32" ) @@ -473,6 +475,7 @@ def _prepare_inputs(self) -> None: temperature=self.share_inputs["temperature"], top_p=self.share_inputs["top_p"], top_k=self.share_inputs["top_k"], + min_p=self.share_inputs["min_p"], step_idx=self.share_inputs["step_idx"], pre_token_ids=self.share_inputs["pre_ids"], frequency_penalties=self.share_inputs["frequency_score"], diff --git a/test/layers/test_min_sampling.py b/test/layers/test_min_sampling.py new file mode 100644 index 0000000000..624e00e125 --- /dev/null +++ b/test/layers/test_min_sampling.py @@ -0,0 +1,113 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import numpy as np +import paddle +import paddle.nn.functional as F + +from fastdeploy.model_executor.ops.gpu import min_p_sampling + + +class TestMinPSampling(unittest.TestCase): + def setUp(self): + self.sample_time = 1000000 + self.vocab_size = 1000 + self.min_p_value = 0.5 + self.batch_size = 3 + self.batch_min_p_values = [0.1, 0.0, 0.9] + self.additional_batch_min_p_values = [0.1, 0.0, 0.3] + + # min_p:0.5:FastDeploy + def min_p_sampling_cpu(self, min_p): + logits = paddle.ones(shape=[1, self.vocab_size], dtype="float32") + logits[0][0] = 10 + logits[0][1] = 8 + low_prob_tensor = paddle.linspace(2.0, 0.0, self.vocab_size - 2) + logits[0][2:] = low_prob_tensor + + probs = F.softmax(logits) + max_probabilities = paddle.amax(probs, axis=-1, keepdim=True) + adjusted_min_p = max_probabilities * min_p.reshape([-1, 1]) + invalid_token_mask = probs < adjusted_min_p + probs = paddle.where(invalid_token_mask, paddle.full_like(probs, 0.0), probs) + return probs + + # min_p:0.5:FastDeploy + def fastdeploy_min_p_sampling(self, min_p): + logits = paddle.ones(shape=[1, self.vocab_size], dtype="float32") + logits[0][0] = 10 + logits[0][1] = 8 + low_prob_tensor = paddle.linspace(2.0, 0.0, self.vocab_size - 2) + logits[0][2:] = low_prob_tensor + + probs = F.softmax(logits) + probs = min_p_sampling(probs, min_p) + return probs + + # batch:[0.1.0.0,0.9]:FastDeploy + def fastdeploy_batch_min_p_sampling(self, batch_size, min_p_values): + logits = paddle.ones(shape=[batch_size, self.vocab_size], dtype="float32") + for b in range(batch_size): + logits[b][0] = 10 + logits[b][1] = 8 + logits[b][2:] = paddle.linspace(2.0, 0.0, self.vocab_size - 2) + + probs = F.softmax(logits, axis=-1) + min_p_arr = paddle.to_tensor(min_p_values, dtype="float32") + + probs = min_p_sampling(probs, min_p_arr) + + return probs + + def compare_results(self, probs, probs_cpu, atol=1e-6, rtol=1e-6): + probs_np = probs.numpy() + probs_cpu_np = probs_cpu.numpy() + try: + np.testing.assert_allclose( + probs_np, + probs_cpu_np, + rtol=rtol, + atol=atol, + ) + print("The results are same between fastdeploy_min_p_sampling and min_p_sampling_cpu") + except AssertionError as e: + raise AssertionError( + f"The results are different between fastdeploy_min_p_sampling and min_p_sampling_cpu:\n{str(e)}" + ) + + def test_single_min_p_sampling(self): + min_p = paddle.to_tensor([self.min_p_value], dtype="float32") + probs = self.fastdeploy_min_p_sampling(min_p) + probs_cpu = self.min_p_sampling_cpu(min_p) + self.compare_results(probs, probs_cpu) + + def test_batch_min_p_sampling(self): + batch_min_p = paddle.to_tensor(self.batch_min_p_values, dtype="float32") + batch_probs = self.fastdeploy_batch_min_p_sampling(self.batch_size, batch_min_p) + batch_probs_cpu = self.min_p_sampling_cpu(batch_min_p) + self.compare_results(batch_probs, batch_probs_cpu) + + def test_additional_batch_min_p_sampling(self): + additional_batch_min_p = paddle.to_tensor(self.additional_batch_min_p_values, dtype="float32") + additional_batch_probs = self.fastdeploy_batch_min_p_sampling(self.batch_size, additional_batch_min_p) + additional_batch_probs_cpu = self.min_p_sampling_cpu(additional_batch_min_p) + self.compare_results(additional_batch_probs, additional_batch_probs_cpu) + + +if __name__ == "__main__": + if paddle.is_compiled_with_cuda(): + unittest.main() diff --git a/test/layers/test_sampler.py b/test/layers/test_sampler.py index 65a6bfbe68..c46e2c8bdd 100644 --- a/test/layers/test_sampler.py +++ b/test/layers/test_sampler.py @@ -73,5 +73,6 @@ def test_sampler(): print(next_tokens) + if __name__ == "__main__": test_sampler()