Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for tensorflow.ops.math_ops.matrix_inverse."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.python.framework import constant_op
     24 from tensorflow.python.ops import array_ops
     25 from tensorflow.python.ops import gradient_checker
     26 from tensorflow.python.ops import linalg_ops
     27 from tensorflow.python.ops import math_ops
     28 from tensorflow.python.ops import random_ops
     29 from tensorflow.python.platform import test
     30 
     31 
     32 def _AddTest(test_class, op_name, testcase_name, fn):
     33   test_name = "_".join(["test", op_name, testcase_name])
     34   if hasattr(test_class, test_name):
     35     raise RuntimeError("Test %s defined more than once" % test_name)
     36   setattr(test_class, test_name, fn)
     37 
     38 
     39 class QrOpTest(test.TestCase):
     40 
     41   def testWrongDimensions(self):
     42     # The input to qr should be a tensor of at least rank 2.
     43     scalar = constant_op.constant(1.)
     44     with self.assertRaisesRegexp(ValueError,
     45                                  "Shape must be at least rank 2 but is rank 0"):
     46       linalg_ops.qr(scalar)
     47     vector = constant_op.constant([1., 2.])
     48     with self.assertRaisesRegexp(ValueError,
     49                                  "Shape must be at least rank 2 but is rank 1"):
     50       linalg_ops.qr(vector)
     51 
     52   def testConcurrentExecutesWithoutError(self):
     53     with self.test_session(use_gpu=True) as sess:
     54       all_ops = []
     55       for full_matrices_ in True, False:
     56         for rows_ in 4, 5:
     57           for cols_ in 4, 5:
     58             matrix1 = random_ops.random_normal([rows_, cols_], seed=42)
     59             matrix2 = random_ops.random_normal([rows_, cols_], seed=42)
     60             q1, r1 = linalg_ops.qr(matrix1, full_matrices=full_matrices_)
     61             q2, r2 = linalg_ops.qr(matrix2, full_matrices=full_matrices_)
     62             all_ops += [q1, r1, q2, r2]
     63       val = sess.run(all_ops)
     64       for i in range(8):
     65         q = 4 * i
     66         self.assertAllEqual(val[q], val[q + 2])  # q1 == q2
     67         self.assertAllEqual(val[q + 1], val[q + 3])  # r1 == r2
     68 
     69 
     70 def _GetQrOpTest(dtype_, shape_, full_matrices_, use_static_shape_):
     71 
     72   is_complex = dtype_ in (np.complex64, np.complex128)
     73   is_single = dtype_ in (np.float32, np.complex64)
     74 
     75   def CompareOrthogonal(self, x, y, rank):
     76     if is_single:
     77       atol = 5e-4
     78     else:
     79       atol = 5e-14
     80     # We only compare the first 'rank' orthogonal vectors since the
     81     # remainder form an arbitrary orthonormal basis for the
     82     # (row- or column-) null space, whose exact value depends on
     83     # implementation details. Notice that since we check that the
     84     # matrices of singular vectors are unitary elsewhere, we do
     85     # implicitly test that the trailing vectors of x and y span the
     86     # same space.
     87     x = x[..., 0:rank]
     88     y = y[..., 0:rank]
     89     # Q is only unique up to sign (complex phase factor for complex matrices),
     90     # so we normalize the sign first.
     91     sum_of_ratios = np.sum(np.divide(y, x), -2, keepdims=True)
     92     phases = np.divide(sum_of_ratios, np.abs(sum_of_ratios))
     93     x *= phases
     94     self.assertAllClose(x, y, atol=atol)
     95 
     96   def CheckApproximation(self, a, q, r):
     97     if is_single:
     98       tol = 1e-5
     99     else:
    100       tol = 1e-14
    101     # Tests that a ~= q*r.
    102     a_recon = math_ops.matmul(q, r)
    103     self.assertAllClose(a_recon.eval(), a, rtol=tol, atol=tol)
    104 
    105   def CheckUnitary(self, x):
    106     # Tests that x[...,:,:]^H * x[...,:,:] is close to the identity.
    107     xx = math_ops.matmul(x, x, adjoint_a=True)
    108     identity = array_ops.matrix_band_part(array_ops.ones_like(xx), 0, 0)
    109     if is_single:
    110       tol = 1e-5
    111     else:
    112       tol = 1e-14
    113     self.assertAllClose(identity.eval(), xx.eval(), atol=tol)
    114 
    115   def Test(self):
    116     np.random.seed(1)
    117     x_np = np.random.uniform(
    118         low=-1.0, high=1.0, size=np.prod(shape_)).reshape(shape_).astype(dtype_)
    119     if is_complex:
    120       x_np += 1j * np.random.uniform(
    121           low=-1.0, high=1.0,
    122           size=np.prod(shape_)).reshape(shape_).astype(dtype_)
    123 
    124     with self.test_session(use_gpu=True) as sess:
    125       if use_static_shape_:
    126         x_tf = constant_op.constant(x_np)
    127       else:
    128         x_tf = array_ops.placeholder(dtype_)
    129       q_tf, r_tf = linalg_ops.qr(x_tf, full_matrices=full_matrices_)
    130 
    131       if use_static_shape_:
    132         q_tf_val, r_tf_val = sess.run([q_tf, r_tf])
    133       else:
    134         q_tf_val, r_tf_val = sess.run([q_tf, r_tf], feed_dict={x_tf: x_np})
    135 
    136       q_dims = q_tf_val.shape
    137       np_q = np.ndarray(q_dims, dtype_)
    138       np_q_reshape = np.reshape(np_q, (-1, q_dims[-2], q_dims[-1]))
    139       new_first_dim = np_q_reshape.shape[0]
    140 
    141       x_reshape = np.reshape(x_np, (-1, x_np.shape[-2], x_np.shape[-1]))
    142       for i in range(new_first_dim):
    143         if full_matrices_:
    144           np_q_reshape[i, :, :], _ = np.linalg.qr(
    145               x_reshape[i, :, :], mode="complete")
    146         else:
    147           np_q_reshape[i, :, :], _ = np.linalg.qr(
    148               x_reshape[i, :, :], mode="reduced")
    149       np_q = np.reshape(np_q_reshape, q_dims)
    150       CompareOrthogonal(self, np_q, q_tf_val, min(shape_[-2:]))
    151       CheckApproximation(self, x_np, q_tf_val, r_tf_val)
    152       CheckUnitary(self, q_tf_val)
    153 
    154   return Test
    155 
    156 
    157 class QrGradOpTest(test.TestCase):
    158   pass
    159 
    160 
    161 def _GetQrGradOpTest(dtype_, shape_, full_matrices_):
    162 
    163   def Test(self):
    164     np.random.seed(42)
    165     a = np.random.uniform(low=-1.0, high=1.0, size=shape_).astype(dtype_)
    166     if dtype_ in [np.complex64, np.complex128]:
    167       a += 1j * np.random.uniform(
    168           low=-1.0, high=1.0, size=shape_).astype(dtype_)
    169     # Optimal stepsize for central difference is O(epsilon^{1/3}).
    170     epsilon = np.finfo(dtype_).eps
    171     delta = 0.1 * epsilon**(1.0 / 3.0)
    172     if dtype_ in [np.float32, np.complex64]:
    173       tol = 3e-2
    174     else:
    175       tol = 1e-6
    176     with self.test_session(use_gpu=True):
    177       tf_a = constant_op.constant(a)
    178       tf_b = linalg_ops.qr(tf_a, full_matrices=full_matrices_)
    179       for b in tf_b:
    180         x_init = np.random.uniform(
    181             low=-1.0, high=1.0, size=shape_).astype(dtype_)
    182         if dtype_ in [np.complex64, np.complex128]:
    183           x_init += 1j * np.random.uniform(
    184               low=-1.0, high=1.0, size=shape_).astype(dtype_)
    185         theoretical, numerical = gradient_checker.compute_gradient(
    186             tf_a,
    187             tf_a.get_shape().as_list(),
    188             b,
    189             b.get_shape().as_list(),
    190             x_init_value=x_init,
    191             delta=delta)
    192         self.assertAllClose(theoretical, numerical, atol=tol, rtol=tol)
    193 
    194   return Test
    195 
    196 
    197 if __name__ == "__main__":
    198   for dtype in np.float32, np.float64, np.complex64, np.complex128:
    199     for rows in 1, 2, 5, 10, 32, 100:
    200       for cols in 1, 2, 5, 10, 32, 100:
    201         for full_matrices in False, True:
    202           for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10):
    203             for use_static_shape in True, False:
    204               shape = batch_dims + (rows, cols)
    205               name = "%s_%s_full_%s_static_%s" % (dtype.__name__,
    206                                                   "_".join(map(str, shape)),
    207                                                   full_matrices,
    208                                                   use_static_shape)
    209               _AddTest(QrOpTest, "Qr", name,
    210                        _GetQrOpTest(dtype, shape, full_matrices,
    211                                     use_static_shape))
    212 
    213   # TODO(pfau): Get working with complex types.
    214   # TODO(pfau): Get working with full_matrices when rows != cols
    215   # TODO(pfau): Get working when rows < cols
    216   # TODO(pfau): Get working with shapeholders (dynamic shapes)
    217   for full_matrices in False, True:
    218     for dtype in np.float32, np.float64:
    219       for rows in 1, 2, 5, 10:
    220         for cols in 1, 2, 5, 10:
    221           if rows == cols or (not full_matrices and rows > cols):
    222             for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10):
    223               shape = batch_dims + (rows, cols)
    224               name = "%s_%s_full_%s" % (dtype.__name__,
    225                                         "_".join(map(str, shape)),
    226                                         full_matrices)
    227               _AddTest(QrGradOpTest, "QrGrad", name,
    228                        _GetQrGradOpTest(dtype, shape, full_matrices))
    229   test.main()
    230