From 526a9be6edcad1483092b635f5531faaa69edf52 Mon Sep 17 00:00:00 2001 From: hanlintang Date: Wed, 16 Apr 2025 09:26:22 +0000 Subject: [PATCH] Create ernie1.0 infer example to fit pd3.0.0 --- slm/model_zoo/ernie-1.0/README.md | 22 +-- .../ernie-1.0/finetune/deploy/README.md | 116 ++------------ .../ernie-1.0/finetune/deploy/infer.py | 112 +++++++++++++ .../finetune/deploy/seq_cls_infer.py | 149 ------------------ 4 files changed, 141 insertions(+), 258 deletions(-) create mode 100644 slm/model_zoo/ernie-1.0/finetune/deploy/infer.py delete mode 100644 slm/model_zoo/ernie-1.0/finetune/deploy/seq_cls_infer.py diff --git a/slm/model_zoo/ernie-1.0/README.md b/slm/model_zoo/ernie-1.0/README.md index c5b10b8137da..fb78a22ff274 100644 --- a/slm/model_zoo/ernie-1.0/README.md +++ b/slm/model_zoo/ernie-1.0/README.md @@ -616,7 +616,7 @@ python run_qa.py \ ## 4. 预测部署 -以中文文本情感分类问题为例,介绍一下从模型 finetune 到部署的过程。 +以中文文本情感分类问题为例,介绍一下从模型 finetune 到部署的过程(已经在 PaddlePaddle 3.0.0版本验证)。 与之前的 finetune 参数配置稍有区别,此处加入了一些配置选项。 @@ -644,20 +644,24 @@ python run_seq_cls.py \ --save_total_limit 3 \ ``` -训练完导出模型之后,可以用于部署,`deploy/seq_cls_infer.py`文件提供了 python 部署预测示例。可执行以下命令运行部署示例: +训练完导出模型之后,可以用于部署,`deploy/infer.py`文件提供了 python 部署预测示例。可执行以下命令运行部署示例: ```shell -python deploy/seq_cls_infer.py --model_dir tmp/chnsenticorp_v2/export/ --device cpu --backend paddle +python deploy/infer.py --model_dir tmp/chnsenticorp_v2/export/ --device gpu ``` 运行后预测结果打印如下: ```text -WARNING: Logging before InitGoogleLogging() is written to STDERR -W0301 08:25:37.617117 58742 analysis_config.cc:958] It is detected that mkldnn and memory_optimize_pass are enabled at the same time, but they are not supported yet. Currently, memory_optimize_pass is explicitly disabled -[INFO] fastdeploy/runtime/runtime.cc(266)::CreatePaddleBackend Runtime initialized with Backend::PDINFER in Device::CPU. -Batch id: 0, example id: 0, sentence: 这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般, label: negative, negative prob: 0.9999, positive prob: 0.0001. -Batch id: 1, example id: 0, sentence: 怀着十分激动的心情放映,可是看着看着发现,在放映完毕后,出现一集米老鼠的动画片!开始还怀疑是不是赠送的个别现象,可是后来发现每张DVD后面都有!真不知道生产商怎么想的,我想看的是猫和老鼠,不是米老鼠!如果厂家是想赠送的话,那就全套米老鼠和唐老鸭都赠送,只在每张DVD后面添加一集算什么??简直是画蛇添足!!, label: negative, negative prob: 0.9998, positive prob: 0.0002. -Batch id: 2, example id: 0, sentence: 还稍微重了点,可能是硬盘大的原故,还要再轻半斤就好了。其他要进一步验证。贴的几种膜气泡较多,用不了多久就要更换了,屏幕膜稍好点,但比没有要强多了。建议配赠几张膜让用用户自己贴。, label: negative, negative prob: 0.9999, positive prob: 0.0001. +Batch id: 1189, example id: 0, sentence: 作为五星级 酒店的硬件是差了点 装修很久 电视很小 只是位置很好 楼下是DFS 对面是海港城 但性价比不高, label: positive, negative prob: 0.0001, positive prob: 0.9999. +Batch id: 1190, example id: 0, sentence: 最好别去,很差,看完很差想换酒店,他们竟跟我要服务费.也没待那房间2分种,居然解决了问题,可觉的下次不能去的,, label: negative, negative prob: 1.0000, positive prob: 0.0000. +Batch id: 1191, example id: 0, sentence: 看了一半就看不下去了,后半本犹豫几次都放下没有继续看的激情,故事平淡的连个波折起伏都没有,职场里那点事儿也学得太模糊,没有具体描述,而且杜拉拉就做一个行政而已,是个人都会做的没有技术含量的工作 也能描写的这么有技术含量 真是为难作者了本来冲着畅销排行第一买来看看,觉得总不至于大部分人都没品味吧?结果证明这个残酷的事实,一本让人如同嚼蜡的“畅销书”......, label: negative, negative prob: 0.9999, positive prob: 0.0001. +Batch id: 1192, example id: 0, sentence: 酒店环境很好 就是有一点点偏 交通不是很便利 去哪都需要达车 关键是不好打 酒店应该想办法解决一下, label: positive, negative prob: 0.0003, positive prob: 0.9997. +Batch id: 1193, example id: 0, sentence: 价格在这个地段属于适中, 附近有早餐店,小饭店, 比较方便,无早也无所, label: positive, negative prob: 0.1121, positive prob: 0.8879. +Batch id: 1194, example id: 0, sentence: 酒店的位置不错,附近都靠近购物中心和写字楼区。以前来大连一直都住,但感觉比较陈旧了。住的期间,酒店在进行装修,翻新和升级房间设备。好是好,希望到时房价别涨太多了。, label: positive, negative prob: 0.0000, positive prob: 1.0000. +Batch id: 1195, example id: 0, sentence: 位置不很方便,周围乱哄哄的,卫生条件也不如其他如家的店。以后绝不会再住在这里。, label: negative, negative prob: 1.0000, positive prob: 0.0000. +Batch id: 1196, example id: 0, sentence: 抱着很大兴趣买的,买来粗粗一翻排版很不错,姐姐还说快看吧,如果好我也买一本。可是真的看了,实在不怎么样。就是中文里夹英文单词说话,才翻了2页实在不想勉强自己了。我想说的是,练习英文单词,靠这本书肯定没有效果,其它好的方法比这强多了。, label: negative, negative prob: 1.0000, positive prob: 0.0000. +Batch id: 1197, example id: 0, sentence: 东西不错,不过有人不太喜欢镜面的,我个人比较喜欢,总之还算满意。, label: positive, negative prob: 0.0001, positive prob: 0.9999. +Batch id: 1198, example id: 0, sentence: 房间不错,只是上网速度慢得无法忍受,打开一个网页要等半小时,连邮件都无法收。另前台工作人员服务态度是很好,只是效率有得改善。, label: positive, negative prob: 0.0001, positive prob: 0.9999. ...... ``` diff --git a/slm/model_zoo/ernie-1.0/finetune/deploy/README.md b/slm/model_zoo/ernie-1.0/finetune/deploy/README.md index bf75b84958e9..ab01803aa3e0 100644 --- a/slm/model_zoo/ernie-1.0/finetune/deploy/README.md +++ b/slm/model_zoo/ernie-1.0/finetune/deploy/README.md @@ -1,28 +1,33 @@ -# FastDeploy ERNIE 1.0 模型 Python 部署示例 +# ERNIE 1.0 模型 Python 推理示例 -在部署前,参考 [FastDeploy SDK 安装文档](https://github.com/PaddlePaddle/FastDeploy/blob/develop/docs/cn/build_and_install/download_prebuilt_libraries.md)安装 FastDeploy Python SDK。 - -本目录下提供 `seq_cls_infer.py` 快速完成在 CPU/GPU 的中文情感分类任务的 Python 部署示例。 +本目录下提供 `infer.py` 快速完成在 CPU/GPU 的中文情感分类任务的 Python 推理示例。 ## 快速开始 -以下示例展示如何基于 FastDeploy 库完成 ERNIE 1.0 模型在 ChnSenticorp 数据集上进行文本分类任务的 Python 预测部署,可通过命令行参数`--device`以及`--backend`指定运行在不同的硬件以及推理引擎后端,并使用`--model_dir`参数指定运行的模型,具体参数设置可查看下面[参数说明](#参数说明)。示例中的模型是按照 [ERNIE 1.0 训练文档](../../README.md)导出得到的部署模型,其模型目录为`model_zoo/ernie-1.0/finetune/tmp/export`(用户可按实际情况设置)。 +以下示例展示 ERNIE 1.0 模型在 ChnSenticorp 数据集上进行文本分类任务的 Python 预测部署,可通过命令行参数`--device`指定运行在不同的硬件,并使用`--model_dir`参数指定运行的模型,具体参数设置可查看下面[参数说明](#参数说明)。示例中的模型是按照 [ERNIE 1.0 训练文档](../../README.md)导出得到的部署模型,其模型目录为`model_zoo/ernie-1.0/finetune/tmp/export`(用户可按实际情况设置)。 ```bash # CPU 推理 -python seq_cls_infer.py --model_dir ../tmp/chnsenticorp_v2/export/ --device cpu --backend paddle +python infer.py --model_dir ../tmp/chnsenticorp_v2/export/ --device cpu # GPU 推理 -python seq_cls_infer.py --model_dir ../tmp/chnsenticorp_v2/export/ --device gpu --backend paddle +python infer.py --model_dir ../tmp/chnsenticorp_v2/export/ --device gpu ``` 运行完成后返回的结果如下: ```bash -[INFO] fastdeploy/runtime/runtime.cc(266)::CreatePaddleBackend Runtime initialized with Backend::PDINFER in Device::GPU. -Batch id: 0, example id: 0, sentence: 这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般, label: negative, negative prob: 0.9999, positive prob: 0.0001. -Batch id: 1, example id: 0, sentence: 怀着十分激动的心情放映,可是看着看着发现,在放映完毕后,出现一集米老鼠的动画片!开始还怀疑是不是赠送的个别现象,可是后来发现每张DVD后面都有!真不知道生产商怎么想的,我想看的是猫和老鼠,不是米老鼠!如果厂家是想赠送的话,那就全套米老鼠和唐老鸭都赠送,只在每张DVD后面添加一集算什么??简直是画蛇添足!!, label: negative, negative prob: 0.9998, positive prob: 0.0002. -Batch id: 2, example id: 0, sentence: 还稍微重了点,可能是硬盘大的原故,还要再轻半斤就好了。其他要进一步验证。贴的几种膜气泡较多,用不了多久就要更换了,屏幕膜稍好点,但比没有要强多了。建议配赠几张膜让用用户自己贴。, label: negative, negative prob: 0.9999, positive prob: 0.0001. +...... +Batch id: 1189, example id: 0, sentence: 作为五星级 酒店的硬件是差了点 装修很久 电视很小 只是位置很好 楼下是DFS 对面是海港城 但性价比不高, label: positive, negative prob: 0.0001, positive prob: 0.9999. +Batch id: 1190, example id: 0, sentence: 最好别去,很差,看完很差想换酒店,他们竟跟我要服务费.也没待那房间2分种,居然解决了问题,可觉的下次不能去的,, label: negative, negative prob: 1.0000, positive prob: 0.0000. +Batch id: 1191, example id: 0, sentence: 看了一半就看不下去了,后半本犹豫几次都放下没有继续看的激情,故事平淡的连个波折起伏都没有,职场里那点事儿也学得太模糊,没有具体描述,而且杜拉拉就做一个行政而已,是个人都会做的没有技术含量的工作 也能描写的这么有技术含量 真是为难作者了本来冲着畅销排行第一买来看看,觉得总不至于大部分人都没品味吧?结果证明这个残酷的事实,一本让人如同嚼蜡的“畅销书”......, label: negative, negative prob: 0.9999, positive prob: 0.0001. +Batch id: 1192, example id: 0, sentence: 酒店环境很好 就是有一点点偏 交通不是很便利 去哪都需要达车 关键是不好打 酒店应该想办法解决一下, label: positive, negative prob: 0.0003, positive prob: 0.9997. +Batch id: 1193, example id: 0, sentence: 价格在这个地段属于适中, 附近有早餐店,小饭店, 比较方便,无早也无所, label: positive, negative prob: 0.1121, positive prob: 0.8879. +Batch id: 1194, example id: 0, sentence: 酒店的位置不错,附近都靠近购物中心和写字楼区。以前来大连一直都住,但感觉比较陈旧了。住的期间,酒店在进行装修,翻新和升级房间设备。好是好,希望到时房价别涨太多了。, label: positive, negative prob: 0.0000, positive prob: 1.0000. +Batch id: 1195, example id: 0, sentence: 位置不很方便,周围乱哄哄的,卫生条件也不如其他如家的店。以后绝不会再住在这里。, label: negative, negative prob: 1.0000, positive prob: 0.0000. +Batch id: 1196, example id: 0, sentence: 抱着很大兴趣买的,买来粗粗一翻排版很不错,姐姐还说快看吧,如果好我也买一本。可是真的看了,实在不怎么样。就是中文里夹英文单词说话,才翻了2页实在不想勉强自己了。我想说的是,练习英文单词,靠这本书肯定没有效果,其它好的方法比这强多了。, label: negative, negative prob: 1.0000, positive prob: 0.0000. +Batch id: 1197, example id: 0, sentence: 东西不错,不过有人不太喜欢镜面的,我个人比较喜欢,总之还算满意。, label: positive, negative prob: 0.0001, positive prob: 0.9999. +Batch id: 1198, example id: 0, sentence: 房间不错,只是上网速度慢得无法忍受,打开一个网页要等半小时,连邮件都无法收。另前台工作人员服务态度是很好,只是效率有得改善。, label: positive, negative prob: 0.0001, positive prob: 0.9999. ...... ``` @@ -36,92 +41,3 @@ Batch id: 2, example id: 0, sentence: 还稍微重了点,可能是硬盘大的 |--device | 运行的设备,可选范围: ['cpu', 'gpu'],默认为'cpu' | |--device_id | 运行设备的 id。默认为0。 | |--cpu_threads | 当使用 cpu 推理时,指定推理的 cpu 线程数,默认为1。| -|--backend | 支持的推理后端,可选范围: ['onnx_runtime', 'paddle', 'openvino', 'tensorrt', 'paddle_tensorrt'],默认为'paddle' | -|--use_fp16 | 是否使用 FP16模式进行推理。使用 tensorrt 和 paddle_tensorrt 后端时可开启,默认为 False | - -## FastDeploy 高阶用法 - -FastDeploy 在 Python 端上,提供 `fastdeploy.RuntimeOption.use_xxx()` 以及 `fastdeploy.RuntimeOption.use_xxx_backend()` 接口支持开发者选择不同的硬件、不同的推理引擎进行部署。在不同的硬件上部署 ERNIE 1.0 模型,需要选择硬件所支持的推理引擎进行部署,下表展示如何在不同的硬件上选择可用的推理引擎部署 ERNIE 1.0 模型。 - -符号说明: (1) ✅: 已经支持; (2) ❔: 正在进行中; (3) N/A: 暂不支持; - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
硬件 硬件对应的接口 可用的推理引擎 推理引擎对应的接口 是否支持 Paddle 新格式量化模型 是否支持 FP16 模式
CPU use_cpu() Paddle Inference use_paddle_infer_backend() N/A
ONNX Runtime use_ort_backend() N/A
OpenVINO use_openvino_backend() N/A
GPU use_gpu() Paddle Inference use_paddle_infer_backend() N/A
ONNX Runtime use_ort_backend()
Paddle TensorRT use_paddle_infer_backend() + paddle_infer_option.enable_trt = True
TensorRT use_trt_backend()
昆仑芯 XPU use_kunlunxin() Paddle Lite use_paddle_lite_backend() N/A
华为 昇腾 use_ascend() Paddle Lite use_paddle_lite_backend()
Graphcore IPU use_ipu() Paddle Inference use_paddle_infer_backend() N/A
diff --git a/slm/model_zoo/ernie-1.0/finetune/deploy/infer.py b/slm/model_zoo/ernie-1.0/finetune/deploy/infer.py new file mode 100644 index 000000000000..fa83e0eb535c --- /dev/null +++ b/slm/model_zoo/ernie-1.0/finetune/deploy/infer.py @@ -0,0 +1,112 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +import numpy as np +from paddle import inference +from scipy.special import softmax + +from paddlenlp.datasets import load_dataset +from paddlenlp.transformers import AutoTokenizer +from paddlenlp.utils.env import ( + PADDLE_INFERENCE_MODEL_SUFFIX, + PADDLE_INFERENCE_WEIGHTS_SUFFIX, +) + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument("--model_dir", required=True, help="The directory of model.") + parser.add_argument("--vocab_path", type=str, default="", help="The path of tokenizer vocab.") + parser.add_argument("--model_prefix", type=str, default="model", help="The model and params file prefix.") + parser.add_argument("--device", type=str, default="cpu", choices=["gpu", "cpu"], help="Device type.") + parser.add_argument("--batch_size", type=int, default=1, help="The batch size of data.") + parser.add_argument("--max_length", type=int, default=128, help="The max length of sequence.") + parser.add_argument("--cpu_threads", type=int, default=1, help="Number of threads for CPU.") + parser.add_argument("--device_id", type=int, default=0, help="GPU device id.") + return parser.parse_args() + + +def batchfy_text(texts, batch_size): + return [texts[i : i + batch_size] for i in range(0, len(texts), batch_size)] + + +class Predictor(object): + def __init__(self, args): + if args.vocab_path and os.path.isdir(args.vocab_path): + self.tokenizer = AutoTokenizer.from_pretrained(args.vocab_path) + else: + self.tokenizer = AutoTokenizer.from_pretrained("ernie-1.0") + + model_file = os.path.join(args.model_dir, args.model_prefix + f"{PADDLE_INFERENCE_MODEL_SUFFIX}") + params_file = os.path.join(args.model_dir, args.model_prefix + f"{PADDLE_INFERENCE_WEIGHTS_SUFFIX}") + + config = inference.Config(model_file, params_file) + if args.device == "gpu": + config.enable_use_gpu(100, args.device_id) + else: + config.disable_gpu() + config.set_cpu_math_library_num_threads(args.cpu_threads) + + config.switch_use_feed_fetch_ops(False) + self.predictor = inference.create_predictor(config) + self.input_handles = [self.predictor.get_input_handle(name) for name in self.predictor.get_input_names()] + self.output_handle = self.predictor.get_output_handle(self.predictor.get_output_names()[0]) + + self.batch_size = args.batch_size + self.max_length = args.max_length + + def preprocess(self, texts): + encoded = self.tokenizer( + texts, max_length=self.max_length, padding=True, truncation=True, return_token_type_ids=True + ) + input_ids = np.array(encoded["input_ids"], dtype="int64") + token_type_ids = np.array(encoded["token_type_ids"], dtype="int64") + return input_ids, token_type_ids + + def infer(self, input_ids, token_type_ids): + self.input_handles[0].copy_from_cpu(input_ids) + self.input_handles[1].copy_from_cpu(token_type_ids) + self.predictor.run() + return self.output_handle.copy_to_cpu() + + def postprocess(self, logits): + probs = softmax(logits, axis=1) + return {"label": probs.argmax(axis=1), "confidence": probs} + + def predict(self, texts): + input_ids, token_type_ids = self.preprocess(texts) + logits = self.infer(input_ids, token_type_ids) + return self.postprocess(logits) + + +if __name__ == "__main__": + args = parse_arguments() + predictor = Predictor(args) + test_ds = load_dataset("chnsenticorp", splits=["test"]) + texts_ds = [d["text"] for d in test_ds] + label_map = {0: "negative", 1: "positive"} + batch_texts = batchfy_text(texts_ds, args.batch_size) + + for bs, texts in enumerate(batch_texts): + outputs = predictor.predict(texts) + for i, sentence in enumerate(texts): + label = outputs["label"][i] + confidence = outputs["confidence"][i] + print( + f"Batch id: {bs}, example id: {i}, sentence: {sentence}, " + f"label: {label_map[label]}, negative prob: {confidence[0]:.4f}, positive prob: {confidence[1]:.4f}." + ) diff --git a/slm/model_zoo/ernie-1.0/finetune/deploy/seq_cls_infer.py b/slm/model_zoo/ernie-1.0/finetune/deploy/seq_cls_infer.py deleted file mode 100644 index 2ba53e76b24c..000000000000 --- a/slm/model_zoo/ernie-1.0/finetune/deploy/seq_cls_infer.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import distutils.util -import os - -import fastdeploy as fd -import numpy as np - -from paddlenlp.datasets import load_dataset -from paddlenlp.transformers import AutoTokenizer - - -def parse_arguments(): - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument("--model_dir", required=True, help="The directory of model.") - parser.add_argument("--vocab_path", type=str, default="", help="The path of tokenizer vocab.") - parser.add_argument("--model_prefix", type=str, default="model", help="The model and params file prefix.") - parser.add_argument( - "--device", - type=str, - default="cpu", - choices=["gpu", "cpu"], - help="Type of inference device, support 'cpu' or 'gpu'.", - ) - parser.add_argument( - "--backend", - type=str, - default="paddle", - choices=["onnx_runtime", "paddle", "openvino", "tensorrt", "paddle_tensorrt"], - help="The inference runtime backend.", - ) - parser.add_argument("--batch_size", type=int, default=1, help="The batch size of data.") - parser.add_argument("--max_length", type=int, default=128, help="The max length of sequence.") - parser.add_argument("--log_interval", type=int, default=10, help="The interval of logging.") - parser.add_argument("--use_fp16", type=distutils.util.strtobool, default=False, help="Wheter to use FP16 mode") - parser.add_argument("--cpu_threads", type=int, default=1, help="Number of threads to predict when using cpu.") - parser.add_argument("--device_id", type=int, default=0, help="Select which gpu device to train model.") - return parser.parse_args() - - -def batchfy_text(texts, batch_size): - batch_texts = [] - batch_start = 0 - while batch_start < len(texts): - batch_texts += [texts[batch_start : min(batch_start + batch_size, len(texts))]] - batch_start += batch_size - return batch_texts - - -class Predictor(object): - def __init__(self, args): - self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir) - self.runtime = self.create_fd_runtime(args) - self.batch_size = args.batch_size - self.max_length = args.max_length - - def create_fd_runtime(self, args): - option = fd.RuntimeOption() - model_path = os.path.join(args.model_dir, args.model_prefix + ".pdmodel") - params_path = os.path.join(args.model_dir, args.model_prefix + ".pdiparams") - option.set_model_path(model_path, params_path) - if args.device == "cpu": - option.use_cpu() - option.set_cpu_thread_num(args.cpu_threads) - else: - option.use_gpu(args.device_id) - if args.backend == "paddle": - option.use_paddle_infer_backend() - elif args.backend == "onnx_runtime": - option.use_ort_backend() - elif args.backend == "openvino": - option.use_openvino_backend() - else: - option.use_trt_backend() - if args.backend == "paddle_tensorrt": - option.use_paddle_infer_backend() - option.paddle_infer_option.collect_trt_shape = True - option.paddle_infer_option.enable_trt = True - trt_file = os.path.join(args.model_dir, "model.trt") - option.trt_option.set_shape( - "input_ids", [1, 1], [args.batch_size, args.max_length], [args.batch_size, args.max_length] - ) - option.trt_option.set_shape( - "token_type_ids", [1, 1], [args.batch_size, args.max_length], [args.batch_size, args.max_length] - ) - if args.use_fp16: - option.trt_option.enable_fp16 = True - trt_file = trt_file + ".fp16" - option.trt_option.serialize_file = trt_file - return fd.Runtime(option) - - def preprocess(self, text): - data = self.tokenizer(text, max_length=self.max_length, padding=True, truncation=True) - input_ids_name = self.runtime.get_input_info(0).name - token_type_ids_name = self.runtime.get_input_info(1).name - input_map = { - input_ids_name: np.array(data["input_ids"], dtype="int64"), - token_type_ids_name: np.array(data["token_type_ids"], dtype="int64"), - } - return input_map - - def infer(self, input_map): - results = self.runtime.infer(input_map) - return results - - def postprocess(self, infer_data): - logits = np.array(infer_data[0]) - max_value = np.max(logits, axis=1, keepdims=True) - exp_data = np.exp(logits - max_value) - probs = exp_data / np.sum(exp_data, axis=1, keepdims=True) - out_dict = {"label": probs.argmax(axis=-1), "confidence": probs} - return out_dict - - def predict(self, texts): - input_map = self.preprocess(texts) - infer_result = self.infer(input_map) - output = self.postprocess(infer_result) - return output - - -if __name__ == "__main__": - args = parse_arguments() - predictor = Predictor(args) - test_ds = load_dataset("chnsenticorp", splits=["test"]) - texts_ds = [d["text"] for d in test_ds] - label_map = {0: "negative", 1: "positive"} - batch_texts = batchfy_text(texts_ds, args.batch_size) - for bs, texts in enumerate(batch_texts): - outputs = predictor.predict(texts) - for i, sentence1 in enumerate(texts): - print( - f"Batch id: {bs}, example id: {i}, sentence: {sentence1}, " - f"label: {label_map[outputs['label'][i]]}, negative prob: {outputs['confidence'][i][0]:.4f}, " - f"positive prob: {outputs['confidence'][i][1]:.4f}." - )