1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import numpy as np 21 22 from tensorflow.python.framework import dtypes 23 from tensorflow.python.framework import random_seed 24 from tensorflow.python.ops import array_ops 25 from tensorflow.python.ops import math_ops 26 from tensorflow.python.ops.linalg import linalg as linalg_lib 27 from tensorflow.python.ops.linalg import linear_operator_test_util 28 from tensorflow.python.platform import test 29 30 linalg = linalg_lib 31 random_seed.set_random_seed(23) 32 rng = np.random.RandomState(0) 33 34 35 class BaseLinearOperatorLowRankUpdatetest(object): 36 """Base test for this type of operator.""" 37 38 # Subclasses should set these attributes to either True or False. 39 40 # If True, A = L + UDV^H 41 # If False, A = L + UV^H or A = L + UU^H, depending on _use_v. 42 _use_diag_update = None 43 44 # If True, diag is > 0, which means D is symmetric positive definite. 45 _is_diag_update_positive = None 46 47 # If True, A = L + UDV^H 48 # If False, A = L + UDU^H or A = L + UU^H, depending on _use_diag_update 49 _use_v = None 50 51 @property 52 def _dtypes_to_test(self): 53 # TODO(langmore) Test complex types once cholesky works with them. 54 # See comment in LinearOperatorLowRankUpdate.__init__. 55 return [dtypes.float32, dtypes.float64] 56 57 @property 58 def _shapes_to_test(self): 59 # Previously we had a (2, 10, 10) shape at the end. We did this to test the 60 # inversion and determinant lemmas on not-tiny matrices, since these are 61 # known to have stability issues. This resulted in test timeouts, so this 62 # shape has been removed, but rest assured, the tests did pass. 63 return [(0, 0), (1, 1), (1, 3, 3), (3, 4, 4), (2, 1, 4, 4)] 64 65 def _operator_and_mat_and_feed_dict(self, shape, dtype, use_placeholder): 66 # Recall A = L + UDV^H 67 shape = list(shape) 68 diag_shape = shape[:-1] 69 k = shape[-2] // 2 + 1 70 u_perturbation_shape = shape[:-1] + [k] 71 diag_update_shape = shape[:-2] + [k] 72 73 # base_operator L will be a symmetric positive definite diagonal linear 74 # operator, with condition number as high as 1e4. 75 base_diag = linear_operator_test_util.random_uniform( 76 diag_shape, minval=1e-4, maxval=1., dtype=dtype) 77 base_diag_ph = array_ops.placeholder(dtype=dtype) 78 79 # U 80 u = linear_operator_test_util.random_normal_correlated_columns( 81 u_perturbation_shape, dtype=dtype) 82 u_ph = array_ops.placeholder(dtype=dtype) 83 84 # V 85 v = linear_operator_test_util.random_normal_correlated_columns( 86 u_perturbation_shape, dtype=dtype) 87 v_ph = array_ops.placeholder(dtype=dtype) 88 89 # D 90 if self._is_diag_update_positive: 91 diag_update = linear_operator_test_util.random_uniform( 92 diag_update_shape, minval=1e-4, maxval=1., dtype=dtype) 93 else: 94 diag_update = linear_operator_test_util.random_normal( 95 diag_update_shape, stddev=1e-4, dtype=dtype) 96 diag_update_ph = array_ops.placeholder(dtype=dtype) 97 98 if use_placeholder: 99 # Evaluate here because (i) you cannot feed a tensor, and (ii) 100 # values are random and we want the same value used for both mat and 101 # feed_dict. 102 base_diag = base_diag.eval() 103 u = u.eval() 104 v = v.eval() 105 diag_update = diag_update.eval() 106 107 # In all cases, set base_operator to be positive definite. 108 base_operator = linalg.LinearOperatorDiag( 109 base_diag_ph, is_positive_definite=True) 110 111 operator = linalg.LinearOperatorLowRankUpdate( 112 base_operator, 113 u=u_ph, 114 v=v_ph if self._use_v else None, 115 diag_update=diag_update_ph if self._use_diag_update else None, 116 is_diag_update_positive=self._is_diag_update_positive) 117 feed_dict = { 118 base_diag_ph: base_diag, 119 u_ph: u, 120 v_ph: v, 121 diag_update_ph: diag_update} 122 else: 123 base_operator = linalg.LinearOperatorDiag( 124 base_diag, is_positive_definite=True) 125 operator = linalg.LinearOperatorLowRankUpdate( 126 base_operator, 127 u, 128 v=v if self._use_v else None, 129 diag_update=diag_update if self._use_diag_update else None, 130 is_diag_update_positive=self._is_diag_update_positive) 131 feed_dict = None 132 133 # The matrix representing L 134 base_diag_mat = array_ops.matrix_diag(base_diag) 135 136 # The matrix representing D 137 diag_update_mat = array_ops.matrix_diag(diag_update) 138 139 # Set up mat as some variant of A = L + UDV^H 140 if self._use_v and self._use_diag_update: 141 # In this case, we have L + UDV^H and it isn't symmetric. 142 expect_use_cholesky = False 143 mat = base_diag_mat + math_ops.matmul( 144 u, math_ops.matmul(diag_update_mat, v, adjoint_b=True)) 145 elif self._use_v: 146 # In this case, we have L + UDV^H and it isn't symmetric. 147 expect_use_cholesky = False 148 mat = base_diag_mat + math_ops.matmul(u, v, adjoint_b=True) 149 elif self._use_diag_update: 150 # In this case, we have L + UDU^H, which is PD if D > 0, since L > 0. 151 expect_use_cholesky = self._is_diag_update_positive 152 mat = base_diag_mat + math_ops.matmul( 153 u, math_ops.matmul(diag_update_mat, u, adjoint_b=True)) 154 else: 155 # In this case, we have L + UU^H, which is PD since L > 0. 156 expect_use_cholesky = True 157 mat = base_diag_mat + math_ops.matmul(u, u, adjoint_b=True) 158 159 if expect_use_cholesky: 160 self.assertTrue(operator._use_cholesky) 161 else: 162 self.assertFalse(operator._use_cholesky) 163 164 return operator, mat, feed_dict 165 166 167 class LinearOperatorLowRankUpdatetestWithDiagUseCholesky( 168 BaseLinearOperatorLowRankUpdatetest, 169 linear_operator_test_util.SquareLinearOperatorDerivedClassTest): 170 """A = L + UDU^H, D > 0, L > 0 ==> A > 0 and we can use a Cholesky.""" 171 172 _use_diag_update = True 173 _is_diag_update_positive = True 174 _use_v = False 175 176 def setUp(self): 177 # Decrease tolerance since we are testing with condition numbers as high as 178 # 1e4. 179 self._atol[dtypes.float32] = 1e-5 180 self._rtol[dtypes.float32] = 1e-5 181 self._atol[dtypes.float64] = 1e-10 182 self._rtol[dtypes.float64] = 1e-10 183 184 185 class LinearOperatorLowRankUpdatetestWithDiagCannotUseCholesky( 186 BaseLinearOperatorLowRankUpdatetest, 187 linear_operator_test_util.SquareLinearOperatorDerivedClassTest): 188 """A = L + UDU^H, D !> 0, L > 0 ==> A !> 0 and we cannot use a Cholesky.""" 189 190 _use_diag_update = True 191 _is_diag_update_positive = False 192 _use_v = False 193 194 def setUp(self): 195 # Decrease tolerance since we are testing with condition numbers as high as 196 # 1e4. This class does not use Cholesky, and thus needs even looser 197 # tolerance. 198 self._atol[dtypes.float32] = 1e-4 199 self._rtol[dtypes.float32] = 1e-4 200 self._atol[dtypes.float64] = 1e-9 201 self._rtol[dtypes.float64] = 1e-9 202 203 204 class LinearOperatorLowRankUpdatetestNoDiagUseCholesky( 205 BaseLinearOperatorLowRankUpdatetest, 206 linear_operator_test_util.SquareLinearOperatorDerivedClassTest): 207 """A = L + UU^H, L > 0 ==> A > 0 and we can use a Cholesky.""" 208 209 _use_diag_update = False 210 _is_diag_update_positive = None 211 _use_v = False 212 213 def setUp(self): 214 # Decrease tolerance since we are testing with condition numbers as high as 215 # 1e4. 216 self._atol[dtypes.float32] = 1e-5 217 self._rtol[dtypes.float32] = 1e-5 218 self._atol[dtypes.float64] = 1e-10 219 self._rtol[dtypes.float64] = 1e-10 220 221 222 class LinearOperatorLowRankUpdatetestNoDiagCannotUseCholesky( 223 BaseLinearOperatorLowRankUpdatetest, 224 linear_operator_test_util.SquareLinearOperatorDerivedClassTest): 225 """A = L + UV^H, L > 0 ==> A is not symmetric and we cannot use a Cholesky.""" 226 227 _use_diag_update = False 228 _is_diag_update_positive = None 229 _use_v = True 230 231 def setUp(self): 232 # Decrease tolerance since we are testing with condition numbers as high as 233 # 1e4. This class does not use Cholesky, and thus needs even looser 234 # tolerance. 235 self._atol[dtypes.float32] = 1e-4 236 self._rtol[dtypes.float32] = 1e-4 237 self._atol[dtypes.float64] = 1e-9 238 self._rtol[dtypes.float64] = 1e-9 239 240 241 class LinearOperatorLowRankUpdatetestWithDiagNotSquare( 242 BaseLinearOperatorLowRankUpdatetest, 243 linear_operator_test_util.NonSquareLinearOperatorDerivedClassTest): 244 """A = L + UDU^H, D > 0, L > 0 ==> A > 0 and we can use a Cholesky.""" 245 246 _use_diag_update = True 247 _is_diag_update_positive = True 248 _use_v = True 249 250 251 class LinearOpearatorLowRankUpdateBroadcastsShape(test.TestCase): 252 """Test that the operator's shape is the broadcast of arguments.""" 253 254 def test_static_shape_broadcasts_up_from_operator_to_other_args(self): 255 base_operator = linalg.LinearOperatorIdentity(num_rows=3) 256 u = array_ops.ones(shape=[2, 3, 2]) 257 diag = array_ops.ones(shape=[2, 2]) 258 259 operator = linalg.LinearOperatorLowRankUpdate(base_operator, u, diag) 260 261 # domain_dimension is 3 262 self.assertAllEqual([2, 3, 3], operator.shape) 263 with self.test_session(): 264 self.assertAllEqual([2, 3, 3], operator.to_dense().eval().shape) 265 266 def test_dynamic_shape_broadcasts_up_from_operator_to_other_args(self): 267 num_rows_ph = array_ops.placeholder(dtypes.int32) 268 269 base_operator = linalg.LinearOperatorIdentity(num_rows=num_rows_ph) 270 271 u_shape_ph = array_ops.placeholder(dtypes.int32) 272 u = array_ops.ones(shape=u_shape_ph) 273 274 operator = linalg.LinearOperatorLowRankUpdate(base_operator, u) 275 276 feed_dict = { 277 num_rows_ph: 3, 278 u_shape_ph: [2, 3, 2], # batch_shape = [2] 279 } 280 281 with self.test_session(): 282 shape_tensor = operator.shape_tensor().eval(feed_dict=feed_dict) 283 self.assertAllEqual([2, 3, 3], shape_tensor) 284 dense = operator.to_dense().eval(feed_dict=feed_dict) 285 self.assertAllEqual([2, 3, 3], dense.shape) 286 287 def test_u_and_v_incompatible_batch_shape_raises(self): 288 base_operator = linalg.LinearOperatorIdentity(num_rows=3, dtype=np.float64) 289 u = rng.rand(5, 3, 2) 290 v = rng.rand(4, 3, 2) 291 with self.assertRaisesRegexp(ValueError, "Incompatible shapes"): 292 linalg.LinearOperatorLowRankUpdate(base_operator, u=u, v=v) 293 294 def test_u_and_base_operator_incompatible_batch_shape_raises(self): 295 base_operator = linalg.LinearOperatorIdentity( 296 num_rows=3, batch_shape=[4], dtype=np.float64) 297 u = rng.rand(5, 3, 2) 298 with self.assertRaisesRegexp(ValueError, "Incompatible shapes"): 299 linalg.LinearOperatorLowRankUpdate(base_operator, u=u) 300 301 def test_u_and_base_operator_incompatible_domain_dimension(self): 302 base_operator = linalg.LinearOperatorIdentity(num_rows=3, dtype=np.float64) 303 u = rng.rand(5, 4, 2) 304 with self.assertRaisesRegexp(ValueError, "not compatible"): 305 linalg.LinearOperatorLowRankUpdate(base_operator, u=u) 306 307 def test_u_and_diag_incompatible_low_rank_raises(self): 308 base_operator = linalg.LinearOperatorIdentity(num_rows=3, dtype=np.float64) 309 u = rng.rand(5, 3, 2) 310 diag = rng.rand(5, 4) # Last dimension should be 2 311 with self.assertRaisesRegexp(ValueError, "not compatible"): 312 linalg.LinearOperatorLowRankUpdate(base_operator, u=u, diag_update=diag) 313 314 def test_diag_incompatible_batch_shape_raises(self): 315 base_operator = linalg.LinearOperatorIdentity(num_rows=3, dtype=np.float64) 316 u = rng.rand(5, 3, 2) 317 diag = rng.rand(4, 2) # First dimension should be 5 318 with self.assertRaisesRegexp(ValueError, "Incompatible shapes"): 319 linalg.LinearOperatorLowRankUpdate(base_operator, u=u, diag_update=diag) 320 321 322 if __name__ == "__main__": 323 test.main() 324