Skip to content

Commit 7f97963

Browse files
committed
defaults to torch dataloader when no batch_sampler is set
1 parent c0c519b commit 7f97963

File tree

1 file changed

+32
-22
lines changed

1 file changed

+32
-22
lines changed

pytorch_forecasting/data/timeseries/_timeseries.py

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import torch
2121
from torch.distributions import Beta
2222
from torch.nn.utils import rnn
23-
from torch.utils.data import DataLoader, Dataset
23+
from torch.utils.data import BatchSampler, DataLoader, Dataset, RandomSampler
2424
from torch.utils.data.sampler import Sampler, SequentialSampler
2525

2626
from pytorch_forecasting.data.encoders import (
@@ -2347,26 +2347,41 @@ def __item_tensor__(self, idx: int) -> tuple[dict[str, torch.Tensor], torch.Tens
23472347
(target, weight),
23482348
)
23492349

2350-
def __precompute__(self, batch_size, shuffle, drop_last):
2350+
def __precompute__(self, kwargs):
23512351
"""
23522352
Precompute sample for model
23532353
23542354
Args:
2355-
batch_size : int, optional, default=64
2356-
batch size for training model. Defaults to 64.
2357-
shuffle : bool
2358-
indicate whether to shuffle the data
2359-
drop_last : bool
2360-
indicate whether to drop last
2355+
**kwargs: additional arguments passed to ``DataLoader`` constructor
23612356
"""
2362-
sampler = TimeSynchronizedBatchSampler(
2363-
SequentialSampler(self),
2364-
batch_size=batch_size,
2365-
shuffle=shuffle,
2366-
drop_last=drop_last,
2367-
)
2357+
batch_sampler = kwargs["batch_sampler"]
2358+
if batch_sampler is None:
2359+
sampler = (
2360+
RandomSampler(self) if kwargs["shuffle"] else SequentialSampler(self)
2361+
)
2362+
batch_sampler = BatchSampler(
2363+
sampler=sampler,
2364+
batch_size=kwargs["batch_size"],
2365+
drop_last=kwargs["drop_last"],
2366+
)
2367+
else:
2368+
if isinstance(batch_sampler, str):
2369+
sampler = kwargs["batch_sampler"]
2370+
if sampler == "synchronized":
2371+
batch_sampler = TimeSynchronizedBatchSampler(
2372+
SequentialSampler(self),
2373+
batch_size=kwargs["batch_size"],
2374+
shuffle=kwargs["shuffle"],
2375+
drop_last=kwargs["drop_last"],
2376+
)
2377+
else:
2378+
raise ValueError(
2379+
f"batch_sampler '{batch_sampler}' is not recognized."
2380+
)
2381+
else:
2382+
raise ValueError(f"batch_sampler '{batch_sampler}' is not recognized.")
23682383

2369-
for batch in sampler:
2384+
for batch in batch_sampler:
23702385
batch_samples = []
23712386

23722387
for idx in batch:
@@ -2674,13 +2689,8 @@ def to_dataloader(
26742689
kwargs = default_kwargs
26752690

26762691
if self.precompute:
2677-
self.__precompute__(
2678-
batch_size=kwargs["batch_size"],
2679-
shuffle=kwargs["shuffle"],
2680-
drop_last=kwargs["drop_last"],
2681-
)
2682-
default_kwargs["collate_fn"] = self.__fast_collate_fn__()
2683-
2692+
kwargs["collate_fn"] = self.__fast_collate_fn__()
2693+
self.__precompute__(kwargs)
26842694
if kwargs["batch_sampler"] is not None:
26852695
sampler = kwargs["batch_sampler"]
26862696
if isinstance(sampler, str):

0 commit comments

Comments
 (0)