From a53a152f1e0323a61a35a0968fdc21aabaafce6f Mon Sep 17 00:00:00 2001 From: Antonio Sanchez Date: Tue, 29 Apr 2025 10:08:02 -0700 Subject: [PATCH] Add custom variable updater. Allows customization for how variables are updated by the optimizer. The base optimizer simply defers to the update handler to do the update, allowing full customization. Can replace the existing `overwrite_with_gradient` attribute on variables, which currently is very application-specific. --- keras/api/optimizers/__init__.py | 1 + keras/src/backend/common/variables.py | 28 +++++++ keras/src/optimizers/__init__.py | 2 + keras/src/optimizers/base_optimizer.py | 63 +++++--------- keras/src/optimizers/loss_scale_optimizer.py | 10 +-- keras/src/optimizers/optimizer.py | 86 ++++++++++++++++++++ keras/src/optimizers/optimizer_test.py | 19 ++++- 7 files changed, 156 insertions(+), 53 deletions(-) diff --git a/keras/api/optimizers/__init__.py b/keras/api/optimizers/__init__.py index 40f6ab4018f5..044b78ba38e7 100644 --- a/keras/api/optimizers/__init__.py +++ b/keras/api/optimizers/__init__.py @@ -24,5 +24,6 @@ from keras.src.optimizers.muon import Muon as Muon from keras.src.optimizers.nadam import Nadam as Nadam from keras.src.optimizers.optimizer import Optimizer as Optimizer +from keras.src.optimizers.optimizer import VariableUpdater as VariableUpdater from keras.src.optimizers.rmsprop import RMSprop as RMSprop from keras.src.optimizers.sgd import SGD as SGD diff --git a/keras/src/backend/common/variables.py b/keras/src/backend/common/variables.py index 47ce553f1e98..e3955fd1fa7b 100644 --- a/keras/src/backend/common/variables.py +++ b/keras/src/backend/common/variables.py @@ -150,6 +150,8 @@ def __init__( self._autocast = bool(autocast) self._aggregation = aggregation self._synchronization = synchronization + # Custom variable updater. + self._updater = None # `self._overwrite_with_gradient` is an internal property to determine # whether this variable should be overwritten by the computed gradient. # Ref: https://github.com/google/flax/blob/main/flax/linen/fp8_ops.py @@ -334,6 +336,29 @@ def path(self): """The path of the variable within the Keras model or layer.""" return self._path + @property + def updater(self): + """Custom variable updater. + + This property is designed for special-casing variable updates during + training, such as quantized float8 `scale` and `amax_history`, where + the gradients represent updated scale factors, or for updating large + embedding tables, where we need to handle sparse updates to a dense + table. + """ + return self._updater + + @updater.setter + def updater(self, updater): + from keras.src import optimizers + + if not isinstance(updater, optimizers.VariableUpdater): + raise TypeError( + "`updater` must be a `keras.optimizers.VariableUpdater`. " + f"Received: {updater.__class__.__name__}." + ) + self._updater = updater + @property def overwrite_with_gradient(self): """Whether this variable should be overwritten by the gradient. @@ -355,6 +380,9 @@ def overwrite_with_gradient(self, value): f"Received: {value}" ) self._overwrite_with_gradient = value + from keras.src import optimizers + + self._updater = optimizers.OverwriteScaleWithGradientUpdater() @property def regularizer(self): diff --git a/keras/src/optimizers/__init__.py b/keras/src/optimizers/__init__.py index 4db5319793ea..56ba2d71ec31 100644 --- a/keras/src/optimizers/__init__.py +++ b/keras/src/optimizers/__init__.py @@ -11,6 +11,8 @@ from keras.src.optimizers.muon import Muon from keras.src.optimizers.nadam import Nadam from keras.src.optimizers.optimizer import Optimizer +from keras.src.optimizers.optimizer import OverwriteScaleWithGradientUpdater +from keras.src.optimizers.optimizer import VariableUpdater from keras.src.optimizers.rmsprop import RMSprop from keras.src.optimizers.sgd import SGD from keras.src.saving import serialization_lib diff --git a/keras/src/optimizers/base_optimizer.py b/keras/src/optimizers/base_optimizer.py index 261feff5824a..a6a57a5f18b5 100644 --- a/keras/src/optimizers/base_optimizer.py +++ b/keras/src/optimizers/base_optimizer.py @@ -204,6 +204,9 @@ def iterations(self): def _track_variable(self, variable): self._tracker.add_to_store("variables", variable) + def _get_variable_updater(self, variable): + return getattr(variable, "updater", None) + @tracking.no_automatic_dependency_tracking def build(self, variables): if self.use_ema: @@ -212,6 +215,11 @@ def build(self, variables): self._accumulated_gradients = [] for i, variable in enumerate(variables): self._trainable_variables_indices[self._var_key(variable)] = i + custom_updater = self._get_variable_updater(variable) + if custom_updater is not None: + # Build the updater. + custom_updater.build(self, variable) + if self.use_ema: self._model_variables_moving_average.append( self.add_variable_from_reference( @@ -431,10 +439,8 @@ def apply(self, grads, trainable_variables=None): # Overwrite targeted variables directly with their gradients if # their `overwrite_with_gradient` is set. - grads, trainable_variables = ( - self._overwrite_variables_directly_with_gradients( - grads, trainable_variables - ) + grads, trainable_variables = self.__handle_custom_updaters( + grads, trainable_variables ) if len(list(grads)) == 0: @@ -698,21 +704,14 @@ def _get_current_learning_rate(self): return self._learning_rate() return self._learning_rate - def _overwrite_variables_directly_with_gradients(self, grads, vars): - """Overwrite the variables directly by their gradients. - - This method is designed for a special case where we want to overwrite - the variable directly with its computed gradient. For example, in float8 - training, new `scale` and `amax_history` are computed as gradients, and - we want to overwrite them directly instead of following the typical - procedure such as gradient descent with a learning rate, gradient - clipping and weight decaying. + def __handle_custom_updaters(self, grads, vars): + """Update any variable that has a custom updater. After the update, the processed pairs will be filtered out. """ # Shortcut for `tf.Variable` because it doesn't have a - # `overwrite_with_gradient` attr - if any(not hasattr(v, "overwrite_with_gradient") for v in vars): + # `updater` attr. + if not any(self._get_variable_updater(v) is not None for v in vars): return grads, vars # Shallow copies @@ -722,33 +721,8 @@ def _overwrite_variables_directly_with_gradients(self, grads, vars): # Iterate from right to left for safe popping for i in range(len(filtered_grads) - 1, -1, -1): g, v = filtered_grads[i], filtered_vars[i] - if v.overwrite_with_gradient: - if self.gradient_accumulation_steps: - # Utilize a stateless manner for JAX compatibility - steps = self.gradient_accumulation_steps - is_update_step = (self._iterations + 1) % steps == 0 - acc_g = self._accumulated_gradients[ - self._get_variable_index(v) - ] - # `ops.maximum` is utilized for gradient accumulation for - # `overwrite_with_gradient=True` variables - new_g_acc = ops.cond( - is_update_step, - lambda: ops.zeros(g.shape, dtype=g.dtype), - lambda: ops.maximum(g, acc_g), - ) - new_g = ops.cond( - is_update_step, - lambda: ops.maximum(g, acc_g), - lambda: g, - ) - new_v = ops.cond( - is_update_step, lambda: new_g, lambda: v.value - ) - v.assign(new_v) - acc_g.assign(new_g_acc) - else: - v.assign(g) + if v.updater: + v.updater.update_step(g, v) filtered_grads.pop(i) filtered_vars.pop(i) return filtered_grads, filtered_vars @@ -926,6 +900,11 @@ def finalize_variable_values(self, var_list): # optimizer. self._overwrite_model_variables_with_average_value(var_list) + for var in var_list: + updater = self._get_variable_updater(var) + if updater is not None: + updater.finalize_variable_value(var) + def _obj_type(self): return "Optimizer" diff --git a/keras/src/optimizers/loss_scale_optimizer.py b/keras/src/optimizers/loss_scale_optimizer.py index 1e7449b2fd81..9be5401f668c 100644 --- a/keras/src/optimizers/loss_scale_optimizer.py +++ b/keras/src/optimizers/loss_scale_optimizer.py @@ -102,12 +102,6 @@ def stateless_apply(self, optimizer_variables, grads, trainable_variables): ), ) - def _overwrite_variable_with_gradient(self, variable): - return ( - hasattr(variable, "overwrite_with_gradient") - and variable.overwrite_with_gradient - ) - def _stateless_handle_finite_grads( self, optimizer_variables, grads, trainable_variables ): @@ -137,7 +131,7 @@ def increment(): scale = self.dynamic_scale unscaled_grads = [ g - if g is None or self._overwrite_variable_with_gradient(v) + if g is None or self._get_variable_updater(v) is not None else ops.divide(g, scale) for g, v in zip(grads, trainable_variables) ] @@ -183,7 +177,7 @@ def _stateful_handle_finite_grads(self, grads, trainable_variables): tvs = trainable_variables or self._trainable_variables unscaled_grads = [ g - if g is None or self._overwrite_variable_with_gradient(v) + if g is None or self._get_variable_updater(v) is not None else ops.divide(g, scale) for g, v in zip(grads, tvs) ] diff --git a/keras/src/optimizers/optimizer.py b/keras/src/optimizers/optimizer.py index c285b814ba74..6ac3a98fec41 100644 --- a/keras/src/optimizers/optimizer.py +++ b/keras/src/optimizers/optimizer.py @@ -1,4 +1,5 @@ from keras.src import backend +from keras.src import ops from keras.src.api_export import keras_export from keras.src.optimizers import base_optimizer @@ -23,5 +24,90 @@ class Optimizer(BackendOptimizer, base_optimizer.BaseOptimizer): pass +@keras_export("keras.optimizers.VariableUpdater") +class VariableUpdater: + """Allows special handling of variable updates.""" + + def build(self, optimizer, variable): + """Set up any state that might depend on the optimizer. + + This may add variables directly to the optimizer for updating state. + + Args: + optimizer: The optimizer used to update the variables during training. + variable: Variable to update. + """ + pass + + def update_step(self, gradient, variable): + """Update the variable state using the supplied gradient. + + Args: + gradient: Gradient for the variable. + variable: Variable to update. + """ + pass + + def finalize_variable_value(self, variable): + """Set the final value of the trainable variable. + + Sometimes there are some extra steps before ending the variable updates, + such as overriding the model variables with its average value. + + Args: + variable: Variable to finalize. + """ + pass + + +class OverwriteScaleWithGradientUpdater(VariableUpdater): + """Special variable update handler for float8 quantization scales. + + The "gradient" of the scale factor (scale, amax_history) is actually the + updated scale to assign to the variable. Supports gradient accumulation + steps, in which the maximum scale factor between intermediate gradient + steps is recorded. + """ + + def build(self, optimizer, variable): + # Keep reference copy of iterations so we can update gradient + # accumulators appropriately. + self._iterations = optimizer._iterations + # Support gradient accumulation by adding an accumulator directly + # to the optimizer. + self._gradient_accumulation_steps = ( + optimizer.gradient_accumulation_steps + ) + if self._gradient_accumulation_steps: + self.gradient_accumulator = optimizer.add_variable_from_reference( + reference_variable=variable, name="gradient_accumulation" + ) + + def update_step(self, gradient, variable): + if self._gradient_accumulation_steps: + # Utilize a stateless manner for JAX compatibility + steps = self._gradient_accumulation_steps + is_update_step = (self._iterations + 1) % steps == 0 + # Keep track of the maximum scale factor encountered. + new_g_acc = ops.cond( + is_update_step, + lambda: ops.zeros(gradient.shape, dtype=gradient.dtype), + lambda: ops.maximum(gradient, self.gradient_accumulator), + ) + new_g = ops.cond( + is_update_step, + lambda: ops.maximum(gradient, self.gradient_accumulator), + lambda: gradient, + ) + new_v = ops.cond( + is_update_step, lambda: new_g, lambda: variable.value + ) + variable.assign(new_v) + self.gradient_accumulator.assign(new_g_acc) + else: + # Assign scale "gradient" directly to variable. + variable.assign(gradient) + + Optimizer.__doc__ = base_optimizer.BaseOptimizer.__doc__ base_optimizer_keyword_args = base_optimizer.base_optimizer_keyword_args diff --git a/keras/src/optimizers/optimizer_test.py b/keras/src/optimizers/optimizer_test.py index 7d661df9a3c0..606bd20acfd2 100644 --- a/keras/src/optimizers/optimizer_test.py +++ b/keras/src/optimizers/optimizer_test.py @@ -152,6 +152,19 @@ def test_constraints_are_applied(self): optimizer.apply_gradients([(grad, v)]) self.assertAlmostEqual(np.min(v), 0.0) + def test_custom_variable_updater(self): + class IncrementVariable(optimizers.VariableUpdater): + def update_step(self, gradient, variable): + variable.assign_add(1.0) + + orig_value = np.random.random((2, 2)) - 1.0 + v = backend.Variable(orig_value) + v.updater = IncrementVariable() + optimizer = optimizers.SGD(learning_rate=0.0001) + grad = backend.numpy.zeros((2, 2)) + optimizer.apply_gradients([(grad, v)]) + self.assertAllClose(v, orig_value + 1) + def test_get_method(self): obj = optimizers.get("sgd") self.assertIsInstance(obj, optimizers.SGD) @@ -298,7 +311,7 @@ def test_overwrite_with_gradient_with_gradient_accumulation(self): self.assertAllClose(v, [[1.0, 2.0], [3.0, 4.0]]) self.assertAllClose(v2, [[1.0, 2.0], [3.0, 4.0]]) self.assertAllClose( - optimizer._accumulated_gradients[0], [[1.0, 1.0], [1.0, 1.0]] + v.updater.gradient_accumulator, [[1.0, 1.0], [1.0, 1.0]] ) self.assertAllClose( optimizer._accumulated_gradients[1], [[1.0, 1.0], [1.0, 1.0]] @@ -311,7 +324,7 @@ def test_overwrite_with_gradient_with_gradient_accumulation(self): self.assertAllClose(v, [[2.0, 2.0], [2.0, 2.0]]) self.assertAllClose(v2, [[-0.5, 0.5], [1.5, 2.5]]) self.assertAllClose( - optimizer._accumulated_gradients[0], [[0.0, 0.0], [0.0, 0.0]] + v.updater.gradient_accumulator, [[0.0, 0.0], [0.0, 0.0]] ) self.assertAllClose( optimizer._accumulated_gradients[1], [[0.0, 0.0], [0.0, 0.0]] @@ -324,7 +337,7 @@ def test_overwrite_with_gradient_with_gradient_accumulation(self): self.assertAllClose(v, [[2.0, 2.0], [2.0, 2.0]]) self.assertAllClose(v2, [[-0.5, 0.5], [1.5, 2.5]]) self.assertAllClose( - optimizer._accumulated_gradients[0], [[1.0, 1.0], [1.0, 1.0]] + v.updater.gradient_accumulator, [[1.0, 1.0], [1.0, 1.0]] ) self.assertAllClose( optimizer._accumulated_gradients[1], [[1.0, 1.0], [1.0, 1.0]]