1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import numpy as np 21 22 from tensorflow.python.client import session 23 from tensorflow.python.framework import constant_op 24 from tensorflow.python.framework import dtypes as dtypes_lib 25 from tensorflow.python.framework import ops 26 from tensorflow.python.ops import array_ops 27 from tensorflow.python.ops import control_flow_ops 28 from tensorflow.python.ops import gradient_checker 29 from tensorflow.python.ops import variables 30 from tensorflow.python.platform import test as test_lib 31 32 33 def _AddTest(test, op_name, testcase_name, fn): 34 test_name = "_".join(["test", op_name, testcase_name]) 35 if hasattr(test, test_name): 36 raise RuntimeError("Test %s defined more than once" % test_name) 37 setattr(test, test_name, fn) 38 39 40 class MatrixBandPartTest(test_lib.TestCase): 41 pass # Filled in below 42 43 44 def _GetMatrixBandPartTest(dtype_, batch_shape_, shape_): 45 46 def Test(self): 47 mat = np.ones(shape_).astype(dtype_) 48 batch_mat = np.tile(mat, batch_shape_ + (1, 1)) 49 for lower in -1, 0, 1, shape_[-2] - 1: 50 for upper in -1, 0, 1, shape_[-1] - 1: 51 band_np = mat 52 if lower >= 0: 53 band_np = np.triu(band_np, -lower) 54 if upper >= 0: 55 band_np = np.tril(band_np, upper) 56 if batch_shape_ is not (): 57 band_np = np.tile(band_np, batch_shape_ + (1, 1)) 58 for index_dtype in [dtypes_lib.int32, dtypes_lib.int64]: 59 with self.test_session(use_gpu=False): 60 band = array_ops.matrix_band_part( 61 batch_mat, 62 constant_op.constant(lower, index_dtype), 63 constant_op.constant(upper, index_dtype)) 64 self.assertAllEqual(band_np, band.eval()) 65 66 return Test 67 68 69 class MatrixBandPartGradTest(test_lib.TestCase): 70 pass # Filled in below 71 72 73 def _GetMatrixBandPartGradTest(dtype_, batch_shape_, shape_): 74 75 def Test(self): 76 shape = batch_shape_ + shape_ 77 x = constant_op.constant(np.random.rand(*shape), dtype=dtype_) 78 with self.test_session(use_gpu=False): 79 for lower in -1, 0, 1, shape_[-2] - 1: 80 for upper in -1, 0, 1, shape_[-1] - 1: 81 y = array_ops.matrix_band_part(x, lower, upper) 82 error = gradient_checker.compute_gradient_error( 83 x, x.get_shape().as_list(), y, y.get_shape().as_list()) 84 self.assertLess(error, 1e-4) 85 86 return Test 87 88 89 class MatrixBandPartBenchmark(test_lib.Benchmark): 90 91 shapes = [ 92 (10, 16, 16), 93 (10, 101, 101), 94 (10, 256, 256), 95 (10, 1000, 1000), 96 (10, 1024, 1024), 97 (10, 2048, 2048), 98 (10, 10, 4, 4), 99 (10, 10, 10, 10), 100 (10, 10, 16, 16), 101 (10, 10, 101, 101), 102 (10, 10, 256, 256), 103 (10, 10, 1000, 1000), 104 (10, 10, 1024, 1024), 105 (10, 10, 2048, 2048), 106 ] 107 108 def benchmarkMatrixBandPartOp(self): 109 for shape_ in self.shapes: 110 for limits in (-1, -1), (-1, 0), (0, -1), (2, 2): 111 with ops.Graph().as_default(), \ 112 session.Session() as sess, \ 113 ops.device("/cpu:0"): 114 matrix = variables.Variable(array_ops.ones(shape_)) 115 band = array_ops.matrix_band_part(matrix, limits[0], limits[1]) 116 variables.global_variables_initializer().run() 117 self.run_op_benchmark( 118 sess, 119 control_flow_ops.group(band), 120 min_iters=10, 121 name="matrix_band_part_cpu_{shape}_{limits}".format( 122 shape=shape_, limits=limits)) 123 124 if test_lib.is_gpu_available(True): 125 with ops.Graph().as_default(), \ 126 session.Session() as sess, \ 127 ops.device("/gpu:0"): 128 matrix = variables.Variable(array_ops.ones(shape_)) 129 band = array_ops.matrix_band_part(matrix, limits[0], limits[1]) 130 variables.global_variables_initializer().run() 131 self.run_op_benchmark( 132 sess, 133 control_flow_ops.group(band), 134 min_iters=10, 135 name="matrix_band_part_gpu_{shape}_{limits}".format( 136 shape=shape_, limits=limits)) 137 138 139 if __name__ == "__main__": 140 dtypes = (np.bool, np.int32, np.int64, np.float32, np.float64, np.complex64, 141 np.complex128) 142 for dtype in dtypes: 143 for batch_shape in ((), (2,), (1, 3, 2)): 144 for rows in 1, 2, 7: 145 for cols in 1, 2, 7: 146 shape = (rows, cols) 147 name = "%s_%s" % (dtype.__name__, 148 "_".join(map(str, batch_shape + shape))) 149 _AddTest(MatrixBandPartTest, "MatrixBandPart", name, 150 _GetMatrixBandPartTest(dtype, batch_shape, shape)) 151 152 for dtype in (np.float32, np.float64): 153 for batch_shape in ((), (2,)): 154 for rows in 1, 2, 7: 155 for cols in 1, 2, 7: 156 shape = (rows, cols) 157 name = "%s_%s" % (dtype.__name__, 158 "_".join(map(str, batch_shape + shape))) 159 _AddTest(MatrixBandPartGradTest, "MatrixBandPartGrad", name, 160 _GetMatrixBandPartGradTest(dtype, batch_shape, shape)) 161 162 test_lib.main() 163