Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for tensorflow.python.ops.linalg_ops."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.python.framework import dtypes
     24 from tensorflow.python.framework import ops
     25 from tensorflow.python.ops import array_ops
     26 from tensorflow.python.ops import linalg_ops
     27 from tensorflow.python.ops import math_ops
     28 from tensorflow.python.ops.linalg import linalg
     29 from tensorflow.python.platform import test
     30 
     31 
     32 def _AddTest(test_class, op_name, testcase_name, fn):
     33   test_name = "_".join(["test", op_name, testcase_name])
     34   if hasattr(test_class, test_name):
     35     raise RuntimeError("Test %s defined more than once" % test_name)
     36   setattr(test_class, test_name, fn)
     37 
     38 
     39 def _RandomPDMatrix(n, rng, dtype=np.float64):
     40   """Random positive definite matrix."""
     41   temp = rng.randn(n, n).astype(dtype)
     42   if dtype in [np.complex64, np.complex128]:
     43     temp.imag = rng.randn(n, n)
     44   return np.conj(temp).dot(temp.T)
     45 
     46 
     47 class CholeskySolveTest(test.TestCase):
     48 
     49   def setUp(self):
     50     self.rng = np.random.RandomState(0)
     51 
     52   def test_works_with_five_different_random_pos_def_matrices(self):
     53     for n in range(1, 6):
     54       for np_type, atol in [(np.float32, 0.05), (np.float64, 1e-5)]:
     55         with self.test_session(use_gpu=True):
     56           # Create 2 x n x n matrix
     57           array = np.array(
     58               [_RandomPDMatrix(n, self.rng),
     59                _RandomPDMatrix(n, self.rng)]).astype(np_type)
     60           chol = linalg_ops.cholesky(array)
     61           for k in range(1, 3):
     62             rhs = self.rng.randn(2, n, k).astype(np_type)
     63             x = linalg_ops.cholesky_solve(chol, rhs)
     64             self.assertAllClose(
     65                 rhs, math_ops.matmul(array, x).eval(), atol=atol)
     66 
     67 
     68 class LogdetTest(test.TestCase):
     69 
     70   def setUp(self):
     71     self.rng = np.random.RandomState(42)
     72 
     73   def test_works_with_five_different_random_pos_def_matrices(self):
     74     for n in range(1, 6):
     75       for np_dtype, atol in [(np.float32, 0.05), (np.float64, 1e-5),
     76                              (np.complex64, 0.05), (np.complex128, 1e-5)]:
     77         matrix = _RandomPDMatrix(n, self.rng, np_dtype)
     78         _, logdet_np = np.linalg.slogdet(matrix)
     79         with self.test_session(use_gpu=True):
     80           # Create 2 x n x n matrix
     81           # matrix = np.array(
     82           #     [_RandomPDMatrix(n, self.rng, np_dtype),
     83           #      _RandomPDMatrix(n, self.rng, np_dtype)]).astype(np_dtype)
     84           logdet_tf = linalg.logdet(matrix)
     85           self.assertAllClose(logdet_np, logdet_tf.eval(), atol=atol)
     86 
     87   def test_works_with_underflow_case(self):
     88     for np_dtype, atol in [(np.float32, 0.05), (np.float64, 1e-5),
     89                            (np.complex64, 0.05), (np.complex128, 1e-5)]:
     90       matrix = (np.eye(20) * 1e-6).astype(np_dtype)
     91       _, logdet_np = np.linalg.slogdet(matrix)
     92       with self.test_session(use_gpu=True):
     93         logdet_tf = linalg.logdet(matrix)
     94         self.assertAllClose(logdet_np, logdet_tf.eval(), atol=atol)
     95 
     96 
     97 class SlogdetTest(test.TestCase):
     98 
     99   def setUp(self):
    100     self.rng = np.random.RandomState(42)
    101 
    102   def test_works_with_five_different_random_pos_def_matrices(self):
    103     for n in range(1, 6):
    104       for np_dtype, atol in [(np.float32, 0.05), (np.float64, 1e-5),
    105                              (np.complex64, 0.05), (np.complex128, 1e-5)]:
    106         matrix = _RandomPDMatrix(n, self.rng, np_dtype)
    107         sign_np, log_abs_det_np = np.linalg.slogdet(matrix)
    108         with self.test_session(use_gpu=True):
    109           sign_tf, log_abs_det_tf = linalg.slogdet(matrix)
    110           self.assertAllClose(log_abs_det_np, log_abs_det_tf.eval(), atol=atol)
    111           self.assertAllClose(sign_np, sign_tf.eval(), atol=atol)
    112 
    113   def test_works_with_underflow_case(self):
    114     for np_dtype, atol in [(np.float32, 0.05), (np.float64, 1e-5),
    115                            (np.complex64, 0.05), (np.complex128, 1e-5)]:
    116       matrix = (np.eye(20) * 1e-6).astype(np_dtype)
    117       sign_np, log_abs_det_np = np.linalg.slogdet(matrix)
    118       with self.test_session(use_gpu=True):
    119         sign_tf, log_abs_det_tf = linalg.slogdet(matrix)
    120         self.assertAllClose(log_abs_det_np, log_abs_det_tf.eval(), atol=atol)
    121         self.assertAllClose(sign_np, sign_tf.eval(), atol=atol)
    122 
    123 
    124 class AdjointTest(test.TestCase):
    125 
    126   def test_compare_to_numpy(self):
    127     for dtype in np.float64, np.float64, np.complex64, np.complex128:
    128       matrix_np = np.array([[1 + 1j, 2 + 2j, 3 + 3j], [4 + 4j, 5 + 5j,
    129                                                        6 + 6j]]).astype(dtype)
    130       expected_transposed = np.conj(matrix_np.T)
    131       with self.test_session():
    132         matrix = ops.convert_to_tensor(matrix_np)
    133         transposed = linalg.adjoint(matrix)
    134         self.assertEqual((3, 2), transposed.get_shape())
    135         self.assertAllEqual(expected_transposed, transposed.eval())
    136 
    137 
    138 class EyeTest(test.TestCase):
    139   pass  # Will be filled in below
    140 
    141 
    142 def _GetEyeTest(num_rows, num_columns, batch_shape, dtype):
    143 
    144   def Test(self):
    145     eye_np = np.eye(num_rows, M=num_columns, dtype=dtype.as_numpy_dtype)
    146     if batch_shape is not None:
    147       eye_np = np.tile(eye_np, batch_shape + [1, 1])
    148     for use_placeholder in False, True:
    149       if use_placeholder and (num_columns is None or batch_shape is None):
    150         return
    151       with self.test_session(use_gpu=True) as sess:
    152         if use_placeholder:
    153           num_rows_placeholder = array_ops.placeholder(
    154               dtypes.int32, name="num_rows")
    155           num_columns_placeholder = array_ops.placeholder(
    156               dtypes.int32, name="num_columns")
    157           batch_shape_placeholder = array_ops.placeholder(
    158               dtypes.int32, name="batch_shape")
    159           eye = linalg_ops.eye(
    160               num_rows_placeholder,
    161               num_columns=num_columns_placeholder,
    162               batch_shape=batch_shape_placeholder,
    163               dtype=dtype)
    164           eye_tf = sess.run(
    165               eye,
    166               feed_dict={
    167                   num_rows_placeholder: num_rows,
    168                   num_columns_placeholder: num_columns,
    169                   batch_shape_placeholder: batch_shape
    170               })
    171         else:
    172           eye_tf = linalg_ops.eye(
    173               num_rows,
    174               num_columns=num_columns,
    175               batch_shape=batch_shape,
    176               dtype=dtype).eval()
    177         self.assertAllEqual(eye_np, eye_tf)
    178 
    179   return Test
    180 
    181 
    182 if __name__ == "__main__":
    183   for _num_rows in 0, 1, 2, 5:
    184     for _num_columns in None, 0, 1, 2, 5:
    185       for _batch_shape in None, [], [2], [2, 3]:
    186         for _dtype in (dtypes.int32, dtypes.int64, dtypes.float32,
    187                        dtypes.float64, dtypes.complex64, dtypes.complex128):
    188           name = "dtype_%s_num_rows_%s_num_column_%s_batch_shape_%s_" % (
    189               _dtype.name, _num_rows, _num_columns, _batch_shape)
    190           _AddTest(EyeTest, "EyeTest", name,
    191                    _GetEyeTest(_num_rows, _num_columns, _batch_shape, _dtype))
    192 
    193   test.main()
    194