Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import numpy as np
     21 
     22 from tensorflow.python.framework import constant_op
     23 from tensorflow.python.framework import dtypes as dtypes_lib
     24 from tensorflow.python.framework import ops
     25 from tensorflow.python.ops import array_ops
     26 from tensorflow.python.ops import gradient_checker
     27 from tensorflow.python.ops import gradients_impl
     28 from tensorflow.python.platform import test
     29 from tensorflow.python.platform import tf_logging
     30 
     31 
     32 class MatrixDiagTest(test.TestCase):
     33 
     34   def testVector(self):
     35     with self.test_session(use_gpu=True):
     36       v = np.array([1.0, 2.0, 3.0])
     37       mat = np.diag(v)
     38       v_diag = array_ops.matrix_diag(v)
     39       self.assertEqual((3, 3), v_diag.get_shape())
     40       self.assertAllEqual(v_diag.eval(), mat)
     41 
     42   def _testBatchVector(self, dtype):
     43     with self.test_session(use_gpu=True):
     44       v_batch = np.array([[1.0, 0.0, 3.0], [4.0, 5.0, 6.0]]).astype(dtype)
     45       mat_batch = np.array([[[1.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 3.0]],
     46                             [[4.0, 0.0, 0.0], [0.0, 5.0, 0.0],
     47                              [0.0, 0.0, 6.0]]]).astype(dtype)
     48       v_batch_diag = array_ops.matrix_diag(v_batch)
     49       self.assertEqual((2, 3, 3), v_batch_diag.get_shape())
     50       self.assertAllEqual(v_batch_diag.eval(), mat_batch)
     51 
     52   def testBatchVector(self):
     53     self._testBatchVector(np.float32)
     54     self._testBatchVector(np.float64)
     55     self._testBatchVector(np.int32)
     56     self._testBatchVector(np.int64)
     57     self._testBatchVector(np.bool)
     58 
     59   def testInvalidShape(self):
     60     with self.assertRaisesRegexp(ValueError, "must be at least rank 1"):
     61       array_ops.matrix_diag(0)
     62 
     63   def testInvalidShapeAtEval(self):
     64     with self.test_session(use_gpu=True):
     65       v = array_ops.placeholder(dtype=dtypes_lib.float32)
     66       with self.assertRaisesOpError("input must be at least 1-dim"):
     67         array_ops.matrix_diag(v).eval(feed_dict={v: 0.0})
     68 
     69   def testGrad(self):
     70     shapes = ((3,), (7, 4))
     71     with self.test_session(use_gpu=True):
     72       for shape in shapes:
     73         x = constant_op.constant(np.random.rand(*shape), np.float32)
     74         y = array_ops.matrix_diag(x)
     75         error = gradient_checker.compute_gradient_error(x,
     76                                                         x.get_shape().as_list(),
     77                                                         y,
     78                                                         y.get_shape().as_list())
     79         self.assertLess(error, 1e-4)
     80 
     81 
     82 class MatrixSetDiagTest(test.TestCase):
     83 
     84   def testSquare(self):
     85     with self.test_session(use_gpu=True):
     86       v = np.array([1.0, 2.0, 3.0])
     87       mat = np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 1.0], [1.0, 1.0, 1.0]])
     88       mat_set_diag = np.array([[1.0, 1.0, 0.0], [1.0, 2.0, 1.0],
     89                                [1.0, 1.0, 3.0]])
     90       output = array_ops.matrix_set_diag(mat, v)
     91       self.assertEqual((3, 3), output.get_shape())
     92       self.assertAllEqual(mat_set_diag, output.eval())
     93 
     94   def testRectangular(self):
     95     with self.test_session(use_gpu=True):
     96       v = np.array([3.0, 4.0])
     97       mat = np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 1.0]])
     98       expected = np.array([[3.0, 1.0, 0.0], [1.0, 4.0, 1.0]])
     99       output = array_ops.matrix_set_diag(mat, v)
    100       self.assertEqual((2, 3), output.get_shape())
    101       self.assertAllEqual(expected, output.eval())
    102 
    103       v = np.array([3.0, 4.0])
    104       mat = np.array([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0]])
    105       expected = np.array([[3.0, 1.0], [1.0, 4.0], [1.0, 1.0]])
    106       output = array_ops.matrix_set_diag(mat, v)
    107       self.assertEqual((3, 2), output.get_shape())
    108       self.assertAllEqual(expected, output.eval())
    109 
    110   def _testSquareBatch(self, dtype):
    111     with self.test_session(use_gpu=True):
    112       v_batch = np.array([[-1.0, 0.0, -3.0], [-4.0, -5.0, -6.0]]).astype(dtype)
    113       mat_batch = np.array([[[1.0, 0.0, 3.0], [0.0, 2.0, 0.0], [1.0, 0.0, 3.0]],
    114                             [[4.0, 0.0, 4.0], [0.0, 5.0, 0.0],
    115                              [2.0, 0.0, 6.0]]]).astype(dtype)
    116 
    117       mat_set_diag_batch = np.array([[[-1.0, 0.0, 3.0], [0.0, 0.0, 0.0],
    118                                       [1.0, 0.0, -3.0]],
    119                                      [[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0],
    120                                       [2.0, 0.0, -6.0]]]).astype(dtype)
    121 
    122       output = array_ops.matrix_set_diag(mat_batch, v_batch)
    123       self.assertEqual((2, 3, 3), output.get_shape())
    124       self.assertAllEqual(mat_set_diag_batch, output.eval())
    125 
    126   def testSquareBatch(self):
    127     self._testSquareBatch(np.float32)
    128     self._testSquareBatch(np.float64)
    129     self._testSquareBatch(np.int32)
    130     self._testSquareBatch(np.int64)
    131     self._testSquareBatch(np.bool)
    132 
    133   def testRectangularBatch(self):
    134     with self.test_session(use_gpu=True):
    135       v_batch = np.array([[-1.0, -2.0], [-4.0, -5.0]])
    136       mat_batch = np.array([[[1.0, 0.0, 3.0], [0.0, 2.0, 0.0]],
    137                             [[4.0, 0.0, 4.0], [0.0, 5.0, 0.0]]])
    138 
    139       mat_set_diag_batch = np.array([[[-1.0, 0.0, 3.0], [0.0, -2.0, 0.0]],
    140                                      [[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0]]])
    141       output = array_ops.matrix_set_diag(mat_batch, v_batch)
    142       self.assertEqual((2, 2, 3), output.get_shape())
    143       self.assertAllEqual(mat_set_diag_batch, output.eval())
    144 
    145   def testInvalidShape(self):
    146     with self.assertRaisesRegexp(ValueError, "must be at least rank 2"):
    147       array_ops.matrix_set_diag(0, [0])
    148     with self.assertRaisesRegexp(ValueError, "must be at least rank 1"):
    149       array_ops.matrix_set_diag([[0]], 0)
    150 
    151   def testInvalidShapeAtEval(self):
    152     with self.test_session(use_gpu=True):
    153       v = array_ops.placeholder(dtype=dtypes_lib.float32)
    154       with self.assertRaisesOpError("input must be at least 2-dim"):
    155         array_ops.matrix_set_diag(v, [v]).eval(feed_dict={v: 0.0})
    156       with self.assertRaisesOpError(
    157           r"but received input shape: \[1,1\] and diagonal shape: \[\]"):
    158         array_ops.matrix_set_diag([[v]], v).eval(feed_dict={v: 0.0})
    159 
    160   def testGrad(self):
    161     shapes = ((3, 4, 4), (3, 3, 4), (3, 4, 3), (7, 4, 8, 8))
    162     with self.test_session(use_gpu=True):
    163       for shape in shapes:
    164         x = constant_op.constant(
    165             np.random.rand(*shape), dtype=dtypes_lib.float32)
    166         diag_shape = shape[:-2] + (min(shape[-2:]),)
    167         x_diag = constant_op.constant(
    168             np.random.rand(*diag_shape), dtype=dtypes_lib.float32)
    169         y = array_ops.matrix_set_diag(x, x_diag)
    170         error_x = gradient_checker.compute_gradient_error(
    171             x,
    172             x.get_shape().as_list(), y,
    173             y.get_shape().as_list())
    174         self.assertLess(error_x, 1e-4)
    175         error_x_diag = gradient_checker.compute_gradient_error(
    176             x_diag,
    177             x_diag.get_shape().as_list(), y,
    178             y.get_shape().as_list())
    179         self.assertLess(error_x_diag, 1e-4)
    180 
    181   def testGradWithNoShapeInformation(self):
    182     with self.test_session(use_gpu=True) as sess:
    183       v = array_ops.placeholder(dtype=dtypes_lib.float32)
    184       mat = array_ops.placeholder(dtype=dtypes_lib.float32)
    185       grad_input = array_ops.placeholder(dtype=dtypes_lib.float32)
    186       output = array_ops.matrix_set_diag(mat, v)
    187       grads = gradients_impl.gradients(output, [mat, v], grad_ys=grad_input)
    188       grad_input_val = np.random.rand(3, 3).astype(np.float32)
    189       grad_vals = sess.run(
    190           grads,
    191           feed_dict={
    192               v: 2 * np.ones(3),
    193               mat: np.ones((3, 3)),
    194               grad_input: grad_input_val
    195           })
    196       self.assertAllEqual(np.diag(grad_input_val), grad_vals[1])
    197       self.assertAllEqual(grad_input_val - np.diag(np.diag(grad_input_val)),
    198                           grad_vals[0])
    199 
    200 
    201 class MatrixDiagPartTest(test.TestCase):
    202 
    203   def testSquare(self):
    204     with self.test_session(use_gpu=True):
    205       v = np.array([1.0, 2.0, 3.0])
    206       mat = np.diag(v)
    207       mat_diag = array_ops.matrix_diag_part(mat)
    208       self.assertEqual((3,), mat_diag.get_shape())
    209       self.assertAllEqual(mat_diag.eval(), v)
    210 
    211   def testRectangular(self):
    212     with self.test_session(use_gpu=True):
    213       mat = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
    214       mat_diag = array_ops.matrix_diag_part(mat)
    215       self.assertAllEqual(mat_diag.eval(), np.array([1.0, 5.0]))
    216       mat = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
    217       mat_diag = array_ops.matrix_diag_part(mat)
    218       self.assertAllEqual(mat_diag.eval(), np.array([1.0, 4.0]))
    219 
    220   def _testSquareBatch(self, dtype):
    221     with self.test_session(use_gpu=True):
    222       v_batch = np.array([[1.0, 0.0, 3.0], [4.0, 5.0, 6.0]]).astype(dtype)
    223       mat_batch = np.array([[[1.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 3.0]],
    224                             [[4.0, 0.0, 0.0], [0.0, 5.0, 0.0],
    225                              [0.0, 0.0, 6.0]]]).astype(dtype)
    226       self.assertEqual(mat_batch.shape, (2, 3, 3))
    227       mat_batch_diag = array_ops.matrix_diag_part(mat_batch)
    228       self.assertEqual((2, 3), mat_batch_diag.get_shape())
    229       self.assertAllEqual(mat_batch_diag.eval(), v_batch)
    230 
    231   def testSquareBatch(self):
    232     self._testSquareBatch(np.float32)
    233     self._testSquareBatch(np.float64)
    234     self._testSquareBatch(np.int32)
    235     self._testSquareBatch(np.int64)
    236     self._testSquareBatch(np.bool)
    237 
    238   def testRectangularBatch(self):
    239     with self.test_session(use_gpu=True):
    240       v_batch = np.array([[1.0, 2.0], [4.0, 5.0]])
    241       mat_batch = np.array([[[1.0, 0.0, 0.0], [0.0, 2.0, 0.0]],
    242                             [[4.0, 0.0, 0.0], [0.0, 5.0, 0.0]]])
    243       self.assertEqual(mat_batch.shape, (2, 2, 3))
    244       mat_batch_diag = array_ops.matrix_diag_part(mat_batch)
    245       self.assertEqual((2, 2), mat_batch_diag.get_shape())
    246       self.assertAllEqual(mat_batch_diag.eval(), v_batch)
    247 
    248   def testInvalidShape(self):
    249     with self.assertRaisesRegexp(ValueError, "must be at least rank 2"):
    250       array_ops.matrix_diag_part(0)
    251 
    252   def testInvalidShapeAtEval(self):
    253     with self.test_session(use_gpu=True):
    254       v = array_ops.placeholder(dtype=dtypes_lib.float32)
    255       with self.assertRaisesOpError("input must be at least 2-dim"):
    256         array_ops.matrix_diag_part(v).eval(feed_dict={v: 0.0})
    257 
    258   def testGrad(self):
    259     shapes = ((3, 3), (2, 3), (3, 2), (5, 3, 3))
    260     with self.test_session(use_gpu=True):
    261       for shape in shapes:
    262         x = constant_op.constant(np.random.rand(*shape), dtype=np.float32)
    263         y = array_ops.matrix_diag_part(x)
    264         error = gradient_checker.compute_gradient_error(x,
    265                                                         x.get_shape().as_list(),
    266                                                         y,
    267                                                         y.get_shape().as_list())
    268         self.assertLess(error, 1e-4)
    269 
    270 
    271 class DiagTest(test.TestCase):
    272 
    273   def _diagOp(self, diag, dtype, expected_ans, use_gpu):
    274     with self.test_session(use_gpu=use_gpu):
    275       tf_ans = array_ops.diag(ops.convert_to_tensor(diag.astype(dtype)))
    276       out = tf_ans.eval()
    277       tf_ans_inv = array_ops.diag_part(expected_ans)
    278       inv_out = tf_ans_inv.eval()
    279     self.assertAllClose(out, expected_ans)
    280     self.assertAllClose(inv_out, diag)
    281     self.assertShapeEqual(expected_ans, tf_ans)
    282     self.assertShapeEqual(diag, tf_ans_inv)
    283 
    284   def diagOp(self, diag, dtype, expected_ans):
    285     self._diagOp(diag, dtype, expected_ans, False)
    286     self._diagOp(diag, dtype, expected_ans, True)
    287 
    288   def testEmptyTensor(self):
    289     x = np.array([])
    290     expected_ans = np.empty([0, 0])
    291     self.diagOp(x, np.int32, expected_ans)
    292 
    293   def testRankOneIntTensor(self):
    294     x = np.array([1, 2, 3])
    295     expected_ans = np.array([[1, 0, 0], [0, 2, 0], [0, 0, 3]])
    296     self.diagOp(x, np.int32, expected_ans)
    297     self.diagOp(x, np.int64, expected_ans)
    298 
    299   def testRankOneFloatTensor(self):
    300     x = np.array([1.1, 2.2, 3.3])
    301     expected_ans = np.array([[1.1, 0, 0], [0, 2.2, 0], [0, 0, 3.3]])
    302     self.diagOp(x, np.float32, expected_ans)
    303     self.diagOp(x, np.float64, expected_ans)
    304 
    305   def testRankOneComplexTensor(self):
    306     for dtype in [np.complex64, np.complex128]:
    307       x = np.array([1.1 + 1.1j, 2.2 + 2.2j, 3.3 + 3.3j], dtype=dtype)
    308       expected_ans = np.array(
    309           [[1.1 + 1.1j, 0 + 0j, 0 + 0j], [0 + 0j, 2.2 + 2.2j, 0 + 0j],
    310            [0 + 0j, 0 + 0j, 3.3 + 3.3j]],
    311           dtype=dtype)
    312       self.diagOp(x, dtype, expected_ans)
    313 
    314   def testRankTwoIntTensor(self):
    315     x = np.array([[1, 2, 3], [4, 5, 6]])
    316     expected_ans = np.array([[[[1, 0, 0], [0, 0, 0]], [[0, 2, 0], [0, 0, 0]],
    317                               [[0, 0, 3], [0, 0, 0]]],
    318                              [[[0, 0, 0], [4, 0, 0]], [[0, 0, 0], [0, 5, 0]],
    319                               [[0, 0, 0], [0, 0, 6]]]])
    320     self.diagOp(x, np.int32, expected_ans)
    321     self.diagOp(x, np.int64, expected_ans)
    322 
    323   def testRankTwoFloatTensor(self):
    324     x = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]])
    325     expected_ans = np.array(
    326         [[[[1.1, 0, 0], [0, 0, 0]], [[0, 2.2, 0], [0, 0, 0]],
    327           [[0, 0, 3.3], [0, 0, 0]]], [[[0, 0, 0], [4.4, 0, 0]],
    328                                       [[0, 0, 0], [0, 5.5, 0]], [[0, 0, 0],
    329                                                                  [0, 0, 6.6]]]])
    330     self.diagOp(x, np.float32, expected_ans)
    331     self.diagOp(x, np.float64, expected_ans)
    332 
    333   def testRankTwoComplexTensor(self):
    334     for dtype in [np.complex64, np.complex128]:
    335       x = np.array(
    336           [[1.1 + 1.1j, 2.2 + 2.2j, 3.3 + 3.3j],
    337            [4.4 + 4.4j, 5.5 + 5.5j, 6.6 + 6.6j]],
    338           dtype=dtype)
    339       expected_ans = np.array(
    340           [[[[1.1 + 1.1j, 0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j, 0 + 0j]], [
    341               [0 + 0j, 2.2 + 2.2j, 0 + 0j], [0 + 0j, 0 + 0j, 0 + 0j]
    342           ], [[0 + 0j, 0 + 0j, 3.3 + 3.3j], [0 + 0j, 0 + 0j, 0 + 0j]]], [[
    343               [0 + 0j, 0 + 0j, 0 + 0j], [4.4 + 4.4j, 0 + 0j, 0 + 0j]
    344           ], [[0 + 0j, 0 + 0j, 0 + 0j], [0 + 0j, 5.5 + 5.5j, 0 + 0j]
    345              ], [[0 + 0j, 0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j, 6.6 + 6.6j]]]],
    346           dtype=dtype)
    347       self.diagOp(x, dtype, expected_ans)
    348 
    349   def testRankThreeFloatTensor(self):
    350     x = np.array([[[1.1, 2.2], [3.3, 4.4]], [[5.5, 6.6], [7.7, 8.8]]])
    351     expected_ans = np.array([[[[[[1.1, 0], [0, 0]], [[0, 0], [0, 0]]],
    352                                [[[0, 2.2], [0, 0]], [[0, 0], [0, 0]]]],
    353                               [[[[0, 0], [3.3, 0]], [[0, 0], [0, 0]]],
    354                                [[[0, 0], [0, 4.4]], [[0, 0], [0, 0]]]]],
    355                              [[[[[0, 0], [0, 0]], [[5.5, 0], [0, 0]]],
    356                                [[[0, 0], [0, 0]], [[0, 6.6], [0, 0]]]],
    357                               [[[[0, 0], [0, 0]], [[0, 0], [7.7, 0]]],
    358                                [[[0, 0], [0, 0]], [[0, 0], [0, 8.8]]]]]])
    359     self.diagOp(x, np.float32, expected_ans)
    360     self.diagOp(x, np.float64, expected_ans)
    361 
    362   def testRankThreeComplexTensor(self):
    363     for dtype in [np.complex64, np.complex128]:
    364       x = np.array(
    365           [[[1.1 + 1.1j, 2.2 + 2.2j], [3.3 + 3.3j, 4.4 + 4.4j]],
    366            [[5.5 + 5.5j, 6.6 + 6.6j], [7.7 + 7.7j, 8.8 + 8.8j]]],
    367           dtype=dtype)
    368       expected_ans = np.array(
    369           [[[[[[1.1 + 1.1j, 0 + 0j], [0 + 0j, 0 + 0j]], [[0 + 0j, 0 + 0j], [
    370               0 + 0j, 0 + 0j
    371           ]]], [[[0 + 0j, 2.2 + 2.2j], [0 + 0j, 0 + 0j]], [[0 + 0j, 0 + 0j], [
    372               0 + 0j, 0 + 0j
    373           ]]]], [[[[0 + 0j, 0 + 0j], [3.3 + 3.3j, 0 + 0j]], [[0 + 0j, 0 + 0j], [
    374               0 + 0j, 0 + 0j
    375           ]]], [[[0 + 0j, 0 + 0j], [0 + 0j, 4.4 + 4.4j]], [[0 + 0j, 0 + 0j], [
    376               0 + 0j, 0 + 0j
    377           ]]]]], [[[[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]], [
    378               [5.5 + 5.5j, 0 + 0j], [0 + 0j, 0 + 0j]
    379           ]], [[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]], [[0 + 0j, 6.6 + 6.6j], [
    380               0 + 0j, 0 + 0j
    381           ]]]], [[[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]], [[0 + 0j, 0 + 0j], [
    382               7.7 + 7.7j, 0 + 0j
    383           ]]], [[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]],
    384                 [[0 + 0j, 0 + 0j], [0 + 0j, 8.8 + 8.8j]]]]]],
    385           dtype=dtype)
    386       self.diagOp(x, dtype, expected_ans)
    387 
    388   def testRankFourNumberTensor(self):
    389     for dtype in [np.float32, np.float64, np.int64, np.int32]:
    390       # Input with shape [2, 1, 2, 3]
    391       x = np.array(
    392           [[[[1, 2, 3], [4, 5, 6]]], [[[7, 8, 9], [10, 11, 12]]]], dtype=dtype)
    393       # Output with shape [2, 1, 2, 3, 2, 1, 2, 3]
    394       expected_ans = np.array(
    395           [[[[[[[[1, 0, 0], [0, 0, 0]]], [[[0, 0, 0], [0, 0, 0]]]], [
    396               [[[0, 2, 0], [0, 0, 0]]], [[[0, 0, 0], [0, 0, 0]]]
    397           ], [[[[0, 0, 3], [0, 0, 0]]], [[[0, 0, 0], [0, 0, 0]]]]], [[
    398               [[[0, 0, 0], [4, 0, 0]]], [[[0, 0, 0], [0, 0, 0]]]
    399           ], [[[[0, 0, 0], [0, 5, 0]]], [[[0, 0, 0], [0, 0, 0]]]], [
    400               [[[0, 0, 0], [0, 0, 6]]], [[[0, 0, 0], [0, 0, 0]]]
    401           ]]]], [[[[[[[0, 0, 0], [0, 0, 0]]], [[[7, 0, 0], [0, 0, 0]]]], [
    402               [[[0, 0, 0], [0, 0, 0]]], [[[0, 8, 0], [0, 0, 0]]]
    403           ], [[[[0, 0, 0], [0, 0, 0]]], [[[0, 0, 9], [0, 0, 0]]]]], [[
    404               [[[0, 0, 0], [0, 0, 0]]], [[[0, 0, 0], [10, 0, 0]]]
    405           ], [[[[0, 0, 0], [0, 0, 0]]], [[[0, 0, 0], [0, 11, 0]]]
    406              ], [[[[0, 0, 0], [0, 0, 0]]], [[[0, 0, 0], [0, 0, 12]]]]]]]],
    407           dtype=dtype)
    408       self.diagOp(x, dtype, expected_ans)
    409 
    410   def testInvalidRank(self):
    411     with self.assertRaisesRegexp(ValueError, "must be at least rank 1"):
    412       array_ops.diag(0.0)
    413 
    414 
    415 class DiagPartOpTest(test.TestCase):
    416 
    417   def setUp(self):
    418     np.random.seed(0)
    419 
    420   def _diagPartOp(self, tensor, dtype, expected_ans, use_gpu):
    421     with self.test_session(use_gpu=use_gpu):
    422       tensor = ops.convert_to_tensor(tensor.astype(dtype))
    423       tf_ans_inv = array_ops.diag_part(tensor)
    424       inv_out = tf_ans_inv.eval()
    425     self.assertAllClose(inv_out, expected_ans)
    426     self.assertShapeEqual(expected_ans, tf_ans_inv)
    427 
    428   def diagPartOp(self, tensor, dtype, expected_ans):
    429     self._diagPartOp(tensor, dtype, expected_ans, False)
    430     self._diagPartOp(tensor, dtype, expected_ans, True)
    431 
    432   def testRankTwoFloatTensor(self):
    433     x = np.random.rand(3, 3)
    434     i = np.arange(3)
    435     expected_ans = x[i, i]
    436     self.diagPartOp(x, np.float32, expected_ans)
    437     self.diagPartOp(x, np.float64, expected_ans)
    438 
    439   def testRankFourFloatTensorUnknownShape(self):
    440     x = np.random.rand(3, 3)
    441     i = np.arange(3)
    442     expected_ans = x[i, i]
    443     for shape in None, (None, 3), (3, None):
    444       with self.test_session(use_gpu=False):
    445         t = ops.convert_to_tensor(x.astype(np.float32))
    446         t.set_shape(shape)
    447         tf_ans = array_ops.diag_part(t)
    448         out = tf_ans.eval()
    449       self.assertAllClose(out, expected_ans)
    450       self.assertShapeEqual(expected_ans, tf_ans)
    451 
    452   def testRankFourFloatTensor(self):
    453     x = np.random.rand(2, 3, 2, 3)
    454     i = np.arange(2)[:, None]
    455     j = np.arange(3)
    456     expected_ans = x[i, j, i, j]
    457     self.diagPartOp(x, np.float32, expected_ans)
    458     self.diagPartOp(x, np.float64, expected_ans)
    459 
    460   def testRankSixFloatTensor(self):
    461     x = np.random.rand(2, 2, 2, 2, 2, 2)
    462     i = np.arange(2)[:, None, None]
    463     j = np.arange(2)[:, None]
    464     k = np.arange(2)
    465     expected_ans = x[i, j, k, i, j, k]
    466     self.diagPartOp(x, np.float32, expected_ans)
    467     self.diagPartOp(x, np.float64, expected_ans)
    468 
    469   def testRankEightComplexTensor(self):
    470     x = np.random.rand(2, 2, 2, 3, 2, 2, 2, 3)
    471     i = np.arange(2)[:, None, None, None]
    472     j = np.arange(2)[:, None, None]
    473     k = np.arange(2)[:, None]
    474     l = np.arange(3)
    475     expected_ans = x[i, j, k, l, i, j, k, l]
    476     self.diagPartOp(x, np.complex64, expected_ans)
    477     self.diagPartOp(x, np.complex128, expected_ans)
    478 
    479   def testOddRank(self):
    480     w = np.random.rand(2)
    481     x = np.random.rand(2, 2, 2)
    482     self.assertRaises(ValueError, self.diagPartOp, w, np.float32, 0)
    483     self.assertRaises(ValueError, self.diagPartOp, x, np.float32, 0)
    484     with self.assertRaises(ValueError):
    485       array_ops.diag_part(0.0)
    486 
    487   def testUnevenDimensions(self):
    488     w = np.random.rand(2, 5)
    489     x = np.random.rand(2, 1, 2, 3)
    490     self.assertRaises(ValueError, self.diagPartOp, w, np.float32, 0)
    491     self.assertRaises(ValueError, self.diagPartOp, x, np.float32, 0)
    492 
    493 
    494 class DiagGradOpTest(test.TestCase):
    495 
    496   def testDiagGrad(self):
    497     np.random.seed(0)
    498     shapes = ((3,), (3, 3), (3, 3, 3))
    499     dtypes = (dtypes_lib.float32, dtypes_lib.float64)
    500     with self.test_session(use_gpu=False):
    501       errors = []
    502       for shape in shapes:
    503         for dtype in dtypes:
    504           x1 = constant_op.constant(np.random.rand(*shape), dtype=dtype)
    505           y = array_ops.diag(x1)
    506           error = gradient_checker.compute_gradient_error(
    507               x1,
    508               x1.get_shape().as_list(), y,
    509               y.get_shape().as_list())
    510           tf_logging.info("error = %f", error)
    511           self.assertLess(error, 1e-4)
    512 
    513 
    514 class DiagGradPartOpTest(test.TestCase):
    515 
    516   def testDiagPartGrad(self):
    517     np.random.seed(0)
    518     shapes = ((3, 3), (3, 3, 3, 3))
    519     dtypes = (dtypes_lib.float32, dtypes_lib.float64)
    520     with self.test_session(use_gpu=False):
    521       errors = []
    522       for shape in shapes:
    523         for dtype in dtypes:
    524           x1 = constant_op.constant(np.random.rand(*shape), dtype=dtype)
    525           y = array_ops.diag_part(x1)
    526           error = gradient_checker.compute_gradient_error(
    527               x1,
    528               x1.get_shape().as_list(), y,
    529               y.get_shape().as_list())
    530           tf_logging.info("error = %f", error)
    531           self.assertLess(error, 1e-4)
    532 
    533 
    534 if __name__ == "__main__":
    535   test.main()
    536