Skip to content

Commit 2072a9d

Browse files
Xiangxu-0103ZwwWayne
authored andcommitted
[Refactor] Refactor voxelization for faster speed (#2062)
* refactor voxelization for faster speed * fix doc typo
1 parent 13ba0dc commit 2072a9d

File tree

2 files changed

+56
-53
lines changed

2 files changed

+56
-53
lines changed

mmdet3d/models/data_preprocessors/data_preprocessor.py

Lines changed: 50 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import math
33
from numbers import Number
4-
from typing import Dict, List, Optional, Sequence, Tuple, Union
4+
from typing import Dict, List, Optional, Sequence, Union
55

66
import numpy as np
77
import torch
@@ -28,24 +28,25 @@ class Det3DDataPreprocessor(DetDataPreprocessor):
2828
- 1) For image data:
2929
- Pad images in inputs to the maximum size of current batch with defined
3030
``pad_value``. The padding size can be divisible by a defined
31-
``pad_size_divisor``
31+
``pad_size_divisor``.
3232
- Stack images in inputs to batch_imgs.
3333
- Convert images in inputs from bgr to rgb if the shape of input is
34-
(3, H, W).
34+
(3, H, W).
3535
- Normalize images in inputs with defined std and mean.
3636
- Do batch augmentations during training.
3737
3838
- 2) For point cloud data:
39-
- if no voxelization, directly return list of point cloud data.
40-
- if voxelization is applied, voxelize point cloud according to
39+
- If no voxelization, directly return list of point cloud data.
40+
- If voxelization is applied, voxelize point cloud according to
4141
``voxel_type`` and obtain ``voxels``.
4242
4343
Args:
44-
voxel (bool): Whether to apply voxelziation to point cloud.
44+
voxel (bool): Whether to apply voxelization to point cloud.
45+
Defaults to False.
4546
voxel_type (str): Voxelization type. Two voxelization types are
4647
provided: 'hard' and 'dynamic', respectively for hard
4748
voxelization and dynamic voxelization. Defaults to 'hard'.
48-
voxel_layer (:obj:`ConfigDict`, optional): Voxelization layer
49+
voxel_layer (dict or :obj:`ConfigDict`, optional): Voxelization layer
4950
config. Defaults to None.
5051
mean (Sequence[Number], optional): The pixel mean of R, G, B channels.
5152
Defaults to None.
@@ -54,11 +55,21 @@ class Det3DDataPreprocessor(DetDataPreprocessor):
5455
pad_size_divisor (int): The size of padded image should be
5556
divisible by ``pad_size_divisor``. Defaults to 1.
5657
pad_value (Number): The padded pixel value. Defaults to 0.
57-
bgr_to_rgb (bool): whether to convert image from BGR to RGB.
58+
pad_mask (bool): Whether to pad instance masks. Defaults to False.
59+
mask_pad_value (int): The padded pixel value for instance masks.
60+
Defaults to 0.
61+
pad_seg (bool): Whether to pad semantic segmentation maps.
62+
Defaults to False.
63+
seg_pad_value (int): The padded pixel value for semantic
64+
segmentation maps. Defaults to 255.
65+
bgr_to_rgb (bool): Whether to convert image from BGR to RGB.
5866
Defaults to False.
59-
rgb_to_bgr (bool): whether to convert image from RGB to RGB.
67+
rgb_to_bgr (bool): Whether to convert image from RGB to BGR.
6068
Defaults to False.
61-
batch_augments (list[dict], optional): Batch-level augmentations
69+
boxtype2tensor (bool): Whether to keep the ``BaseBoxes`` type of
70+
bboxes data or not. Defaults to True.
71+
batch_augments (List[dict], optional): Batch-level augmentations.
72+
Defaults to None.
6273
"""
6374

6475
def __init__(self,
@@ -76,8 +87,8 @@ def __init__(self,
7687
bgr_to_rgb: bool = False,
7788
rgb_to_bgr: bool = False,
7889
boxtype2tensor: bool = True,
79-
batch_augments: Optional[List[dict]] = None):
80-
super().__init__(
90+
batch_augments: Optional[List[dict]] = None) -> None:
91+
super(Det3DDataPreprocessor).__init__(
8192
mean=mean,
8293
std=std,
8394
pad_size_divisor=pad_size_divisor,
@@ -94,24 +105,21 @@ def __init__(self,
94105
if voxel:
95106
self.voxel_layer = Voxelization(**voxel_layer)
96107

97-
def forward(
98-
self,
99-
data: Union[dict, List[dict]],
100-
training: bool = False
101-
) -> Tuple[Union[dict, List[dict]], Optional[list]]:
102-
"""Perform normalization、padding and bgr2rgb conversion based on
108+
def forward(self,
109+
data: Union[dict, List[dict]],
110+
training: bool = False) -> Union[dict, List[dict]]:
111+
"""Perform normalization, padding and bgr2rgb conversion based on
103112
``BaseDataPreprocessor``.
104113
105114
Args:
106-
data (dict | List[dict]): data from dataloader.
115+
data (dict or List[dict]): Data from dataloader.
107116
The dict contains the whole batch data, when it is
108117
a list[dict], the list indicate test time augmentation.
109-
110118
training (bool): Whether to enable training time augmentation.
111119
Defaults to False.
112120
113121
Returns:
114-
Dict | List[Dict]: Data in the same format as the model input.
122+
dict or List[dict]: Data in the same format as the model input.
115123
"""
116124
if isinstance(data, list):
117125
num_augs = len(data)
@@ -126,7 +134,7 @@ def forward(
126134
return self.simple_process(data, training)
127135

128136
def simple_process(self, data: dict, training: bool = False) -> dict:
129-
"""Perform normalizationpadding and bgr2rgb conversion for img data
137+
"""Perform normalization, padding and bgr2rgb conversion for img data
130138
based on ``BaseDataPreprocessor``, and voxelize point cloud if `voxel`
131139
is set to be True.
132140
@@ -188,7 +196,7 @@ def simple_process(self, data: dict, training: bool = False) -> dict:
188196

189197
return {'inputs': batch_inputs, 'data_samples': data_samples}
190198

191-
def preprocess_img(self, _batch_img):
199+
def preprocess_img(self, _batch_img: torch.Tensor) -> torch.Tensor:
192200
# channel transform
193201
if self._channel_conversion:
194202
_batch_img = _batch_img[[2, 1, 0], ...]
@@ -206,7 +214,7 @@ def preprocess_img(self, _batch_img):
206214
return _batch_img
207215

208216
def collate_data(self, data: dict) -> dict:
209-
"""Copying data to the target device and Performs normalization
217+
"""Copying data to the target device and Performs normalization,
210218
padding and bgr2rgb conversion and stack based on
211219
``BaseDataPreprocessor``.
212220
@@ -273,7 +281,7 @@ def collate_data(self, data: dict) -> dict:
273281
raise TypeError(
274282
'Output of `cast_data` should be a list of dict '
275283
'or a tuple with inputs and data_samples, but got'
276-
f'{type(data)} {data}')
284+
f'{type(data)}: {data}')
277285

