diff --git a/avalanche/benchmarks/utils/collate_functions.py b/avalanche/benchmarks/utils/collate_functions.py new file mode 100644 index 000000000..be94d4ac7 --- /dev/null +++ b/avalanche/benchmarks/utils/collate_functions.py @@ -0,0 +1,90 @@ +################################################################################ +# Copyright (c) 2021 ContinualAI. # +# Copyrights licensed under the MIT License. # +# See the accompanying LICENSE file for terms. # +# # +# Date: 21-04-2022 # +# Author(s): Antonio Carta, Lorenzo Pellegrini # +# E-mail: contact@continualai.org # +# Website: avalanche.continualai.org # +################################################################################ + +import itertools +from collections import defaultdict + +import torch + + +def classification_collate_mbatches_fn(mbatches): + """Combines multiple mini-batches together. + + Concatenates each tensor in the mini-batches along dimension 0 (usually + this is the batch size). + + :param mbatches: sequence of mini-batches. + :return: a single mini-batch + """ + batch = [] + for i in range(len(mbatches[0])): + t = classification_single_values_collate_fn( + [el[i] for el in mbatches], i) + batch.append(t) + return batch + + +def classification_single_values_collate_fn(values_list, index): + """ + Collate function used to merge the single elements (x or y or t, + etcetera) of a minibatch of data from a classification dataset. + + This function assumes that all values are tensors of the same shape + (excluding the first dimension). + + :param values_list: The list of values to merge. + :param index: The index of the element. 0 for x values, 1 for y values, + etcetera. In this implementation, this parameter is ignored. + :return: The merged values. + """ + return torch.cat(values_list, dim=0) + + +def detection_collate_fn(batch): + """ + Collate function used when loading detection datasets using a DataLoader. + + This will merge the single samples of a batch to create a minibatch. + This collate function follows the torchvision format for detection tasks. + """ + return tuple(zip(*batch)) + + +def detection_collate_mbatches_fn(mbatches): + """ + Collate function used when loading detection datasets using a DataLoader. + + This will merge multiple batches to create a concatenated batch. + + Beware that merging multiple batches is different from creating a batch + from single dataset elements: Batches can be created from a + list of single dataset elements by using :func:`detection_collate_fn`. + """ + lists_dict = defaultdict(list) + for mb in mbatches: + for mb_elem_idx, mb_elem in enumerate(mb): + lists_dict[mb_elem_idx].append(mb_elem) + + lists = [] + for mb_elem_idx in range(max(lists_dict.keys()) + 1): + lists.append(list(itertools.chain.from_iterable( + lists_dict[mb_elem_idx] + ))) + + return lists + + +__all__ = [ + 'classification_collate_mbatches_fn', + 'classification_single_values_collate_fn', + 'detection_collate_fn', + 'detection_collate_mbatches_fn' +] diff --git a/avalanche/benchmarks/utils/data_loader.py b/avalanche/benchmarks/utils/data_loader.py index f0b1a9dd1..3ac88a1c7 100644 --- a/avalanche/benchmarks/utils/data_loader.py +++ b/avalanche/benchmarks/utils/data_loader.py @@ -14,57 +14,26 @@ support for balanced dataloading between different tasks or balancing between the current data and the replay memory. """ -import itertools -from collections import defaultdict from itertools import chain -from typing import Dict, Sequence +from typing import Dict, Sequence, Union import torch -from torch.utils.data import RandomSampler +from torch.utils.data import RandomSampler, DistributedSampler from torch.utils.data.dataloader import DataLoader from avalanche.benchmarks.utils import AvalancheDataset +from avalanche.benchmarks.utils.collate_functions import \ + classification_collate_mbatches_fn +from avalanche.benchmarks.utils.collate_functions import detection_collate_fn \ + as _detection_collate_fn +from avalanche.benchmarks.utils.collate_functions import \ + detection_collate_mbatches_fn as _detection_collate_mbatches_fn +_default_collate_mbatches_fn = classification_collate_mbatches_fn -def _default_collate_mbatches_fn(mbatches): - """Combines multiple mini-batches together. +detection_collate_fn = _detection_collate_fn - Concatenates each tensor in the mini-batches along dimension 0 (usually this - is the batch size). - - :param mbatches: sequence of mini-batches. - :return: a single mini-batch - """ - batch = [] - for i in range(len(mbatches[0])): - t = torch.cat([el[i] for el in mbatches], dim=0) - batch.append(t) - return batch - - -def detection_collate_fn(batch): - """ - Collate function used when loading detection datasets using a DataLoader. - """ - return tuple(zip(*batch)) - - -def detection_collate_mbatches_fn(mbatches): - """ - Collate function used when loading detection datasets using a DataLoader. - """ - lists_dict = defaultdict(list) - for mb in mbatches: - for mb_elem_idx, mb_elem in enumerate(mb): - lists_dict[mb_elem_idx].append(mb_elem) - - lists = [] - for mb_elem_idx in range(max(lists_dict.keys()) + 1): - lists.append( - list(itertools.chain.from_iterable(lists_dict[mb_elem_idx])) - ) - - return lists +detection_collate_mbatches_fn = _detection_collate_mbatches_fn def collate_from_data_or_kwargs(data, kwargs): @@ -105,7 +74,7 @@ def __init__( each task separately. See pytorch :class:`DataLoader`. """ self.data = data - self.dataloaders: Dict[int, DataLoader] = {} + self.dataloaders: Dict[int, DataLoader] = dict() self.oversample_small_tasks = oversample_small_tasks self.collate_mbatches = collate_mbatches @@ -142,6 +111,7 @@ def __init__( oversample_small_groups: bool = False, collate_mbatches=_default_collate_mbatches_fn, batch_size: int = 32, + distributed_sampling: bool = True, **kwargs ): """Data loader that balances data from multiple datasets. @@ -166,9 +136,11 @@ def __init__( each group separately. See pytorch :class:`DataLoader`. """ self.datasets = datasets - self.dataloaders = [] + self.batch_sizes = [] self.oversample_small_groups = oversample_small_groups self.collate_mbatches = collate_mbatches + self.distributed_sampling = distributed_sampling + self.loader_kwargs = kwargs # check if batch_size is larger than or equal to the number of datasets assert batch_size >= len(datasets) @@ -177,49 +149,75 @@ def __init__( ds_batch_size = batch_size // len(datasets) remaining = batch_size % len(datasets) - for data in self.datasets: + for _ in self.datasets: bs = ds_batch_size if remaining > 0: bs += 1 remaining -= 1 - collate_from_data_or_kwargs(data, kwargs) - self.dataloaders.append(DataLoader( - data, batch_size=bs, **kwargs)) - self.max_len = max([len(d) for d in self.dataloaders]) + self.batch_sizes.append(bs) + + loaders_for_len_estimation = [ + _make_data_loader( + dataset, + distributed_sampling, + kwargs, + mb_size, + force_no_workers=True)[0] + for dataset, mb_size in zip(self.datasets, self.batch_sizes)] + + self.max_len = max([len(d) for d in loaders_for_len_estimation]) def __iter__(self): + dataloaders = [] + samplers = [] + for dataset, mb_size in zip(self.datasets, self.batch_sizes): + data_l, data_l_sampler = _make_data_loader( + dataset, + self.distributed_sampling, + self.loader_kwargs, + mb_size) + + dataloaders.append(data_l) + samplers.append(data_l_sampler) + iter_dataloaders = [] - for dl in self.dataloaders: + for dl in dataloaders: iter_dataloaders.append(iter(dl)) - max_num_mbatches = max([len(d) for d in iter_dataloaders]) + max_num_mbatches = max([len(d) for d in dataloaders]) for it in range(max_num_mbatches): mb_curr = [] - is_removed_dataloader = False + removed_dataloaders_idxs = [] # copy() is necessary because we may remove keys from the # dictionary. This would break the generator. - for tid, t_loader in enumerate(iter_dataloaders): + for tid, (t_loader, t_loader_sampler) in \ + enumerate(zip(iter_dataloaders, samplers)): try: batch = next(t_loader) except StopIteration: # StopIteration is thrown if dataset ends. if self.oversample_small_groups: # reinitialize data loader - iter_dataloaders[tid] = iter(self.dataloaders[tid]) + if isinstance(t_loader_sampler, DistributedSampler): + # Manage shuffling in DistributedSampler + t_loader_sampler.set_epoch(t_loader_sampler.epoch+1) + + iter_dataloaders[tid] = iter(dataloaders[tid]) batch = next(iter_dataloaders[tid]) else: # We iteratated over all the data from this group # and we don't need the iterator anymore. iter_dataloaders[tid] = None - is_removed_dataloader = True + samplers[tid] = None + removed_dataloaders_idxs.append(tid) continue mb_curr.append(batch) yield self.collate_mbatches(mb_curr) # clear empty data-loaders - if is_removed_dataloader: - while None in iter_dataloaders: - iter_dataloaders.remove(None) + for tid in reversed(removed_dataloaders_idxs): + del iter_dataloaders[tid] + del samplers[tid] def __len__(self): return self.max_len @@ -233,6 +231,7 @@ def __init__( self, datasets: Sequence[AvalancheDataset], collate_mbatches=_default_collate_mbatches_fn, + distributed_sampling: bool = True, **kwargs ): """Data loader that balances data from multiple datasets emitting an @@ -254,12 +253,23 @@ def __init__( self.collate_mbatches = collate_mbatches for data in self.datasets: + if _DistributedHelper.is_distributed and distributed_sampling: + seed = torch.randint( + 0, + 2 ** 32 - 1 - _DistributedHelper.world_size, + (1,), + dtype=torch.int64) + seed += _DistributedHelper.rank + generator = torch.Generator() + generator.manual_seed(int(seed)) + else: + generator = None # Default infinite_sampler = RandomSampler( - data, replacement=True, num_samples=10 ** 10 + data, replacement=True, num_samples=10 ** 10, + generator=generator ) collate_from_data_or_kwargs(data, kwargs) - dl = DataLoader(data, sampler=infinite_sampler, - **kwargs) + dl = DataLoader(data, sampler=infinite_sampler, **kwargs) self.dataloaders.append(dl) self.max_len = 10 ** 10 @@ -291,12 +301,13 @@ def __init__( batch_size: int = 32, batch_size_mem: int = 32, task_balanced_dataloader: bool = False, + distributed_sampling: bool = True, **kwargs ): - """Custom data loader for rehearsal strategies. + """ Custom data loader for rehearsal strategies. - The iterates in parallel two datasets, the current `data` and the - rehearsal `memory`, which are used to create mini-batches by + This dataloader iterates in parallel two datasets, the current `data` + and the rehearsal `memory`, which are used to create mini-batches by concatenating their data together. Mini-batches from both of them are balanced using the task label (i.e. each mini-batch contains a balanced number of examples from all the tasks in the `data` and `memory`). @@ -325,10 +336,13 @@ def __init__( self.data = data self.memory = memory - self.loader_data: Sequence[DataLoader] = {} - self.loader_memory: Sequence[DataLoader] = {} self.oversample_small_tasks = oversample_small_tasks + self.task_balanced_dataloader = task_balanced_dataloader self.collate_mbatches = collate_mbatches + self.data_batch_sizes: Union[int, Dict[int, int]] = dict() + self.memory_batch_sizes: Union[int, Dict[int, int]] = dict() + self.distributed_sampling = distributed_sampling + self.loader_kwargs = kwargs num_keys = len(self.memory.task_set) if task_balanced_dataloader: @@ -338,10 +352,8 @@ def __init__( "and current data." ) - # Create dataloader for data items - self.loader_data, _ = self._create_dataloaders( - data, batch_size, 0, False, **kwargs - ) + self.data_batch_sizes, _ = self._get_batch_sizes( + data, batch_size, 0, False) # Create dataloader for memory items if task_balanced_dataloader: @@ -351,49 +363,85 @@ def __init__( single_group_batch_size = batch_size_mem remaining_example = 0 - self.loader_memory, remaining_example = self._create_dataloaders( - memory, - single_group_batch_size, - remaining_example, - task_balanced_dataloader, - **kwargs - ) + self.memory_batch_sizes, _ = self._get_batch_sizes( + memory, single_group_batch_size, remaining_example, + task_balanced_dataloader) - self.max_len = max( - [ - len(d) - for d in chain( - self.loader_data.values(), self.loader_memory.values() - ) - ] - ) + loaders_for_len_estimation = [] + + if isinstance(self.data_batch_sizes, int): + loaders_for_len_estimation.append(_make_data_loader( + data, distributed_sampling, kwargs, self.data_batch_sizes, + force_no_workers=True + )[0]) + else: + # Task balanced + for task_id in data.task_set: + dataset = data.task_set[task_id] + mb_sz = self.data_batch_sizes[task_id] + + loaders_for_len_estimation.append(_make_data_loader( + dataset, distributed_sampling, kwargs, mb_sz, + force_no_workers=True + )[0]) + + if isinstance(self.memory_batch_sizes, int): + loaders_for_len_estimation.append(_make_data_loader( + memory, distributed_sampling, kwargs, self.memory_batch_sizes, + force_no_workers=True + )[0]) + else: + for task_id in memory.task_set: + dataset = memory.task_set[task_id] + mb_sz = self.memory_batch_sizes[task_id] + + loaders_for_len_estimation.append(_make_data_loader( + dataset, distributed_sampling, kwargs, mb_sz, + force_no_workers=True + )[0]) + + self.max_len = max([len(d) for d in loaders_for_len_estimation]) def __iter__(self): + loader_data, sampler_data = self._create_loaders_and_samplers( + self.data, self.data_batch_sizes) + + loader_memory, sampler_memory = self._create_loaders_and_samplers( + self.memory, self.memory_batch_sizes) + iter_data_dataloaders = {} iter_buffer_dataloaders = {} - for t in self.loader_data.keys(): - iter_data_dataloaders[t] = iter(self.loader_data[t]) - for t in self.loader_memory.keys(): - iter_buffer_dataloaders[t] = iter(self.loader_memory[t]) + for t in loader_data.keys(): + iter_data_dataloaders[t] = iter(loader_data[t]) + for t in loader_memory.keys(): + iter_buffer_dataloaders[t] = iter(loader_memory[t]) - max_len = max([len(d) for d in iter_data_dataloaders.values()]) + max_len = max( + [ + len(d) + for d in chain( + loader_data.values(), + loader_memory.values(), + ) + ] + ) try: for it in range(max_len): mb_curr = [] - self._get_mini_batch_from_data_dict( - self.data, + ReplayDataLoader._get_mini_batch_from_data_dict( iter_data_dataloaders, - self.loader_data, - False, + sampler_data, + loader_data, + self.oversample_small_tasks, mb_curr, ) - self._get_mini_batch_from_data_dict( - self.memory, + ReplayDataLoader._get_mini_batch_from_data_dict( iter_buffer_dataloaders, - self.loader_memory, + sampler_memory, + loader_memory, self.oversample_small_tasks, mb_curr, ) @@ -405,10 +453,10 @@ def __iter__(self): def __len__(self): return self.max_len + @staticmethod def _get_mini_batch_from_data_dict( - self, - data, iter_dataloaders, + iter_samplers, loaders_dict, oversample_small_tasks, mb_curr, @@ -417,6 +465,7 @@ def _get_mini_batch_from_data_dict( # dictionary. This would break the generator. for t in list(iter_dataloaders.keys()): t_loader = iter_dataloaders[t] + t_sampler = iter_samplers[t] try: tbatch = next(t_loader) except StopIteration: @@ -424,42 +473,93 @@ def _get_mini_batch_from_data_dict( # reinitialize data loader if oversample_small_tasks: # reinitialize data loader + if isinstance(t_sampler, DistributedSampler): + # Manage shuffling in DistributedSampler + t_sampler.set_epoch(t_sampler.epoch + 1) + iter_dataloaders[t] = iter(loaders_dict[t]) tbatch = next(iter_dataloaders[t]) else: del iter_dataloaders[t] + del iter_samplers[t] continue mb_curr.append(tbatch) - def _create_dataloaders( - self, - data_dict, - single_exp_batch_size, - remaining_example, - task_balanced_dataloader, - **kwargs - ): - loaders_dict: Dict[int, DataLoader] = {} + def _create_loaders_and_samplers(self, data, batch_sizes): + loaders = dict() + samplers = dict() + + if isinstance(batch_sizes, int): + loader, sampler = _make_data_loader( + data, self.distributed_sampling, self.loader_kwargs, + batch_sizes, + ) + loaders[0] = loader + samplers[0] = sampler + else: + for task_id in data.task_set: + dataset = data.task_set[task_id] + mb_sz = batch_sizes[task_id] + + loader, sampler = _make_data_loader( + dataset, self.distributed_sampling, + self.loader_kwargs, mb_sz) + + loaders[task_id] = loader + samplers[task_id] = sampler + return loaders, samplers + + @staticmethod + def _get_batch_sizes(data_dict, single_exp_batch_size, remaining_example, + task_balanced_dataloader): + batch_sizes = dict() if task_balanced_dataloader: for task_id in data_dict.task_set: - data = data_dict.task_set[task_id] current_batch_size = single_exp_batch_size if remaining_example > 0: current_batch_size += 1 remaining_example -= 1 - collate_from_data_or_kwargs(data, kwargs) - loaders_dict[task_id] = DataLoader( - data, batch_size=current_batch_size, - **kwargs - ) + batch_sizes[task_id] = current_batch_size else: - collate_from_data_or_kwargs(data_dict, kwargs) - loaders_dict[0] = DataLoader( - data_dict, batch_size=single_exp_batch_size, - **kwargs - ) + # Current data is loaded without task balancing + batch_sizes = single_exp_batch_size + return batch_sizes, remaining_example + + +def _make_data_loader( + dataset, distributed_sampling, data_loader_args, + batch_size, force_no_workers=False): + data_loader_args = data_loader_args.copy() + + collate_from_data_or_kwargs(dataset, data_loader_args) + + if force_no_workers: + data_loader_args['num_workers'] = 0 + + if _DistributedHelper.is_distributed and distributed_sampling: + sampler = DistributedSampler( + dataset, + shuffle=data_loader_args.pop('shuffle', False), + drop_last=data_loader_args.pop('drop_last', False) + ) + data_loader = DataLoader( + dataset, sampler=sampler, batch_size=batch_size, + **data_loader_args) + else: + sampler = None + data_loader = DataLoader( + dataset, batch_size=batch_size, **data_loader_args) + + return data_loader, sampler + + +class __DistributedHelperPlaceholder: + is_distributed = False + world_size = 1 + rank = 0 + - return loaders_dict, remaining_example +_DistributedHelper = __DistributedHelperPlaceholder() __all__ = [