Home | History | Annotate | Download | only in linalg
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import numpy as np
     21 
     22 from tensorflow.python.framework import dtypes
     23 from tensorflow.python.framework import random_seed
     24 from tensorflow.python.ops import array_ops
     25 from tensorflow.python.ops import math_ops
     26 from tensorflow.python.ops.linalg import linalg as linalg_lib
     27 from tensorflow.python.ops.linalg import linear_operator_test_util
     28 from tensorflow.python.platform import test
     29 
     30 linalg = linalg_lib
     31 random_seed.set_random_seed(23)
     32 rng = np.random.RandomState(0)
     33 
     34 
     35 class BaseLinearOperatorLowRankUpdatetest(object):
     36   """Base test for this type of operator."""
     37 
     38   # Subclasses should set these attributes to either True or False.
     39 
     40   # If True, A = L + UDV^H
     41   # If False, A = L + UV^H or A = L + UU^H, depending on _use_v.
     42   _use_diag_update = None
     43 
     44   # If True, diag is > 0, which means D is symmetric positive definite.
     45   _is_diag_update_positive = None
     46 
     47   # If True, A = L + UDV^H
     48   # If False, A = L + UDU^H or A = L + UU^H, depending on _use_diag_update
     49   _use_v = None
     50 
     51   @property
     52   def _dtypes_to_test(self):
     53     # TODO(langmore) Test complex types once cholesky works with them.
     54     # See comment in LinearOperatorLowRankUpdate.__init__.
     55     return [dtypes.float32, dtypes.float64]
     56 
     57   @property
     58   def _shapes_to_test(self):
     59     # Previously we had a (2, 10, 10) shape at the end.  We did this to test the
     60     # inversion and determinant lemmas on not-tiny matrices, since these are
     61     # known to have stability issues.  This resulted in test timeouts, so this
     62     # shape has been removed, but rest assured, the tests did pass.
     63     return [(0, 0), (1, 1), (1, 3, 3), (3, 4, 4), (2, 1, 4, 4)]
     64 
     65   def _operator_and_mat_and_feed_dict(self, shape, dtype, use_placeholder):
     66     # Recall A = L + UDV^H
     67     shape = list(shape)
     68     diag_shape = shape[:-1]
     69     k = shape[-2] // 2 + 1
     70     u_perturbation_shape = shape[:-1] + [k]
     71     diag_update_shape = shape[:-2] + [k]
     72 
     73     # base_operator L will be a symmetric positive definite diagonal linear
     74     # operator, with condition number as high as 1e4.
     75     base_diag = linear_operator_test_util.random_uniform(
     76         diag_shape, minval=1e-4, maxval=1., dtype=dtype)
     77     base_diag_ph = array_ops.placeholder(dtype=dtype)
     78 
     79     # U
     80     u = linear_operator_test_util.random_normal_correlated_columns(
     81         u_perturbation_shape, dtype=dtype)
     82     u_ph = array_ops.placeholder(dtype=dtype)
     83 
     84     # V
     85     v = linear_operator_test_util.random_normal_correlated_columns(
     86         u_perturbation_shape, dtype=dtype)
     87     v_ph = array_ops.placeholder(dtype=dtype)
     88 
     89     # D
     90     if self._is_diag_update_positive:
     91       diag_update = linear_operator_test_util.random_uniform(
     92           diag_update_shape, minval=1e-4, maxval=1., dtype=dtype)
     93     else:
     94       diag_update = linear_operator_test_util.random_normal(
     95           diag_update_shape, stddev=1e-4, dtype=dtype)
     96     diag_update_ph = array_ops.placeholder(dtype=dtype)
     97 
     98     if use_placeholder:
     99       # Evaluate here because (i) you cannot feed a tensor, and (ii)
    100       # values are random and we want the same value used for both mat and
    101       # feed_dict.
    102       base_diag = base_diag.eval()
    103       u = u.eval()
    104       v = v.eval()
    105       diag_update = diag_update.eval()
    106 
    107       # In all cases, set base_operator to be positive definite.
    108       base_operator = linalg.LinearOperatorDiag(
    109           base_diag_ph, is_positive_definite=True)
    110 
    111       operator = linalg.LinearOperatorLowRankUpdate(
    112           base_operator,
    113           u=u_ph,
    114           v=v_ph if self._use_v else None,
    115           diag_update=diag_update_ph if self._use_diag_update else None,
    116           is_diag_update_positive=self._is_diag_update_positive)
    117       feed_dict = {
    118           base_diag_ph: base_diag,
    119           u_ph: u,
    120           v_ph: v,
    121           diag_update_ph: diag_update}
    122     else:
    123       base_operator = linalg.LinearOperatorDiag(
    124           base_diag, is_positive_definite=True)
    125       operator = linalg.LinearOperatorLowRankUpdate(
    126           base_operator,
    127           u,
    128           v=v if self._use_v else None,
    129           diag_update=diag_update if self._use_diag_update else None,
    130           is_diag_update_positive=self._is_diag_update_positive)
    131       feed_dict = None
    132 
    133     # The matrix representing L
    134     base_diag_mat = array_ops.matrix_diag(base_diag)
    135 
    136     # The matrix representing D
    137     diag_update_mat = array_ops.matrix_diag(diag_update)
    138 
    139     # Set up mat as some variant of A = L + UDV^H
    140     if self._use_v and self._use_diag_update:
    141       # In this case, we have L + UDV^H and it isn't symmetric.
    142       expect_use_cholesky = False
    143       mat = base_diag_mat + math_ops.matmul(
    144           u, math_ops.matmul(diag_update_mat, v, adjoint_b=True))
    145     elif self._use_v:
    146       # In this case, we have L + UDV^H and it isn't symmetric.
    147       expect_use_cholesky = False
    148       mat = base_diag_mat + math_ops.matmul(u, v, adjoint_b=True)
    149     elif self._use_diag_update:
    150       # In this case, we have L + UDU^H, which is PD if D > 0, since L > 0.
    151       expect_use_cholesky = self._is_diag_update_positive
    152       mat = base_diag_mat + math_ops.matmul(
    153           u, math_ops.matmul(diag_update_mat, u, adjoint_b=True))
    154     else:
    155       # In this case, we have L + UU^H, which is PD since L > 0.
    156       expect_use_cholesky = True
    157       mat = base_diag_mat + math_ops.matmul(u, u, adjoint_b=True)
    158 
    159     if expect_use_cholesky:
    160       self.assertTrue(operator._use_cholesky)
    161     else:
    162       self.assertFalse(operator._use_cholesky)
    163 
    164     return operator, mat, feed_dict
    165 
    166 
    167 class LinearOperatorLowRankUpdatetestWithDiagUseCholesky(
    168     BaseLinearOperatorLowRankUpdatetest,
    169     linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
    170   """A = L + UDU^H, D > 0, L > 0 ==> A > 0 and we can use a Cholesky."""
    171 
    172   _use_diag_update = True
    173   _is_diag_update_positive = True
    174   _use_v = False
    175 
    176   def setUp(self):
    177     # Decrease tolerance since we are testing with condition numbers as high as
    178     # 1e4.
    179     self._atol[dtypes.float32] = 1e-5
    180     self._rtol[dtypes.float32] = 1e-5
    181     self._atol[dtypes.float64] = 1e-10
    182     self._rtol[dtypes.float64] = 1e-10
    183 
    184 
    185 class LinearOperatorLowRankUpdatetestWithDiagCannotUseCholesky(
    186     BaseLinearOperatorLowRankUpdatetest,
    187     linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
    188   """A = L + UDU^H, D !> 0, L > 0 ==> A !> 0 and we cannot use a Cholesky."""
    189 
    190   _use_diag_update = True
    191   _is_diag_update_positive = False
    192   _use_v = False
    193 
    194   def setUp(self):
    195     # Decrease tolerance since we are testing with condition numbers as high as
    196     # 1e4.  This class does not use Cholesky, and thus needs even looser
    197     # tolerance.
    198     self._atol[dtypes.float32] = 1e-4
    199     self._rtol[dtypes.float32] = 1e-4
    200     self._atol[dtypes.float64] = 1e-9
    201     self._rtol[dtypes.float64] = 1e-9
    202 
    203 
    204 class LinearOperatorLowRankUpdatetestNoDiagUseCholesky(
    205     BaseLinearOperatorLowRankUpdatetest,
    206     linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
    207   """A = L + UU^H, L > 0 ==> A > 0 and we can use a Cholesky."""
    208 
    209   _use_diag_update = False
    210   _is_diag_update_positive = None
    211   _use_v = False
    212 
    213   def setUp(self):
    214     # Decrease tolerance since we are testing with condition numbers as high as
    215     # 1e4.
    216     self._atol[dtypes.float32] = 1e-5
    217     self._rtol[dtypes.float32] = 1e-5
    218     self._atol[dtypes.float64] = 1e-10
    219     self._rtol[dtypes.float64] = 1e-10
    220 
    221 
    222 class LinearOperatorLowRankUpdatetestNoDiagCannotUseCholesky(
    223     BaseLinearOperatorLowRankUpdatetest,
    224     linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
    225   """A = L + UV^H, L > 0 ==> A is not symmetric and we cannot use a Cholesky."""
    226 
    227   _use_diag_update = False
    228   _is_diag_update_positive = None
    229   _use_v = True
    230 
    231   def setUp(self):
    232     # Decrease tolerance since we are testing with condition numbers as high as
    233     # 1e4.  This class does not use Cholesky, and thus needs even looser
    234     # tolerance.
    235     self._atol[dtypes.float32] = 1e-4
    236     self._rtol[dtypes.float32] = 1e-4
    237     self._atol[dtypes.float64] = 1e-9
    238     self._rtol[dtypes.float64] = 1e-9
    239 
    240 
    241 class LinearOperatorLowRankUpdatetestWithDiagNotSquare(
    242     BaseLinearOperatorLowRankUpdatetest,
    243     linear_operator_test_util.NonSquareLinearOperatorDerivedClassTest):
    244   """A = L + UDU^H, D > 0, L > 0 ==> A > 0 and we can use a Cholesky."""
    245 
    246   _use_diag_update = True
    247   _is_diag_update_positive = True
    248   _use_v = True
    249 
    250 
    251 class LinearOpearatorLowRankUpdateBroadcastsShape(test.TestCase):
    252   """Test that the operator's shape is the broadcast of arguments."""
    253 
    254   def test_static_shape_broadcasts_up_from_operator_to_other_args(self):
    255     base_operator = linalg.LinearOperatorIdentity(num_rows=3)
    256     u = array_ops.ones(shape=[2, 3, 2])
    257     diag = array_ops.ones(shape=[2, 2])
    258 
    259     operator = linalg.LinearOperatorLowRankUpdate(base_operator, u, diag)
    260 
    261     # domain_dimension is 3
    262     self.assertAllEqual([2, 3, 3], operator.shape)
    263     with self.test_session():
    264       self.assertAllEqual([2, 3, 3], operator.to_dense().eval().shape)
    265 
    266   def test_dynamic_shape_broadcasts_up_from_operator_to_other_args(self):
    267     num_rows_ph = array_ops.placeholder(dtypes.int32)
    268 
    269     base_operator = linalg.LinearOperatorIdentity(num_rows=num_rows_ph)
    270 
    271     u_shape_ph = array_ops.placeholder(dtypes.int32)
    272     u = array_ops.ones(shape=u_shape_ph)
    273 
    274     operator = linalg.LinearOperatorLowRankUpdate(base_operator, u)
    275 
    276     feed_dict = {
    277         num_rows_ph: 3,
    278         u_shape_ph: [2, 3, 2],  # batch_shape = [2]
    279     }
    280 
    281     with self.test_session():
    282       shape_tensor = operator.shape_tensor().eval(feed_dict=feed_dict)
    283       self.assertAllEqual([2, 3, 3], shape_tensor)
    284       dense = operator.to_dense().eval(feed_dict=feed_dict)
    285       self.assertAllEqual([2, 3, 3], dense.shape)
    286 
    287   def test_u_and_v_incompatible_batch_shape_raises(self):
    288     base_operator = linalg.LinearOperatorIdentity(num_rows=3, dtype=np.float64)
    289     u = rng.rand(5, 3, 2)
    290     v = rng.rand(4, 3, 2)
    291     with self.assertRaisesRegexp(ValueError, "Incompatible shapes"):
    292       linalg.LinearOperatorLowRankUpdate(base_operator, u=u, v=v)
    293 
    294   def test_u_and_base_operator_incompatible_batch_shape_raises(self):
    295     base_operator = linalg.LinearOperatorIdentity(
    296         num_rows=3, batch_shape=[4], dtype=np.float64)
    297     u = rng.rand(5, 3, 2)
    298     with self.assertRaisesRegexp(ValueError, "Incompatible shapes"):
    299       linalg.LinearOperatorLowRankUpdate(base_operator, u=u)
    300 
    301   def test_u_and_base_operator_incompatible_domain_dimension(self):
    302     base_operator = linalg.LinearOperatorIdentity(num_rows=3, dtype=np.float64)
    303     u = rng.rand(5, 4, 2)
    304     with self.assertRaisesRegexp(ValueError, "not compatible"):
    305       linalg.LinearOperatorLowRankUpdate(base_operator, u=u)
    306 
    307   def test_u_and_diag_incompatible_low_rank_raises(self):
    308     base_operator = linalg.LinearOperatorIdentity(num_rows=3, dtype=np.float64)
    309     u = rng.rand(5, 3, 2)
    310     diag = rng.rand(5, 4)  # Last dimension should be 2
    311     with self.assertRaisesRegexp(ValueError, "not compatible"):
    312       linalg.LinearOperatorLowRankUpdate(base_operator, u=u, diag_update=diag)
    313 
    314   def test_diag_incompatible_batch_shape_raises(self):
    315     base_operator = linalg.LinearOperatorIdentity(num_rows=3, dtype=np.float64)
    316     u = rng.rand(5, 3, 2)
    317     diag = rng.rand(4, 2)  # First dimension should be 5
    318     with self.assertRaisesRegexp(ValueError, "Incompatible shapes"):
    319       linalg.LinearOperatorLowRankUpdate(base_operator, u=u, diag_update=diag)
    320 
    321 
    322 if __name__ == "__main__":
    323   test.main()
    324