Skip to content

【PaddleNLP No.18】Create new infer example for BERT validated on pd3.0 #10422

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions slm/model_zoo/bert/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

[BERT](https://arxiv.org/abs/1810.04805) (Bidirectional Encoder Representations from Transformers)以[Transformer](https://arxiv.org/abs/1706.03762) 编码器为网络基本组件,使用掩码语言模型(Masked Language Model)和邻接句子预测(Next Sentence Prediction)两个任务在大规模无标注文本语料上进行预训练(pre-train),得到融合了双向内容的通用语义表示模型。以预训练产生的通用语义表示模型为基础,结合任务适配的简单输出层,微调(fine-tune)后即可应用到下游的 NLP 任务,效果通常也较直接在下游的任务上训练的模型更优。此前 BERT 即在[GLUE 评测任务](https://gluebenchmark.com/tasks)上取得了 SOTA 的结果。

本项目是 BERT 在 Paddle 2.0上的开源实现,包含了预训练和[GLUE 评测任务](https://gluebenchmark.com/tasks)上的微调代码。
本项目是 BERT 在 Paddle 2.0上的开源实现,并在 PaddlePaddle 3.0.0版本进行了适配与验证,包含了预训练和[GLUE 评测任务](https://gluebenchmark.com/tasks)上的微调代码。

## 快速开始

Expand Down Expand Up @@ -262,23 +262,22 @@ python -u ./export_model.py \
其中参数释义如下:
- `model_type` 指示了模型类型,使用 BERT 模型时设置为 bert 即可。
- `model_path` 表示训练模型的保存路径,与训练时的`output_dir`一致。
- `output_path` 表示导出预测模型文件的前缀。保存时会添加后缀(`pdiparams`,`pdiparams.info`,`pdmodel`);除此之外,还会在`output_path`包含的目录下保存 tokenizer 相关内容。
- `output_path` 表示导出预测模型文件的前缀。保存时会添加后缀(`pdiparams`,`pdiparams.info`,`pdmodel`/`json`);除此之外,还会在`output_path`包含的目录下保存 tokenizer 相关内容。

完成模型导出后,可以开始部署。`deploy/python/seq_cls_infer.py` 文件提供了 python 部署预测示例。可执行以下命令运行部署示例:
完成模型导出后,可以开始部署。`deploy/python/infer.py` 文件提供了 python 部署预测示例。可执行以下命令运行部署示例:

```shell
python deploy/python/seq_cls_infer.py --model_dir infer_model/ --device gpu --backend paddle
python deploy/python/infer.py --model_dir infer_model/ --device gpu
```

运行后预测结果打印如下:

```bash
[INFO] fastdeploy/runtime/runtime.cc(266)::CreatePaddleBackend Runtime initialized with Backend::PDINFER in Device::GPU.
Batch id: 0, example id: 0, sentence1: against shimmering cinematography that lends the setting the ethereal beauty of an asian landscape painting, label: positive, negative prob: 0.0003, positive prob: 0.9997.
Batch id: 1, example id: 0, sentence1: the situation in a well-balanced fashion, label: positive, negative prob: 0.0002, positive prob: 0.9998.
Batch id: 2, example id: 0, sentence1: at achieving the modest , crowd-pleasing goals it sets for itself, label: positive, negative prob: 0.0017, positive prob: 0.9983.
Batch id: 3, example id: 0, sentence1: so pat it makes your teeth hurt, label: negative, negative prob: 0.9986, positive prob: 0.0014.
Batch id: 4, example id: 0, sentence1: this new jangle of noise , mayhem and stupidity must be a serious contender for the title ., label: negative, negative prob: 0.9806, positive prob: 0.0194.
Batch id: 0, example id: 0, sentence: against shimmering cinematography that lends the setting the ethereal beauty of an asian landscape painting, label: positive, negative prob: 0.4623, positive prob: 0.5377.
Batch id: 0, example id: 1, sentence: the situation in a well-balanced fashion, label: positive, negative prob: 0.3500, positive prob: 0.6500.
Batch id: 1, example id: 0, sentence: at achieving the modest , crowd-pleasing goals it sets for itself, label: positive, negative prob: 0.4530, positive prob: 0.5470.
Batch id: 1, example id: 1, sentence: so pat it makes your teeth hurt, label: positive, negative prob: 0.3816, positive prob: 0.6184.
Batch id: 2, example id: 0, sentence: this new jangle of noise , mayhem and stupidity must be a serious contender for the title ., label: positive, negative prob: 0.3650, positive prob: 0.6350.
```

更多详细用法可参考 [Python 部署](deploy/python/README.md)。
Expand Down
117 changes: 12 additions & 105 deletions slm/model_zoo/bert/deploy/python/README.md
Original file line number Diff line number Diff line change
@@ -1,128 +1,35 @@
# FastDeploy BERT 模型 Python 部署示例

在部署前,参考 [FastDeploy SDK 安装文档](https://github.com/PaddlePaddle/FastDeploy/blob/develop/docs/cn/build_and_install/download_prebuilt_libraries.md)安装 FastDeploy Python SDK。

本目录下分别提供 `seq_cls_infer.py` 快速完成在 CPU/GPU 的 GLUE 文本分类任务的 Python 部署示例。
# BERT 模型 Python 推理示例
本目录下提供 `seq_cls_infer.py` 快速完成在 CPU/GPU 的 GLUE 文本分类任务的 Python 示例。

## 快速开始

以下示例展示如何基于 FastDeploy 库完成 BERT 模型在 GLUE SST-2 数据集上进行自然语言推断任务的 Python 预测部署,可通过命令行参数`--device`以及`--backend`指定运行在不同的硬件以及推理引擎后端,并使用`--model_dir`参数指定运行的模型,具体参数设置可查看下面[参数说明](#参数说明)。示例中的模型是按照 [BERT 训练文档](../../README.md)导出得到的部署模型,其模型目录为`model_zoo/bert/infer_model`(用户可按实际情况设置)。
可通过命令行参数`--device`指定运行在不同的硬件,并使用`--model_dir`参数指定运行的模型,具体参数设置可查看下面[参数说明](#参数说明)。示例中的模型是按照 [BERT 训练文档](../../README.md)导出得到的部署模型,其模型目录为`model_zoo/bert/infer_model`(用户可按实际情况设置)。


```bash
# CPU 推理
python seq_cls_infer.py --model_dir ../../infer_model/ --device cpu --backend paddle
python infer.py --model_dir ../../infer_model/ --device cpu
# GPU 推理
python seq_cls_infer.py --model_dir ../../infer_model/ --device gpu --backend paddle
python infer.py --model_dir ../../infer_model/ --device gpu
```

运行完成后返回的结果如下:

```bash
[INFO] fastdeploy/runtime/runtime.cc(266)::CreatePaddleBackend Runtime initialized with Backend::PDINFER in Device::GPU.
Batch id: 0, example id: 0, sentence1: against shimmering cinematography that lends the setting the ethereal beauty of an asian landscape painting, label: positive, negative prob: 0.0003, positive prob: 0.9997.
Batch id: 1, example id: 0, sentence1: the situation in a well-balanced fashion, label: positive, negative prob: 0.0002, positive prob: 0.9998.
Batch id: 2, example id: 0, sentence1: at achieving the modest , crowd-pleasing goals it sets for itself, label: positive, negative prob: 0.0017, positive prob: 0.9983.
Batch id: 3, example id: 0, sentence1: so pat it makes your teeth hurt, label: negative, negative prob: 0.9986, positive prob: 0.0014.
Batch id: 4, example id: 0, sentence1: this new jangle of noise , mayhem and stupidity must be a serious contender for the title ., label: negative, negative prob: 0.9806, positive prob: 0.0194.
Batch id: 0, example id: 0, sentence: against shimmering cinematography that lends the setting the ethereal beauty of an asian landscape painting, label: positive, negative prob: 0.4623, positive prob: 0.5377.
Batch id: 0, example id: 1, sentence: the situation in a well-balanced fashion, label: positive, negative prob: 0.3500, positive prob: 0.6500.
Batch id: 1, example id: 0, sentence: at achieving the modest , crowd-pleasing goals it sets for itself, label: positive, negative prob: 0.4530, positive prob: 0.5470.
Batch id: 1, example id: 1, sentence: so pat it makes your teeth hurt, label: positive, negative prob: 0.3816, positive prob: 0.6184.
Batch id: 2, example id: 0, sentence: this new jangle of noise , mayhem and stupidity must be a serious contender for the title ., label: positive, negative prob: 0.3650, positive prob: 0.6350.
```

## 参数说明

| 参数 |参数说明 |
|----------|--------------|
|--model_dir | 指定部署模型的目录, |
|--batch_size |输入的 batch size,默认为 1|
|--batch_size |输入的 batch size,默认为 2|
|--max_length |最大序列长度,默认为 128|
|--device | 运行的设备,可选范围: ['cpu', 'gpu'],默认为'cpu' |
|--device_id | 运行设备的 id。默认为0。 |
|--cpu_threads | 当使用 cpu 推理时,指定推理的 cpu 线程数,默认为1。|
|--backend | 支持的推理后端,可选范围: ['onnx_runtime', 'paddle', 'openvino', 'tensorrt', 'paddle_tensorrt'],默认为'paddle' |
|--use_fp16 | 是否使用 FP16模式进行推理。使用 tensorrt 和 paddle_tensorrt 后端时可开启,默认为 False |

## FastDeploy 高阶用法

FastDeploy 在 Python 端上,提供 `fastdeploy.RuntimeOption.use_xxx()` 以及 `fastdeploy.RuntimeOption.use_xxx_backend()` 接口支持开发者选择不同的硬件、不同的推理引擎进行部署。在不同的硬件上部署 BERT 模型,需要选择硬件所支持的推理引擎进行部署,下表展示如何在不同的硬件上选择可用的推理引擎部署 BERT 模型。

符号说明: (1) ✅: 已经支持; (2) ❔: 正在进行中; (3) N/A: 暂不支持;

<table>
<tr>
<td align=center> 硬件</td>
<td align=center> 硬件对应的接口</td>
<td align=center> 可用的推理引擎 </td>
<td align=center> 推理引擎对应的接口 </td>
<td align=center> 是否支持 Paddle 新格式量化模型 </td>
<td align=center> 是否支持 FP16 模式 </td>
</tr>
<tr>
<td rowspan=3 align=center> CPU </td>
<td rowspan=3 align=center> use_cpu() </td>
<td align=center> Paddle Inference </td>
<td align=center> use_paddle_infer_backend() </td>
<td align=center> ✅ </td>
<td align=center> N/A </td>
</tr>
<tr>
<td align=center> ONNX Runtime </td>
<td align=center> use_ort_backend() </td>
<td align=center> ✅ </td>
<td align=center> N/A </td>
</tr>
<tr>
<td align=center> OpenVINO </td>
<td align=center> use_openvino_backend() </td>
<td align=center> ❔ </td>
<td align=center> N/A </td>
</tr>
<tr>
<td rowspan=4 align=center> GPU </td>
<td rowspan=4 align=center> use_gpu() </td>
<td align=center> Paddle Inference </td>
<td align=center> use_paddle_infer_backend() </td>
<td align=center> ✅ </td>
<td align=center> N/A </td>
</tr>
<tr>
<td align=center> ONNX Runtime </td>
<td align=center> use_ort_backend() </td>
<td align=center> ✅ </td>
<td align=center> ❔ </td>
</tr>
<tr>
<td align=center> Paddle TensorRT </td>
<td align=center> use_paddle_infer_backend() + paddle_infer_option.enable_trt = True </td>
<td align=center> ✅ </td>
<td align=center> ✅ </td>
</tr>
<tr>
<td align=center> TensorRT </td>
<td align=center> use_trt_backend() </td>
<td align=center> ✅ </td>
<td align=center> ✅ </td>
</tr>
<tr>
<td align=center> 昆仑芯 XPU </td>
<td align=center> use_kunlunxin() </td>
<td align=center> Paddle Lite </td>
<td align=center> use_paddle_lite_backend() </td>
<td align=center> N/A </td>
<td align=center> ✅ </td>
</tr>
<tr>
<td align=center> 华为 昇腾 </td>
<td align=center> use_ascend() </td>
<td align=center> Paddle Lite </td>
<td align=center> use_paddle_lite_backend() </td>
<td align=center> ❔ </td>
<td align=center> ✅ </td>
</tr>
<tr>
<td align=center> Graphcore IPU </td>
<td align=center> use_ipu() </td>
<td align=center> Paddle Inference </td>
<td align=center> use_paddle_infer_backend() </td>
<td align=center> ❔ </td>
<td align=center> N/A </td>
</tr>
</table>
|--cpu_threads | 当使用 cpu 推理时,指定推理的 cpu 线程数,默认为4。|
123 changes: 123 additions & 0 deletions slm/model_zoo/bert/deploy/python/infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os

import numpy as np
from paddle import inference
from scipy.special import softmax

from paddlenlp.transformers import AutoTokenizer
from paddlenlp.utils.env import (
PADDLE_INFERENCE_MODEL_SUFFIX,
PADDLE_INFERENCE_WEIGHTS_SUFFIX,
)


def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--model_dir", required=True, help="The directory of model.")
parser.add_argument("--model_prefix", type=str, default="model", help="Prefix of the model file (no extension).")
parser.add_argument("--device", choices=["gpu", "cpu"], default="cpu", help="Device for inference.")
parser.add_argument("--device_id", type=int, default=0, help="GPU device ID if using GPU.")
parser.add_argument("--cpu_threads", type=int, default=4, help="CPU threads if using CPU.")
parser.add_argument("--batch_size", type=int, default=2, help="Batch size for inference.")
parser.add_argument("--max_length", type=int, default=128, help="Max sequence length.")
return parser.parse_args()


def batchfy_text(texts, batch_size):
return [texts[i : i + batch_size] for i in range(0, len(texts), batch_size)]


class Predictor(object):
def __init__(self, args):
self.batch_size = args.batch_size
self.max_length = args.max_length

self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir)

model_file = os.path.join(args.model_dir, args.model_prefix + f"{PADDLE_INFERENCE_MODEL_SUFFIX}")
params_file = os.path.join(args.model_dir, args.model_prefix + f"{PADDLE_INFERENCE_WEIGHTS_SUFFIX}")

if not os.path.exists(model_file):
raise FileNotFoundError(f"Model file not found: {model_file}")
if not os.path.exists(params_file):
raise FileNotFoundError(f"Params file not found: {params_file}")

config = inference.Config(model_file, params_file)
if args.device == "gpu":
config.enable_use_gpu(100, args.device_id)
else:
config.disable_gpu()
config.set_cpu_math_library_num_threads(args.cpu_threads)

config.switch_use_feed_fetch_ops(False)
self.predictor = inference.create_predictor(config)
self.input_handles = [self.predictor.get_input_handle(name) for name in self.predictor.get_input_names()]
self.output_handle = self.predictor.get_output_handle(self.predictor.get_output_names()[0])

def preprocess(self, texts):
encoded = self.tokenizer(
texts,
padding=True,
truncation=True,
max_length=self.max_length,
return_token_type_ids=True,
)
input_ids = np.array(encoded["input_ids"], dtype="int64")
token_type_ids = np.array(encoded["token_type_ids"], dtype="int64")
return input_ids, token_type_ids

def infer(self, input_ids, token_type_ids):
self.input_handles[0].copy_from_cpu(input_ids)
self.input_handles[1].copy_from_cpu(token_type_ids)
self.predictor.run()
return self.output_handle.copy_to_cpu()

def postprocess(self, logits):
probs = softmax(logits, axis=1)
return {"label": probs.argmax(axis=1), "confidence": probs}

def predict(self, texts):
input_ids, token_type_ids = self.preprocess(texts)
logits = self.infer(input_ids, token_type_ids)
return self.postprocess(logits)


if __name__ == "__main__":
args = parse_arguments()
predictor = Predictor(args)

texts_ds = [
"against shimmering cinematography that lends the setting the ethereal beauty of an asian landscape painting",
"the situation in a well-balanced fashion",
"at achieving the modest , crowd-pleasing goals it sets for itself",
"so pat it makes your teeth hurt",
"this new jangle of noise , mayhem and stupidity must be a serious contender for the title .",
]
label_map = {0: "negative", 1: "positive"}

batch_texts = batchfy_text(texts_ds, args.batch_size)

for bs, texts in enumerate(batch_texts):
outputs = predictor.predict(texts)
for i, sentence in enumerate(texts):
label = outputs["label"][i]
confidence = outputs["confidence"][i]
print(
f"Batch id: {bs}, example id: {i}, sentence: {sentence}, "
f"label: {label_map[label]}, negative prob: {confidence[0]:.4f}, positive prob: {confidence[1]:.4f}."
)
Loading
Loading