1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for tensorflow.ops.math_ops.matrix_inverse.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 23 from tensorflow.python.framework import constant_op 24 from tensorflow.python.ops import array_ops 25 from tensorflow.python.ops import gradient_checker 26 from tensorflow.python.ops import linalg_ops 27 from tensorflow.python.ops import math_ops 28 from tensorflow.python.ops import random_ops 29 from tensorflow.python.platform import test 30 31 32 def _AddTest(test_class, op_name, testcase_name, fn): 33 test_name = "_".join(["test", op_name, testcase_name]) 34 if hasattr(test_class, test_name): 35 raise RuntimeError("Test %s defined more than once" % test_name) 36 setattr(test_class, test_name, fn) 37 38 39 class QrOpTest(test.TestCase): 40 41 def testWrongDimensions(self): 42 # The input to qr should be a tensor of at least rank 2. 43 scalar = constant_op.constant(1.) 44 with self.assertRaisesRegexp(ValueError, 45 "Shape must be at least rank 2 but is rank 0"): 46 linalg_ops.qr(scalar) 47 vector = constant_op.constant([1., 2.]) 48 with self.assertRaisesRegexp(ValueError, 49 "Shape must be at least rank 2 but is rank 1"): 50 linalg_ops.qr(vector) 51 52 def testConcurrentExecutesWithoutError(self): 53 with self.test_session(use_gpu=True) as sess: 54 all_ops = [] 55 for full_matrices_ in True, False: 56 for rows_ in 4, 5: 57 for cols_ in 4, 5: 58 matrix1 = random_ops.random_normal([rows_, cols_], seed=42) 59 matrix2 = random_ops.random_normal([rows_, cols_], seed=42) 60 q1, r1 = linalg_ops.qr(matrix1, full_matrices=full_matrices_) 61 q2, r2 = linalg_ops.qr(matrix2, full_matrices=full_matrices_) 62 all_ops += [q1, r1, q2, r2] 63 val = sess.run(all_ops) 64 for i in range(8): 65 q = 4 * i 66 self.assertAllEqual(val[q], val[q + 2]) # q1 == q2 67 self.assertAllEqual(val[q + 1], val[q + 3]) # r1 == r2 68 69 70 def _GetQrOpTest(dtype_, shape_, full_matrices_, use_static_shape_): 71 72 is_complex = dtype_ in (np.complex64, np.complex128) 73 is_single = dtype_ in (np.float32, np.complex64) 74 75 def CompareOrthogonal(self, x, y, rank): 76 if is_single: 77 atol = 5e-4 78 else: 79 atol = 5e-14 80 # We only compare the first 'rank' orthogonal vectors since the 81 # remainder form an arbitrary orthonormal basis for the 82 # (row- or column-) null space, whose exact value depends on 83 # implementation details. Notice that since we check that the 84 # matrices of singular vectors are unitary elsewhere, we do 85 # implicitly test that the trailing vectors of x and y span the 86 # same space. 87 x = x[..., 0:rank] 88 y = y[..., 0:rank] 89 # Q is only unique up to sign (complex phase factor for complex matrices), 90 # so we normalize the sign first. 91 sum_of_ratios = np.sum(np.divide(y, x), -2, keepdims=True) 92 phases = np.divide(sum_of_ratios, np.abs(sum_of_ratios)) 93 x *= phases 94 self.assertAllClose(x, y, atol=atol) 95 96 def CheckApproximation(self, a, q, r): 97 if is_single: 98 tol = 1e-5 99 else: 100 tol = 1e-14 101 # Tests that a ~= q*r. 102 a_recon = math_ops.matmul(q, r) 103 self.assertAllClose(a_recon.eval(), a, rtol=tol, atol=tol) 104 105 def CheckUnitary(self, x): 106 # Tests that x[...,:,:]^H * x[...,:,:] is close to the identity. 107 xx = math_ops.matmul(x, x, adjoint_a=True) 108 identity = array_ops.matrix_band_part(array_ops.ones_like(xx), 0, 0) 109 if is_single: 110 tol = 1e-5 111 else: 112 tol = 1e-14 113 self.assertAllClose(identity.eval(), xx.eval(), atol=tol) 114 115 def Test(self): 116 np.random.seed(1) 117 x_np = np.random.uniform( 118 low=-1.0, high=1.0, size=np.prod(shape_)).reshape(shape_).astype(dtype_) 119 if is_complex: 120 x_np += 1j * np.random.uniform( 121 low=-1.0, high=1.0, 122 size=np.prod(shape_)).reshape(shape_).astype(dtype_) 123 124 with self.test_session(use_gpu=True) as sess: 125 if use_static_shape_: 126 x_tf = constant_op.constant(x_np) 127 else: 128 x_tf = array_ops.placeholder(dtype_) 129 q_tf, r_tf = linalg_ops.qr(x_tf, full_matrices=full_matrices_) 130 131 if use_static_shape_: 132 q_tf_val, r_tf_val = sess.run([q_tf, r_tf]) 133 else: 134 q_tf_val, r_tf_val = sess.run([q_tf, r_tf], feed_dict={x_tf: x_np}) 135 136 q_dims = q_tf_val.shape 137 np_q = np.ndarray(q_dims, dtype_) 138 np_q_reshape = np.reshape(np_q, (-1, q_dims[-2], q_dims[-1])) 139 new_first_dim = np_q_reshape.shape[0] 140 141 x_reshape = np.reshape(x_np, (-1, x_np.shape[-2], x_np.shape[-1])) 142 for i in range(new_first_dim): 143 if full_matrices_: 144 np_q_reshape[i, :, :], _ = np.linalg.qr( 145 x_reshape[i, :, :], mode="complete") 146 else: 147 np_q_reshape[i, :, :], _ = np.linalg.qr( 148 x_reshape[i, :, :], mode="reduced") 149 np_q = np.reshape(np_q_reshape, q_dims) 150 CompareOrthogonal(self, np_q, q_tf_val, min(shape_[-2:])) 151 CheckApproximation(self, x_np, q_tf_val, r_tf_val) 152 CheckUnitary(self, q_tf_val) 153 154 return Test 155 156 157 class QrGradOpTest(test.TestCase): 158 pass 159 160 161 def _GetQrGradOpTest(dtype_, shape_, full_matrices_): 162 163 def Test(self): 164 np.random.seed(42) 165 a = np.random.uniform(low=-1.0, high=1.0, size=shape_).astype(dtype_) 166 if dtype_ in [np.complex64, np.complex128]: 167 a += 1j * np.random.uniform( 168 low=-1.0, high=1.0, size=shape_).astype(dtype_) 169 # Optimal stepsize for central difference is O(epsilon^{1/3}). 170 epsilon = np.finfo(dtype_).eps 171 delta = 0.1 * epsilon**(1.0 / 3.0) 172 if dtype_ in [np.float32, np.complex64]: 173 tol = 3e-2 174 else: 175 tol = 1e-6 176 with self.test_session(use_gpu=True): 177 tf_a = constant_op.constant(a) 178 tf_b = linalg_ops.qr(tf_a, full_matrices=full_matrices_) 179 for b in tf_b: 180 x_init = np.random.uniform( 181 low=-1.0, high=1.0, size=shape_).astype(dtype_) 182 if dtype_ in [np.complex64, np.complex128]: 183 x_init += 1j * np.random.uniform( 184 low=-1.0, high=1.0, size=shape_).astype(dtype_) 185 theoretical, numerical = gradient_checker.compute_gradient( 186 tf_a, 187 tf_a.get_shape().as_list(), 188 b, 189 b.get_shape().as_list(), 190 x_init_value=x_init, 191 delta=delta) 192 self.assertAllClose(theoretical, numerical, atol=tol, rtol=tol) 193 194 return Test 195 196 197 if __name__ == "__main__": 198 for dtype in np.float32, np.float64, np.complex64, np.complex128: 199 for rows in 1, 2, 5, 10, 32, 100: 200 for cols in 1, 2, 5, 10, 32, 100: 201 for full_matrices in False, True: 202 for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10): 203 for use_static_shape in True, False: 204 shape = batch_dims + (rows, cols) 205 name = "%s_%s_full_%s_static_%s" % (dtype.__name__, 206 "_".join(map(str, shape)), 207 full_matrices, 208 use_static_shape) 209 _AddTest(QrOpTest, "Qr", name, 210 _GetQrOpTest(dtype, shape, full_matrices, 211 use_static_shape)) 212 213 # TODO(pfau): Get working with complex types. 214 # TODO(pfau): Get working with full_matrices when rows != cols 215 # TODO(pfau): Get working when rows < cols 216 # TODO(pfau): Get working with shapeholders (dynamic shapes) 217 for full_matrices in False, True: 218 for dtype in np.float32, np.float64: 219 for rows in 1, 2, 5, 10: 220 for cols in 1, 2, 5, 10: 221 if rows == cols or (not full_matrices and rows > cols): 222 for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10): 223 shape = batch_dims + (rows, cols) 224 name = "%s_%s_full_%s" % (dtype.__name__, 225 "_".join(map(str, shape)), 226 full_matrices) 227 _AddTest(QrGradOpTest, "QrGrad", name, 228 _GetQrGradOpTest(dtype, shape, full_matrices)) 229 test.main() 230