1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Code for backpropagation using the tape utilities.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import collections 22 import functools 23 import operator 24 import threading 25 26 import six 27 28 from tensorflow.python import pywrap_tensorflow 29 from tensorflow.python.eager import context 30 from tensorflow.python.eager import execute 31 from tensorflow.python.eager import imperative_grad 32 from tensorflow.python.eager import tape 33 from tensorflow.python.framework import constant_op 34 from tensorflow.python.framework import dtypes 35 from tensorflow.python.framework import errors 36 from tensorflow.python.framework import ops 37 from tensorflow.python.framework import tensor_shape 38 from tensorflow.python.ops import array_ops 39 from tensorflow.python.ops import gen_array_ops 40 from tensorflow.python.ops import math_ops 41 from tensorflow.python.ops import resource_variable_ops 42 from tensorflow.python.util import nest 43 from tensorflow.python.util import tf_inspect 44 45 46 class _TensorCache(object): 47 """Simple cache which evicts items based on length in a FIFO manner.""" 48 49 def __init__(self, max_items=256): 50 self._data = collections.OrderedDict() 51 self._max_items = max_items if max_items else 256 52 53 def put(self, key, value): 54 self._data[key] = value 55 56 if len(self._data) > self._max_items: 57 self._data.popitem(last=False) 58 59 def get(self, key): 60 return self._data.get(key, None) 61 62 def flush(self): 63 self._data = {} 64 65 66 _op_attr_type_cache = {} 67 68 69 def op_attr_type(op_type, attr_name): 70 try: 71 return _op_attr_type_cache[(op_type, attr_name)] 72 except KeyError: 73 with errors.raise_exception_on_not_ok_status() as status: 74 h = context.context()._handle # pylint: disable=protected-access 75 attr_type = pywrap_tensorflow.TFE_OpNameGetAttrType( 76 h, op_type, attr_name, status) 77 _op_attr_type_cache[(op_type, attr_name)] = attr_type 78 return attr_type 79 80 81 def make_attr(attr_type, value): 82 if attr_type == pywrap_tensorflow.TF_ATTR_TYPE: 83 return dtypes.as_dtype(value) 84 elif attr_type == [pywrap_tensorflow.TF_ATTR_TYPE]: 85 return [dtypes.as_dtype(v) for v in value] 86 elif attr_type == pywrap_tensorflow.TF_ATTR_SHAPE: 87 return tensor_shape.as_shape(value).as_proto() 88 elif attr_type == [pywrap_tensorflow.TF_ATTR_SHAPE]: 89 return [tensor_shape.as_shape(v).as_proto() for v in value] 90 return value 91 92 93 class _MockOp(object): 94 """Pretends to be a tf.Operation for the gradient functions.""" 95 96 def __init__(self, attrs, inputs, outputs, typ): 97 self.attrs = attrs 98 self.inputs = inputs 99 self.outputs = outputs 100 self.type = typ 101 102 def get_attr(self, attr): 103 typ = op_attr_type(self.type, attr) 104 for i in range(0, len(self.attrs), 2): 105 if self.attrs[i] == attr: 106 return make_attr(typ, self.attrs[i + 1]) 107 raise KeyError(attr) 108 109 110 def _magic_gradient_function(op_name, attr_tuple, num_inputs, 111 inputs, outputs, out_grads): 112 """Calls the gradient function of the op. 113 114 Args: 115 op_name: the name of the op to be differentiated. 116 attr_tuple: the attrs, as a tuple. 117 num_inputs: the number of inputs to the op. 118 inputs: inputs to the original operation. 119 outputs: outputs to the original operation. 120 out_grads: gradients of the operation wrt its outputs. 121 122 Returns: 123 The gradients with respect to the inputs of the function, as a list. 124 """ 125 mock_op = _MockOp(attr_tuple, inputs, outputs, op_name) 126 grad_fn = ops._gradient_registry.lookup(op_name) # pylint: disable=protected-access 127 if grad_fn is None: 128 return [None] * num_inputs 129 130 return grad_fn(mock_op, *out_grads) 131 132 133 _gradient_functions = {} 134 _gradient_functions_lock = threading.Lock() 135 136 137 _tracing = False 138 139 140 # TODO(apassos) replace this with a mechanism which can happen at the op 141 # gradient function registration site, to be less error-prone 142 # TODO(apassos) add ops other than those in nn_grad and math_grad 143 _ops_which_dont_need_outputs = set([ 144 "Identity", 145 "MatMul", 146 "Conv2DBackpropInput", 147 "Conv2DBackpropFilter", 148 "Conv3D", 149 "Conv3DBackpropInputV2", 150 "AvgPool3D", 151 "AvgPool3DGrad", 152 "MaxPool3D", 153 "MaxPool3DGrad", 154 "MaxPool3DGradGrad", 155 "BiasAdd", 156 "BiasAddV1", 157 "BiasAddGrad", 158 "Relu6", 159 "Softplus", 160 "SoftplusGrad", 161 "Softsign", 162 "ReluGrad", 163 "Conv2D", 164 "DepthwiseConv2dNative", 165 "Dilation2D", 166 "AvgPool", 167 "AvgPoolGrad", 168 "BatchNormWithGlobalNormalization", 169 "L2Loss", 170 "Sum", 171 "Prod", 172 "SegmentSum", 173 "SegmentMean", 174 "SparseSegmentSum", 175 "SparseSegmentMean", 176 "SparseSegmentSqrtN", 177 "SegmentMin", 178 "SegmentMax", 179 "UnsortedSegmentSum", 180 "UnsortedSegmentMax", 181 "UnsortedSegmentMin", 182 "UnsortedSegmentProd", 183 "Abs", 184 "Neg", 185 "ReciprocalGrad", 186 "Square", 187 "Expm1", 188 "Log", 189 "Log1p", 190 "TanhGrad", 191 "SigmoidGrad", 192 "Sign", 193 "Sin", 194 "Cos", 195 "Tan", 196 "Add", 197 "Sub", 198 "Mul", 199 "Div", 200 "RealDiv", 201 "Maximum", 202 "Minimum", 203 "SquaredDifference", 204 "Select", 205 "SparseMatMul", 206 "BatchMatMul", 207 "Complex", 208 "Real", 209 "Imag", 210 "Angle", 211 "Conj", 212 "Cast", 213 "Cross", 214 "Cumsum", 215 "Cumprod", 216 "ReadVariableOp", 217 "VarHandleOp", 218 "Shape", 219 ]) 220 221 _ops_which_dont_need_inputs = set([ 222 "Identity", 223 "Softmax", 224 "LogSoftmax", 225 "BiasAdd", 226 "Relu", 227 "Elu", 228 "Selu", 229 "SparseSoftmaxCrossEntropyWithLogits", 230 "Neg", 231 "Inv", 232 "Reciprocal", 233 "Sqrt", 234 "Exp", 235 "Tanh", 236 "Sigmoid", 237 "Real", 238 "Imag", 239 "Conj", 240 "ReadVariableOp", 241 "VarHandleOp", 242 "Shape", 243 ]) 244 245 246 # TODO(agarwal): use an automatic mechanism for handling None arguments to 247 # gradient functions. 248 # Some gradient functions can accept None arguments for gradients. The following 249 # maps the operation name to the indices at which the corresponding gradient 250 # function can accept None values. 251 # e.g. FusedBatchNorm outputs 5 values and hence receives 5 gradient values 252 # during backprop. However the gradient function uses only the first of those 253 # values and ignores the rest. The entry, "FusedBatchNorm": [1, 2, 3, 4], 254 # indicates that only the gradient corresponding to index 0 is used, and the 255 # gradient values at indices 1-4 are ignored (and hence can be None). The 256 # backprop algorithm can then leverage this by not constructing zeros to 257 # pass for those indices. 258 _grad_fn_accepts_none_for_indices = { 259 "SoftmaxCrossEntropyWithLogits": [1], 260 "FusedBatchNorm": [1, 2, 3, 4] 261 } 262 263 264 def _record_gradient(op_name, inputs, attrs, results, name): 265 """Records gradients for a TensorFlow operation. 266 267 Args: 268 op_name: Name of the TensorFlow operation (see REGISTER_OP in C++ code) to 269 execute. 270 inputs: A flat list of Tensor object inputs to the operation. 271 attrs: A tuple with alternating string attr names and attr values for this 272 operation. 273 results: The results of the operation (as a flat list). 274 name: Customized name for the operation. 275 276 Returns: 277 A list of maybe-wrapped results. Either Tensors or TensorNodes. 278 279 Raises: 280 An exception on error. 281 """ 282 if not tape.could_possibly_record(): 283 return 284 285 if op_name in _ops_which_dont_need_outputs: 286 op_outputs = None 287 else: 288 # TODO(apassos) this line creates a weak circular reference where the 289 # backprop function keeps an output alive which in turn keeps the tape entry 290 # alive which keeps the backprop function alive. Figure out how to break 291 # this up without breaking second derivatives of ops like Exp whose 292 # gradients depend only on the outputs. 293 op_outputs = results 294 295 if op_name in _ops_which_dont_need_inputs: 296 op_inputs = None 297 else: 298 op_inputs = inputs 299 300 num_inputs = len(inputs) 301 302 def grad_fn(*orig_outputs): 303 """Generated gradient function.""" 304 result = _magic_gradient_function(op_name, attrs, num_inputs, 305 op_inputs, op_outputs, orig_outputs) 306 if _tracing: 307 print("Gradient for", (name if name else op_name), "inputs", op_inputs, 308 "output_grads", orig_outputs, "gradients", result) 309 return nest.flatten(result) 310 311 tape.record_operation(op_name, results, inputs, grad_fn) 312 if _tracing: 313 print("Computed op", (name if name else op_name), "inputs", inputs, 314 "outputs", results) 315 316 317 execute.record_gradient = _record_gradient 318 319 320 def implicit_val_and_grad(f): 321 """Returns a function which differentiates f with respect to variables. 322 323 The wrapped function returns the value and the gradient of f when called with 324 the same arguments. The gradient is with respect to all TFE variables which 325 have `variable.watch()` called on them by f. 326 327 This function is useful when the exact set of variables to differentiate with 328 is not known ahead of time. 329 330 Example: 331 332 ```python 333 dense_layer = tf.layers.Dense(1) 334 def loss(x, y): 335 return tf.reduce_sum(tf.square(dense_layer(x) - y)) 336 337 # Obtain the gradient function. 338 val_grad_fn = tfe.implicit_value_and_gradients(loss) 339 340 # Invoke the gradient function with concrete values of x and y. 341 x = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 342 y = tf.constant([[10.0], [20.0]]) 343 value, grads_and_vars = val_grad_fn(x, y) 344 print('Value of loss: %s' % value) 345 346 # Apply the gradients to Variables. 347 optimizer = tf.train.GradientDescentOptimizer(0.1) 348 optimizer.apply_gradients(grads_and_vars) 349 ``` 350 351 Args: 352 f: function to be differentiated. If `f` returns a scalar, this scalar will 353 be differentiated. If `f` returns a tensor or list of tensors, by default 354 a scalar will be computed by adding all their values to produce a single 355 scalar. 356 357 Returns: 358 A function which, when called, returns a tuple pair. 359 Its first element is the value to which the function evaluates. 360 Its second element is list of (gradient, variable) pairs. 361 362 Raises: 363 ValueError: if `f` returns None. 364 """ 365 # TODO(cais): Remove calls to tf.constant() once the gradients functions 366 # accept lists and np.ndarrays. 367 368 def grad_fn(*args): 369 """Computes the gradient of the wrapped function.""" 370 this_tape = tape.push_new_tape() 371 try: 372 end_node = f(*args) 373 if end_node is None: 374 raise ValueError("Cannot differentiate a function that returns None; " 375 "did you forget to return a value from {}?".format( 376 f.__name__)) 377 finally: 378 tape.pop_tape(this_tape) 379 # Sorting variables by id, which is monotonically increasing in construction 380 # order. This ensures unique order across executions. 381 variables = list(sorted(this_tape.watched_variables(), 382 key=lambda v: v.handle._id)) # pylint: disable=protected-access 383 sources = [x.handle for x in variables] 384 385 if not sources: 386 raise ValueError("No trainable variables were accessed while the " 387 "function was being computed.") 388 grad = imperative_grad.imperative_grad(_default_vspace, 389 this_tape, 390 nest.flatten(end_node), 391 sources) 392 return end_node, list(zip(grad, variables)) 393 394 return grad_fn 395 396 397 def implicit_grad(f): 398 """Returns a function which differentiates f with respect to variables. 399 400 The wrapped function returns the gradient of f when called with the same 401 arguments. The gradient is with respect to all TFE variables which have 402 `variable.watch()` called on them by f. 403 404 This function is useful when the exact set of variables to differentiate with 405 is not known ahead of time. 406 407 Example: 408 409 ```python 410 dense_layer = tf.layers.Dense(1) 411 def loss(x, y): 412 return tf.reduce_sum(tf.square(dense_layer(x) - y)) 413 414 # Obtain the gradient function. 415 grad_fn = tfe.implicit_gradients(loss) 416 417 # Invoke the gradient function with concrete values of x and y. 418 x = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 419 y = tf.constant([[10.0], [20.0]]) 420 grads_and_vars = grad_fn(x, y) 421 422 # Apply the gradients to Variables. 423 optimizer = tf.train.GradientDescentOptimizer(0.1) 424 optimizer.apply_gradients(grads_and_vars) 425 ``` 426 427 Args: 428 f: function to be differentiated. If `f` returns a scalar, this scalar will 429 be differentiated. If `f` returns a tensor or list of tensors, by default 430 a scalar will be computed by adding all their values to produce a single 431 scalar. 432 433 Returns: 434 A function which, when called, returns a list of (gradient, variable) pairs. 435 """ 436 # TODO(cais): Remove calls to tf.constant() once the gradients functions 437 # accept lists and np.ndarrays. 438 439 def grad_fn(*args, **kwds): 440 """Computes the gradient of the wrapped function.""" 441 return implicit_val_and_grad(f)(*args, **kwds)[1] 442 443 return grad_fn 444 445 446 def _get_arg_spec(f, params, param_args): 447 """The positions of the parameters of f to be differentiated in param_args.""" 448 try: 449 args = tf_inspect.getargspec(f).args 450 except TypeError as e: 451 # TypeError can happen when f is a callable object. 452 if params is None: 453 return range(len(param_args)) 454 elif all(isinstance(x, int) for x in params): 455 return params 456 raise ValueError("Either callable provided is not a function or could not " 457 "inspect its arguments by name: %s. Original error: %s" 458 % (f, e)) 459 if params is None: 460 if not args: 461 return range(len(param_args)) 462 return range(len(args)) 463 elif all(isinstance(x, six.string_types) for x in params): 464 return [args.index(n) for n in params] 465 elif all(isinstance(x, int) for x in params): 466 return params 467 else: 468 raise ValueError( 469 "params must be all strings or all integers; got %s." % params) 470 471 472 def gradients_function(f, params=None): 473 """Returns a function which differentiates f with respect to params. 474 475 Example: 476 ```python 477 # f(x, y) = (x ^ 3) * y - x * (y ^ 2) 478 # Therefore, the 1st order derivatives are: 479 # df / dx = 3 * (x ^ 2) * y - y ^ 2 480 # df / dy = x ^ 3 - 2 * x * y 481 # The 2nd order derivatives with respect to x is: 482 # d^2 f / (dx)^2 = 6 * x * y 483 def f(x, y): 484 return x * x * x * y - x * y * y 485 486 # Obtain a function that returns 1st order gradients. 487 grad_fn = tfe.gradients_function(f) 488 489 x = 2.0 490 y = 3.0 491 492 # Invoke the 1st order gradient function. 493 x_grad, y_grad = grad_fn(x, y) 494 assert x_grad.numpy() == 3 * (2 ** 2) * 3 - 3 ** 2 495 assert y_grad.numpy() == (2 ** 3) - 2 * 2 * 3 496 497 # Obtain a function that returns the 2nd order gradient with respect to x. 498 gradgrad_fn = tfe.gradients_function(lambda x, y: grad_fn(x, y)[0]) 499 500 # Invoke the 2nd order gradient function. 501 x_gradgrad = gradgrad_fn(x, y)[0] 502 assert x_gradgrad.numpy() == 6 * 2 * 3 503 504 # To obtain a callable that returns the gradient(s) of `f` with respect to a 505 # subset of its inputs, use the `params` keyword argument with 506 # `gradients_function()`. 507 ygrad_fn = tfe.gradients_function(f, params=[1]) 508 509 (y_grad,) = ygrad_fn(x, y) 510 assert y_grad.numpy() == (2 ** 3) - 2 * 2 * 3 511 ``` 512 513 Args: 514 f: function to be differentiated. If `f` returns a scalar, this scalar will 515 be differentiated. If `f` returns a tensor or list of tensors, by default 516 a scalar will be computed by adding all their values to produce a single 517 scalar. If desired, the tensors can be elementwise multiplied by the 518 tensors passed as the `dy` keyword argument to the returned gradient 519 function. 520 params: list of parameter names of f or list of integers indexing the 521 parameters with respect to which we'll differentiate. Passing None 522 differentiates with respect to all parameters. 523 524 Returns: 525 function which, when called, returns the value of f and the gradient 526 of f with respect to all of `params`. The function takes an extra optional 527 keyword argument "dy". Setting it allows computation of vector jacobian 528 products for vectors other than the vector of ones. 529 530 Raises: 531 ValueError: if the params are not all strings or all integers. 532 """ 533 534 def decorated(*args, **kwds): 535 """Computes the gradient of the decorated function.""" 536 537 _, grad = val_and_grad_function(f, params=params)(*args, **kwds) 538 return grad 539 540 return decorated 541 542 543 def _ensure_unique_tensor_objects(parameter_positions, args): 544 """Make each of the parameter_positions in args a unique ops.Tensor object. 545 546 Ensure that each parameter is treated independently. 547 For example: 548 549 def f(x, y): return x * y 550 g = gradients_function(f) 551 one = tf.constant(1.) 552 553 g(one, one) should return [1., 1.] 554 (even though the two arguments are the same Tensor object). 555 556 Args: 557 parameter_positions: List of indices into args defining the arguments to 558 differentiate against. 559 args: A list of arguments to the function to be differentiated. 560 561 Returns: 562 args, possibly edited in-place. 563 """ 564 s = set() 565 for (i, t) in enumerate(args): 566 if i in parameter_positions: 567 tid = ops.tensor_id(t) 568 if tid in s: 569 args[i] = gen_array_ops.identity(args[i]) 570 else: 571 s.add(tid) 572 return args 573 574 575 def val_and_grad_function(f, params=None): 576 """Returns a function that computes f and its derivative w.r.t. params. 577 578 Example: 579 ```python 580 # f(x, y) = (x ^ 3) * y - x * (y ^ 2) 581 # Therefore, the 1st order derivatives are: 582 # df / dx = 3 * (x ^ 2) * y - y ^ 2 583 # df / dy = x ^ 3 - 2 * x * y 584 def f(x, y): 585 return x * x * x * y - x * y * y 586 587 # Obtain a function that returns the function value and the 1st order 588 # gradients. 589 val_grads_fn = tfe.value_and_gradients_function(f) 590 591 x = 2.0 592 y = 3.0 593 594 # Invoke the value-and-gradients function. 595 f_val, (x_grad, y_grad) = val_grads_fn(x, y) 596 assert f_val.numpy() == (2 ** 3) * 3 - 2 * (3 ** 2) 597 assert x_grad.numpy() == 3 * (2 ** 2) * 3 - 3 ** 2 598 assert y_grad.numpy() == (2 ** 3) - 2 * 2 * 3 599 600 # To obtain a callable that returns the value of `f` and the gradient(s) of 601 # `f` with respect to a subset of its inputs, use the `params` keyword 602 # argument with `value_and_gradients_function()`. 603 val_ygrad_fn = tfe.value_and_gradients_function(f, params=[1]) 604 605 f_val, (y_grad,) = val_ygrad_fn(x, y) 606 assert f_val.numpy() == (2 ** 3) * 3 - 2 * (3 ** 2) 607 assert y_grad.numpy() == (2 ** 3) - 2 * 2 * 3 608 ``` 609 610 Args: 611 f: function to be differentiated. If `f` returns a scalar, this scalar will 612 be differentiated. If `f` returns a tensor or list of tensors, by default 613 a scalar will be computed by adding all their values to produce a single 614 scalar. If desired, the tensors can be elementwise multiplied by the 615 tensors passed as the `dy` keyword argument to the returned gradient 616 function. 617 params: list of parameter names of f or list of integers indexing the 618 parameters with respect to which we'll differentiate. Passing `None` 619 differentiates with respect to all parameters. 620 621 Returns: function which, when called, returns the value of f and the gradient 622 of f with respect to all of `params`. The function takes an extra optional 623 keyword argument "dy". Setting it allows computation of vector jacobian 624 products for vectors other than the vector of ones. 625 626 Raises: 627 ValueError: if the params are not all strings or all integers. 628 """ 629 630 def decorated(*args, **kwds): 631 """Computes the value and gradient of the decorated function.""" 632 dy = kwds.pop("dy", None) 633 if kwds: 634 raise ValueError("Functions to be differentiated cannot " 635 "receive keyword arguments.") 636 val, vjp = make_vjp(f, params)(*args, **kwds) 637 return val, vjp(dy=dy) 638 639 return decorated 640 641 642 def make_vjp(f, params=None): 643 """Returns a function that computes f and is vjp w.r.t. params. 644 645 The term "vjp" here is an abbreviation for vector-jacobian product. 646 647 Args: 648 f: the function to be differentiated. 649 params: the parameters (numbers or names) to differentiate with respect to. 650 A value of None will differentiate with respect to all parameters. 651 652 Returns: 653 A function, which when called, returns a tuple (value, vjp), where: 654 - value is the result of calling f. 655 - vjp is a function, which takes a vector as an argument and 656 returns the product of that vector with the Jacobian of f. 657 Providing no argument to vjp is equivalent to providing a 658 vector of ones. 659 660 For example, 661 ```python 662 def f(x): 663 return x * x 664 665 wrapped_fn = tfe.make_vjp(f) 666 result, vjp = wrapped_fn(tf.constant(3.0)) 667 # result is 9.0 668 vjp() # the vjp function rturns 6.0 669 670 Raises: 671 ValueError: if `f` returns None. 672 """ 673 674 def decorated(*args, **kwds): 675 """Computes the value and gradient of the decorated function.""" 676 parameter_positions = _get_arg_spec(f, params, args) 677 assert not kwds, "The gradient function can't take keyword arguments." 678 this_tape = tape.push_new_tape() 679 try: 680 sources = [] 681 args = [ 682 ops.convert_to_tensor(args[i]) 683 if i in parameter_positions else args[i] 684 for i in range(len(args)) 685 ] 686 args = _ensure_unique_tensor_objects(parameter_positions, args) 687 for i in parameter_positions: 688 sources.append(args[i]) 689 tape.watch(args[i]) 690 result = f(*args) 691 if result is None: 692 raise ValueError("Cannot differentiate a function that returns None; " 693 "did you forget to return a value from {}?".format( 694 f.__name__)) 695 flat_result = nest.flatten(result) 696 flat_result = [gen_array_ops.identity(x) for x in flat_result] 697 result = nest.pack_sequence_as(result, flat_result) 698 finally: 699 tape.pop_tape(this_tape) 700 def vjp(dy=None): 701 if dy is not None: 702 dy = [ops.convert_to_tensor(x) for x in nest.flatten(dy)] 703 return imperative_grad.imperative_grad( 704 _default_vspace, this_tape, nest.flatten(result), sources, 705 output_gradients=dy) 706 return result, vjp 707 708 return decorated 709 710 711 def _aggregate_grads(gradients): 712 """Aggregate gradients from multiple sources. 713 714 Args: 715 gradients: A list of 'Tensor' or 'IndexedSlices' gradients. 716 717 Returns: 718 If 'gradients' only has 'Tensor', returns an aggregated 'Tensor'. 719 Otherwise returns an aggregated 'IndexedSlices'. 720 """ 721 assert gradients, "No gradients to aggregate" 722 723 if len(gradients) == 1: 724 return gradients[0] 725 if all([isinstance(g, ops.Tensor) for g in gradients]): 726 return math_ops.add_n(gradients) 727 else: 728 assert all([isinstance(g, (ops.Tensor, ops.IndexedSlices)) 729 for g in gradients]) 730 indexed_slices_list = [] 731 for grad in gradients: 732 # TODO(xpan): Support nested IndexedSlices and core IndexedSlices 733 if isinstance(grad, ops.Tensor): 734 indexed_slices = ops.IndexedSlices( 735 grad, 736 math_ops.range(grad.shape[0]), 737 constant_op.constant(grad.shape.as_list())) 738 indexed_slices_list.append(indexed_slices) 739 else: 740 indexed_slices_list.append(grad) 741 742 # Dense shapes from all gradients should be the same. 743 dense_shape = indexed_slices_list[0].dense_shape 744 # For simplicity now, always cast to int64. 745 indices = array_ops.concat([math_ops.cast(x.indices, dtypes.int64) 746 for x in indexed_slices_list], 0) 747 values = array_ops.concat([x.values for x in indexed_slices_list], 0) 748 return ops.IndexedSlices(values, indices, dense_shape) 749 750 751 def _num_elements(grad): 752 """The number of elements in the `grad` tensor.""" 753 if isinstance(grad, ops.Tensor): 754 return functools.reduce(operator.mul, grad._shape_tuple(), 1) # pylint: disable=protected-access 755 if isinstance(grad, ops.IndexedSlices): 756 return functools.reduce(operator.mul, grad.values._shape_tuple(), 1) # pylint: disable=protected-access 757 raise ValueError("`grad` not a Tensor or IndexedSlices.") 758 759 760 _zeros_cache = _TensorCache() 761 762 763 def _fast_fill(value, shape, dtype): 764 return array_ops.fill(shape, constant_op.constant(value, dtype=dtype)) 765 766 767 def _zeros(shape, dtype): 768 """Wraps array_ops.zeros to cache last zero for a given shape and dtype.""" 769 device = context.context().device_name 770 if dtype == dtypes.variant: 771 # TODO(apassos): need to save enough information about variant tensors to do 772 # a zeros 773 return None 774 cache_key = shape, dtype, device 775 cached = _zeros_cache.get(cache_key) 776 if cached is None: 777 cached = _fast_fill(0, shape, dtype) 778 _zeros_cache.put(cache_key, cached) 779 return cached 780 781 782 def _ones(shape, dtype): 783 if shape == (): # pylint: disable=g-explicit-bool-comparison 784 return constant_op.constant(1, dtype=dtype) 785 return _fast_fill(1, shape, dtype) 786 787 788 _default_vspace = imperative_grad.VSpace( 789 num_elements_fn=_num_elements, 790 aggregate_fn=_aggregate_grads, 791 tensor_id=ops.tensor_id, 792 zeros=_zeros, 793 ones=_ones) 794 795 796 class GradientTape(object): 797 """Records operations to use to compute gradients. 798 799 Operations are recorded if: 800 - they happen in code marked by this context manager 801 - at least one of their inputs is being watched 802 803 Outputs of recorded operations are watched. Variables are automatically 804 watched and tensors can be manually watched by calling the watch method on the 805 context manager. 806 807 Example usage: 808 809 ```python 810 with tfe.GradientTape() as g: 811 x = tf.constant(3.0) 812 g.watch(x) 813 y = x * x 814 grad = g.gradient(y, [x])[0] 815 assert grad.numpy() == 6.0 816 ``` 817 818 It is possible to use GradientTapes to compute higher-order derivatives as 819 follows: 820 821 ```python 822 with tfe.GradientTape() as g: 823 x = tf.constant(3.0) 824 g.watch(x) 825 y = x * x 826 with tfe.GradientTape() as gg: 827 gg.watch(y) 828 z = 2 * y 829 inner_grad = gg.gradient(z, [y])[0] 830 assert inner_grad.numpy() == 2 831 y = y + inner_grad 832 grad = g.gradient(y, [x])[0] 833 assert grad.numpy() == 6.0 834 ``` 835 836 By default, the resources held by a GradientTape are released as soon as 837 GradientTape.gradient() method is called. However, if one need to compute 838 multiple gradients over the same computation, she can create a persistent 839 GradientTape. Persistent tapes allow multiple calls to the gradient() method 840 and release resources when the tape object is destructed. 841 842 Example usage: 843 844 ```python 845 with tfe.GradientTape(persistent=True) as g: 846 x = tf.constant(3.0) 847 g.watch(x) 848 y = x * x 849 z = y * y 850 dz_dx = g.gradient(z, [x])[0] 851 assert dz_dx.numpy() == 108.0 # 4*x^3 at x = 3 852 dy_dx = g.gradient(y, [x])[0] 853 assert dy_dx.numpy() == 6.0 854 del g # Drop the reference to the tape 855 """ 856 857 def __init__(self, persistent=False): 858 """Creates a new GradientTape. 859 860 Args: 861 persistent: Boolean controlling whether a persistent gradient tape 862 is created. Must be True or False. 863 864 """ 865 self._tape = None 866 self._persistent = persistent 867 868 def __enter__(self): 869 self._tape = tape.push_new_tape(persistent=self._persistent) 870 return self 871 872 def __exit__(self, typ, value, traceback): 873 tape.pop_tape(self._tape) 874 875 def watch(self, tensor): 876 """Ensures that `tensor` is being traced by this tape. 877 878 Args: 879 tensor: a Tensor or Variable a list of Tensors or Variables. 880 """ 881 for t in nest.flatten(tensor): 882 if isinstance(t, resource_variable_ops.ResourceVariable): 883 t = t.handle 884 tape.watch(t) 885 886 def watched_variables(self): 887 return self._tape.watched_variables() 888 889 def gradient(self, target, sources, output_gradients=None): 890 """Computes the gradient using information traced by the tape. 891 892 Args: 893 target: the tensor to be differentiated. 894 sources: a list of Tensors or Variables, the target will be 895 differentiated with respect to the sources. 896 output_gradients: a list of gradients, one for each element of 897 target. Defaults to None. 898 899 Returns: 900 a list of Tensors (or IndexedSlices, or None), one for each element in 901 `sources`. 902 903 Raises: 904 RuntimeError: if called inside the context of the tape, or if called more 905 than once. 906 """ 907 if self._tape is None: 908 raise RuntimeError("GradientTape.gradient can only be called once " 909 "on non-persistent tapes, and " 910 "only when the context manager has exited.") 911 sources = [x.handle if isinstance(x, resource_variable_ops.ResourceVariable) 912 else x 913 for x in sources] 914 grad = imperative_grad.imperative_grad( 915 _default_vspace, self._tape, [target], sources, 916 output_gradients=output_gradients) 917 if not self._persistent: 918 self._tape = None 919 return grad 920