1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import numpy as np 21 22 from tensorflow.python.framework import constant_op 23 from tensorflow.python.framework import dtypes as dtypes_lib 24 from tensorflow.python.framework import ops 25 from tensorflow.python.framework import test_util 26 from tensorflow.python.ops import array_ops 27 from tensorflow.python.ops import gradient_checker 28 from tensorflow.python.ops import gradients_impl 29 from tensorflow.python.platform import test 30 from tensorflow.python.platform import tf_logging 31 32 33 class MatrixDiagTest(test.TestCase): 34 35 @test_util.run_deprecated_v1 36 def testVector(self): 37 with self.session(use_gpu=True): 38 v = np.array([1.0, 2.0, 3.0]) 39 mat = np.diag(v) 40 v_diag = array_ops.matrix_diag(v) 41 self.assertEqual((3, 3), v_diag.get_shape()) 42 self.assertAllEqual(v_diag.eval(), mat) 43 44 def _testBatchVector(self, dtype): 45 with self.cached_session(use_gpu=True): 46 v_batch = np.array([[1.0, 0.0, 3.0], [4.0, 5.0, 6.0]]).astype(dtype) 47 mat_batch = np.array([[[1.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 3.0]], 48 [[4.0, 0.0, 0.0], [0.0, 5.0, 0.0], 49 [0.0, 0.0, 6.0]]]).astype(dtype) 50 v_batch_diag = array_ops.matrix_diag(v_batch) 51 self.assertEqual((2, 3, 3), v_batch_diag.get_shape()) 52 self.assertAllEqual(v_batch_diag.eval(), mat_batch) 53 54 @test_util.run_deprecated_v1 55 def testBatchVector(self): 56 self._testBatchVector(np.float32) 57 self._testBatchVector(np.float64) 58 self._testBatchVector(np.int32) 59 self._testBatchVector(np.int64) 60 self._testBatchVector(np.bool) 61 62 @test_util.run_deprecated_v1 63 def testInvalidShape(self): 64 with self.assertRaisesRegexp(ValueError, "must be at least rank 1"): 65 array_ops.matrix_diag(0) 66 67 @test_util.run_deprecated_v1 68 @test_util.disable_xla("b/123337890") # Error messages differ 69 def testInvalidShapeAtEval(self): 70 with self.session(use_gpu=True): 71 v = array_ops.placeholder(dtype=dtypes_lib.float32) 72 with self.assertRaisesOpError("input must be at least 1-dim"): 73 array_ops.matrix_diag(v).eval(feed_dict={v: 0.0}) 74 75 @test_util.run_deprecated_v1 76 def testGrad(self): 77 shapes = ((3,), (7, 4)) 78 with self.session(use_gpu=True): 79 for shape in shapes: 80 x = constant_op.constant(np.random.rand(*shape), np.float32) 81 y = array_ops.matrix_diag(x) 82 error = gradient_checker.compute_gradient_error(x, 83 x.get_shape().as_list(), 84 y, 85 y.get_shape().as_list()) 86 self.assertLess(error, 1e-4) 87 88 89 class MatrixSetDiagTest(test.TestCase): 90 91 @test_util.run_deprecated_v1 92 def testSquare(self): 93 with self.session(use_gpu=True): 94 v = np.array([1.0, 2.0, 3.0]) 95 mat = np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 1.0], [1.0, 1.0, 1.0]]) 96 mat_set_diag = np.array([[1.0, 1.0, 0.0], [1.0, 2.0, 1.0], 97 [1.0, 1.0, 3.0]]) 98 output = array_ops.matrix_set_diag(mat, v) 99 self.assertEqual((3, 3), output.get_shape()) 100 self.assertAllEqual(mat_set_diag, self.evaluate(output)) 101 102 @test_util.run_deprecated_v1 103 def testRectangular(self): 104 with self.session(use_gpu=True): 105 v = np.array([3.0, 4.0]) 106 mat = np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 1.0]]) 107 expected = np.array([[3.0, 1.0, 0.0], [1.0, 4.0, 1.0]]) 108 output = array_ops.matrix_set_diag(mat, v) 109 self.assertEqual((2, 3), output.get_shape()) 110 self.assertAllEqual(expected, self.evaluate(output)) 111 112 v = np.array([3.0, 4.0]) 113 mat = np.array([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0]]) 114 expected = np.array([[3.0, 1.0], [1.0, 4.0], [1.0, 1.0]]) 115 output = array_ops.matrix_set_diag(mat, v) 116 self.assertEqual((3, 2), output.get_shape()) 117 self.assertAllEqual(expected, self.evaluate(output)) 118 119 def _testSquareBatch(self, dtype): 120 with self.cached_session(use_gpu=True): 121 v_batch = np.array([[-1.0, 0.0, -3.0], [-4.0, -5.0, -6.0]]).astype(dtype) 122 mat_batch = np.array([[[1.0, 0.0, 3.0], [0.0, 2.0, 0.0], [1.0, 0.0, 3.0]], 123 [[4.0, 0.0, 4.0], [0.0, 5.0, 0.0], 124 [2.0, 0.0, 6.0]]]).astype(dtype) 125 126 mat_set_diag_batch = np.array([[[-1.0, 0.0, 3.0], [0.0, 0.0, 0.0], 127 [1.0, 0.0, -3.0]], 128 [[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0], 129 [2.0, 0.0, -6.0]]]).astype(dtype) 130 131 output = array_ops.matrix_set_diag(mat_batch, v_batch) 132 self.assertEqual((2, 3, 3), output.get_shape()) 133 self.assertAllEqual(mat_set_diag_batch, self.evaluate(output)) 134 135 @test_util.run_deprecated_v1 136 def testSquareBatch(self): 137 self._testSquareBatch(np.float32) 138 self._testSquareBatch(np.float64) 139 self._testSquareBatch(np.int32) 140 self._testSquareBatch(np.int64) 141 self._testSquareBatch(np.bool) 142 143 @test_util.run_deprecated_v1 144 def testRectangularBatch(self): 145 with self.session(use_gpu=True): 146 v_batch = np.array([[-1.0, -2.0], [-4.0, -5.0]]) 147 mat_batch = np.array([[[1.0, 0.0, 3.0], [0.0, 2.0, 0.0]], 148 [[4.0, 0.0, 4.0], [0.0, 5.0, 0.0]]]) 149 150 mat_set_diag_batch = np.array([[[-1.0, 0.0, 3.0], [0.0, -2.0, 0.0]], 151 [[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0]]]) 152 output = array_ops.matrix_set_diag(mat_batch, v_batch) 153 self.assertEqual((2, 2, 3), output.get_shape()) 154 self.assertAllEqual(mat_set_diag_batch, self.evaluate(output)) 155 156 @test_util.run_deprecated_v1 157 def testInvalidShape(self): 158 with self.assertRaisesRegexp(ValueError, "must be at least rank 2"): 159 array_ops.matrix_set_diag(0, [0]) 160 with self.assertRaisesRegexp(ValueError, "must be at least rank 1"): 161 array_ops.matrix_set_diag([[0]], 0) 162 163 @test_util.run_deprecated_v1 164 def testInvalidShapeAtEval(self): 165 with self.session(use_gpu=True): 166 v = array_ops.placeholder(dtype=dtypes_lib.float32) 167 with self.assertRaisesOpError("input must be at least 2-dim"): 168 array_ops.matrix_set_diag(v, [v]).eval(feed_dict={v: 0.0}) 169 with self.assertRaisesOpError( 170 r"but received input shape: \[1,1\] and diagonal shape: \[\]"): 171 array_ops.matrix_set_diag([[v]], v).eval(feed_dict={v: 0.0}) 172 173 @test_util.run_deprecated_v1 174 def testGrad(self): 175 shapes = ((3, 4, 4), (3, 3, 4), (3, 4, 3), (7, 4, 8, 8)) 176 with self.session(use_gpu=True): 177 for shape in shapes: 178 x = constant_op.constant( 179 np.random.rand(*shape), dtype=dtypes_lib.float32) 180 diag_shape = shape[:-2] + (min(shape[-2:]),) 181 x_diag = constant_op.constant( 182 np.random.rand(*diag_shape), dtype=dtypes_lib.float32) 183 y = array_ops.matrix_set_diag(x, x_diag) 184 error_x = gradient_checker.compute_gradient_error( 185 x, 186 x.get_shape().as_list(), y, 187 y.get_shape().as_list()) 188 self.assertLess(error_x, 1e-4) 189 error_x_diag = gradient_checker.compute_gradient_error( 190 x_diag, 191 x_diag.get_shape().as_list(), y, 192 y.get_shape().as_list()) 193 self.assertLess(error_x_diag, 1e-4) 194 195 @test_util.run_deprecated_v1 196 def testGradWithNoShapeInformation(self): 197 with self.session(use_gpu=True) as sess: 198 v = array_ops.placeholder(dtype=dtypes_lib.float32) 199 mat = array_ops.placeholder(dtype=dtypes_lib.float32) 200 grad_input = array_ops.placeholder(dtype=dtypes_lib.float32) 201 output = array_ops.matrix_set_diag(mat, v) 202 grads = gradients_impl.gradients(output, [mat, v], grad_ys=grad_input) 203 grad_input_val = np.random.rand(3, 3).astype(np.float32) 204 grad_vals = sess.run( 205 grads, 206 feed_dict={ 207 v: 2 * np.ones(3), 208 mat: np.ones((3, 3)), 209 grad_input: grad_input_val 210 }) 211 self.assertAllEqual(np.diag(grad_input_val), grad_vals[1]) 212 self.assertAllEqual(grad_input_val - np.diag(np.diag(grad_input_val)), 213 grad_vals[0]) 214 215 216 class MatrixDiagPartTest(test.TestCase): 217 218 @test_util.run_deprecated_v1 219 def testSquare(self): 220 with self.session(use_gpu=True): 221 v = np.array([1.0, 2.0, 3.0]) 222 mat = np.diag(v) 223 mat_diag = array_ops.matrix_diag_part(mat) 224 self.assertEqual((3,), mat_diag.get_shape()) 225 self.assertAllEqual(mat_diag.eval(), v) 226 227 @test_util.run_deprecated_v1 228 def testRectangular(self): 229 with self.session(use_gpu=True): 230 mat = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 231 mat_diag = array_ops.matrix_diag_part(mat) 232 self.assertAllEqual(mat_diag.eval(), np.array([1.0, 5.0])) 233 mat = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) 234 mat_diag = array_ops.matrix_diag_part(mat) 235 self.assertAllEqual(mat_diag.eval(), np.array([1.0, 4.0])) 236 237 def _testSquareBatch(self, dtype): 238 with self.cached_session(use_gpu=True): 239 v_batch = np.array([[1.0, 0.0, 3.0], [4.0, 5.0, 6.0]]).astype(dtype) 240 mat_batch = np.array([[[1.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 3.0]], 241 [[4.0, 0.0, 0.0], [0.0, 5.0, 0.0], 242 [0.0, 0.0, 6.0]]]).astype(dtype) 243 self.assertEqual(mat_batch.shape, (2, 3, 3)) 244 mat_batch_diag = array_ops.matrix_diag_part(mat_batch) 245 self.assertEqual((2, 3), mat_batch_diag.get_shape()) 246 self.assertAllEqual(mat_batch_diag.eval(), v_batch) 247 248 @test_util.run_deprecated_v1 249 def testSquareBatch(self): 250 self._testSquareBatch(np.float32) 251 self._testSquareBatch(np.float64) 252 self._testSquareBatch(np.int32) 253 self._testSquareBatch(np.int64) 254 self._testSquareBatch(np.bool) 255 256 @test_util.run_deprecated_v1 257 def testRectangularBatch(self): 258 with self.session(use_gpu=True): 259 v_batch = np.array([[1.0, 2.0], [4.0, 5.0]]) 260 mat_batch = np.array([[[1.0, 0.0, 0.0], [0.0, 2.0, 0.0]], 261 [[4.0, 0.0, 0.0], [0.0, 5.0, 0.0]]]) 262 self.assertEqual(mat_batch.shape, (2, 2, 3)) 263 mat_batch_diag = array_ops.matrix_diag_part(mat_batch) 264 self.assertEqual((2, 2), mat_batch_diag.get_shape()) 265 self.assertAllEqual(mat_batch_diag.eval(), v_batch) 266 267 @test_util.run_deprecated_v1 268 def testInvalidShape(self): 269 with self.assertRaisesRegexp(ValueError, "must be at least rank 2"): 270 array_ops.matrix_diag_part(0) 271 272 @test_util.run_deprecated_v1 273 @test_util.disable_xla("b/123337890") # Error messages differ 274 def testInvalidShapeAtEval(self): 275 with self.session(use_gpu=True): 276 v = array_ops.placeholder(dtype=dtypes_lib.float32) 277 with self.assertRaisesOpError("input must be at least 2-dim"): 278 array_ops.matrix_diag_part(v).eval(feed_dict={v: 0.0}) 279 280 @test_util.run_deprecated_v1 281 def testGrad(self): 282 shapes = ((3, 3), (2, 3), (3, 2), (5, 3, 3)) 283 with self.session(use_gpu=True): 284 for shape in shapes: 285 x = constant_op.constant(np.random.rand(*shape), dtype=np.float32) 286 y = array_ops.matrix_diag_part(x) 287 error = gradient_checker.compute_gradient_error(x, 288 x.get_shape().as_list(), 289 y, 290 y.get_shape().as_list()) 291 self.assertLess(error, 1e-4) 292 293 294 class DiagTest(test.TestCase): 295 296 def _diagOp(self, diag, dtype, expected_ans, use_gpu): 297 with self.cached_session(use_gpu=use_gpu): 298 tf_ans = array_ops.diag(ops.convert_to_tensor(diag.astype(dtype))) 299 out = self.evaluate(tf_ans) 300 tf_ans_inv = array_ops.diag_part(expected_ans) 301 inv_out = self.evaluate(tf_ans_inv) 302 self.assertAllClose(out, expected_ans) 303 self.assertAllClose(inv_out, diag) 304 self.assertShapeEqual(expected_ans, tf_ans) 305 self.assertShapeEqual(diag, tf_ans_inv) 306 307 def diagOp(self, diag, dtype, expected_ans): 308 self._diagOp(diag, dtype, expected_ans, False) 309 self._diagOp(diag, dtype, expected_ans, True) 310 311 def testEmptyTensor(self): 312 x = np.array([]) 313 expected_ans = np.empty([0, 0]) 314 self.diagOp(x, np.int32, expected_ans) 315 316 def testRankOneIntTensor(self): 317 x = np.array([1, 2, 3]) 318 expected_ans = np.array([[1, 0, 0], [0, 2, 0], [0, 0, 3]]) 319 self.diagOp(x, np.int32, expected_ans) 320 self.diagOp(x, np.int64, expected_ans) 321 322 def testRankOneFloatTensor(self): 323 x = np.array([1.1, 2.2, 3.3]) 324 expected_ans = np.array([[1.1, 0, 0], [0, 2.2, 0], [0, 0, 3.3]]) 325 self.diagOp(x, np.float32, expected_ans) 326 self.diagOp(x, np.float64, expected_ans) 327 328 def testRankOneComplexTensor(self): 329 for dtype in [np.complex64, np.complex128]: 330 x = np.array([1.1 + 1.1j, 2.2 + 2.2j, 3.3 + 3.3j], dtype=dtype) 331 expected_ans = np.array( 332 [[1.1 + 1.1j, 0 + 0j, 0 + 0j], [0 + 0j, 2.2 + 2.2j, 0 + 0j], 333 [0 + 0j, 0 + 0j, 3.3 + 3.3j]], 334 dtype=dtype) 335 self.diagOp(x, dtype, expected_ans) 336 337 def testRankTwoIntTensor(self): 338 x = np.array([[1, 2, 3], [4, 5, 6]]) 339 expected_ans = np.array([[[[1, 0, 0], [0, 0, 0]], [[0, 2, 0], [0, 0, 0]], 340 [[0, 0, 3], [0, 0, 0]]], 341 [[[0, 0, 0], [4, 0, 0]], [[0, 0, 0], [0, 5, 0]], 342 [[0, 0, 0], [0, 0, 6]]]]) 343 self.diagOp(x, np.int32, expected_ans) 344 self.diagOp(x, np.int64, expected_ans) 345 346 def testRankTwoFloatTensor(self): 347 x = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]]) 348 expected_ans = np.array( 349 [[[[1.1, 0, 0], [0, 0, 0]], [[0, 2.2, 0], [0, 0, 0]], 350 [[0, 0, 3.3], [0, 0, 0]]], [[[0, 0, 0], [4.4, 0, 0]], 351 [[0, 0, 0], [0, 5.5, 0]], [[0, 0, 0], 352 [0, 0, 6.6]]]]) 353 self.diagOp(x, np.float32, expected_ans) 354 self.diagOp(x, np.float64, expected_ans) 355 356 def testRankTwoComplexTensor(self): 357 for dtype in [np.complex64, np.complex128]: 358 x = np.array( 359 [[1.1 + 1.1j, 2.2 + 2.2j, 3.3 + 3.3j], 360 [4.4 + 4.4j, 5.5 + 5.5j, 6.6 + 6.6j]], 361 dtype=dtype) 362 expected_ans = np.array( 363 [[[[1.1 + 1.1j, 0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j, 0 + 0j]], [ 364 [0 + 0j, 2.2 + 2.2j, 0 + 0j], [0 + 0j, 0 + 0j, 0 + 0j] 365 ], [[0 + 0j, 0 + 0j, 3.3 + 3.3j], [0 + 0j, 0 + 0j, 0 + 0j]]], [[ 366 [0 + 0j, 0 + 0j, 0 + 0j], [4.4 + 4.4j, 0 + 0j, 0 + 0j] 367 ], [[0 + 0j, 0 + 0j, 0 + 0j], [0 + 0j, 5.5 + 5.5j, 0 + 0j] 368 ], [[0 + 0j, 0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j, 6.6 + 6.6j]]]], 369 dtype=dtype) 370 self.diagOp(x, dtype, expected_ans) 371 372 def testRankThreeFloatTensor(self): 373 x = np.array([[[1.1, 2.2], [3.3, 4.4]], [[5.5, 6.6], [7.7, 8.8]]]) 374 expected_ans = np.array([[[[[[1.1, 0], [0, 0]], [[0, 0], [0, 0]]], 375 [[[0, 2.2], [0, 0]], [[0, 0], [0, 0]]]], 376 [[[[0, 0], [3.3, 0]], [[0, 0], [0, 0]]], 377 [[[0, 0], [0, 4.4]], [[0, 0], [0, 0]]]]], 378 [[[[[0, 0], [0, 0]], [[5.5, 0], [0, 0]]], 379 [[[0, 0], [0, 0]], [[0, 6.6], [0, 0]]]], 380 [[[[0, 0], [0, 0]], [[0, 0], [7.7, 0]]], 381 [[[0, 0], [0, 0]], [[0, 0], [0, 8.8]]]]]]) 382 self.diagOp(x, np.float32, expected_ans) 383 self.diagOp(x, np.float64, expected_ans) 384 385 def testRankThreeComplexTensor(self): 386 for dtype in [np.complex64, np.complex128]: 387 x = np.array( 388 [[[1.1 + 1.1j, 2.2 + 2.2j], [3.3 + 3.3j, 4.4 + 4.4j]], 389 [[5.5 + 5.5j, 6.6 + 6.6j], [7.7 + 7.7j, 8.8 + 8.8j]]], 390 dtype=dtype) 391 expected_ans = np.array( 392 [[[[[[1.1 + 1.1j, 0 + 0j], [0 + 0j, 0 + 0j]], [[0 + 0j, 0 + 0j], [ 393 0 + 0j, 0 + 0j 394 ]]], [[[0 + 0j, 2.2 + 2.2j], [0 + 0j, 0 + 0j]], [[0 + 0j, 0 + 0j], [ 395 0 + 0j, 0 + 0j 396 ]]]], [[[[0 + 0j, 0 + 0j], [3.3 + 3.3j, 0 + 0j]], [[0 + 0j, 0 + 0j], [ 397 0 + 0j, 0 + 0j 398 ]]], [[[0 + 0j, 0 + 0j], [0 + 0j, 4.4 + 4.4j]], [[0 + 0j, 0 + 0j], [ 399 0 + 0j, 0 + 0j 400 ]]]]], [[[[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]], [ 401 [5.5 + 5.5j, 0 + 0j], [0 + 0j, 0 + 0j] 402 ]], [[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]], [[0 + 0j, 6.6 + 6.6j], [ 403 0 + 0j, 0 + 0j 404 ]]]], [[[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]], [[0 + 0j, 0 + 0j], [ 405 7.7 + 7.7j, 0 + 0j 406 ]]], [[[0 + 0j, 0 + 0j], [0 + 0j, 0 + 0j]], 407 [[0 + 0j, 0 + 0j], [0 + 0j, 8.8 + 8.8j]]]]]], 408 dtype=dtype) 409 self.diagOp(x, dtype, expected_ans) 410 411 def testRankFourNumberTensor(self): 412 for dtype in [np.float32, np.float64, np.int64, np.int32]: 413 # Input with shape [2, 1, 2, 3] 414 x = np.array( 415 [[[[1, 2, 3], [4, 5, 6]]], [[[7, 8, 9], [10, 11, 12]]]], dtype=dtype) 416 # Output with shape [2, 1, 2, 3, 2, 1, 2, 3] 417 expected_ans = np.array( 418 [[[[[[[[1, 0, 0], [0, 0, 0]]], [[[0, 0, 0], [0, 0, 0]]]], [ 419 [[[0, 2, 0], [0, 0, 0]]], [[[0, 0, 0], [0, 0, 0]]] 420 ], [[[[0, 0, 3], [0, 0, 0]]], [[[0, 0, 0], [0, 0, 0]]]]], [[ 421 [[[0, 0, 0], [4, 0, 0]]], [[[0, 0, 0], [0, 0, 0]]] 422 ], [[[[0, 0, 0], [0, 5, 0]]], [[[0, 0, 0], [0, 0, 0]]]], [ 423 [[[0, 0, 0], [0, 0, 6]]], [[[0, 0, 0], [0, 0, 0]]] 424 ]]]], [[[[[[[0, 0, 0], [0, 0, 0]]], [[[7, 0, 0], [0, 0, 0]]]], [ 425 [[[0, 0, 0], [0, 0, 0]]], [[[0, 8, 0], [0, 0, 0]]] 426 ], [[[[0, 0, 0], [0, 0, 0]]], [[[0, 0, 9], [0, 0, 0]]]]], [[ 427 [[[0, 0, 0], [0, 0, 0]]], [[[0, 0, 0], [10, 0, 0]]] 428 ], [[[[0, 0, 0], [0, 0, 0]]], [[[0, 0, 0], [0, 11, 0]]] 429 ], [[[[0, 0, 0], [0, 0, 0]]], [[[0, 0, 0], [0, 0, 12]]]]]]]], 430 dtype=dtype) 431 self.diagOp(x, dtype, expected_ans) 432 433 @test_util.run_deprecated_v1 434 def testInvalidRank(self): 435 with self.assertRaisesRegexp(ValueError, "must be at least rank 1"): 436 array_ops.diag(0.0) 437 438 439 class DiagPartOpTest(test.TestCase): 440 441 def setUp(self): 442 np.random.seed(0) 443 444 def _diagPartOp(self, tensor, dtype, expected_ans, use_gpu): 445 with self.cached_session(use_gpu=use_gpu): 446 tensor = ops.convert_to_tensor(tensor.astype(dtype)) 447 tf_ans_inv = array_ops.diag_part(tensor) 448 inv_out = self.evaluate(tf_ans_inv) 449 self.assertAllClose(inv_out, expected_ans) 450 self.assertShapeEqual(expected_ans, tf_ans_inv) 451 452 def diagPartOp(self, tensor, dtype, expected_ans): 453 self._diagPartOp(tensor, dtype, expected_ans, False) 454 self._diagPartOp(tensor, dtype, expected_ans, True) 455 456 def testRankTwoFloatTensor(self): 457 x = np.random.rand(3, 3) 458 i = np.arange(3) 459 expected_ans = x[i, i] 460 self.diagPartOp(x, np.float32, expected_ans) 461 self.diagPartOp(x, np.float64, expected_ans) 462 463 def testRankFourFloatTensorUnknownShape(self): 464 x = np.random.rand(3, 3) 465 i = np.arange(3) 466 expected_ans = x[i, i] 467 for shape in None, (None, 3), (3, None): 468 with self.cached_session(use_gpu=False): 469 t = ops.convert_to_tensor(x.astype(np.float32)) 470 t.set_shape(shape) 471 tf_ans = array_ops.diag_part(t) 472 out = self.evaluate(tf_ans) 473 self.assertAllClose(out, expected_ans) 474 self.assertShapeEqual(expected_ans, tf_ans) 475 476 def testRankFourFloatTensor(self): 477 x = np.random.rand(2, 3, 2, 3) 478 i = np.arange(2)[:, None] 479 j = np.arange(3) 480 expected_ans = x[i, j, i, j] 481 self.diagPartOp(x, np.float32, expected_ans) 482 self.diagPartOp(x, np.float64, expected_ans) 483 484 def testRankSixFloatTensor(self): 485 x = np.random.rand(2, 2, 2, 2, 2, 2) 486 i = np.arange(2)[:, None, None] 487 j = np.arange(2)[:, None] 488 k = np.arange(2) 489 expected_ans = x[i, j, k, i, j, k] 490 self.diagPartOp(x, np.float32, expected_ans) 491 self.diagPartOp(x, np.float64, expected_ans) 492 493 def testRankEightComplexTensor(self): 494 x = np.random.rand(2, 2, 2, 3, 2, 2, 2, 3) 495 i = np.arange(2)[:, None, None, None] 496 j = np.arange(2)[:, None, None] 497 k = np.arange(2)[:, None] 498 l = np.arange(3) 499 expected_ans = x[i, j, k, l, i, j, k, l] 500 self.diagPartOp(x, np.complex64, expected_ans) 501 self.diagPartOp(x, np.complex128, expected_ans) 502 503 @test_util.run_deprecated_v1 504 def testOddRank(self): 505 w = np.random.rand(2) 506 x = np.random.rand(2, 2, 2) 507 self.assertRaises(ValueError, self.diagPartOp, w, np.float32, 0) 508 self.assertRaises(ValueError, self.diagPartOp, x, np.float32, 0) 509 with self.assertRaises(ValueError): 510 array_ops.diag_part(0.0) 511 512 @test_util.run_deprecated_v1 513 def testUnevenDimensions(self): 514 w = np.random.rand(2, 5) 515 x = np.random.rand(2, 1, 2, 3) 516 self.assertRaises(ValueError, self.diagPartOp, w, np.float32, 0) 517 self.assertRaises(ValueError, self.diagPartOp, x, np.float32, 0) 518 519 520 class DiagGradOpTest(test.TestCase): 521 522 @test_util.run_deprecated_v1 523 def testDiagGrad(self): 524 np.random.seed(0) 525 shapes = ((3,), (3, 3), (3, 3, 3)) 526 dtypes = (dtypes_lib.float32, dtypes_lib.float64) 527 with self.session(use_gpu=False): 528 errors = [] 529 for shape in shapes: 530 for dtype in dtypes: 531 x1 = constant_op.constant(np.random.rand(*shape), dtype=dtype) 532 y = array_ops.diag(x1) 533 error = gradient_checker.compute_gradient_error( 534 x1, 535 x1.get_shape().as_list(), y, 536 y.get_shape().as_list()) 537 tf_logging.info("error = %f", error) 538 self.assertLess(error, 1e-4) 539 540 541 class DiagGradPartOpTest(test.TestCase): 542 543 @test_util.run_deprecated_v1 544 def testDiagPartGrad(self): 545 np.random.seed(0) 546 shapes = ((3, 3), (3, 3, 3, 3)) 547 dtypes = (dtypes_lib.float32, dtypes_lib.float64) 548 with self.session(use_gpu=False): 549 errors = [] 550 for shape in shapes: 551 for dtype in dtypes: 552 x1 = constant_op.constant(np.random.rand(*shape), dtype=dtype) 553 y = array_ops.diag_part(x1) 554 error = gradient_checker.compute_gradient_error( 555 x1, 556 x1.get_shape().as_list(), y, 557 y.get_shape().as_list()) 558 tf_logging.info("error = %f", error) 559 self.assertLess(error, 1e-4) 560 561 562 if __name__ == "__main__": 563 test.main() 564