Skip to content

Commit 06caf31

Browse files
authored
Merge pull request #1095 from lrzpellegrini/sampler_base_dataloading
Implement sampler-based dataloading logic
2 parents 441a968 + 712b8d2 commit 06caf31

File tree

2 files changed

+310
-120
lines changed

2 files changed

+310
-120
lines changed
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
################################################################################
2+
# Copyright (c) 2021 ContinualAI. #
3+
# Copyrights licensed under the MIT License. #
4+
# See the accompanying LICENSE file for terms. #
5+
# #
6+
# Date: 21-04-2022 #
7+
# Author(s): Antonio Carta, Lorenzo Pellegrini #
8+
# E-mail: contact@continualai.org #
9+
# Website: avalanche.continualai.org #
10+
################################################################################
11+
12+
import itertools
13+
from collections import defaultdict
14+
15+
import torch
16+
17+
18+
def classification_collate_mbatches_fn(mbatches):
19+
"""Combines multiple mini-batches together.
20+
21+
Concatenates each tensor in the mini-batches along dimension 0 (usually
22+
this is the batch size).
23+
24+
:param mbatches: sequence of mini-batches.
25+
:return: a single mini-batch
26+
"""
27+
batch = []
28+
for i in range(len(mbatches[0])):
29+
t = classification_single_values_collate_fn(
30+
[el[i] for el in mbatches], i)
31+
batch.append(t)
32+
return batch
33+
34+
35+
def classification_single_values_collate_fn(values_list, index):
36+
"""
37+
Collate function used to merge the single elements (x or y or t,
38+
etcetera) of a minibatch of data from a classification dataset.
39+
40+
This function assumes that all values are tensors of the same shape
41+
(excluding the first dimension).
42+
43+
:param values_list: The list of values to merge.
44+
:param index: The index of the element. 0 for x values, 1 for y values,
45+
etcetera. In this implementation, this parameter is ignored.
46+
:return: The merged values.
47+
"""
48+
return torch.cat(values_list, dim=0)
49+
50+
51+
def detection_collate_fn(batch):
52+
"""
53+
Collate function used when loading detection datasets using a DataLoader.
54+
55+
This will merge the single samples of a batch to create a minibatch.
56+
This collate function follows the torchvision format for detection tasks.
57+
"""
58+
return tuple(zip(*batch))
59+
60+
61+
def detection_collate_mbatches_fn(mbatches):
62+
"""
63+
Collate function used when loading detection datasets using a DataLoader.
64+
65+
This will merge multiple batches to create a concatenated batch.
66+
67+
Beware that merging multiple batches is different from creating a batch
68+
from single dataset elements: Batches can be created from a
69+
list of single dataset elements by using :func:`detection_collate_fn`.
70+
"""
71+
lists_dict = defaultdict(list)
72+
for mb in mbatches:
73+
for mb_elem_idx, mb_elem in enumerate(mb):
74+
lists_dict[mb_elem_idx].append(mb_elem)
75+
76+
lists = []
77+
for mb_elem_idx in range(max(lists_dict.keys()) + 1):
78+
lists.append(list(itertools.chain.from_iterable(
79+
lists_dict[mb_elem_idx]
80+
)))
81+
82+
return lists
83+
84+
85+
__all__ = [
86+
'classification_collate_mbatches_fn',
87+
'classification_single_values_collate_fn',
88+
'detection_collate_fn',
89+
'detection_collate_mbatches_fn'
90+
]

0 commit comments

Comments
 (0)