Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <functional>
     17 #include <memory>
     18 
     19 #include "tensorflow/core/common_runtime/device.h"
     20 #include "tensorflow/core/common_runtime/device_factory.h"
     21 #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
     22 #include "tensorflow/core/framework/allocator.h"
     23 #include "tensorflow/core/framework/fake_input.h"
     24 #include "tensorflow/core/framework/node_def_builder.h"
     25 #include "tensorflow/core/framework/op_kernel.h"
     26 #include "tensorflow/core/framework/tensor.h"
     27 #include "tensorflow/core/framework/types.h"
     28 #include "tensorflow/core/framework/types.pb.h"
     29 #include "tensorflow/core/kernels/ops_testutil.h"
     30 #include "tensorflow/core/kernels/ops_util.h"
     31 #include "tensorflow/core/lib/io/path.h"
     32 #include "tensorflow/core/lib/strings/strcat.h"
     33 #include "tensorflow/core/platform/test.h"
     34 #include "tensorflow/core/platform/test_benchmark.h"
     35 
     36 namespace tensorflow {
     37 namespace {
     38 
     39 class ReverseOpTest : public OpsTestBase {
     40  protected:
     41   void MakeOp(DataType data_type) {
     42     TF_ASSERT_OK(NodeDefBuilder("myop", "Reverse")
     43                      .Input(FakeInput(data_type))
     44                      .Input(FakeInput())
     45                      .Attr("T", data_type)
     46                      .Finalize(node_def()));
     47     TF_ASSERT_OK(InitOp());
     48   }
     49 
     50   template <typename T>
     51   void Reverse_0() {
     52     MakeOp(DataTypeToEnum<T>::value);
     53     AddInputFromArray<T>(TensorShape({}), {3});
     54     AddInputFromArray<bool>(TensorShape({}), {true});
     55     TF_ASSERT_OK(RunOpKernel());
     56 
     57     Tensor* output = GetOutput(0);
     58     Tensor expected(allocator(), DataTypeToEnum<T>::value, TensorShape({}));
     59     expected.scalar<T>() = expected.scalar<T>().constant(3);
     60     test::ExpectTensorEqual<T>(expected, *output);
     61   }
     62 
     63   template <typename T>
     64   void Reverse_234() {
     65     MakeOp(DataTypeToEnum<T>::value);
     66     // Feed and run
     67     // [[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]
     68     //  [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]]
     69     AddInputFromArray<T>(TensorShape({2, 3, 4}),
     70                          {0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
     71                           12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23});
     72     AddInputFromArray<bool>(TensorShape({3}), {true, false, true});
     73 
     74     TF_ASSERT_OK(RunOpKernel());
     75 
     76     // Check the new state of the input
     77     Tensor* params_tensor = GetOutput(0);
     78     Tensor expected(allocator(), DataTypeToEnum<T>::value,
     79                     TensorShape({2, 3, 4}));
     80     // Should become
     81     // [[[15, 14, 13, 12], [19, 18, 17, 16], [23, 22, 21, 20]]
     82     //  [[3, 2, 1, 0], [7, 6, 5, 4], [11, 10, 9, 8]]]
     83     test::FillValues<T>(&expected,
     84                         {15, 14, 13, 12, 19, 18, 17, 16, 23, 22, 21, 20,
     85                          3,  2,  1,  0,  7,  6,  5,  4,  11, 10, 9,  8});
     86     test::ExpectTensorEqual<T>(expected, *params_tensor);
     87   }
     88 
     89   template <typename T>
     90   void Reverse_1234() {
     91     MakeOp(DataTypeToEnum<T>::value);
     92     // Feed and run
     93     // [[[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]
     94     //   [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]]]
     95     AddInputFromArray<T>(TensorShape({1, 2, 3, 4}),
     96                          {0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
     97                           12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23});
     98     AddInputFromArray<bool>(TensorShape({4}), {true, true, false, true});
     99 
    100     TF_ASSERT_OK(RunOpKernel());
    101 
    102     // Check the new state of the input
    103     Tensor* params_tensor = GetOutput(0);
    104     Tensor expected(allocator(), DataTypeToEnum<T>::value,
    105                     TensorShape({1, 2, 3, 4}));
    106     // Should become
    107     // [[[[15, 14, 13, 12], [19, 18, 17, 16], [23, 22, 21, 20]]
    108     //   [[3, 2, 1, 0], [7, 6, 5, 4], [11, 10, 9, 8]]]]
    109     test::FillValues<T>(&expected,
    110                         {15, 14, 13, 12, 19, 18, 17, 16, 23, 22, 21, 20,
    111                          3,  2,  1,  0,  7,  6,  5,  4,  11, 10, 9,  8});
    112     test::ExpectTensorEqual<T>(expected, *params_tensor);
    113   }
    114 };
    115 
    116 TEST_F(ReverseOpTest, Reverse_0_uint8) { Reverse_0<uint8>(); }
    117 
    118 TEST_F(ReverseOpTest, Reverse_0_int8) { Reverse_0<int8>(); }
    119 
    120 TEST_F(ReverseOpTest, Reverse_0_uint16) { Reverse_0<uint16>(); }
    121 
    122 TEST_F(ReverseOpTest, Reverse_0_int16) { Reverse_0<int16>(); }
    123 
    124 TEST_F(ReverseOpTest, Reverse_0_float) { Reverse_0<float>(); }
    125 
    126 TEST_F(ReverseOpTest, Reverse_0_int32) { Reverse_0<int32>(); }
    127 
    128 TEST_F(ReverseOpTest, Reverse_0_int64) { Reverse_0<int64>(); }
    129 
    130 TEST_F(ReverseOpTest, Reverse_0_double) { Reverse_0<double>(); }
    131 
    132 TEST_F(ReverseOpTest, Reverse_0_complex64) { Reverse_0<complex64>(); }
    133 
    134 TEST_F(ReverseOpTest, Reverse_0_complex128) { Reverse_0<complex128>(); }
    135 
    136 TEST_F(ReverseOpTest, Reverse_234_uint8) { Reverse_234<uint8>(); }
    137 
    138 TEST_F(ReverseOpTest, Reverse_234_int8) { Reverse_234<int8>(); }
    139 
    140 TEST_F(ReverseOpTest, Reverse_234_uint16) { Reverse_234<uint16>(); }
    141 
    142 TEST_F(ReverseOpTest, Reverse_234_int16) { Reverse_234<int16>(); }
    143 
    144 TEST_F(ReverseOpTest, Reverse_234_float) { Reverse_234<float>(); }
    145 
    146 TEST_F(ReverseOpTest, Reverse_234_int32) { Reverse_234<int32>(); }
    147 
    148 TEST_F(ReverseOpTest, Reverse_234_int64) { Reverse_234<int64>(); }
    149 
    150 TEST_F(ReverseOpTest, Reverse_234_double) { Reverse_234<double>(); }
    151 
    152 TEST_F(ReverseOpTest, Reverse_234_complex64) { Reverse_234<complex64>(); }
    153 
    154 TEST_F(ReverseOpTest, Reverse_234_complex128) { Reverse_234<complex128>(); }
    155 
    156 TEST_F(ReverseOpTest, Reverse_1234_uint8) { Reverse_1234<uint8>(); }
    157 
    158 TEST_F(ReverseOpTest, Reverse_1234_int8) { Reverse_1234<int8>(); }
    159 
    160 TEST_F(ReverseOpTest, Reverse_1234_uint16) { Reverse_1234<uint16>(); }
    161 
    162 TEST_F(ReverseOpTest, Reverse_1234_int16) { Reverse_1234<int16>(); }
    163 
    164 TEST_F(ReverseOpTest, Reverse_1234_float) { Reverse_1234<float>(); }
    165 
    166 TEST_F(ReverseOpTest, Reverse_1234_int32) { Reverse_1234<int32>(); }
    167 
    168 TEST_F(ReverseOpTest, Reverse_1234_int64) { Reverse_1234<int64>(); }
    169 
    170 TEST_F(ReverseOpTest, Reverse_1234_double) { Reverse_1234<double>(); }
    171 
    172 TEST_F(ReverseOpTest, Reverse_1234_complex64) { Reverse_1234<complex64>(); }
    173 
    174 TEST_F(ReverseOpTest, Reverse_1234_complex128) { Reverse_1234<complex128>(); }
    175 
    176 static SessionOptions GetOptions(int intra_threads) {
    177   SessionOptions opts;
    178   opts.config.set_intra_op_parallelism_threads(intra_threads);
    179   opts.config.set_inter_op_parallelism_threads(1);
    180   return opts;
    181 }
    182 
    183 // Creates a Graph which "reduce"s a 3D float tensor of "num" elements
    184 // into a scalar.
    185 template <typename T>
    186 static Graph* Reverse(const TensorShape& shape, int reverse_axis) {
    187   Graph* g = new Graph(OpRegistry::Global());
    188   Tensor data(DataTypeToEnum<T>::value, shape);
    189   data.flat<T>().setRandom();
    190   Tensor axes(DT_INT32, TensorShape({1}));
    191   axes.flat<int32>()(0) = reverse_axis;
    192   test::graph::Reverse(g, test::graph::Constant(g, data),
    193                        test::graph::Constant(g, axes));
    194   return g;
    195 }
    196 
    197 template <typename T>
    198 static void RunReverseRowsBenchmark(int iters, int outer_dim, int middle_dim,
    199                                     int intra_threads, int channels) {
    200   SessionOptions opts = GetOptions(intra_threads);
    201   TensorShape shape{outer_dim, middle_dim, channels};
    202   const int64 num_items = static_cast<int64>(iters) * shape.num_elements();
    203   testing::ItemsProcessed(num_items);
    204   testing::BytesProcessed(num_items * sizeof(T));
    205   testing::UseRealTime();
    206   test::Benchmark("cpu", Reverse<T>(shape, 1), &opts).Run(iters);
    207 }
    208 
    209 static void BM_ReverseRowsOf1Channel_1T_float(int iters, int outer_dim,
    210                                               int middle_dim) {
    211   RunReverseRowsBenchmark<float>(iters, outer_dim, middle_dim,
    212                                  1 /* intra_threads */, 1 /* channels */);
    213 }
    214 
    215 BENCHMARK(BM_ReverseRowsOf1Channel_1T_float)
    216     ->ArgPair(288, 288)
    217     ->ArgPair(1024, 1024)
    218     ->ArgPair(10 * 1024, 1024);
    219 
    220 static void BM_ReverseRowsOf1Channel_1T_uint8(int iters, int outer_dim,
    221                                               int middle_dim) {
    222   RunReverseRowsBenchmark<uint8>(iters, outer_dim, middle_dim,
    223                                  1 /* intra_threads */, 1 /* channels */);
    224 }
    225 
    226 BENCHMARK(BM_ReverseRowsOf1Channel_1T_uint8)
    227     ->ArgPair(288, 288)
    228     ->ArgPair(1024, 1024)
    229     ->ArgPair(10 * 1024, 1024);
    230 
    231 static void BM_ReverseRowsOf1Channel_4T_float(int iters, int outer_dim,
    232                                               int middle_dim) {
    233   RunReverseRowsBenchmark<float>(iters, outer_dim, middle_dim,
    234                                  4 /* intra_threads */, 1 /* channels */);
    235 }
    236 
    237 BENCHMARK(BM_ReverseRowsOf1Channel_4T_float)
    238     ->ArgPair(288, 288)
    239     ->ArgPair(1024, 1024)
    240     ->ArgPair(10 * 1024, 1024);
    241 
    242 static void BM_ReverseRowsOf1Channel_4T_uint8(int iters, int outer_dim,
    243                                               int middle_dim) {
    244   RunReverseRowsBenchmark<uint8>(iters, outer_dim, middle_dim,
    245                                  4 /* intra_threads */, 1 /* channels */);
    246 }
    247 
    248 BENCHMARK(BM_ReverseRowsOf1Channel_4T_uint8)
    249     ->ArgPair(288, 288)
    250     ->ArgPair(1024, 1024)
    251     ->ArgPair(10 * 1024, 1024);
    252 
    253 static void BM_ReverseRowsOf3Channels_1T_float(int iters, int outer_dim,
    254                                                int middle_dim) {
    255   RunReverseRowsBenchmark<float>(iters, outer_dim, middle_dim,
    256                                  1 /* intra_threads */, 3 /* channels */);
    257 }
    258 
    259 BENCHMARK(BM_ReverseRowsOf3Channels_1T_float)
    260     ->ArgPair(288, 288)
    261     ->ArgPair(30, 30)
    262     ->ArgPair(1024, 1024)
    263     ->ArgPair(10 * 1024, 1024);
    264 
    265 static void BM_ReverseRowsOf3Channels_1T_uint8(int iters, int outer_dim,
    266                                                int middle_dim) {
    267   RunReverseRowsBenchmark<uint8>(iters, outer_dim, middle_dim,
    268                                  1 /* intra_threads */, 3 /* channels */);
    269 }
    270 
    271 BENCHMARK(BM_ReverseRowsOf3Channels_1T_uint8)
    272     ->ArgPair(288, 288)
    273     ->ArgPair(30, 30)
    274     ->ArgPair(1024, 1024)
    275     ->ArgPair(10 * 1024, 1024);
    276 
    277 static void BM_ReverseRowsOf3Channels_4T_float(int iters, int outer_dim,
    278                                                int middle_dim) {
    279   RunReverseRowsBenchmark<float>(iters, outer_dim, middle_dim,
    280                                  4 /* intra_threads */, 3 /* channels */);
    281 }
    282 
    283 BENCHMARK(BM_ReverseRowsOf3Channels_4T_float)
    284     ->ArgPair(288, 288)
    285     ->ArgPair(30, 30)
    286     ->ArgPair(1024, 1024)
    287     ->ArgPair(10 * 1024, 1024);
    288 
    289 static void BM_ReverseRowsOf3Channels_4T_uint8(int iters, int outer_dim,
    290                                                int middle_dim) {
    291   RunReverseRowsBenchmark<uint8>(iters, outer_dim, middle_dim,
    292                                  4 /* intra_threads */, 3 /* channels */);
    293 }
    294 BENCHMARK(BM_ReverseRowsOf3Channels_4T_uint8)
    295     ->ArgPair(288, 288)
    296     ->ArgPair(30, 30)
    297     ->ArgPair(1024, 1024)
    298     ->ArgPair(10 * 1024, 1024);
    299 
    300 static void BM_ReverseRowsOf4Channels_1T_float(int iters, int outer_dim,
    301                                                int middle_dim) {
    302   RunReverseRowsBenchmark<float>(iters, outer_dim, middle_dim,
    303                                  1 /* intra_threads */, 4 /* channels */);
    304 }
    305 
    306 BENCHMARK(BM_ReverseRowsOf4Channels_1T_float)
    307     ->ArgPair(288, 288)
    308     ->ArgPair(1024, 1024)
    309     ->ArgPair(10 * 1024, 1024);
    310 
    311 static void BM_ReverseRowsOf4Channels_1T_uint8(int iters, int outer_dim,
    312                                                int middle_dim) {
    313   RunReverseRowsBenchmark<uint8>(iters, outer_dim, middle_dim,
    314                                  1 /* intra_threads */, 4 /* channels */);
    315 }
    316 
    317 BENCHMARK(BM_ReverseRowsOf4Channels_1T_uint8)
    318     ->ArgPair(288, 288)
    319     ->ArgPair(1024, 1024)
    320     ->ArgPair(10 * 1024, 1024);
    321 
    322 static void BM_ReverseRowsOf4Channels_4T_float(int iters, int outer_dim,
    323                                                int middle_dim) {
    324   RunReverseRowsBenchmark<float>(iters, outer_dim, middle_dim,
    325                                  4 /* intra_threads */, 4 /* channels */);
    326 }
    327 
    328 BENCHMARK(BM_ReverseRowsOf4Channels_4T_float)
    329     ->ArgPair(288, 288)
    330     ->ArgPair(1024, 1024)
    331     ->ArgPair(10 * 1024, 1024);
    332 
    333 static void BM_ReverseRowsOf4Channels_4T_uint8(int iters, int outer_dim,
    334                                                int middle_dim) {
    335   RunReverseRowsBenchmark<uint8>(iters, outer_dim, middle_dim,
    336                                  4 /* intra_threads */, 4 /* channels */);
    337 }
    338 
    339 BENCHMARK(BM_ReverseRowsOf4Channels_4T_uint8)
    340     ->ArgPair(288, 288)
    341     ->ArgPair(1024, 1024)
    342     ->ArgPair(10 * 1024, 1024);
    343 
    344 }  // namespace
    345 }  // namespace tensorflow
    346