Skip to content

Commit 9a03b56

Browse files
authored
[Features] Support PV_RCNN modules (#1957)
* add pvrcnn module code * add voxelsa * fix * fix comments * fix comments * fix comments * add stack sa * fix * fix comments * fix comments * fix * add ut * fix comments
1 parent 0f4ba41 commit 9a03b56

File tree

18 files changed

+2310
-11
lines changed

18 files changed

+2310
-11
lines changed
Lines changed: 353 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,353 @@
1+
_base_ = [
2+
'../_base_/datasets/kitti-3d-3class.py',
3+
'../_base_/schedules/cyclic-40e.py', '../_base_/default_runtime.py'
4+
]
5+
6+
voxel_size = [0.05, 0.05, 0.1]
7+
point_cloud_range = [0, -40, -3, 70.4, 40, 1]
8+
9+
data_root = 'data/kitti/'
10+
class_names = ['Pedestrian', 'Cyclist', 'Car']
11+
metainfo = dict(CLASSES=class_names)
12+
db_sampler = dict(
13+
data_root=data_root,
14+
info_path=data_root + 'kitti_dbinfos_train.pkl',
15+
rate=1.0,
16+
prepare=dict(
17+
filter_by_difficulty=[-1],
18+
filter_by_min_points=dict(Car=5, Pedestrian=5, Cyclist=5)),
19+
classes=class_names,
20+
sample_groups=dict(Car=15, Pedestrian=10, Cyclist=10),
21+
points_loader=dict(
22+
type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4))
23+
24+
train_pipeline = [
25+
dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4),
26+
dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True),
27+
dict(type='ObjectSample', db_sampler=db_sampler, use_ground_plane=True),
28+
dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5),
29+
dict(
30+
type='GlobalRotScaleTrans',
31+
rot_range=[-0.78539816, 0.78539816],
32+
scale_ratio_range=[0.95, 1.05]),
33+
dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range),
34+
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
35+
dict(type='PointShuffle'),
36+
dict(
37+
type='Pack3DDetInputs',
38+
keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
39+
]
40+
test_pipeline = [
41+
dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4),
42+
dict(
43+
type='MultiScaleFlipAug3D',
44+
img_scale=(1333, 800),
45+
pts_scale_ratio=1,
46+
flip=False,
47+
transforms=[
48+
dict(
49+
type='GlobalRotScaleTrans',
50+
rot_range=[0, 0],
51+
scale_ratio_range=[1., 1.],
52+
translation_std=[0, 0, 0]),
53+
dict(type='RandomFlip3D'),
54+
dict(
55+
type='PointsRangeFilter', point_cloud_range=point_cloud_range)
56+
]),
57+
dict(type='Pack3DDetInputs', keys=['points'])
58+
]
59+
60+
model = dict(
61+
type='PointVoxelRCNN',
62+
data_preprocessor=dict(
63+
type='Det3DDataPreprocessor',
64+
voxel=True,
65+
voxel_layer=dict(
66+
max_num_points=5, # max_points_per_voxel
67+
point_cloud_range=point_cloud_range,
68+
voxel_size=voxel_size,
69+
max_voxels=(16000, 40000))),
70+
voxel_encoder=dict(type='HardSimpleVFE'),
71+
middle_encoder=dict(
72+
type='SparseEncoder',
73+
in_channels=4,
74+
sparse_shape=[41, 1600, 1408],
75+
order=('conv', 'norm', 'act'),
76+
encoder_paddings=((0, 0, 0), ((1, 1, 1), 0, 0), ((1, 1, 1), 0, 0),
77+
((0, 1, 1), 0, 0)),
78+
return_middle_feats=True),
79+
points_encoder=dict(
80+
type='VoxelSetAbstraction',
81+
num_keypoints=2048,
82+
fused_out_channel=128,
83+
voxel_size=voxel_size,
84+
point_cloud_range=point_cloud_range,
85+
voxel_sa_cfgs_list=[
86+
dict(
87+
type='StackedSAModuleMSG',
88+
in_channels=16,
89+
scale_factor=1,
90+
radius=(0.4, 0.8),
91+
sample_nums=(16, 16),
92+
mlp_channels=((16, 16), (16, 16)),
93+
use_xyz=True),
94+
dict(
95+
type='StackedSAModuleMSG',
96+
in_channels=32,
97+
scale_factor=2,
98+
radius=(0.8, 1.2),
99+
sample_nums=(16, 32),
100+
mlp_channels=((32, 32), (32, 32)),
101+
use_xyz=True),
102+
dict(
103+
type='StackedSAModuleMSG',
104+
in_channels=64,
105+
scale_factor=4,
106+
radius=(1.2, 2.4),
107+
sample_nums=(16, 32),
108+
mlp_channels=((64, 64), (64, 64)),
109+
use_xyz=True),
110+
dict(
111+
type='StackedSAModuleMSG',
112+
in_channels=64,
113+
scale_factor=8,
114+
radius=(2.4, 4.8),
115+
sample_nums=(16, 32),
116+
mlp_channels=((64, 64), (64, 64)),
117+
use_xyz=True)
118+
],
119+
rawpoints_sa_cfgs=dict(
120+
type='StackedSAModuleMSG',
121+
in_channels=1,
122+
radius=(0.4, 0.8),
123+
sample_nums=(16, 16),
124+
mlp_channels=((16, 16), (16, 16)),
125+
use_xyz=True),
126+
bev_feat_channel=256,
127+
bev_scale_factor=8),
128+
backbone=dict(
129+
type='SECOND',
130+
in_channels=256,
131+
layer_nums=[5, 5],
132+
layer_strides=[1, 2],
133+
out_channels=[128, 256]),
134+
neck=dict(
135+
type='SECONDFPN',
136+
in_channels=[128, 256],
137+
upsample_strides=[1, 2],
138+
out_channels=[256, 256]),
139+
rpn_head=dict(
140+
type='PartA2RPNHead',
141+
num_classes=3,
142+
in_channels=512,
143+
feat_channels=512,
144+
use_direction_classifier=True,
145+
dir_offset=0.78539,
146+
anchor_generator=dict(
147+
type='Anchor3DRangeGenerator',
148+
ranges=[[0, -40.0, -0.6, 70.4, 40.0, -0.6],
149+
[0, -40.0, -0.6, 70.4, 40.0, -0.6],
150+
[0, -40.0, -1.78, 70.4, 40.0, -1.78]],
151+
sizes=[[0.8, 0.6, 1.73], [1.76, 0.6, 1.73], [3.9, 1.6, 1.56]],
152+
rotations=[0, 1.57],
153+
reshape_out=False),
154+
diff_rad_by_sin=True,
155+
assigner_per_size=True,
156+
assign_per_class=True,
157+
bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder'),
158+
loss_cls=dict(
159+
type='mmdet.FocalLoss',
160+
use_sigmoid=True,
161+
gamma=2.0,
162+
alpha=0.25,
163+
loss_weight=1.0),
164+
loss_bbox=dict(
165+
type='mmdet.SmoothL1Loss', beta=1.0 / 9.0, loss_weight=2.0),
166+
loss_dir=dict(
167+
type='mmdet.CrossEntropyLoss', use_sigmoid=False,
168+
loss_weight=0.2)),
169+
roi_head=dict(
170+
type='PVRCNNRoiHead',
171+
num_classes=3,
172+
semantic_head=dict(
173+
type='ForegroundSegmentationHead',
174+
in_channels=640,
175+
extra_width=0.1,
176+
loss_seg=dict(
177+
type='mmdet.FocalLoss',
178+
use_sigmoid=True,
179+
reduction='sum',
180+
gamma=2.0,
181+
alpha=0.25,
182+
activated=True,
183+
loss_weight=1.0)),
184+
bbox_roi_extractor=dict(
185+
type='Batch3DRoIGridExtractor',
186+
grid_size=6,
187+
roi_layer=dict(
188+
type='StackedSAModuleMSG',
189+
in_channels=128,
190+
radius=(0.8, 1.6),
191+
sample_nums=(16, 16),
192+
mlp_channels=((64, 64), (64, 64)),
193+
use_xyz=True,
194+
pool_mod='max'),
195+
),
196+
bbox_head=dict(
197+
type='PVRCNNBBoxHead',
198+
in_channels=128,
199+
grid_size=6,
200+
num_classes=3,
201+
class_agnostic=True,
202+
shared_fc_channels=(256, 256),
203+
reg_channels=(256, 256),
204+
cls_channels=(256, 256),
205+
dropout_ratio=0.3,
206+
with_corner_loss=True,
207+
bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder'),
208+
loss_bbox=dict(
209+
type='mmdet.SmoothL1Loss',
210+
beta=1.0 / 9.0,
211+
reduction='sum',
212+
loss_weight=1.0),
213+
loss_cls=dict(
214+
type='mmdet.CrossEntropyLoss',
215+
use_sigmoid=True,
216+
reduction='sum',
217+
loss_weight=1.0))),
218+
# model training and testing settings
219+
train_cfg=dict(
220+
rpn=dict(
221+
assigner=[
222+
dict( # for Pedestrian
223+
type='Max3DIoUAssigner',
224+
iou_calculator=dict(type='BboxOverlapsNearest3D'),
225+
pos_iou_thr=0.5,
226+
neg_iou_thr=0.35,
227+
min_pos_iou=0.35,
228+
ignore_iof_thr=-1),
229+
dict( # for Cyclist
230+
type='Max3DIoUAssigner',
231+
iou_calculator=dict(type='BboxOverlapsNearest3D'),
232+
pos_iou_thr=0.5,
233+
neg_iou_thr=0.35,
234+
min_pos_iou=0.35,
235+
ignore_iof_thr=-1),
236+
dict( # for Car
237+
type='Max3DIoUAssigner',
238+
iou_calculator=dict(type='BboxOverlapsNearest3D'),
239+
pos_iou_thr=0.6,
240+
neg_iou_thr=0.45,
241+
min_pos_iou=0.45,
242+
ignore_iof_thr=-1)
243+
],
244+
allowed_border=0,
245+
pos_weight=-1,
246+
debug=False),
247+
rpn_proposal=dict(
248+
nms_pre=9000,
249+
nms_post=512,
250+
max_num=512,
251+
nms_thr=0.8,
252+
score_thr=0,
253+
use_rotate_nms=True),
254+
rcnn=dict(
255+
assigner=[
256+
dict( # for Pedestrian
257+
type='Max3DIoUAssigner',
258+
iou_calculator=dict(
259+
type='BboxOverlaps3D', coordinate='lidar'),
260+
pos_iou_thr=0.55,
261+
neg_iou_thr=0.55,
262+
min_pos_iou=0.55,
263+
ignore_iof_thr=-1),
264+
dict( # for Cyclist
265+
type='Max3DIoUAssigner',
266+
iou_calculator=dict(
267+
type='BboxOverlaps3D', coordinate='lidar'),
268+
pos_iou_thr=0.55,
269+
neg_iou_thr=0.55,
270+
min_pos_iou=0.55,
271+
ignore_iof_thr=-1),
272+
dict( # for Car
273+
type='Max3DIoUAssigner',
274+
iou_calculator=dict(
275+
type='BboxOverlaps3D', coordinate='lidar'),
276+
pos_iou_thr=0.55,
277+
neg_iou_thr=0.55,
278+
min_pos_iou=0.55,
279+
ignore_iof_thr=-1)
280+
],
281+
sampler=dict(
282+
type='IoUNegPiecewiseSampler',
283+
num=128,
284+
pos_fraction=0.5,
285+
neg_piece_fractions=[0.8, 0.2],
286+
neg_iou_piece_thrs=[0.55, 0.1],
287+
neg_pos_ub=-1,
288+
add_gt_as_proposals=False,
289+
return_iou=True),
290+
cls_pos_thr=0.75,
291+
cls_neg_thr=0.25)),
292+
test_cfg=dict(
293+
rpn=dict(
294+
nms_pre=1024,
295+
nms_post=100,
296+
max_num=100,
297+
nms_thr=0.7,
298+
score_thr=0,
299+
use_rotate_nms=True),
300+
rcnn=dict(
301+
use_rotate_nms=True,
302+
use_raw_score=True,
303+
nms_thr=0.1,
304+
score_thr=0.1)))
305+
train_dataloader = dict(
306+
batch_size=2,
307+
num_workers=2,
308+
dataset=dict(dataset=dict(pipeline=train_pipeline, metainfo=metainfo)))
309+
test_dataloader = dict(dataset=dict(pipeline=test_pipeline, metainfo=metainfo))
310+
eval_dataloader = dict(dataset=dict(pipeline=test_pipeline, metainfo=metainfo))
311+
lr = 0.001
312+
optim_wrapper = dict(optimizer=dict(lr=lr))
313+
param_scheduler = [
314+
# learning rate scheduler
315+
# During the first 16 epochs, learning rate increases from 0 to lr * 10
316+
# during the next 24 epochs, learning rate decreases from lr * 10 to
317+
# lr * 1e-4
318+
dict(
319+
type='CosineAnnealingLR',
320+
T_max=15,
321+
eta_min=lr * 10,
322+
begin=0,
323+
end=15,
324+
by_epoch=True,
325+
convert_to_iter_based=True),
326+
dict(
327+
type='CosineAnnealingLR',
328+
T_max=25,
329+
eta_min=lr * 1e-4,
330+
begin=15,
331+
end=40,
332+
by_epoch=True,
333+
convert_to_iter_based=True),
334+
# momentum scheduler
335+
# During the first 16 epochs, momentum increases from 0 to 0.85 / 0.95
336+
# during the next 24 epochs, momentum increases from 0.85 / 0.95 to 1
337+
dict(
338+
type='CosineAnnealingMomentum',
339+
T_max=15,
340+
eta_min=0.85 / 0.95,
341+
begin=0,
342+
end=15,
343+
by_epoch=True,
344+
convert_to_iter_based=True),
345+
dict(
346+
type='CosineAnnealingMomentum',
347+
T_max=25,
348+
eta_min=1,
349+
begin=15,
350+
end=40,
351+
by_epoch=True,
352+
convert_to_iter_based=True)
353+
]