278286
data['inputs']['imgs'] = batch_imgs
279287

@@ -284,14 +292,14 @@ def collate_data(self, data: dict) -> dict:
284292
def _get_pad_shape(self, data: dict) -> List[tuple]:
285293
"""Get the pad_shape of each image based on data and
286294
pad_size_divisor."""
287-
# rewrite `_get_pad_shape` for obaining image inputs.
295+
# rewrite `_get_pad_shape` for obtaining image inputs.
288296
_batch_inputs = data['inputs']['img']
289297
# Process data with `pseudo_collate`.
290298
if is_list_of(_batch_inputs, torch.Tensor):
291299
batch_pad_shape = []
292300
for ori_input in _batch_inputs:
293301
if ori_input.dim() == 4:
294-
# mean multiivew input, select ont of the
302+
# mean multiview input, select one of the
295303
# image to calculate the pad shape
296304
ori_input = ori_input[0]
297305
pad_h = int(
@@ -316,24 +324,24 @@ def _get_pad_shape(self, data: dict) -> List[tuple]:
316324
batch_pad_shape = [(pad_h, pad_w)] * _batch_inputs.shape[0]
317325
else:
318326
raise TypeError('Output of `cast_data` should be a list of dict '
319-
'or a tuple with inputs and data_samples, but got'
327+
'or a tuple with inputs and data_samples, but got '
320328
f'{type(data)}: {data}')
321329
return batch_pad_shape
322330

323331
@torch.no_grad()
324-
def voxelize(self, points: List[torch.Tensor]) -> Dict:
332+
def voxelize(self, points: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
325333
"""Apply voxelization to point cloud.
326334
327335
Args:
328336
points (List[Tensor]): Point cloud in one data batch.
329337
330338
Returns:
331-
dict[str, Tensor]: Voxelization information.
339+
Dict[str, Tensor]: Voxelization information.
332340
333-
- voxels (Tensor): Features of voxels, shape is MXNxC for hard
334-
voxelization, NXC for dynamic voxelization.
335-
- coors (Tensor): Coordinates of voxels, shape is Nx(1+NDim),
336-
where 1 represents the batch index.
341+
- voxels (Tensor): Features of voxels, shape is MxNxC for hard
342+
voxelization, NxC for dynamic voxelization.
343+
- coors (Tensor): Coordinates of voxels, shape is Nx(1+NDim),
344+
where 1 represents the batch index.
337345
- num_points (Tensor, optional): Number of points in each voxel.
338346
- voxel_centers (Tensor, optional): Centers of voxels.
339347
"""
@@ -342,43 +350,38 @@ def voxelize(self, points: List[torch.Tensor]) -> Dict:
342350

343351
if self.voxel_type == 'hard':
344352
voxels, coors, num_points, voxel_centers = [], [], [], []
345-
for res in points:
353+
for i, res in enumerate(points):
346354
res_voxels, res_coors, res_num_points = self.voxel_layer(res)
347355
res_voxel_centers = (
348356
res_coors[:, [2, 1, 0]] + 0.5) * res_voxels.new_tensor(
349357
self.voxel_layer.voxel_size) + res_voxels.new_tensor(
350358
self.voxel_layer.point_cloud_range[0:3])
359+
res_coors = F.pad(res_coors, (1, 0), mode='constant', value=i)
351360
voxels.append(res_voxels)
352361
coors.append(res_coors)
353362
num_points.append(res_num_points)
354363
voxel_centers.append(res_voxel_centers)
355364

356365
voxels = torch.cat(voxels, dim=0)
366+
coors = torch.cat(coors, dim=0)
357367
num_points = torch.cat(num_points, dim=0)
358368
voxel_centers = torch.cat(voxel_centers, dim=0)
359-
coors_batch = []
360-
for i, coor in enumerate(coors):
361-
coor_pad = F.pad(coor, (1, 0), mode='constant', value=i)
362-
coors_batch.append(coor_pad)
363-
coors_batch = torch.cat(coors_batch, dim=0)
369+
364370
voxel_dict['num_points'] = num_points
365371
voxel_dict['voxel_centers'] = voxel_centers
366372
elif self.voxel_type == 'dynamic':
367373
coors = []
368374
# dynamic voxelization only provide a coors mapping
369-
for res in points:
375+
for i, res in enumerate(points):
370376
res_coors = self.voxel_layer(res)
377+
res_coors = F.pad(res_coors, (1, 0), mode='constant', value=i)
371378
coors.append(res_coors)
372379
voxels = torch.cat(points, dim=0)
373-
coors_batch = []
374-
for i, coor in enumerate(coors):
375-
coor_pad = F.pad(coor, (1, 0), mode='constant', value=i)
376-
coors_batch.append(coor_pad)
377-
coors_batch = torch.cat(coors_batch, dim=0)
380+
coors = torch.cat(coors, dim=0)
378381
else:
379382
raise ValueError(f'Invalid voxelization type {self.voxel_type}')
380383

381384
voxel_dict['voxels'] = voxels
382-
voxel_dict['coors'] = coors_batch
385+
voxel_dict['coors'] = coors
383386

384387
return voxel_dict

mmdet3d/models/data_preprocessors/utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def multiview_img_stack_batch(
1212
"""
1313
Compared to the stack_batch in mmengine.model.utils,
1414
multiview_img_stack_batch further handle the multiview images.
15-
see diff of padded_sizes[:, :-2] = 0 vs padded_sizees[:, 0] = 0 in line 47
15+
see diff of padded_sizes[:, :-2] = 0 vs padded_sizes[:, 0] = 0 in line 47
1616
Stack multiple tensors to form a batch and pad the tensor to the max
1717
shape use the right bottom padding mode in these images. If
1818
``pad_size_divisor > 0``, add padding to ensure the shape of each dim is
@@ -23,20 +23,20 @@ def multiview_img_stack_batch(
2323
pad_size_divisor (int): If ``pad_size_divisor > 0``, add padding
2424
to ensure the shape of each dim is divisible by
2525
``pad_size_divisor``. This depends on the model, and many
26-
models need to be divisible by 32. Defaults to 1
27-
pad_value (int, float): The padding value. Defaults to 0.
26+
models need to be divisible by 32. Defaults to 1.
27+
pad_value (int or float): The padding value. Defaults to 0.
2828
2929
Returns:
3030
Tensor: The n dim tensor.
3131
"""
3232
assert isinstance(
3333
tensor_list,
34-
list), (f'Expected input type to be list, but got {type(tensor_list)}')
34+
list), f'Expected input type to be list, but got {type(tensor_list)}'
3535
assert tensor_list, '`tensor_list` could not be an empty list'
3636
assert len({
3737
tensor.ndim
3838
for tensor in tensor_list
39-
}) == 1, (f'Expected the dimensions of all tensors must be the same, '
39+
}) == 1, ('Expected the dimensions of all tensors must be the same, '
4040
f'but got {[tensor.ndim for tensor in tensor_list]}')
4141

4242
dim = tensor_list[0].dim()
@@ -46,7 +46,7 @@ def multiview_img_stack_batch(
4646
max_sizes = torch.ceil(
4747
torch.max(all_sizes, dim=0)[0] / pad_size_divisor) * pad_size_divisor
4848
padded_sizes = max_sizes - all_sizes
49-
# The first dim normally means channel, which should not be padded.
49+
# The first dim normally means channel, which should not be padded.
5050
padded_sizes[:, :-2] = 0
5151
if padded_sizes.sum() == 0:
5252
return torch.stack(tensor_list)

0 commit comments

Comments
 (0)