Skip to content

Commit 37a0920

Browse files
authored
Fix for issue #21118: inconsistent behavior across callbacks (#21275)
* Replaced the mode="auto" logic in ModelCheckpoint with the generalized logic used in EarlyStopping. * Added a unit test for the fixed use case #15 * Extracted _set_monitor_op() to MonitoredCallback class for reuse across all callbacks that needs it * Changed to MonitorCallback * Changed to MonitorCallback * Added MonitorCallback * Added @pytest.mark.requires_trainable_backend * Changed to MonitorCallback * Removed exporting to public API
1 parent 22de2de commit 37a0920

File tree

7 files changed

+243
-124
lines changed

7 files changed

+243
-124
lines changed

keras/src/callbacks/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from keras.src.callbacks.lambda_callback import LambdaCallback
88
from keras.src.callbacks.learning_rate_scheduler import LearningRateScheduler
99
from keras.src.callbacks.model_checkpoint import ModelCheckpoint
10+
from keras.src.callbacks.monitor_callback import MonitorCallback
1011
from keras.src.callbacks.progbar_logger import ProgbarLogger
1112
from keras.src.callbacks.reduce_lr_on_plateau import ReduceLROnPlateau
1213
from keras.src.callbacks.remote_monitor import RemoteMonitor

keras/src/callbacks/early_stopping.py

Lines changed: 3 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
import warnings
22

3-
from keras.src import ops
43
from keras.src.api_export import keras_export
5-
from keras.src.callbacks.callback import Callback
6-
from keras.src.trainers import compile_utils
4+
from keras.src.callbacks.monitor_callback import MonitorCallback
75
from keras.src.utils import io_utils
86

97

108
@keras_export("keras.callbacks.EarlyStopping")
11-
class EarlyStopping(Callback):
9+
class EarlyStopping(MonitorCallback):
1210
"""Stop training when a monitored metric has stopped improving.
1311
1412
Assuming the goal of a training is to minimize the loss. With this, the
@@ -76,72 +74,20 @@ def __init__(
7674
restore_best_weights=False,
7775
start_from_epoch=0,
7876
):
79-
super().__init__()
80-
81-
self.monitor = monitor
77+
super().__init__(monitor, mode, min_delta=min_delta)
8278
self.patience = patience
8379
self.verbose = verbose
8480
self.baseline = baseline
85-
self.min_delta = abs(min_delta)
8681
self.wait = 0
8782
self.stopped_epoch = 0
8883
self.restore_best_weights = restore_best_weights
8984
self.best_weights = None
9085
self.start_from_epoch = start_from_epoch
9186

92-
if mode not in ["auto", "min", "max"]:
93-
warnings.warn(
94-
f"EarlyStopping mode {mode} is unknown, fallback to auto mode.",
95-
stacklevel=2,
96-
)
97-
mode = "auto"
98-
self.mode = mode
99-
self.monitor_op = None
100-
101-
def _set_monitor_op(self):
102-
if self.mode == "min":
103-
self.monitor_op = ops.less
104-
elif self.mode == "max":
105-
self.monitor_op = ops.greater
106-
else:
107-
metric_name = self.monitor.removeprefix("val_")
108-
if metric_name == "loss":
109-
self.monitor_op = ops.less
110-
if hasattr(self.model, "metrics"):
111-
all_metrics = []
112-
for m in self.model.metrics:
113-
if isinstance(
114-
m,
115-
(
116-
compile_utils.CompileMetrics,
117-
compile_utils.MetricsList,
118-
),
119-
):
120-
all_metrics.extend(m.metrics)
121-
for m in all_metrics:
122-
if m.name == metric_name:
123-
if hasattr(m, "_direction"):
124-
if m._direction == "up":
125-
self.monitor_op = ops.greater
126-
else:
127-
self.monitor_op = ops.less
128-
if self.monitor_op is None:
129-
raise ValueError(
130-
f"EarlyStopping callback received monitor={self.monitor} "
131-
"but Keras isn't able to automatically determine whether "
132-
"that metric should be maximized or minimized. "
133-
"Pass `mode='max'` in order to do early stopping based "
134-
"on the highest metric value, or pass `mode='min'` "
135-
"in order to use the lowest value."
136-
)
137-
if self.monitor_op == ops.less:
138-
self.min_delta *= -1
139-
14087
def on_train_begin(self, logs=None):
14188
# Allow instances to be re-used
14289
self.wait = 0
14390
self.stopped_epoch = 0
144-
self.best = None
14591
self.best_weights = None
14692
self.best_epoch = 0
14793

@@ -206,8 +152,3 @@ def get_monitor_value(self, logs):
206152
stacklevel=2,
207153
)
208154
return monitor_value
209-
210-
def _is_improvement(self, monitor_value, reference_value):
211-
if reference_value is None:
212-
return True
213-
return self.monitor_op(monitor_value - self.min_delta, reference_value)

keras/src/callbacks/model_checkpoint.py

Lines changed: 10 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66

77
from keras.src import backend
88
from keras.src.api_export import keras_export
9-
from keras.src.callbacks.callback import Callback
9+
from keras.src.callbacks.monitor_callback import MonitorCallback
1010
from keras.src.utils import file_utils
1111
from keras.src.utils import io_utils
1212

1313

1414
@keras_export("keras.callbacks.ModelCheckpoint")
15-
class ModelCheckpoint(Callback):
15+
class ModelCheckpoint(MonitorCallback):
1616
"""Callback to save the Keras model or model weights at some frequency.
1717
1818
`ModelCheckpoint` callback is used in conjunction with training using
@@ -105,9 +105,8 @@ class ModelCheckpoint(Callback):
105105
decision to overwrite the current save file is made based on either
106106
the maximization or the minimization of the monitored quantity.
107107
For `val_acc`, this should be `"max"`, for `val_loss` this should be
108-
`"min"`, etc. In `"auto"` mode, the mode is set to `"max"` if the
109-
quantities monitored are `"acc"` or start with `"fmeasure"` and are
110-
set to `"min"` for the rest of the quantities.
108+
`"min"`, etc. In `"auto"` mode, the direction is automatically
109+
inferred from the name of the monitored quantity.
111110
save_weights_only: if `True`, then only the model's weights will be
112111
saved (`model.save_weights(filepath)`), else the full model is
113112
saved (`model.save(filepath)`).
@@ -136,42 +135,14 @@ def __init__(
136135
save_freq="epoch",
137136
initial_value_threshold=None,
138137
):
139-
super().__init__()
140-
self.monitor = monitor
138+
super().__init__(monitor, mode, initial_value_threshold)
141139
self.verbose = verbose
142140
self.filepath = file_utils.path_to_string(filepath)
143141
self.save_best_only = save_best_only
144142
self.save_weights_only = save_weights_only
145143
self.save_freq = save_freq
146144
self._batches_seen_since_last_saving = 0
147145
self._last_batch_seen = 0
148-
self.best = initial_value_threshold
149-
150-
if mode not in ["auto", "min", "max"]:
151-
warnings.warn(
152-
f"ModelCheckpoint mode '{mode}' is unknown, "
153-
"fallback to auto mode.",
154-
stacklevel=2,
155-
)
156-
mode = "auto"
157-
158-
if mode == "min":
159-
self.monitor_op = np.less
160-
if self.best is None:
161-
self.best = np.inf
162-
elif mode == "max":
163-
self.monitor_op = np.greater
164-
if self.best is None:
165-
self.best = -np.inf
166-
else:
167-
if "acc" in self.monitor or self.monitor.startswith("fmeasure"):
168-
self.monitor_op = np.greater
169-
if self.best is None:
170-
self.best = -np.inf
171-
else:
172-
self.monitor_op = np.less
173-
if self.best is None:
174-
self.best = np.inf
175146

