Skip to content

Commit 43a7127

Browse files
authored
Add python deployment for squeezesegv3
1 parent d1b90f3 commit 43a7127

File tree

2 files changed

+204
-0
lines changed

2 files changed

+204
-0
lines changed

deploy/squeezesegv3/python/infer.py

+164
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
17+
import cv2
18+
import numpy as np
19+
import paddle
20+
from paddle.inference import Config, create_predictor
21+
22+
from paddle3d import transforms as T
23+
from paddle3d.sample import Sample
24+
from paddle3d.transforms.normalize import NormalizeRangeImage
25+
from paddle3d.transforms.reader import LoadSemanticKITTIRange
26+
27+
28+
def parse_args():
29+
parser = argparse.ArgumentParser()
30+
parser.add_argument(
31+
"--model_file",
32+
type=str,
33+
help="Model filename, Specify this when your model is a combined model.",
34+
required=True)
35+
parser.add_argument(
36+
"--params_file",
37+
type=str,
38+
help=
39+
"Parameter filename, Specify this when your model is a combined model.",
40+
required=True)
41+
parser.add_argument(
42+
'--lidar_file', type=str, help='The lidar path.', required=True)
43+
parser.add_argument(
44+
'--img_mean',
45+
type=str,
46+
help='The mean value of range-view image.',
47+
required=True)
48+
parser.add_argument(
49+
'--img_std',
50+
type=str,
51+
help='The variance value of range-view image.',
52+
required=True)
53+
parser.add_argument("--gpu_id", type=int, default=0, help="GPU card id.")
54+
parser.add_argument(
55+
"--use_trt",
56+
type=int,
57+
default=0,
58+
help="Whether to use tensorrt to accelerate when using gpu.")
59+
parser.add_argument(
60+
"--trt_precision",
61+
type=int,
62+
default=0,
63+
help="Precision type of tensorrt, 0: kFloat32, 1: kHalf.")
64+
parser.add_argument(
65+
"--trt_use_static",
66+
type=int,
67+
default=0,
68+
help="Whether to load the tensorrt graph optimization from a disk path."
69+
)
70+
parser.add_argument(
71+
"--trt_static_dir",
72+
type=str,
73+
help="Path of a tensorrt graph optimization directory.")
74+
75+
return parser.parse_args()
76+
77+
78+
def preprocess(file_path, img_mean, img_std):
79+
if isinstance(img_mean, str):
80+
img_mean = eval(img_mean)
81+
if isinstance(img_std, str):
82+
img_std = eval(img_std)
83+
84+
sample = Sample(path=file_path, modality="lidar")
85+
86+
transforms = T.Compose([
87+
LoadSemanticKITTIRange(project_label=False),
88+
NormalizeRangeImage(mean=img_mean, std=img_std)
89+
])
90+
91+
sample = transforms(sample)
92+
93+
if "proj_mask" in sample.meta:
94+
sample.data *= sample.meta.pop("proj_mask")
95+
return np.expand_dims(sample.data,
96+
0), sample.meta.proj_x, sample.meta.proj_y
97+
98+
99+
def init_predictor(model_file,
100+
params_file,
101+
gpu_id=0,
102+
use_trt=False,
103+
trt_precision=0,
104+
trt_use_static=False,
105+
trt_static_dir=None):
106+
config = Config(model_file, params_file)
107+
config.enable_memory_optim()
108+
config.enable_use_gpu(1000, gpu_id)
109+
if use_trt:
110+
precision_mode = paddle.inference.PrecisionType.Float32
111+
if trt_precision == 1:
112+
precision_mode = paddle.inference.PrecisionType.Half
113+
config.enable_tensorrt_engine(
114+
workspace_size=1 << 20,
115+
max_batch_size=1,
116+
min_subgraph_size=3,
117+
precision_mode=precision_mode,
118+
use_static=trt_use_static,
119+
use_calib_mode=False)
120+
if trt_use_static:
121+
config.set_optim_cache_dir(trt_static_dir)
122+
123+
predictor = create_predictor(config)
124+
return predictor
125+
126+
127+
def run(predictor, points):
128+
# copy img data to input tensor
129+
input_names = predictor.get_input_names()
130+
input_tensor = predictor.get_input_handle(input_names[0])
131+
input_tensor.reshape(points.shape)
132+
input_tensor.copy_from_cpu(points.copy())
133+
134+
# do the inference
135+
predictor.run()
136+
137+
results = []
138+
# get out data from output tensor
139+
output_names = predictor.get_output_names()
140+
output_tensor = predictor.get_output_handle(output_names[0])
141+
pred_label = output_tensor.copy_to_cpu()
142+
143+
return pred_label[0]
144+
145+
146+
def postprocess(pred_img_label, proj_x, proj_y):
147+
return pred_img_label[proj_y, proj_x]
148+
149+
150+
def main(args):
151+
predictor = init_predictor(args.model_file, args.params_file, args.gpu_id,
152+
args.use_trt, args.trt_precision,
153+
args.trt_use_static, args.trt_static_dir)
154+
range_img, proj_x, proj_y = preprocess(args.lidar_file, args.img_mean,
155+
args.img_std)
156+
pred_img_label = run(predictor, range_img)
157+
pred_point_label = postprocess(pred_img_label, proj_x, proj_y)
158+
return pred_point_label
159+
160+
161+
if __name__ == '__main__':
162+
args = parse_args()
163+
164+
main(args)

