Skip to content

Commit d6ad6a7

Browse files
authored
add the code of generating cam_sync_labels in waymo dataset (#1870)
1 parent 74117ce commit d6ad6a7

File tree

8 files changed

+198
-44
lines changed

8 files changed

+198
-44
lines changed

mmdet3d/datasets/convert_utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -459,14 +459,12 @@ def generate_waymo_mono3d_record(ann_rec, x1, y1, x2, y2, sample_data_token,
459459
repro_rec['bbox_corners'] = [x1, y1, x2, y2]
460460
repro_rec['filename'] = filename
461461

462-
coco_rec['file_name'] = filename
463462
coco_rec['image_id'] = sample_data_token
464463
coco_rec['area'] = (y2 - y1) * (x2 - x1)
465464

466465
if repro_rec['category_name'] not in kitti_categories:
467466
return None
468467
cat_name = repro_rec['category_name']
469-
coco_rec['category_name'] = cat_name
470468
coco_rec['category_id'] = kitti_categories.index(cat_name)
471469
coco_rec['bbox_label'] = coco_rec['category_id']
472470
coco_rec['bbox_label_3d'] = coco_rec['bbox_label']

mmdet3d/structures/bbox_3d/box_3d_mode.py

Lines changed: 72 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ class Box3DMode(IntEnum):
6363
DEPTH = 2
6464

6565
@staticmethod
66-
def convert(box, src, dst, rt_mat=None, with_yaw=True):
66+
def convert(box, src, dst, rt_mat=None, with_yaw=True, correct_yaw=False):
6767
"""Convert boxes from `src` mode to `dst` mode.
6868
6969
Args:
@@ -81,6 +81,7 @@ def convert(box, src, dst, rt_mat=None, with_yaw=True):
8181
with_yaw (bool, optional): If `box` is an instance of
8282
:obj:`BaseInstance3DBoxes`, whether or not it has a yaw angle.
8383
Defaults to True.
84+
correct_yaw (bool): If the yaw is rotated by rt_mat.
8485
8586
Returns:
8687
(tuple | list | np.ndarray | torch.Tensor |
@@ -119,41 +120,89 @@ def convert(box, src, dst, rt_mat=None, with_yaw=True):
119120
rt_mat = arr.new_tensor([[0, -1, 0], [0, 0, -1], [1, 0, 0]])
120121
xyz_size = torch.cat([x_size, z_size, y_size], dim=-1)
121122
if with_yaw:
122-
yaw = -yaw - np.pi / 2
123-
yaw = limit_period(yaw, period=np.pi * 2)
123+
if correct_yaw:
124+
yaw_vector = torch.cat([
125+
torch.cos(yaw),
126+
torch.sin(yaw),
127+
torch.zeros_like(yaw)
128+
],
129+
dim=1)
130+
else:
131+
yaw = -yaw - np.pi / 2
132+
yaw = limit_period(yaw, period=np.pi * 2)
124133
elif src == Box3DMode.CAM and dst == Box3DMode.LIDAR:
125134
if rt_mat is None:
126135
rt_mat = arr.new_tensor([[0, 0, 1], [-1, 0, 0], [0, -1, 0]])
127136
xyz_size = torch.cat([x_size, z_size, y_size], dim=-1)
128137
if with_yaw:
129-
yaw = -yaw - np.pi / 2
130-
yaw = limit_period(yaw, period=np.pi * 2)
138+
if correct_yaw:
139+
yaw_vector = torch.cat([
140+
torch.cos(-yaw),
141+
torch.zeros_like(yaw),
142+
torch.sin(-yaw)
143+
],
144+
dim=1)
145+
else:
146+
yaw = -yaw - np.pi / 2
147+
yaw = limit_period(yaw, period=np.pi * 2)
131148
elif src == Box3DMode.DEPTH and dst == Box3DMode.CAM:
132149
if rt_mat is None:
133150
rt_mat = arr.new_tensor([[1, 0, 0], [0, 0, -1], [0, 1, 0]])
134151
xyz_size = torch.cat([x_size, z_size, y_size], dim=-1)
135152
if with_yaw:
136-
yaw = -yaw
153+
if correct_yaw:
154+
yaw_vector = torch.cat([
155+
torch.cos(yaw),
156+
torch.sin(yaw),
157+
torch.zeros_like(yaw)
158+
],
159+
dim=1)
160+
else:
161+
yaw = -yaw
137162
elif src == Box3DMode.CAM and dst == Box3DMode.DEPTH:
138163
if rt_mat is None:
139164
rt_mat = arr.new_tensor([[1, 0, 0], [0, 0, 1], [0, -1, 0]])
140165
xyz_size = torch.cat([x_size, z_size, y_size], dim=-1)
141166
if with_yaw:
142-
yaw = -yaw
167+
if correct_yaw:
168+
yaw_vector = torch.cat([
169+
torch.cos(-yaw),
170+
torch.zeros_like(yaw),
171+
torch.sin(-yaw)
172+
],
173+
dim=1)
174+
else:
175+
yaw = -yaw
143176
elif src == Box3DMode.LIDAR and dst == Box3DMode.DEPTH:
144177
if rt_mat is None:
145178
rt_mat = arr.new_tensor([[0, -1, 0], [1, 0, 0], [0, 0, 1]])
146179
xyz_size = torch.cat([x_size, y_size, z_size], dim=-1)
147180
if with_yaw:
148-
yaw = yaw + np.pi / 2
149-
yaw = limit_period(yaw, period=np.pi * 2)
181+
if correct_yaw:
182+
yaw_vector = torch.cat([
183+
torch.cos(yaw),
184+
torch.sin(yaw),
185+
torch.zeros_like(yaw)
186+
],
187+
dim=1)
188+
else:
189+
yaw = yaw + np.pi / 2
190+
yaw = limit_period(yaw, period=np.pi * 2)
150191
elif src == Box3DMode.DEPTH and dst == Box3DMode.LIDAR:
151192
if rt_mat is None:
152193
rt_mat = arr.new_tensor([[0, 1, 0], [-1, 0, 0], [0, 0, 1]])
153194
xyz_size = torch.cat([x_size, y_size, z_size], dim=-1)
154195
if with_yaw:
155-
yaw = yaw - np.pi / 2
156-
yaw = limit_period(yaw, period=np.pi * 2)
196+
if correct_yaw:
197+
yaw_vector = torch.cat([
198+
torch.cos(yaw),
199+
torch.sin(yaw),
200+
torch.zeros_like(yaw)
201+
],
202+
dim=1)
203+
else:
204+
yaw = yaw - np.pi / 2
205+
yaw = limit_period(yaw, period=np.pi * 2)
157206
else:
158207
raise NotImplementedError(
159208
f'Conversion from Box3DMode {src} to {dst} '
@@ -168,6 +217,18 @@ def convert(box, src, dst, rt_mat=None, with_yaw=True):
168217
else:
169218
xyz = arr[..., :3] @ rt_mat.t()
170219

220+
# Note: we only use rotation in rt_mat
221+
# so don't need to extend yaw_vector
222+
if with_yaw and correct_yaw:
223+
rot_yaw_vector = yaw_vector @ rt_mat[:3, :3].t()
224+
if dst == Box3DMode.CAM:
225+
yaw = torch.atan2(-rot_yaw_vector[:, [2]], rot_yaw_vector[:,
226+
[0]])
227+
elif dst in [Box3DMode.LIDAR, Box3DMode.DEPTH]:
228+
yaw = torch.atan2(rot_yaw_vector[:, [1]], rot_yaw_vector[:,
229+
[0]])
230+
yaw = limit_period(yaw, period=np.pi * 2)
231+
171232
if with_yaw:
172233
remains = arr[..., 7:]
173234
arr = torch.cat([xyz[..., :3], xyz_size, yaw, remains], dim=-1)

mmdet3d/structures/bbox_3d/cam_box3d.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def height_overlaps(cls, boxes1, boxes2, mode='iou'):
280280
overlaps_h = torch.clamp(heighest_of_bottom - lowest_of_top, min=0)
281281
return overlaps_h
282282

283-
def convert_to(self, dst, rt_mat=None):
283+
def convert_to(self, dst, rt_mat=None, correct_yaw=False):
284284
"""Convert self to ``dst`` mode.
285285
286286
Args:
@@ -291,14 +291,19 @@ def convert_to(self, dst, rt_mat=None):
291291
The conversion from ``src`` coordinates to ``dst`` coordinates
292292
usually comes along the change of sensors, e.g., from camera
293293
to LiDAR. This requires a transformation matrix.
294-
294+
correct_yaw (bool): If convert the yaw angle to the target
295+
coordinate. Defaults to False.
295296
Returns:
296297
:obj:`BaseInstance3DBoxes`:
297298
The converted box of the same type in the ``dst`` mode.
298299
"""
299300
from .box_3d_mode import Box3DMode
300301
return Box3DMode.convert(
301-
box=self, src=Box3DMode.CAM, dst=dst, rt_mat=rt_mat)
302+
box=self,
303+
src=Box3DMode.CAM,
304+
dst=dst,
305+
rt_mat=rt_mat,
306+
correct_yaw=correct_yaw)
302307

303308
def points_in_boxes_part(self, points, boxes_override=None):
304309
"""Find the box in which each point is.

mmdet3d/structures/bbox_3d/lidar_box3d.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def flip(self, bev_direction='horizontal', points=None):
174174
points.flip(bev_direction)
175175
return points
176176

177-
def convert_to(self, dst, rt_mat=None):
177+
def convert_to(self, dst, rt_mat=None, correct_yaw=False):
178178
"""Convert self to ``dst`` mode.
179179
180180
Args:
@@ -185,14 +185,19 @@ def convert_to(self, dst, rt_mat=None):
185185
The conversion from ``src`` coordinates to ``dst`` coordinates
186186
usually comes along the change of sensors, e.g., from camera
187187
to LiDAR. This requires a transformation matrix.
188-
188+
correct_yaw (bool): If convert the yaw angle to the target
189+
coordinate. Defaults to False.
189190
Returns:
190191
:obj:`BaseInstance3DBoxes`:
191192
The converted box of the same type in the ``dst`` mode.
192193
"""
193194
from .box_3d_mode import Box3DMode
194195
return Box3DMode.convert(
195-
box=self, src=Box3DMode.LIDAR, dst=dst, rt_mat=rt_mat)
196+
box=self,
197+
src=Box3DMode.LIDAR,
198+
dst=dst,
199+
rt_mat=rt_mat,
200+
correct_yaw=correct_yaw)
196201

197202
def enlarged_box(self, extra_width):
198203
"""Enlarge the length, width and height boxes.

tools/create_data.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,9 @@ def waymo_data_prep(root_path,
191191
"""
192192
from tools.dataset_converters import waymo_converter as waymo
193193

194-
splits = ['training', 'validation', 'testing']
194+
splits = [
195+
'training', 'validation', 'testing', 'testing_3d_camera_only_detection'
196+
]
195197
for i, split in enumerate(splits):
196198
load_dir = osp.join(root_path, 'waymo_format', split)
197199
if split == 'validation':
@@ -203,7 +205,8 @@ def waymo_data_prep(root_path,
203205
save_dir,
204206
prefix=str(i),
205207
workers=workers,
206-
test_mode=(split == 'testing'))
208+
test_mode=(split
209+
in ['testing', 'testing_3d_camera_only_detection']))
207210
converter.convert()
208211
# Generate waymo infos
209212
out_dir = osp.join(out_dir, 'kitti_format')

