Skip to content

Commit d032388

Browse files
mauvilsaBorda
andauthored
LightningCLI instantiator receives values applied by instantiation links to set in hparams (#20777)
* Instantiator receives values applied by instantiation links to set in hparams (#20311). * Add cleandir to test_lightning_cli_link_arguments * fix install... --------- Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: Jirka B <j.borovec+github@gmail.com>
1 parent cce06ec commit d032388

File tree

6 files changed

+117
-12
lines changed

6 files changed

+117
-12
lines changed

.azure/gpu-tests-pytorch.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,6 @@ jobs:
117117
set -e
118118
extra=$(python -c "print({'lightning': 'pytorch-'}.get('$(PACKAGE_NAME)', ''))")
119119
pip install -e ".[${extra}dev]" pytest-timeout -U --extra-index-url="${TORCH_URL}"
120-
pip install setuptools==75.6.0 jsonargparse==4.35.0
121120
displayName: "Install package & dependencies"
122121
123122
- bash: pip uninstall -y lightning

dockers/base-cuda/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ ENV \
3434
MAKEFLAGS="-j2"
3535

3636
RUN \
37-
apt-get update && apt-get install -y wget && \
37+
apt-get update --fix-missing && apt-get install -y wget && \
3838
apt-get update -qq --fix-missing && \
3939
NCCL_VER=$(dpkg -s libnccl2 | grep '^Version:' | awk -F ' ' '{print $2}' | awk -F '-' '{print $1}' | grep -ve '^\s*$') && \
4040
CUDA_VERSION_MM=${CUDA_VERSION%.*} && \

requirements/pytorch/extra.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
matplotlib>3.1, <3.10.0
66
omegaconf >=2.2.3, <2.4.0
77
hydra-core >=1.2.0, <1.4.0
8-
jsonargparse[signatures] >=4.28.0, <=4.40.0
8+
jsonargparse[signatures] >=4.39.0, <4.40.0
99
rich >=12.3.0, <14.1.0
1010
tensorboardX >=2.2, <2.7.0 # min version is set by torch.onnx missing attribute
1111
bitsandbytes >=0.45.2,<0.45.3; platform_system != "Darwin"

src/lightning/pytorch/CHANGELOG.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2929

3030
### Fixed
3131

32-
- Fixed logger_connector has edge case where step can be a float ([#20692](https://github.com/Lightning-AI/pytorch-lightning/issues/20692))
32+
- Fixed `save_hyperparameters` not working correctly with `LightningCLI` when there are parsing links applied on instantiation ([#20777](https://github.com/Lightning-AI/pytorch-lightning/pull/20777))
33+
34+
35+
- Fixed logger_connector has edge case where step can be a float ([#20692](https://github.com/Lightning-AI/pytorch-lightning/pull/20692))
3336

3437

3538
- Fix: Synchronize SIGTERM Handling in DDP to Prevent Deadlocks ([#20825](https://github.com/Lightning-AI/pytorch-lightning/pull/20825))

src/lightning/pytorch/cli.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,7 @@ def __init__(
327327
args: ArgsType = None,
328328
run: bool = True,
329329
auto_configure_optimizers: bool = True,
330+
load_from_checkpoint_support: bool = True,
330331
) -> None:
331332
"""Receives as input pytorch-lightning classes (or callables which return pytorch-lightning classes), which are
332333
called / instantiated using a parsed configuration file and / or command line args.
@@ -367,6 +368,11 @@ def __init__(
367368
``dict`` or ``jsonargparse.Namespace``.
368369
run: Whether subcommands should be added to run a :class:`~lightning.pytorch.trainer.trainer.Trainer`
369370
method. If set to ``False``, the trainer and model classes will be instantiated only.
371+
auto_configure_optimizers: Whether to automatically add default optimizer and lr_scheduler arguments.
372+
load_from_checkpoint_support: Whether ``save_hyperparameters`` should save the original parsed
373+
hyperparameters (instead of what ``__init__`` receives), such that it is possible for
374+
``load_from_checkpoint`` to correctly instantiate classes even when using complex nesting and
375+
dependency injection.
370376
371377
"""
372378
self.save_config_callback = save_config_callback
@@ -396,7 +402,8 @@ def __init__(
396402

397403
self._set_seed()
398404

399-
self._add_instantiators()
405+
if load_from_checkpoint_support:
406+
self._add_instantiators()
400407
self.before_instantiate_classes()
401408
self.instantiate_classes()
402409
self.after_instantiate_classes()
@@ -544,11 +551,14 @@ def parse_arguments(self, parser: LightningArgumentParser, args: ArgsType) -> No
544551
else:
545552
self.config = parser.parse_args(args)
546553

547-
def _add_instantiators(self) -> None:
554+
def _dump_config(self) -> None:
555+
if hasattr(self, "config_dump"):
556+
return
548557
self.config_dump = yaml.safe_load(self.parser.dump(self.config, skip_link_targets=False, skip_none=False))
549558
if "subcommand" in self.config:
550559
self.config_dump = self.config_dump[self.config.subcommand]
551560

561+
def _add_instantiators(self) -> None:
552562
self.parser.add_instantiator(
553563
_InstantiatorFn(cli=self, key="model"),
554564
_get_module_type(self._model_class),
@@ -799,12 +809,27 @@ def _get_module_type(value: Union[Callable, type]) -> type:
799809
return value
800810

801811

812+
def _set_dict_nested(data: dict, key: str, value: Any) -> None:
813+
keys = key.split(".")
814+
for k in keys[:-1]:
815+
assert k in data, f"Expected key {key} to be in data"
816+
data = data[k]
817+
data[keys[-1]] = value
818+
819+
802820
class _InstantiatorFn:
803821
def __init__(self, cli: LightningCLI, key: str) -> None:
804822
self.cli = cli
805823
self.key = key
806824

807-
def __call__(self, class_type: type[ModuleType], *args: Any, **kwargs: Any) -> ModuleType:
825+
def __call__(
826+
self,
827+
class_type: type[ModuleType],
828+
*args: Any,
829+
applied_instantiation_links: dict,
830+
**kwargs: Any,
831+
) -> ModuleType:
832+
self.cli._dump_config()
808833
hparams = self.cli.config_dump.get(self.key, {})
809834
if "class_path" in hparams:
810835
# To make hparams backwards compatible, and so that it is the same irrespective of subclass_mode, the
@@ -815,6 +840,15 @@ def __call__(self, class_type: type[ModuleType], *args: Any, **kwargs: Any) -> M
815840
**hparams.get("init_args", {}),
816841
**hparams.get("dict_kwargs", {}),
817842
}
843+
# get instantiation link target values from kwargs
844+
for key, value in applied_instantiation_links.items():
845+
if not key.startswith(f"{self.key}."):
846+
continue
847+
key = key[len(f"{self.key}.") :]
848+
if key.startswith("init_args."):
849+
key = key[len("init_args.") :]
850+
_set_dict_nested(hparams, key, value)
851+
818852
with _given_hyperparameters_context(
819853
hparams=hparams,
820854
instantiator="lightning.pytorch.cli.instantiate_module",

tests/tests_pytorch/test_cli.py

Lines changed: 74 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,7 @@ def __init__(self, activation: torch.nn.Module = None, transform: Optional[list[
550550
class BoringModelRequiredClasses(BoringModel):
551551
def __init__(self, num_classes: int, batch_size: int = 8):
552552
super().__init__()
553+
self.save_hyperparameters()
553554
self.num_classes = num_classes
554555
self.batch_size = batch_size
555556

@@ -561,35 +562,103 @@ def __init__(self, batch_size: int = 8):
561562
self.num_classes = 5 # only available after instantiation
562563

563564

564-
def test_lightning_cli_link_arguments():
565+
def test_lightning_cli_link_arguments(cleandir):
565566
class MyLightningCLI(LightningCLI):
566567
def add_arguments_to_parser(self, parser):
567568
parser.link_arguments("data.batch_size", "model.batch_size")
568569
parser.link_arguments("data.num_classes", "model.num_classes", apply_on="instantiate")
569570

570-
cli_args = ["--data.batch_size=12"]
571+
cli_args = ["--data.batch_size=12", "--trainer.max_epochs=1"]
571572

572573
with mock.patch("sys.argv", ["any.py"] + cli_args):
573574
cli = MyLightningCLI(BoringModelRequiredClasses, BoringDataModuleBatchSizeAndClasses, run=False)
574575

575576
assert cli.model.batch_size == 12
576577
assert cli.model.num_classes == 5
577578

578-
class MyLightningCLI(LightningCLI):
579+
cli.trainer.fit(cli.model)
580+
hparams_path = Path(cli.trainer.log_dir) / "hparams.yaml"
581+
assert hparams_path.is_file()
582+
hparams = yaml.safe_load(hparams_path.read_text())
583+
584+
hparams.pop("_instantiator")
585+
assert hparams == {"batch_size": 12, "num_classes": 5}
586+
587+
class MyLightningCLI2(LightningCLI):
579588
def add_arguments_to_parser(self, parser):
580589
parser.link_arguments("data.batch_size", "model.init_args.batch_size")
581590
parser.link_arguments("data.num_classes", "model.init_args.num_classes", apply_on="instantiate")
582591

583-
cli_args[-1] = "--model=tests_pytorch.test_cli.BoringModelRequiredClasses"
592+
cli_args[0] = "--model=tests_pytorch.test_cli.BoringModelRequiredClasses"
584593

585594
with mock.patch("sys.argv", ["any.py"] + cli_args):
586-
cli = MyLightningCLI(
595+
cli = MyLightningCLI2(
587596
BoringModelRequiredClasses, BoringDataModuleBatchSizeAndClasses, subclass_mode_model=True, run=False
588597
)
589598

590599
assert cli.model.batch_size == 8
591600
assert cli.model.num_classes == 5
592601

602+
cli.trainer.fit(cli.model)
603+
hparams_path = Path(cli.trainer.log_dir) / "hparams.yaml"
604+
assert hparams_path.is_file()
605+
hparams = yaml.safe_load(hparams_path.read_text())
606+
607+
hparams.pop("_instantiator")
608+
assert hparams == {"batch_size": 8, "num_classes": 5}
609+
610+
611+
class CustomAdam(torch.optim.Adam):
612+
def __init__(self, params, num_classes: Optional[int] = None, **kwargs):
613+
super().__init__(params, **kwargs)
614+
615+
616+
class DeepLinkTargetModel(BoringModel):
617+
def __init__(
618+
self,
619+
optimizer: OptimizerCallable = torch.optim.Adam,
620+
):
621+
super().__init__()
622+
self.save_hyperparameters()
623+
self.optimizer = optimizer
624+
625+
def configure_optimizers(self):
626+
optimizer = self.optimizer(self.parameters())
627+
return {"optimizer": optimizer}
628+
629+
630+
def test_lightning_cli_link_arguments_subcommands_nested_target(cleandir):
631+
class MyLightningCLI(LightningCLI):
632+
def add_arguments_to_parser(self, parser):
633+
parser.link_arguments(
634+
"data.num_classes",
635+
"model.init_args.optimizer.init_args.num_classes",
636+
apply_on="instantiate",
637+
)
638+
639+
cli_args = [
640+
"fit",
641+
"--data.batch_size=12",
642+
"--trainer.max_epochs=1",
643+
"--model=tests_pytorch.test_cli.DeepLinkTargetModel",
644+
"--model.optimizer=tests_pytorch.test_cli.CustomAdam",
645+
]
646+
647+
with mock.patch("sys.argv", ["any.py"] + cli_args):
648+
cli = MyLightningCLI(
649+
DeepLinkTargetModel,
650+
BoringDataModuleBatchSizeAndClasses,
651+
subclass_mode_model=True,
652+
auto_configure_optimizers=False,
653+
)
654+
655+
hparams_path = Path(cli.trainer.log_dir) / "hparams.yaml"
656+
assert hparams_path.is_file()
657+
hparams = yaml.safe_load(hparams_path.read_text())
658+
659+
assert hparams["optimizer"]["class_path"] == "tests_pytorch.test_cli.CustomAdam"
660+
assert hparams["optimizer"]["init_args"]["num_classes"] == 5
661+
593662

594663
class EarlyExitTestModel(BoringModel):
595664
def on_fit_start(self):

0 commit comments

Comments
 (0)