Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <functional>
     17 #include <memory>
     18 #include <vector>
     19 
     20 #include "tensorflow/core/framework/allocator.h"
     21 #include "tensorflow/core/framework/fake_input.h"
     22 #include "tensorflow/core/framework/node_def_builder.h"
     23 #include "tensorflow/core/framework/op_kernel.h"
     24 #include "tensorflow/core/framework/tensor.h"
     25 #include "tensorflow/core/framework/types.h"
     26 #include "tensorflow/core/framework/types.pb.h"
     27 #include "tensorflow/core/kernels/ops_testutil.h"
     28 #include "tensorflow/core/kernels/ops_util.h"
     29 #include "tensorflow/core/lib/core/status_test_util.h"
     30 #include "tensorflow/core/lib/random/simple_philox.h"
     31 #include "tensorflow/core/platform/logging.h"
     32 #include "tensorflow/core/platform/test.h"
     33 #include "tensorflow/core/platform/test_benchmark.h"
     34 
     35 namespace tensorflow {
     36 namespace {
     37 
     38 class ScatterUpdateOpTest : public OpsTestBase {
     39  protected:
     40   void MakeOp(DataType variable_ref_type, DataType index_type) {
     41     TF_ASSERT_OK(NodeDefBuilder("myop", "ScatterUpdate")
     42                      .Input(FakeInput(variable_ref_type))
     43                      .Input(FakeInput(index_type))
     44                      .Input(FakeInput(RemoveRefType(variable_ref_type)))
     45                      .Finalize(node_def()));
     46     TF_ASSERT_OK(InitOp());
     47   }
     48 };
     49 
     50 TEST_F(ScatterUpdateOpTest, Simple_StringType) {
     51   MakeOp(DT_STRING_REF, DT_INT32);
     52   AddInputFromArray<string>(TensorShape({1}), {"Brain"});
     53   AddInputFromArray<int32>(TensorShape({1}), {0});
     54   AddInputFromArray<string>(TensorShape({1}), {"TensorFlow"});
     55   TF_ASSERT_OK(RunOpKernel());
     56   // Check the new state of the input
     57   Tensor params_tensor = *mutable_input(0).tensor;
     58   Tensor expected(allocator(), DT_STRING, TensorShape({1}));
     59   test::FillValues<string>(&expected, {"TensorFlow"});
     60   test::ExpectTensorEqual<string>(expected, params_tensor);
     61 }
     62 
     63 TEST_F(ScatterUpdateOpTest, Simple_BoolType) {
     64   MakeOp(DT_BOOL_REF, DT_INT32);
     65   AddInputFromArray<bool>(TensorShape({1}), {false});
     66   AddInputFromArray<int32>(TensorShape({1}), {0});
     67   AddInputFromArray<bool>(TensorShape({1}), {true});
     68   TF_ASSERT_OK(RunOpKernel());
     69   // Check the new state of the input
     70   Tensor params_tensor = *mutable_input(0).tensor;
     71   Tensor expected(allocator(), DT_BOOL, TensorShape({1}));
     72   test::FillValues<bool>(&expected, {true});
     73   test::ExpectTensorEqual<bool>(expected, params_tensor);
     74 }
     75 
     76 TEST_F(ScatterUpdateOpTest, Simple_TwoD32) {
     77   MakeOp(DT_FLOAT_REF, DT_INT32);
     78 
     79   // Feed and run
     80   AddInputFromArray<float>(TensorShape({5, 3}),
     81                            {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
     82   AddInputFromArray<int32>(TensorShape({3}), {0, 4, 2});
     83   AddInputFromArray<float>(TensorShape({3, 3}),
     84                            {100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
     85   TF_ASSERT_OK(RunOpKernel());
     86 
     87   // Check the new state of the input
     88   Tensor params_tensor = *mutable_input(0).tensor;
     89   Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 3}));
     90   test::FillValues<float>(&expected, {100, 101, 102, 0, 0, 0, 10000, 10001,
     91                                       10002, 0, 0, 0, 777, 778, 779});
     92   test::ExpectTensorEqual<float>(expected, params_tensor);
     93 }
     94 
     95 TEST_F(ScatterUpdateOpTest, Simple_Two64) {
     96   MakeOp(DT_FLOAT_REF, DT_INT64);
     97 
     98   // Feed and run
     99   AddInputFromArray<float>(TensorShape({5, 3}),
    100                            {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
    101   AddInputFromArray<int64>(TensorShape({3}), {0, 4, 2});
    102   AddInputFromArray<float>(TensorShape({3, 3}),
    103                            {100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
    104   TF_ASSERT_OK(RunOpKernel());
    105 
    106   // Check the new state of the input
    107   Tensor params_tensor = *mutable_input(0).tensor;
    108   Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 3}));
    109   test::FillValues<float>(&expected, {100, 101, 102, 0, 0, 0, 10000, 10001,
    110                                       10002, 0, 0, 0, 777, 778, 779});
    111   test::ExpectTensorEqual<float>(expected, params_tensor);
    112 }
    113 
    114 TEST_F(ScatterUpdateOpTest, Simple_ZeroD) {
    115   MakeOp(DT_FLOAT_REF, DT_INT32);
    116 
    117   // Feed and run
    118   AddInputFromArray<float>(TensorShape({5}), {0, 0, 0, 0, 0});
    119   AddInputFromArray<int32>(TensorShape({}), {3});
    120   AddInputFromArray<float>(TensorShape({}), {101});
    121   TF_ASSERT_OK(RunOpKernel());
    122 
    123   // Check the new state of the input
    124   Tensor params_tensor = *mutable_input(0).tensor;
    125   Tensor expected(allocator(), DT_FLOAT, TensorShape({5}));
    126   test::FillValues<float>(&expected, {0, 0, 0, 101, 0});
    127   test::ExpectTensorEqual<float>(expected, params_tensor);
    128 }
    129 
    130 TEST_F(ScatterUpdateOpTest, Simple_OneD) {
    131   MakeOp(DT_FLOAT_REF, DT_INT32);
    132 
    133   // Feed and run
    134   AddInputFromArray<float>(TensorShape({5}), {0, 0, 0, 0, 0});
    135   AddInputFromArray<int32>(TensorShape({3}), {0, 4, 2});
    136   AddInputFromArray<float>(TensorShape({3}), {100, 101, 102});
    137   TF_ASSERT_OK(RunOpKernel());
    138 
    139   // Check the new state of the input
    140   Tensor params_tensor = *mutable_input(0).tensor;
    141   Tensor expected(allocator(), DT_FLOAT, TensorShape({5}));
    142   test::FillValues<float>(&expected, {100, 0, 102, 0, 101});
    143   test::ExpectTensorEqual<float>(expected, params_tensor);
    144 }
    145 
    146 TEST_F(ScatterUpdateOpTest, HigherRank) {
    147   MakeOp(DT_FLOAT_REF, DT_INT32);
    148 
    149   // Feed and run
    150   AddInputFromArray<float>(TensorShape({8}), {0, 0, 0, 0, 0, 0, 0, 0});
    151   AddInputFromArray<int32>(TensorShape({2, 3}), {0, 4, 2, 1, 3, 6});
    152   AddInputFromArray<float>(TensorShape({2, 3}), {10, 20, 30, 40, 50, 60});
    153   TF_ASSERT_OK(RunOpKernel());
    154 
    155   // Check the new state of the input
    156   Tensor params_tensor = *mutable_input(0).tensor;
    157   Tensor expected(allocator(), DT_FLOAT, TensorShape({8}));
    158   test::FillValues<float>(&expected, {10, 40, 30, 50, 20, 0, 60, 0});
    159   test::ExpectTensorEqual<float>(expected, params_tensor);
    160 }
    161 
    162 TEST_F(ScatterUpdateOpTest, Error_IndexOutOfRange) {
    163   MakeOp(DT_FLOAT_REF, DT_INT32);
    164 
    165   // Feed and run
    166   AddInputFromArray<float>(TensorShape({5, 3}),
    167                            {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
    168   AddInputFromArray<int32>(TensorShape({3}), {0, 4, 99});
    169   AddInputFromArray<float>(TensorShape({3, 3}),
    170                            {100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
    171   Status s = RunOpKernel();
    172   EXPECT_TRUE(
    173       StringPiece(s.ToString()).contains("indices[2] = 99 is not in [0, 5)"))
    174       << s;
    175 }
    176 
    177 TEST_F(ScatterUpdateOpTest, Error_WrongDimsIndices) {
    178   MakeOp(DT_FLOAT_REF, DT_INT32);
    179 
    180   // Feed and run
    181   AddInputFromArray<float>(TensorShape({2, 3}), {0, 0, 0, 0, 0, 0});
    182   AddInputFromArray<int32>(TensorShape({1, 3}), {0, 4, 99});
    183   AddInputFromArray<float>(TensorShape({3, 3}),
    184                            {100, 101, 102, 777, 778, 779, 10000, 10001, 10002});
    185   Status s = RunOpKernel();
    186   EXPECT_TRUE(StringPiece(s.ToString())
    187                   .contains("Must have updates.shape = indices.shape + "
    188                             "params.shape[1:], got "))
    189       << s;
    190 }
    191 
    192 TEST_F(ScatterUpdateOpTest, Error_MismatchedParamsAndUpdateDimensions) {
    193   MakeOp(DT_FLOAT_REF, DT_INT32);
    194 
    195   // Feed and run
    196   AddInputFromArray<float>(TensorShape({5, 3}),
    197                            {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
    198   AddInputFromArray<int32>(TensorShape({3}), {0, 4, 2});
    199   AddInputFromArray<float>(
    200       TensorShape({3, 4}),
    201       {100, 101, 102, 103, 777, 778, 779, 780, 10000, 10001, 10002, 10004});
    202   Status s = RunOpKernel();
    203   EXPECT_TRUE(StringPiece(s.ToString())
    204                   .contains("Must have updates.shape = indices.shape + "
    205                             "params.shape[1:], got "))
    206 
    207       << s;
    208 }
    209 
    210 TEST_F(ScatterUpdateOpTest, Error_MismatchedIndicesAndUpdateDimensions) {
    211   MakeOp(DT_FLOAT_REF, DT_INT32);
    212 
    213   // Feed and run
    214   AddInputFromArray<float>(TensorShape({5, 3}),
    215                            {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
    216   AddInputFromArray<int32>(TensorShape({3}), {0, 4, 2});
    217   AddInputFromArray<float>(TensorShape({2, 3}),
    218                            {100, 101, 102, 10000, 10001, 10002});
    219   Status s = RunOpKernel();
    220   EXPECT_TRUE(StringPiece(s.ToString())
    221                   .contains("Must have updates.shape = indices.shape + "
    222                             "params.shape[1:], got "))
    223       << s;
    224 }
    225 
    226 class ScatterUpdateBM : public ScatterUpdateOpTest {
    227  public:
    228   void TestBody() override {}
    229   void MakeBenchmarkOp(const char* op, DataType index_type) {
    230     TF_ASSERT_OK(NodeDefBuilder("myop", op)
    231                      .Input(FakeInput(DT_FLOAT_REF))
    232                      .Input(FakeInput(index_type))
    233                      .Input(FakeInput(DT_FLOAT))
    234                      .Finalize(node_def()));
    235     TF_CHECK_OK(InitOp());
    236   }
    237 };
    238 
    239 template <typename Index>
    240 static void BM_ScatterHelper(int iters, int embedding_size, const char* op) {
    241   testing::StopTiming();
    242   const int kRows = 10000000 / embedding_size;
    243   std::vector<float> values;
    244   values.reserve(kRows);
    245   for (int i = 0; i < kRows * embedding_size; i++) {
    246     values.push_back(i);
    247   }
    248   const int kNumUpdates = 1000;
    249   random::PhiloxRandom philox(301, 17);
    250   random::SimplePhilox rnd(&philox);
    251   std::vector<Index> indices;
    252   std::vector<float> updates;
    253   for (int i = 0; i < kNumUpdates; i++) {
    254     indices.push_back(rnd.Uniform(kRows));
    255     for (int j = 0; j < embedding_size; j++) {
    256       updates.push_back(i * 10 + j);
    257     }
    258   }
    259 
    260   ScatterUpdateBM bm;
    261   bm.MakeBenchmarkOp(op, DataTypeToEnum<Index>::v());
    262   bm.AddInputFromArray<float>(TensorShape({kRows, embedding_size}), values);
    263   bm.AddInputFromArray<Index>(TensorShape({kNumUpdates}), indices);
    264   bm.AddInputFromArray<float>(TensorShape({kNumUpdates, embedding_size}),
    265                               updates);
    266   testing::ItemsProcessed((static_cast<int64>(kNumUpdates) * embedding_size) *
    267                           iters);
    268   testing::StartTiming();
    269   while (iters-- > 0) {
    270     Status s = bm.RunOpKernel();
    271   }
    272   testing::StopTiming();
    273 }
    274 
    275 static void BM_ScatterUpdateInt32(int iters, int embedding_size) {
    276   BM_ScatterHelper<int32>(iters, embedding_size, "ScatterUpdate");
    277 }
    278 static void BM_ScatterUpdateInt64(int iters, int embedding_size) {
    279   BM_ScatterHelper<int64>(iters, embedding_size, "ScatterUpdate");
    280 }
    281 
    282 static void BM_ScatterAddInt32(int iters, int embedding_size) {
    283   BM_ScatterHelper<int32>(iters, embedding_size, "ScatterAdd");
    284 }
    285 static void BM_ScatterAddInt64(int iters, int embedding_size) {
    286   BM_ScatterHelper<int64>(iters, embedding_size, "ScatterAdd");
    287 }
    288 
    289 static void BM_ScatterMulInt32(int iters, int embedding_size) {
    290   BM_ScatterHelper<int32>(iters, embedding_size, "ScatterMul");
    291 }
    292 static void BM_ScatterMulInt64(int iters, int embedding_size) {
    293   BM_ScatterHelper<int64>(iters, embedding_size, "ScatterMul");
    294 }
    295 
    296 static void BM_ScatterDivInt32(int iters, int embedding_size) {
    297   BM_ScatterHelper<int32>(iters, embedding_size, "ScatterDiv");
    298 }
    299 static void BM_ScatterDivInt64(int iters, int embedding_size) {
    300   BM_ScatterHelper<int64>(iters, embedding_size, "ScatterDiv");
    301 }
    302 
    303 BENCHMARK(BM_ScatterUpdateInt32)
    304     ->Arg(1)
    305     ->Arg(10)
    306     ->Arg(32)
    307     ->Arg(50)
    308     ->Arg(64)
    309     ->Arg(80)
    310     ->Arg(96)
    311     ->Arg(112)
    312     ->Arg(192)
    313     ->Arg(256)
    314     ->Arg(1024)
    315     ->Arg(10000)
    316     ->Arg(100000)
    317     ->Arg(1000000);
    318 BENCHMARK(BM_ScatterUpdateInt64)
    319     ->Arg(1)
    320     ->Arg(10)
    321     ->Arg(64)
    322     ->Arg(256)
    323     ->Arg(1024)
    324     ->Arg(100000);
    325 
    326 BENCHMARK(BM_ScatterAddInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
    327 BENCHMARK(BM_ScatterAddInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
    328 
    329 BENCHMARK(BM_ScatterMulInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
    330 BENCHMARK(BM_ScatterMulInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
    331 
    332 BENCHMARK(BM_ScatterDivInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
    333 BENCHMARK(BM_ScatterDivInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024);
    334 
    335 }  // namespace
    336 }  // namespace tensorflow
    337