Skip to content

DRAFT: Add custom variable updater. #21225

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras/api/optimizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,6 @@
from keras.src.optimizers.muon import Muon as Muon
from keras.src.optimizers.nadam import Nadam as Nadam
from keras.src.optimizers.optimizer import Optimizer as Optimizer
from keras.src.optimizers.optimizer import VariableUpdater as VariableUpdater
from keras.src.optimizers.rmsprop import RMSprop as RMSprop
from keras.src.optimizers.sgd import SGD as SGD
28 changes: 28 additions & 0 deletions keras/src/backend/common/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ def __init__(
self._autocast = bool(autocast)
self._aggregation = aggregation
self._synchronization = synchronization
# Custom variable updater.
self._updater = None
# `self._overwrite_with_gradient` is an internal property to determine
# whether this variable should be overwritten by the computed gradient.
# Ref: https://github.com/google/flax/blob/main/flax/linen/fp8_ops.py
Expand Down Expand Up @@ -334,6 +336,29 @@ def path(self):
"""The path of the variable within the Keras model or layer."""
return self._path

@property
def updater(self):
"""Custom variable updater.

This property is designed for special-casing variable updates during
training, such as quantized float8 `scale` and `amax_history`, where
the gradients represent updated scale factors, or for updating large
embedding tables, where we need to handle sparse updates to a dense
table.
"""
return self._updater

@updater.setter
def updater(self, updater):
from keras.src import optimizers

if not isinstance(updater, optimizers.VariableUpdater):
raise TypeError(
"`updater` must be a `keras.optimizers.VariableUpdater`. "
f"Received: {updater.__class__.__name__}."
)
self._updater = updater

@property
def overwrite_with_gradient(self):
"""Whether this variable should be overwritten by the gradient.
Expand All @@ -355,6 +380,9 @@ def overwrite_with_gradient(self, value):
f"Received: {value}"
)
self._overwrite_with_gradient = value
from keras.src import optimizers

self._updater = optimizers.OverwriteScaleWithGradientUpdater()

@property
def regularizer(self):
Expand Down
2 changes: 2 additions & 0 deletions keras/src/optimizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from keras.src.optimizers.muon import Muon
from keras.src.optimizers.nadam import Nadam
from keras.src.optimizers.optimizer import Optimizer
from keras.src.optimizers.optimizer import OverwriteScaleWithGradientUpdater
from keras.src.optimizers.optimizer import VariableUpdater
from keras.src.optimizers.rmsprop import RMSprop
from keras.src.optimizers.sgd import SGD
from keras.src.saving import serialization_lib
Expand Down
63 changes: 21 additions & 42 deletions keras/src/optimizers/base_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,9 @@ def iterations(self):
def _track_variable(self, variable):
self._tracker.add_to_store("variables", variable)

def _get_variable_updater(self, variable):
return getattr(variable, "updater", None)

@tracking.no_automatic_dependency_tracking
def build(self, variables):
if self.use_ema:
Expand All @@ -212,6 +215,11 @@ def build(self, variables):
self._accumulated_gradients = []
for i, variable in enumerate(variables):
self._trainable_variables_indices[self._var_key(variable)] = i
custom_updater = self._get_variable_updater(variable)
if custom_updater is not None:
# Build the updater.
custom_updater.build(self, variable)

if self.use_ema:
self._model_variables_moving_average.append(
self.add_variable_from_reference(
Expand Down Expand Up @@ -431,10 +439,8 @@ def apply(self, grads, trainable_variables=None):

# Overwrite targeted variables directly with their gradients if
# their `overwrite_with_gradient` is set.
grads, trainable_variables = (
self._overwrite_variables_directly_with_gradients(
grads, trainable_variables
)
grads, trainable_variables = self.__handle_custom_updaters(
grads, trainable_variables
)

if len(list(grads)) == 0:
Expand Down Expand Up @@ -698,21 +704,14 @@ def _get_current_learning_rate(self):
return self._learning_rate()
return self._learning_rate

def _overwrite_variables_directly_with_gradients(self, grads, vars):
"""Overwrite the variables directly by their gradients.

This method is designed for a special case where we want to overwrite
the variable directly with its computed gradient. For example, in float8
training, new `scale` and `amax_history` are computed as gradients, and
we want to overwrite them directly instead of following the typical
procedure such as gradient descent with a learning rate, gradient
clipping and weight decaying.
def __handle_custom_updaters(self, grads, vars):
"""Update any variable that has a custom updater.

After the update, the processed pairs will be filtered out.
"""
# Shortcut for `tf.Variable` because it doesn't have a
# `overwrite_with_gradient` attr
if any(not hasattr(v, "overwrite_with_gradient") for v in vars):
# `updater` attr.
if not any(self._get_variable_updater(v) is not None for v in vars):
return grads, vars

# Shallow copies
Expand All @@ -722,33 +721,8 @@ def _overwrite_variables_directly_with_gradients(self, grads, vars):
# Iterate from right to left for safe popping
for i in range(len(filtered_grads) - 1, -1, -1):
g, v = filtered_grads[i], filtered_vars[i]
if v.overwrite_with_gradient:
if self.gradient_accumulation_steps:
# Utilize a stateless manner for JAX compatibility
steps = self.gradient_accumulation_steps
is_update_step = (self._iterations + 1) % steps == 0
acc_g = self._accumulated_gradients[
self._get_variable_index(v)
]
# `ops.maximum` is utilized for gradient accumulation for
# `overwrite_with_gradient=True` variables
new_g_acc = ops.cond(
is_update_step,
lambda: ops.zeros(g.shape, dtype=g.dtype),
lambda: ops.maximum(g, acc_g),
)
new_g = ops.cond(
is_update_step,
lambda: ops.maximum(g, acc_g),
lambda: g,
)
new_v = ops.cond(
is_update_step, lambda: new_g, lambda: v.value
)
v.assign(new_v)
acc_g.assign(new_g_acc)
else:
v.assign(g)
if v.updater:
v.updater.update_step(g, v)
filtered_grads.pop(i)
filtered_vars.pop(i)
return filtered_grads, filtered_vars
Expand Down Expand Up @@ -926,6 +900,11 @@ def finalize_variable_values(self, var_list):
# optimizer.
self._overwrite_model_variables_with_average_value(var_list)

for var in var_list:
updater = self._get_variable_updater(var)
if updater is not None:
updater.finalize_variable_value(var)

def _obj_type(self):
return "Optimizer"

Expand Down
10 changes: 2 additions & 8 deletions keras/src/optimizers/loss_scale_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,6 @@ def stateless_apply(self, optimizer_variables, grads, trainable_variables):
),
)

