Home | History | Annotate | Download | only in tflite
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include "tensorflow/lite/toco/tflite/operator.h"
     16 
     17 #include "flatbuffers/flexbuffers.h"
     18 #include <gmock/gmock.h>
     19 #include <gtest/gtest.h>
     20 #include "tensorflow/lite/toco/model.h"
     21 #include "tensorflow/lite/toco/tooling_util.h"
     22 
     23 #include "tensorflow/core/framework/attr_value.pb.h"
     24 #include "tensorflow/core/framework/node_def.pb.h"
     25 
     26 namespace toco {
     27 
     28 namespace tflite {
     29 namespace {
     30 
     31 class OperatorTest : public ::testing::Test {
     32  protected:
     33   // Return the operator for the given name and type.
     34   const BaseOperator& GetOperator(const string& name, OperatorType type) {
     35     using OpsByName = std::map<string, std::unique_ptr<BaseOperator>>;
     36     using OpsByType = std::map<OperatorType, std::unique_ptr<BaseOperator>>;
     37 
     38     static auto* by_name = new OpsByName(BuildOperatorByNameMap());
     39     static auto* by_type = new OpsByType(BuildOperatorByTypeMap());
     40 
     41     // Make sure the two maps were consitently built.
     42     CHECK(by_name->count(name)) << "No operator for '" << name << "'.";
     43     BaseOperator* op1 = by_name->at(name).get();
     44     CHECK(op1->type() == type) << "while verifying '" << name << "'.";
     45 
     46     CHECK(by_type->count(type))
     47         << "No operator for '" << OperatorTypeName(type) << "'.";
     48     BaseOperator* op2 = by_type->at(type).get();
     49     CHECK(op2->name() == name)
     50         << "while verifying '" << OperatorTypeName(type) << "'.";
     51 
     52     return *op1;
     53   }
     54 
     55   // Use the given BaseOperator to serialize the tf.mini operator into a set of
     56   // TF Lite options. Proceed to deserialize the options back into a new
     57   // tf.mini operator, which is then returned. If `options` is given, it will
     58   // be populated with the serialized options.
     59   template <typename T>
     60   std::unique_ptr<T> SerializeAndDeserialize(const BaseOperator& op,
     61                                              const T& toco_op,
     62                                              Options* options = nullptr) {
     63     flatbuffers::FlatBufferBuilder builder;
     64     Options input_options = op.Serialize(toco_op, &builder);
     65 
     66     if (options) {
     67       *options = input_options;
     68     }
     69 
     70     builder.Finish(CreateOperator(builder, 0, 0, 0, input_options.type,
     71                                   input_options.builtin, input_options.custom,
     72                                   ::tflite::CustomOptionsFormat_FLEXBUFFERS));
     73     auto* output_options =
     74         flatbuffers::GetRoot<::tflite::Operator>(builder.GetBufferPointer());
     75     auto new_toco_op = op.Deserialize(output_options->builtin_options(),
     76                                       output_options->custom_options());
     77 
     78     CHECK(new_toco_op->type == toco_op.type)
     79         << "The type of the serialized and deserialized"
     80         << HelpfulOperatorTypeName(*new_toco_op)
     81         << " does not match the type of the original "
     82         << HelpfulOperatorTypeName(toco_op);
     83 
     84     return std::unique_ptr<T>(dynamic_cast<T*>(new_toco_op.release()));
     85   }
     86 
     87   // Verify serialization and deserialization of simple operators (those
     88   // that don't have any configuration parameters).
     89   template <typename T>
     90   void CheckSimpleOperator(const string& name, OperatorType type) {
     91     Options options;
     92     auto output_toco_op =
     93         SerializeAndDeserialize(GetOperator(name, type), T(), &options);
     94 
     95     ASSERT_EQ(0, options.builtin.o);
     96     ASSERT_EQ(0, options.custom.o);
     97     ASSERT_EQ(::tflite::BuiltinOptions_NONE, options.type);
     98 
     99     ASSERT_NE(nullptr, output_toco_op.get());
    100   }
    101 
    102   template <typename T>
    103   void CheckReducerOperator(const string& name, OperatorType type) {
    104     T op;
    105 
    106     op.keep_dims = false;
    107 
    108     auto output_toco_op = SerializeAndDeserialize(GetOperator(name, type), op);
    109     EXPECT_EQ(op.keep_dims, output_toco_op->keep_dims);
    110   }
    111 };
    112 
    113 TEST_F(OperatorTest, SimpleOperators) {
    114   CheckSimpleOperator<FloorOperator>("FLOOR", OperatorType::kFloor);
    115   CheckSimpleOperator<CeilOperator>("CEIL", OperatorType::kCeil);
    116   CheckSimpleOperator<EluOperator>("ELU", OperatorType::kElu);
    117   CheckSimpleOperator<ReluOperator>("RELU", OperatorType::kRelu);
    118   CheckSimpleOperator<Relu1Operator>("RELU_N1_TO_1", OperatorType::kRelu1);
    119   CheckSimpleOperator<Relu6Operator>("RELU6", OperatorType::kRelu6);
    120   CheckSimpleOperator<LogisticOperator>("LOGISTIC", OperatorType::kLogistic);
    121   CheckSimpleOperator<TanhOperator>("TANH", OperatorType::kTanh);
    122   CheckSimpleOperator<ExpOperator>("EXP", OperatorType::kExp);
    123   CheckSimpleOperator<CosOperator>("COS", OperatorType::kCos);
    124   CheckSimpleOperator<LogSoftmaxOperator>("LOG_SOFTMAX",
    125                                           OperatorType::kLogSoftmax);
    126   CheckSimpleOperator<TensorFlowMaximumOperator>(
    127       "MAXIMUM", OperatorType::kMaximum);  //  Element-wise Maximum
    128   CheckSimpleOperator<TensorFlowMinimumOperator>(
    129       "MINIMUM", OperatorType::kMinimum);  //  Element-wise Minimum
    130   CheckSimpleOperator<TensorFlowLessOperator>("LESS", OperatorType::kLess);
    131   CheckSimpleOperator<NegOperator>("NEG", OperatorType::kNeg);
    132   CheckSimpleOperator<SelectOperator>("SELECT", OperatorType::kSelect);
    133   CheckSimpleOperator<SliceOperator>("SLICE", OperatorType::kSlice);
    134   CheckSimpleOperator<SinOperator>("SIN", OperatorType::kSin);
    135   CheckSimpleOperator<TensorFlowEqualOperator>("EQUAL", OperatorType::kEqual);
    136   CheckSimpleOperator<TensorFlowNotEqualOperator>("NOT_EQUAL",
    137                                                   OperatorType::kNotEqual);
    138   CheckSimpleOperator<LogOperator>("LOG", OperatorType::kLog);
    139   CheckSimpleOperator<TensorFlowSqrtOperator>("SQRT", OperatorType::kSqrt);
    140   CheckSimpleOperator<TensorFlowRsqrtOperator>("RSQRT", OperatorType::kRsqrt);
    141   CheckSimpleOperator<PowOperator>("POW", OperatorType::kPow);
    142   CheckSimpleOperator<LogicalOrOperator>("LOGICAL_OR",
    143                                          OperatorType::kLogicalOr);
    144   CheckSimpleOperator<LogicalAndOperator>("LOGICAL_AND",
    145                                           OperatorType::kLogicalAnd);
    146   CheckSimpleOperator<LogicalNotOperator>("LOGICAL_NOT",
    147                                           OperatorType::kLogicalNot);
    148   CheckSimpleOperator<FloorDivOperator>("FLOOR_DIV", OperatorType::kFloorDiv);
    149   CheckSimpleOperator<TensorFlowSquareOperator>("SQUARE",
    150                                                 OperatorType::kSquare);
    151   CheckSimpleOperator<TensorFlowZerosLikeOperator>("ZEROS_LIKE",
    152                                                    OperatorType::kZerosLike);
    153   CheckSimpleOperator<FloorModOperator>("FLOOR_MOD", OperatorType::kFloorMod);
    154   CheckSimpleOperator<RangeOperator>("RANGE", OperatorType::kRange);
    155   CheckSimpleOperator<FillOperator>("FILL", OperatorType::kFill);
    156   CheckSimpleOperator<ReverseV2Operator>("REVERSE_V2",
    157                                          OperatorType::kReverseV2);
    158   CheckSimpleOperator<TensorFlowRankOperator>("RANK", OperatorType::kRank);
    159 }
    160 
    161 TEST_F(OperatorTest, BuiltinAdd) {
    162   AddOperator op;
    163   op.fused_activation_function = FusedActivationFunctionType::kRelu6;
    164   auto output_toco_op =
    165       SerializeAndDeserialize(GetOperator("ADD", OperatorType::kAdd), op);
    166   EXPECT_EQ(op.fused_activation_function,
    167             output_toco_op->fused_activation_function);
    168 }
    169 
    170 TEST_F(OperatorTest, BuiltinAddN) {
    171   AddNOperator op;
    172   auto output_toco_op =
    173       SerializeAndDeserialize(GetOperator("ADD_N", OperatorType::kAddN), op);
    174   ASSERT_NE(output_toco_op.get(), nullptr);
    175 }
    176 
    177 TEST_F(OperatorTest, BuiltinReducerOps) {
    178   CheckReducerOperator<MeanOperator>("MEAN", OperatorType::kMean);
    179   CheckReducerOperator<TensorFlowSumOperator>("SUM", OperatorType::kSum);
    180   CheckReducerOperator<TensorFlowProdOperator>("REDUCE_PROD",
    181                                                OperatorType::kReduceProd);
    182   CheckReducerOperator<TensorFlowMaxOperator>("REDUCE_MAX",
    183                                               OperatorType::kReduceMax);
    184   CheckReducerOperator<TensorFlowMinOperator>("REDUCE_MIN",
    185                                               OperatorType::kReduceMin);
    186   CheckReducerOperator<TensorFlowAnyOperator>("REDUCE_ANY", OperatorType::kAny);
    187 }
    188 
    189 TEST_F(OperatorTest, BuiltinCast) {
    190   CastOperator op;
    191   op.src_data_type = ArrayDataType::kFloat;
    192   op.dst_data_type = ArrayDataType::kUint8;
    193   auto output_toco_op =
    194       SerializeAndDeserialize(GetOperator("CAST", OperatorType::kCast), op);
    195   EXPECT_EQ(op.src_data_type, output_toco_op->src_data_type);
    196   EXPECT_EQ(op.dst_data_type, output_toco_op->dst_data_type);
    197 }
    198 
    199 TEST_F(OperatorTest, CustomConcatenation) {
    200   ConcatenationOperator op;
    201   op.axis = 123;
    202   auto output_toco_op = SerializeAndDeserialize(
    203       GetOperator("CONCATENATION", OperatorType::kConcatenation), op);
    204   EXPECT_EQ(op.axis, output_toco_op->axis);
    205 }
    206 
    207 TEST_F(OperatorTest, CustomDepthToSpace) {
    208   DepthToSpaceOperator op;
    209   op.block_size = 123;
    210   auto output_toco_op = SerializeAndDeserialize(
    211       GetOperator("DEPTH_TO_SPACE", OperatorType::kDepthToSpace), op);
    212   EXPECT_EQ(op.block_size, output_toco_op->block_size);
    213 }
    214 
    215 TEST_F(OperatorTest, CustomFakeQuant) {
    216   FakeQuantOperator op;
    217   auto* minmax = new MinMax;
    218   minmax->min = -10;
    219   minmax->max = 200;
    220   op.minmax.reset(minmax);
    221   op.num_bits = 16;
    222   auto output_toco_op = SerializeAndDeserialize(
    223       GetOperator("FAKE_QUANT", OperatorType::kFakeQuant), op);
    224   EXPECT_EQ(op.minmax->min, output_toco_op->minmax->min);
    225   EXPECT_EQ(op.minmax->max, output_toco_op->minmax->max);
    226   EXPECT_EQ(op.num_bits, output_toco_op->num_bits);
    227 }
    228 
    229 TEST_F(OperatorTest, CustomFullyConnected) {
    230   FullyConnectedOperator op;
    231   op.fused_activation_function = FusedActivationFunctionType::kRelu6;
    232   auto output_toco_op = SerializeAndDeserialize(
    233       GetOperator("FULLY_CONNECTED", OperatorType::kFullyConnected), op);
    234   EXPECT_EQ(op.fused_activation_function,
    235             output_toco_op->fused_activation_function);
    236 }
    237 
    238 TEST_F(OperatorTest, BuiltinGather) {
    239   GatherOperator op;
    240   auto output_toco_op =
    241       SerializeAndDeserialize(GetOperator("GATHER", OperatorType::kGather), op);
    242   ASSERT_NE(nullptr, output_toco_op.get());
    243 }
    244 
    245 TEST_F(OperatorTest, BuiltinGatherNd) {
    246   GatherNdOperator op;
    247   auto output_toco_op = SerializeAndDeserialize(
    248       GetOperator("GATHER_ND", OperatorType::kGatherNd), op);
    249   ASSERT_NE(output_toco_op.get(), nullptr);
    250 }
    251 
    252 TEST_F(OperatorTest, BuiltinWhere) {
    253   WhereOperator op;
    254   auto output_toco_op =
    255       SerializeAndDeserialize(GetOperator("WHERE", OperatorType::kWhere), op);
    256   ASSERT_NE(output_toco_op.get(), nullptr);
    257 }
    258 
    259 TEST_F(OperatorTest, BuiltinL2Pool) {
    260   L2PoolOperator op;
    261   op.stride_width = 123;
    262   op.stride_height = 124;
    263   op.padding.type = PaddingType::kValid;
    264   op.kwidth = 480;
    265   op.kheight = 1080;
    266   auto output_toco_op = SerializeAndDeserialize(
    267       GetOperator("L2_POOL_2D", OperatorType::kL2Pool), op);
    268   EXPECT_EQ(op.stride_width, output_toco_op->stride_width);
    269   EXPECT_EQ(op.stride_height, output_toco_op->stride_height);
    270   EXPECT_EQ(op.padding.type, output_toco_op->padding.type);
    271   EXPECT_EQ(op.kwidth, output_toco_op->kwidth);
    272   EXPECT_EQ(op.kheight, output_toco_op->kheight);
    273 }
    274 
    275 TEST_F(OperatorTest, BuiltinLocalResponseNormalization) {
    276   LocalResponseNormalizationOperator op;
    277   op.range = 123;
    278   op.bias = 1.23;
    279   op.alpha = 12.3;
    280   op.beta = .123;
    281   auto output_toco_op = SerializeAndDeserialize(
    282       GetOperator("LOCAL_RESPONSE_NORMALIZATION",
    283                   OperatorType::kLocalResponseNormalization),
    284       op);
    285   EXPECT_EQ(op.range, output_toco_op->range);
    286   EXPECT_EQ(op.bias, output_toco_op->bias);
    287   EXPECT_EQ(op.alpha, output_toco_op->alpha);
    288   EXPECT_EQ(op.beta, output_toco_op->beta);
    289 }
    290 
    291 TEST_F(OperatorTest, BuiltinMaxPool) {
    292   MaxPoolOperator op;
    293   op.stride_width = 123;
    294   op.stride_height = 124;
    295   op.padding.type = PaddingType::kValid;
    296   op.kwidth = 480;
    297   op.kheight = 1080;
    298   auto output_toco_op = SerializeAndDeserialize(
    299       GetOperator("MAX_POOL_2D", OperatorType::kMaxPool), op);
    300   EXPECT_EQ(op.stride_width, output_toco_op->stride_width);
    301   EXPECT_EQ(op.stride_height, output_toco_op->stride_height);
    302   EXPECT_EQ(op.padding.type, output_toco_op->padding.type);
    303   EXPECT_EQ(op.kwidth, output_toco_op->kwidth);
    304   EXPECT_EQ(op.kheight, output_toco_op->kheight);
    305 }
    306 
    307 TEST_F(OperatorTest, BuiltinReshape) {
    308   TensorFlowReshapeOperator op;
    309   op.shape = {1, 2, 4, 5, 8};
    310   auto output_toco_op = SerializeAndDeserialize(
    311       GetOperator("RESHAPE", OperatorType::kReshape), op);
    312   EXPECT_EQ(op.shape, output_toco_op->shape);
    313 }
    314 
    315 TEST_F(OperatorTest, CustomSoftmax) {
    316   SoftmaxOperator op;
    317   op.beta = 123.1;
    318   auto output_toco_op = SerializeAndDeserialize(
    319       GetOperator("SOFTMAX", OperatorType::kSoftmax), op);
    320   EXPECT_EQ(op.beta, output_toco_op->beta);
    321 }
    322 
    323 TEST_F(OperatorTest, BuiltinSpaceToDepth) {
    324   SpaceToDepthOperator op;
    325   op.block_size = 123;
    326   auto output_toco_op = SerializeAndDeserialize(
    327       GetOperator("SPACE_TO_DEPTH", OperatorType::kSpaceToDepth), op);
    328   EXPECT_EQ(op.block_size, output_toco_op->block_size);
    329 }
    330 
    331 TEST_F(OperatorTest, CustomSplit) {
    332   TensorFlowSplitOperator op;
    333   op.num_split = 123;
    334   auto output_toco_op =
    335       SerializeAndDeserialize(GetOperator("SPLIT", OperatorType::kSplit), op);
    336   EXPECT_EQ(op.num_split, output_toco_op->num_split);
    337 }
    338 
    339 TEST_F(OperatorTest, CustomSplitV) {
    340   TensorFlowSplitVOperator op;
    341   op.num_split = 123;
    342   auto output_toco_op = SerializeAndDeserialize(
    343       GetOperator("SPLIT_V", OperatorType::kSplitV), op);
    344   EXPECT_EQ(op.num_split, output_toco_op->num_split);
    345 }
    346 
    347 TEST_F(OperatorTest, BuiltinAveragePool) {
    348   AveragePoolOperator op;
    349   op.fused_activation_function = FusedActivationFunctionType::kRelu6;
    350   op.stride_width = 123;
    351   op.stride_height = 124;
    352   op.padding.type = PaddingType::kValid;
    353   op.kwidth = 480;
    354   op.kheight = 1080;
    355   auto output_toco_op = SerializeAndDeserialize(
    356       GetOperator("AVERAGE_POOL_2D", OperatorType::kAveragePool), op);
    357   EXPECT_EQ(op.fused_activation_function,
    358             output_toco_op->fused_activation_function);
    359   EXPECT_EQ(op.stride_width, output_toco_op->stride_width);
    360   EXPECT_EQ(op.stride_height, output_toco_op->stride_height);
    361   EXPECT_EQ(op.padding.type, output_toco_op->padding.type);
    362   EXPECT_EQ(op.kwidth, output_toco_op->kwidth);
    363   EXPECT_EQ(op.kheight, output_toco_op->kheight);
    364 }
    365 
    366 TEST_F(OperatorTest, BuiltinConvolution) {
    367   ConvOperator op;
    368   op.stride_width = 123;
    369   op.stride_height = 124;
    370   op.padding.type = PaddingType::kValid;
    371   op.fused_activation_function = FusedActivationFunctionType::kRelu6;
    372   auto output_toco_op =
    373       SerializeAndDeserialize(GetOperator("CONV_2D", OperatorType::kConv), op);
    374   EXPECT_EQ(op.stride_width, output_toco_op->stride_width);
    375   EXPECT_EQ(op.stride_height, output_toco_op->stride_height);
    376   EXPECT_EQ(op.padding.type, output_toco_op->padding.type);
    377   EXPECT_EQ(op.fused_activation_function,
    378             output_toco_op->fused_activation_function);
    379 }
    380 
    381 TEST_F(OperatorTest, BuiltinDepthwiseConvolution) {
    382   DepthwiseConvOperator op;
    383   op.stride_width = 123;
    384   op.stride_height = 124;
    385   op.padding.type = PaddingType::kValid;
    386   op.depth_multiplier = 6;
    387   op.fused_activation_function = FusedActivationFunctionType::kRelu6;
    388   auto output_toco_op = SerializeAndDeserialize(
    389       GetOperator("DEPTHWISE_CONV_2D", OperatorType::kDepthwiseConv), op);
    390   EXPECT_EQ(op.stride_width, output_toco_op->stride_width);
    391   EXPECT_EQ(op.stride_height, output_toco_op->stride_height);
    392   EXPECT_EQ(op.padding.type, output_toco_op->padding.type);
    393   EXPECT_EQ(op.depth_multiplier, output_toco_op->depth_multiplier);
    394   EXPECT_EQ(op.fused_activation_function,
    395             output_toco_op->fused_activation_function);
    396 }
    397 
    398 TEST_F(OperatorTest, BuiltinL2Norm) {
    399   L2NormalizationOperator op;
    400   op.fused_activation_function = FusedActivationFunctionType::kRelu6;
    401   auto output_toco_op = SerializeAndDeserialize(
    402       GetOperator("L2_NORMALIZATION", OperatorType::kL2Normalization), op);
    403   EXPECT_EQ(op.fused_activation_function,
    404             output_toco_op->fused_activation_function);
    405 }
    406 
    407 TEST_F(OperatorTest, BuiltinMul) {
    408   MulOperator op;
    409   op.fused_activation_function = FusedActivationFunctionType::kRelu6;
    410   auto output_toco_op =
    411       SerializeAndDeserialize(GetOperator("MUL", OperatorType::kMul), op);
    412   EXPECT_EQ(op.fused_activation_function,
    413             output_toco_op->fused_activation_function);
    414 }
    415 
    416 TEST_F(OperatorTest, ResizeBilinear) {
    417   ResizeBilinearOperator op;
    418   op.align_corners = true;
    419   auto output_toco_op = SerializeAndDeserialize(
    420       GetOperator("RESIZE_BILINEAR", OperatorType::kResizeBilinear), op);
    421   EXPECT_EQ(op.align_corners, output_toco_op->align_corners);
    422 }
    423 
    424 TEST_F(OperatorTest, ResizeNearestNeighbor) {
    425   ResizeNearestNeighborOperator op;
    426   op.align_corners = true;
    427   auto output_toco_op =
    428       SerializeAndDeserialize(GetOperator("RESIZE_NEAREST_NEIGHBOR",
    429                                           OperatorType::kResizeNearestNeighbor),
    430                               op);
    431   EXPECT_EQ(op.align_corners, output_toco_op->align_corners);
    432 }
    433 
    434 TEST_F(OperatorTest, Svdf) {
    435   SvdfOperator op;
    436   op.fused_activation_function = FusedActivationFunctionType::kRelu;
    437   op.rank = 1;
    438   auto output_toco_op =
    439       SerializeAndDeserialize(GetOperator("SVDF", OperatorType::kSvdf), op);
    440   EXPECT_EQ(op.fused_activation_function,
    441             output_toco_op->fused_activation_function);
    442   EXPECT_EQ(op.rank, output_toco_op->rank);
    443 }
    444 
    445 TEST_F(OperatorTest, Squeeze) {
    446   SqueezeOperator op;
    447   op.squeeze_dims = {-2, -3, 4, 1, 4};
    448 
    449   auto output_toco_op = SerializeAndDeserialize(
    450       GetOperator("SQUEEZE", OperatorType::kSqueeze), op);
    451   EXPECT_EQ(op.squeeze_dims, output_toco_op->squeeze_dims);
    452 }
    453 
    454 TEST_F(OperatorTest, StridedSlice) {
    455   StridedSliceOperator op;
    456 
    457   op.begin_mask = 1;
    458   op.end_mask = 2;
    459   op.ellipsis_mask = 1;
    460   op.new_axis_mask = 1;
    461   op.shrink_axis_mask = 2;
    462 
    463   auto output_toco_op = SerializeAndDeserialize(
    464       GetOperator("STRIDED_SLICE", OperatorType::kStridedSlice), op);
    465   EXPECT_EQ(op.start_indices, output_toco_op->start_indices);
    466   EXPECT_EQ(op.stop_indices, output_toco_op->stop_indices);
    467   EXPECT_EQ(op.strides, output_toco_op->strides);
    468   EXPECT_EQ(op.begin_mask, output_toco_op->begin_mask);
    469   EXPECT_EQ(op.end_mask, output_toco_op->end_mask);
    470   EXPECT_EQ(op.end_mask, output_toco_op->end_mask);
    471   EXPECT_EQ(op.ellipsis_mask, output_toco_op->ellipsis_mask);
    472   EXPECT_EQ(op.new_axis_mask, output_toco_op->new_axis_mask);
    473   EXPECT_EQ(op.shrink_axis_mask, output_toco_op->shrink_axis_mask);
    474 }
    475 
    476 TEST_F(OperatorTest, BuiltinTopKV2) {
    477   TopKV2Operator op;
    478   auto output_toco_op = SerializeAndDeserialize(
    479       GetOperator("TOPK_V2", OperatorType::kTopK_V2), op);
    480   ASSERT_NE(nullptr, output_toco_op.get());
    481 }
    482 
    483 TEST_F(OperatorTest, BuiltinArgMax) {
    484   ArgMaxOperator op;
    485   auto output_toco_op = SerializeAndDeserialize(
    486       GetOperator("ARG_MAX", OperatorType::kArgMax), op);
    487   EXPECT_EQ(op.output_data_type, output_toco_op->output_data_type);
    488 }
    489 
    490 TEST_F(OperatorTest, BuiltinArgMin) {
    491   ArgMinOperator op;
    492   auto output_toco_op = SerializeAndDeserialize(
    493       GetOperator("ARG_MIN", OperatorType::kArgMin), op);
    494   EXPECT_EQ(op.output_data_type, output_toco_op->output_data_type);
    495 }
    496 
    497 TEST_F(OperatorTest, BuiltinDequantize) {
    498   DequantizeOperator op;
    499   auto output_toco_op = SerializeAndDeserialize(
    500       GetOperator("DEQUANTIZE", OperatorType::kDequantize), op);
    501 }
    502 
    503 TEST_F(OperatorTest, BuiltinTransposeConv) {
    504   TransposeConvOperator op;
    505   op.stride_width = 123;
    506   op.stride_height = 124;
    507   op.padding.type = PaddingType::kValid;
    508   auto output_toco_op = SerializeAndDeserialize(
    509       GetOperator("TRANSPOSE_CONV", OperatorType::kTransposeConv), op);
    510   EXPECT_EQ(op.stride_width, output_toco_op->stride_width);
    511   EXPECT_EQ(op.stride_height, output_toco_op->stride_height);
    512   EXPECT_EQ(op.padding.type, output_toco_op->padding.type);
    513 }
    514 
    515 TEST_F(OperatorTest, BuiltinShape) {
    516   TensorFlowShapeOperator op;
    517   op.output_data_type = ArrayDataType::kInt64;
    518   auto output_toco_op =
    519       SerializeAndDeserialize(GetOperator("SHAPE", OperatorType::kShape), op);
    520   EXPECT_EQ(op.output_data_type, output_toco_op->output_data_type);
    521 }
    522 
    523 TEST_F(OperatorTest, BuiltinSparseToDense) {
    524   SparseToDenseOperator op;
    525   op.validate_indices = false;
    526   std::unique_ptr<toco::SparseToDenseOperator> output_toco_op =
    527       SerializeAndDeserialize(
    528           GetOperator("SPARSE_TO_DENSE", OperatorType::kSparseToDense), op);
    529   EXPECT_EQ(op.validate_indices, output_toco_op->validate_indices);
    530 }
    531 
    532 TEST_F(OperatorTest, BuiltinPack) {
    533   PackOperator op;
    534   op.values_count = 3;
    535   op.axis = 1;
    536   std::unique_ptr<toco::PackOperator> output_toco_op =
    537       SerializeAndDeserialize(GetOperator("PACK", OperatorType::kPack), op);
    538   EXPECT_EQ(op.values_count, output_toco_op->values_count);
    539   EXPECT_EQ(op.axis, output_toco_op->axis);
    540 }
    541 
    542 TEST_F(OperatorTest, BuiltinOneHot) {
    543   OneHotOperator op;
    544   op.axis = 2;
    545   auto output_toco_op = SerializeAndDeserialize(
    546       GetOperator("ONE_HOT", OperatorType::kOneHot), op);
    547   EXPECT_EQ(op.axis, output_toco_op->axis);
    548 }
    549 
    550 TEST_F(OperatorTest, BuiltinUnpack) {
    551   UnpackOperator op;
    552   op.num = 5;
    553   op.axis = 2;
    554   auto output_toco_op =
    555       SerializeAndDeserialize(GetOperator("UNPACK", OperatorType::kUnpack), op);
    556   EXPECT_EQ(op.num, output_toco_op->num);
    557   EXPECT_EQ(op.axis, output_toco_op->axis);
    558 }
    559 
    560 TEST_F(OperatorTest, BuiltinLeakyRelu) {
    561   LeakyReluOperator op;
    562   op.alpha = 3;
    563   auto output_toco_op = SerializeAndDeserialize(
    564       GetOperator("LEAKY_RELU", OperatorType::kLeakyRelu), op);
    565   EXPECT_EQ(op.alpha, output_toco_op->alpha);
    566 }
    567 
    568 TEST_F(OperatorTest, BuiltinSquaredDifference) {
    569   SquaredDifferenceOperator op;
    570   auto output_toco_op = SerializeAndDeserialize(
    571       GetOperator("SQUARED_DIFFERENCE", OperatorType::kSquaredDifference), op);
    572   ASSERT_NE(nullptr, output_toco_op.get());
    573 }
    574 
    575 TEST_F(OperatorTest, CustomCTCBeamSearchDecoder) {
    576   CTCBeamSearchDecoderOperator op;
    577   op.beam_width = 3;
    578   op.top_paths = 2;
    579   op.merge_repeated = false;
    580   std::unique_ptr<toco::CTCBeamSearchDecoderOperator> output_toco_op =
    581       SerializeAndDeserialize(GetOperator("CTC_BEAM_SEARCH_DECODER",
    582                                           OperatorType::kCTCBeamSearchDecoder),
    583                               op);
    584   EXPECT_EQ(op.beam_width, output_toco_op->beam_width);
    585   EXPECT_EQ(op.top_paths, output_toco_op->top_paths);
    586   EXPECT_EQ(op.merge_repeated, output_toco_op->merge_repeated);
    587 }
    588 
    589 TEST_F(OperatorTest, TensorFlowUnsupported) {
    590   TensorFlowUnsupportedOperator op;
    591   op.tensorflow_op = "MyCustomUnsupportedOp";
    592 
    593   ::tensorflow::NodeDef node_def;
    594   auto attr = node_def.mutable_attr();
    595   (*attr)["float_attr"].set_f(2.0);
    596   (*attr)["str_attr"].set_s("Hello World");
    597   (*attr)["int_attr"].set_i(17);
    598   (*attr)["bool_attr"].set_b(true);
    599   {
    600     auto* list = (*attr)["list_string_attr"].mutable_list();
    601     list->add_s("abcde");
    602     list->add_s("1234");
    603     list->add_s("");
    604     list->add_s("zyxwv");
    605     list->add_s("!-.");
    606   }
    607   {
    608     auto* list = (*attr)["list_float_attr"].mutable_list();
    609     list->add_f(std::numeric_limits<float>::min());
    610     list->add_f(2.0);
    611     list->add_f(-std::numeric_limits<float>::max());
    612   }
    613   {
    614     auto* list = (*attr)["list_int_attr"].mutable_list();
    615     list->add_i(1);
    616     list->add_i(20);
    617     list->add_i(1LL << 40);
    618     list->add_i(-(1LL << 40));
    619   }
    620   node_def.SerializeToString(&op.tensorflow_node_def);
    621 
    622   auto output_toco_op = SerializeAndDeserialize(
    623       GetOperator("TENSORFLOW_UNSUPPORTED", OperatorType::kUnsupported), op);
    624 
    625   ::tensorflow::NodeDef output_node_def;
    626   output_node_def.ParseFromString(output_toco_op->tensorflow_node_def);
    627   const auto& output_attr = output_node_def.attr();
    628   EXPECT_EQ(2.0, output_attr.at("float_attr").f());
    629   EXPECT_EQ("Hello World", output_attr.at("str_attr").s());
    630   EXPECT_EQ(17, output_attr.at("int_attr").i());
    631   EXPECT_EQ(true, output_attr.at("bool_attr").b());
    632   {
    633     const auto& list = output_attr.at("list_string_attr").list();
    634     ASSERT_EQ(5, list.s_size());
    635     EXPECT_EQ("abcde", list.s(0));
    636     EXPECT_EQ("1234", list.s(1));
    637     EXPECT_EQ("", list.s(2));
    638     EXPECT_EQ("zyxwv", list.s(3));
    639     EXPECT_EQ("!-.", list.s(4));
    640   }
    641   {
    642     const auto& list = output_attr.at("list_float_attr").list();
    643     ASSERT_EQ(3, list.f_size());
    644     EXPECT_EQ(std::numeric_limits<float>::min(), list.f(0));
    645     EXPECT_EQ(2.0, list.f(1));
    646     EXPECT_EQ(-std::numeric_limits<float>::max(), list.f(2));
    647   }
    648   {
    649     const auto& list = output_attr.at("list_int_attr").list();
    650     ASSERT_EQ(4, list.i_size());
    651     EXPECT_EQ(1, list.i(0));
    652     EXPECT_EQ(20, list.i(1));
    653     EXPECT_EQ(1LL << 40, list.i(2));
    654     EXPECT_EQ(-(1LL << 40), list.i(3));
    655   }
    656 }
    657 
    658 TEST_F(OperatorTest, TensorFlowUnsupportedWithoutAttr) {
    659   TensorFlowUnsupportedOperator op;
    660   op.tensorflow_op = "MyCustomUnsupportedOp";
    661   auto output_toco_op = SerializeAndDeserialize(
    662       GetOperator("TENSORFLOW_UNSUPPORTED", OperatorType::kUnsupported), op);
    663 
    664   ::tensorflow::NodeDef output_node_def;
    665   output_node_def.ParseFromString(output_toco_op->tensorflow_node_def);
    666   EXPECT_TRUE(output_node_def.attr().empty());
    667 }
    668 
    669 TEST_F(OperatorTest, TestShouldExportAsFlexOp) {
    670   EXPECT_FALSE(ShouldExportAsFlexOp(false, "Conv2D"));
    671   EXPECT_TRUE(ShouldExportAsFlexOp(true, "Conv2D"));
    672   EXPECT_TRUE(ShouldExportAsFlexOp(true, "EluGrad"));
    673   EXPECT_TRUE(ShouldExportAsFlexOp(true, "RFFT"));
    674   EXPECT_FALSE(ShouldExportAsFlexOp(true, "MyAwesomeCustomOp"));
    675   // While the RandomShuffle op is available on desktop, it is not in the kernel
    676   // set available on mobile and should be excluded.
    677   EXPECT_FALSE(ShouldExportAsFlexOp(true, "RandomShuffle"));
    678 }
    679 
    680 TEST_F(OperatorTest, BuiltinMirrorPad) {
    681   MirrorPadOperator op;
    682   op.mode = MirrorPadMode::kReflect;
    683   auto output_toco_op = SerializeAndDeserialize(
    684       GetOperator("MIRROR_PAD", OperatorType::kMirrorPad), op);
    685   EXPECT_EQ(op.mode, output_toco_op->mode);
    686 }
    687 
    688 TEST_F(OperatorTest, BuiltinUnique) {
    689   UniqueOperator op;
    690   op.idx_out_type = ArrayDataType::kInt64;
    691   auto output_toco_op =
    692       SerializeAndDeserialize(GetOperator("UNIQUE", OperatorType::kUnique), op);
    693   ASSERT_NE(nullptr, output_toco_op.get());
    694   EXPECT_EQ(output_toco_op->idx_out_type, op.idx_out_type);
    695 }
    696 
    697 TEST_F(OperatorTest, BuiltinReverseSequence) {
    698   ReverseSequenceOperator op;
    699   op.seq_dim = 3;
    700   op.batch_dim = 1;
    701   std::unique_ptr<toco::ReverseSequenceOperator> output_toco_op =
    702       SerializeAndDeserialize(
    703           GetOperator("REVERSE_SEQUENCE", OperatorType::kReverseSequence), op);
    704   EXPECT_EQ(op.seq_dim, output_toco_op->seq_dim);
    705   EXPECT_EQ(op.batch_dim, output_toco_op->batch_dim);
    706 }
    707 
    708 // Test version for a simple Op with 2 versions and the input type controls the
    709 // version.
    710 template <typename Op>
    711 void SimpleVersioningTest() {
    712   Op op;
    713   op.inputs = {"input1"};
    714   auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
    715   const BaseOperator* base_op = operator_by_type_map.at(op.type).get();
    716 
    717   Model uint8_model;
    718   Array& uint8_array = uint8_model.GetOrCreateArray(op.inputs[0]);
    719   uint8_array.data_type = ArrayDataType::kUint8;
    720   OperatorSignature uint8_signature = {.op = &op, .model = &uint8_model};
    721   EXPECT_EQ(base_op->GetVersion(uint8_signature), 1);
    722 
    723   Model int8_model;
    724   Array& int8_array = int8_model.GetOrCreateArray(op.inputs[0]);
    725   int8_array.data_type = ArrayDataType::kInt8;
    726   OperatorSignature int8_signature = {.op = &op, .model = &int8_model};
    727   EXPECT_EQ(base_op->GetVersion(int8_signature), 2);
    728 }
    729 
    730 // Test version for a simple Op with 2 versions and the output type controls the
    731 // version.
    732 template <typename Op>
    733 void SimpleOutputVersioningTest() {
    734   Op op;
    735   op.outputs = {"output1"};
    736   auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
    737   const BaseOperator* base_op = operator_by_type_map.at(op.type).get();
    738 
    739   Model uint8_model;
    740   Array& uint8_array = uint8_model.GetOrCreateArray(op.outputs[0]);
    741   uint8_array.data_type = ArrayDataType::kUint8;
    742   OperatorSignature uint8_signature = {.op = &op, .model = &uint8_model};
    743   EXPECT_EQ(base_op->GetVersion(uint8_signature), 1);
    744 
    745   Model int8_model;
    746   Array& int8_array = int8_model.GetOrCreateArray(op.outputs[0]);
    747   int8_array.data_type = ArrayDataType::kInt8;
    748   OperatorSignature int8_signature = {.op = &op, .model = &int8_model};
    749   EXPECT_EQ(base_op->GetVersion(int8_signature), 2);
    750 }
    751 
    752 TEST_F(OperatorTest, VersioningEqualTest) {
    753   SimpleVersioningTest<TensorFlowEqualOperator>();
    754 }
    755 
    756 TEST_F(OperatorTest, VersioningNotEqualTest) {
    757   SimpleVersioningTest<TensorFlowNotEqualOperator>();
    758 }
    759 
    760 TEST_F(OperatorTest, VersioningLessTest) {
    761   SimpleVersioningTest<TensorFlowLessOperator>();
    762 }
    763 
    764 TEST_F(OperatorTest, VersioningLessEqualTest) {
    765   SimpleVersioningTest<TensorFlowLessEqualOperator>();
    766 }
    767 
    768 TEST_F(OperatorTest, VersioningGreaterTest) {
    769   SimpleVersioningTest<TensorFlowGreaterOperator>();
    770 }
    771 
    772 TEST_F(OperatorTest, VersioningGreaterEqualTest) {
    773   SimpleVersioningTest<TensorFlowGreaterEqualOperator>();
    774 }
    775 
    776 TEST_F(OperatorTest, VersioningSpaceToBatchNDTest) {
    777   SimpleVersioningTest<SpaceToBatchNDOperator>();
    778 }
    779 
    780 TEST_F(OperatorTest, VersioningLogSoftmaxTest) {
    781   SimpleVersioningTest<LogSoftmaxOperator>();
    782 }
    783 
    784 TEST_F(OperatorTest, VersioningPackTest) {
    785   SimpleVersioningTest<PackOperator>();
    786 }
    787 
    788 TEST_F(OperatorTest, VersioningBatchToSpaceNDTest) {
    789   SimpleVersioningTest<BatchToSpaceNDOperator>();
    790 }
    791 
    792 TEST_F(OperatorTest, VersioningTanhTest) {
    793   SimpleVersioningTest<TanhOperator>();
    794 }
    795 
    796 TEST_F(OperatorTest, VersioningStridedSliceTest) {
    797   SimpleVersioningTest<StridedSliceOperator>();
    798 }
    799 
    800 TEST_F(OperatorTest, VersioningSpaceToDepthTest) {
    801   SimpleVersioningTest<SpaceToDepthOperator>();
    802 }
    803 
    804 TEST_F(OperatorTest, VersioningSliceTest) {
    805   SimpleVersioningTest<SliceOperator>();
    806 }
    807 
    808 TEST_F(OperatorTest, VersioningLogisticTest) {
    809   SimpleVersioningTest<LogisticOperator>();
    810 }
    811 
    812 TEST_F(OperatorTest, VersioningL2NormTest) {
    813   SimpleOutputVersioningTest<L2NormalizationOperator>();
    814 }
    815 
    816 TEST_F(OperatorTest, VersioningMaxTest) {
    817   SimpleVersioningTest<TensorFlowMaximumOperator>();
    818 }
    819 
    820 TEST_F(OperatorTest, VersioningMinTest) {
    821   SimpleVersioningTest<TensorFlowMinimumOperator>();
    822 }
    823 
    824 TEST_F(OperatorTest, VersioningAddTest) { SimpleVersioningTest<AddOperator>(); }
    825 
    826 TEST_F(OperatorTest, VersioningSubTest) { SimpleVersioningTest<SubOperator>(); }
    827 
    828 TEST_F(OperatorTest, VersioningMulTest) { SimpleVersioningTest<MulOperator>(); }
    829 
    830 TEST_F(OperatorTest, VersioningPadTest) { SimpleVersioningTest<PadOperator>(); }
    831 
    832 TEST_F(OperatorTest, VersioningPadV2Test) {
    833   SimpleVersioningTest<PadV2Operator>();
    834 }
    835 
    836 TEST_F(OperatorTest, VersioningConcatenationTest) {
    837   SimpleVersioningTest<ConcatenationOperator>();
    838 }
    839 
    840 TEST_F(OperatorTest, VersioningSelectTest) {
    841   SimpleVersioningTest<SelectOperator>();
    842 }
    843 
    844 TEST_F(OperatorTest, VersioningRelu6Test) {
    845   SimpleVersioningTest<Relu6Operator>();
    846 }
    847 
    848 TEST_F(OperatorTest, VersioningFullyConnectedTest) {
    849   FullyConnectedOperator fully_connected_op;
    850   fully_connected_op.inputs = {"input", "weight"};
    851   fully_connected_op.outputs = {"output"};
    852   auto operator_by_type_map = BuildOperatorByTypeMap(false /*enable_flex_ops*/);
    853   const BaseOperator* op =
    854       operator_by_type_map.at(fully_connected_op.type).get();
    855 
    856   Model uint8_model;
    857   Array& input_uint8_array =
    858       uint8_model.GetOrCreateArray(fully_connected_op.inputs[0]);
    859   input_uint8_array.data_type = ArrayDataType::kUint8;
    860   Array& weight_uint8_array =
    861       uint8_model.GetOrCreateArray(fully_connected_op.inputs[1]);
    862   weight_uint8_array.data_type = ArrayDataType::kUint8;
    863   Array& output_uint8_array =
    864       uint8_model.GetOrCreateArray(fully_connected_op.outputs[0]);
    865   output_uint8_array.data_type = ArrayDataType::kUint8;
    866   OperatorSignature uint8_signature = {.op = &fully_connected_op,
    867                                        .model = &uint8_model};
    868   EXPECT_EQ(op->GetVersion(uint8_signature), 1);
    869 
    870   Model int8_model;
    871   Array& input_int8_array =
    872       int8_model.GetOrCreateArray(fully_connected_op.inputs[0]);
    873   input_int8_array.data_type = ArrayDataType::kInt8;
    874   Array& weight_int8_array =
    875       int8_model.GetOrCreateArray(fully_connected_op.inputs[1]);
    876   weight_int8_array.data_type = ArrayDataType::kInt8;
    877   Array& output_int8_array =
    878       int8_model.GetOrCreateArray(fully_connected_op.outputs[0]);
    879   output_int8_array.data_type = ArrayDataType::kInt8;
    880   OperatorSignature int8_signature = {.op = &fully_connected_op,
    881                                       .model = &int8_model};
    882   EXPECT_EQ(op->GetVersion(int8_signature), 4);
    883 }
    884 
    885 }  // namespace
    886 }  // namespace tflite
    887 
    888 }  // namespace toco
    889