Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for tensorflow.ops.gen_linalg_ops.matrix_exponential."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import itertools
     22 import math
     23 
     24 import numpy as np
     25 
     26 from tensorflow.python.client import session
     27 from tensorflow.python.framework import constant_op
     28 from tensorflow.python.framework import ops
     29 from tensorflow.python.ops import control_flow_ops
     30 from tensorflow.python.ops import gen_linalg_ops
     31 from tensorflow.python.ops import random_ops
     32 from tensorflow.python.ops import variables
     33 from tensorflow.python.platform import test
     34 
     35 
     36 def np_expm(x):
     37   """Slow but accurate Taylor series matrix exponential."""
     38   y = np.zeros(x.shape, dtype=x.dtype)
     39   xn = np.eye(x.shape[0], dtype=x.dtype)
     40   for n in range(40):
     41     y += xn / float(math.factorial(n))
     42     xn = np.dot(xn, x)
     43   return y
     44 
     45 
     46 class ExponentialOpTest(test.TestCase):
     47 
     48   def _verifyExponential(self, x, np_type):
     49     inp = x.astype(np_type)
     50     with self.test_session(use_gpu=True):
     51       tf_ans = gen_linalg_ops._matrix_exponential(inp)
     52       if x.size == 0:
     53         np_ans = np.empty(x.shape, dtype=np_type)
     54       else:
     55         if x.ndim > 2:
     56           np_ans = np.zeros(inp.shape, dtype=np_type)
     57           for i in itertools.product(*[range(x) for x in inp.shape[:-2]]):
     58             np_ans[i] = np_expm(inp[i])
     59         else:
     60           np_ans = np_expm(inp)
     61       out = tf_ans.eval()
     62       self.assertAllClose(np_ans, out, rtol=1e-4, atol=1e-3)
     63 
     64   def _verifyExponentialReal(self, x):
     65     for np_type in [np.float32, np.float64]:
     66       self._verifyExponential(x, np_type)
     67 
     68   def _verifyExponentialComplex(self, x):
     69     for np_type in [np.complex64, np.complex128]:
     70       self._verifyExponential(x, np_type)
     71 
     72   def _makeBatch(self, matrix1, matrix2):
     73     matrix_batch = np.concatenate(
     74         [np.expand_dims(matrix1, 0),
     75          np.expand_dims(matrix2, 0)])
     76     matrix_batch = np.tile(matrix_batch, [2, 3, 1, 1])
     77     return matrix_batch
     78 
     79   def testNonsymmetric(self):
     80     # 2x2 matrices
     81     matrix1 = np.array([[1., 2.], [3., 4.]])
     82     matrix2 = np.array([[1., 3.], [3., 5.]])
     83     self._verifyExponentialReal(matrix1)
     84     self._verifyExponentialReal(matrix2)
     85     # A multidimensional batch of 2x2 matrices
     86     self._verifyExponentialReal(self._makeBatch(matrix1, matrix2))
     87     # Complex
     88     matrix1 = matrix1.astype(np.complex64)
     89     matrix1 += 1j * matrix1
     90     matrix2 = matrix2.astype(np.complex64)
     91     matrix2 += 1j * matrix2
     92     self._verifyExponentialComplex(matrix1)
     93     self._verifyExponentialComplex(matrix2)
     94     # Complex batch
     95     self._verifyExponentialComplex(self._makeBatch(matrix1, matrix2))
     96 
     97   def testSymmetricPositiveDefinite(self):
     98     # 2x2 matrices
     99     matrix1 = np.array([[2., 1.], [1., 2.]])
    100     matrix2 = np.array([[3., -1.], [-1., 3.]])
    101     self._verifyExponentialReal(matrix1)
    102     self._verifyExponentialReal(matrix2)
    103     # A multidimensional batch of 2x2 matrices
    104     self._verifyExponentialReal(self._makeBatch(matrix1, matrix2))
    105     # Complex
    106     matrix1 = matrix1.astype(np.complex64)
    107     matrix1 += 1j * matrix1
    108     matrix2 = matrix2.astype(np.complex64)
    109     matrix2 += 1j * matrix2
    110     self._verifyExponentialComplex(matrix1)
    111     self._verifyExponentialComplex(matrix2)
    112     # Complex batch
    113     self._verifyExponentialComplex(self._makeBatch(matrix1, matrix2))
    114 
    115   def testNonSquareMatrix(self):
    116     # When the exponential of a non-square matrix is attempted we should return
    117     # an error
    118     with self.assertRaises(ValueError):
    119       gen_linalg_ops._matrix_exponential(np.array([[1., 2., 3.], [3., 4., 5.]]))
    120 
    121   def testWrongDimensions(self):
    122     # The input to the exponential should be at least a 2-dimensional tensor.
    123     tensor3 = constant_op.constant([1., 2.])
    124     with self.assertRaises(ValueError):
    125       gen_linalg_ops._matrix_exponential(tensor3)
    126 
    127   def testEmpty(self):
    128     self._verifyExponentialReal(np.empty([0, 2, 2]))
    129     self._verifyExponentialReal(np.empty([2, 0, 0]))
    130 
    131   def testRandomSmallAndLarge(self):
    132     np.random.seed(42)
    133     for dtype in np.float32, np.float64, np.complex64, np.complex128:
    134       for batch_dims in [(), (1,), (3,), (2, 2)]:
    135         for size in 8, 31, 32:
    136           shape = batch_dims + (size, size)
    137           matrix = np.random.uniform(
    138               low=-1.0, high=1.0,
    139               size=np.prod(shape)).reshape(shape).astype(dtype)
    140           self._verifyExponentialReal(matrix)
    141 
    142   def testConcurrentExecutesWithoutError(self):
    143     with self.test_session(use_gpu=True) as sess:
    144       matrix1 = random_ops.random_normal([5, 5], seed=42)
    145       matrix2 = random_ops.random_normal([5, 5], seed=42)
    146       expm1 = gen_linalg_ops._matrix_exponential(matrix1)
    147       expm2 = gen_linalg_ops._matrix_exponential(matrix2)
    148       expm = sess.run([expm1, expm2])
    149       self.assertAllEqual(expm[0], expm[1])
    150 
    151 
    152 class MatrixExponentialBenchmark(test.Benchmark):
    153 
    154   shapes = [
    155       (4, 4),
    156       (10, 10),
    157       (16, 16),
    158       (101, 101),
    159       (256, 256),
    160       (1000, 1000),
    161       (1024, 1024),
    162       (2048, 2048),
    163       (513, 4, 4),
    164       (513, 16, 16),
    165       (513, 256, 256),
    166   ]
    167 
    168   def _GenerateMatrix(self, shape):
    169     batch_shape = shape[:-2]
    170     shape = shape[-2:]
    171     assert shape[0] == shape[1]
    172     n = shape[0]
    173     matrix = np.ones(shape).astype(np.float32) / (
    174         2.0 * n) + np.diag(np.ones(n).astype(np.float32))
    175     return variables.Variable(np.tile(matrix, batch_shape + (1, 1)))
    176 
    177   def benchmarkMatrixExponentialOp(self):
    178     for shape in self.shapes:
    179       with ops.Graph().as_default(), \
    180           session.Session() as sess, \
    181           ops.device("/cpu:0"):
    182         matrix = self._GenerateMatrix(shape)
    183         expm = gen_linalg_ops._matrix_exponential(matrix)
    184         variables.global_variables_initializer().run()
    185         self.run_op_benchmark(
    186             sess,
    187             control_flow_ops.group(expm),
    188             min_iters=25,
    189             name="matrix_exponential_cpu_{shape}".format(
    190                 shape=shape))
    191 
    192 
    193 if __name__ == "__main__":
    194   test.main()
    195