Skip to content

Commit 0f4ba41

Browse files
[Enhance] Speed up evaluation on waymo (#2008)
* support fast eval on waymo * support waymo evaluatioin more flexible and faster * support waymo evaluatioin more flexible and faster * renames * add docs * add guides for multi-thread evaluation toolkit * fix docstring * add download link for idx2metainfo * add docstring * set convert_kitti_format=False in Lidar-based methods * fix docs * add docstring
1 parent bee069b commit 0f4ba41

File tree

10 files changed

+273
-78
lines changed

10 files changed

+273
-78
lines changed

configs/_base_/datasets/waymoD5-3d-3class.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,8 @@
151151
ann_file='./data/waymo/kitti_format/waymo_infos_val.pkl',
152152
waymo_bin_file='./data/waymo/waymo_format/gt.bin',
153153
data_root='./data/waymo/waymo_format',
154-
file_client_args=file_client_args)
154+
file_client_args=file_client_args,
155+
convert_kitti_format=False)
155156
test_evaluator = val_evaluator
156157

157158
vis_backends = [dict(type='LocalVisBackend')]

configs/_base_/datasets/waymoD5-3d-car.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,8 @@
135135
type='WaymoMetric',
136136
ann_file='./data/waymo/kitti_format/waymo_infos_val.pkl',
137137
waymo_bin_file='./data/waymo/waymo_format/gt.bin',
138-
data_root='./data/waymo/waymo_format')
138+
data_root='./data/waymo/waymo_format',
139+
convert_kitti_format=False)
139140
test_evaluator = val_evaluator
140141

141142
vis_backends = [dict(type='LocalVisBackend')]

