1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import numpy as np 21 22 from tensorflow.contrib.solvers.python.ops import least_squares 23 from tensorflow.contrib.solvers.python.ops import util 24 from tensorflow.python.framework import constant_op 25 from tensorflow.python.ops import array_ops 26 from tensorflow.python.platform import test as test_lib 27 28 29 def _add_test(test, test_name, fn): 30 test_name = "_".join(["test", test_name]) 31 if hasattr(test, test_name): 32 raise RuntimeError("Test %s defined more than once" % test_name) 33 setattr(test, test_name, fn) 34 35 36 class LeastSquaresTest(test_lib.TestCase): 37 pass # Filled in below. 38 39 40 def _get_least_squares_tests(dtype_, use_static_shape_, shape_): 41 42 def test_cgls(self): 43 np.random.seed(1) 44 a_np = np.random.uniform( 45 low=-1.0, high=1.0, size=np.prod(shape_)).reshape(shape_).astype(dtype_) 46 rhs_np = np.random.uniform( 47 low=-1.0, high=1.0, size=shape_[0]).astype(dtype_) 48 tol = 1e-12 if dtype_ == np.float64 else 1e-6 49 max_iter = 20 50 with self.test_session() as sess: 51 if use_static_shape_: 52 a = constant_op.constant(a_np) 53 rhs = constant_op.constant(rhs_np) 54 else: 55 a = array_ops.placeholder(dtype_) 56 rhs = array_ops.placeholder(dtype_) 57 operator = util.create_operator(a) 58 cgls_graph = least_squares.cgls(operator, rhs, tol=tol, max_iter=max_iter) 59 if use_static_shape_: 60 cgls_val = sess.run(cgls_graph) 61 else: 62 cgls_val = sess.run(cgls_graph, feed_dict={a: a_np, rhs: rhs_np}) 63 # Below we use s = A^* (rhs - A x), s0 = A^* rhs 64 norm_s0 = np.linalg.norm(np.dot(a_np.T, rhs_np)) 65 norm_s = np.sqrt(cgls_val.gamma) 66 self.assertLessEqual(norm_s, tol * norm_s0) 67 # Validate that we get an equally small residual norm with numpy 68 # using the computed solution. 69 r_np = rhs_np - np.dot(a_np, cgls_val.x) 70 norm_s_np = np.linalg.norm(np.dot(a_np.T, r_np)) 71 self.assertLessEqual(norm_s_np, tol * norm_s0) 72 73 return [test_cgls] 74 75 76 if __name__ == "__main__": 77 for dtype in np.float32, np.float64: 78 for shape in [[4, 4], [8, 5], [3, 7]]: 79 for use_static_shape in True, False: 80 arg_string = "%s_%s_staticshape_%s" % (dtype.__name__, 81 "_".join(map(str, shape)), 82 use_static_shape) 83 for test_fn in _get_least_squares_tests(dtype, use_static_shape, shape): 84 name = "_".join(["LeastSquares", test_fn.__name__, arg_string]) 85 _add_test(LeastSquaresTest, name, test_fn) 86 87 test_lib.main() 88