Home | History | Annotate | Download | only in tflite
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include "tensorflow/contrib/lite/toco/tflite/operator.h"
     16 
     17 #include "flatbuffers/flexbuffers.h"
     18 #include <gmock/gmock.h>
     19 #include <gtest/gtest.h>
     20 #include "tensorflow/contrib/lite/toco/tooling_util.h"
     21 
     22 #include "tensorflow/core/framework/attr_value.pb.h"
     23 #include "tensorflow/core/framework/node_def.pb.h"
     24 
     25 namespace toco {
     26 
     27 namespace tflite {
     28 namespace {
     29 
     30 class OperatorTest : public ::testing::Test {
     31  protected:
     32   // Return the operator for the given name and type.
     33   const BaseOperator& GetOperator(const string& name, OperatorType type) {
     34     using OpsByName = std::map<string, std::unique_ptr<BaseOperator>>;
     35     using OpsByType = std::map<OperatorType, std::unique_ptr<BaseOperator>>;
     36 
     37     static auto* by_name = new OpsByName(BuildOperatorByNameMap());
     38     static auto* by_type = new OpsByType(BuildOperatorByTypeMap());
     39 
     40     // Make sure the two maps were consitently built.
     41     CHECK(by_name->count(name)) << "No operator for '" << name << "'.";
     42     BaseOperator* op1 = by_name->at(name).get();
     43     CHECK(op1->type() == type) << "while verifying '" << name << "'.";
     44 
     45     CHECK(by_type->count(type))
     46         << "No operator for '" << OperatorTypeName(type) << "'.";
     47     BaseOperator* op2 = by_type->at(type).get();
     48     CHECK(op2->name() == name)
     49         << "while verifying '" << OperatorTypeName(type) << "'.";
     50 
     51     return *op1;
     52   }
     53 
     54   // Use the given BaseOperator to serialize the tf.mini operator into a set of
     55   // TF Lite options. Proceed to deserialize the options back into a new
     56   // tf.mini operator, which is then returned. If `options` is given, it will
     57   // be populated with the serialized options.
     58   template <typename T>
     59   std::unique_ptr<T> SerializeAndDeserialize(const BaseOperator& op,
     60                                              const T& toco_op,
     61                                              Options* options = nullptr) {
     62     flatbuffers::FlatBufferBuilder builder;
     63     Options input_options = op.Serialize(toco_op, &builder);
     64 
     65     if (options) {
     66       *options = input_options;
     67     }
     68 
     69     builder.Finish(CreateOperator(builder, 0, 0, 0, input_options.type,
     70                                   input_options.builtin, input_options.custom,
     71                                   ::tflite::CustomOptionsFormat_FLEXBUFFERS));
     72     auto* output_options =
     73         flatbuffers::GetRoot<::tflite::Operator>(builder.GetBufferPointer());
     74     auto new_toco_op = op.Deserialize(output_options->builtin_options(),
     75                                       output_options->custom_options());
     76 
     77     CHECK(dynamic_cast<T*>(new_toco_op.get()))
     78         << "Cannot cast " << HelpfulOperatorTypeName(*new_toco_op) << " to "
     79         << HelpfulOperatorTypeName(toco_op);
     80 
     81     return std::unique_ptr<T>(dynamic_cast<T*>(new_toco_op.release()));
     82   }
     83 
     84   // Verify serialization and deserialization of simple operators (those
     85   // that don't have any configuration parameters).
     86   template <typename T>
     87   void CheckSimpleOperator(const string& name, OperatorType type) {
     88     Options options;
     89     auto output_toco_op =
     90         SerializeAndDeserialize(GetOperator(name, type), T(), &options);
     91 
     92     ASSERT_EQ(0, options.builtin.o);
     93     ASSERT_EQ(0, options.custom.o);
     94     ASSERT_EQ(::tflite::BuiltinOptions_NONE, options.type);
     95 
     96     ASSERT_NE(nullptr, output_toco_op.get());
     97   }
     98 };
     99 
    100 TEST_F(OperatorTest, SimpleOperators) {
    101   CheckSimpleOperator<DequantizeOperator>("DEQUANTIZE",
    102                                           OperatorType::kDequantize);
    103   CheckSimpleOperator<FloorOperator>("FLOOR", OperatorType::kFloor);
    104   CheckSimpleOperator<ReluOperator>("RELU", OperatorType::kRelu);
    105   CheckSimpleOperator<Relu1Operator>("RELU_N1_TO_1", OperatorType::kRelu1);
    106   CheckSimpleOperator<Relu6Operator>("RELU6", OperatorType::kRelu6);
    107   CheckSimpleOperator<LogisticOperator>("LOGISTIC", OperatorType::kLogistic);
    108   CheckSimpleOperator<TanhOperator>("TANH", OperatorType::kTanh);
    109   CheckSimpleOperator<ExpOperator>("EXP", OperatorType::kExp);
    110 }
    111 
    112 TEST_F(OperatorTest, BuiltinAdd) {
    113   AddOperator op;
    114   op.fused_activation_function = FusedActivationFunctionType::kRelu6;
    115   auto output_toco_op =
    116       SerializeAndDeserialize(GetOperator("ADD", OperatorType::kAdd), op);
    117   EXPECT_EQ(op.fused_activation_function,
    118             output_toco_op->fused_activation_function);
    119 }
    120 
    121 TEST_F(OperatorTest, BuiltinMean) {
    122   MeanOperator op;
    123   op.keep_dims = false;
    124 
    125   auto output_toco_op =
    126       SerializeAndDeserialize(GetOperator("MEAN", OperatorType::kMean), op);
    127   EXPECT_EQ(op.keep_dims, output_toco_op->keep_dims);
    128 }
    129 
    130 TEST_F(OperatorTest, CustomCast) {
    131   CastOperator op;
    132   op.src_data_type = ArrayDataType::kFloat;
    133   op.dst_data_type = ArrayDataType::kUint8;
    134   auto output_toco_op =
    135       SerializeAndDeserialize(GetOperator("CAST", OperatorType::kCast), op);
    136   EXPECT_EQ(op.src_data_type, output_toco_op->src_data_type);
    137   EXPECT_EQ(op.dst_data_type, output_toco_op->dst_data_type);
    138 }
    139 
    140 TEST_F(OperatorTest, CustomConcatenation) {
    141   ConcatenationOperator op;
    142   op.axis = 123;
    143   auto output_toco_op = SerializeAndDeserialize(
    144       GetOperator("CONCATENATION", OperatorType::kConcatenation), op);
    145   EXPECT_EQ(op.axis, output_toco_op->axis);
    146 }
    147 
    148 TEST_F(OperatorTest, CustomDepthToSpace) {
    149   DepthToSpaceOperator op;
    150   op.block_size = 123;
    151   auto output_toco_op = SerializeAndDeserialize(
    152       GetOperator("DEPTH_TO_SPACE", OperatorType::kDepthToSpace), op);
    153   EXPECT_EQ(op.block_size, output_toco_op->block_size);
    154 }
    155 
    156 TEST_F(OperatorTest, CustomFakeQuant) {
    157   FakeQuantOperator op;
    158   auto* minmax = new MinMax;
    159   minmax->min = -10;
    160   minmax->max = 200;
    161   op.minmax.reset(minmax);
    162   auto output_toco_op = SerializeAndDeserialize(
    163       GetOperator("FAKE_QUANT", OperatorType::kFakeQuant), op);
    164   EXPECT_EQ(op.minmax->min, output_toco_op->minmax->min);
    165   EXPECT_EQ(op.minmax->max, output_toco_op->minmax->max);
    166 }
    167 
    168 TEST_F(OperatorTest, CustomFullyConnected) {
    169   FullyConnectedOperator op;
    170   op.fused_activation_function = FusedActivationFunctionType::kRelu6;
    171   auto output_toco_op = SerializeAndDeserialize(
    172       GetOperator("FULLY_CONNECTED", OperatorType::kFullyConnected), op);
    173   EXPECT_EQ(op.fused_activation_function,
    174             output_toco_op->fused_activation_function);
    175 }
    176 
    177 TEST_F(OperatorTest, BuiltinGather) {
    178   GatherOperator op;
    179   auto output_toco_op =
    180       SerializeAndDeserialize(GetOperator("GATHER", OperatorType::kGather), op);
    181   ASSERT_NE(nullptr, output_toco_op.get());
    182 }
    183 
    184 TEST_F(OperatorTest, BuiltinL2Pool) {
    185   L2PoolOperator op;
    186   op.stride_width = 123;
    187   op.stride_height = 124;
    188   op.padding.type = PaddingType::kValid;
    189   op.kwidth = 480;
    190   op.kheight = 1080;
    191   auto output_toco_op = SerializeAndDeserialize(
    192       GetOperator("L2_POOL_2D", OperatorType::kL2Pool), op);
    193   EXPECT_EQ(op.stride_width, output_toco_op->stride_width);
    194   EXPECT_EQ(op.stride_height, output_toco_op->stride_height);
    195   EXPECT_EQ(op.padding.type, output_toco_op->padding.type);
    196   EXPECT_EQ(op.kwidth, output_toco_op->kwidth);
    197   EXPECT_EQ(op.kheight, output_toco_op->kheight);
    198 }
    199 
    200 TEST_F(OperatorTest, BuiltinLocalResponseNormalization) {
    201   LocalResponseNormalizationOperator op;
    202   op.range = 123;
    203   op.bias = 1.23;
    204   op.alpha = 12.3;
    205   op.beta = .123;
    206   auto output_toco_op = SerializeAndDeserialize(
    207       GetOperator("LOCAL_RESPONSE_NORMALIZATION",
    208                   OperatorType::kLocalResponseNormalization),
    209       op);
    210   EXPECT_EQ(op.range, output_toco_op->range);
    211   EXPECT_EQ(op.bias, output_toco_op->bias);
    212   EXPECT_EQ(op.alpha, output_toco_op->alpha);
    213   EXPECT_EQ(op.beta, output_toco_op->beta);
    214 }
    215 
    216 TEST_F(OperatorTest, BuiltinMaxPool) {
    217   MaxPoolOperator op;
    218   op.stride_width = 123;
    219   op.stride_height = 124;
    220   op.padding.type = PaddingType::kValid;
    221   op.kwidth = 480;
    222   op.kheight = 1080;
    223   auto output_toco_op = SerializeAndDeserialize(
    224       GetOperator("MAX_POOL_2D", OperatorType::kMaxPool), op);
    225   EXPECT_EQ(op.stride_width, output_toco_op->stride_width);
    226   EXPECT_EQ(op.stride_height, output_toco_op->stride_height);
    227   EXPECT_EQ(op.padding.type, output_toco_op->padding.type);
    228   EXPECT_EQ(op.kwidth, output_toco_op->kwidth);
    229   EXPECT_EQ(op.kheight, output_toco_op->kheight);
    230 }
    231 
    232 TEST_F(OperatorTest, BuiltinReshape) {
    233   TensorFlowReshapeOperator op;
    234   op.shape = {1, 2, 4, 5, 8};
    235   auto output_toco_op = SerializeAndDeserialize(
    236       GetOperator("RESHAPE", OperatorType::kTensorFlowReshape), op);
    237   EXPECT_EQ(op.shape, output_toco_op->shape);
    238 }
    239 
    240 TEST_F(OperatorTest, CustomSoftmax) {
    241   SoftmaxOperator op;
    242   op.beta = 123.1;
    243   auto output_toco_op = SerializeAndDeserialize(
    244       GetOperator("SOFTMAX", OperatorType::kSoftmax), op);
    245   EXPECT_EQ(op.beta, output_toco_op->beta);
    246 }
    247 
    248 TEST_F(OperatorTest, BuiltinSpaceToDepth) {
    249   SpaceToDepthOperator op;
    250   op.block_size = 123;
    251   auto output_toco_op = SerializeAndDeserialize(
    252       GetOperator("SPACE_TO_DEPTH", OperatorType::kSpaceToDepth), op);
    253   EXPECT_EQ(op.block_size, output_toco_op->block_size);
    254 }
    255 
    256 TEST_F(OperatorTest, CustomSplit) {
    257   TensorFlowSplitOperator op;
    258   op.num_split = 123;
    259   auto output_toco_op = SerializeAndDeserialize(
    260       GetOperator("SPLIT", OperatorType::kTensorFlowSplit), op);
    261   EXPECT_EQ(op.num_split, output_toco_op->num_split);
    262 }
    263 
    264 TEST_F(OperatorTest, BuiltinAveragePool) {
    265   AveragePoolOperator op;
    266   op.fused_activation_function = FusedActivationFunctionType::kRelu6;
    267   op.stride_width = 123;
    268   op.stride_height = 124;
    269   op.padding.type = PaddingType::kValid;
    270   op.kwidth = 480;
    271   op.kheight = 1080;
    272   auto output_toco_op = SerializeAndDeserialize(
    273       GetOperator("AVERAGE_POOL_2D", OperatorType::kAveragePool), op);
    274   EXPECT_EQ(op.fused_activation_function,
    275             output_toco_op->fused_activation_function);
    276   EXPECT_EQ(op.stride_width, output_toco_op->stride_width);
    277   EXPECT_EQ(op.stride_height, output_toco_op->stride_height);
    278   EXPECT_EQ(op.padding.type, output_toco_op->padding.type);
    279   EXPECT_EQ(op.kwidth, output_toco_op->kwidth);
    280   EXPECT_EQ(op.kheight, output_toco_op->kheight);
    281 }
    282 
    283 TEST_F(OperatorTest, BuiltinConvolution) {
    284   ConvOperator op;
    285   op.stride_width = 123;
    286   op.stride_height = 124;
    287   op.padding.type = PaddingType::kValid;
    288   op.fused_activation_function = FusedActivationFunctionType::kRelu6;
    289   auto output_toco_op =
    290       SerializeAndDeserialize(GetOperator("CONV_2D", OperatorType::kConv), op);
    291   EXPECT_EQ(op.stride_width, output_toco_op->stride_width);
    292   EXPECT_EQ(op.stride_height, output_toco_op->stride_height);
    293   EXPECT_EQ(op.padding.type, output_toco_op->padding.type);
    294   EXPECT_EQ(op.fused_activation_function,
    295             output_toco_op->fused_activation_function);
    296 }
    297 
    298 TEST_F(OperatorTest, BuiltinDepthwiseConvolution) {
    299   DepthwiseConvOperator op;
    300   op.stride_width = 123;
    301   op.stride_height = 124;
    302   op.padding.type = PaddingType::kValid;
    303   op.depth_multiplier = 6;
    304   op.fused_activation_function = FusedActivationFunctionType::kRelu6;
    305   auto output_toco_op = SerializeAndDeserialize(
    306       GetOperator("DEPTHWISE_CONV_2D", OperatorType::kDepthwiseConv), op);
    307   EXPECT_EQ(op.stride_width, output_toco_op->stride_width);
    308   EXPECT_EQ(op.stride_height, output_toco_op->stride_height);
    309   EXPECT_EQ(op.padding.type, output_toco_op->padding.type);
    310   EXPECT_EQ(op.depth_multiplier, output_toco_op->depth_multiplier);
    311   EXPECT_EQ(op.fused_activation_function,
    312             output_toco_op->fused_activation_function);
    313 }
    314 
    315 TEST_F(OperatorTest, BuiltinL2Norm) {
    316   L2NormalizationOperator op;
    317   op.fused_activation_function = FusedActivationFunctionType::kRelu6;
    318   auto output_toco_op = SerializeAndDeserialize(
    319       GetOperator("L2_NORMALIZATION", OperatorType::kL2Normalization), op);
    320   EXPECT_EQ(op.fused_activation_function,
    321             output_toco_op->fused_activation_function);
    322 }
    323 
    324 TEST_F(OperatorTest, BuiltinMul) {
    325   MulOperator op;
    326   op.fused_activation_function = FusedActivationFunctionType::kRelu6;
    327   auto output_toco_op =
    328       SerializeAndDeserialize(GetOperator("MUL", OperatorType::kMul), op);
    329   EXPECT_EQ(op.fused_activation_function,
    330             output_toco_op->fused_activation_function);
    331 }
    332 
    333 TEST_F(OperatorTest, ResizeBilinear) {
    334   ResizeBilinearOperator op;
    335   op.align_corners = true;
    336   auto output_toco_op = SerializeAndDeserialize(
    337       GetOperator("RESIZE_BILINEAR", OperatorType::kResizeBilinear), op);
    338   EXPECT_EQ(op.align_corners, output_toco_op->align_corners);
    339 }
    340 
    341 TEST_F(OperatorTest, Svdf) {
    342   SvdfOperator op;
    343   op.fused_activation_function = FusedActivationFunctionType::kRelu;
    344   op.rank = 1;
    345   auto output_toco_op =
    346       SerializeAndDeserialize(GetOperator("SVDF", OperatorType::kSvdf), op);
    347   EXPECT_EQ(op.fused_activation_function,
    348             output_toco_op->fused_activation_function);
    349   EXPECT_EQ(op.rank, output_toco_op->rank);
    350 }
    351 
    352 TEST_F(OperatorTest, Squeeze) {
    353   SqueezeOperator op;
    354   op.squeeze_dims = {-2, -3, 4, 1, 4};
    355 
    356   auto output_toco_op = SerializeAndDeserialize(
    357       GetOperator("SQUEEZE", OperatorType::kSqueeze), op);
    358   EXPECT_EQ(op.squeeze_dims, output_toco_op->squeeze_dims);
    359 }
    360 
    361 TEST_F(OperatorTest, StridedSlice) {
    362   StridedSliceOperator op;
    363 
    364   op.begin_mask = 1;
    365   op.end_mask = 2;
    366   op.ellipsis_mask = 1;
    367   op.new_axis_mask = 1;
    368   op.shrink_axis_mask = 2;
    369 
    370   auto output_toco_op = SerializeAndDeserialize(
    371       GetOperator("STRIDED_SLICE", OperatorType::kStridedSlice), op);
    372   EXPECT_EQ(op.start_indices, output_toco_op->start_indices);
    373   EXPECT_EQ(op.stop_indices, output_toco_op->stop_indices);
    374   EXPECT_EQ(op.strides, output_toco_op->strides);
    375   EXPECT_EQ(op.begin_mask, output_toco_op->begin_mask);
    376   EXPECT_EQ(op.end_mask, output_toco_op->end_mask);
    377   EXPECT_EQ(op.end_mask, output_toco_op->end_mask);
    378   EXPECT_EQ(op.ellipsis_mask, output_toco_op->ellipsis_mask);
    379   EXPECT_EQ(op.new_axis_mask, output_toco_op->new_axis_mask);
    380   EXPECT_EQ(op.shrink_axis_mask, output_toco_op->shrink_axis_mask);
    381 }
    382 
    383 TEST_F(OperatorTest, BuiltinTopKV2) {
    384   TopKV2Operator op;
    385   auto output_toco_op = SerializeAndDeserialize(
    386       GetOperator("TOPK_V2", OperatorType::kTopK_V2), op);
    387   ASSERT_NE(nullptr, output_toco_op.get());
    388 }
    389 
    390 TEST_F(OperatorTest, TensorFlowUnsupported) {
    391   TensorFlowUnsupportedOperator op;
    392   op.tensorflow_op = "MyCustomUnsupportedOp";
    393 
    394   ::tensorflow::NodeDef node_def;
    395   auto attr = node_def.mutable_attr();
    396   (*attr)["float_attr"].set_f(2.0);
    397   (*attr)["str_attr"].set_s("Hello World");
    398   (*attr)["int_attr"].set_i(17);
    399   (*attr)["bool_attr"].set_b(true);
    400   node_def.SerializeToString(&op.tensorflow_node_def);
    401 
    402   auto output_toco_op =
    403       SerializeAndDeserialize(GetOperator("TENSORFLOW_UNSUPPORTED",
    404                                           OperatorType::kTensorFlowUnsupported),
    405                               op);
    406 
    407   ::tensorflow::NodeDef output_node_def;
    408   output_node_def.ParseFromString(output_toco_op->tensorflow_node_def);
    409   const auto& output_attr = output_node_def.attr();
    410   EXPECT_EQ(2.0, output_attr.at("float_attr").f());
    411   EXPECT_EQ("Hello World", output_attr.at("str_attr").s());
    412   EXPECT_EQ(17, output_attr.at("int_attr").i());
    413   EXPECT_EQ(true, output_attr.at("bool_attr").b());
    414 }
    415 
    416 TEST_F(OperatorTest, TensorFlowUnsupportedWithoutAttr) {
    417   TensorFlowUnsupportedOperator op;
    418   op.tensorflow_op = "MyCustomUnsupportedOp";
    419   auto output_toco_op =
    420       SerializeAndDeserialize(GetOperator("TENSORFLOW_UNSUPPORTED",
    421                                           OperatorType::kTensorFlowUnsupported),
    422                               op);
    423 
    424   ::tensorflow::NodeDef output_node_def;
    425   output_node_def.ParseFromString(output_toco_op->tensorflow_node_def);
    426   EXPECT_TRUE(output_node_def.attr().empty());
    427 }
    428 
    429 }  // namespace
    430 }  // namespace tflite
    431 
    432 }  // namespace toco
    433