docs/models/squeezesegv3/README.md

+40
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
* [训练](#h3-id52h3)
1111
* [评估](#h3-id53h3)
1212
* [模型导出](#h3-id54h3)
13+
* [模型部署](#h3-id55h3)
1314

1415
## <h2 id="1">引用</h2>
1516

@@ -125,3 +126,42 @@ python tools/export.py \
125126
| model | 待导出模型参数`model.pdparams`路径 || - |
126127
| input_shape | 指定模型的输入尺寸,支持`N, C, H, W``H, W`格式 || - |
127128
| save_dir | 保存导出模型的路径,`save_dir`下将会生成三个文件:`squeezesegv3.pdiparams ``squeezesegv3.pdiparams.info``squeezesegv3.pdmodel` || `deploy` |
129+
130+
131+
132+
### <h3 id="55">模型部署</h3>
133+
134+
#### C++部署
135+
136+
Coming soon...
137+
138+
#### Python部署
139+
140+
命令参数说明如下:
141+
142+
| 参数 | 说明 |
143+
| -- | -- |
144+
| model_file | 导出模型的结构文件`squeezesegv3.pdmodel`所在路径 |
145+
| params_file | 导出模型的参数文件`squeezesegv3.pdiparams`所在路径 |
146+
| lidar_file | 待预测的点云文件所在路径 |
147+
| img_mean | 点云投影到range-view后所成图像的均值,例如为`12.12,10.88,0.23,-1.04,0.21` |
148+
| img_std | 点云投影到range-view后所成图像的方差,例如为`12.32,11.47,6.91,0.86,0.16` |
149+
| use_trt | 是否使用TensorRT进行加速,默认0|
150+
| trt_precision | 当use_trt设置为1时,模型精度可设置0或1,0表示fp32, 1表示fp16。默认0 |
151+
| trt_use_static | 当trt_use_static设置为1时,**在首次运行程序的时候会将TensorRT的优化信息进行序列化到磁盘上,下次运行时直接加载优化的序列化信息而不需要重新生成**。默认0 |
152+
| trt_static_dir | 当trt_use_static设置为1时,保存优化信息的路径 |
153+
154+
155+
运行以下命令,执行预测:
156+
157+
```
158+
python infer.py --model_file /path/to/squeezesegv3.pdmodel --params_file /path/to/squeezesegv3.pdiparams --lidar_file /path/to/lidar.pcd.bin --img_mean 12.12,10.88,0.23,-1.04,0.21 --img_std 12.32,11.47,6.91,0.86,0.16
159+
```
160+
161+
如果要开启TensorRT的话,请卸载掉原有的`paddlepaddel_gpu`,至[Paddle官网](https://paddleinference.paddlepaddle.org.cn/user_guides/download_lib.html#python)下载与TensorRT连编的预编译Paddle Inferece安装包,选择符合本地环境CUDA/cuDNN/TensorRT版本的安装包完成安装即可。
162+
163+
运行以下命令,开启TensorRT加速模型预测:
164+
165+
```
166+
python infer.py --model_file /path/to/squeezesegv3.pdmodel --params_file /path/to/squeezesegv3.pdiparams --lidar_file /path/to/lidar.pcd.bin --img_mean 12.12,10.88,0.23,-1.04,0.21 --img_std 12.32,11.47,6.91,0.86,0.16 --use_trt 1
167+
```

0 commit comments

Comments
 (0)