diff --git a/src/lightning/pytorch/cli.py b/src/lightning/pytorch/cli.py index 75a6347c95356..1e72394da4fc4 100644 --- a/src/lightning/pytorch/cli.py +++ b/src/lightning/pytorch/cli.py @@ -685,13 +685,17 @@ def get_automatic( f"`{self.__class__.__name__}.configure_optimizers`." ) - optimizer = instantiate_class(self.model.parameters(), optimizer_init) + optimizer = instantiate_class(self._get_model_parameters(), optimizer_init) lr_scheduler = instantiate_class(optimizer, lr_scheduler_init) if lr_scheduler_init else None fn = partial(self.configure_optimizers, optimizer=optimizer, lr_scheduler=lr_scheduler) update_wrapper(fn, self.configure_optimizers) # necessary for `is_overridden` # override the existing method self.model.configure_optimizers = MethodType(fn, self.model) + def _get_model_parameters(self): + assert hasattr(self,'model'), "model not instantiated yet. You have to instantiate a model object in order to access its parameters" + return self.model.parameters() + def _get(self, config: Namespace, key: str, default: Optional[Any] = None) -> Any: """Utility to get a config value which might be inside a subcommand.""" return config.get(str(self.subcommand), config).get(key, default)