Skip to content

Commit 5392bdb

Browse files
authored
[PaddleNLP No.18] Create new infer example validated on pd3.0 (#10422)
1 parent fd22e8f commit 5392bdb

File tree

4 files changed

+144
-268
lines changed

4 files changed

+144
-268
lines changed

slm/model_zoo/bert/README.md

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
[BERT](https://arxiv.org/abs/1810.04805) (Bidirectional Encoder Representations from Transformers)以[Transformer](https://arxiv.org/abs/1706.03762) 编码器为网络基本组件,使用掩码语言模型(Masked Language Model)和邻接句子预测(Next Sentence Prediction)两个任务在大规模无标注文本语料上进行预训练(pre-train),得到融合了双向内容的通用语义表示模型。以预训练产生的通用语义表示模型为基础,结合任务适配的简单输出层,微调(fine-tune)后即可应用到下游的 NLP 任务,效果通常也较直接在下游的任务上训练的模型更优。此前 BERT 即在[GLUE 评测任务](https://gluebenchmark.com/tasks)上取得了 SOTA 的结果。
66

7-
本项目是 BERT 在 Paddle 2.0上的开源实现,包含了预训练和[GLUE 评测任务](https://gluebenchmark.com/tasks)上的微调代码。
7+
本项目是 BERT 在 Paddle 2.0上的开源实现,并在 PaddlePaddle 3.0.0版本进行了适配与验证,包含了预训练和[GLUE 评测任务](https://gluebenchmark.com/tasks)上的微调代码。
88

99
## 快速开始
1010

@@ -262,23 +262,22 @@ python -u ./export_model.py \
262262
其中参数释义如下:
263263
- `model_type` 指示了模型类型,使用 BERT 模型时设置为 bert 即可。
264264
- `model_path` 表示训练模型的保存路径,与训练时的`output_dir`一致。
265-
- `output_path` 表示导出预测模型文件的前缀。保存时会添加后缀(`pdiparams``pdiparams.info``pdmodel`);除此之外,还会在`output_path`包含的目录下保存 tokenizer 相关内容。
265+
- `output_path` 表示导出预测模型文件的前缀。保存时会添加后缀(`pdiparams``pdiparams.info``pdmodel`/`json`);除此之外,还会在`output_path`包含的目录下保存 tokenizer 相关内容。
266266

267-
完成模型导出后,可以开始部署。`deploy/python/seq_cls_infer.py` 文件提供了 python 部署预测示例。可执行以下命令运行部署示例:
267+
完成模型导出后,可以开始部署。`deploy/python/infer.py` 文件提供了 python 部署预测示例。可执行以下命令运行部署示例:
268268

269269
```shell
270-
python deploy/python/seq_cls_infer.py --model_dir infer_model/ --device gpu --backend paddle
270+
python deploy/python/infer.py --model_dir infer_model/ --device gpu
271271
```
272272

273273
运行后预测结果打印如下:
274274

275275
```bash
276-
[INFO] fastdeploy/runtime/runtime.cc(266)::CreatePaddleBackend Runtime initialized with Backend::PDINFER in Device::GPU.
277-
Batch id: 0, example id: 0, sentence1: against shimmering cinematography that lends the setting the ethereal beauty of an asian landscape painting, label: positive, negative prob: 0.0003, positive prob: 0.9997.
278-
Batch id: 1, example id: 0, sentence1: the situation in a well-balanced fashion, label: positive, negative prob: 0.0002, positive prob: 0.9998.
279-
Batch id: 2, example id: 0, sentence1: at achieving the modest , crowd-pleasing goals it sets for itself, label: positive, negative prob: 0.0017, positive prob: 0.9983.
280-
Batch id: 3, example id: 0, sentence1: so pat it makes your teeth hurt, label: negative, negative prob: 0.9986, positive prob: 0.0014.
281-
Batch id: 4, example id: 0, sentence1: this new jangle of noise , mayhem and stupidity must be a serious contender for the title ., label: negative, negative prob: 0.9806, positive prob: 0.0194.
276+
Batch id: 0, example id: 0, sentence: against shimmering cinematography that lends the setting the ethereal beauty of an asian landscape painting, label: positive, negative prob: 0.4623, positive prob: 0.5377.
277+
Batch id: 0, example id: 1, sentence: the situation in a well-balanced fashion, label: positive, negative prob: 0.3500, positive prob: 0.6500.
278+
Batch id: 1, example id: 0, sentence: at achieving the modest , crowd-pleasing goals it sets for itself, label: positive, negative prob: 0.4530, positive prob: 0.5470.
279+
Batch id: 1, example id: 1, sentence: so pat it makes your teeth hurt, label: positive, negative prob: 0.3816, positive prob: 0.6184.
280+
Batch id: 2, example id: 0, sentence: this new jangle of noise , mayhem and stupidity must be a serious contender for the title ., label: positive, negative prob: 0.3650, positive prob: 0.6350.
282281
```
283282

284283
更多详细用法可参考 [Python 部署](deploy/python/README.md)
Lines changed: 12 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -1,128 +1,35 @@
1-
# FastDeploy BERT 模型 Python 部署示例
2-
3-
在部署前,参考 [FastDeploy SDK 安装文档](https://github.com/PaddlePaddle/FastDeploy/blob/develop/docs/cn/build_and_install/download_prebuilt_libraries.md)安装 FastDeploy Python SDK。
4-
5-
本目录下分别提供 `seq_cls_infer.py` 快速完成在 CPU/GPU 的 GLUE 文本分类任务的 Python 部署示例。
1+
# BERT 模型 Python 推理示例
2+
本目录下提供 `seq_cls_infer.py` 快速完成在 CPU/GPU 的 GLUE 文本分类任务的 Python 示例。
63

74
## 快速开始
85

9-
以下示例展示如何基于 FastDeploy 库完成 BERT 模型在 GLUE SST-2 数据集上进行自然语言推断任务的 Python 预测部署,可通过命令行参数`--device`以及`--backend`指定运行在不同的硬件以及推理引擎后端,并使用`--model_dir`参数指定运行的模型,具体参数设置可查看下面[参数说明](#参数说明)。示例中的模型是按照 [BERT 训练文档](../../README.md)导出得到的部署模型,其模型目录为`model_zoo/bert/infer_model`(用户可按实际情况设置)。
6+
可通过命令行参数`--device`指定运行在不同的硬件,并使用`--model_dir`参数指定运行的模型,具体参数设置可查看下面[参数说明](#参数说明)。示例中的模型是按照 [BERT 训练文档](../../README.md)导出得到的部署模型,其模型目录为`model_zoo/bert/infer_model`(用户可按实际情况设置)。
107

118

129
```bash
1310
# CPU 推理
14-
python seq_cls_infer.py --model_dir ../../infer_model/ --device cpu --backend paddle
11+
python infer.py --model_dir ../../infer_model/ --device cpu
1512
# GPU 推理
16-
python seq_cls_infer.py --model_dir ../../infer_model/ --device gpu --backend paddle
13+
python infer.py --model_dir ../../infer_model/ --device gpu
1714
```
1815

1916
运行完成后返回的结果如下:
2017

2118
```bash
22-
[INFO] fastdeploy/runtime/runtime.cc(266)::CreatePaddleBackend Runtime initialized with Backend::PDINFER in Device::GPU.
23-
Batch id: 0, example id: 0, sentence1: against shimmering cinematography that lends the setting the ethereal beauty of an asian landscape painting, label: positive, negative prob: 0.0003, positive prob: 0.9997.
24-
Batch id: 1, example id: 0, sentence1: the situation in a well-balanced fashion, label: positive, negative prob: 0.0002, positive prob: 0.9998.
25-
Batch id: 2, example id: 0, sentence1: at achieving the modest , crowd-pleasing goals it sets for itself, label: positive, negative prob: 0.0017, positive prob: 0.9983.
26-
Batch id: 3, example id: 0, sentence1: so pat it makes your teeth hurt, label: negative, negative prob: 0.9986, positive prob: 0.0014.
27-
Batch id: 4, example id: 0, sentence1: this new jangle of noise , mayhem and stupidity must be a serious contender for the title ., label: negative, negative prob: 0.9806, positive prob: 0.0194.
19+
Batch id: 0, example id: 0, sentence: against shimmering cinematography that lends the setting the ethereal beauty of an asian landscape painting, label: positive, negative prob: 0.4623, positive prob: 0.5377.
20+
Batch id: 0, example id: 1, sentence: the situation in a well-balanced fashion, label: positive, negative prob: 0.3500, positive prob: 0.6500.
21+
Batch id: 1, example id: 0, sentence: at achieving the modest , crowd-pleasing goals it sets for itself, label: positive, negative prob: 0.4530, positive prob: 0.5470.
22+
Batch id: 1, example id: 1, sentence: so pat it makes your teeth hurt, label: positive, negative prob: 0.3816, positive prob: 0.6184.
23+
Batch id: 2, example id: 0, sentence: this new jangle of noise , mayhem and stupidity must be a serious contender for the title ., label: positive, negative prob: 0.3650, positive prob: 0.6350.
2824
```
2925

3026
## 参数说明
3127

3228
| 参数 |参数说明 |
3329
|----------|--------------|
3430
|--model_dir | 指定部署模型的目录, |
35-
|--batch_size |输入的 batch size,默认为 1|
31+
|--batch_size |输入的 batch size,默认为 2|
3632
|--max_length |最大序列长度,默认为 128|
3733
|--device | 运行的设备,可选范围: ['cpu', 'gpu'],默认为'cpu' |
3834
|--device_id | 运行设备的 id。默认为0。 |
39-
|--cpu_threads | 当使用 cpu 推理时,指定推理的 cpu 线程数,默认为1。|
40-
|--backend | 支持的推理后端,可选范围: ['onnx_runtime', 'paddle', 'openvino', 'tensorrt', 'paddle_tensorrt'],默认为'paddle' |
41-
|--use_fp16 | 是否使用 FP16模式进行推理。使用 tensorrt 和 paddle_tensorrt 后端时可开启,默认为 False |
42-
43-
## FastDeploy 高阶用法
44-
45-
FastDeploy 在 Python 端上,提供 `fastdeploy.RuntimeOption.use_xxx()` 以及 `fastdeploy.RuntimeOption.use_xxx_backend()` 接口支持开发者选择不同的硬件、不同的推理引擎进行部署。在不同的硬件上部署 BERT 模型,需要选择硬件所支持的推理引擎进行部署,下表展示如何在不同的硬件上选择可用的推理引擎部署 BERT 模型。
46-
47-
符号说明: (1) ✅: 已经支持; (2) ❔: 正在进行中; (3) N/A: 暂不支持;
48-
49-
<table>
50-
<tr>
51-
<td align=center> 硬件</td>
52-
<td align=center> 硬件对应的接口</td>
53-
<td align=center> 可用的推理引擎 </td>
54-
<td align=center> 推理引擎对应的接口 </td>
55-
<td align=center> 是否支持 Paddle 新格式量化模型 </td>
56-
<td align=center> 是否支持 FP16 模式 </td>
57-
</tr>
58-
<tr>
59-
<td rowspan=3 align=center> CPU </td>
60-
<td rowspan=3 align=center> use_cpu() </td>
61-
<td align=center> Paddle Inference </td>
62-
<td align=center> use_paddle_infer_backend() </td>
63-
<td align=center> ✅ </td>
64-
<td align=center> N/A </td>
65-
</tr>
66-
<tr>
67-
<td align=center> ONNX Runtime </td>
68-
<td align=center> use_ort_backend() </td>
69-
<td align=center> ✅ </td>
70-
<td align=center> N/A </td>
71-
</tr>
72-
<tr>
73-
<td align=center> OpenVINO </td>
74-
<td align=center> use_openvino_backend() </td>
75-
<td align=center> ❔ </td>
76-
<td align=center> N/A </td>
77-
</tr>
78-
<tr>
79-
<td rowspan=4 align=center> GPU </td>
80-
<td rowspan=4 align=center> use_gpu() </td>
81-
<td align=center> Paddle Inference </td>
82-
<td align=center> use_paddle_infer_backend() </td>
83-
<td align=center> ✅ </td>
84-
<td align=center> N/A </td>
85-
</tr>
86-
<tr>
87-
<td align=center> ONNX Runtime </td>
88-
<td align=center> use_ort_backend() </td>
89-
<td align=center> ✅ </td>
90-
<td align=center> ❔ </td>
91-
</tr>
92-
<tr>
93-
<td align=center> Paddle TensorRT </td>
94-
<td align=center> use_paddle_infer_backend() + paddle_infer_option.enable_trt = True </td>
95-
<td align=center> ✅ </td>
96-
<td align=center> ✅ </td>
97-
</tr>
98-
<tr>
99-
<td align=center> TensorRT </td>
100-
<td align=center> use_trt_backend() </td>
101-
<td align=center> ✅ </td>
102-
<td align=center> ✅ </td>
103-
</tr>
104-
<tr>
105-
<td align=center> 昆仑芯 XPU </td>
106-
<td align=center> use_kunlunxin() </td>
107-
<td align=center> Paddle Lite </td>
108-
<td align=center> use_paddle_lite_backend() </td>
109-
<td align=center> N/A </td>
110-
<td align=center> ✅ </td>
111-
</tr>
112-
<tr>
113-
<td align=center> 华为 昇腾 </td>
114-
<td align=center> use_ascend() </td>
115-
<td align=center> Paddle Lite </td>
116-
<td align=center> use_paddle_lite_backend() </td>
117-
<td align=center> ❔ </td>
118-
<td align=center> ✅ </td>
119-
</tr>
120-
<tr>
121-
<td align=center> Graphcore IPU </td>
122-
<td align=center> use_ipu() </td>
123-
<td align=center> Paddle Inference </td>
124-
<td align=center> use_paddle_infer_backend() </td>
125-
<td align=center> ❔ </td>
126-
<td align=center> N/A </td>
127-
</tr>
128-
</table>
35+
|--cpu_threads | 当使用 cpu 推理时,指定推理的 cpu 线程数,默认为4。|
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
import os
17+
18+
import numpy as np
19+
from paddle import inference
20+
from scipy.special import softmax
21+
22+
from paddlenlp.transformers import AutoTokenizer
23+
from paddlenlp.utils.env import (
24+
PADDLE_INFERENCE_MODEL_SUFFIX,
25+
PADDLE_INFERENCE_WEIGHTS_SUFFIX,
26+
)
27+
28+
29+
def parse_arguments():
30+
parser = argparse.ArgumentParser()
31+
parser.add_argument("--model_dir", required=True, help="The directory of model.")
32+
parser.add_argument("--model_prefix", type=str, default="model", help="Prefix of the model file (no extension).")
33+
parser.add_argument("--device", choices=["gpu", "cpu"], default="cpu", help="Device for inference.")
34+
parser.add_argument("--device_id", type=int, default=0, help="GPU device ID if using GPU.")
35+
parser.add_argument("--cpu_threads", type=int, default=4, help="CPU threads if using CPU.")
36+
parser.add_argument("--batch_size", type=int, default=2, help="Batch size for inference.")
37+
parser.add_argument("--max_length", type=int, default=128, help="Max sequence length.")
38+
return parser.parse_args()
39+
40+
41+
def batchfy_text(texts, batch_size):
42+
return [texts[i : i + batch_size] for i in range(0, len(texts), batch_size)]
43+
44+
45+
class Predictor(object):
46+
def __init__(self, args):
47+
self.batch_size = args.batch_size
48+
self.max_length = args.max_length
49+
50+
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir)
51+
52+
model_file = os.path.join(args.model_dir, args.model_prefix + f"{PADDLE_INFERENCE_MODEL_SUFFIX}")
53+
params_file = os.path.join(args.model_dir, args.model_prefix + f"{PADDLE_INFERENCE_WEIGHTS_SUFFIX}")
54+
55+
if not os.path.exists(model_file):
56+
raise FileNotFoundError(f"Model file not found: {model_file}")
57+
if not os.path.exists(params_file):
58+
raise FileNotFoundError(f"Params file not found: {params_file}")
59+
60+
config = inference.Config(model_file, params_file)
61+
if args.device == "gpu":
62+
config.enable_use_gpu(100, args.device_id)
63+
else:
64+
config.disable_gpu()
65+
config.set_cpu_math_library_num_threads(args.cpu_threads)
66+
67+
config.switch_use_feed_fetch_ops(False)
68+
self.predictor = inference.create_predictor(config)
69+
self.input_handles = [self.predictor.get_input_handle(name) for name in self.predictor.get_input_names()]
70+
self.output_handle = self.predictor.get_output_handle(self.predictor.get_output_names()[0])
71+
72+
def preprocess(self, texts):
73+
encoded = self.tokenizer(
74+
texts,
75+
padding=True,
76+
truncation=True,
77+
max_length=self.max_length,
78+
return_token_type_ids=True,
79+
)
80+
input_ids = np.array(encoded["input_ids"], dtype="int64")
81+
token_type_ids = np.array(encoded["token_type_ids"], dtype="int64")
82+
return input_ids, token_type_ids
83+
84+
def infer(self, input_ids, token_type_ids):
85+
self.input_handles[0].copy_from_cpu(input_ids)
86+
self.input_handles[1].copy_from_cpu(token_type_ids)
87+
self.predictor.run()
88+
return self.output_handle.copy_to_cpu()
89+
90+
def postprocess(self, logits):
91+
probs = softmax(logits, axis=1)
92+
return {"label": probs.argmax(axis=1), "confidence": probs}
93+
94+
def predict(self, texts):
95+
input_ids, token_type_ids = self.preprocess(texts)
96+
logits = self.infer(input_ids, token_type_ids)
97+
return self.postprocess(logits)
98+
99+
100+
if __name__ == "__main__":
101+
args = parse_arguments()
102+
predictor = Predictor(args)
103+
104+
texts_ds = [
105+
"against shimmering cinematography that lends the setting the ethereal beauty of an asian landscape painting",
106+
"the situation in a well-balanced fashion",
107+
"at achieving the modest , crowd-pleasing goals it sets for itself",
108+
"so pat it makes your teeth hurt",
109+
"this new jangle of noise , mayhem and stupidity must be a serious contender for the title .",
110+
]
111+
label_map = {0: "negative", 1: "positive"}
112+
113+
batch_texts = batchfy_text(texts_ds, args.batch_size)
114+
115+
for bs, texts in enumerate(batch_texts):
116+
outputs = predictor.predict(texts)
117+
for i, sentence in enumerate(texts):
118+
label = outputs["label"][i]
119+
confidence = outputs["confidence"][i]
120+
print(
121+
f"Batch id: {bs}, example id: {i}, sentence: {sentence}, "
122+
f"label: {label_map[label]}, negative prob: {confidence[0]:.4f}, positive prob: {confidence[1]:.4f}."
123+
)

0 commit comments

Comments
 (0)