From b96de4bc7138b7568a28b48c3013a4fd9a1872b8 Mon Sep 17 00:00:00 2001 From: hanlintang Date: Wed, 23 Apr 2025 10:00:43 +0000 Subject: [PATCH] [PIR]Fix ernie-3.0 deploy&add training without hf --- slm/model_zoo/ernie-3.0/README.md | 14 +- .../ernie-3.0/deploy/python/README.md | 188 +++------------- .../ernie-3.0/deploy/python/seq_cls_infer.py | 137 +++++------- .../deploy/python/token_cls_infer.py | 145 +++++------- slm/model_zoo/ernie-3.0/infer.py | 13 +- .../ernie-3.0/run_token_cls_without_hf.py | 211 ++++++++++++++++++ 6 files changed, 372 insertions(+), 336 deletions(-) create mode 100644 slm/model_zoo/ernie-3.0/run_token_cls_without_hf.py diff --git a/slm/model_zoo/ernie-3.0/README.md b/slm/model_zoo/ernie-3.0/README.md index 5fd86eae0966..dfaec475a1c1 100644 --- a/slm/model_zoo/ernie-3.0/README.md +++ b/slm/model_zoo/ernie-3.0/README.md @@ -1293,6 +1293,7 @@ batch_size=32 和 1,预测精度为 FP16 时,GPU 下的效果-时延图: ├── compress_token_cls.py # 序列标注任务的压缩脚本 ├── compress_qa.py # 阅读理解任务的压缩脚本 ├── utils.py # 训练工具脚本 +├── infer.py # 推理脚本 ├── configs # 压缩配置文件夹 │ └── default.yml # 默认配置文件 ├── deploy # 部署目录 @@ -1375,6 +1376,9 @@ python run_seq_cls.py --model_name_or_path ernie-3.0-medium-zh --dataset afqmc # 序列标注任务 python run_token_cls.py --model_name_or_path ernie-3.0-medium-zh --dataset msra_ner --output_dir ./best_models --export_model_dir best_models/ --do_train --do_eval --do_export --config=configs/default.yml +# 如果无法连接huggingface +python run_token_cls_without_hf.py --model_name_or_path ernie-3.0-medium-zh --dataset msra_ner --output_dir ./best_models --export_model_dir best_models/ --do_train --do_eval --do_export --config=configs/default.yml + # 阅读理解任务 python run_qa.py --model_name_or_path ernie-3.0-medium-zh --dataset cmrc2018 --output_dir ./best_models --export_model_dir best_models/ --do_train --do_eval --do_export --config=configs/default.yml ``` @@ -1527,6 +1531,14 @@ python compress_qa.py --model_name_or_path best_models/cmrc2018/ --dataset cmrc2 三类任务(分类、序列标注、阅读理解)经过裁剪 + 量化后加速比均达到 3 倍左右,所有任务上平均精度损失可控制在 0.5 以内(0.46)。 + + +## 推理 +目录中的 ```infer.py```提供了使用导出模型进行推理的样例。运行命令: +```shell +python infer.py --model_name_or_path ernie-3.0-medium-zh --model_path ./best_models/afqmc/export/ +``` + ## 部署 @@ -1549,7 +1561,7 @@ python compress_qa.py --model_name_or_path best_models/cmrc2018/ --dataset cmrc2 -#### Python 部署 +### Python 部署 Python 部署请参考:[Python 部署指南](./deploy/python/README.md) diff --git a/slm/model_zoo/ernie-3.0/deploy/python/README.md b/slm/model_zoo/ernie-3.0/deploy/python/README.md index 79728e26f07d..ba0921c9d56b 100644 --- a/slm/model_zoo/ernie-3.0/deploy/python/README.md +++ b/slm/model_zoo/ernie-3.0/deploy/python/README.md @@ -8,51 +8,37 @@ ### 快速开始 -以下示例展示如何基于 FastDeploy 库完成 ERNIE 3.0 Medium 模型在 CLUE Benchmark 的 [AFQMC 数据集](https://github.com/CLUEbenchmark/CLUE)上进行文本分类任务的 Python 预测部署,可通过命令行参数`--device`以及`--backend`指定运行在不同的硬件以及推理引擎后端,并使用`--model_dir`参数指定运行的模型,具体参数设置可查看下面[参数说明](#参数说明)。示例中的模型是按照 [ERNIE 3.0 训练文档](../../README.md)导出得到的部署模型,其模型目录为`model_zoo/ernie-3.0/best_models/afqmc/export`(用户可按实际情况设置)。 +以下示例展示如何完成 ERNIE 3.0 Medium 模型在 CLUE Benchmark 的 [AFQMC 数据集](https://github.com/CLUEbenchmark/CLUE)上进行文本分类任务的 Python 预测部署,可通过命令行参数`--device`指定运行在不同的硬件,并使用`--model_dir`参数指定运行的模型,具体参数设置可查看下面[参数说明](#参数说明)。示例中的模型是按照 [ERNIE 3.0 训练文档](../../README.md)导出得到的部署模型,其模型目录为`model_zoo/ernie-3.0/best_models/afqmc/export`(用户可按实际情况设置)。 ```bash # CPU 推理 -python seq_cls_infer.py --model_dir ../../best_models/afqmc/export --device cpu --backend paddle +python seq_cls_infer.py --model_dir ../../best_models/afqmc/export --device cpu # GPU 推理 -python seq_cls_infer.py --model_dir ../../best_models/afqmc/export --device gpu --backend paddle +python seq_cls_infer.py --model_dir ../../best_models/afqmc/export --device gpu ``` 运行完成后返回的结果如下: ```bash +I0423 05:00:21.622229 8408 print_statistics.cc:44] --- detected [85, 273] subgraphs! +--- Running PIR pass [dead_code_elimination_pass] +I0423 05:00:21.622710 8408 print_statistics.cc:50] --- detected [113] subgraphs! +--- Running PIR pass [replace_fetch_with_shadow_output_pass] +I0423 05:00:21.622859 8408 print_statistics.cc:50] --- detected [1] subgraphs! +--- Running PIR pass [remove_shadow_feed_pass] +I0423 05:00:21.626749 8408 print_statistics.cc:50] --- detected [2] subgraphs! +--- Running PIR pass [inplace_pass] +I0423 05:00:21.631474 8408 print_statistics.cc:50] --- detected [2] subgraphs! +I0423 05:00:21.631560 8408 analysis_predictor.cc:1186] ======= pir optimization completed ======= +I0423 05:00:21.641817 8408 pir_interpreter.cc:1640] pir interpreter is running by trace mode ... +Batch 0, example 0 | s1: 花呗收款额度限制 | s2: 收钱码,对花呗支付的金额有限制吗 | label: 1 | score: 0.5175 +Batch 1, example 0 | s1: 花呗支持高铁票支付吗 | s2: 为什么友付宝不支持花呗付款 | label: 0 | score: 0.9873 -[INFO] fastdeploy/runtime.cc(596)::Init Runtime initialized with Backend::PDINFER in Device::CPU. -Batch id:0, example id:0, sentence1:花呗收款额度限制, sentence2:收钱码,对花呗支付的金额有限制吗, label:0, similarity:0.5099 -Batch id:1, example id:0, sentence1:花呗支持高铁票支付吗, sentence2:为什么友付宝不支持花呗付款, label:0, similarity:0.9862 - -``` - -### 量化模型部署 - -该示例支持部署 Paddle INT8 新格式量化模型,仅需在`--model_dir`参数传入量化模型路径,并且在对应硬件上选择可用的推理引擎后端,即可完成量化模型部署。在 GPU 上部署量化模型时,可选后端为`paddle_tensorrt`、`tensorrt`;在 CPU 上部署量化模型时,可选后端为`paddle`、`onnx_runtime`。下面将展示如何使用该示例完成量化模型部署,示例中的模型是按照 [ERNIE 3.0 训练文档](../../README.md) 压缩量化后导出得到的量化模型。 - -```bash - -# 在GPU上使用 tensorrt 后端,模型目录可按照实际模型路径设置 -python seq_cls_infer.py --model_dir ../../best_models/afqmc/width_mult_0.75/mse16_1/ --device gpu --backend tensorrt --model_prefix int8 - -# 在CPU上使用paddle_inference后端,模型目录可按照实际模型路径设置 -python seq_cls_infer.py --model_dir ../../best_models/afqmc/width_mult_0.75/mse16_1/ --device cpu --backend paddle --model_prefix int8 - -``` - -运行完成后返回的结果如下: - -```bash -[INFO] fastdeploy/runtime/runtime.cc(101)::Init Runtime initialized with Backend::PDINFER in Device::GPU. -Batch id:0, example id:0, sentence1:花呗收款额度限制, sentence2:收钱码,对花呗支付的金额有限制吗, label:0, similarity:0.5224 -Batch id:1, example id:0, sentence1:花呗支持高铁票支付吗, sentence2:为什么友付宝不支持花呗付款, label:0, similarity:0.9856 ``` - ### 参数说明 `seq_cls_infer.py` 除了以上示例的命令行参数,还支持更多命令行参数的设置。以下为各命令行参数的说明。 @@ -63,66 +49,32 @@ Batch id:1, example id:0, sentence1:花呗支持高铁票支付吗, sentence2: |--batch_size |输入的 batch size,默认为 1| |--max_length |最大序列长度,默认为 128| |--device | 运行的设备,可选范围: ['cpu', 'gpu'],默认为'cpu' | -|--backend | 支持的推理后端,可选范围: ['onnx_runtime', 'paddle', 'openvino', 'tensorrt', 'paddle_tensorrt'],默认为'paddle' | -|--use_fp16 | 是否使用 FP16模式进行推理。使用 tensorrt 和 paddle_tensorrt 后端时可开启,默认为 False | ## 序列标注任务 ### 快速开始 -以下示例展示如何基于 FastDeploy 库完成 ERNIE 3.0 Medium 模型在 CLUE Benchmark 的[ MSRA_NER 数据集](https://github.com/lemonhu/NER-BERT-pytorch/tree/master/data/msra)上进行序列标注任务的 Python 预测部署,可通过命令行参数`--device`以及`--backend`指定运行在不同的硬件以及推理引擎后端,并使用`--model_dir`参数指定运行的模型,具体参数设置可查看下面[参数说明](#参数说明)。示例中的模型是按照 [ERNIE 3.0 训练文档](../../README.md)导出得到的部署模型,其模型目录为`model_zoo/ernie-3.0/best_models/msra_ner/export`(用户可按实际情况设置)。 +以下示例展示如何完成 ERNIE 3.0 Medium 模型在 CLUE Benchmark 的[ MSRA_NER 数据集](https://github.com/lemonhu/NER-BERT-pytorch/tree/master/data/msra)上进行序列标注任务的 Python 预测部署,可通过命令行参数`--device`指定运行在不同的硬件,并使用`--model_dir`参数指定运行的模型,具体参数设置可查看下面[参数说明](#参数说明)。示例中的模型是按照 [ERNIE 3.0 训练文档](../../README.md)导出得到的部署模型,其模型目录为`model_zoo/ernie-3.0/best_models/msra_ner/export`(用户可按实际情况设置)。 ```bash # CPU 推理 -python token_cls_infer.py --model_dir ../../best_models/msra_ner/export/ --device cpu --backend paddle +python token_cls_infer.py --model_dir ../../best_models/msra_ner/export/ --device cpu # GPU 推理 -python token_cls_infer.py --model_dir ../../best_models/msra_ner/export/ --device gpu --backend paddle - -``` - -运行完成后返回的结果如下: - -```bash - -[INFO] fastdeploy/runtime.cc(500)::Init Runtime initialized with Backend::PDINFER in Device::CPU. -input data: 北京的涮肉,重庆的火锅,成都的小吃都是极具特色的美食。 -The model detects all entities: -entity: 北京 label: LOC pos: [0, 1] -entity: 重庆 label: LOC pos: [6, 7] -entity: 成都 label: LOC pos: [12, 13] ------------------------------ -input data: 乔丹、科比、詹姆斯和姚明都是篮球界的标志性人物。 -The model detects all entities: -entity: 乔丹 label: PER pos: [0, 1] -entity: 科比 label: PER pos: [3, 4] -entity: 詹姆斯 label: PER pos: [6, 8] -entity: 姚明 label: PER pos: [10, 11] ------------------------------ - -``` - -### 量化模型部署 - -该示例支持部署 Paddle INT8 新格式量化模型,仅需在`--model_dir`参数传入量化模型路径,并且在对应硬件上选择可用的推理引擎后端,即可完成量化模型部署。在 GPU 上部署量化模型时,可选后端为`paddle_tensorrt`、`tensorrt`;在 CPU 上部署量化模型时,可选后端为`paddle`、`onnx_runtime`。下面将展示如何使用该示例完成量化模型部署,示例中的模型是按照 [ERNIE 3.0 训练文档](../../README.md) 压缩量化后导出得到的量化模型。 - -```bash - -# 在GPU上使用 tensorrt 后端,模型目录可按照实际模型路径设置 -python token_cls_infer.py --model_dir ../../best_models/msra_ner/width_mult_0.75/mse16_1/ --device gpu --backend tensorrt --model_prefix int8 - -# 在CPU上使用paddle_inference后端,模型目录可按照实际模型路径设置 -python token_cls_infer.py --model_dir ../../best_models/msra_ner/width_mult_0.75/mse16_1/ --device cpu --backend paddle --model_prefix int8 +python token_cls_infer.py --model_dir ../../best_models/msra_ner/export/ --device gpu ``` 运行完成后返回的结果如下: ```bash - -[INFO] fastdeploy/runtime.cc(500)::Init Runtime initialized with Backend::PDINFER in Device::CPU. +...... +--- Running PIR pass [inplace_pass] +I0423 09:51:42.250245 4644 print_statistics.cc:50] --- detected [1] subgraphs! +I0423 09:51:42.250334 4644 analysis_predictor.cc:1186] ======= pir optimization completed ======= +I0423 09:51:42.261358 4644 pir_interpreter.cc:1640] pir interpreter is running by trace mode ... input data: 北京的涮肉,重庆的火锅,成都的小吃都是极具特色的美食。 The model detects all entities: entity: 北京 label: LOC pos: [0, 1] @@ -148,97 +100,7 @@ entity: 姚明 label: PER pos: [10, 11] |--batch_size |输入的 batch size,默认为 1| |--max_length |最大序列长度,默认为 128| |--device | 运行的设备,可选范围: ['cpu', 'gpu'],默认为'cpu' | -|--backend | 支持的推理后端,可选范围: ['onnx_runtime', 'paddle', 'openvino', 'tensorrt', 'paddle_tensorrt'],默认为'paddle' | -|--use_fp16 | 是否使用 FP16模式进行推理。使用 tensorrt 和 paddle_tensorrt 后端时可开启,默认为 False | -|--model_prefix| 模型文件前缀。前缀会分别与'.pdmodel'和'.pdiparams'拼接得到模型文件名和参数文件名。默认为 'model'| - - -## FastDeploy 高阶用法 - -FastDeploy 在 Python 端上,提供 `fastdeploy.RuntimeOption.use_xxx()` 以及 `fastdeploy.RuntimeOption.use_xxx_backend()` 接口支持开发者选择不同的硬件、不同的推理引擎进行部署。在不同的硬件上部署 ERNIE 3.0 模型,需要选择硬件所支持的推理引擎进行部署,下表展示如何在不同的硬件上选择可用的推理引擎部署 ERNIE 3.0 模型。 - -符号说明: (1) ✅: 已经支持; (2) ❔: 正在进行中; (3) N/A: 暂不支持; - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
硬件 硬件对应的接口 可用的推理引擎 推理引擎对应的接口 是否支持 Paddle 新格式量化模型 是否支持 FP16 模式
CPU use_cpu() Paddle Inference use_paddle_infer_backend() N/A
ONNX Runtime use_ort_backend() N/A
OpenVINO use_openvino_backend() N/A
GPU use_gpu() Paddle Inference use_paddle_infer_backend() N/A
ONNX Runtime use_ort_backend()
Paddle TensorRT use_trt_backend() + enable_paddle_to_trt()
TensorRT use_trt_backend()
昆仑芯 XPU use_kunlunxin() Paddle Lite use_paddle_lite_backend() N/A
华为 昇腾 use_ascend() Paddle Lite use_paddle_lite_backend()
Graphcore IPU use_ipu() Paddle Inference use_paddle_infer_backend() N/A
+|--model_prefix| 模型文件前缀。前缀会分别与'PADDLE_INFERENCE_MODEL_SUFFIX'和'PADDLE_INFERENCE_WEIGHTS_SUFFIX'拼接得到模型文件名和参数文件名。默认为 'model'| ## 相关文档 diff --git a/slm/model_zoo/ernie-3.0/deploy/python/seq_cls_infer.py b/slm/model_zoo/ernie-3.0/deploy/python/seq_cls_infer.py index 8d8a11505b07..e46e70007e1a 100644 --- a/slm/model_zoo/ernie-3.0/deploy/python/seq_cls_infer.py +++ b/slm/model_zoo/ernie-3.0/deploy/python/seq_cls_infer.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,40 +11,29 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import distutils.util + +import argparse import os -import fastdeploy as fd import numpy as np +import paddle.inference as paddle_infer from paddlenlp.transformers import AutoTokenizer +from paddlenlp.utils.env import ( + PADDLE_INFERENCE_MODEL_SUFFIX, + PADDLE_INFERENCE_WEIGHTS_SUFFIX, +) def parse_arguments(): - import argparse - parser = argparse.ArgumentParser() parser.add_argument("--model_dir", required=True, help="The directory of model.") parser.add_argument("--vocab_path", type=str, default="", help="The path of tokenizer vocab.") parser.add_argument("--model_prefix", type=str, default="model", help="The model and params file prefix.") - parser.add_argument( - "--device", - type=str, - default="cpu", - choices=["gpu", "cpu"], - help="Type of inference device, support 'cpu' or 'gpu'.", - ) - parser.add_argument( - "--backend", - type=str, - default="paddle", - choices=["onnx_runtime", "paddle", "openvino", "tensorrt", "paddle_tensorrt"], - help="The inference runtime backend.", - ) - parser.add_argument("--batch_size", type=int, default=1, help="The batch size of data.") - parser.add_argument("--max_length", type=int, default=128, help="The max length of sequence.") - parser.add_argument("--log_interval", type=int, default=10, help="The interval of logging.") - parser.add_argument("--use_fp16", type=distutils.util.strtobool, default=False, help="Wheter to use FP16 mode") + parser.add_argument("--device", type=str, default="cpu", choices=["gpu", "cpu"]) + parser.add_argument("--batch_size", type=int, default=1) + parser.add_argument("--max_length", type=int, default=128) + parser.add_argument("--log_interval", type=int, default=10) return parser.parse_args() @@ -52,97 +41,79 @@ def batchfy_text(texts, batch_size): batch_texts = [] batch_start = 0 while batch_start < len(texts): - batch_texts += [texts[batch_start : min(batch_start + batch_size, len(texts))]] + batch_texts.append(texts[batch_start : batch_start + batch_size]) batch_start += batch_size return batch_texts -class Predictor(object): +class Predictor: def __init__(self, args): self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir) - self.runtime = self.create_fd_runtime(args) + self.predictor = self.create_predictor(args) + self.input_names = self.predictor.get_input_names() + self.output_names = self.predictor.get_output_names() self.batch_size = args.batch_size self.max_length = args.max_length - def create_fd_runtime(self, args): - option = fd.RuntimeOption() - model_path = os.path.join(args.model_dir, args.model_prefix + ".pdmodel") - params_path = os.path.join(args.model_dir, args.model_prefix + ".pdiparams") - option.set_model_path(model_path, params_path) - if args.device == "cpu": - option.use_cpu() - else: - option.use_gpu() - if args.backend == "paddle": - option.use_paddle_infer_backend() - elif args.backend == "onnx_runtime": - option.use_ort_backend() - elif args.backend == "openvino": - option.use_openvino_backend() + def create_predictor(self, args): + model_path = os.path.join(args.model_dir, args.model_prefix + f"{PADDLE_INFERENCE_MODEL_SUFFIX}") + params_path = os.path.join(args.model_dir, args.model_prefix + f"{PADDLE_INFERENCE_WEIGHTS_SUFFIX}") + config = paddle_infer.Config(model_path, params_path) + + if args.device == "gpu": + config.enable_use_gpu(100, 0) else: - option.use_trt_backend() - if args.backend == "paddle_tensorrt": - option.enable_paddle_to_trt() - option.enable_paddle_trt_collect_shape() - trt_file = os.path.join(args.model_dir, "model.trt") - option.set_trt_input_shape( - "input_ids", - min_shape=[1, 1], - opt_shape=[args.batch_size, args.max_length], - max_shape=[args.batch_size, args.max_length], - ) - option.set_trt_input_shape( - "token_type_ids", - min_shape=[1, 1], - opt_shape=[args.batch_size, args.max_length], - max_shape=[args.batch_size, args.max_length], - ) - if args.use_fp16: - option.enable_trt_fp16() - trt_file = trt_file + ".fp16" - option.set_trt_cache_file(trt_file) - return fd.Runtime(option) + config.disable_gpu() + config.switch_use_feed_fetch_ops(False) + config.enable_memory_optim() + return paddle_infer.create_predictor(config) def preprocess(self, text, text_pair): - data = self.tokenizer(text, text_pair, max_length=self.max_length, padding=True, truncation=True) - input_ids_name = self.runtime.get_input_info(0).name - token_type_ids_name = self.runtime.get_input_info(1).name - input_map = { - input_ids_name: np.array(data["input_ids"], dtype="int64"), - token_type_ids_name: np.array(data["token_type_ids"], dtype="int64"), + encoded = self.tokenizer( + text, text_pair, max_length=self.max_length, padding=True, truncation=True, return_tensors="np" + ) + return { + "input_ids": encoded["input_ids"].astype("int64"), + "token_type_ids": encoded["token_type_ids"].astype("int64"), } - return input_map def infer(self, input_map): - results = self.runtime.infer(input_map) - return results + input_ids_handle = self.predictor.get_input_handle(self.input_names[0]) + token_type_ids_handle = self.predictor.get_input_handle(self.input_names[1]) - def postprocess(self, infer_data): - logits = np.array(infer_data[0]) + input_ids_handle.copy_from_cpu(input_map["input_ids"]) + token_type_ids_handle.copy_from_cpu(input_map["token_type_ids"]) + + self.predictor.run() + + output_handle = self.predictor.get_output_handle(self.output_names[0]) + return output_handle.copy_to_cpu() + + def postprocess(self, logits): max_value = np.max(logits, axis=1, keepdims=True) - exp_data = np.exp(logits - max_value) - probs = exp_data / np.sum(exp_data, axis=1, keepdims=True) - out_dict = {"label": probs.argmax(axis=-1), "confidence": probs.max(axis=-1)} - return out_dict + exp = np.exp(logits - max_value) + probs = exp / np.sum(exp, axis=1, keepdims=True) + return {"label": np.argmax(probs, axis=1), "confidence": np.max(probs, axis=1)} def predict(self, texts, texts_pair=None): input_map = self.preprocess(texts, texts_pair) - infer_result = self.infer(input_map) - output = self.postprocess(infer_result) - return output + logits = self.infer(input_map) + return self.postprocess(logits) if __name__ == "__main__": args = parse_arguments() predictor = Predictor(args) + texts_ds = ["花呗收款额度限制", "花呗支持高铁票支付吗"] texts_pair_ds = ["收钱码,对花呗支付的金额有限制吗", "为什么友付宝不支持花呗付款"] + batch_texts = batchfy_text(texts_ds, args.batch_size) batch_texts_pair = batchfy_text(texts_pair_ds, args.batch_size) for bs, (texts, texts_pair) in enumerate(zip(batch_texts, batch_texts_pair)): outputs = predictor.predict(texts, texts_pair) - for i, (sentence1, sentence2) in enumerate(zip(texts, texts_pair)): + for i, (s1, s2) in enumerate(zip(texts, texts_pair)): print( - f"Batch id:{bs}, example id:{i}, sentence1:{sentence1}, sentence2:{sentence2}, label:{outputs['label'][i]}, similarity:{outputs['confidence'][i]:.4f}" + f"Batch {bs}, example {i} | s1: {s1} | s2: {s2} | label: {outputs['label'][i]} | score: {outputs['confidence'][i]:.4f}" ) diff --git a/slm/model_zoo/ernie-3.0/deploy/python/token_cls_infer.py b/slm/model_zoo/ernie-3.0/deploy/python/token_cls_infer.py index b7da79eb5ab0..4d2ecc6d0029 100644 --- a/slm/model_zoo/ernie-3.0/deploy/python/token_cls_infer.py +++ b/slm/model_zoo/ernie-3.0/deploy/python/token_cls_infer.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,40 +11,29 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import distutils.util + +import argparse import os -import fastdeploy as fd import numpy as np +import paddle.inference as paddle_infer from paddlenlp.transformers import AutoTokenizer +from paddlenlp.utils.env import ( + PADDLE_INFERENCE_MODEL_SUFFIX, + PADDLE_INFERENCE_WEIGHTS_SUFFIX, +) def parse_arguments(): - import argparse - parser = argparse.ArgumentParser() parser.add_argument("--model_dir", required=True, help="The directory of model.") parser.add_argument("--vocab_path", type=str, default="", help="The path of tokenizer vocab.") parser.add_argument("--model_prefix", type=str, default="model", help="The model and params file prefix.") - parser.add_argument( - "--device", - type=str, - default="cpu", - choices=["gpu", "cpu"], - help="Type of inference device, support 'cpu' or 'gpu'.", - ) - parser.add_argument( - "--backend", - type=str, - default="paddle", - choices=["onnx_runtime", "paddle", "openvino", "tensorrt", "paddle_tensorrt"], - help="The inference runtime backend.", - ) - parser.add_argument("--batch_size", type=int, default=1, help="The batch size of data.") - parser.add_argument("--max_length", type=int, default=128, help="The max length of sequence.") - parser.add_argument("--log_interval", type=int, default=10, help="The interval of logging.") - parser.add_argument("--use_fp16", type=distutils.util.strtobool, default=False, help="Wheter to use FP16 mode") + parser.add_argument("--device", type=str, default="cpu", choices=["gpu", "cpu"]) + parser.add_argument("--batch_size", type=int, default=1) + parser.add_argument("--max_length", type=int, default=128) + parser.add_argument("--log_interval", type=int, default=10) return parser.parse_args() @@ -52,79 +41,63 @@ def batchfy_text(texts, batch_size): batch_texts = [] batch_start = 0 while batch_start < len(texts): - batch_texts += [texts[batch_start : min(batch_start + batch_size, len(texts))]] + batch_texts.append(texts[batch_start : batch_start + batch_size]) batch_start += batch_size return batch_texts -class ErnieForTokenClassificationPredictor(object): +class ErnieForTokenClassificationPredictor: def __init__(self, args): self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir) - self.runtime = self.create_fd_runtime(args) + self.predictor = self.create_predictor(args) + self.input_names = self.predictor.get_input_names() + self.output_names = self.predictor.get_output_names() self.batch_size = args.batch_size self.max_length = args.max_length self.label_names = ["B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "O"] - def create_fd_runtime(self, args): - option = fd.RuntimeOption() - model_path = os.path.join(args.model_dir, args.model_prefix + ".pdmodel") - params_path = os.path.join(args.model_dir, args.model_prefix + ".pdiparams") - option.set_model_path(model_path, params_path) - if args.device == "cpu": - option.use_cpu() - else: - option.use_gpu() - if args.backend == "paddle": - option.use_paddle_infer_backend() - elif args.backend == "onnx_runtime": - option.use_ort_backend() - elif args.backend == "openvino": - option.use_openvino_backend() + def create_predictor(self, args): + model_path = os.path.join(args.model_dir, args.model_prefix + f"{PADDLE_INFERENCE_MODEL_SUFFIX}") + params_path = os.path.join(args.model_dir, args.model_prefix + f"{PADDLE_INFERENCE_WEIGHTS_SUFFIX}") + config = paddle_infer.Config(model_path, params_path) + + if args.device == "gpu": + config.enable_use_gpu(100, 0) else: - option.use_trt_backend() - if args.backend == "paddle_tensorrt": - option.enable_paddle_to_trt() - option.enable_paddle_trt_collect_shape() - trt_file = os.path.join(args.model_dir, "infer.trt") - option.set_trt_input_shape( - "input_ids", - min_shape=[1, 1], - opt_shape=[args.batch_size, args.max_length], - max_shape=[args.batch_size, args.max_length], - ) - option.set_trt_input_shape( - "token_type_ids", - min_shape=[1, 1], - opt_shape=[args.batch_size, args.max_length], - max_shape=[args.batch_size, args.max_length], - ) - if args.use_fp16: - option.enable_trt_fp16() - trt_file = trt_file + ".fp16" - option.set_trt_cache_file(trt_file) - return fd.Runtime(option) + config.disable_gpu() + + config.switch_use_feed_fetch_ops(False) + config.enable_memory_optim() + return paddle_infer.create_predictor(config) def preprocess(self, texts): - is_split_into_words = False - if isinstance(texts[0], list): - is_split_into_words = True - data = self.tokenizer( - texts, max_length=self.max_length, padding=True, truncation=True, is_split_into_words=is_split_into_words + is_split_into_words = isinstance(texts[0], list) + encoded = self.tokenizer( + texts, + max_length=self.max_length, + padding=True, + truncation=True, + is_split_into_words=is_split_into_words, + return_tensors="np", ) - input_ids_name = self.runtime.get_input_info(0).name - token_type_ids_name = self.runtime.get_input_info(1).name - input_map = { - input_ids_name: np.array(data["input_ids"], dtype="int64"), - token_type_ids_name: np.array(data["token_type_ids"], dtype="int64"), + return { + "input_ids": encoded["input_ids"].astype("int64"), + "token_type_ids": encoded["token_type_ids"].astype("int64"), } - return input_map def infer(self, input_map): - results = self.runtime.infer(input_map) - return results + input_ids_handle = self.predictor.get_input_handle(self.input_names[0]) + token_type_ids_handle = self.predictor.get_input_handle(self.input_names[1]) + + input_ids_handle.copy_from_cpu(input_map["input_ids"]) + token_type_ids_handle.copy_from_cpu(input_map["token_type_ids"]) + + self.predictor.run() + output_handle = self.predictor.get_output_handle(self.output_names[0]) + return output_handle.copy_to_cpu() def postprocess(self, infer_data, input_data): - result = np.array(infer_data[0]) + result = np.array(infer_data) tokens_label = result.argmax(axis=-1).tolist() value = [] for batch, token_label in enumerate(tokens_label): @@ -132,7 +105,8 @@ def postprocess(self, infer_data, input_data): label_name = "" items = [] for i, label in enumerate(token_label): - if (self.label_names[label] == "O" or "B-" in self.label_names[label]) and start >= 0: + label_str = self.label_names[label] + if (label_str == "O" or "B-" in label_str) and start >= 0: entity = input_data[batch][start : i - 1] if isinstance(entity, list): entity = "".join(entity) @@ -146,19 +120,17 @@ def postprocess(self, infer_data, input_data): } ) start = -1 - if "B-" in self.label_names[label]: + if "B-" in label_str: start = i - 1 - label_name = self.label_names[label][2:] + label_name = label_str[2:] value.append(items) - out_dict = {"value": value, "tokens_label": tokens_label} - return out_dict + return {"value": value, "tokens_label": tokens_label} def predict(self, texts): input_map = self.preprocess(texts) infer_result = self.infer(input_map) - output = self.postprocess(infer_result, texts) - return output + return self.postprocess(infer_result, texts) def token_cls_print_ret(infer_result, input_data): @@ -166,8 +138,8 @@ def token_cls_print_ret(infer_result, input_data): for i, ret in enumerate(rets): print("input data:", input_data[i]) print("The model detects all entities:") - for iterm in ret: - print("entity:", iterm["entity"], " label:", iterm["label"], " pos:", iterm["pos"]) + for item in ret: + print("entity:", item["entity"], " label:", item["label"], " pos:", item["pos"]) print("-----------------------------") @@ -176,6 +148,7 @@ def token_cls_print_ret(infer_result, input_data): predictor = ErnieForTokenClassificationPredictor(args) texts = ["北京的涮肉,重庆的火锅,成都的小吃都是极具特色的美食。", "乔丹、科比、詹姆斯和姚明都是篮球界的标志性人物。"] batch_data = batchfy_text(texts, args.batch_size) + for data in batch_data: outputs = predictor.predict(data) token_cls_print_ret(outputs, data) diff --git a/slm/model_zoo/ernie-3.0/infer.py b/slm/model_zoo/ernie-3.0/infer.py index 0941a2db9b8a..524fb0801bc4 100755 --- a/slm/model_zoo/ernie-3.0/infer.py +++ b/slm/model_zoo/ernie-3.0/infer.py @@ -30,6 +30,10 @@ from paddlenlp.metrics.squad import compute_prediction, squad_evaluate from paddlenlp.trainer.argparser import strtobool from paddlenlp.transformers import AutoTokenizer +from paddlenlp.utils.env import ( + PADDLE_INFERENCE_MODEL_SUFFIX, + PADDLE_INFERENCE_WEIGHTS_SUFFIX, +) METRIC_CLASSES = { "afqmc": Accuracy, @@ -183,8 +187,8 @@ def create_predictor(cls, args): import paddle2onnx onnx_model = paddle2onnx.command.c_paddle_to_onnx( - model_file=args.model_path + ".pdmodel", - params_file=args.model_path + ".pdiparams", + model_file=args.model_path + f"model{PADDLE_INFERENCE_MODEL_SUFFIX}", + params_file=args.model_path + f"model{PADDLE_INFERENCE_WEIGHTS_SUFFIX}", opset_version=13, enable_onnx_checker=True, ) @@ -210,7 +214,10 @@ def create_predictor(cls, args): input_handles = [input_name1, input_name2] return cls(predictor, input_handles, []) - config = paddle.inference.Config(args.model_path + ".pdmodel", args.model_path + ".pdiparams") + config = paddle.inference.Config( + args.model_path + f"model{PADDLE_INFERENCE_MODEL_SUFFIX}", + args.model_path + f"model{PADDLE_INFERENCE_WEIGHTS_SUFFIX}", + ) if args.device == "gpu": # set GPU configs accordingly config.enable_use_gpu(100, 0) diff --git a/slm/model_zoo/ernie-3.0/run_token_cls_without_hf.py b/slm/model_zoo/ernie-3.0/run_token_cls_without_hf.py new file mode 100644 index 000000000000..61ead7050ae1 --- /dev/null +++ b/slm/model_zoo/ernie-3.0/run_token_cls_without_hf.py @@ -0,0 +1,211 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from functools import partial + +import numpy as np +import paddle +import paddle.nn as nn + +# 🔁 替换这行: +# from evaluate import load as load_metric +from seqeval.metrics import accuracy_score, f1_score, precision_score, recall_score +from utils import DataArguments, ModelArguments, load_config, token_convert_example + +import paddlenlp +from paddlenlp.data import DataCollatorForTokenClassification +from paddlenlp.datasets import load_dataset +from paddlenlp.trainer import ( + PdArgumentParser, + Trainer, + TrainingArguments, + get_last_checkpoint, +) +from paddlenlp.transformers import ErnieForTokenClassification, ErnieTokenizer +from paddlenlp.utils.log import logger + + +def main(): + parser = PdArgumentParser((ModelArguments, DataArguments, TrainingArguments)) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + model_args, data_args, training_args = load_config( + model_args.config, "TokenClassification", data_args.dataset, model_args, data_args, training_args + ) + + training_args.print_config(model_args, "Model") + training_args.print_config(data_args, "Data") + + paddle.set_device(training_args.device) + + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: {training_args.world_size}, " + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + ) + + data_args.dataset = data_args.dataset.strip() + training_args.output_dir = os.path.join(training_args.output_dir, data_args.dataset) + + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info(f"Checkpoint detected, resuming training at {last_checkpoint}.") + + raw_datasets = load_dataset(data_args.dataset) + label_list = raw_datasets["train"].label_list + data_args.label_list = label_list + data_args.ignore_label = -100 + data_args.no_entity_id = 0 + + num_classes = len(label_list) + + tokenizer = ErnieTokenizer.from_pretrained(model_args.model_name_or_path) + model = ErnieForTokenClassification.from_pretrained(model_args.model_name_or_path, num_classes=num_classes) + + class criterion(nn.Layer): + def __init__(self): + super(criterion, self).__init__() + self.loss_fn = paddle.nn.loss.CrossEntropyLoss(ignore_index=data_args.ignore_label) + + def forward(self, *args, **kwargs): + return paddle.mean(self.loss_fn(*args, **kwargs)) + + loss_fct = criterion() + + trans_fn = partial( + token_convert_example, + tokenizer=tokenizer, + no_entity_id=data_args.no_entity_id, + max_seq_length=data_args.max_seq_length, + dynamic_max_length=data_args.dynamic_max_length, + ) + data_collator = DataCollatorForTokenClassification(tokenizer, label_pad_token_id=data_args.ignore_label) + + logger.info("Data Preprocessing...") + if training_args.do_train: + train_dataset = raw_datasets["train"].map(trans_fn, lazy=training_args.lazy_data_processing) + if training_args.do_eval: + eval_dataset = raw_datasets["test"].map(trans_fn, lazy=training_args.lazy_data_processing) + if training_args.do_predict: + test_dataset = raw_datasets["test"].map(trans_fn, lazy=training_args.lazy_data_processing) + + # ✅ 替换为原生 seqeval 的 compute_metrics 函数 + def compute_metrics(p): + predictions, labels = p + predictions = np.argmax(predictions, axis=2) + + true_predictions = [ + [label_list[p] for (p, l) in zip(pred, label) if l != -100] for pred, label in zip(predictions, labels) + ] + true_labels = [ + [label_list[l] for (p, l) in zip(pred, label) if l != -100] for pred, label in zip(predictions, labels) + ] + + return { + "precision": precision_score(true_labels, true_predictions), + "recall": recall_score(true_labels, true_predictions), + "f1": f1_score(true_labels, true_predictions), + "accuracy": accuracy_score(true_labels, true_predictions), + } + + trainer = Trainer( + model=model, + criterion=loss_fct, + args=training_args, + data_collator=data_collator, + train_dataset=train_dataset if training_args.do_train else None, + eval_dataset=eval_dataset if training_args.do_eval else None, + tokenizer=tokenizer, + compute_metrics=compute_metrics, + ) + + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + checkpoint = last_checkpoint + + if training_args.do_train: + train_result = trainer.train(resume_from_checkpoint=checkpoint) + metrics = train_result.metrics + trainer.save_model() + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + + if training_args.do_eval: + eval_metrics = trainer.evaluate() + trainer.log_metrics("eval", eval_metrics) + + if training_args.do_predict: + test_ret = trainer.predict(test_dataset) + trainer.log_metrics("test", test_ret.metrics) + tokens_label = test_ret.predictions.argmax(axis=-1).tolist() + value = [] + for batch, token_label in enumerate(tokens_label): + start = -1 + label_name = "" + items = [] + input_data = tokenizer.convert_ids_to_tokens(test_dataset[batch]["input_ids"])[1:-1] + for i, label in enumerate(token_label): + if (data_args.label_list[label] == "O" or "B-" in data_args.label_list[label]) and start >= 0: + entity = input_data[start : i - 1] + if isinstance(entity, list): + entity = "".join(entity) + items.append( + { + "pos": [start, i - 2], + "entity": entity, + "label": label_name, + } + ) + start = -1 + if "B-" in data_args.label_list[label]: + start = i - 1 + label_name = data_args.label_list[label][2:] + if start >= 0: + items.append( + { + "pos": [start, len(token_label) - 1], + "entity": input_data[start : len(token_label) - 1], + "label": "", + } + ) + value.append(items) + + out_dict = {"value": value, "tokens_label": tokens_label} + out_file = open(os.path.join(training_args.output_dir, "test_results.json"), "w") + json.dump(out_dict, out_file, ensure_ascii=True) + + if training_args.do_export: + input_spec = [ + paddle.static.InputSpec(shape=[None, None], dtype="int64"), + paddle.static.InputSpec(shape=[None, None], dtype="int64"), + ] + model_args.export_model_dir = os.path.join(model_args.export_model_dir, data_args.dataset, "export") + paddlenlp.transformers.export_model( + model=trainer.model, input_spec=input_spec, path=model_args.export_model_dir + ) + trainer.tokenizer.save_pretrained(model_args.export_model_dir) + + +if __name__ == "__main__": + main()