def _overwrite_variable_with_gradient(self, variable):
return (
hasattr(variable, "overwrite_with_gradient")
and variable.overwrite_with_gradient
)

def _stateless_handle_finite_grads(
self, optimizer_variables, grads, trainable_variables
):
Expand Down Expand Up @@ -137,7 +131,7 @@ def increment():
scale = self.dynamic_scale
unscaled_grads = [
g
if g is None or self._overwrite_variable_with_gradient(v)
if g is None or self._get_variable_updater(v) is not None
else ops.divide(g, scale)
for g, v in zip(grads, trainable_variables)
]
Expand Down Expand Up @@ -183,7 +177,7 @@ def _stateful_handle_finite_grads(self, grads, trainable_variables):
tvs = trainable_variables or self._trainable_variables
unscaled_grads = [
g
if g is None or self._overwrite_variable_with_gradient(v)
if g is None or self._get_variable_updater(v) is not None
else ops.divide(g, scale)
for g, v in zip(grads, tvs)
]
Expand Down
86 changes: 86 additions & 0 deletions keras/src/optimizers/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from keras.src import backend
from keras.src import ops
from keras.src.api_export import keras_export
from keras.src.optimizers import base_optimizer

Expand All @@ -23,5 +24,90 @@ class Optimizer(BackendOptimizer, base_optimizer.BaseOptimizer):
pass


@keras_export("keras.optimizers.VariableUpdater")
class VariableUpdater:
"""Allows special handling of variable updates."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is intended to be public in the API then we should provide an explanation and a usage example in the docstring

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! Yes, I'd like it to be public so I can inherit from it without needing to import internal keras APIs.

Is this location (within optimizer.py) okay, or would you prefer it somewhere else, (e.g. in its own file)?


def build(self, optimizer, variable):
"""Set up any state that might depend on the optimizer.

This may add variables directly to the optimizer for updating state.

Args:
optimizer: The optimizer used to update the variables during training.
variable: Variable to update.
"""
pass

def update_step(self, gradient, variable):
"""Update the variable state using the supplied gradient.

Args:
gradient: Gradient for the variable.
variable: Variable to update.
"""
pass

def finalize_variable_value(self, variable):
"""Set the final value of the trainable variable.

Sometimes there are some extra steps before ending the variable updates,
such as overriding the model variables with its average value.

Args:
variable: Variable to finalize.
"""
pass


class OverwriteScaleWithGradientUpdater(VariableUpdater):
"""Special variable update handler for float8 quantization scales.

