diff --git a/slm/model_zoo/bert/README.md b/slm/model_zoo/bert/README.md
index f25ed3e07bca..6403f6f79bf1 100644
--- a/slm/model_zoo/bert/README.md
+++ b/slm/model_zoo/bert/README.md
@@ -4,7 +4,7 @@
[BERT](https://arxiv.org/abs/1810.04805) (Bidirectional Encoder Representations from Transformers)以[Transformer](https://arxiv.org/abs/1706.03762) 编码器为网络基本组件,使用掩码语言模型(Masked Language Model)和邻接句子预测(Next Sentence Prediction)两个任务在大规模无标注文本语料上进行预训练(pre-train),得到融合了双向内容的通用语义表示模型。以预训练产生的通用语义表示模型为基础,结合任务适配的简单输出层,微调(fine-tune)后即可应用到下游的 NLP 任务,效果通常也较直接在下游的任务上训练的模型更优。此前 BERT 即在[GLUE 评测任务](https://gluebenchmark.com/tasks)上取得了 SOTA 的结果。
-本项目是 BERT 在 Paddle 2.0上的开源实现,包含了预训练和[GLUE 评测任务](https://gluebenchmark.com/tasks)上的微调代码。
+本项目是 BERT 在 Paddle 2.0上的开源实现,并在 PaddlePaddle 3.0.0版本进行了适配与验证,包含了预训练和[GLUE 评测任务](https://gluebenchmark.com/tasks)上的微调代码。
## 快速开始
@@ -262,23 +262,22 @@ python -u ./export_model.py \
其中参数释义如下:
- `model_type` 指示了模型类型,使用 BERT 模型时设置为 bert 即可。
- `model_path` 表示训练模型的保存路径,与训练时的`output_dir`一致。
-- `output_path` 表示导出预测模型文件的前缀。保存时会添加后缀(`pdiparams`,`pdiparams.info`,`pdmodel`);除此之外,还会在`output_path`包含的目录下保存 tokenizer 相关内容。
+- `output_path` 表示导出预测模型文件的前缀。保存时会添加后缀(`pdiparams`,`pdiparams.info`,`pdmodel`/`json`);除此之外,还会在`output_path`包含的目录下保存 tokenizer 相关内容。
-完成模型导出后,可以开始部署。`deploy/python/seq_cls_infer.py` 文件提供了 python 部署预测示例。可执行以下命令运行部署示例:
+完成模型导出后,可以开始部署。`deploy/python/infer.py` 文件提供了 python 部署预测示例。可执行以下命令运行部署示例:
```shell
-python deploy/python/seq_cls_infer.py --model_dir infer_model/ --device gpu --backend paddle
+python deploy/python/infer.py --model_dir infer_model/ --device gpu
```
运行后预测结果打印如下:
```bash
-[INFO] fastdeploy/runtime/runtime.cc(266)::CreatePaddleBackend Runtime initialized with Backend::PDINFER in Device::GPU.
-Batch id: 0, example id: 0, sentence1: against shimmering cinematography that lends the setting the ethereal beauty of an asian landscape painting, label: positive, negative prob: 0.0003, positive prob: 0.9997.
-Batch id: 1, example id: 0, sentence1: the situation in a well-balanced fashion, label: positive, negative prob: 0.0002, positive prob: 0.9998.
-Batch id: 2, example id: 0, sentence1: at achieving the modest , crowd-pleasing goals it sets for itself, label: positive, negative prob: 0.0017, positive prob: 0.9983.
-Batch id: 3, example id: 0, sentence1: so pat it makes your teeth hurt, label: negative, negative prob: 0.9986, positive prob: 0.0014.
-Batch id: 4, example id: 0, sentence1: this new jangle of noise , mayhem and stupidity must be a serious contender for the title ., label: negative, negative prob: 0.9806, positive prob: 0.0194.
+Batch id: 0, example id: 0, sentence: against shimmering cinematography that lends the setting the ethereal beauty of an asian landscape painting, label: positive, negative prob: 0.4623, positive prob: 0.5377.
+Batch id: 0, example id: 1, sentence: the situation in a well-balanced fashion, label: positive, negative prob: 0.3500, positive prob: 0.6500.
+Batch id: 1, example id: 0, sentence: at achieving the modest , crowd-pleasing goals it sets for itself, label: positive, negative prob: 0.4530, positive prob: 0.5470.
+Batch id: 1, example id: 1, sentence: so pat it makes your teeth hurt, label: positive, negative prob: 0.3816, positive prob: 0.6184.
+Batch id: 2, example id: 0, sentence: this new jangle of noise , mayhem and stupidity must be a serious contender for the title ., label: positive, negative prob: 0.3650, positive prob: 0.6350.
```
更多详细用法可参考 [Python 部署](deploy/python/README.md)。
diff --git a/slm/model_zoo/bert/deploy/python/README.md b/slm/model_zoo/bert/deploy/python/README.md
index 7e78d164ac14..a7d00f405312 100644
--- a/slm/model_zoo/bert/deploy/python/README.md
+++ b/slm/model_zoo/bert/deploy/python/README.md
@@ -1,30 +1,26 @@
-# FastDeploy BERT 模型 Python 部署示例
-
-在部署前,参考 [FastDeploy SDK 安装文档](https://github.com/PaddlePaddle/FastDeploy/blob/develop/docs/cn/build_and_install/download_prebuilt_libraries.md)安装 FastDeploy Python SDK。
-
-本目录下分别提供 `seq_cls_infer.py` 快速完成在 CPU/GPU 的 GLUE 文本分类任务的 Python 部署示例。
+# BERT 模型 Python 推理示例
+本目录下提供 `seq_cls_infer.py` 快速完成在 CPU/GPU 的 GLUE 文本分类任务的 Python 示例。
## 快速开始
-以下示例展示如何基于 FastDeploy 库完成 BERT 模型在 GLUE SST-2 数据集上进行自然语言推断任务的 Python 预测部署,可通过命令行参数`--device`以及`--backend`指定运行在不同的硬件以及推理引擎后端,并使用`--model_dir`参数指定运行的模型,具体参数设置可查看下面[参数说明](#参数说明)。示例中的模型是按照 [BERT 训练文档](../../README.md)导出得到的部署模型,其模型目录为`model_zoo/bert/infer_model`(用户可按实际情况设置)。
+可通过命令行参数`--device`指定运行在不同的硬件,并使用`--model_dir`参数指定运行的模型,具体参数设置可查看下面[参数说明](#参数说明)。示例中的模型是按照 [BERT 训练文档](../../README.md)导出得到的部署模型,其模型目录为`model_zoo/bert/infer_model`(用户可按实际情况设置)。
```bash
# CPU 推理
-python seq_cls_infer.py --model_dir ../../infer_model/ --device cpu --backend paddle
+python infer.py --model_dir ../../infer_model/ --device cpu
# GPU 推理
-python seq_cls_infer.py --model_dir ../../infer_model/ --device gpu --backend paddle
+python infer.py --model_dir ../../infer_model/ --device gpu
```
运行完成后返回的结果如下:
```bash
-[INFO] fastdeploy/runtime/runtime.cc(266)::CreatePaddleBackend Runtime initialized with Backend::PDINFER in Device::GPU.
-Batch id: 0, example id: 0, sentence1: against shimmering cinematography that lends the setting the ethereal beauty of an asian landscape painting, label: positive, negative prob: 0.0003, positive prob: 0.9997.
-Batch id: 1, example id: 0, sentence1: the situation in a well-balanced fashion, label: positive, negative prob: 0.0002, positive prob: 0.9998.
-Batch id: 2, example id: 0, sentence1: at achieving the modest , crowd-pleasing goals it sets for itself, label: positive, negative prob: 0.0017, positive prob: 0.9983.
-Batch id: 3, example id: 0, sentence1: so pat it makes your teeth hurt, label: negative, negative prob: 0.9986, positive prob: 0.0014.
-Batch id: 4, example id: 0, sentence1: this new jangle of noise , mayhem and stupidity must be a serious contender for the title ., label: negative, negative prob: 0.9806, positive prob: 0.0194.
+Batch id: 0, example id: 0, sentence: against shimmering cinematography that lends the setting the ethereal beauty of an asian landscape painting, label: positive, negative prob: 0.4623, positive prob: 0.5377.
+Batch id: 0, example id: 1, sentence: the situation in a well-balanced fashion, label: positive, negative prob: 0.3500, positive prob: 0.6500.
+Batch id: 1, example id: 0, sentence: at achieving the modest , crowd-pleasing goals it sets for itself, label: positive, negative prob: 0.4530, positive prob: 0.5470.
+Batch id: 1, example id: 1, sentence: so pat it makes your teeth hurt, label: positive, negative prob: 0.3816, positive prob: 0.6184.
+Batch id: 2, example id: 0, sentence: this new jangle of noise , mayhem and stupidity must be a serious contender for the title ., label: positive, negative prob: 0.3650, positive prob: 0.6350.
```
## 参数说明
@@ -32,97 +28,8 @@ Batch id: 4, example id: 0, sentence1: this new jangle of noise , mayhem and stu
| 参数 |参数说明 |
|----------|--------------|
|--model_dir | 指定部署模型的目录, |
-|--batch_size |输入的 batch size,默认为 1|
+|--batch_size |输入的 batch size,默认为 2|
|--max_length |最大序列长度,默认为 128|
|--device | 运行的设备,可选范围: ['cpu', 'gpu'],默认为'cpu' |
|--device_id | 运行设备的 id。默认为0。 |
-|--cpu_threads | 当使用 cpu 推理时,指定推理的 cpu 线程数,默认为1。|
-|--backend | 支持的推理后端,可选范围: ['onnx_runtime', 'paddle', 'openvino', 'tensorrt', 'paddle_tensorrt'],默认为'paddle' |
-|--use_fp16 | 是否使用 FP16模式进行推理。使用 tensorrt 和 paddle_tensorrt 后端时可开启,默认为 False |
-
-## FastDeploy 高阶用法
-
-FastDeploy 在 Python 端上,提供 `fastdeploy.RuntimeOption.use_xxx()` 以及 `fastdeploy.RuntimeOption.use_xxx_backend()` 接口支持开发者选择不同的硬件、不同的推理引擎进行部署。在不同的硬件上部署 BERT 模型,需要选择硬件所支持的推理引擎进行部署,下表展示如何在不同的硬件上选择可用的推理引擎部署 BERT 模型。
-
-符号说明: (1) ✅: 已经支持; (2) ❔: 正在进行中; (3) N/A: 暂不支持;
-
-
-
- 硬件 |
- 硬件对应的接口 |
- 可用的推理引擎 |
- 推理引擎对应的接口 |
- 是否支持 Paddle 新格式量化模型 |
- 是否支持 FP16 模式 |
-
-
- CPU |
- use_cpu() |
- Paddle Inference |
- use_paddle_infer_backend() |
- ✅ |
- N/A |
-
-
- ONNX Runtime |
- use_ort_backend() |
- ✅ |
- N/A |
-
-
- OpenVINO |
- use_openvino_backend() |
- ❔ |
- N/A |
-
-
- GPU |
- use_gpu() |
- Paddle Inference |
- use_paddle_infer_backend() |
- ✅ |
- N/A |
-
-
- ONNX Runtime |
- use_ort_backend() |
- ✅ |
- ❔ |
-
-
- Paddle TensorRT |
- use_paddle_infer_backend() + paddle_infer_option.enable_trt = True |
- ✅ |
- ✅ |
-
-
- TensorRT |
- use_trt_backend() |
- ✅ |
- ✅ |
-
-
- 昆仑芯 XPU |
- use_kunlunxin() |
- Paddle Lite |
- use_paddle_lite_backend() |
- N/A |
- ✅ |
-
-
- 华为 昇腾 |
- use_ascend() |
- Paddle Lite |
- use_paddle_lite_backend() |
- ❔ |
- ✅ |
-
-
- Graphcore IPU |
- use_ipu() |
- Paddle Inference |
- use_paddle_infer_backend() |
- ❔ |
- N/A |
-
-
+|--cpu_threads | 当使用 cpu 推理时,指定推理的 cpu 线程数,默认为4。|
diff --git a/slm/model_zoo/bert/deploy/python/infer.py b/slm/model_zoo/bert/deploy/python/infer.py
new file mode 100644
index 000000000000..d8d1315f8e0b
--- /dev/null
+++ b/slm/model_zoo/bert/deploy/python/infer.py
@@ -0,0 +1,123 @@
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import os
+
+import numpy as np
+from paddle import inference
+from scipy.special import softmax
+
+from paddlenlp.transformers import AutoTokenizer
+from paddlenlp.utils.env import (
+ PADDLE_INFERENCE_MODEL_SUFFIX,
+ PADDLE_INFERENCE_WEIGHTS_SUFFIX,
+)
+
+
+def parse_arguments():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model_dir", required=True, help="The directory of model.")
+ parser.add_argument("--model_prefix", type=str, default="model", help="Prefix of the model file (no extension).")
+ parser.add_argument("--device", choices=["gpu", "cpu"], default="cpu", help="Device for inference.")
+ parser.add_argument("--device_id", type=int, default=0, help="GPU device ID if using GPU.")
+ parser.add_argument("--cpu_threads", type=int, default=4, help="CPU threads if using CPU.")
+ parser.add_argument("--batch_size", type=int, default=2, help="Batch size for inference.")
+ parser.add_argument("--max_length", type=int, default=128, help="Max sequence length.")
+ return parser.parse_args()
+
+
+def batchfy_text(texts, batch_size):
+ return [texts[i : i + batch_size] for i in range(0, len(texts), batch_size)]
+
+
+class Predictor(object):
+ def __init__(self, args):
+ self.batch_size = args.batch_size
+ self.max_length = args.max_length
+
+ self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir)
+
+ model_file = os.path.join(args.model_dir, args.model_prefix + f"{PADDLE_INFERENCE_MODEL_SUFFIX}")
+ params_file = os.path.join(args.model_dir, args.model_prefix + f"{PADDLE_INFERENCE_WEIGHTS_SUFFIX}")
+
+ if not os.path.exists(model_file):
+ raise FileNotFoundError(f"Model file not found: {model_file}")
+ if not os.path.exists(params_file):
+ raise FileNotFoundError(f"Params file not found: {params_file}")
+
+ config = inference.Config(model_file, params_file)
+ if args.device == "gpu":
+ config.enable_use_gpu(100, args.device_id)
+ else:
+ config.disable_gpu()
+ config.set_cpu_math_library_num_threads(args.cpu_threads)
+
+ config.switch_use_feed_fetch_ops(False)
+ self.predictor = inference.create_predictor(config)
+ self.input_handles = [self.predictor.get_input_handle(name) for name in self.predictor.get_input_names()]
+ self.output_handle = self.predictor.get_output_handle(self.predictor.get_output_names()[0])
+
+ def preprocess(self, texts):
+ encoded = self.tokenizer(
+ texts,
+ padding=True,
+ truncation=True,
+ max_length=self.max_length,
+ return_token_type_ids=True,
+ )
+ input_ids = np.array(encoded["input_ids"], dtype="int64")
+ token_type_ids = np.array(encoded["token_type_ids"], dtype="int64")
+ return input_ids, token_type_ids
+
+ def infer(self, input_ids, token_type_ids):
+ self.input_handles[0].copy_from_cpu(input_ids)
+ self.input_handles[1].copy_from_cpu(token_type_ids)
+ self.predictor.run()
+ return self.output_handle.copy_to_cpu()
+
+ def postprocess(self, logits):
+ probs = softmax(logits, axis=1)
+ return {"label": probs.argmax(axis=1), "confidence": probs}
+
+ def predict(self, texts):
+ input_ids, token_type_ids = self.preprocess(texts)
+ logits = self.infer(input_ids, token_type_ids)
+ return self.postprocess(logits)
+
+
+if __name__ == "__main__":
+ args = parse_arguments()
+ predictor = Predictor(args)
+
+ texts_ds = [
+ "against shimmering cinematography that lends the setting the ethereal beauty of an asian landscape painting",
+ "the situation in a well-balanced fashion",
+ "at achieving the modest , crowd-pleasing goals it sets for itself",
+ "so pat it makes your teeth hurt",
+ "this new jangle of noise , mayhem and stupidity must be a serious contender for the title .",
+ ]
+ label_map = {0: "negative", 1: "positive"}
+
+ batch_texts = batchfy_text(texts_ds, args.batch_size)
+
+ for bs, texts in enumerate(batch_texts):
+ outputs = predictor.predict(texts)
+ for i, sentence in enumerate(texts):
+ label = outputs["label"][i]
+ confidence = outputs["confidence"][i]
+ print(
+ f"Batch id: {bs}, example id: {i}, sentence: {sentence}, "
+ f"label: {label_map[label]}, negative prob: {confidence[0]:.4f}, positive prob: {confidence[1]:.4f}."
+ )
diff --git a/slm/model_zoo/bert/deploy/python/seq_cls_infer.py b/slm/model_zoo/bert/deploy/python/seq_cls_infer.py
deleted file mode 100644
index 34105530778d..000000000000
--- a/slm/model_zoo/bert/deploy/python/seq_cls_infer.py
+++ /dev/null
@@ -1,153 +0,0 @@
-# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import distutils.util
-import os
-
-import fastdeploy as fd
-import numpy as np
-
-from paddlenlp.transformers import AutoTokenizer
-
-
-def parse_arguments():
- import argparse
-
- parser = argparse.ArgumentParser()
- parser.add_argument("--model_dir", required=True, help="The directory of model.")
- parser.add_argument("--vocab_path", type=str, default="", help="The path of tokenizer vocab.")
- parser.add_argument("--model_prefix", type=str, default="model", help="The model and params file prefix.")
- parser.add_argument(
- "--device",
- type=str,
- default="cpu",
- choices=["gpu", "cpu"],
- help="Type of inference device, support 'cpu' or 'gpu'.",
- )
- parser.add_argument(
- "--backend",
- type=str,
- default="paddle",
- choices=["onnx_runtime", "paddle", "openvino", "tensorrt", "paddle_tensorrt"],
- help="The inference runtime backend.",
- )
- parser.add_argument("--cpu_threads", type=int, default=1, help="Number of threads to predict when using cpu.")
- parser.add_argument("--device_id", type=int, default=0, help="Select which gpu device to train model.")
- parser.add_argument("--batch_size", type=int, default=1, help="The batch size of data.")
- parser.add_argument("--max_length", type=int, default=128, help="The max length of sequence.")
- parser.add_argument("--log_interval", type=int, default=10, help="The interval of logging.")
- parser.add_argument("--use_fp16", type=distutils.util.strtobool, default=False, help="Wheter to use FP16 mode")
- return parser.parse_args()
-
-
-def batchfy_text(texts, batch_size):
- batch_texts = []
- batch_start = 0
- while batch_start < len(texts):
- batch_texts += [texts[batch_start : min(batch_start + batch_size, len(texts))]]
- batch_start += batch_size
- return batch_texts
-
-
-class Predictor(object):
- def __init__(self, args):
- self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir)
- self.runtime = self.create_fd_runtime(args)
- self.batch_size = args.batch_size
- self.max_length = args.max_length
-
- def create_fd_runtime(self, args):
- option = fd.RuntimeOption()
- model_path = os.path.join(args.model_dir, args.model_prefix + ".pdmodel")
- params_path = os.path.join(args.model_dir, args.model_prefix + ".pdiparams")
- option.set_model_path(model_path, params_path)
- if args.device == "cpu":
- option.use_cpu()
- option.set_cpu_thread_num(args.cpu_threads)
- else:
- option.use_gpu(args.device_id)
- if args.backend == "paddle":
- option.use_paddle_infer_backend()
- elif args.backend == "onnx_runtime":
- option.use_ort_backend()
- elif args.backend == "openvino":
- option.use_openvino_backend()
- else:
- option.use_trt_backend()
- if args.backend == "paddle_tensorrt":
- option.use_paddle_infer_backend()
- option.paddle_infer_option.collect_trt_shape = True
- option.paddle_infer_option.enable_trt = True
- trt_file = os.path.join(args.model_dir, "model.trt")
- option.trt_option.set_shape(
- "input_ids", [1, 1], [args.batch_size, args.max_length], [args.batch_size, args.max_length]
- )
- option.trt_option.set_shape(
- "token_type_ids", [1, 1], [args.batch_size, args.max_length], [args.batch_size, args.max_length]
- )
- if args.use_fp16:
- option.trt_option.enable_fp16 = True
- trt_file = trt_file + ".fp16"
- option.trt_option.serialize_file = trt_file
- return fd.Runtime(option)
-
- def preprocess(self, text):
- data = self.tokenizer(text, max_length=self.max_length, padding=True, truncation=True)
- input_ids_name = self.runtime.get_input_info(0).name
- token_type_ids_name = self.runtime.get_input_info(1).name
- input_map = {
- input_ids_name: np.array(data["input_ids"], dtype="int64"),
- token_type_ids_name: np.array(data["token_type_ids"], dtype="int64"),
- }
- return input_map
-
- def infer(self, input_map):
- results = self.runtime.infer(input_map)
- return results
-
- def postprocess(self, infer_data):
- logits = np.array(infer_data[0])
- max_value = np.max(logits, axis=1, keepdims=True)
- exp_data = np.exp(logits - max_value)
- probs = exp_data / np.sum(exp_data, axis=1, keepdims=True)
- out_dict = {"label": probs.argmax(axis=-1), "confidence": probs}
- return out_dict
-
- def predict(self, texts):
- input_map = self.preprocess(texts)
- infer_result = self.infer(input_map)
- output = self.postprocess(infer_result)
- return output
-
-
-if __name__ == "__main__":
- args = parse_arguments()
- predictor = Predictor(args)
- texts_ds = [
- "against shimmering cinematography that lends the setting the ethereal beauty of an asian landscape painting",
- "the situation in a well-balanced fashion",
- "at achieving the modest , crowd-pleasing goals it sets for itself",
- "so pat it makes your teeth hurt",
- "this new jangle of noise , mayhem and stupidity must be a serious contender for the title .",
- ]
- label_map = {0: "negative", 1: "positive"}
- batch_texts = batchfy_text(texts_ds, args.batch_size)
- for bs, texts in enumerate(batch_texts):
- outputs = predictor.predict(texts)
- for i, sentence1 in enumerate(texts):
- print(
- f"Batch id: {bs}, example id: {i}, sentence1: {sentence1}, "
- f"label: {label_map[outputs['label'][i]]}, negative prob: {outputs['confidence'][i][0]:.4f}, "
- f"positive prob: {outputs['confidence'][i][1]:.4f}."
- )