Skip to content

Commit b59d8c0

Browse files
[Fix] fix instance statistics when only detecting a single class (#2003)
1 parent bd2a49d commit b59d8c0

File tree

2 files changed

+21
-10
lines changed

2 files changed

+21
-10
lines changed

configs/centerpoint/centerpoint_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@
8989
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
9090
dict(type='ObjectNameFilter', classes=class_names),
9191
dict(type='PointShuffle'),
92-
dict(type='DefaultFormatBundle3D', class_names=class_names),
9392
dict(
9493
type='Pack3DDetInputs',
9594
keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])

mmdet3d/datasets/det3d_dataset.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -255,8 +255,9 @@ def parse_ann_info(self, info: dict) -> Optional[dict]:
255255
ann_info['instances'] = info['instances']
256256

257257
for label in ann_info['gt_labels_3d']:
258-
cat_name = self.metainfo['classes'][label]
259-
self.num_ins_per_cat[cat_name] += 1
258+
if label != -1:
259+
cat_name = self.metainfo['classes'][label]
260+
self.num_ins_per_cat[cat_name] += 1
260261

261262
return ann_info
262263

@@ -336,12 +337,16 @@ def _show_ins_var(self, old_labels: np.ndarray, new_labels: torch.Tensor):
336337
"""
337338
ori_num_per_cat = dict()
338339
for label in old_labels:
339-
cat_name = self.metainfo['classes'][label]
340-
ori_num_per_cat[cat_name] = ori_num_per_cat.get(cat_name, 0) + 1
340+
if label != -1:
341+
cat_name = self.metainfo['classes'][label]
342+
ori_num_per_cat[cat_name] = ori_num_per_cat.get(cat_name,
343+
0) + 1
341344
new_num_per_cat = dict()
342345
for label in new_labels:
343-
cat_name = self.metainfo['classes'][label]
344-
new_num_per_cat[cat_name] = new_num_per_cat.get(cat_name, 0) + 1
346+
if label != -1:
347+
cat_name = self.metainfo['classes'][label]
348+
new_num_per_cat[cat_name] = new_num_per_cat.get(cat_name,
349+
0) + 1
345350
content_show = [['category', 'new number', 'ori number']]
346351
for cat_name, num in ori_num_per_cat.items():
347352
new_num = new_num_per_cat.get(cat_name, 0)
@@ -387,9 +392,16 @@ def prepare_data(self, index: int) -> Optional[dict]:
387392
return None
388393

389394
if self.show_ins_var:
390-
self._show_ins_var(
391-
ori_input_dict['ann_info']['gt_labels_3d'],
392-
example['data_samples'].gt_instances_3d.labels_3d)
395+
if 'ann_info' in ori_input_dict:
396+
self._show_ins_var(
397+
ori_input_dict['ann_info']['gt_labels_3d'],
398+
example['data_samples'].gt_instances_3d.labels_3d)
399+
else:
400+
print_log(
401+
"'ann_info' is not in the input dict. It's probably that "
402+
'the data is not in training mode',
403+
'current',
404+
level=30)
393405

394406
return example
395407

0 commit comments

Comments
 (0)