Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <functional>
     17 #include <memory>
     18 #include <vector>
     19 
     20 #include "tensorflow/core/framework/allocator.h"
     21 #include "tensorflow/core/framework/fake_input.h"
     22 #include "tensorflow/core/framework/node_def_builder.h"
     23 #include "tensorflow/core/framework/op_kernel.h"
     24 #include "tensorflow/core/framework/tensor.h"
     25 #include "tensorflow/core/framework/types.h"
     26 #include "tensorflow/core/framework/types.pb.h"
     27 #include "tensorflow/core/kernels/ops_testutil.h"
     28 #include "tensorflow/core/kernels/ops_util.h"
     29 #include "tensorflow/core/lib/core/status_test_util.h"
     30 #include "tensorflow/core/lib/random/simple_philox.h"
     31 #include "tensorflow/core/platform/logging.h"
     32 #include "tensorflow/core/platform/test.h"
     33 #include "tensorflow/core/platform/test_benchmark.h"
     34 
     35 namespace tensorflow {
     36 namespace {
     37 
     38 class ScatterNdUpdateOpTest : public OpsTestBase {
     39  protected:
     40   void MakeOp(DataType variable_ref_type, DataType index_type) {
     41     TF_ASSERT_OK(NodeDefBuilder("myop", "ScatterNdUpdate")
     42                      .Input(FakeInput(variable_ref_type))
     43                      .Input(FakeInput(index_type))
     44                      .Input(FakeInput(RemoveRefType(variable_ref_type)))
     45                      .Finalize(node_def()));
     46     TF_ASSERT_OK(InitOp());
     47   }
     48 };
     49 
     50 // TODO(simister): Re-enable this once binary size is under control.
     51 // TEST_F(ScatterNdUpdateOpTest, Simple_StringType) {
     52 //   MakeOp(DT_STRING_REF, DT_INT32);
     53 //   AddInputFromArray<string>(TensorShape({1}), {"Brain"});
     54 //   AddInputFromArray<int32>(TensorShape({1}), {0});
     55 //   AddInputFromArray<string>(TensorShape({1}), {"TensorFlow"});
     56 //   TF_ASSERT_OK(RunOpKernel());
     57 //   // Check the new state of the input
     58 //   Tensor params_tensor = *mutable_input(0).tensor;
     59 //   Tensor expected(allocator(), DT_STRING, TensorShape({1}));
     60 //   test::FillValues<string>(&expected, {"TensorFlow"});
     61 //   test::ExpectTensorEqual<string>(expected, params_tensor);
     62 // }
     63 
     64 // TEST_F(ScatterNdUpdateOpTest, Simple_BoolType) {
     65 //   MakeOp(DT_BOOL_REF, DT_INT32);
     66 //   AddInputFromArray<bool>(TensorShape({1}), {false});
     67 //   AddInputFromArray<int32>(TensorShape({1}), {0});
     68 //   AddInputFromArray<bool>(TensorShape({1}), {true});
     69 //   TF_ASSERT_OK(RunOpKernel());
     70 //   // Check the new state of the input
     71 //   Tensor params_tensor = *mutable_input(0).tensor;
     72 //   Tensor expected(allocator(), DT_BOOL, TensorShape({1}));
     73 //   test::FillValues<bool>(&expected, {true});
     74 //   test::ExpectTensorEqual<bool>(expected, params_tensor);
     75 // }
     76 
     77 TEST_F(ScatterNdUpdateOpTest, Simple_TwoD32) {
     78   MakeOp(DT_FLOAT_REF, DT_INT32);
     79 
     80   // Feed and run
     81   AddInputFromArray<float>(TensorShape({5, 3}),
     82                            {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
     83   AddInputFromArray<int32>(TensorShape({3, 1}), {0, 4, 2});
     84   AddInputFromArray<float>(TensorShape({3, 3}),
     85                            {100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
     86   TF_ASSERT_OK(RunOpKernel());
     87 
     88   // Check the new state of the input
     89   Tensor params_tensor = *mutable_input(0).tensor;
     90   Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 3}));
     91   test::FillValues<float>(&expected, {100, 101, 102, 0, 0, 0, 10000, 10001,
     92                                       10002, 0, 0, 0, 777, 778, 779});
     93   test::ExpectTensorEqual<float>(expected, params_tensor);
     94 }
     95 
     96 TEST_F(ScatterNdUpdateOpTest, Simple_Two64) {
     97   MakeOp(DT_FLOAT_REF, DT_INT64);
     98 
     99   // Feed and run
    100   AddInputFromArray<float>(TensorShape({5, 3}),
    101                            {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
    102   AddInputFromArray<int64>(TensorShape({3, 1}), {0, 4, 2});
    103   AddInputFromArray<float>(TensorShape({3, 3}),
    104                            {100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
    105   TF_ASSERT_OK(RunOpKernel());
    106 
    107   // Check the new state of the input
    108   Tensor params_tensor = *mutable_input(0).tensor;
    109   Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 3}));
    110   test::FillValues<float>(&expected, {100, 101, 102, 0, 0, 0, 10000, 10001,
    111                                       10002, 0, 0, 0, 777, 778, 779});
    112   test::ExpectTensorEqual<float>(expected, params_tensor);
    113 }
    114 
    115 /*TEST_F(ScatterNdUpdateOpTest, Simple_ZeroElements) {
    116   MakeOp(DT_FLOAT_REF, DT_INT32);
    117 
    118   // Feed and run
    119   AddInputFromArray<float>(TensorShape({0}), {});
    120   AddInputFromArray<int32>(TensorShape({0}), {});
    121   AddInputFromArray<float>(TensorShape({0}), {});
    122   Status s = RunOpKernel();
    123   EXPECT_TRUE(StringPiece(s.ToString())
    124                   .contains("Output must not have 0 elements, got shape: "))
    125       << s;
    126 }*/
    127 
    128 TEST_F(ScatterNdUpdateOpTest, Simple_ZeroD) {
    129   MakeOp(DT_FLOAT_REF, DT_INT32);
    130 
    131   // Feed and run
    132   AddInputFromArray<float>(TensorShape({5}), {0, 0, 0, 0, 0});
    133   AddInputFromArray<int32>(TensorShape({1}), {3});
    134   AddInputFromArray<float>(TensorShape({1}), {101});
    135   TF_ASSERT_OK(RunOpKernel());
    136 
    137   // Check the new state of the input
    138   Tensor params_tensor = *mutable_input(0).tensor;
    139   Tensor expected(allocator(), DT_FLOAT, TensorShape({5}));
    140   test::FillValues<float>(&expected, {0, 0, 0, 101, 0});
    141   test::ExpectTensorEqual<float>(expected, params_tensor);
    142 }
    143 
    144 TEST_F(ScatterNdUpdateOpTest, Simple_OneD) {
    145   MakeOp(DT_FLOAT_REF, DT_INT32);
    146 
    147   // Feed and run
    148   AddInputFromArray<float>(TensorShape({5}), {0, 0, 0, 0, 0});
    149   AddInputFromArray<int32>(TensorShape({3, 1}), {0, 4, 2});
    150   AddInputFromArray<float>(TensorShape({3}), {100, 101, 102});
    151   TF_ASSERT_OK(RunOpKernel());
    152 
    153   // Check the new state of the input
    154   Tensor params_tensor = *mutable_input(0).tensor;
    155   Tensor expected(allocator(), DT_FLOAT, TensorShape({5}));
    156   test::FillValues<float>(&expected, {100, 0, 102, 0, 101});
    157   test::ExpectTensorEqual<float>(expected, params_tensor);
    158 }
    159 
    160 TEST_F(ScatterNdUpdateOpTest, HigherRank) {
    161   MakeOp(DT_FLOAT_REF, DT_INT32);
    162 
    163   // Feed and run
    164   AddInputFromArray<float>(TensorShape({8}), {0, 0, 0, 0, 0, 0, 0, 0});
    165   AddInputFromArray<int32>(TensorShape({2, 3, 1}), {0, 4, 2, 1, 3, 6});
    166   AddInputFromArray<float>(TensorShape({2, 3}), {10, 20, 30, 40, 50, 60});
    167   TF_ASSERT_OK(RunOpKernel());
    168 
    169   // Check the new state of the input
    170   Tensor params_tensor = *mutable_input(0).tensor;
    171   Tensor expected(allocator(), DT_FLOAT, TensorShape({8}));
    172   test::FillValues<float>(&expected, {10, 40, 30, 50, 20, 0, 60, 0});
    173   test::ExpectTensorEqual<float>(expected, params_tensor);
    174 }
    175 
    176 TEST_F(ScatterNdUpdateOpTest, Error_IndexOutOfRange) {
    177   MakeOp(DT_FLOAT_REF, DT_INT32);
    178 
    179   // Feed and run
    180   AddInputFromArray<float>(TensorShape({5, 3}),
    181                            {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
    182   AddInputFromArray<int32>(TensorShape({3, 1}), {0, 4, 99});
    183   AddInputFromArray<float>(TensorShape({3, 3}),
    184                            {100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
    185   Status s = RunOpKernel();
    186   EXPECT_TRUE(
    187       StringPiece(s.ToString())
    188           .contains("Invalid indices: [2,0] = [99] does not index into [5,3]"))
    189       << s;
    190 }
    191 
    192 TEST_F(ScatterNdUpdateOpTest, Error_WrongDimsIndices) {
    193   MakeOp(DT_FLOAT_REF, DT_INT32);
    194 
    195   // Feed and run
    196   AddInputFromArray<float>(TensorShape({2, 3}), {0, 0, 0, 0, 0, 0});
    197   AddInputFromArray<int32>(TensorShape({1, 3, 1}), {0, 4, 99});
    198   AddInputFromArray<float>(TensorShape({3, 3}),
    199                            {100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
    200   Status s = RunOpKernel();
    201   EXPECT_TRUE(StringPiece(s.ToString())
    202                   .contains("The outermost dimension of updates and indices "
    203                             "must match. Got indices.shape [1,3,1], "
    204                             "updates.shape [3,3]"))
    205       << s;
    206 }
    207 
    208 TEST_F(ScatterNdUpdateOpTest, Error_MismatchedParamsAndUpdateDimensions) {
    209   MakeOp(DT_FLOAT_REF, DT_INT32);
    210 
    211   // Feed and run
    212   AddInputFromArray<float>(TensorShape({5, 3}),
    213                            {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
    214   AddInputFromArray<int32>(TensorShape({3, 1}), {0, 4, 2});
    215   AddInputFromArray<float>(
    216       TensorShape({3, 4}),
    217       {100, 101, 102, 103, 777, 778, 779, 780, 10000, 10001, 10002, 10004});
    218   Status s = RunOpKernel();
    219   EXPECT_TRUE(
    220       StringPiece(s.ToString())
    221           .contains("Must have updates.shape = indices.shape[:batch_dim]"))
    222 
    223       << s;
    224 }
    225 
    226 TEST_F(ScatterNdUpdateOpTest, Error_MismatchedIndicesAndUpdateDimensions) {
    227   MakeOp(DT_FLOAT_REF, DT_INT32);
    228 
    229   // Feed and run
    230   AddInputFromArray<float>(TensorShape({5, 3}),
    231                            {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
    232   AddInputFromArray<int32>(TensorShape({3, 1}), {0, 4, 2});
    233   AddInputFromArray<float>(TensorShape({2, 3}),
    234                            {100, 101, 102, 10000, 10001, 10002});
    235   Status s = RunOpKernel();
    236   EXPECT_TRUE(
    237       StringPiece(s.ToString())
    238           .contains(
    239               "The outermost dimension of updates and indices must match."))
    240       << s;
    241 }
    242 
    243 class ScatterNdUpdateBM : public ScatterNdUpdateOpTest {
    244  public:
    245   void TestBody() override {}
    246   void MakeBenchmarkOp(const char* op, DataType index_type) {
    247     TF_ASSERT_OK(NodeDefBuilder("myop", op)
    248                      .Input(FakeInput(DT_FLOAT_REF))
    249                      .Input(FakeInput(index_type))
    250                      .Input(FakeInput(DT_FLOAT))
    251                      .Finalize(node_def()));
    252     TF_CHECK_OK(InitOp());
    253   }
    254 };
    255 
    256 template <typename Index>
    257 static void BM_ScatterNdHelper(int iters, int embedding_size, const char* op) {
    258   testing::StopTiming();
    259   const int kRows = 10000000 / embedding_size;
    260   std::vector<float> values;
    261   values.reserve(kRows);
    262   for (int i = 0; i < kRows * embedding_size; i++) {
    263     values.push_back(i);
    264   }
    265   const int kNumUpdates = 1000;
    266   random::PhiloxRandom philox(301, 17);
    267   random::SimplePhilox rnd(&philox);
    268   std::vector<Index> indices;
    269   std::vector<float> updates;
    270   for (int i = 0; i < kNumUpdates; i++) {
    271     indices.push_back(rnd.Uniform(kRows));
    272     for (int j = 0; j < embedding_size; j++) {
    273       updates.push_back(i * 10 + j);
    274     }
    275   }
    276 
    277   ScatterNdUpdateBM bm;
    278   bm.MakeBenchmarkOp(op, DataTypeToEnum<Index>::v());
    279   bm.AddInputFromArray<float>(TensorShape({kRows, embedding_size}), values);
    280   bm.AddInputFromArray<Index>(TensorShape({kNumUpdates}), indices);
    281   bm.AddInputFromArray<float>(TensorShape({kNumUpdates, embedding_size}),
    282                               updates);
    283   testing::ItemsProcessed((static_cast<int64>(kNumUpdates) * embedding_size) *
    284                           iters);
    285   testing::StartTiming();
    286   while (iters-- > 0) {
    287     Status s = bm.RunOpKernel();
    288   }
    289   testing::StopTiming();
    290 }
    291 
    292 static void BM_ScatterNdUpdateInt32(int iters, int embedding_size) {
    293   BM_ScatterNdHelper<int32>(iters, embedding_size, "ScatterNdUpdate");
    294 }
    295 static void BM_ScatterNdUpdateInt64(int iters, int embedding_size) {
    296   BM_ScatterNdHelper<int64>(iters, embedding_size, "ScatterNdUpdate");
    297 }
    298 
    299 static void BM_ScatterNdAddInt32(int iters, int embedding_size) {
    300   BM_ScatterNdHelper<int32>(iters, embedding_size, "ScatterNdAdd");
    301 }
    302 static void BM_ScatterNdAddInt64(int iters, int embedding_size) {
    303   BM_ScatterNdHelper<int64>(iters, embedding_size, "ScatterNdAdd");
    304 }
    305 
    306 BENCHMARK(BM_ScatterNdUpdateInt32)
    307     ->Arg(1)
    308     ->Arg(10)
    309     ->Arg(64)
    310     ->Arg(256)
    311     ->Arg(1024);
    312 BENCHMARK(BM_ScatterNdUpdateInt64)
    313     ->Arg(1)
    314     ->Arg(10)
    315     ->Arg(64)
    316     ->Arg(256)
    317     ->Arg(1024);
    318 
    319 BENCHMARK(BM_ScatterNdAddInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
    320 BENCHMARK(BM_ScatterNdAddInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
    321 
    322 }  // namespace
    323 }  // namespace tensorflow
    324