176147
if self.save_freq != "epoch" and not isinstance(self.save_freq, int):
177148
raise ValueError(
@@ -205,6 +176,10 @@ def on_epoch_begin(self, epoch, logs=None):
205176
self._current_epoch = epoch
206177

207178
def on_epoch_end(self, epoch, logs=None):
179+
if self.monitor_op is None:
180+
# Delay setup until the model's metrics are all built
181+
self._set_monitor_op()
182+
208183
if self.save_freq == "epoch":
209184
self._save_model(epoch=epoch, batch=None, logs=logs)
210185

@@ -262,7 +237,7 @@ def _should_save_model(self, epoch, batch, logs, filepath):
262237
)
263238
return True
264239
else:
265-
if self.monitor_op(current, self.best):
240+
if self._is_improvement(current, self.best):
266241
if self.verbose > 0:
267242
io_utils.print_msg(
268243
f"\nEpoch {epoch + 1}: {self.monitor} "

keras/src/callbacks/model_checkpoint_test.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def get_model():
164164
# Case 5: metric not available.
165165
cbks = [
166166
callbacks.ModelCheckpoint(
167-
filepath, monitor="unknown", save_best_only=True
167+
filepath, monitor="unknown", save_best_only=True, mode="min"
168168
)
169169
]
170170
with pytest.warns(UserWarning):
@@ -453,6 +453,37 @@ def get_model():
453453
)
454454
self.assertFalse(os.path.exists(filepath))
455455

