1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import numpy as np 21 22 from tensorflow.python.framework import dtypes 23 from tensorflow.python.framework import test_util 24 from tensorflow.python.ops import array_ops 25 from tensorflow.python.ops.linalg import linalg as linalg_lib 26 from tensorflow.python.ops.linalg import linear_operator_test_util 27 from tensorflow.python.platform import test 28 29 30 rng = np.random.RandomState(2016) 31 32 33 class LinearOperatorZerosTest( 34 linear_operator_test_util.SquareLinearOperatorDerivedClassTest): 35 """Most tests done in the base class LinearOperatorDerivedClassTest.""" 36 37 @property 38 def _tests_to_skip(self): 39 return [ 40 "cholesky", "log_abs_det", "inverse", "solve", "solve_with_broadcast"] 41 42 @property 43 def _operator_build_infos(self): 44 build_info = linear_operator_test_util.OperatorBuildInfo 45 return [ 46 build_info((1, 1)), 47 build_info((1, 3, 3)), 48 build_info((3, 4, 4)), 49 build_info((2, 1, 4, 4))] 50 51 def _operator_and_matrix( 52 self, build_info, dtype, use_placeholder, 53 ensure_self_adjoint_and_pd=False): 54 del ensure_self_adjoint_and_pd 55 del use_placeholder 56 shape = list(build_info.shape) 57 assert shape[-1] == shape[-2] 58 59 batch_shape = shape[:-2] 60 num_rows = shape[-1] 61 62 operator = linalg_lib.LinearOperatorZeros( 63 num_rows, batch_shape=batch_shape, dtype=dtype) 64 matrix = array_ops.zeros(shape=shape, dtype=dtype) 65 66 return operator, matrix 67 68 def test_assert_positive_definite(self): 69 operator = linalg_lib.LinearOperatorZeros(num_rows=2) 70 with self.assertRaisesOpError("non-positive definite"): 71 operator.assert_positive_definite() 72 73 def test_assert_non_singular(self): 74 with self.assertRaisesOpError("non-invertible"): 75 operator = linalg_lib.LinearOperatorZeros(num_rows=2) 76 operator.assert_non_singular() 77 78 @test_util.run_deprecated_v1 79 def test_assert_self_adjoint(self): 80 with self.cached_session(): 81 operator = linalg_lib.LinearOperatorZeros(num_rows=2) 82 operator.assert_self_adjoint().run() # Should not fail 83 84 def test_non_scalar_num_rows_raises_static(self): 85 with self.assertRaisesRegexp(ValueError, "must be a 0-D Tensor"): 86 linalg_lib.LinearOperatorZeros(num_rows=[2]) 87 with self.assertRaisesRegexp(ValueError, "must be a 0-D Tensor"): 88 linalg_lib.LinearOperatorZeros(num_rows=2, num_columns=[2]) 89 90 def test_non_integer_num_rows_raises_static(self): 91 with self.assertRaisesRegexp(TypeError, "must be integer"): 92 linalg_lib.LinearOperatorZeros(num_rows=2.) 93 with self.assertRaisesRegexp(TypeError, "must be integer"): 94 linalg_lib.LinearOperatorZeros(num_rows=2, num_columns=2.) 95 96 def test_negative_num_rows_raises_static(self): 97 with self.assertRaisesRegexp(ValueError, "must be non-negative"): 98 linalg_lib.LinearOperatorZeros(num_rows=-2) 99 with self.assertRaisesRegexp(ValueError, "must be non-negative"): 100 linalg_lib.LinearOperatorZeros(num_rows=2, num_columns=-2) 101 102 def test_non_1d_batch_shape_raises_static(self): 103 with self.assertRaisesRegexp(ValueError, "must be a 1-D"): 104 linalg_lib.LinearOperatorZeros(num_rows=2, batch_shape=2) 105 106 def test_non_integer_batch_shape_raises_static(self): 107 with self.assertRaisesRegexp(TypeError, "must be integer"): 108 linalg_lib.LinearOperatorZeros(num_rows=2, batch_shape=[2.]) 109 110 def test_negative_batch_shape_raises_static(self): 111 with self.assertRaisesRegexp(ValueError, "must be non-negative"): 112 linalg_lib.LinearOperatorZeros(num_rows=2, batch_shape=[-2]) 113 114 @test_util.run_deprecated_v1 115 def test_non_scalar_num_rows_raises_dynamic(self): 116 with self.cached_session(): 117 num_rows = array_ops.placeholder(dtypes.int32) 118 operator = linalg_lib.LinearOperatorZeros( 119 num_rows, assert_proper_shapes=True) 120 with self.assertRaisesOpError("must be a 0-D Tensor"): 121 operator.to_dense().eval(feed_dict={num_rows: [2]}) 122 123 @test_util.run_deprecated_v1 124 def test_negative_num_rows_raises_dynamic(self): 125 with self.cached_session(): 126 n = array_ops.placeholder(dtypes.int32) 127 operator = linalg_lib.LinearOperatorZeros( 128 num_rows=n, assert_proper_shapes=True) 129 with self.assertRaisesOpError("must be non-negative"): 130 operator.to_dense().eval(feed_dict={n: -2}) 131 132 operator = linalg_lib.LinearOperatorZeros( 133 num_rows=2, num_columns=n, assert_proper_shapes=True) 134 with self.assertRaisesOpError("must be non-negative"): 135 operator.to_dense().eval(feed_dict={n: -2}) 136 137 @test_util.run_deprecated_v1 138 def test_non_1d_batch_shape_raises_dynamic(self): 139 with self.cached_session(): 140 batch_shape = array_ops.placeholder(dtypes.int32) 141 operator = linalg_lib.LinearOperatorZeros( 142 num_rows=2, batch_shape=batch_shape, assert_proper_shapes=True) 143 with self.assertRaisesOpError("must be a 1-D"): 144 operator.to_dense().eval(feed_dict={batch_shape: 2}) 145 146 @test_util.run_deprecated_v1 147 def test_negative_batch_shape_raises_dynamic(self): 148 with self.cached_session(): 149 batch_shape = array_ops.placeholder(dtypes.int32) 150 operator = linalg_lib.LinearOperatorZeros( 151 num_rows=2, batch_shape=batch_shape, assert_proper_shapes=True) 152 with self.assertRaisesOpError("must be non-negative"): 153 operator.to_dense().eval(feed_dict={batch_shape: [-2]}) 154 155 def test_wrong_matrix_dimensions_raises_static(self): 156 operator = linalg_lib.LinearOperatorZeros(num_rows=2) 157 x = rng.randn(3, 3).astype(np.float32) 158 with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"): 159 operator.matmul(x) 160 161 @test_util.run_deprecated_v1 162 def test_wrong_matrix_dimensions_raises_dynamic(self): 163 num_rows = array_ops.placeholder(dtypes.int32) 164 x = array_ops.placeholder(dtypes.float32) 165 166 with self.cached_session(): 167 operator = linalg_lib.LinearOperatorZeros( 168 num_rows, assert_proper_shapes=True) 169 y = operator.matmul(x) 170 with self.assertRaisesOpError("Incompatible.*dimensions"): 171 y.eval(feed_dict={num_rows: 2, x: rng.rand(3, 3)}) 172 173 def test_is_x_flags(self): 174 # The is_x flags are by default all True. 175 operator = linalg_lib.LinearOperatorZeros(num_rows=2) 176 self.assertFalse(operator.is_positive_definite) 177 self.assertFalse(operator.is_non_singular) 178 self.assertTrue(operator.is_self_adjoint) 179 180 def test_zeros_matmul(self): 181 operator1 = linalg_lib.LinearOperatorIdentity(num_rows=2) 182 operator2 = linalg_lib.LinearOperatorZeros(num_rows=2) 183 self.assertTrue(isinstance( 184 operator1.matmul(operator2), 185 linalg_lib.LinearOperatorZeros)) 186 187 self.assertTrue(isinstance( 188 operator2.matmul(operator1), 189 linalg_lib.LinearOperatorZeros)) 190 191 192 class LinearOperatorZerosNotSquareTest( 193 linear_operator_test_util.NonSquareLinearOperatorDerivedClassTest): 194 195 def _operator_and_matrix(self, build_info, dtype, use_placeholder): 196 del use_placeholder 197 shape = list(build_info.shape) 198 199 batch_shape = shape[:-2] 200 num_rows = shape[-2] 201 num_columns = shape[-1] 202 203 operator = linalg_lib.LinearOperatorZeros( 204 num_rows, num_columns, is_square=False, is_self_adjoint=False, 205 batch_shape=batch_shape, dtype=dtype) 206 matrix = array_ops.zeros(shape=shape, dtype=dtype) 207 208 return operator, matrix 209 210 211 if __name__ == "__main__": 212 test.main() 213