|
| 1 | +# coding=utf-8 |
| 2 | +# Copyright 2019 The Tensor2Tensor Authors. |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +"""Multi-step optimizers simulating large batches. |
| 17 | +
|
| 18 | +Optimizer variants which make it possible to use very large batch sizes with |
| 19 | +limited GPU memory. Optimizers in this module accumulate the gradients for n |
| 20 | +batches, and call the optimizer's update rule every n batches with the |
| 21 | +accumulated gradients. |
| 22 | +
|
| 23 | +See [Saunders et al., 2018](https://arxiv.org/abs/1805.00456) for details. |
| 24 | +""" |
| 25 | +from __future__ import absolute_import |
| 26 | +from __future__ import division |
| 27 | +from __future__ import print_function |
| 28 | + |
| 29 | +import tensorflow as tf |
| 30 | +from tensorflow.python.eager import context |
| 31 | +from tensorflow.python.framework import dtypes |
| 32 | +from tensorflow.python.framework import ops |
| 33 | +from tensorflow.python.ops import control_flow_ops |
| 34 | +from tensorflow.python.ops import math_ops |
| 35 | +from tensorflow.python.ops import resource_variable_ops |
| 36 | +from tensorflow.python.ops import state_ops |
| 37 | +from tensorflow.python.training import optimizer |
| 38 | +from tensorflow.python.training import training_ops |
| 39 | +from tensorflow.python.util.tf_export import tf_export |
| 40 | +from tensorflow.keras import backend as K |
| 41 | + |
| 42 | + |
| 43 | +class MultistepAdamOptimizer(optimizer.Optimizer): |
| 44 | + """Adam with SGD updates every n steps with accumulated gradients.""" |
| 45 | + |
| 46 | + def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, |
| 47 | + use_locking=False, name="Adam", n=1): |
| 48 | + super(MultistepAdamOptimizer, self).__init__(use_locking=use_locking, name=name) |
| 49 | + self._lr = learning_rate |
| 50 | + self._beta1 = beta1 |
| 51 | + self._beta2 = beta2 |
| 52 | + self._epsilon = epsilon |
| 53 | + # Tensor versions of the constructor arguments, created in _prepare(). |
| 54 | + self._lr_t = None |
| 55 | + self._beta1_t = None |
| 56 | + self._beta2_t = None |
| 57 | + self._epsilon_t = None |
| 58 | + self._n = n # Call Adam optimizer every n batches with accumulated grads |
| 59 | + self._n_t = None # n as tensor |
| 60 | + |
| 61 | + def _get_beta_accumulators(self): |
| 62 | + with ops.init_scope(): |
| 63 | + if context.executing_eagerly(): |
| 64 | + graph = None |
| 65 | + else: |
| 66 | + graph = ops.get_default_graph() |
| 67 | + return (self._get_non_slot_variable("beta1_power", graph=graph), |
| 68 | + self._get_non_slot_variable("beta2_power", graph=graph)) |
| 69 | + |
| 70 | + def _create_slots(self, var_list): |
| 71 | + """Create slot variables for Adam with accumulated gradients.""" |
| 72 | + first_var = min(var_list, key=lambda x: x.name) |
| 73 | + self._create_non_slot_variable(initial_value=self._beta1, name="beta1_power", colocate_with=first_var) |
| 74 | + self._create_non_slot_variable(initial_value=self._beta2, name="beta2_power", colocate_with=first_var) |
| 75 | + #if iter is initialized as an int32, this optimizer could not run |
| 76 | + #with tensorflow_hub with a tensorflow-gpu version |
| 77 | + self._create_non_slot_variable(initial_value=0.0 if self._n == 1 else 1.0, name="iter", colocate_with=first_var) |
| 78 | + # Create slots for the first and second moments, as well as grad_acc. |
| 79 | + for v in var_list: |
| 80 | + self._zeros_slot(v, "m", self._name) |
| 81 | + self._zeros_slot(v, "v", self._name) |
| 82 | + self._zeros_slot(v, "grad_acc", self._name) |
| 83 | + |
| 84 | + |
| 85 | + def _get_iter_variable(self): |
| 86 | + graph = ( |
| 87 | + None if tf.executing_eagerly() else tf.get_default_graph()) |
| 88 | + return self._get_non_slot_variable("iter", graph=graph) |
| 89 | + |
| 90 | + def _prepare(self): |
| 91 | + lr = self._call_if_callable(self._lr) |
| 92 | + beta1 = self._call_if_callable(self._beta1) |
| 93 | + beta2 = self._call_if_callable(self._beta2) |
| 94 | + epsilon = self._call_if_callable(self._epsilon) |
| 95 | + self._beta1_t = ops.convert_to_tensor(beta1, name="beta1") |
| 96 | + self._beta2_t = ops.convert_to_tensor(beta2, name="beta2") |
| 97 | + self._lr_t = ops.convert_to_tensor(lr, name="learning_rate") |
| 98 | + self._epsilon_t = ops.convert_to_tensor(epsilon, name="epsilon") |
| 99 | + self._n_t = tf.convert_to_tensor(self._n, name="n") |
| 100 | + |
| 101 | + def _apply_cond(self, apply_fn, grad, var, *args, **kwargs): |
| 102 | + """Apply conditionally if counter is zero.""" |
| 103 | + grad_acc = self.get_slot(var, "grad_acc") |
| 104 | + |
| 105 | + def apply_adam(grad_acc, apply_fn, grad, var, *args, **kwargs): |
| 106 | + total_grad = (grad_acc + grad) / tf.cast(self._n_t, grad.dtype) |
| 107 | + adam_op = apply_fn(total_grad, var, *args, **kwargs) |
| 108 | + with tf.control_dependencies([adam_op]): |
| 109 | + grad_acc_to_zero_op = grad_acc.assign(tf.zeros_like(grad_acc), |
| 110 | + use_locking=self._use_locking) |
| 111 | + return tf.group(adam_op, grad_acc_to_zero_op) |
| 112 | + |
| 113 | + def accumulate_gradient(grad_acc, grad): |
| 114 | + assign_op = tf.assign_add(grad_acc, grad, use_locking=self._use_locking) |
| 115 | + return tf.group(assign_op) # Strip return value |
| 116 | + |
| 117 | + return tf.cond( |
| 118 | + tf.equal(self._get_iter_variable(), 0), |
| 119 | + lambda: apply_adam(grad_acc, apply_fn, grad, var, *args, **kwargs), |
| 120 | + lambda: accumulate_gradient(grad_acc, grad)) |
| 121 | + |
| 122 | + def _apply_dense(self, grad, var): |
| 123 | + return self._apply_cond(self._apply_dense_in_action, grad, var) |
| 124 | + |
| 125 | + def _apply_dense_in_action(self, grad, var): |
| 126 | + m = self.get_slot(var, "m") |
| 127 | + v = self.get_slot(var, "v") |
| 128 | + beta1_power, beta2_power = self._get_beta_accumulators() |
| 129 | + return training_ops.apply_adam(var, m, v, |
| 130 | + math_ops.cast(beta1_power, var.dtype.base_dtype), |
| 131 | + math_ops.cast(beta2_power, var.dtype.base_dtype), |
| 132 | + math_ops.cast(self._lr_t, var.dtype.base_dtype), |
| 133 | + math_ops.cast(self._beta1_t, var.dtype.base_dtype), |
| 134 | + math_ops.cast(self._beta2_t, var.dtype.base_dtype), |
| 135 | + math_ops.cast(self._epsilon_t, var.dtype.base_dtype), |
| 136 | + grad, |
| 137 | + use_locking=self._use_locking).op |
| 138 | + |
| 139 | + def _resource_apply_dense(self, grad, var): |
| 140 | + return self._apply_cond(self._resource_apply_dense_in_action, grad, var) |
| 141 | + |
| 142 | + def _resource_apply_dense_in_action(self, grad, var): |
| 143 | + m = self.get_slot(var, "m") |
| 144 | + v = self.get_slot(var, "v") |
| 145 | + beta1_power, beta2_power = self._get_beta_accumulators() |
| 146 | + return training_ops.resource_apply_adam(var.handle, |
| 147 | + m.handle, |
| 148 | + v.handle, |
| 149 | + math_ops.cast(beta1_power, grad.dtype.base_dtype), |
| 150 | + math_ops.cast(beta2_power, grad.dtype.base_dtype), |
| 151 | + math_ops.cast(self._lr_t, var.dtype.base_dtype), |
| 152 | + math_ops.cast(self._beta1_t, grad.dtype.base_dtype), |
| 153 | + math_ops.cast(self._beta2_t, grad.dtype.base_dtype), |
| 154 | + math_ops.cast(self._epsilon_t, grad.dtype.base_dtype), |
| 155 | + grad, use_locking=self._use_locking) |
| 156 | + |
| 157 | + def _apply_sparse_shared(self, grad, var, indices, scatter_add): |
| 158 | + beta1_power, beta2_power = self._get_beta_accumulators() |
| 159 | + beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) |
| 160 | + beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) |
| 161 | + lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) |
| 162 | + beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) |
| 163 | + beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) |
| 164 | + epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) |
| 165 | + lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) |
| 166 | + # m_t = beta1 * m + (1 - beta1) * g_t |
| 167 | + m = self.get_slot(var, "m") |
| 168 | + m_scaled_g_values = grad * (1 - beta1_t) |
| 169 | + m_t = state_ops.assign(m, m * beta1_t, use_locking=self._use_locking) |
| 170 | + with ops.control_dependencies([m_t]): |
| 171 | + m_t = scatter_add(m, indices, m_scaled_g_values) |
| 172 | + # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) |
| 173 | + v = self.get_slot(var, "v") |
| 174 | + v_scaled_g_values = (grad * grad) * (1 - beta2_t) |
| 175 | + v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking) |
| 176 | + with ops.control_dependencies([v_t]): |
| 177 | + v_t = scatter_add(v, indices, v_scaled_g_values) |
| 178 | + v_sqrt = math_ops.sqrt(v_t) |
| 179 | + var_update = state_ops.assign_sub(var, lr * m_t / (v_sqrt + epsilon_t), use_locking=self._use_locking) |
| 180 | + return control_flow_ops.group(*[var_update, m_t, v_t]) |
| 181 | + |
| 182 | + def _apply_sparse(self, grad, var): |
| 183 | + # TODO(fstahlberg): Implement a sparse version |
| 184 | + tf.logging.warning("MultistepAdamOptimizer does not support sparse updates") |
| 185 | + dense_grad = tf.convert_to_tensor(grad) |
| 186 | + return self._apply_cond(self._apply_dense_in_action, dense_grad, var) |
| 187 | + |
| 188 | + def _resource_apply_sparse_duplicate_indices(self, grad, var, indices): |
| 189 | + tf.logging.warning("MultistepAdamOptimizer does not support sparse updates") |
| 190 | + # Note that conversion to a dense Tensor handles duplicate `indices` |
| 191 | + # correctly (summing them). A real sparse implementation will probably want |
| 192 | + # to override _resource_apply_sparse instead so it gets them de-duplicated |
| 193 | + # automatically. |
| 194 | + dense_grad = tf.convert_to_tensor(tf.IndexedSlices(values=grad, |
| 195 | + indices=indices, dense_shape=tf.shape(var))) |
| 196 | + return self._apply_cond(self._resource_apply_dense_in_action, dense_grad, var) |
| 197 | + |
| 198 | + def _resource_scatter_add(self, x, i, v): |
| 199 | + with ops.control_dependencies( |
| 200 | + [resource_variable_ops.resource_scatter_add(x.handle, i, v)]): |
| 201 | + return x.value() |
| 202 | + |
| 203 | + def _resource_apply_sparse(self, grad, var, indices): |
| 204 | + return self._apply_sparse_shared(grad, var, indices, self._resource_scatter_add) |
| 205 | + |
| 206 | + def _finish(self, update_ops, name_scope): |
| 207 | + """Updates beta_power variables every n batches and incrs counter.""" |
| 208 | + iter_ = self._get_iter_variable() |
| 209 | + beta1_power, beta2_power = self._get_beta_accumulators() |
| 210 | + with tf.control_dependencies(update_ops): |
| 211 | + with tf.colocate_with(iter_): |
| 212 | + def update_beta_op(): |
| 213 | + update_beta1 = beta1_power.assign( |
| 214 | + beta1_power * self._beta1_t, |
| 215 | + use_locking=self._use_locking) |
| 216 | + update_beta2 = beta2_power.assign( |
| 217 | + beta2_power * self._beta2_t, |
| 218 | + use_locking=self._use_locking) |
| 219 | + return tf.group(update_beta1, update_beta2) |
| 220 | + maybe_update_beta = tf.cond( |
| 221 | + tf.equal(iter_, 0), update_beta_op, tf.no_op) |
| 222 | + with tf.control_dependencies([maybe_update_beta]): |
| 223 | + #TODO(Cuong): It is suboptimal here because we have to cast twice (float to int, |
| 224 | + #and then int to float) |
| 225 | + update_iter = iter_.assign(K.cast(tf.mod(K.cast(iter_ + 1.0, dtype=dtypes.int32), self._n_t), dtype=dtypes.float32), |
| 226 | + use_locking=self._use_locking) |
| 227 | + return tf.group( |
| 228 | + *update_ops + [update_iter, maybe_update_beta], name=name_scope) |
| 229 | + |
0 commit comments