1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <functional> 17 #include <memory> 18 19 #include "tensorflow/core/common_runtime/device.h" 20 #include "tensorflow/core/common_runtime/device_factory.h" 21 #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" 22 #include "tensorflow/core/framework/allocator.h" 23 #include "tensorflow/core/framework/fake_input.h" 24 #include "tensorflow/core/framework/node_def_builder.h" 25 #include "tensorflow/core/framework/op_kernel.h" 26 #include "tensorflow/core/framework/tensor.h" 27 #include "tensorflow/core/framework/types.h" 28 #include "tensorflow/core/framework/types.pb.h" 29 #include "tensorflow/core/kernels/ops_testutil.h" 30 #include "tensorflow/core/kernels/ops_util.h" 31 #include "tensorflow/core/lib/io/path.h" 32 #include "tensorflow/core/lib/strings/strcat.h" 33 #include "tensorflow/core/platform/test.h" 34 #include "tensorflow/core/platform/test_benchmark.h" 35 36 namespace tensorflow { 37 namespace { 38 39 class ReverseOpTest : public OpsTestBase { 40 protected: 41 void MakeOp(DataType data_type) { 42 TF_ASSERT_OK(NodeDefBuilder("myop", "Reverse") 43 .Input(FakeInput(data_type)) 44 .Input(FakeInput()) 45 .Attr("T", data_type) 46 .Finalize(node_def())); 47 TF_ASSERT_OK(InitOp()); 48 } 49 50 template <typename T> 51 void Reverse_0() { 52 MakeOp(DataTypeToEnum<T>::value); 53 AddInputFromArray<T>(TensorShape({}), {3}); 54 AddInputFromArray<bool>(TensorShape({}), {true}); 55 TF_ASSERT_OK(RunOpKernel()); 56 57 Tensor* output = GetOutput(0); 58 Tensor expected(allocator(), DataTypeToEnum<T>::value, TensorShape({})); 59 expected.scalar<T>() = expected.scalar<T>().constant(3); 60 test::ExpectTensorEqual<T>(expected, *output); 61 } 62 63 template <typename T> 64 void Reverse_234() { 65 MakeOp(DataTypeToEnum<T>::value); 66 // Feed and run 67 // [[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]] 68 // [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]] 69 AddInputFromArray<T>(TensorShape({2, 3, 4}), 70 {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 71 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}); 72 AddInputFromArray<bool>(TensorShape({3}), {true, false, true}); 73 74 TF_ASSERT_OK(RunOpKernel()); 75 76 // Check the new state of the input 77 Tensor* params_tensor = GetOutput(0); 78 Tensor expected(allocator(), DataTypeToEnum<T>::value, 79 TensorShape({2, 3, 4})); 80 // Should become 81 // [[[15, 14, 13, 12], [19, 18, 17, 16], [23, 22, 21, 20]] 82 // [[3, 2, 1, 0], [7, 6, 5, 4], [11, 10, 9, 8]]] 83 test::FillValues<T>(&expected, 84 {15, 14, 13, 12, 19, 18, 17, 16, 23, 22, 21, 20, 85 3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8}); 86 test::ExpectTensorEqual<T>(expected, *params_tensor); 87 } 88 89 template <typename T> 90 void Reverse_1234() { 91 MakeOp(DataTypeToEnum<T>::value); 92 // Feed and run 93 // [[[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]] 94 // [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]]] 95 AddInputFromArray<T>(TensorShape({1, 2, 3, 4}), 96 {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 97 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}); 98 AddInputFromArray<bool>(TensorShape({4}), {true, true, false, true}); 99 100 TF_ASSERT_OK(RunOpKernel()); 101 102 // Check the new state of the input 103 Tensor* params_tensor = GetOutput(0); 104 Tensor expected(allocator(), DataTypeToEnum<T>::value, 105 TensorShape({1, 2, 3, 4})); 106 // Should become 107 // [[[[15, 14, 13, 12], [19, 18, 17, 16], [23, 22, 21, 20]] 108 // [[3, 2, 1, 0], [7, 6, 5, 4], [11, 10, 9, 8]]]] 109 test::FillValues<T>(&expected, 110 {15, 14, 13, 12, 19, 18, 17, 16, 23, 22, 21, 20, 111 3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8}); 112 test::ExpectTensorEqual<T>(expected, *params_tensor); 113 } 114 }; 115 116 TEST_F(ReverseOpTest, Reverse_0_uint8) { Reverse_0<uint8>(); } 117 118 TEST_F(ReverseOpTest, Reverse_0_int8) { Reverse_0<int8>(); } 119 120 TEST_F(ReverseOpTest, Reverse_0_uint16) { Reverse_0<uint16>(); } 121 122 TEST_F(ReverseOpTest, Reverse_0_int16) { Reverse_0<int16>(); } 123 124 TEST_F(ReverseOpTest, Reverse_0_float) { Reverse_0<float>(); } 125 126 TEST_F(ReverseOpTest, Reverse_0_int32) { Reverse_0<int32>(); } 127 128 TEST_F(ReverseOpTest, Reverse_0_int64) { Reverse_0<int64>(); } 129 130 TEST_F(ReverseOpTest, Reverse_0_double) { Reverse_0<double>(); } 131 132 TEST_F(ReverseOpTest, Reverse_0_complex64) { Reverse_0<complex64>(); } 133 134 TEST_F(ReverseOpTest, Reverse_0_complex128) { Reverse_0<complex128>(); } 135 136 TEST_F(ReverseOpTest, Reverse_234_uint8) { Reverse_234<uint8>(); } 137 138 TEST_F(ReverseOpTest, Reverse_234_int8) { Reverse_234<int8>(); } 139 140 TEST_F(ReverseOpTest, Reverse_234_uint16) { Reverse_234<uint16>(); } 141 142 TEST_F(ReverseOpTest, Reverse_234_int16) { Reverse_234<int16>(); } 143 144 TEST_F(ReverseOpTest, Reverse_234_float) { Reverse_234<float>(); } 145 146 TEST_F(ReverseOpTest, Reverse_234_int32) { Reverse_234<int32>(); } 147 148 TEST_F(ReverseOpTest, Reverse_234_int64) { Reverse_234<int64>(); } 149 150 TEST_F(ReverseOpTest, Reverse_234_double) { Reverse_234<double>(); } 151 152 TEST_F(ReverseOpTest, Reverse_234_complex64) { Reverse_234<complex64>(); } 153 154 TEST_F(ReverseOpTest, Reverse_234_complex128) { Reverse_234<complex128>(); } 155 156 TEST_F(ReverseOpTest, Reverse_1234_uint8) { Reverse_1234<uint8>(); } 157 158 TEST_F(ReverseOpTest, Reverse_1234_int8) { Reverse_1234<int8>(); } 159 160 TEST_F(ReverseOpTest, Reverse_1234_uint16) { Reverse_1234<uint16>(); } 161 162 TEST_F(ReverseOpTest, Reverse_1234_int16) { Reverse_1234<int16>(); } 163 164 TEST_F(ReverseOpTest, Reverse_1234_float) { Reverse_1234<float>(); } 165 166 TEST_F(ReverseOpTest, Reverse_1234_int32) { Reverse_1234<int32>(); } 167 168 TEST_F(ReverseOpTest, Reverse_1234_int64) { Reverse_1234<int64>(); } 169 170 TEST_F(ReverseOpTest, Reverse_1234_double) { Reverse_1234<double>(); } 171 172 TEST_F(ReverseOpTest, Reverse_1234_complex64) { Reverse_1234<complex64>(); } 173 174 TEST_F(ReverseOpTest, Reverse_1234_complex128) { Reverse_1234<complex128>(); } 175 176 static SessionOptions GetOptions(int intra_threads) { 177 SessionOptions opts; 178 opts.config.set_intra_op_parallelism_threads(intra_threads); 179 opts.config.set_inter_op_parallelism_threads(1); 180 return opts; 181 } 182 183 // Creates a Graph which "reduce"s a 3D float tensor of "num" elements 184 // into a scalar. 185 template <typename T> 186 static Graph* Reverse(const TensorShape& shape, int reverse_axis) { 187 Graph* g = new Graph(OpRegistry::Global()); 188 Tensor data(DataTypeToEnum<T>::value, shape); 189 data.flat<T>().setRandom(); 190 Tensor axes(DT_INT32, TensorShape({1})); 191 axes.flat<int32>()(0) = reverse_axis; 192 test::graph::Reverse(g, test::graph::Constant(g, data), 193 test::graph::Constant(g, axes)); 194 return g; 195 } 196 197 template <typename T> 198 static void RunReverseRowsBenchmark(int iters, int outer_dim, int middle_dim, 199 int intra_threads, int channels) { 200 SessionOptions opts = GetOptions(intra_threads); 201 TensorShape shape{outer_dim, middle_dim, channels}; 202 const int64 num_items = static_cast<int64>(iters) * shape.num_elements(); 203 testing::ItemsProcessed(num_items); 204 testing::BytesProcessed(num_items * sizeof(T)); 205 testing::UseRealTime(); 206 test::Benchmark("cpu", Reverse<T>(shape, 1), &opts).Run(iters); 207 } 208 209 static void BM_ReverseRowsOf1Channel_1T_float(int iters, int outer_dim, 210 int middle_dim) { 211 RunReverseRowsBenchmark<float>(iters, outer_dim, middle_dim, 212 1 /* intra_threads */, 1 /* channels */); 213 } 214 215 BENCHMARK(BM_ReverseRowsOf1Channel_1T_float) 216 ->ArgPair(288, 288) 217 ->ArgPair(1024, 1024) 218 ->ArgPair(10 * 1024, 1024); 219 220 static void BM_ReverseRowsOf1Channel_1T_uint8(int iters, int outer_dim, 221 int middle_dim) { 222 RunReverseRowsBenchmark<uint8>(iters, outer_dim, middle_dim, 223 1 /* intra_threads */, 1 /* channels */); 224 } 225 226 BENCHMARK(BM_ReverseRowsOf1Channel_1T_uint8) 227 ->ArgPair(288, 288) 228 ->ArgPair(1024, 1024) 229 ->ArgPair(10 * 1024, 1024); 230 231 static void BM_ReverseRowsOf1Channel_4T_float(int iters, int outer_dim, 232 int middle_dim) { 233 RunReverseRowsBenchmark<float>(iters, outer_dim, middle_dim, 234 4 /* intra_threads */, 1 /* channels */); 235 } 236 237 BENCHMARK(BM_ReverseRowsOf1Channel_4T_float) 238 ->ArgPair(288, 288) 239 ->ArgPair(1024, 1024) 240 ->ArgPair(10 * 1024, 1024); 241 242 static void BM_ReverseRowsOf1Channel_4T_uint8(int iters, int outer_dim, 243 int middle_dim) { 244 RunReverseRowsBenchmark<uint8>(iters, outer_dim, middle_dim, 245 4 /* intra_threads */, 1 /* channels */); 246 } 247 248 BENCHMARK(BM_ReverseRowsOf1Channel_4T_uint8) 249 ->ArgPair(288, 288) 250 ->ArgPair(1024, 1024) 251 ->ArgPair(10 * 1024, 1024); 252 253 static void BM_ReverseRowsOf3Channels_1T_float(int iters, int outer_dim, 254 int middle_dim) { 255 RunReverseRowsBenchmark<float>(iters, outer_dim, middle_dim, 256 1 /* intra_threads */, 3 /* channels */); 257 } 258 259 BENCHMARK(BM_ReverseRowsOf3Channels_1T_float) 260 ->ArgPair(288, 288) 261 ->ArgPair(30, 30) 262 ->ArgPair(1024, 1024) 263 ->ArgPair(10 * 1024, 1024); 264 265 static void BM_ReverseRowsOf3Channels_1T_uint8(int iters, int outer_dim, 266 int middle_dim) { 267 RunReverseRowsBenchmark<uint8>(iters, outer_dim, middle_dim, 268 1 /* intra_threads */, 3 /* channels */); 269 } 270 271 BENCHMARK(BM_ReverseRowsOf3Channels_1T_uint8) 272 ->ArgPair(288, 288) 273 ->ArgPair(30, 30) 274 ->ArgPair(1024, 1024) 275 ->ArgPair(10 * 1024, 1024); 276 277 static void BM_ReverseRowsOf3Channels_4T_float(int iters, int outer_dim, 278 int middle_dim) { 279 RunReverseRowsBenchmark<float>(iters, outer_dim, middle_dim, 280 4 /* intra_threads */, 3 /* channels */); 281 } 282 283 BENCHMARK(BM_ReverseRowsOf3Channels_4T_float) 284 ->ArgPair(288, 288) 285 ->ArgPair(30, 30) 286 ->ArgPair(1024, 1024) 287 ->ArgPair(10 * 1024, 1024); 288 289 static void BM_ReverseRowsOf3Channels_4T_uint8(int iters, int outer_dim, 290 int middle_dim) { 291 RunReverseRowsBenchmark<uint8>(iters, outer_dim, middle_dim, 292 4 /* intra_threads */, 3 /* channels */); 293 } 294 BENCHMARK(BM_ReverseRowsOf3Channels_4T_uint8) 295 ->ArgPair(288, 288) 296 ->ArgPair(30, 30) 297 ->ArgPair(1024, 1024) 298 ->ArgPair(10 * 1024, 1024); 299 300 static void BM_ReverseRowsOf4Channels_1T_float(int iters, int outer_dim, 301 int middle_dim) { 302 RunReverseRowsBenchmark<float>(iters, outer_dim, middle_dim, 303 1 /* intra_threads */, 4 /* channels */); 304 } 305 306 BENCHMARK(BM_ReverseRowsOf4Channels_1T_float) 307 ->ArgPair(288, 288) 308 ->ArgPair(1024, 1024) 309 ->ArgPair(10 * 1024, 1024); 310 311 static void BM_ReverseRowsOf4Channels_1T_uint8(int iters, int outer_dim, 312 int middle_dim) { 313 RunReverseRowsBenchmark<uint8>(iters, outer_dim, middle_dim, 314 1 /* intra_threads */, 4 /* channels */); 315 } 316 317 BENCHMARK(BM_ReverseRowsOf4Channels_1T_uint8) 318 ->ArgPair(288, 288) 319 ->ArgPair(1024, 1024) 320 ->ArgPair(10 * 1024, 1024); 321 322 static void BM_ReverseRowsOf4Channels_4T_float(int iters, int outer_dim, 323 int middle_dim) { 324 RunReverseRowsBenchmark<float>(iters, outer_dim, middle_dim, 325 4 /* intra_threads */, 4 /* channels */); 326 } 327 328 BENCHMARK(BM_ReverseRowsOf4Channels_4T_float) 329 ->ArgPair(288, 288) 330 ->ArgPair(1024, 1024) 331 ->ArgPair(10 * 1024, 1024); 332 333 static void BM_ReverseRowsOf4Channels_4T_uint8(int iters, int outer_dim, 334 int middle_dim) { 335 RunReverseRowsBenchmark<uint8>(iters, outer_dim, middle_dim, 336 4 /* intra_threads */, 4 /* channels */); 337 } 338 339 BENCHMARK(BM_ReverseRowsOf4Channels_4T_uint8) 340 ->ArgPair(288, 288) 341 ->ArgPair(1024, 1024) 342 ->ArgPair(10 * 1024, 1024); 343 344 } // namespace 345 } // namespace tensorflow 346