1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <functional> 17 #include <memory> 18 #include <vector> 19 20 #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" 21 #include "tensorflow/core/framework/allocator.h" 22 #include "tensorflow/core/framework/fake_input.h" 23 #include "tensorflow/core/framework/node_def_builder.h" 24 #include "tensorflow/core/framework/op_kernel.h" 25 #include "tensorflow/core/framework/tensor.h" 26 #include "tensorflow/core/framework/types.h" 27 #include "tensorflow/core/framework/types.pb.h" 28 #include "tensorflow/core/graph/testlib.h" 29 #include "tensorflow/core/kernels/ops_testutil.h" 30 #include "tensorflow/core/kernels/ops_util.h" 31 #include "tensorflow/core/lib/core/status_test_util.h" 32 #include "tensorflow/core/lib/gtl/array_slice.h" 33 #include "tensorflow/core/lib/random/simple_philox.h" 34 #include "tensorflow/core/platform/test.h" 35 #include "tensorflow/core/platform/test_benchmark.h" 36 37 namespace tensorflow { 38 namespace { 39 40 class GatherOpTest : public OpsTestBase { 41 protected: 42 void MakeOp(DataType data_type, DataType index_type) { 43 TF_ASSERT_OK(NodeDefBuilder("myop", "GatherV2") 44 .Input(FakeInput(data_type)) 45 .Input(FakeInput(index_type)) 46 .Input(FakeInput(index_type)) 47 .Finalize(node_def())); 48 TF_ASSERT_OK(InitOp()); 49 } 50 }; 51 52 TEST_F(GatherOpTest, ScalarIndices) { 53 MakeOp(DT_FLOAT, DT_INT32); 54 55 // Feed and run 56 AddInputFromArray<float>(TensorShape({5}), {0, 1, 2, 3, 4}); 57 AddInputFromArray<int32>(TensorShape({}), {3}); 58 AddInputFromArray<int32>(TensorShape({}), {0}); 59 TF_ASSERT_OK(RunOpKernel()); 60 61 // Check the output. 62 Tensor expected(allocator(), DT_FLOAT, TensorShape({})); 63 test::FillValues<float>(&expected, {3}); 64 test::ExpectTensorEqual<float>(expected, *GetOutput(0)); 65 } 66 67 TEST_F(GatherOpTest, ScalarIndices_Complex) { 68 MakeOp(DT_COMPLEX64, DT_INT32); 69 70 // Feed and run 71 AddInputFromArray<std::complex<float>>( 72 TensorShape({5}), {std::complex<float>(0, 10), std::complex<float>(1, 11), 73 std::complex<float>(2, 12), std::complex<float>(3, 13), 74 std::complex<float>(4, 14)}); 75 AddInputFromArray<int32>(TensorShape({}), {3}); 76 AddInputFromArray<int32>(TensorShape({}), {0}); 77 TF_ASSERT_OK(RunOpKernel()); 78 79 // Check the output. 80 Tensor expected(allocator(), DT_COMPLEX64, TensorShape({})); 81 test::FillValues<std::complex<float>>(&expected, 82 {std::complex<float>(3, 13)}); 83 test::ExpectTensorEqual<std::complex<float>>(expected, *GetOutput(0)); 84 } 85 86 TEST_F(GatherOpTest, Simple_TwoD32_Axis0) { 87 MakeOp(DT_FLOAT, DT_INT32); 88 89 // Feed and run 90 AddInputFromArray<float>(TensorShape({5, 3}), 91 {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14}); 92 AddInputFromArray<int32>(TensorShape({4}), {0, 4, 0, 2}); 93 AddInputFromArray<int32>(TensorShape({}), {0}); 94 TF_ASSERT_OK(RunOpKernel()); 95 96 // Check the output. 97 Tensor expected(allocator(), DT_FLOAT, TensorShape({4, 3})); 98 test::FillValues<float>(&expected, {0, 1, 2, 12, 13, 14, 0, 1, 2, 6, 7, 8}); 99 test::ExpectTensorEqual<float>(expected, *GetOutput(0)); 100 } 101 102 TEST_F(GatherOpTest, Simple_TwoD32_Axis1) { 103 MakeOp(DT_FLOAT, DT_INT32); 104 105 // Feed and run 106 AddInputFromArray<float>(TensorShape({5, 3}), 107 {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14}); 108 AddInputFromArray<int32>(TensorShape({4}), {0, 1, 0, 2}); 109 AddInputFromArray<int32>(TensorShape({}), {1}); 110 TF_ASSERT_OK(RunOpKernel()); 111 112 // Check the output. 113 Tensor expected(allocator(), DT_FLOAT, TensorShape({5, 4})); 114 test::FillValues<float>(&expected, {0, 1, 0, 2, 3, 4, 3, 5, 6, 7, 115 6, 8, 9, 10, 9, 11, 12, 13, 12, 14}); 116 test::ExpectTensorEqual<float>(expected, *GetOutput(0)); 117 } 118 119 TEST_F(GatherOpTest, ZeroSize_TwoD32) { 120 MakeOp(DT_FLOAT, DT_INT32); 121 122 // Feed and run 123 AddInputFromArray<float>(TensorShape({5, 0}), {}); 124 AddInputFromArray<int32>(TensorShape({4}), {0, 4, 0, 2}); 125 AddInputFromArray<int32>(TensorShape({}), {0}); 126 TF_ASSERT_OK(RunOpKernel()); 127 128 // Check the output. 129 Tensor expected(allocator(), DT_FLOAT, TensorShape({4, 0})); 130 test::ExpectTensorEqual<float>(expected, *GetOutput(0)); 131 } 132 133 TEST_F(GatherOpTest, Simple_TwoD64) { 134 MakeOp(DT_FLOAT, DT_INT64); 135 136 // Feed and run 137 AddInputFromArray<float>(TensorShape({5, 3}), 138 {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14}); 139 AddInputFromArray<int64>(TensorShape({4}), {0, 4, 0, 2}); 140 AddInputFromArray<int64>(TensorShape({}), {0}); 141 TF_ASSERT_OK(RunOpKernel()); 142 143 // Check the output. 144 Tensor expected(allocator(), DT_FLOAT, TensorShape({4, 3})); 145 test::FillValues<float>(&expected, {0, 1, 2, 12, 13, 14, 0, 1, 2, 6, 7, 8}); 146 test::ExpectTensorEqual<float>(expected, *GetOutput(0)); 147 } 148 149 TEST_F(GatherOpTest, HighRank) { 150 MakeOp(DT_FLOAT, DT_INT32); 151 152 // Feed and run 153 AddInputFromArray<float>(TensorShape({4}), {0, 1, 2, 3}); 154 AddInputFromArray<int32>(TensorShape({2, 3}), {1, 2, 0, 2, 3, 0}); 155 AddInputFromArray<int32>(TensorShape({}), {0}); 156 TF_ASSERT_OK(RunOpKernel()); 157 158 // Check the output 159 Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 3})); 160 test::FillValues<float>(&expected, {1, 2, 0, 2, 3, 0}); 161 test::ExpectTensorEqual<float>(expected, *GetOutput(0)); 162 } 163 164 TEST_F(GatherOpTest, Error_IndexOutOfRange) { 165 MakeOp(DT_FLOAT, DT_INT32); 166 167 // Feed and run 168 AddInputFromArray<float>(TensorShape({5, 3}), 169 {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14}); 170 AddInputFromArray<int32>(TensorShape({4}), {0, 4, 99, 2}); 171 AddInputFromArray<int32>(TensorShape({}), {0}); 172 Status s = RunOpKernel(); 173 EXPECT_TRUE( 174 StringPiece(s.ToString()).contains("indices[2] = 99 is not in [0, 5)")) 175 << s; 176 } 177 178 constexpr int kLookups = 2000; 179 180 template <typename Index> 181 static Graph* Gather(int dim) { 182 Graph* g = new Graph(OpRegistry::Global()); 183 // Always use a 512MB buffer. 184 const int kRows = ((512 << 20) / sizeof(float)) / dim; 185 Tensor params(DT_FLOAT, TensorShape({kRows, dim})); 186 params.flat<float>().setRandom(); 187 188 random::PhiloxRandom philox(301, 17); 189 random::SimplePhilox rnd(&philox); 190 std::vector<Index> indices_vec; 191 indices_vec.reserve(kLookups); 192 for (int i = 0; i < kLookups; i++) { 193 indices_vec.push_back(rnd.Uniform(kRows)); 194 } 195 Tensor indices(DataTypeToEnum<Index>::value, TensorShape({kLookups})); 196 for (int i = 0; i < indices_vec.size(); i++) { 197 indices.flat<Index>()(i) = indices_vec[i]; 198 } 199 200 Tensor axis(DataTypeToEnum<Index>::value, TensorShape({})); 201 axis.scalar<Index>()() = 0; 202 203 test::graph::Gather(g, test::graph::Constant(g, params), 204 test::graph::Constant(g, indices), 205 test::graph::HostConstant(g, axis)); 206 return g; 207 } 208 209 #define BM_GATHER(DEVICE, INDEX) \ 210 static void BM_##DEVICE##_gather_##INDEX(int iters, int dim) { \ 211 const int64 tot = static_cast<int64>(iters) * kLookups * dim; \ 212 testing::ItemsProcessed(tot); \ 213 testing::BytesProcessed(tot * sizeof(float)); \ 214 testing::UseRealTime(); \ 215 test::Benchmark(#DEVICE, Gather<INDEX>(dim)).Run(iters); \ 216 } \ 217 BENCHMARK(BM_##DEVICE##_gather_##INDEX) \ 218 ->Arg(1) \ 219 ->Arg(10) \ 220 ->Arg(20) \ 221 ->Arg(64) \ 222 ->Arg(100) \ 223 ->Arg(200) \ 224 ->Arg(1000) 225 226 BM_GATHER(cpu, int32); 227 BM_GATHER(gpu, int32); 228 BM_GATHER(cpu, int64); 229 BM_GATHER(gpu, int64); 230 231 } // namespace 232 } // namespace tensorflow 233