Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/lightning/pytorch/loops/optimization/automatic.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@
from typing_extensions import override

import lightning.pytorch as pl
from lightning.fabric.utilities.warnings import PossibleUserWarning
from lightning.pytorch.loops.loop import _Loop
from lightning.pytorch.loops.optimization.closure import AbstractClosure, OutputResult
from lightning.pytorch.loops.progress import _OptimizationProgress
from lightning.pytorch.loops.utilities import _block_parallel_sync_behavior
from lightning.pytorch.trainer import call
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.rank_zero import WarningCache
from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_warn
from lightning.pytorch.utilities.types import STEP_OUTPUT


Expand Down Expand Up @@ -320,10 +321,11 @@ def _training_step(self, kwargs: OrderedDict) -> ClosureResult:
self.trainer.strategy.post_training_step() # unused hook - call anyway for backward compatibility

if training_step_output is None and trainer.world_size > 1:
raise RuntimeError(
rank_zero_warn(
"Skipping the `training_step` by returning None in distributed training is not supported."
" It is recommended that you rewrite your training logic to avoid having to skip the step in the first"
" place."
" place.",
category=PossibleUserWarning,
)

return self.output_result_cls.from_training_step_output(training_step_output, trainer.accumulate_grad_batches)
25 changes: 0 additions & 25 deletions tests/tests_pytorch/loops/optimization/test_automatic_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Iterator, Mapping
from contextlib import nullcontext
from typing import Generic, TypeVar

import pytest
Expand Down Expand Up @@ -84,27 +83,3 @@ def training_step(self, batch, batch_idx):

with pytest.raises(MisconfigurationException, match=match):
trainer.fit(model)


@pytest.mark.parametrize("world_size", [1, 2])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

convert this tests to validate the warning

def test_skip_training_step_not_allowed(world_size, tmp_path):
"""Test that skipping the training_step in distributed training is not allowed."""

class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
return None

model = TestModel()
trainer = Trainer(
default_root_dir=tmp_path,
max_steps=1,
barebones=True,
)
trainer.strategy.world_size = world_size # mock world size without launching processes
error_context = (
pytest.raises(RuntimeError, match="Skipping the `training_step` .* is not supported")
if world_size > 1
else nullcontext()
)
with error_context:
trainer.fit(model)
Loading