diff --git a/keras/api/optimizers/__init__.py b/keras/api/optimizers/__init__.py index 40f6ab4018f5..044b78ba38e7 100644 --- a/keras/api/optimizers/__init__.py +++ b/keras/api/optimizers/__init__.py @@ -24,5 +24,6 @@ from keras.src.optimizers.muon import Muon as Muon from keras.src.optimizers.nadam import Nadam as Nadam from keras.src.optimizers.optimizer import Optimizer as Optimizer +from keras.src.optimizers.optimizer import VariableUpdater as VariableUpdater from keras.src.optimizers.rmsprop import RMSprop as RMSprop from keras.src.optimizers.sgd import SGD as SGD diff --git a/keras/src/backend/common/variables.py b/keras/src/backend/common/variables.py index 47ce553f1e98..e3955fd1fa7b 100644 --- a/keras/src/backend/common/variables.py +++ b/keras/src/backend/common/variables.py @@ -150,6 +150,8 @@ def __init__( self._autocast = bool(autocast) self._aggregation = aggregation self._synchronization = synchronization + # Custom variable updater. + self._updater = None # `self._overwrite_with_gradient` is an internal property to determine # whether this variable should be overwritten by the computed gradient. # Ref: https://github.com/google/flax/blob/main/flax/linen/fp8_ops.py @@ -334,6 +336,29 @@ def path(self): """The path of the variable within the Keras model or layer.""" return self._path + @property + def updater(self): + """Custom variable updater. + + This property is designed for special-casing variable updates during + training, such as quantized float8 `scale` and `amax_history`, where + the gradients represent updated scale factors, or for updating large + embedding tables, where we need to handle sparse updates to a dense + table. + """ + return self._updater + + @updater.setter + def updater(self, updater): + from keras.src import optimizers + + if not isinstance(updater, optimizers.VariableUpdater): + raise TypeError( + "`updater` must be a `keras.optimizers.VariableUpdater`. " + f"Received: {updater.__class__.__name__}." + ) + self._updater = updater + @property def overwrite_with_gradient(self): """Whether this variable should be overwritten by the gradient. @@ -355,6 +380,9 @@ def overwrite_with_gradient(self, value): f"Received: {value}" ) self._overwrite_with_gradient = value + from keras.src import optimizers + + self._updater = optimizers.OverwriteScaleWithGradientUpdater() @property def regularizer(self): diff --git a/keras/src/optimizers/__init__.py b/keras/src/optimizers/__init__.py index 4db5319793ea..56ba2d71ec31 100644 --- a/keras/src/optimizers/__init__.py +++ b/keras/src/optimizers/__init__.py @@ -11,6 +11,8 @@ from keras.src.optimizers.muon import Muon from keras.src.optimizers.nadam import Nadam from keras.src.optimizers.optimizer import Optimizer +from keras.src.optimizers.optimizer import OverwriteScaleWithGradientUpdater +from keras.src.optimizers.optimizer import VariableUpdater from keras.src.optimizers.rmsprop import RMSprop from keras.src.optimizers.sgd import SGD from keras.src.saving import serialization_lib diff --git a/keras/src/optimizers/base_optimizer.py b/keras/src/optimizers/base_optimizer.py index 261feff5824a..a6a57a5f18b5 100644 --- a/keras/src/optimizers/base_optimizer.py +++ b/keras/src/optimizers/base_optimizer.py @@ -204,6 +204,9 @@ def iterations(self): def _track_variable(self, variable): self._tracker.add_to_store("variables", variable) + def _get_variable_updater(self, variable): + return getattr(variable, "updater", None) + @tracking.no_automatic_dependency_tracking def build(self, variables): if self.use_ema: @@ -212,6 +215,11 @@ def build(self, variables): self._accumulated_gradients = [] for i, variable in enumerate(variables): self._trainable_variables_indices[self._var_key(variable)] = i + custom_updater = self._get_variable_updater(variable) + if custom_updater is not None: + # Build the updater. + custom_updater.build(self, variable) + if self.use_ema: self._model_variables_moving_average.append( self.add_variable_from_reference( @@ -431,10 +439,8 @@ def apply(self, grads, trainable_variables=None): # Overwrite targeted variables directly with their gradients if # their `overwrite_with_gradient` is set. - grads, trainable_variables = ( - self._overwrite_variables_directly_with_gradients( - grads, trainable_variables - ) + grads, trainable_variables = self.__handle_custom_updaters( + grads, trainable_variables ) if len(list(grads)) == 0: @@ -698,21 +704,14 @@ def _get_current_learning_rate(self): return self._learning_rate() return self._learning_rate - def _overwrite_variables_directly_with_gradients(self, grads, vars): - """Overwrite the variables directly by their gradients. - - This method is designed for a special case where we want to overwrite - the variable directly with its computed gradient. For example, in float8 - training, new `scale` and `amax_history` are computed as gradients, and - we want to overwrite them directly instead of following the typical - procedure such as gradient descent with a learning rate, gradient - clipping and weight decaying. + def __handle_custom_updaters(self, grads, vars): + """Update any variable that has a custom updater. After the update, the processed pairs will be filtered out. """ # Shortcut for `tf.Variable` because it doesn't have a - # `overwrite_with_gradient` attr - if any(not hasattr(v, "overwrite_with_gradient") for v in vars): + # `updater` attr. + if not any(self._get_variable_updater(v) is not None for v in vars): return grads, vars # Shallow copies @@ -722,33 +721,8 @@ def _overwrite_variables_directly_with_gradients(self, grads, vars): # Iterate from right to left for safe popping for i in range(len(filtered_grads) - 1, -1, -1): g, v = filtered_grads[i], filtered_vars[i] - if v.overwrite_with_gradient: - if self.gradient_accumulation_steps: - # Utilize a stateless manner for JAX compatibility - steps = self.gradient_accumulation_steps - is_update_step = (self._iterations + 1) % steps == 0 - acc_g = self._accumulated_gradients[ - self._get_variable_index(v) - ] - # `ops.maximum` is utilized for gradient accumulation for - # `overwrite_with_gradient=True` variables - new_g_acc = ops.cond( - is_update_step, - lambda: ops.zeros(g.shape, dtype=g.dtype), - lambda: ops.maximum(g, acc_g), - ) - new_g = ops.cond( - is_update_step, - lambda: ops.maximum(g, acc_g), - lambda: g, - ) - new_v = ops.cond( - is_update_step, lambda: new_g, lambda: v.value - ) - v.assign(new_v) - acc_g.assign(new_g_acc) - else: - v.assign(g) + if v.updater: + v.updater.update_step(g, v) filtered_grads.pop(i) filtered_vars.pop(i) return filtered_grads, filtered_vars @@ -926,6 +900,11 @@ def finalize_variable_values(self, var_list): # optimizer. self._overwrite_model_variables_with_average_value(var_list) + for var in var_list: + updater = self._get_variable_updater(var) + if updater is not None: + updater.finalize_variable_value(var) + def _obj_type(self): return "Optimizer" diff --git a/keras/src/optimizers/loss_scale_optimizer.py b/keras/src/optimizers/loss_scale_optimizer.py index 1e7449b2fd81..9be5401f668c 100644 --- a/keras/src/optimizers/loss_scale_optimizer.py +++ b/keras/src/optimizers/loss_scale_optimizer.py @@ -102,12 +102,6 @@ def stateless_apply(self, optimizer_variables, grads, trainable_variables): ), ) - def _overwrite_variable_with_gradient(self, variable): - return ( - hasattr(variable, "overwrite_with_gradient") - and variable.overwrite_with_gradient - ) - def _stateless_handle_finite_grads( self, optimizer_variables, grads, trainable_variables ): @@ -137,7 +131,7 @@ def increment(): scale = self.dynamic_scale unscaled_grads = [ g - if g is None or self._overwrite_variable_with_gradient(v) + if g is None or self._get_variable_updater(v) is not None else ops.divide(g, scale) for g, v in zip(grads, trainable_variables) ] @@ -183,7 +177,7 @@ def _stateful_handle_finite_grads(self, grads, trainable_variables): tvs = trainable_variables or self._trainable_variables unscaled_grads = [ g - if g is None or self._overwrite_variable_with_gradient(v) + if g is None or self._get_variable_updater(v) is not None else ops.divide(g, scale) for g, v in zip(grads, tvs) ] diff --git a/keras/src/optimizers/optimizer.py b/keras/src/optimizers/optimizer.py index c285b814ba74..6ac3a98fec41 100644 --- a/keras/src/optimizers/optimizer.py +++ b/keras/src/optimizers/optimizer.py @@ -1,4 +1,5 @@ from keras.src import backend +from keras.src import ops from keras.src.api_export import keras_export from keras.src.optimizers import base_optimizer @@ -23,5 +24,90 @@ class Optimizer(BackendOptimizer, base_optimizer.BaseOptimizer): pass +@keras_export("keras.optimizers.VariableUpdater") +class VariableUpdater: + """Allows special handling of variable updates.""" + + def build(self, optimizer, variable): + """Set up any state that might depend on the optimizer. + + This may add variables directly to the optimizer for updating state. + + Args: + optimizer: The optimizer used to update the variables during training. + variable: Variable to update. + """ + pass + + def update_step(self, gradient, variable): + """Update the variable state using the supplied gradient. + + Args: + gradient: Gradient for the variable. + variable: Variable to update. + """ + pass + + def finalize_variable_value(self, variable): + """Set the final value of the trainable variable. + + Sometimes there are some extra steps before ending the variable updates, + such as overriding the model variables with its average value. + + Args: + variable: Variable to finalize. + """ + pass + + +class OverwriteScaleWithGradientUpdater(VariableUpdater): + """Special variable update handler for float8 quantization scales. + + The "gradient" of the scale factor (scale, amax_history) is actually the + updated scale to assign to the variable. Supports gradient accumulation + steps, in which the maximum scale factor between intermediate gradient + steps is recorded. + """ + + def build(self, optimizer, variable): + # Keep reference copy of iterations so we can update gradient + # accumulators appropriately. + self._iterations = optimizer._iterations + # Support gradient accumulation by adding an accumulator directly + # to the optimizer. + self._gradient_accumulation_steps = ( + optimizer.gradient_accumulation_steps + ) + if self._gradient_accumulation_steps: + self.gradient_accumulator = optimizer.add_variable_from_reference( + reference_variable=variable, name="gradient_accumulation" + ) + + def update_step(self, gradient, variable): + if self._gradient_accumulation_steps: + # Utilize a stateless manner for JAX compatibility + steps = self._gradient_accumulation_steps + is_update_step = (self._iterations + 1) % steps == 0 + # Keep track of the maximum scale factor encountered. + new_g_acc = ops.cond( + is_update_step, + lambda: ops.zeros(gradient.shape, dtype=gradient.dtype), + lambda: ops.maximum(gradient, self.gradient_accumulator), + ) + new_g = ops.cond( + is_update_step, + lambda: ops.maximum(gradient, self.gradient_accumulator), + lambda: gradient, + ) + new_v = ops.cond( + is_update_step, lambda: new_g, lambda: variable.value + ) + variable.assign(new_v) + self.gradient_accumulator.assign(new_g_acc) + else: + # Assign scale "gradient" directly to variable. + variable.assign(gradient) + + Optimizer.__doc__ = base_optimizer.BaseOptimizer.__doc__ base_optimizer_keyword_args = base_optimizer.base_optimizer_keyword_args diff --git a/keras/src/optimizers/optimizer_test.py b/keras/src/optimizers/optimizer_test.py index 7d661df9a3c0..606bd20acfd2 100644 --- a/keras/src/optimizers/optimizer_test.py +++ b/keras/src/optimizers/optimizer_test.py @@ -152,6 +152,19 @@ def test_constraints_are_applied(self): optimizer.apply_gradients([(grad, v)]) self.assertAlmostEqual(np.min(v), 0.0) + def test_custom_variable_updater(self): + class IncrementVariable(optimizers.VariableUpdater): + def update_step(self, gradient, variable): + variable.assign_add(1.0) + + orig_value = np.random.random((2, 2)) - 1.0 + v = backend.Variable(orig_value) + v.updater = IncrementVariable() + optimizer = optimizers.SGD(learning_rate=0.0001) + grad = backend.numpy.zeros((2, 2)) + optimizer.apply_gradients([(grad, v)]) + self.assertAllClose(v, orig_value + 1) + def test_get_method(self): obj = optimizers.get("sgd") self.assertIsInstance(obj, optimizers.SGD) @@ -298,7 +311,7 @@ def test_overwrite_with_gradient_with_gradient_accumulation(self): self.assertAllClose(v, [[1.0, 2.0], [3.0, 4.0]]) self.assertAllClose(v2, [[1.0, 2.0], [3.0, 4.0]]) self.assertAllClose( - optimizer._accumulated_gradients[0], [[1.0, 1.0], [1.0, 1.0]] + v.updater.gradient_accumulator, [[1.0, 1.0], [1.0, 1.0]] ) self.assertAllClose( optimizer._accumulated_gradients[1], [[1.0, 1.0], [1.0, 1.0]] @@ -311,7 +324,7 @@ def test_overwrite_with_gradient_with_gradient_accumulation(self): self.assertAllClose(v, [[2.0, 2.0], [2.0, 2.0]]) self.assertAllClose(v2, [[-0.5, 0.5], [1.5, 2.5]]) self.assertAllClose( - optimizer._accumulated_gradients[0], [[0.0, 0.0], [0.0, 0.0]] + v.updater.gradient_accumulator, [[0.0, 0.0], [0.0, 0.0]] ) self.assertAllClose( optimizer._accumulated_gradients[1], [[0.0, 0.0], [0.0, 0.0]] @@ -324,7 +337,7 @@ def test_overwrite_with_gradient_with_gradient_accumulation(self): self.assertAllClose(v, [[2.0, 2.0], [2.0, 2.0]]) self.assertAllClose(v2, [[-0.5, 0.5], [1.5, 2.5]]) self.assertAllClose( - optimizer._accumulated_gradients[0], [[1.0, 1.0], [1.0, 1.0]] + v.updater.gradient_accumulator, [[1.0, 1.0], [1.0, 1.0]] ) self.assertAllClose( optimizer._accumulated_gradients[1], [[1.0, 1.0], [1.0, 1.0]]