Skip to content

[PIR] Update paddle.inference infer example for Ernie-vil2.0 #10500

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions slm/examples/lexical_analysis/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@ python export_model.py --data_dir=./lexical_analysis_dataset_tiny --params_path=

导出模型之后,可以用于部署,deploy/predict.py 文件提供了 python 部署预测示例。运行方式:

开启 PIR(PaddlePaddle 3.0.0默认):
```shell
python deploy/predict.py --model_file=infer_model/static_graph_params.json --params_file=infer_model/static_graph_params.pdiparams --data_dir lexical_analysis_dataset_tiny
```

未开启 PIR:
```shell
python deploy/predict.py --model_file=infer_model/static_graph_params.pdmodel --params_file=infer_model/static_graph_params.pdiparams --data_dir lexical_analysis_dataset_tiny
```
Expand Down
47 changes: 20 additions & 27 deletions slm/model_zoo/ernie-vil2.0/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,10 @@ Tensor(shape=[1, 2], dtype=float32, place=Place(gpu:0), stop_gradient=True,
```shell
mkdir -p data/datasets
wget https://paddlenlp.bj.bcebos.com/datasets/Flickr30k-CN.tar.gz
tar -xzvf Flickr30k-CN.tar.gz -d data/datasets/
tar -xzvf Flickr30k-CN.tar.gz -C data/datasets/
mv data/datasets/Flickr30k-CN_copy data/datasets/Flickr30k-CN

python preprocess/create_arrow_dataset.py \
--data_dir data/datasets/Flickr30k-CN \
--splits train,valid,test \
--image_dir data/datasets/Flickr30k-CN/image \
--t2i_type jsonl
python preprocess/create_arrow_dataset.py --data_dir data/datasets/Flickr30k-CN --image_dir data/datasets/Flickr30k-CN/image --splits train,valid,test
```
执行完后,data 目录应是如下结构:

Expand Down Expand Up @@ -337,30 +334,30 @@ python predict.py --resume output_pd/checkpoint-600/ --image_path examples/21285

```
......
-0.15448952, 0.72006893, 0.36882138, -0.84108782, 0.37967119,
0.12349987, -1.02212155, -0.58292383, 1.48998547, -0.46960664,
0.30193087, -0.56355256, -0.30767381, -0.34489608, 0.59651250,
-0.49545336, -0.95961350, 0.68815416, 0.47264558, -0.25057256,
-0.61301452, 0.09002528, -0.03568697]])
0.30446628, -0.40303054, -0.44902760, -0.20834517, 0.61418092,
-0.47503090, -0.90602577, 0.61230117, 0.31328726, -0.30551922,
-0.70518905, 0.02921746, -0.06500954]])
Text features
Tensor(shape=[2, 768], dtype=float32, place=Place(cpu), stop_gradient=True,
[[ 0.04250492, -0.41429815, 0.26164034, ..., 0.26221907,
0.34387457, 0.18779679],
[ 0.06672275, -0.41456315, 0.13787840, ..., 0.21791621,
0.36693257, 0.34208682]])
Label probs: Tensor(shape=[1, 2], dtype=float32, place=Place(cpu), stop_gradient=True,
[[0.99110782, 0.00889216]])
Tensor(shape=[2, 768], dtype=float32, place=Place(gpu:0), stop_gradient=True,
[[ 0.04464678, -0.43012181, 0.25478637, ..., 0.27861869,
0.36597741, 0.20715161],
[ 0.06647702, -0.43343985, 0.12268012, ..., 0.23637798,
0.38784462, 0.36298674]])
model temperature
Parameter containing:
Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=False,
[4.29992294])
Label probs: Tensor(shape=[1, 2], dtype=float32, place=Place(gpu:0), stop_gradient=True,
[[0.99257678, 0.00742322]])
```
可以看到`猫的照片`的相似度更高,结果符合预期。

<a name="模型导出预测"></a>

## 模型导出预测

上一节是动态图的示例,下面提供了简单的导出静态图预测的示例,帮助用户将预训练模型导出成预测部署的参数。首先安装[FastDeploy](https://github.com/PaddlePaddle/FastDeploy):
上一节是动态图的示例,下面提供了简单的导出静态图预测的示例,帮助用户将预训练模型导出成预测部署的参数。

```
pip install fastdeploy-gpu-python -f https://www.paddlepaddle.org.cn/whl/fastdeploy.html
```
然后运行下面的命令:

Expand All @@ -372,15 +369,11 @@ python export_model.py --model_path=output_pd/checkpoint-600/ \

对于导出的模型,我们提供了 Python 的 infer 脚本,调用预测库对简单的例子进行预测。
```shell
python deploy/python/infer.py --model_dir ./infer_model/
python deploy/python/infer.py --model_dir ./infer_model/ --image_path examples/212855663-c0a54707-e14c-4450-b45d-0162ae76aeb8.jpeg --device gpu
```
可以得到如下输出:
```
......
-5.63553333e-01 -3.07674855e-01 -3.44897419e-01 5.96513569e-01
-4.95454431e-01 -9.59614694e-01 6.88151956e-01 4.72645760e-01
-2.50571519e-01 -6.13013864e-01 9.00242254e-02 -3.56860608e-02]]
[[0.99110764 0.00889209]]
[[0.9925795 0.00742046]]
```
可以看到输出的概率值跟前面的预测结果几乎是一致的

Expand Down
208 changes: 70 additions & 138 deletions slm/model_zoo/ernie-vil2.0/deploy/python/infer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import distutils.util
import argparse
import os

import fastdeploy as fd
import numpy as np
import paddle.inference as paddle_infer
from PIL import Image
from scipy.special import softmax

from paddlenlp.transformers import ErnieViLProcessor
from paddlenlp.utils.env import (
Expand All @@ -27,161 +28,92 @@


def parse_arguments():
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--model_dir", required=True, help="The directory of model.")
parser.add_argument(
"--device",
type=str,
default="gpu",
choices=["gpu", "cpu", "kunlunxin"],
help="Type of inference device, support 'cpu', 'kunlunxin' or 'gpu'.",
)
parser.add_argument(
"--backend",
type=str,
default="onnx_runtime",
choices=["onnx_runtime", "paddle", "openvino", "tensorrt", "paddle_tensorrt"],
help="The inference runtime backend.",
)
parser.add_argument("--batch_size", type=int, default=1, help="The batch size of data.")
parser.add_argument("--temperature", type=float, default=4.30022621, help="The temperature of the model.")
parser.add_argument("--max_length", type=int, default=128, help="The max length of sequence.")
parser.add_argument("--log_interval", type=int, default=10, help="The interval of logging.")
parser.add_argument("--use_fp16", type=distutils.util.strtobool, default=False, help="Wheter to use FP16 mode")
parser.add_argument(
"--encode_type",
type=str,
default="text",
choices=[
"image",
"text",
],
help="The encoder type.",
)
parser.add_argument(
"--image_path",
default="000000039769.jpg",
type=str,
help="image_path used for prediction",
)
parser.add_argument("--model_dir", required=True, help="Directory with .json and .pdiparams")
parser.add_argument("--device", default="gpu", choices=["gpu", "cpu"], help="Device for inference")
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--temperature", type=float, default=4.3)
parser.add_argument("--max_length", type=int, default=128)
parser.add_argument("--encode_type", choices=["text", "image"], default="text")
parser.add_argument("--image_path", type=str, default="data/datasets/Flickr30k-CN/image/36979.jpg")
return parser.parse_args()


class ErnieVil2Predictor(object):
class PaddleErnieViLPredictor:
def __init__(self, args):
self.args = args
self.processor = ErnieViLProcessor.from_pretrained("PaddlePaddle/ernie_vil-2.0-base-zh")
self.runtime = self.create_fd_runtime(args)
self.batch_size = args.batch_size
self.max_length = args.max_length
self.encode_type = args.encode_type

def create_fd_runtime(self, args):
option = fd.RuntimeOption()
if args.encode_type == "text":
model_path = os.path.join(args.model_dir, f"get_text_features{PADDLE_INFERENCE_MODEL_SUFFIX}")
params_path = os.path.join(args.model_dir, f"get_text_features{PADDLE_INFERENCE_WEIGHTS_SUFFIX}")
else:
model_path = os.path.join(args.model_dir, f"get_image_features{PADDLE_INFERENCE_MODEL_SUFFIX}")
params_path = os.path.join(args.model_dir, f"get_image_features{PADDLE_INFERENCE_WEIGHTS_SUFFIX}")
option.set_model_path(model_path, params_path)
if args.device == "kunlunxin":
option.use_kunlunxin()
option.use_paddle_lite_backend()
return fd.Runtime(option)
if args.device == "cpu":
option.use_cpu()
self.predictor, self.input_names, self.output_names = self.load_predictor()

def load_predictor(self):
model_file = os.path.join(
self.args.model_dir, f"get_{self.args.encode_type}_features{PADDLE_INFERENCE_MODEL_SUFFIX}"
)
params_file = os.path.join(
self.args.model_dir, f"get_{self.args.encode_type}_features{PADDLE_INFERENCE_WEIGHTS_SUFFIX}"
)

config = paddle_infer.Config(model_file, params_file)
if self.args.device == "gpu":
config.enable_use_gpu(100, 0)
else:
option.use_gpu()
if args.backend == "paddle":
option.use_paddle_infer_backend()
elif args.backend == "onnx_runtime":
option.use_ort_backend()
elif args.backend == "openvino":
option.use_openvino_backend()
else:
option.use_trt_backend()
if args.backend == "paddle_tensorrt":
option.enable_paddle_to_trt()
option.enable_paddle_trt_collect_shape()
trt_file = os.path.join(args.model_dir, "{}_infer.trt".format(args.encode_type))
if args.encode_type == "text":
option.set_trt_input_shape(
"input_ids",
min_shape=[1, args.max_length],
opt_shape=[args.batch_size, args.max_length],
max_shape=[args.batch_size, args.max_length],
)
else:
option.set_trt_input_shape(
"pixel_values",
min_shape=[1, 3, 224, 224],
opt_shape=[args.batch_size, 3, 224, 224],
max_shape=[args.batch_size, 3, 224, 224],
)
if args.use_fp16:
option.enable_trt_fp16()
trt_file = trt_file + ".fp16"
option.set_trt_cache_file(trt_file)
return fd.Runtime(option)
config.disable_gpu()
config.disable_glog_info()
config.switch_ir_optim(True)

predictor = paddle_infer.create_predictor(config)
input_names = predictor.get_input_names()
output_names = predictor.get_output_names()
return predictor, input_names, output_names

def preprocess(self, inputs):
if self.encode_type == "text":
dataset = [np.array([self.processor(text=text)["input_ids"] for text in inputs], dtype="int64")]
if self.args.encode_type == "text":
input_ids = [self.processor(text=t)["input_ids"] for t in inputs]
input_ids = np.array(input_ids, dtype="int64")
return {"input_ids": input_ids}
else:
dataset = [np.array([self.processor(images=image)["pixel_values"][0] for image in inputs])]
input_map = {}
for input_field_id, data in enumerate(dataset):
input_field = self.runtime.get_input_info(input_field_id).name
input_map[input_field] = data
return input_map

def postprocess(self, infer_data):
logits = np.array(infer_data[0])
out_dict = {
"features": logits,
}
return out_dict

def infer(self, input_map):
results = self.runtime.infer(input_map)
return results
pixel_values = [self.processor(images=img)["pixel_values"][0] for img in inputs]
pixel_values = np.stack(pixel_values)
return {"pixel_values": pixel_values.astype("float32")}

def infer(self, input_dict):
for name in self.input_names:
input_tensor = self.predictor.get_input_handle(name)
input_tensor.copy_from_cpu(input_dict[name])
self.predictor.run()
output_tensor = self.predictor.get_output_handle(self.output_names[0])
return output_tensor.copy_to_cpu()

def predict(self, inputs):
input_map = self.preprocess(inputs)
infer_result = self.infer(input_map)
output = self.postprocess(infer_result)
output = self.infer(input_map)
return output


def main():
args = parse_arguments()
texts = [
"猫的照片",
"狗的照片",
]
args.batch_size = 2
predictor = ErnieVil2Predictor(args)
outputs = predictor.predict(texts)
print(outputs)
text_feats = outputs["features"]
image = Image.open(args.image_path)

# 文本推理
args.encode_type = "text"
predictor_text = PaddleErnieViLPredictor(args)
texts = ["猫的照片", "狗的照片"]
args.batch_size = len(texts)
text_features = predictor_text.predict(texts)

# 图像推理
args.encode_type = "image"
args.batch_size = 1
predictor = ErnieVil2Predictor(args)
images = [image]
outputs = predictor.predict(images)
image_feats = outputs["features"]
print(image_feats)
from scipy.special import softmax

image_feats = image_feats / np.linalg.norm(image_feats, ord=2, axis=-1, keepdims=True)
text_feats = text_feats / np.linalg.norm(text_feats, ord=2, axis=-1, keepdims=True)
# Get from dygraph, refer to predict.py
exp_data = np.exp(args.temperature)
m = softmax(np.matmul(exp_data * text_feats, image_feats.T), axis=0).T
print(m)
predictor_image = PaddleErnieViLPredictor(args)
image = Image.open(args.image_path).convert("RGB")
image_features = predictor_image.predict([image])

# 特征归一化 + 相似度计算
image_features = image_features / np.linalg.norm(image_features, axis=-1, keepdims=True)
text_features = text_features / np.linalg.norm(text_features, axis=-1, keepdims=True)

sim_logits = softmax(np.exp(args.temperature) * np.matmul(text_features, image_features.T), axis=0).T
print("相似度矩阵(image→text):")
print(sim_logits)


if __name__ == "__main__":
Expand Down
Loading