mmdet3d/apis/inference.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ def init_model(config: Union[str, Path, Config],
6767

6868
if checkpoint is not None:
6969
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
70-
7170
dataset_meta = checkpoint['meta'].get('dataset_meta', None)
7271
# save the dataset_meta in the model for convenience
7372
if 'dataset_meta' in checkpoint.get('meta', {}):

mmdet3d/models/detectors/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from .mvx_two_stage import MVXTwoStageDetector
1515
from .parta2 import PartA2
1616
from .point_rcnn import PointRCNN
17+
from .pv_rcnn import PointVoxelRCNN
1718
from .sassd import SASSD
1819
from .single_stage_mono3d import SingleStageMono3DDetector
1920
from .smoke_mono3d import SMOKEMono3D
@@ -26,5 +27,6 @@
2627
'DynamicMVXFasterRCNN', 'MVXFasterRCNN', 'PartA2', 'VoteNet', 'H3DNet',
2728
'CenterPoint', 'SSD3DNet', 'ImVoteNet', 'SingleStageMono3DDetector',
2829
'FCOSMono3D', 'ImVoxelNet', 'GroupFree3DNet', 'PointRCNN', 'SMOKEMono3D',
29-
'SASSD', 'MinkSingleStage3DDetector', 'MultiViewDfM', 'DfM'
30+
'SASSD', 'MinkSingleStage3DDetector', 'MultiViewDfM', 'DfM',
31+
'PointVoxelRCNN'
3032
]

0 commit comments

Comments
 (0)