Home | History | Annotate | Download | only in training
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 """Base class for optimizers."""
     17 # pylint: disable=g-bad-name
     18 
     19 from __future__ import absolute_import
     20 from __future__ import division
     21 from __future__ import print_function
     22 
     23 import abc
     24 
     25 from tensorflow.python.eager import backprop
     26 from tensorflow.python.eager import context
     27 from tensorflow.python.framework import dtypes
     28 from tensorflow.python.framework import ops
     29 from tensorflow.python.ops import array_ops
     30 from tensorflow.python.ops import control_flow_ops
     31 from tensorflow.python.ops import gradients
     32 from tensorflow.python.ops import math_ops
     33 from tensorflow.python.ops import resource_variable_ops
     34 from tensorflow.python.ops import state_ops
     35 from tensorflow.python.ops import variable_scope
     36 from tensorflow.python.ops import variables
     37 from tensorflow.python.training import checkpointable
     38 from tensorflow.python.training import slot_creator
     39 from tensorflow.python.util import nest
     40 from tensorflow.python.util.tf_export import tf_export
     41 
     42 
     43 def _get_variable_for(v):
     44   """Returns the ResourceVariable responsible for v, or v if not necessary."""
     45   if context.in_eager_mode():
     46     return v
     47   if v.op.type == "VarHandleOp":
     48     for var in variables.trainable_variables():
     49       if (isinstance(var, resource_variable_ops.ResourceVariable)
     50           and var.handle.op is v.op):
     51         return var
     52     raise ValueError("Got %s but could not locate source variable." % (str(v)))
     53   return v
     54 
     55 
     56 def _deduplicate_indexed_slices(values, indices):
     57   """Sums `values` associated with any non-unique `indices`.
     58 
     59   Args:
     60     values: A `Tensor` with rank >= 1.
     61     indices: A one-dimensional integer `Tensor`, indexing into the first
     62       dimension of `values` (as in an IndexedSlices object).
     63   Returns:
     64     A tuple of (`summed_values`, `unique_indices`) where `unique_indices` is a
     65     de-duplicated version of `indices` and `summed_values` contains the sum of
     66     `values` slices associated with each unique index.
     67   """
     68   unique_indices, new_index_positions = array_ops.unique(indices)
     69   summed_values = math_ops.unsorted_segment_sum(
     70       values, new_index_positions,
     71       array_ops.shape(unique_indices)[0])
     72   return (summed_values, unique_indices)
     73 
     74 
     75 def _var_key(var):
     76   if context.in_eager_mode():
     77     return var._shared_name  # pylint: disable=protected-access
     78   return (var.op.graph, var.op.name)
     79 
     80 
     81 class _OptimizableVariable(object):
     82   """Interface for abstracting over variables in the optimizers."""
     83 
     84   @abc.abstractmethod
     85   def target(self):
     86     """Returns the optimization target for this variable."""
     87     raise NotImplementedError("Calling an abstract method.")
     88 
     89   @abc.abstractmethod
     90   def update_op(self, optimizer, g):
     91     """Returns the update ops for updating the variable."""
     92     raise NotImplementedError("Calling an abstract method.")
     93 
     94 
     95 class _RefVariableProcessor(_OptimizableVariable):
     96   """Processor for Variable."""
     97 
     98   def __init__(self, v):
     99     self._v = v
    100 
    101   def target(self):
    102     return self._v._ref()  # pylint: disable=protected-access
    103 
    104   def update_op(self, optimizer, g):
    105     if isinstance(g, ops.Tensor):
    106       update_op = optimizer._apply_dense(g, self._v)  # pylint: disable=protected-access
    107       if self._v.constraint is not None:
    108         with ops.control_dependencies([update_op]):
    109           return self._v.assign(self._v.constraint(self._v))
    110       else:
    111         return update_op
    112     else:
    113       assert isinstance(g, ops.IndexedSlices), ("Gradient ", g, " is neither a "
    114                                                 "tensor nor IndexedSlices.")
    115       if self._v.constraint is not None:
    116         raise RuntimeError(
    117             "Cannot use a constraint function on a sparse variable.")
    118       # pylint: disable=protected-access
    119       return optimizer._apply_sparse_duplicate_indices(g, self._v)
    120 
    121 
    122 class _DenseReadResourceVariableProcessor(_OptimizableVariable):
    123   """Processor for dense ResourceVariables."""
    124 
    125   def __init__(self, v):
    126     self._v = v
    127 
    128   def target(self):
    129     return self._v
    130 
    131   def update_op(self, optimizer, g):
    132     # pylint: disable=protected-access
    133     update_op = optimizer._resource_apply_dense(g, self._v.op.inputs[0])
    134     if self._v.constraint is not None:
    135       with ops.control_dependencies([update_op]):
    136         return self._v.assign(self._v.constraint(self._v))
    137     else:
    138       return update_op
    139 
    140 
    141 class _DenseResourceVariableProcessor(_OptimizableVariable):
    142   """Processor for dense ResourceVariables."""
    143 
    144   def __init__(self, v):
    145     self._v = v
    146 
    147   def target(self):
    148     return self._v
    149 
    150   def update_op(self, optimizer, g):
    151     # pylint: disable=protected-access
    152     if isinstance(g, ops.IndexedSlices):
    153       if self._v.constraint is not None:
    154         raise RuntimeError(
    155             "Cannot use a constraint function on a sparse variable.")
    156       return optimizer._resource_apply_sparse_duplicate_indices(
    157           g.values, self._v, g.indices)
    158     update_op = optimizer._resource_apply_dense(g, self._v)
    159     if self._v.constraint is not None:
    160       with ops.control_dependencies([update_op]):
    161         return self._v.assign(self._v.constraint(self._v))
    162     else:
    163       return update_op
    164 
    165 
    166 class _StreamingModelPortProcessor(_OptimizableVariable):
    167   """Processor for streaming ModelPorts."""
    168 
    169   def __init__(self, v):
    170     self._v = v
    171 
    172   def target(self):
    173     return self._v
    174 
    175   def update_op(self, optimizer, g):
    176     return g
    177 
    178 
    179 class _TensorProcessor(_OptimizableVariable):
    180   """Processor for ordinary Tensors.
    181 
    182   Even though a Tensor can't really be updated, sometimes it is useful to
    183   compute the gradients with respect to a Tensor using the optimizer. Updating
    184   the Tensor is, of course, unsupported.
    185   """
    186 
    187   def __init__(self, v):
    188     self._v = v
    189 
    190   def target(self):
    191     return self._v
    192 
    193   def update_op(self, optimizer, g):
    194     raise NotImplementedError("Trying to update a Tensor ", self._v)
    195 
    196 
    197 def _get_processor(v):
    198   """The processor of v."""
    199   if context.in_eager_mode():
    200     if isinstance(v, ops.Tensor):
    201       return _TensorProcessor(v)
    202     else:
    203       return _DenseResourceVariableProcessor(v)
    204   if v.op.type == "VarHandleOp":
    205     return _DenseResourceVariableProcessor(v)
    206   if isinstance(v, variables.Variable):
    207     return _RefVariableProcessor(v)
    208   if v.op.type == "SubmodelPort":
    209     return _StreamingModelPortProcessor(v)
    210   if isinstance(v, ops.Tensor):
    211     return _TensorProcessor(v)
    212   raise NotImplementedError("Trying to optimize unsupported type ", v)
    213 
    214 
    215 @tf_export("train.Optimizer")
    216 class Optimizer(checkpointable.Checkpointable):
    217   """Base class for optimizers.
    218 
    219   This class defines the API to add Ops to train a model.  You never use this
    220   class directly, but instead instantiate one of its subclasses such as
    221   `GradientDescentOptimizer`, `AdagradOptimizer`, or `MomentumOptimizer`.
    222 
    223   ### Usage
    224 
    225   ```python
    226   # Create an optimizer with the desired parameters.
    227   opt = GradientDescentOptimizer(learning_rate=0.1)
    228   # Add Ops to the graph to minimize a cost by updating a list of variables.
    229   # "cost" is a Tensor, and the list of variables contains tf.Variable
    230   # objects.
    231   opt_op = opt.minimize(cost, var_list=<list of variables>)
    232   ```
    233 
    234   In the training program you will just have to run the returned Op.
    235 
    236   ```python
    237   # Execute opt_op to do one step of training:
    238   opt_op.run()
    239   ```
    240 
    241   ### Processing gradients before applying them.
    242 
    243   Calling `minimize()` takes care of both computing the gradients and
    244   applying them to the variables.  If you want to process the gradients
    245   before applying them you can instead use the optimizer in three steps:
    246 
    247   1.  Compute the gradients with `compute_gradients()`.
    248   2.  Process the gradients as you wish.
    249   3.  Apply the processed gradients with `apply_gradients()`.
    250 
    251   Example:
    252 
    253   ```python
    254   # Create an optimizer.
    255   opt = GradientDescentOptimizer(learning_rate=0.1)
    256 
    257   # Compute the gradients for a list of variables.
    258   grads_and_vars = opt.compute_gradients(loss, <list of variables>)
    259 
    260   # grads_and_vars is a list of tuples (gradient, variable).  Do whatever you
    261   # need to the 'gradient' part, for example cap them, etc.
    262   capped_grads_and_vars = [(MyCapper(gv[0]), gv[1]) for gv in grads_and_vars]
    263 
    264   # Ask the optimizer to apply the capped gradients.
    265   opt.apply_gradients(capped_grads_and_vars)
    266   ```
    267 
    268   ### Gating Gradients
    269 
    270   Both `minimize()` and `compute_gradients()` accept a `gate_gradients`
    271   argument that controls the degree of parallelism during the application of
    272   the gradients.
    273 
    274   The possible values are: `GATE_NONE`, `GATE_OP`, and `GATE_GRAPH`.
    275 
    276   <b>`GATE_NONE`</b>: Compute and apply gradients in parallel.  This provides
    277   the maximum parallelism in execution, at the cost of some non-reproducibility
    278   in the results.  For example the two gradients of `matmul` depend on the input
    279   values: With `GATE_NONE` one of the gradients could be applied to one of the
    280   inputs _before_ the other gradient is computed resulting in non-reproducible
    281   results.
    282 
    283   <b>`GATE_OP`</b>: For each Op, make sure all gradients are computed before
    284   they are used.  This prevents race conditions for Ops that generate gradients
    285   for multiple inputs where the gradients depend on the inputs.
    286 
    287   <b>`GATE_GRAPH`</b>: Make sure all gradients for all variables are computed
    288   before any one of them is used.  This provides the least parallelism but can
    289   be useful if you want to process all gradients before applying any of them.
    290 
    291   ### Slots
    292 
    293   Some optimizer subclasses, such as `MomentumOptimizer` and `AdagradOptimizer`
    294   allocate and manage additional variables associated with the variables to
    295   train.  These are called <i>Slots</i>.  Slots have names and you can ask the
    296   optimizer for the names of the slots that it uses.  Once you have a slot name
    297   you can ask the optimizer for the variable it created to hold the slot value.
    298 
    299   This can be useful if you want to log debug a training algorithm, report stats
    300   about the slots, etc.
    301   """
    302 
    303   # Values for gate_gradients.
    304   GATE_NONE = 0
    305   GATE_OP = 1
    306   GATE_GRAPH = 2
    307 
    308   def __init__(self, use_locking, name):
    309     """Create a new Optimizer.
    310 
    311     This must be called by the constructors of subclasses.
    312 
    313     Args:
    314       use_locking: Bool. If True apply use locks to prevent concurrent updates
    315         to variables.
    316       name: A non-empty string.  The name to use for accumulators created
    317         for the optimizer.
    318 
    319     Raises:
    320       ValueError: If name is malformed.
    321     """
    322     if not name:
    323       raise ValueError("Must specify the optimizer name")
    324     self._use_locking = use_locking
    325     self._name = name
    326     # Dictionary of slots.
    327     #  {slot_name :
    328     #      {_var_key(variable_to_train): slot_for_the_variable, ... },
    329     #   ... }
    330     self._slots = {}
    331     self._non_slot_dict = {}
    332     # For implementing Checkpointable. Stores information about how to restore
    333     # slot variables which have not yet been created
    334     # (checkpointable._CheckpointPosition objects).
    335     #  {slot_name :
    336     #      {_var_key(variable_to_train): [checkpoint_position, ... ], ... },
    337     #   ... }
    338     self._deferred_slot_restorations = {}
    339 
    340   def get_name(self):
    341     return self._name
    342 
    343   def minimize(self, loss, global_step=None, var_list=None,
    344                gate_gradients=GATE_OP, aggregation_method=None,
    345                colocate_gradients_with_ops=False, name=None,
    346                grad_loss=None):
    347     """Add operations to minimize `loss` by updating `var_list`.
    348 
    349     This method simply combines calls `compute_gradients()` and
    350     `apply_gradients()`. If you want to process the gradient before applying
    351     them call `compute_gradients()` and `apply_gradients()` explicitly instead
    352     of using this function.
    353 
    354     Args:
    355       loss: A `Tensor` containing the value to minimize.
    356       global_step: Optional `Variable` to increment by one after the
    357         variables have been updated.
    358       var_list: Optional list or tuple of `Variable` objects to update to
    359         minimize `loss`.  Defaults to the list of variables collected in
    360         the graph under the key `GraphKeys.TRAINABLE_VARIABLES`.
    361       gate_gradients: How to gate the computation of gradients.  Can be
    362         `GATE_NONE`, `GATE_OP`, or  `GATE_GRAPH`.
    363       aggregation_method: Specifies the method used to combine gradient terms.
    364         Valid values are defined in the class `AggregationMethod`.
    365       colocate_gradients_with_ops: If True, try colocating gradients with
    366         the corresponding op.
    367       name: Optional name for the returned operation.
    368       grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
    369 
    370     Returns:
    371       An Operation that updates the variables in `var_list`.  If `global_step`
    372       was not `None`, that operation also increments `global_step`.
    373 
    374     Raises:
    375       ValueError: If some of the variables are not `Variable` objects.
    376 
    377     @compatibility(eager)
    378     When eager execution is enabled, `loss` should be a Python function that
    379     takes elements of `var_list` as arguments and computes the value to be
    380     minimized. If `var_list` is None, `loss` should take no arguments.
    381     Minimization (and gradient computation) is done with respect to the
    382     elements of `var_list` if not None, else with respect to any trainable
    383     variables created during the execution of the `loss` function.
    384     `gate_gradients`, `aggregation_method`, `colocate_gradients_with_ops` and
    385     `grad_loss` are ignored when eager execution is enabled.
    386     @end_compatibility
    387     """
    388     grads_and_vars = self.compute_gradients(
    389         loss, var_list=var_list, gate_gradients=gate_gradients,
    390         aggregation_method=aggregation_method,
    391         colocate_gradients_with_ops=colocate_gradients_with_ops,
    392         grad_loss=grad_loss)
    393 
    394     vars_with_grad = [v for g, v in grads_and_vars if g is not None]
    395     if not vars_with_grad:
    396       raise ValueError(
    397           "No gradients provided for any variable, check your graph for ops"
    398           " that do not support gradients, between variables %s and loss %s." %
    399           ([str(v) for _, v in grads_and_vars], loss))
    400 
    401     return self.apply_gradients(grads_and_vars, global_step=global_step,
    402                                 name=name)
    403 
    404   def compute_gradients(self, loss, var_list=None,
    405                         gate_gradients=GATE_OP,
    406                         aggregation_method=None,
    407                         colocate_gradients_with_ops=False,
    408                         grad_loss=None):
    409     """Compute gradients of `loss` for the variables in `var_list`.
    410 
    411     This is the first part of `minimize()`.  It returns a list
    412     of (gradient, variable) pairs where "gradient" is the gradient
    413     for "variable".  Note that "gradient" can be a `Tensor`, an
    414     `IndexedSlices`, or `None` if there is no gradient for the
    415     given variable.
    416 
    417     Args:
    418       loss: A Tensor containing the value to minimize or a callable taking
    419         no arguments which returns the value to minimize. When eager execution
    420         is enabled it must be a callable.
    421       var_list: Optional list or tuple of `tf.Variable` to update to minimize
    422         `loss`.  Defaults to the list of variables collected in the graph
    423         under the key `GraphKeys.TRAINABLE_VARIABLES`.
    424       gate_gradients: How to gate the computation of gradients.  Can be
    425         `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
    426       aggregation_method: Specifies the method used to combine gradient terms.
    427         Valid values are defined in the class `AggregationMethod`.
    428       colocate_gradients_with_ops: If True, try colocating gradients with
    429         the corresponding op.
    430       grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
    431 
    432     Returns:
    433       A list of (gradient, variable) pairs. Variable is always present, but
    434       gradient can be `None`.
    435 
    436     Raises:
    437       TypeError: If `var_list` contains anything else than `Variable` objects.
    438       ValueError: If some arguments are invalid.
    439       RuntimeError: If called with eager execution enabled and `loss` is
    440         not callable.
    441 
    442     @compatibility(eager)
    443     When eager execution is enabled, `gate_gradients`, `aggregation_method`,
    444     and `colocate_gradients_with_ops` are ignored.
    445     @end_compatibility
    446     """
    447     if callable(loss):
    448       with backprop.GradientTape() as tape:
    449         if var_list is not None:
    450           tape.watch(var_list)
    451         loss_value = loss()
    452       if var_list is None:
    453         var_list = tape.watched_variables()
    454       grads = tape.gradient(loss_value, var_list, grad_loss)
    455       return list(zip(grads, var_list))
    456     if context.in_eager_mode():
    457       raise RuntimeError(
    458           "`loss` passed to Optimizer.compute_gradients should "
    459           "be a function when eager execution is enabled.")
    460     if gate_gradients not in [Optimizer.GATE_NONE, Optimizer.GATE_OP,
    461                               Optimizer.GATE_GRAPH]:
    462       raise ValueError("gate_gradients must be one of: Optimizer.GATE_NONE, "
    463                        "Optimizer.GATE_OP, Optimizer.GATE_GRAPH.  Not %s" %
    464                        gate_gradients)
    465     self._assert_valid_dtypes([loss])
    466     if grad_loss is not None:
    467       self._assert_valid_dtypes([grad_loss])
    468     if var_list is None:
    469       var_list = (
    470           variables.trainable_variables() +
    471           ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
    472     else:
    473       var_list = nest.flatten(var_list)
    474     # pylint: disable=protected-access
    475     var_list += ops.get_collection(ops.GraphKeys._STREAMING_MODEL_PORTS)
    476     # pylint: enable=protected-access
    477     processors = [_get_processor(v) for v in var_list]
    478     if not var_list:
    479       raise ValueError("No variables to optimize.")
    480     var_refs = [p.target() for p in processors]
    481     grads = gradients.gradients(
    482         loss, var_refs, grad_ys=grad_loss,
    483         gate_gradients=(gate_gradients == Optimizer.GATE_OP),
    484         aggregation_method=aggregation_method,
    485         colocate_gradients_with_ops=colocate_gradients_with_ops)
    486     if gate_gradients == Optimizer.GATE_GRAPH:
    487       grads = control_flow_ops.tuple(grads)
    488     grads_and_vars = list(zip(grads, var_list))
    489     self._assert_valid_dtypes(
    490         [v for g, v in grads_and_vars
    491          if g is not None and v.dtype != dtypes.resource])
    492     return grads_and_vars
    493 
    494   def apply_gradients(self, grads_and_vars, global_step=None, name=None):
    495     """Apply gradients to variables.
    496 
    497     This is the second part of `minimize()`. It returns an `Operation` that
    498     applies gradients.
    499 
    500     Args:
    501       grads_and_vars: List of (gradient, variable) pairs as returned by
    502         `compute_gradients()`.
    503       global_step: Optional `Variable` to increment by one after the
    504         variables have been updated.
    505       name: Optional name for the returned operation.  Default to the
    506         name passed to the `Optimizer` constructor.
    507 
    508     Returns:
    509       An `Operation` that applies the specified gradients. If `global_step`
    510       was not None, that operation also increments `global_step`.
    511 
    512     Raises:
    513       TypeError: If `grads_and_vars` is malformed.
    514       ValueError: If none of the variables have gradients.
    515     """
    516     # This is a default implementation of apply_gradients() that can be shared
    517     # by most optimizers.  It relies on the subclass implementing the following
    518     # methods: _create_slots(), _prepare(), _apply_dense(), and _apply_sparse().
    519 
    520     grads_and_vars = tuple(grads_and_vars)  # Make sure repeat iteration works.
    521     if not grads_and_vars:
    522       raise ValueError("No variables provided.")
    523     converted_grads_and_vars = []
    524     for g, v in grads_and_vars:
    525       if g is not None:
    526         try:
    527           # Convert the grad to Tensor or IndexedSlices if necessary.
    528           g = ops.convert_to_tensor_or_indexed_slices(g)
    529         except TypeError:
    530           raise TypeError(
    531               "Gradient must be convertible to a Tensor"
    532               " or IndexedSlices, or None: %s" % g)
    533         if not isinstance(g, (ops.Tensor, ops.IndexedSlices)):
    534           raise TypeError(
    535               "Gradient must be a Tensor, IndexedSlices, or None: %s" % g)
    536       p = _get_processor(v)
    537       converted_grads_and_vars.append((g, v, p))
    538 
    539     converted_grads_and_vars = tuple(converted_grads_and_vars)
    540     var_list = [v for g, v, _ in converted_grads_and_vars if g is not None]
    541     if not var_list:
    542       raise ValueError("No gradients provided for any variable: %s." %
    543                        ([str(v) for _, _, v in converted_grads_and_vars],))
    544     with ops.init_scope():
    545       self._create_slots([_get_variable_for(v) for v in var_list])
    546     update_ops = []
    547     with ops.name_scope(name, self._name) as name:
    548       self._prepare()
    549       for grad, var, processor in converted_grads_and_vars:
    550         if grad is None:
    551           continue
    552         # We colocate all ops created in _apply_dense or _apply_sparse
    553         # on the same device as the variable.
    554         # TODO(apassos): figure out how to get the variable name here.
    555         scope_name = var.op.name if context.in_graph_mode() else ""
    556         with ops.name_scope("update_" + scope_name), ops.colocate_with(var):
    557           update_ops.append(processor.update_op(self, grad))
    558       if global_step is None:
    559         apply_updates = self._finish(update_ops, name)
    560       else:
    561         with ops.control_dependencies([self._finish(update_ops, "update")]):
    562           with ops.colocate_with(global_step):
    563             if isinstance(global_step, resource_variable_ops.ResourceVariable):
    564               # TODO(apassos): the implicit read in assign_add is slow; consider
    565               # making it less so.
    566               apply_updates = resource_variable_ops.assign_add_variable_op(
    567                   global_step.handle,
    568                   ops.convert_to_tensor(1, dtype=global_step.dtype),
    569                   name=name)
    570             else:
    571               apply_updates = state_ops.assign_add(global_step, 1, name=name)
    572 
    573       if context.in_graph_mode():
    574         if isinstance(apply_updates, ops.Tensor):
    575           apply_updates = apply_updates.op
    576         train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
    577         if apply_updates not in train_op:
    578           train_op.append(apply_updates)
    579 
    580       return apply_updates
    581 
    582   def get_slot(self, var, name):
    583     """Return a slot named `name` created for `var` by the Optimizer.
    584 
    585     Some `Optimizer` subclasses use additional variables.  For example
    586     `Momentum` and `Adagrad` use variables to accumulate updates.  This method
    587     gives access to these `Variable` objects if for some reason you need them.
    588 
    589     Use `get_slot_names()` to get the list of slot names created by the
    590     `Optimizer`.
    591 
    592     Args:
    593       var: A variable passed to `minimize()` or `apply_gradients()`.
    594       name: A string.
    595 
    596     Returns:
    597       The `Variable` for the slot if it was created, `None` otherwise.
    598     """
    599     named_slots = self._slots.get(name, None)
    600     if not named_slots:
    601       return None
    602     return named_slots.get(_var_key(var), None)
    603 
    604   def get_slot_names(self):
    605     """Return a list of the names of slots created by the `Optimizer`.
    606 
    607     See `get_slot()`.
    608 
    609     Returns:
    610       A list of strings.
    611     """
    612     return sorted(self._slots.keys())
    613 
    614   def variables(self):
    615     """A list of variables which encode the current state of `Optimizer`.
    616 
    617     Includes slot variables and additional global variables created by the
    618     optimizer in the current default graph.
    619 
    620     Returns:
    621       A list of variables.
    622     """
    623     executing_eagerly = context.in_eager_mode()
    624     current_graph = ops.get_default_graph()
    625 
    626     def _from_current_graph(variable):
    627       if executing_eagerly:
    628         # No variable.op in eager mode. We don't expect lots of eager graphs,
    629         # but behavior should be consistent with graph mode.
    630         return variable._graph_key == current_graph._graph_key  # pylint: disable=protected-access
    631       else:
    632         return variable.op.graph is current_graph
    633 
    634     optimizer_variables = [v for v in self._non_slot_variables()
    635                            if _from_current_graph(v)]
    636     for _, variable_dict in self._slots.items():
    637       for _, slot_for_variable in variable_dict.items():
    638         if _from_current_graph(slot_for_variable):
    639           optimizer_variables.append(slot_for_variable)
    640     # Sort variables by name so that the return is deterministic.
    641     return sorted(optimizer_variables, key=lambda v: v.name)
    642 
    643   def _create_non_slot_variable(self, initial_value, name, colocate_with):
    644     """Add an extra variable, not associated with a slot."""
    645     if context.in_graph_mode():
    646       graph = colocate_with.graph
    647     else:
    648       graph = None
    649 
    650     key = (name, graph)
    651     v = self._non_slot_dict.get(key, None)
    652     if v is None:
    653       with ops.colocate_with(colocate_with):
    654         v = variable_scope.variable(initial_value, name=name, trainable=False)
    655       self._non_slot_dict[key] = v
    656 
    657     return v
    658 
    659   def _get_non_slot_variable(self, name, graph=None):
    660     return self._non_slot_dict.get((name, graph), None)
    661 
    662   def _non_slot_variables(self):
    663     """Additional variables created by the `Optimizer`.
    664 
    665     Returns:
    666       A list or tuple of variables.
    667     """
    668     return self._non_slot_dict.values()
    669 
    670   def _assert_valid_dtypes(self, tensors):
    671     """Asserts tensors are all valid types (see `_valid_dtypes`).
    672 
    673     Args:
    674       tensors: Tensors to check.
    675 
    676     Raises:
    677       ValueError: If any tensor is not a valid type.
    678     """
    679     valid_dtypes = self._valid_dtypes()
    680     for t in tensors:
    681       dtype = t.dtype.base_dtype
    682       if dtype not in valid_dtypes:
    683         raise ValueError(
    684             "Invalid type %r for %s, expected: %s." % (
    685                 dtype, t.name, [v for v in valid_dtypes]))
    686 
    687   # --------------
    688   # Methods to be implemented by subclasses if they want to use the
    689   # inherited implementation of apply_gradients() or compute_gradients().
    690   # --------------
    691   def _valid_dtypes(self):
    692     """Valid types for loss, variables and gradients.
    693 
    694     Subclasses should override to allow other float types.
    695 
    696     Returns:
    697       Valid types for loss, variables and gradients.
    698     """
    699     return set(
    700         [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64])
    701 
    702   def _create_slots(self, var_list):
    703     """Create all slots needed by the variables.
    704 
    705     Args:
    706       var_list: A list of `Variable` objects.
    707     """
    708     # No slots needed by default
    709     pass
    710 
    711   def _prepare(self):
    712     """Create all needed tensors before applying gradients.
    713 
    714     This is called with the name_scope using the "name" that
    715     users have chosen for the application of gradients.
    716     """
    717     pass
    718 
    719   def _apply_dense(self, grad, var):
    720     """Add ops to apply dense gradients to `var`.
    721 
    722     Args:
    723       grad: A `Tensor`.
    724       var: A `Variable` object.
    725 
    726     Returns:
    727       An `Operation`.
    728     """
    729     raise NotImplementedError()
    730 
    731   def _resource_apply_dense(self, grad, handle):
    732     """Add ops to apply dense gradients to the variable `handle`.
    733 
    734     Args:
    735       grad: a `Tensor` representing the gradient.
    736       handle: a `Tensor` of dtype `resource` which points to the variable
    737        to be updated.
    738 
    739     Returns:
    740       An `Operation` which updates the value of the variable.
    741     """
    742     raise NotImplementedError()
    743 
    744   def _resource_apply_sparse_duplicate_indices(self, grad, handle, indices):
    745     """Add ops to apply sparse gradients to `handle`, with repeated indices.
    746 
    747     Optimizers which override this method must deal with repeated indices. See
    748     the docstring of `_apply_sparse_duplicate_indices` for details. By default
    749     the correct behavior, to sum non-unique indices and their associated
    750     gradients, is enforced by first pre-processing `grad` and `indices` and
    751     passing them on to `_resource_apply_sparse`. Optimizers which deal correctly
    752     with duplicate indices may instead override this method to avoid the
    753     overhead of summing.
    754 
    755     Args:
    756       grad: a `Tensor` representing the gradient for the affected indices.
    757       handle: a `Tensor` of dtype `resource` which points to the variable
    758        to be updated.
    759       indices: a `Tensor` of integral type representing the indices for
    760        which the gradient is nonzero. Indices may be repeated.
    761 
    762     Returns:
    763       An `Operation` which updates the value of the variable.
    764     """
    765     summed_grad, unique_indices = _deduplicate_indexed_slices(
    766         values=grad, indices=indices)
    767     return self._resource_apply_sparse(summed_grad, handle, unique_indices)
    768 
    769   def _resource_apply_sparse(self, grad, handle, indices):
    770     """Add ops to apply sparse gradients to the variable `handle`.
    771 
    772     Similar to `_apply_sparse`, the `indices` argument to this method has been
    773     de-duplicated. Optimizers which deal correctly with non-unique indices may
    774     instead override `_resource_apply_sparse_duplicate_indices` to avoid this
    775     overhead.
    776 
    777     Args:
    778       grad: a `Tensor` representing the gradient for the affected indices.
    779       handle: a `Tensor` of dtype `resource` which points to the variable
    780        to be updated.
    781       indices: a `Tensor` of integral type representing the indices for
    782        which the gradient is nonzero. Indices are unique.
    783 
    784     Returns:
    785       An `Operation` which updates the value of the variable.
    786     """
    787     raise NotImplementedError()
    788 
    789   def _apply_sparse_duplicate_indices(self, grad, var):
    790     """Add ops to apply sparse gradients to `var`, with repeated sparse indices.
    791 
    792     Optimizers which override this method must deal with IndexedSlices objects
    793     such as the following:
    794 
    795       IndexedSlicesValue(values=[1, 1], indices=[0, 0], dense_shape=[1])
    796 
    797     The correct interpretation is:
    798 
    799       IndexedSlicesValue(values=[2], indices=[0], dense_shape=[1])
    800 
    801     Many optimizers deal incorrectly with repeated indices when updating based
    802     on sparse gradients (e.g. summing squares rather than squaring the sum, or
    803     applying momentum terms multiple times). Adding first is always the correct
    804     behavior, so this is enforced here by reconstructing the IndexedSlices to
    805     have only unique indices, then calling _apply_sparse.
    806 
    807     Optimizers which deal correctly with repeated indices may instead override
    808     this method to avoid the overhead of summing indices.
    809 
    810     Args:
    811       grad: `IndexedSlices`.
    812       var: A `Variable` object.
    813 
    814     Returns:
    815       An `Operation`.
    816     """
    817     summed_values, unique_indices = _deduplicate_indexed_slices(
    818         values=grad.values, indices=grad.indices)
    819     gradient_no_duplicate_indices = ops.IndexedSlices(
    820         indices=unique_indices,
    821         values=summed_values,
    822         dense_shape=grad.dense_shape)
    823     return self._apply_sparse(gradient_no_duplicate_indices, var)
    824 
    825   def _apply_sparse(self, grad, var):
    826     """Add ops to apply sparse gradients to `var`.
    827 
    828     The IndexedSlices object passed to `grad` in this function is by default
    829     pre-processed in `_apply_sparse_duplicate_indices` to remove duplicate
    830     indices (see its docstring for details). Optimizers which can tolerate or
    831     have correct special cases for duplicate sparse indices may override
    832     `_apply_sparse_duplicate_indices` instead of this function, avoiding that
    833     overhead.
    834 
    835     Args:
    836       grad: `IndexedSlices`, with no repeated indices.
    837       var: A `Variable` object.
    838 
    839     Returns:
    840       An `Operation`.
    841     """
    842     raise NotImplementedError()
    843 
    844   def _finish(self, update_ops, name_scope):
    845     """Do what is needed to finish the update.
    846 
    847     This is called with the `name_scope` using the "name" that
    848     users have chosen for the application of gradients.
    849 
    850     Args:
    851       update_ops: List of `Operation` objects to update variables.  This list
    852         contains the values returned by the `_apply_dense()` and
    853         `_apply_sparse()` calls.
    854       name_scope: String.  Name to use for the returned operation.
    855 
    856     Returns:
    857       The operation to apply updates.
    858     """
    859     return control_flow_ops.group(*update_ops, name=name_scope)
    860 
    861   # --------------
    862   # Utility methods for subclasses.
    863   # --------------
    864 
    865   def _slot_dict(self, slot_name):
    866     """Returns a dict for caching slots created under the given name.
    867 
    868     Args:
    869       slot_name: Name for the slot.
    870 
    871     Returns:
    872       A dict that maps primary `Variable` objects to the slot created
    873       for that variable, under the given slot name.
    874     """
    875     named_slots = self._slots.get(slot_name, None)
    876     if named_slots is None:
    877       named_slots = {}
    878       self._slots[slot_name] = named_slots
    879     return named_slots
    880 
    881   def _get_or_make_slot(self, var, val, slot_name, op_name):
    882     """Find or create a slot for a variable.
    883 
    884     Args:
    885       var: A `Variable` object.
    886       val: A `Tensor`.  The initial value of the slot.
    887       slot_name: Name for the slot.
    888       op_name: Name to use when scoping the Variable that
    889         needs to be created for the slot.
    890 
    891     Returns:
    892       A `Variable` object.
    893     """
    894     named_slots = self._slot_dict(slot_name)
    895     if _var_key(var) not in named_slots:
    896       new_slot_variable = slot_creator.create_slot(var, val, op_name)
    897       self._restore_slot_variable(
    898           slot_name=slot_name, variable=var,
    899           slot_variable=new_slot_variable)
    900       named_slots[_var_key(var)] = new_slot_variable
    901     return named_slots[_var_key(var)]
    902 
    903   def _get_or_make_slot_with_initializer(self, var, initializer, shape, dtype,
    904                                          slot_name, op_name):
    905     """Find or create a slot for a variable, using an Initializer.
    906 
    907     Args:
    908       var: A `Variable` object.
    909       initializer: An `Initializer`.  The initial value of the slot.
    910       shape: Shape of the initial value of the slot.
    911       dtype: Type of the value of the slot.
    912       slot_name: Name for the slot.
    913       op_name: Name to use when scoping the Variable that
    914         needs to be created for the slot.
    915 
    916     Returns:
    917       A `Variable` object.
    918     """
    919     named_slots = self._slot_dict(slot_name)
    920     if _var_key(var) not in named_slots:
    921       new_slot_variable = slot_creator.create_slot_with_initializer(
    922           var, initializer, shape, dtype, op_name)
    923       self._restore_slot_variable(
    924           slot_name=slot_name, variable=var,
    925           slot_variable=new_slot_variable)
    926       named_slots[_var_key(var)] = new_slot_variable
    927     return named_slots[_var_key(var)]
    928 
    929   def _zeros_slot(self, var, slot_name, op_name):
    930     """Find or create a slot initialized with 0.0.
    931 
    932     Args:
    933       var: A `Variable` object.
    934       slot_name: Name for the slot.
    935       op_name: Name to use when scoping the Variable that
    936         needs to be created for the slot.
    937 
    938     Returns:
    939       A `Variable` object.
    940     """
    941     named_slots = self._slot_dict(slot_name)
    942     if _var_key(var) not in named_slots:
    943       new_slot_variable = slot_creator.create_zeros_slot(var, op_name)
    944       self._restore_slot_variable(
    945           slot_name=slot_name, variable=var,
    946           slot_variable=new_slot_variable)
    947       named_slots[_var_key(var)] = new_slot_variable
    948     return named_slots[_var_key(var)]
    949 
    950   # --------------
    951   # For implementing the Checkpointable interface.
    952   # --------------
    953 
    954   def _restore_slot_variable(self, slot_name, variable, slot_variable):
    955     """Restore a newly created slot variable's value."""
    956     variable_key = _var_key(variable)
    957     deferred_restorations = self._deferred_slot_restorations.get(
    958         slot_name, {}).pop(variable_key, [])
    959     # Iterate over restores, highest restore UID first to minimize the number
    960     # of assignments.
    961     deferred_restorations.sort(key=lambda position: position.restore_uid,
    962                                reverse=True)
    963     for checkpoint_position in deferred_restorations:
    964       checkpoint_position.restore(slot_variable)
    965 
    966   def _create_or_restore_slot_variable(
    967       self, slot_variable_position, slot_name, variable):
    968     """Restore a slot variable's value, possibly creating it.
    969 
    970     Called when a variable which has an associated slot variable is created or
    971     restored. When executing eagerly, we create the slot variable with a
    972     restoring initializer.
    973 
    974     No new variables are created when graph building. Instead,
    975     _restore_slot_variable catches these after normal creation and adds restore
    976     ops to the graph. This method is nonetheless important when graph building
    977     for the case when a slot variable has already been created but `variable`
    978     has just been added to a dependency graph (causing us to realize that the
    979     slot variable needs to be restored).
    980 
    981     Args:
    982       slot_variable_position: A `checkpointable._CheckpointPosition` object
    983         indicating the slot variable `Checkpointable` object to be restored.
    984       slot_name: The name of this `Optimizer`'s slot to restore into.
    985       variable: The variable object this slot is being created for.
    986     """
    987     named_slots = self._slot_dict(slot_name)
    988     variable_key = _var_key(variable)
    989     slot_variable = named_slots.get(variable_key, None)
    990     if (slot_variable is None
    991         and context.in_eager_mode()
    992         and slot_variable_position.is_simple_variable()):
    993       initializer = checkpointable.CheckpointInitialValue(
    994           checkpoint_position=slot_variable_position)
    995       slot_variable = self._get_or_make_slot(
    996           var=variable,
    997           val=initializer,
    998           slot_name=slot_name,
    999           op_name=self._name)
   1000       # Slot variables are not owned by any one object (because we don't want to
   1001       # save the slot variable if the optimizer is saved without the non-slot
   1002       # variable, or if the non-slot variable is saved without the optimizer;
   1003       # it's a dependency hypergraph with edges of the form (optimizer, non-slot
   1004       # variable, variable)). So we don't _track_ slot variables anywhere, and
   1005       # instead special-case this dependency and otherwise pretend it's a normal
   1006       # graph.
   1007     if slot_variable is not None:
   1008       # If we've either made this slot variable, or if we've pulled out an
   1009       # existing slot variable, we should restore it.
   1010       slot_variable_position.restore(slot_variable)
   1011     else:
   1012       # We didn't make the slot variable. Defer restoring until it gets created
   1013       # normally. We keep a list rather than the one with the highest restore
   1014       # UID in case slot variables have their own dependencies, in which case
   1015       # those could differ between restores.
   1016       self._deferred_slot_restorations.setdefault(
   1017           slot_name, {}).setdefault(variable_key, []).append(
   1018               slot_variable_position)
   1019