Home | History | Annotate | Download | only in tests
      1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for tensorflow.ops.tf.MatrixTriangularSolve."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import itertools
     22 
     23 import numpy as np
     24 
     25 from tensorflow.compiler.tests.xla_test import XLATestCase
     26 from tensorflow.python.framework import constant_op
     27 from tensorflow.python.framework import dtypes
     28 from tensorflow.python.ops import array_ops
     29 from tensorflow.python.ops import linalg_ops
     30 from tensorflow.python.ops import math_ops
     31 from tensorflow.python.platform import test
     32 
     33 
     34 def MakePlaceholder(x):
     35   return array_ops.placeholder(dtypes.as_dtype(x.dtype), shape=x.shape)
     36 
     37 
     38 class MatrixTriangularSolveOpTest(XLATestCase):
     39 
     40   def _VerifyTriangularSolveBase(self, sess, placeholder_a, placeholder_ca,
     41                                  placeholder_b, a, clean_a, b, verification,
     42                                  atol):
     43     feed_dict = {placeholder_a: a, placeholder_ca: clean_a, placeholder_b: b}
     44     verification_np = sess.run(verification, feed_dict)
     45     self.assertAllClose(b, verification_np, atol=atol)
     46 
     47   def _VerifyTriangularSolve(self, a, b, lower, adjoint, atol):
     48     clean_a = np.tril(a) if lower else np.triu(a)
     49     with self.test_session() as sess:
     50       placeholder_a = MakePlaceholder(a)
     51       placeholder_ca = MakePlaceholder(clean_a)
     52       placeholder_b = MakePlaceholder(b)
     53       with self.test_scope():
     54         x = linalg_ops.matrix_triangular_solve(
     55             placeholder_a, placeholder_b, lower=lower, adjoint=adjoint)
     56       verification = math_ops.matmul(placeholder_ca, x, adjoint_a=adjoint)
     57       self._VerifyTriangularSolveBase(sess, placeholder_a, placeholder_ca,
     58                                       placeholder_b, a, clean_a, b,
     59                                       verification, atol)
     60 
     61   def _VerifyTriangularSolveCombo(self, a, b, atol=1e-4):
     62     transp = lambda x: np.swapaxes(x, -1, -2)
     63     for lower, adjoint in itertools.product([True, False], repeat=2):
     64       self._VerifyTriangularSolve(
     65           a if lower else transp(a), b, lower, adjoint, atol)
     66 
     67   def testBasic(self):
     68     rng = np.random.RandomState(0)
     69     a = np.tril(rng.randn(5, 5))
     70     b = rng.randn(5, 7)
     71     for dtype in self.float_types:
     72       self._VerifyTriangularSolveCombo(a.astype(dtype), b.astype(dtype))
     73 
     74   def testBasicNotActuallyTriangular(self):
     75     rng = np.random.RandomState(0)
     76     a = rng.randn(5, 5)  # the `a` matrix is not lower-triangular
     77     b = rng.randn(5, 7)
     78     for dtype in self.float_types:
     79       self._VerifyTriangularSolveCombo(a.astype(dtype), b.astype(dtype))
     80 
     81   def testBasicComplexDtypes(self):
     82     rng = np.random.RandomState(0)
     83     a = np.tril(rng.randn(5, 5) + rng.randn(5, 5) * 1j)
     84     b = rng.randn(5, 7) + rng.randn(5, 7) * 1j
     85     for dtype in self.complex_types:
     86       self._VerifyTriangularSolveCombo(a.astype(dtype), b.astype(dtype))
     87 
     88   def testBatch(self):
     89     rng = np.random.RandomState(0)
     90     shapes = [((4, 3, 3), (4, 3, 5)), ((1, 2, 2), (1, 2, 1)),
     91               ((1, 1, 1), (1, 1, 2)), ((2, 3, 4, 4), (2, 3, 4, 1))]
     92     tuples = itertools.product(self.float_types, shapes)
     93     for dtype, (a_shape, b_shape) in tuples:
     94       n = a_shape[-1]
     95       a = np.tril(rng.rand(*a_shape) - 0.5) / (2.0 * n) + np.eye(n)
     96       b = rng.randn(*b_shape)
     97       self._VerifyTriangularSolveCombo(
     98           a.astype(dtype), b.astype(dtype), atol=1e-3)
     99 
    100   def testLarge(self):
    101     n = 1024
    102     rng = np.random.RandomState(0)
    103     a = np.tril(rng.rand(n, n) - 0.5) / (2.0 * n) + np.eye(n)
    104     b = rng.randn(n, n)
    105     self._VerifyTriangularSolve(
    106         a.astype(np.float32), b.astype(np.float32), True, False, 1e-4)
    107 
    108   def testNonSquareCoefficientMatrix(self):
    109     rng = np.random.RandomState(0)
    110     for dtype in self.float_types:
    111       a = rng.randn(3, 4).astype(dtype)
    112       b = rng.randn(4, 4).astype(dtype)
    113       with self.assertRaises(ValueError):
    114         linalg_ops.matrix_triangular_solve(a, b)
    115       with self.assertRaises(ValueError):
    116         linalg_ops.matrix_triangular_solve(a, b)
    117 
    118   def testWrongDimensions(self):
    119     randn = np.random.RandomState(0).randn
    120     for dtype in self.float_types:
    121       lhs = constant_op.constant(randn(3, 3), dtype=dtype)
    122       rhs = constant_op.constant(randn(4, 3), dtype=dtype)
    123       with self.assertRaises(ValueError):
    124         linalg_ops.matrix_triangular_solve(lhs, rhs)
    125       with self.assertRaises(ValueError):
    126         linalg_ops.matrix_triangular_solve(lhs, rhs)
    127 
    128 
    129 if __name__ == "__main__":
    130   test.main()
    131