diff --git a/slm/model_zoo/bert/README.md b/slm/model_zoo/bert/README.md index f25ed3e07bca..6403f6f79bf1 100644 --- a/slm/model_zoo/bert/README.md +++ b/slm/model_zoo/bert/README.md @@ -4,7 +4,7 @@ [BERT](https://arxiv.org/abs/1810.04805) (Bidirectional Encoder Representations from Transformers)以[Transformer](https://arxiv.org/abs/1706.03762) 编码器为网络基本组件,使用掩码语言模型(Masked Language Model)和邻接句子预测(Next Sentence Prediction)两个任务在大规模无标注文本语料上进行预训练(pre-train),得到融合了双向内容的通用语义表示模型。以预训练产生的通用语义表示模型为基础,结合任务适配的简单输出层,微调(fine-tune)后即可应用到下游的 NLP 任务,效果通常也较直接在下游的任务上训练的模型更优。此前 BERT 即在[GLUE 评测任务](https://gluebenchmark.com/tasks)上取得了 SOTA 的结果。 -本项目是 BERT 在 Paddle 2.0上的开源实现,包含了预训练和[GLUE 评测任务](https://gluebenchmark.com/tasks)上的微调代码。 +本项目是 BERT 在 Paddle 2.0上的开源实现,并在 PaddlePaddle 3.0.0版本进行了适配与验证,包含了预训练和[GLUE 评测任务](https://gluebenchmark.com/tasks)上的微调代码。 ## 快速开始 @@ -262,23 +262,22 @@ python -u ./export_model.py \ 其中参数释义如下: - `model_type` 指示了模型类型,使用 BERT 模型时设置为 bert 即可。 - `model_path` 表示训练模型的保存路径,与训练时的`output_dir`一致。 -- `output_path` 表示导出预测模型文件的前缀。保存时会添加后缀(`pdiparams`,`pdiparams.info`,`pdmodel`);除此之外,还会在`output_path`包含的目录下保存 tokenizer 相关内容。 +- `output_path` 表示导出预测模型文件的前缀。保存时会添加后缀(`pdiparams`,`pdiparams.info`,`pdmodel`/`json`);除此之外,还会在`output_path`包含的目录下保存 tokenizer 相关内容。 -完成模型导出后,可以开始部署。`deploy/python/seq_cls_infer.py` 文件提供了 python 部署预测示例。可执行以下命令运行部署示例: +完成模型导出后,可以开始部署。`deploy/python/infer.py` 文件提供了 python 部署预测示例。可执行以下命令运行部署示例: ```shell -python deploy/python/seq_cls_infer.py --model_dir infer_model/ --device gpu --backend paddle +python deploy/python/infer.py --model_dir infer_model/ --device gpu ``` 运行后预测结果打印如下: ```bash -[INFO] fastdeploy/runtime/runtime.cc(266)::CreatePaddleBackend Runtime initialized with Backend::PDINFER in Device::GPU. -Batch id: 0, example id: 0, sentence1: against shimmering cinematography that lends the setting the ethereal beauty of an asian landscape painting, label: positive, negative prob: 0.0003, positive prob: 0.9997. -Batch id: 1, example id: 0, sentence1: the situation in a well-balanced fashion, label: positive, negative prob: 0.0002, positive prob: 0.9998. -Batch id: 2, example id: 0, sentence1: at achieving the modest , crowd-pleasing goals it sets for itself, label: positive, negative prob: 0.0017, positive prob: 0.9983. -Batch id: 3, example id: 0, sentence1: so pat it makes your teeth hurt, label: negative, negative prob: 0.9986, positive prob: 0.0014. -Batch id: 4, example id: 0, sentence1: this new jangle of noise , mayhem and stupidity must be a serious contender for the title ., label: negative, negative prob: 0.9806, positive prob: 0.0194. +Batch id: 0, example id: 0, sentence: against shimmering cinematography that lends the setting the ethereal beauty of an asian landscape painting, label: positive, negative prob: 0.4623, positive prob: 0.5377. +Batch id: 0, example id: 1, sentence: the situation in a well-balanced fashion, label: positive, negative prob: 0.3500, positive prob: 0.6500. +Batch id: 1, example id: 0, sentence: at achieving the modest , crowd-pleasing goals it sets for itself, label: positive, negative prob: 0.4530, positive prob: 0.5470. +Batch id: 1, example id: 1, sentence: so pat it makes your teeth hurt, label: positive, negative prob: 0.3816, positive prob: 0.6184. +Batch id: 2, example id: 0, sentence: this new jangle of noise , mayhem and stupidity must be a serious contender for the title ., label: positive, negative prob: 0.3650, positive prob: 0.6350. ``` 更多详细用法可参考 [Python 部署](deploy/python/README.md)。 diff --git a/slm/model_zoo/bert/deploy/python/README.md b/slm/model_zoo/bert/deploy/python/README.md index 7e78d164ac14..a7d00f405312 100644 --- a/slm/model_zoo/bert/deploy/python/README.md +++ b/slm/model_zoo/bert/deploy/python/README.md @@ -1,30 +1,26 @@ -# FastDeploy BERT 模型 Python 部署示例 - -在部署前,参考 [FastDeploy SDK 安装文档](https://github.com/PaddlePaddle/FastDeploy/blob/develop/docs/cn/build_and_install/download_prebuilt_libraries.md)安装 FastDeploy Python SDK。 - -本目录下分别提供 `seq_cls_infer.py` 快速完成在 CPU/GPU 的 GLUE 文本分类任务的 Python 部署示例。 +# BERT 模型 Python 推理示例 +本目录下提供 `seq_cls_infer.py` 快速完成在 CPU/GPU 的 GLUE 文本分类任务的 Python 示例。 ## 快速开始 -以下示例展示如何基于 FastDeploy 库完成 BERT 模型在 GLUE SST-2 数据集上进行自然语言推断任务的 Python 预测部署,可通过命令行参数`--device`以及`--backend`指定运行在不同的硬件以及推理引擎后端,并使用`--model_dir`参数指定运行的模型,具体参数设置可查看下面[参数说明](#参数说明)。示例中的模型是按照 [BERT 训练文档](../../README.md)导出得到的部署模型,其模型目录为`model_zoo/bert/infer_model`(用户可按实际情况设置)。 +可通过命令行参数`--device`指定运行在不同的硬件,并使用`--model_dir`参数指定运行的模型,具体参数设置可查看下面[参数说明](#参数说明)。示例中的模型是按照 [BERT 训练文档](../../README.md)导出得到的部署模型,其模型目录为`model_zoo/bert/infer_model`(用户可按实际情况设置)。 ```bash # CPU 推理 -python seq_cls_infer.py --model_dir ../../infer_model/ --device cpu --backend paddle +python infer.py --model_dir ../../infer_model/ --device cpu # GPU 推理 -python seq_cls_infer.py --model_dir ../../infer_model/ --device gpu --backend paddle +python infer.py --model_dir ../../infer_model/ --device gpu ``` 运行完成后返回的结果如下: ```bash -[INFO] fastdeploy/runtime/runtime.cc(266)::CreatePaddleBackend Runtime initialized with Backend::PDINFER in Device::GPU. -Batch id: 0, example id: 0, sentence1: against shimmering cinematography that lends the setting the ethereal beauty of an asian landscape painting, label: positive, negative prob: 0.0003, positive prob: 0.9997. -Batch id: 1, example id: 0, sentence1: the situation in a well-balanced fashion, label: positive, negative prob: 0.0002, positive prob: 0.9998. -Batch id: 2, example id: 0, sentence1: at achieving the modest , crowd-pleasing goals it sets for itself, label: positive, negative prob: 0.0017, positive prob: 0.9983. -Batch id: 3, example id: 0, sentence1: so pat it makes your teeth hurt, label: negative, negative prob: 0.9986, positive prob: 0.0014. -Batch id: 4, example id: 0, sentence1: this new jangle of noise , mayhem and stupidity must be a serious contender for the title ., label: negative, negative prob: 0.9806, positive prob: 0.0194. +Batch id: 0, example id: 0, sentence: against shimmering cinematography that lends the setting the ethereal beauty of an asian landscape painting, label: positive, negative prob: 0.4623, positive prob: 0.5377. +Batch id: 0, example id: 1, sentence: the situation in a well-balanced fashion, label: positive, negative prob: 0.3500, positive prob: 0.6500. +Batch id: 1, example id: 0, sentence: at achieving the modest , crowd-pleasing goals it sets for itself, label: positive, negative prob: 0.4530, positive prob: 0.5470. +Batch id: 1, example id: 1, sentence: so pat it makes your teeth hurt, label: positive, negative prob: 0.3816, positive prob: 0.6184. +Batch id: 2, example id: 0, sentence: this new jangle of noise , mayhem and stupidity must be a serious contender for the title ., label: positive, negative prob: 0.3650, positive prob: 0.6350. ``` ## 参数说明 @@ -32,97 +28,8 @@ Batch id: 4, example id: 0, sentence1: this new jangle of noise , mayhem and stu | 参数 |参数说明 | |----------|--------------| |--model_dir | 指定部署模型的目录, | -|--batch_size |输入的 batch size,默认为 1| +|--batch_size |输入的 batch size,默认为 2| |--max_length |最大序列长度,默认为 128| |--device | 运行的设备,可选范围: ['cpu', 'gpu'],默认为'cpu' | |--device_id | 运行设备的 id。默认为0。 | -|--cpu_threads | 当使用 cpu 推理时,指定推理的 cpu 线程数,默认为1。| -|--backend | 支持的推理后端,可选范围: ['onnx_runtime', 'paddle', 'openvino', 'tensorrt', 'paddle_tensorrt'],默认为'paddle' | -|--use_fp16 | 是否使用 FP16模式进行推理。使用 tensorrt 和 paddle_tensorrt 后端时可开启,默认为 False | - -## FastDeploy 高阶用法 - -FastDeploy 在 Python 端上,提供 `fastdeploy.RuntimeOption.use_xxx()` 以及 `fastdeploy.RuntimeOption.use_xxx_backend()` 接口支持开发者选择不同的硬件、不同的推理引擎进行部署。在不同的硬件上部署 BERT 模型,需要选择硬件所支持的推理引擎进行部署,下表展示如何在不同的硬件上选择可用的推理引擎部署 BERT 模型。 - -符号说明: (1) ✅: 已经支持; (2) ❔: 正在进行中; (3) N/A: 暂不支持; - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
硬件 硬件对应的接口 可用的推理引擎 推理引擎对应的接口 是否支持 Paddle 新格式量化模型 是否支持 FP16 模式
CPU use_cpu() Paddle Inference use_paddle_infer_backend() N/A
ONNX Runtime use_ort_backend() N/A
OpenVINO use_openvino_backend() N/A
GPU use_gpu() Paddle Inference use_paddle_infer_backend() N/A
ONNX Runtime use_ort_backend()
Paddle TensorRT use_paddle_infer_backend() + paddle_infer_option.enable_trt = True
TensorRT use_trt_backend()
昆仑芯 XPU use_kunlunxin() Paddle Lite use_paddle_lite_backend() N/A
华为 昇腾 use_ascend() Paddle Lite use_paddle_lite_backend()
Graphcore IPU use_ipu() Paddle Inference use_paddle_infer_backend() N/A
+|--cpu_threads | 当使用 cpu 推理时,指定推理的 cpu 线程数,默认为4。| diff --git a/slm/model_zoo/bert/deploy/python/infer.py b/slm/model_zoo/bert/deploy/python/infer.py new file mode 100644 index 000000000000..d8d1315f8e0b --- /dev/null +++ b/slm/model_zoo/bert/deploy/python/infer.py @@ -0,0 +1,123 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +import numpy as np +from paddle import inference +from scipy.special import softmax + +from paddlenlp.transformers import AutoTokenizer +from paddlenlp.utils.env import ( + PADDLE_INFERENCE_MODEL_SUFFIX, + PADDLE_INFERENCE_WEIGHTS_SUFFIX, +) + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument("--model_dir", required=True, help="The directory of model.") + parser.add_argument("--model_prefix", type=str, default="model", help="Prefix of the model file (no extension).") + parser.add_argument("--device", choices=["gpu", "cpu"], default="cpu", help="Device for inference.") + parser.add_argument("--device_id", type=int, default=0, help="GPU device ID if using GPU.") + parser.add_argument("--cpu_threads", type=int, default=4, help="CPU threads if using CPU.") + parser.add_argument("--batch_size", type=int, default=2, help="Batch size for inference.") + parser.add_argument("--max_length", type=int, default=128, help="Max sequence length.") + return parser.parse_args() + + +def batchfy_text(texts, batch_size): + return [texts[i : i + batch_size] for i in range(0, len(texts), batch_size)] + + +class Predictor(object): + def __init__(self, args): + self.batch_size = args.batch_size + self.max_length = args.max_length + + self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir) + + model_file = os.path.join(args.model_dir, args.model_prefix + f"{PADDLE_INFERENCE_MODEL_SUFFIX}") + params_file = os.path.join(args.model_dir, args.model_prefix + f"{PADDLE_INFERENCE_WEIGHTS_SUFFIX}") + + if not os.path.exists(model_file): + raise FileNotFoundError(f"Model file not found: {model_file}") + if not os.path.exists(params_file): + raise FileNotFoundError(f"Params file not found: {params_file}") + + config = inference.Config(model_file, params_file) + if args.device == "gpu": + config.enable_use_gpu(100, args.device_id) + else: + config.disable_gpu() + config.set_cpu_math_library_num_threads(args.cpu_threads) + + config.switch_use_feed_fetch_ops(False) + self.predictor = inference.create_predictor(config) + self.input_handles = [self.predictor.get_input_handle(name) for name in self.predictor.get_input_names()] + self.output_handle = self.predictor.get_output_handle(self.predictor.get_output_names()[0]) + + def preprocess(self, texts): + encoded = self.tokenizer( + texts, + padding=True, + truncation=True, + max_length=self.max_length, + return_token_type_ids=True, + ) + input_ids = np.array(encoded["input_ids"], dtype="int64") + token_type_ids = np.array(encoded["token_type_ids"], dtype="int64") + return input_ids, token_type_ids + + def infer(self, input_ids, token_type_ids): + self.input_handles[0].copy_from_cpu(input_ids) + self.input_handles[1].copy_from_cpu(token_type_ids) + self.predictor.run() + return self.output_handle.copy_to_cpu() + + def postprocess(self, logits): + probs = softmax(logits, axis=1) + return {"label": probs.argmax(axis=1), "confidence": probs} + + def predict(self, texts): + input_ids, token_type_ids = self.preprocess(texts) + logits = self.infer(input_ids, token_type_ids) + return self.postprocess(logits) + + +if __name__ == "__main__": + args = parse_arguments() + predictor = Predictor(args) + + texts_ds = [ + "against shimmering cinematography that lends the setting the ethereal beauty of an asian landscape painting", + "the situation in a well-balanced fashion", + "at achieving the modest , crowd-pleasing goals it sets for itself", + "so pat it makes your teeth hurt", + "this new jangle of noise , mayhem and stupidity must be a serious contender for the title .", + ] + label_map = {0: "negative", 1: "positive"} + + batch_texts = batchfy_text(texts_ds, args.batch_size) + + for bs, texts in enumerate(batch_texts): + outputs = predictor.predict(texts) + for i, sentence in enumerate(texts): + label = outputs["label"][i] + confidence = outputs["confidence"][i] + print( + f"Batch id: {bs}, example id: {i}, sentence: {sentence}, " + f"label: {label_map[label]}, negative prob: {confidence[0]:.4f}, positive prob: {confidence[1]:.4f}." + ) diff --git a/slm/model_zoo/bert/deploy/python/seq_cls_infer.py b/slm/model_zoo/bert/deploy/python/seq_cls_infer.py deleted file mode 100644 index 34105530778d..000000000000 --- a/slm/model_zoo/bert/deploy/python/seq_cls_infer.py +++ /dev/null @@ -1,153 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import distutils.util -import os - -import fastdeploy as fd -import numpy as np - -from paddlenlp.transformers import AutoTokenizer - - -def parse_arguments(): - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument("--model_dir", required=True, help="The directory of model.") - parser.add_argument("--vocab_path", type=str, default="", help="The path of tokenizer vocab.") - parser.add_argument("--model_prefix", type=str, default="model", help="The model and params file prefix.") - parser.add_argument( - "--device", - type=str, - default="cpu", - choices=["gpu", "cpu"], - help="Type of inference device, support 'cpu' or 'gpu'.", - ) - parser.add_argument( - "--backend", - type=str, - default="paddle", - choices=["onnx_runtime", "paddle", "openvino", "tensorrt", "paddle_tensorrt"], - help="The inference runtime backend.", - ) - parser.add_argument("--cpu_threads", type=int, default=1, help="Number of threads to predict when using cpu.") - parser.add_argument("--device_id", type=int, default=0, help="Select which gpu device to train model.") - parser.add_argument("--batch_size", type=int, default=1, help="The batch size of data.") - parser.add_argument("--max_length", type=int, default=128, help="The max length of sequence.") - parser.add_argument("--log_interval", type=int, default=10, help="The interval of logging.") - parser.add_argument("--use_fp16", type=distutils.util.strtobool, default=False, help="Wheter to use FP16 mode") - return parser.parse_args() - - -def batchfy_text(texts, batch_size): - batch_texts = [] - batch_start = 0 - while batch_start < len(texts): - batch_texts += [texts[batch_start : min(batch_start + batch_size, len(texts))]] - batch_start += batch_size - return batch_texts - - -class Predictor(object): - def __init__(self, args): - self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir) - self.runtime = self.create_fd_runtime(args) - self.batch_size = args.batch_size - self.max_length = args.max_length - - def create_fd_runtime(self, args): - option = fd.RuntimeOption() - model_path = os.path.join(args.model_dir, args.model_prefix + ".pdmodel") - params_path = os.path.join(args.model_dir, args.model_prefix + ".pdiparams") - option.set_model_path(model_path, params_path) - if args.device == "cpu": - option.use_cpu() - option.set_cpu_thread_num(args.cpu_threads) - else: - option.use_gpu(args.device_id) - if args.backend == "paddle": - option.use_paddle_infer_backend() - elif args.backend == "onnx_runtime": - option.use_ort_backend() - elif args.backend == "openvino": - option.use_openvino_backend() - else: - option.use_trt_backend() - if args.backend == "paddle_tensorrt": - option.use_paddle_infer_backend() - option.paddle_infer_option.collect_trt_shape = True - option.paddle_infer_option.enable_trt = True - trt_file = os.path.join(args.model_dir, "model.trt") - option.trt_option.set_shape( - "input_ids", [1, 1], [args.batch_size, args.max_length], [args.batch_size, args.max_length] - ) - option.trt_option.set_shape( - "token_type_ids", [1, 1], [args.batch_size, args.max_length], [args.batch_size, args.max_length] - ) - if args.use_fp16: - option.trt_option.enable_fp16 = True - trt_file = trt_file + ".fp16" - option.trt_option.serialize_file = trt_file - return fd.Runtime(option) - - def preprocess(self, text): - data = self.tokenizer(text, max_length=self.max_length, padding=True, truncation=True) - input_ids_name = self.runtime.get_input_info(0).name - token_type_ids_name = self.runtime.get_input_info(1).name - input_map = { - input_ids_name: np.array(data["input_ids"], dtype="int64"), - token_type_ids_name: np.array(data["token_type_ids"], dtype="int64"), - } - return input_map - - def infer(self, input_map): - results = self.runtime.infer(input_map) - return results - - def postprocess(self, infer_data): - logits = np.array(infer_data[0]) - max_value = np.max(logits, axis=1, keepdims=True) - exp_data = np.exp(logits - max_value) - probs = exp_data / np.sum(exp_data, axis=1, keepdims=True) - out_dict = {"label": probs.argmax(axis=-1), "confidence": probs} - return out_dict - - def predict(self, texts): - input_map = self.preprocess(texts) - infer_result = self.infer(input_map) - output = self.postprocess(infer_result) - return output - - -if __name__ == "__main__": - args = parse_arguments() - predictor = Predictor(args) - texts_ds = [ - "against shimmering cinematography that lends the setting the ethereal beauty of an asian landscape painting", - "the situation in a well-balanced fashion", - "at achieving the modest , crowd-pleasing goals it sets for itself", - "so pat it makes your teeth hurt", - "this new jangle of noise , mayhem and stupidity must be a serious contender for the title .", - ] - label_map = {0: "negative", 1: "positive"} - batch_texts = batchfy_text(texts_ds, args.batch_size) - for bs, texts in enumerate(batch_texts): - outputs = predictor.predict(texts) - for i, sentence1 in enumerate(texts): - print( - f"Batch id: {bs}, example id: {i}, sentence1: {sentence1}, " - f"label: {label_map[outputs['label'][i]]}, negative prob: {outputs['confidence'][i][0]:.4f}, " - f"positive prob: {outputs['confidence'][i][1]:.4f}." - )