import tensorflow as tf from tensorflow.python.framework import ops from tensorflow.python.keras.optimizer_v2 import optimizer_v2 from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.training import training_ops class MomentumOptimizer(tf.keras.optimizers.Optimizer): r"""Gradient descent (with momentum) optimizer. Does not use keras momentum. Args: learning_rate: A `Tensor`, floating point value, or a schedule that is a `tf.keras.optimizers.schedules.LearningRateSchedule`, or a callable that takes no arguments and returns the actual value to use. The learning rate. Defaults to 0.01. momentum: float hyperparameter >= 0 that accelerates gradient descent in the relevant direction and dampens oscillations. Defaults to 0, i.e., vanilla gradient descent. nesterov: boolean. Whether to apply Nesterov momentum. Defaults to `False`. name: Optional name prefix for the operations created when applying gradients. Defaults to `"SGD"`. **kwargs: Keyword arguments. Allowed to be one of `"clipnorm"` or `"clipvalue"`. `"clipnorm"` (float) clips gradients by norm; `"clipvalue"` (float) clips gradients by value. """ _HAS_AGGREGATE_GRAD = True def __init__(self, learning_rate=0.01, momentum=0.0, nesterov=False, name="MomentumOptimizer", **kwargs): super(MomentumOptimizer, self).__init__(name, **kwargs) self._set_hyper("learning_rate", kwargs.get("lr", learning_rate)) self._set_hyper("decay", self._initial_decay) self._momentum = False if isinstance(momentum, ops.Tensor) or callable(momentum) or momentum > 0: self._momentum = True if isinstance(momentum, (int, float)) and (momentum < 0 or momentum > 1): raise ValueError("`momentum` must be between [0, 1].") self._set_hyper("momentum", momentum) self.nesterov = nesterov def _create_slots(self, var_list): if self._momentum: for var in var_list: self.add_slot(var, "momentum") def _prepare_local(self, var_device, var_dtype, apply_state): super(MomentumOptimizer, self)._prepare_local(var_device, var_dtype, apply_state) apply_state[(var_device, var_dtype)]["momentum"] = array_ops.identity( self._get_hyper("momentum", var_dtype)) def _resource_apply_dense(self, grad, var, apply_state=None): var_device, var_dtype = var.device, var.dtype.base_dtype coefficients = ((apply_state or {}).get((var_device, var_dtype)) or self._fallback_apply_state(var_device, var_dtype)) if self._momentum: momentum_var = self.get_slot(var, "momentum") return training_ops.resource_apply_momentum( var.handle, momentum_var.handle, coefficients["lr_t"], grad, coefficients["momentum"], use_locking=self._use_locking, use_nesterov=self.nesterov) else: return training_ops.resource_apply_gradient_descent( var.handle, coefficients["lr_t"], grad, use_locking=self._use_locking) def _resource_apply_sparse_duplicate_indices(self, grad, var, indices, **kwargs): if self._momentum: return super(MomentumOptimizer, self)._resource_apply_sparse_duplicate_indices( grad, var, indices, **kwargs) else: var_device, var_dtype = var.device, var.dtype.base_dtype coefficients = (kwargs.get("apply_state", {}).get((var_device, var_dtype)) or self._fallback_apply_state(var_device, var_dtype)) return resource_variable_ops.resource_scatter_add( var.handle, indices, -grad * coefficients["lr_t"]) def _resource_apply_sparse(self, grad, var, indices, apply_state=None): # This method is only needed for momentum optimization. var_device, var_dtype = var.device, var.dtype.base_dtype coefficients = ((apply_state or {}).get((var_device, var_dtype)) or self._fallback_apply_state(var_device, var_dtype)) momentum_var = self.get_slot(var, "momentum") return training_ops.resource_sparse_apply_momentum( var.handle, momentum_var.handle, coefficients["lr_t"], grad, indices, coefficients["momentum"], use_locking=self._use_locking, use_nesterov=self.nesterov) def get_config(self): config = super(MomentumOptimizer, self).get_config() config.update({ "learning_rate": self._serialize_hyperparameter("learning_rate"), "decay": self._serialize_hyperparameter("decay"), "momentum": self._serialize_hyperparameter("momentum"), "nesterov": self.nesterov, }) return config