Skip to content

Commit 7411db0

Browse files
qlzh727chenmoneygithubfchollet
authored
Cherrypick for cl/482011499: Throw error on deprecated fields. (#17179)
* Throw error on deprecated fields. PiperOrigin-RevId: 482011499 * copyedits Co-authored-by: Chen Qian <chenmoney@google.com> Co-authored-by: François Chollet <francois.chollet@gmail.com>
1 parent b12b9a1 commit 7411db0

File tree

2 files changed

+15
-15
lines changed

2 files changed

+15
-15
lines changed

keras/optimizers/optimizer_experimental/optimizer.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,18 +101,25 @@ def _create_iteration_variable(self):
101101
def _process_kwargs(self, kwargs):
102102
# Remove the `is_legacy_optimizer` arg, which is for serialization only.
103103
kwargs.pop("is_legacy_optimizer", None)
104+
lr = kwargs.pop("lr", None)
105+
if lr:
106+
logging.warning(
107+
"`lr` is deprecated, please use "
108+
"`learning_rate` instead, or use the legacy optimizer, e.g.,"
109+
f"tf.keras.optimizers.legacy.{self.__class__.__name__}."
110+
)
104111
legacy_kwargs = {
105-
"lr",
106112
"decay",
107-
"gradient_transformers",
108113
"gradient_aggregator",
114+
"gradient_transformers",
109115
}
110116
for k in kwargs:
111117
if k in legacy_kwargs:
112-
logging.warning(
113-
"%s is deprecated in `optimizer_experimental.Optimizer`"
114-
", please check the docstring for valid arguments.",
115-
k,
118+
raise ValueError(
119+
f"{k} is deprecated in the new Keras optimizer, please"
120+
"check the docstring for valid arguments, or use the "
121+
"legacy optimizer, e.g., "
122+
f"tf.keras.optimizers.legacy.{self.__class__.__name__}."
116123
)
117124
else:
118125
raise TypeError(

keras/optimizers/optimizer_experimental/optimizer_test.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,9 @@
44
"""
55

66
import os
7-
import re
87

98
import numpy as np
109
import tensorflow.compat.v2 as tf
11-
from absl import logging
1210
from absl.testing import parameterized
1311

1412
import keras
@@ -209,14 +207,9 @@ def testClipGlobalNorm(self):
209207
clipped_grad = optimizer._clip_gradients(grad)
210208
self.assertAllClose(clipped_grad[0], [0.5, 0.5])
211209

212-
def testPassingLegacyArgsRaiseWarning(self):
213-
with self.assertLogs(level="WARNING") as log_output:
214-
logging.set_verbosity(logging.WARNING)
210+
def testPassingLegacyArgsRaiseError(self):
211+
with self.assertRaisesRegex(ValueError, "decay is deprecated*"):
215212
_ = adam_new.Adam(clipnorm=1, decay=0.5)
216-
expected_log = "decay is deprecated in"
217-
output = log_output[0][0].message
218-
219-
self.assertTrue(re.search(expected_log, output))
220213

221214
def testPassingLegacyClipnorm(self):
222215
optimizer = adam_new.Adam(clipnorm=1)

0 commit comments

Comments
 (0)