Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import numpy as np
     21 
     22 from tensorflow.python.client import session
     23 from tensorflow.python.framework import constant_op
     24 from tensorflow.python.framework import dtypes as dtypes_lib
     25 from tensorflow.python.framework import ops
     26 from tensorflow.python.ops import array_ops
     27 from tensorflow.python.ops import control_flow_ops
     28 from tensorflow.python.ops import gradient_checker
     29 from tensorflow.python.ops import variables
     30 from tensorflow.python.platform import test as test_lib
     31 
     32 
     33 def _AddTest(test, op_name, testcase_name, fn):
     34   test_name = "_".join(["test", op_name, testcase_name])
     35   if hasattr(test, test_name):
     36     raise RuntimeError("Test %s defined more than once" % test_name)
     37   setattr(test, test_name, fn)
     38 
     39 
     40 class MatrixBandPartTest(test_lib.TestCase):
     41   pass  # Filled in below
     42 
     43 
     44 def _GetMatrixBandPartTest(dtype_, batch_shape_, shape_):
     45 
     46   def Test(self):
     47     mat = np.ones(shape_).astype(dtype_)
     48     batch_mat = np.tile(mat, batch_shape_ + (1, 1))
     49     for lower in -1, 0, 1, shape_[-2] - 1:
     50       for upper in -1, 0, 1, shape_[-1] - 1:
     51         band_np = mat
     52         if lower >= 0:
     53           band_np = np.triu(band_np, -lower)
     54         if upper >= 0:
     55           band_np = np.tril(band_np, upper)
     56         if batch_shape_ is not ():
     57           band_np = np.tile(band_np, batch_shape_ + (1, 1))
     58         for index_dtype in [dtypes_lib.int32, dtypes_lib.int64]:
     59           with self.test_session(use_gpu=False):
     60             band = array_ops.matrix_band_part(
     61                 batch_mat,
     62                 constant_op.constant(lower, index_dtype),
     63                 constant_op.constant(upper, index_dtype))
     64             self.assertAllEqual(band_np, band.eval())
     65 
     66   return Test
     67 
     68 
     69 class MatrixBandPartGradTest(test_lib.TestCase):
     70   pass  # Filled in below
     71 
     72 
     73 def _GetMatrixBandPartGradTest(dtype_, batch_shape_, shape_):
     74 
     75   def Test(self):
     76     shape = batch_shape_ + shape_
     77     x = constant_op.constant(np.random.rand(*shape), dtype=dtype_)
     78     with self.test_session(use_gpu=False):
     79       for lower in -1, 0, 1, shape_[-2] - 1:
     80         for upper in -1, 0, 1, shape_[-1] - 1:
     81           y = array_ops.matrix_band_part(x, lower, upper)
     82           error = gradient_checker.compute_gradient_error(
     83               x, x.get_shape().as_list(), y, y.get_shape().as_list())
     84           self.assertLess(error, 1e-4)
     85 
     86   return Test
     87 
     88 
     89 class MatrixBandPartBenchmark(test_lib.Benchmark):
     90 
     91   shapes = [
     92       (10, 16, 16),
     93       (10, 101, 101),
     94       (10, 256, 256),
     95       (10, 1000, 1000),
     96       (10, 1024, 1024),
     97       (10, 2048, 2048),
     98       (10, 10, 4, 4),
     99       (10, 10, 10, 10),
    100       (10, 10, 16, 16),
    101       (10, 10, 101, 101),
    102       (10, 10, 256, 256),
    103       (10, 10, 1000, 1000),
    104       (10, 10, 1024, 1024),
    105       (10, 10, 2048, 2048),
    106   ]
    107 
    108   def benchmarkMatrixBandPartOp(self):
    109     for shape_ in self.shapes:
    110       for limits in (-1, -1), (-1, 0), (0, -1), (2, 2):
    111         with ops.Graph().as_default(), \
    112             session.Session() as sess, \
    113             ops.device("/cpu:0"):
    114           matrix = variables.Variable(array_ops.ones(shape_))
    115           band = array_ops.matrix_band_part(matrix, limits[0], limits[1])
    116           variables.global_variables_initializer().run()
    117           self.run_op_benchmark(
    118               sess,
    119               control_flow_ops.group(band),
    120               min_iters=10,
    121               name="matrix_band_part_cpu_{shape}_{limits}".format(
    122                   shape=shape_, limits=limits))
    123 
    124         if test_lib.is_gpu_available(True):
    125           with ops.Graph().as_default(), \
    126               session.Session() as sess, \
    127               ops.device("/gpu:0"):
    128             matrix = variables.Variable(array_ops.ones(shape_))
    129             band = array_ops.matrix_band_part(matrix, limits[0], limits[1])
    130             variables.global_variables_initializer().run()
    131             self.run_op_benchmark(
    132                 sess,
    133                 control_flow_ops.group(band),
    134                 min_iters=10,
    135                 name="matrix_band_part_gpu_{shape}_{limits}".format(
    136                     shape=shape_, limits=limits))
    137 
    138 
    139 if __name__ == "__main__":
    140   dtypes = (np.bool, np.int32, np.int64, np.float32, np.float64, np.complex64,
    141             np.complex128)
    142   for dtype in dtypes:
    143     for batch_shape in ((), (2,), (1, 3, 2)):
    144       for rows in 1, 2, 7:
    145         for cols in 1, 2, 7:
    146           shape = (rows, cols)
    147           name = "%s_%s" % (dtype.__name__,
    148                             "_".join(map(str, batch_shape + shape)))
    149           _AddTest(MatrixBandPartTest, "MatrixBandPart", name,
    150                    _GetMatrixBandPartTest(dtype, batch_shape, shape))
    151 
    152   for dtype in (np.float32, np.float64):
    153     for batch_shape in ((), (2,)):
    154       for rows in 1, 2, 7:
    155         for cols in 1, 2, 7:
    156           shape = (rows, cols)
    157           name = "%s_%s" % (dtype.__name__,
    158                             "_".join(map(str, batch_shape + shape)))
    159           _AddTest(MatrixBandPartGradTest, "MatrixBandPartGrad", name,
    160                    _GetMatrixBandPartGradTest(dtype, batch_shape, shape))
    161 
    162   test_lib.main()
    163