|
20 | 20 | import torch
|
21 | 21 | from torch.distributions import Beta
|
22 | 22 | from torch.nn.utils import rnn
|
23 |
| -from torch.utils.data import DataLoader, Dataset |
| 23 | +from torch.utils.data import BatchSampler, DataLoader, Dataset, RandomSampler |
24 | 24 | from torch.utils.data.sampler import Sampler, SequentialSampler
|
25 | 25 |
|
26 | 26 | from pytorch_forecasting.data.encoders import (
|
@@ -2347,26 +2347,41 @@ def __item_tensor__(self, idx: int) -> tuple[dict[str, torch.Tensor], torch.Tens
|
2347 | 2347 | (target, weight),
|
2348 | 2348 | )
|
2349 | 2349 |
|
2350 |
| - def __precompute__(self, batch_size, shuffle, drop_last): |
| 2350 | + def __precompute__(self, kwargs): |
2351 | 2351 | """
|
2352 | 2352 | Precompute sample for model
|
2353 | 2353 |
|
2354 | 2354 | Args:
|
2355 |
| - batch_size : int, optional, default=64 |
2356 |
| - batch size for training model. Defaults to 64. |
2357 |
| - shuffle : bool |
2358 |
| - indicate whether to shuffle the data |
2359 |
| - drop_last : bool |
2360 |
| - indicate whether to drop last |
| 2355 | + **kwargs: additional arguments passed to ``DataLoader`` constructor |
2361 | 2356 | """
|
2362 |
| - sampler = TimeSynchronizedBatchSampler( |
2363 |
| - SequentialSampler(self), |
2364 |
| - batch_size=batch_size, |
2365 |
| - shuffle=shuffle, |
2366 |
| - drop_last=drop_last, |
2367 |
| - ) |
| 2357 | + batch_sampler = kwargs["batch_sampler"] |
| 2358 | + if batch_sampler is None: |
| 2359 | + sampler = ( |
| 2360 | + RandomSampler(self) if kwargs["shuffle"] else SequentialSampler(self) |
| 2361 | + ) |
| 2362 | + batch_sampler = BatchSampler( |
| 2363 | + sampler=sampler, |
| 2364 | + batch_size=kwargs["batch_size"], |
| 2365 | + drop_last=kwargs["drop_last"], |
| 2366 | + ) |
| 2367 | + else: |
| 2368 | + if isinstance(batch_sampler, str): |
| 2369 | + sampler = kwargs["batch_sampler"] |
| 2370 | + if sampler == "synchronized": |
| 2371 | + batch_sampler = TimeSynchronizedBatchSampler( |
| 2372 | + SequentialSampler(self), |
| 2373 | + batch_size=kwargs["batch_size"], |
| 2374 | + shuffle=kwargs["shuffle"], |
| 2375 | + drop_last=kwargs["drop_last"], |
| 2376 | + ) |
| 2377 | + else: |
| 2378 | + raise ValueError( |
| 2379 | + f"batch_sampler '{batch_sampler}' is not recognized." |
| 2380 | + ) |
| 2381 | + else: |
| 2382 | + raise ValueError(f"batch_sampler '{batch_sampler}' is not recognized.") |
2368 | 2383 |
|
2369 |
| - for batch in sampler: |
| 2384 | + for batch in batch_sampler: |
2370 | 2385 | batch_samples = []
|
2371 | 2386 |
|
2372 | 2387 | for idx in batch:
|
@@ -2674,13 +2689,8 @@ def to_dataloader(
|
2674 | 2689 | kwargs = default_kwargs
|
2675 | 2690 |
|
2676 | 2691 | if self.precompute:
|
2677 |
| - self.__precompute__( |
2678 |
| - batch_size=kwargs["batch_size"], |
2679 |
| - shuffle=kwargs["shuffle"], |
2680 |
| - drop_last=kwargs["drop_last"], |
2681 |
| - ) |
2682 |
| - default_kwargs["collate_fn"] = self.__fast_collate_fn__() |
2683 |
| - |
| 2692 | + kwargs["collate_fn"] = self.__fast_collate_fn__() |
| 2693 | + self.__precompute__(kwargs) |
2684 | 2694 | if kwargs["batch_sampler"] is not None:
|
2685 | 2695 | sampler = kwargs["batch_sampler"]
|
2686 | 2696 | if isinstance(sampler, str):
|
|
0 commit comments