Home | History | Annotate | Download | only in layers
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Reversible Residual Block.
     16 
     17 From
     18 [The Reversible Residual Network: Backpropagation Without Storing
     19 Activations](https://arxiv.org/abs/1707.04585).
     20 
     21 Also contains the @recompute_grad decorator, which recomputes the forward
     22 function on the backwards pass.
     23 """
     24 
     25 from __future__ import absolute_import
     26 from __future__ import division
     27 from __future__ import print_function
     28 
     29 import functools
     30 import re
     31 
     32 from six.moves import xrange  # pylint: disable=redefined-builtin
     33 
     34 from tensorflow.contrib.framework.python import ops as contrib_framework_ops
     35 from tensorflow.python.framework import function
     36 from tensorflow.python.framework import ops as framework_ops
     37 from tensorflow.python.layers import base
     38 from tensorflow.python.ops import array_ops
     39 from tensorflow.python.ops import control_flow_ops
     40 from tensorflow.python.ops import gradients_impl
     41 from tensorflow.python.ops import math_ops
     42 from tensorflow.python.ops import variable_scope
     43 from tensorflow.python.platform import tf_logging as logging
     44 from tensorflow.python.util import nest
     45 
     46 __all__ = ["rev_block", "RevBlock", "recompute_grad"]
     47 
     48 LAYER_RE = re.compile(".*revlayer_([0-9]*)/([fg])/.*")
     49 
     50 
     51 def _acc_grads(*lists_of_grads):
     52   """Accumulates lists of gradients."""
     53   acc_grads = []
     54   for grads in zip(*lists_of_grads):
     55     grads = [g for g in grads if g is not None]
     56     if grads:
     57       acc_grads.append(math_ops.add_n(grads))
     58     else:
     59       acc_grads.append(None)
     60   return acc_grads
     61 
     62 
     63 def _rev_layer_forward(xs, f, g, f_side_input, g_side_input,
     64                        gate_outputs=False):
     65   """Forward for 1 reversible layer."""
     66   x1, x2 = xs
     67   y1 = x1 + (f(x2, f_side_input) if f_side_input else f(x2))
     68   y2 = x2 + (g(y1, g_side_input) if g_side_input else g(y1))
     69   if gate_outputs:
     70     return control_flow_ops.tuple([y1, y2])
     71   else:
     72     return (y1, y2)
     73 
     74 
     75 def _rev_layer_backward(ys, grad_ys, f, g, f_vars, f_side_input, g_vars,
     76                         g_side_input):
     77   """Backprop for 1 layer."""
     78   y1, y2 = ys
     79   grad_y1, grad_y2 = grad_ys
     80 
     81   # Reconstruct intermediates and inputs (x1, x2)
     82   # stop_gradients required on fn inputs to prevent infinite recursion into this
     83   # grad function on the calls to gradients.
     84   y1_stop = array_ops.stop_gradient(y1)
     85   g_side_input = [array_ops.stop_gradient(t) for t in g_side_input]
     86   gy1 = g(y1_stop, g_side_input) if g_side_input else g(y1_stop)
     87 
     88   x2 = y2 - gy1
     89   x2_stop = array_ops.stop_gradient(x2)
     90   f_side_input = [array_ops.stop_gradient(t) for t in f_side_input]
     91   fx2 = f(x2_stop, f_side_input) if f_side_input else f(x2_stop)
     92 
     93   x1 = y1 - fx2
     94 
     95   # Compute gradients wrt to inputs
     96   # dL/dy2 * dG(y1)/y1
     97   grad_gy1_y2 = gradients_impl.gradients(gy1, y1_stop, grad_y2)[0]
     98   grad_x1 = grad_y1 + grad_gy1_y2
     99   grad_x2 = (
    100       gradients_impl.gradients(fx2, x2_stop, grad_y1)[0] + grad_y2 +
    101       gradients_impl.gradients(fx2, x2_stop, grad_gy1_y2)[0])
    102 
    103   # Compute gradients wrt to vars and side inputs in f and g
    104   grads1 = gradients_impl.gradients(gy1, g_vars + g_side_input, grad_y2)
    105   grad_g_vars, grad_g_side = grads1[:len(g_vars)], grads1[len(g_vars):]
    106   grads2 = gradients_impl.gradients(fx2, f_vars + f_side_input, grad_y1)
    107   grad_f_y1, grad_f_side1 = grads2[:len(f_vars)], grads2[len(f_vars):]
    108   grads3 = gradients_impl.gradients(fx2, f_vars + f_side_input, grad_gy1_y2)
    109   grad_f_y2, grad_f_side2 = grads3[:len(f_vars)], grads3[len(f_vars):]
    110   grad_f_vars = _acc_grads(grad_f_y1, grad_f_y2)
    111 
    112   grad_f_side = _acc_grads(grad_f_side1, grad_f_side2)
    113 
    114   # Put returns in a tuple to ensure a constant memory budget (i.e. don't want
    115   # the subsequent layer to start computing and consuming memory based on a
    116   # subset of these values).
    117   outputs = ((x1, x2), (grad_x1, grad_x2), (grad_f_vars, grad_f_side),
    118              (grad_g_vars, grad_g_side))
    119   tupled = control_flow_ops.tuple(nest.flatten(outputs))
    120   return nest.pack_sequence_as(outputs, tupled)
    121 
    122 
    123 def _rev_block_forward(x1,
    124                        x2,
    125                        f,
    126                        g,
    127                        num_layers=1,
    128                        f_side_input=None,
    129                        g_side_input=None,
    130                        gate_outputs=False):
    131   """Forward for a series of reversible layers."""
    132   out = (x1, x2)
    133   for i in xrange(num_layers):
    134     out = _rev_layer_forward(
    135         out, f[i], g[i], f_side_input, g_side_input, gate_outputs=gate_outputs)
    136 
    137   y1, y2 = out
    138   return y1, y2
    139 
    140 
    141 def _scope_wrap(fn, scope):
    142 
    143   @functools.wraps(fn)
    144   def wrap(*args, **kwargs):
    145     with variable_scope.variable_scope(scope):
    146       return fn(*args, **kwargs)
    147 
    148   return wrap
    149 
    150 
    151 class RevBlock(base.Layer):
    152   """Block of reversible layers. See rev_block."""
    153 
    154   def __init__(self,
    155                f,
    156                g,
    157                num_layers=1,
    158                f_side_input=None,
    159                g_side_input=None,
    160                use_efficient_backprop=True,
    161                name="revblock",
    162                **kwargs):
    163     super(RevBlock, self).__init__(name=name, **kwargs)
    164 
    165     if isinstance(f, list):
    166       assert len(f) == num_layers
    167     else:
    168       f = [f] * num_layers
    169 
    170     if isinstance(g, list):
    171       assert len(g) == num_layers
    172     else:
    173       g = [g] * num_layers
    174 
    175     f = [_scope_wrap(fn, "revlayer_%d/f" % i) for i, fn in enumerate(f)]
    176     g = [_scope_wrap(fn, "revlayer_%d/g" % i) for i, fn in enumerate(g)]
    177 
    178     self.f = f
    179     self.g = g
    180 
    181     self.num_layers = num_layers
    182     self.f_side_input = f_side_input or []
    183     self.g_side_input = g_side_input or []
    184 
    185     self._use_efficient_backprop = use_efficient_backprop
    186 
    187   def call(self, inputs, forward=True):
    188     vs = variable_scope.get_variable_scope()
    189     vars_before = vs.global_variables()
    190 
    191     if forward:
    192       x1, x2 = inputs
    193       out = self._forward(x1, x2)
    194     else:
    195       y1, y2 = inputs
    196       out = self._backward(y1, y2)
    197 
    198     # Add any created variables to the Layer's variable stores
    199     new_vars = vs.global_variables()[len(vars_before):]
    200     train_vars = vs.trainable_variables()
    201     for new_var in new_vars:
    202       if new_var in train_vars:
    203         self._trainable_weights.append(new_var)
    204       else:
    205         self._non_trainable_weights.append(new_var)
    206 
    207     return out
    208 
    209   def forward(self, x1, x2):
    210     return self.apply([x1, x2])
    211 
    212   def backward(self, y1, y2):
    213     return self.apply([y1, y2], forward=False)
    214 
    215   def build(self, _):
    216     logging.warn("RevBlock constructs its variables on first call, not on "
    217                  "build.")
    218     self.built = True
    219 
    220   def _efficient_grad_fn(self, inputs, variables, ys, grad_ys):
    221     """Custom gradient fn for a block of reversible residual layers."""
    222     side_inputs = inputs[2:]
    223     f_side_idxs = [None] * len(self.f_side_input)
    224     g_side_idxs = [None] * len(self.g_side_input)
    225     assert len(side_inputs) == len(self.f_side_input) + len(self.g_side_input)
    226 
    227     for i, t in enumerate(side_inputs):
    228       if t in self.f_side_input:
    229         f_side_idxs[self.f_side_input.index(t)] = i
    230       elif t in self.g_side_input:
    231         g_side_idxs[self.g_side_input.index(t)] = i
    232       else:
    233         assert False
    234 
    235     f_vars = [[] for _ in range(self.num_layers)]
    236     g_vars = [[] for _ in range(self.num_layers)]
    237     f_vars_idxs = [[] for _ in range(self.num_layers)]
    238     g_vars_idxs = [[] for _ in range(self.num_layers)]
    239 
    240     for i, t in enumerate(variables):
    241       ref = _underlying_variable_ref(t)
    242 
    243       # Use the name to identify the layer number and function (f or g)
    244       regex = LAYER_RE.match(ref.name)
    245       layer_no = int(regex.group(1))
    246       fn_name = regex.group(2)
    247       if fn_name == "f":
    248         f_vars[layer_no].append(ref)
    249         f_vars_idxs[layer_no].append(i)
    250       else:
    251         assert fn_name == "g"
    252         g_vars[layer_no].append(ref)
    253         g_vars_idxs[layer_no].append(i)
    254 
    255     f_var_grads = []
    256     g_var_grads = []
    257     f_side_grads = []
    258     g_side_grads = []
    259 
    260     # Reverse variable containers to go backward
    261     f_vars.reverse()
    262     g_vars.reverse()
    263     f = list(self.f)
    264     g = list(self.g)
    265     f.reverse()
    266     g.reverse()
    267 
    268     with variable_scope.variable_scope(self.scope_name, reuse=True):
    269       for i in xrange(self.num_layers):
    270         ys, grad_ys, f_ret, g_ret = _rev_layer_backward(
    271             ys, grad_ys, f[i], g[i], f_vars[i], self.f_side_input, g_vars[i],
    272             self.g_side_input)
    273 
    274         grad_f_vars, grad_f_side = f_ret
    275         grad_g_vars, grad_g_side = g_ret
    276         f_var_grads.append(grad_f_vars)
    277         g_var_grads.append(grad_g_vars)
    278         f_side_grads.append(grad_f_side)
    279         g_side_grads.append(grad_g_side)
    280 
    281     # Accumulate layer gradients for f_side_input and g_side_input
    282     acc_f_side_grads = _acc_grads(*f_side_grads)
    283     acc_g_side_grads = _acc_grads(*g_side_grads)
    284 
    285     # Use the stored idxs to put gradients in the passed-in order.
    286     side_input_grads = [None] * len(side_inputs)
    287     variable_grads = [None] * len(variables)
    288 
    289     # Variable gradients were collected in reverse layer order. Reverse to match
    290     # idxs.
    291     f_var_grads.reverse()
    292     g_var_grads.reverse()
    293     for idxs, grads in list(zip(f_vars_idxs, f_var_grads)) + list(
    294         zip(g_vars_idxs, g_var_grads)):
    295       for i, grad in zip(idxs, grads):
    296         variable_grads[i] = grad
    297 
    298     for i, grad in zip(f_side_idxs, acc_f_side_grads):
    299       side_input_grads[i] = grad
    300     for i, grad in zip(g_side_idxs, acc_g_side_grads):
    301       side_input_grads[i] = grad
    302 
    303     grad_x1, grad_x2 = grad_ys
    304     return [grad_x1, grad_x2] + side_input_grads, variable_grads
    305 
    306   def _forward(self, x1, x2):
    307     """Run forward through the reversible layers."""
    308 
    309     side_inputs = [self.f_side_input, self.g_side_input]
    310     flat_side_inputs = nest.flatten(side_inputs)
    311 
    312     custom_grad_fn = (
    313         self._efficient_grad_fn if self._use_efficient_backprop else None)
    314 
    315     @_fn_with_custom_grad(custom_grad_fn)
    316     def _forward_wrap(x1_, x2_, *flat_side_inputs):
    317       f_side, g_side = nest.pack_sequence_as(side_inputs, flat_side_inputs)
    318       return _rev_block_forward(
    319           x1_,
    320           x2_,
    321           self.f,
    322           self.g,
    323           num_layers=self.num_layers,
    324           f_side_input=f_side,
    325           g_side_input=g_side,
    326           gate_outputs=self._use_efficient_backprop)
    327 
    328     return _forward_wrap(x1, x2, *flat_side_inputs)
    329 
    330   def _backward(self, y1, y2):
    331     """Run backward through the reversible layers."""
    332 
    333     f = list(self.f)
    334     g = list(self.g)
    335     f.reverse()
    336     g.reverse()
    337 
    338     for i in xrange(self.num_layers):
    339       gy1 = g[i](y1, self.g_side_input) if self.g_side_input else g[i](y1)
    340       x2 = y2 - gy1
    341       fx2 = f[i](x2, self.f_side_input) if self.f_side_input else f[i](x2)
    342       x1 = y1 - fx2
    343 
    344       y1, y2 = x1, x2
    345 
    346     return x1, x2
    347 
    348 
    349 def rev_block(x1,
    350               x2,
    351               f,
    352               g,
    353               num_layers=1,
    354               f_side_input=None,
    355               g_side_input=None,
    356               is_training=True):
    357   """A block of reversible residual layers.
    358 
    359   A reversible residual layer is defined as:
    360 
    361   ```
    362   y1 = x1 + f(x2, f_side_input)
    363   y2 = x2 + g(y1, g_side_input)
    364   ```
    365 
    366   A reversible residual block, defined here, is a series of reversible residual
    367   layers.
    368 
    369   Limitations:
    370   * f and g must not close over any Tensors; all side inputs to f and g should
    371     be passed in with f_side_input and g_side_input which will be forwarded to
    372     f and g.
    373   * f and g must not change the dimensionality of their inputs in order for the
    374     addition in the equations above to work.
    375 
    376   Args:
    377     x1: a float Tensor.
    378     x2: a float Tensor.
    379     f: a function, (Tensor) -> (Tensor) (or list of such of length num_layers).
    380       Should not change the shape of the Tensor. Can make calls to get_variable.
    381       See f_side_input if there are side inputs.
    382     g: a function, (Tensor) -> (Tensor) (or list of such of length num_layers).
    383       Should not change the shape of the Tensor. Can make calls to get_variable.
    384       See g_side_input if there are side inputs.
    385     num_layers: int, number of reversible residual layers. Each layer will
    386       apply f and g according to the equations above, with new variables in each
    387       layer.
    388     f_side_input: list of Tensors, side input to f. If not None, signature of f
    389       should be (Tensor, list<Tensor>) -> (Tensor).
    390     g_side_input: list of Tensors, side input to g. If not None, signature of g
    391       should be (Tensor, list<Tensor>) -> (Tensor).
    392     is_training: bool, whether to actually use the efficient backprop codepath.
    393 
    394   Returns:
    395     y1, y2: tuple of float Tensors.
    396   """
    397   block = RevBlock(
    398       f=f,
    399       g=g,
    400       num_layers=num_layers,
    401       f_side_input=f_side_input,
    402       g_side_input=g_side_input,
    403       use_efficient_backprop=is_training,
    404       _reuse=variable_scope.get_variable_scope().reuse)
    405   return block.forward(x1, x2)
    406 
    407 
    408 def recompute_grad(fn):
    409   """Decorator that recomputes the function on the backwards pass.
    410 
    411   Args:
    412     fn: a function that takes Tensors (all as positional arguments) and returns
    413       a tuple of Tensors.
    414 
    415   Returns:
    416     A wrapped fn that is identical to fn when called, but its activations will
    417     be discarded and recomputed on the backwards pass (i.e. on a call to
    418     tf.gradients).
    419   """
    420 
    421   @functools.wraps(fn)
    422   def wrapped(*args):
    423     return _recompute_grad(fn, args)
    424 
    425   return wrapped
    426 
    427 
    428 def _recompute_grad(fn, args):
    429   """See recompute_grad."""
    430 
    431   cached_vs = []
    432   cached_arg_scope = []
    433 
    434   def grad_fn(inputs, variables, outputs, output_grads):
    435     """Recompute outputs for gradient computation."""
    436     del outputs
    437     # Recompute outputs
    438     with framework_ops.control_dependencies(output_grads):
    439       with contrib_framework_ops.arg_scope(cached_arg_scope[0]):
    440         with variable_scope.variable_scope(cached_vs[0], reuse=True):
    441           outputs = fn(*inputs)
    442 
    443     if not (isinstance(outputs, list) or isinstance(outputs, tuple)):
    444       outputs = [outputs]
    445     outputs = list(outputs)
    446     grads = gradients_impl.gradients(outputs, inputs + variables, output_grads)
    447     grad_inputs = grads[:len(inputs)]
    448     grad_vars = grads[len(inputs):]
    449     return grad_inputs, grad_vars
    450 
    451   @_fn_with_custom_grad(grad_fn)
    452   def fn_with_recompute(*args):
    453     cached_vs.append(variable_scope.get_variable_scope())
    454     # TODO(rsepassi): Rm conditional in TF 1.4
    455     if hasattr(contrib_framework_ops, "current_arg_scope"):
    456       cached_arg_scope.append(contrib_framework_ops.current_arg_scope())
    457     else:
    458       cached_arg_scope.append({})
    459     return fn(*args)
    460 
    461   return fn_with_recompute(*args)
    462 
    463 
    464 def _underlying_variable_ref(t):
    465   """Find the underlying variable ref.
    466 
    467   Traverses through Identity, ReadVariableOp, and Enter ops.
    468   Stops when op type has Variable or VarHandle in name.
    469 
    470   Args:
    471     t: a Tensor
    472 
    473   Returns:
    474     a Tensor that is a variable ref, or None on error.
    475   """
    476   while t.op.type in ["Identity", "ReadVariableOp", "Enter"]:
    477     t = t.op.inputs[0]
    478 
    479   op_type = t.op.type
    480   if "Variable" in op_type or "VarHandle" in op_type:
    481     return t
    482   else:
    483     return None
    484 
    485 
    486 def _fn_with_custom_grad(grad_fn, use_global_vars=False):
    487   """Decorator to create a subgraph with a custom gradient function.
    488 
    489   The subgraph created by the decorated function is NOT put in a Defun and so
    490   does not suffer from the limitations of the Defun (all subgraph ops on the
    491   same device, no summaries).
    492 
    493   Args:
    494     grad_fn: function with signature
    495       (inputs, variables, outputs, output_grads) -> (grad_inputs, grad_vars),
    496       all of which are lists of Tensors.
    497     use_global_vars: if True, variables will be the global variables created.
    498       If False, will be the trainable variables.
    499 
    500   Returns:
    501     Decorator for function such that the gradient is defined by grad_fn.
    502   """
    503 
    504   def dec(fn):
    505 
    506     @functools.wraps(fn)
    507     def wrapped(*args):
    508       return _fn_with_custom_grad_internal(
    509           fn, args, grad_fn, use_global_vars=use_global_vars)
    510 
    511     return wrapped
    512 
    513   return dec
    514 
    515 
    516 def _fn_with_custom_grad_internal(fn, inputs, grad_fn, use_global_vars=False):
    517   """Create a subgraph with a custom gradient.
    518 
    519   Args:
    520     fn: function that takes inputs as arguments and produces 1 or more Tensors.
    521     inputs: list<Tensor>, will be passed as fn(*inputs).
    522     grad_fn: function with signature
    523       (inputs, vars, outputs, output_grads) -> (grad_inputs, grad_vars),
    524       all of which are lists of Tensors.
    525     use_global_vars: if True, variables will be the global variables created.
    526       If False, will be the trainable variables.
    527 
    528   Returns:
    529     fn(*inputs)
    530   """
    531   vs = variable_scope.get_variable_scope()
    532   get_vars_fn = (
    533       vs.global_variables if use_global_vars else vs.trainable_variables)
    534   len_before_vars = len(get_vars_fn())
    535   inputs = list(inputs)
    536   outputs = fn(*inputs)
    537   train_vars = get_vars_fn()[len_before_vars:]
    538 
    539   if grad_fn is None:
    540     return outputs
    541 
    542   if not (isinstance(outputs, tuple) or isinstance(outputs, list)):
    543     outputs = [outputs]
    544   outputs = list(outputs)
    545 
    546   defun_inputs = [inputs, train_vars, outputs]
    547 
    548   def custom_grad_fn(op, *dys):
    549     """Custom grad fn applying grad_fn for identity Defun."""
    550     fn_inputs, fn_vars, fn_outputs = nest.pack_sequence_as(
    551         defun_inputs, list(op.inputs))
    552     dys = list(dys)
    553     assert len(fn_outputs) == len(outputs)
    554     assert len(fn_outputs) == len(dys)
    555 
    556     grad_inputs, grad_vars = grad_fn(fn_inputs, fn_vars, fn_outputs, dys)
    557     grad_outputs = [None] * len(fn_outputs)
    558     return tuple(grad_inputs + grad_vars + grad_outputs)
    559 
    560   # The Defun takes as input the original inputs, the trainable variables
    561   # created in fn, and the outputs. In the forward it passes through the
    562   # outputs. In the backwards, it produces gradients for the original inputs
    563   # and the trainable variables.
    564   in_types = [t.dtype for t in inputs]
    565   out_types = [t.dtype for t in outputs]
    566   var_types = [t.dtype for t in train_vars]
    567 
    568   # Get a unique name for the Defun
    569   with framework_ops.name_scope("identity_custom_grad") as ns:
    570     defun_name = ns
    571 
    572   @function.Defun(
    573       *(in_types + var_types + out_types),
    574       func_name=defun_name,
    575       python_grad_func=custom_grad_fn,
    576       shape_func=lambda _: [t.get_shape() for t in outputs])
    577   def identity(*args):
    578     _, _, outs = nest.pack_sequence_as(defun_inputs, args)
    579     return tuple([array_ops.identity(t) for t in outs])
    580 
    581   flat_inputs = nest.flatten(defun_inputs)
    582   id_out = identity(*flat_inputs)
    583   return id_out
    584