docs/en/advanced_guides/datasets/waymo_det.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ For evaluation on Waymo, please follow the [instruction](https://github.com/waym
106106
```shell
107107
# download the code and enter the base directory
108108
git clone https://github.com/waymo-research/waymo-open-dataset.git waymo-od
109+
# git clone https://github.com/Abyssaledge/waymo-open-dataset-master waymo-od # if you want to use faster multi-thread version.
109110
cd waymo-od
110111
git checkout remotes/origin/master
111112

docs/en/user_guides/dataset_prepare.md

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,25 @@ Download Waymo open dataset V1.2 [HERE](https://waymo.com/open/download/) and it
110110
python tools/create_data.py waymo --root-path ./data/waymo/ --out-dir ./data/waymo/ --workers 128 --extra-tag waymo
111111
```
112112

113-
Note that if your local disk does not have enough space for saving converted data, you can change the `out-dir` to anywhere else. Just remember to create folders and prepare data there in advance and link them back to `data/waymo/kitti_format` after the data conversion.
113+
Note that:
114+
115+
- If your local disk does not have enough space for saving converted data, you can change the `out-dir` to anywhere else. Just remember to create folders and prepare data there in advance and link them back to `data/waymo/kitti_format` after the data conversion.
116+
117+
- If you want faster evaluation on Waymo, you can download the preprocessed [metainfo](https://download.openmmlab.com/mmdetection3d/data/waymo/idx2metainfo.pkl) containing `contextname` and `timestamp` to the directory `data/waymo/waymo_format/`. Then, the dataset config is modified like the following:
118+
119+
```python
120+
val_evaluator = dict(
121+
type='WaymoMetric',
122+
ann_file='./data/waymo/kitti_format/waymo_infos_val.pkl',
123+
waymo_bin_file='./data/waymo/waymo_format/gt.bin',
124+
data_root='./data/waymo/waymo_format',
125+
file_client_args=file_client_args,
126+
convert_kitti_format=True,
127+
idx2metainfo='data/waymo/waymo_format/idx2metainfo.pkl'
128+
)
129+
```
130+
131+
Now, this trick is only used for LiDAR-based detection methods.
114132

115133
### NuScenes
116134

docs/zh_cn/advanced_guides/datasets/waymo_det.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ mmdetection3d
106106
```shell
107107
# download the code and enter the base directory
108108
git clone https://github.com/waymo-research/waymo-open-dataset.git waymo-od
109+
# git clone https://github.com/Abyssaledge/waymo-open-dataset-master waymo-od # if you want to use faster multi-thread version.
109110
cd waymo-od
110111
git checkout remotes/origin/master
111112

docs/zh_cn/user_guides/dataset_prepare.md

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,23 @@ python tools/create_data.py kitti --root-path ./data/kitti --out-dir ./data/kitt
104104
python tools/create_data.py waymo --root-path ./data/waymo/ --out-dir ./data/waymo/ --workers 128 --extra-tag waymo
105105
```
106106

107-
注意,如果你的硬盘空间大小不足以存储转换后的数据,你可以将 `out-dir` 参数设定为别的路径。
108-
你只需要记得在那个路径下创建文件夹并下载数据,然后在数据预处理完成后将其链接回 `data/waymo/kitti_format` 即可。
107+
注意:
108+
109+
- 如果你的硬盘空间大小不足以存储转换后的数据,你可以将 `out-dir` 参数设定为别的路径。
110+
你只需要记得在那个路径下创建文件夹并下载数据,然后在数据预处理完成后将其链接回 `data/waymo/kitti_format` 即可
111+
- 如果你想在 Waymo 上进行更快的评估,你可以下载已经预处理好的[元信息文件](https://download.openmmlab.com/mmdetection3d/data/waymo/idx2metainfo.pkl) 并将其放置在 `data/waymo/waymo_format/` 目录下. 接着,你可以按照下方来更改数据集的配置:
112+
```python
113+
val_evaluator = dict(
114+
type='WaymoMetric',
115+
ann_file='./data/waymo/kitti_format/waymo_infos_val.pkl',
116+
waymo_bin_file='./data/waymo/waymo_format/gt.bin',
117+
data_root='./data/waymo/waymo_format',
118+
file_client_args=file_client_args,
119+
convert_kitti_format=True,
120+
idx2metainfo='data/waymo/waymo_format/idx2metainfo.pkl'
121+
)
122+
```
123+
目前这种方式仅限于纯点云任务。
109124

110125
### NuScenes
111126

mmdet3d/datasets/waymo_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def __init__(self,
9696
self.max_sweeps = max_sweeps
9797
# we do not provide file_client_args to custom_3d init
9898
# because we want disk loading for info
99-
# while ceph loading for KITTI2Waymo
99+
# while ceph loading for Prediction2Waymo
100100
super().__init__(
101101
data_root=data_root,
102102
ann_file=ann_file,
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22

3-
from .prediction_kitti_to_waymo import KITTI2Waymo
3+
from .prediction_to_waymo import Prediction2Waymo
44

5-
__all__ = ['KITTI2Waymo']
5+
__all__ = ['Prediction2Waymo']

mmdet3d/evaluation/functional/waymo_utils/prediction_kitti_to_waymo.py renamed to mmdet3d/evaluation/functional/waymo_utils/prediction_to_waymo.py

Lines changed: 174 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -5,63 +5,81 @@
55

66
try:
77
from waymo_open_dataset import dataset_pb2 as open_dataset
8+
from waymo_open_dataset import label_pb2
9+
from waymo_open_dataset.protos import metrics_pb2
10+
from waymo_open_dataset.protos.metrics_pb2 import Objects
811
except ImportError:
12+
Objects = None
913
raise ImportError(
1014
'Please run "pip install waymo-open-dataset-tf-2-1-0==1.2.0" '
1115
'to install the official devkit first.')
1216

1317
from glob import glob
1418
from os.path import join
19+
from typing import List, Optional
1520

1621
import mmengine
1722
import numpy as np
1823
import tensorflow as tf
19-
from waymo_open_dataset import label_pb2
20-
from waymo_open_dataset.protos import metrics_pb2
2124

2225

23-
class KITTI2Waymo(object):
24-
"""KITTI predictions to Waymo converter.
26+
class Prediction2Waymo(object):
27+
"""Predictions to Waymo converter. The format of prediction results could
28+
be original format or kitti-format.
2529
2630
This class serves as the converter to change predictions from KITTI to
2731
Waymo format.
2832
2933
Args:
30-
kitti_result_files (list[dict]): Predictions in KITTI format.
34+
results (list[dict]): Prediction results.
3135
waymo_tfrecords_dir (str): Directory to load waymo raw data.
3236
waymo_results_save_dir (str): Directory to save converted predictions
3337
in waymo format (.bin files).
3438
waymo_results_final_path (str): Path to save combined
3539
predictions in waymo format (.bin file), like 'a/b/c.bin'.
3640
prefix (str): Prefix of filename. In general, 0 for training, 1 for
3741
validation and 2 for testing.
38-
workers (str): Number of parallel processes.
42+
classes (dict): A list of class name.
43+
workers (str): Number of parallel processes. Defaults to 2.
44+
file_client_args (str): File client for reading gt in waymo format.
45+
Defaults to ``dict(backend='disk')``.
46+
from_kitti_format (bool, optional): Whether the reuslts are kitti
47+
format. Defaults to False.
48+
idx2metainfo (Optional[dict], optional): The mapping from sample_idx to
49+
metainfo. The metainfo must contain the keys: 'idx2contextname' and
50+
'idx2timestamp'. Defaults to None.
3951
"""
4052

4153
def __init__(self,
42-
kitti_result_files,
43-
waymo_tfrecords_dir,
44-
waymo_results_save_dir,
45-
waymo_results_final_path,
46-
prefix,
47-
workers=64,
48-
file_client_args=dict(backend='disk')):
49-
50-
self.kitti_result_files = kitti_result_files
54+
results: List[dict],
55+
waymo_tfrecords_dir: str,
56+
waymo_results_save_dir: str,
57+
waymo_results_final_path: str,
58+
prefix: str,
59+
classes: dict,
60+
workers: int = 2,
61+
file_client_args: dict = dict(backend='disk'),
62+
from_kitti_format: bool = False,
63+
idx2metainfo: Optional[dict] = None):
64+
65+
self.results = results
5166
self.waymo_tfrecords_dir = waymo_tfrecords_dir
5267
self.waymo_results_save_dir = waymo_results_save_dir
5368
self.waymo_results_final_path = waymo_results_final_path
5469
self.prefix = prefix
70+
self.classes = classes
5571
self.workers = int(workers)
5672
self.file_client_args = file_client_args
57-
self.name2idx = {}
58-
for idx, result in enumerate(kitti_result_files):
59-
if len(result['sample_id']) > 0:
60-
self.name2idx[str(result['sample_id'][0])] = idx
73+
self.from_kitti_format = from_kitti_format
74+
if idx2metainfo is not None:
75+
self.idx2metainfo = idx2metainfo
76+
# If ``fast_eval``, the metainfo does not need to be read from
77+
# original data online. It's preprocessed offline.
78+
self.fast_eval = True
79+
else:
80+
self.fast_eval = False
6181

62-
# turn on eager execution for older tensorflow versions
63-
if int(tf.__version__.split('.')[0]) < 2:
64-
tf.enable_eager_execution()
82+
self.name2idx = {}
6583

6684
self.k2w_cls_map = {
6785
'Car': label_pb2.Label.TYPE_VEHICLE,
@@ -70,12 +88,28 @@ def __init__(self,
7088
'Cyclist': label_pb2.Label.TYPE_CYCLIST,
7189
}
7290

73-
self.T_ref_to_front_cam = np.array([[0.0, 0.0, 1.0, 0.0],
74-
[-1.0, 0.0, 0.0, 0.0],
75-
[0.0, -1.0, 0.0, 0.0],
76-
[0.0, 0.0, 0.0, 1.0]])
91+
if self.from_kitti_format:
92+
self.T_ref_to_front_cam = np.array([[0.0, 0.0, 1.0, 0.0],
93+
[-1.0, 0.0, 0.0, 0.0],
94+
[0.0, -1.0, 0.0, 0.0],
95+
[0.0, 0.0, 0.0, 1.0]])
96+
# ``sample_idx`` of the sample in kitti-format is an array
97+
for idx, result in enumerate(results):
98+
if len(result['sample_idx']) > 0:
99+
self.name2idx[str(result['sample_idx'][0])] = idx
100+
else:
101+
# ``sample_idx`` of the sample in the original prediction
102+
# is an int value.
103+
for idx, result in enumerate(results):
104+
self.name2idx[str(result['sample_idx'])] = idx
105+
106+
if not self.fast_eval:
107+
# need to read original '.tfrecord' file
108+
self.get_file_names()
109+
# turn on eager execution for older tensorflow versions
110+
if int(tf.__version__.split('.')[0]) < 2:
111+
tf.enable_eager_execution()
77112

78-
self.get_file_names()
79113
self.create_folder()
80114

81115
def get_file_names(self):
@@ -207,22 +241,30 @@ def convert_one(self, file_idx):
207241

208242
filename = f'{self.prefix}{file_idx:03d}{frame_num:03d}'
209243

210-
for camera in frame.context.camera_calibrations:
211-
# FRONT = 1, see dataset.proto for details
212-
if camera.name == 1:
213-
T_front_cam_to_vehicle = np.array(
214-
camera.extrinsic.transform).reshape(4, 4)
215-
216-
T_k2w = T_front_cam_to_vehicle @ self.T_ref_to_front_cam
217-
218244
context_name = frame.context.name
219245
frame_timestamp_micros = frame.timestamp_micros
220246

221247
if filename in self.name2idx:
222-
kitti_result = \
223-
self.kitti_result_files[self.name2idx[filename]]
224-
objects = self.parse_objects(kitti_result, T_k2w, context_name,
225-
frame_timestamp_micros)
248+
if self.from_kitti_format:
249+
for camera in frame.context.camera_calibrations:
250+
# FRONT = 1, see dataset.proto for details
251+
if camera.name == 1:
252+
T_front_cam_to_vehicle = np.array(
253+
camera.extrinsic.transform).reshape(4, 4)
254+
255+
T_k2w = T_front_cam_to_vehicle @ self.T_ref_to_front_cam
256+
257+
kitti_result = \
258+
self.results[self.name2idx[filename]]
259+
objects = self.parse_objects(kitti_result, T_k2w,
260+
context_name,
261+
frame_timestamp_micros)
262+
else:
263+
index = self.name2idx[filename]
264+
objects = self.parse_objects_from_origin(
265+
self.results[index], context_name,
266+
frame_timestamp_micros)
267+
226268
else:
227269
print(filename, 'not found.')
228270
objects = metrics_pb2.Objects()
@@ -232,11 +274,100 @@ def convert_one(self, file_idx):
232274
'wb') as f:
233275
f.write(objects.SerializeToString())
234276

277+
def convert_one_fast(self, res_index: int):
278+
"""Convert action for single file. It read the metainfo from the
279+
preprocessed file offline and will be faster.
280+
281+
Args:
282+
res_index (int): The indices of the results.
283+
"""
284+
sample_idx = self.results[res_index]['sample_idx']
285+
if len(self.results[res_index]['pred_instances_3d']) > 0:
286+
objects = self.parse_objects_from_origin(
287+
self.results[res_index],
288+
self.idx2metainfo[str(sample_idx)]['contextname'],
289+
self.idx2metainfo[str(sample_idx)]['timestamp'])
290+
else:
291+
print(sample_idx, 'not found.')
292+
objects = metrics_pb2.Objects()
293+
294+
with open(
295+
join(self.waymo_results_save_dir, f'{sample_idx}.bin'),
296+
'wb') as f:
297+
f.write(objects.SerializeToString())
298+
299+
def parse_objects_from_origin(self, result: dict, contextname: str,
300+
timestamp: str) -> Objects:
301+
"""Parse obejcts from the original prediction results.
302+
303+
Args:
304+
result (dict): The original prediction results.
305+
contextname (str): The ``contextname`` of sample in waymo.
306+
timestamp (str): The ``timestamp`` of sample in waymo.
307+
308+
Returns:
309+
metrics_pb2.Objects: The parsed object.
310+
"""
311+
lidar_boxes = result['pred_instances_3d']['bboxes_3d'].tensor
312+
scores = result['pred_instances_3d']['scores_3d']
313+
labels = result['pred_instances_3d']['labels_3d']
314+
315+
def parse_one_object(index):
316+
class_name = self.classes[labels[index].item()]
317+
318+
box = label_pb2.Label.Box()
319+
height = lidar_boxes[index][5].item()
320+
heading = lidar_boxes[index][6].item()
321+
322+
while heading < -np.pi:
323+
heading += 2 * np.pi
324+
while heading > np.pi:
325+
heading -= 2 * np.pi
326+
327+
box.center_x = lidar_boxes[index][0].item()
328+
box.center_y = lidar_boxes[index][1].item()
329+
box.center_z = lidar_boxes[index][2].item() + height / 2
330+
box.length = lidar_boxes[index][3].item()
331+
box.width = lidar_boxes[index][4].item()
332+
box.height = height
333+
box.heading = heading
334+
335+
o = metrics_pb2.Object()
336+
o.object.box.CopyFrom(box)
337+
o.object.type = self.k2w_cls_map[class_name]
338+
o.score = scores[index].item()
339+
o.context_name = contextname
340+
o.frame_timestamp_micros = timestamp
341+
342+
return o
343+
344+
objects = metrics_pb2.Objects()
345+
for i in range(len(lidar_boxes)):
346+
objects.objects.append(parse_one_object(i))
347+
348+
return objects
349+
235350
def convert(self):
236351
"""Convert action."""
237352
print('Start converting ...')
238-
mmengine.track_parallel_progress(self.convert_one, range(len(self)),
239-
self.workers)
353+
convert_func = self.convert_one_fast if self.fast_eval else \
354+
self.convert_one
355+
356+
# from torch.multiprocessing import set_sharing_strategy
357+
# # Force using "file_system" sharing strategy for stability
358+
# set_sharing_strategy("file_system")
359+
360+
# mmengine.track_parallel_progress(convert_func, range(len(self)),
361+
# self.workers)
362+
363+
# TODO: Support multiprocessing. Now, multiprocessing evaluation will
364+
# cause shared memory error in torch-1.10 and torch-1.11. Details can
365+
# be seen in https://github.com/pytorch/pytorch/issues/67864.
366+
prog_bar = mmengine.ProgressBar(len(self))
367+
for i in range(len(self)):
368+
convert_func(i)
369+
prog_bar.update()
370+
240371
print('\nFinished ...')
241372

242373
# combine all files into one .bin
@@ -248,7 +379,8 @@ def convert(self):
248379

249380
def __len__(self):
250381
"""Length of the filename list."""
251-
return len(self.waymo_tfrecord_pathnames)
382+
return len(self.results) if self.fast_eval else len(
383+
self.waymo_tfrecord_pathnames)
252384

253385
def transform(self, T, x, y, z):
254386
"""Transform the coordinates with matrix T.

0 commit comments

Comments
 (0)