tools/dataset_converters/kitti_data_utils.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,9 +394,18 @@ def gather_single(self, idx):
394394
self.relative_path,
395395
info_type='label_all',
396396
use_prefix_id=True)
397+
cam_sync_label_path = get_label_path(
398+
idx,
399+
self.path,
400+
self.training,
401+
self.relative_path,
402+
info_type='cam_sync_label_all',
403+
use_prefix_id=True)
397404
if self.relative_path:
398405
label_path = str(root_path / label_path)
406+
cam_sync_label_path = str(root_path / cam_sync_label_path)
399407
annotations = get_label_anno(label_path)
408+
cam_sync_annotations = get_label_anno(cam_sync_label_path)
400409
info['image'] = image_info
401410
info['point_cloud'] = pc_info
402411
if self.calib:
@@ -437,16 +446,37 @@ def gather_single(self, idx):
437446
Tr_velo_to_cam = np.array([
438447
float(info) for info in lines[6].split(' ')[1:13]
439448
]).reshape([3, 4])
449+
Tr_velo_to_cam1 = np.array([
450+
float(info) for info in lines[7].split(' ')[1:13]
451+
]).reshape([3, 4])
452+
Tr_velo_to_cam2 = np.array([
453+
float(info) for info in lines[8].split(' ')[1:13]
454+
]).reshape([3, 4])
455+
Tr_velo_to_cam3 = np.array([
456+
float(info) for info in lines[9].split(' ')[1:13]
457+
]).reshape([3, 4])
458+
Tr_velo_to_cam4 = np.array([
459+
float(info) for info in lines[10].split(' ')[1:13]
460+
]).reshape([3, 4])
440461
if self.extend_matrix:
441462
Tr_velo_to_cam = _extend_matrix(Tr_velo_to_cam)
463+
Tr_velo_to_cam1 = _extend_matrix(Tr_velo_to_cam1)
464+
Tr_velo_to_cam2 = _extend_matrix(Tr_velo_to_cam2)
465+
Tr_velo_to_cam3 = _extend_matrix(Tr_velo_to_cam3)
466+
Tr_velo_to_cam4 = _extend_matrix(Tr_velo_to_cam4)
442467
calib_info['P0'] = P0
443468
calib_info['P1'] = P1
444469
calib_info['P2'] = P2
445470
calib_info['P3'] = P3
446471
calib_info['P4'] = P4
447472
calib_info['R0_rect'] = rect_4x4
448473
calib_info['Tr_velo_to_cam'] = Tr_velo_to_cam
474+
calib_info['Tr_velo_to_cam1'] = Tr_velo_to_cam1
475+
calib_info['Tr_velo_to_cam2'] = Tr_velo_to_cam2
476+
calib_info['Tr_velo_to_cam3'] = Tr_velo_to_cam3
477+
calib_info['Tr_velo_to_cam4'] = Tr_velo_to_cam4
449478
info['calib'] = calib_info
479+
450480
if self.pose:
451481
pose_path = get_pose_path(
452482
idx,
@@ -460,6 +490,13 @@ def gather_single(self, idx):
460490
info['annos'] = annotations
461491
info['annos']['camera_id'] = info['annos'].pop('score')
462492
add_difficulty_to_annos(info)
493+
info['cam_sync_annos'] = cam_sync_annotations
494+
# NOTE: the 2D labels do not have strict correspondence with
495+
# the projected 2D lidar labels
496+
# e.g.: the projected 2D labels can be in camera 2
497+
# while the most_visible_camera can have id 4
498+
info['cam_sync_annos']['camera_id'] = info['cam_sync_annos'].pop(
499+
'score')
463500

464501
sweeps = []
465502
prev_idx = idx

tools/dataset_converters/update_infos_to_v2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def clear_data_info_unused_keys(data_info):
195195
empty_flag = True
196196
for key in keys:
197197
# we allow no annotations in datainfo
198-
if key == 'instances':
198+
if key in ['instances', 'cam_sync_instances', 'cam_instances']:
199199
empty_flag = False
200200
continue
201201
if isinstance(data_info[key], list):
@@ -1057,4 +1057,4 @@ def update_pkl_infos(dataset, out_dir, pkl_path):
10571057
if args.out_dir is None:
10581058
args.out_dir = args.root_dir
10591059
update_pkl_infos(
1060-
dataset=args.dataset, out_dir=args.out_dir, pkl_path=args.pkl_path)
1060+
dataset=args.dataset, out_dir=args.out_dir, pkl_path=args.pkl)

0 commit comments

Comments
 (0)