Home | History | Annotate | Download | only in linalg
      1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import numpy as np
     21 
     22 from tensorflow.python.framework import dtypes
     23 from tensorflow.python.framework import test_util
     24 from tensorflow.python.ops import array_ops
     25 from tensorflow.python.ops.linalg import linalg as linalg_lib
     26 from tensorflow.python.ops.linalg import linear_operator_test_util
     27 from tensorflow.python.platform import test
     28 
     29 
     30 rng = np.random.RandomState(2016)
     31 
     32 
     33 class LinearOperatorZerosTest(
     34     linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
     35   """Most tests done in the base class LinearOperatorDerivedClassTest."""
     36 
     37   @property
     38   def _tests_to_skip(self):
     39     return [
     40         "cholesky", "log_abs_det", "inverse", "solve", "solve_with_broadcast"]
     41 
     42   @property
     43   def _operator_build_infos(self):
     44     build_info = linear_operator_test_util.OperatorBuildInfo
     45     return [
     46         build_info((1, 1)),
     47         build_info((1, 3, 3)),
     48         build_info((3, 4, 4)),
     49         build_info((2, 1, 4, 4))]
     50 
     51   def _operator_and_matrix(
     52       self, build_info, dtype, use_placeholder,
     53       ensure_self_adjoint_and_pd=False):
     54     del ensure_self_adjoint_and_pd
     55     del use_placeholder
     56     shape = list(build_info.shape)
     57     assert shape[-1] == shape[-2]
     58 
     59     batch_shape = shape[:-2]
     60     num_rows = shape[-1]
     61 
     62     operator = linalg_lib.LinearOperatorZeros(
     63         num_rows, batch_shape=batch_shape, dtype=dtype)
     64     matrix = array_ops.zeros(shape=shape, dtype=dtype)
     65 
     66     return operator, matrix
     67 
     68   def test_assert_positive_definite(self):
     69     operator = linalg_lib.LinearOperatorZeros(num_rows=2)
     70     with self.assertRaisesOpError("non-positive definite"):
     71       operator.assert_positive_definite()
     72 
     73   def test_assert_non_singular(self):
     74     with self.assertRaisesOpError("non-invertible"):
     75       operator = linalg_lib.LinearOperatorZeros(num_rows=2)
     76       operator.assert_non_singular()
     77 
     78   @test_util.run_deprecated_v1
     79   def test_assert_self_adjoint(self):
     80     with self.cached_session():
     81       operator = linalg_lib.LinearOperatorZeros(num_rows=2)
     82       operator.assert_self_adjoint().run()  # Should not fail
     83 
     84   def test_non_scalar_num_rows_raises_static(self):
     85     with self.assertRaisesRegexp(ValueError, "must be a 0-D Tensor"):
     86       linalg_lib.LinearOperatorZeros(num_rows=[2])
     87     with self.assertRaisesRegexp(ValueError, "must be a 0-D Tensor"):
     88       linalg_lib.LinearOperatorZeros(num_rows=2, num_columns=[2])
     89 
     90   def test_non_integer_num_rows_raises_static(self):
     91     with self.assertRaisesRegexp(TypeError, "must be integer"):
     92       linalg_lib.LinearOperatorZeros(num_rows=2.)
     93     with self.assertRaisesRegexp(TypeError, "must be integer"):
     94       linalg_lib.LinearOperatorZeros(num_rows=2, num_columns=2.)
     95 
     96   def test_negative_num_rows_raises_static(self):
     97     with self.assertRaisesRegexp(ValueError, "must be non-negative"):
     98       linalg_lib.LinearOperatorZeros(num_rows=-2)
     99     with self.assertRaisesRegexp(ValueError, "must be non-negative"):
    100       linalg_lib.LinearOperatorZeros(num_rows=2, num_columns=-2)
    101 
    102   def test_non_1d_batch_shape_raises_static(self):
    103     with self.assertRaisesRegexp(ValueError, "must be a 1-D"):
    104       linalg_lib.LinearOperatorZeros(num_rows=2, batch_shape=2)
    105 
    106   def test_non_integer_batch_shape_raises_static(self):
    107     with self.assertRaisesRegexp(TypeError, "must be integer"):
    108       linalg_lib.LinearOperatorZeros(num_rows=2, batch_shape=[2.])
    109 
    110   def test_negative_batch_shape_raises_static(self):
    111     with self.assertRaisesRegexp(ValueError, "must be non-negative"):
    112       linalg_lib.LinearOperatorZeros(num_rows=2, batch_shape=[-2])
    113 
    114   @test_util.run_deprecated_v1
    115   def test_non_scalar_num_rows_raises_dynamic(self):
    116     with self.cached_session():
    117       num_rows = array_ops.placeholder(dtypes.int32)
    118       operator = linalg_lib.LinearOperatorZeros(
    119           num_rows, assert_proper_shapes=True)
    120       with self.assertRaisesOpError("must be a 0-D Tensor"):
    121         operator.to_dense().eval(feed_dict={num_rows: [2]})
    122 
    123   @test_util.run_deprecated_v1
    124   def test_negative_num_rows_raises_dynamic(self):
    125     with self.cached_session():
    126       n = array_ops.placeholder(dtypes.int32)
    127       operator = linalg_lib.LinearOperatorZeros(
    128           num_rows=n, assert_proper_shapes=True)
    129       with self.assertRaisesOpError("must be non-negative"):
    130         operator.to_dense().eval(feed_dict={n: -2})
    131 
    132       operator = linalg_lib.LinearOperatorZeros(
    133           num_rows=2, num_columns=n, assert_proper_shapes=True)
    134       with self.assertRaisesOpError("must be non-negative"):
    135         operator.to_dense().eval(feed_dict={n: -2})
    136 
    137   @test_util.run_deprecated_v1
    138   def test_non_1d_batch_shape_raises_dynamic(self):
    139     with self.cached_session():
    140       batch_shape = array_ops.placeholder(dtypes.int32)
    141       operator = linalg_lib.LinearOperatorZeros(
    142           num_rows=2, batch_shape=batch_shape, assert_proper_shapes=True)
    143       with self.assertRaisesOpError("must be a 1-D"):
    144         operator.to_dense().eval(feed_dict={batch_shape: 2})
    145 
    146   @test_util.run_deprecated_v1
    147   def test_negative_batch_shape_raises_dynamic(self):
    148     with self.cached_session():
    149       batch_shape = array_ops.placeholder(dtypes.int32)
    150       operator = linalg_lib.LinearOperatorZeros(
    151           num_rows=2, batch_shape=batch_shape, assert_proper_shapes=True)
    152       with self.assertRaisesOpError("must be non-negative"):
    153         operator.to_dense().eval(feed_dict={batch_shape: [-2]})
    154 
    155   def test_wrong_matrix_dimensions_raises_static(self):
    156     operator = linalg_lib.LinearOperatorZeros(num_rows=2)
    157     x = rng.randn(3, 3).astype(np.float32)
    158     with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):
    159       operator.matmul(x)
    160 
    161   @test_util.run_deprecated_v1
    162   def test_wrong_matrix_dimensions_raises_dynamic(self):
    163     num_rows = array_ops.placeholder(dtypes.int32)
    164     x = array_ops.placeholder(dtypes.float32)
    165 
    166     with self.cached_session():
    167       operator = linalg_lib.LinearOperatorZeros(
    168           num_rows, assert_proper_shapes=True)
    169       y = operator.matmul(x)
    170       with self.assertRaisesOpError("Incompatible.*dimensions"):
    171         y.eval(feed_dict={num_rows: 2, x: rng.rand(3, 3)})
    172 
    173   def test_is_x_flags(self):
    174     # The is_x flags are by default all True.
    175     operator = linalg_lib.LinearOperatorZeros(num_rows=2)
    176     self.assertFalse(operator.is_positive_definite)
    177     self.assertFalse(operator.is_non_singular)
    178     self.assertTrue(operator.is_self_adjoint)
    179 
    180   def test_zeros_matmul(self):
    181     operator1 = linalg_lib.LinearOperatorIdentity(num_rows=2)
    182     operator2 = linalg_lib.LinearOperatorZeros(num_rows=2)
    183     self.assertTrue(isinstance(
    184         operator1.matmul(operator2),
    185         linalg_lib.LinearOperatorZeros))
    186 
    187     self.assertTrue(isinstance(
    188         operator2.matmul(operator1),
    189         linalg_lib.LinearOperatorZeros))
    190 
    191 
    192 class LinearOperatorZerosNotSquareTest(
    193     linear_operator_test_util.NonSquareLinearOperatorDerivedClassTest):
    194 
    195   def _operator_and_matrix(self, build_info, dtype, use_placeholder):
    196     del use_placeholder
    197     shape = list(build_info.shape)
    198 
    199     batch_shape = shape[:-2]
    200     num_rows = shape[-2]
    201     num_columns = shape[-1]
    202 
    203     operator = linalg_lib.LinearOperatorZeros(
    204         num_rows, num_columns, is_square=False, is_self_adjoint=False,
    205         batch_shape=batch_shape, dtype=dtype)
    206     matrix = array_ops.zeros(shape=shape, dtype=dtype)
    207 
    208     return operator, matrix
    209 
    210 
    211 if __name__ == "__main__":
    212   test.main()
    213