1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for tensorflow.python.ops.linalg_ops.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 23 from tensorflow.python.framework import dtypes 24 from tensorflow.python.framework import ops 25 from tensorflow.python.ops import array_ops 26 from tensorflow.python.ops import linalg_ops 27 from tensorflow.python.ops import math_ops 28 from tensorflow.python.ops.linalg import linalg 29 from tensorflow.python.platform import test 30 31 32 def _AddTest(test_class, op_name, testcase_name, fn): 33 test_name = "_".join(["test", op_name, testcase_name]) 34 if hasattr(test_class, test_name): 35 raise RuntimeError("Test %s defined more than once" % test_name) 36 setattr(test_class, test_name, fn) 37 38 39 def _RandomPDMatrix(n, rng, dtype=np.float64): 40 """Random positive definite matrix.""" 41 temp = rng.randn(n, n).astype(dtype) 42 if dtype in [np.complex64, np.complex128]: 43 temp.imag = rng.randn(n, n) 44 return np.conj(temp).dot(temp.T) 45 46 47 class CholeskySolveTest(test.TestCase): 48 49 def setUp(self): 50 self.rng = np.random.RandomState(0) 51 52 def test_works_with_five_different_random_pos_def_matrices(self): 53 for n in range(1, 6): 54 for np_type, atol in [(np.float32, 0.05), (np.float64, 1e-5)]: 55 with self.test_session(use_gpu=True): 56 # Create 2 x n x n matrix 57 array = np.array( 58 [_RandomPDMatrix(n, self.rng), 59 _RandomPDMatrix(n, self.rng)]).astype(np_type) 60 chol = linalg_ops.cholesky(array) 61 for k in range(1, 3): 62 rhs = self.rng.randn(2, n, k).astype(np_type) 63 x = linalg_ops.cholesky_solve(chol, rhs) 64 self.assertAllClose( 65 rhs, math_ops.matmul(array, x).eval(), atol=atol) 66 67 68 class LogdetTest(test.TestCase): 69 70 def setUp(self): 71 self.rng = np.random.RandomState(42) 72 73 def test_works_with_five_different_random_pos_def_matrices(self): 74 for n in range(1, 6): 75 for np_dtype, atol in [(np.float32, 0.05), (np.float64, 1e-5), 76 (np.complex64, 0.05), (np.complex128, 1e-5)]: 77 matrix = _RandomPDMatrix(n, self.rng, np_dtype) 78 _, logdet_np = np.linalg.slogdet(matrix) 79 with self.test_session(use_gpu=True): 80 # Create 2 x n x n matrix 81 # matrix = np.array( 82 # [_RandomPDMatrix(n, self.rng, np_dtype), 83 # _RandomPDMatrix(n, self.rng, np_dtype)]).astype(np_dtype) 84 logdet_tf = linalg.logdet(matrix) 85 self.assertAllClose(logdet_np, logdet_tf.eval(), atol=atol) 86 87 def test_works_with_underflow_case(self): 88 for np_dtype, atol in [(np.float32, 0.05), (np.float64, 1e-5), 89 (np.complex64, 0.05), (np.complex128, 1e-5)]: 90 matrix = (np.eye(20) * 1e-6).astype(np_dtype) 91 _, logdet_np = np.linalg.slogdet(matrix) 92 with self.test_session(use_gpu=True): 93 logdet_tf = linalg.logdet(matrix) 94 self.assertAllClose(logdet_np, logdet_tf.eval(), atol=atol) 95 96 97 class SlogdetTest(test.TestCase): 98 99 def setUp(self): 100 self.rng = np.random.RandomState(42) 101 102 def test_works_with_five_different_random_pos_def_matrices(self): 103 for n in range(1, 6): 104 for np_dtype, atol in [(np.float32, 0.05), (np.float64, 1e-5), 105 (np.complex64, 0.05), (np.complex128, 1e-5)]: 106 matrix = _RandomPDMatrix(n, self.rng, np_dtype) 107 sign_np, log_abs_det_np = np.linalg.slogdet(matrix) 108 with self.test_session(use_gpu=True): 109 sign_tf, log_abs_det_tf = linalg.slogdet(matrix) 110 self.assertAllClose(log_abs_det_np, log_abs_det_tf.eval(), atol=atol) 111 self.assertAllClose(sign_np, sign_tf.eval(), atol=atol) 112 113 def test_works_with_underflow_case(self): 114 for np_dtype, atol in [(np.float32, 0.05), (np.float64, 1e-5), 115 (np.complex64, 0.05), (np.complex128, 1e-5)]: 116 matrix = (np.eye(20) * 1e-6).astype(np_dtype) 117 sign_np, log_abs_det_np = np.linalg.slogdet(matrix) 118 with self.test_session(use_gpu=True): 119 sign_tf, log_abs_det_tf = linalg.slogdet(matrix) 120 self.assertAllClose(log_abs_det_np, log_abs_det_tf.eval(), atol=atol) 121 self.assertAllClose(sign_np, sign_tf.eval(), atol=atol) 122 123 124 class AdjointTest(test.TestCase): 125 126 def test_compare_to_numpy(self): 127 for dtype in np.float64, np.float64, np.complex64, np.complex128: 128 matrix_np = np.array([[1 + 1j, 2 + 2j, 3 + 3j], [4 + 4j, 5 + 5j, 129 6 + 6j]]).astype(dtype) 130 expected_transposed = np.conj(matrix_np.T) 131 with self.test_session(): 132 matrix = ops.convert_to_tensor(matrix_np) 133 transposed = linalg.adjoint(matrix) 134 self.assertEqual((3, 2), transposed.get_shape()) 135 self.assertAllEqual(expected_transposed, transposed.eval()) 136 137 138 class EyeTest(test.TestCase): 139 pass # Will be filled in below 140 141 142 def _GetEyeTest(num_rows, num_columns, batch_shape, dtype): 143 144 def Test(self): 145 eye_np = np.eye(num_rows, M=num_columns, dtype=dtype.as_numpy_dtype) 146 if batch_shape is not None: 147 eye_np = np.tile(eye_np, batch_shape + [1, 1]) 148 for use_placeholder in False, True: 149 if use_placeholder and (num_columns is None or batch_shape is None): 150 return 151 with self.test_session(use_gpu=True) as sess: 152 if use_placeholder: 153 num_rows_placeholder = array_ops.placeholder( 154 dtypes.int32, name="num_rows") 155 num_columns_placeholder = array_ops.placeholder( 156 dtypes.int32, name="num_columns") 157 batch_shape_placeholder = array_ops.placeholder( 158 dtypes.int32, name="batch_shape") 159 eye = linalg_ops.eye( 160 num_rows_placeholder, 161 num_columns=num_columns_placeholder, 162 batch_shape=batch_shape_placeholder, 163 dtype=dtype) 164 eye_tf = sess.run( 165 eye, 166 feed_dict={ 167 num_rows_placeholder: num_rows, 168 num_columns_placeholder: num_columns, 169 batch_shape_placeholder: batch_shape 170 }) 171 else: 172 eye_tf = linalg_ops.eye( 173 num_rows, 174 num_columns=num_columns, 175 batch_shape=batch_shape, 176 dtype=dtype).eval() 177 self.assertAllEqual(eye_np, eye_tf) 178 179 return Test 180 181 182 if __name__ == "__main__": 183 for _num_rows in 0, 1, 2, 5: 184 for _num_columns in None, 0, 1, 2, 5: 185 for _batch_shape in None, [], [2], [2, 3]: 186 for _dtype in (dtypes.int32, dtypes.int64, dtypes.float32, 187 dtypes.float64, dtypes.complex64, dtypes.complex128): 188 name = "dtype_%s_num_rows_%s_num_column_%s_batch_shape_%s_" % ( 189 _dtype.name, _num_rows, _num_columns, _batch_shape) 190 _AddTest(EyeTest, "EyeTest", name, 191 _GetEyeTest(_num_rows, _num_columns, _batch_shape, _dtype)) 192 193 test.main() 194