Skip to content

Commit e6784e4

Browse files
Merge pull request #17225 from lgeiger:fix-mixed-precision-ema (#17226)
PiperOrigin-RevId: 486999108 Co-authored-by: TensorFlower Gardener <gardener@tensorflow.org>
1 parent 7411db0 commit e6784e4

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

keras/engine/training_test.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2172,9 +2172,13 @@ def metrics(self):
21722172
)
21732173

21742174
@test_combinations.run_all_keras_modes(always_skip_v1=True)
2175-
def test_ema_overwrite(self):
2175+
@parameterized.named_parameters(
2176+
("mixed_float16", "mixed_float16"), ("float32", "float32")
2177+
)
2178+
def test_ema_overwrite(self, test_policy):
21762179
if not tf.__internal__.tf2.enabled():
21772180
self.skipTest("EMA optimizer is only available in TF2.")
2181+
policy.set_global_policy(test_policy)
21782182
model = sequential.Sequential()
21792183
model.add(input_layer.Input(shape=(4,)))
21802184
model.add(layers_module.Dense(1, activation="relu"))
@@ -2188,6 +2192,7 @@ def test_ema_overwrite(self):
21882192
history = model.fit(dataset, epochs=2, steps_per_epoch=10)
21892193
self.assertLen(history.history["loss"], 2)
21902194
self.assertAllClose(initial_value, model.trainable_variables[0])
2195+
policy.set_global_policy("float32")
21912196

21922197
@test_combinations.run_all_keras_modes(always_skip_v1=True)
21932198
def test_get_verbosity(self):

keras/mixed_precision/loss_scale_optimizer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1415,6 +1415,9 @@ def ema_momentum(self):
14151415
def ema_momentum(self, ema_momentum):
14161416
self._optimizer.ema_momentum = ema_momentum
14171417

1418+
def finalize_variable_values(self, var_list):
1419+
self._optimizer.finalize_variable_values(var_list)
1420+
14181421

14191422
class FakeOptimizerForRestoration(tf.__internal__.tracking.Trackable):
14201423
"""A fake optimizer used to support restoring TensorFlow 2.2 checkpoints.

0 commit comments

Comments
 (0)