1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <functional> 17 #include <memory> 18 #include <vector> 19 20 #include "tensorflow/core/framework/allocator.h" 21 #include "tensorflow/core/framework/fake_input.h" 22 #include "tensorflow/core/framework/node_def_builder.h" 23 #include "tensorflow/core/framework/op_kernel.h" 24 #include "tensorflow/core/framework/tensor.h" 25 #include "tensorflow/core/framework/types.h" 26 #include "tensorflow/core/framework/types.pb.h" 27 #include "tensorflow/core/kernels/ops_testutil.h" 28 #include "tensorflow/core/kernels/ops_util.h" 29 #include "tensorflow/core/lib/core/status_test_util.h" 30 #include "tensorflow/core/lib/random/simple_philox.h" 31 #include "tensorflow/core/platform/logging.h" 32 #include "tensorflow/core/platform/test.h" 33 #include "tensorflow/core/platform/test_benchmark.h" 34 35 namespace tensorflow { 36 namespace { 37 38 class ScatterUpdateOpTest : public OpsTestBase { 39 protected: 40 void MakeOp(DataType variable_ref_type, DataType index_type) { 41 TF_ASSERT_OK(NodeDefBuilder("myop", "ScatterUpdate") 42 .Input(FakeInput(variable_ref_type)) 43 .Input(FakeInput(index_type)) 44 .Input(FakeInput(RemoveRefType(variable_ref_type))) 45 .Finalize(node_def())); 46 TF_ASSERT_OK(InitOp()); 47 } 48 }; 49 50 TEST_F(ScatterUpdateOpTest, Simple_StringType) { 51 MakeOp(DT_STRING_REF, DT_INT32); 52 AddInputFromArray<string>(TensorShape({1}), {"Brain"}); 53 AddInputFromArray<int32>(TensorShape({1}), {0}); 54 AddInputFromArray<string>(TensorShape({1}), {"TensorFlow"}); 55 TF_ASSERT_OK(RunOpKernel()); 56 // Check the new state of the input 57 Tensor params_tensor = *mutable_input(0).tensor; 58 Tensor expected(allocator(), DT_STRING, TensorShape({1})); 59 test::FillValues<string>(&expected, {"TensorFlow"}); 60 test::ExpectTensorEqual<string>(expected, params_tensor); 61 } 62 63 TEST_F(ScatterUpdateOpTest, Simple_BoolType) { 64 MakeOp(DT_BOOL_REF, DT_INT32); 65 AddInputFromArray<bool>(TensorShape({1}), {false}); 66 AddInputFromArray<int32>(TensorShape({1}), {0}); 67 AddInputFromArray<bool>(TensorShape({1}), {true}); 68 TF_ASSERT_OK(RunOpKernel()); 69 // Check the new state of the input 70 Tensor params_tensor = *mutable_input(0).tensor; 71 Tensor expected(allocator(), DT_BOOL, TensorShape({1})); 72 test::FillValues<bool>(&expected, {true}); 73 test::ExpectTensorEqual<bool>(expected, params_tensor); 74 } 75 76 TEST_F(ScatterUpdateOpTest, Simple_TwoD32) { 77 MakeOp(DT_FLOAT_REF, DT_INT32); 78 79 // Feed and run 80 AddInputFromArray<float>(TensorShape({5, 3}), 81 {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); 82 AddInputFromArray<int32>(TensorShape({3}), {0, 4, 2}); 83 AddInputFromArray<float>(TensorShape({3, 3}), 84 {100, 101, 102, 777, 778, 779, 10000, 10001, 10002}); 85 TF_ASSERT_OK(RunOpKernel()); 86 87 // Check the new state of the input 88 Tensor params_tensor = *mutable_input(0).tensor; 89 Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 3})); 90 test::FillValues<float>(&expected, {100, 101, 102, 0, 0, 0, 10000, 10001, 91 10002, 0, 0, 0, 777, 778, 779}); 92 test::ExpectTensorEqual<float>(expected, params_tensor); 93 } 94 95 TEST_F(ScatterUpdateOpTest, Simple_Two64) { 96 MakeOp(DT_FLOAT_REF, DT_INT64); 97 98 // Feed and run 99 AddInputFromArray<float>(TensorShape({5, 3}), 100 {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); 101 AddInputFromArray<int64>(TensorShape({3}), {0, 4, 2}); 102 AddInputFromArray<float>(TensorShape({3, 3}), 103 {100, 101, 102, 777, 778, 779, 10000, 10001, 10002}); 104 TF_ASSERT_OK(RunOpKernel()); 105 106 // Check the new state of the input 107 Tensor params_tensor = *mutable_input(0).tensor; 108 Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 3})); 109 test::FillValues<float>(&expected, {100, 101, 102, 0, 0, 0, 10000, 10001, 110 10002, 0, 0, 0, 777, 778, 779}); 111 test::ExpectTensorEqual<float>(expected, params_tensor); 112 } 113 114 TEST_F(ScatterUpdateOpTest, Simple_ZeroD) { 115 MakeOp(DT_FLOAT_REF, DT_INT32); 116 117 // Feed and run 118 AddInputFromArray<float>(TensorShape({5}), {0, 0, 0, 0, 0}); 119 AddInputFromArray<int32>(TensorShape({}), {3}); 120 AddInputFromArray<float>(TensorShape({}), {101}); 121 TF_ASSERT_OK(RunOpKernel()); 122 123 // Check the new state of the input 124 Tensor params_tensor = *mutable_input(0).tensor; 125 Tensor expected(allocator(), DT_FLOAT, TensorShape({5})); 126 test::FillValues<float>(&expected, {0, 0, 0, 101, 0}); 127 test::ExpectTensorEqual<float>(expected, params_tensor); 128 } 129 130 TEST_F(ScatterUpdateOpTest, Simple_OneD) { 131 MakeOp(DT_FLOAT_REF, DT_INT32); 132 133 // Feed and run 134 AddInputFromArray<float>(TensorShape({5}), {0, 0, 0, 0, 0}); 135 AddInputFromArray<int32>(TensorShape({3}), {0, 4, 2}); 136 AddInputFromArray<float>(TensorShape({3}), {100, 101, 102}); 137 TF_ASSERT_OK(RunOpKernel()); 138 139 // Check the new state of the input 140 Tensor params_tensor = *mutable_input(0).tensor; 141 Tensor expected(allocator(), DT_FLOAT, TensorShape({5})); 142 test::FillValues<float>(&expected, {100, 0, 102, 0, 101}); 143 test::ExpectTensorEqual<float>(expected, params_tensor); 144 } 145 146 TEST_F(ScatterUpdateOpTest, HigherRank) { 147 MakeOp(DT_FLOAT_REF, DT_INT32); 148 149 // Feed and run 150 AddInputFromArray<float>(TensorShape({8}), {0, 0, 0, 0, 0, 0, 0, 0}); 151 AddInputFromArray<int32>(TensorShape({2, 3}), {0, 4, 2, 1, 3, 6}); 152 AddInputFromArray<float>(TensorShape({2, 3}), {10, 20, 30, 40, 50, 60}); 153 TF_ASSERT_OK(RunOpKernel()); 154 155 // Check the new state of the input 156 Tensor params_tensor = *mutable_input(0).tensor; 157 Tensor expected(allocator(), DT_FLOAT, TensorShape({8})); 158 test::FillValues<float>(&expected, {10, 40, 30, 50, 20, 0, 60, 0}); 159 test::ExpectTensorEqual<float>(expected, params_tensor); 160 } 161 162 TEST_F(ScatterUpdateOpTest, Error_IndexOutOfRange) { 163 MakeOp(DT_FLOAT_REF, DT_INT32); 164 165 // Feed and run 166 AddInputFromArray<float>(TensorShape({5, 3}), 167 {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); 168 AddInputFromArray<int32>(TensorShape({3}), {0, 4, 99}); 169 AddInputFromArray<float>(TensorShape({3, 3}), 170 {100, 101, 102, 777, 778, 779, 10000, 10001, 10002}); 171 Status s = RunOpKernel(); 172 EXPECT_TRUE( 173 StringPiece(s.ToString()).contains("indices[2] = 99 is not in [0, 5)")) 174 << s; 175 } 176 177 TEST_F(ScatterUpdateOpTest, Error_WrongDimsIndices) { 178 MakeOp(DT_FLOAT_REF, DT_INT32); 179 180 // Feed and run 181 AddInputFromArray<float>(TensorShape({2, 3}), {0, 0, 0, 0, 0, 0}); 182 AddInputFromArray<int32>(TensorShape({1, 3}), {0, 4, 99}); 183 AddInputFromArray<float>(TensorShape({3, 3}), 184 {100, 101, 102, 777, 778, 779, 10000, 10001, 10002}); 185 Status s = RunOpKernel(); 186 EXPECT_TRUE(StringPiece(s.ToString()) 187 .contains("Must have updates.shape = indices.shape + " 188 "params.shape[1:], got ")) 189 << s; 190 } 191 192 TEST_F(ScatterUpdateOpTest, Error_MismatchedParamsAndUpdateDimensions) { 193 MakeOp(DT_FLOAT_REF, DT_INT32); 194 195 // Feed and run 196 AddInputFromArray<float>(TensorShape({5, 3}), 197 {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); 198 AddInputFromArray<int32>(TensorShape({3}), {0, 4, 2}); 199 AddInputFromArray<float>( 200 TensorShape({3, 4}), 201 {100, 101, 102, 103, 777, 778, 779, 780, 10000, 10001, 10002, 10004}); 202 Status s = RunOpKernel(); 203 EXPECT_TRUE(StringPiece(s.ToString()) 204 .contains("Must have updates.shape = indices.shape + " 205 "params.shape[1:], got ")) 206 207 << s; 208 } 209 210 TEST_F(ScatterUpdateOpTest, Error_MismatchedIndicesAndUpdateDimensions) { 211 MakeOp(DT_FLOAT_REF, DT_INT32); 212 213 // Feed and run 214 AddInputFromArray<float>(TensorShape({5, 3}), 215 {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); 216 AddInputFromArray<int32>(TensorShape({3}), {0, 4, 2}); 217 AddInputFromArray<float>(TensorShape({2, 3}), 218 {100, 101, 102, 10000, 10001, 10002}); 219 Status s = RunOpKernel(); 220 EXPECT_TRUE(StringPiece(s.ToString()) 221 .contains("Must have updates.shape = indices.shape + " 222 "params.shape[1:], got ")) 223 << s; 224 } 225 226 class ScatterUpdateBM : public ScatterUpdateOpTest { 227 public: 228 void TestBody() override {} 229 void MakeBenchmarkOp(const char* op, DataType index_type) { 230 TF_ASSERT_OK(NodeDefBuilder("myop", op) 231 .Input(FakeInput(DT_FLOAT_REF)) 232 .Input(FakeInput(index_type)) 233 .Input(FakeInput(DT_FLOAT)) 234 .Finalize(node_def())); 235 TF_CHECK_OK(InitOp()); 236 } 237 }; 238 239 template <typename Index> 240 static void BM_ScatterHelper(int iters, int embedding_size, const char* op) { 241 testing::StopTiming(); 242 const int kRows = 10000000 / embedding_size; 243 std::vector<float> values; 244 values.reserve(kRows); 245 for (int i = 0; i < kRows * embedding_size; i++) { 246 values.push_back(i); 247 } 248 const int kNumUpdates = 1000; 249 random::PhiloxRandom philox(301, 17); 250 random::SimplePhilox rnd(&philox); 251 std::vector<Index> indices; 252 std::vector<float> updates; 253 for (int i = 0; i < kNumUpdates; i++) { 254 indices.push_back(rnd.Uniform(kRows)); 255 for (int j = 0; j < embedding_size; j++) { 256 updates.push_back(i * 10 + j); 257 } 258 } 259 260 ScatterUpdateBM bm; 261 bm.MakeBenchmarkOp(op, DataTypeToEnum<Index>::v()); 262 bm.AddInputFromArray<float>(TensorShape({kRows, embedding_size}), values); 263 bm.AddInputFromArray<Index>(TensorShape({kNumUpdates}), indices); 264 bm.AddInputFromArray<float>(TensorShape({kNumUpdates, embedding_size}), 265 updates); 266 testing::ItemsProcessed((static_cast<int64>(kNumUpdates) * embedding_size) * 267 iters); 268 testing::StartTiming(); 269 while (iters-- > 0) { 270 Status s = bm.RunOpKernel(); 271 } 272 testing::StopTiming(); 273 } 274 275 static void BM_ScatterUpdateInt32(int iters, int embedding_size) { 276 BM_ScatterHelper<int32>(iters, embedding_size, "ScatterUpdate"); 277 } 278 static void BM_ScatterUpdateInt64(int iters, int embedding_size) { 279 BM_ScatterHelper<int64>(iters, embedding_size, "ScatterUpdate"); 280 } 281 282 static void BM_ScatterAddInt32(int iters, int embedding_size) { 283 BM_ScatterHelper<int32>(iters, embedding_size, "ScatterAdd"); 284 } 285 static void BM_ScatterAddInt64(int iters, int embedding_size) { 286 BM_ScatterHelper<int64>(iters, embedding_size, "ScatterAdd"); 287 } 288 289 static void BM_ScatterMulInt32(int iters, int embedding_size) { 290 BM_ScatterHelper<int32>(iters, embedding_size, "ScatterMul"); 291 } 292 static void BM_ScatterMulInt64(int iters, int embedding_size) { 293 BM_ScatterHelper<int64>(iters, embedding_size, "ScatterMul"); 294 } 295 296 static void BM_ScatterDivInt32(int iters, int embedding_size) { 297 BM_ScatterHelper<int32>(iters, embedding_size, "ScatterDiv"); 298 } 299 static void BM_ScatterDivInt64(int iters, int embedding_size) { 300 BM_ScatterHelper<int64>(iters, embedding_size, "ScatterDiv"); 301 } 302 303 BENCHMARK(BM_ScatterUpdateInt32) 304 ->Arg(1) 305 ->Arg(10) 306 ->Arg(32) 307 ->Arg(50) 308 ->Arg(64) 309 ->Arg(80) 310 ->Arg(96) 311 ->Arg(112) 312 ->Arg(192) 313 ->Arg(256) 314 ->Arg(1024) 315 ->Arg(10000) 316 ->Arg(100000) 317 ->Arg(1000000); 318 BENCHMARK(BM_ScatterUpdateInt64) 319 ->Arg(1) 320 ->Arg(10) 321 ->Arg(64) 322 ->Arg(256) 323 ->Arg(1024) 324 ->Arg(100000); 325 326 BENCHMARK(BM_ScatterAddInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024); 327 BENCHMARK(BM_ScatterAddInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024); 328 329 BENCHMARK(BM_ScatterMulInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024); 330 BENCHMARK(BM_ScatterMulInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024); 331 332 BENCHMARK(BM_ScatterDivInt32)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024); 333 BENCHMARK(BM_ScatterDivInt64)->Arg(1)->Arg(10)->Arg(64)->Arg(256)->Arg(1024); 334 335 } // namespace 336 } // namespace tensorflow 337