1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for tensorflow.ops.tf.MatrixTriangularSolve.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import itertools 22 23 import numpy as np 24 25 from tensorflow.compiler.tests.xla_test import XLATestCase 26 from tensorflow.python.framework import constant_op 27 from tensorflow.python.framework import dtypes 28 from tensorflow.python.ops import array_ops 29 from tensorflow.python.ops import linalg_ops 30 from tensorflow.python.ops import math_ops 31 from tensorflow.python.platform import test 32 33 34 def MakePlaceholder(x): 35 return array_ops.placeholder(dtypes.as_dtype(x.dtype), shape=x.shape) 36 37 38 class MatrixTriangularSolveOpTest(XLATestCase): 39 40 def _VerifyTriangularSolveBase(self, sess, placeholder_a, placeholder_ca, 41 placeholder_b, a, clean_a, b, verification, 42 atol): 43 feed_dict = {placeholder_a: a, placeholder_ca: clean_a, placeholder_b: b} 44 verification_np = sess.run(verification, feed_dict) 45 self.assertAllClose(b, verification_np, atol=atol) 46 47 def _VerifyTriangularSolve(self, a, b, lower, adjoint, atol): 48 clean_a = np.tril(a) if lower else np.triu(a) 49 with self.test_session() as sess: 50 placeholder_a = MakePlaceholder(a) 51 placeholder_ca = MakePlaceholder(clean_a) 52 placeholder_b = MakePlaceholder(b) 53 with self.test_scope(): 54 x = linalg_ops.matrix_triangular_solve( 55 placeholder_a, placeholder_b, lower=lower, adjoint=adjoint) 56 verification = math_ops.matmul(placeholder_ca, x, adjoint_a=adjoint) 57 self._VerifyTriangularSolveBase(sess, placeholder_a, placeholder_ca, 58 placeholder_b, a, clean_a, b, 59 verification, atol) 60 61 def _VerifyTriangularSolveCombo(self, a, b, atol=1e-4): 62 transp = lambda x: np.swapaxes(x, -1, -2) 63 for lower, adjoint in itertools.product([True, False], repeat=2): 64 self._VerifyTriangularSolve( 65 a if lower else transp(a), b, lower, adjoint, atol) 66 67 def testBasic(self): 68 rng = np.random.RandomState(0) 69 a = np.tril(rng.randn(5, 5)) 70 b = rng.randn(5, 7) 71 for dtype in self.float_types: 72 self._VerifyTriangularSolveCombo(a.astype(dtype), b.astype(dtype)) 73 74 def testBasicNotActuallyTriangular(self): 75 rng = np.random.RandomState(0) 76 a = rng.randn(5, 5) # the `a` matrix is not lower-triangular 77 b = rng.randn(5, 7) 78 for dtype in self.float_types: 79 self._VerifyTriangularSolveCombo(a.astype(dtype), b.astype(dtype)) 80 81 def testBasicComplexDtypes(self): 82 rng = np.random.RandomState(0) 83 a = np.tril(rng.randn(5, 5) + rng.randn(5, 5) * 1j) 84 b = rng.randn(5, 7) + rng.randn(5, 7) * 1j 85 for dtype in self.complex_types: 86 self._VerifyTriangularSolveCombo(a.astype(dtype), b.astype(dtype)) 87 88 def testBatch(self): 89 rng = np.random.RandomState(0) 90 shapes = [((4, 3, 3), (4, 3, 5)), ((1, 2, 2), (1, 2, 1)), 91 ((1, 1, 1), (1, 1, 2)), ((2, 3, 4, 4), (2, 3, 4, 1))] 92 tuples = itertools.product(self.float_types, shapes) 93 for dtype, (a_shape, b_shape) in tuples: 94 n = a_shape[-1] 95 a = np.tril(rng.rand(*a_shape) - 0.5) / (2.0 * n) + np.eye(n) 96 b = rng.randn(*b_shape) 97 self._VerifyTriangularSolveCombo( 98 a.astype(dtype), b.astype(dtype), atol=1e-3) 99 100 def testLarge(self): 101 n = 1024 102 rng = np.random.RandomState(0) 103 a = np.tril(rng.rand(n, n) - 0.5) / (2.0 * n) + np.eye(n) 104 b = rng.randn(n, n) 105 self._VerifyTriangularSolve( 106 a.astype(np.float32), b.astype(np.float32), True, False, 1e-4) 107 108 def testNonSquareCoefficientMatrix(self): 109 rng = np.random.RandomState(0) 110 for dtype in self.float_types: 111 a = rng.randn(3, 4).astype(dtype) 112 b = rng.randn(4, 4).astype(dtype) 113 with self.assertRaises(ValueError): 114 linalg_ops.matrix_triangular_solve(a, b) 115 with self.assertRaises(ValueError): 116 linalg_ops.matrix_triangular_solve(a, b) 117 118 def testWrongDimensions(self): 119 randn = np.random.RandomState(0).randn 120 for dtype in self.float_types: 121 lhs = constant_op.constant(randn(3, 3), dtype=dtype) 122 rhs = constant_op.constant(randn(4, 3), dtype=dtype) 123 with self.assertRaises(ValueError): 124 linalg_ops.matrix_triangular_solve(lhs, rhs) 125 with self.assertRaises(ValueError): 126 linalg_ops.matrix_triangular_solve(lhs, rhs) 127 128 129 if __name__ == "__main__": 130 test.main() 131