Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import numpy as np
     21 
     22 from tensorflow.python.framework import constant_op
     23 from tensorflow.python.framework import dtypes as dtypes_lib
     24 from tensorflow.python.framework import ops
     25 from tensorflow.python.framework import test_util
     26 from tensorflow.python.ops import array_ops
     27 from tensorflow.python.ops import gradient_checker
     28 from tensorflow.python.ops import gradients_impl
     29 from tensorflow.python.platform import test
     30 from tensorflow.python.platform import tf_logging
     31 
     32 
     33 class MatrixDiagTest(test.TestCase):
     34 
     35   @test_util.run_deprecated_v1
     36   def testVector(self):
     37     with self.session(use_gpu=True):
     38       v = np.array([1.0, 2.0, 3.0])
     39       mat = np.diag(v)
     40       v_diag = array_ops.matrix_diag(v)
     41       self.assertEqual((3, 3), v_diag.get_shape())
     42       self.assertAllEqual(v_diag.eval(), mat)
     43 
     44   def _testBatchVector(self, dtype):
     45     with self.cached_session(use_gpu=True):
     46       v_batch = np.array([[1.0, 0.0, 3.0], [4.0, 5.0, 6.0]]).astype(dtype)
     47       mat_batch = np.array([[[1.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 3.0]],
     48                             [[4.0, 0.0, 0.0], [0.0, 5.0, 0.0],
     49                              [0.0, 0.0, 6.0]]]).astype(dtype)
     50       v_batch_diag = array_ops.matrix_diag(v_batch)
     51       self.assertEqual((2, 3, 3), v_batch_diag.get_shape())
     52       self.assertAllEqual(v_batch_diag.eval(), mat_batch)
     53 
     54   @test_util.run_deprecated_v1
     55   def testBatchVector(self):
     56     self._testBatchVector(np.float32)
     57     self._testBatchVector(np.float64)
     58     self._testBatchVector(np.int32)
     59     self._testBatchVector(np.int64)
     60     self._testBatchVector(np.bool)
     61 
     62   @test_util.run_deprecated_v1
     63   def testInvalidShape(self):
     64     with self.assertRaisesRegexp(ValueError, "must be at least rank 1"):
     65       array_ops.matrix_diag(0)
     66 
     67   @test_util.run_deprecated_v1
     68   @test_util.disable_xla("b/123337890")  # Error messages differ
     69   def testInvalidShapeAtEval(self):
     70     with self.session(use_gpu=True):
     71       v = array_ops.placeholder(dtype=dtypes_lib.float32)
     72       with self.assertRaisesOpError("input must be at least 1-dim"):
     73         array_ops.matrix_diag(v).eval(feed_dict={v: 0.0})
     74 
     75   @test_util.run_deprecated_v1
     76   def testGrad(self):
     77     shapes = ((3,), (7, 4))
     78     with self.session(use_gpu=True):
     79       for shape in shapes:
     80         x = constant_op.constant(np.random.rand(*shape), np.float32)
     81         y = array_ops.matrix_diag(x)
     82         error = gradient_checker.compute_gradient_error(x,
     83                                                         x.get_shape().as_list(),
     84                                                         y,
     85                                                         y.get_shape().as_list())
     86         self.assertLess(error, 1e-4)
     87 
     88 
     89 class MatrixSetDiagTest(test.TestCase):
     90 
     91   @test_util.run_deprecated_v1
     92   def testSquare(self):
     93     with self.session(use_gpu=True):
     94       v = np.array([1.0, 2.0, 3.0])
     95       mat = np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 1.0], [1.0, 1.0, 1.0]])
     96       mat_set_diag = np.array([[1.0, 1.0, 0.0], [1.0, 2.0, 1.0],
     97                                [1.0, 1.0, 3.0]])
     98       output = array_ops.matrix_set_diag(mat, v)
     99       self.assertEqual((3, 3), output.get_shape())
    100       self.assertAllEqual(mat_set_diag, self.evaluate(output))
    101 
    102   @test_util.run_deprecated_v1
    103   def testRectangular(self):
    104     with self.session(use_gpu=True):
    105       v = np.array([3.0, 4.0])
    106       mat = np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 1.0]])
    107       expected = np.array([[3.0, 1.0, 0.0], [1.0, 4.0, 1.0]])
    108       output = array_ops.matrix_set_diag(mat, v)
    109       self.assertEqual((2, 3), output.get_shape())
    110       self.assertAllEqual(expected, self.evaluate(output))
    111 
    112       v = np.array([3.0, 4.0])
    113       mat = np.array([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0]])
    114       expected = np.array([[3.0, 1.0], [1.0, 4.0], [1.0, 1.0]])
    115       output = array_ops.matrix_set_diag(mat, v)
    116       self.assertEqual((3, 2), output.get_shape())
    117       self.assertAllEqual(expected, self.evaluate(output))
    118 
    119   def _testSquareBatch(self, dtype):
    120     with self.cached_session(use_gpu=True):
    121       v_batch = np.array([[-1.0, 0.0, -3.0], [-4.0, -5.0, -6.0]]).astype(dtype)
    122       mat_batch = np.array([[[1.0, 0.0, 3.0], [0.0, 2.0, 0.0], [1.0, 0.0, 3.0]],
    123                             [[4.0, 0.0, 4.0], [0.0, 5.0, 0.0],
    124                              [2.0, 0.0, 6.0]]]).astype(dtype)
    125 
    126       mat_set_diag_batch = np.array([[[-1.0, 0.0, 3.0], [0.0, 0.0, 0.0],
    127                                       [1.0, 0.0, -3.0]],
    128                                      [[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0],
    129                                       [2.0, 0.0, -6.0]]]).astype(dtype)
    130 
    131       output = array_ops.matrix_set_diag(mat_batch, v_batch)
    132       self.assertEqual((2, 3, 3), output.get_shape())
    133       self.assertAllEqual(mat_set_diag_batch, self.evaluate(output))
    134 
    135   @test_util.run_deprecated_v1
    136   def testSquareBatch(self):
    137     self._testSquareBatch(np.float32)
    138     self._testSquareBatch(np.float64)
    139     self._testSquareBatch(np.int32)
    140     self._testSquareBatch(np.int64)
    141     self._testSquareBatch(np.bool)
    142 
    143   @test_util.run_deprecated_v1
    144   def testRectangularBatch(self):
    145     with self.session(use_gpu=True):
    146       v_batch = np.array([[-1.0, -2.0], [-4.0, -5.0]])
    147       mat_batch = np.array([[[1.0, 0.0, 3.0], [0.0, 2.0, 0.0]],
    148                             [[4.0, 0.0, 4.0], [0.0, 5.0, 0.0]]])
    149 
    150       mat_set_diag_batch = np.array([[[-1.0, 0.0, 3.0], [0.0, -2.0, 0.0]],
    151                                      [[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0]]])
    152       output = array_ops.matrix_set_diag(mat_batch, v_batch)
    153       self.assertEqual((2, 2, 3), output.get_shape())
    154       self.assertAllEqual(mat_set_diag_batch, self.evaluate(output))
    155 
    156   @test_util.run_deprecated_v1
    157   def testInvalidShape(self):
    158     with self.assertRaisesRegexp(ValueError, "must be at least rank 2"):
    159       array_ops.matrix_set_diag(0, [0])
    160     with self.assertRaisesRegexp(ValueError, "must be at least rank 1"):
    161       array_ops.matrix_set_diag([[0]], 0)
    162 
    163   @test_util.run_deprecated_v1
    164   def testInvalidShapeAtEval(self):
    165     with self.session(use_gpu=True):
    166       v = array_ops.placeholder(dtype=dtypes_lib.float32)
    167       with self.assertRaisesOpError("input must be at least 2-dim"):
    168         array_ops.matrix_set_diag(v, [v]).eval(feed_dict={v: 0.0})
    169       with self.assertRaisesOpError(
    170           r"but received input shape: \[1,1\] and diagonal shape: \[\]"):
    171         array_ops.matrix_set_diag([[v]], v).eval(feed_dict={v: 0.0})
    172 
    173   @test_util.run_deprecated_v1
    174   def testGrad(self):
    175     shapes = ((3, 4, 4), (3, 3, 4), (3, 4, 3), (7, 4, 8, 8))
    176     with self.session(use_gpu=True):
    177       for shape in shapes:
    178         x = constant_op.constant(
    179             np.random.rand(*shape), dtype=dtypes_lib.float32)
    180         diag_shape = shape[:-2] + (min(shape[-2:]),)
    181         x_diag = constant_op.constant(
    182             np.random.rand(*diag_shape), dtype=dtypes_lib.float32)
    183         y = array_ops.matrix_set_diag(x, x_diag)
    184         error_x = gradient_checker.compute_gradient_error(
    185             x,
    186             x.get_shape().as_list(), y,
    187             y.get_shape().as_list())
    188         self.assertLess(error_x, 1e-4)
    189         error_x_diag = gradient_checker.compute_gradient_error(
    190             x_diag,
    191             x_diag.get_shape().as_list(), y,
    192             y.get_shape().as_list())
    193         self.assertLess(error_x_diag, 1e-4)
    194 
    195   @test_util.run_deprecated_v1
    196   def testGradWithNoShapeInformation(self):
    197     with self.session(use_gpu=True) as sess:
    198       v = array_ops.placeholder(dtype=dtypes_lib.float32)
    199       mat = array_ops.placeholder(dtype=dtypes_lib.float32)
    200       grad_input = array_ops.placeholder(dtype=dtypes_lib.float32)
    201       output = array_ops.matrix_set_diag(mat, v)
    202       grads = gradients_impl.gradients(output, [mat, v], grad_ys=grad_input)
    203       grad_input_val = np.random.rand(3, 3).astype(np.float32)
    204       grad_vals = sess.run(
    205           grads,
    206           feed_dict={
    207               v: 2 * np.ones(3),
    208               mat: np.ones((3, 3)),
    209               grad_input: grad_input_val
    210           })
    211       self.assertAllEqual(np.diag(grad_input_val), grad_vals[1])
    212       self.assertAllEqual(grad_input_val - np.diag(np.diag(grad_input_val)),
    213                           grad_vals[0])
    214 
    215 
    216 class MatrixDiagPartTest(test.TestCase):
    217 
    218   @test_util.run_deprecated_v1
    219   def testSquare(self):
    220     with self.session(use_gpu=True):
    221       v = np.array([1.0, 2.0, 3.0])
    222       mat = np.diag(v)
    223       mat_diag = array_ops.matrix_diag_part(mat)
    224       self.assertEqual((3,), mat_diag.get_shape())
    225       self.assertAllEqual(mat_diag.eval(), v)
    226 
    227   @test_util.run_deprecated_v1
    228   def testRectangular(self):
    229     with self.session(use_gpu=True):
    230       mat = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
    231       mat_diag = array_ops.matrix_diag_part(mat)
    232       self.assertAllEqual(mat_diag.eval(), np.array([1.0, 5.0]))
    233       mat = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
    234       mat_diag = array_ops.matrix_diag_part(mat)
    235       self.assertAllEqual(mat_diag.eval(), np.array([1.0, 4.0]))
    236 
    237   def _testSquareBatch(self, dtype):
    238     with self.cached_session(use_gpu=True):
    239       v_batch = np.array([[1.0, 0.0, 3.0], [4.0, 5.0, 6.0]]).astype(dtype)
    240       mat_batch = np.array([[[1.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 3.0]],
    241                             [[4.0, 0.0, 0.0], [0.0, 5.0, 0.0],
    242                              [0.0, 0.0, 6.0]]]).astype(dtype)
    243       self.assertEqual(mat_batch.shape, (2, 3, 3))
    244       mat_batch_diag = array_ops.matrix_diag_part(mat_batch)
    245       self.assertEqual((2, 3), mat_batch_diag.get_shape())
    246       self.assertAllEqual(mat_batch_diag.eval(), v_batch)
    247 
    248   @test_util.run_deprecated_v1
    249   def testSquareBatch(self):
    250     self._testSquareBatch(np.float32)
    251     self._testSquareBatch(np.float64)
    252     self._testSquareBatch(np.int32)
    253     self._testSquareBatch(np.int64)
    254     self._testSquareBatch(np.bool)
    255 
    256   @test_util.run_deprecated_v1
    257   def testRectangularBatch(self):
    258     with self.session(use_gpu=True):
    259       v_batch = np.array([[1.0, 2.0], [4.0, 5.0]])
    260       mat_batch = np.array([[[1.0, 0.0, 0.0], [0.0, 2.0, 0.0]],
    261                             [[4.0, 0.0, 0.0], [0.0, 5.0, 0.0]]])
    262       self.assertEqual(mat_batch.shape, (2, 2, 3))
    263       mat_batch_diag = array_ops.matrix_diag_part(mat_batch)
    264       self.assertEqual((2, 2), mat_batch_diag.get_shape())
    265       self.assertAllEqual(mat_batch_diag.eval(), v_batch)
    266 
    267   @test_util.run_deprecated_v1
    268   def testInvalidShape(self):
    269     with self.assertRaisesRegexp(ValueError, "must be at least rank 2"):
    270       array_ops.matrix_diag_part(0)
    271 
    272   @test_util.run_deprecated_v1
    273   @test_util.disable_xla("b/123337890")  # Error messages differ
    274   def testInvalidShapeAtEval(self):
    275     with self.session(use_gpu=True):
    276       v = array_ops.placeholder(dtype=dtypes_lib.float32)
    277       with self.assertRaisesOpError("input must be at least 2-dim"):
    278         array_ops.matrix_diag_part(v).eval(feed_dict={v: 0.0})
    279 
    280   @test_util.run_deprecated_v1
    281   def testGrad(self):
    282     shapes = ((3, 3), (2, 3), (3, 2), (5, 3, 3))
    283     with self.session(use_gpu=True):
    284       for shape in shapes:
    285         x = constant_op.constant(np.random.rand(*shape), dtype=np.float32)
    286         y = array_ops.matrix_diag_part(x)
    287         error = gradient_checker.compute_gradient_error(x,
    288                                                         x.get_shape().as_list(),
    289                                                         y,
    290                                                         y.get_shape().as_list())
    291         self.assertLess(error, 1e-4)
    292 
    293 
    294 class DiagTest(test.TestCase):
    295 
    296   def _diagOp(self, diag, dtype, expected_ans, use_gpu):
    297     with self.cached_session(use_gpu=use_gpu):
    298       tf_ans = array_ops.diag(ops.convert_to_tensor(diag.astype(dtype)))
    299       out = self.evaluate(tf_ans)
    300       tf_ans_inv = array_ops.diag_part(expected_ans)
    301       inv_out = self.evaluate(tf_ans_inv)
    302     self.assertAllClose(out, expected_ans)
    303     self.assertAllClose(inv_out, diag)
    304     self.assertShapeEqual(expected_ans, tf_ans)
    305     self.assertShapeEqual(diag, tf_ans_inv)
    306 
    307   def diagOp(self, diag, dtype, expected_ans):
    308     self._diagOp(diag, dtype, expected_ans, False)
    309     self._diagOp(diag, dtype, expected_ans, True)
    310 
    311   def testEmptyTensor(self):
    312     x = np.array([])
    313     expected_ans = np.empty([0, 0])
    314     self.diagOp(x, np.int32, expected_ans)
    315 
    316   def testRankOneIntTensor(self):
    317     x = np.array([1, 2, 3])
    318     expected_ans = np.array([[1, 0, 0], [0, 2, 0], [0, 0, 3]])
    319     self.diagOp(x, np.int32, expected_ans)
    320     self.diagOp(x, np.int64, expected_ans)
    321 
    322   def testRankOneFloatTensor(self):
    323     x = np.array([1.1, 2.2, 3.3])
    324     expected_ans = np.array([[1.1, 0, 0], [0, 2.2, 0], [0, 0, 3.3]])
    325     self.diagOp(x, np.float32, expected_ans)
    326     self.diagOp(x, np.float64, expected_ans)
    327 
    328   def testRankOneComplexTensor(self):
    329     for dtype in [np.complex64, np.complex128]:
    330       x = np.array([1.1 + 1.1j, 2.2 + 2.2j, 3.3 + 3.3j], dtype=dtype)
    331       expected_ans = np.array(
    332           [[1.1 + 1.1j, 0 + 0j, 0 + 0j], [0 + 0j, 2.2 + 2.2j, 0 + 0j],
    333            [0 + 0j, 0 + 0j, 3.3 + 3.3j]],
    334           dtype=dtype)
    335       self.diagOp(x, dtype, expected_ans)
    336 
    337   def testRankTwoIntTensor(self):
    338     x = np.array([[1, 2, 3], [4, 5, 6]])
    339     expected_ans = np.array([[[[1, 0, 0], [0, 0, 0]], [[0, 2, 0], [0, 0, 0]],
    340                               [[0, 0, 3], [0, 0, 0]]],
    341                              [[[0, 0, 0], [4, 0, 0]], [[0, 0, 0], [0, 5, 0]],
    342                               [[0, 0, 0], [0, 0, 6]]]])
    343     self.diagOp(x, np.int32, expected_ans)
    344     self.diagOp(x, np.int64, expected_ans)
    345 
    346   def testRankTwoFloatTensor(self):
    347     x = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]])
    348     expected_ans = np.array(
    349         [[[[1.1, 0, 0], [0, 0, 0]], [[0, 2.2, 0], [0, 0, 0]],
    350           [[0, 0, 3.3], [0, 0, 0]]], [[[0, 0, 0], [4.4, 0, 0]],
    351                                       [[0, 0, 0], [0, 5.5, 0]], [[0, 0, 0],
    352                                                                  [0, 0, 6.6]]]])
    353     self.diagOp(x, np.float32, expected_ans)
    354     self.diagOp(x, np.float64, expected_ans)
    355 
    356   def testRankTwoComplexTensor(self):
    357     for dtype in [np.complex64, np.complex128]:
    358       x = np.array(
    359           [[1.1 + 1.1j, 2.2 + 2.2j, 3.3 + 3.3j],
    360            [4.4 + 4.4j, 5.5 + 5.5j, 6.6 + 6.6j]],
    361           dtype=dtype)
    362       expected_ans = np.array(
    363           [[[[1.1 + 1.1j, 0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j, 0 + 0j]], [
    364               [0 + 0j, 2.2 + 2.2j, 0 + 0j], [0 + 0j, 0 + 0j, 0 + 0j]
    365           ], [[0 + 0j, 0 + 0j, 3.3 + 3.3j], [0 + 0j, 0 + 0j, 0 + 0j]]], [[
    366               [0 + 0j, 0 + 0j, 0 + 0j], [4.4 + 4.4j, 0 + 0j, 0 + 0j]
    367           ], [[0 + 0j, 0 + 0j, 0 + 0j], [0 + 0j, 5.5 + 5.5j, 0 + 0j]
    368              ], [[0 + 0j, 0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j, 6.6 + 6.6j]]]],
    369           dtype=dtype)
    370       self.diagOp(x, dtype, expected_ans)
    371 
    372   def testRankThreeFloatTensor(self):
    373     x = np.array([[[1.1, 2.2], [3.3, 4.4]], [[5.5, 6.6], [7.7, 8.8]]])
    374     expected_ans = np.array([[[[[[1.1, 0], [0, 0]], [[0, 0], [0, 0]]],
    375                                [[[0, 2.2], [0, 0]], [[0, 0], [0, 0]]]],
    376                               [[[[0, 0], [3.3, 0]], [[0, 0], [0, 0]]],
    377                                [[[0, 0], [0, 4.4]], [[0, 0], [0, 0]]]]],
    378                              [[[[[0, 0], [0, 0]], [[5.5, 0], [0, 0]]],
    379                                [[[0, 0], [0, 0]], [[0, 6.6], [0, 0]]]],
    380                               [[[[0, 0], [0, 0]], [[0, 0], [7.7, 0]]],
    381                                [[[0, 0], [0, 0]], [[0, 0], [0, 8.8]]]]]])
    382     self.diagOp(x, np.float32, expected_ans)
    383     self.diagOp(x, np.float64, expected_ans)
    384 
    385   def testRankThreeComplexTensor(self):
    386     for dtype in [np.complex64, np.complex128]:
    387       x = np.array(
    388           [[[1.1 + 1.1j, 2.2 + 2.2j], [3.3 + 3.3j, 4.4 + 4.4j]],
    389            [[5.5 + 5.5j, 6.6 + 6.6j], [7.7 + 7.7j, 8.8 + 8.8j]]],
    390           dtype=dtype)
    391       expected_ans = np.array(
    392           [[[[[[1.1 + 1.1j, 0 + 0j], [0 + 0j, 0 + 0j]], [[0 + 0j, 0 + 0j], [
    393               0 + 0j, 0 + 0j
    394           ]]], [[[0 + 0j, 2.2 + 2.2j], [0 + 0j, 0 + 0j]], [[0 + 0j, 0 + 0j], [
    395               0 + 0j, 0 + 0j
    396           ]]]], [[[[0 + 0j, 0 + 0j], [3.3 + 3.3j, 0 + 0j]], [[0 + 0j, 0 + 0j], [
    397               0 + 0j, 0 + 0j
    398           ]]], [[[0 + 0j, 0 + 0j], [0 + 0j, 4.4 + 4.4j]], [[0 + 0j, 0 + 0j], [
    399               0 + 0j, 0 + 0j
    400           ]]]]], [[[[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]], [
    401               [5.5 + 5.5j, 0 + 0j], [0 + 0j, 0 + 0j]
    402           ]], [[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]], [[0 + 0j, 6.6 + 6.6j], [
    403               0 + 0j, 0 + 0j
    404           ]]]], [[[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]], [[0 + 0j, 0 + 0j], [
    405               7.7 + 7.7j, 0 + 0j
    406           ]]], [[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]],
    407                 [[0 + 0j, 0 + 0j], [0 + 0j, 8.8 + 8.8j]]]]]],
    408           dtype=dtype)
    409       self.diagOp(x, dtype, expected_ans)
    410 
    411   def testRankFourNumberTensor(self):
    412     for dtype in [np.float32, np.float64, np.int64, np.int32]:
    413       # Input with shape [2, 1, 2, 3]
    414       x = np.array(
    415           [[[[1, 2, 3], [4, 5, 6]]], [[[7, 8, 9], [10, 11, 12]]]], dtype=dtype)
    416       # Output with shape [2, 1, 2, 3, 2, 1, 2, 3]
    417       expected_ans = np.array(
    418           [[[[[[[[1, 0, 0], [0, 0, 0]]], [[[0, 0, 0], [0, 0, 0]]]], [
    419               [[[0, 2, 0], [0, 0, 0]]], [[[0, 0, 0], [0, 0, 0]]]
    420           ], [[[[0, 0, 3], [0, 0, 0]]], [[[0, 0, 0], [0, 0, 0]]]]], [[
    421               [[[0, 0, 0], [4, 0, 0]]], [[[0, 0, 0], [0, 0, 0]]]
    422           ], [[[[0, 0, 0], [0, 5, 0]]], [[[0, 0, 0], [0, 0, 0]]]], [
    423               [[[0, 0, 0], [0, 0, 6]]], [[[0, 0, 0], [0, 0, 0]]]
    424           ]]]], [[[[[[[0, 0, 0], [0, 0, 0]]], [[[7, 0, 0], [0, 0, 0]]]], [
    425               [[[0, 0, 0], [0, 0, 0]]], [[[0, 8, 0], [0, 0, 0]]]
    426           ], [[[[0, 0, 0], [0, 0, 0]]], [[[0, 0, 9], [0, 0, 0]]]]], [[
    427               [[[0, 0, 0], [0, 0, 0]]], [[[0, 0, 0], [10, 0, 0]]]
    428           ], [[[[0, 0, 0], [0, 0, 0]]], [[[0, 0, 0], [0, 11, 0]]]
    429              ], [[[[0, 0, 0], [0, 0, 0]]], [[[0, 0, 0], [0, 0, 12]]]]]]]],
    430           dtype=dtype)
    431       self.diagOp(x, dtype, expected_ans)
    432 
    433   @test_util.run_deprecated_v1
    434   def testInvalidRank(self):
    435     with self.assertRaisesRegexp(ValueError, "must be at least rank 1"):
    436       array_ops.diag(0.0)
    437 
    438 
    439 class DiagPartOpTest(test.TestCase):
    440 
    441   def setUp(self):
    442     np.random.seed(0)
    443 
    444   def _diagPartOp(self, tensor, dtype, expected_ans, use_gpu):
    445     with self.cached_session(use_gpu=use_gpu):
    446       tensor = ops.convert_to_tensor(tensor.astype(dtype))
    447       tf_ans_inv = array_ops.diag_part(tensor)
    448       inv_out = self.evaluate(tf_ans_inv)
    449     self.assertAllClose(inv_out, expected_ans)
    450     self.assertShapeEqual(expected_ans, tf_ans_inv)
    451 
    452   def diagPartOp(self, tensor, dtype, expected_ans):
    453     self._diagPartOp(tensor, dtype, expected_ans, False)
    454     self._diagPartOp(tensor, dtype, expected_ans, True)
    455 
    456   def testRankTwoFloatTensor(self):
    457     x = np.random.rand(3, 3)
    458     i = np.arange(3)
    459     expected_ans = x[i, i]
    460     self.diagPartOp(x, np.float32, expected_ans)
    461     self.diagPartOp(x, np.float64, expected_ans)
    462 
    463   def testRankFourFloatTensorUnknownShape(self):
    464     x = np.random.rand(3, 3)
    465     i = np.arange(3)
    466     expected_ans = x[i, i]
    467     for shape in None, (None, 3), (3, None):
    468       with self.cached_session(use_gpu=False):
    469         t = ops.convert_to_tensor(x.astype(np.float32))
    470         t.set_shape(shape)
    471         tf_ans = array_ops.diag_part(t)
    472         out = self.evaluate(tf_ans)
    473       self.assertAllClose(out, expected_ans)
    474       self.assertShapeEqual(expected_ans, tf_ans)
    475 
    476   def testRankFourFloatTensor(self):
    477     x = np.random.rand(2, 3, 2, 3)
    478     i = np.arange(2)[:, None]
    479     j = np.arange(3)
    480     expected_ans = x[i, j, i, j]
    481     self.diagPartOp(x, np.float32, expected_ans)
    482     self.diagPartOp(x, np.float64, expected_ans)
    483 
    484   def testRankSixFloatTensor(self):
    485     x = np.random.rand(2, 2, 2, 2, 2, 2)
    486     i = np.arange(2)[:, None, None]
    487     j = np.arange(2)[:, None]
    488     k = np.arange(2)
    489     expected_ans = x[i, j, k, i, j, k]
    490     self.diagPartOp(x, np.float32, expected_ans)
    491     self.diagPartOp(x, np.float64, expected_ans)
    492 
    493   def testRankEightComplexTensor(self):
    494     x = np.random.rand(2, 2, 2, 3, 2, 2, 2, 3)
    495     i = np.arange(2)[:, None, None, None]
    496     j = np.arange(2)[:, None, None]
    497     k = np.arange(2)[:, None]
    498     l = np.arange(3)
    499     expected_ans = x[i, j, k, l, i, j, k, l]
    500     self.diagPartOp(x, np.complex64, expected_ans)
    501     self.diagPartOp(x, np.complex128, expected_ans)
    502 
    503   @test_util.run_deprecated_v1
    504   def testOddRank(self):
    505     w = np.random.rand(2)
    506     x = np.random.rand(2, 2, 2)
    507     self.assertRaises(ValueError, self.diagPartOp, w, np.float32, 0)
    508     self.assertRaises(ValueError, self.diagPartOp, x, np.float32, 0)
    509     with self.assertRaises(ValueError):
    510       array_ops.diag_part(0.0)
    511 
    512   @test_util.run_deprecated_v1
    513   def testUnevenDimensions(self):
    514     w = np.random.rand(2, 5)
    515     x = np.random.rand(2, 1, 2, 3)
    516     self.assertRaises(ValueError, self.diagPartOp, w, np.float32, 0)
    517     self.assertRaises(ValueError, self.diagPartOp, x, np.float32, 0)
    518 
    519 
    520 class DiagGradOpTest(test.TestCase):
    521 
    522   @test_util.run_deprecated_v1
    523   def testDiagGrad(self):
    524     np.random.seed(0)
    525     shapes = ((3,), (3, 3), (3, 3, 3))
    526     dtypes = (dtypes_lib.float32, dtypes_lib.float64)
    527     with self.session(use_gpu=False):
    528       errors = []
    529       for shape in shapes:
    530         for dtype in dtypes:
    531           x1 = constant_op.constant(np.random.rand(*shape), dtype=dtype)
    532           y = array_ops.diag(x1)
    533           error = gradient_checker.compute_gradient_error(
    534               x1,
    535               x1.get_shape().as_list(), y,
    536               y.get_shape().as_list())
    537           tf_logging.info("error = %f", error)
    538           self.assertLess(error, 1e-4)
    539 
    540 
    541 class DiagGradPartOpTest(test.TestCase):
    542 
    543   @test_util.run_deprecated_v1
    544   def testDiagPartGrad(self):
    545     np.random.seed(0)
    546     shapes = ((3, 3), (3, 3, 3, 3))
    547     dtypes = (dtypes_lib.float32, dtypes_lib.float64)
    548     with self.session(use_gpu=False):
    549       errors = []
    550       for shape in shapes:
    551         for dtype in dtypes:
    552           x1 = constant_op.constant(np.random.rand(*shape), dtype=dtype)
    553           y = array_ops.diag_part(x1)
    554           error = gradient_checker.compute_gradient_error(
    555               x1,
    556               x1.get_shape().as_list(), y,
    557               y.get_shape().as_list())
    558           tf_logging.info("error = %f", error)
    559           self.assertLess(error, 1e-4)
    560 
    561 
    562 if __name__ == "__main__":
    563   test.main()
    564