Home | History | Annotate | Download | only in tflite
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include "tensorflow/contrib/lite/toco/tflite/operator.h"
     16 
     17 #include "tensorflow/contrib/lite/toco/tflite/builtin_operator.h"
     18 #include "tensorflow/contrib/lite/toco/tflite/custom_operator.h"
     19 #include "tensorflow/contrib/lite/toco/tflite/simple_operator.h"
     20 #include "tensorflow/contrib/lite/toco/tflite/types.h"
     21 
     22 #include "tensorflow/core/framework/attr_value.pb.h"
     23 #include "tensorflow/core/framework/node_def.pb.h"
     24 
     25 namespace toco {
     26 
     27 namespace tflite {
     28 
     29 class AveragePool
     30     : public BuiltinOperator<AveragePoolOperator, ::tflite::Pool2DOptions,
     31                              ::tflite::BuiltinOptions_Pool2DOptions> {
     32  public:
     33   using BuiltinOperator::BuiltinOperator;
     34 
     35   flatbuffers::Offset<TfLiteOptions> WriteOptions(
     36       const TocoOperator& op,
     37       flatbuffers::FlatBufferBuilder* builder) const override {
     38     auto padding = Padding::Serialize(op.padding.type);
     39     auto activation_function =
     40         ActivationFunction::Serialize(op.fused_activation_function);
     41     return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
     42                                          op.stride_height, op.kwidth,
     43                                          op.kheight, activation_function);
     44   }
     45 
     46   void ReadOptions(const TfLiteOptions& options,
     47                    TocoOperator* op) const override {
     48     op->padding.type = Padding::Deserialize(options.padding());
     49     op->stride_width = options.stride_w();
     50     op->stride_height = options.stride_h();
     51     op->kwidth = options.filter_width();
     52     op->kheight = options.filter_height();
     53     op->fused_activation_function =
     54         ActivationFunction::Deserialize(options.fused_activation_function());
     55   }
     56 };
     57 
     58 class Convolution
     59     : public BuiltinOperator<ConvOperator, ::tflite::Conv2DOptions,
     60                              ::tflite::BuiltinOptions_Conv2DOptions> {
     61  public:
     62   using BuiltinOperator::BuiltinOperator;
     63 
     64   flatbuffers::Offset<TfLiteOptions> WriteOptions(
     65       const TocoOperator& op,
     66       flatbuffers::FlatBufferBuilder* builder) const override {
     67     auto padding = Padding::Serialize(op.padding.type);
     68     auto activation_function =
     69         ActivationFunction::Serialize(op.fused_activation_function);
     70     return ::tflite::CreateConv2DOptions(*builder, padding, op.stride_width,
     71                                          op.stride_height, activation_function);
     72   }
     73 
     74   void ReadOptions(const TfLiteOptions& options,
     75                    TocoOperator* op) const override {
     76     op->padding.type = Padding::Deserialize(options.padding());
     77     op->stride_width = options.stride_w();
     78     op->stride_height = options.stride_h();
     79     op->fused_activation_function =
     80         ActivationFunction::Deserialize(options.fused_activation_function());
     81   }
     82 };
     83 
     84 class DepthwiseConvolution
     85     : public BuiltinOperator<DepthwiseConvOperator,
     86                              ::tflite::DepthwiseConv2DOptions,
     87                              ::tflite::BuiltinOptions_DepthwiseConv2DOptions> {
     88  public:
     89   using BuiltinOperator::BuiltinOperator;
     90 
     91   flatbuffers::Offset<TfLiteOptions> WriteOptions(
     92       const TocoOperator& op,
     93       flatbuffers::FlatBufferBuilder* builder) const override {
     94     auto padding = Padding::Serialize(op.padding.type);
     95     auto activation_function =
     96         ActivationFunction::Serialize(op.fused_activation_function);
     97     return ::tflite::CreateDepthwiseConv2DOptions(
     98         *builder, padding, op.stride_width, op.stride_height,
     99         op.depth_multiplier, activation_function);
    100   }
    101 
    102   void ReadOptions(const TfLiteOptions& options,
    103                    TocoOperator* op) const override {
    104     op->padding.type = Padding::Deserialize(options.padding());
    105     op->stride_width = options.stride_w();
    106     op->stride_height = options.stride_h();
    107     op->depth_multiplier = options.depth_multiplier();
    108     op->fused_activation_function =
    109         ActivationFunction::Deserialize(options.fused_activation_function());
    110   }
    111 };
    112 
    113 class Add : public BuiltinOperator<AddOperator, ::tflite::AddOptions,
    114                                    ::tflite::BuiltinOptions_AddOptions> {
    115  public:
    116   using BuiltinOperator::BuiltinOperator;
    117 
    118   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    119       const TocoOperator& op,
    120       flatbuffers::FlatBufferBuilder* builder) const override {
    121     auto activation_function =
    122         ActivationFunction::Serialize(op.fused_activation_function);
    123     return ::tflite::CreateAddOptions(*builder, activation_function);
    124   }
    125 
    126   void ReadOptions(const TfLiteOptions& options,
    127                    TocoOperator* op) const override {
    128     op->fused_activation_function =
    129         ActivationFunction::Deserialize(options.fused_activation_function());
    130   }
    131 };
    132 
    133 class SpaceToBatchND
    134     : public BuiltinOperator<SpaceToBatchNDOperator,
    135                              ::tflite::SpaceToBatchNDOptions,
    136                              ::tflite::BuiltinOptions_SpaceToBatchNDOptions> {
    137  public:
    138   using BuiltinOperator::BuiltinOperator;
    139 
    140   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    141       const TocoOperator& op,
    142       flatbuffers::FlatBufferBuilder* builder) const override {
    143     return ::tflite::CreateSpaceToBatchNDOptions(*builder);
    144   }
    145 
    146   void ReadOptions(const TfLiteOptions& options,
    147                    TocoOperator* op) const override {}
    148 };
    149 
    150 class Sub : public BuiltinOperator<SubOperator, ::tflite::SubOptions,
    151                                    ::tflite::BuiltinOptions_SubOptions> {
    152  public:
    153   using BuiltinOperator::BuiltinOperator;
    154 
    155   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    156       const TocoOperator& op,
    157       flatbuffers::FlatBufferBuilder* builder) const override {
    158     auto activation_function =
    159         ActivationFunction::Serialize(op.fused_activation_function);
    160     return ::tflite::CreateSubOptions(*builder, activation_function);
    161   }
    162 
    163   void ReadOptions(const TfLiteOptions& options,
    164                    TocoOperator* op) const override {
    165     op->fused_activation_function =
    166         ActivationFunction::Deserialize(options.fused_activation_function());
    167   }
    168 };
    169 
    170 class Div : public BuiltinOperator<DivOperator, ::tflite::DivOptions,
    171                                    ::tflite::BuiltinOptions_DivOptions> {
    172  public:
    173   using BuiltinOperator::BuiltinOperator;
    174 
    175   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    176       const TocoOperator& op,
    177       flatbuffers::FlatBufferBuilder* builder) const override {
    178     auto activation_function =
    179         ActivationFunction::Serialize(op.fused_activation_function);
    180     return ::tflite::CreateDivOptions(*builder, activation_function);
    181   }
    182 
    183   void ReadOptions(const TfLiteOptions& options,
    184                    TocoOperator* op) const override {
    185     op->fused_activation_function =
    186         ActivationFunction::Deserialize(options.fused_activation_function());
    187   }
    188 };
    189 
    190 class BatchToSpaceND
    191     : public BuiltinOperator<BatchToSpaceNDOperator,
    192                              ::tflite::BatchToSpaceNDOptions,
    193                              ::tflite::BuiltinOptions_BatchToSpaceNDOptions> {
    194  public:
    195   using BuiltinOperator::BuiltinOperator;
    196 
    197   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    198       const TocoOperator& op,
    199       flatbuffers::FlatBufferBuilder* builder) const override {
    200     return ::tflite::CreateBatchToSpaceNDOptions(*builder);
    201   }
    202 
    203   void ReadOptions(const TfLiteOptions& options,
    204                    TocoOperator* op) const override {}
    205 };
    206 
    207 class Cast : public CustomOperator<CastOperator> {
    208  public:
    209   using CustomOperator::CustomOperator;
    210   void WriteOptions(const TocoOperator& op,
    211                     flexbuffers::Builder* fbb) const override {
    212     fbb->Int("src_data_type", DataType::Serialize(op.src_data_type));
    213     fbb->Int("dst_data_type", DataType::Serialize(op.dst_data_type));
    214   }
    215   void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override {
    216     op->src_data_type = DataType::Deserialize(m["src_data_type"].AsInt64());
    217     op->dst_data_type = DataType::Deserialize(m["dst_data_type"].AsInt64());
    218   }
    219 };
    220 
    221 class Concatenation
    222     : public BuiltinOperator<ConcatenationOperator,
    223                              ::tflite::ConcatenationOptions,
    224                              ::tflite::BuiltinOptions_ConcatenationOptions> {
    225  public:
    226   using BuiltinOperator::BuiltinOperator;
    227   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    228       const TocoOperator& op,
    229       flatbuffers::FlatBufferBuilder* builder) const override {
    230     return ::tflite::CreateConcatenationOptions(*builder, op.axis);
    231   }
    232 
    233   void ReadOptions(const TfLiteOptions& options,
    234                    TocoOperator* op) const override {
    235     op->axis = options.axis();
    236   }
    237 };
    238 
    239 class DepthToSpace : public CustomOperator<DepthToSpaceOperator> {
    240  public:
    241   using CustomOperator::CustomOperator;
    242   void WriteOptions(const TocoOperator& op,
    243                     flexbuffers::Builder* fbb) const override {
    244     fbb->Int("block_size", op.block_size);
    245   }
    246   void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override {
    247     op->block_size = m["block_size"].AsInt64();
    248   }
    249 };
    250 
    251 class FakeQuant : public CustomOperator<FakeQuantOperator> {
    252  public:
    253   using CustomOperator::CustomOperator;
    254   void WriteOptions(const TocoOperator& op,
    255                     flexbuffers::Builder* fbb) const override {
    256     fbb->Float("min", op.minmax->min);
    257     fbb->Float("max", op.minmax->max);
    258   }
    259   void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override {
    260     auto* minmax = new MinMax;
    261     minmax->min = m["min"].AsFloat();
    262     minmax->max = m["max"].AsFloat();
    263     op->minmax.reset(minmax);
    264   }
    265 };
    266 
    267 class FullyConnected
    268     : public BuiltinOperator<FullyConnectedOperator,
    269                              ::tflite::FullyConnectedOptions,
    270                              ::tflite::BuiltinOptions_FullyConnectedOptions> {
    271  public:
    272   using BuiltinOperator::BuiltinOperator;
    273   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    274       const TocoOperator& op,
    275       flatbuffers::FlatBufferBuilder* builder) const override {
    276     auto activation_function =
    277         ActivationFunction::Serialize(op.fused_activation_function);
    278     return ::tflite::CreateFullyConnectedOptions(*builder, activation_function);
    279   }
    280 
    281   void ReadOptions(const TfLiteOptions& options,
    282                    TocoOperator* op) const override {
    283     op->fused_activation_function =
    284         ActivationFunction::Deserialize(options.fused_activation_function());
    285   }
    286 };
    287 
    288 class Gather : public BuiltinOperator<GatherOperator, ::tflite::GatherOptions,
    289                                       ::tflite::BuiltinOptions_GatherOptions> {
    290  public:
    291   using BuiltinOperator::BuiltinOperator;
    292   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    293       const TocoOperator& op,
    294       flatbuffers::FlatBufferBuilder* builder) const override {
    295     return ::tflite::CreateGatherOptions(*builder, op.axis);
    296   }
    297 
    298   void ReadOptions(const TfLiteOptions& options,
    299                    TocoOperator* op) const override {
    300     op->axis = options.axis();
    301   }
    302 };
    303 
    304 class Svdf : public BuiltinOperator<SvdfOperator, ::tflite::SVDFOptions,
    305                                     ::tflite::BuiltinOptions_SVDFOptions> {
    306  public:
    307   using BuiltinOperator::BuiltinOperator;
    308   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    309       const TocoOperator& op,
    310       flatbuffers::FlatBufferBuilder* builder) const override {
    311     auto activation_function =
    312         ActivationFunction::Serialize(op.fused_activation_function);
    313     return ::tflite::CreateSVDFOptions(*builder, op.rank, activation_function);
    314   }
    315 
    316   void ReadOptions(const TfLiteOptions& options,
    317                    TocoOperator* op) const override {
    318     op->fused_activation_function =
    319         ActivationFunction::Deserialize(options.fused_activation_function());
    320     op->rank = options.rank();
    321   }
    322 };
    323 
    324 class L2Normalization
    325     : public BuiltinOperator<L2NormalizationOperator, ::tflite::L2NormOptions,
    326                              ::tflite::BuiltinOptions_L2NormOptions> {
    327  public:
    328   using BuiltinOperator::BuiltinOperator;
    329   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    330       const TocoOperator& op,
    331       flatbuffers::FlatBufferBuilder* builder) const override {
    332     auto activation_function =
    333         ActivationFunction::Serialize(op.fused_activation_function);
    334     return ::tflite::CreateL2NormOptions(*builder, activation_function);
    335   }
    336 
    337   void ReadOptions(const TfLiteOptions& options,
    338                    TocoOperator* op) const override {
    339     op->fused_activation_function =
    340         ActivationFunction::Deserialize(options.fused_activation_function());
    341   }
    342 };
    343 
    344 class L2Pool : public BuiltinOperator<L2PoolOperator, ::tflite::Pool2DOptions,
    345                                       ::tflite::BuiltinOptions_Pool2DOptions> {
    346  public:
    347   using BuiltinOperator::BuiltinOperator;
    348   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    349       const TocoOperator& op,
    350       flatbuffers::FlatBufferBuilder* builder) const override {
    351     auto padding = Padding::Serialize(op.padding.type);
    352     auto activation_function =
    353         ActivationFunction::Serialize(op.fused_activation_function);
    354     return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
    355                                          op.stride_height, op.kwidth,
    356                                          op.kheight, activation_function);
    357   }
    358 
    359   void ReadOptions(const TfLiteOptions& options,
    360                    TocoOperator* op) const override {
    361     op->padding.type = Padding::Deserialize(options.padding());
    362     op->stride_width = options.stride_w();
    363     op->stride_height = options.stride_h();
    364     op->kwidth = options.filter_width();
    365     op->kheight = options.filter_height();
    366     op->fused_activation_function =
    367         ActivationFunction::Deserialize(options.fused_activation_function());
    368   }
    369 };
    370 
    371 class LocalResponseNormalization
    372     : public BuiltinOperator<
    373           LocalResponseNormalizationOperator,
    374           ::tflite::LocalResponseNormalizationOptions,
    375           ::tflite::BuiltinOptions_LocalResponseNormalizationOptions> {
    376  public:
    377   using BuiltinOperator::BuiltinOperator;
    378   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    379       const TocoOperator& op,
    380       flatbuffers::FlatBufferBuilder* builder) const override {
    381     return ::tflite::CreateLocalResponseNormalizationOptions(
    382         *builder, op.range, op.bias, op.alpha, op.beta);
    383   }
    384 
    385   void ReadOptions(const TfLiteOptions& options,
    386                    TocoOperator* op) const override {
    387     op->range = options.radius();
    388     op->bias = options.bias();
    389     op->alpha = options.alpha();
    390     op->beta = options.beta();
    391   }
    392 };
    393 
    394 class MaxPool : public BuiltinOperator<MaxPoolOperator, ::tflite::Pool2DOptions,
    395                                        ::tflite::BuiltinOptions_Pool2DOptions> {
    396  public:
    397   using BuiltinOperator::BuiltinOperator;
    398   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    399       const TocoOperator& op,
    400       flatbuffers::FlatBufferBuilder* builder) const override {
    401     auto padding = Padding::Serialize(op.padding.type);
    402     auto activation_function =
    403         ActivationFunction::Serialize(op.fused_activation_function);
    404     return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
    405                                          op.stride_height, op.kwidth,
    406                                          op.kheight, activation_function);
    407   }
    408 
    409   void ReadOptions(const TfLiteOptions& options,
    410                    TocoOperator* op) const override {
    411     op->padding.type = Padding::Deserialize(options.padding());
    412     op->stride_width = options.stride_w();
    413     op->stride_height = options.stride_h();
    414     op->kwidth = options.filter_width();
    415     op->kheight = options.filter_height();
    416     op->fused_activation_function =
    417         ActivationFunction::Deserialize(options.fused_activation_function());
    418   }
    419 };
    420 
    421 class Mul : public BuiltinOperator<MulOperator, ::tflite::MulOptions,
    422                                    ::tflite::BuiltinOptions_MulOptions> {
    423  public:
    424   using BuiltinOperator::BuiltinOperator;
    425 
    426   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    427       const TocoOperator& op,
    428       flatbuffers::FlatBufferBuilder* builder) const override {
    429     auto activation_function =
    430         ActivationFunction::Serialize(op.fused_activation_function);
    431     return ::tflite::CreateMulOptions(*builder, activation_function);
    432   }
    433 
    434   void ReadOptions(const TfLiteOptions& options,
    435                    TocoOperator* op) const override {
    436     op->fused_activation_function =
    437         ActivationFunction::Deserialize(options.fused_activation_function());
    438   }
    439 };
    440 
    441 class Pad : public BuiltinOperator<PadOperator, ::tflite::PadOptions,
    442                                    ::tflite::BuiltinOptions_PadOptions> {
    443  public:
    444   using BuiltinOperator::BuiltinOperator;
    445 
    446   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    447       const TocoOperator& op,
    448       flatbuffers::FlatBufferBuilder* builder) const override {
    449     return ::tflite::CreatePadOptions(*builder);
    450   }
    451 
    452   void ReadOptions(const TfLiteOptions& options,
    453                    TocoOperator* op) const override {}
    454 };
    455 
    456 class Reshape
    457     : public BuiltinOperator<TensorFlowReshapeOperator,
    458                              ::tflite::ReshapeOptions,
    459                              ::tflite::BuiltinOptions_ReshapeOptions> {
    460  public:
    461   using BuiltinOperator::BuiltinOperator;
    462 
    463   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    464       const TocoOperator& op,
    465       flatbuffers::FlatBufferBuilder* builder) const override {
    466     return ::tflite::CreateReshapeOptions(*builder,
    467                                           builder->CreateVector(op.shape));
    468   }
    469 
    470   void ReadOptions(const TfLiteOptions& options,
    471                    TocoOperator* op) const override {
    472     op->shape.insert(op->shape.end(), options.new_shape()->begin(),
    473                      options.new_shape()->end());
    474   }
    475 };
    476 
    477 class Softmax
    478     : public BuiltinOperator<SoftmaxOperator, ::tflite::SoftmaxOptions,
    479                              ::tflite::BuiltinOptions_SoftmaxOptions> {
    480  public:
    481   using BuiltinOperator::BuiltinOperator;
    482   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    483       const TocoOperator& op,
    484       flatbuffers::FlatBufferBuilder* builder) const override {
    485     return ::tflite::CreateSoftmaxOptions(*builder, op.beta);
    486   }
    487 
    488   void ReadOptions(const TfLiteOptions& options,
    489                    TocoOperator* op) const override {
    490     op->beta = options.beta();
    491   }
    492 };
    493 
    494 class SpaceToDepth
    495     : public BuiltinOperator<SpaceToDepthOperator,
    496                              ::tflite::SpaceToDepthOptions,
    497                              ::tflite::BuiltinOptions_SpaceToDepthOptions> {
    498  public:
    499   using BuiltinOperator::BuiltinOperator;
    500   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    501       const TocoOperator& op,
    502       flatbuffers::FlatBufferBuilder* builder) const override {
    503     return ::tflite::CreateSpaceToDepthOptions(*builder, op.block_size);
    504   }
    505 
    506   void ReadOptions(const TfLiteOptions& options,
    507                    TocoOperator* op) const override {
    508     op->block_size = options.block_size();
    509   }
    510 };
    511 
    512 class Transpose
    513     : public BuiltinOperator<TransposeOperator, ::tflite::TransposeOptions,
    514                              ::tflite::BuiltinOptions_TransposeOptions> {
    515  public:
    516   using BuiltinOperator::BuiltinOperator;
    517   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    518       const TocoOperator& op,
    519       flatbuffers::FlatBufferBuilder* builder) const override {
    520     return ::tflite::CreateTransposeOptions(*builder);
    521   }
    522 
    523   void ReadOptions(const TfLiteOptions& options,
    524                    TocoOperator* op) const override {}
    525 };
    526 
    527 class Lstm : public BuiltinOperator<LstmCellOperator, ::tflite::LSTMOptions,
    528                                     ::tflite::BuiltinOptions_LSTMOptions> {
    529  public:
    530   using BuiltinOperator::BuiltinOperator;
    531   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    532       const TocoOperator& op,
    533       flatbuffers::FlatBufferBuilder* builder) const override {
    534     // Current toco converter only supports tanh, no clip.
    535     return ::tflite::CreateLSTMOptions(*builder, /*fused_activation_function=*/
    536                                        ::tflite::ActivationFunctionType_TANH,
    537                                        /*cell_clip=*/0.0,
    538                                        /*proj_clip=*/0.0);
    539   }
    540 
    541   void ReadOptions(const TfLiteOptions& options,
    542                    TocoOperator* op) const override {
    543     // Only support tanh activation, so check that tflite type is tanh.
    544     CHECK(options.fused_activation_function() ==
    545           ::tflite::ActivationFunctionType_TANH);
    546   }
    547 };
    548 
    549 class Mean : public BuiltinOperator<MeanOperator, ::tflite::MeanOptions,
    550                                     ::tflite::BuiltinOptions_MeanOptions> {
    551  public:
    552   using BuiltinOperator::BuiltinOperator;
    553   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    554       const TocoOperator& op,
    555       flatbuffers::FlatBufferBuilder* builder) const override {
    556     return ::tflite::CreateMeanOptions(*builder, op.keep_dims);
    557   }
    558 
    559   void ReadOptions(const TfLiteOptions& options,
    560                    TocoOperator* op) const override {
    561     op->keep_dims = options.keep_dims();
    562   }
    563 };
    564 
    565 class ResizeBilinear
    566     : public BuiltinOperator<ResizeBilinearOperator,
    567                              ::tflite::ResizeBilinearOptions,
    568                              ::tflite::BuiltinOptions_ResizeBilinearOptions> {
    569  public:
    570   using BuiltinOperator::BuiltinOperator;
    571   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    572       const TocoOperator& op,
    573       flatbuffers::FlatBufferBuilder* builder) const override {
    574     return ::tflite::CreateResizeBilinearOptions(*builder, op.align_corners);
    575   }
    576 
    577   void ReadOptions(const TfLiteOptions& options,
    578                    TocoOperator* op) const override {
    579     op->align_corners = options.align_corners();
    580   }
    581 };
    582 
    583 class Squeeze
    584     : public BuiltinOperator<SqueezeOperator, ::tflite::SqueezeOptions,
    585                              ::tflite::BuiltinOptions_SqueezeOptions> {
    586  public:
    587   using BuiltinOperator::BuiltinOperator;
    588 
    589   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    590       const TocoOperator& op,
    591       flatbuffers::FlatBufferBuilder* builder) const override {
    592     auto squeeze_dims = builder->CreateVector(op.squeeze_dims);
    593     return ::tflite::CreateSqueezeOptions(*builder, squeeze_dims);
    594   }
    595 
    596   void ReadOptions(const TfLiteOptions& options,
    597                    TocoOperator* op) const override {
    598     op->squeeze_dims.insert(op->squeeze_dims.end(),
    599                             options.squeeze_dims()->begin(),
    600                             options.squeeze_dims()->end());
    601   }
    602 };
    603 
    604 class Split
    605     : public BuiltinOperator<TensorFlowSplitOperator, ::tflite::SplitOptions,
    606                              ::tflite::BuiltinOptions_SplitOptions> {
    607  public:
    608   using BuiltinOperator::BuiltinOperator;
    609 
    610   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    611       const TocoOperator& op,
    612       flatbuffers::FlatBufferBuilder* builder) const override {
    613     return ::tflite::CreateSplitOptions(*builder, op.num_split);
    614   }
    615 
    616   void ReadOptions(const TfLiteOptions& options,
    617                    TocoOperator* op) const override {
    618     op->num_split = options.num_splits();
    619   }
    620 };
    621 
    622 class StridedSlice
    623     : public BuiltinOperator<StridedSliceOperator,
    624                              ::tflite::StridedSliceOptions,
    625                              ::tflite::BuiltinOptions_StridedSliceOptions> {
    626  public:
    627   using BuiltinOperator::BuiltinOperator;
    628   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    629       const TocoOperator& op,
    630       flatbuffers::FlatBufferBuilder* builder) const override {
    631     return ::tflite::CreateStridedSliceOptions(
    632         *builder, op.begin_mask, op.end_mask, op.ellipsis_mask,
    633         op.new_axis_mask, op.shrink_axis_mask);
    634   }
    635 
    636   void ReadOptions(const TfLiteOptions& options,
    637                    TocoOperator* op) const override {
    638     op->begin_mask = options.begin_mask();
    639     op->end_mask = options.end_mask();
    640     op->ellipsis_mask = options.ellipsis_mask();
    641     op->new_axis_mask = options.new_axis_mask();
    642     op->shrink_axis_mask = options.shrink_axis_mask();
    643   }
    644 };
    645 
    646 class TopK_V2 : public BuiltinOperator<TopKV2Operator, ::tflite::TopKV2Options,
    647                                        ::tflite::BuiltinOptions_TopKV2Options> {
    648  public:
    649   using BuiltinOperator::BuiltinOperator;
    650   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    651       const TocoOperator& op,
    652       flatbuffers::FlatBufferBuilder* builder) const override {
    653     return ::tflite::CreateTopKV2Options(*builder);
    654   }
    655 
    656   void ReadOptions(const TfLiteOptions& options,
    657                    TocoOperator* op) const override {}
    658 };
    659 
    660 class TensorFlowUnsupported : public BaseOperator {
    661  public:
    662   using BaseOperator::BaseOperator;
    663 
    664   Options Serialize(const Operator& op,
    665                     flatbuffers::FlatBufferBuilder* builder) const override {
    666     auto fbb =
    667         WriteOptions(static_cast<const TensorFlowUnsupportedOperator&>(op));
    668     if (fbb) {
    669       return Options::Custom(builder->CreateVector(fbb->GetBuffer()));
    670     } else {
    671       return Options::Custom(0);
    672     }
    673   }
    674 
    675   std::unique_ptr<Operator> Deserialize(
    676       const BuiltinOptions* builtin_options,
    677       const CustomOptions* custom_options) const override {
    678     auto op = absl::make_unique<TensorFlowUnsupportedOperator>();
    679     if (custom_options) {
    680       auto flexbuffer_map =
    681           flexbuffers::GetRoot(custom_options->data(), custom_options->size())
    682               .AsMap();
    683       ReadOptions(flexbuffer_map, op.get());
    684     }
    685     return std::unique_ptr<Operator>(op.release());
    686   }
    687 
    688   std::unique_ptr<flexbuffers::Builder> WriteOptions(
    689       const TensorFlowUnsupportedOperator& op) const {
    690     auto fbb = absl::make_unique<flexbuffers::Builder>();
    691 
    692     ::tensorflow::NodeDef node_def;
    693     if (!node_def.ParseFromString(op.tensorflow_node_def)) {
    694       LOG(ERROR) << "Failed to parse TensorFlow NodeDef";
    695       return std::unique_ptr<flexbuffers::Builder>();
    696     }
    697 
    698     bool has_valid_attr = false;
    699     size_t map_start = fbb->StartMap();
    700     for (const auto& pair : node_def.attr()) {
    701       const char* key = pair.first.c_str();
    702       const auto& attr = pair.second;
    703       switch (attr.value_case()) {
    704         case ::tensorflow::AttrValue::kS:
    705           fbb->String(key, attr.s());
    706           has_valid_attr = true;
    707           break;
    708         case ::tensorflow::AttrValue::kI:
    709           fbb->Int(key, attr.i());
    710           has_valid_attr = true;
    711           break;
    712         case ::tensorflow::AttrValue::kF:
    713           fbb->Float(key, attr.f());
    714           has_valid_attr = true;
    715           break;
    716         case ::tensorflow::AttrValue::kB:
    717           fbb->Bool(key, attr.b());
    718           has_valid_attr = true;
    719           break;
    720         default:
    721           LOG(WARNING) << "Ignoring unsupported attribute type with key '"
    722                        << key << "'";
    723           break;
    724       }
    725     }
    726     if (!has_valid_attr) {
    727       return std::unique_ptr<flexbuffers::Builder>();
    728     }
    729     fbb->EndMap(map_start);
    730     fbb->Finish();
    731     return std::unique_ptr<flexbuffers::Builder>(fbb.release());
    732   }
    733 
    734   void ReadOptions(const flexbuffers::Map& m,
    735                    TensorFlowUnsupportedOperator* op) const {
    736     ::tensorflow::NodeDef node_def;
    737     auto attr = node_def.mutable_attr();
    738 
    739     const auto& keys = m.Keys();
    740     for (size_t i = 0; i < keys.size(); ++i) {
    741       const auto key = keys[i].AsKey();
    742       const auto& value = m[key];
    743       switch (value.GetType()) {
    744         case flexbuffers::TYPE_STRING:
    745           (*attr)[key].set_s(value.AsString().c_str());
    746           break;
    747         case flexbuffers::TYPE_INT:
    748           (*attr)[key].set_i(value.AsInt64());
    749           break;
    750         case flexbuffers::TYPE_FLOAT:
    751           (*attr)[key].set_f(value.AsFloat());
    752           break;
    753         case flexbuffers::TYPE_BOOL:
    754           (*attr)[key].set_b(value.AsBool());
    755           break;
    756         default:
    757           LOG(WARNING) << "Ignoring unsupported attribute type with key '"
    758                        << key << "'";
    759           break;
    760       }
    761     }
    762     node_def.SerializeToString(&op->tensorflow_node_def);
    763   }
    764 };
    765 
    766 namespace {
    767 // Build a vector containing all the known operators.
    768 std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
    769   std::vector<std::unique_ptr<BaseOperator>> ops;
    770 
    771   // Builtin Operators.
    772   ops.emplace_back(new Add(::tflite::BuiltinOperator_ADD, OperatorType::kAdd));
    773   ops.emplace_back(new Div(::tflite::BuiltinOperator_DIV, OperatorType::kDiv));
    774   ops.emplace_back(new Sub(::tflite::BuiltinOperator_SUB, OperatorType::kSub));
    775   ops.emplace_back(new AveragePool(::tflite::BuiltinOperator_AVERAGE_POOL_2D,
    776                                    OperatorType::kAveragePool));
    777   ops.emplace_back(
    778       new SpaceToBatchND(::tflite::BuiltinOperator_SPACE_TO_BATCH_ND,
    779                          OperatorType::kSpaceToBatchND));
    780   ops.emplace_back(
    781       new BatchToSpaceND(::tflite::BuiltinOperator_BATCH_TO_SPACE_ND,
    782                          OperatorType::kBatchToSpaceND));
    783   ops.emplace_back(new Concatenation(::tflite::BuiltinOperator_CONCATENATION,
    784                                      OperatorType::kConcatenation));
    785   ops.emplace_back(
    786       new Convolution(::tflite::BuiltinOperator_CONV_2D, OperatorType::kConv));
    787   ops.emplace_back(
    788       new DepthwiseConvolution(::tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
    789                                OperatorType::kDepthwiseConv));
    790   ops.emplace_back(new FullyConnected(::tflite::BuiltinOperator_FULLY_CONNECTED,
    791                                       OperatorType::kFullyConnected));
    792   ops.emplace_back(
    793       new Gather(::tflite::BuiltinOperator_GATHER, OperatorType::kGather));
    794   ops.emplace_back(
    795       new L2Normalization(::tflite::BuiltinOperator_L2_NORMALIZATION,
    796                           OperatorType::kL2Normalization));
    797   ops.emplace_back(
    798       new L2Pool(::tflite::BuiltinOperator_L2_POOL_2D, OperatorType::kL2Pool));
    799   ops.emplace_back(new LocalResponseNormalization(
    800       ::tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION,
    801       OperatorType::kLocalResponseNormalization));
    802   ops.emplace_back(new MaxPool(::tflite::BuiltinOperator_MAX_POOL_2D,
    803                                OperatorType::kMaxPool));
    804   ops.emplace_back(new Mul(::tflite::BuiltinOperator_MUL, OperatorType::kMul));
    805   ops.emplace_back(new Pad(::tflite::BuiltinOperator_PAD, OperatorType::kPad));
    806   ops.emplace_back(new Reshape(::tflite::BuiltinOperator_RESHAPE,
    807                                OperatorType::kTensorFlowReshape));
    808   ops.emplace_back(
    809       new Softmax(::tflite::BuiltinOperator_SOFTMAX, OperatorType::kSoftmax));
    810   ops.emplace_back(new SpaceToDepth(::tflite::BuiltinOperator_SPACE_TO_DEPTH,
    811                                     OperatorType::kSpaceToDepth));
    812   ops.emplace_back(
    813       new Svdf(::tflite::BuiltinOperator_SVDF, OperatorType::kSvdf));
    814   ops.emplace_back(new Transpose(::tflite::BuiltinOperator_TRANSPOSE,
    815                                  OperatorType::kTranspose));
    816   ops.emplace_back(
    817       new Mean(::tflite::BuiltinOperator_MEAN, OperatorType::kMean));
    818   ops.emplace_back(new ResizeBilinear(::tflite::BuiltinOperator_RESIZE_BILINEAR,
    819                                       OperatorType::kResizeBilinear));
    820   ops.emplace_back(
    821       new Squeeze(::tflite::BuiltinOperator_SQUEEZE, OperatorType::kSqueeze));
    822   ops.emplace_back(new Split(::tflite::BuiltinOperator_SPLIT,
    823                              OperatorType::kTensorFlowSplit));
    824   ops.emplace_back(new StridedSlice(::tflite::BuiltinOperator_STRIDED_SLICE,
    825                                     OperatorType::kStridedSlice));
    826   ops.emplace_back(
    827       new TopK_V2(::tflite::BuiltinOperator_TOPK_V2, OperatorType::kTopK_V2));
    828   ops.emplace_back(
    829       new Lstm(::tflite::BuiltinOperator_LSTM, OperatorType::kLstmCell));
    830 
    831   // Custom Operators.
    832   ops.emplace_back(new Cast("CAST", OperatorType::kCast));
    833   ops.emplace_back(
    834       new DepthToSpace("DEPTH_TO_SPACE", OperatorType::kDepthToSpace));
    835   ops.emplace_back(new FakeQuant("FAKE_QUANT", OperatorType::kFakeQuant));
    836   ops.emplace_back(new TensorFlowUnsupported(
    837       "TENSORFLOW_UNSUPPORTED", OperatorType::kTensorFlowUnsupported));
    838 
    839   // There operators are supported by Toco, but not by TF Lite, and has no
    840   // attributes.
    841   ops.emplace_back(
    842       new SimpleOperator<AddNOperator>("ADDN", OperatorType::kAddN));
    843   ops.emplace_back(new SimpleOperator<NegOperator>("NEG", OperatorType::kNeg));
    844   ops.emplace_back(new SimpleOperator<TensorFlowRsqrtOperator>(
    845       "RSQRT", OperatorType::kTensorFlowRsqrt));
    846   // Simple Operators.
    847   ops.emplace_back(new SimpleOperator<DequantizeOperator>(
    848       "DEQUANTIZE", OperatorType::kDequantize));
    849   ops.emplace_back(
    850       new SimpleOperator<FloorOperator>("FLOOR", OperatorType::kFloor));
    851   ops.emplace_back(
    852       new SimpleOperator<ReluOperator>("RELU", OperatorType::kRelu));
    853   ops.emplace_back(
    854       new SimpleOperator<Relu1Operator>("RELU_N1_TO_1", OperatorType::kRelu1));
    855   ops.emplace_back(
    856       new SimpleOperator<Relu6Operator>("RELU6", OperatorType::kRelu6));
    857   ops.emplace_back(new SimpleOperator<LogisticOperator>(
    858       "LOGISTIC", OperatorType::kLogistic));
    859   ops.emplace_back(
    860       new SimpleOperator<TanhOperator>("TANH", OperatorType::kTanh));
    861   ops.emplace_back(new SimpleOperator<ExpOperator>("EXP", OperatorType::kExp));
    862 
    863   return ops;
    864 }
    865 }  // namespace
    866 
    867 std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap() {
    868   std::map<OperatorType, std::unique_ptr<BaseOperator>> result;
    869 
    870   std::vector<std::unique_ptr<BaseOperator>> ops = BuildOperatorList();
    871   for (auto& op : ops) {
    872     result[op->type()] = std::move(op);
    873   }
    874 
    875   return result;
    876 }
    877 
    878 std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap() {
    879   std::map<string, std::unique_ptr<BaseOperator>> result;
    880 
    881   std::vector<std::unique_ptr<BaseOperator>> ops = BuildOperatorList();
    882   for (auto& op : ops) {
    883     result[op->name()] = std::move(op);
    884   }
    885 
    886   return result;
    887 }
    888 
    889 }  // namespace tflite
    890 
    891 }  // namespace toco
    892