1
1
"""
2
- Hyperparameters can be efficiently tuned with `optuna <https://optuna.readthedocs.io/>`_.
2
+ Module for hyperparameter optimization.
3
+
4
+ Hyperparameters can be efficiently tuned with `optuna <https://optuna.readthedocs.io/>`.
3
5
"""
4
6
5
7
__all__ = ["optimize_hyperparameters" ]
@@ -87,8 +89,9 @@ def optimize_hyperparameters(
87
89
** kwargs : Any ,
88
90
) -> optuna .Study :
89
91
"""
90
- Optimize hyperparameters. Run hyperparameter optimization. Learning rate for is determined with the
91
- PyTorch Lightning learning rate finder.
92
+ Optimize hyperparameters. Run hyperparameter optimization.
93
+
94
+ Learning rate for is determined with the PyTorch Lightning learning rate finder.
92
95
93
96
Args:
94
97
train_dataloaders (DataLoader):
@@ -98,65 +101,68 @@ def optimize_hyperparameters(
98
101
model_path (str):
99
102
Folder to which model checkpoints are saved.
100
103
monitor (str):
101
- Metric to return. The hyper-parameter (HP) tuner trains a model for a certain HP config, and
102
- reads this metric to score configuration. By default, the lower the better.
104
+ Metric to return. The hyper-parameter (HP) tuner trains a model for a certain HP config,
105
+ and reads this metric to score configuration. By default, the lower the better.
103
106
direction (str):
104
- By default, direction is "minimize", meaning that lower values of the specified `monitor` are
105
- better. You can change this, e.g. to "maximize".
107
+ By default, direction is "minimize", meaning that lower values of the specified
108
+ ``monitor`` are better. You can change this, e.g. to "maximize".
106
109
max_epochs (int, optional):
107
110
Maximum number of epochs to run training. Defaults to 20.
108
111
n_trials (int, optional):
109
112
Number of hyperparameter trials to run. Defaults to 100.
110
113
timeout (float, optional):
111
- Time in seconds after which training is stopped regardless of number of epochs or validation
112
- metric. Defaults to 3600*8.0.
114
+ Time in seconds after which training is stopped regardless of number of epochs or
115
+ validation metric. Defaults to 3600*8.0.
113
116
input_params (dict, optional):
114
- A dictionary, where each `key` contains another dictionary with two keys: `"method"` and
115
- ` "ranges"`. Example:
116
- >>> {"hidden_size": {
117
- >>> "method": "suggest_int",
118
- >>> "ranges": (16, 265),
119
- >>> }}
120
- The method key has to be a method of the `optuna.Trial` object. The ranges key are the input
121
- ranges for the specified method.
117
+ A dictionary, where each `` key`` contains another dictionary with two keys: `` "method"``
118
+ and `` "ranges"` `. Example:
119
+ >>> {"hidden_size": {
120
+ "method": "suggest_int",
121
+ "ranges": (16, 265),
122
+ }}
123
+ The method key has to be a method of the `` optuna.Trial`` object.
124
+ The ranges key are the input ranges for the specified method.
122
125
input_params_generator (Callable, optional):
123
- A function with the following signature: `fn(trial: optuna.Trial, **kwargs: Any) -> Dict[str, Any]
124
- `, returning the parameter values to set up your model for the current trial/run.
126
+ A function with the following signature:
127
+ `fn(trial: optuna.Trial, **kwargs: Any) -> Dict[str, Any]`,
128
+ returning the parameter values to set up your model for the current trial/run.
125
129
Example:
126
- >>> def fn(trial: optuna.Trial , param_ranges: Tuple[int, int] = (16, 265)) -> Dict[str, Any]:
127
- >>> param = trial.suggest_int("param", *param_ranges, log=True)
128
- >>> model_params = {"param": param}
129
- >>> return model_params
130
- Then, when your model is created (before training it and report the metrics for the current
131
- combination of hyperparameters), these dictionary is used as follows:
132
- >>> model = YourModelClass.from_dataset(
133
- >>> train_dataloaders.dataset,
134
- >>> log_interval=-1,
135
- >>> **model_params,
136
- >>> )
130
+ >>> def fn(trial, param_ranges = (16, 265)) -> Dict[str, Any]:
131
+ param = trial.suggest_int("param", *param_ranges, log=True)
132
+ model_params = {"param": param}
133
+ return model_params
134
+ Then, when your model is created (before training it and report the metrics for
135
+ the current combination of hyperparameters), these dictionary is used as follows:
136
+ >>> model = YourModelClass.from_dataset(
137
+ train_dataloaders.dataset,
138
+ log_interval=-1,
139
+ **model_params,
140
+ )
137
141
generator_params (dict, optional):
138
- The additional parameters to be passed to the `input_params_generator` function, if required.
142
+ The additional parameters to be passed to the ``input_params_generator`` function,
143
+ if required.
139
144
learning_rate_range (Tuple[float, float], optional):
140
145
Learning rate range. Defaults to (1e-5, 1.0).
141
146
use_learning_rate_finder (bool):
142
- If to use learning rate finder or optimize as part of hyperparameters. Defaults to True.
147
+ If to use learning rate finder or optimize as part of hyperparameters.
148
+ Defaults to True.
143
149
trainer_kwargs (Dict[str, Any], optional):
144
150
Additional arguments to the
145
- ` PyTorch Lightning trainer <https://pytorch-lightning.readthedocs.io/en/latest/trainer.html>`
146
- such as `limit_train_batches`. Defaults to {}.
151
+ PyTorch Lightning trainer such as ``limit_train_batches``.
152
+ Defaults to {}.
147
153
log_dir (str, optional):
148
154
Folder into which to log results for tensorboard. Defaults to "lightning_logs".
149
155
study (optuna.Study, optional):
150
156
Study to resume. Will create new study by default.
151
157
verbose (Union[int, bool]):
152
158
Level of verbosity.
153
- * None: no change in verbosity level (equivalent to verbose=1 by optuna-set default ).
159
+ * None: no change in verbosity level (equivalent to verbose=1).
154
160
* 0 or False: log only warnings.
155
161
* 1 or True: log pruning events.
156
162
* 2: optuna logging level at debug level.
157
163
Defaults to None.
158
164
pruner (optuna.pruners.BasePruner, optional):
159
- The optuna pruner to use. Defaults to `optuna.pruners.SuccessiveHalvingPruner()`.
165
+ The optuna pruner to use. Defaults to `` optuna.pruners.SuccessiveHalvingPruner()` `.
160
166
**kwargs:
161
167
Additional arguments for your model's class.
162
168
0 commit comments