Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import numpy as np
     21 
     22 from tensorflow.contrib.solvers.python.ops import least_squares
     23 from tensorflow.contrib.solvers.python.ops import util
     24 from tensorflow.python.framework import constant_op
     25 from tensorflow.python.ops import array_ops
     26 from tensorflow.python.platform import test as test_lib
     27 
     28 
     29 def _add_test(test, test_name, fn):
     30   test_name = "_".join(["test", test_name])
     31   if hasattr(test, test_name):
     32     raise RuntimeError("Test %s defined more than once" % test_name)
     33   setattr(test, test_name, fn)
     34 
     35 
     36 class LeastSquaresTest(test_lib.TestCase):
     37   pass  # Filled in below.
     38 
     39 
     40 def _get_least_squares_tests(dtype_, use_static_shape_, shape_):
     41 
     42   def test_cgls(self):
     43     np.random.seed(1)
     44     a_np = np.random.uniform(
     45         low=-1.0, high=1.0, size=np.prod(shape_)).reshape(shape_).astype(dtype_)
     46     rhs_np = np.random.uniform(
     47         low=-1.0, high=1.0, size=shape_[0]).astype(dtype_)
     48     tol = 1e-12 if dtype_ == np.float64 else 1e-6
     49     max_iter = 20
     50     with self.test_session() as sess:
     51       if use_static_shape_:
     52         a = constant_op.constant(a_np)
     53         rhs = constant_op.constant(rhs_np)
     54       else:
     55         a = array_ops.placeholder(dtype_)
     56         rhs = array_ops.placeholder(dtype_)
     57       operator = util.create_operator(a)
     58       cgls_graph = least_squares.cgls(operator, rhs, tol=tol, max_iter=max_iter)
     59       if use_static_shape_:
     60         cgls_val = sess.run(cgls_graph)
     61       else:
     62         cgls_val = sess.run(cgls_graph, feed_dict={a: a_np, rhs: rhs_np})
     63       # Below we use s = A^* (rhs - A x), s0 = A^* rhs
     64       norm_s0 = np.linalg.norm(np.dot(a_np.T, rhs_np))
     65       norm_s = np.sqrt(cgls_val.gamma)
     66       self.assertLessEqual(norm_s, tol * norm_s0)
     67       # Validate that we get an equally small residual norm with numpy
     68       # using the computed solution.
     69       r_np = rhs_np - np.dot(a_np, cgls_val.x)
     70       norm_s_np = np.linalg.norm(np.dot(a_np.T, r_np))
     71       self.assertLessEqual(norm_s_np, tol * norm_s0)
     72 
     73   return [test_cgls]
     74 
     75 
     76 if __name__ == "__main__":
     77   for dtype in np.float32, np.float64:
     78     for shape in [[4, 4], [8, 5], [3, 7]]:
     79       for use_static_shape in True, False:
     80         arg_string = "%s_%s_staticshape_%s" % (dtype.__name__,
     81                                                "_".join(map(str, shape)),
     82                                                use_static_shape)
     83         for test_fn in _get_least_squares_tests(dtype, use_static_shape, shape):
     84           name = "_".join(["LeastSquares", test_fn.__name__, arg_string])
     85           _add_test(LeastSquaresTest, name, test_fn)
     86 
     87   test_lib.main()
     88