5
5
6
6
try :
7
7
from waymo_open_dataset import dataset_pb2 as open_dataset
8
+ from waymo_open_dataset import label_pb2
9
+ from waymo_open_dataset .protos import metrics_pb2
10
+ from waymo_open_dataset .protos .metrics_pb2 import Objects
8
11
except ImportError :
12
+ Objects = None
9
13
raise ImportError (
10
14
'Please run "pip install waymo-open-dataset-tf-2-1-0==1.2.0" '
11
15
'to install the official devkit first.' )
12
16
13
17
from glob import glob
14
18
from os .path import join
19
+ from typing import List , Optional
15
20
16
21
import mmengine
17
22
import numpy as np
18
23
import tensorflow as tf
19
- from waymo_open_dataset import label_pb2
20
- from waymo_open_dataset .protos import metrics_pb2
21
24
22
25
23
- class KITTI2Waymo (object ):
24
- """KITTI predictions to Waymo converter.
26
+ class Prediction2Waymo (object ):
27
+ """Predictions to Waymo converter. The format of prediction results could
28
+ be original format or kitti-format.
25
29
26
30
This class serves as the converter to change predictions from KITTI to
27
31
Waymo format.
28
32
29
33
Args:
30
- kitti_result_files (list[dict]): Predictions in KITTI format .
34
+ results (list[dict]): Prediction results .
31
35
waymo_tfrecords_dir (str): Directory to load waymo raw data.
32
36
waymo_results_save_dir (str): Directory to save converted predictions
33
37
in waymo format (.bin files).
34
38
waymo_results_final_path (str): Path to save combined
35
39
predictions in waymo format (.bin file), like 'a/b/c.bin'.
36
40
prefix (str): Prefix of filename. In general, 0 for training, 1 for
37
41
validation and 2 for testing.
38
- workers (str): Number of parallel processes.
42
+ classes (dict): A list of class name.
43
+ workers (str): Number of parallel processes. Defaults to 2.
44
+ file_client_args (str): File client for reading gt in waymo format.
45
+ Defaults to ``dict(backend='disk')``.
46
+ from_kitti_format (bool, optional): Whether the reuslts are kitti
47
+ format. Defaults to False.
48
+ idx2metainfo (Optional[dict], optional): The mapping from sample_idx to
49
+ metainfo. The metainfo must contain the keys: 'idx2contextname' and
50
+ 'idx2timestamp'. Defaults to None.
39
51
"""
40
52
41
53
def __init__ (self ,
42
- kitti_result_files ,
43
- waymo_tfrecords_dir ,
44
- waymo_results_save_dir ,
45
- waymo_results_final_path ,
46
- prefix ,
47
- workers = 64 ,
48
- file_client_args = dict (backend = 'disk' )):
49
-
50
- self .kitti_result_files = kitti_result_files
54
+ results : List [dict ],
55
+ waymo_tfrecords_dir : str ,
56
+ waymo_results_save_dir : str ,
57
+ waymo_results_final_path : str ,
58
+ prefix : str ,
59
+ classes : dict ,
60
+ workers : int = 2 ,
61
+ file_client_args : dict = dict (backend = 'disk' ),
62
+ from_kitti_format : bool = False ,
63
+ idx2metainfo : Optional [dict ] = None ):
64
+
65
+ self .results = results
51
66
self .waymo_tfrecords_dir = waymo_tfrecords_dir
52
67
self .waymo_results_save_dir = waymo_results_save_dir
53
68
self .waymo_results_final_path = waymo_results_final_path
54
69
self .prefix = prefix
70
+ self .classes = classes
55
71
self .workers = int (workers )
56
72
self .file_client_args = file_client_args
57
- self .name2idx = {}
58
- for idx , result in enumerate (kitti_result_files ):
59
- if len (result ['sample_id' ]) > 0 :
60
- self .name2idx [str (result ['sample_id' ][0 ])] = idx
73
+ self .from_kitti_format = from_kitti_format
74
+ if idx2metainfo is not None :
75
+ self .idx2metainfo = idx2metainfo
76
+ # If ``fast_eval``, the metainfo does not need to be read from
77
+ # original data online. It's preprocessed offline.
78
+ self .fast_eval = True
79
+ else :
80
+ self .fast_eval = False
61
81
62
- # turn on eager execution for older tensorflow versions
63
- if int (tf .__version__ .split ('.' )[0 ]) < 2 :
64
- tf .enable_eager_execution ()
82
+ self .name2idx = {}
65
83
66
84
self .k2w_cls_map = {
67
85
'Car' : label_pb2 .Label .TYPE_VEHICLE ,
@@ -70,12 +88,28 @@ def __init__(self,
70
88
'Cyclist' : label_pb2 .Label .TYPE_CYCLIST ,
71
89
}
72
90
73
- self .T_ref_to_front_cam = np .array ([[0.0 , 0.0 , 1.0 , 0.0 ],
74
- [- 1.0 , 0.0 , 0.0 , 0.0 ],
75
- [0.0 , - 1.0 , 0.0 , 0.0 ],
76
- [0.0 , 0.0 , 0.0 , 1.0 ]])
91
+ if self .from_kitti_format :
92
+ self .T_ref_to_front_cam = np .array ([[0.0 , 0.0 , 1.0 , 0.0 ],
93
+ [- 1.0 , 0.0 , 0.0 , 0.0 ],
94
+ [0.0 , - 1.0 , 0.0 , 0.0 ],
95
+ [0.0 , 0.0 , 0.0 , 1.0 ]])
96
+ # ``sample_idx`` of the sample in kitti-format is an array
97
+ for idx , result in enumerate (results ):
98
+ if len (result ['sample_idx' ]) > 0 :
99
+ self .name2idx [str (result ['sample_idx' ][0 ])] = idx
100
+ else :
101
+ # ``sample_idx`` of the sample in the original prediction
102
+ # is an int value.
103
+ for idx , result in enumerate (results ):
104
+ self .name2idx [str (result ['sample_idx' ])] = idx
105
+
106
+ if not self .fast_eval :
107
+ # need to read original '.tfrecord' file
108
+ self .get_file_names ()
109
+ # turn on eager execution for older tensorflow versions
110
+ if int (tf .__version__ .split ('.' )[0 ]) < 2 :
111
+ tf .enable_eager_execution ()
77
112
78
- self .get_file_names ()
79
113
self .create_folder ()
80
114
81
115
def get_file_names (self ):
@@ -207,22 +241,30 @@ def convert_one(self, file_idx):
207
241
208
242
filename = f'{ self .prefix } { file_idx :03d} { frame_num :03d} '
209
243
210
- for camera in frame .context .camera_calibrations :
211
- # FRONT = 1, see dataset.proto for details
212
- if camera .name == 1 :
213
- T_front_cam_to_vehicle = np .array (
214
- camera .extrinsic .transform ).reshape (4 , 4 )
215
-
216
- T_k2w = T_front_cam_to_vehicle @ self .T_ref_to_front_cam
217
-
218
244
context_name = frame .context .name
219
245
frame_timestamp_micros = frame .timestamp_micros
220
246
221
247
if filename in self .name2idx :
222
- kitti_result = \
223
- self .kitti_result_files [self .name2idx [filename ]]
224
- objects = self .parse_objects (kitti_result , T_k2w , context_name ,
225
- frame_timestamp_micros )
248
+ if self .from_kitti_format :
249
+ for camera in frame .context .camera_calibrations :
250
+ # FRONT = 1, see dataset.proto for details
251
+ if camera .name == 1 :
252
+ T_front_cam_to_vehicle = np .array (
253
+ camera .extrinsic .transform ).reshape (4 , 4 )
254
+
255
+ T_k2w = T_front_cam_to_vehicle @ self .T_ref_to_front_cam
256
+
257
+ kitti_result = \
258
+ self .results [self .name2idx [filename ]]
259
+ objects = self .parse_objects (kitti_result , T_k2w ,
260
+ context_name ,
261
+ frame_timestamp_micros )
262
+ else :
263
+ index = self .name2idx [filename ]
264
+ objects = self .parse_objects_from_origin (
265
+ self .results [index ], context_name ,
266
+ frame_timestamp_micros )
267
+
226
268
else :
227
269
print (filename , 'not found.' )
228
270
objects = metrics_pb2 .Objects ()
@@ -232,11 +274,100 @@ def convert_one(self, file_idx):
232
274
'wb' ) as f :
233
275
f .write (objects .SerializeToString ())
234
276
277
+ def convert_one_fast (self , res_index : int ):
278
+ """Convert action for single file. It read the metainfo from the
279
+ preprocessed file offline and will be faster.
280
+
281
+ Args:
282
+ res_index (int): The indices of the results.
283
+ """
284
+ sample_idx = self .results [res_index ]['sample_idx' ]
285
+ if len (self .results [res_index ]['pred_instances_3d' ]) > 0 :
286
+ objects = self .parse_objects_from_origin (
287
+ self .results [res_index ],
288
+ self .idx2metainfo [str (sample_idx )]['contextname' ],
289
+ self .idx2metainfo [str (sample_idx )]['timestamp' ])
290
+ else :
291
+ print (sample_idx , 'not found.' )
292
+ objects = metrics_pb2 .Objects ()
293
+
294
+ with open (
295
+ join (self .waymo_results_save_dir , f'{ sample_idx } .bin' ),
296
+ 'wb' ) as f :
297
+ f .write (objects .SerializeToString ())
298
+
299
+ def parse_objects_from_origin (self , result : dict , contextname : str ,
300
+ timestamp : str ) -> Objects :
301
+ """Parse obejcts from the original prediction results.
302
+
303
+ Args:
304
+ result (dict): The original prediction results.
305
+ contextname (str): The ``contextname`` of sample in waymo.
306
+ timestamp (str): The ``timestamp`` of sample in waymo.
307
+
308
+ Returns:
309
+ metrics_pb2.Objects: The parsed object.
310
+ """
311
+ lidar_boxes = result ['pred_instances_3d' ]['bboxes_3d' ].tensor
312
+ scores = result ['pred_instances_3d' ]['scores_3d' ]
313
+ labels = result ['pred_instances_3d' ]['labels_3d' ]
314
+
315
+ def parse_one_object (index ):
316
+ class_name = self .classes [labels [index ].item ()]
317
+
318
+ box = label_pb2 .Label .Box ()
319
+ height = lidar_boxes [index ][5 ].item ()
320
+ heading = lidar_boxes [index ][6 ].item ()
321
+
322
+ while heading < - np .pi :
323
+ heading += 2 * np .pi
324
+ while heading > np .pi :
325
+ heading -= 2 * np .pi
326
+
327
+ box .center_x = lidar_boxes [index ][0 ].item ()
328
+ box .center_y = lidar_boxes [index ][1 ].item ()
329
+ box .center_z = lidar_boxes [index ][2 ].item () + height / 2
330
+ box .length = lidar_boxes [index ][3 ].item ()
331
+ box .width = lidar_boxes [index ][4 ].item ()
332
+ box .height = height
333
+ box .heading = heading
334
+
335
+ o = metrics_pb2 .Object ()
336
+ o .object .box .CopyFrom (box )
337
+ o .object .type = self .k2w_cls_map [class_name ]
338
+ o .score = scores [index ].item ()
339
+ o .context_name = contextname
340
+ o .frame_timestamp_micros = timestamp
341
+
342
+ return o
343
+
344
+ objects = metrics_pb2 .Objects ()
345
+ for i in range (len (lidar_boxes )):
346
+ objects .objects .append (parse_one_object (i ))
347
+
348
+ return objects
349
+
235
350
def convert (self ):
236
351
"""Convert action."""
237
352
print ('Start converting ...' )
238
- mmengine .track_parallel_progress (self .convert_one , range (len (self )),
239
- self .workers )
353
+ convert_func = self .convert_one_fast if self .fast_eval else \
354
+ self .convert_one
355
+
356
+ # from torch.multiprocessing import set_sharing_strategy
357
+ # # Force using "file_system" sharing strategy for stability
358
+ # set_sharing_strategy("file_system")
359
+
360
+ # mmengine.track_parallel_progress(convert_func, range(len(self)),
361
+ # self.workers)
362
+
363
+ # TODO: Support multiprocessing. Now, multiprocessing evaluation will
364
+ # cause shared memory error in torch-1.10 and torch-1.11. Details can
365
+ # be seen in https://github.com/pytorch/pytorch/issues/67864.
366
+ prog_bar = mmengine .ProgressBar (len (self ))
367
+ for i in range (len (self )):
368
+ convert_func (i )
369
+ prog_bar .update ()
370
+
240
371
print ('\n Finished ...' )
241
372
242
373
# combine all files into one .bin
@@ -248,7 +379,8 @@ def convert(self):
248
379
249
380
def __len__ (self ):
250
381
"""Length of the filename list."""
251
- return len (self .waymo_tfrecord_pathnames )
382
+ return len (self .results ) if self .fast_eval else len (
383
+ self .waymo_tfrecord_pathnames )
252
384
253
385
def transform (self , T , x , y , z ):
254
386
"""Transform the coordinates with matrix T.
0 commit comments