1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/lite/toco/tflite/operator.h" 16 17 #include "flatbuffers/flexbuffers.h" 18 #include <gmock/gmock.h> 19 #include <gtest/gtest.h> 20 #include "tensorflow/lite/toco/model.h" 21 #include "tensorflow/lite/toco/tooling_util.h" 22 23 #include "tensorflow/core/framework/attr_value.pb.h" 24 #include "tensorflow/core/framework/node_def.pb.h" 25 26 namespace toco { 27 28 namespace tflite { 29 namespace { 30 31 class OperatorTest : public ::testing::Test { 32 protected: 33 // Return the operator for the given name and type. 34 const BaseOperator& GetOperator(const string& name, OperatorType type) { 35 using OpsByName = std::map<string, std::unique_ptr<BaseOperator>>; 36 using OpsByType = std::map<OperatorType, std::unique_ptr<BaseOperator>>; 37 38 static auto* by_name = new OpsByName(BuildOperatorByNameMap()); 39 static auto* by_type = new OpsByType(BuildOperatorByTypeMap()); 40 41 // Make sure the two maps were consitently built. 42 CHECK(by_name->count(name)) << "No operator for '" << name << "'."; 43 BaseOperator* op1 = by_name->at(name).get(); 44 CHECK(op1->type() == type) << "while verifying '" << name << "'."; 45 46 CHECK(by_type->count(type)) 47 << "No operator for '" << OperatorTypeName(type) << "'."; 48 BaseOperator* op2 = by_type->at(type).get(); 49 CHECK(op2->name() == name) 50 << "while verifying '" << OperatorTypeName(type) << "'."; 51 52 return *op1; 53 } 54 55 // Use the given BaseOperator to serialize the tf.mini operator into a set of 56 // TF Lite options. Proceed to deserialize the options back into a new 57 // tf.mini operator, which is then returned. If `options` is given, it will 58 // be populated with the serialized options. 59 template <typename T> 60 std::unique_ptr<T> SerializeAndDeserialize(const BaseOperator& op, 61 const T& toco_op, 62 Options* options = nullptr) { 63 flatbuffers::FlatBufferBuilder builder; 64 Options input_options = op.Serialize(toco_op, &builder); 65 66 if (options) { 67 *options = input_options; 68 } 69 70 builder.Finish(CreateOperator(builder, 0, 0, 0, input_options.type, 71 input_options.builtin, input_options.custom, 72 ::tflite::CustomOptionsFormat_FLEXBUFFERS)); 73 auto* output_options = 74 flatbuffers::GetRoot<::tflite::Operator>(builder.GetBufferPointer()); 75 auto new_toco_op = op.Deserialize(output_options->builtin_options(), 76 output_options->custom_options()); 77 78 CHECK(new_toco_op->type == toco_op.type) 79 << "The type of the serialized and deserialized" 80 << HelpfulOperatorTypeName(*new_toco_op) 81 << " does not match the type of the original " 82 << HelpfulOperatorTypeName(toco_op); 83 84 return std::unique_ptr<T>(dynamic_cast<T*>(new_toco_op.release())); 85 } 86 87 // Verify serialization and deserialization of simple operators (those 88 // that don't have any configuration parameters). 89 template <typename T> 90 void CheckSimpleOperator(const string& name, OperatorType type) { 91 Options options; 92 auto output_toco_op = 93 SerializeAndDeserialize(GetOperator(name, type), T(), &options); 94 95 ASSERT_EQ(0, options.builtin.o); 96 ASSERT_EQ(0, options.custom.o); 97 ASSERT_EQ(::tflite::BuiltinOptions_NONE, options.type); 98 99 ASSERT_NE(nullptr, output_toco_op.get()); 100 } 101 102 template <typename T> 103 void CheckReducerOperator(const string& name, OperatorType type) { 104 T op; 105 106 op.keep_dims = false; 107 108 auto output_toco_op = SerializeAndDeserialize(GetOperator(name, type), op); 109 EXPECT_EQ(op.keep_dims, output_toco_op->keep_dims); 110 } 111 }; 112 113 TEST_F(OperatorTest, SimpleOperators) { 114 CheckSimpleOperator<FloorOperator>("FLOOR", OperatorType::kFloor); 115 CheckSimpleOperator<CeilOperator>("CEIL", OperatorType::kCeil); 116 CheckSimpleOperator<EluOperator>("ELU", OperatorType::kElu); 117 CheckSimpleOperator<ReluOperator>("RELU", OperatorType::kRelu); 118 CheckSimpleOperator<Relu1Operator>("RELU_N1_TO_1", OperatorType::kRelu1); 119 CheckSimpleOperator<Relu6Operator>("RELU6", OperatorType::kRelu6); 120 CheckSimpleOperator<LogisticOperator>("LOGISTIC", OperatorType::kLogistic); 121 CheckSimpleOperator<TanhOperator>("TANH", OperatorType::kTanh); 122 CheckSimpleOperator<ExpOperator>("EXP", OperatorType::kExp); 123 CheckSimpleOperator<CosOperator>("COS", OperatorType::kCos); 124 CheckSimpleOperator<LogSoftmaxOperator>("LOG_SOFTMAX", 125 OperatorType::kLogSoftmax); 126 CheckSimpleOperator<TensorFlowMaximumOperator>( 127 "MAXIMUM", OperatorType::kMaximum); // Element-wise Maximum 128 CheckSimpleOperator<TensorFlowMinimumOperator>( 129 "MINIMUM", OperatorType::kMinimum); // Element-wise Minimum 130 CheckSimpleOperator<TensorFlowLessOperator>("LESS", OperatorType::kLess); 131 CheckSimpleOperator<NegOperator>("NEG", OperatorType::kNeg); 132 CheckSimpleOperator<SelectOperator>("SELECT", OperatorType::kSelect); 133 CheckSimpleOperator<SliceOperator>("SLICE", OperatorType::kSlice); 134 CheckSimpleOperator<SinOperator>("SIN", OperatorType::kSin); 135 CheckSimpleOperator<TensorFlowEqualOperator>("EQUAL", OperatorType::kEqual); 136 CheckSimpleOperator<TensorFlowNotEqualOperator>("NOT_EQUAL", 137 OperatorType::kNotEqual); 138 CheckSimpleOperator<LogOperator>("LOG", OperatorType::kLog); 139 CheckSimpleOperator<TensorFlowSqrtOperator>("SQRT", OperatorType::kSqrt); 140 CheckSimpleOperator<TensorFlowRsqrtOperator>("RSQRT", OperatorType::kRsqrt); 141 CheckSimpleOperator<PowOperator>("POW", OperatorType::kPow); 142 CheckSimpleOperator<LogicalOrOperator>("LOGICAL_OR", 143 OperatorType::kLogicalOr); 144 CheckSimpleOperator<LogicalAndOperator>("LOGICAL_AND", 145 OperatorType::kLogicalAnd); 146 CheckSimpleOperator<LogicalNotOperator>("LOGICAL_NOT", 147 OperatorType::kLogicalNot); 148 CheckSimpleOperator<FloorDivOperator>("FLOOR_DIV", OperatorType::kFloorDiv); 149 CheckSimpleOperator<TensorFlowSquareOperator>("SQUARE", 150 OperatorType::kSquare); 151 CheckSimpleOperator<TensorFlowZerosLikeOperator>("ZEROS_LIKE", 152 OperatorType::kZerosLike); 153 CheckSimpleOperator<FloorModOperator>("FLOOR_MOD", OperatorType::kFloorMod); 154 CheckSimpleOperator<RangeOperator>("RANGE", OperatorType::kRange); 155 CheckSimpleOperator<FillOperator>("FILL", OperatorType::kFill); 156 CheckSimpleOperator<ReverseV2Operator>("REVERSE_V2", 157 OperatorType::kReverseV2); 158 CheckSimpleOperator<TensorFlowRankOperator>("RANK", OperatorType::kRank); 159 } 160 161 TEST_F(OperatorTest, BuiltinAdd) { 162 AddOperator op; 163 op.fused_activation_function = FusedActivationFunctionType::kRelu6; 164 auto output_toco_op = 165 SerializeAndDeserialize(GetOperator("ADD", OperatorType::kAdd), op); 166 EXPECT_EQ(op.fused_activation_function, 167 output_toco_op->fused_activation_function); 168 } 169 170 TEST_F(OperatorTest, BuiltinAddN) { 171 AddNOperator op; 172 auto output_toco_op = 173 SerializeAndDeserialize(GetOperator("ADD_N", OperatorType::kAddN), op); 174 ASSERT_NE(output_toco_op.get(), nullptr); 175 } 176 177 TEST_F(OperatorTest, BuiltinReducerOps) { 178 CheckReducerOperator<MeanOperator>("MEAN", OperatorType::kMean); 179 CheckReducerOperator<TensorFlowSumOperator>("SUM", OperatorType::kSum); 180 CheckReducerOperator<TensorFlowProdOperator>("REDUCE_PROD", 181 OperatorType::kReduceProd); 182 CheckReducerOperator<TensorFlowMaxOperator>("REDUCE_MAX", 183 OperatorType::kReduceMax); 184 CheckReducerOperator<TensorFlowMinOperator>("REDUCE_MIN", 185 OperatorType::kReduceMin); 186 CheckReducerOperator<TensorFlowAnyOperator>("REDUCE_ANY", OperatorType::kAny); 187 } 188 189 TEST_F(OperatorTest, BuiltinCast) { 190 CastOperator op; 191 op.src_data_type = ArrayDataType::kFloat; 192 op.dst_data_type = ArrayDataType::kUint8; 193 auto output_toco_op = 194 SerializeAndDeserialize(GetOperator("CAST", OperatorType::kCast), op); 195 EXPECT_EQ(op.src_data_type, output_toco_op->src_data_type); 196 EXPECT_EQ(op.dst_data_type, output_toco_op->dst_data_type); 197 } 198 199 TEST_F(OperatorTest, CustomConcatenation) { 200 ConcatenationOperator op; 201 op.axis = 123; 202 auto output_toco_op = SerializeAndDeserialize( 203 GetOperator("CONCATENATION", OperatorType::kConcatenation), op); 204 EXPECT_EQ(op.axis, output_toco_op->axis); 205 } 206 207 TEST_F(OperatorTest, CustomDepthToSpace) { 208 DepthToSpaceOperator op; 209 op.block_size = 123; 210 auto output_toco_op = SerializeAndDeserialize( 211 GetOperator("DEPTH_TO_SPACE", OperatorType::kDepthToSpace), op); 212 EXPECT_EQ(op.block_size, output_toco_op->block_size); 213 } 214 215 TEST_F(OperatorTest, CustomFakeQuant) { 216 FakeQuantOperator op; 217 auto* minmax = new MinMax; 218 minmax->min = -10; 219 minmax->max = 200; 220 op.minmax.reset(minmax); 221 op.num_bits = 16; 222 auto output_toco_op = SerializeAndDeserialize( 223 GetOperator("FAKE_QUANT", OperatorType::kFakeQuant), op); 224 EXPECT_EQ(op.minmax->min, output_toco_op->minmax->min); 225 EXPECT_EQ(op.minmax->max, output_toco_op->minmax->max); 226 EXPECT_EQ(op.num_bits, output_toco_op->num_bits); 227 } 228 229 TEST_F(OperatorTest, CustomFullyConnected) { 230 FullyConnectedOperator op; 231 op.fused_activation_function = FusedActivationFunctionType::kRelu6; 232 auto output_toco_op = SerializeAndDeserialize( 233 GetOperator("FULLY_CONNECTED", OperatorType::kFullyConnected), op); 234 EXPECT_EQ(op.fused_activation_function, 235 output_toco_op->fused_activation_function); 236 } 237 238 TEST_F(OperatorTest, BuiltinGather) { 239 GatherOperator op; 240 auto output_toco_op = 241 SerializeAndDeserialize(GetOperator("GATHER", OperatorType::kGather), op); 242 ASSERT_NE(nullptr, output_toco_op.get()); 243 } 244 245 TEST_F(OperatorTest, BuiltinGatherNd) { 246 GatherNdOperator op; 247 auto output_toco_op = SerializeAndDeserialize( 248 GetOperator("GATHER_ND", OperatorType::kGatherNd), op); 249 ASSERT_NE(output_toco_op.get(), nullptr); 250 } 251 252 TEST_F(OperatorTest, BuiltinWhere) { 253 WhereOperator op; 254 auto output_toco_op = 255 SerializeAndDeserialize(GetOperator("WHERE", OperatorType::kWhere), op); 256 ASSERT_NE(output_toco_op.get(), nullptr); 257 } 258 259 TEST_F(OperatorTest, BuiltinL2Pool) { 260 L2PoolOperator op; 261 op.stride_width = 123; 262 op.stride_height = 124; 263 op.padding.type = PaddingType::kValid; 264 op.kwidth = 480; 265 op.kheight = 1080; 266 auto output_toco_op = SerializeAndDeserialize( 267 GetOperator("L2_POOL_2D", OperatorType::kL2Pool), op); 268 EXPECT_EQ(op.stride_width, output_toco_op->stride_width); 269 EXPECT_EQ(op.stride_height, output_toco_op->stride_height); 270 EXPECT_EQ(op.padding.type, output_toco_op->padding.type); 271 EXPECT_EQ(op.kwidth, output_toco_op->kwidth); 272 EXPECT_EQ(op.kheight, output_toco_op->kheight); 273 } 274 275 TEST_F(OperatorTest, BuiltinLocalResponseNormalization) { 276 LocalResponseNormalizationOperator op; 277 op.range = 123; 278 op.bias = 1.23; 279 op.alpha = 12.3; 280 op.beta = .123; 281 auto output_toco_op = SerializeAndDeserialize( 282 GetOperator("LOCAL_RESPONSE_NORMALIZATION", 283 OperatorType::kLocalResponseNormalization), 284 op); 285 EXPECT_EQ(op.range, output_toco_op->range); 286 EXPECT_EQ(op.bias, output_toco_op->bias); 287 EXPECT_EQ(op.alpha, output_toco_op->alpha); 288 EXPECT_EQ(op.beta, output_toco_op->beta); 289 } 290 291 TEST_F(OperatorTest, BuiltinMaxPool) { 292 MaxPoolOperator op; 293 op.stride_width = 123; 294 op.stride_height = 124; 295 op.padding.type = PaddingType::kValid; 296 op.kwidth = 480; 297 op.kheight = 1080; 298 auto output_toco_op = SerializeAndDeserialize( 299 GetOperator("MAX_POOL_2D", OperatorType::kMaxPool), op); 300 EXPECT_EQ(op.stride_width, output_toco_op->stride_width); 301 EXPECT_EQ(op.stride_height, output_toco_op->stride_height); 302 EXPECT_EQ(op.padding.type, output_toco_op->padding.type); 303 EXPECT_EQ(op.kwidth, output_toco_op->kwidth); 304 EXPECT_EQ(op.kheight, output_toco_op->kheight); 305 } 306 307 TEST_F(OperatorTest, BuiltinReshape) { 308 TensorFlowReshapeOperator op; 309 op.shape = {1, 2, 4, 5, 8}; 310 auto output_toco_op = SerializeAndDeserialize( 311 GetOperator("RESHAPE", OperatorType::kReshape), op); 312 EXPECT_EQ(op.shape, output_toco_op->shape); 313 } 314 315 TEST_F(OperatorTest, CustomSoftmax) { 316 SoftmaxOperator op; 317 op.beta = 123.1; 318 auto output_toco_op = SerializeAndDeserialize( 319 GetOperator("SOFTMAX", OperatorType::kSoftmax), op); 320 EXPECT_EQ(op.beta, output_toco_op->beta); 321 } 322 323 TEST_F(OperatorTest, BuiltinSpaceToDepth) { 324 SpaceToDepthOperator op; 325 op.block_size = 123; 326 auto output_toco_op = SerializeAndDeserialize( 327 GetOperator("SPACE_TO_DEPTH", OperatorType::kSpaceToDepth), op); 328 EXPECT_EQ(op.block_size, output_toco_op->block_size); 329 } 330 331 TEST_F(OperatorTest, CustomSplit) { 332 TensorFlowSplitOperator op; 333 op.num_split = 123; 334 auto output_toco_op = 335 SerializeAndDeserialize(GetOperator("SPLIT", OperatorType::kSplit), op); 336 EXPECT_EQ(op.num_split, output_toco_op->num_split); 337 } 338 339 TEST_F(OperatorTest, CustomSplitV) { 340 TensorFlowSplitVOperator op; 341 op.num_split = 123; 342 auto output_toco_op = SerializeAndDeserialize( 343 GetOperator("SPLIT_V", OperatorType::kSplitV), op); 344 EXPECT_EQ(op.num_split, output_toco_op->num_split); 345 } 346 347 TEST_F(OperatorTest, BuiltinAveragePool) { 348 AveragePoolOperator op; 349 op.fused_activation_function = FusedActivationFunctionType::kRelu6; 350 op.stride_width = 123; 351 op.stride_height = 124; 352 op.padding.type = PaddingType::kValid; 353 op.kwidth = 480; 354 op.kheight = 1080; 355 auto output_toco_op = SerializeAndDeserialize( 356 GetOperator("AVERAGE_POOL_2D", OperatorType::kAveragePool), op); 357 EXPECT_EQ(op.fused_activation_function, 358 output_toco_op->fused_activation_function); 359 EXPECT_EQ(op.stride_width, output_toco_op->stride_width); 360 EXPECT_EQ(op.stride_height, output_toco_op->stride_height); 361 EXPECT_EQ(op.padding.type, output_toco_op->padding.type); 362 EXPECT_EQ(op.kwidth, output_toco_op->kwidth); 363 EXPECT_EQ(op.kheight, output_toco_op->kheight); 364 } 365 366 TEST_F(OperatorTest, BuiltinConvolution) { 367 ConvOperator op; 368 op.stride_width = 123; 369 op.stride_height = 124; 370 op.padding.type = PaddingType::kValid; 371 op.fused_activation_function = FusedActivationFunctionType::kRelu6; 372 auto output_toco_op = 373 SerializeAndDeserialize(GetOperator("CONV_2D", OperatorType::kConv), op); 374 EXPECT_EQ(op.stride_width, output_toco_op->stride_width); 375 EXPECT_EQ(op.stride_height, output_toco_op->stride_height); 376 EXPECT_EQ(op.padding.type, output_toco_op->padding.type); 377 EXPECT_EQ(op.fused_activation_function, 378 output_toco_op->fused_activation_function); 379 } 380 381 TEST_F(OperatorTest, BuiltinDepthwiseConvolution) { 382 DepthwiseConvOperator op; 383 op.stride_width = 123; 384 op.stride_height = 124; 385 op.padding.type = PaddingType::kValid; 386 op.depth_multiplier = 6; 387 op.fused_activation_function = FusedActivationFunctionType::kRelu6; 388 auto output_toco_op = SerializeAndDeserialize( 389 GetOperator("DEPTHWISE_CONV_2D", OperatorType::kDepthwiseConv), op); 390 EXPECT_EQ(op.stride_width, output_toco_op->stride_width); 391 EXPECT_EQ(op.stride_height, output_toco_op->stride_height); 392 EXPECT_EQ(op.padding.type, output_toco_op->padding.type); 393 EXPECT_EQ(op.depth_multiplier, output_toco_op->depth_multiplier); 394 EXPECT_EQ(op.fused_activation_function, 395 output_toco_op->fused_activation_function); 396 } 397 398 TEST_F(OperatorTest, BuiltinL2Norm) { 399 L2NormalizationOperator op; 400 op.fused_activation_function = FusedActivationFunctionType::kRelu6; 401 auto output_toco_op = SerializeAndDeserialize( 402 GetOperator("L2_NORMALIZATION", OperatorType::kL2Normalization), op); 403 EXPECT_EQ(op.fused_activation_function, 404 output_toco_op->fused_activation_function); 405 } 406 407 TEST_F(OperatorTest, BuiltinMul) { 408 MulOperator op; 409 op.fused_activation_function = FusedActivationFunctionType::kRelu6; 410 auto output_toco_op = 411 SerializeAndDeserialize(GetOperator("MUL", OperatorType::kMul), op); 412 EXPECT_EQ(op.fused_activation_function, 413 output_toco_op->fused_activation_function); 414 } 415 416 TEST_F(OperatorTest, ResizeBilinear) { 417 ResizeBilinearOperator op; 418 op.align_corners = true; 419 auto output_toco_op = SerializeAndDeserialize( 420 GetOperator("RESIZE_BILINEAR", OperatorType::kResizeBilinear), op); 421 EXPECT_EQ(op.align_corners, output_toco_op->align_corners); 422 } 423 424 TEST_F(OperatorTest, ResizeNearestNeighbor) { 425 ResizeNearestNeighborOperator op; 426 op.align_corners = true; 427 auto output_toco_op = 428 SerializeAndDeserialize(GetOperator("RESIZE_NEAREST_NEIGHBOR", 429 OperatorType::kResizeNearestNeighbor), 430 op); 431 EXPECT_EQ(op.align_corners, output_toco_op->align_corners); 432 } 433 434 TEST_F(OperatorTest, Svdf) { 435 SvdfOperator op; 436 op.fused_activation_function = FusedActivationFunctionType::kRelu; 437 op.rank = 1; 438 auto output_toco_op = 439 SerializeAndDeserialize(GetOperator("SVDF", OperatorType::kSvdf), op); 440 EXPECT_EQ(op.fused_activation_function, 441 output_toco_op->fused_activation_function); 442 EXPECT_EQ(op.rank, output_toco_op->rank); 443 } 444 445 TEST_F(OperatorTest, Squeeze) { 446 SqueezeOperator op; 447 op.squeeze_dims = {-2, -3, 4, 1, 4}; 448 449 auto output_toco_op = SerializeAndDeserialize( 450 GetOperator("SQUEEZE", OperatorType::kSqueeze), op); 451 EXPECT_EQ(op.squeeze_dims, output_toco_op->squeeze_dims); 452 } 453 454 TEST_F(OperatorTest, StridedSlice) { 455 StridedSliceOperator op; 456 457 op.begin_mask = 1; 458 op.end_mask = 2; 459 op.ellipsis_mask = 1; 460 op.new_axis_mask = 1; 461 op.shrink_axis_mask = 2; 462 463 auto output_toco_op = SerializeAndDeserialize( 464 GetOperator("STRIDED_SLICE", OperatorType::kStridedSlice), op); 465 EXPECT_EQ(op.start_indices, output_toco_op->start_indices); 466 EXPECT_EQ(op.stop_indices, output_toco_op->stop_indices); 467 EXPECT_EQ(op.strides, output_toco_op->strides); 468 EXPECT_EQ(op.begin_mask, output_toco_op->begin_mask); 469 EXPECT_EQ(op.end_mask, output_toco_op->end_mask); 470 EXPECT_EQ(op.end_mask, output_toco_op->end_mask); 471 EXPECT_EQ(op.ellipsis_mask, output_toco_op->ellipsis_mask); 472 EXPECT_EQ(op.new_axis_mask, output_toco_op->new_axis_mask); 473 EXPECT_EQ(op.shrink_axis_mask, output_toco_op->shrink_axis_mask); 474 } 475 476 TEST_F(OperatorTest, BuiltinTopKV2) { 477 TopKV2Operator op; 478 auto output_toco_op = SerializeAndDeserialize( 479 GetOperator("TOPK_V2", OperatorType::kTopK_V2), op); 480 ASSERT_NE(nullptr, output_toco_op.get()); 481 } 482 483 TEST_F(OperatorTest, BuiltinArgMax) { 484 ArgMaxOperator op; 485 auto output_toco_op = SerializeAndDeserialize( 486 GetOperator("ARG_MAX", OperatorType::kArgMax), op); 487 EXPECT_EQ(op.output_data_type, output_toco_op->output_data_type); 488 } 489 490 TEST_F(OperatorTest, BuiltinArgMin) { 491 ArgMinOperator op; 492 auto output_toco_op = SerializeAndDeserialize( 493 GetOperator("ARG_MIN", OperatorType::kArgMin), op); 494 EXPECT_EQ(op.output_data_type, output_toco_op->output_data_type); 495 } 496 497 TEST_F(OperatorTest, BuiltinDequantize) { 498 DequantizeOperator op; 499 auto output_toco_op = SerializeAndDeserialize( 500 GetOperator("DEQUANTIZE", OperatorType::kDequantize), op); 501 } 502 503 TEST_F(OperatorTest, BuiltinTransposeConv) { 504 TransposeConvOperator op; 505 op.stride_width = 123; 506 op.stride_height = 124; 507 op.padding.type = PaddingType::kValid; 508 auto output_toco_op = SerializeAndDeserialize( 509 GetOperator("TRANSPOSE_CONV", OperatorType::kTransposeConv), op); 510 EXPECT_EQ(op.stride_width, output_toco_op->stride_width); 511 EXPECT_EQ(op.stride_height, output_toco_op->stride_height); 512 EXPECT_EQ(op.padding.type, output_toco_op->padding.type); 513 } 514 515 TEST_F(OperatorTest, BuiltinShape) { 516 TensorFlowShapeOperator op; 517 op.output_data_type = ArrayDataType::kInt64; 518 auto output_toco_op = 519 SerializeAndDeserialize(GetOperator("SHAPE", OperatorType::kShape), op); 520 EXPECT_EQ(op.output_data_type, output_toco_op->output_data_type); 521 } 522 523 TEST_F(OperatorTest, BuiltinSparseToDense) { 524 SparseToDenseOperator op; 525 op.validate_indices = false; 526 std::unique_ptr<toco::SparseToDenseOperator> output_toco_op = 527 SerializeAndDeserialize( 528 GetOperator("SPARSE_TO_DENSE", OperatorType::kSparseToDense), op); 529 EXPECT_EQ(op.validate_indices, output_toco_op->validate_indices); 530 } 531 532 TEST_F(OperatorTest, BuiltinPack) { 533 PackOperator op; 534 op.values_count = 3; 535 op.axis = 1; 536 std::unique_ptr<toco::PackOperator> output_toco_op = 537 SerializeAndDeserialize(GetOperator("PACK", OperatorType::kPack), op); 538 EXPECT_EQ(op.values_count, output_toco_op->values_count); 539 EXPECT_EQ(op.axis, output_toco_op->axis); 540 } 541 542 TEST_F(OperatorTest, BuiltinOneHot) { 543 OneHotOperator op; 544 op.axis = 2; 545 auto output_toco_op = SerializeAndDeserialize( 546 GetOperator("ONE_HOT", OperatorType::kOneHot), op); 547 EXPECT_EQ(op.axis, output_toco_op->axis); 548 } 549 550 TEST_F(OperatorTest, BuiltinUnpack) { 551 UnpackOperator op; 552 op.num = 5; 553 op.axis = 2; 554 auto output_toco_op = 555 SerializeAndDeserialize(GetOperator("UNPACK", OperatorType::kUnpack), op); 556 EXPECT_EQ(op.num, output_toco_op->num); 557 EXPECT_EQ(op.axis, output_toco_op->axis); 558 } 559 560 TEST_F(OperatorTest, BuiltinLeakyRelu) { 561 LeakyReluOperator op; 562 op.alpha = 3; 563 auto output_toco_op = SerializeAndDeserialize( 564 GetOperator("LEAKY_RELU", OperatorType::kLeakyRelu), op); 565 EXPECT_EQ(op.alpha, output_toco_op->alpha); 566 } 567 568 TEST_F(OperatorTest, BuiltinSquaredDifference) { 569 SquaredDifferenceOperator op; 570 auto output_toco_op = SerializeAndDeserialize( 571 GetOperator("SQUARED_DIFFERENCE", OperatorType::kSquaredDifference), op); 572 ASSERT_NE(nullptr, output_toco_op.get()); 573 } 574 575 TEST_F(OperatorTest, CustomCTCBeamSearchDecoder) { 576 CTCBeamSearchDecoderOperator op; 577 op.beam_width = 3; 578 op.top_paths = 2; 579 op.merge_repeated = false; 580 std::unique_ptr<toco::CTCBeamSearchDecoderOperator> output_toco_op = 581 SerializeAndDeserialize(GetOperator("CTC_BEAM_SEARCH_DECODER", 582 OperatorType::kCTCBeamSearchDecoder), 583 op); 584 EXPECT_EQ(op.beam_width, output_toco_op->beam_width); 585 EXPECT_EQ(op.top_paths, output_toco_op->top_paths); 586 EXPECT_EQ(op.merge_repeated, output_toco_op->merge_repeated); 587 } 588 589 TEST_F(OperatorTest, TensorFlowUnsupported) { 590 TensorFlowUnsupportedOperator op; 591 op.tensorflow_op = "MyCustomUnsupportedOp"; 592 593 ::tensorflow::NodeDef node_def; 594 auto attr = node_def.mutable_attr(); 595 (*attr)["float_attr"].set_f(2.0); 596 (*attr)["str_attr"].set_s("Hello World"); 597 (*attr)["int_attr"].set_i(17); 598 (*attr)["bool_attr"].set_b(true); 599 { 600 auto* list = (*attr)["list_string_attr"].mutable_list(); 601 list->add_s("abcde"); 602 list->add_s("1234"); 603 list->add_s(""); 604 list->add_s("zyxwv"); 605 list->add_s("!-."); 606 } 607 { 608 auto* list = (*attr)["list_float_attr"].mutable_list(); 609 list->add_f(std::numeric_limits<float>::min()); 610 list->add_f(2.0); 611 list->add_f(-std::numeric_limits<float>::max()); 612 } 613 { 614 auto* list = (*attr)["list_int_attr"].mutable_list(); 615 list->add_i(1); 616 list->add_i(20); 617 list->add_i(1LL << 40); 618 list->add_i(-(1LL << 40)); 619 } 620 node_def.SerializeToString(&op.tensorflow_node_def); 621 622 auto output_toco_op = SerializeAndDeserialize( 623 GetOperator("TENSORFLOW_UNSUPPORTED", OperatorType::kUnsupported), op); 624 625 ::tensorflow::NodeDef output_node_def; 626 output_node_def.ParseFromString(output_toco_op->tensorflow_node_def); 627 const auto& output_attr = output_node_def.attr(); 628 EXPECT_EQ(2.0, output_attr.at("float_attr").f()); 629 EXPECT_EQ("Hello World", output_attr.at("str_attr").s()); 630 EXPECT_EQ(17, output_attr.at("int_attr").i()); 631 EXPECT_EQ(true, output_attr.at("bool_attr").b()); 632 { 633 const auto& list = output_attr.at("list_string_attr").list(); 634 ASSERT_EQ(5, list.s_size()); 635 EXPECT_EQ("abcde", list.s(0)); 636 EXPECT_EQ("1234", list.s(1)); 637 EXPECT_EQ("", list.s(2)); 638 EXPECT_EQ("zyxwv", list.s(3)); 639 EXPECT_EQ("!-.", list.s(4)); 640 } 641 { 642 const auto& list = output_attr.at("list_float_attr").list(); 643 ASSERT_EQ(3, list.f_size()); 644 EXPECT_EQ(std::numeric_limits<float>::min(), list.f(0)); 645 EXPECT_EQ(2.0, list.f(1)); 646 EXPECT_EQ(-std::numeric_limits<float>::max(), list.f(2)); 647 } 648 { 649 const auto& list = output_attr.at("list_int_attr").list(); 650 ASSERT_EQ(4, list.i_size()); 651 EXPECT_EQ(1, list.i(0)); 652 EXPECT_EQ(20, list.i(1)); 653 EXPECT_EQ(1LL << 40, list.i(2)); 654 EXPECT_EQ(-(1LL << 40), list.i(3)); 655 } 656 } 657 658 TEST_F(OperatorTest, TensorFlowUnsupportedWithoutAttr) { 659 TensorFlowUnsupportedOperator op; 660 op.tensorflow_op = "MyCustomUnsupportedOp"; 661 auto output_toco_op = SerializeAndDeserialize( 662 GetOperator("TENSORFLOW_UNSUPPORTED", OperatorType::kUnsupported), op); 663 664 ::tensorflow::NodeDef output_node_def; 665 output_node_def.ParseFromString(output_toco_op->tensorflow_node_def); 666 EXPECT_TRUE(output_node_def.attr().empty()); 667 } 668 669 TEST_F(OperatorTest, TestShouldExportAsFlexOp) { 670 EXPECT_FALSE(ShouldExportAsFlexOp(false, "Conv2D")); 671 EXPECT_TRUE(ShouldExportAsFlexOp(true, "Conv2D")); 672 EXPECT_TRUE(ShouldExportAsFlexOp(true, "EluGrad")); 673 EXPECT_TRUE(ShouldExportAsFlexOp(true, "RFFT")); 674 EXPECT_FALSE(ShouldExportAsFlexOp(true, "MyAwesomeCustomOp")); 675 // While the RandomShuffle op is available on desktop, it is not in the kernel 676 // set available on mobile and should be excluded. 677 EXPECT_FALSE(ShouldExportAsFlexOp(true, "RandomShuffle")); 678 } 679 680 TEST_F(OperatorTest, BuiltinMirrorPad) { 681 MirrorPadOperator op; 682 op.mode = MirrorPadMode::kReflect; 683 auto output_toco_op = SerializeAndDeserialize( 684 GetOperator("MIRROR_PAD", OperatorType::kMirrorPad), op); 685 EXPECT_EQ(op.mode, output_toco_op->mode); 686 } 687 688 TEST_F(OperatorTest, BuiltinUnique) { 689 UniqueOperator op; 690 op.idx_out_type = ArrayDataType::kInt64; 691 auto output_toco_op = 692 SerializeAndDeserialize(GetOperator("UNIQUE", OperatorType::kUnique), op); 693 ASSERT_NE(nullptr, output_toco_op.get()); 694 EXPECT_EQ(output_toco_op->idx_out_type, op.idx_out_type); 695 } 696 697 TEST_F(OperatorTest, BuiltinReverseSequence) { 698 ReverseSequenceOperator op; 699 op.seq_dim = 3; 700 op.batch_dim = 1; 701 std::unique_ptr<toco::ReverseSequenceOperator> output_toco_op = 702 SerializeAndDeserialize( 703 GetOperator("REVERSE_SEQUENCE", OperatorType::kReverseSequence), op); 704 EXPECT_EQ(op.seq_dim, output_toco_op->seq_dim); 705 EXPECT_EQ(op.batch_dim, output_toco_op->batch_dim); 706 } 707 708 // Test version for a simple Op with 2 versions and the input type controls the 709 // version. 710 template <typename Op> 711 void SimpleVersioningTest() { 712 Op op; 713 op.inputs = {"input1"}; 714 auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/); 715 const BaseOperator* base_op = operator_by_type_map.at(op.type).get(); 716 717 Model uint8_model; 718 Array& uint8_array = uint8_model.GetOrCreateArray(op.inputs[0]); 719 uint8_array.data_type = ArrayDataType::kUint8; 720 OperatorSignature uint8_signature = {.op = &op, .model = &uint8_model}; 721 EXPECT_EQ(base_op->GetVersion(uint8_signature), 1); 722 723 Model int8_model; 724 Array& int8_array = int8_model.GetOrCreateArray(op.inputs[0]); 725 int8_array.data_type = ArrayDataType::kInt8; 726 OperatorSignature int8_signature = {.op = &op, .model = &int8_model}; 727 EXPECT_EQ(base_op->GetVersion(int8_signature), 2); 728 } 729 730 // Test version for a simple Op with 2 versions and the output type controls the 731 // version. 732 template <typename Op> 733 void SimpleOutputVersioningTest() { 734 Op op; 735 op.outputs = {"output1"}; 736 auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/); 737 const BaseOperator* base_op = operator_by_type_map.at(op.type).get(); 738 739 Model uint8_model; 740 Array& uint8_array = uint8_model.GetOrCreateArray(op.outputs[0]); 741 uint8_array.data_type = ArrayDataType::kUint8; 742 OperatorSignature uint8_signature = {.op = &op, .model = &uint8_model}; 743 EXPECT_EQ(base_op->GetVersion(uint8_signature), 1); 744 745 Model int8_model; 746 Array& int8_array = int8_model.GetOrCreateArray(op.outputs[0]); 747 int8_array.data_type = ArrayDataType::kInt8; 748 OperatorSignature int8_signature = {.op = &op, .model = &int8_model}; 749 EXPECT_EQ(base_op->GetVersion(int8_signature), 2); 750 } 751 752 TEST_F(OperatorTest, VersioningEqualTest) { 753 SimpleVersioningTest<TensorFlowEqualOperator>(); 754 } 755 756 TEST_F(OperatorTest, VersioningNotEqualTest) { 757 SimpleVersioningTest<TensorFlowNotEqualOperator>(); 758 } 759 760 TEST_F(OperatorTest, VersioningLessTest) { 761 SimpleVersioningTest<TensorFlowLessOperator>(); 762 } 763 764 TEST_F(OperatorTest, VersioningLessEqualTest) { 765 SimpleVersioningTest<TensorFlowLessEqualOperator>(); 766 } 767 768 TEST_F(OperatorTest, VersioningGreaterTest) { 769 SimpleVersioningTest<TensorFlowGreaterOperator>(); 770 } 771 772 TEST_F(OperatorTest, VersioningGreaterEqualTest) { 773 SimpleVersioningTest<TensorFlowGreaterEqualOperator>(); 774 } 775 776 TEST_F(OperatorTest, VersioningSpaceToBatchNDTest) { 777 SimpleVersioningTest<SpaceToBatchNDOperator>(); 778 } 779 780 TEST_F(OperatorTest, VersioningLogSoftmaxTest) { 781 SimpleVersioningTest<LogSoftmaxOperator>(); 782 } 783 784 TEST_F(OperatorTest, VersioningPackTest) { 785 SimpleVersioningTest<PackOperator>(); 786 } 787 788 TEST_F(OperatorTest, VersioningBatchToSpaceNDTest) { 789 SimpleVersioningTest<BatchToSpaceNDOperator>(); 790 } 791 792 TEST_F(OperatorTest, VersioningTanhTest) { 793 SimpleVersioningTest<TanhOperator>(); 794 } 795 796 TEST_F(OperatorTest, VersioningStridedSliceTest) { 797 SimpleVersioningTest<StridedSliceOperator>(); 798 } 799 800 TEST_F(OperatorTest, VersioningSpaceToDepthTest) { 801 SimpleVersioningTest<SpaceToDepthOperator>(); 802 } 803 804 TEST_F(OperatorTest, VersioningSliceTest) { 805 SimpleVersioningTest<SliceOperator>(); 806 } 807 808 TEST_F(OperatorTest, VersioningLogisticTest) { 809 SimpleVersioningTest<LogisticOperator>(); 810 } 811 812 TEST_F(OperatorTest, VersioningL2NormTest) { 813 SimpleOutputVersioningTest<L2NormalizationOperator>(); 814 } 815 816 TEST_F(OperatorTest, VersioningMaxTest) { 817 SimpleVersioningTest<TensorFlowMaximumOperator>(); 818 } 819 820 TEST_F(OperatorTest, VersioningMinTest) { 821 SimpleVersioningTest<TensorFlowMinimumOperator>(); 822 } 823 824 TEST_F(OperatorTest, VersioningAddTest) { SimpleVersioningTest<AddOperator>(); } 825 826 TEST_F(OperatorTest, VersioningSubTest) { SimpleVersioningTest<SubOperator>(); } 827 828 TEST_F(OperatorTest, VersioningMulTest) { SimpleVersioningTest<MulOperator>(); } 829 830 TEST_F(OperatorTest, VersioningPadTest) { SimpleVersioningTest<PadOperator>(); } 831 832 TEST_F(OperatorTest, VersioningPadV2Test) { 833 SimpleVersioningTest<PadV2Operator>(); 834 } 835 836 TEST_F(OperatorTest, VersioningConcatenationTest) { 837 SimpleVersioningTest<ConcatenationOperator>(); 838 } 839 840 TEST_F(OperatorTest, VersioningSelectTest) { 841 SimpleVersioningTest<SelectOperator>(); 842 } 843 844 TEST_F(OperatorTest, VersioningRelu6Test) { 845 SimpleVersioningTest<Relu6Operator>(); 846 } 847 848 TEST_F(OperatorTest, VersioningFullyConnectedTest) { 849 FullyConnectedOperator fully_connected_op; 850 fully_connected_op.inputs = {"input", "weight"}; 851 fully_connected_op.outputs = {"output"}; 852 auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/); 853 const BaseOperator* op = 854 operator_by_type_map.at(fully_connected_op.type).get(); 855 856 Model uint8_model; 857 Array& input_uint8_array = 858 uint8_model.GetOrCreateArray(fully_connected_op.inputs[0]); 859 input_uint8_array.data_type = ArrayDataType::kUint8; 860 Array& weight_uint8_array = 861 uint8_model.GetOrCreateArray(fully_connected_op.inputs[1]); 862 weight_uint8_array.data_type = ArrayDataType::kUint8; 863 Array& output_uint8_array = 864 uint8_model.GetOrCreateArray(fully_connected_op.outputs[0]); 865 output_uint8_array.data_type = ArrayDataType::kUint8; 866 OperatorSignature uint8_signature = {.op = &fully_connected_op, 867 .model = &uint8_model}; 868 EXPECT_EQ(op->GetVersion(uint8_signature), 1); 869 870 Model int8_model; 871 Array& input_int8_array = 872 int8_model.GetOrCreateArray(fully_connected_op.inputs[0]); 873 input_int8_array.data_type = ArrayDataType::kInt8; 874 Array& weight_int8_array = 875 int8_model.GetOrCreateArray(fully_connected_op.inputs[1]); 876 weight_int8_array.data_type = ArrayDataType::kInt8; 877 Array& output_int8_array = 878 int8_model.GetOrCreateArray(fully_connected_op.outputs[0]); 879 output_int8_array.data_type = ArrayDataType::kInt8; 880 OperatorSignature int8_signature = {.op = &fully_connected_op, 881 .model = &int8_model}; 882 EXPECT_EQ(op->GetVersion(int8_signature), 4); 883 } 884 885 } // namespace 886 } // namespace tflite 887 888 } // namespace toco 889