From aac67593c3de59a1ad34846c60f4f15d31750189 Mon Sep 17 00:00:00 2001 From: HinGwenWoong Date: Sat, 3 Dec 2022 18:41:50 +0800 Subject: [PATCH 1/5] Add FSD doc --- configs/fsd/README.md | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 configs/fsd/README.md diff --git a/configs/fsd/README.md b/configs/fsd/README.md new file mode 100644 index 0000000000..957a82f888 --- /dev/null +++ b/configs/fsd/README.md @@ -0,0 +1,34 @@ +# FSD: Fully Sparse 3D Object Detection & SST: Single-stride Sparse Transformer + +> [Fully Sparse 3D Object Detection](https://arxiv.org/abs/2207.10035) + + + +## Abstract + +## Introduction + + +## Usage + +### Modify config + + +## Results and models + +### PointPillars + +| Backbone | Lr schd | Mem (GB) | Inf time (fps) | mAP | NDS | Download | +|:--------:| :-----: | :------: | :------------: | :---: | :---: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| [x](x) | 2x | 17.1 | | 40.0 | 53.3 | [model]() \| [log]() | + +## Citation + +```latex +@article{fan2022fully, + title={{Fully Sparse 3D Object Detection}}, + author={Fan, Lue and Wang, Feng and Wang, Naiyan and Zhang, Zhaoxiang}, + journal={arXiv preprint arXiv:2207.10035}, + year={2022} +} +``` From 8d0d7b5b369987ea730589b4cdbb07c3f9ec25d6 Mon Sep 17 00:00:00 2001 From: HinGwenWoong Date: Tue, 6 Dec 2022 23:33:52 +0800 Subject: [PATCH 2/5] Add FSD code, Network is done, but dataset read still need to debug --- configs/_base_/models/fsd.py | 1 + configs/_base_/schedules/cosine_2x.py | 24 + configs/fsd/fsd_waymoD1_1x.py | 408 ++++++ configs/fsd/metafile.yml | 0 mmdet3d/engine/hooks/__init__.py | 4 +- mmdet3d/engine/hooks/fsd_hooks.py | 89 ++ mmdet3d/models/backbones/__init__.py | 3 +- mmdet3d/models/backbones/sir.py | 89 ++ mmdet3d/models/decode_heads/__init__.py | 3 +- .../models/decode_heads/segmentation_head.py | 264 ++++ mmdet3d/models/dense_heads/__init__.py | 5 +- .../models/dense_heads/sparse_cluster_head.py | 576 ++++++++ .../dense_heads/sparse_cluster_head_v2.py | 578 ++++++++ mmdet3d/models/detectors/__init__.py | 4 +- mmdet3d/models/detectors/single_stage_fsd.py | 1162 +++++++++++++++++ mmdet3d/models/detectors/two_stage_fsd.py | 258 ++++ mmdet3d/models/layers/sst/__init__.py | 6 + mmdet3d/models/layers/sst/sst_ops.py | 391 ++++++ mmdet3d/models/middle_encoders/__init__.py | 3 +- mmdet3d/models/middle_encoders/sparse_unet.py | 89 ++ .../middle_encoders/sst_input_layer_v2.py | 328 +++++ mmdet3d/models/necks/__init__.py | 3 +- mmdet3d/models/necks/voxel2point_neck.py | 62 + mmdet3d/models/roi_heads/__init__.py | 4 +- .../models/roi_heads/bbox_heads/__init__.py | 3 +- .../roi_heads/bbox_heads/fsd_bbox_head.py | 791 +++++++++++ mmdet3d/models/roi_heads/fsd_roi_head.py | 309 +++++ .../roi_heads/roi_extractors/__init__.py | 4 +- .../roi_extractors/dynamic_point_pool_op.py | 58 + .../dynamic_point_roi_extractor.py | 100 ++ .../models/task_modules/coders/__init__.py | 3 +- .../coders/base_point_bbox_coder.py | 83 ++ mmdet3d/models/voxel_encoders/__init__.py | 4 +- mmdet3d/models/voxel_encoders/utils.py | 87 ++ .../models/voxel_encoders/voxel_encoder.py | 284 +++- mmdet3d/structures/ops/iou3d_calculator.py | 52 + 36 files changed, 6113 insertions(+), 19 deletions(-) create mode 100644 configs/_base_/models/fsd.py create mode 100644 configs/_base_/schedules/cosine_2x.py create mode 100644 configs/fsd/fsd_waymoD1_1x.py create mode 100644 configs/fsd/metafile.yml create mode 100644 mmdet3d/engine/hooks/fsd_hooks.py create mode 100644 mmdet3d/models/backbones/sir.py create mode 100644 mmdet3d/models/decode_heads/segmentation_head.py create mode 100644 mmdet3d/models/dense_heads/sparse_cluster_head.py create mode 100644 mmdet3d/models/dense_heads/sparse_cluster_head_v2.py create mode 100644 mmdet3d/models/detectors/single_stage_fsd.py create mode 100644 mmdet3d/models/detectors/two_stage_fsd.py create mode 100644 mmdet3d/models/layers/sst/__init__.py create mode 100644 mmdet3d/models/layers/sst/sst_ops.py create mode 100644 mmdet3d/models/middle_encoders/sst_input_layer_v2.py create mode 100644 mmdet3d/models/necks/voxel2point_neck.py create mode 100644 mmdet3d/models/roi_heads/bbox_heads/fsd_bbox_head.py create mode 100644 mmdet3d/models/roi_heads/fsd_roi_head.py create mode 100644 mmdet3d/models/roi_heads/roi_extractors/dynamic_point_pool_op.py create mode 100644 mmdet3d/models/roi_heads/roi_extractors/dynamic_point_roi_extractor.py create mode 100644 mmdet3d/models/task_modules/coders/base_point_bbox_coder.py diff --git a/configs/_base_/models/fsd.py b/configs/_base_/models/fsd.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/configs/_base_/models/fsd.py @@ -0,0 +1 @@ + diff --git a/configs/_base_/schedules/cosine_2x.py b/configs/_base_/schedules/cosine_2x.py new file mode 100644 index 0000000000..b5ca4aacc6 --- /dev/null +++ b/configs/_base_/schedules/cosine_2x.py @@ -0,0 +1,24 @@ +lr = 1e-5 + +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict( + type='AdamW', + lr=lr, + betas=(0.9, 0.999), # the momentum is change during training + weight_decay=0.05, + ), + paramwise_cfg=dict(custom_keys={'norm': dict(decay_mult=0.)}), + clip_grad=dict(grad_clip=dict(max_norm=10, norm_type=2)) +) + +lr_config = dict( + policy='cyclic', + target_ratio=(100, 1e-3), + cyclic_times=1, + step_ratio_up=0.1, +) +momentum_config = None +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=24, val_interval=1) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') \ No newline at end of file diff --git a/configs/fsd/fsd_waymoD1_1x.py b/configs/fsd/fsd_waymoD1_1x.py new file mode 100644 index 0000000000..61f248eb76 --- /dev/null +++ b/configs/fsd/fsd_waymoD1_1x.py @@ -0,0 +1,408 @@ +_base_ = [ + '../_base_/datasets/waymoD5-3d-3class.py', + # '../_base_/models/fsd.py', + '../_base_/schedules/cosine_2x.py', + '../_base_/default_runtime.py', +] + +class_names = ['Car', 'Pedestrian', 'Cyclist'] +num_classes = len(class_names) + +point_cloud_range = [-80, -80, -2, 80, 80, 4] + +dataset_type = 'WaymoDataset' +data_root = 'data/waymo/kitti_format/' +file_client_args = dict(backend='disk') + + +# ==================== model ../_base_/models/fsd.py + +seg_voxel_size = (0.25, 0.25, 0.2) +seg_score_thresh = (0.3, 0.25, 0.25) + +segmentor = dict( + type='VoteSegmentor', + voxel_layer=dict( # need to refactor to Det3DDataPreprocessor + voxel_size=seg_voxel_size, + max_num_points=-1, + point_cloud_range=point_cloud_range, + max_voxels=(-1, -1) + ), + + voxel_encoder=dict( # need to refactor to pts_voxel_encoder + type='DynamicScatterVFE', + in_channels=5, + feat_channels=[64, 64], + voxel_size=seg_voxel_size, + with_cluster_center=True, + with_voxel_center=True, + point_cloud_range=point_cloud_range, + norm_cfg=dict(type='naiveSyncBN1d', eps=1e-3, momentum=0.01), + unique_once=True, + ), + + middle_encoder=dict( + type='PseudoMiddleEncoderForSpconvFSD', + ), + + backbone=dict( + type='SimpleSparseUNet', + in_channels=64, + sparse_shape=[32, 640, 640], + order=('conv', 'norm', 'act'), + norm_cfg=dict(type='naiveSyncBN1d', eps=1e-3, momentum=0.01), + base_channels=64, + output_channels=128, + encoder_channels=((64, ), (64, 64, 64), (64, 64, 64), (128, 128, 128), (256, 256, 256)), + encoder_paddings=((1, ), (1, 1, 1), (1, 1, 1), ((0, 1, 1), 1, 1), (1, 1, 1)), + decoder_channels=((256, 256, 128), (128, 128, 64), (64, 64, 64), (64, 64, 64), (64, 64, 64)), + decoder_paddings=((1, 1), (1, 0), (1, 0), (0, 0), (0, 1)), # decoder paddings seem useless in SubMConv + ), + + decode_neck=dict( + type='Voxel2PointScatterNeck', + voxel_size=seg_voxel_size, + point_cloud_range=point_cloud_range, + ), + + segmentation_head=dict( + type='VoteSegHead', + in_channel=67, + hidden_dims=[128, 128], + num_classes=num_classes, + dropout_ratio=0.0, + conv_cfg=dict(type='Conv1d'), + norm_cfg=dict(type='naiveSyncBN1d'), + act_cfg=dict(type='ReLU'), + loss_decode=dict( + type='mmdet.FocalLoss', + use_sigmoid=True, + gamma=3.0, + alpha=0.8, + loss_weight=1.0), + loss_vote=dict( + type='mmdet.L1Loss', + loss_weight=1.0), + ), + train_cfg=dict( + point_loss=True, + score_thresh=seg_score_thresh, # for training log + class_names=('Car', 'Ped', 'Cyc'), # for training log + centroid_offset=False, + ), +) + +model = dict( + type='FSD', + data_preprocessor=dict(type='Det3DDataPreprocessor'), # hin added + segmentor=segmentor, + backbone=dict( + type='SIR', + num_blocks=3, + in_channels=[84,] + [133, ] * 2, + feat_channels=[[128, 128], ] * 3, + rel_mlp_hidden_dims=[[16, 32],] * 3, + norm_cfg=dict(type='LN', eps=1e-3), + mode='max', + xyz_normalizer=[20, 20, 4], + act='gelu', + unique_once=True, + ), + + bbox_head=dict( + type='SparseClusterHeadV2', + num_classes=num_classes, + bbox_coder=dict(type='BasePointBBoxCoder'), + loss_cls=dict( + type='mmdet.FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0), + loss_center=dict(type='mmdet.L1Loss', loss_weight=0.5), + loss_size=dict(type='mmdet.L1Loss', loss_weight=0.5), + loss_rot=dict(type='mmdet.L1Loss', loss_weight=0.2), + in_channel=128 * 3 * 2, + shared_mlp_dims=[1024, 1024], + train_cfg=None, + test_cfg=None, + norm_cfg=dict(type='LN'), + tasks=[ + dict(class_names=['Car',]), + dict(class_names=['Pedestrian',]), + dict(class_names=['Cyclist',]), + ], + class_names=class_names, + common_attrs=dict( + center=(3, 2, 128), dim=(3, 2, 128), rot=(2, 2, 128), # (out_dim, num_layers, hidden_dim) + ), + num_cls_layer=2, + cls_hidden_dim=128, + separate_head=dict( + type='FSDSeparateHead', + norm_cfg=dict(type='LN'), + act='relu', + ), + as_rpn=True, + ), + roi_head=dict( + type='GroupCorrectionHead', + num_classes=num_classes, + roi_extractor=dict( + type='DynamicPointROIExtractor', + extra_wlh=[0.5, 0.5, 0.5], + max_inbox_point=256, + debug=False, + ), + bbox_head=dict( + type='FullySparseBboxHead', + num_classes=num_classes, + num_blocks=6, + in_channels=[213, 146, 146, 146, 146, 146], + feat_channels=[[128, 128], ] * 6, + rel_mlp_hidden_dims=[[16, 32],] * 6, + rel_mlp_in_channels=[13, ] * 6, + reg_mlp=[512, 512], + cls_mlp=[512, 512], + mode='max', + xyz_normalizer=[20, 20, 4], + act='gelu', + geo_input=True, + with_corner_loss=True, + corner_loss_weight=1.0, + bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder'), + norm_cfg=dict(type='LN', eps=1e-3), + unique_once=True, + + loss_bbox=dict( + type='mmdet.L1Loss', + reduction='mean', + loss_weight=2.0), + + loss_cls=dict( + type='mmdet.CrossEntropyLoss', + use_sigmoid=True, + reduction='mean', + loss_weight=1.0), + cls_dropout=0.1, + reg_dropout=0.1, + ), + train_cfg=None, + test_cfg=None, + init_cfg=None + ), + + train_cfg=dict( + score_thresh=seg_score_thresh, + sync_reg_avg_factor=True, + pre_voxelization_size=(0.1, 0.1, 0.1), + disable_pretrain=True, + disable_pretrain_topks=[600, 200, 200], + rpn=dict( + use_rotate_nms=True, + nms_pre=-1, + nms_thr=None, + score_thr=0.1, + min_bbox_size=0, + max_num=500, + ), + rcnn=dict( + assigner=[ + dict( # Car + type='Max3DIoUAssigner', + iou_calculator=dict( + type='BboxOverlaps3D', coordinate='lidar'), + pos_iou_thr=0.45, + neg_iou_thr=0.45, + min_pos_iou=0.45, + ignore_iof_thr=-1 + ), + dict( # Ped + type='Max3DIoUAssigner', + iou_calculator=dict( + type='BboxOverlaps3D', coordinate='lidar'), + pos_iou_thr=0.35, + neg_iou_thr=0.35, + min_pos_iou=0.35, + ignore_iof_thr=-1 + ), + dict( # Cyc + type='Max3DIoUAssigner', + iou_calculator=dict( + type='BboxOverlaps3D', coordinate='lidar'), + pos_iou_thr=0.35, + neg_iou_thr=0.35, + min_pos_iou=0.35, + ignore_iof_thr=-1 + ), + ], + + sampler=dict( + type='IoUNegPiecewiseSampler', + num=256, + pos_fraction=0.55, + neg_piece_fractions=[0.8, 0.2], + neg_iou_piece_thrs=[0.55, 0.1], + neg_pos_ub=-1, + add_gt_as_proposals=False, + return_iou=True + ), + cls_pos_thr=(0.8, 0.65, 0.65), + cls_neg_thr=(0.2, 0.15, 0.15), + sync_reg_avg_factor=True, + sync_cls_avg_factor=True, + corner_loss_only_car=True, + class_names=class_names, + ) + ), + test_cfg=dict( + score_thresh=seg_score_thresh, + pre_voxelization_size=(0.1, 0.1, 0.1), + skip_rcnn=False, + rpn=dict( + use_rotate_nms=True, + nms_pre=-1, + nms_thr=0.25, + score_thr=0.1, + min_bbox_size=0, + max_num=500, + ), + rcnn=dict( + use_rotate_nms=True, + nms_pre=-1, + nms_thr=0.25, + score_thr=0.1, + min_bbox_size=0, + max_num=500, + ), + ), + cluster_assigner=dict( + cluster_voxel_size=dict( + Car=(0.3, 0.3, 6), + Cyclist=(0.2, 0.2, 6), + Pedestrian=(0.05, 0.05, 6), + ), + min_points=2, + point_cloud_range=point_cloud_range, + connected_dist=dict( + Car=0.6, + Cyclist=0.4, + Pedestrian=0.1, + ), # xy-plane distance + class_names=class_names, + ), +) + +# ==================== setting + +db_sampler = dict( + data_root=data_root, + info_path=data_root + 'waymo_dbinfos_train.pkl', + rate=1.0, + prepare=dict( + filter_by_difficulty=[-1], + filter_by_min_points=dict(Car=5, Pedestrian=5, Cyclist=5)), + classes=class_names, + sample_groups=dict(Car=5, Pedestrian=5, Cyclist=3), + points_loader=dict( + type='LoadPointsFromFile', + coord_type='LIDAR', + load_dim=6, # 5 + use_dim=5)) # [0,1,2,3,4] + +train_pipeline = [ + dict( + type='LoadPointsFromFile', + coord_type='LIDAR', + load_dim=6, + use_dim=5), + dict( + type='LoadAnnotations3D', + with_bbox_3d=True, + with_label_3d=True), + dict(type='ObjectSample', db_sampler=db_sampler), + dict( + type='RandomFlip3D', + sync_2d=False, + flip_ratio_bev_horizontal=0.5, + flip_ratio_bev_vertical=0.5), + dict( + type='GlobalRotScaleTrans', + rot_range=[-0.78539816, 0.78539816], + scale_ratio_range=[0.95, 1.05], + translation_std=[0, 0, 0.2]), + dict(type='PointsRangeFilter', point_cloud_range=_base_.point_cloud_range), + dict(type='ObjectRangeFilter', point_cloud_range=_base_.point_cloud_range), + dict(type='PointShuffle'), + dict(type='Pack3DDetInputs', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d']) +] + +# construct a pipeline for data and gt loading in show function +# please keep its loading function consistent with test_pipeline (e.g. client) +eval_pipeline = [ + dict( + type='LoadPointsFromFile', + coord_type='LIDAR', + load_dim=6, + use_dim=5, + file_client_args=file_client_args), + dict(type='Pack3DDetInputs', keys=['points']) +] + +test_pipeline = [ + dict( + type='LoadPointsFromFile', + coord_type='LIDAR', + load_dim=6, + use_dim=5, + file_client_args=file_client_args), + dict( + type='MultiScaleFlipAug3D', + img_scale=(1333, 800), + pts_scale_ratio=1, + flip=False, + transforms=[ + dict( + type='GlobalRotScaleTrans', + rot_range=[0, 0], + scale_ratio_range=[1., 1.], + translation_std=[0, 0, 0]), + dict(type='RandomFlip3D'), + dict( + type='PointsRangeFilter', point_cloud_range=_base_.point_cloud_range), + dict(type='Pack3DDetInputs', keys=['points']) + ]) +] + +train_dataloader = dict( + batch_size=2, + num_workers=4, + dataset=dict( + type='RepeatDataset', + times=1, + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='waymo_infos_train.pkl', + data_prefix=dict( + pts='training/velodyne', sweeps='training/velodyne'), + pipeline=train_pipeline, + load_interval=1))) + + +val_dataloader = dict( + dataset=dict(pipeline=test_pipeline, metainfo=dict(classes=class_names))) +test_dataloader = dict( + dataset=dict(pipeline=test_pipeline, metainfo=dict(classes=class_names))) + +custom_hooks = [ + dict(type='DisableAugmentationHook', num_last_epochs=1, skip_type_keys=('ObjectSample', 'RandomFlip3D', 'GlobalRotScaleTrans')), + dict(type='EnableFSDDetectionHookIter', enable_after_iter=4000, threshold_buffer=0.3, buffer_iter=8000) +] + +optim_wrapper = dict(optimizer=dict(lr=3e-5)) + +# runtime settings +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=12, val_interval=12) + +evaluation = dict(interval=12, pipeline=eval_pipeline) diff --git a/configs/fsd/metafile.yml b/configs/fsd/metafile.yml new file mode 100644 index 0000000000..e69de29bb2 diff --git a/mmdet3d/engine/hooks/__init__.py b/mmdet3d/engine/hooks/__init__.py index 1d47e4d549..22aa81155a 100644 --- a/mmdet3d/engine/hooks/__init__.py +++ b/mmdet3d/engine/hooks/__init__.py @@ -1,5 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .benchmark_hook import BenchmarkHook from .visualization_hook import Det3DVisualizationHook +from .fsd_hooks import DisableAugmentationHook, EnableFSDDetectionHookIter -__all__ = ['Det3DVisualizationHook', 'BenchmarkHook'] +__all__ = ['Det3DVisualizationHook', 'BenchmarkHook', + 'DisableAugmentationHook', 'EnableFSDDetectionHookIter'] diff --git a/mmdet3d/engine/hooks/fsd_hooks.py b/mmdet3d/engine/hooks/fsd_hooks.py new file mode 100644 index 0000000000..1fc1985e9c --- /dev/null +++ b/mmdet3d/engine/hooks/fsd_hooks.py @@ -0,0 +1,89 @@ +from mmengine.hooks import Hook +from mmengine.hooks.hook import DATA_BATCH + +from mmdet3d.registry import HOOKS + + +@HOOKS.register_module() +class DisableAugmentationHook(Hook): + """Switch the mode of YOLOX during training. + This hook turns off the mosaic and mixup data augmentation and switches + to use L1 loss in bbox_head. + Args: + num_last_epochs (int): The number of latter epochs in the end of the + training to close the data augmentation and switch to L1 loss. + Default: 15. + skip_type_keys (list[str], optional): Sequence of type string to be + skip pipeline. Default: ('Mosaic', 'RandomAffine', 'MixUp') + """ + + def __init__(self, + num_last_epochs=10, + skip_type_keys=('ObjectSample')): + self.num_last_epochs = num_last_epochs + self.skip_type_keys = skip_type_keys + self._restart_dataloader = False + + def before_train_epoch(self, runner): + epoch = runner.epoch # begin from 0 + train_loader = runner.train_dataloader + if epoch == runner.max_epochs - self.num_last_epochs: + runner.logger.info(f'Disable augmentations: {self.skip_type_keys}') + # The dataset pipeline cannot be updated when persistent_workers + # is True, so we need to force the dataloader's multi-process + # restart. This is a very hacky approach. + train_loader.dataset.dataset.update_skip_type_keys(self.skip_type_keys) + if hasattr(train_loader, 'persistent_workers' + ) and train_loader.persistent_workers is True: + + train_loader._DataLoader__initialized = False + train_loader._iterator = None + self._restart_dataloader = True + print('has persistent workers') + else: + # Once the restart is complete, we need to restore + # the initialization flag. + if self._restart_dataloader: + train_loader._DataLoader__initialized = True + +@HOOKS.register_module() +class EnableFSDDetectionHook(Hook): + + def __init__(self, + enable_after_epoch=1, + ): + self.enable_after_epoch = enable_after_epoch + + def before_train_epoch(self, runner): + epoch = runner.epoch # begin from 0 + if epoch == self.enable_after_epoch: + runner.logger.info(f'Enable FSD Detection from now.') + runner.model.module.runtime_info['enable_detection'] = True + +@HOOKS.register_module() +class EnableFSDDetectionHookIter(Hook): + + def __init__(self, + enable_after_iter=5000, + threshold_buffer=0, + buffer_iter=2000, + ): + self.enable_after_iter = enable_after_iter + self.buffer_iter = buffer_iter + self.delta = threshold_buffer / buffer_iter + self.threshold_buffer = threshold_buffer + + def before_train_iter(self, + runner, + batch_idx: int, + data_batch: DATA_BATCH = None): + cur_iter = runner.iter # begin from 0 + if cur_iter == self.enable_after_iter: + runner.logger.info(f'Enable FSD Detection from now.') + if cur_iter >= self.enable_after_iter: # keep the sanity when resuming model + runner.model.module.runtime_info['enable_detection'] = True + if self.threshold_buffer > 0 and cur_iter > self.enable_after_iter and cur_iter < self.enable_after_iter + self.buffer_iter: + runner.model.module.runtime_info['threshold_buffer'] = (self.enable_after_iter + self.buffer_iter - cur_iter) * self.delta + else: + # runner.hook.runtime_info['threshold_buffer'] = 0 + self.threshold_buffer = 0 \ No newline at end of file diff --git a/mmdet3d/models/backbones/__init__.py b/mmdet3d/models/backbones/__init__.py index 009a06947a..85a5ebfd66 100644 --- a/mmdet3d/models/backbones/__init__.py +++ b/mmdet3d/models/backbones/__init__.py @@ -9,9 +9,10 @@ from .pointnet2_sa_msg import PointNet2SAMSG from .pointnet2_sa_ssg import PointNet2SASSG from .second import SECOND +from .sir import SIR __all__ = [ 'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', 'NoStemRegNet', 'SECOND', 'DGCNNBackbone', 'PointNet2SASSG', 'PointNet2SAMSG', - 'MultiBackbone', 'DLANet', 'MinkResNet' + 'MultiBackbone', 'DLANet', 'MinkResNet', 'SIR' ] diff --git a/mmdet3d/models/backbones/sir.py b/mmdet3d/models/backbones/sir.py new file mode 100644 index 0000000000..f7def176df --- /dev/null +++ b/mmdet3d/models/backbones/sir.py @@ -0,0 +1,89 @@ +import copy + +import torch +import torch.nn as nn + +from mmdet3d.registry import MODELS +from .. import builder + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +@MODELS.register_module() +class SIR(nn.Module): + + def __init__( + self, + num_blocks=5, + in_channels=[], + feat_channels=[], + rel_mlp_hidden_dims=[], + with_rel_mlp=True, + with_distance=False, + with_cluster_center=False, + norm_cfg=dict(type='LN', eps=1e-3), + mode='max', + xyz_normalizer=[1.0, 1.0, 1.0], + act='relu', + dropout=0, + unique_once=False, + ): + super().__init__() + + self.num_blocks = num_blocks + self.unique_once = unique_once + + block_list = [] + for i in range(num_blocks): + return_point_feats = i != num_blocks - 1 + kwargs = dict( + type='SIRLayer', + in_channels=in_channels[i], + feat_channels=feat_channels[i], + with_distance=with_distance, + with_cluster_center=with_cluster_center, + with_rel_mlp=with_rel_mlp, + rel_mlp_hidden_dims=rel_mlp_hidden_dims[i], + with_voxel_center=False, + voxel_size=[0.1, 0.1, 0.1], # not used, placeholder + point_cloud_range=[-74.88, -74.88, -2, 74.88, 74.88, 4], # not used, placeholder + norm_cfg=norm_cfg, + mode=mode, + fusion_layer=None, + return_point_feats=return_point_feats, + return_inv=False, + rel_dist_scaler=10.0, + xyz_normalizer=xyz_normalizer, + act=act, + dropout=dropout, + ) + encoder = builder.build_voxel_encoder(kwargs) + block_list.append(encoder) + self.block_list = nn.ModuleList(block_list) + + def forward(self, points, features, coors, f_cluster=None): + + if self.unique_once: + new_coors, unq_inv = torch.unique(coors, return_inverse=True, return_counts=False, dim=0) + else: + new_coors = unq_inv = None + + out_feats = features + + cluster_feat_list = [] + for i, block in enumerate(self.block_list): + in_feats = torch.cat([points, out_feats], 1) + if i < self.num_blocks - 1: + out_feats, out_cluster_feats = block(in_feats, coors, f_cluster, unq_inv_once=unq_inv, + new_coors_once=new_coors) + cluster_feat_list.append(out_cluster_feats) + if i == self.num_blocks - 1: + out_feats, out_cluster_feats, out_coors = block(in_feats, coors, f_cluster, return_both=True, + unq_inv_once=unq_inv, new_coors_once=new_coors) + cluster_feat_list.append(out_cluster_feats) + + final_cluster_feats = torch.cat(cluster_feat_list, dim=1) + + return out_feats, final_cluster_feats, out_coors diff --git a/mmdet3d/models/decode_heads/__init__.py b/mmdet3d/models/decode_heads/__init__.py index 2e86c7c8a9..7535c7eb1b 100644 --- a/mmdet3d/models/decode_heads/__init__.py +++ b/mmdet3d/models/decode_heads/__init__.py @@ -2,5 +2,6 @@ from .dgcnn_head import DGCNNHead from .paconv_head import PAConvHead from .pointnet2_head import PointNet2Head +from .segmentation_head import VoteSegHead -__all__ = ['PointNet2Head', 'DGCNNHead', 'PAConvHead'] +__all__ = ['PointNet2Head', 'DGCNNHead', 'PAConvHead','VoteSegHead'] diff --git a/mmdet3d/models/decode_heads/segmentation_head.py b/mmdet3d/models/decode_heads/segmentation_head.py new file mode 100644 index 0000000000..8c81dc3db3 --- /dev/null +++ b/mmdet3d/models/decode_heads/segmentation_head.py @@ -0,0 +1,264 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch import nn as nn + +from .decode_head import Base3DDecodeHead +from mmengine.model import normal_init + +from mmdet3d.models.layers.sst import build_mlp, scatter_v2 +from torch.utils.checkpoint import checkpoint + +from mmdet3d.registry import MODELS +from .. import LOSSES + + +@MODELS.register_module() +class VoteSegHead(Base3DDecodeHead): + + def __init__(self, + in_channel, + num_classes, + hidden_dims=[], + dropout_ratio=0.5, + conv_cfg=dict(type='Conv1d'), + norm_cfg=dict(type='naiveSyncBN1d'), + act_cfg=dict(type='ReLU'), + loss_decode=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + class_weight=None, + loss_weight=1.0), + loss_vote=dict( + type='L1Loss', + ), + loss_aux=None, + ignore_index=255, + logit_scale=1, + checkpointing=False, + init_bias=None, + init_cfg=None): + end_channel = hidden_dims[-1] if len(hidden_dims) > 0 else in_channel + super(VoteSegHead, self).__init__( + end_channel, + num_classes, + dropout_ratio, + conv_cfg, + norm_cfg, + act_cfg, + loss_decode, + ignore_index, + init_cfg + ) + + self.pre_seg_conv = None + if len(hidden_dims) > 0: + self.pre_seg_conv = build_mlp(in_channel, hidden_dims, norm_cfg, act=act_cfg['type']) + + self.use_sigmoid = loss_decode.get('use_sigmoid', False) + self.bg_label = self.num_classes + if not self.use_sigmoid: + self.num_classes += 1 + + + self.logit_scale = logit_scale + self.conv_seg = nn.Linear(end_channel, self.num_classes) + self.voting = nn.Linear(end_channel, self.num_classes * 3) + self.fp16_enabled = False + self.checkpointing = checkpointing + self.init_bias = init_bias + + if loss_aux is not None: + self.loss_aux = LOSSES.build(loss_aux) + else: + self.loss_aux = None + if loss_decode['type'] == 'FocalLoss': + self.loss_decode = LOSSES.build(loss_decode) # mmdet has a better focal loss supporting single class + + self.loss_vote = LOSSES.build(loss_vote) + + def init_weights(self): + """Initialize weights.""" + super().init_weights() + if self.init_bias is not None: + self.conv_seg.bias.data.fill_(self.init_bias) + print(f'Segmentation Head bias is initialized to {self.init_bias}') + else: + normal_init(self.conv_seg, mean=0, std=0.01) + + # @auto_fp16(apply_to=('voxel_feat',)) + def forward(self, voxel_feat): + """Forward pass. + + """ + + output = voxel_feat + if self.pre_seg_conv is not None: + if self.checkpointing: + output = checkpoint(self.pre_seg_conv, voxel_feat) + else: + output = self.pre_seg_conv(voxel_feat) + logits = self.cls_seg(output) + vote_preds = self.voting(output) + + return logits, vote_preds + + # @force_fp32(apply_to=('seg_logit', 'vote_preds')) + def losses(self, seg_logit, vote_preds, seg_label, vote_targets, vote_mask): + """Compute semantic segmentation loss. + + Args: + seg_logit (torch.Tensor): Predicted per-point segmentation logits \ + of shape [B, num_classes, N]. + seg_label (torch.Tensor): Ground-truth segmentation label of \ + shape [B, N]. + """ + seg_logit = seg_logit * self.logit_scale + loss = dict() + loss['loss_sem_seg'] = self.loss_decode(seg_logit, seg_label) + if self.loss_aux is not None: + loss['loss_aux'] = self.loss_aux(seg_logit, seg_label) + + vote_preds = vote_preds.reshape(-1, self.num_classes, 3) + if not self.use_sigmoid: + assert seg_label.max().item() == self.num_classes - 1 + else: + assert seg_label.max().item() == self.num_classes + valid_vote_preds = vote_preds[vote_mask] # [n_valid, num_cls, 3] + valid_vote_preds = valid_vote_preds.reshape(-1, 3) + num_valid = vote_mask.sum() + + valid_label = seg_label[vote_mask] + + if num_valid > 0: + assert valid_label.max().item() < self.num_classes + assert valid_label.min().item() >= 0 + + indices = torch.arange(num_valid, device=valid_label.device) * self.num_classes + valid_label + valid_vote_preds = valid_vote_preds[indices, :] #[n_valid, 3] + + valid_vote_targets = vote_targets[vote_mask] + + loss['loss_vote'] = self.loss_vote(valid_vote_preds, valid_vote_targets) + else: + loss['loss_vote'] = vote_preds.sum() * 0 + + train_cfg = self.train_cfg + if train_cfg.get('score_thresh', None) is not None: + score_thresh = train_cfg['score_thresh'] + if self.use_sigmoid: + scores = seg_logit.sigmoid() + for i in range(len(score_thresh)): + thr = score_thresh[i] + name = train_cfg['class_names'][i] + this_scores = scores[:, i] + pred_true = this_scores > thr + real_true = seg_label == i + tp = (pred_true & real_true).sum().float() + loss[f'recall_{name}'] = tp / (real_true.sum().float() + 1e-5) + else: + score = seg_logit.softmax(1) + group_lens = train_cfg['group_lens'] + group_score = self.gather_group(score[:, :-1], group_lens) + num_fg = score.new_zeros(1) + for gi in range(len(group_lens)): + pred_true = group_score[:, gi] > score_thresh[gi] + num_fg += pred_true.sum().float() + for i in range(group_lens[gi]): + name = train_cfg['group_names'][gi][i] + real_true = seg_label == train_cfg['class_names'].index(name) + tp = (pred_true & real_true).sum().float() + loss[f'recall_{name}'] = tp / (real_true.sum().float() + 1e-5) + loss[f'num_fg'] = num_fg + + return loss + + def forward_train(self, inputs, img_metas, pts_semantic_mask, vote_targets, vote_mask, return_preds=False): + + seg_logits, vote_preds = self.forward(inputs) + losses = self.losses(seg_logits, vote_preds, pts_semantic_mask, vote_targets, vote_mask) + if return_preds: + return losses, dict(seg_logits=seg_logits, vote_preds=vote_preds) + else: + return losses + + def gather_group(self, scores, group_lens): + assert (scores >= 0).all() + score_per_group = [] + beg = 0 + for group_len in group_lens: + end = beg + group_len + score_this_g = scores[:, beg:end].sum(1) + score_per_group.append(score_this_g) + beg = end + assert end == scores.size(1) == sum(group_lens) + gathered_score = torch.stack(score_per_group, dim=1) + assert gathered_score.size(1) == len(group_lens) + return gathered_score + + def get_targets(self, points_list, gt_bboxes_list, gt_labels_list): + bsz = len(points_list) + label_list = [] + vote_target_list = [] + vote_mask_list = [] + + for i in range(bsz): + + points = points_list[i][:, :3] + bboxes = gt_bboxes_list[i] + bbox_labels = gt_labels_list[i] + + # if self.num_classes < 3: # I don't know why there are some -1 labels when train car-only model. + valid_gt_mask = bbox_labels >= 0 + bboxes = bboxes[valid_gt_mask] + bbox_labels = bbox_labels[valid_gt_mask] + + if len(bbox_labels) == 0: + this_label = torch.ones(len(points), device=points.device, dtype=torch.long) * self.bg_label + this_vote_target = torch.zeros_like(points) + vote_mask = torch.zeros_like(this_label).bool() + else: + extra_width = self.train_cfg.get('extra_width', None) + if extra_width is not None: + bboxes = bboxes.enlarged_box_hw(extra_width) + inbox_inds = bboxes.points_in_boxes(points).long() + this_label = self.get_point_labels(inbox_inds, bbox_labels) + this_vote_target, vote_mask = self.get_vote_target(inbox_inds, points, bboxes) + + label_list.append(this_label) + vote_target_list.append(this_vote_target) + vote_mask_list.append(vote_mask) + + labels = torch.cat(label_list, dim=0) + vote_targets = torch.cat(vote_target_list, dim=0) + vote_mask = torch.cat(vote_mask_list, dim=0) + + return labels, vote_targets, vote_mask + + + def get_point_labels(self, inbox_inds, bbox_labels): + + bg_mask = inbox_inds < 0 + label = -1 * torch.ones(len(inbox_inds), dtype=torch.long, device=inbox_inds.device) + class_labels = bbox_labels[inbox_inds] + class_labels[bg_mask] = self.bg_label + return class_labels + + def get_vote_target(self, inbox_inds, points, bboxes): + + bg_mask = inbox_inds < 0 + if self.train_cfg.get('centroid_offset', False): + centroid, _, inv = scatter_v2(points, inbox_inds, mode='avg', return_inv=True) + center_per_point = centroid[inv] + else: + center_per_point = bboxes.gravity_center[inbox_inds] + delta = center_per_point.to(points.device) - points + delta[bg_mask] = 0 + target = self.encode_vote_targets(delta) + vote_mask = ~bg_mask + return target, vote_mask + + def encode_vote_targets(self, delta): + return torch.sign(delta) * (delta.abs() ** 0.5) + + def decode_vote_targets(self, preds): + return preds * preds.abs() diff --git a/mmdet3d/models/dense_heads/__init__.py b/mmdet3d/models/dense_heads/__init__.py index 4d6385341b..7f5e39c6c3 100644 --- a/mmdet3d/models/dense_heads/__init__.py +++ b/mmdet3d/models/dense_heads/__init__.py @@ -17,11 +17,14 @@ from .smoke_mono3d_head import SMOKEMono3DHead from .ssd_3d_head import SSD3DHead from .vote_head import VoteHead +from .sparse_cluster_head import SparseClusterHead +from .sparse_cluster_head_v2 import SparseClusterHeadV2 __all__ = [ 'Anchor3DHead', 'FreeAnchor3DHead', 'PartA2RPNHead', 'VoteHead', 'SSD3DHead', 'BaseConvBboxHead', 'CenterHead', 'ShapeAwareHead', 'BaseMono3DDenseHead', 'AnchorFreeMono3DHead', 'FCOSMono3DHead', 'GroupFree3DHead', 'PointRPNHead', 'SMOKEMono3DHead', 'PGDHead', - 'MonoFlexHead', 'Base3DDenseHead', 'FCAF3DHead' + 'MonoFlexHead', 'Base3DDenseHead', 'FCAF3DHead', 'SparseClusterHead', + 'SparseClusterHeadV2' ] diff --git a/mmdet3d/models/dense_heads/sparse_cluster_head.py b/mmdet3d/models/dense_heads/sparse_cluster_head.py new file mode 100644 index 0000000000..756460abd1 --- /dev/null +++ b/mmdet3d/models/dense_heads/sparse_cluster_head.py @@ -0,0 +1,576 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import torch +import torch.nn as nn +from mmdet.models.task_modules import AssignResult +from mmdet.models.utils import multi_apply +from mmdet.utils import reduce_mean +from mmengine.model import BaseModule + +from mmdet3d.models.builder import build_loss +from mmdet3d.models.layers import box3d_multiclass_nms +from mmdet3d.models.layers.sst import build_mlp +from mmdet3d.models.task_modules import PseudoSampler +from mmdet3d.models.task_modules.builder import build_bbox_coder +from mmdet3d.registry import MODELS +from mmdet3d.structures import xywhr2xyxyr, LiDARInstance3DBoxes, bbox_overlaps_3d + + +@MODELS.register_module() +class SparseClusterHead(BaseModule): + + def __init__(self, + num_classes, + bbox_coder, + loss_cls, + loss_center, + loss_size, + loss_rot, + in_channel, + shared_mlp_dims, + shared_dropout=0, + cls_mlp=None, + reg_mlp=None, + iou_mlp=None, + train_cfg=None, + test_cfg=None, + norm_cfg=dict(type='LN'), + loss_iou=None, + act='relu', + corner_loss_cfg=None, + enlarge_width=None, + as_rpn=False, + init_cfg=None): + super(SparseClusterHead, self).__init__(init_cfg=init_cfg) + + self.print_info = {} + self.loss_center = build_loss(loss_center) + self.loss_size = build_loss(loss_size) + self.loss_rot = build_loss(loss_rot) + self.loss_cls = build_loss(loss_cls) + self.bbox_coder = build_bbox_coder(bbox_coder) + self.box_code_size = self.bbox_coder.code_size + self.corner_loss_cfg = corner_loss_cfg + self.num_classes = num_classes + self.enlarge_width = enlarge_width + self.sampler = PseudoSampler() + self.sync_reg_avg_factor = False if train_cfg is None else train_cfg.get('sync_reg_avg_factor', True) + self.sync_cls_avg_factor = False if train_cfg is None else train_cfg.get('sync_cls_avg_factor', False) + self.as_rpn = as_rpn + if train_cfg is not None: + self.cfg = self.train_cfg = train_cfg + if test_cfg is not None: + self.cfg = self.test_cfg = test_cfg + + self.num_anchors = num_anchors = 1 # deprecated due to removing assign twice + + + if loss_iou is not None: + self.loss_iou = build_loss(loss_iou) + # self.loss_iou = nn.binary_cross_entropy_with_logits + else: + self.loss_iou = None + + self.fp16_enabled = False + + # Bbox classification and regression + self.shared_mlp = None + if len(shared_mlp_dims) > 0: + self.shared_mlp = build_mlp(in_channel, shared_mlp_dims, norm_cfg, act=act, dropout=shared_dropout) + + + end_channel = shared_mlp_dims[-1] if len(shared_mlp_dims) > 0 else in_channel + + if cls_mlp is not None: + self.conv_cls = build_mlp(end_channel, cls_mlp + [num_classes * num_anchors,], norm_cfg, True, act=act) + else: + self.conv_cls = nn.Linear(end_channel, num_classes * num_anchors) + + if reg_mlp is not None: + self.conv_reg = build_mlp(end_channel, reg_mlp + [self.box_code_size * num_anchors,], norm_cfg, True, act=act) + else: + self.conv_reg = nn.Linear(end_channel, self.box_code_size * num_anchors) + + if loss_iou is not None: + if iou_mlp is not None: + self.conv_iou = build_mlp(end_channel, iou_mlp + [1,], norm_cfg, True, act=act) + else: + self.conv_iou = nn.Linear(end_channel, 1) + + self.save_list = [] + + + def forward(self, feats, pts_xyz=None, pts_inds=None): + + if self.shared_mlp is not None: + feats = self.shared_mlp(feats) + + cls_logits = self.conv_cls(feats) + reg_preds = self.conv_reg(feats) + outs = dict( + cls_logits=cls_logits, + reg_preds=reg_preds, + ) + if self.loss_iou is not None: + outs['iou_logits'] = self.conv_iou(feats) + + return outs + + # @force_fp32(apply_to=('cls_logits', 'reg_preds', 'cluster_xyz')) + def loss(self, + cls_logits, + reg_preds, + cluster_xyz, + cluster_inds, + gt_bboxes_3d, + gt_labels_3d, + img_metas=None, + iou_logits=None, + gt_bboxes_ignore=None, + ): + + if iou_logits is not None and iou_logits.dtype == torch.float16: + iou_logits = iou_logits.to(torch.float) + + cluster_batch_idx = cluster_inds[:, 1] + num_total_samples = len(reg_preds) + + targets = self.get_targets(cluster_xyz, cluster_batch_idx, gt_bboxes_3d, gt_labels_3d, reg_preds) + labels, label_weights, bbox_targets, bbox_weights, iou_labels = targets + assert (label_weights == 1).all(), 'for now' + + cls_avg_factor = num_total_samples * 1.0 + if self.sync_cls_avg_factor: + cls_avg_factor = reduce_mean( + bbox_weights.new_tensor([cls_avg_factor])) + + loss_cls = self.loss_cls( + cls_logits, labels, label_weights, avg_factor=cls_avg_factor) + + # regression loss + pos_inds = ((labels >= 0)& (labels < self.num_classes)).nonzero(as_tuple=False).reshape(-1) + num_pos = len(pos_inds) + assert num_pos == bbox_weights.sum() / self.box_code_size + + pos_reg_preds = reg_preds[pos_inds] + pos_bbox_targets = bbox_targets[pos_inds] + pos_bbox_weights = bbox_weights[pos_inds] + assert (pos_bbox_weights > 0).all() + + reg_avg_factor = num_pos * 1.0 + if self.sync_reg_avg_factor: + reg_avg_factor = reduce_mean( + bbox_weights.new_tensor([reg_avg_factor])) + + if num_pos > 0: + code_weight = self.train_cfg.get('code_weight', None) + if code_weight: + pos_bbox_weights = pos_bbox_weights * bbox_weights.new_tensor( + code_weight)[None, :] + + + loss_center = self.loss_center( + pos_reg_preds[:, :3], + pos_bbox_targets[:, :3], + pos_bbox_weights[:, :3], + avg_factor=reg_avg_factor) + loss_size = self.loss_size( + pos_reg_preds[:, 3:6], + pos_bbox_targets[:, 3:6], + pos_bbox_weights[:, 3:6], + avg_factor=reg_avg_factor) + loss_rot = self.loss_rot( + pos_reg_preds[:, 6:8], + pos_bbox_targets[:, 6:8], + pos_bbox_weights[:, 6:8], + avg_factor=reg_avg_factor) + else: + loss_center = pos_reg_preds.sum() * 0 + loss_size = pos_reg_preds.sum() * 0 + loss_rot = pos_reg_preds.sum() * 0 + + losses = dict( + loss_cls=loss_cls, + loss_center=loss_center, + loss_size=loss_size, + loss_rot=loss_rot, + ) + + if self.corner_loss_cfg is not None: + losses['loss_corner'] = self.get_corner_loss(pos_reg_preds, pos_bbox_targets, cluster_xyz[pos_inds], reg_avg_factor) + + if self.loss_iou is not None: + losses['loss_iou'] = self.loss_iou(iou_logits.reshape(-1), iou_labels, label_weights, avg_factor=cls_avg_factor) + losses['max_iou'] = iou_labels.max() + losses['mean_iou'] = iou_labels[iou_labels > 0].mean() + + return losses + + def get_corner_loss(self, reg_preds, bbox_targets, base_points, reg_avg_factor): + if len(base_points) == 0: + return base_points.new_zeros(1).sum() + dets = self.bbox_coder.decode(reg_preds, base_points, self.corner_loss_cfg.get('detach_yaw', True)) + gts = self.bbox_coder.decode(bbox_targets, base_points) + corner_loss = self.corner_loss(dets, gts, self.corner_loss_cfg['delta']).sum() + corner_loss = corner_loss.sum() / reg_avg_factor * self.corner_loss_cfg['loss_weight'] + return corner_loss + + def corner_loss(self, pred_bbox3d, gt_bbox3d, delta=1): + """Calculate corner loss of given boxes. + + Args: + pred_bbox3d (torch.FloatTensor): Predicted boxes in shape (N, 7). + gt_bbox3d (torch.FloatTensor): Ground truth boxes in shape (N, 7). + + Returns: + torch.FloatTensor: Calculated corner loss in shape (N). + """ + assert pred_bbox3d.shape[0] == gt_bbox3d.shape[0] + + gt_boxes_structure = LiDARInstance3DBoxes(gt_bbox3d) + pred_box_corners = LiDARInstance3DBoxes(pred_bbox3d).corners + gt_box_corners = gt_boxes_structure.corners + + # This flip only changes the heading direction of GT boxes + gt_bbox3d_flip = gt_boxes_structure.clone() + gt_bbox3d_flip.tensor[:, 6] += np.pi + gt_box_corners_flip = gt_bbox3d_flip.corners + + corner_dist = torch.min( + torch.norm(pred_box_corners - gt_box_corners, dim=2), + torch.norm(pred_box_corners - gt_box_corners_flip, + dim=2)) # (N, 8) + # huber loss + abs_error = torch.abs(corner_dist) + quadratic = torch.clamp(abs_error, max=delta) + linear = (abs_error - quadratic) + corner_loss = 0.5 * quadratic**2 + delta * linear + + return corner_loss.mean(1) + + + def get_targets(self, + cluster_xyz, + batch_idx, + gt_bboxes_3d, + gt_labels_3d, + reg_preds=None): + batch_size = len(gt_bboxes_3d) + cluster_xyz_list = self.split_by_batch(cluster_xyz, batch_idx, batch_size) + + if reg_preds is not None: + reg_preds_list = self.split_by_batch(reg_preds, batch_idx, batch_size) + else: + reg_preds_list = [None,] * len(cluster_xyz_list) + + target_list_per_sample = multi_apply(self.get_targets_single, cluster_xyz_list, gt_bboxes_3d, gt_labels_3d, reg_preds_list) + targets = [self.combine_by_batch(t, batch_idx, batch_size) for t in target_list_per_sample] + # targets == [labels, label_weights, bbox_targets, bbox_weights] + return targets + + def split_by_batch(self, data, batch_idx, batch_size): + if self.training: + assert batch_idx.max().item() + 1 <= batch_size + if batch_size == 1: + return [data, ] + data_list = [] + for i in range(batch_size): + sample_mask = batch_idx == i + data_list.append(data[sample_mask]) + return data_list + + def combine_by_batch(self, data_list, batch_idx, batch_size): + assert len(data_list) == batch_size + if data_list[0] is None: + return None + data_shape = (len(batch_idx),) + data_list[0].shape[1:] + full_data = data_list[0].new_zeros(data_shape) + for i, data in enumerate(data_list): + sample_mask = batch_idx == i + full_data[sample_mask] = data + return full_data + + + def get_targets_single(self, + cluster_xyz, + gt_bboxes_3d, + gt_labels_3d, + reg_preds=None): + """Generate targets of vote head for single batch. + + """ + valid_gt_mask = gt_labels_3d >= 0 + gt_bboxes_3d = gt_bboxes_3d[valid_gt_mask] + gt_labels_3d = gt_labels_3d[valid_gt_mask] + + gt_bboxes_3d = gt_bboxes_3d.to(cluster_xyz.device) + if self.train_cfg.get('assign_by_dist', False): + assign_result = self.assign_by_dist_single(cluster_xyz, gt_bboxes_3d, gt_labels_3d) + else: + assign_result = self.assign_single(cluster_xyz, gt_bboxes_3d, gt_labels_3d) + + # Do not put this before assign + + sample_result = self.sampler.sample(assign_result, cluster_xyz, gt_bboxes_3d.tensor) # Pseudo Sampler, use cluster_xyz as pseudo bbox here. + + pos_inds = sample_result.pos_inds + neg_inds = sample_result.neg_inds + + # label targets + num_cluster = len(cluster_xyz) + labels = gt_labels_3d.new_full((num_cluster, ), self.num_classes, dtype=torch.long) + labels[pos_inds] = gt_labels_3d[sample_result.pos_assigned_gt_inds] + assert (labels >= 0).all() + label_weights = cluster_xyz.new_ones(num_cluster) + + # bbox targets + bbox_targets = cluster_xyz.new_zeros((num_cluster, self.box_code_size)) + + bbox_weights = cluster_xyz.new_zeros((num_cluster, self.box_code_size)) + bbox_weights[pos_inds] = 1.0 + + bbox_targets[pos_inds] = self.bbox_coder.encode(sample_result.pos_gt_bboxes, cluster_xyz[pos_inds]) + + if self.loss_iou is not None: + iou_labels = self.get_iou_labels(reg_preds, cluster_xyz, gt_bboxes_3d.tensor, pos_inds) + else: + iou_labels = None + + return labels, label_weights, bbox_targets, bbox_weights, iou_labels + + def get_iou_labels(self, reg_preds, cluster_xyz, gt_bboxes_3d, pos_inds): + assert reg_preds is not None + num_pos = len(pos_inds) + num_preds = len(reg_preds) + if num_pos == 0: + return cluster_xyz.new_zeros(num_preds) + bbox_preds = self.bbox_coder.decode(reg_preds, cluster_xyz) + ious = bbox_overlaps_3d(bbox_preds, gt_bboxes_3d, mode='iou', coordinate='lidar') #[num_preds, num_gts] + ious = ious.max(1)[0] + if not ((ious >= 0) & (ious <= 1)).all(): + print(f'*************** Got illegal iou:{ious.min()} or {ious.max()}') + ious = torch.clamp(ious, min=0, max=1) + + iou_bg_thresh = self.train_cfg.iou_bg_thresh + iou_fg_thresh = self.train_cfg.iou_fg_thresh + fg_mask = ious > iou_fg_thresh + bg_mask = ious < iou_bg_thresh + interval_mask = (fg_mask == 0) & (bg_mask == 0) + + iou_labels = (fg_mask > 0).float() + iou_labels[interval_mask] = \ + (ious[interval_mask] - iou_bg_thresh) / (iou_fg_thresh - iou_bg_thresh) + return iou_labels + + + def assign_single(self, + cluster_xyz, + gt_bboxes_3d, + gt_labels_3d, + ): + """Generate targets of vote head for single batch. + + """ + + num_cluster = cluster_xyz.size(0) + num_gts = gt_bboxes_3d.tensor.size(0) + + # initialize as all background + assigned_gt_inds = cluster_xyz.new_zeros((num_cluster, ), dtype=torch.long) # 0 indicates assign to backgroud + assigned_labels = cluster_xyz.new_full((num_cluster, ), -1, dtype=torch.long) + + if num_gts == 0 or num_cluster == 0: + # No ground truth or cluster, return empty assignment + if num_gts == 0: + # No ground truth, assign all to background + assigned_gt_inds[:] = 0 + return AssignResult( + num_gts, assigned_gt_inds, None, labels=assigned_labels) + + enlarged_box = self.enlarge_gt_bboxes(gt_bboxes_3d) + inbox_inds = enlarged_box.points_in_boxes(cluster_xyz).long() + inbox_inds = self.dist_constrain(inbox_inds, cluster_xyz, gt_bboxes_3d, gt_labels_3d) + pos_cluster_mask = inbox_inds > -1 + + if pos_cluster_mask.any(): + assigned_gt_inds[pos_cluster_mask] = inbox_inds[pos_cluster_mask] + 1 + assigned_labels[pos_cluster_mask] = gt_labels_3d[inbox_inds[pos_cluster_mask]] + + return AssignResult(num_gts, assigned_gt_inds, None, labels=assigned_labels) + + def assign_by_dist_single(self, + cluster_xyz, + gt_bboxes_3d, + gt_labels_3d, + ): + """Generate targets of vote head for single batch. + + """ + + num_cluster = cluster_xyz.size(0) + num_gts = gt_bboxes_3d.tensor.size(0) + + # initialize as all background + assigned_gt_inds = cluster_xyz.new_zeros((num_cluster, ), dtype=torch.long) # 0 indicates assign to backgroud + assigned_labels = cluster_xyz.new_full((num_cluster, ), -1, dtype=torch.long) + + if num_gts == 0 or num_cluster == 0: + # No ground truth or cluster, return empty assignment + if num_gts == 0: + # No ground truth, assign all to background + assigned_gt_inds[:] = 0 + return AssignResult( + num_gts, assigned_gt_inds, None, labels=assigned_labels) + + gt_centers = gt_bboxes_3d.gravity_center[None, :, :2] + pd_xy = cluster_xyz[None, :, :2] + dist_mat = torch.cdist(pd_xy, gt_centers).squeeze(0) + max_dist = self.train_cfg['max_dist'] + min_dist_v, matched_gt_inds = torch.min(dist_mat, dim=1) + + dist_mat[list(range(num_cluster//2)), matched_gt_inds] = 1e6 + + matched_gt_inds[min_dist_v >= max_dist] = -1 + pos_cluster_mask = matched_gt_inds > -1 + + # log + num_matched_gt = len(torch.unique(matched_gt_inds)) - 1 + num_matched_gt = torch.tensor(num_matched_gt, dtype=torch.float, device=cluster_xyz.device) + if torch.distributed.is_available() and torch.distributed.is_initialized(): + torch.distributed.all_reduce(num_matched_gt) + + num_gts_t = torch.tensor(num_gts, dtype=torch.float, device=cluster_xyz.device) + if torch.distributed.is_available() and torch.distributed.is_initialized(): + torch.distributed.all_reduce(num_gts_t) + self.print_info['assign_recall'] = num_matched_gt / (num_gts_t + 1 + 1e-5) + # end log + + + if pos_cluster_mask.any(): + assigned_gt_inds[pos_cluster_mask] = matched_gt_inds[pos_cluster_mask] + 1 + assigned_labels[pos_cluster_mask] = gt_labels_3d[matched_gt_inds[pos_cluster_mask]] + + return AssignResult(num_gts, assigned_gt_inds, None, labels=assigned_labels) + + # generate votes target + def enlarge_gt_bboxes(self, gt_bboxes_3d, gt_labels_3d=None): + if self.enlarge_width is not None: + return gt_bboxes_3d.enlarged_box(self.enlarge_width) + else: + return gt_bboxes_3d + + def dist_constrain(self, inbox_inds, cluster_xyz, gt_bboxes_3d, gt_labels_3d): + + inbox_inds = inbox_inds.clone() + max_dist = self.train_cfg.get('max_assign_dist', None) + if max_dist is None: + return inbox_inds + + if not (inbox_inds > -1).any(): + return inbox_inds + + pos_mask = inbox_inds > -1 + pos_inds = inbox_inds[pos_mask].clone() + pos_xyz = cluster_xyz[pos_mask] + pos_labels = gt_labels_3d[pos_inds] + pos_box_center = gt_bboxes_3d.gravity_center[pos_inds] + rel_dist = torch.linalg.norm(pos_xyz[:, :2] - pos_box_center[:, :2], ord=2, dim=1) # only xy-dist + thresh = torch.zeros_like(rel_dist) + assert len(max_dist) == self.num_classes + for i in range(self.num_classes): + thresh[pos_labels == i] = max_dist[i] + + pos_inds[rel_dist > thresh] = -1 + inbox_inds[pos_mask] = pos_inds + return inbox_inds + + + @torch.no_grad() + def get_bboxes(self, + cls_logits, + reg_preds, + cluster_xyz, + cluster_inds, + input_metas, + iou_logits=None, + rescale=False, + ): + + + batch_inds = cluster_inds[:, 1] + batch_size = len(input_metas) + cls_logits_list = self.split_by_batch(cls_logits, batch_inds, batch_size) + reg_preds_list = self.split_by_batch(reg_preds, batch_inds, batch_size) + cluster_xyz_list = self.split_by_batch(cluster_xyz, batch_inds, batch_size) + + if iou_logits is not None: + iou_logits_list = self.split_by_batch(iou_logits, batch_inds, batch_size) + else: + iou_logits_list = [None,] * len(cls_logits_list) + + multi_results = multi_apply( + self._get_bboxes_single, + cls_logits_list, + iou_logits_list, + reg_preds_list, + cluster_xyz_list, + input_metas + ) + # out_bboxes_list, out_scores_list, out_labels_list = multi_results + results_list = [(b, s, l) for b, s, l in zip(*multi_results)] + return results_list + + + def _get_bboxes_single( + self, + cls_logits, + iou_logits, + reg_preds, + cluster_xyz, + input_meta, + ): + ''' + Get bboxes of a sample + ''' + + if self.as_rpn: + cfg = self.train_cfg.rpn if self.training else self.test_cfg.rpn + else: + cfg = self.test_cfg + + assert cls_logits.size(0) == reg_preds.size(0) == cluster_xyz.size(0) + assert cls_logits.size(1) == self.num_classes + assert reg_preds.size(1) == self.box_code_size + + scores = cls_logits.sigmoid() + + if iou_logits is not None: + iou_scores = iou_logits.sigmoid() + a = cfg.get('iou_score_weight', 0.5) + scores = (scores ** (1 - a)) * (iou_scores ** a) + + nms_pre = cfg.get('nms_pre', -1) + if nms_pre > 0 and scores.shape[0] > nms_pre: + max_scores, _ = scores.max(dim=1) + _, topk_inds = max_scores.topk(nms_pre) + reg_preds = reg_preds[topk_inds, :] + scores = scores[topk_inds, :] + cluster_xyz = cluster_xyz[topk_inds, :] + + bboxes = self.bbox_coder.decode(reg_preds, cluster_xyz) + bboxes_for_nms = xywhr2xyxyr(input_meta['box_type_3d'](bboxes).bev) + + # Add a dummy background class to the front when using sigmoid + padding = scores.new_zeros(scores.shape[0], 1) + scores = torch.cat([scores, padding], dim=1) + + score_thr = cfg.get('score_thr', 0) + results = box3d_multiclass_nms(bboxes, bboxes_for_nms, + scores, score_thr, cfg.max_num, + cfg) + + out_bboxes, out_scores, out_labels = results + + out_bboxes = input_meta['box_type_3d'](out_bboxes) + + return (out_bboxes, out_scores, out_labels) diff --git a/mmdet3d/models/dense_heads/sparse_cluster_head_v2.py b/mmdet3d/models/dense_heads/sparse_cluster_head_v2.py new file mode 100644 index 0000000000..8cca1f1e75 --- /dev/null +++ b/mmdet3d/models/dense_heads/sparse_cluster_head_v2.py @@ -0,0 +1,578 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy + +import torch +import torch.nn as nn +from mmengine.model import BaseModule +from mmdet.models.utils import multi_apply +from mmdet.utils import reduce_mean +from mmdet3d.structures import xywhr2xyxyr, LiDARInstance3DBoxes +from mmdet3d.models.layers.sst import build_mlp +from mmdet3d.models.layers import box3d_multiclass_nms + +from mmdet3d.models import builder +from mmdet3d.models.builder import build_loss +from mmdet3d.registry import MODELS +from .sparse_cluster_head import SparseClusterHead + + +@MODELS.register_module() +class FSDSeparateHead(BaseModule): + + def __init__( + self, + in_channels, + attrs, + norm_cfg=dict(type='LN'), + act='relu', + init_cfg=None, + ): + super().__init__(init_cfg=init_cfg) + self.attrs = attrs + for attr_name in self.attrs: + out_dim, num_layer, hidden_dim = self.attrs[attr_name] + mlp = build_mlp(in_channels, [hidden_dim,] * num_layer + [out_dim,], norm_cfg, is_head=True, act=act) + self.__setattr__(attr_name, mlp) + + + def forward(self, x): + ret_dict = dict() + for attr_name in self.attrs: + ret_dict[attr_name] = self.__getattr__(attr_name)(x) + + return ret_dict + + +@MODELS.register_module() +class SparseClusterHeadV2(SparseClusterHead): + + def __init__(self, + num_classes, + bbox_coder, + loss_cls, + loss_center, + loss_size, + loss_rot, + in_channel, + shared_mlp_dims, + tasks, + class_names, + common_attrs, + num_cls_layer, + cls_hidden_dim, + separate_head, + cls_mlp=None, + reg_mlp=None, + iou_mlp=None, + train_cfg=None, + test_cfg=None, + norm_cfg=dict(type='LN'), + loss_iou=None, + act='relu', + corner_loss_cfg=None, + enlarge_width=None, + as_rpn=False, + init_cfg=None, + shared_dropout=0, + loss_vel=None, + ): + super().__init__( + num_classes, + bbox_coder, + loss_cls, + loss_center, + loss_size, + loss_rot, + in_channel, + shared_mlp_dims, + shared_dropout, + cls_mlp, + reg_mlp, + iou_mlp, + train_cfg, + test_cfg, + norm_cfg, + loss_iou, + act, + corner_loss_cfg, + enlarge_width, + as_rpn, + init_cfg + ) + + # override + self.conv_cls = None + self.conv_reg = None + + if self.shared_mlp is not None: + sep_head_in_channels = shared_mlp_dims[-1] + else: + sep_head_in_channels = in_channel + self.tasks = tasks + self.task_heads = nn.ModuleList() + + for t in tasks: + num_cls = len(t['class_names']) + attrs = copy.deepcopy(common_attrs) + attrs.update(dict(score=(num_cls, num_cls_layer, cls_hidden_dim), )) + separate_head.update( + in_channels=sep_head_in_channels, attrs=attrs) + self.task_heads.append(builder.build_head(separate_head)) + + self.class_names = class_names + all_names = [] + for t in tasks: + all_names += t['class_names'] + + assert all_names == class_names + + if loss_vel is not None: + self.loss_vel = build_loss(loss_vel) + else: + self.loss_vel = None + + + + def forward(self, feats, pts_xyz=None, pts_inds=None): + + if self.shared_mlp is not None: + feats = self.shared_mlp(feats) + + cls_logit_list = [] + reg_pred_list = [] + for h in self.task_heads: + ret_dict = h(feats) + + # keep consistent with v1, combine the regression prediction + cls_logit = ret_dict['score'] + if 'vel' in ret_dict: + reg_pred = torch.cat([ret_dict['center'], ret_dict['dim'], ret_dict['rot'], ret_dict['vel']], dim=-1) + else: + reg_pred = torch.cat([ret_dict['center'], ret_dict['dim'], ret_dict['rot']], dim=-1) + cls_logit_list.append(cls_logit) + reg_pred_list.append(reg_pred) + + outs = dict( + cls_logits=cls_logit_list, + reg_preds=reg_pred_list, + ) + + return outs + + # @force_fp32(apply_to=('cls_logits', 'reg_preds', 'cluster_xyz')) + def loss( + self, + cls_logits, + reg_preds, + cluster_xyz, + cluster_inds, + gt_bboxes_3d, + gt_labels_3d, + img_metas=None, + iou_logits=None, + gt_bboxes_ignore=None, + ): + assert isinstance(cls_logits, list) + assert isinstance(reg_preds, list) + assert len(cls_logits) == len(reg_preds) == len(self.tasks) + all_task_losses = {} + for i in range(len(self.tasks)): + losses_this_task = self.loss_single_task( + i, + cls_logits[i], + reg_preds[i], + cluster_xyz, + cluster_inds, + gt_bboxes_3d, + gt_labels_3d, + iou_logits, + ) + all_task_losses.update(losses_this_task) + return all_task_losses + + + def loss_single_task( + self, + task_id, + cls_logits, + reg_preds, + cluster_xyz, + cluster_inds, + gt_bboxes_3d, + gt_labels_3d, + iou_logits=None, + ): + + + gt_bboxes_3d, gt_labels_3d = self.modify_gt_for_single_task(gt_bboxes_3d, gt_labels_3d, task_id) + + if iou_logits is not None and iou_logits.dtype == torch.float16: + iou_logits = iou_logits.to(torch.float) + + if cluster_inds.ndim == 1: + cluster_batch_idx = cluster_inds + else: + cluster_batch_idx = cluster_inds[:, 1] + + num_total_samples = len(reg_preds) + + num_task_classes = len(self.tasks[task_id]['class_names']) + targets = self.get_targets(num_task_classes, cluster_xyz, cluster_batch_idx, gt_bboxes_3d, gt_labels_3d, reg_preds) + labels, label_weights, bbox_targets, bbox_weights, iou_labels = targets + assert (label_weights == 1).all(), 'for now' + + cls_avg_factor = num_total_samples * 1.0 + if self.sync_cls_avg_factor: + cls_avg_factor = reduce_mean( + bbox_weights.new_tensor([cls_avg_factor])) + + loss_cls = self.loss_cls( + cls_logits, labels, label_weights, avg_factor=cls_avg_factor) + + # regression loss + pos_inds = ((labels >= 0)& (labels < num_task_classes)).nonzero(as_tuple=False).reshape(-1) + num_pos = len(pos_inds) + + pos_reg_preds = reg_preds[pos_inds] + pos_bbox_targets = bbox_targets[pos_inds] + pos_bbox_weights = bbox_weights[pos_inds] + + reg_avg_factor = num_pos * 1.0 + if self.sync_reg_avg_factor: + reg_avg_factor = reduce_mean( + bbox_weights.new_tensor([reg_avg_factor])) + + if num_pos > 0: + code_weight = self.train_cfg.get('code_weight', None) + if code_weight: + pos_bbox_weights = pos_bbox_weights * bbox_weights.new_tensor( + code_weight) + + + loss_center = self.loss_center( + pos_reg_preds[:, :3], + pos_bbox_targets[:, :3], + pos_bbox_weights[:, :3], + avg_factor=reg_avg_factor) + loss_size = self.loss_size( + pos_reg_preds[:, 3:6], + pos_bbox_targets[:, 3:6], + pos_bbox_weights[:, 3:6], + avg_factor=reg_avg_factor) + loss_rot = self.loss_rot( + pos_reg_preds[:, 6:8], + pos_bbox_targets[:, 6:8], + pos_bbox_weights[:, 6:8], + avg_factor=reg_avg_factor) + if self.loss_vel is not None: + loss_vel = self.loss_vel( + pos_reg_preds[:, 8:10], + pos_bbox_targets[:, 8:10], + pos_bbox_weights[:, 8:10], + ) + else: + loss_center = pos_reg_preds.sum() * 0 + loss_size = pos_reg_preds.sum() * 0 + loss_rot = pos_reg_preds.sum() * 0 + if self.loss_vel is not None: + loss_vel = pos_reg_preds.sum() * 0 + + losses = dict( + loss_cls=loss_cls, + loss_center=loss_center, + loss_size=loss_size, + loss_rot=loss_rot, + ) + if self.loss_vel is not None: + losses['loss_vel'] = loss_vel + + if self.corner_loss_cfg is not None: + losses['loss_corner'] = self.get_corner_loss(pos_reg_preds, pos_bbox_targets, cluster_xyz[pos_inds], reg_avg_factor) + + if self.loss_iou is not None: + losses['loss_iou'] = self.loss_iou(iou_logits.reshape(-1), iou_labels, label_weights, avg_factor=cls_avg_factor) + losses['max_iou'] = iou_labels.max() + losses['mean_iou'] = iou_labels[iou_labels > 0].mean() + + losses_with_task_id = {k + '.task' + str(task_id): v for k, v in losses.items()} + + return losses_with_task_id + + def modify_gt_for_single_task(self, gt_bboxes_3d, gt_labels_3d, task_id): + out_bboxes_list, out_labels_list = [], [] + for gts_b, gts_l in zip(gt_bboxes_3d, gt_labels_3d): + out_b, out_l = self.modify_gt_for_single_task_single_sample(gts_b, gts_l, task_id) + out_bboxes_list.append(out_b) + out_labels_list.append(out_l) + return out_bboxes_list, out_labels_list + + def modify_gt_for_single_task_single_sample(self, gt_bboxes_3d, gt_labels_3d, task_id): + assert gt_bboxes_3d.tensor.size(0) == gt_labels_3d.size(0) + if gt_labels_3d.size(0) == 0: + return gt_bboxes_3d, gt_labels_3d + assert (gt_labels_3d >= 0).all() # I don't want -1 in gt_labels_3d + + class_names_this_task = self.tasks[task_id]['class_names'] + num_classes_this_task = len(class_names_this_task) + out_gt_bboxes_list = [] + out_labels_list = [] + for i, name in enumerate(class_names_this_task): + cls_id = self.class_names.index(name) + this_cls_mask = gt_labels_3d == cls_id + out_gt_bboxes_list.append(gt_bboxes_3d[this_cls_mask]) + out_labels_list.append(gt_labels_3d.new_ones(this_cls_mask.sum()) * i) + out_gt_bboxes_3d = gt_bboxes_3d.cat(out_gt_bboxes_list) + out_labels = torch.cat(out_labels_list, dim=0) + if len(out_labels) > 0: + assert out_labels.max().item() < num_classes_this_task + return out_gt_bboxes_3d, out_labels + + def get_targets(self, + num_task_classes, + cluster_xyz, + batch_idx, + gt_bboxes_3d, + gt_labels_3d, + reg_preds=None): + batch_size = len(gt_bboxes_3d) + cluster_xyz_list = self.split_by_batch(cluster_xyz, batch_idx, batch_size) + + if reg_preds is not None: + reg_preds_list = self.split_by_batch(reg_preds, batch_idx, batch_size) + else: + reg_preds_list = [None,] * len(cluster_xyz_list) + + num_task_class_list = [num_task_classes,] * len(cluster_xyz_list) + target_list_per_sample = multi_apply(self.get_targets_single, num_task_class_list, cluster_xyz_list, gt_bboxes_3d, gt_labels_3d, reg_preds_list) + targets = [self.combine_by_batch(t, batch_idx, batch_size) for t in target_list_per_sample] + # targets == [labels, label_weights, bbox_targets, bbox_weights] + return targets + + def get_targets_single(self, + num_task_classes, + cluster_xyz, + gt_bboxes_3d, + gt_labels_3d, + reg_preds=None): + """Generate targets of vote head for single batch. + + """ + num_cluster = len(cluster_xyz) + labels = gt_labels_3d.new_full((num_cluster, ), num_task_classes, dtype=torch.long) + label_weights = cluster_xyz.new_ones(num_cluster) + bbox_targets = cluster_xyz.new_zeros((num_cluster, self.box_code_size)) + bbox_weights = cluster_xyz.new_zeros((num_cluster, self.box_code_size)) + if num_cluster == 0: + iou_labels = None + if self.loss_iou is not None: + iou_labels = cluster_xyz.new_zeros(0) + return labels, label_weights, bbox_targets, bbox_weights, iou_labels + + valid_gt_mask = gt_labels_3d >= 0 + gt_bboxes_3d = gt_bboxes_3d[valid_gt_mask] + gt_labels_3d = gt_labels_3d[valid_gt_mask] + + gt_bboxes_3d = gt_bboxes_3d.to(cluster_xyz.device) + if self.train_cfg.get('assign_by_dist', False): + assign_result = self.assign_by_dist_single(cluster_xyz, gt_bboxes_3d, gt_labels_3d) + else: + assign_result = self.assign_single(cluster_xyz, gt_bboxes_3d, gt_labels_3d) + + # Do not put this before assign + + sample_result = self.sampler.sample(assign_result, cluster_xyz, gt_bboxes_3d.tensor) # Pseudo Sampler, use cluster_xyz as pseudo bbox here. + + pos_inds = sample_result.pos_inds + neg_inds = sample_result.neg_inds + + # label targets + labels[pos_inds] = gt_labels_3d[sample_result.pos_assigned_gt_inds] + assert (labels >= 0).all() + bbox_weights[pos_inds] = 1.0 + + if len(pos_inds) > 0: + bbox_targets[pos_inds] = self.bbox_coder.encode(sample_result.pos_gt_bboxes, cluster_xyz[pos_inds]) + if sample_result.pos_gt_bboxes.size(1) == 10: + # zeros velocity loss weight for pasted objects + assert sample_result.pos_gt_bboxes[:, 9].max().item() in (0, 1) + assert sample_result.pos_gt_bboxes[:, 9].min().item() in (0, 1) + assert bbox_weights.size(1) == 10, 'It is not safe to use -2: as follows if size(1) != 10' + bbox_weights[pos_inds, -2:] = sample_result.pos_gt_bboxes[:, [9]] + + if self.loss_iou is not None: + iou_labels = self.get_iou_labels(reg_preds, cluster_xyz, gt_bboxes_3d.tensor, pos_inds) + else: + iou_labels = None + + return labels, label_weights, bbox_targets, bbox_weights, iou_labels + + + # generate votes target + def enlarge_gt_bboxes(self, gt_bboxes_3d, gt_labels_3d=None): + if self.enlarge_width is not None: + return gt_bboxes_3d.enlarged_box(self.enlarge_width) + else: + return gt_bboxes_3d + + @torch.no_grad() + def get_bboxes(self, + cls_logits, + reg_preds, + cluster_xyz, + cluster_inds, + input_metas, + iou_logits=None, + rescale=False, + ): + + + assert isinstance(cls_logits, list) + assert isinstance(reg_preds, list) + + assert len(cls_logits) == len(reg_preds) == len(self.tasks) + alltask_result_list = [] + for i in range(len(self.tasks)): + res_this_task = self.get_bboxes_single_task( + i, + cls_logits[i], + reg_preds[i], + cluster_xyz, + cluster_inds, + input_metas, + iou_logits, + rescale, + ) + alltask_result_list.append(res_this_task) + + + # concat results, I guess len of return list should equal to batch_size + batch_size = len(input_metas) + real_batch_size = len(alltask_result_list[0]) + assert real_batch_size <= batch_size # may less than batch_size if no + concat_list = [] + + + for b_idx in range(batch_size): + boxes = LiDARInstance3DBoxes.cat([task_res[b_idx][0] for task_res in alltask_result_list]) + score = torch.cat([task_res[b_idx][1] for task_res in alltask_result_list], dim=0) + label = torch.cat([task_res[b_idx][2] for task_res in alltask_result_list], dim=0) + concat_list.append((boxes, score, label)) + + return concat_list + + + @torch.no_grad() + def get_bboxes_single_task( + self, + task_id, + cls_logits, + reg_preds, + cluster_xyz, + cluster_inds, + input_metas, + iou_logits=None, + rescale=False, + ): + + if cluster_inds.ndim == 1: + batch_inds = cluster_inds + else: + batch_inds = cluster_inds[:, 1] + + batch_size = len(input_metas) + cls_logits_list = self.split_by_batch(cls_logits, batch_inds, batch_size) + reg_preds_list = self.split_by_batch(reg_preds, batch_inds, batch_size) + cluster_xyz_list = self.split_by_batch(cluster_xyz, batch_inds, batch_size) + + if iou_logits is not None: + iou_logits_list = self.split_by_batch(iou_logits, batch_inds, batch_size) + else: + iou_logits_list = [None,] * len(cls_logits_list) + + task_id_repeat = [task_id, ] * len(cls_logits_list) + multi_results = multi_apply( + self._get_bboxes_single, + task_id_repeat, + cls_logits_list, + iou_logits_list, + reg_preds_list, + cluster_xyz_list, + input_metas + ) + # out_bboxes_list, out_scores_list, out_labels_list = multi_results + results_list = [(b, s, l) for b, s, l in zip(*multi_results)] + return results_list + + + def _get_bboxes_single( + self, + task_id, + cls_logits, + iou_logits, + reg_preds, + cluster_xyz, + input_meta, + ): + ''' + Get bboxes of a single sample + ''' + + if self.as_rpn: + cfg = self.train_cfg.rpn if self.training else self.test_cfg.rpn + else: + cfg = self.test_cfg + + assert cls_logits.size(0) == reg_preds.size(0) == cluster_xyz.size(0) + assert cls_logits.size(1) == len(self.tasks[task_id]['class_names']) + assert reg_preds.size(1) == self.box_code_size + + if len(cls_logits) == 0: + out_bboxes = reg_preds.new_zeros((0, 7)) + out_bboxes = input_meta['box_type_3d'](out_bboxes) + out_scores = reg_preds.new_zeros(0) + out_labels = reg_preds.new_zeros(0) + return (out_bboxes, out_scores, out_labels) + + scores = cls_logits.sigmoid() + + if iou_logits is not None: + iou_scores = iou_logits.sigmoid() + a = cfg.get('iou_score_weight', 0.5) + scores = (scores ** (1 - a)) * (iou_scores ** a) + + nms_pre = cfg.get('nms_pre', -1) + if nms_pre > 0 and scores.shape[0] > nms_pre: + max_scores, _ = scores.max(dim=1) + _, topk_inds = max_scores.topk(nms_pre) + reg_preds = reg_preds[topk_inds, :] + scores = scores[topk_inds, :] + cluster_xyz = cluster_xyz[topk_inds, :] + + bboxes = self.bbox_coder.decode(reg_preds, cluster_xyz) + bboxes_for_nms = xywhr2xyxyr(input_meta['box_type_3d'](bboxes, box_dim=bboxes.size(1)).bev) + + # Add a dummy background class to the front when using sigmoid + padding = scores.new_zeros(scores.shape[0], 1) + scores = torch.cat([scores, padding], dim=1) + + score_thr = cfg.get('score_thr', 0) + results = box3d_multiclass_nms(bboxes, bboxes_for_nms, + scores, score_thr, cfg.max_num, + cfg) + + out_bboxes, out_scores, out_labels = results + + out_bboxes = input_meta['box_type_3d'](out_bboxes, out_bboxes.size(1)) + + # modify task labels to global label indices + new_labels = torch.zeros_like(out_labels) - 1 # all -1 + if len(out_labels) > 0: + for i, name in enumerate(self.tasks[task_id]['class_names']): + global_cls_ind = self.class_names.index(name) + new_labels[out_labels == i] = global_cls_ind + + assert (new_labels >= 0).all() + + out_labels = new_labels + + return (out_bboxes, out_scores, out_labels) \ No newline at end of file diff --git a/mmdet3d/models/detectors/__init__.py b/mmdet3d/models/detectors/__init__.py index c95e00ca0d..5cb9cef970 100644 --- a/mmdet3d/models/detectors/__init__.py +++ b/mmdet3d/models/detectors/__init__.py @@ -21,6 +21,8 @@ from .ssd3dnet import SSD3DNet from .votenet import VoteNet from .voxelnet import VoxelNet +from .single_stage_fsd import SingleStageFSD, VoteSegmentor +from .two_stage_fsd import FSD __all__ = [ 'Base3DDetector', 'VoxelNet', 'DynamicVoxelNet', 'MVXTwoStageDetector', @@ -28,5 +30,5 @@ 'CenterPoint', 'SSD3DNet', 'ImVoteNet', 'SingleStageMono3DDetector', 'FCOSMono3D', 'ImVoxelNet', 'GroupFree3DNet', 'PointRCNN', 'SMOKEMono3D', 'SASSD', 'MinkSingleStage3DDetector', 'MultiViewDfM', 'DfM', - 'PointVoxelRCNN' + 'PointVoxelRCNN', 'SingleStageFSD', 'VoteSegmentor', 'FSD' ] diff --git a/mmdet3d/models/detectors/single_stage_fsd.py b/mmdet3d/models/detectors/single_stage_fsd.py new file mode 100644 index 0000000000..55e80bb119 --- /dev/null +++ b/mmdet3d/models/detectors/single_stage_fsd.py @@ -0,0 +1,1162 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple, List + +import torch +from mmcv.ops import Voxelization, furthest_point_sample +from mmdet.models import build_detector +from mmdet.models.utils import multi_apply +from scipy.sparse.csgraph import connected_components +from torch import Tensor +from torch.nn import functional as F + +from mmdet3d.models.layers.sst import scatter_v2, get_inner_win_inds +from mmdet3d.models.segmentors.base import Base3DSegmentor +from mmdet3d.registry import MODELS +from mmdet3d.structures.ops import bbox3d2result +from .single_stage import SingleStage3DDetector +from .. import builder +from ..builder import build_backbone, build_head, build_neck +from ...structures.det3d_data_sample import OptSampleList +from ...utils.typing import SampleList + +try: + from torchex import connected_components as cc_gpu +except ImportError: + cc_gpu = None + +def fps(points, N): + idx = furthest_point_sample(points.unsqueeze(0), N) + idx = idx.squeeze(0).long() + points = points[idx] + return points + +def filter_almost_empty(coors, min_points): + new_coors, unq_inv, unq_cnt = torch.unique(coors, return_inverse=True, return_counts=True, dim=0) + cnt_per_point = unq_cnt[unq_inv] + valid_mask = cnt_per_point >= min_points + return valid_mask + +def find_connected_componets_gpu(points, batch_idx, dist): + + assert len(points) > 0 + assert cc_gpu is not None + components_inds = cc_gpu(points, batch_idx, dist, 100, 2, False) + assert len(torch.unique(components_inds)) == components_inds.max().item() + 1 + return components_inds + +def find_connected_componets(points, batch_idx, dist): + + device = points.device + bsz = batch_idx.max().item() + 1 + base = 0 + components_inds = torch.zeros_like(batch_idx) - 1 + + for i in range(bsz): + batch_mask = batch_idx == i + if batch_mask.any(): + this_points = points[batch_mask] + dist_mat = this_points[:, None, :2] - this_points[None, :, :2] # only care about xy + dist_mat = (dist_mat ** 2).sum(2) ** 0.5 + adj_mat = dist_mat < dist + adj_mat = adj_mat.cpu().numpy() + c_inds = connected_components(adj_mat, directed=False)[1] + c_inds = torch.from_numpy(c_inds).to(device).int() + base + base = c_inds.max().item() + 1 + components_inds[batch_mask] = c_inds + + assert len(torch.unique(components_inds)) == components_inds.max().item() + 1 + + return components_inds + +def find_connected_componets_single_batch(points, batch_idx, dist): + + device = points.device + + this_points = points + dist_mat = this_points[:, None, :2] - this_points[None, :, :2] # only care about xy + dist_mat = (dist_mat ** 2).sum(2) ** 0.5 + # dist_mat = torch.cdist(this_points[:, :2], this_points[:, :2], p=2) + adj_mat = dist_mat < dist + adj_mat = adj_mat.cpu().numpy() + c_inds = connected_components(adj_mat, directed=False)[1] + c_inds = torch.from_numpy(c_inds).to(device).int() + + return c_inds + +def ssg(points, batch_idx, num_fps, radius): + device = points.device + bsz = batch_idx.max().item() + 1 + base = 0 + components_inds = torch.zeros_like(batch_idx) - 2 + for i in range(bsz): + batch_mask = batch_idx == i + if batch_mask.any(): + this_points = points[batch_mask] + this_inds = ssg_single_sample(this_points, num_fps, radius) + this_inds[this_inds > -1] += base # keep -1 + base = this_inds.max().item() + 1 + components_inds[batch_mask] = this_inds + assert (components_inds > -2).all() + return components_inds + +def ssg_single_sample(points, num_fps, radius): + """ + a little complicated + """ + if num_fps >= len(points): + key_points = points + else: + key_points = fps(points, num_fps) + + k_dist_mat = key_points[:, None, :2] - key_points[None, :, :2] + k_dist_mat = (k_dist_mat ** 2).sum(2) ** 0.5 #[k, k] + dist_mask = k_dist_mat < radius * 2 + 0.01 + + triangle1 = torch.arange(len(key_points))[None, :].expand(len(key_points), -1) #[[0,1,2], [0, 1, 2]] + triangle2 = triangle1.T #[[0, 0, 0], [1, 1, 1]] + triangle_mask = triangle1 <= triangle2 + dist_mask[triangle_mask] = False + invalid_keypoints_mask = dist_mask.any(0) + + key_points = key_points[~invalid_keypoints_mask] + + dist_mat = key_points[:, None, :2] - points[None, :, :2] #[K, N] + dist_mat = (dist_mat ** 2).sum(2) ** 0.5 + + in_radius_mask = dist_mat < radius + + assert (in_radius_mask.sum(0) <= 1).all() + + valid_centers_mask = in_radius_mask.sum(0) == 1 # if a point falls into multiple balls or does not fall into any ball, it is invalid. + assert valid_centers_mask.any() + + pos = torch.nonzero(in_radius_mask) + cluster_inds = pos[:, 0] + + col_inds = pos[:, 1] + sorted_col_inds, order = torch.sort(col_inds) + cluster_inds = cluster_inds[order] + assert (sorted_col_inds == torch.nonzero(valid_centers_mask).reshape(-1)).all() + + cluster_inds_full = cluster_inds.new_zeros(len(points)) - 1 + + cluster_inds_full[valid_centers_mask] = cluster_inds + + return cluster_inds_full + + +def modify_cluster_by_class(cluster_inds_list): + new_list = [] + for i, inds in enumerate(cluster_inds_list): + cls_pad = inds.new_ones((len(inds),)) * i + inds = torch.cat([cls_pad[:, None], inds], 1) + # inds = F.pad(inds, (1, 0), 'constant', i) + new_list.append(inds) + return new_list + + +@MODELS.register_module() +class VoteSegmentor(Base3DSegmentor): + + def __init__(self, + voxel_layer, + voxel_encoder, + middle_encoder, + backbone, + segmentation_head, + decode_neck=None, + auxiliary_head=None, + voxel_downsampling_size=None, + train_cfg=None, + test_cfg=None, + init_cfg=None, + pretrained=None, + tanh_dims=None, + **extra_kwargs): + super().__init__(init_cfg=init_cfg) + + self.voxel_layer = Voxelization(**voxel_layer) + + self.voxel_encoder = builder.build_voxel_encoder(voxel_encoder) + self.middle_encoder = builder.build_middle_encoder(middle_encoder) + self.backbone = build_backbone(backbone) + self.segmentation_head = build_head(segmentation_head) + self.segmentation_head.train_cfg = train_cfg + self.segmentation_head.test_cfg = test_cfg + self.decode_neck = build_neck(decode_neck) + + assert voxel_encoder['type'] == 'DynamicScatterVFE' + + + self.print_info = {} + self.train_cfg = train_cfg + self.test_cfg = test_cfg + self.cfg = train_cfg if train_cfg is not None else test_cfg + self.num_classes = segmentation_head['num_classes'] + self.save_list = [] + self.point_cloud_range = voxel_layer['point_cloud_range'] + self.voxel_size = voxel_layer['voxel_size'] + self.voxel_downsampling_size = voxel_downsampling_size + self.tanh_dims = tanh_dims + + def encode_decode(self): + return None + + def aug_test(self, points, img_metas, imgs=None, rescale=False): + """Test function with augmentaiton.""" + return NotImplementedError + + @torch.no_grad() + # @force_fp32() + def voxelize(self, points): + """Apply dynamic voxelization to points. + Args: + points (list[torch.Tensor]): Points of each sample. + Returns: + tuple[torch.Tensor]: Concatenated points and coordinates. + """ + coors = [] + # dynamic voxelization only provide a coors mapping + for res in points: + res_coors = self.voxel_layer(res) + coors.append(res_coors) + points = torch.cat(points, dim=0) + coors_batch = [] + for i, coor in enumerate(coors): + coor_pad = F.pad(coor, (1, 0), mode='constant', value=i) + coors_batch.append(coor_pad) + coors_batch = torch.cat(coors_batch, dim=0) + return points, coors_batch + + def extract_feat(self, points, img_metas): + """Extract features from points.""" + batch_points, coors = self.voxelize(points) + coors = coors.long() + voxel_features, voxel_coors, voxel2point_inds = self.voxel_encoder(batch_points, coors, return_inv=True) + voxel_info = self.middle_encoder(voxel_features, voxel_coors) + x = self.backbone(voxel_info)[0] + padding = -1 + voxel_coors_dropped = x['voxel_feats'] # bug, leave it for feature modification + if 'shuffle_inds' not in voxel_info: + voxel_feats_reorder = x['voxel_feats'] + else: + # this branch only used in SST-based FSD + voxel_feats_reorder = self.reorder(x['voxel_feats'], voxel_info['shuffle_inds'], voxel_info['voxel_keep_inds'], padding) #'not consistent with voxel_coors any more' + + out = self.decode_neck(batch_points, coors, voxel_feats_reorder, voxel2point_inds, padding) + + return out, coors, batch_points + + def reorder(self, data, shuffle_inds, keep_inds, padding=-1): + ''' + Padding dropped voxel and reorder voxels. voxel length and order will be consistent with the output of voxel_encoder. + ''' + num_voxel_no_drop = len(shuffle_inds) + data_dim = data.size(1) + + temp_data = padding * data.new_ones((num_voxel_no_drop, data_dim)) + out_data = padding * data.new_ones((num_voxel_no_drop, data_dim)) + + temp_data[keep_inds] = data + out_data[shuffle_inds] = temp_data + + return out_data + + def voxel_downsample(self, points_list): + device = points_list[0].device + out_points_list = [] + voxel_size = torch.tensor(self.voxel_downsampling_size, device=device) + pc_range = torch.tensor(self.point_cloud_range, device=device) + + for points in points_list: + coors = torch.div(points[:, :3] - pc_range[None, :3], voxel_size[None, :], rounding_mode='floor').long() + out_points, new_coors = scatter_v2(points, coors, mode='avg', return_inv=False) + out_points_list.append(out_points) + return out_points_list + + def loss(self, + points, + img_metas, + gt_bboxes_3d, + gt_labels_3d, + as_subsegmentor=False, + ): + if self.tanh_dims is not None: + for p in points: + p[:, self.tanh_dims] = torch.tanh(p[:, self.tanh_dims]) + elif points[0].size(1) in (4,5): + # a hack way to scale the intensity and elongation in WOD + points = [torch.cat([p[:, :3], torch.tanh(p[:, 3:])], dim=1) for p in points] + + if self.voxel_downsampling_size is not None: + points = self.voxel_downsample(points) + + labels, vote_targets, vote_mask = self.segmentation_head.get_targets(points, gt_bboxes_3d, gt_labels_3d) + + neck_out, pts_coors, points = self.extract_feat(points, img_metas) + + losses = dict() + + feats = neck_out[0] + valid_pts_mask = neck_out[1] + points = points[valid_pts_mask] + pts_coors = pts_coors[valid_pts_mask] + labels = labels[valid_pts_mask] + vote_targets = vote_targets[valid_pts_mask] + vote_mask = vote_mask[valid_pts_mask] + + assert feats.size(0) == labels.size(0) + + if as_subsegmentor: + loss_decode, preds_dict = self.segmentation_head.loss(feats, img_metas, labels, vote_targets, vote_mask, return_preds=True) + losses.update(loss_decode) + + seg_logits = preds_dict['seg_logits'] + vote_preds = preds_dict['vote_preds'] + + offsets = self.segmentation_head.decode_vote_targets(vote_preds) + + output_dict = dict( + seg_points=points, + seg_logits=preds_dict['seg_logits'], + seg_vote_preds=preds_dict['vote_preds'], + offsets=offsets, + seg_feats=feats, + batch_idx=pts_coors[:, 0], + losses=losses + ) + else: + loss_decode = self.segmentation_head.loss(feats, img_metas, labels, vote_targets, vote_mask, return_preds=False) + losses.update(loss_decode) + output_dict = losses + + return output_dict + + def simple_test(self, points, img_metas, gt_bboxes_3d=None, gt_labels_3d=None, rescale=False): + + if self.tanh_dims is not None: + for p in points: + p[:, self.tanh_dims] = torch.tanh(p[:, self.tanh_dims]) + elif points[0].size(1) in (4,5): + points = [torch.cat([p[:, :3], torch.tanh(p[:, 3:])], dim=1) for p in points] + + if self.voxel_downsampling_size is not None: + points = self.voxel_downsample(points) + + seg_pred = [] + x, pts_coors, points = self.extract_feat(points, img_metas) + feats = x[0] + valid_pts_mask = x[1] + points = points[valid_pts_mask] + pts_coors = pts_coors[valid_pts_mask] + + seg_logits, vote_preds = self.segmentation_head.forward_test(feats, img_metas, self.test_cfg) + + offsets = self.segmentation_head.decode_vote_targets(vote_preds) + + output_dict = dict( + seg_points=points, + seg_logits=seg_logits, + seg_vote_preds=vote_preds, + offsets=offsets, + seg_feats=feats, + batch_idx=pts_coors[:, 0], + ) + + return output_dict + + def predict(self, batch_inputs: Tensor, batch_data_samples: SampleList) -> SampleList: + raise NotImplementedError + + def _forward(self, batch_inputs: Tensor, batch_data_samples: OptSampleList = None) -> Tuple[List[Tensor]]: + raise NotImplementedError + + +@MODELS.register_module() +class SingleStageFSD(SingleStage3DDetector): + + def __init__(self, + backbone, + segmentor, + voxel_layer=None, + voxel_encoder=None, + middle_encoder=None, + neck=None, + bbox_head=None, + train_cfg=None, + test_cfg=None, + cluster_assigner=None, + data_preprocessor=None, + init_cfg=None): + super(SingleStageFSD, self).__init__( + backbone=backbone, + neck=neck, + bbox_head=bbox_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) + if voxel_layer is not None: + self.voxel_layer = Voxelization(**voxel_layer) + if voxel_encoder is not None: + self.voxel_encoder = builder.build_voxel_encoder(voxel_encoder) + if middle_encoder is not None: + self.middle_encoder = builder.build_middle_encoder(middle_encoder) + + self.segmentor = build_detector(segmentor) + self.head_type = bbox_head['type'] + self.num_classes = bbox_head['num_classes'] + + self.cfg = self.train_cfg if self.train_cfg else self.test_cfg + if 'radius' in cluster_assigner: + self.cluster_assigner = SSGAssigner(**cluster_assigner) + elif 'hybrid' in cluster_assigner: + cluster_assigner.pop('hybrid') + self.cluster_assigner = HybridAssigner(**cluster_assigner) + else: + self.cluster_assigner = ClusterAssigner(**cluster_assigner) + self.cluster_assigner.num_classes = self.num_classes + self.print_info = {} + self.as_rpn = bbox_head.get('as_rpn', False) + + @torch.no_grad() + # @force_fp32() + def voxelize(self, points): + """Apply dynamic voxelization to points. + + """ + raise ValueError('This function should not be called in FSD') + device = points[0].device + voxel_size = torch.tensor(self.voxel_layer.voxel_size, device=device) + pc_range = torch.tensor(self.voxel_layer.point_cloud_range, device=device) + + coors = [] + for res in points: + res_coors = torch.div(res[:, :3] - pc_range[None, :3], voxel_size[None, :], rounding_mode='floor').long() + res_coors = res_coors[:, [2, 1, 0]] # to zyx order + coors.append(res_coors) + + points = torch.cat(points, dim=0) + + coors_batch = [] + for i, coor in enumerate(coors): + coor_pad = F.pad(coor, (1, 0), mode='constant', value=i) + coors_batch.append(coor_pad) + coors_batch = torch.cat(coors_batch, dim=0) + + return points, coors_batch + + def extract_feat(self, points, pts_feats, pts_cluster_inds, img_metas, center_preds): + """Extract features from points.""" + cluster_xyz, _, inv_inds = scatter_v2(center_preds, pts_cluster_inds, mode='avg', return_inv=True) + + f_cluster = points[:, :3] - cluster_xyz[inv_inds] + + out_pts_feats, cluster_feats, out_coors = self.backbone(points, pts_feats, pts_cluster_inds, f_cluster) + out_dict = dict( + cluster_feats=cluster_feats, + cluster_xyz=cluster_xyz, + cluster_inds=out_coors + ) + if self.as_rpn: + out_dict['cluster_pts_feats'] = out_pts_feats + out_dict['cluster_pts_xyz'] = points + + return out_dict + # + # def loss(self, + # points, + # img_metas, + # gt_bboxes_3d, + # gt_labels_3d, + # gt_bboxes_ignore=None, + # runtime_info=None): + + def loss(self, batch_inputs_dict: dict, batch_data_samples: SampleList, + **kwargs) -> dict: + self.runtime_info = runtime_info # stupid way to get arguements from children class + losses = {} + gt_bboxes_3d = [b[l>=0] for b, l in zip(gt_bboxes_3d, gt_labels_3d)] + gt_labels_3d = [l[l>=0] for l in gt_labels_3d] + + seg_out_dict = self.segmentor(points=points, img_metas=img_metas, gt_bboxes_3d=gt_bboxes_3d, gt_labels_3d=gt_labels_3d, as_subsegmentor=True) + + seg_feats = seg_out_dict['seg_feats'] + if self.train_cfg.get('detach_segmentor', False): + seg_feats = seg_feats.detach() + seg_loss = seg_out_dict['losses'] + losses.update(seg_loss) + + dict_to_sample = dict( + seg_points=seg_out_dict['seg_points'], + seg_logits=seg_out_dict['seg_logits'].detach(), + seg_vote_preds=seg_out_dict['seg_vote_preds'].detach(), + seg_feats=seg_feats, + batch_idx=seg_out_dict['batch_idx'], + vote_offsets=seg_out_dict['offsets'].detach(), + ) + if self.cfg.get('pre_voxelization_size', None) is not None: + dict_to_sample = self.pre_voxelize(dict_to_sample) + sampled_out = self.sample(dict_to_sample, dict_to_sample['vote_offsets'], gt_bboxes_3d, gt_labels_3d) # per cls list in sampled_out + + # we filter almost empty voxel in clustering, so here is a valid_mask + cluster_inds_list, valid_mask_list = self.cluster_assigner(sampled_out['center_preds'], sampled_out['batch_idx'], gt_bboxes_3d, gt_labels_3d, origin_points=sampled_out['seg_points']) # per cls list + pts_cluster_inds = torch.cat(cluster_inds_list, dim=0) #[N, 3], (cls_id, batch_idx, cluster_id) + + num_clusters = len(torch.unique(pts_cluster_inds, dim=0)) * torch.ones((1,), device=pts_cluster_inds.device).float() + losses['num_clusters'] = num_clusters + + sampled_out = self.update_sample_results_by_mask(sampled_out, valid_mask_list) + + combined_out = self.combine_classes(sampled_out, ['seg_points', 'seg_logits', 'seg_vote_preds', 'seg_feats', 'center_preds']) + + points = combined_out['seg_points'] + pts_feats = torch.cat([combined_out['seg_logits'], combined_out['seg_vote_preds'], combined_out['seg_feats']], dim=1) + assert len(pts_cluster_inds) == len(points) == len(pts_feats) + losses['num_fg_points'] = torch.ones((1,), device=points.device).float() * len(points) + + extracted_outs = self.extract_feat(points, pts_feats, pts_cluster_inds, img_metas, combined_out['center_preds']) + cluster_feats = extracted_outs['cluster_feats'] + cluster_xyz = extracted_outs['cluster_xyz'] + cluster_inds = extracted_outs['cluster_inds'] # [class, batch, groups] + + assert (cluster_inds[:, 0]).max().item() < self.num_classes + + outs = self.bbox_head(cluster_feats, cluster_xyz, cluster_inds) + loss_inputs = (outs['cls_logits'], outs['reg_preds']) + (cluster_xyz, cluster_inds) + (gt_bboxes_3d, gt_labels_3d, img_metas) + det_loss = self.bbox_head.loss( + *loss_inputs, iou_logits=outs.get('iou_logits', None), gt_bboxes_ignore=gt_bboxes_ignore) + + if hasattr(self.bbox_head, 'print_info'): + self.print_info.update(self.bbox_head.print_info) + losses.update(det_loss) + losses.update(self.print_info) + + if self.as_rpn: + output_dict = dict( + rpn_losses=losses, + cls_logits=outs['cls_logits'], + reg_preds=outs['reg_preds'], + cluster_xyz=cluster_xyz, + cluster_inds=cluster_inds, + all_input_points=dict_to_sample['seg_points'], + valid_pts_feats=extracted_outs['cluster_pts_feats'], + valid_pts_xyz=extracted_outs['cluster_pts_xyz'], + seg_feats=dict_to_sample['seg_feats'], + pts_mask=sampled_out['fg_mask_list'], + pts_batch_inds=dict_to_sample['batch_idx'], + ) + return output_dict + else: + return losses + + def update_sample_results_by_mask(self, sampled_out, valid_mask_list): + for k in sampled_out: + old_data = sampled_out[k] + if len(old_data[0]) == len(valid_mask_list[0]) or 'fg_mask' in k: + if 'fg_mask' in k: + new_data_list = [] + for data, mask in zip(old_data, valid_mask_list): + new_data = data.clone() + new_data[data] = mask + assert new_data.sum() == mask.sum() + new_data_list.append(new_data) + sampled_out[k] = new_data_list + else: + new_data_list = [data[mask] for data, mask in zip(old_data, valid_mask_list)] + sampled_out[k] = new_data_list + return sampled_out + + def combine_classes(self, data_dict, name_list): + out_dict = {} + for name in data_dict: + if name in name_list: + out_dict[name] = torch.cat(data_dict[name], 0) + return out_dict + + def pre_voxelize(self, data_dict): + batch_idx = data_dict['batch_idx'] + points = data_dict['seg_points'] + + voxel_size = torch.tensor(self.cfg.pre_voxelization_size, device=batch_idx.device) + pc_range = torch.tensor(self.cluster_assigner.point_cloud_range, device=points.device) + coors = torch.div(points[:, :3] - pc_range[None, :3], voxel_size[None, :], rounding_mode='floor').long() + coors = coors[:, [2, 1, 0]] # to zyx order + coors = torch.cat([batch_idx[:, None], coors], dim=1) + + new_coors, unq_inv = torch.unique(coors, return_inverse=True, return_counts=False, dim=0) + + voxelized_data_dict = {} + for data_name in data_dict: + data = data_dict[data_name] + if data.dtype in (torch.float, torch.float16): + voxelized_data, voxel_coors = scatter_v2(data, coors, mode='avg', return_inv=False, new_coors=new_coors, unq_inv=unq_inv) + voxelized_data_dict[data_name] = voxelized_data + + voxelized_data_dict['batch_idx'] = voxel_coors[:, 0] + return voxelized_data_dict + + def simple_test(self, points, img_metas, imgs=None, rescale=False, gt_bboxes_3d=None, gt_labels_3d=None): + """Test function without augmentaiton.""" + if gt_bboxes_3d is not None: + gt_bboxes_3d = gt_bboxes_3d[0] + gt_labels_3d = gt_labels_3d[0] + assert isinstance(gt_bboxes_3d, list) + assert isinstance(gt_labels_3d, list) + assert len(gt_bboxes_3d) == len(gt_labels_3d) == 1, 'assuming single sample testing' + + seg_out_dict = self.segmentor.simple_test(points, img_metas, rescale=False) + + seg_feats = seg_out_dict['seg_feats'] + + dict_to_sample = dict( + seg_points=seg_out_dict['seg_points'], + seg_logits=seg_out_dict['seg_logits'], + seg_vote_preds=seg_out_dict['seg_vote_preds'], + seg_feats=seg_feats, + batch_idx=seg_out_dict['batch_idx'], + vote_offsets = seg_out_dict['offsets'] + ) + if self.cfg.get('pre_voxelization_size', None) is not None: + dict_to_sample = self.pre_voxelize(dict_to_sample) + sampled_out = self.sample(dict_to_sample, dict_to_sample['vote_offsets'], gt_bboxes_3d, gt_labels_3d) # per cls list in sampled_out + + # we filter almost empty voxel in clustering, so here is a valid_mask + cluster_inds_list, valid_mask_list = self.cluster_assigner(sampled_out['center_preds'], sampled_out['batch_idx'], gt_bboxes_3d, gt_labels_3d, origin_points=sampled_out['seg_points']) # per cls list + + pts_cluster_inds = torch.cat(cluster_inds_list, dim=0) #[N, 3], (cls_id, batch_idx, cluster_id) + + sampled_out = self.update_sample_results_by_mask(sampled_out, valid_mask_list) + + combined_out = self.combine_classes(sampled_out, ['seg_points', 'seg_logits', 'seg_vote_preds', 'seg_feats', 'center_preds']) + + points = combined_out['seg_points'] + pts_feats = torch.cat([combined_out['seg_logits'], combined_out['seg_vote_preds'], combined_out['seg_feats']], dim=1) + assert len(pts_cluster_inds) == len(points) == len(pts_feats) + + extracted_outs = self.extract_feat(points, pts_feats, pts_cluster_inds, img_metas, combined_out['center_preds']) + cluster_feats = extracted_outs['cluster_feats'] + cluster_xyz = extracted_outs['cluster_xyz'] + cluster_inds = extracted_outs['cluster_inds'] + assert (cluster_inds[:, 1] == 0).all() + + outs = self.bbox_head(cluster_feats, cluster_xyz, cluster_inds) + + bbox_list = self.bbox_head.get_bboxes( + outs['cls_logits'], outs['reg_preds'], + cluster_xyz, cluster_inds, img_metas, + rescale=rescale, + iou_logits=outs.get('iou_logits', None)) + + if self.as_rpn: + output_dict = dict( + all_input_points=dict_to_sample['seg_points'], + valid_pts_feats=extracted_outs['cluster_pts_feats'], + valid_pts_xyz=extracted_outs['cluster_pts_xyz'], + seg_feats=dict_to_sample['seg_feats'], + pts_mask=sampled_out['fg_mask_list'], + pts_batch_inds=dict_to_sample['batch_idx'], + proposal_list=bbox_list + ) + return output_dict + else: + bbox_results = [ + bbox3d2result(bboxes, scores, labels) + for bboxes, scores, labels in bbox_list + ] + return bbox_results + + def predict(self, batch_inputs: Tensor, batch_data_samples: SampleList) -> SampleList: + raise NotImplementedError + + def _forward(self, batch_inputs: Tensor, batch_data_samples: OptSampleList = None) -> Tuple[List[Tensor]]: + raise NotImplementedError + + def aug_test(self, points, img_metas, imgs=None, rescale=False): + """Test function with augmentaiton.""" + return NotImplementedError + + + def sample(self, dict_to_sample, offset, gt_bboxes_3d=None, gt_labels_3d=None): + + if self.cfg.get('group_sample', False): + return self.group_sample(dict_to_sample, offset) + + cfg = self.train_cfg if self.training else self.test_cfg + + seg_logits = dict_to_sample['seg_logits'] + assert (seg_logits < 0).any() # make sure no sigmoid applied + + if seg_logits.size(1) == self.num_classes: + seg_scores = seg_logits.sigmoid() + else: + raise NotImplementedError + + offset = offset.reshape(-1, self.num_classes, 3) + seg_points = dict_to_sample['seg_points'][:, :3] + fg_mask_list = [] # fg_mask of each cls + center_preds_list = [] # fg_mask of each cls + + batch_idx = dict_to_sample['batch_idx'] + batch_size = batch_idx.max().item() + 1 + for cls in range(self.num_classes): + cls_score_thr = cfg['score_thresh'][cls] + + fg_mask = self.get_fg_mask(seg_scores, seg_points, cls, batch_idx, gt_bboxes_3d, gt_labels_3d) + + if len(torch.unique(batch_idx[fg_mask])) < batch_size: + one_random_pos_per_sample = self.get_sample_beg_position(batch_idx, fg_mask) + fg_mask[one_random_pos_per_sample] = True # at least one point per sample + + fg_mask_list.append(fg_mask) + + this_offset = offset[fg_mask, cls, :] + this_points = seg_points[fg_mask, :] + this_centers = this_points + this_offset + center_preds_list.append(this_centers) + + + output_dict = {} + for data_name in dict_to_sample: + data = dict_to_sample[data_name] + cls_data_list = [] + for fg_mask in fg_mask_list: + cls_data_list.append(data[fg_mask]) + + output_dict[data_name] = cls_data_list + output_dict['fg_mask_list'] = fg_mask_list + output_dict['center_preds'] = center_preds_list + + return output_dict + + def get_sample_beg_position(self, batch_idx, fg_mask): + assert batch_idx.shape == fg_mask.shape + inner_inds = get_inner_win_inds(batch_idx.contiguous()) + pos = torch.where(inner_inds == 0)[0] + return pos + + def get_fg_mask(self, seg_scores, seg_points, cls_id, batch_inds, gt_bboxes_3d, gt_labels_3d): + if self.training and self.train_cfg.get('disable_pretrain', False) and not self.runtime_info.get('enable_detection', False): + seg_scores = seg_scores[:, cls_id] + topks = self.train_cfg.get('disable_pretrain_topks', [100, 100, 100]) + k = min(topks[cls_id], len(seg_scores)) + top_inds = torch.topk(seg_scores, k)[1] + fg_mask = torch.zeros_like(seg_scores, dtype=torch.bool) + fg_mask[top_inds] = True + else: + seg_scores = seg_scores[:, cls_id] + cls_score_thr = self.cfg['score_thresh'][cls_id] + if self.training: + buffer_thr = self.runtime_info.get('threshold_buffer', 0) + else: + buffer_thr = 0 + fg_mask = seg_scores > cls_score_thr + buffer_thr + + # add fg points + cfg = self.train_cfg if self.training else self.test_cfg + + if cfg.get('add_gt_fg_points', False): + bsz = len(gt_bboxes_3d) + assert len(seg_scores) == len(seg_points) == len(batch_inds) + point_list = self.split_by_batch(seg_points, batch_inds, bsz) + gt_fg_mask_list = [] + + for i, points in enumerate(point_list): + + gt_mask = gt_labels_3d[i] == cls_id + gts = gt_bboxes_3d[i][gt_mask] + + if not gt_mask.any() or len(points) == 0: + gt_fg_mask_list.append(gt_mask.new_zeros(len(points), dtype=torch.bool)) + continue + + gt_fg_mask_list.append(gts.points_in_boxes(points) > -1) + + gt_fg_mask = self.combine_by_batch(gt_fg_mask_list, batch_inds, bsz) + fg_mask = fg_mask | gt_fg_mask + + + return fg_mask + + def split_by_batch(self, data, batch_idx, batch_size): + assert batch_idx.max().item() + 1 <= batch_size + data_list = [] + for i in range(batch_size): + sample_mask = batch_idx == i + data_list.append(data[sample_mask]) + return data_list + + def combine_by_batch(self, data_list, batch_idx, batch_size): + assert len(data_list) == batch_size + if data_list[0] is None: + return None + data_shape = (len(batch_idx),) + data_list[0].shape[1:] + full_data = data_list[0].new_zeros(data_shape) + for i, data in enumerate(data_list): + sample_mask = batch_idx == i + full_data[sample_mask] = data + return full_data + + def group_sample(self, dict_to_sample, offset): + + """ + For argoverse 2 dataset, where the number of classes is large + """ + + bsz = dict_to_sample['batch_idx'].max().item() + 1 + assert bsz == 1, "Maybe some codes need to be modified if bsz > 1" + # combine all classes as fg class. + cfg = self.train_cfg if self.training else self.test_cfg + + seg_logits = dict_to_sample['seg_logits'] + assert (seg_logits < 0).any() # make sure no sigmoid applied + + assert seg_logits.size(1) == self.num_classes + 1 # we have background class + seg_scores = seg_logits.softmax(1) + + offset = offset.reshape(-1, self.num_classes + 1, 3) + seg_points = dict_to_sample['seg_points'][:, :3] + fg_mask_list = [] # fg_mask of each cls + center_preds_list = [] # fg_mask of each cls + + + cls_score_thrs = cfg['score_thresh'] + group_lens = cfg['group_lens'] + num_groups = len(group_lens) + assert num_groups == len(cls_score_thrs) + assert isinstance(cls_score_thrs, (list, tuple)) + grouped_score = self.gather_group(seg_scores[:, :-1], group_lens) # without background score + + beg = 0 + for i, group_len in enumerate(group_lens): + end = beg + group_len + + fg_mask = grouped_score[:, i] > cls_score_thrs[i] + + if not fg_mask.any(): + fg_mask[0] = True # at least one point + + fg_mask_list.append(fg_mask) + + this_offset = offset[fg_mask, beg:end, :] + offset_weight = self.get_offset_weight(seg_logits[fg_mask, beg:end]) + assert torch.isclose(offset_weight.sum(1), offset_weight.new_ones(len(offset_weight))).all() + this_offset = (this_offset * offset_weight[:, :, None]).sum(dim=1) + this_points = seg_points[fg_mask, :] + this_centers = this_points + this_offset + center_preds_list.append(this_centers) + beg = end + assert end == 26, 'for 26class argo' + + + output_dict = {} + for data_name in dict_to_sample: + data = dict_to_sample[data_name] + cls_data_list = [] + for fg_mask in fg_mask_list: + cls_data_list.append(data[fg_mask]) + + output_dict[data_name] = cls_data_list + output_dict['fg_mask_list'] = fg_mask_list + output_dict['center_preds'] = center_preds_list + + return output_dict + + def get_offset_weight(self, seg_logit): + mode = self.cfg['offset_weight'] + if mode == 'max': + weight = ((seg_logit - seg_logit.max(1)[0][:, None]).abs() < 1e-6).float() + assert ((weight == 1).any(1)).all() + weight = weight / weight.sum(1)[:, None] # in case of two max values + return weight + else: + raise NotImplementedError + + def gather_group(self, scores, group_lens): + assert (scores >= 0).all() + score_per_group = [] + beg = 0 + for group_len in group_lens: + end = beg + group_len + score_this_g = scores[:, beg:end].sum(1) + score_per_group.append(score_this_g) + beg = end + assert end == scores.size(1) == sum(group_lens) + gathered_score = torch.stack(score_per_group, dim=1) + assert gathered_score.size(1) == len(group_lens) + return gathered_score + + +class ClusterAssigner(torch.nn.Module): + ''' Generating cluster centers for each class and assign each point to cluster centers + ''' + + def __init__( + self, + cluster_voxel_size, + min_points, + point_cloud_range, + connected_dist, + class_names=['Car', 'Cyclist', 'Pedestrian'], + gpu_clustering=(False, False), + ): + super().__init__() + self.cluster_voxel_size = cluster_voxel_size + self.min_points = min_points + self.connected_dist = connected_dist + self.point_cloud_range = point_cloud_range + self.class_names = class_names + self.gpu_clustering = gpu_clustering + + @torch.no_grad() + def forward(self, points_list, batch_idx_list, gt_bboxes_3d=None, gt_labels_3d=None, origin_points=None): + gt_bboxes_3d = None + gt_labels_3d = None + assert self.num_classes == len(self.class_names) + cluster_inds_list, valid_mask_list = \ + multi_apply(self.forward_single_class, points_list, batch_idx_list, self.class_names, origin_points) + cluster_inds_list = modify_cluster_by_class(cluster_inds_list) + return cluster_inds_list, valid_mask_list + + def forward_single_class(self, points, batch_idx, class_name, origin_points): + batch_idx = batch_idx.int() + + if isinstance(self.cluster_voxel_size, dict): + cluster_vsize = self.cluster_voxel_size[class_name] + elif isinstance(self.cluster_voxel_size, list): + cluster_vsize = self.cluster_voxel_size[self.class_names.index(class_name)] + else: + cluster_vsize = self.cluster_voxel_size + + voxel_size = torch.tensor(cluster_vsize, device=points.device) + pc_range = torch.tensor(self.point_cloud_range, device=points.device) + coors = torch.div(points - pc_range[None, :3], voxel_size[None, :], rounding_mode='floor').int() + # coors = coors[:, [2, 1, 0]] # to zyx order + coors = torch.cat([batch_idx[:, None], coors], dim=1) + + valid_mask = filter_almost_empty(coors, min_points=self.min_points) + if not valid_mask.any(): + valid_mask = ~valid_mask + # return coors.new_zeros((3,0)), valid_mask + + points = points[valid_mask] + batch_idx = batch_idx[valid_mask] + coors = coors[valid_mask] + # elif len(points) + + sampled_centers, voxel_coors, inv_inds = scatter_v2(points, coors, mode='avg', return_inv=True) + + if isinstance(self.connected_dist, dict): + dist = self.connected_dist[class_name] + elif isinstance(self.connected_dist, list): + dist = self.connected_dist[self.class_names.index(class_name)] + else: + dist = self.connected_dist + + if self.training: + cluster_inds = find_connected_componets(sampled_centers, voxel_coors[:, 0], dist) + else: + if self.gpu_clustering[1]: + cluster_inds = find_connected_componets_gpu(sampled_centers, voxel_coors[:, 0], dist) + else: + cluster_inds = find_connected_componets_single_batch(sampled_centers, voxel_coors[:, 0], dist) + assert len(cluster_inds) == len(sampled_centers) + + cluster_inds_per_point = cluster_inds[inv_inds] + cluster_inds_per_point = torch.stack([batch_idx, cluster_inds_per_point], 1) + return cluster_inds_per_point, valid_mask + + +class SSGAssigner(torch.nn.Module): + ''' Generating cluster centers for each class and assign each point to cluster centers + ''' + + def __init__( + self, + cluster_voxel_size, + point_cloud_range, + radius, + num_fps, + class_names=['Car', 'Cyclist', 'Pedestrian'], + ): + super().__init__() + self.cluster_voxel_size = cluster_voxel_size + self.radius = radius + self.num_fps = num_fps + self.point_cloud_range = point_cloud_range + self.class_names = class_names + + @torch.no_grad() + def forward(self, points_list, batch_idx_list, gt_bboxes_3d=None, gt_labels_3d=None, origin_points=None): + gt_bboxes_3d = None + gt_labels_3d = None + assert self.num_classes == len(self.class_names) + cluster_inds_list, valid_mask_list = \ + multi_apply(self.forward_single_class, points_list, batch_idx_list, self.class_names, origin_points) + cluster_inds_list = modify_cluster_by_class(cluster_inds_list) + return cluster_inds_list, valid_mask_list + + def forward_single_class(self, points, batch_idx, class_name, origin_points): + + if isinstance(self.cluster_voxel_size, dict): + cluster_vsize = self.cluster_voxel_size[class_name] + elif isinstance(self.cluster_voxel_size, list): + cluster_vsize = self.cluster_voxel_size[self.class_names.index(class_name)] + else: + cluster_vsize = self.cluster_voxel_size + + if isinstance(self.radius, dict): + radius = self.radius[class_name] + elif isinstance(self.radius, list): + radius = self.radius[self.class_names.index(class_name)] + else: + radius = self.radius + + voxel_size = torch.tensor(cluster_vsize, device=points.device) + pc_range = torch.tensor(self.point_cloud_range, device=points.device) + coors = torch.div(points - pc_range[None, :3], voxel_size[None, :], rounding_mode='floor').long() + coors = coors[:, [2, 1, 0]] # to zyx order + coors = torch.cat([batch_idx[:, None], coors], dim=1) + + voxels, _, inv_inds = scatter_v2(points, coors, mode='avg', return_inv=True) + + num_fps = self.num_fps[class_name] + if num_fps >= len(voxels): + key_points = voxels + else: + key_points = fps(voxels, self.num_fps[class_name]) + + k_dist_mat = key_points[:, None, :2] - key_points[None, :, :2] + k_dist_mat = (k_dist_mat ** 2).sum(2) ** 0.5 #[k, k] + dist_mask = k_dist_mat < radius * 2 + 0.01 + + triangle1 = torch.arange(len(key_points))[None, :].expand(len(key_points), -1) #[[0,1,2], [0, 1, 2]] + triangle2 = triangle1.T #[[0, 0, 0], [1, 1, 1]] + triangle_mask = triangle1 <= triangle2 + dist_mask[triangle_mask] = False + invalid_keypoints_mask = dist_mask.any(0) + + key_points = key_points[~invalid_keypoints_mask] + + dist_mat = key_points[:, None, :2] - voxels[None, :, :2] #[K, N] + dist_mat = (dist_mat ** 2).sum(2) ** 0.5 + + in_radius_mask = dist_mat < radius + + assert (in_radius_mask.sum(0) <= 1).all() + + valid_centers_mask = in_radius_mask.sum(0) == 1 + assert valid_centers_mask.any() + + pos = torch.nonzero(in_radius_mask) + cluster_inds = pos[:, 0] + + col_inds = pos[:, 1] + sorted_col_inds, order = torch.sort(col_inds) + cluster_inds = cluster_inds[order] + assert (sorted_col_inds == torch.nonzero(valid_centers_mask).reshape(-1)).all() + + cluster_inds_full = cluster_inds.new_zeros(len(voxels)) - 1 + + cluster_inds_full[valid_centers_mask] = cluster_inds + + cluster_inds_per_point = cluster_inds_full[inv_inds] + valid_pts_mask = cluster_inds_per_point > -1 + + cluster_inds_per_point = torch.stack([batch_idx, cluster_inds_per_point], 1) + cluster_inds_per_point = cluster_inds_per_point[valid_pts_mask] + + return cluster_inds_per_point, valid_pts_mask + + +class HybridAssigner(torch.nn.Module): + ''' Generating cluster centers for each class and assign each point to cluster centers + ''' + + def __init__( + self, + point_cloud_range, + cfg_per_class, + class_names=['Car', 'Cyclist', 'Pedestrian'], + ): + super().__init__() + self.point_cloud_range = point_cloud_range + self.class_names = class_names + self.cfg_per_class = cfg_per_class + + @torch.no_grad() + def forward(self, points_list, batch_idx_list, gt_bboxes_3d=None, gt_labels_3d=None, origin_points=None): + gt_bboxes_3d = None + gt_labels_3d = None + assert self.num_classes == len(self.class_names) + cluster_inds_list, valid_mask_list = \ + multi_apply(self.forward_single_class, points_list, batch_idx_list, self.class_names, origin_points) + cluster_inds_list = modify_cluster_by_class(cluster_inds_list) + return cluster_inds_list, valid_mask_list + + def forward_single_class(self, points, batch_idx, class_name, origin_points): + """ + Dispatcher + """ + assigner_type = self.cfg_per_class[class_name]['assigner_type'] + if assigner_type == 'ssg': + return self.forward_ssg(points, batch_idx, class_name, origin_points) + elif assigner_type == 'ccl': + return self.forward_ccl(points, batch_idx, class_name, origin_points) + + def forward_ssg(self, points, batch_idx, class_name, origin_points): + + cluster_vsize = self.cfg_per_class[class_name]['cluster_voxel_size'] + radius = self.cfg_per_class[class_name]['radius'] + num_fps = self.cfg_per_class[class_name]['num_fps'] + + voxel_size = torch.tensor(cluster_vsize, device=points.device) + pc_range = torch.tensor(self.point_cloud_range, device=points.device) + coors = torch.div(points - pc_range[None, :3], voxel_size[None, :], rounding_mode='floor').long() + coors = coors[:, [2, 1, 0]] # to zyx order + coors = torch.cat([batch_idx[:, None], coors], dim=1) + + voxels, voxel_coors, inv_inds = scatter_v2(points, coors, mode='avg', return_inv=True) + + cluster_inds_full = ssg(voxels, voxel_coors[:, 0], num_fps, radius) + + cluster_inds_per_point = cluster_inds_full[inv_inds] + valid_pts_mask = cluster_inds_per_point > -1 + + cluster_inds_per_point = torch.stack([batch_idx, cluster_inds_per_point], 1) + cluster_inds_per_point = cluster_inds_per_point[valid_pts_mask] + + return cluster_inds_per_point, valid_pts_mask + + + def forward_ccl(self, points, batch_idx, class_name, origin_points): + + cluster_vsize = self.cfg_per_class[class_name]['cluster_voxel_size'] + min_points = self.cfg_per_class[class_name]['min_points'] + dist = self.cfg_per_class[class_name]['connected_dist'] + + voxel_size = torch.tensor(cluster_vsize, device=points.device) + pc_range = torch.tensor(self.point_cloud_range, device=points.device) + coors = torch.div(points - pc_range[None, :3], voxel_size[None, :], rounding_mode='floor').long() + coors = coors[:, [2, 1, 0]] # to zyx order + coors = torch.cat([batch_idx[:, None], coors], dim=1) + + valid_mask = filter_almost_empty(coors, min_points=min_points) + if not valid_mask.any(): + valid_mask = ~valid_mask + # return coors.new_zeros((3,0)), valid_mask + + points = points[valid_mask] + batch_idx = batch_idx[valid_mask] + coors = coors[valid_mask] + # elif len(points) + + sampled_centers, voxel_coors, inv_inds = scatter_v2(points, coors, mode='avg', return_inv=True) + + + cluster_inds = find_connected_componets(sampled_centers, voxel_coors[:, 0], dist) + assert len(cluster_inds) == len(sampled_centers) + + cluster_inds_per_point = cluster_inds[inv_inds] + cluster_inds_per_point = torch.stack([batch_idx, cluster_inds_per_point], 1) + return cluster_inds_per_point, valid_mask \ No newline at end of file diff --git a/mmdet3d/models/detectors/two_stage_fsd.py b/mmdet3d/models/detectors/two_stage_fsd.py new file mode 100644 index 0000000000..1ed69efece --- /dev/null +++ b/mmdet3d/models/detectors/two_stage_fsd.py @@ -0,0 +1,258 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Union + +from .single_stage_fsd import SingleStageFSD +import torch +from mmdet3d.structures import bbox3d2result +from .. import builder + +from mmdet3d.registry import MODELS +from ...structures.det3d_data_sample import SampleList + + +@MODELS.register_module() +class FSD(SingleStageFSD): + + def __init__(self, + backbone, + segmentor, + voxel_layer=None, + voxel_encoder=None, + middle_encoder=None, + neck=None, + bbox_head=None, + roi_head=None, + train_cfg=None, + test_cfg=None, + cluster_assigner=None, + data_preprocessor=dict(type='Det3DDataPreprocessor'), + init_cfg=None): + super().__init__( + backbone=backbone, + segmentor=segmentor, + voxel_layer=voxel_layer, + voxel_encoder=voxel_encoder, + middle_encoder=middle_encoder, + neck=neck, + bbox_head=bbox_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + cluster_assigner=cluster_assigner, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg, + ) + + # update train and test cfg here for now + rcnn_train_cfg = train_cfg.rcnn if train_cfg else None + roi_head.update(train_cfg=rcnn_train_cfg) + roi_head.update(test_cfg=test_cfg.rcnn) + roi_head.pretrained = None + self.roi_head = builder.build_head(roi_head) + self.num_classes = self.bbox_head.num_classes + self.runtime_info = dict() + + # def loss(self, + # points, + # img_metas, + # gt_bboxes_3d, + # gt_labels_3d, + # gt_bboxes_ignore=None): + + def loss(self, batch_inputs_dict: dict, batch_data_samples: SampleList, + **kwargs) -> dict: + + gt_bboxes_3d = [b[l>=0] for b, l in zip(gt_bboxes_3d, gt_labels_3d)] + gt_labels_3d = [l[l>=0] for l in gt_labels_3d] + + losses = {} + rpn_outs = super().loss( + points=points, + img_metas=img_metas, + gt_bboxes_3d=gt_bboxes_3d, + gt_labels_3d=gt_labels_3d, + gt_bboxes_ignore=gt_bboxes_ignore, + runtime_info=self.runtime_info + ) + losses.update(rpn_outs['rpn_losses']) + + proposal_list = self.bbox_head.get_bboxes( + rpn_outs['cls_logits'], rpn_outs['reg_preds'], rpn_outs['cluster_xyz'], rpn_outs['cluster_inds'], img_metas + ) + + assert len(proposal_list) == len(gt_bboxes_3d) + + pts_xyz, pts_feats, pts_batch_inds = self.prepare_multi_class_roi_input( + rpn_outs['all_input_points'], + rpn_outs['valid_pts_feats'], + rpn_outs['seg_feats'], + rpn_outs['pts_mask'], + rpn_outs['pts_batch_inds'], + rpn_outs['valid_pts_xyz'] + ) + + roi_losses = self.roi_head.loss( + pts_xyz, + pts_feats, + pts_batch_inds, + img_metas, + proposal_list, + gt_bboxes_3d, + gt_labels_3d, + ) + + losses.update(roi_losses) + return losses + + def prepare_roi_input(self, points, cluster_pts_feats, pts_seg_feats, pts_mask, pts_batch_inds, cluster_pts_xyz): + assert isinstance(pts_mask, list) + pts_mask = pts_mask[0] + assert points.shape[0] == pts_seg_feats.shape[0] == pts_mask.shape[0] == pts_batch_inds.shape[0] + + if self.training and self.train_cfg.get('detach_seg_feats', False): + pts_seg_feats = pts_seg_feats.detach() + + if self.training and self.train_cfg.get('detach_cluster_feats', False): + cluster_pts_feats = cluster_pts_feats.detach() + + pad_feats = cluster_pts_feats.new_zeros(points.shape[0], cluster_pts_feats.shape[1]) + pad_feats[pts_mask] = cluster_pts_feats + assert torch.isclose(points[pts_mask], cluster_pts_xyz).all() + + cat_feats = torch.cat([pad_feats, pts_seg_feats], dim=1) + + return points, cat_feats, pts_batch_inds + + def prepare_multi_class_roi_input(self, points, cluster_pts_feats, pts_seg_feats, pts_mask, pts_batch_inds, cluster_pts_xyz): + assert isinstance(pts_mask, list) + bg_mask = sum(pts_mask) == 0 + assert points.shape[0] == pts_seg_feats.shape[0] == bg_mask.shape[0] == pts_batch_inds.shape[0] + + if self.training and self.train_cfg.get('detach_seg_feats', False): + pts_seg_feats = pts_seg_feats.detach() + + if self.training and self.train_cfg.get('detach_cluster_feats', False): + cluster_pts_feats = cluster_pts_feats.detach() + + + ##### prepare points for roi head + fg_points_list = [points[m] for m in pts_mask] + all_fg_points = torch.cat(fg_points_list, dim=0) + + assert torch.isclose(all_fg_points, cluster_pts_xyz).all() + + bg_pts_xyz = points[bg_mask] + all_points = torch.cat([bg_pts_xyz, all_fg_points], dim=0) + ##### + + ##### prepare features for roi head + fg_seg_feats_list = [pts_seg_feats[m] for m in pts_mask] + all_fg_seg_feats = torch.cat(fg_seg_feats_list, dim=0) + bg_seg_feats = pts_seg_feats[bg_mask] + all_seg_feats = torch.cat([bg_seg_feats, all_fg_seg_feats], dim=0) + + num_out_points = len(all_points) + assert num_out_points == len(all_seg_feats) + + pad_feats = cluster_pts_feats.new_zeros(bg_mask.sum(), cluster_pts_feats.shape[1]) + all_cluster_pts_feats = torch.cat([pad_feats, cluster_pts_feats], dim=0) + ##### + + ##### prepare batch inds for roi head + bg_batch_inds = pts_batch_inds[bg_mask] + fg_batch_inds_list = [pts_batch_inds[m] for m in pts_mask] + fg_batch_inds = torch.cat(fg_batch_inds_list, dim=0) + all_batch_inds = torch.cat([bg_batch_inds, fg_batch_inds], dim=0) + + + # pad_feats[pts_mask] = cluster_pts_feats + + cat_feats = torch.cat([all_cluster_pts_feats, all_seg_feats], dim=1) + + # sort for roi extractor + all_batch_inds, inds = all_batch_inds.sort() + all_points = all_points[inds] + cat_feats = cat_feats[inds] + + return all_points, cat_feats, all_batch_inds + + def simple_test(self, points, img_metas, imgs=None, rescale=False, gt_bboxes_3d=None, gt_labels_3d=None): + + + rpn_outs = super().simple_test( + points=points, + img_metas=img_metas, + gt_bboxes_3d=gt_bboxes_3d, + gt_labels_3d=gt_labels_3d, + ) + + proposal_list = rpn_outs['proposal_list'] + + if self.test_cfg.get('skip_rcnn', False): + bbox_results = [ + bbox3d2result(bboxes, scores, labels) + for bboxes, scores, labels in proposal_list + ] + return bbox_results + + if self.num_classes > 1 or self.test_cfg.get('enable_multi_class_test', False): + prepare_func = self.prepare_multi_class_roi_input + else: + prepare_func = self.prepare_roi_input + + pts_xyz, pts_feats, pts_batch_inds = prepare_func( + rpn_outs['all_input_points'], + rpn_outs['valid_pts_feats'], + rpn_outs['seg_feats'], + rpn_outs['pts_mask'], + rpn_outs['pts_batch_inds'], + rpn_outs['valid_pts_xyz'] + ) + + results = self.roi_head.simple_test( + pts_xyz, + pts_feats, + pts_batch_inds, + img_metas, + proposal_list, + gt_bboxes_3d, + gt_labels_3d, + ) + + return results + + + def extract_fg_by_gt(self, point_list, gt_bboxes_3d, gt_labels_3d, extra_width): + if isinstance(gt_bboxes_3d[0], list): + assert len(gt_bboxes_3d) == 1 + assert len(gt_labels_3d) == 1 + gt_bboxes_3d = gt_bboxes_3d[0] + gt_labels_3d = gt_labels_3d[0] + + bsz = len(point_list) + + new_point_list = [] + for i in range(bsz): + points = point_list[i] + gts = gt_bboxes_3d[i].to(points.device) + if len(gts) == 0: + this_fg_mask = points.new_zeros(len(points), dtype=torch.bool) + this_fg_mask[:min(1000, len(points))] = True + else: + if isinstance(extra_width, dict): + this_labels = gt_labels_3d[i] + enlarged_gts_list = [] + for cls in range(self.num_classes): + cls_mask = this_labels == cls + if cls_mask.any(): + this_enlarged_gts = gts[cls_mask].enlarged_box(extra_width[cls]) + enlarged_gts_list.append(this_enlarged_gts) + enlarged_gts = gts.cat(enlarged_gts_list) + else: + enlarged_gts = gts.enlarged_box(extra_width) + pts_inds = enlarged_gts.points_in_boxes(points[:, :3]) + this_fg_mask = pts_inds > -1 + if not this_fg_mask.any(): + this_fg_mask[:min(1000, len(points))] = True + + new_point_list.append(points[this_fg_mask]) + return new_point_list diff --git a/mmdet3d/models/layers/sst/__init__.py b/mmdet3d/models/layers/sst/__init__.py new file mode 100644 index 0000000000..851e7f9c07 --- /dev/null +++ b/mmdet3d/models/layers/sst/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .sst_ops import build_mlp, scatter_v2, get_inner_win_inds, flat2window_v2, window2flat_v2, get_flat2win_inds_v2, get_window_coors + +__all__ = [ + 'build_mlp', 'scatter_v2', 'get_inner_win_inds', 'flat2window_v2', 'window2flat_v2', 'get_flat2win_inds_v2', 'get_window_coors' +] diff --git a/mmdet3d/models/layers/sst/sst_ops.py b/mmdet3d/models/layers/sst/sst_ops.py new file mode 100644 index 0000000000..e54946469e --- /dev/null +++ b/mmdet3d/models/layers/sst/sst_ops.py @@ -0,0 +1,391 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import random +import traceback + +import numpy as np +import torch +import torch.nn as nn +import torch_scatter +from mmcv.cnn import build_norm_layer + + +def scatter_nd(indices, updates, shape): + """pytorch edition of tensorflow scatter_nd. + + this function don't contain except handle code. so use this carefully when + indice repeats, don't support repeat add which is supported in tensorflow. + """ + ret = torch.zeros(*shape, dtype=updates.dtype, device=updates.device) + ndim = indices.shape[-1] + output_shape = list(indices.shape[:-1]) + shape[indices.shape[-1]:] + flatted_indices = indices.view(-1, ndim) + slices = [flatted_indices[:, i] for i in range(ndim)] + slices += [Ellipsis] + ret[slices] = updates.view(*output_shape) + return ret + +@torch.no_grad() +def get_flat2win_inds(batch_win_inds, voxel_drop_lvl, drop_info, debug=True): + ''' + Args: + batch_win_inds: shape=[N, ]. Indicates which window a voxel belongs to. Window inds is unique is the whole batch. + voxel_drop_lvl: shape=[N, ]. Indicates batching_level of the window the voxel belongs to. + Returns: + flat2window_inds_dict: contains flat2window_inds of each voxel, shape=[N,] + Determine the voxel position in range [0, num_windows * max_tokens) of each voxel. + ''' + device = batch_win_inds.device + + flat2window_inds_dict = {} + + for dl in drop_info: # dl: short for drop level + + dl_mask = voxel_drop_lvl == dl + if not dl_mask.any(): + continue + + conti_win_inds = make_continuous_inds(batch_win_inds[dl_mask]) + + max_tokens = drop_info[dl]['max_tokens'] + + inner_win_inds = get_inner_win_inds(conti_win_inds) + + flat2window_inds = conti_win_inds * max_tokens + inner_win_inds + + flat2window_inds_dict[dl] = (flat2window_inds, torch.where(dl_mask)) + + if debug: + num_windows = len(torch.unique(conti_win_inds)) + assert inner_win_inds.max() < max_tokens, f'Max inner inds({inner_win_inds.max()}) larger(equal) than {max_tokens}' + assert (flat2window_inds >= 0).all() + max_ind = flat2window_inds.max().item() + assert max_ind < num_windows * max_tokens, f'max_ind({max_ind}) larger than upper bound({num_windows * max_tokens})' + assert max_ind >= (num_windows-1) * max_tokens, f'max_ind({max_ind}) less than lower bound({(num_windows-1) * max_tokens})' + + return flat2window_inds_dict + + +def flat2window(feat, voxel_drop_lvl, flat2win_inds_dict, drop_info, padding=0): + ''' + Args: + feat: shape=[N, C], N is the voxel num in the batch. + voxel_drop_lvl: shape=[N, ]. Indicates drop_level of the window the voxel belongs to. + Returns: + feat_3d_dict: contains feat_3d of each drop level. Shape of feat_3d is [num_windows, num_max_tokens, C]. + + drop_info: + {1:{'max_tokens':50, 'range':(0, 50)}, } + ''' + dtype = feat.dtype + device = feat.device + feat_dim = feat.shape[-1] + + feat_3d_dict = {} + + for dl in drop_info: + + dl_mask = voxel_drop_lvl == dl + if not dl_mask.any(): + continue + + feat_this_dl = feat[dl_mask] + + this_inds = flat2win_inds_dict[dl][0] + + max_tokens = drop_info[dl]['max_tokens'] + num_windows = (this_inds // max_tokens).max().item() + 1 + padding = torch.tensor(padding, dtype=dtype, device=device) + feat_3d = torch.ones((num_windows * max_tokens, feat_dim), dtype=dtype, device=device) * padding + # if this_inds.max() >= num_windows * max_tokens: + # set_trace() + feat_3d[this_inds] = feat_this_dl + feat_3d = feat_3d.reshape((num_windows, max_tokens, feat_dim)) + feat_3d_dict[dl] = feat_3d + + return feat_3d_dict + +def window2flat(feat_3d_dict, inds_dict): + flat_feat_list = [] + + num_all_voxel = 0 + for dl in inds_dict: + num_all_voxel += inds_dict[dl][0].shape[0] + + dtype = feat_3d_dict[list(feat_3d_dict.keys())[0]].dtype + + device = feat_3d_dict[list(feat_3d_dict.keys())[0]].device + feat_dim = feat_3d_dict[list(feat_3d_dict.keys())[0]].shape[-1] + + all_flat_feat = torch.zeros((num_all_voxel, feat_dim), device=device, dtype=dtype) + # check_feat = -torch.ones((num_all_voxel,), device=device, dtype=torch.long) + + for dl in feat_3d_dict: + feat = feat_3d_dict[dl] + feat_dim = feat.shape[-1] + inds, flat_pos = inds_dict[dl] + feat = feat.reshape(-1, feat_dim) + flat_feat = feat[inds] + all_flat_feat[flat_pos] = flat_feat + # check_feat[flat_pos] = 0 + # flat_feat_list.append(flat_feat) + # assert (check_feat == 0).all() + + return all_flat_feat + +def get_flat2win_inds_v2(batch_win_inds, voxel_drop_lvl, drop_info, debug=True): + transform_dict = get_flat2win_inds(batch_win_inds, voxel_drop_lvl, drop_info, debug) + # add voxel_drop_lvl and batching_info into transform_dict for better wrapping + transform_dict['voxel_drop_level'] = voxel_drop_lvl + transform_dict['batching_info'] = drop_info + return transform_dict + +def window2flat_v2(feat_3d_dict, inds_dict): + inds_v1 = {k:inds_dict[k] for k in inds_dict if not isinstance(k, str)} + return window2flat(feat_3d_dict, inds_v1) + +def flat2window_v2(feat, inds_dict, padding=0): + assert 'voxel_drop_level' in inds_dict, 'voxel_drop_level should be in inds_dict in v2 function' + inds_v1 = {k:inds_dict[k] for k in inds_dict if not isinstance(k, str)} + batching_info = inds_dict['batching_info'] + return flat2window(feat, inds_dict['voxel_drop_level'], inds_v1, batching_info, padding=padding) + +def scatter_v2(feat, coors, mode, return_inv=True, min_points=0, unq_inv=None, new_coors=None): + assert feat.size(0) == coors.size(0) + if mode == 'avg': + mode = 'mean' + + + if unq_inv is None and min_points > 0: + new_coors, unq_inv, unq_cnt = torch.unique(coors, return_inverse=True, return_counts=True, dim=0) + elif unq_inv is None: + new_coors, unq_inv = torch.unique(coors, return_inverse=True, return_counts=False, dim=0) + else: + assert new_coors is not None, 'please pass new_coors for interface consistency, caller: {}'.format(traceback.extract_stack()[-2][2]) + + + if min_points > 0: + cnt_per_point = unq_cnt[unq_inv] + valid_mask = cnt_per_point >= min_points + feat = feat[valid_mask] + coors = coors[valid_mask] + new_coors, unq_inv, unq_cnt = torch.unique(coors, return_inverse=True, return_counts=True, dim=0) + + if mode == 'max': + new_feat, argmax = torch_scatter.scatter_max(feat, unq_inv, dim=0) + elif mode in ('mean', 'sum'): + new_feat = torch_scatter.scatter(feat, unq_inv, dim=0, reduce=mode) + else: + raise NotImplementedError + + if not return_inv: + return new_feat, new_coors + else: + return new_feat, new_coors, unq_inv + +def filter_almost_empty(pts_coors, min_points=5): + if min_points > 0: + new_coors, unq_inv, unq_cnt = torch.unique(coors, return_inverse=True, return_counts=True, dim=0) + cnt_per_point = unq_cnt[unq_inv] + valid_mask = cnt_per_point >= min_points + else: + valid_mask = torch.ones(len(pts_coors), device=pts_coors.device, dtype=torch.bool) + return valid_mask + + +@torch.no_grad() +def get_inner_win_inds_deprecated(win_inds): + ''' + Args: + win_inds indicates which windows a voxel belongs to. Voxels share a window have same inds. + shape = [N,] + Return: + inner_inds: shape=[N,]. Indicates voxel's id in a window. if M voxels share a window, their inner_inds would + be torch.arange(m, dtype=torch.long) + Note that this function might output different results from get_inner_win_inds_slow due to the unstable pytorch sort. + ''' + + sort_inds, order = win_inds.sort() #sort_inds is like [0,0,0, 1, 2,2] -> [0,1, 2, 0, 0, 1] + roll_inds_left = torch.roll(sort_inds, -1) # [0,0, 1, 2,2,0] + + diff = sort_inds - roll_inds_left #[0, 0, -1, -1, 0, 2] + end_pos_mask = diff != 0 + + bincount = torch.bincount(win_inds) + # assert bincount.max() <= max_tokens + unique_sort_inds, _ = torch.sort(torch.unique(win_inds)) + num_tokens_each_win = bincount[unique_sort_inds] #[3, 1, 2] + + template = torch.ones_like(win_inds) #[1,1,1, 1, 1,1] + template[end_pos_mask] = (num_tokens_each_win-1) * -1 #[1,1,-2, 0, 1,-1] + + inner_inds = torch.cumsum(template, 0) #[1,2,0, 0, 1,0] + inner_inds[end_pos_mask] = num_tokens_each_win #[1,2,3, 1, 1,2] + inner_inds -= 1 #[0,1,2, 0, 0,1] + + + #recover the order + inner_inds_reorder = -torch.ones_like(win_inds) + inner_inds_reorder[order] = inner_inds + + ##sanity check + assert (inner_inds >= 0).all() + assert (inner_inds == 0).sum() == len(unique_sort_inds) + assert (num_tokens_each_win > 0).all() + random_win = unique_sort_inds[random.randint(0, len(unique_sort_inds)-1)] + random_mask = win_inds == random_win + num_voxel_this_win = bincount[random_win].item() + random_inner_inds = inner_inds_reorder[random_mask] + + assert len(torch.unique(random_inner_inds)) == num_voxel_this_win + assert random_inner_inds.max() == num_voxel_this_win - 1 + assert random_inner_inds.min() == 0 + + return inner_inds_reorder + +import ingroup_indices +from torch.autograd import Function +class IngroupIndicesFunction(Function): + + @staticmethod + def forward(ctx, group_inds): + + out_inds = torch.zeros_like(group_inds) - 1 + + ingroup_indices.forward(group_inds, out_inds) + + ctx.mark_non_differentiable(out_inds) + + return out_inds + + @staticmethod + def backward(ctx, g): + + return None + +get_inner_win_inds = IngroupIndicesFunction.apply + +@torch.no_grad() +def get_window_coors(coors, sparse_shape, window_shape, do_shift): + + if len(window_shape) == 2: + win_shape_x, win_shape_y = window_shape + win_shape_z = sparse_shape[-1] + else: + win_shape_x, win_shape_y, win_shape_z = window_shape + + sparse_shape_x, sparse_shape_y, sparse_shape_z = sparse_shape + assert sparse_shape_z < sparse_shape_x, 'Usually holds... in case of wrong order' + + max_num_win_x = int(np.ceil((sparse_shape_x / win_shape_x)) + 1) # plus one here to meet the needs of shift. + max_num_win_y = int(np.ceil((sparse_shape_y / win_shape_y)) + 1) # plus one here to meet the needs of shift. + max_num_win_z = int(np.ceil((sparse_shape_z / win_shape_z)) + 1) # plus one here to meet the needs of shift. + max_num_win_per_sample = max_num_win_x * max_num_win_y * max_num_win_z + + if do_shift: + shift_x, shift_y, shift_z = win_shape_x // 2, win_shape_y // 2, win_shape_z // 2 + else: + shift_x, shift_y, shift_z = win_shape_x, win_shape_y, win_shape_z + + # compatibility between 2D window and 3D window + if sparse_shape_z == win_shape_z: + shift_z = 0 + + shifted_coors_x = coors[:, 3] + shift_x + shifted_coors_y = coors[:, 2] + shift_y + shifted_coors_z = coors[:, 1] + shift_z + + win_coors_x = shifted_coors_x // win_shape_x + win_coors_y = shifted_coors_y // win_shape_y + win_coors_z = shifted_coors_z // win_shape_z + + if len(window_shape) == 2: + assert (win_coors_z == 0).all() + + batch_win_inds = coors[:, 0] * max_num_win_per_sample + \ + win_coors_x * max_num_win_y * max_num_win_z + \ + win_coors_y * max_num_win_z + \ + win_coors_z + + coors_in_win_x = shifted_coors_x % win_shape_x + coors_in_win_y = shifted_coors_y % win_shape_y + coors_in_win_z = shifted_coors_z % win_shape_z + coors_in_win = torch.stack([coors_in_win_z, coors_in_win_y, coors_in_win_x], dim=-1) + # coors_in_win = torch.stack([coors_in_win_x, coors_in_win_y], dim=-1) + + return batch_win_inds, coors_in_win + +@torch.no_grad() +def make_continuous_inds(inds): + + ### make batch_win_inds continuous + dtype = inds.dtype + device = inds.device + + unique_inds, _ = torch.sort(torch.unique(inds)) + num_valid_inds = len(unique_inds) + max_origin_inds = unique_inds.max().item() + canvas = -torch.ones((max_origin_inds+1,), dtype=dtype, device=device) + canvas[unique_inds] = torch.arange(num_valid_inds, dtype=dtype, device=device) + + conti_inds = canvas[inds] + + return conti_inds + + +def build_mlp(in_channel, hidden_dims, norm_cfg, is_head=False, act='relu', bias=False, dropout=0): + layer_list = [] + last_channel = in_channel + for i, c in enumerate(hidden_dims): + act_layer = get_activation_layer(act, c) + + norm_layer = build_norm_layer(norm_cfg, c)[1] + if i == len(hidden_dims) - 1 and is_head: + layer_list.append(nn.Linear(last_channel, c, bias=True),) + else: + sq = [ + nn.Linear(last_channel, c, bias=bias), + norm_layer, + act_layer, + ] + if dropout > 0: + sq.append(nn.Dropout(dropout)) + layer_list.append( + nn.Sequential( + *sq + ) + ) + + last_channel = c + mlp = nn.Sequential(*layer_list) + return mlp + +def get_activation(activation): + """Return an activation function given a string""" + if activation == "relu": + return torch.nn.functional.relu + if activation == "gelu": + return torch.nn.functional.gelu + if activation == "glu": + return torch.nn.functional.glu + raise RuntimeError(F"activation should be relu/gelu, not {activation}.") + +def get_activation_layer(act, dim=None): + """Return an activation function given a string""" + act = act.lower() + if act == 'relu': + act_layer = nn.ReLU(inplace=True) + elif act == 'gelu': + act_layer = nn.GELU() + elif act == 'leakyrelu': + act_layer = nn.LeakyReLU(inplace=True) + elif act == 'prelu': + act_layer = nn.PReLU(num_parameters=dim) + elif act == 'swish' or act == 'silu': + act_layer = nn.SiLU(inplace=True) + elif act == 'glu': + act_layer = nn.GLU() + elif act == 'elu': + act_layer = nn.ELU(inplace=True) + else: + raise NotImplementedError + return act_layer \ No newline at end of file diff --git a/mmdet3d/models/middle_encoders/__init__.py b/mmdet3d/models/middle_encoders/__init__.py index 96f5d2019d..3c84368ddc 100644 --- a/mmdet3d/models/middle_encoders/__init__.py +++ b/mmdet3d/models/middle_encoders/__init__.py @@ -3,8 +3,9 @@ from .sparse_encoder import SparseEncoder, SparseEncoderSASSD from .sparse_unet import SparseUNet from .voxel_set_abstraction import VoxelSetAbstraction +from .sst_input_layer_v2 import PseudoMiddleEncoderForSpconvFSD __all__ = [ 'PointPillarsScatter', 'SparseEncoder', 'SparseEncoderSASSD', 'SparseUNet', - 'VoxelSetAbstraction' + 'VoxelSetAbstraction', 'PseudoMiddleEncoderForSpconvFSD' ] diff --git a/mmdet3d/models/middle_encoders/sparse_unet.py b/mmdet3d/models/middle_encoders/sparse_unet.py index 2d13507719..a0828157a0 100644 --- a/mmdet3d/models/middle_encoders/sparse_unet.py +++ b/mmdet3d/models/middle_encoders/sparse_unet.py @@ -297,3 +297,92 @@ def make_decoder_layers(self, make_block, norm_cfg, in_channels): indice_key='subm1', conv_type='SubMConv3d')) in_channels = block_channels[2] + + +@MODELS.register_module() +class SimpleSparseUNet(SparseUNet): + r""" A simpler SparseUNet, removing the densify part + """ + + def __init__(self, + in_channels, + sparse_shape, + order=('conv', 'norm', 'act'), + norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01), + base_channels=16, + output_channels=128, + ndim=3, + encoder_channels=((16, ), (32, 32, 32), (64, 64, 64), (64, 64, + 64)), + encoder_paddings=((1, ), (1, 1, 1), (1, 1, 1), ((0, 1, 1), 1, + 1)), + decoder_channels=((64, 64, 64), (64, 64, 32), (32, 32, 16), + (16, 16, 16)), + decoder_paddings=((1, 0), (1, 0), (0, 0), (0, 1)), + keep_coors_dims=None, + act_type='relu', + init_cfg=None): + super().__init__( + in_channels=in_channels, + sparse_shape=sparse_shape, + order=order, + norm_cfg=norm_cfg, + base_channels=base_channels, + output_channels=output_channels, + encoder_channels=encoder_channels, + encoder_paddings=encoder_paddings, + decoder_channels=decoder_channels, + decoder_paddings=decoder_paddings, + # ndim=ndim, + # act_type=act_type, + init_cfg=init_cfg, + ) + self.conv_out = None # override + self.ndim = ndim + self.keep_coors_dims = keep_coors_dims + + # @auto_fp16(apply_to=('voxel_features', )) + def forward(self, voxel_info): + """Forward of SparseUNet. + + Args: + voxel_features (torch.float32): Voxel features in shape [N, C]. + coors (torch.int32): Coordinates in shape [N, 4], + the columns in the order of (batch_idx, z_idx, y_idx, x_idx). + batch_size (int): Batch size. + + Returns: + dict[str, torch.Tensor]: Backbone features. + """ + coors = voxel_info['voxel_coors'] + if self.ndim == 2: + assert (coors[:, 1] == 0).all() + coors = coors[:, [0, 2, 3]] # remove the z-axis indices + if self.keep_coors_dims is not None: + coors = coors[:, self.keep_coors_dims] + voxel_features = voxel_info['voxel_feats'] + coors = coors.int() + batch_size = coors[:, 0].max().item() + 1 + input_sp_tensor = SparseConvTensor(voxel_features, coors, + self.sparse_shape, + batch_size) + x = self.conv_input(input_sp_tensor) + + encode_features = [] + for encoder_layer in self.encoder_layers: + x = encoder_layer(x) + encode_features.append(x) + + x = encode_features[-1] + for i in range(self.stage_num, 0, -1): + x = self.decoder_layer_forward(encode_features[i - 1], x, + getattr(self, f'lateral_layer{i}'), + getattr(self, f'merge_layer{i}'), + getattr(self, f'upsample_layer{i}')) + # decode_features.append(x) + + seg_features = x.features + ret = {'voxel_feats':x.features} + ret = [ret,] # keep consistent with SSTv2 + + return ret \ No newline at end of file diff --git a/mmdet3d/models/middle_encoders/sst_input_layer_v2.py b/mmdet3d/models/middle_encoders/sst_input_layer_v2.py new file mode 100644 index 0000000000..eda6f4c82f --- /dev/null +++ b/mmdet3d/models/middle_encoders/sst_input_layer_v2.py @@ -0,0 +1,328 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# move the computation of position embeding and mask in middle_encoder_layer + +import torch +from mmdet3d.models.layers.sst import flat2window_v2, window2flat_v2, get_inner_win_inds, \ + get_flat2win_inds_v2, get_window_coors +from torch import nn + +from mmdet3d.registry import MODELS + + +@MODELS.register_module() +class PseudoMiddleEncoderForSpconvFSD(nn.Module): + + def __init__(self,): + super().__init__() + + # @auto_fp16(apply_to=('voxel_feat', )) + def forward(self, voxel_feats, voxel_coors, batch_size=None): + ''' + Args: + voxel_feats: shape=[N, C], N is the voxel num in the batch. + coors: shape=[N, 4], [b, z, y, x] + Returns: + feat_3d_dict: contains region features (feat_3d) of each region batching level. Shape of feat_3d is [num_windows, num_max_tokens, C]. + flat2win_inds_list: two dict containing transformation information for non-shifted grouping and shifted grouping, respectively. The two dicts are used in function flat2window and window2flat. + voxel_info: dict containing extra information of each voxel for usage in the backbone. + ''' + + voxel_info = {} + voxel_info['voxel_feats'] = voxel_feats + voxel_info['voxel_coors'] = voxel_coors + + return voxel_info + + +@MODELS.register_module() +class SSTInputLayerV2(nn.Module): + """ + This is one of the core class of SST, converting the output of voxel_encoder to sst input. + There are 3 things to be done in this class: + 1. Reginal Grouping : assign window indices to each voxel. + 2. Voxel drop and region batching: see our paper for detail + 3. Pre-computing the transfomation information for converting flat features ([N x C]) to region features ([R, T, C]). + R is the number of regions containing at most T tokens (voxels). See function flat2window and window2flat for details. + + Main args: + drop_info (dict): drop configuration for region batching. + window_shape (tuple[int]): (num_x, num_y). Each window is divided to num_x * num_y pillars (including empty pillars). + shift_list (list[tuple]): [(shift_x, shift_y), ]. shift_x = 5 means all windonws will be shifted for 5 voxels along positive direction of x-aixs. + debug: apply strong assertion for developing. + """ + + def __init__(self, + drop_info, + window_shape, + sparse_shape, + shuffle_voxels=True, + debug=True, + normalize_pos=False, + pos_temperature=10000, + mute=False, + ): + super().__init__() + self.fp16_enabled = False + self.meta_drop_info = drop_info + self.sparse_shape = sparse_shape + self.shuffle_voxels = shuffle_voxels + self.debug = debug + self.window_shape = window_shape + self.normalize_pos = normalize_pos + self.pos_temperature = pos_temperature + self.mute = mute + + + # @auto_fp16(apply_to=('voxel_feat', )) + def forward(self, voxel_feats, voxel_coors, batch_size=None): + ''' + Args: + voxel_feats: shape=[N, C], N is the voxel num in the batch. + coors: shape=[N, 4], [b, z, y, x] + Returns: + feat_3d_dict: contains region features (feat_3d) of each region batching level. Shape of feat_3d is [num_windows, num_max_tokens, C]. + flat2win_inds_list: two dict containing transformation information for non-shifted grouping and shifted grouping, respectively. The two dicts are used in function flat2window and window2flat. + voxel_info: dict containing extra information of each voxel for usage in the backbone. + ''' + self.set_drop_info() + voxel_coors = voxel_coors.long() + + if self.shuffle_voxels: + # shuffle the voxels to make the drop process uniform. + shuffle_inds = torch.randperm(len(voxel_feats)) + voxel_feats = voxel_feats[shuffle_inds] + voxel_coors = voxel_coors[shuffle_inds] + + voxel_info = self.window_partition(voxel_coors) + voxel_info['voxel_feats'] = voxel_feats + voxel_info['voxel_coors'] = voxel_coors + voxel_info = self.drop_voxel(voxel_info, 2) # voxel_info is updated in this function + + voxel_feats = voxel_info['voxel_feats'] # after dropping + voxel_coors = voxel_info['voxel_coors'] + + for i in range(2): + + voxel_info[f'flat2win_inds_shift{i}'] = \ + get_flat2win_inds_v2(voxel_info[f'batch_win_inds_shift{i}'], voxel_info[f'voxel_drop_level_shift{i}'], self.drop_info, debug=True) + + voxel_info[f'pos_dict_shift{i}'] = \ + self.get_pos_embed(voxel_info[f'flat2win_inds_shift{i}'], voxel_info[f'coors_in_win_shift{i}'], voxel_feats.size(1), voxel_feats.dtype) + + voxel_info[f'key_mask_shift{i}'] = \ + self.get_key_padding_mask(voxel_info[f'flat2win_inds_shift{i}']) + + if self.debug: + coors_3d_dict_shift0 = flat2window_v2(voxel_coors, voxel_info['flat2win_inds_shift0']) + coors_2d = window2flat_v2(coors_3d_dict_shift0, voxel_info['flat2win_inds_shift0']) + assert (coors_2d == voxel_coors).all() + + if self.shuffle_voxels: + voxel_info['shuffle_inds'] = shuffle_inds + + return voxel_info + + def drop_single_shift(self, batch_win_inds): + drop_info = self.drop_info + drop_lvl_per_voxel = -torch.ones_like(batch_win_inds) + inner_win_inds = get_inner_win_inds(batch_win_inds) + bincount = torch.bincount(batch_win_inds) + num_per_voxel_before_drop = bincount[batch_win_inds] # + target_num_per_voxel = torch.zeros_like(batch_win_inds) + + for dl in drop_info: + max_tokens = drop_info[dl]['max_tokens'] + lower, upper = drop_info[dl]['drop_range'] + range_mask = (num_per_voxel_before_drop >= lower) & (num_per_voxel_before_drop < upper) + target_num_per_voxel[range_mask] = max_tokens + drop_lvl_per_voxel[range_mask] = dl + + if self.debug: + assert (target_num_per_voxel > 0).all() + assert (drop_lvl_per_voxel >= 0).all() + + keep_mask = inner_win_inds < target_num_per_voxel + return keep_mask, drop_lvl_per_voxel + + def drop_voxel(self, voxel_info, num_shifts): + ''' + To make it clear and easy to follow, we do not use loop to process two shifts. + ''' + + batch_win_inds_s0 = voxel_info['batch_win_inds_shift0'] + num_all_voxel = batch_win_inds_s0.shape[0] + + voxel_keep_inds = torch.arange(num_all_voxel, device=batch_win_inds_s0.device, dtype=torch.long) + + keep_mask_s0, drop_lvl_s0 = self.drop_single_shift(batch_win_inds_s0) + if self.debug: + assert (drop_lvl_s0 >= 0).all() + + drop_lvl_s0 = drop_lvl_s0[keep_mask_s0] + voxel_keep_inds = voxel_keep_inds[keep_mask_s0] + batch_win_inds_s0 = batch_win_inds_s0[keep_mask_s0] + + if num_shifts == 1: + voxel_info['voxel_keep_inds'] = voxel_keep_inds + voxel_info['voxel_drop_level_shift0'] = drop_lvl_s0 + voxel_info['batch_win_inds_shift0'] = batch_win_inds_s0 + return voxel_info + + batch_win_inds_s1 = voxel_info['batch_win_inds_shift1'] + batch_win_inds_s1 = batch_win_inds_s1[keep_mask_s0] + + keep_mask_s1, drop_lvl_s1 = self.drop_single_shift(batch_win_inds_s1) + if self.debug: + assert (drop_lvl_s1 >= 0).all() + + # drop data in first shift again + drop_lvl_s0 = drop_lvl_s0[keep_mask_s1] + voxel_keep_inds = voxel_keep_inds[keep_mask_s1] + batch_win_inds_s0 = batch_win_inds_s0[keep_mask_s1] + + drop_lvl_s1 = drop_lvl_s1[keep_mask_s1] + batch_win_inds_s1 = batch_win_inds_s1[keep_mask_s1] + + voxel_info['voxel_keep_inds'] = voxel_keep_inds + voxel_info['voxel_drop_level_shift0'] = drop_lvl_s0 + voxel_info['batch_win_inds_shift0'] = batch_win_inds_s0 + voxel_info['voxel_drop_level_shift1'] = drop_lvl_s1 + voxel_info['batch_win_inds_shift1'] = batch_win_inds_s1 + voxel_keep_inds = voxel_info['voxel_keep_inds'] + + voxel_num_before_drop = len(voxel_info['voxel_coors']) + voxel_info['voxel_feats'] = voxel_info['voxel_feats'][voxel_keep_inds] + voxel_info['voxel_coors'] = voxel_info['voxel_coors'][voxel_keep_inds] + + # Some other variables need to be dropped. + for k, v in voxel_info.items(): + if isinstance(v, torch.Tensor) and len(v) == voxel_num_before_drop: + voxel_info[k] = v[voxel_keep_inds] + + ### sanity check + if self.debug and self.training: + for dl in self.drop_info: + max_tokens = self.drop_info[dl]['max_tokens'] + + mask_s0 = drop_lvl_s0 == dl + if not mask_s0.any(): + if not self.mute: + print(f'No voxel belongs to drop_level:{dl} in shift 0') + continue + real_max = torch.bincount(batch_win_inds_s0[mask_s0]).max() + assert real_max <= max_tokens, f'real_max({real_max}) > {max_tokens} in shift0' + + mask_s1 = drop_lvl_s1 == dl + if not mask_s1.any(): + if not self.mute: + print(f'No voxel belongs to drop_level:{dl} in shift 1') + continue + real_max = torch.bincount(batch_win_inds_s1[mask_s1]).max() + assert real_max <= max_tokens, f'real_max({real_max}) > {max_tokens} in shift1' + ### + return voxel_info + + @torch.no_grad() + def window_partition(self, coors): + voxel_info = {} + for i in range(2): + batch_win_inds, coors_in_win = get_window_coors(coors, self.sparse_shape, self.window_shape, i == 1) + voxel_info[f'batch_win_inds_shift{i}'] = batch_win_inds + voxel_info[f'coors_in_win_shift{i}'] = coors_in_win + + return voxel_info + + @torch.no_grad() + def get_pos_embed(self, inds_dict, coors_in_win, feat_dim, dtype): + ''' + Args: + coors_in_win: shape=[N, 3], order: z, y, x + ''' + + # [N,] + window_shape = self.window_shape + if len(window_shape) == 2: + ndim = 2 + win_x, win_y = window_shape + win_z = 0 + elif window_shape[-1] == 1: + ndim = 2 + win_x, win_y = window_shape[:2] + win_z = 0 + else: + win_x, win_y, win_z = window_shape + ndim = 3 + + assert coors_in_win.size(1) == 3 + z, y, x = coors_in_win[:, 0] - win_z/2, coors_in_win[:, 1] - win_y/2, coors_in_win[:, 2] - win_x/2 + assert (x >= -win_x/2 - 1e-4).all() + assert (x <= win_x/2-1 + 1e-4).all() + + if self.normalize_pos: + x = x / win_x * 2 * 3.1415 #[-pi, pi] + y = y / win_y * 2 * 3.1415 #[-pi, pi] + z = z / win_z * 2 * 3.1415 #[-pi, pi] + + pos_length = feat_dim // ndim + # [pos_length] + inv_freq = torch.arange( + pos_length, dtype=torch.float32, device=coors_in_win.device) + inv_freq = self.pos_temperature ** (2 * (inv_freq // 2) / pos_length) + + # [num_tokens, pos_length] + embed_x = x[:, None] / inv_freq[None, :] + embed_y = y[:, None] / inv_freq[None, :] + if ndim == 3: + embed_z = z[:, None] / inv_freq[None, :] + + # [num_tokens, pos_length] + embed_x = torch.stack([embed_x[:, ::2].sin(), embed_x[:, 1::2].cos()], dim=-1).flatten(1) + embed_y = torch.stack([embed_y[:, ::2].sin(), embed_y[:, 1::2].cos()], dim=-1).flatten(1) + if ndim == 3: + embed_z = torch.stack([embed_z[:, ::2].sin(), embed_z[:, 1::2].cos()], dim=-1).flatten(1) + + # [num_tokens, c] + if ndim == 3: + pos_embed_2d = torch.cat([embed_x, embed_y, embed_z], dim=-1).to(dtype) + else: + pos_embed_2d = torch.cat([embed_x, embed_y], dim=-1).to(dtype) + + gap = feat_dim - pos_embed_2d.size(1) + assert gap >= 0 + if gap > 0: + assert ndim == 3 + padding = torch.zeros((pos_embed_2d.size(0), gap), dtype=dtype, device=coors_in_win.device) + pos_embed_2d = torch.cat([pos_embed_2d, padding], dim=1) + else: + assert ndim == 2 + + pos_embed_dict = flat2window_v2( + pos_embed_2d, inds_dict) + + return pos_embed_dict + + @torch.no_grad() + def get_key_padding_mask(self, ind_dict): + num_all_voxel = len(ind_dict['voxel_drop_level']) + key_padding = torch.ones((num_all_voxel, 1)).to(ind_dict['voxel_drop_level'].device).bool() + + window_key_padding_dict = flat2window_v2(key_padding, ind_dict) + + # logical not. True means masked + for key, value in window_key_padding_dict.items(): + window_key_padding_dict[key] = value.logical_not().squeeze(2) + + return window_key_padding_dict + + def set_drop_info(self): + if hasattr(self, 'drop_info'): + return + meta = self.meta_drop_info + if isinstance(meta, tuple): + if self.training: + self.drop_info = meta[0] + else: + self.drop_info = meta[1] + else: + self.drop_info = meta + print(f'drop_info is set to {self.drop_info}, in input_layer') \ No newline at end of file diff --git a/mmdet3d/models/necks/__init__.py b/mmdet3d/models/necks/__init__.py index 0fb3a42360..cfb372b8b0 100644 --- a/mmdet3d/models/necks/__init__.py +++ b/mmdet3d/models/necks/__init__.py @@ -5,7 +5,8 @@ from .imvoxel_neck import OutdoorImVoxelNeck from .pointnet2_fp_neck import PointNetFPNeck from .second_fpn import SECONDFPN +from .voxel2point_neck import Voxel2PointScatterNeck __all__ = [ - 'FPN', 'SECONDFPN', 'OutdoorImVoxelNeck', 'PointNetFPNeck', 'DLANeck' + 'FPN', 'SECONDFPN', 'OutdoorImVoxelNeck', 'PointNetFPNeck', 'DLANeck', 'Voxel2PointScatterNeck' ] diff --git a/mmdet3d/models/necks/voxel2point_neck.py b/mmdet3d/models/necks/voxel2point_neck.py new file mode 100644 index 0000000000..8e54c3b569 --- /dev/null +++ b/mmdet3d/models/necks/voxel2point_neck.py @@ -0,0 +1,62 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch import nn + +from mmdet3d.registry import MODELS + + +@MODELS.register_module() +class Voxel2PointScatterNeck(nn.Module): + """ + A memory-efficient voxel2point with torch_scatter + """ + + def __init__( + self, + point_cloud_range=None, + voxel_size=None, + with_xyz=True, + normalize_local_xyz=False, + ): + super().__init__() + self.point_cloud_range = point_cloud_range + self.voxel_size = voxel_size + self.with_xyz = with_xyz + self.normalize_local_xyz = normalize_local_xyz + + def forward(self, points, pts_coors, voxel_feats, voxel2point_inds, voxel_padding=-1): + """Forward function. + + Args: + points (torch.Tensor): of shape (N, C_point). + pts_coors (torch.Tensor): of shape (N, 4). + voxel_feats (torch.Tensor): of shape (M, C_feature), should be padded and reordered. + voxel2point_inds: (N,) + + Returns: + torch.Tensor: of shape (N, C_feature+C_point). + """ + assert points.size(0) == pts_coors.size(0) == voxel2point_inds.size(-1) + dtype = voxel_feats.dtype + device = voxel_feats.device + pts_feats = voxel_feats[voxel2point_inds] # voxel_feats must be the output of torch_scatter, voxel2point_inds is the input of torch_scatter + pts_mask = ~((pts_feats == voxel_padding).all(1)) # some dropped voxels are padded + if self.with_xyz: + pts_feats = pts_feats[pts_mask] + pts_coors = pts_coors[pts_mask] + points = points[pts_mask] + + voxel_size = torch.tensor(self.voxel_size, dtype=dtype, device=device).reshape(1,3) + pc_min_range = torch.tensor(self.point_cloud_range[:3], dtype=dtype, device=device).reshape(1,3) + voxel_center_each_pts = (pts_coors[:, [3,2,1]].to(dtype).to(device) + 0.5) * voxel_size + pc_min_range# x y z order + local_xyz = points[:, :3] - voxel_center_each_pts + if self.normalize_local_xyz: + local_xyz = local_xyz / (voxel_size / 2) + + if self.training and not self.normalize_local_xyz: + assert (local_xyz.abs() < voxel_size / 2 + 1e-3).all(), 'Holds in training. However, in test, this is not always True because of lack of point range clip' + results = torch.cat([pts_feats, local_xyz], 1) + else: + results = pts_feats[pts_mask] + + return results, pts_mask \ No newline at end of file diff --git a/mmdet3d/models/roi_heads/__init__.py b/mmdet3d/models/roi_heads/__init__.py index 0e90b1a755..dfbff8d7b4 100644 --- a/mmdet3d/models/roi_heads/__init__.py +++ b/mmdet3d/models/roi_heads/__init__.py @@ -7,9 +7,11 @@ from .point_rcnn_roi_head import PointRCNNRoIHead from .pv_rcnn_roi_head import PVRCNNRoiHead from .roi_extractors import Single3DRoIAwareExtractor, SingleRoIExtractor +from .fsd_roi_head import GroupCorrectionHead __all__ = [ 'Base3DRoIHead', 'PartAggregationROIHead', 'PointwiseSemanticHead', 'Single3DRoIAwareExtractor', 'PartA2BboxHead', 'SingleRoIExtractor', - 'H3DRoIHead', 'PrimitiveHead', 'PointRCNNRoIHead', 'PVRCNNRoiHead' + 'H3DRoIHead', 'PrimitiveHead', 'PointRCNNRoIHead', 'PVRCNNRoiHead', + 'GroupCorrectionHead' ] diff --git a/mmdet3d/models/roi_heads/bbox_heads/__init__.py b/mmdet3d/models/roi_heads/bbox_heads/__init__.py index 994465ed8d..f0065468ec 100644 --- a/mmdet3d/models/roi_heads/bbox_heads/__init__.py +++ b/mmdet3d/models/roi_heads/bbox_heads/__init__.py @@ -8,9 +8,10 @@ from .parta2_bbox_head import PartA2BboxHead from .point_rcnn_bbox_head import PointRCNNBboxHead from .pv_rcnn_bbox_head import PVRCNNBBoxHead +from .fsd_bbox_head import FullySparseBboxHead __all__ = [ 'BBoxHead', 'ConvFCBBoxHead', 'Shared2FCBBoxHead', 'Shared4Conv1FCBBoxHead', 'DoubleConvFCBBoxHead', 'PartA2BboxHead', - 'H3DBboxHead', 'PointRCNNBboxHead', 'PVRCNNBBoxHead' + 'H3DBboxHead', 'PointRCNNBboxHead', 'PVRCNNBBoxHead', 'FullySparseBboxHead' ] diff --git a/mmdet3d/models/roi_heads/bbox_heads/fsd_bbox_head.py b/mmdet3d/models/roi_heads/bbox_heads/fsd_bbox_head.py new file mode 100644 index 0000000000..8ff65a7d98 --- /dev/null +++ b/mmdet3d/models/roi_heads/bbox_heads/fsd_bbox_head.py @@ -0,0 +1,791 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import torch +from mmengine.model import BaseModule +from torch import nn as nn +import torch.nn.functional as F + +from mmdet3d.structures import LiDARInstance3DBoxes, rotation_3d_in_axis +from mmdet3d.structures import xywhr2xyxyr +from mmdet3d.models.builder import build_loss +from mmdet3d.models.layers.sst import build_mlp + +from mmdet3d.structures.ops.iou3d_calculator import nms_gpu, nms_normal_gpu +from mmdet3d.models.task_modules.builder import build_bbox_coder +from mmdet.models.utils import multi_apply +from mmdet.utils import reduce_mean + +from mmdet3d.models import builder +from mmdet3d.registry import MODELS + + +@MODELS.register_module() +class FullySparseBboxHead(BaseModule): + + def __init__(self, + num_classes, + num_blocks, + in_channels, + feat_channels, + rel_mlp_hidden_dims, + rel_mlp_in_channels, + reg_mlp, + cls_mlp, + with_rel_mlp=True, + with_cluster_center=False, + with_distance=False, + mode='max', + xyz_normalizer=[20, 20, 4], + act='gelu', + geo_input=True, + with_corner_loss=False, + bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder'), + norm_cfg=dict(type='LN', eps=1e-3, momentum=0.01), + corner_loss_weight=1.0, + loss_bbox=dict( + type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=2.0), + loss_cls=dict( + type='CrossEntropyLoss', + use_sigmoid=True, + reduction='none', + loss_weight=1.0), + dropout=0, + cls_dropout=0, + reg_dropout=0, + unique_once=False, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.num_classes = num_classes + self.with_corner_loss = with_corner_loss + self.bbox_coder = build_bbox_coder(bbox_coder) + self.box_code_size = self.bbox_coder.code_size + self.loss_bbox = build_loss(loss_bbox) + self.loss_cls = build_loss(loss_cls) + self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False) + self.geo_input = geo_input + self.corner_loss_weight = corner_loss_weight + + self.num_blocks = num_blocks + self.print_info = {} + self.unique_once = unique_once + + block_list = [] + for i in range(num_blocks): + return_point_feats = i != num_blocks-1 + kwargs = dict( + type='SIRLayer', + in_channels=in_channels[i], + feat_channels=feat_channels[i], + with_distance=with_distance, + with_cluster_center=with_cluster_center, + with_rel_mlp=with_rel_mlp, + rel_mlp_hidden_dims=rel_mlp_hidden_dims[i], + rel_mlp_in_channel=rel_mlp_in_channels[i], + with_voxel_center=False, + voxel_size=[0.1, 0.1, 0.1], # not used, placeholder + point_cloud_range=[-74.88, -74.88, -2, 74.88, 74.88, 4], # not used, placeholder + norm_cfg=norm_cfg, + mode=mode, + fusion_layer=None, + return_point_feats=return_point_feats, + return_inv=False, + rel_dist_scaler=10.0, + xyz_normalizer=xyz_normalizer, + act=act, + dropout=dropout, + ) + encoder = builder.build_voxel_encoder(kwargs) + block_list.append(encoder) + self.block_list = nn.ModuleList(block_list) + + end_channel = 0 + for c in feat_channels: + end_channel += sum(c) + + if cls_mlp is not None: + self.conv_cls = build_mlp(end_channel, cls_mlp + [1,], norm_cfg, True, act=act, dropout=cls_dropout) + else: + self.conv_cls = nn.Linear(end_channel, 1) + + if reg_mlp is not None: + self.conv_reg = build_mlp(end_channel, reg_mlp + [self.box_code_size,], norm_cfg, True, act=act, dropout=reg_dropout) + else: + self.conv_reg = nn.Linear(end_channel, self.box_code_size) + + + def init_weights(self): + super().init_weights() + + # @force_fp32(apply_to=('pts_features', 'rois')) + def forward(self, pts_xyz, pts_features, pts_info, roi_inds, rois): + """Forward pass. + + Args: + seg_feats (torch.Tensor): Point-wise semantic features. + part_feats (torch.Tensor): Point-wise part prediction features. + + Returns: + tuple[torch.Tensor]: Score of class and bbox predictions. + """ + assert pts_features.size(0) > 0 + + rois_batch_idx = rois[:, 0] + rois = rois[:, 1:] + roi_centers = rois[:, :3] + rel_xyz = pts_xyz[:, :3] - roi_centers[roi_inds] + + if self.unique_once: + new_coors, unq_inv = torch.unique(roi_inds, return_inverse=True, return_counts=False, dim=0) + else: + new_coors = unq_inv = None + + + out_feats = pts_features + f_cluster = torch.cat([pts_info['local_xyz'], pts_info['boundary_offset'], pts_info['is_in_margin'][:, None], rel_xyz], dim=-1) + + cluster_feat_list = [] + for i, block in enumerate(self.block_list): + + in_feats = torch.cat([pts_xyz, out_feats], 1) + + if self.geo_input: + in_feats = torch.cat([in_feats, f_cluster / 10], 1) + + if i < self.num_blocks - 1: + # return point features + out_feats, out_cluster_feats = block(in_feats, roi_inds, f_cluster, unq_inv_once=unq_inv, new_coors_once=new_coors) + cluster_feat_list.append(out_cluster_feats) + if i == self.num_blocks - 1: + # return group features + out_cluster_feats, out_coors = block(in_feats, roi_inds, f_cluster, unq_inv_once=unq_inv, new_coors_once=new_coors) + cluster_feat_list.append(out_cluster_feats) + + final_cluster_feats = torch.cat(cluster_feat_list, dim=1) + + if self.training and (out_coors == -1).any(): + assert out_coors[0].item() == -1, 'This should hold due to sorted=True in torch.unique' + + nonempty_roi_mask = self.get_nonempty_roi_mask(out_coors, len(rois)) + + cls_score = self.conv_cls(final_cluster_feats) + bbox_pred = self.conv_reg(final_cluster_feats) + + cls_score = self.align_roi_feature_and_rois(cls_score, out_coors, len(rois)) + bbox_pred = self.align_roi_feature_and_rois(bbox_pred, out_coors, len(rois)) + + return cls_score, bbox_pred, nonempty_roi_mask + + def get_nonempty_roi_mask(self, out_coors, num_rois): + if self.training: + assert out_coors.max() + 1 <= num_rois + assert out_coors.ndim == 1 + assert torch.unique(out_coors).size(0) == out_coors.size(0) + assert (out_coors == torch.sort(out_coors)[0]).all() + out_coors = out_coors[out_coors >= 0] + nonempty_roi_mask = torch.zeros(num_rois, dtype=torch.bool, device=out_coors.device) + nonempty_roi_mask[out_coors] = True + return nonempty_roi_mask + + def align_roi_feature_and_rois(self, features, out_coors, num_rois): + """ + 1. The order of roi features obtained by dynamic pooling may not align with rois + 2. Even if we set sorted=True in torch.unique, the empty group (with idx -1) will be the first feature, causing misaligned + So here we explicitly align them to make sure the sanity + """ + new_feature = features.new_zeros((num_rois, features.size(1))) + coors_mask = out_coors >= 0 + + if not coors_mask.any(): + new_feature[:len(features), :] = features * 0 # pseudo gradient, avoid unused_parameters + return new_feature + + nonempty_coors = out_coors[coors_mask] + nonempty_feats = features[coors_mask] + + new_feature[nonempty_coors] = nonempty_feats + + return new_feature + + + def loss(self, cls_score, bbox_pred, nonempty_roi_mask, rois, labels, bbox_targets, pos_batch_idx, + pos_gt_bboxes, pos_gt_labels, reg_mask, label_weights, bbox_weights): + """Coumputing losses. + + Args: + cls_score (torch.Tensor): Scores of each roi. + bbox_pred (torch.Tensor): Predictions of bboxes. + rois (torch.Tensor): Roi bboxes. + labels (torch.Tensor): Labels of class. + bbox_targets (torch.Tensor): Target of positive bboxes. + pos_gt_bboxes (torch.Tensor): Ground truths of positive bboxes. + reg_mask (torch.Tensor): Mask for positive bboxes. + label_weights (torch.Tensor): Weights of class loss. + bbox_weights (torch.Tensor): Weights of bbox loss. + + Returns: + dict: Computed losses. + + - loss_cls (torch.Tensor): Loss of classes. + - loss_bbox (torch.Tensor): Loss of bboxes. + - loss_corner (torch.Tensor): Loss of corners. + """ + losses = dict() + num_total_samples = rcnn_batch_size = cls_score.shape[0] + assert num_total_samples > 0 + + # calculate class loss + cls_flat = cls_score.view(-1) # only to classify foreground and background + + label_weights[~nonempty_roi_mask] = 0 # do not calculate cls loss for empty rois + label_weights[nonempty_roi_mask] = 1 # we use avg_factor in loss_cls, so we need to set it to 1 + bbox_weights[...] = 1 # we use avg_factor in loss_bbox, so we need to set it to 1 + + reg_mask[~nonempty_roi_mask] = 0 # do not calculate loss for empty rois + + cls_avg_factor = num_total_samples * 1.0 + if self.train_cfg.get('sync_cls_avg_factor', False): + cls_avg_factor = reduce_mean( + bbox_weights.new_tensor([cls_avg_factor])) + + loss_cls = self.loss_cls(cls_flat, labels, label_weights, avg_factor=cls_avg_factor) + losses['loss_rcnn_cls'] = loss_cls + + # calculate regression loss + pos_inds = (reg_mask > 0) + losses['num_pos_rois'] = pos_inds.sum().float() + losses['num_neg_rois'] = (reg_mask <= 0).sum().float() + + reg_avg_factor = pos_inds.sum().item() + if self.train_cfg.get('sync_reg_avg_factor', False): + reg_avg_factor = reduce_mean( + bbox_weights.new_tensor([reg_avg_factor])) + + if pos_inds.any() == 0: + # fake a bbox loss + losses['loss_rcnn_bbox'] = bbox_pred.sum() * 0 + if self.with_corner_loss: + losses['loss_rcnn_corner'] = bbox_pred.sum() * 0 + else: + pos_bbox_pred = bbox_pred[pos_inds] + # bbox_targets should have same size with pos_bbox_pred in normal case. But reg_mask is modified by nonempty_roi_mask. So it could be different. + # filter bbox_targets per sample + + bbox_targets = self.filter_pos_assigned_but_empty_rois(bbox_targets, pos_batch_idx, pos_inds, rois[:, 0].int()) + + assert not (pos_bbox_pred == -1).all(1).any() + bbox_weights_flat = bbox_weights[pos_inds].view(-1, 1).repeat(1, pos_bbox_pred.shape[-1]) + + + if pos_bbox_pred.size(0) != bbox_targets.size(0): + raise ValueError('Impossible after filtering bbox_targets') + # I don't know why this happens + losses['loss_rcnn_bbox'] = bbox_pred.sum() * 0 + if self.with_corner_loss: + losses['loss_rcnn_corner'] = bbox_pred.sum() * 0 + return losses + + assert bbox_targets.numel() > 0 + loss_bbox = self.loss_bbox(pos_bbox_pred, bbox_targets, bbox_weights_flat, avg_factor=reg_avg_factor) + losses['loss_rcnn_bbox'] = loss_bbox + + if self.with_corner_loss: + code_size = self.bbox_coder.code_size + pos_roi_boxes3d = rois[..., 1:code_size + 1].view(-1, code_size)[pos_inds] + pos_roi_boxes3d = pos_roi_boxes3d.view(-1, code_size) + batch_anchors = pos_roi_boxes3d.clone().detach() + pos_rois_rotation = pos_roi_boxes3d[..., 6].view(-1) + roi_xyz = pos_roi_boxes3d[..., 0:3].view(-1, 3) + batch_anchors[..., 0:3] = 0 + # decode boxes + pred_boxes3d = self.bbox_coder.decode( + batch_anchors, + pos_bbox_pred.view(-1, code_size)).view(-1, code_size) + + pred_boxes3d[..., 0:3] = rotation_3d_in_axis( + pred_boxes3d[..., 0:3].unsqueeze(1), + (pos_rois_rotation + np.pi / 2), + axis=2).squeeze(1) + + pred_boxes3d[:, 0:3] += roi_xyz + + # calculate corner loss + assert pos_gt_bboxes.size(0) == pos_gt_labels.size(0) + pos_gt_bboxes = self.filter_pos_assigned_but_empty_rois(pos_gt_bboxes, pos_batch_idx, pos_inds, rois[:, 0].int()) + pos_gt_labels = self.filter_pos_assigned_but_empty_rois(pos_gt_labels, pos_batch_idx, pos_inds, rois[:, 0].int()) + if self.train_cfg.get('corner_loss_only_car', True): + car_type_index = self.train_cfg['class_names'].index('Car') + car_mask = pos_gt_labels == car_type_index + pos_gt_bboxes = pos_gt_bboxes[car_mask] + pred_boxes3d = pred_boxes3d[car_mask] + if len(pos_gt_bboxes) > 0: + loss_corner = self.get_corner_loss_lidar( + pred_boxes3d, pos_gt_bboxes) * self.corner_loss_weight + else: + loss_corner = bbox_pred.sum() * 0 + + losses['loss_rcnn_corner'] = loss_corner + + return losses + + def filter_pos_assigned_but_empty_rois(self, pos_data, pos_batch_idx, filtered_pos_mask, roi_batch_idx): + real_bsz = roi_batch_idx.max().item() + 1 + filter_data_list = [] + for b_idx in range(real_bsz): + roi_batch_mask = roi_batch_idx == b_idx + data_batch_mask = pos_batch_idx == b_idx + filter_data = pos_data[data_batch_mask][torch.nonzero(filtered_pos_mask[roi_batch_mask]).reshape(-1)] + filter_data_list.append(filter_data) + out = torch.cat(filter_data_list, 0) + return out + + def get_targets(self, sampling_results, rcnn_train_cfg, concat=True): + """Generate targets. + + Args: + sampling_results (list[:obj:`SamplingResult`]): + Sampled results from rois. + rcnn_train_cfg (:obj:`ConfigDict`): Training config of rcnn. + concat (bool): Whether to concatenate targets between batches. + + Returns: + tuple[torch.Tensor]: Targets of boxes and class prediction. + """ + pos_bboxes_list = [res.pos_bboxes for res in sampling_results] + pos_gt_bboxes_list = [res.pos_gt_bboxes for res in sampling_results] + iou_list = [res.iou for res in sampling_results] + pos_label_list = [res.pos_gt_labels for res in sampling_results] + targets = multi_apply( + self._get_target_single, + pos_bboxes_list, + pos_gt_bboxes_list, + iou_list, + pos_label_list, + cfg=rcnn_train_cfg) + + (label, bbox_targets, pos_gt_bboxes, reg_mask, label_weights, + bbox_weights) = targets + + pos_gt_labels = pos_label_list + bbox_target_batch_idx = [] + + if concat: + label = torch.cat(label, 0) + bbox_target_batch_idx = torch.cat([t.new_ones(len(t), dtype=torch.int) * i for i, t in enumerate(bbox_targets)]) + bbox_targets = torch.cat(bbox_targets, 0) + pos_gt_bboxes = torch.cat(pos_gt_bboxes, 0) + pos_gt_labels = torch.cat(pos_gt_labels, 0) + reg_mask = torch.cat(reg_mask, 0) + + label_weights = torch.cat(label_weights, 0) + label_weights /= torch.clamp(label_weights.sum(), min=1.0) + + bbox_weights = torch.cat(bbox_weights, 0) + bbox_weights /= torch.clamp(bbox_weights.sum(), min=1.0) + + return (label, bbox_targets, bbox_target_batch_idx, pos_gt_bboxes, pos_gt_labels, reg_mask, label_weights, + bbox_weights) + + def _get_target_single(self, pos_bboxes, pos_gt_bboxes, ious, pos_labels, cfg): + """Generate training targets for a single sample. + + Args: + pos_bboxes (torch.Tensor): Positive boxes with shape + (N, 7). + pos_gt_bboxes (torch.Tensor): Ground truth boxes with shape + (M, 7). + ious (torch.Tensor): IoU between `pos_bboxes` and `pos_gt_bboxes` + in shape (N, M). + cfg (dict): Training configs. + + Returns: + tuple[torch.Tensor]: Target for positive boxes. + (label, bbox_targets, pos_gt_bboxes, reg_mask, label_weights, + bbox_weights) + """ + assert pos_gt_bboxes.size(1) in (7, 9, 10) + if pos_gt_bboxes.size(1) in (9, 10): + pos_bboxes = pos_bboxes[:, :7] + pos_gt_bboxes = pos_gt_bboxes[:, :7] + + if self.num_classes > 1 or self.train_cfg.get('enable_multi_class_test', False): + label, label_weights = self.get_multi_class_soft_label(ious, pos_labels, cfg) + else: + label, label_weights = self.get_single_class_soft_label(ious, cfg) + + # box regression target + reg_mask = pos_bboxes.new_zeros(ious.size(0)).long() + reg_mask[0:pos_gt_bboxes.size(0)] = 1 + bbox_weights = (reg_mask > 0).float() + bbox_weights = self.get_class_wise_box_weights(bbox_weights, pos_labels, cfg) + + if reg_mask.bool().any(): + pos_gt_bboxes_ct = pos_gt_bboxes.clone().detach() + roi_center = pos_bboxes[..., 0:3] + roi_ry = pos_bboxes[..., 6] % (2 * np.pi) + + # canonical transformation + pos_gt_bboxes_ct[..., 0:3] -= roi_center + pos_gt_bboxes_ct[..., 6] -= roi_ry + pos_gt_bboxes_ct[..., 0:3] = rotation_3d_in_axis( + pos_gt_bboxes_ct[..., 0:3].unsqueeze(1), + -(roi_ry + np.pi / 2), + axis=2).squeeze(1) + + # flip orientation if rois have opposite orientation + ry_label = pos_gt_bboxes_ct[..., 6] % (2 * np.pi) # 0 ~ 2pi + opposite_flag = (ry_label > np.pi * 0.5) & (ry_label < np.pi * 1.5) + ry_label[opposite_flag] = (ry_label[opposite_flag] + np.pi) % ( + 2 * np.pi) # (0 ~ pi/2, 3pi/2 ~ 2pi) + flag = ry_label > np.pi + ry_label[flag] = ry_label[flag] - np.pi * 2 # (-pi/2, pi/2) + ry_label = torch.clamp(ry_label, min=-np.pi / 2, max=np.pi / 2) + pos_gt_bboxes_ct[..., 6] = ry_label + + rois_anchor = pos_bboxes.clone().detach() + rois_anchor[:, 0:3] = 0 + rois_anchor[:, 6] = 0 + bbox_targets = self.bbox_coder.encode(rois_anchor, + pos_gt_bboxes_ct) + else: + # no fg bbox + bbox_targets = pos_gt_bboxes.new_empty((0, 7)) + + return (label, bbox_targets, pos_gt_bboxes, reg_mask, label_weights, + bbox_weights) + + def get_class_wise_box_weights(self, weights, gt_labels, cfg): + class_wise_weight = cfg.get('class_wise_box_weights', None) + if class_wise_weight is None: + return weights + + num_samples = len(weights) + num_pos = len(gt_labels) + all_gt_labels = torch.cat([gt_labels, gt_labels.new_full((num_samples - num_pos,), -1)], dim=0) + for i in range(self.num_classes): + this_cls_mask = (all_gt_labels == i) + weights[this_cls_mask] *= class_wise_weight[i] + + return weights + + def get_single_class_soft_label(self, ious, cfg): + + cls_pos_mask = ious > cfg.cls_pos_thr + cls_neg_mask = ious < cfg.cls_neg_thr + interval_mask = (cls_pos_mask == 0) & (cls_neg_mask == 0) + + # iou regression target + label = (cls_pos_mask > 0).float() + # label[interval_mask] = ious[interval_mask] * 2 - 0.5 + label[interval_mask] = (ious[interval_mask] - cfg.cls_neg_thr) / (cfg.cls_pos_thr - cfg.cls_neg_thr) + assert (label >= 0).all() + # label weights + label_weights = (label >= 0).float() + return label, label_weights + + def get_multi_class_soft_label(self, ious, pos_gt_labels, cfg): + pos_thrs = cfg.cls_pos_thr + neg_thrs = cfg.cls_neg_thr + + if isinstance(pos_thrs, float): + assert isinstance(neg_thrs, float) + pos_thrs = [pos_thrs] * self.num_classes + neg_thrs = [neg_thrs] * self.num_classes + else: + assert isinstance(pos_thrs, (list, tuple)) and isinstance(neg_thrs, (list, tuple)) + + assert (pos_gt_labels >= 0).all() + assert (pos_gt_labels < self.num_classes).all() + num_samples = ious.size(0) + num_pos = pos_gt_labels.size(0) + + # if num_pos > 0 and num_pos < len(ious): + # # all pos samples are in the left of ious array + # if not ious[num_pos-1].item() >= ious[num_pos].item(): + # try: + # assert pos_gt_labels[-1].item() in (1, 2), 'The only resonable case is iou of positive Ped or Cyc less than positive Car' + # except AssertionError as e: + # print('Something werid happened') + # print('All ious: \n', ious) + # print('All labels: \n', pos_gt_labels) + + + + all_gt_labels = torch.cat([pos_gt_labels, pos_gt_labels.new_full((num_samples - num_pos,), -1)], dim=0) + + check = pos_gt_labels.new_zeros(ious.size(0)) - 1 + all_label = ious.new_zeros(ious.size(0)) + for i in range(self.num_classes): + pos_thr_i = pos_thrs[i] + neg_thr_i = neg_thrs[i] + this_cls_mask = (all_gt_labels == i) + check[this_cls_mask] += 1 + + this_ious = ious[this_cls_mask] + pos_mask = this_ious > pos_thr_i + neg_mask = this_ious < neg_thr_i + interval_mask = (pos_mask == 0) & (neg_mask == 0) + this_label = (pos_mask > 0).float() + this_label[interval_mask] = (this_ious[interval_mask] - neg_thr_i) / (pos_thr_i - neg_thr_i) + all_label[this_cls_mask] = this_label + + + assert (all_label >= 0).all() + # label weights + label_weights = (all_label >= 0).float() + + class_wise_weight = cfg.get('class_wise_cls_weights', None) + if class_wise_weight is not None: + for i in range(self.num_classes): + this_cls_mask = (all_gt_labels == i) + label_weights[this_cls_mask] *= class_wise_weight[i] + + assert (check[:num_pos] == 0).all() + assert (check[num_pos:] == -1).all() + return all_label, label_weights + + + def get_corner_loss_lidar(self, pred_bbox3d, gt_bbox3d, delta=1): + """Calculate corner loss of given boxes. + + Args: + pred_bbox3d (torch.FloatTensor): Predicted boxes in shape (N, 7). + gt_bbox3d (torch.FloatTensor): Ground truth boxes in shape (N, 7). + + Returns: + torch.FloatTensor: Calculated corner loss in shape (N). + """ + assert pred_bbox3d.shape[0] == gt_bbox3d.shape[0] + + # This is a little bit hack here because we assume the box for + # Part-A2 is in LiDAR coordinates + gt_boxes_structure = LiDARInstance3DBoxes(gt_bbox3d) + pred_box_corners = LiDARInstance3DBoxes(pred_bbox3d).corners + gt_box_corners = gt_boxes_structure.corners + + # This flip only changes the heading direction of GT boxes + gt_bbox3d_flip = gt_boxes_structure.clone() + gt_bbox3d_flip.tensor[:, 6] += np.pi + gt_box_corners_flip = gt_bbox3d_flip.corners + + corner_dist = torch.min( + torch.norm(pred_box_corners - gt_box_corners, dim=2), + torch.norm(pred_box_corners - gt_box_corners_flip, + dim=2)) # (N, 8) + # huber loss + abs_error = torch.abs(corner_dist) + quadratic = torch.clamp(abs_error, max=delta) + linear = (abs_error - quadratic) + corner_loss = 0.5 * quadratic**2 + delta * linear + + return corner_loss.mean() + + def get_bboxes( + self, + rois, + cls_score, + bbox_pred, + valid_roi_mask, + class_labels, + class_pred, + img_metas, + cfg=None + ): + """Generate bboxes from bbox head predictions. + + Args: + rois (torch.Tensor): Roi bounding boxes. + cls_score (torch.Tensor): Scores of bounding boxes. + bbox_pred (torch.Tensor): Bounding boxes predictions + class_labels (torch.Tensor): Label of classes, from rpn. + class_pred (torch.Tensor): Score for nms. From rpn + img_metas (list[dict]): Point cloud and image's meta info. + cfg (:obj:`ConfigDict`): Testing config. + + Returns: + list[tuple]: Decoded bbox, scores and labels after nms. + """ + assert rois.size(0) == cls_score.size(0) == bbox_pred.size(0) + assert isinstance(class_labels, list) and isinstance(class_pred, list) and len(class_labels) == len(class_pred) == 1 + + cls_score = cls_score.sigmoid() + assert (class_pred[0] >= 0).all() + + if self.test_cfg.get('rcnn_score_nms', False): + # assert class_pred[0].shape == cls_score.shape + class_pred[0] = cls_score.squeeze(1) + + # regard empty bboxes as false positive + rois = rois[valid_roi_mask] + cls_score = cls_score[valid_roi_mask] + bbox_pred = bbox_pred[valid_roi_mask] + + + for i in range(len(class_labels)): + class_labels[i] = class_labels[i][valid_roi_mask] + class_pred[i] = class_pred[i][valid_roi_mask] + + if rois.numel() == 0: + return [( + img_metas[0]['box_type_3d'](rois[:, 1:], rois.size(1) - 1), + class_pred[0], + class_labels[0] + ),] + + + roi_batch_id = rois[..., 0] + roi_boxes = rois[..., 1:] # boxes without batch id + batch_size = int(roi_batch_id.max().item() + 1) + + # decode boxes + roi_ry = roi_boxes[..., 6].view(-1) + roi_xyz = roi_boxes[..., 0:3].view(-1, 3) + local_roi_boxes = roi_boxes.clone().detach() + local_roi_boxes[..., 0:3] = 0 + + assert local_roi_boxes.size(1) in (7, 9) # with or without velocity + if local_roi_boxes.size(1) == 9: + # fake zero predicted velocity, which means rcnn do not refine the velocity + bbox_pred = F.pad(bbox_pred, (0, 2), 'constant', 0) + + rcnn_boxes3d = self.bbox_coder.decode(local_roi_boxes, bbox_pred) + rcnn_boxes3d[..., 0:3] = rotation_3d_in_axis( + rcnn_boxes3d[..., 0:3].unsqueeze(1), (roi_ry + np.pi / 2), + axis=2).squeeze(1) + rcnn_boxes3d[:, 0:3] += roi_xyz + + # post processing + result_list = [] + if cfg.get('multi_class_nms', False) or self.num_classes > 1: + nms_func = self.multi_class_nms + else: + nms_func = self.single_class_nms + + for batch_id in range(batch_size): + cur_class_labels = class_labels[batch_id] + if batch_size == 1: + cur_cls_score = cls_score.view(-1) + cur_rcnn_boxes3d = rcnn_boxes3d + else: + roi_batch_mask = roi_batch_id == batch_id + cur_cls_score = cls_score[roi_batch_mask].view(-1) + cur_rcnn_boxes3d = rcnn_boxes3d[roi_batch_mask] + + cur_box_prob = class_pred[batch_id] + selected = nms_func(cur_box_prob, cur_class_labels, cur_rcnn_boxes3d, + cfg.score_thr, cfg.nms_thr, + img_metas[batch_id], + cfg.use_rotate_nms) + selected_bboxes = cur_rcnn_boxes3d[selected] + selected_label_preds = cur_class_labels[selected] + selected_scores = cur_cls_score[selected] + + result_list.append( + (img_metas[batch_id]['box_type_3d'](selected_bboxes, selected_bboxes.size(1)), + selected_scores, selected_label_preds)) + return result_list + + def multi_class_nms(self, + box_probs, + labels, # labels from rpn + box_preds, + score_thr, + nms_thr, + input_meta, + use_rotate_nms=True): + """Multi-class NMS for box head. + + Note: + This function has large overlap with the `box3d_multiclass_nms` + implemented in `mmdet3d.core.post_processing`. We are considering + merging these two functions in the future. + + Args: + box_probs (torch.Tensor): Predicted boxes probabitilies in + shape (N,). + box_preds (torch.Tensor): Predicted boxes in shape (N, 7+C). + score_thr (float): Threshold of scores. + nms_thr (float): Threshold for NMS. + input_meta (dict): Meta informations of the current sample. + use_rotate_nms (bool, optional): Whether to use rotated nms. + Defaults to True. + + Returns: + torch.Tensor: Selected indices. + """ + if use_rotate_nms: + nms_func = nms_gpu + else: + nms_func = nms_normal_gpu + + assert box_probs.ndim == 1 + + selected_list = [] + boxes_for_nms = xywhr2xyxyr(input_meta['box_type_3d']( + box_preds, box_preds.size(1)).bev) + + score_thresh = score_thr if isinstance( + score_thr, (list, tuple)) else [score_thr for x in range(self.num_classes)] + nms_thresh = nms_thr if isinstance( + nms_thr, (list, tuple)) else [nms_thr for x in range(self.num_classes)] + + for k in range(0, self.num_classes): + class_scores_keep = (box_probs >= score_thresh[k]) & (labels == k) + + if class_scores_keep.any(): + original_idxs = class_scores_keep.nonzero( + as_tuple=False).view(-1) + cur_boxes_for_nms = boxes_for_nms[class_scores_keep] + cur_rank_scores = box_probs[class_scores_keep] + + cur_selected = nms_func(cur_boxes_for_nms, cur_rank_scores, + nms_thresh[k]) + + if cur_selected.shape[0] == 0: + continue + selected_list.append(original_idxs[cur_selected]) + + selected = torch.cat( + selected_list, dim=0) if len(selected_list) > 0 else [] + return selected + + def single_class_nms(self, + box_probs, + labels, # labels from rpn + box_preds, + score_thr, + nms_thr, + input_meta, + use_rotate_nms=True): + + if use_rotate_nms: + nms_func = nms_gpu + else: + nms_func = nms_normal_gpu + + assert box_probs.ndim == 1 + + selected_list = [] + boxes_for_nms = xywhr2xyxyr(input_meta['box_type_3d']( + box_preds, box_preds.size(1)).bev) + + assert isinstance(score_thr, float) + score_thresh = score_thr + nms_thresh = nms_thr + class_scores_keep = box_probs >= score_thresh + + if class_scores_keep.int().sum() > 0: + original_idxs = class_scores_keep.nonzero( + as_tuple=False).view(-1) + cur_boxes_for_nms = boxes_for_nms[class_scores_keep] + cur_rank_scores = box_probs[class_scores_keep] + + if nms_thresh is not None: + cur_selected = nms_func(cur_boxes_for_nms, cur_rank_scores, nms_thresh) + else: + cur_selected = torch.arange(len(original_idxs), device=original_idxs.device, dtype=torch.long) + + if len(cur_selected) > 0: + selected_list.append(original_idxs[cur_selected]) + + selected = torch.cat(selected_list, dim=0) if len(selected_list) > 0 else [] + return selected \ No newline at end of file diff --git a/mmdet3d/models/roi_heads/fsd_roi_head.py b/mmdet3d/models/roi_heads/fsd_roi_head.py new file mode 100644 index 0000000000..2f6fa9b076 --- /dev/null +++ b/mmdet3d/models/roi_heads/fsd_roi_head.py @@ -0,0 +1,309 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from typing import Tuple + +import torch +from mmdet.structures import SampleList +from mmdet.utils import InstanceList +from torch import Tensor +from torch.nn import functional as F + +from mmdet.models.task_modules import AssignResult +from mmdet3d.structures.ops import bbox3d2result, bbox3d2roi +from mmdet3d.structures import LiDARInstance3DBoxes +from mmdet3d.models.task_modules.builder import build_assigner, build_sampler +from ..builder import build_head, build_roi_extractor +from .base_3droi_head import Base3DRoIHead +from mmdet3d.registry import MODELS + + +@MODELS.register_module() +class GroupCorrectionHead(Base3DRoIHead): + """Part aggregation roi head for PartA2. + + Args: + semantic_head (ConfigDict): Config of semantic head. + num_classes (int): The number of classes. + seg_roi_extractor (ConfigDict): Config of seg_roi_extractor. + part_roi_extractor (ConfigDict): Config of part_roi_extractor. + bbox_head (ConfigDict): Config of bbox_head. + train_cfg (ConfigDict): Training config. + test_cfg (ConfigDict): Testing config. + """ + + def __init__(self, + num_classes=3, + roi_extractor=None, + bbox_head=None, + train_cfg=None, + test_cfg=None, + pretrained=None, + init_cfg=None): + super().__init__( + bbox_head=bbox_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + init_cfg=init_cfg) + self.num_classes = num_classes + + self.roi_extractor = build_roi_extractor(roi_extractor) + + self.init_assigner_sampler() + + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be setting at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is a deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + + def init_mask_head(self): + pass + + def init_bbox_head(self, bbox_roi_extractor: dict = None, + bbox_head: dict = None) -> None: + """Initialize box head and box roi extractor. + + Args: + bbox_roi_extractor (dict or ConfigDict): Config of box + roi extractor. + bbox_head (dict or ConfigDict): Config of box in box head. + """ + # self.bbox_roi_extractor = MODELS.build(bbox_roi_extractor) + self.bbox_head = MODELS.build(bbox_head) + self.bbox_head.train_cfg = self.train_cfg + self.bbox_head.test_cfg = self.test_cfg + + def init_assigner_sampler(self): + """Initialize assigner and sampler.""" + self.bbox_assigner = None + self.bbox_sampler = None + if self.train_cfg: + if isinstance(self.train_cfg.assigner, dict): + self.bbox_assigner = build_assigner(self.train_cfg.assigner) + elif isinstance(self.train_cfg.assigner, list): + self.bbox_assigner = [ + build_assigner(res) for res in self.train_cfg.assigner + ] + self.bbox_sampler = build_sampler(self.train_cfg.sampler) + + def forward_train( + self, + pts_xyz, + pts_feats, + pts_batch_idx, + img_metas, + proposal_list, + gt_bboxes_3d, + gt_labels_3d + ): + + losses = dict() + + sample_results = self._assign_and_sample(proposal_list, gt_bboxes_3d, + gt_labels_3d) + + bbox_results = self._bbox_forward_train( + pts_xyz, + pts_feats, + pts_batch_idx, + sample_results + ) + + losses.update(bbox_results['loss_bbox']) + + return losses + + def simple_test( + self, + pts_xyz, + pts_feats, + pts_batch_inds, + img_metas, + proposal_list, + gt_bboxes_3d, + gt_labels_3d, + **kwargs): + + """Simple testing forward function of PartAggregationROIHead. + + Note: + This function assumes that the batch size is 1 + + Args: + feats_dict (dict): Contains features from the first stage. + voxels_dict (dict): Contains information of voxels. + img_metas (list[dict]): Meta info of each image. + proposal_list (list[dict]): Proposal information from rpn. + + Returns: + dict: Bbox results of one frame. + """ + + + assert len(proposal_list) == 1, 'only support bsz==1 to make cls_preds and labels_3d consistent with bbox_results' + rois = bbox3d2roi([res[0].tensor for res in proposal_list]) + cls_preds = [res[1] for res in proposal_list] + labels_3d = [res[2] for res in proposal_list] + + if len(rois) == 0: + # fake prediction without velocity dims + rois = torch.tensor([[0,0,0,5,1,1,1,0]], dtype=rois.dtype, device=rois.device) + cls_preds = [torch.tensor([0.0], dtype=torch.float32, device=rois.device)] + labels_3d = [torch.tensor([0], dtype=torch.int64, device=rois.device)] + + + # cls_preds = cls_preds[0] + # labels_3d = labels_3d[0] + + bbox_results = self._bbox_forward(pts_xyz, pts_feats, pts_batch_inds, rois) + + bbox_list = self.bbox_head.get_bboxes( + rois, + bbox_results['cls_score'], + bbox_results['bbox_pred'], + bbox_results['valid_roi_mask'], + labels_3d, + cls_preds, + img_metas, + cfg=self.test_cfg) + + bbox_results = [ + bbox3d2result(bboxes, scores, labels) + for bboxes, scores, labels in bbox_list + ] + return bbox_results + + def _bbox_forward_train(self, pts_xyz, pts_feats, batch_idx, sampling_results): + + rois = bbox3d2roi([res.bboxes for res in sampling_results]) + + bbox_results = self._bbox_forward(pts_xyz, pts_feats, batch_idx, rois) + + bbox_targets = self.bbox_head.get_targets(sampling_results, self.train_cfg) + + loss_bbox = self.bbox_head.loss( + bbox_results['cls_score'], + bbox_results['bbox_pred'], + bbox_results['valid_roi_mask'], + rois, + *bbox_targets + ) + + bbox_results.update(loss_bbox=loss_bbox) + return bbox_results + + def _bbox_forward(self, pts_xyz, pts_feats, batch_idx, rois): + + assert pts_xyz.size(0) == pts_feats.size(0) == batch_idx.size(0) + + ext_pts_inds, ext_pts_roi_inds, ext_pts_info = self.roi_extractor( + pts_xyz[:, :3], # intensity might be in pts_xyz + batch_idx, + rois[:, :8], + ) + + new_pts_feats = pts_feats[ext_pts_inds] + new_pts_xyz = pts_xyz[ext_pts_inds] + + # def forward(self, pts_xyz, pts_features, pts_info, roi_inds, rois): + + cls_score, bbox_pred, valid_roi_mask = self.bbox_head( + new_pts_xyz, + new_pts_feats, + ext_pts_info, + ext_pts_roi_inds, + rois, + ) + + bbox_results = dict( + cls_score=cls_score, + bbox_pred=bbox_pred, + valid_roi_mask=valid_roi_mask, + ) + + return bbox_results + + def _assign_and_sample(self, proposal_list, gt_bboxes_3d, gt_labels_3d): + """Assign and sample proposals for training. + + Args: + proposal_list (list[dict]): Proposals produced by RPN. + gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth + boxes. + gt_labels_3d (list[torch.Tensor]): Ground truth labels + + Returns: + list[:obj:`SamplingResult`]: Sampled results of each training + sample. + """ + assert len(proposal_list) == len(gt_bboxes_3d) + sampling_results = [] + # bbox assign + for batch_idx in range(len(proposal_list)): + cur_boxes, cur_scores, cur_pd_labels = proposal_list[batch_idx] + # fake a box if no real proposal + no_proposal = len(cur_boxes) == 0 + if no_proposal: + # print('*******fake a box*******') + cur_boxes = LiDARInstance3DBoxes(torch.tensor([[0,0,5,1,1,1,0]], dtype=torch.float32, device=cur_boxes.device)) + cur_scores = torch.tensor([0.0], dtype=torch.float32, device=cur_boxes.device) + cur_pd_labels = torch.tensor([0], dtype=torch.int64, device=cur_boxes.device) + + cur_gt_bboxes = gt_bboxes_3d[batch_idx].to(cur_boxes.device) + cur_gt_labels = gt_labels_3d[batch_idx] + + batch_num_gts = 0 + # 0 is bg + batch_gt_indis = cur_gt_labels.new_full((len(cur_boxes), ), 0) + batch_max_overlaps = cur_boxes.tensor.new_zeros(len(cur_boxes)) + # -1 is bg + batch_gt_labels = cur_gt_labels.new_full((len(cur_boxes), ), -1) + + # each class may have its own assigner + if isinstance(self.bbox_assigner, list): + for i, assigner in enumerate(self.bbox_assigner): + gt_cls_mask = (cur_gt_labels == i) + pred_cls_mask = (cur_pd_labels == i) + cur_assign_res = assigner.assign( + cur_boxes.tensor[pred_cls_mask, :7], + cur_gt_bboxes.tensor[gt_cls_mask, :7], + gt_labels=cur_gt_labels[gt_cls_mask]) + # gather assign_results in different class into one result + batch_num_gts += cur_assign_res.num_gts + # gt inds (1-based) + gt_inds_arange_pad = gt_cls_mask.nonzero( + as_tuple=False).view(-1) + 1 + # pad 0 for indice unassigned + gt_inds_arange_pad = F.pad( + gt_inds_arange_pad, (1, 0), mode='constant', value=0) + # pad -1 for indice ignore + gt_inds_arange_pad = F.pad( + gt_inds_arange_pad, (1, 0), mode='constant', value=-1) + # convert to 0~gt_num+2 for indices + # gt_inds_arange_pad += 1 + # now 0 is bg, >1 is fg in batch_gt_indis + batch_gt_indis[pred_cls_mask] = gt_inds_arange_pad[ + cur_assign_res.gt_inds + 1] # - 1 + batch_max_overlaps[ + pred_cls_mask] = cur_assign_res.max_overlaps + batch_gt_labels[pred_cls_mask] = cur_assign_res.labels + + assign_result = AssignResult(batch_num_gts, batch_gt_indis, + batch_max_overlaps, + batch_gt_labels) + else: # for single class + assign_result = self.bbox_assigner.assign( + cur_boxes.tensor[:, :7], + cur_gt_bboxes.tensor[:, :7], + gt_labels=cur_gt_labels) + # sample boxes + sampling_result = self.bbox_sampler.sample(assign_result, + cur_boxes.tensor, + cur_gt_bboxes.tensor, + cur_gt_labels) + sampling_results.append(sampling_result) + return sampling_results + + def loss(self, x: Tuple[Tensor], rpn_results_list: InstanceList, batch_data_samples: SampleList): + raise NotImplementedError diff --git a/mmdet3d/models/roi_heads/roi_extractors/__init__.py b/mmdet3d/models/roi_heads/roi_extractors/__init__.py index f10e7179c7..7f1b565762 100644 --- a/mmdet3d/models/roi_heads/roi_extractors/__init__.py +++ b/mmdet3d/models/roi_heads/roi_extractors/__init__.py @@ -4,8 +4,10 @@ from .batch_roigridpoint_extractor import Batch3DRoIGridExtractor from .single_roiaware_extractor import Single3DRoIAwareExtractor from .single_roipoint_extractor import Single3DRoIPointExtractor +from .dynamic_point_roi_extractor import DynamicPointROIExtractor __all__ = [ 'SingleRoIExtractor', 'Single3DRoIAwareExtractor', - 'Single3DRoIPointExtractor', 'Batch3DRoIGridExtractor' + 'Single3DRoIPointExtractor', 'Batch3DRoIGridExtractor', + 'DynamicPointROIExtractor' ] diff --git a/mmdet3d/models/roi_heads/roi_extractors/dynamic_point_pool_op.py b/mmdet3d/models/roi_heads/roi_extractors/dynamic_point_pool_op.py new file mode 100644 index 0000000000..e0da4a131f --- /dev/null +++ b/mmdet3d/models/roi_heads/roi_extractors/dynamic_point_pool_op.py @@ -0,0 +1,58 @@ +import torch +from torch.autograd import Function +import dynamic_point_pool_ext + + +class DynamicPointPoolFunction(Function): + + @staticmethod + def forward(ctx, rois, pts, extra_wlh, max_inbox_point, max_all_pts=50000): + """RoIAwarePool3d function forward. + + Args: + rois (torch.Tensor): [N, 7], in LiDAR coordinate, + (x, y, z) is the bottom center of rois + pts (torch.Tensor): [npoints, 3] + pts_feature (torch.Tensor): [npoints, C] + out_size (int or tuple): n or [n1, n2, n3] + max_pts_per_voxel (int): m + mode (int): 0 (max pool) or 1 (average pool) + + Returns: + pooled_features (torch.Tensor): [N, out_x, out_y, out_z, C] + """ + + # pts_inds, roi_inds, pts_norm_xyz, pts_offset = dynamic_point_pool_ext.forward(rois, pts) + out_pts_idx = -1 * pts.new_ones(max_all_pts, dtype=torch.long) + out_roi_idx = -1 * pts.new_ones(max_all_pts, dtype=torch.long) + out_pts_feats = pts.new_zeros(max_all_pts, 13, dtype=torch.float) + + assert len(rois) > 0 + dynamic_point_pool_ext.forward(rois, pts, extra_wlh, max_inbox_point, out_pts_idx, out_roi_idx, out_pts_feats) + # Because of cuda block layout, the out_roi_idx is automatically sorted, but not strictly guaranteed. + valid_mask = out_pts_idx >= 0 + + if not valid_mask.any(): + # fake a non-empty input + out_pts_idx = out_pts_idx[0:1] + out_roi_idx = out_roi_idx[0:1] + out_pts_feats = out_pts_feats[0:1, :] + else: + out_pts_idx = out_pts_idx[valid_mask] + out_roi_idx = out_roi_idx[valid_mask] + out_pts_feats = out_pts_feats[valid_mask] + unique_roi_idx = torch.unique(out_roi_idx) + + ctx.mark_non_differentiable(out_pts_idx) + ctx.mark_non_differentiable(out_roi_idx) + ctx.mark_non_differentiable(out_pts_feats) + + return out_pts_idx, out_roi_idx, out_pts_feats + + @staticmethod + def backward(ctx, g1, g2, g3): + + return None, None, None, None, None + + +dynamic_point_pool = DynamicPointPoolFunction.apply diff --git a/mmdet3d/models/roi_heads/roi_extractors/dynamic_point_roi_extractor.py b/mmdet3d/models/roi_heads/roi_extractors/dynamic_point_roi_extractor.py new file mode 100644 index 0000000000..818e1504be --- /dev/null +++ b/mmdet3d/models/roi_heads/roi_extractors/dynamic_point_roi_extractor.py @@ -0,0 +1,100 @@ +import torch +from mmengine.model import BaseModule + +from mmdet3d.models.roi_heads.roi_extractors.dynamic_point_pool_op import dynamic_point_pool +from mmdet3d.registry import MODELS + + +@MODELS.register_module() +class DynamicPointROIExtractor(BaseModule): + """Point-wise roi-aware Extractor. + + Extract Point-wise roi features. + + Args: + roi_layer (dict): The config of roi layer. + """ + + def __init__(self, + init_cfg=None, + debug=True, + extra_wlh=[0, 0, 0], + max_inbox_point=512,): + super().__init__(init_cfg=init_cfg) + self.debug = debug + self.extra_wlh = extra_wlh + self.max_inbox_point = max_inbox_point + + + def forward(self, pts_xyz, batch_inds, rois): + + # assert batch_inds is sorted + assert len(pts_xyz) > 0 + assert len(batch_inds) > 0 + assert len(rois) > 0 + + if not (batch_inds == 0).all(): + assert (batch_inds.sort()[0] == batch_inds).all() + + all_inds, all_pts_info, all_roi_inds = [], [], [] + + roi_inds_base = 0 + pts_inds_base = 0 + + for batch_idx in range(int(batch_inds.max()) + 1): + roi_batch_mask = (rois[..., 0].int() == batch_idx) + pts_batch_mask = (batch_inds.int() == batch_idx) + + num_roi_this_batch = roi_batch_mask.sum().item() + num_pts_this_batch = pts_batch_mask.sum().item() + assert num_roi_this_batch > 0 + assert num_pts_this_batch > 0 + + ext_pts_inds, roi_inds, ext_pts_info = dynamic_point_pool( + rois[..., 1:][roi_batch_mask], + pts_xyz[pts_batch_mask], + self.extra_wlh, + self.max_inbox_point, + ) + # append returns to all_inds, all_local_xyz, all_offset + if len(ext_pts_inds) == 1 and ext_pts_inds[0].item() == -1: + assert roi_inds[0].item() == -1 + all_inds.append(ext_pts_inds) # keep -1 and do not plus the base + all_pts_info.append(ext_pts_info) + all_roi_inds.append(roi_inds) # keep -1 and do not plus the base + else: + all_inds.append(ext_pts_inds + pts_inds_base) + all_pts_info.append(ext_pts_info) + all_roi_inds.append(roi_inds + roi_inds_base) + + pts_inds_base += num_pts_this_batch + roi_inds_base += num_roi_this_batch + + all_inds = torch.cat(all_inds, dim=0) + all_pts_info = torch.cat(all_pts_info, dim=0) + all_roi_inds = torch.cat(all_roi_inds, dim=0) + + all_out_xyz = all_pts_info[:, :3] + all_local_xyz = all_pts_info[:, 3:6] + all_offset = all_pts_info[:, 6:-1] + is_in_margin = all_pts_info[:, -1] + + if self.debug: + roi_per_pts = rois[..., 1:][all_roi_inds] + in_box_pts = pts_xyz[all_inds] + assert torch.isclose(in_box_pts, all_out_xyz).all() + assert torch.isclose(all_offset[:, 0] + all_offset[:, 3], roi_per_pts[:, 4]).all() + assert torch.isclose(all_offset[:, 1] + all_offset[:, 4], roi_per_pts[:, 3]).all() + assert torch.isclose(all_offset[:, 2] + all_offset[:, 5], roi_per_pts[:, 5]).all() + assert (all_local_xyz[:, 0].abs() < roi_per_pts[:, 4] + self.extra_wlh[0] + 1e-5).all() + assert (all_local_xyz[:, 1].abs() < roi_per_pts[:, 3] + self.extra_wlh[1] + 1e-5).all() + assert (all_local_xyz[:, 2].abs() < roi_per_pts[:, 5] + self.extra_wlh[2] + 1e-5).all() + + ext_pts_info = dict( + local_xyz=all_local_xyz, + boundary_offset=all_offset, + is_in_margin=is_in_margin, + ) + + return all_inds, all_roi_inds, ext_pts_info + diff --git a/mmdet3d/models/task_modules/coders/__init__.py b/mmdet3d/models/task_modules/coders/__init__.py index b22e725be7..03cac80a28 100644 --- a/mmdet3d/models/task_modules/coders/__init__.py +++ b/mmdet3d/models/task_modules/coders/__init__.py @@ -9,10 +9,11 @@ from .pgd_bbox_coder import PGDBBoxCoder from .point_xyzwhlr_bbox_coder import PointXYZWHLRBBoxCoder from .smoke_bbox_coder import SMOKECoder +from .base_point_bbox_coder import BasePointBBoxCoder __all__ = [ 'DeltaXYZWLHRBBoxCoder', 'PartialBinBasedBBoxCoder', 'CenterPointBBoxCoder', 'AnchorFreeBBoxCoder', 'GroupFree3DBBoxCoder', 'PointXYZWHLRBBoxCoder', 'FCOS3DBBoxCoder', 'PGDBBoxCoder', 'SMOKECoder', - 'MonoFlexCoder' + 'MonoFlexCoder', 'BasePointBBoxCoder' ] diff --git a/mmdet3d/models/task_modules/coders/base_point_bbox_coder.py b/mmdet3d/models/task_modules/coders/base_point_bbox_coder.py new file mode 100644 index 0000000000..68762ba069 --- /dev/null +++ b/mmdet3d/models/task_modules/coders/base_point_bbox_coder.py @@ -0,0 +1,83 @@ +import torch + +from mmdet.models.task_modules import BaseBBoxCoder + +from mmdet3d.models.task_modules.builder import BBOX_CODERS + + +@BBOX_CODERS.register_module() +class BasePointBBoxCoder(BaseBBoxCoder): + """Bbox coder for CenterPoint. + Args: + pc_range (list[float]): Range of point cloud. + out_size_factor (int): Downsample factor of the model. + voxel_size (list[float]): Size of voxel. + post_center_range (list[float]): Limit of the center. + Default: None. + max_num (int): Max number to be kept. Default: 100. + score_threshold (float): Threshold to filter boxes based on score. + Default: None. + code_size (int): Code size of bboxes. Default: 9 + """ + + def __init__(self, + post_center_range=None, + score_thresh=0.1, + num_classes=3, + max_num=500, + code_size=8): + + self.post_center_range = post_center_range + self.code_size = code_size + self.EPS = 1e-6 + self.score_thresh=score_thresh + self.num_classes = num_classes + self.max_num = max_num + + def encode(self, bboxes, base_points): + """ + Get regress target given bboxes and corresponding base_points + """ + dtype = bboxes.dtype + device = bboxes.device + + assert bboxes.size(1) in (7, 9, 10), f'bboxes shape: {bboxes.shape}' + assert bboxes.size(0) == base_points.size(0) + xyz = bboxes[:,:3] + dims = bboxes[:, 3:6] + yaw = bboxes[:, 6:7] + + log_dims = (dims + self.EPS).log() + + dist2center = xyz - base_points + + delta = dist2center # / self.window_size_meter + reg_target = torch.cat([delta, log_dims, yaw.sin(), yaw.cos()], dim=1) + if bboxes.size(1) in (9, 10): # with velocity or copypaste flag + assert self.code_size == 10 + reg_target = torch.cat([reg_target, bboxes[:, [7, 8]]], dim=1) + return reg_target + + def decode(self, reg_preds, base_points, detach_yaw=False): + + assert reg_preds.size(1) in (8, 10) + assert reg_preds.size(1) == self.code_size + + if self.code_size == 10: + velo = reg_preds[:, -2:] + reg_preds = reg_preds[:, :8] # remove the velocity + + dist2center = reg_preds[:, :3] # * self.window_size_meter + xyz = dist2center + base_points + + dims = reg_preds[:, 3:6].exp() - self.EPS + + sin = reg_preds[:, 6:7] + cos = reg_preds[:, 7:8] + yaw = torch.atan2(sin, cos) + if detach_yaw: + yaw = yaw.clone().detach() + bboxes = torch.cat([xyz, dims, yaw], dim=1) + if self.code_size == 10: + bboxes = torch.cat([bboxes, velo], dim=1) + return bboxes diff --git a/mmdet3d/models/voxel_encoders/__init__.py b/mmdet3d/models/voxel_encoders/__init__.py index 2926a83422..d3efb7a1af 100644 --- a/mmdet3d/models/voxel_encoders/__init__.py +++ b/mmdet3d/models/voxel_encoders/__init__.py @@ -1,8 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. from .pillar_encoder import DynamicPillarFeatureNet, PillarFeatureNet -from .voxel_encoder import DynamicSimpleVFE, DynamicVFE, HardSimpleVFE, HardVFE +from .voxel_encoder import DynamicSimpleVFE, DynamicVFE, HardSimpleVFE, HardVFE, DynamicScatterVFE, SIRLayer __all__ = [ 'PillarFeatureNet', 'DynamicPillarFeatureNet', 'HardVFE', 'DynamicVFE', - 'HardSimpleVFE', 'DynamicSimpleVFE' + 'HardSimpleVFE', 'DynamicSimpleVFE', 'DynamicScatterVFE', 'SIRLayer' ] diff --git a/mmdet3d/models/voxel_encoders/utils.py b/mmdet3d/models/voxel_encoders/utils.py index 9b9e7afc59..aff32aee88 100644 --- a/mmdet3d/models/voxel_encoders/utils.py +++ b/mmdet3d/models/voxel_encoders/utils.py @@ -4,6 +4,8 @@ from torch import nn from torch.nn import functional as F +from mmdet3d.models.layers.sst.sst_ops import get_activation_layer + def get_paddings_indicator(actual_num, max_num, axis=0): """Create boolean mask by actually number of a padded tensor. @@ -102,6 +104,91 @@ def forward(self, inputs): return concatenated +class DynamicVFELayer(nn.Module): + """Replace the Voxel Feature Encoder layer in VFE layers. + + This layer has the same utility as VFELayer above + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + norm_cfg (dict): Config dict of normalization layers + """ + + def __init__(self, + in_channels, + out_channels, + norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01) + ): + super(DynamicVFELayer, self).__init__() + self.fp16_enabled = False + # self.units = int(out_channels / 2) + self.norm = build_norm_layer(norm_cfg, out_channels)[1] + self.linear = nn.Linear(in_channels, out_channels, bias=False) + + # @auto_fp16(apply_to=('inputs'), out_fp32=True) + def forward(self, inputs): + """Forward function. + + Args: + inputs (torch.Tensor): Voxels features of shape (M, C). + M is the number of points, C is the number of channels of point features. + + Returns: + torch.Tensor: point features in shape (M, C). + """ + # [K, T, 7] tensordot [7, units] = [K, T, units] + x = self.linear(inputs) + x = self.norm(x) + pointwise = F.relu(x) + return pointwise + + +class DynamicVFELayerV2(nn.Module): + """Replace the Voxel Feature Encoder layer in VFE layers. + This layer has the same utility as VFELayer above + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + norm_cfg (dict): Config dict of normalization layers + """ + + def __init__(self, + in_channels, + out_channels, + norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01), + act='relu', + dropout=0.0, + ): + super(DynamicVFELayerV2, self).__init__() + self.fp16_enabled = False + # self.units = int(out_channels / 2) + self.norm = build_norm_layer(norm_cfg, out_channels)[1] + self.linear = nn.Linear(in_channels, out_channels, bias=False) + self.act = get_activation_layer(act, out_channels) + if dropout > 0: + self.dropout = nn.Dropout(p=dropout) + else: + self.dropout = None + + # @auto_fp16(apply_to=('inputs'), out_fp32=True) + def forward(self, inputs): + """Forward function. + Args: + inputs (torch.Tensor): Voxels features of shape (M, C). + M is the number of points, C is the number of channels of point features. + Returns: + torch.Tensor: point features in shape (M, C). + """ + # [K, T, 7] tensordot [7, units] = [K, T, units] + if self.dropout is not None: + inputs = self.dropout(inputs) + x = self.linear(inputs) + x = self.norm(x) + pointwise = self.act(x) + return pointwise + + class PFNLayer(nn.Module): """Pillar Feature Net Layer. diff --git a/mmdet3d/models/voxel_encoders/voxel_encoder.py b/mmdet3d/models/voxel_encoders/voxel_encoder.py index 867818329d..c2ffed3881 100644 --- a/mmdet3d/models/voxel_encoders/voxel_encoder.py +++ b/mmdet3d/models/voxel_encoders/voxel_encoder.py @@ -6,7 +6,8 @@ from mmdet3d.registry import MODELS from .. import builder -from .utils import VFELayer, get_paddings_indicator +from .utils import VFELayer, get_paddings_indicator, DynamicVFELayerV2, DynamicVFELayer +from mmdet3d.models.layers.sst import build_mlp @MODELS.register_module() @@ -134,7 +135,7 @@ def __init__(self, if with_voxel_center: in_channels += 3 if with_distance: - in_channels += 1 + in_channels += 3 # 1 self.in_channels = in_channels self._with_distance = with_distance self._with_cluster_center = with_cluster_center @@ -159,11 +160,17 @@ def __init__(self, out_filters = feat_channels[i + 1] if i > 0: in_filters *= 2 - norm_name, norm_layer = build_norm_layer(norm_cfg, out_filters) + # norm_name, norm_layer = build_norm_layer(norm_cfg, out_filters) + # vfe_layers.append( + # nn.Sequential( + # nn.Linear(in_filters, out_filters, bias=False), norm_layer, + # nn.ReLU(inplace=True))) + vfe_layers.append( - nn.Sequential( - nn.Linear(in_filters, out_filters, bias=False), norm_layer, - nn.ReLU(inplace=True))) + DynamicVFELayer( + in_filters, + out_filters, + norm_cfg)) self.vfe_layers = nn.ModuleList(vfe_layers) self.num_vfe = len(vfe_layers) self.vfe_scatter = DynamicScatter(voxel_size, point_cloud_range, @@ -487,3 +494,268 @@ def fusion_with_mask(self, features, mask, voxel_feats, coors, img_feats, out = torch.max(voxel_canvas, dim=1)[0] return out + + +@MODELS.register_module() +class DynamicScatterVFE(DynamicVFE): + """ Same with DynamicVFE but use torch_scatter to avoid construct canvas in map_voxel_center_to_point. + The canvas is very memory-consuming when use tiny voxel size (5cm * 5cm * 5cm) in large 3D space. + """ + + def __init__(self, + in_channels=4, + feat_channels=[], + with_distance=False, + with_cluster_center=False, + with_voxel_center=False, + voxel_size=(0.2, 0.2, 4), + point_cloud_range=(0, -40, -3, 70.4, 40, 1), + norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01), + mode='max', + fusion_layer=None, + return_point_feats=False, + return_inv=True, + rel_dist_scaler=1.0, + unique_once=False, + ): + super(DynamicScatterVFE, self).__init__( + in_channels, + feat_channels, + with_distance, + with_cluster_center, + with_voxel_center, + voxel_size, + point_cloud_range, + norm_cfg, + mode, + fusion_layer, + return_point_feats, + ) + # overwrite + self.scatter = None + self.vfe_scatter = None + self.cluster_scatter = None + self.rel_dist_scaler = rel_dist_scaler + self.mode = mode + self.unique_once = unique_once + + def map_voxel_center_to_point(self, voxel_mean, voxel2point_inds): + + return voxel_mean[voxel2point_inds] + + # if out_fp16=True, the large numbers of points + # lead to overflow error in following layers + # @force_fp32(out_fp16=False) + def forward(self, + features, + coors, + points=None, + img_feats=None, + img_metas=None, + return_inv=False): + + if self.unique_once: + new_coors, unq_inv_once = torch.unique(coors, return_inverse=True, return_counts=False, dim=0) + else: + new_coors = unq_inv_once = None + + features_ls = [features] + origin_point_coors = features[:, :3] + # Find distance of x, y, and z from cluster center + if self._with_cluster_center: + voxel_mean, _, unq_inv = scatter_v2(features[:, :3], coors, mode='avg', new_coors=new_coors, + unq_inv=unq_inv_once) + points_mean = self.map_voxel_center_to_point(voxel_mean, unq_inv) + # TODO: maybe also do cluster for reflectivity + f_cluster = features[:, :3] - points_mean[:, :3] + features_ls.append(f_cluster / self.rel_dist_scaler) + + # Find distance of x, y, and z from pillar center + if self._with_voxel_center: + f_center = features.new_zeros(size=(features.size(0), 3)) + f_center[:, 0] = features[:, 0] - ( + coors[:, 3].type_as(features) * self.vx + self.x_offset) + f_center[:, 1] = features[:, 1] - ( + coors[:, 2].type_as(features) * self.vy + self.y_offset) + f_center[:, 2] = features[:, 2] - ( + coors[:, 1].type_as(features) * self.vz + self.z_offset) + features_ls.append(f_center) + + if self._with_distance: + points_dist = torch.norm(features[:, :3], 2, 1, keepdim=True) + features_ls.append(points_dist) + + # Combine together feature decorations + features = torch.cat(features_ls, dim=-1) + + for i, vfe in enumerate(self.vfe_layers): + point_feats = vfe(features) + + if (i == len(self.vfe_layers) - 1 and self.fusion_layer is not None + and img_feats is not None): + point_feats = self.fusion_layer(img_feats, points, point_feats, + img_metas) + voxel_feats, voxel_coors, unq_inv = scatter_v2(point_feats, coors, mode=self.mode, new_coors=new_coors, + unq_inv=unq_inv_once) + if i != len(self.vfe_layers) - 1: + # need to concat voxel feats if it is not the last vfe + feat_per_point = self.map_voxel_center_to_point(voxel_feats, unq_inv) + features = torch.cat([point_feats, feat_per_point], dim=1) + if self.return_point_feats: + return point_feats + + if return_inv: + return voxel_feats, voxel_coors, unq_inv + else: + return voxel_feats, voxel_coors + + +@MODELS.register_module() +class SIRLayer(DynamicVFE): + + def __init__(self, + in_channels=4, + feat_channels=[], + with_distance=False, + with_cluster_center=False, + with_rel_mlp=True, + rel_mlp_hidden_dims=[16, ], + rel_mlp_in_channel=3, + with_voxel_center=False, + voxel_size=(0.2, 0.2, 4), + point_cloud_range=(0, -40, -3, 70.4, 40, 1), + norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01), + mode='max', + fusion_layer=None, + return_point_feats=False, + return_inv=True, + rel_dist_scaler=1.0, + with_shortcut=True, + xyz_normalizer=[1.0, 1.0, 1.0], + act='relu', + dropout=0.0, + ): + super().__init__( + in_channels, + feat_channels, + with_distance, + with_cluster_center, + with_voxel_center, + voxel_size, + point_cloud_range, + norm_cfg, + mode, + fusion_layer, + return_point_feats, + ) + # overwrite + self.scatter = None + self.vfe_scatter = None + self.cluster_scatter = None + self.rel_dist_scaler = rel_dist_scaler + self.mode = mode + self.with_shortcut = with_shortcut + self._with_rel_mlp = with_rel_mlp + self.xyz_normalizer = xyz_normalizer + if with_rel_mlp: + rel_mlp_hidden_dims.append(in_channels) # not self.in_channels + self.rel_mlp = build_mlp(rel_mlp_in_channel, rel_mlp_hidden_dims, norm_cfg, act=act) + + if act != 'relu' or dropout > 0: # do not double in_filter + feat_channels = [self.in_channels] + list(feat_channels) + vfe_layers = [] + for i in range(len(feat_channels) - 1): + in_filters = feat_channels[i] + out_filters = feat_channels[i + 1] + if i > 0: + in_filters *= 2 + + vfe_layers.append( + DynamicVFELayerV2( + in_filters, + out_filters, + norm_cfg, + act=act, + dropout=dropout, + ) + ) + self.vfe_layers = nn.ModuleList(vfe_layers) + self.num_vfe = len(vfe_layers) + + def map_voxel_center_to_point(self, voxel_mean, voxel2point_inds): + + return voxel_mean[voxel2point_inds] + + # if out_fp16=True, the large numbers of points + # lead to overflow error in following layers + # @force_fp32(out_fp16=False) + def forward(self, + features, + coors, + f_cluster=None, + points=None, + img_feats=None, + img_metas=None, + return_inv=False, + return_both=False, + unq_inv_once=None, + new_coors_once=None, + ): + + xyz_normalizer = torch.tensor(self.xyz_normalizer, device=features.device, dtype=features.dtype) + features_ls = [torch.cat([features[:, :3] / xyz_normalizer[None, :], features[:, 3:]], dim=1)] + # origin_point_coors = features[:, :3] + if self.with_shortcut: + shortcut = features[:, 3:] + if f_cluster is None: + # Find distance of x, y, and z from cluster center + voxel_mean, mean_coors, unq_inv = scatter_v2(features[:, :3], coors, mode='avg', unq_inv=unq_inv_once, + new_coors=new_coors_once) + points_mean = self.map_voxel_center_to_point( + voxel_mean, unq_inv) + # TODO: maybe also do cluster for reflectivity + f_cluster = (features[:, :3] - points_mean[:, :3]) / self.rel_dist_scaler + else: + f_cluster = f_cluster / self.rel_dist_scaler + + if self._with_cluster_center: + features_ls.append(f_cluster / 10.0) + + if self._with_rel_mlp: + features_ls[0] = features_ls[0] * self.rel_mlp(f_cluster) + + if self._with_distance: + points_dist = torch.norm(features[:, :3], 2, 1, keepdim=True) + features_ls.append(points_dist) + + # Combine together feature decorations + features = torch.cat(features_ls, dim=-1) + + voxel_feats_list = [] + for i, vfe in enumerate(self.vfe_layers): + point_feats = vfe(features) + + voxel_feats, voxel_coors, unq_inv = scatter_v2(point_feats, coors, mode=self.mode, unq_inv=unq_inv_once, + new_coors=new_coors_once) + voxel_feats_list.append(voxel_feats) + if i != len(self.vfe_layers) - 1: + # need to concat voxel feats if it is not the last vfe + feat_per_point = self.map_voxel_center_to_point(voxel_feats, unq_inv) + features = torch.cat([point_feats, feat_per_point], dim=1) + + voxel_feats = torch.cat(voxel_feats_list, dim=1) + + if return_both: + if self.with_shortcut and point_feats.shape == shortcut.shape: + point_feats = point_feats + shortcut + return point_feats, voxel_feats, voxel_coors + + if self.return_point_feats: + if self.with_shortcut and point_feats.shape == shortcut.shape: + point_feats = point_feats + shortcut + return point_feats, voxel_feats + + if return_inv: + return voxel_feats, voxel_coors, unq_inv + else: + return voxel_feats, voxel_coors diff --git a/mmdet3d/structures/ops/iou3d_calculator.py b/mmdet3d/structures/ops/iou3d_calculator.py index baec1cbe45..6aad188e18 100644 --- a/mmdet3d/structures/ops/iou3d_calculator.py +++ b/mmdet3d/structures/ops/iou3d_calculator.py @@ -327,3 +327,55 @@ def axis_aligned_bbox_overlaps_3d(bboxes1, enclose_area = torch.max(enclose_area, eps) gious = ious - (enclose_area - union) / enclose_area return gious + + +# from . import iou3d_cuda + +def nms_gpu(boxes, scores, thresh, pre_maxsize=None, post_max_size=None): + """Nms function with gpu implementation. + + Args: + boxes (torch.Tensor): Input boxes with the shape of [N, 5] + ([x1, y1, x2, y2, ry]). + scores (torch.Tensor): Scores of boxes with the shape of [N]. + thresh (int): Threshold. + pre_maxsize (int): Max size of boxes before nms. Default: None. + post_maxsize (int): Max size of boxes after nms. Default: None. + + Returns: + torch.Tensor: Indexes after nms. + """ + raise NotImplementedError + # order = scores.sort(0, descending=True)[1] + # + # if pre_maxsize is not None: + # order = order[:pre_maxsize] + # boxes = boxes[order].contiguous() + # + # keep = torch.zeros(boxes.size(0), dtype=torch.long) + # num_out = iou3d_cuda.nms_gpu(boxes, keep, thresh, boxes.device.index) + # keep = order[keep[:num_out].cuda(boxes.device)].contiguous() + # if post_max_size is not None: + # keep = keep[:post_max_size] + # return keep + +def nms_normal_gpu(boxes, scores, thresh): + """Normal non maximum suppression on GPU. + + Args: + boxes (torch.Tensor): Input boxes with shape (N, 5). + scores (torch.Tensor): Scores of predicted boxes with shape (N). + thresh (torch.Tensor): Threshold of non maximum suppression. + + Returns: + torch.Tensor: Remaining indices with scores in descending order. + """ + raise NotImplementedError + # order = scores.sort(0, descending=True)[1] + # + # boxes = boxes[order].contiguous() + # + # keep = torch.zeros(boxes.size(0), dtype=torch.long) + # num_out = iou3d_cuda.nms_normal_gpu(boxes, keep, thresh, + # boxes.device.index) + # return order[keep[:num_out].cuda(boxes.device)].contiguous() \ No newline at end of file From 9137a42fee46583272dd8377a643224e2b64642d Mon Sep 17 00:00:00 2001 From: HinGwenWoong Date: Mon, 19 Dec 2022 15:47:20 +0800 Subject: [PATCH 3/5] Test pipeline can run --- .../dense_heads/sparse_cluster_head_v2.py | 4 +- mmdet3d/models/detectors/single_stage_fsd.py | 78 ++++---------- mmdet3d/models/detectors/two_stage_fsd.py | 38 ++++--- .../roi_heads/bbox_heads/fsd_bbox_head.py | 19 ++-- mmdet3d/models/roi_heads/fsd_roi_head.py | 6 +- .../models/voxel_encoders/voxel_encoder.py | 4 +- mmdet3d/structures/ops/iou3d_calculator.py | 102 +++++++++--------- 7 files changed, 108 insertions(+), 143 deletions(-) diff --git a/mmdet3d/models/dense_heads/sparse_cluster_head_v2.py b/mmdet3d/models/dense_heads/sparse_cluster_head_v2.py index 8cca1f1e75..e5f586aa2b 100644 --- a/mmdet3d/models/dense_heads/sparse_cluster_head_v2.py +++ b/mmdet3d/models/dense_heads/sparse_cluster_head_v2.py @@ -549,7 +549,7 @@ def _get_bboxes_single( cluster_xyz = cluster_xyz[topk_inds, :] bboxes = self.bbox_coder.decode(reg_preds, cluster_xyz) - bboxes_for_nms = xywhr2xyxyr(input_meta['box_type_3d'](bboxes, box_dim=bboxes.size(1)).bev) + bboxes_for_nms = xywhr2xyxyr(input_meta.box_type_3d(bboxes, box_dim=bboxes.size(1)).bev) # Add a dummy background class to the front when using sigmoid padding = scores.new_zeros(scores.shape[0], 1) @@ -562,7 +562,7 @@ def _get_bboxes_single( out_bboxes, out_scores, out_labels = results - out_bboxes = input_meta['box_type_3d'](out_bboxes, out_bboxes.size(1)) + out_bboxes = input_meta.box_type_3d(out_bboxes, out_bboxes.size(1)) # modify task labels to global label indices new_labels = torch.zeros_like(out_labels) - 1 # all -1 diff --git a/mmdet3d/models/detectors/single_stage_fsd.py b/mmdet3d/models/detectors/single_stage_fsd.py index 55e80bb119..7f21429f86 100644 --- a/mmdet3d/models/detectors/single_stage_fsd.py +++ b/mmdet3d/models/detectors/single_stage_fsd.py @@ -236,7 +236,7 @@ def extract_feat(self, points, img_metas): voxel_info = self.middle_encoder(voxel_features, voxel_coors) x = self.backbone(voxel_info)[0] padding = -1 - voxel_coors_dropped = x['voxel_feats'] # bug, leave it for feature modification + # voxel_coors_dropped = x['voxel_feats'] # bug, leave it for feature modification if 'shuffle_inds' not in voxel_info: voxel_feats_reorder = x['voxel_feats'] else: @@ -332,7 +332,10 @@ def loss(self, return output_dict - def simple_test(self, points, img_metas, gt_bboxes_3d=None, gt_labels_3d=None, rescale=False): + def predict(self, batch_inputs: Tensor, batch_data_samples: SampleList) -> SampleList: + + points = batch_inputs['points'] + img_metas = batch_data_samples[0].metainfo if self.tanh_dims is not None: for p in points: @@ -343,14 +346,14 @@ def simple_test(self, points, img_metas, gt_bboxes_3d=None, gt_labels_3d=None, r if self.voxel_downsampling_size is not None: points = self.voxel_downsample(points) - seg_pred = [] + # seg_pred = [] x, pts_coors, points = self.extract_feat(points, img_metas) feats = x[0] valid_pts_mask = x[1] points = points[valid_pts_mask] pts_coors = pts_coors[valid_pts_mask] - seg_logits, vote_preds = self.segmentation_head.forward_test(feats, img_metas, self.test_cfg) + seg_logits, vote_preds = self.segmentation_head.forward(feats) offsets = self.segmentation_head.decode_vote_targets(vote_preds) @@ -365,9 +368,6 @@ def simple_test(self, points, img_metas, gt_bboxes_3d=None, gt_labels_3d=None, r return output_dict - def predict(self, batch_inputs: Tensor, batch_data_samples: SampleList) -> SampleList: - raise NotImplementedError - def _forward(self, batch_inputs: Tensor, batch_data_samples: OptSampleList = None) -> Tuple[List[Tensor]]: raise NotImplementedError @@ -596,16 +596,9 @@ def pre_voxelize(self, data_dict): voxelized_data_dict['batch_idx'] = voxel_coors[:, 0] return voxelized_data_dict - def simple_test(self, points, img_metas, imgs=None, rescale=False, gt_bboxes_3d=None, gt_labels_3d=None): - """Test function without augmentaiton.""" - if gt_bboxes_3d is not None: - gt_bboxes_3d = gt_bboxes_3d[0] - gt_labels_3d = gt_labels_3d[0] - assert isinstance(gt_bboxes_3d, list) - assert isinstance(gt_labels_3d, list) - assert len(gt_bboxes_3d) == len(gt_labels_3d) == 1, 'assuming single sample testing' - - seg_out_dict = self.segmentor.simple_test(points, img_metas, rescale=False) + # def simple_test(self, points, img_metas, imgs=None, rescale=False, gt_bboxes_3d=None, gt_labels_3d=None): + def predict(self, batch_inputs: Tensor, batch_data_samples: SampleList) -> SampleList: + seg_out_dict = self.segmentor.predict(batch_inputs, batch_data_samples) seg_feats = seg_out_dict['seg_feats'] @@ -617,14 +610,16 @@ def simple_test(self, points, img_metas, imgs=None, rescale=False, gt_bboxes_3d= batch_idx=seg_out_dict['batch_idx'], vote_offsets = seg_out_dict['offsets'] ) + if self.cfg.get('pre_voxelization_size', None) is not None: dict_to_sample = self.pre_voxelize(dict_to_sample) - sampled_out = self.sample(dict_to_sample, dict_to_sample['vote_offsets'], gt_bboxes_3d, gt_labels_3d) # per cls list in sampled_out + + sampled_out = self.sample(dict_to_sample, dict_to_sample['vote_offsets']) # per cls list in sampled_out # we filter almost empty voxel in clustering, so here is a valid_mask - cluster_inds_list, valid_mask_list = self.cluster_assigner(sampled_out['center_preds'], sampled_out['batch_idx'], gt_bboxes_3d, gt_labels_3d, origin_points=sampled_out['seg_points']) # per cls list + cluster_inds_list, valid_mask_list = self.cluster_assigner(sampled_out['center_preds'], sampled_out['batch_idx'], origin_points=sampled_out['seg_points']) # per cls list - pts_cluster_inds = torch.cat(cluster_inds_list, dim=0) #[N, 3], (cls_id, batch_idx, cluster_id) + pts_cluster_inds = torch.cat(cluster_inds_list, dim=0) # [N, 3], (cls_id, batch_idx, cluster_id) sampled_out = self.update_sample_results_by_mask(sampled_out, valid_mask_list) @@ -634,7 +629,7 @@ def simple_test(self, points, img_metas, imgs=None, rescale=False, gt_bboxes_3d= pts_feats = torch.cat([combined_out['seg_logits'], combined_out['seg_vote_preds'], combined_out['seg_feats']], dim=1) assert len(pts_cluster_inds) == len(points) == len(pts_feats) - extracted_outs = self.extract_feat(points, pts_feats, pts_cluster_inds, img_metas, combined_out['center_preds']) + extracted_outs = self.extract_feat(points, pts_feats, pts_cluster_inds, batch_data_samples, combined_out['center_preds']) cluster_feats = extracted_outs['cluster_feats'] cluster_xyz = extracted_outs['cluster_xyz'] cluster_inds = extracted_outs['cluster_inds'] @@ -644,8 +639,8 @@ def simple_test(self, points, img_metas, imgs=None, rescale=False, gt_bboxes_3d= bbox_list = self.bbox_head.get_bboxes( outs['cls_logits'], outs['reg_preds'], - cluster_xyz, cluster_inds, img_metas, - rescale=rescale, + cluster_xyz, cluster_inds, batch_data_samples, + rescale=False, iou_logits=outs.get('iou_logits', None)) if self.as_rpn: @@ -666,9 +661,6 @@ def simple_test(self, points, img_metas, imgs=None, rescale=False, gt_bboxes_3d= ] return bbox_results - def predict(self, batch_inputs: Tensor, batch_data_samples: SampleList) -> SampleList: - raise NotImplementedError - def _forward(self, batch_inputs: Tensor, batch_data_samples: OptSampleList = None) -> Tuple[List[Tensor]]: raise NotImplementedError @@ -682,10 +674,10 @@ def sample(self, dict_to_sample, offset, gt_bboxes_3d=None, gt_labels_3d=None): if self.cfg.get('group_sample', False): return self.group_sample(dict_to_sample, offset) - cfg = self.train_cfg if self.training else self.test_cfg + # cfg = self.train_cfg if self.training else self.test_cfg seg_logits = dict_to_sample['seg_logits'] - assert (seg_logits < 0).any() # make sure no sigmoid applied + assert (seg_logits < 0).any() # make sure no sigmoid applied if seg_logits.size(1) == self.num_classes: seg_scores = seg_logits.sigmoid() @@ -700,9 +692,7 @@ def sample(self, dict_to_sample, offset, gt_bboxes_3d=None, gt_labels_3d=None): batch_idx = dict_to_sample['batch_idx'] batch_size = batch_idx.max().item() + 1 for cls in range(self.num_classes): - cls_score_thr = cfg['score_thresh'][cls] - - fg_mask = self.get_fg_mask(seg_scores, seg_points, cls, batch_idx, gt_bboxes_3d, gt_labels_3d) + fg_mask = self.get_fg_mask(seg_scores, cls) if len(torch.unique(batch_idx[fg_mask])) < batch_size: one_random_pos_per_sample = self.get_sample_beg_position(batch_idx, fg_mask) @@ -735,7 +725,7 @@ def get_sample_beg_position(self, batch_idx, fg_mask): pos = torch.where(inner_inds == 0)[0] return pos - def get_fg_mask(self, seg_scores, seg_points, cls_id, batch_inds, gt_bboxes_3d, gt_labels_3d): + def get_fg_mask(self, seg_scores, cls_id): if self.training and self.train_cfg.get('disable_pretrain', False) and not self.runtime_info.get('enable_detection', False): seg_scores = seg_scores[:, cls_id] topks = self.train_cfg.get('disable_pretrain_topks', [100, 100, 100]) @@ -752,30 +742,6 @@ def get_fg_mask(self, seg_scores, seg_points, cls_id, batch_inds, gt_bboxes_3d, buffer_thr = 0 fg_mask = seg_scores > cls_score_thr + buffer_thr - # add fg points - cfg = self.train_cfg if self.training else self.test_cfg - - if cfg.get('add_gt_fg_points', False): - bsz = len(gt_bboxes_3d) - assert len(seg_scores) == len(seg_points) == len(batch_inds) - point_list = self.split_by_batch(seg_points, batch_inds, bsz) - gt_fg_mask_list = [] - - for i, points in enumerate(point_list): - - gt_mask = gt_labels_3d[i] == cls_id - gts = gt_bboxes_3d[i][gt_mask] - - if not gt_mask.any() or len(points) == 0: - gt_fg_mask_list.append(gt_mask.new_zeros(len(points), dtype=torch.bool)) - continue - - gt_fg_mask_list.append(gts.points_in_boxes(points) > -1) - - gt_fg_mask = self.combine_by_batch(gt_fg_mask_list, batch_inds, bsz) - fg_mask = fg_mask | gt_fg_mask - - return fg_mask def split_by_batch(self, data, batch_idx, batch_size): diff --git a/mmdet3d/models/detectors/two_stage_fsd.py b/mmdet3d/models/detectors/two_stage_fsd.py index 1ed69efece..c722b31342 100644 --- a/mmdet3d/models/detectors/two_stage_fsd.py +++ b/mmdet3d/models/detectors/two_stage_fsd.py @@ -1,6 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. from typing import Union +from mmengine.structures import InstanceData +from torch import Tensor + from .single_stage_fsd import SingleStageFSD import torch from mmdet3d.structures import bbox3d2result @@ -50,7 +53,7 @@ def __init__(self, self.roi_head = builder.build_head(roi_head) self.num_classes = self.bbox_head.num_classes self.runtime_info = dict() - + # def loss(self, # points, # img_metas, @@ -102,7 +105,7 @@ def loss(self, batch_inputs_dict: dict, batch_data_samples: SampleList, losses.update(roi_losses) return losses - + def prepare_roi_input(self, points, cluster_pts_feats, pts_seg_feats, pts_mask, pts_batch_inds, cluster_pts_xyz): assert isinstance(pts_mask, list) pts_mask = pts_mask[0] @@ -113,7 +116,7 @@ def prepare_roi_input(self, points, cluster_pts_feats, pts_seg_feats, pts_mask, if self.training and self.train_cfg.get('detach_cluster_feats', False): cluster_pts_feats = cluster_pts_feats.detach() - + pad_feats = cluster_pts_feats.new_zeros(points.shape[0], cluster_pts_feats.shape[1]) pad_feats[pts_mask] = cluster_pts_feats assert torch.isclose(points[pts_mask], cluster_pts_xyz).all() @@ -174,16 +177,11 @@ def prepare_multi_class_roi_input(self, points, cluster_pts_feats, pts_seg_feats cat_feats = cat_feats[inds] return all_points, cat_feats, all_batch_inds - - def simple_test(self, points, img_metas, imgs=None, rescale=False, gt_bboxes_3d=None, gt_labels_3d=None): + # def predict(self, points, img_metas, imgs=None, rescale=False, gt_bboxes_3d=None, gt_labels_3d=None): + def predict(self, batch_inputs: Tensor, batch_data_samples: SampleList) -> SampleList: - rpn_outs = super().simple_test( - points=points, - img_metas=img_metas, - gt_bboxes_3d=gt_bboxes_3d, - gt_labels_3d=gt_labels_3d, - ) + rpn_outs = super().predict(batch_inputs, batch_data_samples) proposal_list = rpn_outs['proposal_list'] @@ -212,14 +210,20 @@ def simple_test(self, points, img_metas, imgs=None, rescale=False, gt_bboxes_3d= pts_xyz, pts_feats, pts_batch_inds, - img_metas, - proposal_list, - gt_bboxes_3d, - gt_labels_3d, + batch_data_samples, + proposal_list ) + results_3d = [] + for res in results: + pred_instances_3d = InstanceData() + pred_instances_3d.bboxes_3d = res['bboxes_3d'] + pred_instances_3d.scores_3d = res['scores_3d'] + pred_instances_3d.labels_3d = res['labels_3d'] + results_3d.append(pred_instances_3d) + + results = self.add_pred_to_datasample(batch_data_samples, results_3d) return results - def extract_fg_by_gt(self, point_list, gt_bboxes_3d, gt_labels_3d, extra_width): if isinstance(gt_bboxes_3d[0], list): @@ -253,6 +257,6 @@ def extract_fg_by_gt(self, point_list, gt_bboxes_3d, gt_labels_3d, extra_width): this_fg_mask = pts_inds > -1 if not this_fg_mask.any(): this_fg_mask[:min(1000, len(points))] = True - + new_point_list.append(points[this_fg_mask]) return new_point_list diff --git a/mmdet3d/models/roi_heads/bbox_heads/fsd_bbox_head.py b/mmdet3d/models/roi_heads/bbox_heads/fsd_bbox_head.py index 8ff65a7d98..ae4313c996 100644 --- a/mmdet3d/models/roi_heads/bbox_heads/fsd_bbox_head.py +++ b/mmdet3d/models/roi_heads/bbox_heads/fsd_bbox_head.py @@ -10,12 +10,11 @@ from mmdet3d.models.builder import build_loss from mmdet3d.models.layers.sst import build_mlp -from mmdet3d.structures.ops.iou3d_calculator import nms_gpu, nms_normal_gpu from mmdet3d.models.task_modules.builder import build_bbox_coder from mmdet.models.utils import multi_apply from mmdet.utils import reduce_mean -from mmdet3d.models import builder +from mmdet3d.models import builder, nms_bev, nms_normal_bev from mmdet3d.registry import MODELS @@ -626,7 +625,7 @@ def get_bboxes( if rois.numel() == 0: return [( - img_metas[0]['box_type_3d'](rois[:, 1:], rois.size(1) - 1), + img_metas[0].box_type_3d(rois[:, 1:], rois.size(1) - 1), class_pred[0], class_labels[0] ),] @@ -680,7 +679,7 @@ def get_bboxes( selected_scores = cur_cls_score[selected] result_list.append( - (img_metas[batch_id]['box_type_3d'](selected_bboxes, selected_bboxes.size(1)), + (img_metas[batch_id].box_type_3d(selected_bboxes, selected_bboxes.size(1)), selected_scores, selected_label_preds)) return result_list @@ -713,14 +712,14 @@ def multi_class_nms(self, torch.Tensor: Selected indices. """ if use_rotate_nms: - nms_func = nms_gpu + nms_func = nms_bev else: - nms_func = nms_normal_gpu + nms_func = nms_normal_bev assert box_probs.ndim == 1 selected_list = [] - boxes_for_nms = xywhr2xyxyr(input_meta['box_type_3d']( + boxes_for_nms = xywhr2xyxyr(input_meta.box_type_3d( box_preds, box_preds.size(1)).bev) score_thresh = score_thr if isinstance( @@ -758,14 +757,14 @@ def single_class_nms(self, use_rotate_nms=True): if use_rotate_nms: - nms_func = nms_gpu + nms_func = nms_bev else: - nms_func = nms_normal_gpu + nms_func = nms_normal_bev assert box_probs.ndim == 1 selected_list = [] - boxes_for_nms = xywhr2xyxyr(input_meta['box_type_3d']( + boxes_for_nms = xywhr2xyxyr(input_meta.box_type_3d( box_preds, box_preds.size(1)).bev) assert isinstance(score_thr, float) diff --git a/mmdet3d/models/roi_heads/fsd_roi_head.py b/mmdet3d/models/roi_heads/fsd_roi_head.py index 2f6fa9b076..2215866fa1 100644 --- a/mmdet3d/models/roi_heads/fsd_roi_head.py +++ b/mmdet3d/models/roi_heads/fsd_roi_head.py @@ -120,10 +120,7 @@ def simple_test( pts_feats, pts_batch_inds, img_metas, - proposal_list, - gt_bboxes_3d, - gt_labels_3d, - **kwargs): + proposal_list): """Simple testing forward function of PartAggregationROIHead. @@ -140,7 +137,6 @@ def simple_test( dict: Bbox results of one frame. """ - assert len(proposal_list) == 1, 'only support bsz==1 to make cls_preds and labels_3d consistent with bbox_results' rois = bbox3d2roi([res[0].tensor for res in proposal_list]) cls_preds = [res[1] for res in proposal_list] diff --git a/mmdet3d/models/voxel_encoders/voxel_encoder.py b/mmdet3d/models/voxel_encoders/voxel_encoder.py index c2ffed3881..c7c21fcbca 100644 --- a/mmdet3d/models/voxel_encoders/voxel_encoder.py +++ b/mmdet3d/models/voxel_encoders/voxel_encoder.py @@ -7,7 +7,7 @@ from mmdet3d.registry import MODELS from .. import builder from .utils import VFELayer, get_paddings_indicator, DynamicVFELayerV2, DynamicVFELayer -from mmdet3d.models.layers.sst import build_mlp +from mmdet3d.models.layers.sst import build_mlp, scatter_v2 @MODELS.register_module() @@ -560,7 +560,7 @@ def forward(self, new_coors = unq_inv_once = None features_ls = [features] - origin_point_coors = features[:, :3] + # origin_point_coors = features[:, :3] # Find distance of x, y, and z from cluster center if self._with_cluster_center: voxel_mean, _, unq_inv = scatter_v2(features[:, :3], coors, mode='avg', new_coors=new_coors, diff --git a/mmdet3d/structures/ops/iou3d_calculator.py b/mmdet3d/structures/ops/iou3d_calculator.py index 6aad188e18..ff5b5cee1d 100644 --- a/mmdet3d/structures/ops/iou3d_calculator.py +++ b/mmdet3d/structures/ops/iou3d_calculator.py @@ -328,54 +328,54 @@ def axis_aligned_bbox_overlaps_3d(bboxes1, gious = ious - (enclose_area - union) / enclose_area return gious - -# from . import iou3d_cuda - -def nms_gpu(boxes, scores, thresh, pre_maxsize=None, post_max_size=None): - """Nms function with gpu implementation. - - Args: - boxes (torch.Tensor): Input boxes with the shape of [N, 5] - ([x1, y1, x2, y2, ry]). - scores (torch.Tensor): Scores of boxes with the shape of [N]. - thresh (int): Threshold. - pre_maxsize (int): Max size of boxes before nms. Default: None. - post_maxsize (int): Max size of boxes after nms. Default: None. - - Returns: - torch.Tensor: Indexes after nms. - """ - raise NotImplementedError - # order = scores.sort(0, descending=True)[1] - # - # if pre_maxsize is not None: - # order = order[:pre_maxsize] - # boxes = boxes[order].contiguous() - # - # keep = torch.zeros(boxes.size(0), dtype=torch.long) - # num_out = iou3d_cuda.nms_gpu(boxes, keep, thresh, boxes.device.index) - # keep = order[keep[:num_out].cuda(boxes.device)].contiguous() - # if post_max_size is not None: - # keep = keep[:post_max_size] - # return keep - -def nms_normal_gpu(boxes, scores, thresh): - """Normal non maximum suppression on GPU. - - Args: - boxes (torch.Tensor): Input boxes with shape (N, 5). - scores (torch.Tensor): Scores of predicted boxes with shape (N). - thresh (torch.Tensor): Threshold of non maximum suppression. - - Returns: - torch.Tensor: Remaining indices with scores in descending order. - """ - raise NotImplementedError - # order = scores.sort(0, descending=True)[1] - # - # boxes = boxes[order].contiguous() - # - # keep = torch.zeros(boxes.size(0), dtype=torch.long) - # num_out = iou3d_cuda.nms_normal_gpu(boxes, keep, thresh, - # boxes.device.index) - # return order[keep[:num_out].cuda(boxes.device)].contiguous() \ No newline at end of file +# +# # from . import iou3d_cuda +# +# def nms_gpu(boxes, scores, thresh, pre_maxsize=None, post_max_size=None): +# """Nms function with gpu implementation. +# +# Args: +# boxes (torch.Tensor): Input boxes with the shape of [N, 5] +# ([x1, y1, x2, y2, ry]). +# scores (torch.Tensor): Scores of boxes with the shape of [N]. +# thresh (int): Threshold. +# pre_maxsize (int): Max size of boxes before nms. Default: None. +# post_maxsize (int): Max size of boxes after nms. Default: None. +# +# Returns: +# torch.Tensor: Indexes after nms. +# """ +# raise NotImplementedError +# # order = scores.sort(0, descending=True)[1] +# # +# # if pre_maxsize is not None: +# # order = order[:pre_maxsize] +# # boxes = boxes[order].contiguous() +# # +# # keep = torch.zeros(boxes.size(0), dtype=torch.long) +# # num_out = iou3d_cuda.nms_gpu(boxes, keep, thresh, boxes.device.index) +# # keep = order[keep[:num_out].cuda(boxes.device)].contiguous() +# # if post_max_size is not None: +# # keep = keep[:post_max_size] +# # return keep +# +# def nms_normal_gpu(boxes, scores, thresh): +# """Normal non maximum suppression on GPU. +# +# Args: +# boxes (torch.Tensor): Input boxes with shape (N, 5). +# scores (torch.Tensor): Scores of predicted boxes with shape (N). +# thresh (torch.Tensor): Threshold of non maximum suppression. +# +# Returns: +# torch.Tensor: Remaining indices with scores in descending order. +# """ +# raise NotImplementedError +# # order = scores.sort(0, descending=True)[1] +# # +# # boxes = boxes[order].contiguous() +# # +# # keep = torch.zeros(boxes.size(0), dtype=torch.long) +# # num_out = iou3d_cuda.nms_normal_gpu(boxes, keep, thresh, +# # boxes.device.index) +# # return order[keep[:num_out].cuda(boxes.device)].contiguous() From 709f88d1a04118871e718a5a497560bffae3d64d Mon Sep 17 00:00:00 2001 From: HinGwenWoong Date: Mon, 19 Dec 2022 17:16:58 +0800 Subject: [PATCH 4/5] Fix even install spconv but using mmcv.spconv --- mmdet3d/models/layers/spconv/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mmdet3d/models/layers/spconv/__init__.py b/mmdet3d/models/layers/spconv/__init__.py index 98db14a7ac..d5872bd0f9 100644 --- a/mmdet3d/models/layers/spconv/__init__.py +++ b/mmdet3d/models/layers/spconv/__init__.py @@ -3,6 +3,7 @@ try: import spconv + from spconv import pytorch except ImportError: IS_SPCONV2_AVAILABLE = False else: From 220b1a048c5b1b22779334dda63f1e47c2685056 Mon Sep 17 00:00:00 2001 From: HinGwenWoong Date: Wed, 11 Jan 2023 10:20:01 +0800 Subject: [PATCH 5/5] Improve code --- configs/fsd/fsd_waymoD1_1x.py | 33 ++++++++++++----------- mmdet3d/models/detectors/two_stage_fsd.py | 2 +- 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/configs/fsd/fsd_waymoD1_1x.py b/configs/fsd/fsd_waymoD1_1x.py index 61f248eb76..bcc5fe0220 100644 --- a/configs/fsd/fsd_waymoD1_1x.py +++ b/configs/fsd/fsd_waymoD1_1x.py @@ -357,21 +357,24 @@ use_dim=5, file_client_args=file_client_args), dict( - type='MultiScaleFlipAug3D', - img_scale=(1333, 800), - pts_scale_ratio=1, - flip=False, - transforms=[ - dict( - type='GlobalRotScaleTrans', - rot_range=[0, 0], - scale_ratio_range=[1., 1.], - translation_std=[0, 0, 0]), - dict(type='RandomFlip3D'), - dict( - type='PointsRangeFilter', point_cloud_range=_base_.point_cloud_range), - dict(type='Pack3DDetInputs', keys=['points']) - ]) + type='PointsRangeFilter', point_cloud_range=_base_.point_cloud_range), + dict(type='Pack3DDetInputs', keys=['points']) + # dict( + # type='MultiScaleFlipAug3D', + # img_scale=(1333, 800), + # pts_scale_ratio=1, + # flip=False, + # transforms=[ + # dict( + # type='GlobalRotScaleTrans', + # rot_range=[0, 0], + # scale_ratio_range=[1., 1.], + # translation_std=[0, 0, 0]), + # dict(type='RandomFlip3D'), + # dict( + # type='PointsRangeFilter', point_cloud_range=_base_.point_cloud_range), + # dict(type='Pack3DDetInputs', keys=['points']) + # ]) ] train_dataloader = dict( diff --git a/mmdet3d/models/detectors/two_stage_fsd.py b/mmdet3d/models/detectors/two_stage_fsd.py index c722b31342..be02261cf8 100644 --- a/mmdet3d/models/detectors/two_stage_fsd.py +++ b/mmdet3d/models/detectors/two_stage_fsd.py @@ -10,6 +10,7 @@ from .. import builder from mmdet3d.registry import MODELS + from ...structures.det3d_data_sample import SampleList @@ -136,7 +137,6 @@ def prepare_multi_class_roi_input(self, points, cluster_pts_feats, pts_seg_feats if self.training and self.train_cfg.get('detach_cluster_feats', False): cluster_pts_feats = cluster_pts_feats.detach() - ##### prepare points for roi head fg_points_list = [points[m] for m in pts_mask] all_fg_points = torch.cat(fg_points_list, dim=0)