Skip to content

Commit 896c2a6

Browse files
authored
Error for unsupported precision types with ModelParallelStrategy (#19902)
1 parent c09356d commit 896c2a6

File tree

5 files changed

+46
-13
lines changed

5 files changed

+46
-13
lines changed

src/lightning/fabric/connector.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
)
6363
from lightning.fabric.strategies.ddp import _DDP_FORK_ALIASES
6464
from lightning.fabric.strategies.fsdp import _FSDP_ALIASES, FSDPStrategy
65+
from lightning.fabric.strategies.model_parallel import ModelParallelStrategy
6566
from lightning.fabric.utilities import rank_zero_info, rank_zero_warn
6667
from lightning.fabric.utilities.device_parser import _determine_root_gpu_device
6768
from lightning.fabric.utilities.imports import _IS_INTERACTIVE
@@ -460,6 +461,12 @@ def _check_and_init_precision(self) -> Precision:
460461
return DeepSpeedPrecision(self._precision_input) # type: ignore
461462
if isinstance(self.strategy, FSDPStrategy):
462463
return FSDPPrecision(precision=self._precision_input) # type: ignore[arg-type]
464+
mp_precision_supported = ("32-true", "bf16-mixed", "bf16-true", "16-true")
465+
if isinstance(self.strategy, ModelParallelStrategy) and self._precision_input not in mp_precision_supported:
466+
raise ValueError(
467+
f"The `ModelParallelStrategy` does not support `Fabric(..., precision={self._precision_input!r})`."
468+
f" Choose a different precision among: {', '.join(mp_precision_supported)}."
469+
)
463470
if self._precision_input in ("16-true", "bf16-true"):
464471
return HalfPrecision(self._precision_input) # type: ignore
465472
if self._precision_input == "32-true":

src/lightning/pytorch/trainer/connectors/accelerator_connector.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,16 @@ def _validate_precision_choice(self) -> None:
529529
self.accelerator, CUDAAccelerator
530530
):
531531
raise RuntimeError("Bitsandbytes is only supported on CUDA GPUs.")
532+
mp_precision_supported = ("32-true", "bf16-mixed", "bf16-true", "16-true")
533+
if (
534+
isinstance(self._strategy_flag, ModelParallelStrategy)
535+
and self._precision_flag not in mp_precision_supported
536+
):
537+
raise ValueError(
538+
f"The `ModelParallelStrategy` does not support `Fabric(..., precision={self._precision_flag!r})`."
539+
f" Choose a different precision among: {', '.join(mp_precision_supported)}."
540+
)
541+
532542
if _habana_available_and_importable():
533543
from lightning_habana import HPUAccelerator
534544

tests/tests_fabric/strategies/test_model_parallel_integration.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -241,9 +241,7 @@ def _train(fabric, model=None, optimizer=None):
241241
@pytest.mark.parametrize(
242242
"precision",
243243
[
244-
pytest.param(
245-
"16-mixed", marks=pytest.mark.xfail(reason="Precision plugin does not implement ShardedGradScaler yet")
246-
),
244+
pytest.param("32-true"),
247245
pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True)),
248246
],
249247
)
@@ -548,26 +546,17 @@ def _parallelize_single_linear_tp_fsdp2(model, device_mesh):
548546
"precision",
549547
[
550548
"32-true",
551-
pytest.param("16-mixed"),
552549
pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True)),
553550
],
554551
)
555552
@pytest.mark.parametrize(
556553
"clip_type",
557554
[
558555
pytest.param("norm", marks=pytest.mark.skip("Gradient clipping by norm is not correct.")),
559-
pytest.param(
560-
"val",
561-
marks=pytest.mark.xfail(
562-
raises=RecursionError, strict=False, reason="Recursion error when clipping DTensor"
563-
),
564-
),
556+
"val",
565557
],
566558
)
567559
def test_clip_gradients(clip_type, precision):
568-
if clip_type == "norm" and precision == "16-mixed":
569-
pytest.skip(reason="Clipping by norm with 16-mixed is numerically unstable.")
570-
571560
strategy = ModelParallelStrategy(_parallelize_single_linear_tp_fsdp2)
572561
fabric = Fabric(accelerator="auto", devices=2, precision=precision, strategy=strategy)
573562
fabric.launch()

tests/tests_fabric/test_connector.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import inspect
1515
import os
1616
import sys
17+
from contextlib import nullcontext
1718
from typing import Any, Dict
1819
from unittest import mock
1920
from unittest.mock import Mock
@@ -53,6 +54,7 @@
5354
DDPStrategy,
5455
DeepSpeedStrategy,
5556
FSDPStrategy,
57+
ModelParallelStrategy,
5658
SingleDeviceStrategy,
5759
SingleDeviceXLAStrategy,
5860
XLAFSDPStrategy,
@@ -866,6 +868,18 @@ def test_precision_selection_amp_ddp(strategy, devices, is_custom_plugin, plugin
866868
assert isinstance(connector.precision, plugin_cls)
867869

868870

871+
@RunIf(min_torch="2.3")
872+
@pytest.mark.parametrize(
873+
("precision", "raises"),
874+
[("32-true", False), ("16-true", False), ("bf16-true", False), ("16-mixed", True), ("bf16-mixed", False)],
875+
)
876+
@mock.patch("lightning.fabric.accelerators.mps.MPSAccelerator.is_available", return_value=False)
877+
def test_precision_selection_model_parallel(_, precision, raises):
878+
error_context = pytest.raises(ValueError, match=f"does not support .*{precision}") if raises else nullcontext()
879+
with error_context:
880+
_Connector(precision=precision, strategy=ModelParallelStrategy(lambda x, _: x))
881+
882+
869883
def test_bitsandbytes_precision_cuda_required(monkeypatch):
870884
monkeypatch.setattr(lightning.fabric.plugins.precision.bitsandbytes, "_BITSANDBYTES_AVAILABLE", True)
871885
monkeypatch.setitem(sys.modules, "bitsandbytes", Mock())

tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import inspect
1515
import os
1616
import sys
17+
from contextlib import nullcontext
1718
from typing import Any, Dict
1819
from unittest import mock
1920
from unittest.mock import Mock
@@ -48,6 +49,7 @@
4849
DDPStrategy,
4950
DeepSpeedStrategy,
5051
FSDPStrategy,
52+
ModelParallelStrategy,
5153
SingleDeviceStrategy,
5254
SingleDeviceXLAStrategy,
5355
XLAStrategy,
@@ -1063,3 +1065,14 @@ def test_bitsandbytes_precision_cuda_required(monkeypatch):
10631065
monkeypatch.setitem(sys.modules, "bitsandbytes", Mock())
10641066
with pytest.raises(RuntimeError, match="Bitsandbytes is only supported on CUDA GPUs"):
10651067
_AcceleratorConnector(accelerator="cpu", plugins=BitsandbytesPrecision(mode="int8"))
1068+
1069+
1070+
@RunIf(min_torch="2.3")
1071+
@pytest.mark.parametrize(
1072+
("precision", "raises"),
1073+
[("32-true", False), ("16-true", False), ("bf16-true", False), ("16-mixed", True), ("bf16-mixed", False)],
1074+
)
1075+
def test_precision_selection_model_parallel(precision, raises, mps_count_0):
1076+
error_context = pytest.raises(ValueError, match=f"does not support .*{precision}") if raises else nullcontext()
1077+
with error_context:
1078+
_AcceleratorConnector(precision=precision, strategy=ModelParallelStrategy())

0 commit comments

Comments
 (0)