From e043f07f4ccc890f8d3753b96d397b63213967f3 Mon Sep 17 00:00:00 2001 From: Marco Garosi Date: Wed, 15 Jan 2025 09:39:30 +0100 Subject: [PATCH] add option to disable drop last on training --- dassl/config/defaults.py | 1 + dassl/data/data_manager.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/dassl/config/defaults.py b/dassl/config/defaults.py index cd873e9..2da1656 100644 --- a/dassl/config/defaults.py +++ b/dassl/config/defaults.py @@ -107,6 +107,7 @@ # Parameter of RandomClassSampler # Number of instances per class _C.DATALOADER.TRAIN_X.N_INS = 16 +_C.DATALOADER.TRAIN_X.DISABLE_DROP_LAST = False # Setting for the train_u data-loader _C.DATALOADER.TRAIN_U = CN() diff --git a/dassl/data/data_manager.py b/dassl/data/data_manager.py index c0a4b42..fca782c 100644 --- a/dassl/data/data_manager.py +++ b/dassl/data/data_manager.py @@ -40,7 +40,7 @@ def build_data_loader( batch_size=batch_size, sampler=sampler, num_workers=cfg.DATALOADER.NUM_WORKERS, - drop_last=is_train and len(data_source) >= batch_size, + drop_last=is_train and len(data_source) >= batch_size and not cfg.DATALOADER.TRAIN_X.DISABLE_DROP_LAST, pin_memory=(torch.cuda.is_available() and cfg.USE_CUDA) ) assert len(data_loader) > 0