From da68f6df4f0566d4f30ee042fabcc9b99e09f537 Mon Sep 17 00:00:00 2001 From: lizexu123 <2694294196@qq.com> Date: Mon, 14 Jul 2025 08:27:34 +0000 Subject: [PATCH 01/16] Fastdeploy support min_p --- .../min_p_sampling_from_probs.cu | 64 ++++++++++ .../gpu_ops/sample_kernels/sampling.cuh | 112 ++++++++++++++++++ custom_ops/setup_ops.py | 1 + docs/offline_inference.md | 1 + docs/zh/offline_inference.md | 1 + fastdeploy/engine/sampling_params.py | 8 ++ fastdeploy/entrypoints/openai/protocol.py | 2 + .../model_executor/layers/sample/meta_data.py | 1 + .../layers/sample/ops/top_k_top_p_sampling.py | 15 +++ .../model_executor/layers/sample/sampler.py | 6 +- fastdeploy/worker/gpu_model_runner.py | 5 + fastdeploy/worker/vl_gpu_model_runner.py | 4 +- test/layers/test_sampler.py | 1 + 13 files changed, 219 insertions(+), 2 deletions(-) create mode 100644 custom_ops/gpu_ops/sample_kernels/min_p_sampling_from_probs.cu diff --git a/custom_ops/gpu_ops/sample_kernels/min_p_sampling_from_probs.cu b/custom_ops/gpu_ops/sample_kernels/min_p_sampling_from_probs.cu new file mode 100644 index 0000000000..9802e9b512 --- /dev/null +++ b/custom_ops/gpu_ops/sample_kernels/min_p_sampling_from_probs.cu @@ -0,0 +1,64 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "helper.h" +#include "paddle/phi/backends/context_pool.h" +#include "sample_kernels/sampling.cuh" + +std::vector MinPSamplingFromProbs(const paddle::Tensor &probs, + const paddle::Tensor &min_p, + int seed) { + std::vector probs_shape = probs.shape(); + unsigned int batch_size = probs_shape[0]; + unsigned int vocab_size = probs_shape[1]; + uint64_t philox_seed = seed; + uint64_t philox_offset = 0; + auto cu_stream = probs.stream(); + + auto samples = + paddle::empty({batch_size, 1}, paddle::DataType::INT64, probs.place()); + + cudaError_t status; + + status = sampling::MinPSamplingFromProb( + const_cast(probs.data()),samples.data(), + batch_size,min_p.data(),vocab_size,true,philox_seed,philox_offset,cu_stream); + + PD_CHECK(status == cudaSuccess, "SamplingFromProbs failed with error code " + + std::string(cudaGetErrorString(status))); + + return {samples}; +} + +std::vector> +MinPSamplingFromProbsInferShape(const std::vector &probs_shape, + const paddle::optional> &min_p_shape) { + int64_t bs = probs_shape[0]; + return {{bs, 1}}; +} + +std::vector +MinPSamplingFromProbsInferDtype(const paddle::DataType &probs_dtype, + const paddle::optional &min_p_dtype) { + return {paddle::DataType::INT64}; +} + + +PD_BUILD_STATIC_OP(min_p_sampling) + .Inputs({"probs", "min_p"}) + .Outputs({"samples"}) + .Attrs({"seed: int"}) + .SetKernelFn(PD_KERNEL(MinPSamplingFromProbs)) + .SetInferShapeFn(PD_INFER_SHAPE(MinPSamplingFromProbsInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(MinPSamplingFromProbsInferDtype)); diff --git a/custom_ops/gpu_ops/sample_kernels/sampling.cuh b/custom_ops/gpu_ops/sample_kernels/sampling.cuh index 7102c73d60..b6b48b838f 100644 --- a/custom_ops/gpu_ops/sample_kernels/sampling.cuh +++ b/custom_ops/gpu_ops/sample_kernels/sampling.cuh @@ -276,6 +276,9 @@ __device__ __forceinline__ void DeviceSamplingFromProb( aggregate += aggregate_local; } + + + template @@ -391,6 +394,8 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* output, } } + + template @@ -553,6 +558,85 @@ struct RenormTempStorage { }; }; +template +__global__ void MinPSamplingFromProbKernel(DType* probs, IdType* output, + float* min_p_arr, uint32_t d, + uint64_t philox_seed, uint64_t philox_offset) { + const uint32_t bx = blockIdx.x, tx = threadIdx.x; + float p = (min_p_arr == nullptr) ? 1e-6 : min_p_arr[bx]; + curandStatePhilox4_32_10_t state; + curand_init(philox_seed, bx, philox_offset, &state); + const uint32_t row_idx = bx; + + extern __shared__ __align__( + alignof(SamplingTempStorage)) + uint8_t smem_sampling[]; + auto& temp_storage = + reinterpret_cast&>( + smem_sampling); + + float max_val = GetMaxValue>( + probs, row_idx, d, temp_storage); + float pivot = max_val * p; + + vec_t probs_vec; + float aggregate_gt_pivot = 0; +#pragma unroll 2 + for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { + probs_vec.fill(0); + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { + probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); + } + + float probs_gt_pivot[VEC_SIZE]; +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; ++j) { + probs_gt_pivot[j] = (probs_vec[j] >= pivot) ? probs_vec[j] : 0; + } + + aggregate_gt_pivot += BlockReduce(temp_storage.block_prim.reduce) + .Sum(probs_gt_pivot); + if (tx == 0) { + temp_storage.block_aggregate.value = aggregate_gt_pivot; + } + __syncthreads(); + } + + float aggregate = 0; + float q = temp_storage.block_aggregate.value; + + int sampled_id; + temp_storage.sampled_id = d; + __syncthreads(); + float u = curand_uniform(&state) * q; +#pragma unroll 2 + for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { + probs_vec.fill(0); + if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { + probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); + } + + DeviceSamplingFromProb( + i, d, [&](float x) { return x >= pivot; }, u, probs_vec, aggregate, &temp_storage); + if (aggregate > u) { + break; + } + } + sampled_id = temp_storage.sampled_id; + if (sampled_id == d) { + // NOTE(Zihao): this would happen when u is very close to 1 + // and the sum of probabilities is smaller than u + // In this case, we use the last valid index as the sampled id + sampled_id = temp_storage.last_valid_id; + } + output[bx] = sampled_id; +} + + template __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* top_k_arr, uint32_t d) { @@ -705,6 +789,34 @@ cudaError_t TopPSamplingFromProb(T *probs, IdType *output, return cudaSuccess; } +template +cudaError_t MinPSamplingFromProb(T *probs, IdType *output, + uint32_t batch_size, const T *min_p_val, + uint32_t d, bool deterministic, + uint64_t philox_seed, uint64_t philox_offset, + cudaStream_t stream = 0){ + constexpr uint32_t BLOCK_THREADS = 1024; + const uint32_t vec_size = std::gcd(16 / sizeof(T), d); + + const uint32_t smem_size = sizeof(SamplingTempStorage); + dim3 nblks(batch_size); + dim3 nthrs(BLOCK_THREADS); + void* args[] = {&probs, &output, &min_p_val,&d,&philox_seed,&philox_offset}; + DISPATCH_ALIGNED_VEC_SIZE( + vec_size, VEC_SIZE, + {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { + auto kernel = + MinPSamplingFromProbKernel; + CUDA_CALL(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + CUDA_CALL(cudaLaunchKernel((void *)kernel, nblks, nthrs, args, + smem_size, stream)); + })}); + return cudaSuccess; +} + + template cudaError_t TopKTopPSamplingFromProb(T *probs, IdType *output, uint32_t batch_size, const T *top_p_val, const IdType *top_k_val, diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index c002beeb66..b4c9939ac6 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -283,6 +283,7 @@ def find_end_files(directory, end_str): "gpu_ops/text_image_gather_scatter.cu", "gpu_ops/sample_kernels/rejection_top_p_sampling.cu", "gpu_ops/sample_kernels/top_k_renorm_probs.cu", + "gpu_ops/sample_kernels/min_p_sampling_from_probs.cu", "gpu_ops/get_position_ids_and_mask_encoder_batch.cu", "gpu_ops/fused_rotary_position_encoding.cu", "gpu_ops/noaux_tc.cu", diff --git a/docs/offline_inference.md b/docs/offline_inference.md index e1cdfb088f..45a77615a7 100644 --- a/docs/offline_inference.md +++ b/docs/offline_inference.md @@ -180,6 +180,7 @@ For ```LLM``` configuration, refer to [Parameter Documentation](parameters.md). * temperature(float): Controls randomness (higher = more random) * top_p(float): Probability threshold for token selection * top_k(int): Number of tokens considered for sampling +* min_p(float): Minimum probability relative to the maximum probability for a token to be considered (>0 filters low-probability tokens to improve quality) * max_tokens(int): Maximum generated tokens (input + output) * min_tokens(int): Minimum forced generation length diff --git a/docs/zh/offline_inference.md b/docs/zh/offline_inference.md index 828eae886c..04760e45f7 100644 --- a/docs/zh/offline_inference.md +++ b/docs/zh/offline_inference.md @@ -180,6 +180,7 @@ for output in outputs: * temperature(float): 控制生成随机性的参数,值越高结果越随机,值越低结果越确定 * top_p(float): 概率累积分布截断阈值,仅考虑累计概率达到此阈值的最可能token集合 * top_k(int): 采样概率最高的token数量,考虑概率最高的k个token进行采样 +* min_p(float): token入选的最小概率阈值(相对于最高概率token的比值,设为>0可通过过滤低概率token来提升文本生成质量) * max_tokens(int): 限制模型生成的最大token数量(包括输入和输出) * min_tokens(int): 强制模型生成的最少token数量,避免过早结束 diff --git a/fastdeploy/engine/sampling_params.py b/fastdeploy/engine/sampling_params.py index a7912407a8..bdca405dfc 100644 --- a/fastdeploy/engine/sampling_params.py +++ b/fastdeploy/engine/sampling_params.py @@ -53,6 +53,9 @@ class SamplingParams: top_p: Float that controls the cumulative probability of the top tokens to consider. Must be in [0, 1]. Set to 1 to consider all tokens. top_k: Int that controls the number of top tokens to consider. Must be a positive integer. + min_p:Float that represents the minimum probability for a token to be + considered, relative to the probability of the most likely token. + Must be in [0, 1]. Set to 0 to disable this. seed: Random seed to use for the generation. stop: list of strings that stop the generation when they are generated. The returned output will not contain the stop strings. @@ -84,6 +87,7 @@ class SamplingParams: temperature: float = None top_p: float = 1.0 top_k: int = 0 + min_p: float=0.0 seed: Optional[int] = None stop: Optional[Union[str, List[str]]] = None stop_token_ids: Optional[Union[List[List[int]], List[int]]] = None @@ -114,6 +118,7 @@ def from_optional(cls, temperature, top_p, top_k, + min_p, seed=None, stop=None, stop_token_ids=None, @@ -134,6 +139,7 @@ def from_optional(cls, temperature=temperature if temperature is not None else 1.0, top_p=top_p if top_p is not None else 1.0, top_k=top_k if top_k is not None else 0, + min_p=min_p if min_p is not None else 0.0, seed=seed, stop=stop, stop_token_ids=stop_token_ids, @@ -180,6 +186,8 @@ def _verify_args(self) -> None: if not isinstance(self.top_k, int): raise TypeError( f"top_k must be an integer, got {type(self.top_k).__name__}") + if not 0.0 <=self.min_p <= 1.0: + raise ValueError("min_p must be in [0,1],got f{self.min_p}") if self.max_tokens is not None and self.max_tokens < 1: raise ValueError( diff --git a/fastdeploy/entrypoints/openai/protocol.py b/fastdeploy/entrypoints/openai/protocol.py index d4391e567c..d55bfb7731 100644 --- a/fastdeploy/entrypoints/openai/protocol.py +++ b/fastdeploy/entrypoints/openai/protocol.py @@ -310,6 +310,7 @@ class CompletionRequest(BaseModel): temperature: Optional[float] = None top_p: Optional[float] = None top_k: Optional[int] = None + min_p: Optional[float] = None user: Optional[str] = None response_format: Optional[AnyResponseFormat] = None @@ -426,6 +427,7 @@ class ChatCompletionRequest(BaseModel): temperature: Optional[float] = None top_p: Optional[float] = None top_k: Optional[int] = None + min_p: Optional[float] = None user: Optional[str] = None metadata: Optional[dict] = None diff --git a/fastdeploy/model_executor/layers/sample/meta_data.py b/fastdeploy/model_executor/layers/sample/meta_data.py index 41a96ee1e8..83090ba976 100644 --- a/fastdeploy/model_executor/layers/sample/meta_data.py +++ b/fastdeploy/model_executor/layers/sample/meta_data.py @@ -42,4 +42,5 @@ class SamplingMetadata: top_p: paddle.Tensor top_k: Optional[paddle.Tensor] = None + min_p: Optional[paddle.Tensor] = None max_num_logprobs: Optional[int] = None diff --git a/fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py b/fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py index e364b13f21..22ead4ddba 100644 --- a/fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py +++ b/fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py @@ -154,3 +154,18 @@ def rejection_top_p_sampling( except ImportError: raise RuntimeError("Cannot import rejection_top_p_sampling op.") return ids + +def min_p_sampling( + x:paddle.tensor, + min_p:paddle.Tensor, + seed:int=-1 +)->paddle.Tensor: + """ + min_p_sampling + """ + try: + from fastdeploy.model_executor.ops.gpu import min_p_sampling + ids=min_p_sampling(x,min_p,seed) + except ImportError: + raise RuntimeError("Cannot import min_p_sampling op.") + return ids diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index 162bbc347f..a46c2a1aef 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -27,7 +27,7 @@ from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata from fastdeploy.model_executor.layers.sample.ops import ( apply_penalty_multi_scores, apply_speculative_penalty_multi_scores, - top_k_top_p_sampling) + min_p_sampling, top_k_top_p_sampling) from fastdeploy.platforms import current_platform from fastdeploy.worker.output import LogprobsTensors, SamplerOutput @@ -251,6 +251,7 @@ def forward_cuda( logits = self.processor.apply_token_mask(logits, skip_idx_list) + logits = apply_penalty_multi_scores( sampling_metadata.pre_token_ids, logits, @@ -266,6 +267,9 @@ def forward_cuda( probs = F.softmax(logits) + if hasattr(sampling_metadata,"min_p") and sampling_metadata.min_p > 0.0: + probs = min_p_sampling(probs, sampling_metadata.min_p) + _, next_tokens = top_k_top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k) logprobs_tensors = None if num_logprobs is None else \ diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index bb1080f75e..610670350d 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -247,6 +247,7 @@ def insert_prefill_inputs(self, req_dicts: List[Request]): request.eos_token_ids, dtype="int64").reshape(-1, 1) self.share_inputs["top_p"][idx:idx + 1] = request.get("top_p", 1.0) self.share_inputs["top_k"][idx:idx + 1] = request.get("top_k", 0) + self.share_inputs["min_p"][idx:idx + 1] = request.get("min_p",0.0) self.share_inputs["temperature"][idx:idx + 1] = request.get( "temperature", 0.95) self.share_inputs["penalty_score"][idx:idx + 1] = request.get( @@ -357,6 +358,9 @@ def _init_share_inputs(self, max_num_seqs: int): self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1], 0, dtype='int64') + self.share_inputs["min_p"] = paddle.full([max_num_seqs, 1], + 0.0, + dtype='float32') self.share_inputs["temperature"] = paddle.full( [max_num_seqs, 1], self.model_config.temperature, dtype='float32') self.share_inputs["penalty_score"] = paddle.full( @@ -581,6 +585,7 @@ def _prepare_inputs(self) -> None: temperature=self.share_inputs["temperature"], top_p=self.share_inputs["top_p"], top_k=self.share_inputs["top_k"], + min_p=self.share_inputs["min_p"], step_idx=self.share_inputs["step_idx"], pre_token_ids=self.share_inputs["pre_ids"], frequency_penalties=self.share_inputs["frequency_score"], diff --git a/fastdeploy/worker/vl_gpu_model_runner.py b/fastdeploy/worker/vl_gpu_model_runner.py index 5ad5c0f724..b37bcd17c9 100644 --- a/fastdeploy/worker/vl_gpu_model_runner.py +++ b/fastdeploy/worker/vl_gpu_model_runner.py @@ -31,6 +31,7 @@ SpeculativeConfig) from fastdeploy.input.ernie_tokenizer import ErnieBotTokenizer from fastdeploy.input.mm_processor import DataProcessor +from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.model_executor.layers.attention import get_attention_backend from fastdeploy.model_executor.layers.rotary_embedding import get_rope_3d from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata @@ -46,7 +47,6 @@ from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ( ScatterOp, VariableResolutionResamplerModel) from fastdeploy.platforms import current_platform -from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.worker.output import SamplerOutput from fastdeploy.worker.utils import check_safetensors_model from fastdeploy.worker.vl_model_runner_base import VLModelRunnerBase @@ -679,6 +679,8 @@ def get_numeric_value(task, key, default_value): get_numeric_value(task, "temperature", 0.2), "top_k": get_numeric_value(task, "top_k", 0), + "min_p": + get_numeric_value(task,"min_p",0.0), "penalty_score": get_numeric_value(task, "repetition_penalty", 1.0), "frequency_score": diff --git a/test/layers/test_sampler.py b/test/layers/test_sampler.py index 2887400d06..d9172b9ff4 100644 --- a/test/layers/test_sampler.py +++ b/test/layers/test_sampler.py @@ -91,5 +91,6 @@ def test_sampler(): print(next_tokens) + if __name__ == "__main__": test_sampler() From 4df179f2444ef725c5c84d41006f16baf48b091c Mon Sep 17 00:00:00 2001 From: lizexu123 <2694294196@qq.com> Date: Mon, 14 Jul 2025 10:42:17 +0000 Subject: [PATCH 02/16] add test_min_p --- test/layers/test_min_p.py | 211 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 211 insertions(+) create mode 100644 test/layers/test_min_p.py diff --git a/test/layers/test_min_p.py b/test/layers/test_min_p.py new file mode 100644 index 0000000000..559b1dbcd6 --- /dev/null +++ b/test/layers/test_min_p.py @@ -0,0 +1,211 @@ +import flashinfer.sampling +import matplotlib.pyplot as plt +import numpy as np +import paddle +import paddle.nn.functional as F +import torch +from tqdm import tqdm + +from fastdeploy.model_executor.ops.gpu import min_p_sampling + +# 设置测试参数 +sample_time = 1000000 +vocab_size = 1000 +min_p_value = 0.5 + +# 定义压缩函数 +def compress(data): + new_data = np.array([0, 0, 0], dtype=float) + new_data[0] = data[0] + new_data[1] = data[1] + new_data[2] = np.sum(data[2:]) + return new_data + +# FastDeploy 采样函数 +def fastdeploy_min_p_sampling(): + # 创建 logits + logits = paddle.ones(shape=[1, vocab_size], dtype="float32") + logits[0][0] = 10 + logits[0][1] = 8 + low_prob_tensor = paddle.linspace(2.0, 0.0, vocab_size - 2) + logits[0][2:] = low_prob_tensor + + # 计算概率 + probs = F.softmax(logits) + min_p = paddle.to_tensor([min_p_value], dtype="float32") + + # 计算允许的 token + max_prob = probs.max().item() + threshold = max_prob * min_p.item() + allowed_tokens = paddle.where(probs[0] >= threshold)[0].numpy() + + # 初始化统计变量 + sample_freq = [0] * vocab_size + low_prob_token_times = 0 + low_prob_token_probs = [] + + # 执行采样 + for i in tqdm(range(sample_time), desc="FastDeploy Sampling"): + ids = min_p_sampling(probs, min_p, seed=-1) + sample_freq[ids.item()] += 1 + if ids.item() >= 2: + low_prob_token_times += 1 + low_prob_token_probs.append(low_prob_token_times / (i + 1)) + + # 处理采样结果 + sample_freq = np.array(sample_freq, dtype=float) / sample_time + low_prob_token_probs = np.array(low_prob_token_probs, dtype=float) + + # 压缩数据 + ori_data1 = probs.numpy().reshape(-1) + data1 = compress(ori_data1) # 原始概率 + data2 = compress(sample_freq) # 采样概率 + + # 计算理论归一化概率 + allowed_probs = probs[0, allowed_tokens].numpy() + norm_scale = np.sum(allowed_probs) + data3 = np.zeros_like(data1) + for idx in allowed_tokens: + if idx < 2: + data3[idx] = ori_data1[idx] / norm_scale + else: + data3[2] += ori_data1[idx] / norm_scale + + # 绘制柱状图 + plot_bar_chart(data1, data2, data3, "FastDeploy[min_p_sampling]") + # 绘制曲线图 + plot_low_prob_curve(low_prob_token_probs, sample_time, "FastDeploy[min_p_sampling]") + + return data2, data3 + +# vLLM (FlashInfer) 采样函数 +def vllm_min_p_sampling(): + # 创建 logits + logits = torch.ones((1, vocab_size), dtype=torch.float32).cuda() + logits[0][0] = 10 + logits[0][1] = 8 + low_prob_tensor = torch.linspace(2.0, 0.0, vocab_size - 2).cuda() + logits[0][2:] = low_prob_tensor + + # 计算概率 + probs = torch.softmax(logits, dim=-1) + min_p = torch.tensor([min_p_value], dtype=torch.float32).cuda() + + # 计算允许的 token + max_prob = probs.max().item() + threshold = max_prob * min_p.item() + allowed_tokens = torch.where(probs[0] >= threshold)[0].cpu().numpy() + + # 初始化统计变量 + sample_freq = [0] * vocab_size + low_prob_token_times = 0 + low_prob_token_probs = [] + + # 执行采样 + for i in tqdm(range(sample_time), desc="FlashInfer Sampling"): + ids = flashinfer.sampling.min_p_sampling_from_probs(probs, min_p, deterministic=False) + sample_freq[ids.item()] += 1 + if ids.item() >= 2: + low_prob_token_times += 1 + low_prob_token_probs.append(low_prob_token_times / (i + 1)) + + # 处理采样结果 + sample_freq = np.array(sample_freq, dtype=float) / sample_time + low_prob_token_probs = np.array(low_prob_token_probs, dtype=float) + + # 压缩数据 + ori_data1 = probs.cpu().numpy().reshape(-1) + data1 = compress(ori_data1) # 原始概率 + data2 = compress(sample_freq) # 采样概率 + + # 计算理论归一化概率 + allowed_probs = probs[0, allowed_tokens].cpu().numpy() + norm_scale = np.sum(allowed_probs) + data3 = np.zeros_like(data1) + for idx in allowed_tokens: + if idx < 2: + data3[idx] = ori_data1[idx] / norm_scale + else: + data3[2] += ori_data1[idx] / norm_scale + + # 绘制柱状图 + plot_bar_chart(data1, data2, data3, "vLLM[min_p_sampling]") + # 绘制曲线图 + plot_low_prob_curve(low_prob_token_probs, sample_time, "vLLM[min_p_sampling]") + + return data2, data3 + +# 绘制柱状图的函数 +def plot_bar_chart(data1, data2, data3, title): + plt.figure(figsize=(6, 6)) + bar_width = 0.2 + idx = np.arange(len(data1)).astype(float) + + bars1 = plt.bar(idx - bar_width, data1, width=bar_width, color='salmon', label='原始概率', alpha=0.9) + bars2 = plt.bar(idx, data2, width=bar_width, color='skyblue', label='采样概率', alpha=0.9) + bars3 = plt.bar(idx + bar_width, data3, width=bar_width, color='orange', label='归一化原始概率', alpha=0.9) + + plt.bar_label(bars1, label_type='edge', padding=3, fmt='%.3f', fontsize=5, color='black') + plt.bar_label(bars2, label_type='edge', padding=3, fmt='%.3f', fontsize=5, color='red') + plt.bar_label(bars3, label_type='edge', padding=3, fmt='%.3f', fontsize=5, color='blue') + + plt.title(title, fontsize=14) + plt.xlabel("索引", fontsize=12) + plt.ylabel("概率", fontsize=12) + plt.ylim(0, 1.1) + plt.xlim(-1, 3) + plt.xticks(range(0, 3, 1)) + plt.legend(fontsize=10) + plt.grid(axis='y', linestyle='--', alpha=0.5) + output_path = f"{title.replace(' ', '_')}.png" + plt.savefig(output_path, dpi=300, bbox_inches='tight') + plt.clf() + +# 绘制低概率 token 概率曲线的函数 +def plot_low_prob_curve(low_prob_token_probs, sample_time, title): + plt.figure(figsize=(6, 6)) + plt.plot(np.arange(0, sample_time), low_prob_token_probs, marker='', linestyle='-', linewidth=1, color='blue') + plt.xlabel('采样次数') + plt.ylabel('概率') + plt.title('低概率 token 的概率') + plt.grid(alpha=0.3) + output_path = f"{title.replace(' ', '_')}_low_prob.png" + plt.savefig(output_path, dpi=300, bbox_inches='tight') + plt.clf() + +# 主函数 +def main(): + print("运行 FastDeploy 采样...") + data2_fastdeploy, data3_fastdeploy = fastdeploy_min_p_sampling() + print("运行 vLLM (FlashInfer) 采样...") + data2_vllm, data3_vllm = vllm_min_p_sampling() + + # 计算误差 + error_fastdeploy = np.abs(data2_fastdeploy - data3_fastdeploy) + error_vllm = np.abs(data2_vllm - data3_vllm) + + # 输出对比结果 + print("\nFastDeploy 对比结果:") + print(f"采样概率 (data2): {data2_fastdeploy}") + print(f"理论归一化概率 (data3): {data3_fastdeploy}") + print(f"误差: {error_fastdeploy}") + print(f"最大误差: {np.max(error_fastdeploy)}, 是否小于 1e-5: {np.max(error_fastdeploy) < 1e-5}") + + print("\nvLLM (FlashInfer) 对比结果:") + print(f"采样概率 (data2): {data2_vllm}") + print(f"理论归一化概率 (data3): {data3_vllm}") + print(f"误差: {error_vllm}") + print(f"最大误差: {np.max(error_vllm)}, 是否小于 1e-5: {np.max(error_vllm) < 1e-5}") + + # 检查误差是否满足要求 + if np.max(error_fastdeploy) < 1e-5 and np.max(error_vllm) < 1e-5: + print("\n结论:FastDeploy 和 vLLM 的采样结果均与理论值一致,误差小于 1e-5。") + else: + print("\n结论:存在误差大于 1e-5 的情况,建议增加采样次数或检查实现细节。") + +# 运行程序 +if __name__ == "__main__": + if paddle.device.is_compiled_with_cuda() and torch.cuda.is_available(): + main() + else: + print("GPU 不可用,请检查环境配置(需要支持 PaddlePaddle 和 PyTorch 的 CUDA)。") From f8e391dc71e07d246d95d9c9e1abc1610ab16445 Mon Sep 17 00:00:00 2001 From: lizexu123 <2694294196@qq.com> Date: Mon, 14 Jul 2025 11:02:13 +0000 Subject: [PATCH 03/16] fix --- fastdeploy/worker/xpu_model_runner.py | 7 +- test/layers/test_min_p.py | 96 +++++++++++---------------- 2 files changed, 44 insertions(+), 59 deletions(-) diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index 7fb585f8a5..d8fad53965 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -23,6 +23,7 @@ from fastdeploy.config import FDConfig from fastdeploy.engine.request import Request +from fastdeploy.model_executor.forward_meta import ForwardMeta, XPUForwardMeta from fastdeploy.model_executor.layers.attention import get_attention_backend from fastdeploy.model_executor.layers.attention.base_attention_backend import \ AttentionBackend @@ -31,7 +32,6 @@ from fastdeploy.model_executor.layers.sample.sampler import Sampler from fastdeploy.model_executor.model_loader import get_model_from_loader from fastdeploy.utils import get_logger -from fastdeploy.model_executor.forward_meta import ForwardMeta, XPUForwardMeta from fastdeploy.worker.model_runner_base import ModelRunnerBase from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput @@ -297,6 +297,7 @@ def process_prefill_inputs(self, req_dicts: List[Request]): self.share_inputs["pre_ids"][idx:idx + 1] = -1 self.share_inputs["top_p"][idx:idx + 1] = request.get("top_p", 1.0) self.share_inputs["top_k"][idx:idx + 1] = request.get("top_k", 0) + self.share_inputs["min_p"][idx:idx+1] = request.get("min_p", 0.0) self.share_inputs["temperature"][idx:idx + 1] = request.get( "temperature", 0.95) self.share_inputs["penalty_score"][idx:idx + 1] = request.get( @@ -369,6 +370,9 @@ def _init_share_inputs(self, max_num_seqs: int): self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1], 0, dtype='int64') + self.share_inputs["min_p"] = paddle.full([max_num_seqs, 1], + 0.0, + dtype='float32') self.share_inputs["temperature"] = paddle.full( [max_num_seqs, 1], self.model_config.temperature, dtype='float32') self.share_inputs["penalty_score"] = paddle.full( @@ -519,6 +523,7 @@ def _prepare_inputs(self) -> None: temperature=self.share_inputs["temperature"], top_p=self.share_inputs["top_p"], top_k=self.share_inputs["top_k"], + min_p=self.share_inputs["min_p"], step_idx=self.share_inputs["step_idx"], pre_token_ids=self.share_inputs["pre_ids"], frequency_penalties=self.share_inputs["frequency_score"], diff --git a/test/layers/test_min_p.py b/test/layers/test_min_p.py index 559b1dbcd6..8076d7b01d 100644 --- a/test/layers/test_min_p.py +++ b/test/layers/test_min_p.py @@ -8,12 +8,11 @@ from fastdeploy.model_executor.ops.gpu import min_p_sampling -# 设置测试参数 sample_time = 1000000 vocab_size = 1000 min_p_value = 0.5 -# 定义压缩函数 + def compress(data): new_data = np.array([0, 0, 0], dtype=float) new_data[0] = data[0] @@ -21,30 +20,25 @@ def compress(data): new_data[2] = np.sum(data[2:]) return new_data -# FastDeploy 采样函数 + def fastdeploy_min_p_sampling(): - # 创建 logits logits = paddle.ones(shape=[1, vocab_size], dtype="float32") logits[0][0] = 10 logits[0][1] = 8 low_prob_tensor = paddle.linspace(2.0, 0.0, vocab_size - 2) logits[0][2:] = low_prob_tensor - # 计算概率 probs = F.softmax(logits) min_p = paddle.to_tensor([min_p_value], dtype="float32") - # 计算允许的 token max_prob = probs.max().item() threshold = max_prob * min_p.item() allowed_tokens = paddle.where(probs[0] >= threshold)[0].numpy() - # 初始化统计变量 sample_freq = [0] * vocab_size low_prob_token_times = 0 low_prob_token_probs = [] - # 执行采样 for i in tqdm(range(sample_time), desc="FastDeploy Sampling"): ids = min_p_sampling(probs, min_p, seed=-1) sample_freq[ids.item()] += 1 @@ -52,16 +46,13 @@ def fastdeploy_min_p_sampling(): low_prob_token_times += 1 low_prob_token_probs.append(low_prob_token_times / (i + 1)) - # 处理采样结果 sample_freq = np.array(sample_freq, dtype=float) / sample_time low_prob_token_probs = np.array(low_prob_token_probs, dtype=float) - # 压缩数据 ori_data1 = probs.numpy().reshape(-1) - data1 = compress(ori_data1) # 原始概率 - data2 = compress(sample_freq) # 采样概率 + data1 = compress(ori_data1) + data2 = compress(sample_freq) - # 计算理论归一化概率 allowed_probs = probs[0, allowed_tokens].numpy() norm_scale = np.sum(allowed_probs) data3 = np.zeros_like(data1) @@ -71,37 +62,32 @@ def fastdeploy_min_p_sampling(): else: data3[2] += ori_data1[idx] / norm_scale - # 绘制柱状图 plot_bar_chart(data1, data2, data3, "FastDeploy[min_p_sampling]") - # 绘制曲线图 plot_low_prob_curve(low_prob_token_probs, sample_time, "FastDeploy[min_p_sampling]") return data2, data3 -# vLLM (FlashInfer) 采样函数 + def vllm_min_p_sampling(): - # 创建 logits logits = torch.ones((1, vocab_size), dtype=torch.float32).cuda() logits[0][0] = 10 logits[0][1] = 8 low_prob_tensor = torch.linspace(2.0, 0.0, vocab_size - 2).cuda() logits[0][2:] = low_prob_tensor - # 计算概率 probs = torch.softmax(logits, dim=-1) min_p = torch.tensor([min_p_value], dtype=torch.float32).cuda() - # 计算允许的 token max_prob = probs.max().item() threshold = max_prob * min_p.item() allowed_tokens = torch.where(probs[0] >= threshold)[0].cpu().numpy() - # 初始化统计变量 + sample_freq = [0] * vocab_size low_prob_token_times = 0 low_prob_token_probs = [] - # 执行采样 + for i in tqdm(range(sample_time), desc="FlashInfer Sampling"): ids = flashinfer.sampling.min_p_sampling_from_probs(probs, min_p, deterministic=False) sample_freq[ids.item()] += 1 @@ -109,16 +95,13 @@ def vllm_min_p_sampling(): low_prob_token_times += 1 low_prob_token_probs.append(low_prob_token_times / (i + 1)) - # 处理采样结果 sample_freq = np.array(sample_freq, dtype=float) / sample_time low_prob_token_probs = np.array(low_prob_token_probs, dtype=float) - # 压缩数据 ori_data1 = probs.cpu().numpy().reshape(-1) - data1 = compress(ori_data1) # 原始概率 - data2 = compress(sample_freq) # 采样概率 + data1 = compress(ori_data1) + data2 = compress(sample_freq) - # 计算理论归一化概率 allowed_probs = probs[0, allowed_tokens].cpu().numpy() norm_scale = np.sum(allowed_probs) data3 = np.zeros_like(data1) @@ -128,30 +111,27 @@ def vllm_min_p_sampling(): else: data3[2] += ori_data1[idx] / norm_scale - # 绘制柱状图 plot_bar_chart(data1, data2, data3, "vLLM[min_p_sampling]") - # 绘制曲线图 plot_low_prob_curve(low_prob_token_probs, sample_time, "vLLM[min_p_sampling]") return data2, data3 -# 绘制柱状图的函数 def plot_bar_chart(data1, data2, data3, title): plt.figure(figsize=(6, 6)) bar_width = 0.2 idx = np.arange(len(data1)).astype(float) - bars1 = plt.bar(idx - bar_width, data1, width=bar_width, color='salmon', label='原始概率', alpha=0.9) - bars2 = plt.bar(idx, data2, width=bar_width, color='skyblue', label='采样概率', alpha=0.9) - bars3 = plt.bar(idx + bar_width, data3, width=bar_width, color='orange', label='归一化原始概率', alpha=0.9) + bars1 = plt.bar(idx - bar_width, data1, width=bar_width, color='salmon', label='Original Probability', alpha=0.9) + bars2 = plt.bar(idx, data2, width=bar_width, color='skyblue', label='Sampled Probability', alpha=0.9) + bars3 = plt.bar(idx + bar_width, data3, width=bar_width, color='orange', label='Normalized Original Probability', alpha=0.9) plt.bar_label(bars1, label_type='edge', padding=3, fmt='%.3f', fontsize=5, color='black') plt.bar_label(bars2, label_type='edge', padding=3, fmt='%.3f', fontsize=5, color='red') plt.bar_label(bars3, label_type='edge', padding=3, fmt='%.3f', fontsize=5, color='blue') plt.title(title, fontsize=14) - plt.xlabel("索引", fontsize=12) - plt.ylabel("概率", fontsize=12) + plt.xlabel("Index", fontsize=12) + plt.ylabel("Probability", fontsize=12) plt.ylim(0, 1.1) plt.xlim(-1, 3) plt.xticks(range(0, 3, 1)) @@ -161,51 +141,51 @@ def plot_bar_chart(data1, data2, data3, title): plt.savefig(output_path, dpi=300, bbox_inches='tight') plt.clf() -# 绘制低概率 token 概率曲线的函数 +# Function to plot low-probability token probability curve def plot_low_prob_curve(low_prob_token_probs, sample_time, title): plt.figure(figsize=(6, 6)) plt.plot(np.arange(0, sample_time), low_prob_token_probs, marker='', linestyle='-', linewidth=1, color='blue') - plt.xlabel('采样次数') - plt.ylabel('概率') - plt.title('低概率 token 的概率') + plt.xlabel('Sample Times') + plt.ylabel('Probability') + plt.title('Probability of Low-Probability Tokens') plt.grid(alpha=0.3) output_path = f"{title.replace(' ', '_')}_low_prob.png" plt.savefig(output_path, dpi=300, bbox_inches='tight') plt.clf() -# 主函数 +# Main function def main(): - print("运行 FastDeploy 采样...") + print("Running FastDeploy sampling...") data2_fastdeploy, data3_fastdeploy = fastdeploy_min_p_sampling() - print("运行 vLLM (FlashInfer) 采样...") + print("Running vLLM (FlashInfer) sampling...") data2_vllm, data3_vllm = vllm_min_p_sampling() - # 计算误差 + # Calculate errors error_fastdeploy = np.abs(data2_fastdeploy - data3_fastdeploy) error_vllm = np.abs(data2_vllm - data3_vllm) - # 输出对比结果 - print("\nFastDeploy 对比结果:") - print(f"采样概率 (data2): {data2_fastdeploy}") - print(f"理论归一化概率 (data3): {data3_fastdeploy}") - print(f"误差: {error_fastdeploy}") - print(f"最大误差: {np.max(error_fastdeploy)}, 是否小于 1e-5: {np.max(error_fastdeploy) < 1e-5}") + # Print comparison results + print("\nFastDeploy Comparison Results:") + print(f"Sampled Probability (data2): {data2_fastdeploy}") + print(f"Theoretical Normalized Probability (data3): {data3_fastdeploy}") + print(f"Error: {error_fastdeploy}") + print(f"Maximum Error: {np.max(error_fastdeploy)}, Is Less Than 1e-5: {np.max(error_fastdeploy) < 1e-5}") - print("\nvLLM (FlashInfer) 对比结果:") - print(f"采样概率 (data2): {data2_vllm}") - print(f"理论归一化概率 (data3): {data3_vllm}") - print(f"误差: {error_vllm}") - print(f"最大误差: {np.max(error_vllm)}, 是否小于 1e-5: {np.max(error_vllm) < 1e-5}") + print("\nvLLM (FlashInfer) Comparison Results:") + print(f"Sampled Probability (data2): {data2_vllm}") + print(f"Theoretical Normalized Probability (data3): {data3_vllm}") + print(f"Error: {error_vllm}") + print(f"Maximum Error: {np.max(error_vllm)}, Is Less Than 1e-5: {np.max(error_vllm) < 1e-5}") - # 检查误差是否满足要求 + # Check if errors meet the requirement if np.max(error_fastdeploy) < 1e-5 and np.max(error_vllm) < 1e-5: - print("\n结论:FastDeploy 和 vLLM 的采样结果均与理论值一致,误差小于 1e-5。") + print("\nConclusion: Both FastDeploy and vLLM sampling results are consistent with theoretical values, with errors less than 1e-5.") else: - print("\n结论:存在误差大于 1e-5 的情况,建议增加采样次数或检查实现细节。") + print("\nConclusion: There are cases where the error is greater than 1e-5. It is recommended to increase the number of samples or check the implementation details.") -# 运行程序 +# Run the program if __name__ == "__main__": if paddle.device.is_compiled_with_cuda() and torch.cuda.is_available(): main() else: - print("GPU 不可用,请检查环境配置(需要支持 PaddlePaddle 和 PyTorch 的 CUDA)。") + print("GPU is not available. Please check the environment configuration (requires support for PaddlePaddle and PyTorch with CUDA).") From edb4202d8d55eafdfb3fe0ba245cde2fb9652b38 Mon Sep 17 00:00:00 2001 From: lizexu123 <2694294196@qq.com> Date: Wed, 16 Jul 2025 13:25:22 +0000 Subject: [PATCH 04/16] min_p_sampling --- .../min_p_sampling_from_probs.cu | 4 +- .../gpu_ops/sample_kernels/sampling.cuh | 12 +- fastdeploy/demo/offline_demo.py | 13 +- fastdeploy/demo/openai_demo.py | 77 +++--- .../model_executor/layers/sample/meta_data.py | 4 +- .../layers/sample/ops/__init__.py | 3 +- .../layers/sample/ops/top_k_top_p_sampling.py | 30 ++- .../model_executor/layers/sample/sampler.py | 14 +- fastdeploy/worker/gpu_model_runner.py | 7 +- test/layers/test_min_p.py | 236 +++++++++--------- 10 files changed, 218 insertions(+), 182 deletions(-) diff --git a/custom_ops/gpu_ops/sample_kernels/min_p_sampling_from_probs.cu b/custom_ops/gpu_ops/sample_kernels/min_p_sampling_from_probs.cu index 9802e9b512..e8944421dd 100644 --- a/custom_ops/gpu_ops/sample_kernels/min_p_sampling_from_probs.cu +++ b/custom_ops/gpu_ops/sample_kernels/min_p_sampling_from_probs.cu @@ -32,8 +32,8 @@ std::vector MinPSamplingFromProbs(const paddle::Tensor &probs, cudaError_t status; status = sampling::MinPSamplingFromProb( - const_cast(probs.data()),samples.data(), - batch_size,min_p.data(),vocab_size,true,philox_seed,philox_offset,cu_stream); + const_cast(probs.data()),min_p.data(),samples.data(), + batch_size,vocab_size,true,philox_seed,philox_offset,cu_stream); PD_CHECK(status == cudaSuccess, "SamplingFromProbs failed with error code " + std::string(cudaGetErrorString(status))); diff --git a/custom_ops/gpu_ops/sample_kernels/sampling.cuh b/custom_ops/gpu_ops/sample_kernels/sampling.cuh index b6b48b838f..1ac7d3b670 100644 --- a/custom_ops/gpu_ops/sample_kernels/sampling.cuh +++ b/custom_ops/gpu_ops/sample_kernels/sampling.cuh @@ -561,11 +561,11 @@ struct RenormTempStorage { template -__global__ void MinPSamplingFromProbKernel(DType* probs, IdType* output, - float* min_p_arr, uint32_t d, +__global__ void MinPSamplingFromProbKernel(DType* probs, float* min_p_arr, + IdType* output,uint32_t d, uint64_t philox_seed, uint64_t philox_offset) { const uint32_t bx = blockIdx.x, tx = threadIdx.x; - float p = (min_p_arr == nullptr) ? 1e-6 : min_p_arr[bx]; + float p = (min_p_arr == nullptr) ? 0 : min_p_arr[bx]; curandStatePhilox4_32_10_t state; curand_init(philox_seed, bx, philox_offset, &state); const uint32_t row_idx = bx; @@ -790,8 +790,8 @@ cudaError_t TopPSamplingFromProb(T *probs, IdType *output, } template -cudaError_t MinPSamplingFromProb(T *probs, IdType *output, - uint32_t batch_size, const T *min_p_val, +cudaError_t MinPSamplingFromProb(T *probs, const T* min_p_arr,IdType *output, + uint32_t batch_size, uint32_t d, bool deterministic, uint64_t philox_seed, uint64_t philox_offset, cudaStream_t stream = 0){ @@ -801,7 +801,7 @@ cudaError_t MinPSamplingFromProb(T *probs, IdType *output, const uint32_t smem_size = sizeof(SamplingTempStorage); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &output, &min_p_val,&d,&philox_seed,&philox_offset}; + void* args[] = {&probs, &min_p_arr,&output,&d,&philox_seed,&philox_offset}; DISPATCH_ALIGNED_VEC_SIZE( vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { diff --git a/fastdeploy/demo/offline_demo.py b/fastdeploy/demo/offline_demo.py index 856757aa00..d4044510e9 100644 --- a/fastdeploy/demo/offline_demo.py +++ b/fastdeploy/demo/offline_demo.py @@ -17,13 +17,14 @@ from fastdeploy.engine.sampling_params import SamplingParams from fastdeploy.entrypoints.llm import LLM -model_name_or_path = "./models/llama-7b" +model_name_or_path = "/home/zexuli/Models/Qwen3-0.6B" # 超参设置 -sampling_params = SamplingParams(temperature=0.1, max_tokens=30) -llm = LLM(model=model_name_or_path, tensor_parallel_size=1) -output = llm.generate(prompts="who are you?", - use_tqdm=True, - sampling_params=sampling_params) +sampling_params = SamplingParams(temperature=0.1,top_p=0.7,top_k=10,min_p=0.1) +llm = LLM(model=model_name_or_path, tensor_parallel_size=1,reasoning_parser="qwen3") +prompt = "北京天安门在哪里?" +messages = [{"role": "user", "content": prompt}] +output = llm.chat([messages], + sampling_params) print(output) diff --git a/fastdeploy/demo/openai_demo.py b/fastdeploy/demo/openai_demo.py index 1b8b5862af..d7557e9ea7 100644 --- a/fastdeploy/demo/openai_demo.py +++ b/fastdeploy/demo/openai_demo.py @@ -18,44 +18,43 @@ import openai ip = "0.0.0.0" -service_http_port = "9809" # 服务配置的 +service_http_port = "9012" # 服务配置的 client = openai.Client(base_url=f"http://{ip}:{service_http_port}/v1", api_key="EMPTY_API_KEY") # 非流式返回 -response = client.completions.create( - model="default", - prompt="Hello, how are you?", - max_tokens=64, - stream=False, -) +# response = client.completions.create( +# model="default", +# prompt="Hello, how are you?", +# max_tokens=64, +# stream=False, +# ) -print(response.choices[0].text) -print("\n") +# print(response.choices[0].text) +# print("\n") -# 流式返回 -response = client.completions.create( - model="default", - prompt="Hello, how are you?", - max_tokens=100, - stream=True, -) +# # 流式返回 +# response = client.completions.create( +# model="default", +# prompt="Hello, how are you?", +# max_tokens=100, +# stream=True, +# ) -for chunk in response: - print(chunk.choices[0].text, end='') -print("\n") +# for chunk in response: +# print(chunk.choices[0].text, end='') +# print("\n") # Chat completion # 非流式返回 response = client.chat.completions.create( model="default", messages=[ - {"role": "user", "content": "Hello, who are you"}, - {"role": "assistant", "content": "I'm a helpful AI assistant."}, - {"role": "user", "content": "List 3 countries and their capitals."}, + {"role": "user", "content": "北京天安门在哪里?"}, ], - temperature=1, - max_tokens=64, + temperature=0.1, + top_p=0.7, + metadata={"min_p":0.1}, stream=False, ) @@ -64,19 +63,19 @@ # # 流式返回 -response = client.chat.completions.create( - model="default", - messages=[ - {"role": "user", "content": "Hello, who are you"}, - {"role": "assistant", "content": "I'm a helpful AI assistant."}, - {"role": "user", "content": "List 3 countries and their capitals."}, - ], - temperature=1, - max_tokens=64, - stream=True, -) +# response = client.chat.completions.create( +# model="default", +# messages=[ +# {"role": "user", "content": "Hello, who are you"}, +# {"role": "assistant", "content": "I'm a helpful AI assistant."}, +# {"role": "user", "content": "List 3 countries and their capitals."}, +# ], +# temperature=1, +# max_tokens=64, +# stream=True, +# ) -for chunk in response: - if chunk.choices[0].delta is not None: - print(chunk.choices[0].delta.content, end='') -print("\n") +# for chunk in response: +# if chunk.choices[0].delta is not None: +# print(chunk.choices[0].delta.content, end='') +# print("\n") diff --git a/fastdeploy/model_executor/layers/sample/meta_data.py b/fastdeploy/model_executor/layers/sample/meta_data.py index 83090ba976..9fd69f608d 100644 --- a/fastdeploy/model_executor/layers/sample/meta_data.py +++ b/fastdeploy/model_executor/layers/sample/meta_data.py @@ -15,7 +15,7 @@ """ from dataclasses import dataclass -from typing import Optional +from typing import Optional, Union import paddle @@ -42,5 +42,5 @@ class SamplingMetadata: top_p: paddle.Tensor top_k: Optional[paddle.Tensor] = None - min_p: Optional[paddle.Tensor] = None + min_p: Optional[Union[float, paddle.Tensor]] = None max_num_logprobs: Optional[int] = None diff --git a/fastdeploy/model_executor/layers/sample/ops/__init__.py b/fastdeploy/model_executor/layers/sample/ops/__init__.py index 37c803ca3e..1a77651f49 100644 --- a/fastdeploy/model_executor/layers/sample/ops/__init__.py +++ b/fastdeploy/model_executor/layers/sample/ops/__init__.py @@ -16,10 +16,11 @@ from .apply_penalty_multi_scores import ( apply_penalty_multi_scores, apply_speculative_penalty_multi_scores) -from .top_k_top_p_sampling import top_k_top_p_sampling +from .top_k_top_p_sampling import min_p_sampling, top_k_top_p_sampling __all__ = [ "apply_penalty_multi_scores", "apply_speculative_penalty_multi_scores", "top_k_top_p_sampling", + "min_p_sampling", ] diff --git a/fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py b/fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py index 22ead4ddba..0a50b0ed0c 100644 --- a/fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py +++ b/fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py @@ -137,6 +137,8 @@ def rejection_top_p_sampling( ) else: if order == "top_k_first": + print("走的这里吧?",x) + print("top_k",top_k) renorm_probs = top_k_renorm_probs(x, top_k) ids = rejection_top_p_sampling( renorm_probs, @@ -156,16 +158,28 @@ def rejection_top_p_sampling( return ids def min_p_sampling( - x:paddle.tensor, - min_p:paddle.Tensor, + logits:paddle.tensor, + min_p_arr:Optional[paddle.Tensor], seed:int=-1 -)->paddle.Tensor: +)-> tuple[paddle.Tensor, paddle.Tensor]: """ min_p_sampling """ - try: + _ = None + + if current_platform.is_cuda(): from fastdeploy.model_executor.ops.gpu import min_p_sampling - ids=min_p_sampling(x,min_p,seed) - except ImportError: - raise RuntimeError("Cannot import min_p_sampling op.") - return ids + ids=min_p_sampling(logits,min_p_arr,seed) + + return ids,_ + else: + probability_values= paddle.nn.functional.softmax(logits,axis=-1) + max_probabilities = paddle.amax(probability_values, + axis=-1, + keepdim=True) + adjusted_min_p = max_probabilities * min_p_arr + invalid_token_mask = probability_values < adjusted_min_p + logits = paddle.where(invalid_token_mask, + paddle.full_like(logits, -float('inf')), + logits) + return _,logits diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index a46c2a1aef..f8489b214c 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -265,12 +265,16 @@ def forward_cuda( sampling_metadata.eos_token_ids, ) + # print("sampling_metadata.min_p",sampling_metadata.min_p) probs = F.softmax(logits) - - if hasattr(sampling_metadata,"min_p") and sampling_metadata.min_p > 0.0: - probs = min_p_sampling(probs, sampling_metadata.min_p) - - _, next_tokens = top_k_top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k) + if sampling_metadata.min_p is not None: + next_tokens,probs= min_p_sampling(probs,sampling_metadata.min_p) + if next_tokens is not None: + pass + else: + _, next_tokens = top_k_top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k) + else: + _, next_tokens = top_k_top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k) logprobs_tensors = None if num_logprobs is None else \ self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=next_tokens) diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 610670350d..4f18c46bdd 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -78,6 +78,7 @@ def __init__( else: self.sampler = SpeculativeSampler(fd_config) + # Lazy initialize kv cache after model loading # self.kv_caches: list[paddle.Tensor] = [] @@ -580,12 +581,16 @@ def _prepare_inputs(self) -> None: # Initialize forward meta data self.initialize_forward_meta() + num_reqs = int((self.share_inputs["seq_lens_this_time"] > 0).sum()) + min_p_slice = self.share_inputs["min_p"][:num_reqs] + no_min_p = paddle.all(min_p_slice == 0.0).item() + # Get sampling metadata self.sampling_metadata = SamplingMetadata( temperature=self.share_inputs["temperature"], top_p=self.share_inputs["top_p"], top_k=self.share_inputs["top_k"], - min_p=self.share_inputs["min_p"], + min_p=None if no_min_p else self.share_inputs["min_p"], step_idx=self.share_inputs["step_idx"], pre_token_ids=self.share_inputs["pre_ids"], frequency_penalties=self.share_inputs["frequency_score"], diff --git a/test/layers/test_min_p.py b/test/layers/test_min_p.py index 8076d7b01d..15f5b508aa 100644 --- a/test/layers/test_min_p.py +++ b/test/layers/test_min_p.py @@ -1,4 +1,18 @@ -import flashinfer.sampling +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + import matplotlib.pyplot as plt import numpy as np import paddle @@ -11,6 +25,9 @@ sample_time = 1000000 vocab_size = 1000 min_p_value = 0.5 +batch_size = 3 +batch_min_p_values = [0.1, 0.5, 0.9] +batch_min_p_values2=[0,3,0,0,0.4] def compress(data): @@ -21,6 +38,45 @@ def compress(data): return new_data +def plot_bar_chart(data1, data2, data3, title, request_idx=None): + plt.figure(figsize=(6, 6)) + bar_width = 0.2 + idx = np.arange(len(data1)).astype(float) + + bars1 = plt.bar(idx - bar_width, data1, width=bar_width, color='salmon', label='Original Probability', alpha=0.9) + bars2 = plt.bar(idx, data2, width=bar_width, color='skyblue', label='Sampled Probability', alpha=0.9) + bars3 = plt.bar(idx + bar_width, data3, width=bar_width, color='orange', label='Normalized Original Probability', alpha=0.9) + + plt.bar_label(bars1, label_type='edge', padding=3, fmt='%.3f', fontsize=5, color='black') + plt.bar_label(bars2, label_type='edge', padding=3, fmt='%.3f', fontsize=5, color='red') + plt.bar_label(bars3, label_type='edge', padding=3, fmt='%.3f', fontsize=5, color='blue') + + full_title = title if request_idx is None else f"{title} (min_p={batch_min_p_values[request_idx]})" + plt.title(full_title, fontsize=14) + plt.xlabel("Index", fontsize=12) + plt.ylabel("Probability", fontsize=12) + plt.ylim(0, 1.1) + plt.xlim(-1, 3) + plt.xticks(range(0, 3, 1)) + plt.legend(fontsize=10) + plt.grid(axis='y', linestyle='--', alpha=0.5) + output_path = f"{title.replace(' ', '_')}{'' if request_idx is None else f'_req{request_idx}'}.png" + plt.savefig(output_path, dpi=300, bbox_inches='tight') + plt.clf() + +def plot_low_prob_curve(low_prob_token_probs, sample_time, title, request_idx=None): + plt.figure(figsize=(6, 6)) + plt.plot(np.arange(0, sample_time), low_prob_token_probs, marker='', linestyle='-', linewidth=1, color='blue') + plt.xlabel('Sample Times') + plt.ylabel('Probability') + full_title = 'Probability of Low-Probability Tokens' if request_idx is None else f"Low-Probability Tokens (min_p={batch_min_p_values[request_idx]})" + plt.title(full_title) + plt.grid(alpha=0.3) + output_path = f"{title.replace(' ', '_')}_low_prob{'' if request_idx is None else f'_req{request_idx}'}.png" + plt.savefig(output_path, dpi=300, bbox_inches='tight') + plt.clf() + +# min_p:0.5:FastDeploy def fastdeploy_min_p_sampling(): logits = paddle.ones(shape=[1, vocab_size], dtype="float32") logits[0][0] = 10 @@ -68,124 +124,80 @@ def fastdeploy_min_p_sampling(): return data2, data3 -def vllm_min_p_sampling(): - logits = torch.ones((1, vocab_size), dtype=torch.float32).cuda() - logits[0][0] = 10 - logits[0][1] = 8 - low_prob_tensor = torch.linspace(2.0, 0.0, vocab_size - 2).cuda() - logits[0][2:] = low_prob_tensor - - probs = torch.softmax(logits, dim=-1) - min_p = torch.tensor([min_p_value], dtype=torch.float32).cuda() - - max_prob = probs.max().item() - threshold = max_prob * min_p.item() - allowed_tokens = torch.where(probs[0] >= threshold)[0].cpu().numpy() - - - sample_freq = [0] * vocab_size - low_prob_token_times = 0 - low_prob_token_probs = [] +# batch:[0.1.0,5,0.9]:FastDeploy +def fastdeploy_batch_min_p_sampling(batch_size, min_p_values): + logits = paddle.ones(shape=[batch_size, vocab_size], dtype="float32") + for b in range(batch_size): + logits[b][0] = 10 + logits[b][1] = 8 + logits[b][2:] = paddle.linspace(2.0, 0.0, vocab_size - 2) + + probs = F.softmax(logits, axis=-1) + min_p_arr = paddle.to_tensor(min_p_values, dtype="float32") + + allowed_tokens_list = [] + for b in range(batch_size): + max_prob = probs[b].max().item() + threshold = max_prob * min_p_values[b] + allowed_tokens = paddle.where(probs[b] >= threshold)[0].numpy() + allowed_tokens_list.append(allowed_tokens) + + sample_freq = [np.zeros(vocab_size, dtype=float) for _ in range(batch_size)] + low_prob_token_times = [0] * batch_size + low_prob_token_probs = [[] for _ in range(batch_size)] + + for i in tqdm(range(sample_time), desc="FastDeploy Batch Sampling"): + ids = min_p_sampling(probs, min_p_arr, seed=-1) + for b in range(batch_size): + sample_freq[b][ids[b].item()] += 1 + if ids[b].item() >= 2: + low_prob_token_times[b] += 1 + low_prob_token_probs[b].append(low_prob_token_times[b] / (i + 1)) + + data2_list = [] + data3_list = [] + for b in range(batch_size): + sample_freq_b = sample_freq[b] / sample_time + low_prob_token_probs[b] = np.array(low_prob_token_probs[b], dtype=float) + + ori_data1 = probs[b].numpy() + data1 = compress(ori_data1) + data2 = compress(sample_freq_b) + data2_list.append(data2) + + allowed_probs = probs[b, allowed_tokens_list[b]].numpy() + norm_scale = np.sum(allowed_probs) + data3 = np.zeros_like(data1) + for idx in allowed_tokens_list[b]: + if idx < 2: + data3[idx] = ori_data1[idx] / norm_scale + else: + data3[2] += ori_data1[idx] / norm_scale + data3_list.append(data3) + + plot_bar_chart(data1, data2, data3, "FastDeploy[min_p_batch_sampling]", b) + plot_low_prob_curve(low_prob_token_probs[b], sample_time, "FastDeploy[min_p_batch_sampling]", b) + + return data2_list, data3_list - for i in tqdm(range(sample_time), desc="FlashInfer Sampling"): - ids = flashinfer.sampling.min_p_sampling_from_probs(probs, min_p, deterministic=False) - sample_freq[ids.item()] += 1 - if ids.item() >= 2: - low_prob_token_times += 1 - low_prob_token_probs.append(low_prob_token_times / (i + 1)) - - sample_freq = np.array(sample_freq, dtype=float) / sample_time - low_prob_token_probs = np.array(low_prob_token_probs, dtype=float) - - ori_data1 = probs.cpu().numpy().reshape(-1) - data1 = compress(ori_data1) - data2 = compress(sample_freq) - - allowed_probs = probs[0, allowed_tokens].cpu().numpy() - norm_scale = np.sum(allowed_probs) - data3 = np.zeros_like(data1) - for idx in allowed_tokens: - if idx < 2: - data3[idx] = ori_data1[idx] / norm_scale - else: - data3[2] += ori_data1[idx] / norm_scale +def main(): + print("Running single min_p sampling (min_p=0.5)...") + data2_fastdeploy, data3_fastdeploy = fastdeploy_min_p_sampling() - plot_bar_chart(data1, data2, data3, "vLLM[min_p_sampling]") - plot_low_prob_curve(low_prob_token_probs, sample_time, "vLLM[min_p_sampling]") + print("\nFastDeploy Single Request Results:") + print(f"Sampled Probability: {data2_fastdeploy}") + print(f"Theoretical Normalized Probability: {data3_fastdeploy}") - return data2, data3 + print("\nRunning batch min_p sampling (min_p=[0.1, 0.5, 0.9])...") + data2_fd_batch, data3_fd_batch = fastdeploy_batch_min_p_sampling(batch_size, batch_min_p_values) -def plot_bar_chart(data1, data2, data3, title): - plt.figure(figsize=(6, 6)) - bar_width = 0.2 - idx = np.arange(len(data1)).astype(float) + data2_fd_batch,data3_fd_batch = fastdeploy_batch_min_p_sampling(batch_size,batch_min_p_values2) - bars1 = plt.bar(idx - bar_width, data1, width=bar_width, color='salmon', label='Original Probability', alpha=0.9) - bars2 = plt.bar(idx, data2, width=bar_width, color='skyblue', label='Sampled Probability', alpha=0.9) - bars3 = plt.bar(idx + bar_width, data3, width=bar_width, color='orange', label='Normalized Original Probability', alpha=0.9) + for b in range(batch_size): + print(f"\nBatch Request {b} (min_p={batch_min_p_values[b]}):") + print(f"FastDeploy - Sampled: {data2_fd_batch[b]}, Normalized: {data3_fd_batch[b]}") - plt.bar_label(bars1, label_type='edge', padding=3, fmt='%.3f', fontsize=5, color='black') - plt.bar_label(bars2, label_type='edge', padding=3, fmt='%.3f', fontsize=5, color='red') - plt.bar_label(bars3, label_type='edge', padding=3, fmt='%.3f', fontsize=5, color='blue') - - plt.title(title, fontsize=14) - plt.xlabel("Index", fontsize=12) - plt.ylabel("Probability", fontsize=12) - plt.ylim(0, 1.1) - plt.xlim(-1, 3) - plt.xticks(range(0, 3, 1)) - plt.legend(fontsize=10) - plt.grid(axis='y', linestyle='--', alpha=0.5) - output_path = f"{title.replace(' ', '_')}.png" - plt.savefig(output_path, dpi=300, bbox_inches='tight') - plt.clf() - -# Function to plot low-probability token probability curve -def plot_low_prob_curve(low_prob_token_probs, sample_time, title): - plt.figure(figsize=(6, 6)) - plt.plot(np.arange(0, sample_time), low_prob_token_probs, marker='', linestyle='-', linewidth=1, color='blue') - plt.xlabel('Sample Times') - plt.ylabel('Probability') - plt.title('Probability of Low-Probability Tokens') - plt.grid(alpha=0.3) - output_path = f"{title.replace(' ', '_')}_low_prob.png" - plt.savefig(output_path, dpi=300, bbox_inches='tight') - plt.clf() - -# Main function -def main(): - print("Running FastDeploy sampling...") - data2_fastdeploy, data3_fastdeploy = fastdeploy_min_p_sampling() - print("Running vLLM (FlashInfer) sampling...") - data2_vllm, data3_vllm = vllm_min_p_sampling() - - # Calculate errors - error_fastdeploy = np.abs(data2_fastdeploy - data3_fastdeploy) - error_vllm = np.abs(data2_vllm - data3_vllm) - - # Print comparison results - print("\nFastDeploy Comparison Results:") - print(f"Sampled Probability (data2): {data2_fastdeploy}") - print(f"Theoretical Normalized Probability (data3): {data3_fastdeploy}") - print(f"Error: {error_fastdeploy}") - print(f"Maximum Error: {np.max(error_fastdeploy)}, Is Less Than 1e-5: {np.max(error_fastdeploy) < 1e-5}") - - print("\nvLLM (FlashInfer) Comparison Results:") - print(f"Sampled Probability (data2): {data2_vllm}") - print(f"Theoretical Normalized Probability (data3): {data3_vllm}") - print(f"Error: {error_vllm}") - print(f"Maximum Error: {np.max(error_vllm)}, Is Less Than 1e-5: {np.max(error_vllm) < 1e-5}") - - # Check if errors meet the requirement - if np.max(error_fastdeploy) < 1e-5 and np.max(error_vllm) < 1e-5: - print("\nConclusion: Both FastDeploy and vLLM sampling results are consistent with theoretical values, with errors less than 1e-5.") - else: - print("\nConclusion: There are cases where the error is greater than 1e-5. It is recommended to increase the number of samples or check the implementation details.") - -# Run the program if __name__ == "__main__": if paddle.device.is_compiled_with_cuda() and torch.cuda.is_available(): main() - else: - print("GPU is not available. Please check the environment configuration (requires support for PaddlePaddle and PyTorch with CUDA).") From 53bdfb2d71f3c4f46df3f1f5d3d3dc69d4aa7460 Mon Sep 17 00:00:00 2001 From: lizexu123 <2694294196@qq.com> Date: Wed, 16 Jul 2025 13:30:25 +0000 Subject: [PATCH 05/16] update --- fastdeploy/demo/offline_demo.py | 13 +++--- fastdeploy/demo/openai_demo.py | 77 +++++++++++++++++---------------- 2 files changed, 45 insertions(+), 45 deletions(-) diff --git a/fastdeploy/demo/offline_demo.py b/fastdeploy/demo/offline_demo.py index d4044510e9..856757aa00 100644 --- a/fastdeploy/demo/offline_demo.py +++ b/fastdeploy/demo/offline_demo.py @@ -17,14 +17,13 @@ from fastdeploy.engine.sampling_params import SamplingParams from fastdeploy.entrypoints.llm import LLM -model_name_or_path = "/home/zexuli/Models/Qwen3-0.6B" +model_name_or_path = "./models/llama-7b" # 超参设置 -sampling_params = SamplingParams(temperature=0.1,top_p=0.7,top_k=10,min_p=0.1) -llm = LLM(model=model_name_or_path, tensor_parallel_size=1,reasoning_parser="qwen3") -prompt = "北京天安门在哪里?" -messages = [{"role": "user", "content": prompt}] -output = llm.chat([messages], - sampling_params) +sampling_params = SamplingParams(temperature=0.1, max_tokens=30) +llm = LLM(model=model_name_or_path, tensor_parallel_size=1) +output = llm.generate(prompts="who are you?", + use_tqdm=True, + sampling_params=sampling_params) print(output) diff --git a/fastdeploy/demo/openai_demo.py b/fastdeploy/demo/openai_demo.py index d7557e9ea7..1b8b5862af 100644 --- a/fastdeploy/demo/openai_demo.py +++ b/fastdeploy/demo/openai_demo.py @@ -18,43 +18,44 @@ import openai ip = "0.0.0.0" -service_http_port = "9012" # 服务配置的 +service_http_port = "9809" # 服务配置的 client = openai.Client(base_url=f"http://{ip}:{service_http_port}/v1", api_key="EMPTY_API_KEY") # 非流式返回 -# response = client.completions.create( -# model="default", -# prompt="Hello, how are you?", -# max_tokens=64, -# stream=False, -# ) +response = client.completions.create( + model="default", + prompt="Hello, how are you?", + max_tokens=64, + stream=False, +) -# print(response.choices[0].text) -# print("\n") +print(response.choices[0].text) +print("\n") -# # 流式返回 -# response = client.completions.create( -# model="default", -# prompt="Hello, how are you?", -# max_tokens=100, -# stream=True, -# ) +# 流式返回 +response = client.completions.create( + model="default", + prompt="Hello, how are you?", + max_tokens=100, + stream=True, +) -# for chunk in response: -# print(chunk.choices[0].text, end='') -# print("\n") +for chunk in response: + print(chunk.choices[0].text, end='') +print("\n") # Chat completion # 非流式返回 response = client.chat.completions.create( model="default", messages=[ - {"role": "user", "content": "北京天安门在哪里?"}, + {"role": "user", "content": "Hello, who are you"}, + {"role": "assistant", "content": "I'm a helpful AI assistant."}, + {"role": "user", "content": "List 3 countries and their capitals."}, ], - temperature=0.1, - top_p=0.7, - metadata={"min_p":0.1}, + temperature=1, + max_tokens=64, stream=False, ) @@ -63,19 +64,19 @@ # # 流式返回 -# response = client.chat.completions.create( -# model="default", -# messages=[ -# {"role": "user", "content": "Hello, who are you"}, -# {"role": "assistant", "content": "I'm a helpful AI assistant."}, -# {"role": "user", "content": "List 3 countries and their capitals."}, -# ], -# temperature=1, -# max_tokens=64, -# stream=True, -# ) +response = client.chat.completions.create( + model="default", + messages=[ + {"role": "user", "content": "Hello, who are you"}, + {"role": "assistant", "content": "I'm a helpful AI assistant."}, + {"role": "user", "content": "List 3 countries and their capitals."}, + ], + temperature=1, + max_tokens=64, + stream=True, +) -# for chunk in response: -# if chunk.choices[0].delta is not None: -# print(chunk.choices[0].delta.content, end='') -# print("\n") +for chunk in response: + if chunk.choices[0].delta is not None: + print(chunk.choices[0].delta.content, end='') +print("\n") From 8d33bd4e74daec61019b5bb656a038ed294c7298 Mon Sep 17 00:00:00 2001 From: lizexu123 <2694294196@qq.com> Date: Wed, 16 Jul 2025 13:44:57 +0000 Subject: [PATCH 06/16] delete vl_gpu_model_runner.py --- fastdeploy/worker/vl_gpu_model_runner.py | 1270 ---------------------- 1 file changed, 1270 deletions(-) delete mode 100644 fastdeploy/worker/vl_gpu_model_runner.py diff --git a/fastdeploy/worker/vl_gpu_model_runner.py b/fastdeploy/worker/vl_gpu_model_runner.py deleted file mode 100644 index b37bcd17c9..0000000000 --- a/fastdeploy/worker/vl_gpu_model_runner.py +++ /dev/null @@ -1,1270 +0,0 @@ -""" -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License" -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -import argparse -import json -import os -import random -from typing import Optional - -import numpy as np -import paddle -import paddle.distributed.fleet as fleet -from paddleformers.transformers.model_utils import load_tp_checkpoint -from safetensors import safe_open - -from fastdeploy.config import (DeviceConfig, FDConfig, GraphOptimizationConfig, - KVCacheConfig, LoadConfig, ModelConfig, - MoEConfig, MoEPhase, ParallelConfig, - SpeculativeConfig) -from fastdeploy.input.ernie_tokenizer import ErnieBotTokenizer -from fastdeploy.input.mm_processor import DataProcessor -from fastdeploy.model_executor.forward_meta import ForwardMeta -from fastdeploy.model_executor.layers.attention import get_attention_backend -from fastdeploy.model_executor.layers.rotary_embedding import get_rope_3d -from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata -from fastdeploy.model_executor.layers.sample.sampler import Sampler -from fastdeploy.model_executor.models.ernie4_5_moe import \ - Ernie4_5_PretrainedModel -from fastdeploy.model_executor.models.ernie4_5_vl.configuration import \ - Ernie4_5_VLMoeConfig -from fastdeploy.model_executor.models.ernie4_5_vl.dfnrope import \ - DFNRopeVisionTransformerConfig -from fastdeploy.model_executor.models.ernie4_5_vl.dfnrope.modeling import \ - DFNRopeVisionTransformerPretrainedModel -from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ( - ScatterOp, VariableResolutionResamplerModel) -from fastdeploy.platforms import current_platform -from fastdeploy.worker.output import SamplerOutput -from fastdeploy.worker.utils import check_safetensors_model -from fastdeploy.worker.vl_model_runner_base import VLModelRunnerBase - -if current_platform.is_cuda() and current_platform.available(): - from fastdeploy.model_executor.layers.utils import ( - remove_padding, speculate_remove_padding) - -from fastdeploy.model_executor.ops.gpu import (save_output, save_output_topk, - set_stop_value_multi_ends, - set_value_by_flags_and_idx, - update_inputs) - - -class GPUVLModelRunner(VLModelRunnerBase): - """ - The GPUVLModelRunner class for vision-language tasks on GPU. - """ - - def __init__( - self, - config: ModelConfig, - args: argparse.Namespace, - nranks: int, - rank: int, - ) -> None: - """ - GPUVLModelRunner init - """ - self.nranks = nranks - self.rank = rank - - hcg = fleet.get_hybrid_communicate_group() - self.tensor_parallel_degree = max(hcg.get_model_parallel_world_size(), - 1) - self.tensor_parallel_rank = hcg.get_model_parallel_rank() - self.mp_src_rank = hcg.get_model_parallel_group_src_rank() - self.mp_group = hcg.get_model_parallel_group() - self.is_safetensors_model = check_safetensors_model( - args.model_name_or_path) - self.enable_logprob = args.enable_logprob - - model_path = os.path.dirname(args.model_name_or_path) - args.llm_model_name_or_path = args.model_name_or_path - if not self.is_safetensors_model: - args.tokenizer = args.image_preprocessor = model_path - else: - args.tokenizer = args.image_preprocessor = args.model_name_or_path - args.vision_model_name_or_path = os.path.join( - model_path, "DFNRopeVisionTransformer") - - self.amp_black = [ - "reduce_sum", - "c_softmax_with_cross_entropy", - "elementwise_div", - "sin", - "cos", - "sort", - "multinomial", - ] - self.amp_white = [ - "lookup_table", - "lookup_table_v2", - "flash_attn", - "matmul", - "matmul_v2", - "fused_gemm_epilogue", - ] - - super().__init__(config, args) - self.init_extra_input(config, args) - - self._reset_paddle_env() - - self.sampler = Sampler() - - def _reset_paddle_env(self): - pass - - def update_chunked_prefill(self, tasks: list[any]) -> None: - """ - update chunked prefill - """ - if not self.args.enable_chunked_prefill: - return - - for task in tasks: - if task.chunk_idx > len(task.prefill_chunk_info): - continue - - idx = task.idx - if task.chunk_idx == len(task.prefill_chunk_info): - self.share_inputs["seq_lens_this_time"][idx:idx + 1] = 1 - self.share_inputs['seq_lens_encoder'][idx:idx + 1] = 0 - self.share_inputs["seq_lens_decoder"][idx:idx + - 1] = task.start_idx - self.share_inputs["step_idx"][idx:idx + 1] = 1 - else: - inputs = self._preprocess_task( - task.prefill_chunk_info[task.chunk_idx]) - if inputs.get("images") is not None: - self.share_inputs[ - "image_features"] = self.extract_vision_features( - inputs) - else: - # Compatible with the situation that lacks images and videos - self.share_inputs["image_features"] = None - - token_chunk_size = inputs["input_ids"].shape[1] - self.share_inputs["input_ids"][ - idx:idx + 1, :token_chunk_size] = inputs["input_ids"] - self.share_inputs["seq_lens_this_time"][idx:idx + - 1] = token_chunk_size - self.share_inputs['seq_lens_encoder'][idx:idx + - 1] = token_chunk_size - self.share_inputs["seq_lens_decoder"][idx:idx + - 1] = task.start_idx - self.share_inputs["step_idx"][idx:idx + 1] = 0 - - task.start_idx += token_chunk_size - task.chunk_idx += 1 - - def _load_model( - self, - model_name: str, - dynamic_load_weight: int = 0, - ) -> None: - """ - Load the model from the given model name. - """ - - vocab_file_names = [ - "tokenizer.model", "spm.model", "ernie_token_100k.model" - ] - for i in range(len(vocab_file_names)): - if os.path.exists( - os.path.join(self.args.tokenizer, vocab_file_names[i])): - ErnieBotTokenizer.resource_files_names[ - "vocab_file"] = vocab_file_names[i] - break - - tokenizer = ErnieBotTokenizer.from_pretrained( - self.args.tokenizer, - model_max_length=self.args.max_model_len, - padding_side="right", - use_fast=False, - ) - tokenizer.ignored_index = -100 - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.unk_token - - config = Ernie4_5_VLMoeConfig.from_pretrained( - self.args.llm_model_name_or_path, - tensor_parallel_degree=self.tensor_parallel_degree, - tensor_parallel_rank=self.tensor_parallel_rank, - moe_group="dummy", - ) - self.model_cfg = config - if self.is_safetensors_model: - meta_json = os.path.join(self.args.model_name_or_path, - "model.safetensors.index.json") - if os.path.exists(meta_json): - with open( - os.path.join(self.args.model_name_or_path, - "model.safetensors.index.json"), - "r") as f: - self.weight_map = json.load(f)["weight_map"] - else: - self.weight_map = {} - with safe_open(os.path.join(self.args.model_name_or_path, - "model.safetensors"), - framework="np") as f: - keys = f.keys() - for k in keys: - self.weight_map[k] = "model.safetensors" - - if self.is_safetensors_model: - vision_config = config.vision_config - vision_config.tensor_parallel_degree = self.tensor_parallel_degree - vision_config.tensor_parallel_rank = self.tensor_parallel_rank - vision_config.attn_sep = False - vision_config.dtype = "bfloat16" - else: - vision_config = DFNRopeVisionTransformerConfig.from_pretrained( - self.args.vision_model_name_or_path, - tensor_parallel_degree=self.tensor_parallel_degree, - tensor_parallel_rank=self.tensor_parallel_rank, - attn_sep=False, - dtype="bfloat16", - ) - config.vision_config = vision_config - self.vision_config = vision_config - config.pixel_hidden_size = config.vision_config.hidden_size - config.im_patch_id = tokenizer.get_vocab()["<|IMAGE_PLACEHOLDER|>"] - config.think_end_id = tokenizer.get_vocab()[""] - config.max_text_id = config.im_patch_id - - config.sequence_parallel = False - - self.dtype = self.args.dtype - paddle.set_default_dtype(self.dtype) - - self.vision_model, self.resampler_model = self.inject_pp_vision_model( - self.args, config) - - processor = DataProcessor( - tokenizer_name=self.args.tokenizer, - image_preprocessor_name=str(self.args.image_preprocessor), - ) - processor.eval() - image_preprocess = processor.image_preprocessor - image_preprocess.image_mean_tensor = paddle.to_tensor( - image_preprocess.image_mean, dtype="float32").reshape([1, 3, 1, 1]) - image_preprocess.image_std_tensor = paddle.to_tensor( - image_preprocess.image_std, dtype="float32").reshape([1, 3, 1, 1]) - image_preprocess.rescale_factor = paddle.to_tensor( - image_preprocess.rescale_factor, dtype="float32") - image_preprocess.image_mean_tensor = image_preprocess.image_mean_tensor.squeeze( - [-2, -1]).repeat_interleave(config.vision_config.patch_size**2 * 1, - -1) - image_preprocess.image_std_tensor = image_preprocess.image_std_tensor.squeeze( - [-2, -1]).repeat_interleave(config.vision_config.patch_size**2 * 1, - -1) - self.image_preprocess = image_preprocess - - graph_opt_config = GraphOptimizationConfig( - self.args.enable_static_graph_inference, self.args.use_cudagraph, - self.args.max_capture_batch_size) - - fd_config, self.model = build_stream_line_model( - self.args.model_name_or_path, - self.args.dtype, - self.args.block_size, - max_model_len=self.args.max_model_len, - tokenizer=tokenizer, - quantization=self.args.quantization, - graph_opt_config=graph_opt_config, - ) - self.model.eval() - self.set_state_dict(self.args) - - fd_config.parallel_config.max_model_len = fd_config.model_config.max_seq_len - self.fd_config = fd_config - attn_backend_cls = get_attention_backend() - num_heads = self.fd_config.model_config.num_attention_heads // \ - self.fd_config.parallel_config.tensor_parallel_degree - self.fd_config.model_config.kv_num_heads = int( - self.fd_config.model_config.num_key_value_heads - ) // self.fd_config.parallel_config.tensor_parallel_degree - head_dim = self.fd_config.model_config.head_dim - self.attn_backend = attn_backend_cls( - self.fd_config, - kv_num_heads=self.fd_config.model_config.kv_num_heads, - num_heads=num_heads, - head_dim=head_dim) - self._init_kvcache() - - def init_extra_input(self, config: ModelConfig, args: argparse.Namespace) -> None: - """ - Initialize extra input tensors. - """ - head_dim = self.model_cfg.head_dim - self.share_inputs.update({ - "rope_emb": - paddle.full(shape=[ - args.max_num_seqs, 2, 1, self.max_length, 1, head_dim // 2 - ], - fill_value=0, - dtype="float32") - }) - self.share_inputs.update({"image_features": None}) - self.share_inputs.update({ - "need_think_end": - paddle.full(shape=[args.max_num_seqs, 1], - fill_value=0, - dtype="int32") - }) - self.share_inputs.update({ - "enable_thinking": - paddle.full(shape=[1], fill_value=True, dtype="bool") - }) - self.share_inputs.update({ - "reasoning_index": - paddle.full(shape=[args.max_num_seqs, 1], - fill_value=0, - dtype="int32") - }) - - def init_rotary_position_embedding(self, max_model_len: int) -> None: - """ - Init rotary position embedding - """ - pass - - def _init_kvcache(self): - """ - Init kv cache - """ - cache_kvs = {} - total_block_num = self.num_gpu_blocks - num_layers = self.model_cfg.get("num_layers", - None) or self.model_cfg.get( - "num_hidden_layers", None) - - kv_num_head = self.model_cfg.get( - "num_key_value_heads", - self.model_cfg.num_attention_heads, - ) - kv_num_head = kv_num_head // self.tensor_parallel_degree - self.model_cfg.kv_num_head = kv_num_head - - for i in range(num_layers): - cache_type = self.args.dtype - cache_kvs["key_caches_{}".format(i)] = paddle.full( - shape=[ - total_block_num, - kv_num_head, - self.args.block_size, - self.model_cfg.head_dim, - ], - fill_value=0, - dtype=cache_type, - ) - cache_kvs["value_caches_{}".format(i)] = paddle.full( - shape=[ - total_block_num, - kv_num_head, - self.args.block_size, - self.model_cfg.head_dim, - ], - fill_value=0, - dtype=cache_type, - ) - - self.share_inputs["caches"] = list(cache_kvs.values()) - for value in cache_kvs.values(): - del value - paddle.device.cuda.empty_cache() - - def clear_parameters(self, pid: int) -> None: - """ clear_parameters """ - if "caches" in self.share_inputs: - self.model.clear_parameters(pid) - del self.share_inputs["caches"] - paddle.device.cuda.empty_cache() - self.model.log_memory_usage("clear all memory") - - def update_parameters(self, pid: int) -> None: - """ update_parameters """ - if "caches" not in self.share_inputs: - self.model.update_parameters(pid) - self._init_kvcache() - self.model.log_memory_usage("update all memory") - - @paddle.no_grad() - def set_state_dict(self, args: argparse.Namespace) -> None: - """set_state_dict""" - if not self.is_safetensors_model: - rank_model_paths = [] - for root, dirs, files in os.walk(self.args.llm_model_name_or_path): - for file in files: - if file == f"model_state.tp0{self.tensor_parallel_rank}.pdparams": - rank_model_paths.append(os.path.join(root, file)) - elif file == "model_state.pdparams": - rank_model_paths.append(os.path.join(root, file)) - state_dict = {} - for path in rank_model_paths: - loaded_dict = paddle.load(path, return_numpy=True) - state_dict.update(loaded_dict) - - resampler_state = {} - for key in list(state_dict.keys()): - if "vision" in key: - state_dict.pop(key) - if key.startswith("ernie.resampler_model."): - value = state_dict.pop(key) - value = paddle.to_tensor(value).cast("bfloat16") - value = value.numpy() - resampler_state[ - key[len("ernie.resampler_model."):]] = value - elif key.startswith("resampler_model."): - value = state_dict.pop(key) - value = paddle.to_tensor(value).cast("bfloat16") - value = value.numpy() - resampler_state[key[len("resampler_model."):]] = value - self.model.set_state_dict(state_dict) - self.resampler_model.set_state_dict(resampler_state) - else: - state_dict = load_tp_checkpoint( - args.model_name_or_path, - Ernie4_5_PretrainedModel, - self.model_cfg, - return_numpy=True, - ) - for key in list(state_dict.keys()): - if key.startswith("vision_model.") or key.startswith( - "ernie.resampler_model."): - state_dict.pop(key) - self.model.set_state_dict(state_dict) - - @paddle.no_grad() - def vit_load( - self, - model_path: str, - tensor_parallel_degree: int, - tensor_parallel_rank: int, - ) -> None: - """ - Load vit tp weight - """ - if tensor_parallel_degree == 1: - rank_model_path = os.path.join(model_path, "model_state.pdparams") - else: - rank_model_path = os.path.join( - model_path, f"model_state_tp0{tensor_parallel_rank}.pdparams") - if os.path.exists(rank_model_path): - return paddle.load(rank_model_path, return_numpy=True) - else: - raise ValueError(f"No such a file {rank_model_path}") - - @paddle.no_grad() - def inject_pp_vision_model(self, args: argparse.Namespace, cfg: Ernie4_5_VLMoeConfig): - """ - Inject pp vision model - """ - - def set_vision_state_dict(model, - tensor_parallel_degree: int=8, - tensor_parallel_rank: int=0, - name: str=""): - """ - Set vision model weight - """ - model_state_dict = model.state_dict() - compat_keys = [name + k for k in model_state_dict.keys()] - model_files = set() - for k in compat_keys: - if k in self.weight_map.keys(): - model_files.add( - os.path.join(args.model_name_or_path, - self.weight_map[k])) - state_dict = {} - for model_file in model_files: - with safe_open(model_file, framework="np") as f: - for k in f.keys(): - if k in compat_keys: - new_k = k.replace(name, "") - tensor = f.get_tensor(k) - if tensor_parallel_degree > 1: - if "resampler_model" in name and new_k == "spatial_linear.0.weight": - tensor = np.split( - tensor, tensor_parallel_degree, - axis=0)[tensor_parallel_rank] - elif name == "vision_model.": - if "attn.proj.weight" in new_k or "fc2.weight" in new_k: - tensor = np.split( - tensor, - tensor_parallel_degree, - axis=0)[tensor_parallel_rank] - elif "fc1.weight" in new_k or "fc1.bias" in new_k: - tensor = np.split( - tensor, - tensor_parallel_degree, - axis=-1)[tensor_parallel_rank] - elif "qkv.weight" in new_k: - head_dim = self.vision_config.hidden_size // self.vision_config.num_heads - tensor = tensor.reshape([ - self.vision_config.hidden_size, 3, - self.vision_config.num_heads, - head_dim - ]) - tensor = np.split( - tensor, - tensor_parallel_degree, - axis=-2 - )[tensor_parallel_rank].reshape([ - self.vision_config.hidden_size, -1 - ]) - elif "qkv.bias" in new_k: - head_dim = self.vision_config.hidden_size // self.vision_config.num_heads - tensor = tensor.reshape([ - 3, self.vision_config.num_heads, - head_dim - ]) - tensor = np.split( - tensor, - tensor_parallel_degree, - axis=-2 - )[tensor_parallel_rank].reshape([-1]) - state_dict[new_k] = tensor - model.set_state_dict(state_dict) - - vision_model = DFNRopeVisionTransformerPretrainedModel( - cfg.vision_config) - vision_model = paddle.amp.decorate(models=vision_model, - level="O2", - dtype="bfloat16") - vision_model.eval() - if not self.is_safetensors_model: - vit_state_dict = self.vit_load(args.vision_model_name_or_path, - self.tensor_parallel_degree, - self.tensor_parallel_rank) - vision_model.set_state_dict(vit_state_dict) - else: - set_vision_state_dict( - vision_model, - tensor_parallel_degree=self.tensor_parallel_degree, - tensor_parallel_rank=self.tensor_parallel_rank, - name="vision_model.", - ) - - resampler_model = VariableResolutionResamplerModel( - cfg.pixel_hidden_size, - cfg.hidden_size, - cfg.spatial_conv_size, - cfg.temporal_conv_size, - config=cfg, - ) - resampler_model = paddle.amp.decorate(models=resampler_model, - level="O2", - dtype="bfloat16") - resampler_model.eval() - if self.is_safetensors_model: - is_ernie_begin = False - for k in self.weight_map.keys(): - if k.startswith("ernie.resampler_model."): - is_ernie_begin = True - set_vision_state_dict( - resampler_model, - tensor_parallel_degree=self.tensor_parallel_degree, - tensor_parallel_rank=self.tensor_parallel_rank, - name="ernie.resampler_model." - if is_ernie_begin else "resampler_model.", - ) - return vision_model, resampler_model - - @paddle.no_grad() - def extract_vision_features(self, inputs: list[paddle.Tensor]) -> paddle.Tensor: - """extract_vision_features""" - assert inputs["images"] is not None - grid_thw = inputs["grid_thw"] - - images = inputs["images"].cast("float32") - images = self.image_preprocess.rescale_factor * images - self.image_preprocess.image_mean_tensor - images = images / self.image_preprocess.image_std_tensor - images = images.cast("bfloat16") - - token_type_ids = inputs["token_type_ids"] - token_type_ids_w_video = token_type_ids - input_ids = inputs["input_ids"] - # convert to img patch id - image_mask = input_ids == self.model_cfg.im_patch_id - image_type_ids = inputs["image_type_ids"] - with paddle.amp.auto_cast( - True, - custom_black_list=self.amp_black, - custom_white_list=self.amp_white, - level="O2", - dtype=self.dtype, - ): - image_features = self.vision_model.extract_feature( - images, grid_thw) - if self.tensor_parallel_degree > 1: - S, C = image_features.shape - image_features = image_features.reshape( - [-1, C * self.model_cfg.spatial_conv_size**2]) - image_features = ScatterOp.apply(image_features, - axis=-1) # mp 切 Fea - image_features = image_features.reshape([S, -1]) - image_features = self.resampler_model( - image_features, - image_mask, - token_type_ids_w_video, - image_type_ids, - grid_thw, - ) - return image_features - - @paddle.no_grad() - def prepare_rope3d(self, position_ids: paddle.Tensor, **kwargs) -> paddle.Tensor: - """prepare_rope3d""" - - prefix_max_position_ids = paddle.max(position_ids) + 1 - dec_pos_ids = paddle.tile( - paddle.arange(kwargs["max_length"], - dtype="int64").unsqueeze(0).unsqueeze(-1), [1, 1, 3]) - dec_pos_ids = dec_pos_ids + prefix_max_position_ids - position_ids_3d_real = paddle.concat([position_ids, dec_pos_ids], - axis=1) - - rope_emb = get_rope_3d( - position_ids=position_ids_3d_real, - rotary_dim=self.model_cfg.head_dim, - paritial_rotary_factor=1.0, - base=self.model_cfg.rope_theta, - max_position=self.args.max_model_len, - freq_allocation=self.model_cfg.freq_allocation, - ) - return rope_emb - - def prefill_finished(self): - """ - Verify prefill operation completion - """ - prefill_statue = (self.share_inputs["seq_lens_this_time"] != 0) & ( - self.share_inputs["seq_lens_this_time"] != 1) - return not paddle.any(prefill_statue).numpy() - - def dy_input_preprocess(self, tasks: list[any]) -> None: - """ - dynamic insertion - """ - - def get_numeric_value(task, key, default_value): - if task.get(key, None) is not None: - return task.get(key) - else: - return default_value - - for i in range(len(tasks)): - task = tasks[i] - idx = task.idx - - kwargs = { - "max_length": - get_numeric_value(task, "max_tokens", 2048), - "top_p": - get_numeric_value(task, "top_p", 0.8), - "temperature": - get_numeric_value(task, "temperature", 0.2), - "top_k": - get_numeric_value(task, "top_k", 0), - "min_p": - get_numeric_value(task,"min_p",0.0), - "penalty_score": - get_numeric_value(task, "repetition_penalty", 1.0), - "frequency_score": - get_numeric_value(task, "frequency_penalty", 0.0), - "presence_score": - get_numeric_value(task, "presence_penalty", 0.0), - "decode_strategy": - "sampling", - "pad_token_id": - self.args.pad_token_id, - "enable_thinking": - get_numeric_value(task, "enable_thinking", True), - "reasoning_max_tokens": - get_numeric_value(task, "reasoning_max_tokens", 2048), - } - - if self.args.enable_chunked_prefill: - task.set("chunk_idx", 1) - inputs = self._preprocess_task(task.prefill_chunk_info[0]) - if inputs.get("images") is not None: - self.share_inputs[ - "image_features"] = self.extract_vision_features( - inputs) - else: - # Compatible with the situation that lacks images and videos - self.share_inputs["image_features"] = None - if task.multimodal_inputs["position_ids"] is not None: - position_ids = paddle.to_tensor( - task.multimodal_inputs["position_ids"], - dtype="int64").unsqueeze([0]) - else: - position_ids = None - - token_chunk_size = inputs["input_ids"].shape[1] - task.set("start_idx", token_chunk_size) - self.share_inputs["input_ids"][ - idx:idx + 1, :token_chunk_size] = inputs["input_ids"] - self.share_inputs["seq_lens_this_time"][idx:idx + - 1] = token_chunk_size - self.share_inputs["seq_lens_encoder"][idx:idx + - 1] = token_chunk_size - self.share_inputs["step_seq_lens_encoder"][ - idx:idx + 1] = token_chunk_size - else: - inputs = self._preprocess_task(task.multimodal_inputs) - if inputs.get("images") is not None: - self.share_inputs[ - "image_features"] = self.extract_vision_features( - inputs) - else: - # Compatible with the situation that lacks images and videos - self.share_inputs["image_features"] = None - position_ids = inputs["position_ids"] - - length = inputs["input_ids"].shape[1] - self.share_inputs["input_ids"][ - idx:idx + 1, :length] = inputs["input_ids"] - self.share_inputs["seq_lens_this_time"][idx:idx + 1] = length - self.share_inputs["seq_lens_encoder"][idx:idx + 1] = length - self.share_inputs["step_seq_lens_encoder"][idx:idx + - 1] = length - - # force - self.share_inputs["enable_thinking"][:] = kwargs["enable_thinking"] - self.share_inputs["need_think_end"][ - idx:idx + 1, :] = 1 if kwargs["enable_thinking"] else 0 - - self.share_inputs["reasoning_index"][ - idx:idx + 1, :] = kwargs["reasoning_max_tokens"] - - self.share_inputs["rope_emb"][idx:idx + - 1, :] = self.prepare_rope3d( - position_ids, **kwargs) - - self.share_inputs["top_p"][idx:idx + 1] = kwargs["top_p"] - self.share_inputs["temperature"][idx:idx + - 1] = kwargs["temperature"] - self.share_inputs["eos_token_id"][:] = np.array( - task.eos_token_ids).astype("int64").reshape(-1, 1) - self.share_inputs["penalty_score"][idx:idx + - 1] = kwargs["penalty_score"] - self.share_inputs["frequency_score"][idx:idx + - 1] = kwargs["frequency_score"] - self.share_inputs["presence_score"][idx:idx + - 1] = kwargs["presence_score"] - self.share_inputs["seq_lens_decoder"][idx:idx + 1] = 0 - self.share_inputs["step_idx"][idx:idx + 1] = 0 - self.share_inputs["min_dec_len"][idx:idx + 1] = 1 - self.share_inputs["max_dec_len"][idx:idx + - 1] = kwargs["max_length"] - self.share_inputs["stop_flags"][idx:idx + 1] = False - self.share_inputs["pre_ids"][idx:idx + 1] = -1 - encoder_block_num = len(task.get("block_tables")) - self.share_inputs["encoder_block_lens"][idx:idx + - 1] = encoder_block_num - self.share_inputs["block_tables"][idx:idx + 1, :] = -1 - self.share_inputs["block_tables"][ - idx:idx + 1, :encoder_block_num] = np.array(task.block_tables, - dtype="int32") - - def pre_process(self) -> None: - """ - pre_process - """ - if current_platform.is_cuda(): - if self.args.speculative_method is not None: - ( - ids_remove_padding, - padding_offset, - cum_offsets, - cu_seqlens_q, - cu_seqlens_k, - ) = speculate_remove_padding( - max_len=self.args.max_model_len, - input_ids=self.share_inputs["input_ids"], - seq_lens_this_time=self.share_inputs["seq_lens_this_time"], - draft_tokens=self.share_inputs["draft_tokens"], - seq_lens_encoder=self.share_inputs["seq_lens_encoder"]) - else: - ( - ids_remove_padding, - padding_offset, - cum_offsets, - cu_seqlens_q, - cu_seqlens_k, - ) = remove_padding( - max_len=self.args.max_model_len, - input_ids=self.share_inputs["input_ids"], - seq_lens_this_time=self.share_inputs["seq_lens_this_time"]) - self.share_inputs["ids_remove_padding"] = ids_remove_padding - self.share_inputs["padding_offset"] = padding_offset - self.share_inputs["cum_offsets"] = cum_offsets - self.share_inputs["cu_seqlens_q"] = cu_seqlens_q - self.share_inputs["cu_seqlens_k"] = cu_seqlens_k - self.share_inputs["decoder_batch_ids"] = paddle.full( - [self.fd_config.parallel_config.max_num_seqs, 1], 0, dtype='int32') - self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full( - [self.fd_config.parallel_config.max_num_seqs, 1], 0, dtype='int32') - # initialize_forward_meta - self.forward_meta = ForwardMeta( - input_ids=self.share_inputs["input_ids"], - ids_remove_padding=self.share_inputs["ids_remove_padding"], - rotary_embs=self.share_inputs["rope_emb"], - attn_backend=self.attn_backend, - decoder_batch_ids=self.share_inputs["decoder_batch_ids"], - decoder_tile_ids_per_batch=self.share_inputs["decoder_tile_ids_per_batch"], - seq_lens_encoder=self.share_inputs["seq_lens_encoder"], - seq_lens_decoder=self.share_inputs["seq_lens_decoder"], - seq_lens_this_time=self.share_inputs["seq_lens_this_time"], - cum_offsets=self.share_inputs["cum_offsets"], - padding_offset=self.share_inputs["padding_offset"], - cu_seqlens_q=self.share_inputs["cu_seqlens_q"], - cu_seqlens_k=self.share_inputs["cu_seqlens_k"], - block_tables=self.share_inputs["block_tables"], - caches=self.share_inputs["caches"] - ) - self.attn_backend.init_attention_metadata(self.forward_meta) - - self.sampling_metadata = SamplingMetadata( - temperature=self.share_inputs["temperature"], - top_p=self.share_inputs["top_p"], - step_idx=self.share_inputs["step_idx"], - pre_token_ids=self.share_inputs["pre_ids"], - frequency_penalties=self.share_inputs["frequency_score"], - presence_penalties=self.share_inputs["presence_score"], - repetition_penalties=self.share_inputs["penalty_score"], - min_dec_lens=self.share_inputs["min_dec_len"], - bad_words_token_ids=self.share_inputs["bad_tokens"], - eos_token_ids=self.share_inputs["eos_token_id"], - max_num_logprobs=20 if self.enable_logprob else None, - ) - - def generate(self) -> None: - """ - generate - """ - self.pre_process() - hiddden_states = self.model(self.share_inputs["ids_remove_padding"], - self.share_inputs["image_features"], - self.forward_meta) - logits = self.model.compute_logits(hiddden_states) - set_value_by_flags_and_idx( - self.share_inputs["pre_ids"], - self.share_inputs["input_ids"], - self.share_inputs["seq_lens_this_time"], - self.share_inputs["seq_lens_encoder"], - self.share_inputs["seq_lens_decoder"], - self.share_inputs["step_idx"], - self.share_inputs["stop_flags"], - ) - # sampler & save_output - sampler_output = self.sampler(logits, self.sampling_metadata) - if self.fd_config.parallel_config.tensor_parallel_degree > 1: - paddle.distributed.broadcast(sampler_output.sampled_token_ids, 0) - self.post_process(sampler_output) - - def post_process(self, sampler_output: SamplerOutput) -> None: - """ - post_process - """ - if self.share_inputs["enable_thinking"]: - exists_think_end = sampler_output.sampled_token_ids == self.model_cfg.think_end_id - paddle.assign( - paddle.where( - exists_think_end, - self.share_inputs["need_think_end"] - 1, - self.share_inputs["need_think_end"], - ), self.share_inputs["need_think_end"]) - - paddle.assign( - paddle.where( - self.share_inputs["need_think_end"].cast("bool"), - self.share_inputs["reasoning_index"] - 1, - self.share_inputs["reasoning_index"], - ), self.share_inputs["reasoning_index"]) - - stop_wo_think = ( - (sampler_output.sampled_token_ids == self.share_inputs["eos_token_id"]) | - (self.share_inputs["reasoning_index"] == 0)) & ( - self.share_inputs["need_think_end"] > 0) - sampler_output.sampled_token_ids = paddle.where(stop_wo_think, - self.model_cfg.think_end_id, - sampler_output.sampled_token_ids) - paddle.assign( - paddle.where( - stop_wo_think, - self.share_inputs["need_think_end"] - 1, - self.share_inputs["need_think_end"], - ), self.share_inputs["need_think_end"]) - paddle.assign( - paddle.where( - self.share_inputs["stop_flags"], - self.share_inputs["step_idx"], - self.share_inputs["step_idx"] + 1, - ), - self.share_inputs["step_idx"], - ) - length_cond = paddle.greater_equal(self.share_inputs["step_idx"], - self.share_inputs["max_dec_len"]) - paddle.assign( - paddle.logical_or(self.share_inputs["stop_flags"], length_cond), - self.share_inputs["stop_flags"], - ) - - set_stop_value_multi_ends( - sampler_output.sampled_token_ids, - self.share_inputs["stop_flags"], - self.share_inputs["seq_lens_this_time"], - self.share_inputs["eos_token_id"], - self.share_inputs["next_tokens"], - False, - ) # multi ends - # update inputs - update_inputs( - self.share_inputs["stop_flags"], - self.share_inputs["not_need_stop"], - self.share_inputs["seq_lens_this_time"], - self.share_inputs["seq_lens_encoder"], - self.share_inputs["seq_lens_decoder"], - self.share_inputs["input_ids"], - self.share_inputs["stop_nums"], - sampler_output.sampled_token_ids, - self.share_inputs["is_block_step"], - ) - if sampler_output.logprobs_tensors is None: - save_output( - sampler_output.sampled_token_ids, - self.share_inputs["not_need_stop"], - self.rank, - False, # use_ep - ) - else: - save_output_topk( - sampler_output.sampled_token_ids, - sampler_output.logprobs_tensors.logprob_token_ids, - sampler_output.logprobs_tensors.logprobs, - sampler_output.logprobs_tensors.selected_token_ranks, - self.share_inputs["not_need_stop"], - self.rank, - ) - - def _cal_theortical_kvcache(self): - """ - Calculate the size of kvcache for computational theory - """ - num_layers = self.model_cfg.get("num_layers", - None) or self.model_cfg.get( - "num_hidden_layers", None) - byte_of_cache = 2 - # support c8 c4 - - hidden_dim = self.model_cfg.head_dim * self.model_cfg.kv_num_head - theoretical_kv_cache_memory = (2 * byte_of_cache * - self.args.block_size * num_layers * - hidden_dim) - return theoretical_kv_cache_memory - - def _update_share_input_block_num(self): - """ - Update share_inputs['block_tables'] and share_inputs['free_list'] - """ - num_gpu_blocks = self.num_gpu_blocks - - del self.share_inputs["caches"] - self._init_kvcache() - - del self.share_inputs["block_tables"] - self.share_inputs["block_tables"] = paddle.full( - [self.args.max_num_seqs, num_gpu_blocks], -1, dtype="int32") - - # Init free list - free_list = list( - range(num_gpu_blocks - 1, - int(num_gpu_blocks * self.args.kv_cache_ratio) - 1, -1)) - self.free_list_len = len(free_list) - self.share_inputs.update({ - "free_list": - paddle.to_tensor(free_list, dtype="int32"), - "free_list_len": - paddle.full([1], self.free_list_len, dtype="int32"), - }) - - def dummy_input(self, num_total_tokens: int, number_of_tasks: int) -> None: - """ - fake input to profile - """ - input_length = min(num_total_tokens // number_of_tasks, - self.args.max_model_len - 10) - block_num = (input_length + self.args.block_size - 1 ) // self.args.block_size \ - + self.args.enc_dec_block_num - self.share_inputs["free_list"] = paddle.to_tensor([], dtype="int32") - self.share_inputs["free_list_len"][0] = 0 - - for i in range(number_of_tasks): - idx = i - self.share_inputs["input_ids"][idx:idx + - 1, :input_length] = np.array( - [5] * input_length) - self.share_inputs["eos_token_id"][:] = np.array( - [2], dtype="int64").reshape(-1, 1) - self.share_inputs["seq_lens_this_time"][idx:idx + 1] = input_length - self.share_inputs["step_seq_lens_encoder"][idx:idx + - 1] = input_length - self.share_inputs["seq_lens_encoder"][idx:idx + 1] = input_length - self.share_inputs["seq_lens_decoder"][idx:idx + 1] = 0 - self.share_inputs["step_idx"][idx:idx + 1] = 0 - self.share_inputs["max_dec_len"][idx:idx + 1] = 10 - self.share_inputs["stop_flags"][idx:idx + 1] = False - - self.share_inputs["first_token_ids"][ - idx:idx + 1] = self.share_inputs["input_ids"][idx:idx + 1, :1] - self.share_inputs["ori_seq_lens_encoder"][idx:idx + - 1] = input_length - - self.share_inputs["infer_seed"][idx:idx + 1] = random.randint( - 0, 922337203685477580) - self.share_inputs["encoder_block_lens"][idx:idx + 1] = block_num - self.share_inputs["block_tables"][idx : idx + 1, :block_num] = np.arange(idx * block_num, \ - (idx + 1) * block_num, 1) - - def _preprocess_task(self, one: dict) -> None: - """process batch""" - - input_ids = one["input_ids"][np.newaxis, :] - input_ids = paddle.to_tensor(input_ids, dtype=paddle.int64) - token_type_ids = one["token_type_ids"][np.newaxis, :] - token_type_ids = paddle.to_tensor(token_type_ids, dtype=paddle.int64) - - if one["images"] is not None: - image_type_ids = one["image_type_ids"][np.newaxis, :] - images = one["images"] - image_type_ids = paddle.to_tensor(image_type_ids, - dtype=paddle.int64) - images = paddle.to_tensor(images, dtype="uint8") - grid_thw = paddle.to_tensor(one["grid_thw"], dtype="int64") - else: - image_type_ids = None - images = None - grid_thw = None - - if one["position_ids"] is not None: - position_ids = paddle.to_tensor(one["position_ids"], - dtype="int64").unsqueeze([0]) - else: - position_ids = None - - result = dict( - input_ids=input_ids, - image_type_ids=image_type_ids, - token_type_ids=token_type_ids, - position_ids=position_ids, - grid_thw=grid_thw, - images=images, - ) - return result - - -def build_stream_line_model( - model_path: str, - dtype: str, - block_size: int, - max_model_len: int, - tokenizer: ErnieBotTokenizer, - quantization: str = "None", - graph_opt_config: Optional[GraphOptimizationConfig] = None -) -> tuple[FDConfig, paddle.nn.layer]: - """ - build model - """ - import contextlib - - from paddleformers.transformers.configuration_utils import PretrainedConfig - from paddleformers.trl import llm_utils - from paddleformers.utils.log import logger - - from fastdeploy.model_executor.layers.quantization import \ - get_quantization_config - from fastdeploy.model_executor.models.model_base import ModelRegistry - - config, _ = PretrainedConfig.get_config_dict(model_path) - config["head_dim"] = config.get( - "head_dim", config["hidden_size"] // config["num_attention_heads"]) - config["rope_theta"] = config.get("rope_theta", 10000.0) - rope_theta = config["rope_theta"] - model_config = ModelConfig.from_dict(config) - model_config.head_dim = config["head_dim"] - - parallel_config = ParallelConfig() - speculative_config = SpeculativeConfig() - device_config = DeviceConfig() - load_config = LoadConfig() - moe_config = MoEConfig() - kv_cache_config = KVCacheConfig() - kv_cache_config.cache_quant_dtype = "none" - - tensor_parallel_rank, tensor_parallel_degree = llm_utils.init_dist_env() - parallel_config.tensor_parallel_rank = tensor_parallel_rank - parallel_config.tensor_parallel_degree = tensor_parallel_degree - parallel_config.tensor_parallel_degree = tensor_parallel_degree - parallel_config.expert_parallel_degree = 1 - parallel_config.expert_parallel_rank = int(tensor_parallel_rank / - tensor_parallel_degree) - parallel_config.column_cut = False - - speculative_config.is_mtp = False - speculative_config.draft_type = "None" - - # Note(tangbinhan): used for load_checkpoint - model_config.tensor_parallel_rank = parallel_config.tensor_parallel_rank - model_config.tensor_parallel_degree = parallel_config.tensor_parallel_degree - model_config.is_mtp = speculative_config.is_mtp - moe_config.num_experts = None - - # use the length of tokenizer as the origin vocab size - ori_vocab_size = len(tokenizer) - moe_intermediate_size = (config.get("moe_intermediate_size", None), ) - if isinstance(moe_intermediate_size, list) or isinstance( - moe_intermediate_size, tuple): - moe_intermediate_size = moe_intermediate_size[0] - - num_key_value_heads = config.get("num_key_value_heads", -1) - if num_key_value_heads is None: - num_key_value_heads = -1 - - # RL need, some model num_key_value_heads less tensor_parallel_degree, need copy - if num_key_value_heads < tensor_parallel_degree: - logger.warning( - f"key value heads num is {num_key_value_heads}, tensor parallel degree is {tensor_parallel_degree}" - ) - num_key_value_heads = tensor_parallel_degree - - if config.get("ffn_hidden_size", None) is not None: - ffn_hidden_size = config["ffn_hidden_size"] - elif config.get("intermediate_size", None) is not None: - ffn_hidden_size = config["intermediate_size"] - else: - ffn_hidden_size = 4 * config["hidden_size"] - if config["hidden_act"].lower() == "swiglu": - if paddle.distributed.get_world_size() > 1: - multiple_of = 8 * config["num_attention_heads"] - else: - multiple_of = 4 * config["num_attention_heads"] - ffn_hidden_size = multiple_of * ( - (int(2 * ffn_hidden_size / 3) + multiple_of - 1) // - multiple_of) - - num_layers = config.get("num_layers", None) or config.get( - "num_hidden_layers", None) - if num_layers is None: - raise ValueError(f"num_layers<{num_layers}> is invalid") - - remove_tail_layer = config.get("remove_tail_layer") - if remove_tail_layer is True: - num_layers -= 1 - elif isinstance(remove_tail_layer, int): - num_layers -= remove_tail_layer - - moe_num_experts = config.get("moe_num_experts", 0) - if isinstance(moe_num_experts, list): - moe_num_experts = max(moe_num_experts) - use_moe = moe_num_experts > 0 - - context = contextlib.nullcontext() - - if config["hidden_act"].lower() == "swiglu": - model_config.hidden_act = "swiglu" - model_config.ffn_hidden_size = ffn_hidden_size - model_config.max_seq_len = max_model_len - model_config.num_layers = num_layers - model_config.dtype = dtype - parallel_config.block_size = block_size - - parallel_config.msg_queue_id = None - model_config.num_key_value_heads = num_key_value_heads - model_config.return_all_hidden_states = False - speculative_config.draft_type = "None" - model_config.start_layer_index = 0 - if use_moe: - moe_config.num_experts = config.get("moe_num_experts", None) - moe_config.moe_intermediate_size = config.get("moe_intermediate_size", - None) - moe_config.top_k = config.get("moe_topk", 8) - moe_config.moe_num_shared_experts = config.get( - "moe_num_shared_experts", 0) - moe_config.moe_layer_start_index = config.get("moe_layer_start_index", - None) - moe_config.moe_layer_end_index = config.get("moe_layer_end_index", - None) - - model_config.moe_phase = MoEPhase.PREFILL - model_config.ori_vocab_size = ori_vocab_size - - quantization_config = config.get("quantization_config", None) - - quant_config_name = None - if quantization_config is not None and quantization_config.get( - "quantization", None) is None: - raise ValueError( - "quantization_config should have a key named 'quantization' for specify quant config." - ) - - if quantization_config is not None: - quant_config_name = quantization_config["quantization"] - quant_cls = get_quantization_config(quant_config_name) - quant_config = quant_cls.from_config(quantization_config) - elif quantization != "None": - quantization_config = {} - if use_moe and quantization == "wint4": - quantization_config["dense_quant_type"] = "wint8" - quantization_config["moe_quant_type"] = "wint4" - quant_config_name = "mix_quant" - else: - quant_config_name = quantization - quant_cls = get_quantization_config(quant_config_name) - quant_config = quant_cls.from_config(quantization_config) - else: - quant_config = None - - logger.info("===========quantization_config==============") - if quant_config is not None: - logger.info(f"{quantization_config}") - else: - logger.info( - "No quantization config found and use original weight and act dtype." - ) - logger.info("============================================") - - fd_config = FDConfig( - model_config=model_config, - parallel_config=parallel_config, - speculative_config=speculative_config, - device_config=device_config, - load_config=load_config, - moe_config=moe_config, - quant_config=quant_config, - kv_cache_config=kv_cache_config, - graph_opt_config=graph_opt_config, - ) - fd_config.parallel_config.max_model_len = max_model_len - fd_config.model_config.rope_theta = rope_theta - - with context: - model_cls = ModelRegistry.get_class(model_config.architectures[0]) - model = model_cls(fd_config) - - model.eval() - return fd_config, model From 3e8cd8511a1619ac306b8c649fa32f1664159463 Mon Sep 17 00:00:00 2001 From: lizexu123 <2694294196@qq.com> Date: Wed, 16 Jul 2025 13:46:42 +0000 Subject: [PATCH 07/16] fix --- fastdeploy/worker/xpu_model_runner.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index 0c6564afbc..23cb4edbb0 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -518,12 +518,16 @@ def _prepare_inputs(self) -> None: self.forward_meta.attn_backend = self.attn_backends[0] self.initialize_attention_backend() + num_reqs = int((self.share_inputs["seq_lens_this_time"] > 0).sum()) + min_p_slice = self.share_inputs["min_p"][:num_reqs] + no_min_p = paddle.all(min_p_slice == 0.0).item() + # Get sampling metadata self.sampling_metadata = SamplingMetadata( temperature=self.share_inputs["temperature"], top_p=self.share_inputs["top_p"], top_k=self.share_inputs["top_k"], - min_p=self.share_inputs["min_p"], + min_p=None if no_min_p else self.share_inputs["min_p"], step_idx=self.share_inputs["step_idx"], pre_token_ids=self.share_inputs["pre_ids"], frequency_penalties=self.share_inputs["frequency_score"], From 4618992f12c429eb72c86ffd437f7ee22cc4cd3b Mon Sep 17 00:00:00 2001 From: lizexu123 <2694294196@qq.com> Date: Fri, 18 Jul 2025 04:10:28 +0000 Subject: [PATCH 08/16] Align usage of min_p with vLLM --- .../min_p_sampling_from_probs.cu | 31 +++++----- .../gpu_ops/sample_kernels/sampling.cuh | 57 +++---------------- .../model_executor/layers/sample/meta_data.py | 4 +- .../layers/sample/ops/top_k_top_p_sampling.py | 34 +++++------ .../model_executor/layers/sample/sampler.py | 15 ++--- fastdeploy/worker/gpu_model_runner.py | 10 +--- test/layers/test_min_p.py | 3 +- 7 files changed, 50 insertions(+), 104 deletions(-) diff --git a/custom_ops/gpu_ops/sample_kernels/min_p_sampling_from_probs.cu b/custom_ops/gpu_ops/sample_kernels/min_p_sampling_from_probs.cu index e8944421dd..c44c16b430 100644 --- a/custom_ops/gpu_ops/sample_kernels/min_p_sampling_from_probs.cu +++ b/custom_ops/gpu_ops/sample_kernels/min_p_sampling_from_probs.cu @@ -17,48 +17,49 @@ #include "sample_kernels/sampling.cuh" std::vector MinPSamplingFromProbs(const paddle::Tensor &probs, - const paddle::Tensor &min_p, - int seed) { + const paddle::Tensor &min_p) { std::vector probs_shape = probs.shape(); unsigned int batch_size = probs_shape[0]; unsigned int vocab_size = probs_shape[1]; - uint64_t philox_seed = seed; - uint64_t philox_offset = 0; auto cu_stream = probs.stream(); - auto samples = - paddle::empty({batch_size, 1}, paddle::DataType::INT64, probs.place()); + auto renorm_probs = + GetEmptyTensor(probs.dims(), paddle::DataType::FLOAT32, probs.place()); cudaError_t status; - status = sampling::MinPSamplingFromProb( - const_cast(probs.data()),min_p.data(),samples.data(), - batch_size,vocab_size,true,philox_seed,philox_offset,cu_stream); + status = sampling::MinPSamplingFromProb( + const_cast(probs.data()), + const_cast(min_p.data()), + renorm_probs.data(), + batch_size, + vocab_size, + true, // deterministic + cu_stream); + PD_CHECK(status == cudaSuccess, "SamplingFromProbs failed with error code " + std::string(cudaGetErrorString(status))); - return {samples}; + return {renorm_probs}; } std::vector> MinPSamplingFromProbsInferShape(const std::vector &probs_shape, const paddle::optional> &min_p_shape) { - int64_t bs = probs_shape[0]; - return {{bs, 1}}; + return {probs_shape}; } std::vector MinPSamplingFromProbsInferDtype(const paddle::DataType &probs_dtype, const paddle::optional &min_p_dtype) { - return {paddle::DataType::INT64}; + return {probs_dtype}; } PD_BUILD_STATIC_OP(min_p_sampling) .Inputs({"probs", "min_p"}) - .Outputs({"samples"}) - .Attrs({"seed: int"}) + .Outputs({"renorm_probs"}) .SetKernelFn(PD_KERNEL(MinPSamplingFromProbs)) .SetInferShapeFn(PD_INFER_SHAPE(MinPSamplingFromProbsInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(MinPSamplingFromProbsInferDtype)); diff --git a/custom_ops/gpu_ops/sample_kernels/sampling.cuh b/custom_ops/gpu_ops/sample_kernels/sampling.cuh index 1ac7d3b670..f14694fa1b 100644 --- a/custom_ops/gpu_ops/sample_kernels/sampling.cuh +++ b/custom_ops/gpu_ops/sample_kernels/sampling.cuh @@ -560,14 +560,11 @@ struct RenormTempStorage { template -__global__ void MinPSamplingFromProbKernel(DType* probs, float* min_p_arr, - IdType* output,uint32_t d, - uint64_t philox_seed, uint64_t philox_offset) { + typename DType,typename IdType> +__global__ void MinPSamplingFromProbKernel(DType* probs, const float* min_p_arr, + DType* renormed_prob,uint32_t d) { const uint32_t bx = blockIdx.x, tx = threadIdx.x; float p = (min_p_arr == nullptr) ? 0 : min_p_arr[bx]; - curandStatePhilox4_32_10_t state; - curand_init(philox_seed, bx, philox_offset, &state); const uint32_t row_idx = bx; extern __shared__ __align__( @@ -583,7 +580,6 @@ __global__ void MinPSamplingFromProbKernel(DType* probs, float* min_p_arr, float pivot = max_val * p; vec_t probs_vec; - float aggregate_gt_pivot = 0; #pragma unroll 2 for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { probs_vec.fill(0); @@ -591,49 +587,15 @@ __global__ void MinPSamplingFromProbKernel(DType* probs, float* min_p_arr, probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); } - float probs_gt_pivot[VEC_SIZE]; #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { - probs_gt_pivot[j] = (probs_vec[j] >= pivot) ? probs_vec[j] : 0; - } - - aggregate_gt_pivot += BlockReduce(temp_storage.block_prim.reduce) - .Sum(probs_gt_pivot); - if (tx == 0) { - temp_storage.block_aggregate.value = aggregate_gt_pivot; + probs_vec[j] = (probs_vec[j] >= pivot) ? probs_vec[j] : 0; } - __syncthreads(); - } - - float aggregate = 0; - float q = temp_storage.block_aggregate.value; - - int sampled_id; - temp_storage.sampled_id = d; - __syncthreads(); - float u = curand_uniform(&state) * q; -#pragma unroll 2 - for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { - probs_vec.fill(0); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { - probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); + probs_vec.store(renormed_prob + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); } - DeviceSamplingFromProb( - i, d, [&](float x) { return x >= pivot; }, u, probs_vec, aggregate, &temp_storage); - if (aggregate > u) { - break; - } - } - sampled_id = temp_storage.sampled_id; - if (sampled_id == d) { - // NOTE(Zihao): this would happen when u is very close to 1 - // and the sum of probabilities is smaller than u - // In this case, we use the last valid index as the sampled id - sampled_id = temp_storage.last_valid_id; } - output[bx] = sampled_id; } @@ -789,11 +751,10 @@ cudaError_t TopPSamplingFromProb(T *probs, IdType *output, return cudaSuccess; } -template -cudaError_t MinPSamplingFromProb(T *probs, const T* min_p_arr,IdType *output, +template +cudaError_t MinPSamplingFromProb(T *probs, const T* min_p_arr,T *renormed_prob, uint32_t batch_size, uint32_t d, bool deterministic, - uint64_t philox_seed, uint64_t philox_offset, cudaStream_t stream = 0){ constexpr uint32_t BLOCK_THREADS = 1024; const uint32_t vec_size = std::gcd(16 / sizeof(T), d); @@ -801,13 +762,13 @@ cudaError_t MinPSamplingFromProb(T *probs, const T* min_p_arr,IdType *output, const uint32_t smem_size = sizeof(SamplingTempStorage); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &min_p_arr,&output,&d,&philox_seed,&philox_offset}; + void* args[] = {&probs, &min_p_arr,&renormed_prob,&d}; DISPATCH_ALIGNED_VEC_SIZE( vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { auto kernel = MinPSamplingFromProbKernel; + VEC_SIZE, DETERMINISTIC, T,IdType>; CUDA_CALL(cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); CUDA_CALL(cudaLaunchKernel((void *)kernel, nblks, nthrs, args, diff --git a/fastdeploy/model_executor/layers/sample/meta_data.py b/fastdeploy/model_executor/layers/sample/meta_data.py index 9fd69f608d..83090ba976 100644 --- a/fastdeploy/model_executor/layers/sample/meta_data.py +++ b/fastdeploy/model_executor/layers/sample/meta_data.py @@ -15,7 +15,7 @@ """ from dataclasses import dataclass -from typing import Optional, Union +from typing import Optional import paddle @@ -42,5 +42,5 @@ class SamplingMetadata: top_p: paddle.Tensor top_k: Optional[paddle.Tensor] = None - min_p: Optional[Union[float, paddle.Tensor]] = None + min_p: Optional[paddle.Tensor] = None max_num_logprobs: Optional[int] = None diff --git a/fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py b/fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py index 0a50b0ed0c..3572f2604a 100644 --- a/fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py +++ b/fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py @@ -60,6 +60,7 @@ def top_k_top_p_sampling( """ top_p_class = envs.FD_SAMPLING_CLASS.lower() + if top_p_class == "air": _, ids = air_top_p_sampling(x, top_p, @@ -137,8 +138,6 @@ def rejection_top_p_sampling( ) else: if order == "top_k_first": - print("走的这里吧?",x) - print("top_k",top_k) renorm_probs = top_k_renorm_probs(x, top_k) ids = rejection_top_p_sampling( renorm_probs, @@ -158,28 +157,21 @@ def rejection_top_p_sampling( return ids def min_p_sampling( - logits:paddle.tensor, + probs:paddle.tensor, min_p_arr:Optional[paddle.Tensor], - seed:int=-1 )-> tuple[paddle.Tensor, paddle.Tensor]: """ min_p_sampling """ - _ = None - - if current_platform.is_cuda(): - from fastdeploy.model_executor.ops.gpu import min_p_sampling - ids=min_p_sampling(logits,min_p_arr,seed) - - return ids,_ + if paddle.count_nonzero(min_p_arr)==0: + return probs else: - probability_values= paddle.nn.functional.softmax(logits,axis=-1) - max_probabilities = paddle.amax(probability_values, - axis=-1, - keepdim=True) - adjusted_min_p = max_probabilities * min_p_arr - invalid_token_mask = probability_values < adjusted_min_p - logits = paddle.where(invalid_token_mask, - paddle.full_like(logits, -float('inf')), - logits) - return _,logits + if current_platform.is_cuda(): + from fastdeploy.model_executor.ops.gpu import min_p_sampling + probs=min_p_sampling(probs,min_p_arr) + else: + max_probabilities = paddle.amax(probs,axis=-1,keepdim=True) + adjusted_min_p = max_probabilities * min_p_arr + invalid_token_mask = probs < adjusted_min_p + probs= paddle.where(invalid_token_mask,paddle.full_like(probs,0.0),probs) + return probs diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index f8489b214c..44d20f7257 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -176,6 +176,7 @@ def __init__(self): self.forward = self.forward_cuda else: raise NotImplementedError() + self.step=0 self.processor = SamplerProcessor() @@ -265,16 +266,11 @@ def forward_cuda( sampling_metadata.eos_token_ids, ) - # print("sampling_metadata.min_p",sampling_metadata.min_p) probs = F.softmax(logits) - if sampling_metadata.min_p is not None: - next_tokens,probs= min_p_sampling(probs,sampling_metadata.min_p) - if next_tokens is not None: - pass - else: - _, next_tokens = top_k_top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k) - else: - _, next_tokens = top_k_top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k) + + probs= min_p_sampling(probs,sampling_metadata.min_p) + + _, next_tokens = top_k_top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k) logprobs_tensors = None if num_logprobs is None else \ self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=next_tokens) @@ -288,6 +284,7 @@ def forward_cuda( sampled_token_ids=next_tokens, logprobs_tensors=logprobs_tensors, ) + self.step+=1 return sampler_output diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index fb0459636f..2ca0274a17 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -248,7 +248,7 @@ def insert_prefill_inputs(self, req_dicts: List[Request]): self.share_inputs["input_ids"][idx:idx + 1, :length] = np.array( request.prompt_token_ids) - + print("self.share_inputs['input_ids']",self.share_inputs["input_ids"]) # Use chunked prefill if self.parallel_config.enable_chunked_prefill: request.set("chunk_idx", 1) @@ -341,7 +341,7 @@ def get_attr_from_request(request, attr, default_value=None): self.share_inputs["top_p"][idx:idx + 1] = get_attr_from_request(request, "top_p", 0.7) self.share_inputs["top_k"][idx:idx + 1] = request.get("top_k", 0) self.share_inputs["min_p"][idx:idx + 1] = request.get("min_p",0.0) - + self.share_inputs["temperature"][idx:idx + 1] = get_attr_from_request(request,"temperature", 0.95) self.share_inputs["penalty_score"][idx:idx + 1] = get_attr_from_request( request, "repetition_penalty", 1.0) @@ -698,16 +698,12 @@ def _prepare_inputs(self) -> None: # Initialize forward meta data self.initialize_forward_meta() - num_reqs = int((self.share_inputs["seq_lens_this_time"] > 0).sum()) - min_p_slice = self.share_inputs["min_p"][:num_reqs] - no_min_p = paddle.all(min_p_slice == 0.0).item() - # Get sampling metadata self.sampling_metadata = SamplingMetadata( temperature=self.share_inputs["temperature"], top_p=self.share_inputs["top_p"], top_k=self.share_inputs["top_k"], - min_p=None if no_min_p else self.share_inputs["min_p"], + min_p=self.share_inputs["min_p"], step_idx=self.share_inputs["step_idx"], pre_token_ids=self.share_inputs["pre_ids"], frequency_penalties=self.share_inputs["frequency_score"], diff --git a/test/layers/test_min_p.py b/test/layers/test_min_p.py index 15f5b508aa..58854c6a1c 100644 --- a/test/layers/test_min_p.py +++ b/test/layers/test_min_p.py @@ -17,7 +17,6 @@ import numpy as np import paddle import paddle.nn.functional as F -import torch from tqdm import tqdm from fastdeploy.model_executor.ops.gpu import min_p_sampling @@ -199,5 +198,5 @@ def main(): print(f"FastDeploy - Sampled: {data2_fd_batch[b]}, Normalized: {data3_fd_batch[b]}") if __name__ == "__main__": - if paddle.device.is_compiled_with_cuda() and torch.cuda.is_available(): + if paddle.device.is_compiled_with_cuda(): main() From 12ebeb24b1b101d4fc167774807b324d6be9ff19 Mon Sep 17 00:00:00 2001 From: lizexu123 <2694294196@qq.com> Date: Fri, 18 Jul 2025 04:14:27 +0000 Subject: [PATCH 09/16] fix --- fastdeploy/worker/gpu_model_runner.py | 1 - fastdeploy/worker/xpu_model_runner.py | 5 +---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 2ca0274a17..4567f11cf7 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -248,7 +248,6 @@ def insert_prefill_inputs(self, req_dicts: List[Request]): self.share_inputs["input_ids"][idx:idx + 1, :length] = np.array( request.prompt_token_ids) - print("self.share_inputs['input_ids']",self.share_inputs["input_ids"]) # Use chunked prefill if self.parallel_config.enable_chunked_prefill: request.set("chunk_idx", 1) diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index 23cb4edbb0..c334fe634c 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -518,16 +518,13 @@ def _prepare_inputs(self) -> None: self.forward_meta.attn_backend = self.attn_backends[0] self.initialize_attention_backend() - num_reqs = int((self.share_inputs["seq_lens_this_time"] > 0).sum()) - min_p_slice = self.share_inputs["min_p"][:num_reqs] - no_min_p = paddle.all(min_p_slice == 0.0).item() # Get sampling metadata self.sampling_metadata = SamplingMetadata( temperature=self.share_inputs["temperature"], top_p=self.share_inputs["top_p"], top_k=self.share_inputs["top_k"], - min_p=None if no_min_p else self.share_inputs["min_p"], + min_p=self.share_inputs["min_p"], step_idx=self.share_inputs["step_idx"], pre_token_ids=self.share_inputs["pre_ids"], frequency_penalties=self.share_inputs["frequency_score"], From 5b9ffc522ac65417f2ae76c792bef8117622e7e3 Mon Sep 17 00:00:00 2001 From: lizexu123 <2694294196@qq.com> Date: Fri, 18 Jul 2025 07:43:53 +0000 Subject: [PATCH 10/16] modified unit test --- .../layers/sample/ops/top_k_top_p_sampling.py | 4 +- .../model_executor/layers/sample/sampler.py | 3 +- test/layers/test_min_p.py | 188 +++++------------- 3 files changed, 49 insertions(+), 146 deletions(-) diff --git a/fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py b/fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py index 3572f2604a..8908c862be 100644 --- a/fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py +++ b/fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py @@ -163,7 +163,7 @@ def min_p_sampling( """ min_p_sampling """ - if paddle.count_nonzero(min_p_arr)==0: + if paddle.count_nonzero(min_p_arr) == 0: return probs else: if current_platform.is_cuda(): @@ -172,6 +172,6 @@ def min_p_sampling( else: max_probabilities = paddle.amax(probs,axis=-1,keepdim=True) adjusted_min_p = max_probabilities * min_p_arr - invalid_token_mask = probs < adjusted_min_p + invalid_token_mask = probs < adjusted_min_p.reshape([-1, 1]) probs= paddle.where(invalid_token_mask,paddle.full_like(probs,0.0),probs) return probs diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index 7b1abcf45d..8cd6dcbe99 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -176,7 +176,6 @@ def __init__(self): self.forward = self.forward_cuda else: raise NotImplementedError() - self.step=0 self.processor = SamplerProcessor() @@ -286,7 +285,7 @@ def forward_cuda( sampled_token_ids=next_tokens, logprobs_tensors=logprobs_tensors, ) - self.step+=1 + return sampler_output diff --git a/test/layers/test_min_p.py b/test/layers/test_min_p.py index 58854c6a1c..6b40a7154c 100644 --- a/test/layers/test_min_p.py +++ b/test/layers/test_min_p.py @@ -13,11 +13,9 @@ # limitations under the License. -import matplotlib.pyplot as plt import numpy as np import paddle import paddle.nn.functional as F -from tqdm import tqdm from fastdeploy.model_executor.ops.gpu import min_p_sampling @@ -25,105 +23,38 @@ vocab_size = 1000 min_p_value = 0.5 batch_size = 3 -batch_min_p_values = [0.1, 0.5, 0.9] -batch_min_p_values2=[0,3,0,0,0.4] - - -def compress(data): - new_data = np.array([0, 0, 0], dtype=float) - new_data[0] = data[0] - new_data[1] = data[1] - new_data[2] = np.sum(data[2:]) - return new_data - - -def plot_bar_chart(data1, data2, data3, title, request_idx=None): - plt.figure(figsize=(6, 6)) - bar_width = 0.2 - idx = np.arange(len(data1)).astype(float) - - bars1 = plt.bar(idx - bar_width, data1, width=bar_width, color='salmon', label='Original Probability', alpha=0.9) - bars2 = plt.bar(idx, data2, width=bar_width, color='skyblue', label='Sampled Probability', alpha=0.9) - bars3 = plt.bar(idx + bar_width, data3, width=bar_width, color='orange', label='Normalized Original Probability', alpha=0.9) - - plt.bar_label(bars1, label_type='edge', padding=3, fmt='%.3f', fontsize=5, color='black') - plt.bar_label(bars2, label_type='edge', padding=3, fmt='%.3f', fontsize=5, color='red') - plt.bar_label(bars3, label_type='edge', padding=3, fmt='%.3f', fontsize=5, color='blue') - - full_title = title if request_idx is None else f"{title} (min_p={batch_min_p_values[request_idx]})" - plt.title(full_title, fontsize=14) - plt.xlabel("Index", fontsize=12) - plt.ylabel("Probability", fontsize=12) - plt.ylim(0, 1.1) - plt.xlim(-1, 3) - plt.xticks(range(0, 3, 1)) - plt.legend(fontsize=10) - plt.grid(axis='y', linestyle='--', alpha=0.5) - output_path = f"{title.replace(' ', '_')}{'' if request_idx is None else f'_req{request_idx}'}.png" - plt.savefig(output_path, dpi=300, bbox_inches='tight') - plt.clf() - -def plot_low_prob_curve(low_prob_token_probs, sample_time, title, request_idx=None): - plt.figure(figsize=(6, 6)) - plt.plot(np.arange(0, sample_time), low_prob_token_probs, marker='', linestyle='-', linewidth=1, color='blue') - plt.xlabel('Sample Times') - plt.ylabel('Probability') - full_title = 'Probability of Low-Probability Tokens' if request_idx is None else f"Low-Probability Tokens (min_p={batch_min_p_values[request_idx]})" - plt.title(full_title) - plt.grid(alpha=0.3) - output_path = f"{title.replace(' ', '_')}_low_prob{'' if request_idx is None else f'_req{request_idx}'}.png" - plt.savefig(output_path, dpi=300, bbox_inches='tight') - plt.clf() +batch_min_p_values = [0.1, 0.0, 0.9] + # min_p:0.5:FastDeploy -def fastdeploy_min_p_sampling(): +def min_p_sampling_cpu(min_p): logits = paddle.ones(shape=[1, vocab_size], dtype="float32") logits[0][0] = 10 logits[0][1] = 8 low_prob_tensor = paddle.linspace(2.0, 0.0, vocab_size - 2) logits[0][2:] = low_prob_tensor - probs = F.softmax(logits) - min_p = paddle.to_tensor([min_p_value], dtype="float32") - - max_prob = probs.max().item() - threshold = max_prob * min_p.item() - allowed_tokens = paddle.where(probs[0] >= threshold)[0].numpy() - - sample_freq = [0] * vocab_size - low_prob_token_times = 0 - low_prob_token_probs = [] - - for i in tqdm(range(sample_time), desc="FastDeploy Sampling"): - ids = min_p_sampling(probs, min_p, seed=-1) - sample_freq[ids.item()] += 1 - if ids.item() >= 2: - low_prob_token_times += 1 - low_prob_token_probs.append(low_prob_token_times / (i + 1)) - - sample_freq = np.array(sample_freq, dtype=float) / sample_time - low_prob_token_probs = np.array(low_prob_token_probs, dtype=float) + probs=F.softmax(logits) + max_probabilities = paddle.amax(probs, axis=-1, keepdim=True) + adjusted_min_p = max_probabilities * min_p.reshape([-1, 1]) + invalid_token_mask = probs < adjusted_min_p + probs = paddle.where(invalid_token_mask,paddle.full_like(probs,0.0), probs) + return probs - ori_data1 = probs.numpy().reshape(-1) - data1 = compress(ori_data1) - data2 = compress(sample_freq) - - allowed_probs = probs[0, allowed_tokens].numpy() - norm_scale = np.sum(allowed_probs) - data3 = np.zeros_like(data1) - for idx in allowed_tokens: - if idx < 2: - data3[idx] = ori_data1[idx] / norm_scale - else: - data3[2] += ori_data1[idx] / norm_scale - - plot_bar_chart(data1, data2, data3, "FastDeploy[min_p_sampling]") - plot_low_prob_curve(low_prob_token_probs, sample_time, "FastDeploy[min_p_sampling]") +# min_p:0.5:FastDeploy +def fastdeploy_min_p_sampling(min_p): + logits = paddle.ones(shape=[1, vocab_size], dtype="float32") + logits[0][0] = 10 + logits[0][1] = 8 + low_prob_tensor = paddle.linspace(2.0, 0.0, vocab_size - 2) + logits[0][2:] = low_prob_tensor - return data2, data3 + probs = F.softmax(logits) + probs= min_p_sampling(probs, min_p) + return probs -# batch:[0.1.0,5,0.9]:FastDeploy +# batch:[0.1.0.0,0.9]:FastDeploy def fastdeploy_batch_min_p_sampling(batch_size, min_p_values): logits = paddle.ones(shape=[batch_size, vocab_size], dtype="float32") for b in range(batch_size): @@ -134,68 +65,41 @@ def fastdeploy_batch_min_p_sampling(batch_size, min_p_values): probs = F.softmax(logits, axis=-1) min_p_arr = paddle.to_tensor(min_p_values, dtype="float32") - allowed_tokens_list = [] - for b in range(batch_size): - max_prob = probs[b].max().item() - threshold = max_prob * min_p_values[b] - allowed_tokens = paddle.where(probs[b] >= threshold)[0].numpy() - allowed_tokens_list.append(allowed_tokens) - - sample_freq = [np.zeros(vocab_size, dtype=float) for _ in range(batch_size)] - low_prob_token_times = [0] * batch_size - low_prob_token_probs = [[] for _ in range(batch_size)] - - for i in tqdm(range(sample_time), desc="FastDeploy Batch Sampling"): - ids = min_p_sampling(probs, min_p_arr, seed=-1) - for b in range(batch_size): - sample_freq[b][ids[b].item()] += 1 - if ids[b].item() >= 2: - low_prob_token_times[b] += 1 - low_prob_token_probs[b].append(low_prob_token_times[b] / (i + 1)) - - data2_list = [] - data3_list = [] - for b in range(batch_size): - sample_freq_b = sample_freq[b] / sample_time - low_prob_token_probs[b] = np.array(low_prob_token_probs[b], dtype=float) - - ori_data1 = probs[b].numpy() - data1 = compress(ori_data1) - data2 = compress(sample_freq_b) - data2_list.append(data2) + probs = min_p_sampling(probs, min_p_arr) - allowed_probs = probs[b, allowed_tokens_list[b]].numpy() - norm_scale = np.sum(allowed_probs) - data3 = np.zeros_like(data1) - for idx in allowed_tokens_list[b]: - if idx < 2: - data3[idx] = ori_data1[idx] / norm_scale - else: - data3[2] += ori_data1[idx] / norm_scale - data3_list.append(data3) + return probs - plot_bar_chart(data1, data2, data3, "FastDeploy[min_p_batch_sampling]", b) - plot_low_prob_curve(low_prob_token_probs[b], sample_time, "FastDeploy[min_p_batch_sampling]", b) +def compare_results(probs,probs_cpu,atol=1e-6,rtol=1e-6): + probs_np = probs.numpy() + probs_cpu_np = probs_cpu.numpy() + try: + np.testing.assert_allclose( + probs_np, + probs_cpu_np, + rtol=rtol, + atol=atol, + ) + print("The results are same between fastdeploy_min_p_sampling and min_p_sampling_cpu") + except AssertionError as e: + raise AssertionError( + f"The results are different between fastdeploy_min_p_sampling and min_p_sampling_cpu:\n{str(e)}") - return data2_list, data3_list def main(): + # min_p:0.5:FastDeploy + min_p = paddle.to_tensor([min_p_value],dtype="float32") print("Running single min_p sampling (min_p=0.5)...") - data2_fastdeploy, data3_fastdeploy = fastdeploy_min_p_sampling() - - print("\nFastDeploy Single Request Results:") - print(f"Sampled Probability: {data2_fastdeploy}") - print(f"Theoretical Normalized Probability: {data3_fastdeploy}") + probs = fastdeploy_min_p_sampling(min_p) + probs_cpu = min_p_sampling_cpu(min_p) + compare_results(probs,probs_cpu) - print("\nRunning batch min_p sampling (min_p=[0.1, 0.5, 0.9])...") - data2_fd_batch, data3_fd_batch = fastdeploy_batch_min_p_sampling(batch_size, batch_min_p_values) + # batch:[0.1.0.0,0.9]:FastDeploy + batch_min_p = paddle.to_tensor(batch_min_p_values,dtype="float32") + batch_probs = fastdeploy_batch_min_p_sampling(batch_size,batch_min_p) + batch_probs_cpu = min_p_sampling_cpu(batch_min_p) + compare_results(batch_probs,batch_probs_cpu) - data2_fd_batch,data3_fd_batch = fastdeploy_batch_min_p_sampling(batch_size,batch_min_p_values2) - - for b in range(batch_size): - print(f"\nBatch Request {b} (min_p={batch_min_p_values[b]}):") - print(f"FastDeploy - Sampled: {data2_fd_batch[b]}, Normalized: {data3_fd_batch[b]}") if __name__ == "__main__": if paddle.device.is_compiled_with_cuda(): From 5c335b18e9356bcfdeba2eaf55376701796607c3 Mon Sep 17 00:00:00 2001 From: lizexu123 <2694294196@qq.com> Date: Fri, 18 Jul 2025 09:50:15 +0000 Subject: [PATCH 11/16] fix test_min_sampling --- docs/zh/offline_inference.md | 2 +- fastdeploy/engine/sampling_params.py | 6 +- .../model_executor/layers/sample/sampler.py | 2 +- fastdeploy/worker/gpu_model_runner.py | 2 +- test/layers/test_min_p.py | 106 ---------------- test/layers/test_min_sampling.py | 113 ++++++++++++++++++ 6 files changed, 119 insertions(+), 112 deletions(-) delete mode 100644 test/layers/test_min_p.py create mode 100644 test/layers/test_min_sampling.py diff --git a/docs/zh/offline_inference.md b/docs/zh/offline_inference.md index 04760e45f7..015fc7b720 100644 --- a/docs/zh/offline_inference.md +++ b/docs/zh/offline_inference.md @@ -180,7 +180,7 @@ for output in outputs: * temperature(float): 控制生成随机性的参数,值越高结果越随机,值越低结果越确定 * top_p(float): 概率累积分布截断阈值,仅考虑累计概率达到此阈值的最可能token集合 * top_k(int): 采样概率最高的token数量,考虑概率最高的k个token进行采样 -* min_p(float): token入选的最小概率阈值(相对于最高概率token的比值,设为>0可通过过滤低概率token来提升文本生成质量) +* min_p(float): token入选的最小概率阈值(相对于最高概率token的比值,设为>0可通过过滤低概率token来提升文本生成质量) * max_tokens(int): 限制模型生成的最大token数量(包括输入和输出) * min_tokens(int): 强制模型生成的最少token数量,避免过早结束 diff --git a/fastdeploy/engine/sampling_params.py b/fastdeploy/engine/sampling_params.py index eeacc33fc8..f2f58938af 100644 --- a/fastdeploy/engine/sampling_params.py +++ b/fastdeploy/engine/sampling_params.py @@ -53,7 +53,7 @@ class SamplingParams: top_p: Float that controls the cumulative probability of the top tokens to consider. Must be in [0, 1]. Set to 1 to consider all tokens. top_k: Int that controls the number of top tokens to consider. Must be a positive integer. - min_p:Float that represents the minimum probability for a token to be + min_p: Float that represents the minimum probability for a token to be considered, relative to the probability of the most likely token. Must be in [0, 1]. Set to 0 to disable this. seed: Random seed to use for the generation. @@ -87,7 +87,7 @@ class SamplingParams: temperature: float = None top_p: float = None top_k: int = 0 - min_p: float=0.0 + min_p: float = 0.0 seed: Optional[int] = None stop: Optional[Union[str, List[str]]] = None stop_token_ids: Optional[Union[List[List[int]], List[int]]] = None @@ -186,7 +186,7 @@ def _verify_args(self) -> None: if not isinstance(self.top_k, int): raise TypeError( f"top_k must be an integer, got {type(self.top_k).__name__}") - if not 0.0 <=self.min_p <= 1.0: + if not 0.0 <= self.min_p <= 1.0: raise ValueError("min_p must be in [0,1],got f{self.min_p}") if self.max_tokens is not None and self.max_tokens < 1: diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index 8cd6dcbe99..209c680a12 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -269,7 +269,7 @@ def forward_cuda( probs = F.softmax(logits) - probs= min_p_sampling(probs,sampling_metadata.min_p) + probs = min_p_sampling(probs,sampling_metadata.min_p) _, next_tokens = top_k_top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k) diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 57a7634409..c0b21fa0ec 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -337,7 +337,7 @@ def get_attr_from_request(request, attr, default_value=None): request.eos_token_ids, dtype="int64").reshape(-1, 1) self.share_inputs["top_p"][idx:idx + 1] = get_attr_from_request(request, "top_p", 0.7) self.share_inputs["top_k"][idx:idx + 1] = request.get("top_k", 0) - self.share_inputs["min_p"][idx:idx + 1] = request.get("min_p",0.0) + self.share_inputs["min_p"][idx:idx + 1] = request.get("min_p", 0.0) self.share_inputs["temperature"][idx:idx + 1] = get_attr_from_request(request,"temperature", 0.95) self.share_inputs["penalty_score"][idx:idx + 1] = get_attr_from_request( diff --git a/test/layers/test_min_p.py b/test/layers/test_min_p.py deleted file mode 100644 index 6b40a7154c..0000000000 --- a/test/layers/test_min_p.py +++ /dev/null @@ -1,106 +0,0 @@ -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import numpy as np -import paddle -import paddle.nn.functional as F - -from fastdeploy.model_executor.ops.gpu import min_p_sampling - -sample_time = 1000000 -vocab_size = 1000 -min_p_value = 0.5 -batch_size = 3 -batch_min_p_values = [0.1, 0.0, 0.9] - - -# min_p:0.5:FastDeploy -def min_p_sampling_cpu(min_p): - logits = paddle.ones(shape=[1, vocab_size], dtype="float32") - logits[0][0] = 10 - logits[0][1] = 8 - low_prob_tensor = paddle.linspace(2.0, 0.0, vocab_size - 2) - logits[0][2:] = low_prob_tensor - - probs=F.softmax(logits) - max_probabilities = paddle.amax(probs, axis=-1, keepdim=True) - adjusted_min_p = max_probabilities * min_p.reshape([-1, 1]) - invalid_token_mask = probs < adjusted_min_p - probs = paddle.where(invalid_token_mask,paddle.full_like(probs,0.0), probs) - return probs - -# min_p:0.5:FastDeploy -def fastdeploy_min_p_sampling(min_p): - logits = paddle.ones(shape=[1, vocab_size], dtype="float32") - logits[0][0] = 10 - logits[0][1] = 8 - low_prob_tensor = paddle.linspace(2.0, 0.0, vocab_size - 2) - logits[0][2:] = low_prob_tensor - - probs = F.softmax(logits) - probs= min_p_sampling(probs, min_p) - return probs - - -# batch:[0.1.0.0,0.9]:FastDeploy -def fastdeploy_batch_min_p_sampling(batch_size, min_p_values): - logits = paddle.ones(shape=[batch_size, vocab_size], dtype="float32") - for b in range(batch_size): - logits[b][0] = 10 - logits[b][1] = 8 - logits[b][2:] = paddle.linspace(2.0, 0.0, vocab_size - 2) - - probs = F.softmax(logits, axis=-1) - min_p_arr = paddle.to_tensor(min_p_values, dtype="float32") - - probs = min_p_sampling(probs, min_p_arr) - - return probs - -def compare_results(probs,probs_cpu,atol=1e-6,rtol=1e-6): - probs_np = probs.numpy() - probs_cpu_np = probs_cpu.numpy() - try: - np.testing.assert_allclose( - probs_np, - probs_cpu_np, - rtol=rtol, - atol=atol, - ) - print("The results are same between fastdeploy_min_p_sampling and min_p_sampling_cpu") - except AssertionError as e: - raise AssertionError( - f"The results are different between fastdeploy_min_p_sampling and min_p_sampling_cpu:\n{str(e)}") - - - -def main(): - # min_p:0.5:FastDeploy - min_p = paddle.to_tensor([min_p_value],dtype="float32") - print("Running single min_p sampling (min_p=0.5)...") - probs = fastdeploy_min_p_sampling(min_p) - probs_cpu = min_p_sampling_cpu(min_p) - compare_results(probs,probs_cpu) - - # batch:[0.1.0.0,0.9]:FastDeploy - batch_min_p = paddle.to_tensor(batch_min_p_values,dtype="float32") - batch_probs = fastdeploy_batch_min_p_sampling(batch_size,batch_min_p) - batch_probs_cpu = min_p_sampling_cpu(batch_min_p) - compare_results(batch_probs,batch_probs_cpu) - - -if __name__ == "__main__": - if paddle.device.is_compiled_with_cuda(): - main() diff --git a/test/layers/test_min_sampling.py b/test/layers/test_min_sampling.py new file mode 100644 index 0000000000..d8b7ee50bb --- /dev/null +++ b/test/layers/test_min_sampling.py @@ -0,0 +1,113 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import numpy as np +import paddle +import paddle.nn.functional as F + +from fastdeploy.model_executor.ops.gpu import min_p_sampling + + +class TestMinPSampling(unittest.TestCase): + def setUp(self): + self.sample_time = 1000000 + self.vocab_size = 1000 + self.min_p_value = 0.5 + self.batch_size = 3 + self.batch_min_p_values = [0.1, 0.0, 0.9] + self.additional_batch_min_p_values = [0.1, 0.0, 0.3] + + + # min_p:0.5:FastDeploy + def min_p_sampling_cpu(self,min_p): + logits = paddle.ones(shape=[1, self.vocab_size], dtype="float32") + logits[0][0] = 10 + logits[0][1] = 8 + low_prob_tensor = paddle.linspace(2.0, 0.0, self.vocab_size - 2) + logits[0][2:] = low_prob_tensor + + probs = F.softmax(logits) + max_probabilities = paddle.amax(probs, axis=-1, keepdim=True) + adjusted_min_p = max_probabilities * min_p.reshape([-1, 1]) + invalid_token_mask = probs < adjusted_min_p + probs = paddle.where(invalid_token_mask,paddle.full_like(probs,0.0), probs) + return probs + + # min_p:0.5:FastDeploy + def fastdeploy_min_p_sampling(self,min_p): + logits = paddle.ones(shape=[1, self.vocab_size], dtype="float32") + logits[0][0] = 10 + logits[0][1] = 8 + low_prob_tensor = paddle.linspace(2.0, 0.0, self.vocab_size - 2) + logits[0][2:] = low_prob_tensor + + probs = F.softmax(logits) + probs = min_p_sampling(probs, min_p) + return probs + + + # batch:[0.1.0.0,0.9]:FastDeploy + def fastdeploy_batch_min_p_sampling(self,batch_size, min_p_values): + logits = paddle.ones(shape=[batch_size, self.vocab_size], dtype="float32") + for b in range(batch_size): + logits[b][0] = 10 + logits[b][1] = 8 + logits[b][2:] = paddle.linspace(2.0, 0.0, self.vocab_size - 2) + + probs = F.softmax(logits, axis=-1) + min_p_arr = paddle.to_tensor(min_p_values, dtype="float32") + + probs = min_p_sampling(probs, min_p_arr) + + return probs + + def compare_results(self, probs, probs_cpu, atol=1e-6, rtol=1e-6): + probs_np = probs.numpy() + probs_cpu_np = probs_cpu.numpy() + try: + np.testing.assert_allclose( + probs_np, + probs_cpu_np, + rtol=rtol, + atol=atol, + ) + print("The results are same between fastdeploy_min_p_sampling and min_p_sampling_cpu") + except AssertionError as e: + raise AssertionError( + f"The results are different between fastdeploy_min_p_sampling and min_p_sampling_cpu:\n{str(e)}") + + def test_single_min_p_sampling(self): + min_p = paddle.to_tensor([self.min_p_value], dtype="float32") + probs = self.fastdeploy_min_p_sampling(min_p) + probs_cpu = self.min_p_sampling_cpu(min_p) + self.compare_results(probs, probs_cpu) + + def test_batch_min_p_sampling(self): + batch_min_p = paddle.to_tensor(self.batch_min_p_values, dtype="float32") + batch_probs = self.fastdeploy_batch_min_p_sampling(self.batch_size, batch_min_p) + batch_probs_cpu = self.min_p_sampling_cpu(batch_min_p) + self.compare_results(batch_probs, batch_probs_cpu) + + def test_additional_batch_min_p_sampling(self): + additional_batch_min_p = paddle.to_tensor(self.additional_batch_min_p_values, dtype="float32") + additional_batch_probs = self.fastdeploy_batch_min_p_sampling(self.batch_size, additional_batch_min_p) + additional_batch_probs_cpu = self.min_p_sampling_cpu(additional_batch_min_p) + self.compare_results(additional_batch_probs, additional_batch_probs_cpu) + +if __name__ == "__main__": + if paddle.is_compiled_with_cuda(): + unittest.main() From 417dea6f52d35f41ca75b0a5ea183fafd4f8571b Mon Sep 17 00:00:00 2001 From: lizexu123 <2694294196@qq.com> Date: Mon, 21 Jul 2025 03:20:03 +0000 Subject: [PATCH 12/16] pre-commit all files --- fastdeploy/engine/sampling_params.py | 80 ++++++------- .../layers/sample/ops/__init__.py | 2 +- .../layers/sample/ops/top_k_top_p_sampling.py | 14 ++- .../model_executor/layers/sample/sampler.py | 10 +- fastdeploy/worker/gpu_model_runner.py | 49 ++++---- fastdeploy/worker/xpu_model_runner.py | 57 ++++----- test/layers/test_min_p.py | 113 ++++++++++++++++++ test/layers/test_min_sampling.py | 14 +-- 8 files changed, 217 insertions(+), 122 deletions(-) create mode 100644 test/layers/test_min_p.py diff --git a/fastdeploy/engine/sampling_params.py b/fastdeploy/engine/sampling_params.py index 7e0eb98fdd..564ca266e4 100644 --- a/fastdeploy/engine/sampling_params.py +++ b/fastdeploy/engine/sampling_params.py @@ -108,45 +108,46 @@ def from_dict(cls, req_dict: dict[str, Any]) -> SamplingParams: ) @classmethod - def from_optional(cls, - n, - best_of, - presence_penalty, - frequency_penalty, - repetition_penalty, - temperature, - top_p, - top_k, - min_p, - seed=None, - stop=None, - stop_token_ids=None, - max_tokens=None, - reasoning_max_tokens=None, - min_tokens=1, - logprobs=None, - bad_words=None) -> "SamplingParams": + def from_optional( + cls, + n, + best_of, + presence_penalty, + frequency_penalty, + repetition_penalty, + temperature, + top_p, + top_k, + min_p, + seed=None, + stop=None, + stop_token_ids=None, + max_tokens=None, + reasoning_max_tokens=None, + min_tokens=1, + logprobs=None, + bad_words=None, + ) -> "SamplingParams": """Create instance from command line arguments""" - return cls(n=1 if n is None else n, - best_of=best_of, - presence_penalty=presence_penalty - if presence_penalty is not None else 0.0, - frequency_penalty=frequency_penalty - if frequency_penalty is not None else 0.0, - repetition_penalty=repetition_penalty - if repetition_penalty is not None else 1.0, - temperature=temperature if temperature is not None else 1.0, - top_p=top_p, - top_k=top_k if top_k is not None else 0, - min_p=min_p if min_p is not None else 0.0, - seed=seed, - stop=stop, - stop_token_ids=stop_token_ids, - max_tokens=max_tokens if max_tokens is not None else 8192, - reasoning_max_tokens=reasoning_max_tokens, - min_tokens=min_tokens, - logprobs=logprobs, - bad_words=bad_words) + return cls( + n=1 if n is None else n, + best_of=best_of, + presence_penalty=presence_penalty if presence_penalty is not None else 0.0, + frequency_penalty=frequency_penalty if frequency_penalty is not None else 0.0, + repetition_penalty=repetition_penalty if repetition_penalty is not None else 1.0, + temperature=temperature if temperature is not None else 1.0, + top_p=top_p, + top_k=top_k if top_k is not None else 0, + min_p=min_p if min_p is not None else 0.0, + seed=seed, + stop=stop, + stop_token_ids=stop_token_ids, + max_tokens=max_tokens if max_tokens is not None else 8192, + reasoning_max_tokens=reasoning_max_tokens, + min_tokens=min_tokens, + logprobs=logprobs, + bad_words=bad_words, + ) def __post_init__(self): if self.seed is None: @@ -174,8 +175,7 @@ def _verify_args(self) -> None: if self.top_k < -1: raise ValueError(f"top_k must be 0 (disable), or at least 1, " f"got {self.top_k}.") if not isinstance(self.top_k, int): - raise TypeError( - f"top_k must be an integer, got {type(self.top_k).__name__}") + raise TypeError(f"top_k must be an integer, got {type(self.top_k).__name__}") if not 0.0 <= self.min_p <= 1.0: raise ValueError("min_p must be in [0,1],got f{self.min_p}") diff --git a/fastdeploy/model_executor/layers/sample/ops/__init__.py b/fastdeploy/model_executor/layers/sample/ops/__init__.py index 142566baab..09834b305a 100644 --- a/fastdeploy/model_executor/layers/sample/ops/__init__.py +++ b/fastdeploy/model_executor/layers/sample/ops/__init__.py @@ -18,7 +18,7 @@ apply_penalty_multi_scores, apply_speculative_penalty_multi_scores, ) -from .top_k_top_p_sampling import top_k_top_p_sampling,min_p_sampling +from .top_k_top_p_sampling import min_p_sampling, top_k_top_p_sampling __all__ = [ "apply_penalty_multi_scores", diff --git a/fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py b/fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py index a03a5a91a2..0edd4f42f9 100644 --- a/fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py +++ b/fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py @@ -156,10 +156,11 @@ def rejection_top_p_sampling( raise RuntimeError("Cannot import rejection_top_p_sampling op.") return ids + def min_p_sampling( - probs:paddle.tensor, - min_p_arr:Optional[paddle.Tensor], -)-> tuple[paddle.Tensor, paddle.Tensor]: + probs: paddle.tensor, + min_p_arr: Optional[paddle.Tensor], +) -> tuple[paddle.Tensor, paddle.Tensor]: """ min_p_sampling """ @@ -168,10 +169,11 @@ def min_p_sampling( else: if current_platform.is_cuda(): from fastdeploy.model_executor.ops.gpu import min_p_sampling - probs=min_p_sampling(probs,min_p_arr) + + probs = min_p_sampling(probs, min_p_arr) else: - max_probabilities = paddle.amax(probs,axis=-1,keepdim=True) + max_probabilities = paddle.amax(probs, axis=-1, keepdim=True) adjusted_min_p = max_probabilities * min_p_arr invalid_token_mask = probs < adjusted_min_p.reshape([-1, 1]) - probs= paddle.where(invalid_token_mask,paddle.full_like(probs,0.0),probs) + probs = paddle.where(invalid_token_mask, paddle.full_like(probs, 0.0), probs) return probs diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index b045f4e899..ff32b154ae 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -28,8 +28,11 @@ ) from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata from fastdeploy.model_executor.layers.sample.ops import ( - apply_penalty_multi_scores, apply_speculative_penalty_multi_scores, - min_p_sampling, top_k_top_p_sampling) + apply_penalty_multi_scores, + apply_speculative_penalty_multi_scores, + min_p_sampling, + top_k_top_p_sampling, +) from fastdeploy.platforms import current_platform from fastdeploy.worker.output import LogprobsTensors, SamplerOutput @@ -247,7 +250,6 @@ def forward_cuda( logits = self.processor.apply_token_mask(logits, skip_idx_list) - logits = apply_penalty_multi_scores( sampling_metadata.pre_token_ids, sampling_metadata.prompt_ids, @@ -265,7 +267,7 @@ def forward_cuda( probs = F.softmax(logits) - probs = min_p_sampling(probs,sampling_metadata.min_p) + probs = min_p_sampling(probs, sampling_metadata.min_p) _, next_tokens = top_k_top_p_sampling(probs, sampling_metadata.top_p, sampling_metadata.top_k) diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 0cf54438b3..a86eed570a 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -109,7 +109,6 @@ def __init__( else: self.sampler = SpeculativeSampler(fd_config) - # Lazy initialize kv cache after model loading # self.kv_caches: list[paddle.Tensor] = [] @@ -318,19 +317,21 @@ def get_attr_from_request(request, attr, default_value=None): if len(request.eos_token_ids) < self.parallel_config.eos_tokens_lens: request.eos_token_ids.append(request.eos_token_ids[0]) - self.share_inputs["eos_token_id"][:] = np.array( - request.eos_token_ids, dtype="int64").reshape(-1, 1) - self.share_inputs["top_p"][idx:idx + 1] = get_attr_from_request(request, "top_p", 0.7) - self.share_inputs["top_k"][idx:idx + 1] = request.get("top_k", 0) - self.share_inputs["min_p"][idx:idx + 1] = request.get("min_p", 0.0) - - self.share_inputs["temperature"][idx:idx + 1] = get_attr_from_request(request,"temperature", 0.95) - self.share_inputs["penalty_score"][idx:idx + 1] = get_attr_from_request( - request, "repetition_penalty", 1.0) - self.share_inputs["frequency_score"][idx:idx + 1] = get_attr_from_request( - request, "frequency_penalty", 0.0) - self.share_inputs["presence_score"][idx:idx + 1] = get_attr_from_request( - request, "presence_penalty", 0.0) + self.share_inputs["eos_token_id"][:] = np.array(request.eos_token_ids, dtype="int64").reshape(-1, 1) + self.share_inputs["top_p"][idx : idx + 1] = get_attr_from_request(request, "top_p", 0.7) + self.share_inputs["top_k"][idx : idx + 1] = request.get("top_k", 0) + self.share_inputs["min_p"][idx : idx + 1] = request.get("min_p", 0.0) + + self.share_inputs["temperature"][idx : idx + 1] = get_attr_from_request(request, "temperature", 0.95) + self.share_inputs["penalty_score"][idx : idx + 1] = get_attr_from_request( + request, "repetition_penalty", 1.0 + ) + self.share_inputs["frequency_score"][idx : idx + 1] = get_attr_from_request( + request, "frequency_penalty", 0.0 + ) + self.share_inputs["presence_score"][idx : idx + 1] = get_attr_from_request( + request, "presence_penalty", 0.0 + ) self.share_inputs["min_dec_len"][idx : idx + 1] = request.get("min_tokens", 1) self.share_inputs["max_dec_len"][idx : idx + 1] = request.get( @@ -424,20 +425,12 @@ def _init_share_inputs(self, max_num_seqs: int): dtype="int64", ) self.share_inputs["prompt_ids"] = paddle.full( - [max_num_seqs, self.parallel_config.max_model_len], - self.parallel_config.pad_token_id, - dtype='int64') - self.share_inputs["eos_token_id"] = paddle.full( - [self.parallel_config.eos_tokens_lens, 1], 0, dtype='int64') - self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1], - self.model_config.top_p, - dtype='float32') - self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1], - 0, - dtype='int64') - self.share_inputs["min_p"] = paddle.full([max_num_seqs, 1], - 0.0, - dtype='float32') + [max_num_seqs, self.parallel_config.max_model_len], self.parallel_config.pad_token_id, dtype="int64" + ) + self.share_inputs["eos_token_id"] = paddle.full([self.parallel_config.eos_tokens_lens, 1], 0, dtype="int64") + self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1], self.model_config.top_p, dtype="float32") + self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int64") + self.share_inputs["min_p"] = paddle.full([max_num_seqs, 1], 0.0, dtype="float32") self.share_inputs["temperature"] = paddle.full( [max_num_seqs, 1], self.model_config.temperature, dtype="float32" ) diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index bb906f06d9..bc2c7a88c2 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -300,27 +300,21 @@ def process_prefill_inputs(self, req_dicts: List[Request]): self.share_inputs["input_ids"][idx : idx + 1, :length] = np.array(request.prompt_token_ids) if len(request.eos_token_ids) < self.parallel_config.eos_tokens_lens: request.eos_token_ids.append(request.eos_token_ids[0]) - self.share_inputs["eos_token_id"][:] = np.array( - request.eos_token_ids, dtype="int64").reshape(-1, 1) - self.share_inputs["pre_ids"][idx:idx + 1] = -1 - self.share_inputs["top_p"][idx:idx + 1] = request.get("top_p", 0.7) - self.share_inputs["top_k"][idx:idx + 1] = request.get("top_k", 0) - self.share_inputs["min_p"][idx:idx+1] = request.get("min_p", 0.0) - self.share_inputs["temperature"][idx:idx + 1] = request.get( - "temperature", 0.95) - self.share_inputs["penalty_score"][idx:idx + 1] = request.get( - "repetition_penalty", 1.0) - self.share_inputs["frequency_score"][idx:idx + 1] = request.get( - "frequency_penalty", 0.0) - self.share_inputs["presence_score"][idx:idx + 1] = request.get( - "presence_penalty", 0.0) - self.share_inputs["seq_lens_this_time"][idx:idx + 1] = length - self.share_inputs["step_seq_lens_encoder"][idx:idx + 1] = length - self.share_inputs["seq_lens_encoder"][idx:idx + 1] = length - self.share_inputs["seq_lens_decoder"][idx:idx + 1] = 0 - self.share_inputs["step_idx"][idx:idx + 1] = 0 - self.share_inputs["min_dec_len"][idx:idx + 1] = request.get( - "min_tokens", 1) + self.share_inputs["eos_token_id"][:] = np.array(request.eos_token_ids, dtype="int64").reshape(-1, 1) + self.share_inputs["pre_ids"][idx : idx + 1] = -1 + self.share_inputs["top_p"][idx : idx + 1] = request.get("top_p", 0.7) + self.share_inputs["top_k"][idx : idx + 1] = request.get("top_k", 0) + self.share_inputs["min_p"][idx : idx + 1] = request.get("min_p", 0.0) + self.share_inputs["temperature"][idx : idx + 1] = request.get("temperature", 0.95) + self.share_inputs["penalty_score"][idx : idx + 1] = request.get("repetition_penalty", 1.0) + self.share_inputs["frequency_score"][idx : idx + 1] = request.get("frequency_penalty", 0.0) + self.share_inputs["presence_score"][idx : idx + 1] = request.get("presence_penalty", 0.0) + self.share_inputs["seq_lens_this_time"][idx : idx + 1] = length + self.share_inputs["step_seq_lens_encoder"][idx : idx + 1] = length + self.share_inputs["seq_lens_encoder"][idx : idx + 1] = length + self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0 + self.share_inputs["step_idx"][idx : idx + 1] = 0 + self.share_inputs["min_dec_len"][idx : idx + 1] = request.get("min_tokens", 1) self.share_inputs["max_dec_len"][idx : idx + 1] = request.get( "max_tokens", self.model_config.max_model_len @@ -363,20 +357,12 @@ def _init_share_inputs(self, max_num_seqs: int): dtype="int64", ) self.share_inputs["input_ids"] = paddle.full( - [max_num_seqs, self.parallel_config.max_model_len], - self.parallel_config.pad_token_id, - dtype='int64') - self.share_inputs["eos_token_id"] = paddle.full( - [self.parallel_config.eos_tokens_lens, 1], 0, dtype='int64') - self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1], - self.model_config.top_p, - dtype='float32') - self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1], - 0, - dtype='int64') - self.share_inputs["min_p"] = paddle.full([max_num_seqs, 1], - 0.0, - dtype='float32') + [max_num_seqs, self.parallel_config.max_model_len], self.parallel_config.pad_token_id, dtype="int64" + ) + self.share_inputs["eos_token_id"] = paddle.full([self.parallel_config.eos_tokens_lens, 1], 0, dtype="int64") + self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1], self.model_config.top_p, dtype="float32") + self.share_inputs["top_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int64") + self.share_inputs["min_p"] = paddle.full([max_num_seqs, 1], 0.0, dtype="float32") self.share_inputs["temperature"] = paddle.full( [max_num_seqs, 1], self.model_config.temperature, dtype="float32" ) @@ -482,7 +468,6 @@ def _prepare_inputs(self) -> None: self.forward_meta.attn_backend = self.attn_backends[0] self.initialize_attention_backend() - # Get sampling metadata self.sampling_metadata = SamplingMetadata( temperature=self.share_inputs["temperature"], diff --git a/test/layers/test_min_p.py b/test/layers/test_min_p.py new file mode 100644 index 0000000000..624e00e125 --- /dev/null +++ b/test/layers/test_min_p.py @@ -0,0 +1,113 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import numpy as np +import paddle +import paddle.nn.functional as F + +from fastdeploy.model_executor.ops.gpu import min_p_sampling + + +class TestMinPSampling(unittest.TestCase): + def setUp(self): + self.sample_time = 1000000 + self.vocab_size = 1000 + self.min_p_value = 0.5 + self.batch_size = 3 + self.batch_min_p_values = [0.1, 0.0, 0.9] + self.additional_batch_min_p_values = [0.1, 0.0, 0.3] + + # min_p:0.5:FastDeploy + def min_p_sampling_cpu(self, min_p): + logits = paddle.ones(shape=[1, self.vocab_size], dtype="float32") + logits[0][0] = 10 + logits[0][1] = 8 + low_prob_tensor = paddle.linspace(2.0, 0.0, self.vocab_size - 2) + logits[0][2:] = low_prob_tensor + + probs = F.softmax(logits) + max_probabilities = paddle.amax(probs, axis=-1, keepdim=True) + adjusted_min_p = max_probabilities * min_p.reshape([-1, 1]) + invalid_token_mask = probs < adjusted_min_p + probs = paddle.where(invalid_token_mask, paddle.full_like(probs, 0.0), probs) + return probs + + # min_p:0.5:FastDeploy + def fastdeploy_min_p_sampling(self, min_p): + logits = paddle.ones(shape=[1, self.vocab_size], dtype="float32") + logits[0][0] = 10 + logits[0][1] = 8 + low_prob_tensor = paddle.linspace(2.0, 0.0, self.vocab_size - 2) + logits[0][2:] = low_prob_tensor + + probs = F.softmax(logits) + probs = min_p_sampling(probs, min_p) + return probs + + # batch:[0.1.0.0,0.9]:FastDeploy + def fastdeploy_batch_min_p_sampling(self, batch_size, min_p_values): + logits = paddle.ones(shape=[batch_size, self.vocab_size], dtype="float32") + for b in range(batch_size): + logits[b][0] = 10 + logits[b][1] = 8 + logits[b][2:] = paddle.linspace(2.0, 0.0, self.vocab_size - 2) + + probs = F.softmax(logits, axis=-1) + min_p_arr = paddle.to_tensor(min_p_values, dtype="float32") + + probs = min_p_sampling(probs, min_p_arr) + + return probs + + def compare_results(self, probs, probs_cpu, atol=1e-6, rtol=1e-6): + probs_np = probs.numpy() + probs_cpu_np = probs_cpu.numpy() + try: + np.testing.assert_allclose( + probs_np, + probs_cpu_np, + rtol=rtol, + atol=atol, + ) + print("The results are same between fastdeploy_min_p_sampling and min_p_sampling_cpu") + except AssertionError as e: + raise AssertionError( + f"The results are different between fastdeploy_min_p_sampling and min_p_sampling_cpu:\n{str(e)}" + ) + + def test_single_min_p_sampling(self): + min_p = paddle.to_tensor([self.min_p_value], dtype="float32") + probs = self.fastdeploy_min_p_sampling(min_p) + probs_cpu = self.min_p_sampling_cpu(min_p) + self.compare_results(probs, probs_cpu) + + def test_batch_min_p_sampling(self): + batch_min_p = paddle.to_tensor(self.batch_min_p_values, dtype="float32") + batch_probs = self.fastdeploy_batch_min_p_sampling(self.batch_size, batch_min_p) + batch_probs_cpu = self.min_p_sampling_cpu(batch_min_p) + self.compare_results(batch_probs, batch_probs_cpu) + + def test_additional_batch_min_p_sampling(self): + additional_batch_min_p = paddle.to_tensor(self.additional_batch_min_p_values, dtype="float32") + additional_batch_probs = self.fastdeploy_batch_min_p_sampling(self.batch_size, additional_batch_min_p) + additional_batch_probs_cpu = self.min_p_sampling_cpu(additional_batch_min_p) + self.compare_results(additional_batch_probs, additional_batch_probs_cpu) + + +if __name__ == "__main__": + if paddle.is_compiled_with_cuda(): + unittest.main() diff --git a/test/layers/test_min_sampling.py b/test/layers/test_min_sampling.py index d8b7ee50bb..624e00e125 100644 --- a/test/layers/test_min_sampling.py +++ b/test/layers/test_min_sampling.py @@ -31,9 +31,8 @@ def setUp(self): self.batch_min_p_values = [0.1, 0.0, 0.9] self.additional_batch_min_p_values = [0.1, 0.0, 0.3] - # min_p:0.5:FastDeploy - def min_p_sampling_cpu(self,min_p): + def min_p_sampling_cpu(self, min_p): logits = paddle.ones(shape=[1, self.vocab_size], dtype="float32") logits[0][0] = 10 logits[0][1] = 8 @@ -44,11 +43,11 @@ def min_p_sampling_cpu(self,min_p): max_probabilities = paddle.amax(probs, axis=-1, keepdim=True) adjusted_min_p = max_probabilities * min_p.reshape([-1, 1]) invalid_token_mask = probs < adjusted_min_p - probs = paddle.where(invalid_token_mask,paddle.full_like(probs,0.0), probs) + probs = paddle.where(invalid_token_mask, paddle.full_like(probs, 0.0), probs) return probs # min_p:0.5:FastDeploy - def fastdeploy_min_p_sampling(self,min_p): + def fastdeploy_min_p_sampling(self, min_p): logits = paddle.ones(shape=[1, self.vocab_size], dtype="float32") logits[0][0] = 10 logits[0][1] = 8 @@ -59,9 +58,8 @@ def fastdeploy_min_p_sampling(self,min_p): probs = min_p_sampling(probs, min_p) return probs - # batch:[0.1.0.0,0.9]:FastDeploy - def fastdeploy_batch_min_p_sampling(self,batch_size, min_p_values): + def fastdeploy_batch_min_p_sampling(self, batch_size, min_p_values): logits = paddle.ones(shape=[batch_size, self.vocab_size], dtype="float32") for b in range(batch_size): logits[b][0] = 10 @@ -88,7 +86,8 @@ def compare_results(self, probs, probs_cpu, atol=1e-6, rtol=1e-6): print("The results are same between fastdeploy_min_p_sampling and min_p_sampling_cpu") except AssertionError as e: raise AssertionError( - f"The results are different between fastdeploy_min_p_sampling and min_p_sampling_cpu:\n{str(e)}") + f"The results are different between fastdeploy_min_p_sampling and min_p_sampling_cpu:\n{str(e)}" + ) def test_single_min_p_sampling(self): min_p = paddle.to_tensor([self.min_p_value], dtype="float32") @@ -108,6 +107,7 @@ def test_additional_batch_min_p_sampling(self): additional_batch_probs_cpu = self.min_p_sampling_cpu(additional_batch_min_p) self.compare_results(additional_batch_probs, additional_batch_probs_cpu) + if __name__ == "__main__": if paddle.is_compiled_with_cuda(): unittest.main() From 132eececefa59602139e6a855f7cdc22c1502ad6 Mon Sep 17 00:00:00 2001 From: lizexu123 <2694294196@qq.com> Date: Mon, 21 Jul 2025 03:35:22 +0000 Subject: [PATCH 13/16] fix --- test/layers/test_min_p.py | 113 -------------------------------------- 1 file changed, 113 deletions(-) delete mode 100644 test/layers/test_min_p.py diff --git a/test/layers/test_min_p.py b/test/layers/test_min_p.py deleted file mode 100644 index 624e00e125..0000000000 --- a/test/layers/test_min_p.py +++ /dev/null @@ -1,113 +0,0 @@ -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import unittest - -import numpy as np -import paddle -import paddle.nn.functional as F - -from fastdeploy.model_executor.ops.gpu import min_p_sampling - - -class TestMinPSampling(unittest.TestCase): - def setUp(self): - self.sample_time = 1000000 - self.vocab_size = 1000 - self.min_p_value = 0.5 - self.batch_size = 3 - self.batch_min_p_values = [0.1, 0.0, 0.9] - self.additional_batch_min_p_values = [0.1, 0.0, 0.3] - - # min_p:0.5:FastDeploy - def min_p_sampling_cpu(self, min_p): - logits = paddle.ones(shape=[1, self.vocab_size], dtype="float32") - logits[0][0] = 10 - logits[0][1] = 8 - low_prob_tensor = paddle.linspace(2.0, 0.0, self.vocab_size - 2) - logits[0][2:] = low_prob_tensor - - probs = F.softmax(logits) - max_probabilities = paddle.amax(probs, axis=-1, keepdim=True) - adjusted_min_p = max_probabilities * min_p.reshape([-1, 1]) - invalid_token_mask = probs < adjusted_min_p - probs = paddle.where(invalid_token_mask, paddle.full_like(probs, 0.0), probs) - return probs - - # min_p:0.5:FastDeploy - def fastdeploy_min_p_sampling(self, min_p): - logits = paddle.ones(shape=[1, self.vocab_size], dtype="float32") - logits[0][0] = 10 - logits[0][1] = 8 - low_prob_tensor = paddle.linspace(2.0, 0.0, self.vocab_size - 2) - logits[0][2:] = low_prob_tensor - - probs = F.softmax(logits) - probs = min_p_sampling(probs, min_p) - return probs - - # batch:[0.1.0.0,0.9]:FastDeploy - def fastdeploy_batch_min_p_sampling(self, batch_size, min_p_values): - logits = paddle.ones(shape=[batch_size, self.vocab_size], dtype="float32") - for b in range(batch_size): - logits[b][0] = 10 - logits[b][1] = 8 - logits[b][2:] = paddle.linspace(2.0, 0.0, self.vocab_size - 2) - - probs = F.softmax(logits, axis=-1) - min_p_arr = paddle.to_tensor(min_p_values, dtype="float32") - - probs = min_p_sampling(probs, min_p_arr) - - return probs - - def compare_results(self, probs, probs_cpu, atol=1e-6, rtol=1e-6): - probs_np = probs.numpy() - probs_cpu_np = probs_cpu.numpy() - try: - np.testing.assert_allclose( - probs_np, - probs_cpu_np, - rtol=rtol, - atol=atol, - ) - print("The results are same between fastdeploy_min_p_sampling and min_p_sampling_cpu") - except AssertionError as e: - raise AssertionError( - f"The results are different between fastdeploy_min_p_sampling and min_p_sampling_cpu:\n{str(e)}" - ) - - def test_single_min_p_sampling(self): - min_p = paddle.to_tensor([self.min_p_value], dtype="float32") - probs = self.fastdeploy_min_p_sampling(min_p) - probs_cpu = self.min_p_sampling_cpu(min_p) - self.compare_results(probs, probs_cpu) - - def test_batch_min_p_sampling(self): - batch_min_p = paddle.to_tensor(self.batch_min_p_values, dtype="float32") - batch_probs = self.fastdeploy_batch_min_p_sampling(self.batch_size, batch_min_p) - batch_probs_cpu = self.min_p_sampling_cpu(batch_min_p) - self.compare_results(batch_probs, batch_probs_cpu) - - def test_additional_batch_min_p_sampling(self): - additional_batch_min_p = paddle.to_tensor(self.additional_batch_min_p_values, dtype="float32") - additional_batch_probs = self.fastdeploy_batch_min_p_sampling(self.batch_size, additional_batch_min_p) - additional_batch_probs_cpu = self.min_p_sampling_cpu(additional_batch_min_p) - self.compare_results(additional_batch_probs, additional_batch_probs_cpu) - - -if __name__ == "__main__": - if paddle.is_compiled_with_cuda(): - unittest.main() From 30ff611d6ae54af0a7c343982135441b83310d4d Mon Sep 17 00:00:00 2001 From: lizexu123 <2694294196@qq.com> Date: Mon, 21 Jul 2025 03:39:55 +0000 Subject: [PATCH 14/16] fix --- fastdeploy/engine/sampling_params.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/fastdeploy/engine/sampling_params.py b/fastdeploy/engine/sampling_params.py index 564ca266e4..688351883b 100644 --- a/fastdeploy/engine/sampling_params.py +++ b/fastdeploy/engine/sampling_params.py @@ -127,14 +127,14 @@ def from_optional( min_tokens=1, logprobs=None, bad_words=None, - ) -> "SamplingParams": + ) -> SamplingParams: """Create instance from command line arguments""" return cls( n=1 if n is None else n, best_of=best_of, - presence_penalty=presence_penalty if presence_penalty is not None else 0.0, - frequency_penalty=frequency_penalty if frequency_penalty is not None else 0.0, - repetition_penalty=repetition_penalty if repetition_penalty is not None else 1.0, + presence_penalty=(presence_penalty if presence_penalty is not None else 0.0), + frequency_penalty=(frequency_penalty if frequency_penalty is not None else 0.0), + repetition_penalty=(repetition_penalty if repetition_penalty is not None else 1.0), temperature=temperature if temperature is not None else 1.0, top_p=top_p, top_k=top_k if top_k is not None else 0, From e21671b193c50ef0ac891960e36720e1fb5a98ee Mon Sep 17 00:00:00 2001 From: lizexu123 <2694294196@qq.com> Date: Mon, 21 Jul 2025 03:42:17 +0000 Subject: [PATCH 15/16] fix --- fastdeploy/worker/gpu_model_runner.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index a86eed570a..f2cd2af786 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -425,7 +425,9 @@ def _init_share_inputs(self, max_num_seqs: int): dtype="int64", ) self.share_inputs["prompt_ids"] = paddle.full( - [max_num_seqs, self.parallel_config.max_model_len], self.parallel_config.pad_token_id, dtype="int64" + [max_num_seqs, self.parallel_config.max_model_len], + self.parallel_config.pad_token_id, + dtype="int64", ) self.share_inputs["eos_token_id"] = paddle.full([self.parallel_config.eos_tokens_lens, 1], 0, dtype="int64") self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1], self.model_config.top_p, dtype="float32") From 25b6f934cb57c57dab324aed9a29d61ab300a0fd Mon Sep 17 00:00:00 2001 From: lizexu123 <2694294196@qq.com> Date: Mon, 21 Jul 2025 03:49:50 +0000 Subject: [PATCH 16/16] fix xpu_model_runner.py --- fastdeploy/worker/xpu_model_runner.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index bc2c7a88c2..601d7f264e 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -357,7 +357,9 @@ def _init_share_inputs(self, max_num_seqs: int): dtype="int64", ) self.share_inputs["input_ids"] = paddle.full( - [max_num_seqs, self.parallel_config.max_model_len], self.parallel_config.pad_token_id, dtype="int64" + [max_num_seqs, self.parallel_config.max_model_len], + self.parallel_config.pad_token_id, + dtype="int64", ) self.share_inputs["eos_token_id"] = paddle.full([self.parallel_config.eos_tokens_lens, 1], 0, dtype="int64") self.share_inputs["top_p"] = paddle.full([max_num_seqs, 1], self.model_config.top_p, dtype="float32")