From 70351b809e39681c5c634f6f53e816f816d9f952 Mon Sep 17 00:00:00 2001 From: George Wang Date: Mon, 4 Aug 2025 09:31:08 -0700 Subject: [PATCH] Create Callback to Identify Empty Batches (#1020) Summary: # This Diff: This diff implements Step 1 of T232701473 by creating EmptyDataloaderDetectorCallback, a TNT callback that detects consecutive empty training epochs and implements a fail-fast strategy to surface dataloader issues early. # Callback Feature: The callback helps identify cases where dataloaders return empty batches, which can cause confusing downstream issues that manifest as red herrings (e.g., apparent checkpointing errors that are actually rapid step progression due to empty data). # Next Diff: Add to Mitra's default callbacks (Step 2 of T232701473), and will enable e2e test with Mitra Differential Revision: D79212756 --- .../test_empty_dataloader_detector.py | 210 ++++++++++++++++++ .../callbacks/empty_dataloader_detector.py | 56 +++++ 2 files changed, 266 insertions(+) create mode 100644 tests/framework/callbacks/test_empty_dataloader_detector.py create mode 100644 torchtnt/framework/callbacks/empty_dataloader_detector.py diff --git a/tests/framework/callbacks/test_empty_dataloader_detector.py b/tests/framework/callbacks/test_empty_dataloader_detector.py new file mode 100644 index 0000000000..5fdbc35bc9 --- /dev/null +++ b/tests/framework/callbacks/test_empty_dataloader_detector.py @@ -0,0 +1,210 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest +from unittest.mock import patch + +import torch +import torch.nn as nn +from torch.utils.data import DataLoader, Dataset + +from torchtnt.framework._test_utils import Batch, DummyTrainUnit, get_dummy_train_state +from torchtnt.framework.callbacks.empty_dataloader_detector import ( + EmptyDataloaderDetectorCallback, +) +from torchtnt.framework.state import State +from torchtnt.framework.train import train +from torchtnt.framework.unit import TrainUnit + + +class MockTrainUnit(DummyTrainUnit): + """Mock train unit for testing that extends DummyTrainUnit with step control functionality.""" + + def __init__(self) -> None: + super().__init__(input_dim=2) # Use a default input dimension + self._steps_completed_in_prev_epoch = 0 + + def set_steps_completed_in_prev_epoch(self, steps: int) -> None: + """Set the number of steps completed in the previous epoch.""" + self._steps_completed_in_prev_epoch = steps + self.train_progress._num_steps_completed_in_prev_epoch = steps + + +class EmptyDataloaderDetectorCallbackTest(unittest.TestCase): + def test_init_invalid_threshold(self) -> None: + """Test that invalid threshold values raise ValueError.""" + with self.assertRaisesRegex(ValueError, "threshold must be a positive integer"): + EmptyDataloaderDetectorCallback(threshold=0) + + with self.assertRaisesRegex(ValueError, "threshold must be a positive integer"): + EmptyDataloaderDetectorCallback(threshold=-1) + + def test_init_valid_threshold(self) -> None: + """Test that valid threshold values are accepted.""" + callback = EmptyDataloaderDetectorCallback(threshold=1) + self.assertEqual(callback._threshold, 1) + + callback = EmptyDataloaderDetectorCallback(threshold=5) + self.assertEqual(callback._threshold, 5) + + def test_train_empty_epoch_detection_with_exception(self) -> None: + """Test that consecutive empty train epochs trigger exception when threshold is reached.""" + callback = EmptyDataloaderDetectorCallback(threshold=2) + state = get_dummy_train_state() + unit = MockTrainUnit() + + # First empty epoch - should not raise + unit.set_steps_completed_in_prev_epoch(0) + callback.on_train_epoch_end(state, unit) + self.assertEqual(callback._consecutive_empty_train_epochs, 1) + + # Second empty epoch - should raise exception + unit.set_steps_completed_in_prev_epoch(0) + with self.assertRaisesRegex( + RuntimeError, + "Detected 2 consecutive empty train epochs, which exceeds the threshold of 2", + ): + callback.on_train_epoch_end(state, unit) + + def test_train_reset_counter_on_non_empty_epoch(self) -> None: + """Test that consecutive empty epoch counter resets when a non-empty epoch occurs.""" + callback = EmptyDataloaderDetectorCallback(threshold=3) + state = get_dummy_train_state() + unit = MockTrainUnit() + + # First empty epoch + unit.set_steps_completed_in_prev_epoch(0) + callback.on_train_epoch_end(state, unit) + self.assertEqual(callback._consecutive_empty_train_epochs, 1) + + # Second empty epoch + unit.set_steps_completed_in_prev_epoch(0) + callback.on_train_epoch_end(state, unit) + self.assertEqual(callback._consecutive_empty_train_epochs, 2) + + # Non-empty epoch - should reset counter + unit.set_steps_completed_in_prev_epoch(5) + callback.on_train_epoch_end(state, unit) + self.assertEqual(callback._consecutive_empty_train_epochs, 0) + + # Another empty epoch - counter should start from 1 again + unit.set_steps_completed_in_prev_epoch(0) + callback.on_train_epoch_end(state, unit) + self.assertEqual(callback._consecutive_empty_train_epochs, 1) + + def test_threshold_one(self) -> None: + """Test that threshold=1 triggers immediately on first empty epoch.""" + callback = EmptyDataloaderDetectorCallback(threshold=1) + state = get_dummy_train_state() + unit = MockTrainUnit() + + # First empty epoch should immediately trigger exception + unit.set_steps_completed_in_prev_epoch(0) + with self.assertRaisesRegex( + RuntimeError, + "Detected 1 consecutive empty train epochs, which exceeds the threshold of 1", + ): + callback.on_train_epoch_end(state, unit) + + def test_high_threshold(self) -> None: + """Test that high threshold values work correctly.""" + callback = EmptyDataloaderDetectorCallback(threshold=5) + state = get_dummy_train_state() + unit = MockTrainUnit() + + # Four empty epochs should not trigger + for i in range(4): + unit.set_steps_completed_in_prev_epoch(0) + callback.on_train_epoch_end(state, unit) + self.assertEqual(callback._consecutive_empty_train_epochs, i + 1) + + # Fifth empty epoch should trigger exception + unit.set_steps_completed_in_prev_epoch(0) + with self.assertRaisesRegex( + RuntimeError, + "Detected 5 consecutive empty train epochs, which exceeds the threshold of 5", + ): + callback.on_train_epoch_end(state, unit) + + def test_warning_logged_for_each_empty_epoch(self) -> None: + """Test that a warning is logged for each empty epoch.""" + callback = EmptyDataloaderDetectorCallback(threshold=3) + state = get_dummy_train_state() + unit = MockTrainUnit() + + with patch( + "torchtnt.framework.callbacks.empty_dataloader_detector.logger" + ) as mock_logger: + # First empty epoch + unit.set_steps_completed_in_prev_epoch(0) + callback.on_train_epoch_end(state, unit) + + # Second empty epoch + unit.set_steps_completed_in_prev_epoch(0) + callback.on_train_epoch_end(state, unit) + + # Verify warnings were logged for each empty epoch + self.assertEqual(mock_logger.warning.call_count, 2) + warning_calls = mock_logger.warning.call_args_list + self.assertTrue( + any("Empty train epoch detected" in str(call) for call in warning_calls) + ) + + def test_non_empty_epochs_do_not_trigger_warnings(self) -> None: + """Test that non-empty epochs do not trigger any warnings or exceptions.""" + callback = EmptyDataloaderDetectorCallback(threshold=2) + state = get_dummy_train_state() + unit = MockTrainUnit() + + with patch( + "torchtnt.framework.callbacks.empty_dataloader_detector.logger" + ) as mock_logger: + # Multiple non-empty epochs + for steps in [1, 5, 10, 100]: + unit.set_steps_completed_in_prev_epoch(steps) + callback.on_train_epoch_end(state, unit) + + # No warnings should be logged + mock_logger.warning.assert_not_called() + + # Counter should remain at 0 + self.assertEqual(callback._consecutive_empty_train_epochs, 0) + + def test_empty_dataloader_detection_with_real_training_loop(self) -> None: + """ + Test that simulates the real scenario from failed MAST job f762746046-pviolatingquery_cse. + Tests EmptyDataloaderDetectorCallback with actual training loop and empty dataloaders. + """ + + class EmptyDataset(Dataset[Batch]): + """Dataset that returns no data to simulate empty dataloader scenario.""" + + def __len__(self) -> int: + return 0 + + def __getitem__(self, idx: int) -> Batch: + raise IndexError("Empty dataset") + + callback_with_exception = EmptyDataloaderDetectorCallback(threshold=2) + + train_unit = DummyTrainUnit(input_dim=2) + empty_dataloader = DataLoader(EmptyDataset(), batch_size=1) + + # This should raise an exception after 2 empty epochs + with self.assertRaisesRegex( + RuntimeError, + "Detected 2 consecutive empty train epochs, which exceeds the threshold of 2", + ): + train( + train_unit, + empty_dataloader, + max_epochs=50, # Try to run 50 epochs but should fail at 2 + callbacks=[callback_with_exception], + ) + + self.assertEqual(callback_with_exception._consecutive_empty_train_epochs, 2) diff --git a/torchtnt/framework/callbacks/empty_dataloader_detector.py b/torchtnt/framework/callbacks/empty_dataloader_detector.py new file mode 100644 index 0000000000..67351d4325 --- /dev/null +++ b/torchtnt/framework/callbacks/empty_dataloader_detector.py @@ -0,0 +1,56 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import logging + +from torchtnt.framework.callback import Callback +from torchtnt.framework.state import State +from torchtnt.framework.unit import TTrainUnit + +logger: logging.Logger = logging.getLogger(__name__) + + +class EmptyDataloaderDetectorCallback(Callback): + """ + A callback that detects consecutive empty epochs and raises an error when a threshold is reached. + + This callback helps identify issues where dataloaders return empty batches, which can cause confusing + downstream problems that are hard to debug. It implements a fail-fast strategy to surface these issues early. + """ + + def __init__( + self, + threshold: int = 2, + ) -> None: + if threshold <= 0: + raise ValueError("threshold must be a positive integer") + + self._threshold = threshold + self._consecutive_empty_train_epochs = 0 + + def on_train_epoch_end(self, state: State, unit: TTrainUnit) -> None: + num_steps = unit.train_progress.num_steps_completed_in_prev_epoch + epoch_num = unit.train_progress.num_epochs_completed + + if num_steps == 0: + self._consecutive_empty_train_epochs += 1 + logger.warning( + f"Empty train epoch detected! Epoch {epoch_num} completed 0 steps. " + f"Consecutive empty train epochs: {self._consecutive_empty_train_epochs}" + ) + + if self._consecutive_empty_train_epochs >= self._threshold: + error_msg = ( + f"Detected {self._consecutive_empty_train_epochs} consecutive empty train epochs, " + f"which exceeds the threshold of {self._threshold}. This indicates that the " + f"dataloader is returning empty batches, which could be due to an empty " + f"training table or infrastructure issues with the dataloader." + ) + raise RuntimeError(error_msg) + else: + self._consecutive_empty_train_epochs = 0