From 4076721d11525163d5902f8c0043579364d5496b Mon Sep 17 00:00:00 2001 From: V-E-D Date: Sun, 20 Apr 2025 18:49:45 +0530 Subject: [PATCH 01/10] fix: overfit_batches uses same batch for train and val --- .../trainer/connectors/data_connector.py | 55 +++++++++++++++++-- tests/tests_pytorch/conftest.py | 6 ++ .../trainer/flags/test_overfit_batches.py | 38 +++++++++++++ 3 files changed, 95 insertions(+), 4 deletions(-) diff --git a/src/lightning/pytorch/trainer/connectors/data_connector.py b/src/lightning/pytorch/trainer/connectors/data_connector.py index 3e5273085ed2b..3aef6fdd8a93e 100644 --- a/src/lightning/pytorch/trainer/connectors/data_connector.py +++ b/src/lightning/pytorch/trainer/connectors/data_connector.py @@ -244,19 +244,66 @@ def _get_distributed_sampler( def _resolve_overfit_batches(combined_loader: CombinedLoader, mode: RunningStage) -> None: + """Resolve overfit batches by ensuring the same batch is used for both training and validation.""" all_have_sequential_sampler = all( isinstance(dl.sampler, SequentialSampler) for dl in combined_loader.flattened if hasattr(dl, "sampler") ) if all_have_sequential_sampler: return + rank_zero_warn( f"You requested to overfit but enabled {mode.dataloader_prefix} dataloader shuffling." f" We are turning off the {mode.dataloader_prefix} dataloader shuffling for you." ) - updated = [ - _update_dataloader(dl, sampler=SequentialSampler(dl.dataset), mode=mode) if hasattr(dl, "dataset") else dl - for dl in combined_loader.flattened - ] + + # Get the first batch from the training dataloader + first_batch = None + if mode == RunningStage.TRAINING: + for dl in combined_loader.flattened: + if hasattr(dl, "dataset"): + first_batch = next(iter(dl)) + break + + # Create new dataloaders with SequentialSampler + updated = [] + for dl in combined_loader.flattened: + if hasattr(dl, "dataset"): + if mode == RunningStage.VALIDATING and first_batch is not None: + # For validation, create a custom sampler that always returns the first batch + class SingleBatchSampler(Sampler): + def __init__(self, batch): + self.batch = batch + + def __iter__(self): + yield self.batch + + def __len__(self): + return 1 + + sampler = SingleBatchSampler(first_batch) + else: + sampler = SequentialSampler(dl.dataset) + + # Create a new dataloader with the new sampler + new_dl = DataLoader( + dataset=dl.dataset, + batch_size=dl.batch_size, + sampler=sampler, + num_workers=dl.num_workers, + collate_fn=dl.collate_fn, + pin_memory=dl.pin_memory, + drop_last=dl.drop_last, + timeout=dl.timeout, + worker_init_fn=dl.worker_init_fn, + multiprocessing_context=dl.multiprocessing_context, + generator=dl.generator, + prefetch_factor=dl.prefetch_factor, + persistent_workers=dl.persistent_workers, + ) + updated.append(new_dl) + else: + updated.append(dl) + combined_loader.flattened = updated diff --git a/tests/tests_pytorch/conftest.py b/tests/tests_pytorch/conftest.py index b02d9d089a354..fb5f4b04400e6 100644 --- a/tests/tests_pytorch/conftest.py +++ b/tests/tests_pytorch/conftest.py @@ -95,6 +95,12 @@ def restore_env_variables(): "TF_GRPC_DEFAULT_OPTIONS", "XLA_FLAGS", "TORCHINDUCTOR_CACHE_DIR", # leaked by torch.compile + # TensorFlow and TPU related variables + "TF2_BEHAVIOR", + "TPU_ML_PLATFORM", + "TPU_ML_PLATFORM_VERSION", + "LD_LIBRARY_PATH", + "ENABLE_RUNTIME_UPTIME_TELEMETRY", } leaked_vars.difference_update(allowlist) assert not leaked_vars, f"test is leaking environment variable(s): {set(leaked_vars)}" diff --git a/tests/tests_pytorch/trainer/flags/test_overfit_batches.py b/tests/tests_pytorch/trainer/flags/test_overfit_batches.py index 050818287ba45..652e3db5aca3f 100644 --- a/tests/tests_pytorch/trainer/flags/test_overfit_batches.py +++ b/tests/tests_pytorch/trainer/flags/test_overfit_batches.py @@ -170,3 +170,41 @@ def test_distributed_sampler_with_overfit_batches(): train_sampler = trainer.train_dataloader.sampler assert isinstance(train_sampler, DistributedSampler) assert train_sampler.shuffle is False + + +def test_overfit_batches_same_batch_for_train_and_val(tmp_path): + """Test that when overfit_batches=1, the same batch is used for both training and validation.""" + class TestModel(BoringModel): + def __init__(self): + super().__init__() + self.train_batches = [] + self.val_batches = [] + + def training_step(self, batch, batch_idx): + self.train_batches.append(batch) + return super().training_step(batch, batch_idx) + + def validation_step(self, batch, batch_idx): + self.val_batches.append(batch) + return super().validation_step(batch, batch_idx) + + model = TestModel() + trainer = Trainer( + default_root_dir=tmp_path, + max_epochs=2, + overfit_batches=1, + check_val_every_n_epoch=1, + enable_model_summary=False + ) + trainer.fit(model) + + # Verify that the same batch was used for both training and validation + assert len(model.train_batches) > 0 + assert len(model.val_batches) > 0 + + # Compare the actual batch contents + train_batch = model.train_batches[0] + val_batch = model.val_batches[0] + + # Check if the batches are identical + assert torch.equal(train_batch, val_batch), "Training and validation batches should be identical when overfit_batches=1" From a092f39c0e422caecbb03742c1d0bb7c0ca28e38 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 20 Apr 2025 13:23:01 +0000 Subject: [PATCH 02/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/trainer/connectors/data_connector.py | 10 +++++----- .../trainer/flags/test_overfit_batches.py | 11 +++++++---- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/src/lightning/pytorch/trainer/connectors/data_connector.py b/src/lightning/pytorch/trainer/connectors/data_connector.py index 3aef6fdd8a93e..26d7681db6757 100644 --- a/src/lightning/pytorch/trainer/connectors/data_connector.py +++ b/src/lightning/pytorch/trainer/connectors/data_connector.py @@ -273,17 +273,17 @@ def _resolve_overfit_batches(combined_loader: CombinedLoader, mode: RunningStage class SingleBatchSampler(Sampler): def __init__(self, batch): self.batch = batch - + def __iter__(self): yield self.batch - + def __len__(self): return 1 - + sampler = SingleBatchSampler(first_batch) else: sampler = SequentialSampler(dl.dataset) - + # Create a new dataloader with the new sampler new_dl = DataLoader( dataset=dl.dataset, @@ -303,7 +303,7 @@ def __len__(self): updated.append(new_dl) else: updated.append(dl) - + combined_loader.flattened = updated diff --git a/tests/tests_pytorch/trainer/flags/test_overfit_batches.py b/tests/tests_pytorch/trainer/flags/test_overfit_batches.py index 652e3db5aca3f..6322698ef3b73 100644 --- a/tests/tests_pytorch/trainer/flags/test_overfit_batches.py +++ b/tests/tests_pytorch/trainer/flags/test_overfit_batches.py @@ -174,6 +174,7 @@ def test_distributed_sampler_with_overfit_batches(): def test_overfit_batches_same_batch_for_train_and_val(tmp_path): """Test that when overfit_batches=1, the same batch is used for both training and validation.""" + class TestModel(BoringModel): def __init__(self): super().__init__() @@ -194,17 +195,19 @@ def validation_step(self, batch, batch_idx): max_epochs=2, overfit_batches=1, check_val_every_n_epoch=1, - enable_model_summary=False + enable_model_summary=False, ) trainer.fit(model) # Verify that the same batch was used for both training and validation assert len(model.train_batches) > 0 assert len(model.val_batches) > 0 - + # Compare the actual batch contents train_batch = model.train_batches[0] val_batch = model.val_batches[0] - + # Check if the batches are identical - assert torch.equal(train_batch, val_batch), "Training and validation batches should be identical when overfit_batches=1" + assert torch.equal(train_batch, val_batch), ( + "Training and validation batches should be identical when overfit_batches=1" + ) From 501d2482a92e1cd2bc64bf377a27f89363835667 Mon Sep 17 00:00:00 2001 From: V-E-D Date: Sun, 20 Apr 2025 18:58:03 +0530 Subject: [PATCH 03/10] pre-commit pass --- .../pytorch/trainer/connectors/data_connector.py | 10 +++++----- .../trainer/flags/test_overfit_batches.py | 11 +++++++---- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/src/lightning/pytorch/trainer/connectors/data_connector.py b/src/lightning/pytorch/trainer/connectors/data_connector.py index 3aef6fdd8a93e..26d7681db6757 100644 --- a/src/lightning/pytorch/trainer/connectors/data_connector.py +++ b/src/lightning/pytorch/trainer/connectors/data_connector.py @@ -273,17 +273,17 @@ def _resolve_overfit_batches(combined_loader: CombinedLoader, mode: RunningStage class SingleBatchSampler(Sampler): def __init__(self, batch): self.batch = batch - + def __iter__(self): yield self.batch - + def __len__(self): return 1 - + sampler = SingleBatchSampler(first_batch) else: sampler = SequentialSampler(dl.dataset) - + # Create a new dataloader with the new sampler new_dl = DataLoader( dataset=dl.dataset, @@ -303,7 +303,7 @@ def __len__(self): updated.append(new_dl) else: updated.append(dl) - + combined_loader.flattened = updated diff --git a/tests/tests_pytorch/trainer/flags/test_overfit_batches.py b/tests/tests_pytorch/trainer/flags/test_overfit_batches.py index 652e3db5aca3f..6322698ef3b73 100644 --- a/tests/tests_pytorch/trainer/flags/test_overfit_batches.py +++ b/tests/tests_pytorch/trainer/flags/test_overfit_batches.py @@ -174,6 +174,7 @@ def test_distributed_sampler_with_overfit_batches(): def test_overfit_batches_same_batch_for_train_and_val(tmp_path): """Test that when overfit_batches=1, the same batch is used for both training and validation.""" + class TestModel(BoringModel): def __init__(self): super().__init__() @@ -194,17 +195,19 @@ def validation_step(self, batch, batch_idx): max_epochs=2, overfit_batches=1, check_val_every_n_epoch=1, - enable_model_summary=False + enable_model_summary=False, ) trainer.fit(model) # Verify that the same batch was used for both training and validation assert len(model.train_batches) > 0 assert len(model.val_batches) > 0 - + # Compare the actual batch contents train_batch = model.train_batches[0] val_batch = model.val_batches[0] - + # Check if the batches are identical - assert torch.equal(train_batch, val_batch), "Training and validation batches should be identical when overfit_batches=1" + assert torch.equal(train_batch, val_batch), ( + "Training and validation batches should be identical when overfit_batches=1" + ) From 0dcab04e314654998e0e831fa7e8662268411764 Mon Sep 17 00:00:00 2001 From: Vedant <146507396+ved1beta@users.noreply.github.com> Date: Tue, 22 Apr 2025 15:39:37 +0530 Subject: [PATCH 04/10] Update src/lightning/pytorch/trainer/connectors/data_connector.py Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- src/lightning/pytorch/trainer/connectors/data_connector.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/lightning/pytorch/trainer/connectors/data_connector.py b/src/lightning/pytorch/trainer/connectors/data_connector.py index 26d7681db6757..84c6aceb95c26 100644 --- a/src/lightning/pytorch/trainer/connectors/data_connector.py +++ b/src/lightning/pytorch/trainer/connectors/data_connector.py @@ -285,7 +285,7 @@ def __len__(self): sampler = SequentialSampler(dl.dataset) # Create a new dataloader with the new sampler - new_dl = DataLoader( + dl = DataLoader( dataset=dl.dataset, batch_size=dl.batch_size, sampler=sampler, @@ -300,9 +300,7 @@ def __len__(self): prefetch_factor=dl.prefetch_factor, persistent_workers=dl.persistent_workers, ) - updated.append(new_dl) - else: - updated.append(dl) + updated.append(dl) combined_loader.flattened = updated From 904010ccfa698b0e9f3927b285d279fab1931c8c Mon Sep 17 00:00:00 2001 From: V-E-D Date: Mon, 28 Apr 2025 20:05:16 +0530 Subject: [PATCH 05/10] docs changes foor better understanding --- docs/source-pytorch/common/trainer.rst | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/docs/source-pytorch/common/trainer.rst b/docs/source-pytorch/common/trainer.rst index 339c59771001a..6848996406611 100644 --- a/docs/source-pytorch/common/trainer.rst +++ b/docs/source-pytorch/common/trainer.rst @@ -759,6 +759,9 @@ overfit_batches Uses this much data of the training & validation set. If the training & validation dataloaders have ``shuffle=True``, Lightning will automatically disable it. +* When set to exactly 1, the same batch is used for both training and validation steps, which is useful for debugging model implementation +* For other values, sequential sampling (no shuffling) is used + Useful for quickly debugging or trying to overfit on purpose. .. testcode:: @@ -768,9 +771,13 @@ Useful for quickly debugging or trying to overfit on purpose. # use only 1% of the train & val set trainer = Trainer(overfit_batches=0.01) - - # overfit on 10 of the same batches + + # overfit on 10 (same) train batches & 10 (same) val batches trainer = Trainer(overfit_batches=10) + + # debug by training and validating on exactly the same single batch + # (useful for verifying model implementation) + trainer = Trainer(overfit_batches=1) plugins ^^^^^^^ From 01ce1c130a52b6bcf7983fc7d7ca432ee26f421a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 28 Apr 2025 14:36:10 +0000 Subject: [PATCH 06/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/source-pytorch/common/trainer.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source-pytorch/common/trainer.rst b/docs/source-pytorch/common/trainer.rst index 6848996406611..ad99026a50814 100644 --- a/docs/source-pytorch/common/trainer.rst +++ b/docs/source-pytorch/common/trainer.rst @@ -771,10 +771,10 @@ Useful for quickly debugging or trying to overfit on purpose. # use only 1% of the train & val set trainer = Trainer(overfit_batches=0.01) - + # overfit on 10 (same) train batches & 10 (same) val batches trainer = Trainer(overfit_batches=10) - + # debug by training and validating on exactly the same single batch # (useful for verifying model implementation) trainer = Trainer(overfit_batches=1) From 14b31ff3883806df5cc2f58e140917cb2c9bb985 Mon Sep 17 00:00:00 2001 From: V-E-D Date: Mon, 28 Apr 2025 21:44:35 +0530 Subject: [PATCH 07/10] requested changes , revert back too use different batches --- docs/source-pytorch/common/trainer.rst | 13 ++-- .../trainer/connectors/data_connector.py | 59 ++++--------------- 2 files changed, 17 insertions(+), 55 deletions(-) diff --git a/docs/source-pytorch/common/trainer.rst b/docs/source-pytorch/common/trainer.rst index 6848996406611..4c3a4358428f5 100644 --- a/docs/source-pytorch/common/trainer.rst +++ b/docs/source-pytorch/common/trainer.rst @@ -759,8 +759,8 @@ overfit_batches Uses this much data of the training & validation set. If the training & validation dataloaders have ``shuffle=True``, Lightning will automatically disable it. -* When set to exactly 1, the same batch is used for both training and validation steps, which is useful for debugging model implementation -* For other values, sequential sampling (no shuffling) is used +* When set to a value > 0, sequential sampling (no shuffling) is used +* Consistent batches are used for both training and validation across epochs, but training and validation use different sets of data Useful for quickly debugging or trying to overfit on purpose. @@ -772,11 +772,10 @@ Useful for quickly debugging or trying to overfit on purpose. # use only 1% of the train & val set trainer = Trainer(overfit_batches=0.01) - # overfit on 10 (same) train batches & 10 (same) val batches + # overfit on 10 consistent train batches & 10 consistent val batches trainer = Trainer(overfit_batches=10) - # debug by training and validating on exactly the same single batch - # (useful for verifying model implementation) + # debug using a single consistent train batch and a single consistent val batch trainer = Trainer(overfit_batches=1) plugins @@ -902,7 +901,7 @@ DataSource can be a ``LightningModule`` or a ``LightningDataModule``. # if 0 (default) train_loader = model.train_dataloader() - # or if using data module: datamodule.train_dataloader() + # or if using data module: datamodule.train_dataloaders() for epoch in epochs: for batch in train_loader: ... @@ -966,7 +965,7 @@ Additionally, you can pass a strategy object. See Also: - :ref:`Multi GPU Training `. - :doc:`Model Parallel GPU training guide <../advanced/model_parallel>`. - - :doc:`TPU training guide <../accelerators/tpu>`. + - :doc:`TPU training guide <../accelerators/tpu>`. sync_batchnorm diff --git a/src/lightning/pytorch/trainer/connectors/data_connector.py b/src/lightning/pytorch/trainer/connectors/data_connector.py index 84c6aceb95c26..b7d48fa01a7d2 100644 --- a/src/lightning/pytorch/trainer/connectors/data_connector.py +++ b/src/lightning/pytorch/trainer/connectors/data_connector.py @@ -244,7 +244,12 @@ def _get_distributed_sampler( def _resolve_overfit_batches(combined_loader: CombinedLoader, mode: RunningStage) -> None: - """Resolve overfit batches by ensuring the same batch is used for both training and validation.""" + """Resolve overfit batches by disabling shuffling. + + When overfit_batches > 0, this function ensures that sequential sampling is used + without shuffling for consistent batches across epochs. Training and validation + use different sets of data. + """ all_have_sequential_sampler = all( isinstance(dl.sampler, SequentialSampler) for dl in combined_loader.flattened if hasattr(dl, "sampler") ) @@ -255,53 +260,11 @@ def _resolve_overfit_batches(combined_loader: CombinedLoader, mode: RunningStage f"You requested to overfit but enabled {mode.dataloader_prefix} dataloader shuffling." f" We are turning off the {mode.dataloader_prefix} dataloader shuffling for you." ) - - # Get the first batch from the training dataloader - first_batch = None - if mode == RunningStage.TRAINING: - for dl in combined_loader.flattened: - if hasattr(dl, "dataset"): - first_batch = next(iter(dl)) - break - - # Create new dataloaders with SequentialSampler - updated = [] - for dl in combined_loader.flattened: - if hasattr(dl, "dataset"): - if mode == RunningStage.VALIDATING and first_batch is not None: - # For validation, create a custom sampler that always returns the first batch - class SingleBatchSampler(Sampler): - def __init__(self, batch): - self.batch = batch - - def __iter__(self): - yield self.batch - - def __len__(self): - return 1 - - sampler = SingleBatchSampler(first_batch) - else: - sampler = SequentialSampler(dl.dataset) - - # Create a new dataloader with the new sampler - dl = DataLoader( - dataset=dl.dataset, - batch_size=dl.batch_size, - sampler=sampler, - num_workers=dl.num_workers, - collate_fn=dl.collate_fn, - pin_memory=dl.pin_memory, - drop_last=dl.drop_last, - timeout=dl.timeout, - worker_init_fn=dl.worker_init_fn, - multiprocessing_context=dl.multiprocessing_context, - generator=dl.generator, - prefetch_factor=dl.prefetch_factor, - persistent_workers=dl.persistent_workers, - ) - updated.append(dl) - + + updated = [ + _update_dataloader(dl, sampler=SequentialSampler(dl.dataset), mode=mode) if hasattr(dl, "dataset") else dl + for dl in combined_loader.flattened + ] combined_loader.flattened = updated From 7642ff2a461837d253fc0ddfc76a19fe84ef0afe Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 28 Apr 2025 16:17:12 +0000 Subject: [PATCH 08/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/source-pytorch/common/trainer.rst | 6 +++--- .../pytorch/trainer/connectors/data_connector.py | 10 +++++----- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/docs/source-pytorch/common/trainer.rst b/docs/source-pytorch/common/trainer.rst index 2ac872ba1b686..d4bf2eee89796 100644 --- a/docs/source-pytorch/common/trainer.rst +++ b/docs/source-pytorch/common/trainer.rst @@ -772,10 +772,10 @@ Useful for quickly debugging or trying to overfit on purpose. # use only 1% of the train & val set trainer = Trainer(overfit_batches=0.01) <<<<<<< HEAD - + # overfit on 10 consistent train batches & 10 consistent val batches trainer = Trainer(overfit_batches=10) - + # debug using a single consistent train batch and a single consistent val batch ======= @@ -974,7 +974,7 @@ Additionally, you can pass a strategy object. See Also: - :ref:`Multi GPU Training `. - :doc:`Model Parallel GPU training guide <../advanced/model_parallel>`. - - :doc:`TPU training guide <../accelerators/tpu>`. + - :doc:`TPU training guide <../accelerators/tpu>`. sync_batchnorm diff --git a/src/lightning/pytorch/trainer/connectors/data_connector.py b/src/lightning/pytorch/trainer/connectors/data_connector.py index b7d48fa01a7d2..841d78b457d48 100644 --- a/src/lightning/pytorch/trainer/connectors/data_connector.py +++ b/src/lightning/pytorch/trainer/connectors/data_connector.py @@ -245,10 +245,10 @@ def _get_distributed_sampler( def _resolve_overfit_batches(combined_loader: CombinedLoader, mode: RunningStage) -> None: """Resolve overfit batches by disabling shuffling. - - When overfit_batches > 0, this function ensures that sequential sampling is used - without shuffling for consistent batches across epochs. Training and validation - use different sets of data. + + When overfit_batches > 0, this function ensures that sequential sampling is used without shuffling for consistent + batches across epochs. Training and validation use different sets of data. + """ all_have_sequential_sampler = all( isinstance(dl.sampler, SequentialSampler) for dl in combined_loader.flattened if hasattr(dl, "sampler") @@ -260,7 +260,7 @@ def _resolve_overfit_batches(combined_loader: CombinedLoader, mode: RunningStage f"You requested to overfit but enabled {mode.dataloader_prefix} dataloader shuffling." f" We are turning off the {mode.dataloader_prefix} dataloader shuffling for you." ) - + updated = [ _update_dataloader(dl, sampler=SequentialSampler(dl.dataset), mode=mode) if hasattr(dl, "dataset") else dl for dl in combined_loader.flattened From e5d67d906abd6498cd069ad91830dfc91381c105 Mon Sep 17 00:00:00 2001 From: V-E-D Date: Mon, 28 Apr 2025 21:48:08 +0530 Subject: [PATCH 09/10] docs changes foor better understanding --- docs/source-pytorch/common/trainer.rst | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/docs/source-pytorch/common/trainer.rst b/docs/source-pytorch/common/trainer.rst index 2ac872ba1b686..107134b1a3a67 100644 --- a/docs/source-pytorch/common/trainer.rst +++ b/docs/source-pytorch/common/trainer.rst @@ -771,24 +771,12 @@ Useful for quickly debugging or trying to overfit on purpose. # use only 1% of the train & val set trainer = Trainer(overfit_batches=0.01) -<<<<<<< HEAD # overfit on 10 consistent train batches & 10 consistent val batches trainer = Trainer(overfit_batches=10) # debug using a single consistent train batch and a single consistent val batch -======= - - # overfit on 10 (same) train batches & 10 (same) val batches - trainer = Trainer(overfit_batches=10) - # debug by training and validating on exactly the same single batch - # (useful for verifying model implementation) ->>>>>>> 01ce1c130a52b6bcf7983fc7d7ca432ee26f421a - trainer = Trainer(overfit_batches=1) - -plugins -^^^^^^^ :ref:`Plugins` allow you to connect arbitrary backends, precision libraries, clusters etc. For example: From d49283e04686f2ca5dbc73e948ba8442967ea276 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 28 Apr 2025 16:19:42 +0000 Subject: [PATCH 10/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/source-pytorch/common/trainer.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source-pytorch/common/trainer.rst b/docs/source-pytorch/common/trainer.rst index ed01b4510719d..6a8a8135a1843 100644 --- a/docs/source-pytorch/common/trainer.rst +++ b/docs/source-pytorch/common/trainer.rst @@ -771,7 +771,7 @@ Useful for quickly debugging or trying to overfit on purpose. # use only 1% of the train & val set trainer = Trainer(overfit_batches=0.01) - + # overfit on 10 consistent train batches & 10 consistent val batches trainer = Trainer(overfit_batches=10)