From dcf5a9fd2cdbf1b52fe08b911f8f21c3d045490a Mon Sep 17 00:00:00 2001 From: Varchas Gopalaswamy Date: Mon, 13 May 2024 19:28:55 -0400 Subject: [PATCH 1/4] added some additional ways to do learning rate tuning --- src/lightning/pytorch/callbacks/lr_finder.py | 6 +- src/lightning/pytorch/tuner/lr_finder.py | 83 +++++++++++++++++--- src/lightning/pytorch/tuner/tuning.py | 2 + 3 files changed, 78 insertions(+), 13 deletions(-) diff --git a/src/lightning/pytorch/callbacks/lr_finder.py b/src/lightning/pytorch/callbacks/lr_finder.py index f667b5c501a10..437893519a582 100644 --- a/src/lightning/pytorch/callbacks/lr_finder.py +++ b/src/lightning/pytorch/callbacks/lr_finder.py @@ -18,7 +18,7 @@ Finds optimal learning rate """ -from typing import Optional +from typing import Optional, Literal from typing_extensions import override @@ -92,6 +92,7 @@ def __init__( early_stop_threshold: Optional[float] = 4.0, update_attr: bool = True, attr_name: str = "", + opt_method: Literal["gradient", "slide", "valley"] = "gradient", ) -> None: mode = mode.lower() if mode not in self.SUPPORTED_MODES: @@ -104,7 +105,7 @@ def __init__( self._early_stop_threshold = early_stop_threshold self._update_attr = update_attr self._attr_name = attr_name - + self._opt_method = opt_method self._early_exit = False self.lr_finder: Optional[_LRFinder] = None @@ -120,6 +121,7 @@ def lr_find(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> Non early_stop_threshold=self._early_stop_threshold, update_attr=self._update_attr, attr_name=self._attr_name, + opt_method=self._opt_method, ) if self._early_exit: diff --git a/src/lightning/pytorch/tuner/lr_finder.py b/src/lightning/pytorch/tuner/lr_finder.py index 8eebd3cd7f974..080583e706057 100644 --- a/src/lightning/pytorch/tuner/lr_finder.py +++ b/src/lightning/pytorch/tuner/lr_finder.py @@ -16,7 +16,7 @@ import os import uuid from copy import deepcopy -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, Literal import torch from lightning_utilities.core.imports import RequirementCache @@ -29,6 +29,7 @@ from lightning.pytorch.utilities.parsing import lightning_hasattr, lightning_setattr from lightning.pytorch.utilities.rank_zero import rank_zero_warn from lightning.pytorch.utilities.types import STEP_OUTPUT, LRSchedulerConfig +import numpy as np # check if ipywidgets is installed before importing tqdm.auto # to ensure it won't fail and a progress bar is displayed @@ -93,17 +94,30 @@ class _LRFinder: """ - def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int) -> None: + def __init__( + self, + mode: str, + lr_min: float, + lr_max: float, + num_training: int, + opt_method: Literal["gradient", "slide", "valley"] = "gradient", + opt_parameters: Dict[str, float | int] = None, + ) -> None: assert mode in ("linear", "exponential"), "mode should be either `linear` or `exponential`" self.mode = mode self.lr_min = lr_min self.lr_max = lr_max self.num_training = num_training - + self.opt_method = opt_method self.results: Dict[str, Any] = {} self._total_batch_idx = 0 # for debug purpose + assert self.opt_method in ["gradient", "slide", "valley"] + self.opt_parameters = opt_parameters + if self.opt_parameters is None: + self.opt_parameters = {} + def _exchange_scheduler(self, trainer: "pl.Trainer") -> None: # TODO: update docs here """Decorate `trainer.strategy.setup_optimizers` method such that it sets the user's originally specified @@ -167,6 +181,8 @@ def plot(self, suggest: bool = False, show: bool = False, ax: Optional["Axes"] = _ = self.suggestion() if self._optimal_idx: ax.plot(lrs[self._optimal_idx], losses[self._optimal_idx], markersize=10, marker="o", color="red") + elif self._optimal_lr: + ax.axvline(self._optimal_lr, linestyle="--") if show: plt.show() @@ -188,8 +204,10 @@ def suggestion(self, skip_begin: int = 10, skip_end: int = 1) -> Optional[float] """ losses = torch.tensor(self.results["loss"][skip_begin:-skip_end]) losses = losses[torch.isfinite(losses)] + lrs = self.results["lr"][skip_begin:-skip_end] - if len(losses) < 2: + self._optimal_lr = None + if self.opt_method == "gradient" and len(losses) < 2: # computing np.gradient requires at least 2 points log.error( "Failed to compute suggestion for learning rate because there are not enough points. Increase the loop" @@ -198,13 +216,55 @@ def suggestion(self, skip_begin: int = 10, skip_end: int = 1) -> Optional[float] self._optimal_idx = None return None - # TODO: When computing the argmin here, and some losses are non-finite, the expected indices could be - # incorrectly shifted by an offset - gradients = torch.gradient(losses)[0] # Unpack the tuple - min_grad = torch.argmin(gradients).item() + if self.opt_method == "gradient": + # TODO: When computing the argmin here, and some losses are non-finite, the expected indices could be + # incorrectly shifted by an offset + gradients = torch.gradient(losses)[0] # Unpack the tuple + min_grad = torch.argmin(gradients).item() + + self._optimal_idx = min_grad + skip_begin + opt_lr = self.results["lr"][self._optimal_idx] + elif self.opt_method == "slide": + + # See https://forums.fast.ai/t/automated-learning-rate-suggester/44199 "slide" method + loss_t = self.opt_parameters.get("loss_threshold", 0.5) + lr_diff = self.opt_parameters.get("lr_diff", 15) + adjust_value = self.opt_parameters.get("adjust_value", 1.0) + r_idx = -1 + l_idx = r_idx - lr_diff + gradients = torch.gradient(losses)[0] # Unpack the tuple + + while (l_idx >= -len(losses)) and (abs(gradients[r_idx] - gradients[l_idx]) > loss_t): + local_min_lr = lrs[l_idx] + r_idx -= 1 + l_idx -= 1 + opt_lr = local_min_lr * adjust_value + else: + # See https://forums.fast.ai/t/automated-learning-rate-suggester/44199 "valley" method + n = len(losses) + max_start = 0 + max_end = 0 + + # finding the longest valley. + lds = [1] * n + + for i in range(1, n): + for j in range(0, i): + if losses[i] < losses[j] and lds[i] < lds[j] + 1: + lds[i] = lds[j] + 1 + if lds[max_end] < lds[i]: + max_end = i + max_start = max_end - lds[max_end] + + sections = (max_end - max_start) / 3 + self._optimal_idx = ( + max_start + int(sections) + int(sections / 2) + ) + skip_begin # pick something midway, or 2/3rd of the way to be more aggressive + opt_lr = self.results["lr"][self._optimal_idx] + + self._optimal_lr = opt_lr - self._optimal_idx = min_grad + skip_begin - return self.results["lr"][self._optimal_idx] + return opt_lr def _lr_find( @@ -217,6 +277,7 @@ def _lr_find( early_stop_threshold: Optional[float] = 4.0, update_attr: bool = False, attr_name: str = "", + opt_method: Literal["gradient", "slide", "valley"] = "gradient", ) -> Optional[_LRFinder]: """Enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in picking a good starting learning rate. @@ -266,7 +327,7 @@ def _lr_find( trainer.progress_bar_callback.disable() # Initialize lr finder object (stores results) - lr_finder = _LRFinder(mode, min_lr, max_lr, num_training) + lr_finder = _LRFinder(mode, min_lr, max_lr, num_training, opt_method=opt_method) # Configure optimizer and scheduler lr_finder._exchange_scheduler(trainer) diff --git a/src/lightning/pytorch/tuner/tuning.py b/src/lightning/pytorch/tuner/tuning.py index 8b9b423619bd2..8486d0a1b124f 100644 --- a/src/lightning/pytorch/tuner/tuning.py +++ b/src/lightning/pytorch/tuner/tuning.py @@ -120,6 +120,7 @@ def lr_find( early_stop_threshold: Optional[float] = 4.0, update_attr: bool = True, attr_name: str = "", + opt_method: Literal["gradient", "slide", "valley"] = "gradient", ) -> Optional["_LRFinder"]: """Enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in picking a good starting learning rate. @@ -172,6 +173,7 @@ def lr_find( early_stop_threshold=early_stop_threshold, update_attr=update_attr, attr_name=attr_name, + opt_method=opt_method, ) lr_finder_callback._early_exit = True From b7f01260eed0f77616b2a8bc35c9f1b28e3f70bf Mon Sep 17 00:00:00 2001 From: Varchas Gopalaswamy Date: Mon, 13 May 2024 19:41:36 -0400 Subject: [PATCH 2/4] added doc entry --- src/lightning/pytorch/callbacks/lr_finder.py | 1 + src/lightning/pytorch/tuner/lr_finder.py | 4 +++- src/lightning/pytorch/tuner/tuning.py | 2 +- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/lightning/pytorch/callbacks/lr_finder.py b/src/lightning/pytorch/callbacks/lr_finder.py index 437893519a582..89ef04b77c511 100644 --- a/src/lightning/pytorch/callbacks/lr_finder.py +++ b/src/lightning/pytorch/callbacks/lr_finder.py @@ -50,6 +50,7 @@ class LearningRateFinder(Callback): update_attr: Whether to update the learning rate attribute or not. attr_name: Name of the attribute which stores the learning rate. The names 'learning_rate' or 'lr' get automatically detected. Otherwise, set the name here. + opt_method: Chooses how the optimum learning rate is determined. It can be any of ``("gradient", "slide", "valley")``. Example:: diff --git a/src/lightning/pytorch/tuner/lr_finder.py b/src/lightning/pytorch/tuner/lr_finder.py index 080583e706057..c138efbb3ac6f 100644 --- a/src/lightning/pytorch/tuner/lr_finder.py +++ b/src/lightning/pytorch/tuner/lr_finder.py @@ -79,6 +79,8 @@ class _LRFinder: num_training: number of steps to take between lr_min and lr_max + opt_method: Chooses how the optimum learning rate is determined. It can be any of ``("gradient", "slide", "valley")``. + Example:: # Run lr finder lr_finder = trainer.lr_find(model) @@ -299,7 +301,7 @@ def _lr_find( update_attr: Whether to update the learning rate attribute or not. attr_name: Name of the attribute which stores the learning rate. The names 'learning_rate' or 'lr' get automatically detected. Otherwise, set the name here. - + opt_method: Chooses how the optimum learning rate is determined. It can be any of ``("gradient", "slide", "valley")``. """ if trainer.fast_dev_run: rank_zero_warn("Skipping learning rate finder since `fast_dev_run` is enabled.") diff --git a/src/lightning/pytorch/tuner/tuning.py b/src/lightning/pytorch/tuner/tuning.py index 8486d0a1b124f..437d65407b823 100644 --- a/src/lightning/pytorch/tuner/tuning.py +++ b/src/lightning/pytorch/tuner/tuning.py @@ -149,7 +149,7 @@ def lr_find( update_attr: Whether to update the learning rate attribute or not. attr_name: Name of the attribute which stores the learning rate. The names 'learning_rate' or 'lr' get automatically detected. Otherwise, set the name here. - + opt_method: Chooses how the optimum learning rate is determined. It can be any of ``("gradient", "slide", "valley")``. Raises: MisconfigurationException: If learning rate/lr in ``model`` or ``model.hparams`` isn't overridden, From 42fcbb0582eea9e4e4ea26c073e686af6f224a7c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 13 May 2024 23:47:29 +0000 Subject: [PATCH 3/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/pytorch/callbacks/lr_finder.py | 2 +- src/lightning/pytorch/tuner/lr_finder.py | 5 ++--- src/lightning/pytorch/tuner/tuning.py | 2 +- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/lightning/pytorch/callbacks/lr_finder.py b/src/lightning/pytorch/callbacks/lr_finder.py index 89ef04b77c511..39a9a97eb7b7c 100644 --- a/src/lightning/pytorch/callbacks/lr_finder.py +++ b/src/lightning/pytorch/callbacks/lr_finder.py @@ -18,7 +18,7 @@ Finds optimal learning rate """ -from typing import Optional, Literal +from typing import Literal, Optional from typing_extensions import override diff --git a/src/lightning/pytorch/tuner/lr_finder.py b/src/lightning/pytorch/tuner/lr_finder.py index c138efbb3ac6f..4ccab837d5d96 100644 --- a/src/lightning/pytorch/tuner/lr_finder.py +++ b/src/lightning/pytorch/tuner/lr_finder.py @@ -16,7 +16,7 @@ import os import uuid from copy import deepcopy -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, Literal +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union import torch from lightning_utilities.core.imports import RequirementCache @@ -29,7 +29,6 @@ from lightning.pytorch.utilities.parsing import lightning_hasattr, lightning_setattr from lightning.pytorch.utilities.rank_zero import rank_zero_warn from lightning.pytorch.utilities.types import STEP_OUTPUT, LRSchedulerConfig -import numpy as np # check if ipywidgets is installed before importing tqdm.auto # to ensure it won't fail and a progress bar is displayed @@ -227,7 +226,6 @@ def suggestion(self, skip_begin: int = 10, skip_end: int = 1) -> Optional[float] self._optimal_idx = min_grad + skip_begin opt_lr = self.results["lr"][self._optimal_idx] elif self.opt_method == "slide": - # See https://forums.fast.ai/t/automated-learning-rate-suggester/44199 "slide" method loss_t = self.opt_parameters.get("loss_threshold", 0.5) lr_diff = self.opt_parameters.get("lr_diff", 15) @@ -302,6 +300,7 @@ def _lr_find( attr_name: Name of the attribute which stores the learning rate. The names 'learning_rate' or 'lr' get automatically detected. Otherwise, set the name here. opt_method: Chooses how the optimum learning rate is determined. It can be any of ``("gradient", "slide", "valley")``. + """ if trainer.fast_dev_run: rank_zero_warn("Skipping learning rate finder since `fast_dev_run` is enabled.") diff --git a/src/lightning/pytorch/tuner/tuning.py b/src/lightning/pytorch/tuner/tuning.py index 437d65407b823..9e13aab001716 100644 --- a/src/lightning/pytorch/tuner/tuning.py +++ b/src/lightning/pytorch/tuner/tuning.py @@ -149,7 +149,7 @@ def lr_find( update_attr: Whether to update the learning rate attribute or not. attr_name: Name of the attribute which stores the learning rate. The names 'learning_rate' or 'lr' get automatically detected. Otherwise, set the name here. - opt_method: Chooses how the optimum learning rate is determined. It can be any of ``("gradient", "slide", "valley")``. + opt_method: Chooses how the optimum learning rate is determined. It can be any of ``("gradient", "slide", "valley")``. Raises: MisconfigurationException: If learning rate/lr in ``model`` or ``model.hparams`` isn't overridden, From 3c2fc6de694905d0603d58e4b4c2710315229b04 Mon Sep 17 00:00:00 2001 From: Varchas Gopalaswamy Date: Thu, 16 May 2024 13:05:37 -0400 Subject: [PATCH 4/4] added a constrained gradient option --- src/lightning/pytorch/tuner/lr_finder.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/lightning/pytorch/tuner/lr_finder.py b/src/lightning/pytorch/tuner/lr_finder.py index 4ccab837d5d96..bd151bd7cdccb 100644 --- a/src/lightning/pytorch/tuner/lr_finder.py +++ b/src/lightning/pytorch/tuner/lr_finder.py @@ -101,7 +101,7 @@ def __init__( lr_min: float, lr_max: float, num_training: int, - opt_method: Literal["gradient", "slide", "valley"] = "gradient", + opt_method: Literal["gradient", "slide", "valley", "valley_grad"] = "gradient", opt_parameters: Dict[str, float | int] = None, ) -> None: assert mode in ("linear", "exponential"), "mode should be either `linear` or `exponential`" @@ -114,7 +114,6 @@ def __init__( self.results: Dict[str, Any] = {} self._total_batch_idx = 0 # for debug purpose - assert self.opt_method in ["gradient", "slide", "valley"] self.opt_parameters = opt_parameters if self.opt_parameters is None: self.opt_parameters = {} @@ -239,7 +238,7 @@ def suggestion(self, skip_begin: int = 10, skip_end: int = 1) -> Optional[float] r_idx -= 1 l_idx -= 1 opt_lr = local_min_lr * adjust_value - else: + elif self.opt_method in ["valley", "valley_grad"]: # See https://forums.fast.ai/t/automated-learning-rate-suggester/44199 "valley" method n = len(losses) max_start = 0 @@ -257,9 +256,17 @@ def suggestion(self, skip_begin: int = 10, skip_end: int = 1) -> Optional[float] max_start = max_end - lds[max_end] sections = (max_end - max_start) / 3 - self._optimal_idx = ( + valley_lip_idx = ( max_start + int(sections) + int(sections / 2) ) + skip_begin # pick something midway, or 2/3rd of the way to be more aggressive + if self.opt_method == "valley": + self._optimal_idx = valley_lip_idx + # Look for grad minimum inside the feasible region + else: + feasible_region = slice(valley_lip_idx, valley_lip_idx + losses[valley_lip_idx:].argmin()) + gradients = torch.gradient(losses)[0] # Unpack the tuple + self._optimal_idx = gradients[feasible_region].argmin() + valley_lip_idx + opt_lr = self.results["lr"][self._optimal_idx] self._optimal_lr = opt_lr @@ -277,7 +284,7 @@ def _lr_find( early_stop_threshold: Optional[float] = 4.0, update_attr: bool = False, attr_name: str = "", - opt_method: Literal["gradient", "slide", "valley"] = "gradient", + opt_method: Literal["gradient", "slide", "valley", "valley_grad"] = "gradient", ) -> Optional[_LRFinder]: """Enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in picking a good starting learning rate.