1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Reversible Residual Block. 16 17 From 18 [The Reversible Residual Network: Backpropagation Without Storing 19 Activations](https://arxiv.org/abs/1707.04585). 20 21 Also contains the @recompute_grad decorator, which recomputes the forward 22 function on the backwards pass. 23 """ 24 25 from __future__ import absolute_import 26 from __future__ import division 27 from __future__ import print_function 28 29 import functools 30 import re 31 32 from six.moves import xrange # pylint: disable=redefined-builtin 33 34 from tensorflow.contrib.framework.python import ops as contrib_framework_ops 35 from tensorflow.python.framework import function 36 from tensorflow.python.framework import ops as framework_ops 37 from tensorflow.python.layers import base 38 from tensorflow.python.ops import array_ops 39 from tensorflow.python.ops import control_flow_ops 40 from tensorflow.python.ops import gradients_impl 41 from tensorflow.python.ops import math_ops 42 from tensorflow.python.ops import variable_scope 43 from tensorflow.python.platform import tf_logging as logging 44 from tensorflow.python.util import nest 45 46 __all__ = ["rev_block", "RevBlock", "recompute_grad"] 47 48 LAYER_RE = re.compile(".*revlayer_([0-9]*)/([fg])/.*") 49 50 51 def _acc_grads(*lists_of_grads): 52 """Accumulates lists of gradients.""" 53 acc_grads = [] 54 for grads in zip(*lists_of_grads): 55 grads = [g for g in grads if g is not None] 56 if grads: 57 acc_grads.append(math_ops.add_n(grads)) 58 else: 59 acc_grads.append(None) 60 return acc_grads 61 62 63 def _rev_layer_forward(xs, f, g, f_side_input, g_side_input, 64 gate_outputs=False): 65 """Forward for 1 reversible layer.""" 66 x1, x2 = xs 67 y1 = x1 + (f(x2, f_side_input) if f_side_input else f(x2)) 68 y2 = x2 + (g(y1, g_side_input) if g_side_input else g(y1)) 69 if gate_outputs: 70 return control_flow_ops.tuple([y1, y2]) 71 else: 72 return (y1, y2) 73 74 75 def _rev_layer_backward(ys, grad_ys, f, g, f_vars, f_side_input, g_vars, 76 g_side_input): 77 """Backprop for 1 layer.""" 78 y1, y2 = ys 79 grad_y1, grad_y2 = grad_ys 80 81 # Reconstruct intermediates and inputs (x1, x2) 82 # stop_gradients required on fn inputs to prevent infinite recursion into this 83 # grad function on the calls to gradients. 84 y1_stop = array_ops.stop_gradient(y1) 85 g_side_input = [array_ops.stop_gradient(t) for t in g_side_input] 86 gy1 = g(y1_stop, g_side_input) if g_side_input else g(y1_stop) 87 88 x2 = y2 - gy1 89 x2_stop = array_ops.stop_gradient(x2) 90 f_side_input = [array_ops.stop_gradient(t) for t in f_side_input] 91 fx2 = f(x2_stop, f_side_input) if f_side_input else f(x2_stop) 92 93 x1 = y1 - fx2 94 95 # Compute gradients wrt to inputs 96 # dL/dy2 * dG(y1)/y1 97 grad_gy1_y2 = gradients_impl.gradients(gy1, y1_stop, grad_y2)[0] 98 grad_x1 = grad_y1 + grad_gy1_y2 99 grad_x2 = ( 100 gradients_impl.gradients(fx2, x2_stop, grad_y1)[0] + grad_y2 + 101 gradients_impl.gradients(fx2, x2_stop, grad_gy1_y2)[0]) 102 103 # Compute gradients wrt to vars and side inputs in f and g 104 grads1 = gradients_impl.gradients(gy1, g_vars + g_side_input, grad_y2) 105 grad_g_vars, grad_g_side = grads1[:len(g_vars)], grads1[len(g_vars):] 106 grads2 = gradients_impl.gradients(fx2, f_vars + f_side_input, grad_y1) 107 grad_f_y1, grad_f_side1 = grads2[:len(f_vars)], grads2[len(f_vars):] 108 grads3 = gradients_impl.gradients(fx2, f_vars + f_side_input, grad_gy1_y2) 109 grad_f_y2, grad_f_side2 = grads3[:len(f_vars)], grads3[len(f_vars):] 110 grad_f_vars = _acc_grads(grad_f_y1, grad_f_y2) 111 112 grad_f_side = _acc_grads(grad_f_side1, grad_f_side2) 113 114 # Put returns in a tuple to ensure a constant memory budget (i.e. don't want 115 # the subsequent layer to start computing and consuming memory based on a 116 # subset of these values). 117 outputs = ((x1, x2), (grad_x1, grad_x2), (grad_f_vars, grad_f_side), 118 (grad_g_vars, grad_g_side)) 119 tupled = control_flow_ops.tuple(nest.flatten(outputs)) 120 return nest.pack_sequence_as(outputs, tupled) 121 122 123 def _rev_block_forward(x1, 124 x2, 125 f, 126 g, 127 num_layers=1, 128 f_side_input=None, 129 g_side_input=None, 130 gate_outputs=False): 131 """Forward for a series of reversible layers.""" 132 out = (x1, x2) 133 for i in xrange(num_layers): 134 out = _rev_layer_forward( 135 out, f[i], g[i], f_side_input, g_side_input, gate_outputs=gate_outputs) 136 137 y1, y2 = out 138 return y1, y2 139 140 141 def _scope_wrap(fn, scope): 142 143 @functools.wraps(fn) 144 def wrap(*args, **kwargs): 145 with variable_scope.variable_scope(scope): 146 return fn(*args, **kwargs) 147 148 return wrap 149 150 151 class RevBlock(base.Layer): 152 """Block of reversible layers. See rev_block.""" 153 154 def __init__(self, 155 f, 156 g, 157 num_layers=1, 158 f_side_input=None, 159 g_side_input=None, 160 use_efficient_backprop=True, 161 name="revblock", 162 **kwargs): 163 super(RevBlock, self).__init__(name=name, **kwargs) 164 165 if isinstance(f, list): 166 assert len(f) == num_layers 167 else: 168 f = [f] * num_layers 169 170 if isinstance(g, list): 171 assert len(g) == num_layers 172 else: 173 g = [g] * num_layers 174 175 f = [_scope_wrap(fn, "revlayer_%d/f" % i) for i, fn in enumerate(f)] 176 g = [_scope_wrap(fn, "revlayer_%d/g" % i) for i, fn in enumerate(g)] 177 178 self.f = f 179 self.g = g 180 181 self.num_layers = num_layers 182 self.f_side_input = f_side_input or [] 183 self.g_side_input = g_side_input or [] 184 185 self._use_efficient_backprop = use_efficient_backprop 186 187 def call(self, inputs, forward=True): 188 vs = variable_scope.get_variable_scope() 189 vars_before = vs.global_variables() 190 191 if forward: 192 x1, x2 = inputs 193 out = self._forward(x1, x2) 194 else: 195 y1, y2 = inputs 196 out = self._backward(y1, y2) 197 198 # Add any created variables to the Layer's variable stores 199 new_vars = vs.global_variables()[len(vars_before):] 200 train_vars = vs.trainable_variables() 201 for new_var in new_vars: 202 if new_var in train_vars: 203 self._trainable_weights.append(new_var) 204 else: 205 self._non_trainable_weights.append(new_var) 206 207 return out 208 209 def forward(self, x1, x2): 210 return self.apply([x1, x2]) 211 212 def backward(self, y1, y2): 213 return self.apply([y1, y2], forward=False) 214 215 def build(self, _): 216 logging.warn("RevBlock constructs its variables on first call, not on " 217 "build.") 218 self.built = True 219 220 def _efficient_grad_fn(self, inputs, variables, ys, grad_ys): 221 """Custom gradient fn for a block of reversible residual layers.""" 222 side_inputs = inputs[2:] 223 f_side_idxs = [None] * len(self.f_side_input) 224 g_side_idxs = [None] * len(self.g_side_input) 225 assert len(side_inputs) == len(self.f_side_input) + len(self.g_side_input) 226 227 for i, t in enumerate(side_inputs): 228 if t in self.f_side_input: 229 f_side_idxs[self.f_side_input.index(t)] = i 230 elif t in self.g_side_input: 231 g_side_idxs[self.g_side_input.index(t)] = i 232 else: 233 assert False 234 235 f_vars = [[] for _ in range(self.num_layers)] 236 g_vars = [[] for _ in range(self.num_layers)] 237 f_vars_idxs = [[] for _ in range(self.num_layers)] 238 g_vars_idxs = [[] for _ in range(self.num_layers)] 239 240 for i, t in enumerate(variables): 241 ref = _underlying_variable_ref(t) 242 243 # Use the name to identify the layer number and function (f or g) 244 regex = LAYER_RE.match(ref.name) 245 layer_no = int(regex.group(1)) 246 fn_name = regex.group(2) 247 if fn_name == "f": 248 f_vars[layer_no].append(ref) 249 f_vars_idxs[layer_no].append(i) 250 else: 251 assert fn_name == "g" 252 g_vars[layer_no].append(ref) 253 g_vars_idxs[layer_no].append(i) 254 255 f_var_grads = [] 256 g_var_grads = [] 257 f_side_grads = [] 258 g_side_grads = [] 259 260 # Reverse variable containers to go backward 261 f_vars.reverse() 262 g_vars.reverse() 263 f = list(self.f) 264 g = list(self.g) 265 f.reverse() 266 g.reverse() 267 268 with variable_scope.variable_scope(self.scope_name, reuse=True): 269 for i in xrange(self.num_layers): 270 ys, grad_ys, f_ret, g_ret = _rev_layer_backward( 271 ys, grad_ys, f[i], g[i], f_vars[i], self.f_side_input, g_vars[i], 272 self.g_side_input) 273 274 grad_f_vars, grad_f_side = f_ret 275 grad_g_vars, grad_g_side = g_ret 276 f_var_grads.append(grad_f_vars) 277 g_var_grads.append(grad_g_vars) 278 f_side_grads.append(grad_f_side) 279 g_side_grads.append(grad_g_side) 280 281 # Accumulate layer gradients for f_side_input and g_side_input 282 acc_f_side_grads = _acc_grads(*f_side_grads) 283 acc_g_side_grads = _acc_grads(*g_side_grads) 284 285 # Use the stored idxs to put gradients in the passed-in order. 286 side_input_grads = [None] * len(side_inputs) 287 variable_grads = [None] * len(variables) 288 289 # Variable gradients were collected in reverse layer order. Reverse to match 290 # idxs. 291 f_var_grads.reverse() 292 g_var_grads.reverse() 293 for idxs, grads in list(zip(f_vars_idxs, f_var_grads)) + list( 294 zip(g_vars_idxs, g_var_grads)): 295 for i, grad in zip(idxs, grads): 296 variable_grads[i] = grad 297 298 for i, grad in zip(f_side_idxs, acc_f_side_grads): 299 side_input_grads[i] = grad 300 for i, grad in zip(g_side_idxs, acc_g_side_grads): 301 side_input_grads[i] = grad 302 303 grad_x1, grad_x2 = grad_ys 304 return [grad_x1, grad_x2] + side_input_grads, variable_grads 305 306 def _forward(self, x1, x2): 307 """Run forward through the reversible layers.""" 308 309 side_inputs = [self.f_side_input, self.g_side_input] 310 flat_side_inputs = nest.flatten(side_inputs) 311 312 custom_grad_fn = ( 313 self._efficient_grad_fn if self._use_efficient_backprop else None) 314 315 @_fn_with_custom_grad(custom_grad_fn) 316 def _forward_wrap(x1_, x2_, *flat_side_inputs): 317 f_side, g_side = nest.pack_sequence_as(side_inputs, flat_side_inputs) 318 return _rev_block_forward( 319 x1_, 320 x2_, 321 self.f, 322 self.g, 323 num_layers=self.num_layers, 324 f_side_input=f_side, 325 g_side_input=g_side, 326 gate_outputs=self._use_efficient_backprop) 327 328 return _forward_wrap(x1, x2, *flat_side_inputs) 329 330 def _backward(self, y1, y2): 331 """Run backward through the reversible layers.""" 332 333 f = list(self.f) 334 g = list(self.g) 335 f.reverse() 336 g.reverse() 337 338 for i in xrange(self.num_layers): 339 gy1 = g[i](y1, self.g_side_input) if self.g_side_input else g[i](y1) 340 x2 = y2 - gy1 341 fx2 = f[i](x2, self.f_side_input) if self.f_side_input else f[i](x2) 342 x1 = y1 - fx2 343 344 y1, y2 = x1, x2 345 346 return x1, x2 347 348 349 def rev_block(x1, 350 x2, 351 f, 352 g, 353 num_layers=1, 354 f_side_input=None, 355 g_side_input=None, 356 is_training=True): 357 """A block of reversible residual layers. 358 359 A reversible residual layer is defined as: 360 361 ``` 362 y1 = x1 + f(x2, f_side_input) 363 y2 = x2 + g(y1, g_side_input) 364 ``` 365 366 A reversible residual block, defined here, is a series of reversible residual 367 layers. 368 369 Limitations: 370 * f and g must not close over any Tensors; all side inputs to f and g should 371 be passed in with f_side_input and g_side_input which will be forwarded to 372 f and g. 373 * f and g must not change the dimensionality of their inputs in order for the 374 addition in the equations above to work. 375 376 Args: 377 x1: a float Tensor. 378 x2: a float Tensor. 379 f: a function, (Tensor) -> (Tensor) (or list of such of length num_layers). 380 Should not change the shape of the Tensor. Can make calls to get_variable. 381 See f_side_input if there are side inputs. 382 g: a function, (Tensor) -> (Tensor) (or list of such of length num_layers). 383 Should not change the shape of the Tensor. Can make calls to get_variable. 384 See g_side_input if there are side inputs. 385 num_layers: int, number of reversible residual layers. Each layer will 386 apply f and g according to the equations above, with new variables in each 387 layer. 388 f_side_input: list of Tensors, side input to f. If not None, signature of f 389 should be (Tensor, list<Tensor>) -> (Tensor). 390 g_side_input: list of Tensors, side input to g. If not None, signature of g 391 should be (Tensor, list<Tensor>) -> (Tensor). 392 is_training: bool, whether to actually use the efficient backprop codepath. 393 394 Returns: 395 y1, y2: tuple of float Tensors. 396 """ 397 block = RevBlock( 398 f=f, 399 g=g, 400 num_layers=num_layers, 401 f_side_input=f_side_input, 402 g_side_input=g_side_input, 403 use_efficient_backprop=is_training, 404 _reuse=variable_scope.get_variable_scope().reuse) 405 return block.forward(x1, x2) 406 407 408 def recompute_grad(fn): 409 """Decorator that recomputes the function on the backwards pass. 410 411 Args: 412 fn: a function that takes Tensors (all as positional arguments) and returns 413 a tuple of Tensors. 414 415 Returns: 416 A wrapped fn that is identical to fn when called, but its activations will 417 be discarded and recomputed on the backwards pass (i.e. on a call to 418 tf.gradients). 419 """ 420 421 @functools.wraps(fn) 422 def wrapped(*args): 423 return _recompute_grad(fn, args) 424 425 return wrapped 426 427 428 def _recompute_grad(fn, args): 429 """See recompute_grad.""" 430 431 cached_vs = [] 432 cached_arg_scope = [] 433 434 def grad_fn(inputs, variables, outputs, output_grads): 435 """Recompute outputs for gradient computation.""" 436 del outputs 437 # Recompute outputs 438 with framework_ops.control_dependencies(output_grads): 439 with contrib_framework_ops.arg_scope(cached_arg_scope[0]): 440 with variable_scope.variable_scope(cached_vs[0], reuse=True): 441 outputs = fn(*inputs) 442 443 if not (isinstance(outputs, list) or isinstance(outputs, tuple)): 444 outputs = [outputs] 445 outputs = list(outputs) 446 grads = gradients_impl.gradients(outputs, inputs + variables, output_grads) 447 grad_inputs = grads[:len(inputs)] 448 grad_vars = grads[len(inputs):] 449 return grad_inputs, grad_vars 450 451 @_fn_with_custom_grad(grad_fn) 452 def fn_with_recompute(*args): 453 cached_vs.append(variable_scope.get_variable_scope()) 454 # TODO(rsepassi): Rm conditional in TF 1.4 455 if hasattr(contrib_framework_ops, "current_arg_scope"): 456 cached_arg_scope.append(contrib_framework_ops.current_arg_scope()) 457 else: 458 cached_arg_scope.append({}) 459 return fn(*args) 460 461 return fn_with_recompute(*args) 462 463 464 def _underlying_variable_ref(t): 465 """Find the underlying variable ref. 466 467 Traverses through Identity, ReadVariableOp, and Enter ops. 468 Stops when op type has Variable or VarHandle in name. 469 470 Args: 471 t: a Tensor 472 473 Returns: 474 a Tensor that is a variable ref, or None on error. 475 """ 476 while t.op.type in ["Identity", "ReadVariableOp", "Enter"]: 477 t = t.op.inputs[0] 478 479 op_type = t.op.type 480 if "Variable" in op_type or "VarHandle" in op_type: 481 return t 482 else: 483 return None 484 485 486 def _fn_with_custom_grad(grad_fn, use_global_vars=False): 487 """Decorator to create a subgraph with a custom gradient function. 488 489 The subgraph created by the decorated function is NOT put in a Defun and so 490 does not suffer from the limitations of the Defun (all subgraph ops on the 491 same device, no summaries). 492 493 Args: 494 grad_fn: function with signature 495 (inputs, variables, outputs, output_grads) -> (grad_inputs, grad_vars), 496 all of which are lists of Tensors. 497 use_global_vars: if True, variables will be the global variables created. 498 If False, will be the trainable variables. 499 500 Returns: 501 Decorator for function such that the gradient is defined by grad_fn. 502 """ 503 504 def dec(fn): 505 506 @functools.wraps(fn) 507 def wrapped(*args): 508 return _fn_with_custom_grad_internal( 509 fn, args, grad_fn, use_global_vars=use_global_vars) 510 511 return wrapped 512 513 return dec 514 515 516 def _fn_with_custom_grad_internal(fn, inputs, grad_fn, use_global_vars=False): 517 """Create a subgraph with a custom gradient. 518 519 Args: 520 fn: function that takes inputs as arguments and produces 1 or more Tensors. 521 inputs: list<Tensor>, will be passed as fn(*inputs). 522 grad_fn: function with signature 523 (inputs, vars, outputs, output_grads) -> (grad_inputs, grad_vars), 524 all of which are lists of Tensors. 525 use_global_vars: if True, variables will be the global variables created. 526 If False, will be the trainable variables. 527 528 Returns: 529 fn(*inputs) 530 """ 531 vs = variable_scope.get_variable_scope() 532 get_vars_fn = ( 533 vs.global_variables if use_global_vars else vs.trainable_variables) 534 len_before_vars = len(get_vars_fn()) 535 inputs = list(inputs) 536 outputs = fn(*inputs) 537 train_vars = get_vars_fn()[len_before_vars:] 538 539 if grad_fn is None: 540 return outputs 541 542 if not (isinstance(outputs, tuple) or isinstance(outputs, list)): 543 outputs = [outputs] 544 outputs = list(outputs) 545 546 defun_inputs = [inputs, train_vars, outputs] 547 548 def custom_grad_fn(op, *dys): 549 """Custom grad fn applying grad_fn for identity Defun.""" 550 fn_inputs, fn_vars, fn_outputs = nest.pack_sequence_as( 551 defun_inputs, list(op.inputs)) 552 dys = list(dys) 553 assert len(fn_outputs) == len(outputs) 554 assert len(fn_outputs) == len(dys) 555 556 grad_inputs, grad_vars = grad_fn(fn_inputs, fn_vars, fn_outputs, dys) 557 grad_outputs = [None] * len(fn_outputs) 558 return tuple(grad_inputs + grad_vars + grad_outputs) 559 560 # The Defun takes as input the original inputs, the trainable variables 561 # created in fn, and the outputs. In the forward it passes through the 562 # outputs. In the backwards, it produces gradients for the original inputs 563 # and the trainable variables. 564 in_types = [t.dtype for t in inputs] 565 out_types = [t.dtype for t in outputs] 566 var_types = [t.dtype for t in train_vars] 567 568 # Get a unique name for the Defun 569 with framework_ops.name_scope("identity_custom_grad") as ns: 570 defun_name = ns 571 572 @function.Defun( 573 *(in_types + var_types + out_types), 574 func_name=defun_name, 575 python_grad_func=custom_grad_fn, 576 shape_func=lambda _: [t.get_shape() for t in outputs]) 577 def identity(*args): 578 _, _, outs = nest.pack_sequence_as(defun_inputs, args) 579 return tuple([array_ops.identity(t) for t in outs]) 580 581 flat_inputs = nest.flatten(defun_inputs) 582 id_out = identity(*flat_inputs) 583 return id_out 584