From c7fdcb404c32ac9d405d6fa6e4b27fbfb2b68596 Mon Sep 17 00:00:00 2001 From: YuanmingZhang <641994329@qq.com> Date: Mon, 21 Apr 2025 15:26:16 +0800 Subject: [PATCH] Update cli.py Make a minor change during automatic optimizer configuration - replace self.model.parameters() with an additional method - The behaviour is stay unchanged, but this change may allow users customizing their own CLI, with different parameters given to to optimizer. E.g. If one want to add weight decay to `weight` groups only (L2 regularization), he/she can usually use named_parameters to iterate over the model parameter to determine if one parameter should be put into the `weight`ed group, or remain unregulated. However, this seems impossible when using CLI for automatic configuration. Although one can still write the `configure_optimizers` him/herself, but I think making this minor change would give users a faster path to do such things without creating a bunch of codes. --- src/lightning/pytorch/cli.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/lightning/pytorch/cli.py b/src/lightning/pytorch/cli.py index 75a6347c95356..1e72394da4fc4 100644 --- a/src/lightning/pytorch/cli.py +++ b/src/lightning/pytorch/cli.py @@ -685,13 +685,17 @@ def get_automatic( f"`{self.__class__.__name__}.configure_optimizers`." ) - optimizer = instantiate_class(self.model.parameters(), optimizer_init) + optimizer = instantiate_class(self._get_model_parameters(), optimizer_init) lr_scheduler = instantiate_class(optimizer, lr_scheduler_init) if lr_scheduler_init else None fn = partial(self.configure_optimizers, optimizer=optimizer, lr_scheduler=lr_scheduler) update_wrapper(fn, self.configure_optimizers) # necessary for `is_overridden` # override the existing method self.model.configure_optimizers = MethodType(fn, self.model) + def _get_model_parameters(self): + assert hasattr(self,'model'), "model not instantiated yet. You have to instantiate a model object in order to access its parameters" + return self.model.parameters() + def _get(self, config: Namespace, key: str, default: Optional[Any] = None) -> Any: """Utility to get a config value which might be inside a subcommand.""" return config.get(str(self.subcommand), config).get(key, default)