456+
# Case 15: ModelCheckpoint doesn't save model if auc was max earlier in
457+
# auto mode
458+
mode = "auto"
459+
monitor = "val_auc"
460+
initial_value_threshold = 1
461+
save_best_only = True
462+
cbks = [
463+
callbacks.ModelCheckpoint(
464+
filepath,
465+
monitor=monitor,
466+
save_best_only=save_best_only,
467+
initial_value_threshold=initial_value_threshold,
468+
mode=mode,
469+
)
470+
]
471+
model.compile(
472+
loss="categorical_crossentropy",
473+
optimizer="sgd",
474+
metrics=[metrics.AUC()],
475+
)
476+
model.fit(
477+
x_train,
478+
y_train,
479+
batch_size=BATCH_SIZE,
480+
validation_data=(x_test, y_test),
481+
callbacks=cbks,
482+
epochs=1,
483+
verbose=0,
484+
)
485+
self.assertFalse(os.path.exists(filepath))
486+
456487
@pytest.mark.skipif(
457488
h5py is None,
458489
reason="`h5py` is a required dependency for `ModelCheckpoint` tests.",
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import warnings
2+
3+
from keras.src import ops
4+
from keras.src.callbacks.callback import Callback
5+
from keras.src.trainers import compile_utils
6+
7+
8+
class MonitorCallback(Callback):
9+
"""Base class for callbacks that monitor a quantity and evaluates
10+
improvements.
11+
12+
This class provides common functionality for callbacks that monitor a
13+
metric during training to determine whether a condition has been met,
14+
such as improvement over time. It encapsulates logic for selecting
15+
the comparison operation based on a `monitor` value and `mode`, and
16+
computing whether a new value is an improvement.
17+
18+
It is intended to be subclassed by other callbacks like `ModelCheckpoint`,
19+
`EarlyStopping`, or `ReduceLROnPlateau`, and is not meant to be used
20+
directly.
21+
22+
Arguments:
23+
monitor: Quantity to be monitored. Defaults to `"val_loss"`.
24+
mode: One of `{"auto", "min", "max"}`. In `min` mode, training will aim
25+
to minimize the monitored quantity; in `'max'` mode it will aim to
26+
maximize it.; in `"auto"` mode, the direction is automatically
27+
inferred from the name of the monitored quantity. Defaults to
28+
`"auto"`.
29+
baseline: Floating point initial "best" value of the metric to be
30+
monitored. If `None` (default), the first monitored value will be
31+
used.
32+
min_delta: Minimum change in the monitored quantity to qualify as an
33+
improvement, i.e. an absolute change of less than min_delta, will
34+
count as no improvement. Defaults to `0`.
35+
36+
Raises:
37+
ValueError: If `mode='auto'` is selected and the direction of the metric
38+
cannot be inferred.
39+
"""
40+
41+
def __init__(
42+
self,
43+
monitor="val_loss",
44+
mode="auto",
45+
baseline=None,
46+
min_delta=0,
47+
):
48+
super().__init__()
49+
if mode not in ["auto", "min", "max"]:
50+
warnings.warn(
51+
f"{self.__class__.__name__} mode '{mode}' is unknown, fallback "
52+
"to auto mode.",
53+
stacklevel=2,
54+
)
55+
mode = "auto"
56+
self.monitor = monitor
57+
self.mode = mode
58+
self.best = baseline
59+
self.min_delta = abs(min_delta)
60+
self.monitor_op = None
61+
62+
def _set_monitor_op(self):
63+
if self.mode == "min":
64+
self.monitor_op = ops.less
65+
elif self.mode == "max":
66+
self.monitor_op = ops.greater
67+
else:
68+
metric_name = self.monitor.removeprefix("val_")
69+
if metric_name == "loss":
70+
self.monitor_op = ops.less
71+
if hasattr(self.model, "metrics"):
72+
all_metrics = []
73+
for m in self.model.metrics:
74+
if isinstance(
75+
m,
76+
(
77+
compile_utils.CompileMetrics,
78+
compile_utils.MetricsList,
79+
),
80+
):
81+
all_metrics.extend(m.metrics)
82+
for m in all_metrics:
83+
if m.name == metric_name:
84+
if hasattr(m, "_direction"):
85+
if m._direction == "up":
86+
self.monitor_op = ops.greater
87+
else:
88+
self.monitor_op = ops.less
89+
if self.monitor_op is None:
90+
raise ValueError(
91+
f"{self.__class__.__name__} callback received "
92+
f"monitor={self.monitor}, but Keras isn't able to "
93+
"automatically determine whether that metric should be "
94+
"maximized or minimized. Pass `mode='max'` in order to "
95+
"monitor based on the highest metric value, or pass "
96+
"`mode='min'` in order to use the lowest value."
97+
)
98+
if self.monitor_op == ops.less:
99+
self.min_delta *= -1
100+
101+
def _is_improvement(self, monitor_value, reference_value):
102+
if reference_value is None:
103+
return True
104+
return self.monitor_op(monitor_value - self.min_delta, reference_value)

0 commit comments

Comments
 (0)