Home | History | Annotate | Download | only in ops
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for tensorflow.ops.gradients."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import sys
     22 import warnings
     23 
     24 import numpy as np
     25 
     26 from tensorflow.python.client import session
     27 from tensorflow.python.framework import constant_op
     28 from tensorflow.python.framework import dtypes
     29 from tensorflow.python.framework import function
     30 from tensorflow.python.framework import ops
     31 from tensorflow.python.framework import test_ops
     32 from tensorflow.python.framework import test_util
     33 from tensorflow.python.framework.constant_op import constant
     34 from tensorflow.python.ops import array_grad  # pylint: disable=unused-import
     35 from tensorflow.python.ops import array_ops
     36 from tensorflow.python.ops import control_flow_grad  # pylint: disable=unused-import
     37 from tensorflow.python.ops import control_flow_ops
     38 from tensorflow.python.ops import data_flow_grad  # pylint: disable=unused-import
     39 from tensorflow.python.ops import data_flow_ops  # pylint: disable=unused-import
     40 from tensorflow.python.ops import functional_ops  # pylint: disable=unused-import
     41 from tensorflow.python.ops import gradients
     42 from tensorflow.python.ops import gradients_impl
     43 from tensorflow.python.ops import math_grad  # pylint: disable=unused-import
     44 from tensorflow.python.ops import math_ops
     45 from tensorflow.python.ops import nn_grad  # pylint: disable=unused-import
     46 from tensorflow.python.ops import state_grad  # pylint: disable=unused-import
     47 from tensorflow.python.ops import tensor_array_grad  # pylint: disable=unused-import
     48 from tensorflow.python.ops import tensor_array_ops
     49 from tensorflow.python.ops import variables
     50 from tensorflow.python.ops.nn_ops import bias_add
     51 from tensorflow.python.platform import googletest
     52 
     53 
     54 def _OpsBetween(graph, to_ops, from_ops):
     55   """Build the list of operations between two lists of Operations.
     56 
     57   Args:
     58     graph: a Graph.
     59     to_ops: list of Operations.
     60     from_ops: list of Operations.
     61 
     62   Returns:
     63     The list of operations between "from_ops" and "to_ops", sorted by
     64     decreasing operation id. This list contains all elements of to_ops.
     65 
     66     TODO(touts): Think about returning an empty list if from_ops are not
     67     reachable from to_ops.  Presently it returns to_ops in that case.
     68   """
     69   # List of booleans, indexed by operation id, indicating if
     70   # an op is reached from the output of "input_ops".
     71   reached_ops = [False] * (graph._last_id + 1)
     72   # We only care to reach up to "output_ops" so we mark the
     73   # output ops as reached to avoid recursing past them.
     74   for op in to_ops:
     75     reached_ops[op._id] = True
     76   gradients_impl._MarkReachedOps(from_ops, reached_ops)
     77   between_ops = gradients_impl._GatherInputs(to_ops, reached_ops)
     78   between_ops.sort(key=lambda x: -x._id)
     79   return between_ops
     80 
     81 
     82 @test_util.with_c_api
     83 class GradientsTest(test_util.TensorFlowTestCase):
     84 
     85   def _OpNames(self, op_list):
     86     return ["%s/%d" % (str(op.name), op._id) for op in op_list]
     87 
     88   def _assertOpListEqual(self, ops1, ops2):
     89     self.assertEquals(self._OpNames(ops1), self._OpNames(ops2))
     90 
     91   def testOpsBetweenSimple(self):
     92     with ops.Graph().as_default() as g:
     93       t1 = constant(1.0)
     94       t2 = constant(2.0)
     95       t3 = array_ops.stack([t1, t2])
     96     # Full graph
     97     self._assertOpListEqual([t3.op, t2.op, t1.op],
     98                             _OpsBetween(g, [t3.op], [t1.op, t2.op]))
     99     # Only t1, t3.
    100     self._assertOpListEqual([t3.op, t1.op], _OpsBetween(g, [t3.op], [t1.op]))
    101 
    102   def testOpsBetweenUnreachable(self):
    103     with ops.Graph().as_default() as g:
    104       t1 = constant(1.0)
    105       t2 = constant(2.0)
    106       _ = array_ops.stack([t1, t2])
    107       t4 = constant(1.0)
    108       t5 = constant(2.0)
    109       t6 = array_ops.stack([t4, t5])
    110     # Elements of to_ops are always listed.
    111     self._assertOpListEqual([t6.op], _OpsBetween(g, [t6.op], [t1.op]))
    112 
    113   def testOpsBetweenCut(self):
    114     with ops.Graph().as_default() as g:
    115       t1 = constant(1.0)
    116       t2 = constant(2.0)
    117       t3 = array_ops.stack([t1, t2])
    118       t4 = constant([1.0])
    119       t5 = array_ops.concat([t4, t3], 0)
    120       t6 = constant([2.0])
    121       t7 = array_ops.concat([t5, t6], 0)
    122     self._assertOpListEqual([t7.op, t5.op, t4.op],
    123                             _OpsBetween(g, [t7.op], [t4.op]))
    124 
    125   def testOpsBetweenCycle(self):
    126     with ops.Graph().as_default() as g:
    127       t1 = constant(1.0)
    128       t2 = constant(2.0)
    129       t3 = array_ops.stack([t1, t2])
    130       t4 = array_ops.concat([t3, t3, t3], 0)
    131       t5 = constant([1.0])
    132       t6 = array_ops.concat([t4, t5], 0)
    133       t7 = array_ops.concat([t6, t3], 0)
    134     self._assertOpListEqual([t6.op, t4.op, t3.op],
    135                             _OpsBetween(g, [t6.op], [t3.op]))
    136     self._assertOpListEqual([t7.op, t6.op, t5.op, t4.op, t3.op, t1.op],
    137                             _OpsBetween(g, [t7.op], [t1.op, t5.op]))
    138     self._assertOpListEqual([t6.op, t5.op, t4.op, t3.op, t2.op],
    139                             _OpsBetween(g, [t6.op], [t2.op, t5.op]))
    140 
    141   def testGradients(self):
    142     with ops.Graph().as_default():
    143       inp = constant(1.0, shape=[32, 100], name="in")
    144       w = constant(1.0, shape=[100, 10], name="w")
    145       b = constant(1.0, shape=[10], name="b")
    146       xw = math_ops.matmul(inp, w, name="xw")
    147       h = bias_add(xw, b, name="h")
    148       w_grad = gradients.gradients(h, w)[0]
    149     self.assertEquals("MatMul", w_grad.op.type)
    150     self.assertEquals(w_grad.op._original_op, xw.op)
    151     self.assertTrue(w_grad.op.get_attr("transpose_a"))
    152     self.assertFalse(w_grad.op.get_attr("transpose_b"))
    153 
    154   def testUnusedOutput(self):
    155     with ops.Graph().as_default():
    156       w = constant(1.0, shape=[2, 2])
    157       x = constant(1.0, shape=[2, 2])
    158       wx = math_ops.matmul(w, x)
    159       split_wx = array_ops.split(value=wx, num_or_size_splits=2, axis=0)
    160       c = math_ops.reduce_sum(split_wx[1])
    161       gw = gradients.gradients(c, [w])[0]
    162     self.assertEquals("MatMul", gw.op.type)
    163 
    164   def testColocateGradients(self):
    165     with ops.Graph().as_default() as g:
    166       w = constant(1.0, shape=[1, 1])
    167       x = constant(1.0, shape=[1, 2])
    168       with g.device("/device:GPU:0"):
    169         wx = math_ops.matmul(w, x)
    170       gw = gradients.gradients(wx, [w], colocate_gradients_with_ops=True)[0]
    171     self.assertEqual(gw.op.colocation_groups(), wx.op.colocation_groups())
    172 
    173   def testColocateGradientsWithAggregation(self):
    174     with ops.Graph().as_default() as g:
    175       with g.device("/device:GPU:1"):
    176         w = constant(1.0, shape=[1, 1])
    177       x = constant(1.0, shape=[1, 2])
    178       y = constant(1.0, shape=[1, 2])
    179       wx = math_ops.matmul(w, x)
    180       wy = math_ops.matmul(w, y)
    181       with g.device("/device:GPU:0"):
    182         z = wx + wy
    183 
    184       gw1 = gradients.gradients(z, [w], colocate_gradients_with_ops=True)[0]
    185       self.assertEqual(gw1.op.colocation_groups(), wx.op.colocation_groups())
    186 
    187       gw2 = gradients.gradients(z, [w], colocate_gradients_with_ops=False)[0]
    188       self.assertTrue(wx.op.colocation_groups() != gw2.op.colocation_groups())
    189 
    190   def testColocateGradientsWithAggregationInMultipleDevices(self):
    191     with ops.Graph().as_default() as g:
    192       with g.device("/device:GPU:1"):
    193         w = constant(1.0, shape=[1, 1])
    194       x = constant(1.0, shape=[1, 2])
    195       y = constant(1.0, shape=[1, 2])
    196       with g.device("/task:1"):
    197         wx = math_ops.matmul(w, x)
    198       with g.device("/task:2"):
    199         wy = math_ops.matmul(w, y)
    200       with g.device("/device:GPU:0"):
    201         z = wx + wy
    202 
    203       gw1 = gradients.gradients(z, [w], colocate_gradients_with_ops=True)[0]
    204       self.assertEqual(gw1.op.colocation_groups(), w.op.colocation_groups())
    205 
    206       gw2 = gradients.gradients(z, [w], colocate_gradients_with_ops=False)[0]
    207       self.assertTrue(w.op.colocation_groups() != gw2.op.colocation_groups())
    208 
    209   def testColocateGradientsWithGateGradients(self):
    210     if not test_util.is_gpu_available():
    211       self.skipTest("No GPU available")
    212     with ops.Graph().as_default() as g:
    213       with g.device("/device:CPU:0"):
    214         x = constant(1.0, shape=[1, 1])
    215         y = constant(1.0, shape=[1, 1])
    216         s = x + y
    217       with g.device("/device:GPU:0"):
    218         z = math_ops.reduce_sum(s)
    219 
    220       gz_x = gradients.gradients(z, [x], colocate_gradients_with_ops=True,
    221                                  gate_gradients=True)[0]
    222       with session.Session():
    223         # Make sure the placer doesn't complain.
    224         gz_x.eval()
    225 
    226   def testBoundaryStop(self):
    227     # Test that we don't differentiate 'x'. The gradient function for 'x' is
    228     # set explicitly to None so we will get an exception if the gradient code
    229     # tries to differentiate 'x'.
    230     with ops.Graph().as_default():
    231       c = constant(1.0)
    232       x = array_ops.identity(c)
    233       y = x + 1.0
    234       z = y + 1
    235       grads = gradients.gradients(z, [x])
    236       self.assertTrue(all(x is not None for x in grads))
    237 
    238   def testBoundaryContinue(self):
    239     # Test that we differentiate both 'x' and 'y' correctly when x is a
    240     # predecessor of y.
    241     with self.test_session():
    242       x = constant(1.0)
    243       y = x * 2.0
    244       z = y * 3.0
    245       grads = gradients.gradients(z, [x, y])
    246       self.assertTrue(all(x is not None for x in grads))
    247       self.assertEqual(6.0, grads[0].eval())
    248 
    249   def testAggregationMethodAccumulateN(self):
    250     with self.test_session():
    251       x = constant(1.0)
    252       y = x * 2.0
    253       z = y + y + y + y + y + y + y + y + y + y
    254       grads = gradients.gradients(
    255           z, [x, y],
    256           aggregation_method=gradients.AggregationMethod.
    257           EXPERIMENTAL_ACCUMULATE_N)
    258       self.assertTrue(all(x is not None for x in grads))
    259       self.assertEqual(20.0, grads[0].eval())
    260       self.assertEqual(10.0, grads[1].eval())
    261 
    262   def testAggregationMethodAddN(self):
    263     with self.test_session():
    264       x = constant(1.0)
    265       y = x * 2.0
    266       z = y + y + y + y + y + y + y + y + y + y
    267       grads = gradients.gradients(
    268           z, [x, y], aggregation_method=gradients.AggregationMethod.ADD_N)
    269       self.assertTrue(all(x is not None for x in grads))
    270       self.assertEqual(20.0, grads[0].eval())
    271       self.assertEqual(10.0, grads[1].eval())
    272 
    273   def testAggregationMethodTree(self):
    274     with self.test_session():
    275       x = constant(1.0)
    276       y = x * 2.0
    277       z = y + y + y + y + y + y + y + y + y + y
    278       grads = gradients.gradients(
    279           z, [x, y],
    280           aggregation_method=gradients.AggregationMethod.EXPERIMENTAL_TREE)
    281       self.assertTrue(all(x is not None for x in grads))
    282       self.assertEqual(20.0, grads[0].eval())
    283       self.assertEqual(10.0, grads[1].eval())
    284 
    285   def testNoGradientForStringOutputs(self):
    286     # This test can't be run twice because the TestStringOutput gradient can
    287     # only be registered once. Just run with the C API enabled.
    288     if not ops._USE_C_API: return
    289 
    290     with ops.Graph().as_default():
    291 
    292       def _TestOpGrad(_, float_grad, string_grad):
    293         """Gradient function for TestStringOutput."""
    294         self.assertEquals(float_grad.dtype, dtypes.float32)
    295         self.assertFalse(string_grad)
    296         return float_grad
    297 
    298       ops.RegisterGradient("TestStringOutput")(_TestOpGrad)
    299 
    300       c = constant(1.0)
    301       x, _ = test_ops.test_string_output(c)
    302       z = x * 2.0
    303       w = z * 3.0
    304       grads = gradients.gradients(z, [c])
    305       self.assertTrue(isinstance(grads[0], ops.Tensor))
    306       grads = gradients.gradients(w, [c])
    307       self.assertTrue(isinstance(grads[0], ops.Tensor))
    308 
    309   def testSingletonIndexedSlices(self):
    310     with ops.Graph().as_default():
    311       x = array_ops.placeholder(dtypes.float32)
    312       y = array_ops.identity(x)
    313       dy = ops.IndexedSlices(
    314           array_ops.placeholder(dtypes.float32),
    315           array_ops.placeholder(dtypes.int32))
    316       dx, = gradients.gradients(y, x, grad_ys=dy)
    317       # The IndexedSlices gradient of tf.identity is the identity map.
    318       with self.test_session() as sess:
    319         vdx, vdy = sess.run(
    320             [dx, dy], feed_dict={x: [1.0], dy.indices: [0], dy.values: [2.0]})
    321       self.assertEqual(vdx, vdy)
    322 
    323   def testNonDifferentiableSwitchInWhileLoop(self):
    324     with ops.Graph().as_default():
    325       v = array_ops.placeholder(dtypes.float32, [])
    326 
    327       def _Step(i, a, ta):
    328         a += math_ops.cast(v, dtypes.int32)
    329         return (i + 1, a, ta.write(i, a))
    330 
    331       n = 4
    332       i, _, ta = control_flow_ops.while_loop(
    333           lambda i, *_: i < n,
    334           _Step, [0, 0, tensor_array_ops.TensorArray(
    335               dtypes.int32, size=n)])
    336       target = ta.read(i - 1)
    337       grad, = gradients.gradients(target, v)
    338       self.assertIsNone(grad)
    339 
    340   def testVariableReadValueGradient(self):
    341     with ops.Graph().as_default():
    342       init = constant_op.constant(100.0)
    343       var = variables.Variable(init)
    344       gradient = gradients.gradients(var.read_value(), var)
    345       self.assertIsNotNone(gradient)
    346 
    347   def testVariableAsGraphElementGradient(self):
    348     with ops.Graph().as_default() as graph:
    349       init = constant_op.constant(100.0)
    350       var = variables.Variable(init)
    351       gradient = gradients.gradients(graph.as_graph_element(var), var)
    352       self.assertIsNotNone(gradient)
    353 
    354   def testVariableRefGradient(self):
    355     with ops.Graph().as_default():
    356       init = constant_op.constant(100.0)
    357       var = variables.Variable(init)
    358       gradient = gradients.gradients(var._ref(), var)
    359       self.assertIsNotNone(gradient)
    360 
    361   def testDependentYs(self):
    362     with self.test_session():
    363       x = constant_op.constant(3.0)
    364       y = math_ops.square(x)
    365       y1 = math_ops.square(y)
    366       y2 = math_ops.square(y1)
    367       g = gradients.gradients([y, y2], x)
    368       self.assertAllClose(17502.0, g[0].eval())
    369       g = gradients.gradients(y + y2, x)
    370       self.assertAllClose(17502.0, g[0].eval())
    371       z = array_ops.identity(y)
    372       z2 = array_ops.identity(y2)
    373       g = gradients.gradients([z, z2], x)
    374       self.assertAllClose(17502.0, g[0].eval())
    375 
    376   def testPartialDerivatives(self):
    377     with self.test_session():
    378       x = constant_op.constant(1.)
    379       y = 2 * x
    380       z = x + y
    381       totalg = gradients.gradients(z, [x, y])
    382       self.assertEqual([3.0, 1.0], [g.eval() for g in totalg])
    383       partialg = gradients.gradients(z, [x, y], stop_gradients=[x, y])
    384       self.assertEqual([1.0, 1.0], [g.eval() for g in partialg])
    385 
    386   def testStopGradients(self):
    387     def _MakeGraph(rng, stop_gradients=()):
    388       def _FunctionOf(xs, k=3):
    389         return ops.convert_to_tensor(
    390             sum(math_ops.matmul(rng.rand(k, k), x) for x in xs)
    391             + rng.rand(k, k))
    392 
    393       a = _FunctionOf([])
    394       if "a" in stop_gradients: a = array_ops.stop_gradient(a)
    395       b = _FunctionOf([a])
    396       if "b" in stop_gradients: b = array_ops.stop_gradient(b)
    397       c = _FunctionOf([a, b])
    398       if "c" in stop_gradients: c = array_ops.stop_gradient(c)
    399       d = _FunctionOf([b, c])
    400       if "d" in stop_gradients: d = array_ops.stop_gradient(d)
    401       return dict(a=a, b=b, c=c, d=d)
    402 
    403     def _Gradients(ys, xs, **kwargs):
    404       dydxs = gradients.gradients(ys, xs, **kwargs)
    405       dydxs = [0. * x if dydx is None else dydx
    406                for x, dydx in zip(xs, dydxs)]
    407       return dydxs
    408 
    409     seed = np.random.randint(1000)
    410     cases = []
    411     subsets = [""] + "a b c d ab ac ad bc bd cd abc abd acd bcd abcd".split()
    412     graph = _MakeGraph(np.random.RandomState(seed))
    413     for constants in subsets:
    414       graph_with_stops = _MakeGraph(np.random.RandomState(seed), constants)
    415       for variables_ in subsets:
    416         # compute the gradient when stopped using tf.stop_gradients
    417         grad1 = _Gradients([graph_with_stops["d"]],
    418                            [graph_with_stops[v] for v in variables_])
    419         # compute the gradient when stopped using the stop_gradients kwarg
    420         grad2 = _Gradients([graph["d"]],
    421                            [graph[v] for v in variables_],
    422                            stop_gradients=[graph[v] for v in constants])
    423         cases.append(dict(grad1=grad1, grad2=grad2,
    424                           constants=constants, variables=variables_))
    425 
    426     # evaluate all tensors in one call to session.run for speed
    427     with self.test_session() as sess:
    428       results = sess.run([(case["grad1"], case["grad2"]) for case in cases])
    429 
    430     for (npgrad1, npgrad2), case in zip(results, cases):
    431       for a, b in zip(npgrad1, npgrad2):
    432         np.testing.assert_allclose(a, b)
    433 
    434 
    435 @test_util.with_c_api
    436 class FunctionGradientsTest(test_util.TensorFlowTestCase):
    437 
    438   @classmethod
    439   def XSquarePlusB(cls, x, b):
    440     return x * x + b
    441 
    442   @classmethod
    443   def XSquarePlusBGradient(cls, x, b, g):
    444     # Perturb gradients (multiply by 2), so we can test that this was called.
    445     g *= 2.0
    446     return g * 2.0 * x, g
    447 
    448   @classmethod
    449   def _PythonGradient(cls, op, grad):
    450     # Perturb gradients (multiply by 3), so we can test that this was called.
    451     grad *= 3.0
    452     return grad * op.inputs[0] * 2.0, grad
    453 
    454   @classmethod
    455   def _GetFunc(cls, **kwargs):
    456     return function.Defun(dtypes.float32, dtypes.float32, **
    457                           kwargs)(cls.XSquarePlusB)
    458 
    459   def _GetFuncGradients(self, f, x_value, b_value):
    460     x = constant_op.constant(x_value, name="x")
    461     b = constant_op.constant(b_value, name="b")
    462 
    463     y = f(x, b)
    464     grads = gradients.gradients(y, [x, b])
    465     with self.test_session() as sess:
    466       return sess.run(grads)
    467 
    468   def testFunctionGradientsBasic(self):
    469     g = ops.Graph()
    470     with g.as_default():
    471       f = self._GetFunc()
    472       # Get gradients (should add SymbolicGradient node for function).
    473       grads = self._GetFuncGradients(f, [2.0], [1.0])
    474       self.assertAllEqual([4.0], grads[0])
    475       self.assertAllEqual([1.0], grads[1])
    476 
    477   def testFunctionGradientsComposition(self):
    478     with ops.Graph().as_default():
    479       f = self._GetFunc()
    480       x = constant_op.constant([2.0], name="x")
    481       b1 = constant_op.constant([1.0], name="b1")
    482       b2 = constant_op.constant([1.0], name="b2")
    483 
    484       y = f(f(x, b1), b2)
    485       # Build gradient graph (should add SymbolicGradient node for function).
    486       grads = gradients.gradients(y, [x, b1])
    487 
    488       with self.test_session() as sess:
    489         self.assertAllEqual([40.0], sess.run(grads)[0])
    490         self.assertAllEqual([10.0], sess.run(grads)[1])
    491 
    492   def testFunctionGradientsWithGradFunc(self):
    493     g = ops.Graph()
    494     with g.as_default():
    495       grad_func = function.Defun(dtypes.float32, dtypes.float32,
    496                                  dtypes.float32)(self.XSquarePlusBGradient)
    497       f = self._GetFunc(grad_func=grad_func)
    498       # Get gradients (should add SymbolicGradient node for function, which
    499       # uses the grad_func above, which multiplies all gradients by 2).
    500       grads = self._GetFuncGradients(f, [2.0], [1.0])
    501       self.assertAllEqual([4.0 * 2], grads[0])
    502       self.assertAllEqual([1.0 * 2], grads[1])
    503 
    504   def testFunctionGradientWithRegistration(self):
    505     g = ops.Graph()
    506     with g.as_default():
    507       f = self._GetFunc(python_grad_func=self._PythonGradient)
    508       # Get gradients, using the python gradient function. It multiplies the
    509       # gradients by 3.
    510       grads = self._GetFuncGradients(f, [2.0], [1.0])
    511       self.assertAllEqual([4.0 * 3], grads[0])
    512       self.assertAllEqual([1.0 * 3], grads[1])
    513 
    514   def testFunctionGradientWithGradFuncAndRegistration(self):
    515     g = ops.Graph()
    516     with g.as_default():
    517       grad_func = function.Defun(dtypes.float32, dtypes.float32,
    518                                  dtypes.float32)(self.XSquarePlusBGradient)
    519       with self.assertRaisesRegexp(ValueError, "Gradient defined twice"):
    520         f = self._GetFunc(
    521             grad_func=grad_func, python_grad_func=self._PythonGradient)
    522         f.add_to_graph(ops.Graph())
    523 
    524 
    525 @test_util.with_c_api
    526 class StopGradientTest(test_util.TensorFlowTestCase):
    527 
    528   def testStopGradient(self):
    529     with ops.Graph().as_default():
    530       inp = constant(1.0, shape=[100, 32], name="in")
    531       out = array_ops.stop_gradient(inp)
    532       igrad = gradients.gradients(out, inp)[0]
    533     assert igrad is None
    534 
    535 
    536 @test_util.with_c_api
    537 class PreventGradientTest(test_util.TensorFlowTestCase):
    538 
    539   def testPreventGradient(self):
    540     with ops.Graph().as_default():
    541       inp = constant(1.0, shape=[100, 32], name="in")
    542       out = array_ops.prevent_gradient(inp)
    543       with self.assertRaisesRegexp(LookupError, "explicitly disabled"):
    544         _ = gradients.gradients(out, inp)
    545 
    546 
    547 @test_util.with_c_api
    548 class HessianVectorProductTest(test_util.TensorFlowTestCase):
    549 
    550   def testHessianVectorProduct(self):
    551     # Manually compute the Hessian explicitly for a low-dimensional problem
    552     # and check that HessianVectorProduct matches multiplication by the
    553     # explicit Hessian.
    554     # Specifically, the Hessian of f(x) = x^T A x is
    555     # H = A + A^T.
    556     # We expect HessianVectorProduct(f(x), x, v) to be H v.
    557     m = 4
    558     rng = np.random.RandomState([1, 2, 3])
    559     mat_value = rng.randn(m, m).astype("float32")
    560     v_value = rng.randn(m, 1).astype("float32")
    561     x_value = rng.randn(m, 1).astype("float32")
    562     hess_value = mat_value + mat_value.T
    563     hess_v_value = np.dot(hess_value, v_value)
    564     for use_gpu in [False, True]:
    565       with self.test_session(use_gpu=use_gpu):
    566         mat = constant_op.constant(mat_value)
    567         v = constant_op.constant(v_value)
    568         x = constant_op.constant(x_value)
    569         mat_x = math_ops.matmul(mat, x, name="Ax")
    570         x_mat_x = math_ops.matmul(array_ops.transpose(x), mat_x, name="xAx")
    571         hess_v = gradients_impl._hessian_vector_product(x_mat_x, [x], [v])[0]
    572         hess_v_actual = hess_v.eval()
    573       self.assertAllClose(hess_v_value, hess_v_actual)
    574 
    575 
    576 @test_util.with_c_api
    577 class HessianTest(test_util.TensorFlowTestCase):
    578 
    579   def testHessian1D(self):
    580     # Manually compute the Hessian explicitly for a low-dimensional problem
    581     # and check that `hessian` matches. Specifically, the Hessian of
    582     # f(x) = x^T A x is H = A + A^T.
    583     m = 4
    584     rng = np.random.RandomState([1, 2, 3])
    585     mat_value = rng.randn(m, m).astype("float32")
    586     x_value = rng.randn(m).astype("float32")
    587     hess_value = mat_value + mat_value.T
    588     with self.test_session(use_gpu=True):
    589       mat = constant_op.constant(mat_value)
    590       x = constant_op.constant(x_value)
    591       x_mat_x = math_ops.reduce_sum(x[:, None] * mat * x[None, :])
    592       hess = gradients.hessians(x_mat_x, x)[0]
    593       hess_actual = hess.eval()
    594     self.assertAllClose(hess_value, hess_actual)
    595 
    596   def testHessian1D_multi(self):
    597     # Test the computation of the hessian with respect to multiple tensors
    598     m = 4
    599     n = 3
    600     rng = np.random.RandomState([1, 2, 3])
    601     mat_values = [rng.randn(m, m).astype("float32") for _ in range(n)]
    602     x_values = [rng.randn(m).astype("float32") for _ in range(n)]
    603     hess_values = [mat_value + mat_value.T for mat_value in mat_values]
    604     with self.test_session(use_gpu=True):
    605       mats = [constant_op.constant(mat_value) for mat_value in mat_values]
    606       xs = [constant_op.constant(x_value) for x_value in x_values]
    607       xs_mats_xs = [
    608           math_ops.reduce_sum(x[:, None] * mat * x[None, :])
    609           for x, mat in zip(xs, mats)
    610       ]
    611       hessians = gradients.hessians(xs_mats_xs, xs)
    612       hessians_actual = [hess.eval() for hess in hessians]
    613     for hess_value, hess_actual in zip(hess_values, hessians_actual):
    614       self.assertAllClose(hess_value, hess_actual)
    615 
    616   def testHessianInvalidDimension(self):
    617     for shape in [(10, 10), None]:
    618       with self.test_session(use_gpu=True):
    619         x = array_ops.placeholder(dtypes.float32, shape)
    620         # Expect a ValueError because the dimensions are wrong
    621         with self.assertRaises(ValueError):
    622           gradients.hessians(x, x)
    623 
    624   def testHessian2D_square_matrix(self):
    625     # Manually compute the Hessian explicitly for a low-dimensional problem
    626     # and check that `hessian` matches. Specifically, the Hessian of
    627     # f(x) = 1/2 * x^T * x is H = constant (block identity matrix)
    628     m = 3
    629     rng = np.random.RandomState([1, 2, 3])
    630     x_value = rng.randn(m, m).astype("float32")
    631     with self.test_session(use_gpu=True):
    632       x = constant_op.constant(x_value)
    633       x_square = math_ops.reduce_sum(
    634           math_ops.matmul(array_ops.transpose(x), x) * 0.5
    635       )
    636       hess = gradients.hessians(x_square, x)[0]
    637       hess_actual = hess.eval()
    638     hess_value = np.bmat([
    639         [elem*np.ones((m, m)) for elem in vec]
    640         for vec in np.eye(m)
    641     ]).astype("float32")
    642     self.assertAllEqual((m, m, m, m), hess_actual.shape)
    643     self.assertAllClose(hess_value, hess_actual.reshape((m * m, m * m)))
    644 
    645   def testHessian2D_non_square_matrix(self):
    646     m = 3
    647     n = 4
    648     rng = np.random.RandomState([1, 2, 3])
    649     x_value = rng.randn(m, n).astype("float32")
    650     with self.test_session(use_gpu=True):
    651       x = constant_op.constant(x_value)
    652       x_square = math_ops.reduce_sum(
    653           math_ops.matmul(array_ops.transpose(x), x) * 0.5
    654       )
    655       hess = gradients.hessians(x_square, x)[0]
    656       hess_actual = hess.eval()
    657     hess_value = np.bmat([
    658         [elem*np.ones((n, n)) for elem in vec]
    659         for vec in np.eye(m)
    660     ]).astype("float32")
    661     self.assertAllEqual((m, n, m, n), hess_actual.shape)
    662     self.assertAllClose(hess_value, hess_actual.reshape((m * n, m * n)))
    663 
    664 @test_util.with_c_api
    665 class IndexedSlicesToTensorTest(test_util.TensorFlowTestCase):
    666 
    667   def testIndexedSlicesToTensor(self):
    668     with self.test_session():
    669       np_val = np.random.rand(4, 4, 4, 4).astype(np.float32)
    670       c = constant_op.constant(np_val)
    671       c_sparse = math_ops._as_indexed_slices(c)
    672       self.assertAllEqual(np_val.shape, c_sparse.dense_shape.eval())
    673       c_dense = math_ops.multiply(c_sparse, 1.0)
    674       self.assertAllClose(np_val, c_dense.eval())
    675 
    676   def testIndexedSlicesToTensorList(self):
    677     with self.test_session():
    678       numpy_list = []
    679       dense_list = []
    680       sparse_list = []
    681       for _ in range(3):
    682         np_val = np.random.rand(4, 4, 4, 4).astype(np.float32)
    683         c = constant_op.constant(np_val)
    684         c_sparse = math_ops._as_indexed_slices(c)
    685         numpy_list.append(np_val)
    686         dense_list.append(c)
    687         sparse_list.append(c_sparse)
    688       packed_dense = array_ops.stack(dense_list)
    689       packed_sparse = array_ops.stack(sparse_list)
    690       self.assertAllClose(packed_dense.eval(), packed_sparse.eval())
    691 
    692   def testInt64Indices(self):
    693     with self.test_session():
    694       np_val = np.random.rand(4, 4, 4, 4).astype(np.float32)
    695       c = constant_op.constant(np_val)
    696       c_sparse = math_ops._as_indexed_slices(c)
    697       c_sparse = ops.IndexedSlices(
    698           c_sparse.values,
    699           math_ops.cast(c_sparse.indices, dtypes.int64), c_sparse.dense_shape)
    700       self.assertAllEqual(np_val.shape, c_sparse.dense_shape.eval())
    701       c_dense = math_ops.multiply(c_sparse, 1.0)
    702       self.assertAllClose(np_val, c_dense.eval())
    703 
    704   def testWarnings(self):
    705     # TODO(gunan) Reenable after this issue is fixed:
    706     # https://github.com/google/protobuf/issues/2812
    707     if sys.version_info >= (3, 5):
    708       self.skipTest("Skipped test for Python 3.5+")
    709 
    710     # Smaller than the threshold: no warning.
    711     c_sparse = ops.IndexedSlices(
    712         array_ops.placeholder(dtypes.float32),
    713         array_ops.placeholder(dtypes.int32), constant([4, 4, 4, 4]))
    714     with warnings.catch_warnings(record=True) as w:
    715       math_ops.multiply(c_sparse, 1.0)
    716     self.assertEqual(0, len(w))
    717 
    718     # Greater than or equal to the threshold: warning.
    719     c_sparse = ops.IndexedSlices(
    720         array_ops.placeholder(dtypes.float32),
    721         array_ops.placeholder(dtypes.int32), constant([100, 100, 100, 100]))
    722     # "always" filter prevents the warning from being suppressed if it was
    723     # already triggered in a different test.
    724     warnings.simplefilter("always")
    725     with warnings.catch_warnings(record=True) as w:
    726       math_ops.multiply(c_sparse, 1.0)
    727     self.assertEqual(1, len(w))
    728     self.assertTrue(
    729         "with 100000000 elements. This may consume a large amount of memory." in
    730         str(w[0].message))
    731 
    732     # Unknown dense shape: warning.
    733     c_sparse = ops.IndexedSlices(
    734         array_ops.placeholder(dtypes.float32),
    735         array_ops.placeholder(dtypes.int32),
    736         array_ops.placeholder(dtypes.int32))
    737     with warnings.catch_warnings(record=True) as w:
    738       math_ops.multiply(c_sparse, 1.0)
    739     self.assertEqual(1, len(w))
    740     self.assertTrue(
    741         "of unknown shape. This may consume a large amount of memory." in
    742         str(w[0].message))
    743 
    744 
    745 @test_util.with_c_api
    746 class OnlyRealGradientsTest(test_util.TensorFlowTestCase):
    747 
    748   def testRealOnly(self):
    749     x = constant_op.constant(7+3j, dtype=dtypes.complex64)
    750     y = math_ops.square(x)
    751     with self.assertRaisesRegexp(
    752         TypeError,
    753         r"Gradients of complex tensors must set grad_ys "
    754         r"\(y\.dtype = tf\.complex64\)"):
    755       gradients.gradients(y, x)
    756 
    757 
    758 if __name__ == "__main__":
    759   googletest.main()
    760