Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <functional>
     17 #include <memory>
     18 #include <vector>
     19 
     20 #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
     21 #include "tensorflow/core/framework/allocator.h"
     22 #include "tensorflow/core/framework/fake_input.h"
     23 #include "tensorflow/core/framework/node_def_builder.h"
     24 #include "tensorflow/core/framework/op_kernel.h"
     25 #include "tensorflow/core/framework/tensor.h"
     26 #include "tensorflow/core/framework/types.h"
     27 #include "tensorflow/core/framework/types.pb.h"
     28 #include "tensorflow/core/graph/testlib.h"
     29 #include "tensorflow/core/kernels/ops_testutil.h"
     30 #include "tensorflow/core/kernels/ops_util.h"
     31 #include "tensorflow/core/lib/core/status_test_util.h"
     32 #include "tensorflow/core/lib/gtl/array_slice.h"
     33 #include "tensorflow/core/lib/random/simple_philox.h"
     34 #include "tensorflow/core/platform/test.h"
     35 #include "tensorflow/core/platform/test_benchmark.h"
     36 
     37 namespace tensorflow {
     38 namespace {
     39 
     40 class GatherOpTest : public OpsTestBase {
     41  protected:
     42   void MakeOp(DataType data_type, DataType index_type) {
     43     TF_ASSERT_OK(NodeDefBuilder("myop", "GatherV2")
     44                      .Input(FakeInput(data_type))
     45                      .Input(FakeInput(index_type))
     46                      .Input(FakeInput(index_type))
     47                      .Finalize(node_def()));
     48     TF_ASSERT_OK(InitOp());
     49   }
     50 };
     51 
     52 TEST_F(GatherOpTest, ScalarIndices) {
     53   MakeOp(DT_FLOAT, DT_INT32);
     54 
     55   // Feed and run
     56   AddInputFromArray<float>(TensorShape({5}), {0, 1, 2, 3, 4});
     57   AddInputFromArray<int32>(TensorShape({}), {3});
     58   AddInputFromArray<int32>(TensorShape({}), {0});
     59   TF_ASSERT_OK(RunOpKernel());
     60 
     61   // Check the output.
     62   Tensor expected(allocator(), DT_FLOAT, TensorShape({}));
     63   test::FillValues<float>(&expected, {3});
     64   test::ExpectTensorEqual<float>(expected, *GetOutput(0));
     65 }
     66 
     67 TEST_F(GatherOpTest, ScalarIndices_Complex) {
     68   MakeOp(DT_COMPLEX64, DT_INT32);
     69 
     70   // Feed and run
     71   AddInputFromArray<std::complex<float>>(
     72       TensorShape({5}), {std::complex<float>(0, 10), std::complex<float>(1, 11),
     73                          std::complex<float>(2, 12), std::complex<float>(3, 13),
     74                          std::complex<float>(4, 14)});
     75   AddInputFromArray<int32>(TensorShape({}), {3});
     76   AddInputFromArray<int32>(TensorShape({}), {0});
     77   TF_ASSERT_OK(RunOpKernel());
     78 
     79   // Check the output.
     80   Tensor expected(allocator(), DT_COMPLEX64, TensorShape({}));
     81   test::FillValues<std::complex<float>>(&expected,
     82                                         {std::complex<float>(3, 13)});
     83   test::ExpectTensorEqual<std::complex<float>>(expected, *GetOutput(0));
     84 }
     85 
     86 TEST_F(GatherOpTest, Simple_TwoD32_Axis0) {
     87   MakeOp(DT_FLOAT, DT_INT32);
     88 
     89   // Feed and run
     90   AddInputFromArray<float>(TensorShape({5, 3}),
     91                            {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14});
     92   AddInputFromArray<int32>(TensorShape({4}), {0, 4, 0, 2});
     93   AddInputFromArray<int32>(TensorShape({}), {0});
     94   TF_ASSERT_OK(RunOpKernel());
     95 
     96   // Check the output.
     97   Tensor expected(allocator(), DT_FLOAT, TensorShape({4, 3}));
     98   test::FillValues<float>(&expected, {0, 1, 2, 12, 13, 14, 0, 1, 2, 6, 7, 8});
     99   test::ExpectTensorEqual<float>(expected, *GetOutput(0));
    100 }
    101 
    102 TEST_F(GatherOpTest, Simple_TwoD32_Axis1) {
    103   MakeOp(DT_FLOAT, DT_INT32);
    104 
    105   // Feed and run
    106   AddInputFromArray<float>(TensorShape({5, 3}),
    107                            {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14});
    108   AddInputFromArray<int32>(TensorShape({4}), {0, 1, 0, 2});
    109   AddInputFromArray<int32>(TensorShape({}), {1});
    110   TF_ASSERT_OK(RunOpKernel());
    111 
    112   // Check the output.
    113   Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 4}));
    114   test::FillValues<float>(&expected, {0, 1, 0, 2,  3, 4,  3,  5,  6,  7,
    115                                       6, 8, 9, 10, 9, 11, 12, 13, 12, 14});
    116   test::ExpectTensorEqual<float>(expected, *GetOutput(0));
    117 }
    118 
    119 TEST_F(GatherOpTest, ZeroSize_TwoD32) {
    120   MakeOp(DT_FLOAT, DT_INT32);
    121 
    122   // Feed and run
    123   AddInputFromArray<float>(TensorShape({5, 0}), {});
    124   AddInputFromArray<int32>(TensorShape({4}), {0, 4, 0, 2});
    125   AddInputFromArray<int32>(TensorShape({}), {0});
    126   TF_ASSERT_OK(RunOpKernel());
    127 
    128   // Check the output.
    129   Tensor expected(allocator(), DT_FLOAT, TensorShape({4, 0}));
    130   test::ExpectTensorEqual<float>(expected, *GetOutput(0));
    131 }
    132 
    133 TEST_F(GatherOpTest, Simple_TwoD64) {
    134   MakeOp(DT_FLOAT, DT_INT64);
    135 
    136   // Feed and run
    137   AddInputFromArray<float>(TensorShape({5, 3}),
    138                            {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14});
    139   AddInputFromArray<int64>(TensorShape({4}), {0, 4, 0, 2});
    140   AddInputFromArray<int64>(TensorShape({}), {0});
    141   TF_ASSERT_OK(RunOpKernel());
    142 
    143   // Check the output.
    144   Tensor expected(allocator(), DT_FLOAT, TensorShape({4, 3}));
    145   test::FillValues<float>(&expected, {0, 1, 2, 12, 13, 14, 0, 1, 2, 6, 7, 8});
    146   test::ExpectTensorEqual<float>(expected, *GetOutput(0));
    147 }
    148 
    149 TEST_F(GatherOpTest, HighRank) {
    150   MakeOp(DT_FLOAT, DT_INT32);
    151 
    152   // Feed and run
    153   AddInputFromArray<float>(TensorShape({4}), {0, 1, 2, 3});
    154   AddInputFromArray<int32>(TensorShape({2, 3}), {1, 2, 0, 2, 3, 0});
    155   AddInputFromArray<int32>(TensorShape({}), {0});
    156   TF_ASSERT_OK(RunOpKernel());
    157 
    158   // Check the output
    159   Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3}));
    160   test::FillValues<float>(&expected, {1, 2, 0, 2, 3, 0});
    161   test::ExpectTensorEqual<float>(expected, *GetOutput(0));
    162 }
    163 
    164 TEST_F(GatherOpTest, Error_IndexOutOfRange) {
    165   MakeOp(DT_FLOAT, DT_INT32);
    166 
    167   // Feed and run
    168   AddInputFromArray<float>(TensorShape({5, 3}),
    169                            {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14});
    170   AddInputFromArray<int32>(TensorShape({4}), {0, 4, 99, 2});
    171   AddInputFromArray<int32>(TensorShape({}), {0});
    172   Status s = RunOpKernel();
    173   EXPECT_TRUE(
    174       StringPiece(s.ToString()).contains("indices[2] = 99 is not in [0, 5)"))
    175       << s;
    176 }
    177 
    178 constexpr int kLookups = 2000;
    179 
    180 template <typename Index>
    181 static Graph* Gather(int dim) {
    182   Graph* g = new Graph(OpRegistry::Global());
    183   // Always use a 512MB buffer.
    184   const int kRows = ((512 << 20) / sizeof(float)) / dim;
    185   Tensor params(DT_FLOAT, TensorShape({kRows, dim}));
    186   params.flat<float>().setRandom();
    187 
    188   random::PhiloxRandom philox(301, 17);
    189   random::SimplePhilox rnd(&philox);
    190   std::vector<Index> indices_vec;
    191   indices_vec.reserve(kLookups);
    192   for (int i = 0; i < kLookups; i++) {
    193     indices_vec.push_back(rnd.Uniform(kRows));
    194   }
    195   Tensor indices(DataTypeToEnum<Index>::value, TensorShape({kLookups}));
    196   for (int i = 0; i < indices_vec.size(); i++) {
    197     indices.flat<Index>()(i) = indices_vec[i];
    198   }
    199 
    200   Tensor axis(DataTypeToEnum<Index>::value, TensorShape({}));
    201   axis.scalar<Index>()() = 0;
    202 
    203   test::graph::Gather(g, test::graph::Constant(g, params),
    204                       test::graph::Constant(g, indices),
    205                       test::graph::HostConstant(g, axis));
    206   return g;
    207 }
    208 
    209 #define BM_GATHER(DEVICE, INDEX)                                  \
    210   static void BM_##DEVICE##_gather_##INDEX(int iters, int dim) {  \
    211     const int64 tot = static_cast<int64>(iters) * kLookups * dim; \
    212     testing::ItemsProcessed(tot);                                 \
    213     testing::BytesProcessed(tot * sizeof(float));                 \
    214     testing::UseRealTime();                                       \
    215     test::Benchmark(#DEVICE, Gather<INDEX>(dim)).Run(iters);      \
    216   }                                                               \
    217   BENCHMARK(BM_##DEVICE##_gather_##INDEX)                         \
    218       ->Arg(1)                                                    \
    219       ->Arg(10)                                                   \
    220       ->Arg(20)                                                   \
    221       ->Arg(64)                                                   \
    222       ->Arg(100)                                                  \
    223       ->Arg(200)                                                  \
    224       ->Arg(1000)
    225 
    226 BM_GATHER(cpu, int32);
    227 BM_GATHER(gpu, int32);
    228 BM_GATHER(cpu, int64);
    229 BM_GATHER(gpu, int64);
    230 
    231 }  // namespace
    232 }  // namespace tensorflow
    233