The "gradient" of the scale factor (scale, amax_history) is actually the
updated scale to assign to the variable. Supports gradient accumulation
steps, in which the maximum scale factor between intermediate gradient
steps is recorded.
"""

def build(self, optimizer, variable):
# Keep reference copy of iterations so we can update gradient
# accumulators appropriately.
self._iterations = optimizer._iterations
# Support gradient accumulation by adding an accumulator directly
# to the optimizer.
self._gradient_accumulation_steps = (
optimizer.gradient_accumulation_steps
)
if self._gradient_accumulation_steps:
self.gradient_accumulator = optimizer.add_variable_from_reference(
reference_variable=variable, name="gradient_accumulation"
)

def update_step(self, gradient, variable):
if self._gradient_accumulation_steps:
# Utilize a stateless manner for JAX compatibility
steps = self._gradient_accumulation_steps
is_update_step = (self._iterations + 1) % steps == 0
# Keep track of the maximum scale factor encountered.
new_g_acc = ops.cond(
is_update_step,
lambda: ops.zeros(gradient.shape, dtype=gradient.dtype),
lambda: ops.maximum(gradient, self.gradient_accumulator),
)
new_g = ops.cond(
is_update_step,
lambda: ops.maximum(gradient, self.gradient_accumulator),
lambda: gradient,
)
new_v = ops.cond(
is_update_step, lambda: new_g, lambda: variable.value
)
variable.assign(new_v)
self.gradient_accumulator.assign(new_g_acc)
else:
# Assign scale "gradient" directly to variable.
variable.assign(gradient)


Optimizer.__doc__ = base_optimizer.BaseOptimizer.__doc__
base_optimizer_keyword_args = base_optimizer.base_optimizer_keyword_args
19 changes: 16 additions & 3 deletions keras/src/optimizers/optimizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,19 @@ def test_constraints_are_applied(self):
optimizer.apply_gradients([(grad, v)])
self.assertAlmostEqual(np.min(v), 0.0)

def test_custom_variable_updater(self):
class IncrementVariable(optimizers.VariableUpdater):
def update_step(self, gradient, variable):
variable.assign_add(1.0)

orig_value = np.random.random((2, 2)) - 1.0
v = backend.Variable(orig_value)
v.updater = IncrementVariable()
optimizer = optimizers.SGD(learning_rate=0.0001)
grad = backend.numpy.zeros((2, 2))
optimizer.apply_gradients([(grad, v)])
self.assertAllClose(v, orig_value + 1)

def test_get_method(self):
obj = optimizers.get("sgd")
self.assertIsInstance(obj, optimizers.SGD)
Expand Down Expand Up @@ -298,7 +311,7 @@ def test_overwrite_with_gradient_with_gradient_accumulation(self):
self.assertAllClose(v, [[1.0, 2.0], [3.0, 4.0]])
self.assertAllClose(v2, [[1.0, 2.0], [3.0, 4.0]])
self.assertAllClose(
optimizer._accumulated_gradients[0], [[1.0, 1.0], [1.0, 1.0]]
v.updater.gradient_accumulator, [[1.0, 1.0], [1.0, 1.0]]
)
self.assertAllClose(
optimizer._accumulated_gradients[1], [[1.0, 1.0], [1.0, 1.0]]
Expand All @@ -311,7 +324,7 @@ def test_overwrite_with_gradient_with_gradient_accumulation(self):
self.assertAllClose(v, [[2.0, 2.0], [2.0, 2.0]])
self.assertAllClose(v2, [[-0.5, 0.5], [1.5, 2.5]])
self.assertAllClose(
optimizer._accumulated_gradients[0], [[0.0, 0.0], [0.0, 0.0]]
v.updater.gradient_accumulator, [[0.0, 0.0], [0.0, 0.0]]
)
self.assertAllClose(
optimizer._accumulated_gradients[1], [[0.0, 0.0], [0.0, 0.0]]
Expand All @@ -324,7 +337,7 @@ def test_overwrite_with_gradient_with_gradient_accumulation(self):
self.assertAllClose(v, [[2.0, 2.0], [2.0, 2.0]])
self.assertAllClose(v2, [[-0.5, 0.5], [1.5, 2.5]])
self.assertAllClose(
optimizer._accumulated_gradients[0], [[1.0, 1.0], [1.0, 1.0]]
v.updater.gradient_accumulator, [[1.0, 1.0], [1.0, 1.0]]
)
self.assertAllClose(
optimizer._accumulated_gradients[1], [[1.0, 1.0], [1.0, 1.0]]
Expand Down
Loading