1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #define EIGEN_USE_THREADS 17 18 #include <functional> 19 #include <memory> 20 #include <vector> 21 22 #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" 23 #include "tensorflow/core/framework/allocator.h" 24 #include "tensorflow/core/framework/fake_input.h" 25 #include "tensorflow/core/framework/node_def_builder.h" 26 #include "tensorflow/core/framework/op_kernel.h" 27 #include "tensorflow/core/framework/tensor.h" 28 #include "tensorflow/core/framework/tensor_testutil.h" 29 #include "tensorflow/core/framework/types.h" 30 #include "tensorflow/core/framework/types.pb.h" 31 #include "tensorflow/core/graph/node_builder.h" 32 #include "tensorflow/core/kernels/ops_testutil.h" 33 #include "tensorflow/core/kernels/ops_util.h" 34 #include "tensorflow/core/kernels/quantization_utils.h" 35 #include "tensorflow/core/lib/core/status.h" 36 #include "tensorflow/core/lib/core/status_test_util.h" 37 #include "tensorflow/core/platform/test.h" 38 #include "tensorflow/core/platform/test_benchmark.h" 39 40 namespace tensorflow { 41 42 using test::graph::Constant; 43 44 class QuantizedConcatTest : public OpsTestBase { 45 protected: 46 QuantizedConcatTest() {} 47 48 void TestSmall8Bit(float first_min, float first_max, float second_min, 49 float second_max); 50 void TestSmall32Bit(float first_min, float first_max, float second_min, 51 float second_max); 52 void TestSecondDim8Bit(float first_min, float first_max, float second_min, 53 float second_max); 54 }; 55 56 TEST_F(QuantizedConcatTest, Small8Bit) { 57 TestSmall8Bit(0.0f, 255.0f, 0.0f, 25.0f); 58 } 59 60 TEST_F(QuantizedConcatTest, Small8BitSameRange) { 61 // Range for both is the same, so impl can use memcpy. 62 TestSmall8Bit(0.0f, 255.0f, 0.0f, 255.0f); 63 } 64 65 void QuantizedConcatTest::TestSmall8Bit(float first_min, float first_max, 66 float second_min, float second_max) { 67 TF_ASSERT_OK(NodeDefBuilder("quantized_concat_op", "QuantizedConcat") 68 .Input(FakeInput(DT_INT32)) 69 .Input(FakeInput(2, DT_QUINT8)) 70 .Input(FakeInput(2, DT_FLOAT)) 71 .Input(FakeInput(2, DT_FLOAT)) 72 .Attr("N", 2) 73 .Attr("T", DataTypeToEnum<quint8>::v()) 74 .Finalize(node_def())); 75 TF_ASSERT_OK(InitOp()); 76 const int first_batch = 2; 77 const int first_height = 2; 78 const int first_width = 3; 79 Tensor first_float(DT_FLOAT, {first_batch, first_height, first_width}); 80 test::FillValues<float>(&first_float, 81 {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); 82 Tensor first_quantized = 83 FloatTensorToQuantized<quint8>(first_float, first_min, first_max); 84 85 const int second_batch = 2; 86 const int second_height = 2; 87 const int second_width = 3; 88 Tensor second_float(DT_FLOAT, {second_batch, second_height, second_width}); 89 test::FillValues<float>(&second_float, 90 {13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); 91 Tensor second_quantized = 92 FloatTensorToQuantized<quint8>(second_float, second_min, second_max); 93 94 const int expected_batch = first_batch + second_batch; 95 Tensor expected_float(DT_FLOAT, {expected_batch, first_height, first_width}); 96 test::FillValues<float>(&expected_float, 97 {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 98 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); 99 100 AddInputFromArray<int32>(TensorShape({}), {0}); 101 AddInputFromArray<quint8>(first_quantized.shape(), 102 first_quantized.flat<quint8>()); 103 AddInputFromArray<quint8>(second_quantized.shape(), 104 second_quantized.flat<quint8>()); 105 AddInputFromArray<float>(TensorShape({}), {first_min}); 106 AddInputFromArray<float>(TensorShape({}), {second_min}); 107 AddInputFromArray<float>(TensorShape({}), {first_max}); 108 AddInputFromArray<float>(TensorShape({}), {second_max}); 109 TF_ASSERT_OK(RunOpKernel()); 110 const Tensor& output_quantized = *GetOutput(0); 111 const float output_min = GetOutput(1)->flat<float>()(0); 112 const float output_max = GetOutput(2)->flat<float>()(0); 113 Tensor output_float = 114 QuantizedTensorToFloat<quint8>(output_quantized, output_min, output_max); 115 test::ExpectTensorNear<float>(expected_float, output_float, 0.2); 116 } 117 118 TEST_F(QuantizedConcatTest, Small32Bit) { 119 TestSmall32Bit(0.0f, 1200.0f, 0.0f, 2400.0f); 120 } 121 122 TEST_F(QuantizedConcatTest, Small32BitSameRange) { 123 TestSmall32Bit(-2400.0f, 2400.0f, -2400.0f, 2400.0f); 124 } 125 126 TEST_F(QuantizedConcatTest, Small32BitOneDimSameRangeAsOutput) { 127 TestSmall32Bit(-2400.0f, 2400.0f, -1200.0f, 2400.0f); 128 } 129 130 void QuantizedConcatTest::TestSmall32Bit(float first_min, float first_max, 131 float second_min, float second_max) { 132 TF_ASSERT_OK(NodeDefBuilder("quantized_concat_op", "QuantizedConcat") 133 .Input(FakeInput(DT_INT32)) 134 .Input(FakeInput(2, DT_QINT32)) 135 .Input(FakeInput(2, DT_FLOAT)) 136 .Input(FakeInput(2, DT_FLOAT)) 137 .Attr("N", 2) 138 .Attr("T", DataTypeToEnum<qint32>::v()) 139 .Finalize(node_def())); 140 TF_ASSERT_OK(InitOp()); 141 const int first_batch = 2; 142 const int first_height = 2; 143 const int first_width = 3; 144 Tensor first_float(DT_FLOAT, {first_batch, first_height, first_width}); 145 test::FillValues<float>(&first_float, {100, 200, 300, 400, 500, 600, 700, 800, 146 900, 1000, 1100, 1200}); 147 Tensor first_quantized = 148 FloatTensorToQuantized<qint32>(first_float, first_min, first_max); 149 150 const int second_batch = 2; 151 const int second_height = 2; 152 const int second_width = 3; 153 Tensor second_float(DT_FLOAT, {second_batch, second_height, second_width}); 154 test::FillValues<float>(&second_float, {1300, 1400, 1500, 1600, 1700, 1800, 155 1900, 2000, 2100, 2200, 2300, 2400}); 156 Tensor second_quantized = 157 FloatTensorToQuantized<qint32>(second_float, second_min, second_max); 158 159 const int expected_batch = first_batch + second_batch; 160 Tensor expected_float(DT_FLOAT, {expected_batch, first_height, first_width}); 161 test::FillValues<float>( 162 &expected_float, 163 {100, 200, 300, 400, 500, 600, 700, 800, 900, 1000, 1100, 1200, 164 1300, 1400, 1500, 1600, 1700, 1800, 1900, 2000, 2100, 2200, 2300, 2400}); 165 166 AddInputFromArray<int32>(TensorShape({}), {0}); 167 AddInputFromArray<qint32>(first_quantized.shape(), 168 first_quantized.flat<qint32>()); 169 AddInputFromArray<qint32>(second_quantized.shape(), 170 second_quantized.flat<qint32>()); 171 AddInputFromArray<float>(TensorShape({}), {first_min}); 172 AddInputFromArray<float>(TensorShape({}), {second_min}); 173 AddInputFromArray<float>(TensorShape({}), {first_max}); 174 AddInputFromArray<float>(TensorShape({}), {second_max}); 175 TF_ASSERT_OK(RunOpKernel()); 176 const Tensor& output_quantized = *GetOutput(0); 177 const float output_min = GetOutput(1)->flat<float>()(0); 178 const float output_max = GetOutput(2)->flat<float>()(0); 179 Tensor output_float = 180 QuantizedTensorToFloat<qint32>(output_quantized, output_min, output_max); 181 test::ExpectTensorNear<float>(expected_float, output_float, 0.2); 182 } 183 184 TEST_F(QuantizedConcatTest, SecondDim8Bit) { 185 TestSecondDim8Bit(-10.0f, 150.0f, 0.0f, 200.0f); 186 } 187 188 TEST_F(QuantizedConcatTest, SecondDim8BitSameRange) { 189 TestSecondDim8Bit(-10.0f, 150.0f, -10.0f, 150.0f); 190 } 191 192 void QuantizedConcatTest::TestSecondDim8Bit(float first_min, float first_max, 193 float second_min, 194 float second_max) { 195 TF_ASSERT_OK(NodeDefBuilder("quantized_concat_op", "QuantizedConcat") 196 .Input(FakeInput(DT_INT32)) 197 .Input(FakeInput(2, DT_QUINT8)) 198 .Input(FakeInput(2, DT_FLOAT)) 199 .Input(FakeInput(2, DT_FLOAT)) 200 .Attr("N", 2) 201 .Attr("T", DataTypeToEnum<quint8>::v()) 202 .Finalize(node_def())); 203 TF_ASSERT_OK(InitOp()); 204 const int first_batch = 2; 205 const int first_height = 2; 206 const int first_width = 3; 207 Tensor first_float(DT_FLOAT, {first_batch, first_height, first_width}); 208 test::FillValues<float>(&first_float, 209 {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); 210 Tensor first_quantized = 211 FloatTensorToQuantized<quint8>(first_float, first_min, first_max); 212 213 const int second_batch = 2; 214 const int second_height = 2; 215 const int second_width = 3; 216 Tensor second_float(DT_FLOAT, {second_batch, second_height, second_width}); 217 test::FillValues<float>(&second_float, 218 {13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); 219 Tensor second_quantized = 220 FloatTensorToQuantized<quint8>(second_float, second_min, second_max); 221 222 const int expected_height = first_height + second_height; 223 Tensor expected_float(DT_FLOAT, {first_batch, expected_height, first_width}); 224 test::FillValues<float>(&expected_float, 225 {1, 2, 3, 4, 5, 6, 13, 14, 15, 16, 17, 18, 226 7, 8, 9, 10, 11, 12, 19, 20, 21, 22, 23, 24}); 227 228 AddInputFromArray<int32>(TensorShape({}), {1}); 229 AddInputFromArray<quint8>(first_quantized.shape(), 230 first_quantized.flat<quint8>()); 231 AddInputFromArray<quint8>(second_quantized.shape(), 232 second_quantized.flat<quint8>()); 233 AddInputFromArray<float>(TensorShape({}), {first_min}); 234 AddInputFromArray<float>(TensorShape({}), {second_min}); 235 AddInputFromArray<float>(TensorShape({}), {first_max}); 236 AddInputFromArray<float>(TensorShape({}), {second_max}); 237 TF_ASSERT_OK(RunOpKernel()); 238 const Tensor& output_quantized = *GetOutput(0); 239 const float output_min = GetOutput(1)->flat<float>()(0); 240 const float output_max = GetOutput(2)->flat<float>()(0); 241 Tensor output_float = 242 QuantizedTensorToFloat<quint8>(output_quantized, output_min, output_max); 243 test::ExpectTensorNear<float>(expected_float, output_float, 1.0); 244 } 245 246 // For the benchmark, we set up two 2-dimensional tensors, each kDim1 x 'dim' 247 // in size, and concat them together along "concat_dimension". 248 // If <same_limits> is true, then both concatenated dimensions have the same 249 // quantized range; otherwise, they are set to different values. 250 template <typename T> 251 static void ConcatHelper(int iters, int concat_dimension, bool same_limits, 252 int dim2) { 253 testing::StopTiming(); 254 Graph* g = new Graph(OpRegistry::Global()); 255 256 DataType dt = DataTypeToEnum<T>::v(); 257 const int kDim1 = 100; 258 TensorShape shape({kDim1, dim2}); 259 260 Tensor concat_dim = test::AsScalar<int32>(concat_dimension); 261 Tensor in0(dt, shape); 262 in0.flat<T>().setRandom(); 263 Tensor in1(dt, shape); 264 in1.flat<T>().setRandom(); 265 266 Tensor mins0 = test::AsScalar<float>(-1.0); 267 Tensor maxes0 = test::AsScalar<float>(1.0); 268 Tensor mins1 = test::AsScalar<float>(same_limits ? -1.0 : -255.0); 269 Tensor maxes1 = test::AsScalar<float>(same_limits ? 1.0 : 255.0); 270 271 Node* node; 272 TF_CHECK_OK(NodeBuilder(g->NewName("n"), "QuantizedConcat") 273 .Input(Constant(g, concat_dim)) 274 .Input({Constant(g, in0), Constant(g, in1)}) 275 .Input({Constant(g, mins0), Constant(g, mins1)}) 276 .Input({Constant(g, maxes0), Constant(g, maxes1)}) 277 .Attr("N", 2) 278 .Attr("T", dt) 279 .Finalize(g, &node)); 280 281 testing::BytesProcessed(static_cast<int64>(iters) * 282 ((kDim1 * dim2) + (kDim1 * dim2)) * sizeof(T)); 283 testing::StartTiming(); 284 test::Benchmark("cpu", g).Run(iters); 285 testing::UseRealTime(); 286 } 287 288 static void BM_QConcatDim0SameLimitQInt32(int iters, int dim2) { 289 ConcatHelper<qint32>(iters, 0 /* concat_dimension */, true /* same_limits */, 290 dim2); 291 } 292 293 static void BM_QConcatDim1SameLimitQInt32(int iters, int dim2) { 294 ConcatHelper<qint32>(iters, 1 /* concat_dimension */, true /* same_limits */, 295 dim2); 296 } 297 298 static void BM_QConcatDim0DifferLimitQInt32(int iters, int dim2) { 299 ConcatHelper<qint32>(iters, 0 /* concat_dimension */, false /* same_limits */, 300 dim2); 301 } 302 303 static void BM_QConcatDim1DifferLimitQInt32(int iters, int dim2) { 304 ConcatHelper<qint32>(iters, 1 /* concat_dimension */, false /* same_limits */, 305 dim2); 306 } 307 308 BENCHMARK(BM_QConcatDim0SameLimitQInt32)->Arg(1000)->Arg(20000)->Arg(100000); 309 BENCHMARK(BM_QConcatDim1SameLimitQInt32)->Arg(1000)->Arg(20000)->Arg(100000); 310 BENCHMARK(BM_QConcatDim0DifferLimitQInt32)->Arg(1000)->Arg(20000)->Arg(100000); 311 BENCHMARK(BM_QConcatDim1DifferLimitQInt32)->Arg(1000)->Arg(20000)->Arg(100000); 312 313 static void BM_QConcatDim0SameLimitQUint8(int iters, int dim2) { 314 ConcatHelper<qint32>(iters, 0 /* concat_dimension */, true /* same_limits */, 315 dim2); 316 } 317 318 static void BM_QConcatDim1SameLimitQUint8(int iters, int dim2) { 319 ConcatHelper<qint32>(iters, 1 /* concat_dimension */, true /* same_limits */, 320 dim2); 321 } 322 323 static void BM_QConcatDim0DifferLimitQUint8(int iters, int dim2) { 324 ConcatHelper<qint32>(iters, 0 /* concat_dimension */, false /* same_limits */, 325 dim2); 326 } 327 328 static void BM_QConcatDim1DifferLimitQUint8(int iters, int dim2) { 329 ConcatHelper<qint32>(iters, 1 /* concat_dimension */, false /* same_limits */, 330 dim2); 331 } 332 333 BENCHMARK(BM_QConcatDim0SameLimitQUint8)->Arg(1000)->Arg(20000)->Arg(100000); 334 BENCHMARK(BM_QConcatDim1SameLimitQUint8)->Arg(1000)->Arg(20000)->Arg(100000); 335 BENCHMARK(BM_QConcatDim0DifferLimitQUint8)->Arg(1000)->Arg(20000)->Arg(100000); 336 BENCHMARK(BM_QConcatDim1DifferLimitQUint8)->Arg(1000)->Arg(20000)->Arg(100000); 337 338 } // namespace tensorflow 339