We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent c08d4a5 commit 224f4c5Copy full SHA for 224f4c5
pytorch_forecasting/models/temporal_fusion_transformer/tuning.py
@@ -22,6 +22,11 @@
22
optuna_logger = logging.getLogger("optuna")
23
24
25
+# need to inherit from callback for this to work
26
+class PyTorchLightningPruningCallbackAdjusted(PyTorchLightningPruningCallback, pl.Callback):
27
+ pass
28
+
29
30
def optimize_hyperparameters(
31
train_dataloaders: DataLoader,
32
val_dataloaders: DataLoader,
0 commit comments