Home | History | Annotate | Download | only in tflite
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include "tensorflow/lite/toco/tflite/operator.h"
     16 
     17 #include "tensorflow/core/framework/attr_value.pb.h"
     18 #include "tensorflow/core/framework/node_def.pb.h"
     19 #include "tensorflow/core/framework/op.h"
     20 #include "tensorflow/core/framework/op_def.pb.h"
     21 #include "tensorflow/core/util/ptr_util.h"
     22 
     23 // TODO(ycling): Consider refactoring to extract the LSTM definition out of
     24 // graph_transformation module.
     25 #include "tensorflow/lite/schema/schema_generated.h"
     26 #include "tensorflow/lite/toco/graph_transformations/lstm_utils.h"
     27 #include "tensorflow/lite/toco/model.h"
     28 #include "tensorflow/lite/toco/tflite/builtin_operator.h"
     29 #include "tensorflow/lite/toco/tflite/custom_operator.h"
     30 #include "tensorflow/lite/toco/tflite/simple_operator.h"
     31 #include "tensorflow/lite/toco/tflite/types.h"
     32 #include "tensorflow/lite/toco/tflite/whitelisted_flex_ops.h"
     33 
     34 namespace toco {
     35 
     36 namespace tflite {
     37 
     38 class AveragePool
     39     : public BuiltinOperator<AveragePoolOperator, ::tflite::Pool2DOptions,
     40                              ::tflite::BuiltinOptions_Pool2DOptions> {
     41  public:
     42   using BuiltinOperator::BuiltinOperator;
     43 
     44   flatbuffers::Offset<TfLiteOptions> WriteOptions(
     45       const TocoOperator& op,
     46       flatbuffers::FlatBufferBuilder* builder) const override {
     47     auto padding = Padding::Serialize(op.padding.type);
     48     auto activation_function =
     49         ActivationFunction::Serialize(op.fused_activation_function);
     50     return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
     51                                          op.stride_height, op.kwidth,
     52                                          op.kheight, activation_function);
     53   }
     54 
     55   void ReadOptions(const TfLiteOptions& options,
     56                    TocoOperator* op) const override {
     57     op->padding.type = Padding::Deserialize(options.padding());
     58     op->stride_width = options.stride_w();
     59     op->stride_height = options.stride_h();
     60     op->kwidth = options.filter_width();
     61     op->kheight = options.filter_height();
     62     op->fused_activation_function =
     63         ActivationFunction::Deserialize(options.fused_activation_function());
     64   }
     65 
     66   int GetVersion(const OperatorSignature& op_signature) const override {
     67     const string& input_name = op_signature.op->inputs[0];
     68     const Array& input_array = op_signature.model->GetArray(input_name);
     69     if (input_array.data_type == ArrayDataType::kInt8) {
     70       return 2;
     71     }
     72     return 1;
     73   }
     74 };
     75 
     76 class Convolution
     77     : public BuiltinOperator<ConvOperator, ::tflite::Conv2DOptions,
     78                              ::tflite::BuiltinOptions_Conv2DOptions> {
     79  public:
     80   using BuiltinOperator::BuiltinOperator;
     81 
     82   flatbuffers::Offset<TfLiteOptions> WriteOptions(
     83       const TocoOperator& op,
     84       flatbuffers::FlatBufferBuilder* builder) const override {
     85     auto padding = Padding::Serialize(op.padding.type);
     86     auto activation_function =
     87         ActivationFunction::Serialize(op.fused_activation_function);
     88     return ::tflite::CreateConv2DOptions(*builder, padding, op.stride_width,
     89                                          op.stride_height, activation_function,
     90                                          op.dilation_width_factor,
     91                                          op.dilation_height_factor);
     92   }
     93 
     94   void ReadOptions(const TfLiteOptions& options,
     95                    TocoOperator* op) const override {
     96     op->padding.type = Padding::Deserialize(options.padding());
     97     op->stride_width = options.stride_w();
     98     op->stride_height = options.stride_h();
     99     op->dilation_width_factor = options.dilation_w_factor();
    100     op->dilation_height_factor = options.dilation_h_factor();
    101     op->fused_activation_function =
    102         ActivationFunction::Deserialize(options.fused_activation_function());
    103   }
    104 
    105   int GetVersion(const OperatorSignature& op_signature) const override {
    106     const string& input_name = op_signature.op->inputs[0];
    107     const string& filter_name = op_signature.op->inputs[1];
    108     const string& output_name = op_signature.op->outputs[0];
    109     const Array& input_array = op_signature.model->GetArray(input_name);
    110     const Array& filter_array = op_signature.model->GetArray(filter_name);
    111     const Array& output_array = op_signature.model->GetArray(output_name);
    112     // If the op has signed int8 inputs and outputs, its version 3.
    113     if (input_array.data_type == ArrayDataType::kInt8 &&
    114         filter_array.data_type == ArrayDataType::kInt8 &&
    115         output_array.data_type == ArrayDataType::kInt8) {
    116       return 3;
    117     }
    118     // If the op is a signed int8 hybrid operation, we need to return
    119     // version 2.
    120     if (input_array.data_type == ArrayDataType::kFloat &&
    121         filter_array.data_type == ArrayDataType::kInt8 &&
    122         output_array.data_type == ArrayDataType::kFloat) {
    123       return 2;
    124     }
    125     return 1;
    126   }
    127 };
    128 
    129 class DepthwiseConvolution
    130     : public BuiltinOperator<DepthwiseConvOperator,
    131                              ::tflite::DepthwiseConv2DOptions,
    132                              ::tflite::BuiltinOptions_DepthwiseConv2DOptions> {
    133  public:
    134   using BuiltinOperator::BuiltinOperator;
    135 
    136   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    137       const TocoOperator& op,
    138       flatbuffers::FlatBufferBuilder* builder) const override {
    139     auto padding = Padding::Serialize(op.padding.type);
    140     auto activation_function =
    141         ActivationFunction::Serialize(op.fused_activation_function);
    142     return ::tflite::CreateDepthwiseConv2DOptions(
    143         *builder, padding, op.stride_width, op.stride_height,
    144         op.depth_multiplier, activation_function, op.dilation_width_factor,
    145         op.dilation_height_factor);
    146   }
    147 
    148   void ReadOptions(const TfLiteOptions& options,
    149                    TocoOperator* op) const override {
    150     op->padding.type = Padding::Deserialize(options.padding());
    151     op->stride_width = options.stride_w();
    152     op->stride_height = options.stride_h();
    153     op->depth_multiplier = options.depth_multiplier();
    154     op->fused_activation_function =
    155         ActivationFunction::Deserialize(options.fused_activation_function());
    156     op->dilation_width_factor = options.dilation_w_factor();
    157     op->dilation_height_factor = options.dilation_h_factor();
    158   }
    159 
    160   int GetVersion(const OperatorSignature& op_signature) const override {
    161     const auto& conv_op =
    162         static_cast<const DepthwiseConvOperator&>(*op_signature.op);
    163     const string& input_name = op_signature.op->inputs[0];
    164     const string& filter_name = op_signature.op->inputs[1];
    165     const string& output_name = op_signature.op->outputs[0];
    166     const Array& input_array = op_signature.model->GetArray(input_name);
    167     const Array& filter_array = op_signature.model->GetArray(filter_name);
    168     const Array& output_array = op_signature.model->GetArray(output_name);
    169     // If the op has signed int8 inputs and outputs, its version 3.
    170     if (input_array.data_type == ArrayDataType::kInt8 &&
    171         filter_array.data_type == ArrayDataType::kInt8 &&
    172         output_array.data_type == ArrayDataType::kInt8) {
    173       return 3;
    174     }
    175     if (conv_op.dilation_width_factor != 1 ||
    176         conv_op.dilation_height_factor != 1) {
    177       return 2;
    178     }
    179     return 1;
    180   }
    181 };
    182 
    183 class Add : public BuiltinOperator<AddOperator, ::tflite::AddOptions,
    184                                    ::tflite::BuiltinOptions_AddOptions> {
    185  public:
    186   using BuiltinOperator::BuiltinOperator;
    187 
    188   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    189       const TocoOperator& op,
    190       flatbuffers::FlatBufferBuilder* builder) const override {
    191     auto activation_function =
    192         ActivationFunction::Serialize(op.fused_activation_function);
    193     return ::tflite::CreateAddOptions(*builder, activation_function);
    194   }
    195 
    196   void ReadOptions(const TfLiteOptions& options,
    197                    TocoOperator* op) const override {
    198     op->fused_activation_function =
    199         ActivationFunction::Deserialize(options.fused_activation_function());
    200   }
    201 
    202   int GetVersion(const OperatorSignature& op_signature) const override {
    203     const string& input_name = op_signature.op->inputs[0];
    204     const Array& input_array = op_signature.model->GetArray(input_name);
    205     // Version 2 supports signed int8 input types.
    206     if (input_array.data_type == ArrayDataType::kInt8) {
    207       return 2;
    208     }
    209     return 1;
    210   }
    211 };
    212 
    213 class AddN : public BuiltinOperator<AddNOperator, ::tflite::AddNOptions,
    214                                     ::tflite::BuiltinOptions_AddNOptions> {
    215  public:
    216   using BuiltinOperator::BuiltinOperator;
    217 
    218   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    219       const TocoOperator& op,
    220       flatbuffers::FlatBufferBuilder* builder) const override {
    221     return ::tflite::CreateAddNOptions(*builder);
    222   }
    223 
    224   void ReadOptions(const TfLiteOptions& options,
    225                    TocoOperator* op) const override {}
    226 
    227   int GetVersion(const OperatorSignature& op_signature) const override {
    228     return 1;
    229   }
    230 };
    231 
    232 class SpaceToBatchND
    233     : public BuiltinOperator<SpaceToBatchNDOperator,
    234                              ::tflite::SpaceToBatchNDOptions,
    235                              ::tflite::BuiltinOptions_SpaceToBatchNDOptions> {
    236  public:
    237   using BuiltinOperator::BuiltinOperator;
    238 
    239   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    240       const TocoOperator& op,
    241       flatbuffers::FlatBufferBuilder* builder) const override {
    242     return ::tflite::CreateSpaceToBatchNDOptions(*builder);
    243   }
    244 
    245   void ReadOptions(const TfLiteOptions& options,
    246                    TocoOperator* op) const override {}
    247 
    248   int GetVersion(const OperatorSignature& op_signature) const override {
    249     const string& input_name = op_signature.op->inputs[0];
    250     const Array& input_array = op_signature.model->GetArray(input_name);
    251     // If the op take int8 input, it is version 2.
    252     if (input_array.data_type == ArrayDataType::kInt8) {
    253       return 2;
    254     }
    255     return 1;
    256   }
    257 };
    258 
    259 class Sub : public BuiltinOperator<SubOperator, ::tflite::SubOptions,
    260                                    ::tflite::BuiltinOptions_SubOptions> {
    261  public:
    262   using BuiltinOperator::BuiltinOperator;
    263 
    264   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    265       const TocoOperator& op,
    266       flatbuffers::FlatBufferBuilder* builder) const override {
    267     auto activation_function =
    268         ActivationFunction::Serialize(op.fused_activation_function);
    269     return ::tflite::CreateSubOptions(*builder, activation_function);
    270   }
    271 
    272   void ReadOptions(const TfLiteOptions& options,
    273                    TocoOperator* op) const override {
    274     op->fused_activation_function =
    275         ActivationFunction::Deserialize(options.fused_activation_function());
    276   }
    277 
    278   int GetVersion(const OperatorSignature& op_signature) const override {
    279     const string& input_name = op_signature.op->inputs[0];
    280     const Array& input_array = op_signature.model->GetArray(input_name);
    281     // If the op take int8 input, it is version 2.
    282     if (input_array.data_type == ArrayDataType::kInt8) {
    283       return 2;
    284     }
    285     return 1;
    286   }
    287 };
    288 
    289 class Div : public BuiltinOperator<DivOperator, ::tflite::DivOptions,
    290                                    ::tflite::BuiltinOptions_DivOptions> {
    291  public:
    292   using BuiltinOperator::BuiltinOperator;
    293 
    294   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    295       const TocoOperator& op,
    296       flatbuffers::FlatBufferBuilder* builder) const override {
    297     auto activation_function =
    298         ActivationFunction::Serialize(op.fused_activation_function);
    299     return ::tflite::CreateDivOptions(*builder, activation_function);
    300   }
    301 
    302   void ReadOptions(const TfLiteOptions& options,
    303                    TocoOperator* op) const override {
    304     op->fused_activation_function =
    305         ActivationFunction::Deserialize(options.fused_activation_function());
    306   }
    307 
    308   int GetVersion(const OperatorSignature& op_signature) const override {
    309     return 1;
    310   }
    311 };
    312 
    313 class BatchToSpaceND
    314     : public BuiltinOperator<BatchToSpaceNDOperator,
    315                              ::tflite::BatchToSpaceNDOptions,
    316                              ::tflite::BuiltinOptions_BatchToSpaceNDOptions> {
    317  public:
    318   using BuiltinOperator::BuiltinOperator;
    319 
    320   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    321       const TocoOperator& op,
    322       flatbuffers::FlatBufferBuilder* builder) const override {
    323     return ::tflite::CreateBatchToSpaceNDOptions(*builder);
    324   }
    325 
    326   void ReadOptions(const TfLiteOptions& options,
    327                    TocoOperator* op) const override {}
    328 
    329   int GetVersion(const OperatorSignature& op_signature) const override {
    330     const string& input_name = op_signature.op->inputs[0];
    331     const Array& input_array = op_signature.model->GetArray(input_name);
    332     // If the op take int8 input, it is version 2.
    333     if (input_array.data_type == ArrayDataType::kInt8) {
    334       return 2;
    335     }
    336     return 1;
    337   }
    338 };
    339 
    340 class Cast : public BuiltinOperator<CastOperator, ::tflite::CastOptions,
    341                                     ::tflite::BuiltinOptions_CastOptions> {
    342  public:
    343   using BuiltinOperator::BuiltinOperator;
    344   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    345       const TocoOperator& op,
    346       flatbuffers::FlatBufferBuilder* builder) const override {
    347     return ::tflite::CreateCastOptions(*builder,
    348                                        DataType::Serialize(op.src_data_type),
    349                                        DataType::Serialize(op.dst_data_type));
    350   }
    351 
    352   void ReadOptions(const TfLiteOptions& options,
    353                    TocoOperator* op) const override {
    354     op->src_data_type = DataType::Deserialize(options.in_data_type());
    355     op->dst_data_type = DataType::Deserialize(options.out_data_type());
    356   }
    357 
    358   int GetVersion(const OperatorSignature& op_signature) const override {
    359     return 1;
    360   }
    361 };
    362 
    363 class Concatenation
    364     : public BuiltinOperator<ConcatenationOperator,
    365                              ::tflite::ConcatenationOptions,
    366                              ::tflite::BuiltinOptions_ConcatenationOptions> {
    367  public:
    368   using BuiltinOperator::BuiltinOperator;
    369   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    370       const TocoOperator& op,
    371       flatbuffers::FlatBufferBuilder* builder) const override {
    372     return ::tflite::CreateConcatenationOptions(*builder, op.axis);
    373   }
    374 
    375   void ReadOptions(const TfLiteOptions& options,
    376                    TocoOperator* op) const override {
    377     op->axis = options.axis();
    378   }
    379 
    380   int GetVersion(const OperatorSignature& op_signature) const override {
    381     const string& input_name = op_signature.op->inputs[0];
    382     const Array& input_array = op_signature.model->GetArray(input_name);
    383     // If the op take int8 input, it is version 2.
    384     if (input_array.data_type == ArrayDataType::kInt8) {
    385       return 2;
    386     }
    387     return 1;
    388   }
    389 };
    390 
    391 class DepthToSpace : public CustomOperator<DepthToSpaceOperator> {
    392  public:
    393   using CustomOperator::CustomOperator;
    394   void WriteOptions(const TocoOperator& op,
    395                     flexbuffers::Builder* fbb) const override {
    396     fbb->Int("block_size", op.block_size);
    397   }
    398   void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override {
    399     op->block_size = m["block_size"].AsInt64();
    400   }
    401 
    402   int GetVersion(const OperatorSignature& op_signature) const override {
    403     return 1;
    404   }
    405 };
    406 
    407 class FakeQuant
    408     : public BuiltinOperator<FakeQuantOperator, ::tflite::FakeQuantOptions,
    409                              ::tflite::BuiltinOptions_FakeQuantOptions> {
    410  public:
    411   using BuiltinOperator::BuiltinOperator;
    412   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    413       const TocoOperator& op,
    414       flatbuffers::FlatBufferBuilder* builder) const override {
    415     return ::tflite::CreateFakeQuantOptions(
    416         *builder, op.minmax->min, op.minmax->max, op.num_bits, op.narrow_range);
    417   }
    418   void ReadOptions(const TfLiteOptions& options,
    419                    TocoOperator* op) const override {
    420     auto* minmax = new MinMax;
    421     minmax->min = options.min();
    422     minmax->max = options.max();
    423     op->minmax.reset(minmax);
    424     op->num_bits = options.num_bits();
    425     op->narrow_range = options.narrow_range();
    426   }
    427   int GetVersion(const OperatorSignature& op_signature) const override {
    428     const auto& fq_op = static_cast<const FakeQuantOperator&>(*op_signature.op);
    429     return fq_op.narrow_range ? 2 : 1;
    430   }
    431 };
    432 
    433 class FullyConnected
    434     : public BuiltinOperator<FullyConnectedOperator,
    435                              ::tflite::FullyConnectedOptions,
    436                              ::tflite::BuiltinOptions_FullyConnectedOptions> {
    437  public:
    438   using BuiltinOperator::BuiltinOperator;
    439   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    440       const TocoOperator& op,
    441       flatbuffers::FlatBufferBuilder* builder) const override {
    442     auto activation_function =
    443         ActivationFunction::Serialize(op.fused_activation_function);
    444     ::tflite::FullyConnectedOptionsWeightsFormat tflite_weights_format;
    445     switch (op.weights_format) {
    446       case FullyConnectedWeightsFormat::kDefault:
    447         tflite_weights_format =
    448             ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT;
    449         break;
    450       case FullyConnectedWeightsFormat::kShuffled4x16Int8:
    451         tflite_weights_format =
    452             ::tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8;
    453         break;
    454       default:
    455         LOG(ERROR) << "Unhandled FC weights format";
    456         tflite_weights_format =
    457             ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT;
    458     }
    459     return ::tflite::CreateFullyConnectedOptions(*builder, activation_function,
    460                                                  tflite_weights_format);
    461   }
    462 
    463   void ReadOptions(const TfLiteOptions& options,
    464                    TocoOperator* op) const override {
    465     op->fused_activation_function =
    466         ActivationFunction::Deserialize(options.fused_activation_function());
    467     switch (options.weights_format()) {
    468       case ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT:
    469         op->weights_format = FullyConnectedWeightsFormat::kDefault;
    470         break;
    471       case ::tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8:
    472         op->weights_format = FullyConnectedWeightsFormat::kShuffled4x16Int8;
    473         break;
    474       default:
    475         LOG(ERROR) << "Unhandled FC weights format";
    476         op->weights_format = FullyConnectedWeightsFormat::kDefault;
    477     }
    478   }
    479 
    480   // +-----------------+--------------------+--------------------------+
    481   // |                 |    Weight::Default | Weight::Shuffled4x16Int8 |
    482   // +-----------------+--------------------+--------------------------+
    483   // | Float           |                  1 |                        2 |
    484   // | Quantized Uint8 |                  1 |                        2 |
    485   // | Hybrid          |                  3 |                        3 |
    486   // | Quantized Int8  |                  4 |                        4 |
    487   // +-----------------+--------------------+--------------------------+
    488   int GetVersion(const OperatorSignature& op_signature) const override {
    489     const auto& fc_op =
    490         static_cast<const FullyConnectedOperator&>(*op_signature.op);
    491     const string& input_name = op_signature.op->inputs[0];
    492     const string& weights_name = op_signature.op->inputs[1];
    493     const string& output_name = op_signature.op->outputs[0];
    494     const Array& input_array = op_signature.model->GetArray(input_name);
    495     const Array& weights_array = op_signature.model->GetArray(weights_name);
    496     const Array& output_array = op_signature.model->GetArray(output_name);
    497     // Int8 fully fixed point kernel is at version 4.
    498     if (input_array.data_type == ArrayDataType::kInt8 &&
    499         weights_array.data_type == ArrayDataType::kInt8 &&
    500         output_array.data_type == ArrayDataType::kInt8) {
    501       return 4;
    502     }
    503     // If the op is a signed int8 hybrid operation, we need to return
    504     // version 3.
    505     if (input_array.data_type == ArrayDataType::kFloat &&
    506         weights_array.data_type == ArrayDataType::kInt8 &&
    507         output_array.data_type == ArrayDataType::kFloat) {
    508       return 3;
    509     }
    510     // For float and uint8 fixed point kernels, if the weight is
    511     // Shuffled4x16Int8, is is version 2.
    512     if (fc_op.weights_format ==
    513         FullyConnectedWeightsFormat::kShuffled4x16Int8) {
    514       return 2;
    515     }
    516 
    517     // Otherwise (weight is default), the version is 1.
    518     return 1;
    519   }
    520 };
    521 
    522 class Gather : public BuiltinOperator<GatherOperator, ::tflite::GatherOptions,
    523                                       ::tflite::BuiltinOptions_GatherOptions> {
    524  public:
    525   using BuiltinOperator::BuiltinOperator;
    526   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    527       const TocoOperator& op,
    528       flatbuffers::FlatBufferBuilder* builder) const override {
    529     int axis = op.axis ? op.axis.value() : 0;
    530     return ::tflite::CreateGatherOptions(*builder, axis);
    531   }
    532 
    533   void ReadOptions(const TfLiteOptions& options,
    534                    TocoOperator* op) const override {
    535     op->axis = {options.axis()};
    536   }
    537 
    538   int GetVersion(const OperatorSignature& op_signature) const override {
    539     const string& input_name = op_signature.op->inputs[0];
    540     const Array& input_array = op_signature.model->GetArray(input_name);
    541     // If the op take int8 input, it is version 2.
    542     if (input_array.data_type == ArrayDataType::kInt8) {
    543       return 2;
    544     }
    545     return 1;
    546   }
    547 };
    548 
    549 class GatherNd
    550     : public BuiltinOperator<GatherNdOperator, ::tflite::GatherNdOptions,
    551                              ::tflite::BuiltinOptions_GatherNdOptions> {
    552  public:
    553   using BuiltinOperator::BuiltinOperator;
    554 
    555   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    556       const TocoOperator& op,
    557       flatbuffers::FlatBufferBuilder* builder) const override {
    558     return ::tflite::CreateGatherNdOptions(*builder);
    559   }
    560 
    561   void ReadOptions(const TfLiteOptions& options,
    562                    TocoOperator* op) const override {}
    563 
    564   int GetVersion(const OperatorSignature& op_signature) const override {
    565     return 1;
    566   }
    567 };
    568 
    569 class Svdf : public BuiltinOperator<SvdfOperator, ::tflite::SVDFOptions,
    570                                     ::tflite::BuiltinOptions_SVDFOptions> {
    571  public:
    572   using BuiltinOperator::BuiltinOperator;
    573   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    574       const TocoOperator& op,
    575       flatbuffers::FlatBufferBuilder* builder) const override {
    576     auto activation_function =
    577         ActivationFunction::Serialize(op.fused_activation_function);
    578     return ::tflite::CreateSVDFOptions(*builder, op.rank, activation_function);
    579   }
    580 
    581   void ReadOptions(const TfLiteOptions& options,
    582                    TocoOperator* op) const override {
    583     op->fused_activation_function =
    584         ActivationFunction::Deserialize(options.fused_activation_function());
    585     op->rank = options.rank();
    586   }
    587 
    588   int GetVersion(const OperatorSignature& op_signature) const override {
    589     const string& input_name = op_signature.op->inputs[0];
    590     const string& weights_feature_name = op_signature.op->inputs[1];
    591     const string& output_name = op_signature.op->outputs[0];
    592     const Array& input_array = op_signature.model->GetArray(input_name);
    593     const Array& weights_feature_array =
    594         op_signature.model->GetArray(weights_feature_name);
    595     const Array& output_array = op_signature.model->GetArray(output_name);
    596     // If the op is a signed int8 hybrid operation, we need to return
    597     // version 2.
    598     if (input_array.data_type == ArrayDataType::kFloat &&
    599         weights_feature_array.data_type == ArrayDataType::kInt8 &&
    600         output_array.data_type == ArrayDataType::kFloat) {
    601       return 2;
    602     }
    603     return 1;
    604   }
    605 };
    606 
    607 class L2Normalization
    608     : public BuiltinOperator<L2NormalizationOperator, ::tflite::L2NormOptions,
    609                              ::tflite::BuiltinOptions_L2NormOptions> {
    610  public:
    611   using BuiltinOperator::BuiltinOperator;
    612   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    613       const TocoOperator& op,
    614       flatbuffers::FlatBufferBuilder* builder) const override {
    615     auto activation_function =
    616         ActivationFunction::Serialize(op.fused_activation_function);
    617     return ::tflite::CreateL2NormOptions(*builder, activation_function);
    618   }
    619 
    620   void ReadOptions(const TfLiteOptions& options,
    621                    TocoOperator* op) const override {
    622     op->fused_activation_function =
    623         ActivationFunction::Deserialize(options.fused_activation_function());
    624   }
    625 
    626   int GetVersion(const OperatorSignature& op_signature) const override {
    627     const string& output_name = op_signature.op->outputs[0];
    628     const Array& output_array = op_signature.model->GetArray(output_name);
    629     // Version 2 supports signed int8 input types.
    630     if (output_array.data_type == ArrayDataType::kInt8) {
    631       return 2;
    632     }
    633     return 1;
    634   }
    635 };
    636 
    637 class L2Pool : public BuiltinOperator<L2PoolOperator, ::tflite::Pool2DOptions,
    638                                       ::tflite::BuiltinOptions_Pool2DOptions> {
    639  public:
    640   using BuiltinOperator::BuiltinOperator;
    641   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    642       const TocoOperator& op,
    643       flatbuffers::FlatBufferBuilder* builder) const override {
    644     auto padding = Padding::Serialize(op.padding.type);
    645     auto activation_function =
    646         ActivationFunction::Serialize(op.fused_activation_function);
    647     return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
    648                                          op.stride_height, op.kwidth,
    649                                          op.kheight, activation_function);
    650   }
    651 
    652   void ReadOptions(const TfLiteOptions& options,
    653                    TocoOperator* op) const override {
    654     op->padding.type = Padding::Deserialize(options.padding());
    655     op->stride_width = options.stride_w();
    656     op->stride_height = options.stride_h();
    657     op->kwidth = options.filter_width();
    658     op->kheight = options.filter_height();
    659     op->fused_activation_function =
    660         ActivationFunction::Deserialize(options.fused_activation_function());
    661   }
    662 
    663   int GetVersion(const OperatorSignature& op_signature) const override {
    664     return 1;
    665   }
    666 };
    667 
    668 class LocalResponseNormalization
    669     : public BuiltinOperator<
    670           LocalResponseNormalizationOperator,
    671           ::tflite::LocalResponseNormalizationOptions,
    672           ::tflite::BuiltinOptions_LocalResponseNormalizationOptions> {
    673  public:
    674   using BuiltinOperator::BuiltinOperator;
    675   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    676       const TocoOperator& op,
    677       flatbuffers::FlatBufferBuilder* builder) const override {
    678     return ::tflite::CreateLocalResponseNormalizationOptions(
    679         *builder, op.range, op.bias, op.alpha, op.beta);
    680   }
    681 
    682   void ReadOptions(const TfLiteOptions& options,
    683                    TocoOperator* op) const override {
    684     op->range = options.radius();
    685     op->bias = options.bias();
    686     op->alpha = options.alpha();
    687     op->beta = options.beta();
    688   }
    689 
    690   int GetVersion(const OperatorSignature& op_signature) const override {
    691     return 1;
    692   }
    693 };
    694 
    695 class MaxPool : public BuiltinOperator<MaxPoolOperator, ::tflite::Pool2DOptions,
    696                                        ::tflite::BuiltinOptions_Pool2DOptions> {
    697  public:
    698   using BuiltinOperator::BuiltinOperator;
    699   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    700       const TocoOperator& op,
    701       flatbuffers::FlatBufferBuilder* builder) const override {
    702     auto padding = Padding::Serialize(op.padding.type);
    703     auto activation_function =
    704         ActivationFunction::Serialize(op.fused_activation_function);
    705     return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width,
    706                                          op.stride_height, op.kwidth,
    707                                          op.kheight, activation_function);
    708   }
    709 
    710   void ReadOptions(const TfLiteOptions& options,
    711                    TocoOperator* op) const override {
    712     op->padding.type = Padding::Deserialize(options.padding());
    713     op->stride_width = options.stride_w();
    714     op->stride_height = options.stride_h();
    715     op->kwidth = options.filter_width();
    716     op->kheight = options.filter_height();
    717     op->fused_activation_function =
    718         ActivationFunction::Deserialize(options.fused_activation_function());
    719   }
    720 
    721   int GetVersion(const OperatorSignature& op_signature) const override {
    722     const string& input_name = op_signature.op->inputs[0];
    723     const Array& input_array = op_signature.model->GetArray(input_name);
    724     if (input_array.data_type == ArrayDataType::kInt8) {
    725       return 2;
    726     }
    727     return 1;
    728   }
    729 };
    730 
    731 class Maximum : public SimpleOperator<TensorFlowMaximumOperator> {
    732  public:
    733   explicit Maximum() : SimpleOperator("MAXIMUM", OperatorType::kMaximum) {}
    734   int GetVersion(const OperatorSignature& op_signature) const override {
    735     const string& input_name = op_signature.op->inputs[0];
    736     const Array& input_array = op_signature.model->GetArray(input_name);
    737     // Version 2 supports signed int8 input types.
    738     if (input_array.data_type == ArrayDataType::kInt8) {
    739       return 2;
    740     }
    741     return 1;
    742   }
    743 };
    744 
    745 class Minimum : public SimpleOperator<TensorFlowMinimumOperator> {
    746  public:
    747   explicit Minimum() : SimpleOperator("MINIMUM", OperatorType::kMinimum) {}
    748   int GetVersion(const OperatorSignature& op_signature) const override {
    749     const string& input_name = op_signature.op->inputs[0];
    750     const Array& input_array = op_signature.model->GetArray(input_name);
    751     // Version 2 supports signed int8 input types.
    752     if (input_array.data_type == ArrayDataType::kInt8) {
    753       return 2;
    754     }
    755     return 1;
    756   }
    757 };
    758 
    759 class Mul : public BuiltinOperator<MulOperator, ::tflite::MulOptions,
    760                                    ::tflite::BuiltinOptions_MulOptions> {
    761  public:
    762   using BuiltinOperator::BuiltinOperator;
    763 
    764   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    765       const TocoOperator& op,
    766       flatbuffers::FlatBufferBuilder* builder) const override {
    767     auto activation_function =
    768         ActivationFunction::Serialize(op.fused_activation_function);
    769     return ::tflite::CreateMulOptions(*builder, activation_function);
    770   }
    771 
    772   void ReadOptions(const TfLiteOptions& options,
    773                    TocoOperator* op) const override {
    774     op->fused_activation_function =
    775         ActivationFunction::Deserialize(options.fused_activation_function());
    776   }
    777 
    778   int GetVersion(const OperatorSignature& op_signature) const override {
    779     const string& input_name = op_signature.op->inputs[0];
    780     const Array& input_array = op_signature.model->GetArray(input_name);
    781     // Version 2 supports signed int8 input types.
    782     if (input_array.data_type == ArrayDataType::kInt8) {
    783       return 2;
    784     }
    785     return 1;
    786   }
    787 };
    788 
    789 class Pad : public BuiltinOperator<PadOperator, ::tflite::PadOptions,
    790                                    ::tflite::BuiltinOptions_PadOptions> {
    791  public:
    792   using BuiltinOperator::BuiltinOperator;
    793 
    794   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    795       const TocoOperator& op,
    796       flatbuffers::FlatBufferBuilder* builder) const override {
    797     return ::tflite::CreatePadOptions(*builder);
    798   }
    799 
    800   void ReadOptions(const TfLiteOptions& options,
    801                    TocoOperator* op) const override {}
    802 
    803   int GetVersion(const OperatorSignature& op_signature) const override {
    804     const string& input_name = op_signature.op->inputs[0];
    805     const Array& input_array = op_signature.model->GetArray(input_name);
    806     // If the op take int8 input, it is version 2.
    807     if (input_array.data_type == ArrayDataType::kInt8) {
    808       return 2;
    809     }
    810     return 1;
    811   }
    812 };
    813 
    814 class Tile
    815     : public BuiltinOperator<TensorFlowTileOperator, ::tflite::TileOptions,
    816                              ::tflite::BuiltinOptions_TileOptions> {
    817   using BuiltinOperator::BuiltinOperator;
    818 
    819   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    820       const TocoOperator& op,
    821       flatbuffers::FlatBufferBuilder* builder) const override {
    822     return ::tflite::CreateTileOptions(*builder);
    823   }
    824 
    825   void ReadOptions(const TfLiteOptions& options,
    826                    TocoOperator* op) const override {}
    827   int GetVersion(const OperatorSignature& op_signature) const override {
    828     return 1;
    829   }
    830 };
    831 
    832 class PadV2 : public BuiltinOperator<PadV2Operator, ::tflite::PadV2Options,
    833                                      ::tflite::BuiltinOptions_PadV2Options> {
    834  public:
    835   using BuiltinOperator::BuiltinOperator;
    836 
    837   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    838       const TocoOperator& op,
    839       flatbuffers::FlatBufferBuilder* builder) const override {
    840     return ::tflite::CreatePadV2Options(*builder);
    841   }
    842 
    843   void ReadOptions(const TfLiteOptions& options,
    844                    TocoOperator* op) const override {}
    845 
    846   int GetVersion(const OperatorSignature& op_signature) const override {
    847     const string& input_name = op_signature.op->inputs[0];
    848     const Array& input_array = op_signature.model->GetArray(input_name);
    849     // If the op take int8 input, it is version 2.
    850     if (input_array.data_type == ArrayDataType::kInt8) {
    851       return 2;
    852     }
    853     return 1;
    854   }
    855 };
    856 
    857 class Reshape
    858     : public BuiltinOperator<TensorFlowReshapeOperator,
    859                              ::tflite::ReshapeOptions,
    860                              ::tflite::BuiltinOptions_ReshapeOptions> {
    861  public:
    862   using BuiltinOperator::BuiltinOperator;
    863 
    864   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    865       const TocoOperator& op,
    866       flatbuffers::FlatBufferBuilder* builder) const override {
    867     return ::tflite::CreateReshapeOptions(*builder,
    868                                           builder->CreateVector(op.shape));
    869   }
    870 
    871   void ReadOptions(const TfLiteOptions& options,
    872                    TocoOperator* op) const override {
    873     op->shape.insert(op->shape.end(), options.new_shape()->begin(),
    874                      options.new_shape()->end());
    875   }
    876 
    877   int GetVersion(const OperatorSignature& op_signature) const override {
    878     return 1;
    879   }
    880 };
    881 
    882 class Softmax
    883     : public BuiltinOperator<SoftmaxOperator, ::tflite::SoftmaxOptions,
    884                              ::tflite::BuiltinOptions_SoftmaxOptions> {
    885  public:
    886   using BuiltinOperator::BuiltinOperator;
    887   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    888       const TocoOperator& op,
    889       flatbuffers::FlatBufferBuilder* builder) const override {
    890     return ::tflite::CreateSoftmaxOptions(*builder, op.beta);
    891   }
    892 
    893   void ReadOptions(const TfLiteOptions& options,
    894                    TocoOperator* op) const override {
    895     op->beta = options.beta();
    896   }
    897 
    898   int GetVersion(const OperatorSignature& op_signature) const override {
    899     const string& input_name = op_signature.op->inputs[0];
    900     const Array& input_array = op_signature.model->GetArray(input_name);
    901     if (input_array.data_type == ArrayDataType::kInt8) {
    902       return 2;
    903     }
    904     return 1;
    905   }
    906 };
    907 
    908 class SpaceToDepth
    909     : public BuiltinOperator<SpaceToDepthOperator,
    910                              ::tflite::SpaceToDepthOptions,
    911                              ::tflite::BuiltinOptions_SpaceToDepthOptions> {
    912  public:
    913   using BuiltinOperator::BuiltinOperator;
    914   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    915       const TocoOperator& op,
    916       flatbuffers::FlatBufferBuilder* builder) const override {
    917     return ::tflite::CreateSpaceToDepthOptions(*builder, op.block_size);
    918   }
    919 
    920   void ReadOptions(const TfLiteOptions& options,
    921                    TocoOperator* op) const override {
    922     op->block_size = options.block_size();
    923   }
    924 
    925   int GetVersion(const OperatorSignature& op_signature) const override {
    926     const string& input_name = op_signature.op->inputs[0];
    927     const Array& input_array = op_signature.model->GetArray(input_name);
    928     // If the op take int8 input, it is version 2.
    929     if (input_array.data_type == ArrayDataType::kInt8) {
    930       return 2;
    931     }
    932     return 1;
    933   }
    934 };
    935 
    936 class Transpose
    937     : public BuiltinOperator<TransposeOperator, ::tflite::TransposeOptions,
    938                              ::tflite::BuiltinOptions_TransposeOptions> {
    939  public:
    940   using BuiltinOperator::BuiltinOperator;
    941   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    942       const TocoOperator& op,
    943       flatbuffers::FlatBufferBuilder* builder) const override {
    944     return ::tflite::CreateTransposeOptions(*builder);
    945   }
    946 
    947   void ReadOptions(const TfLiteOptions& options,
    948                    TocoOperator* op) const override {}
    949 
    950   int GetVersion(const OperatorSignature& op_signature) const override {
    951     const string& input_name = op_signature.op->inputs[0];
    952     const Array& input_array = op_signature.model->GetArray(input_name);
    953     // If the op take int8 input, it is version 2.
    954     if (input_array.data_type == ArrayDataType::kInt8) {
    955       return 2;
    956     }
    957     return 1;
    958   }
    959 };
    960 
    961 class Lstm : public BuiltinOperator<LstmCellOperator, ::tflite::LSTMOptions,
    962                                     ::tflite::BuiltinOptions_LSTMOptions> {
    963  public:
    964   using BuiltinOperator::BuiltinOperator;
    965   flatbuffers::Offset<TfLiteOptions> WriteOptions(
    966       const TocoOperator& op,
    967       flatbuffers::FlatBufferBuilder* builder) const override {
    968     ::tflite::LSTMKernelType kernel_type = ::tflite::LSTMKernelType_FULL;
    969     switch (op.kernel_type) {
    970       case LstmCellOperator::KERNEL_BASIC:
    971         kernel_type = ::tflite::LSTMKernelType_BASIC;
    972         break;
    973       case LstmCellOperator::KERNEL_FULL:
    974         kernel_type = ::tflite::LSTMKernelType_FULL;
    975         break;
    976       default:
    977         return -1;
    978     }
    979 
    980     // Current toco converter only supports tanh, no clip.
    981     return ::tflite::CreateLSTMOptions(*builder, /*fused_activation_function=*/
    982                                        ::tflite::ActivationFunctionType_TANH,
    983                                        /*cell_clip=*/0.0,
    984                                        /*proj_clip=*/0.0, kernel_type);
    985   }
    986 
    987   void ReadOptions(const TfLiteOptions& options,
    988                    TocoOperator* op) const override {
    989     // Only support tanh activation, so check that tflite type is tanh.
    990     CHECK(options.fused_activation_function() ==
    991           ::tflite::ActivationFunctionType_TANH);
    992 
    993     switch (options.kernel_type()) {
    994       case ::tflite::LSTMKernelType_BASIC:
    995         op->kernel_type = LstmCellOperator::KERNEL_BASIC;
    996         break;
    997       case ::tflite::LSTMKernelType_FULL:
    998         op->kernel_type = LstmCellOperator::KERNEL_FULL;
    999         break;
   1000     }
   1001   }
   1002 
   1003   int GetVersion(const OperatorSignature& op_signature) const override {
   1004     const auto& lstm_op =
   1005         static_cast<const LstmCellOperator&>(*op_signature.op);
   1006     switch (lstm_op.kernel_type) {
   1007       case LstmCellOperator::KERNEL_FULL: {
   1008         // If the input tensor is float and a weight is int8, this is a version
   1009         // 3 hybrid operation.
   1010         const string& input_name = op_signature.op->inputs[0];
   1011         const string& weights_name = op_signature.op->inputs[2];
   1012         const string& output_name = op_signature.op->outputs[0];
   1013         const Array& input_array = op_signature.model->GetArray(input_name);
   1014         const Array& weights_array = op_signature.model->GetArray(weights_name);
   1015         const Array& output_array = op_signature.model->GetArray(output_name);
   1016         if (input_array.data_type == ArrayDataType::kFloat &&
   1017             weights_array.data_type == ArrayDataType::kInt8 &&
   1018             output_array.data_type == ArrayDataType::kFloat) {
   1019           return 3;
   1020         }
   1021         return 1;
   1022       }
   1023       case LstmCellOperator::KERNEL_BASIC:
   1024         // KERNEL_BASIC was added in version 2.
   1025         return 2;
   1026     }
   1027   }
   1028 
   1029   std::vector<bool> GetMutatingInputVariables(
   1030       const Operator& op) const override {
   1031     const auto& lstm_op = static_cast<const LstmCellOperator&>(op);
   1032 
   1033     std::vector<bool> mutating_input_variables(op.inputs.size(), false);
   1034     switch (lstm_op.kernel_type) {
   1035       case LstmCellOperator::KERNEL_FULL: {
   1036         mutating_input_variables[kInputActivationStateTensor] = true;
   1037         mutating_input_variables[kInputCellStateTensor] = true;
   1038         break;
   1039       }
   1040       case LstmCellOperator::KERNEL_BASIC: {
   1041         mutating_input_variables[LstmCellOperator::PREV_ACTIV_INPUT] = true;
   1042         mutating_input_variables[LstmCellOperator::PREV_STATE_INPUT] = true;
   1043         break;
   1044       }
   1045     }
   1046     return mutating_input_variables;
   1047   }
   1048 };
   1049 
   1050 class UnidirectionalSequenceLstm
   1051     : public BuiltinOperator<
   1052           UnidirectionalSequenceLstmOperator,
   1053           ::tflite::UnidirectionalSequenceLSTMOptions,
   1054           ::tflite::BuiltinOptions_UnidirectionalSequenceLSTMOptions> {
   1055  public:
   1056   using BuiltinOperator::BuiltinOperator;
   1057   flatbuffers::Offset<TfLiteOptions> WriteOptions(
   1058       const TocoOperator& op,
   1059       flatbuffers::FlatBufferBuilder* builder) const override {
   1060     // Current toco converter only supports tanh, no clip.
   1061     return ::tflite::CreateUnidirectionalSequenceLSTMOptions(
   1062         *builder, /*fused_activation_function=*/
   1063         ::tflite::ActivationFunctionType_TANH,
   1064         /*cell_clip=*/0.0,
   1065         /*proj_clip=*/0.0,
   1066         /*time_major=*/true);
   1067   }
   1068 
   1069   void ReadOptions(const TfLiteOptions& options,
   1070                    TocoOperator* op) const override {
   1071     // Only support tanh activation, so check that tflite type is tanh.
   1072     DCHECK(options.fused_activation_function() ==
   1073            ::tflite::ActivationFunctionType_TANH);
   1074   }
   1075 
   1076   int GetVersion(const OperatorSignature& op_signature) const override {
   1077     // If the input tensor is float and a weight is int8, this is a version
   1078     // 2 hybrid operation.
   1079     const string& input_name = op_signature.op->inputs[0];
   1080     const string& weights_name = op_signature.op->inputs[2];
   1081     const string& output_name = op_signature.op->outputs[0];
   1082     const Array& input_array = op_signature.model->GetArray(input_name);
   1083     const Array& weights_array = op_signature.model->GetArray(weights_name);
   1084     const Array& output_array = op_signature.model->GetArray(output_name);
   1085     if (input_array.data_type == ArrayDataType::kFloat &&
   1086         weights_array.data_type == ArrayDataType::kInt8 &&
   1087         output_array.data_type == ArrayDataType::kFloat) {
   1088       return 2;
   1089     }
   1090     return 1;
   1091   }
   1092 
   1093   std::vector<bool> GetMutatingInputVariables(
   1094       const Operator& op) const override {
   1095     std::vector<bool> mutating_input_variables(op.inputs.size(), false);
   1096     mutating_input_variables[kInputActivationStateTensor] = true;
   1097     mutating_input_variables[kInputCellStateTensor] = true;
   1098     return mutating_input_variables;
   1099   }
   1100 };
   1101 
   1102 class BidirectionalSequenceLstm
   1103     : public BuiltinOperator<
   1104           BidirectionalSequenceLstmOperator,
   1105           ::tflite::BidirectionalSequenceLSTMOptions,
   1106           ::tflite::BuiltinOptions_BidirectionalSequenceLSTMOptions> {
   1107  public:
   1108   using BuiltinOperator::BuiltinOperator;
   1109   flatbuffers::Offset<TfLiteOptions> WriteOptions(
   1110       const TocoOperator& op,
   1111       flatbuffers::FlatBufferBuilder* builder) const override {
   1112     // Current toco converter only supports tanh, no clip.
   1113     return ::tflite::CreateBidirectionalSequenceLSTMOptions(
   1114         *builder, /*fused_activation_function=*/
   1115         ::tflite::ActivationFunctionType_TANH,
   1116         /*cell_clip=*/0.0,
   1117         /*proj_clip=*/0.0,
   1118         /*merge_outputs=*/op.merge_outputs,
   1119         /*time_major=*/true);
   1120   }
   1121 
   1122   void ReadOptions(const TfLiteOptions& options,
   1123                    TocoOperator* op) const override {
   1124     // Only support tanh activation, so check that tflite type is tanh.
   1125     DCHECK(options.fused_activation_function() ==
   1126            ::tflite::ActivationFunctionType_TANH);
   1127     op->merge_outputs = options.merge_outputs();
   1128   }
   1129 
   1130   int GetVersion(const OperatorSignature& op_signature) const override {
   1131     return 1;
   1132   }
   1133 
   1134   std::vector<bool> GetMutatingInputVariables(
   1135       const Operator& op) const override {
   1136     std::vector<bool> mutating_input_variables(op.inputs.size(), false);
   1137     // Forward input activation state.
   1138     mutating_input_variables[35] = true;
   1139     // Forward input cell state.
   1140     mutating_input_variables[36] = true;
   1141     // Backward input activation state.
   1142     mutating_input_variables[37] = true;
   1143     // Backward input cell state.
   1144     mutating_input_variables[38] = true;
   1145     return mutating_input_variables;
   1146   }
   1147 };
   1148 
   1149 class BidirectionalSequenceRnn
   1150     : public BuiltinOperator<
   1151           BidirectionalSequenceRnnOperator,
   1152           ::tflite::BidirectionalSequenceRNNOptions,
   1153           ::tflite::BuiltinOptions_BidirectionalSequenceRNNOptions> {
   1154  public:
   1155   using BuiltinOperator::BuiltinOperator;
   1156   flatbuffers::Offset<TfLiteOptions> WriteOptions(
   1157       const TocoOperator& op,
   1158       flatbuffers::FlatBufferBuilder* builder) const override {
   1159     // Current toco converter only supports tanh, no clip.
   1160     return ::tflite::CreateBidirectionalSequenceRNNOptions(
   1161         *builder, /*time_major=*/true,
   1162         /*fused_activation_function=*/
   1163         ::tflite::ActivationFunctionType_TANH,
   1164         /*merge_outputs=*/op.merge_outputs);
   1165   }
   1166 
   1167   void ReadOptions(const TfLiteOptions& options,
   1168                    TocoOperator* op) const override {
   1169     // Only support tanh activation, so check that tflite type is tanh.
   1170     DCHECK(options.fused_activation_function() ==
   1171            ::tflite::ActivationFunctionType_TANH);
   1172     op->merge_outputs = options.merge_outputs();
   1173   }
   1174 
   1175   int GetVersion(const OperatorSignature& op_signature) const override {
   1176     return 1;
   1177   }
   1178 
   1179   std::vector<bool> GetMutatingInputVariables(
   1180       const Operator& op) const override {
   1181     std::vector<bool> mutating_input_variables(op.inputs.size(), false);
   1182     // Forward hidden state.
   1183     mutating_input_variables[4] = true;
   1184     // Backward hidden state.
   1185     mutating_input_variables[8] = true;
   1186     return mutating_input_variables;
   1187   }
   1188 };
   1189 
   1190 class Mean : public BuiltinOperator<MeanOperator, ::tflite::ReducerOptions,
   1191                                     ::tflite::BuiltinOptions_ReducerOptions> {
   1192  public:
   1193   using BuiltinOperator::BuiltinOperator;
   1194   flatbuffers::Offset<TfLiteOptions> WriteOptions(
   1195       const TocoOperator& op,
   1196       flatbuffers::FlatBufferBuilder* builder) const override {
   1197     return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
   1198   }
   1199 
   1200   void ReadOptions(const TfLiteOptions& options,
   1201                    TocoOperator* op) const override {
   1202     op->keep_dims = options.keep_dims();
   1203   }
   1204 
   1205   int GetVersion(const OperatorSignature& op_signature) const override {
   1206     return 1;
   1207   }
   1208 };
   1209 
   1210 class Sum
   1211     : public BuiltinOperator<TensorFlowSumOperator, ::tflite::ReducerOptions,
   1212                              ::tflite::BuiltinOptions_ReducerOptions> {
   1213  public:
   1214   using BuiltinOperator::BuiltinOperator;
   1215   flatbuffers::Offset<TfLiteOptions> WriteOptions(
   1216       const TocoOperator& op,
   1217       flatbuffers::FlatBufferBuilder* builder) const override {
   1218     return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
   1219   }
   1220 
   1221   void ReadOptions(const TfLiteOptions& options,
   1222                    TocoOperator* op) const override {
   1223     op->keep_dims = options.keep_dims();
   1224   }
   1225 
   1226   int GetVersion(const OperatorSignature& op_signature) const override {
   1227     return 1;
   1228   }
   1229 };
   1230 
   1231 class ReduceMax
   1232     : public BuiltinOperator<TensorFlowMaxOperator, ::tflite::ReducerOptions,
   1233                              ::tflite::BuiltinOptions_ReducerOptions> {
   1234  public:
   1235   using BuiltinOperator::BuiltinOperator;
   1236   flatbuffers::Offset<TfLiteOptions> WriteOptions(
   1237       const TocoOperator& op,
   1238       flatbuffers::FlatBufferBuilder* builder) const override {
   1239     return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
   1240   }
   1241 
   1242   void ReadOptions(const TfLiteOptions& options,
   1243                    TocoOperator* op) const override {
   1244     op->keep_dims = options.keep_dims();
   1245   }
   1246 
   1247   int GetVersion(const OperatorSignature& op_signature) const override {
   1248     const string& input_name = op_signature.op->inputs[0];
   1249     const Array& input_array = op_signature.model->GetArray(input_name);
   1250     // If the op take int8 input, it is version 2.
   1251     if (input_array.data_type == ArrayDataType::kInt8) {
   1252       return 2;
   1253     }
   1254     return 1;
   1255   }
   1256 };
   1257 
   1258 class ReduceMin
   1259     : public BuiltinOperator<TensorFlowMinOperator, ::tflite::ReducerOptions,
   1260                              ::tflite::BuiltinOptions_ReducerOptions> {
   1261  public:
   1262   using BuiltinOperator::BuiltinOperator;
   1263   flatbuffers::Offset<TfLiteOptions> WriteOptions(
   1264       const TocoOperator& op,
   1265       flatbuffers::FlatBufferBuilder* builder) const override {
   1266     return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
   1267   }
   1268 
   1269   void ReadOptions(const TfLiteOptions& options,
   1270                    TocoOperator* op) const override {
   1271     op->keep_dims = options.keep_dims();
   1272   }
   1273 
   1274   int GetVersion(const OperatorSignature& op_signature) const override {
   1275     const string& input_name = op_signature.op->inputs[0];
   1276     const Array& input_array = op_signature.model->GetArray(input_name);
   1277     // If the op take int8 input, it is version 2.
   1278     if (input_array.data_type == ArrayDataType::kInt8) {
   1279       return 2;
   1280     }
   1281     return 1;
   1282   }
   1283 };
   1284 
   1285 class ReduceProd
   1286     : public BuiltinOperator<TensorFlowProdOperator, ::tflite::ReducerOptions,
   1287                              ::tflite::BuiltinOptions_ReducerOptions> {
   1288  public:
   1289   using BuiltinOperator::BuiltinOperator;
   1290   flatbuffers::Offset<TfLiteOptions> WriteOptions(
   1291       const TocoOperator& op,
   1292       flatbuffers::FlatBufferBuilder* builder) const override {
   1293     return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
   1294   }
   1295 
   1296   void ReadOptions(const TfLiteOptions& options,
   1297                    TocoOperator* op) const override {
   1298     op->keep_dims = options.keep_dims();
   1299   }
   1300 
   1301   int GetVersion(const OperatorSignature& op_signature) const override {
   1302     return 1;
   1303   }
   1304 };
   1305 
   1306 class ReduceAny
   1307     : public BuiltinOperator<TensorFlowAnyOperator, ::tflite::ReducerOptions,
   1308                              ::tflite::BuiltinOptions_ReducerOptions> {
   1309  public:
   1310   using BuiltinOperator::BuiltinOperator;
   1311   flatbuffers::Offset<TfLiteOptions> WriteOptions(
   1312       const TocoOperator& op,
   1313       flatbuffers::FlatBufferBuilder* builder) const override {
   1314     return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
   1315   }
   1316 
   1317   void ReadOptions(const TfLiteOptions& options,
   1318                    TocoOperator* op) const override {
   1319     op->keep_dims = options.keep_dims();
   1320   }
   1321 
   1322   int GetVersion(const OperatorSignature& op_signature) const override {
   1323     return 1;
   1324   }
   1325 };
   1326 
   1327 class Relu6 : public SimpleOperator<Relu6Operator> {
   1328  public:
   1329   explicit Relu6() : SimpleOperator("RELU6", OperatorType::kRelu6) {}
   1330   int GetVersion(const OperatorSignature& op_signature) const override {
   1331     const string& input_name = op_signature.op->inputs[0];
   1332     const Array& input_array = op_signature.model->GetArray(input_name);
   1333     // Version 2 supports signed int8 input types.
   1334     if (input_array.data_type == ArrayDataType::kInt8) {
   1335       return 2;
   1336     }
   1337     return 1;
   1338   }
   1339 };
   1340 
   1341 class ResizeBilinear
   1342     : public BuiltinOperator<ResizeBilinearOperator,
   1343                              ::tflite::ResizeBilinearOptions,
   1344                              ::tflite::BuiltinOptions_ResizeBilinearOptions> {
   1345  public:
   1346   using BuiltinOperator::BuiltinOperator;
   1347   flatbuffers::Offset<TfLiteOptions> WriteOptions(
   1348       const TocoOperator& op,
   1349       flatbuffers::FlatBufferBuilder* builder) const override {
   1350     return ::tflite::CreateResizeBilinearOptions(*builder, op.align_corners);
   1351   }
   1352 
   1353   void ReadOptions(const TfLiteOptions& options,
   1354                    TocoOperator* op) const override {
   1355     op->align_corners = options.align_corners();
   1356   }
   1357 
   1358   int GetVersion(const OperatorSignature& op_signature) const override {
   1359     const string& input_name = op_signature.op->inputs[0];
   1360     const Array& input_array = op_signature.model->GetArray(input_name);
   1361     // If the op takes int8 input, it is version 2.
   1362     if (input_array.data_type == ArrayDataType::kInt8) {
   1363       return 2;
   1364     }
   1365     return 1;
   1366   }
   1367 };
   1368 
   1369 class ResizeNearestNeighbor
   1370     : public BuiltinOperator<
   1371           ResizeNearestNeighborOperator, ::tflite::ResizeNearestNeighborOptions,
   1372           ::tflite::BuiltinOptions_ResizeNearestNeighborOptions> {
   1373  public:
   1374   using BuiltinOperator::BuiltinOperator;
   1375   flatbuffers::Offset<TfLiteOptions> WriteOptions(
   1376       const TocoOperator& op,
   1377       flatbuffers::FlatBufferBuilder* builder) const override {
   1378     return ::tflite::CreateResizeNearestNeighborOptions(*builder,
   1379                                                         op.align_corners);
   1380   }
   1381 
   1382   void ReadOptions(const TfLiteOptions& options,
   1383                    TocoOperator* op) const override {
   1384     op->align_corners = options.align_corners();
   1385   }
   1386 
   1387   int GetVersion(const OperatorSignature& op_signature) const override {
   1388     const string& input_name = op_signature.op->inputs[0];
   1389     const Array& input_array = op_signature.model->GetArray(input_name);
   1390     // Version 2 supports signed int8 input types.
   1391     if (input_array.data_type == ArrayDataType::kInt8) {
   1392       return 2;
   1393     }
   1394     return 1;
   1395   }
   1396 };
   1397 
   1398 class Squeeze
   1399     : public BuiltinOperator<SqueezeOperator, ::tflite::SqueezeOptions,
   1400                              ::tflite::BuiltinOptions_SqueezeOptions> {
   1401  public:
   1402   using BuiltinOperator::BuiltinOperator;
   1403 
   1404   flatbuffers::Offset<TfLiteOptions> WriteOptions(
   1405       const TocoOperator& op,
   1406       flatbuffers::FlatBufferBuilder* builder) const override {
   1407     auto squeeze_dims = builder->CreateVector(op.squeeze_dims);
   1408     return ::tflite::CreateSqueezeOptions(*builder, squeeze_dims);
   1409   }
   1410 
   1411   void ReadOptions(const TfLiteOptions& options,
   1412                    TocoOperator* op) const override {
   1413     op->squeeze_dims.insert(op->squeeze_dims.end(),
   1414                             options.squeeze_dims()->begin(),
   1415                             options.squeeze_dims()->end());
   1416   }
   1417 
   1418   int GetVersion(const OperatorSignature& op_signature) const override {
   1419     return 1;
   1420   }
   1421 };
   1422 
   1423 class Split
   1424     : public BuiltinOperator<TensorFlowSplitOperator, ::tflite::SplitOptions,
   1425                              ::tflite::BuiltinOptions_SplitOptions> {
   1426  public:
   1427   using BuiltinOperator::BuiltinOperator;
   1428 
   1429   flatbuffers::Offset<TfLiteOptions> WriteOptions(
   1430       const TocoOperator& op,
   1431       flatbuffers::FlatBufferBuilder* builder) const override {
   1432     return ::tflite::CreateSplitOptions(*builder, op.num_split);
   1433   }
   1434 
   1435   void ReadOptions(const TfLiteOptions& options,
   1436                    TocoOperator* op) const override {
   1437     op->num_split = options.num_splits();
   1438   }
   1439 
   1440   int GetVersion(const OperatorSignature& op_signature) const override {
   1441     const string& input_name = op_signature.op->inputs[0];
   1442     const Array& input_array = op_signature.model->GetArray(input_name);
   1443     // If the op take int8 input, it is version 2, for int32 it's version 3.
   1444     if (input_array.data_type == ArrayDataType::kInt8) {
   1445       return 2;
   1446     } else if (input_array.data_type == ArrayDataType::kInt32) {
   1447       return 3;
   1448     }
   1449     return 1;
   1450   }
   1451 };
   1452 
   1453 class SplitV
   1454     : public BuiltinOperator<TensorFlowSplitVOperator, ::tflite::SplitVOptions,
   1455                              ::tflite::BuiltinOptions_SplitVOptions> {
   1456  public:
   1457   using BuiltinOperator::BuiltinOperator;
   1458 
   1459   flatbuffers::Offset<TfLiteOptions> WriteOptions(
   1460       const TocoOperator& op,
   1461       flatbuffers::FlatBufferBuilder* builder) const override {
   1462     return ::tflite::CreateSplitVOptions(*builder, op.num_split);
   1463   }
   1464 
   1465   void ReadOptions(const TfLiteOptions& options,
   1466                    TocoOperator* op) const override {
   1467     op->num_split = options.num_splits();
   1468   }
   1469 
   1470   int GetVersion(const OperatorSignature& op_signature) const override {
   1471     return 1;
   1472   }
   1473 };
   1474 
   1475 class StridedSlice
   1476     : public BuiltinOperator<StridedSliceOperator,
   1477                              ::tflite::StridedSliceOptions,
   1478                              ::tflite::BuiltinOptions_StridedSliceOptions> {
   1479  public:
   1480   using BuiltinOperator::BuiltinOperator;
   1481   flatbuffers::Offset<TfLiteOptions> WriteOptions(
   1482       const TocoOperator& op,
   1483       flatbuffers::FlatBufferBuilder* builder) const override {
   1484     return ::tflite::CreateStridedSliceOptions(
   1485         *builder, op.begin_mask, op.end_mask, op.ellipsis_mask,
   1486         op.new_axis_mask, op.shrink_axis_mask);
   1487   }
   1488 
   1489   void ReadOptions(const TfLiteOptions& options,
   1490                    TocoOperator* op) const override {
   1491     op->begin_mask = options.begin_mask();
   1492     op->end_mask = options.end_mask();
   1493     op->ellipsis_mask = options.ellipsis_mask();
   1494     op->new_axis_mask = options.new_axis_mask();
   1495     op->shrink_axis_mask = options.shrink_axis_mask();
   1496   }
   1497 
   1498   int GetVersion(const OperatorSignature& op_signature) const override {
   1499     const string& input_name = op_signature.op->inputs[0];
   1500     const Array& input_array = op_signature.model->GetArray(input_name);
   1501     // If the op take int8 input, it is version 2.
   1502     if (input_array.data_type == ArrayDataType::kInt8) {
   1503       return 2;
   1504     }
   1505     return 1;
   1506   }
   1507 };
   1508 
   1509 class TopK_V2 : public BuiltinOperator<TopKV2Operator, ::tflite::TopKV2Options,
   1510                                        ::tflite::BuiltinOptions_TopKV2Options> {
   1511  public:
   1512   using BuiltinOperator::BuiltinOperator;
   1513   flatbuffers::Offset<TfLiteOptions> WriteOptions(
   1514       const TocoOperator& op,
   1515       flatbuffers::FlatBufferBuilder* builder) const override {
   1516     return ::tflite::CreateTopKV2Options(*builder);
   1517   }
   1518 
   1519   void ReadOptions(const TfLiteOptions& options,
   1520                    TocoOperator* op) const override {}
   1521 
   1522   int GetVersion(const OperatorSignature& op_signature) const override {
   1523     const string& input_name = op_signature.op->inputs[0];
   1524     const Array& input_array = op_signature.model->GetArray(input_name);
   1525     if (input_array.data_type == ArrayDataType::kInt8) {
   1526       return 2;
   1527     }
   1528     return 1;
   1529   }
   1530 };
   1531 
   1532 class ArgMax : public BuiltinOperator<ArgMaxOperator, ::tflite::ArgMaxOptions,
   1533                                       ::tflite::BuiltinOptions_ArgMaxOptions> {
   1534  public:
   1535   using BuiltinOperator::BuiltinOperator;
   1536   flatbuffers::Offset<TfLiteOptions> WriteOptions(
   1537       const TocoOperator& op,
   1538       flatbuffers::FlatBufferBuilder* builder) const override {
   1539     return ::tflite::CreateArgMaxOptions(
   1540         *builder, DataType::Serialize(op.output_data_type));
   1541   }
   1542 
   1543   void ReadOptions(const TfLiteOptions& options,
   1544                    TocoOperator* op) const override {
   1545     op->output_data_type = DataType::Deserialize(options.output_type());
   1546   }
   1547 
   1548   int GetVersion(const OperatorSignature& op_signature) const override {
   1549     const string& input_name = op_signature.op->inputs[0];
   1550     const Array& input_array = op_signature.model->GetArray(input_name);
   1551     if (input_array.data_type == ArrayDataType::kInt8) {
   1552       return 2;
   1553     }
   1554 
   1555     return 1;
   1556   }
   1557 };
   1558 
   1559 class ArgMin : public BuiltinOperator<ArgMinOperator, ::tflite::ArgMinOptions,
   1560                                       ::tflite::BuiltinOptions_ArgMinOptions> {
   1561  public:
   1562   using BuiltinOperator::BuiltinOperator;
   1563   flatbuffers::Offset<TfLiteOptions> WriteOptions(
   1564       const TocoOperator& op,
   1565       flatbuffers::FlatBufferBuilder* builder) const override {
   1566     return ::tflite::CreateArgMinOptions(
   1567         *builder, DataType::Serialize(op.output_data_type));
   1568   }
   1569 
   1570   void ReadOptions(const TfLiteOptions& options,
   1571                    TocoOperator* op) const override {
   1572     op->output_data_type = DataType::Deserialize(options.output_type());
   1573   }
   1574 
   1575   int GetVersion(const OperatorSignature& op_signature) const override {
   1576     const string& input_name = op_signature.op->inputs[0];
   1577     const Array& input_array = op_signature.model->GetArray(input_name);
   1578     if (input_array.data_type == ArrayDataType::kInt8) {
   1579       return 2;
   1580     }
   1581 
   1582     return 1;
   1583   }
   1584 };
   1585 
   1586 class TransposeConv
   1587     : public BuiltinOperator<TransposeConvOperator,
   1588                              ::tflite::TransposeConvOptions,
   1589                              ::tflite::BuiltinOptions_TransposeConvOptions> {
   1590  public:
   1591   using BuiltinOperator::BuiltinOperator;
   1592 
   1593   flatbuffers::Offset<TfLiteOptions> WriteOptions(
   1594       const TocoOperator& op,
   1595       flatbuffers::FlatBufferBuilder* builder) const override {
   1596     auto padding = Padding::Serialize(op.padding.type);
   1597     return ::tflite::CreateTransposeConvOptions(
   1598         *builder, padding, op.stride_width, op.stride_height);
   1599   }
   1600 
   1601   void ReadOptions(const TfLiteOptions& options,
   1602                    TocoOperator* op) const override {
   1603     op->padding.type = Padding::Deserialize(options.padding());
   1604     op->stride_width = options.stride_w();
   1605     op->stride_height = options.stride_h();
   1606   }
   1607 
   1608   int GetVersion(const OperatorSignature& op_signature) const override {
   1609     return 1;
   1610   }
   1611 };
   1612 
   1613 class SparseToDense
   1614     : public BuiltinOperator<SparseToDenseOperator,
   1615                              ::tflite::SparseToDenseOptions,
   1616                              ::tflite::BuiltinOptions_SparseToDenseOptions> {
   1617  public:
   1618   using BuiltinOperator::BuiltinOperator;
   1619 
   1620   flatbuffers::Offset<TfLiteOptions> WriteOptions(
   1621       const TocoOperator& op,
   1622       flatbuffers::FlatBufferBuilder* builder) const override {
   1623     return ::tflite::CreateSparseToDenseOptions(*builder, op.validate_indices);
   1624   }
   1625 
   1626   void ReadOptions(const TfLiteOptions& options,
   1627                    TocoOperator* op) const override {
   1628     op->validate_indices = options.validate_indices();
   1629   }
   1630 
   1631   int GetVersion(const OperatorSignature& op_signature) const override {
   1632     return 1;
   1633   }
   1634 };
   1635 
   1636 class ExpandDims
   1637     : public BuiltinOperator<ExpandDimsOperator, ::tflite::ExpandDimsOptions,
   1638                              ::tflite::BuiltinOptions_ExpandDimsOptions> {
   1639  public:
   1640   using BuiltinOperator::BuiltinOperator;
   1641 
   1642   flatbuffers::Offset<TfLiteOptions> WriteOptions(
   1643       const TocoOperator& op,
   1644       flatbuffers::FlatBufferBuilder* builder) const override {
   1645     return ::tflite::CreateExpandDimsOptions(*builder);
   1646   }
   1647 
   1648   void ReadOptions(const TfLiteOptions& options,
   1649                    TocoOperator* op) const override {}
   1650 
   1651   int GetVersion(const OperatorSignature& op_signature) const override {
   1652     return 1;
   1653   }
   1654 };
   1655 
   1656 class Pack : public BuiltinOperator<PackOperator, ::tflite::PackOptions,
   1657                                     ::tflite::BuiltinOptions_PackOptions> {
   1658  public:
   1659   using BuiltinOperator::BuiltinOperator;
   1660 
   1661   flatbuffers::Offset<TfLiteOptions> WriteOptions(
   1662       const TocoOperator& op,
   1663       flatbuffers::FlatBufferBuilder* builder) const override {
   1664     return ::tflite::CreatePackOptions(*builder, op.values_count, op.axis);
   1665   }
   1666 
   1667   void ReadOptions(const TfLiteOptions& options,
   1668                    TocoOperator* op) const override {
   1669     op->values_count = options.values_count();
   1670     op->axis = options.axis();
   1671   }
   1672 
   1673   int GetVersion(const OperatorSignature& op_signature) const override {
   1674     const string& input_name = op_signature.op->inputs[0];
   1675     const Array& input_array = op_signature.model->GetArray(input_name);
   1676     // If the op take int8 input, it is version 2.
   1677     if (input_array.data_type == ArrayDataType::kInt8) {
   1678       return 2;
   1679     }
   1680     return 1;
   1681   }
   1682 };
   1683 
   1684 class Shape
   1685     : public BuiltinOperator<TensorFlowShapeOperator, ::tflite::ShapeOptions,
   1686                              ::tflite::BuiltinOptions_ShapeOptions> {
   1687  public:
   1688   using BuiltinOperator::BuiltinOperator;
   1689   flatbuffers::Offset<TfLiteOptions> WriteOptions(
   1690       const TocoOperator& op,
   1691       flatbuffers::FlatBufferBuilder* builder) const override {
   1692     return ::tflite::CreateShapeOptions(
   1693         *builder, DataType::Serialize(op.output_data_type));
   1694   }
   1695 
   1696   void ReadOptions(const TfLiteOptions& options,
   1697                    TocoOperator* op) const override {
   1698     op->output_data_type = DataType::Deserialize(options.out_type());
   1699   }
   1700 
   1701   int GetVersion(const OperatorSignature& op_signature) const override {
   1702     return 1;
   1703   }
   1704 };
   1705 
   1706 class Slice : public SimpleOperator<SliceOperator> {
   1707  public:
   1708   explicit Slice() : SimpleOperator("SLICE", OperatorType::kSlice) {}
   1709   int GetVersion(const OperatorSignature& op_signature) const override {
   1710     const string& input_name = op_signature.op->inputs[0];
   1711     const Array& input_array = op_signature.model->GetArray(input_name);
   1712     // Version 2 supports signed int8 input types.
   1713     if (input_array.data_type == ArrayDataType::kInt8) {
   1714       return 2;
   1715     }
   1716     return 1;
   1717   }
   1718 };
   1719 
   1720 class Tanh : public SimpleOperator<TanhOperator> {
   1721  public:
   1722   explicit Tanh() : SimpleOperator("TANH", OperatorType::kTanh) {}
   1723   int GetVersion(const OperatorSignature& op_signature) const override {
   1724     const string& input_name = op_signature.op->inputs[0];
   1725     const Array& input_array = op_signature.model->GetArray(input_name);
   1726     // Version 2 supports signed int8 input types.
   1727     if (input_array.data_type == ArrayDataType::kInt8) {
   1728       return 2;
   1729     }
   1730     return 1;
   1731   }
   1732 };
   1733 
   1734 class OneHot : public BuiltinOperator<OneHotOperator, ::tflite::OneHotOptions,
   1735                                       ::tflite::BuiltinOptions_OneHotOptions> {
   1736  public:
   1737   using BuiltinOperator::BuiltinOperator;
   1738   flatbuffers::Offset<TfLiteOptions> WriteOptions(
   1739       const TocoOperator& op,
   1740       flatbuffers::FlatBufferBuilder* builder) const override {
   1741     return ::tflite::CreateOneHotOptions(*builder, op.axis);
   1742   }
   1743   void ReadOptions(const TfLiteOptions& options,
   1744                    TocoOperator* op) const override {
   1745     op->axis = options.axis();
   1746   }
   1747 
   1748   int GetVersion(const OperatorSignature& op_signature) const override {
   1749     return 1;
   1750   }
   1751 };
   1752 
   1753 class CTCBeamSearchDecoder
   1754     : public CustomOperator<CTCBeamSearchDecoderOperator> {
   1755  public:
   1756   using CustomOperator::CustomOperator;
   1757 
   1758   void WriteOptions(const TocoOperator& op,
   1759                     flexbuffers::Builder* fbb) const override {
   1760     fbb->Int("beam_width", op.beam_width);
   1761     fbb->Int("top_paths", op.top_paths);
   1762     fbb->Bool("merge_repeated", op.merge_repeated);
   1763   }
   1764 
   1765   void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override {
   1766     op->beam_width = m["beam_width"].AsInt32();
   1767     op->top_paths = m["top_paths"].AsInt32();
   1768     op->merge_repeated = m["merge_repeated"].AsBool();
   1769   }
   1770 
   1771   int GetVersion(const OperatorSignature& op_signature) const override {
   1772     return 1;
   1773   }
   1774 };
   1775 
   1776 class Unpack : public BuiltinOperator<UnpackOperator, ::tflite::UnpackOptions,
   1777                                       ::tflite::BuiltinOptions_UnpackOptions> {
   1778  public:
   1779   using BuiltinOperator::BuiltinOperator;
   1780   flatbuffers::Offset<TfLiteOptions> WriteOptions(
   1781       const TocoOperator& op,
   1782       flatbuffers::FlatBufferBuilder* builder) const override {
   1783     return ::tflite::CreateUnpackOptions(*builder, op.num, op.axis);
   1784   }
   1785   void ReadOptions(const TfLiteOptions& options,
   1786                    TocoOperator* op) const override {
   1787     op->num = options.num();
   1788     op->axis = options.axis();
   1789   }
   1790 
   1791   int GetVersion(const OperatorSignature& op_signature) const override {
   1792     return 1;
   1793   }
   1794 };
   1795 
   1796 class LeakyRelu
   1797     : public BuiltinOperator<LeakyReluOperator, ::tflite::LeakyReluOptions,
   1798                              ::tflite::BuiltinOptions_LeakyReluOptions> {
   1799  public:
   1800   using BuiltinOperator::BuiltinOperator;
   1801   flatbuffers::Offset<TfLiteOptions> WriteOptions(
   1802       const TocoOperator& op,
   1803       flatbuffers::FlatBufferBuilder* builder) const override {
   1804     return ::tflite::CreateLeakyReluOptions(*builder, op.alpha);
   1805   }
   1806   void ReadOptions(const TfLiteOptions& options,
   1807                    TocoOperator* op) const override {
   1808     op->alpha = options.alpha();
   1809   }
   1810 
   1811   int GetVersion(const OperatorSignature& op_signature) const override {
   1812     return 1;
   1813   }
   1814 };
   1815 
   1816 class Logistic : public SimpleOperator<LogisticOperator> {
   1817  public:
   1818   explicit Logistic() : SimpleOperator("LOGISTIC", OperatorType::kLogistic) {}
   1819   int GetVersion(const OperatorSignature& op_signature) const override {
   1820     const string& input_name = op_signature.op->inputs[0];
   1821     const Array& input_array = op_signature.model->GetArray(input_name);
   1822     // Version 2 supports signed int8 input types.
   1823     if (input_array.data_type == ArrayDataType::kInt8) {
   1824       return 2;
   1825     }
   1826     return 1;
   1827   }
   1828 };
   1829 
   1830 class LogSoftmax : public SimpleOperator<LogSoftmaxOperator> {
   1831  public:
   1832   explicit LogSoftmax()
   1833       : SimpleOperator("LOG_SOFTMAX", OperatorType::kLogSoftmax) {}
   1834   int GetVersion(const OperatorSignature& op_signature) const override {
   1835     const string& input_name = op_signature.op->inputs[0];
   1836     const Array& input_array = op_signature.model->GetArray(input_name);
   1837     // Version 2 supports signed int8 input types.
   1838     if (input_array.data_type == ArrayDataType::kInt8) {
   1839       return 2;
   1840     }
   1841     return 1;
   1842   }
   1843 };
   1844 
   1845 class SquaredDifference
   1846     : public BuiltinOperator<
   1847           SquaredDifferenceOperator, ::tflite::SquaredDifferenceOptions,
   1848           ::tflite::BuiltinOptions_SquaredDifferenceOptions> {
   1849  public:
   1850   using BuiltinOperator::BuiltinOperator;
   1851 
   1852   flatbuffers::Offset<TfLiteOptions> WriteOptions(
   1853       const TocoOperator& op,
   1854       flatbuffers::FlatBufferBuilder* builder) const override {
   1855     return ::tflite::CreateSquaredDifferenceOptions(*builder);
   1856   }
   1857 
   1858   void ReadOptions(const TfLiteOptions& options,
   1859                    TocoOperator* op) const override {}
   1860 
   1861   int GetVersion(const OperatorSignature& op_signature) const override {
   1862     return 1;
   1863   }
   1864 };
   1865 
   1866 class MirrorPad
   1867     : public BuiltinOperator<MirrorPadOperator, ::tflite::MirrorPadOptions,
   1868                              ::tflite::BuiltinOptions_MirrorPadOptions> {
   1869  public:
   1870   using BuiltinOperator::BuiltinOperator;
   1871   flatbuffers::Offset<TfLiteOptions> WriteOptions(
   1872       const TocoOperator& op,
   1873       flatbuffers::FlatBufferBuilder* builder) const override {
   1874     return ::tflite::CreateMirrorPadOptions(
   1875         *builder, op.mode == MirrorPadMode::kReflect
   1876                       ? ::tflite::MirrorPadMode::MirrorPadMode_REFLECT
   1877                       : ::tflite::MirrorPadMode::MirrorPadMode_SYMMETRIC);
   1878   }
   1879   void ReadOptions(const TfLiteOptions& options,
   1880                    TocoOperator* op) const override {
   1881     op->mode = options.mode() == ::tflite::MirrorPadMode::MirrorPadMode_REFLECT
   1882                    ? MirrorPadMode::kReflect
   1883                    : MirrorPadMode::kSymmetric;
   1884   }
   1885 
   1886   int GetVersion(const OperatorSignature& op) const override { return 1; }
   1887 };
   1888 
   1889 class Unique : public BuiltinOperator<UniqueOperator, ::tflite::UniqueOptions,
   1890                                       ::tflite::BuiltinOptions_UniqueOptions> {
   1891  public:
   1892   using BuiltinOperator::BuiltinOperator;
   1893   flatbuffers::Offset<TfLiteOptions> WriteOptions(
   1894       const TocoOperator& op,
   1895       flatbuffers::FlatBufferBuilder* builder) const override {
   1896     const UniqueOperator& unique_op = static_cast<const UniqueOperator&>(op);
   1897     return ::tflite::CreateUniqueOptions(
   1898         *builder, unique_op.idx_out_type == toco::ArrayDataType::kInt64
   1899                       ? ::tflite::TensorType::TensorType_INT64
   1900                       : ::tflite::TensorType_INT32);
   1901   }
   1902   void ReadOptions(const TfLiteOptions& options,
   1903                    TocoOperator* op) const override {
   1904     UniqueOperator* unique_op = static_cast<UniqueOperator*>(op);
   1905     unique_op->idx_out_type =
   1906         options.idx_out_type() == ::tflite::TensorType_INT64
   1907             ? toco::ArrayDataType::kInt64
   1908             : toco::ArrayDataType::kInt32;
   1909   }
   1910 
   1911   int GetVersion(const OperatorSignature& op_signature) const override {
   1912     return 1;
   1913   }
   1914 };
   1915 
   1916 class UnidirectionalSequenceRnn
   1917     : public BuiltinOperator<UnidirectionalSequenceRnnOperator,
   1918                              ::tflite::SequenceRNNOptions,
   1919                              ::tflite::BuiltinOptions_SequenceRNNOptions> {
   1920  public:
   1921   using BuiltinOperator::BuiltinOperator;
   1922   flatbuffers::Offset<TfLiteOptions> WriteOptions(
   1923       const TocoOperator& op,
   1924       flatbuffers::FlatBufferBuilder* builder) const override {
   1925     return ::tflite::CreateSequenceRNNOptions(
   1926         *builder, /*time_major=*/true,
   1927         /*fused_activation_function=*/
   1928         ::tflite::ActivationFunctionType_TANH);
   1929   }
   1930   void ReadOptions(const TfLiteOptions& options,
   1931                    TocoOperator* op) const override {
   1932     // Only support tanh activation, so check that tflite type is tanh.
   1933     DCHECK(options.fused_activation_function() ==
   1934            ::tflite::ActivationFunctionType_TANH);
   1935   }
   1936 
   1937   int GetVersion(const OperatorSignature& op_signature) const override {
   1938     return 1;
   1939   }
   1940 
   1941   std::vector<bool> GetMutatingInputVariables(
   1942       const Operator& op) const override {
   1943     std::vector<bool> mutating_input_variables(op.inputs.size(), false);
   1944     mutating_input_variables[4] = true;
   1945     return mutating_input_variables;
   1946   }
   1947 };
   1948 
   1949 class Where : public BuiltinOperator<WhereOperator, ::tflite::WhereOptions,
   1950                                      ::tflite::BuiltinOptions_WhereOptions> {
   1951  public:
   1952   using BuiltinOperator::BuiltinOperator;
   1953 
   1954   flatbuffers::Offset<TfLiteOptions> WriteOptions(
   1955       const TocoOperator& op,
   1956       flatbuffers::FlatBufferBuilder* builder) const override {
   1957     return ::tflite::CreateWhereOptions(*builder);
   1958   }
   1959 
   1960   void ReadOptions(const TfLiteOptions& options,
   1961                    TocoOperator* op) const override {}
   1962 
   1963   int GetVersion(const OperatorSignature& op_signature) const override {
   1964     return 1;
   1965   }
   1966 };
   1967 
   1968 std::unique_ptr<flexbuffers::Builder> WriteFlexOpOptions(
   1969     const string& tensorflow_node_def) {
   1970   auto fbb = absl::make_unique<flexbuffers::Builder>();
   1971 
   1972   ::tensorflow::NodeDef node_def;
   1973   if (!node_def.ParseFromString(tensorflow_node_def)) {
   1974     LOG(ERROR) << "Failed to parse TensorFlow NodeDef";
   1975     return {};
   1976   }
   1977 
   1978   fbb->Vector([&]() {
   1979     fbb->String(node_def.op());
   1980     fbb->String(tensorflow_node_def);
   1981   });
   1982   fbb->Finish();
   1983   LOG(INFO) << "Writing flex op: " << node_def.op();
   1984   return std::unique_ptr<flexbuffers::Builder>(fbb.release());
   1985 }
   1986 
   1987 class TensorFlowUnsupported : public BaseOperator {
   1988  public:
   1989   TensorFlowUnsupported(const string& name, OperatorType type,
   1990                         bool enable_select_tf_ops)
   1991       : BaseOperator(name, type), enable_select_tf_ops_(enable_select_tf_ops) {}
   1992 
   1993   Options Serialize(const Operator& op,
   1994                     flatbuffers::FlatBufferBuilder* builder) const override {
   1995     auto fbb =
   1996         WriteOptions(static_cast<const TensorFlowUnsupportedOperator&>(op));
   1997     if (fbb) {
   1998       return Options::Custom(builder->CreateVector(fbb->GetBuffer()));
   1999     } else {
   2000       return Options::Custom(0);
   2001     }
   2002   }
   2003 
   2004   std::unique_ptr<Operator> Deserialize(
   2005       const BuiltinOptions* builtin_options,
   2006       const CustomOptions* custom_options) const override {
   2007     // Deserializing Flex ops doesn't work now.
   2008     // TODO(ycling): Revisit and decide if we should fix the flow for importing
   2009     // TFLite models with Flex ops.
   2010     auto op = absl::make_unique<TensorFlowUnsupportedOperator>();
   2011     if (custom_options) {
   2012       auto flexbuffer_map =
   2013           flexbuffers::GetRoot(custom_options->data(), custom_options->size())
   2014               .AsMap();
   2015       ReadOptions(flexbuffer_map, op.get());
   2016     }
   2017     return std::unique_ptr<Operator>(op.release());
   2018   }
   2019 
   2020   std::unique_ptr<flexbuffers::Builder> WriteOptions(
   2021       const TensorFlowUnsupportedOperator& op) const {
   2022     if (enable_select_tf_ops_) {
   2023       return WriteFlexOpOptions(op.tensorflow_node_def);
   2024     }
   2025     auto fbb = absl::make_unique<flexbuffers::Builder>();
   2026 
   2027     ::tensorflow::NodeDef node_def;
   2028     if (!node_def.ParseFromString(op.tensorflow_node_def)) {
   2029       LOG(ERROR) << "Failed to parse TensorFlow NodeDef";
   2030       return std::unique_ptr<flexbuffers::Builder>();
   2031     }
   2032 
   2033     if (ShouldExportAsFlexOp(enable_select_tf_ops_, node_def.op())) {
   2034       fbb->Vector([&]() {
   2035         fbb->String(node_def.op());
   2036         fbb->String(op.tensorflow_node_def);
   2037       });
   2038       fbb->Finish();
   2039       LOG(INFO) << "Writing flex op: " << node_def.op();
   2040       return std::unique_ptr<flexbuffers::Builder>(fbb.release());
   2041     }
   2042 
   2043     bool has_valid_attr = false;
   2044     size_t map_start = fbb->StartMap();
   2045     for (const auto& pair : node_def.attr()) {
   2046       const char* key = pair.first.c_str();
   2047       const auto& attr = pair.second;
   2048       switch (attr.value_case()) {
   2049         case ::tensorflow::AttrValue::kS:
   2050           fbb->String(key, attr.s());
   2051           has_valid_attr = true;
   2052           break;
   2053         case ::tensorflow::AttrValue::kI:
   2054           fbb->Int(key, attr.i());
   2055           has_valid_attr = true;
   2056           break;
   2057         case ::tensorflow::AttrValue::kF:
   2058           fbb->Float(key, attr.f());
   2059           has_valid_attr = true;
   2060           break;
   2061         case ::tensorflow::AttrValue::kB:
   2062           fbb->Bool(key, attr.b());
   2063           has_valid_attr = true;
   2064           break;
   2065         case tensorflow::AttrValue::kList:
   2066           if (attr.list().s_size() > 0) {
   2067             auto start = fbb->StartVector(key);
   2068             for (const string& v : attr.list().s()) {
   2069               fbb->Add(v);
   2070             }
   2071             fbb->EndVector(start, /*typed=*/true, /*fixed=*/false);
   2072             has_valid_attr = true;
   2073           } else if (attr.list().i_size() > 0) {
   2074             auto start = fbb->StartVector(key);
   2075             for (const int64_t v : attr.list().i()) {
   2076               fbb->Add(v);
   2077             }
   2078             fbb->EndVector(start, /*typed=*/true, /*fixed=*/false);
   2079             has_valid_attr = true;
   2080           } else if (attr.list().f_size() > 0) {
   2081             auto start = fbb->StartVector(key);
   2082             for (const float v : attr.list().f()) {
   2083               fbb->Add(v);
   2084             }
   2085             fbb->EndVector(start, /*typed=*/true, /*fixed=*/false);
   2086             has_valid_attr = true;
   2087           } else {
   2088             LOG(WARNING)
   2089                 << "Ignoring unsupported type in list attribute with key '"
   2090                 << key << "'";
   2091           }
   2092           break;
   2093         default:
   2094           LOG(WARNING) << "Ignoring unsupported attribute type with key '"
   2095                        << key << "'";
   2096           break;
   2097       }
   2098     }
   2099     if (!has_valid_attr) {
   2100       return std::unique_ptr<flexbuffers::Builder>();
   2101     }
   2102     fbb->EndMap(map_start);
   2103     fbb->Finish();
   2104     return std::unique_ptr<flexbuffers::Builder>(fbb.release());
   2105   }
   2106 
   2107   void ReadOptions(const flexbuffers::Map& m,
   2108                    TensorFlowUnsupportedOperator* op) const {
   2109     ::tensorflow::NodeDef node_def;
   2110     auto attr = node_def.mutable_attr();
   2111 
   2112     const auto& keys = m.Keys();
   2113     for (size_t i = 0; i < keys.size(); ++i) {
   2114       const auto key = keys[i].AsKey();
   2115       const auto& value = m[key];
   2116       // TODO(wvo): hack to make this code compile with 2 different API
   2117       // versions.
   2118       // Please remove once OS/internal versions are in sync.
   2119       // See hardcoded values in the switch below.
   2120       switch (value.GetType()) {
   2121         case 5:  // flexbuffers::FBT_STRING:
   2122           (*attr)[key].set_s(value.AsString().c_str());
   2123           break;
   2124         case 1:  // flexbuffers::FBT_INT:
   2125           (*attr)[key].set_i(value.AsInt64());
   2126           break;
   2127         case 3:  // flexbuffers::FBT_FLOAT:
   2128           (*attr)[key].set_f(value.AsFloat());
   2129           break;
   2130         case 26:  // flexbuffers::FBT_BOOL:
   2131           (*attr)[key].set_b(value.AsBool());
   2132           if (string(key) == "_output_quantized") {
   2133             op->quantized = value.AsBool();
   2134           }
   2135           if (string(key) == "_support_output_type_float_in_quantized_op") {
   2136             op->support_output_type_float_in_quantized_op = value.AsBool();
   2137           }
   2138           break;
   2139         case 11: {  // flexbuffers::FBT_VECTOR_INT: {
   2140           auto* list = (*attr)[key].mutable_list();
   2141           const auto& vector = value.AsTypedVector();
   2142           for (size_t i = 0; i < vector.size(); i++) {
   2143             list->add_i(vector[i].AsInt64());
   2144           }
   2145           break;
   2146         }
   2147         case 13: {  // flexbuffers::FBT_VECTOR_FLOAT: {
   2148           auto* list = (*attr)[key].mutable_list();
   2149           const auto& vector = value.AsTypedVector();
   2150           for (size_t i = 0; i < vector.size(); i++) {
   2151             list->add_f(vector[i].AsFloat());
   2152           }
   2153           break;
   2154         }
   2155         case 15: {  // flexbuffers::FBT_VECTOR_STRING: {
   2156           auto* list = (*attr)[key].mutable_list();
   2157           const auto& vector = value.AsTypedVector();
   2158           for (size_t i = 0; i < vector.size(); i++) {
   2159             list->add_s(vector[i].AsString().str());
   2160           }
   2161           break;
   2162         }
   2163         default:
   2164           LOG(WARNING) << "Ignoring unsupported attribute type with key '"
   2165                        << key << "'";
   2166           break;
   2167       }
   2168     }
   2169     node_def.SerializeToString(&op->tensorflow_node_def);
   2170   }
   2171 
   2172   int GetVersion(const OperatorSignature& op_signature) const override {
   2173     // TODO(ycling): Design and implement a way to plumb the version of
   2174     // custom ops.
   2175     return 1;
   2176   }
   2177 
   2178  private:
   2179   const bool enable_select_tf_ops_;
   2180 };
   2181 
   2182 class Dequantize
   2183     : public BuiltinOperator<DequantizeOperator, ::tflite::DequantizeOptions,
   2184                              ::tflite::BuiltinOptions_DequantizeOptions> {
   2185  public:
   2186   using BuiltinOperator::BuiltinOperator;
   2187 
   2188   flatbuffers::Offset<TfLiteOptions> WriteOptions(
   2189       const TocoOperator& op,
   2190       flatbuffers::FlatBufferBuilder* builder) const override {
   2191     return ::tflite::CreateDequantizeOptions(*builder);
   2192   }
   2193 
   2194   void ReadOptions(const TfLiteOptions& options,
   2195                    TocoOperator* op) const override {}
   2196 
   2197   int GetVersion(const OperatorSignature& op_signature) const override {
   2198     const string& input_name = op_signature.op->inputs[0];
   2199     const Array& input_array = op_signature.model->GetArray(input_name);
   2200     // Version 2 supports signed int8 input types.
   2201     if (input_array.data_type == ArrayDataType::kInt8) {
   2202       return 2;
   2203     }
   2204     return 1;
   2205   }
   2206 };
   2207 
   2208 class ReverseSequence
   2209     : public BuiltinOperator<ReverseSequenceOperator,
   2210                              ::tflite::ReverseSequenceOptions,
   2211                              ::tflite::BuiltinOptions_ReverseSequenceOptions> {
   2212  public:
   2213   using BuiltinOperator::BuiltinOperator;
   2214 
   2215   flatbuffers::Offset<TfLiteOptions> WriteOptions(
   2216       const TocoOperator& op,
   2217       flatbuffers::FlatBufferBuilder* builder) const override {
   2218     return ::tflite::CreateReverseSequenceOptions(*builder, op.seq_dim,
   2219                                                   op.batch_dim);
   2220   }
   2221 
   2222   void ReadOptions(const TfLiteOptions& options,
   2223                    TocoOperator* op) const override {
   2224     op->seq_dim = options.seq_dim();
   2225     op->batch_dim = options.batch_dim();
   2226   }
   2227 
   2228   int GetVersion(const OperatorSignature& op_signature) const override {
   2229     return 1;
   2230   }
   2231 };
   2232 
   2233 class Equal : public SimpleOperator<TensorFlowEqualOperator> {
   2234  public:
   2235   explicit Equal() : SimpleOperator("EQUAL", OperatorType::kEqual) {}
   2236   int GetVersion(const OperatorSignature& op_signature) const override {
   2237     const string& input_name = op_signature.op->inputs[0];
   2238     const Array& input_array = op_signature.model->GetArray(input_name);
   2239     // Version 2 supports signed int8 input types.
   2240     if (input_array.data_type == ArrayDataType::kInt8) {
   2241       return 2;
   2242     }
   2243     return 1;
   2244   }
   2245 };
   2246 
   2247 class NotEqual : public SimpleOperator<TensorFlowNotEqualOperator> {
   2248  public:
   2249   explicit NotEqual() : SimpleOperator("NOT_EQUAL", OperatorType::kNotEqual) {}
   2250   int GetVersion(const OperatorSignature& op_signature) const override {
   2251     const string& input_name = op_signature.op->inputs[0];
   2252     const Array& input_array = op_signature.model->GetArray(input_name);
   2253     // Version 2 supports signed int8 input types.
   2254     if (input_array.data_type == ArrayDataType::kInt8) {
   2255       return 2;
   2256     }
   2257     return 1;
   2258   }
   2259 };
   2260 
   2261 class Greater : public SimpleOperator<TensorFlowGreaterOperator> {
   2262  public:
   2263   explicit Greater() : SimpleOperator("GREATER", OperatorType::kGreater) {}
   2264   int GetVersion(const OperatorSignature& op_signature) const override {
   2265     const string& input_name = op_signature.op->inputs[0];
   2266     const Array& input_array = op_signature.model->GetArray(input_name);
   2267     // Version 2 supports signed int8 input types.
   2268     if (input_array.data_type == ArrayDataType::kInt8) {
   2269       return 2;
   2270     }
   2271     return 1;
   2272   }
   2273 };
   2274 
   2275 class GreaterEqual : public SimpleOperator<TensorFlowGreaterEqualOperator> {
   2276  public:
   2277   explicit GreaterEqual()
   2278       : SimpleOperator("GREATER_EQUAL", OperatorType::kGreaterEqual) {}
   2279   int GetVersion(const OperatorSignature& op_signature) const override {
   2280     const string& input_name = op_signature.op->inputs[0];
   2281     const Array& input_array = op_signature.model->GetArray(input_name);
   2282     // Version 2 supports signed int8 input types.
   2283     if (input_array.data_type == ArrayDataType::kInt8) {
   2284       return 2;
   2285     }
   2286     return 1;
   2287   }
   2288 };
   2289 
   2290 class Less : public SimpleOperator<TensorFlowLessOperator> {
   2291  public:
   2292   explicit Less() : SimpleOperator("LESS", OperatorType::kLess) {}
   2293   int GetVersion(const OperatorSignature& op_signature) const override {
   2294     const string& input_name = op_signature.op->inputs[0];
   2295     const Array& input_array = op_signature.model->GetArray(input_name);
   2296     // Version 2 supports signed int8 input types.
   2297     if (input_array.data_type == ArrayDataType::kInt8) {
   2298       return 2;
   2299     }
   2300     return 1;
   2301   }
   2302 };
   2303 
   2304 class LessEqual : public SimpleOperator<TensorFlowLessEqualOperator> {
   2305  public:
   2306   explicit LessEqual()
   2307       : SimpleOperator("LESS_EQUAL", OperatorType::kLessEqual) {}
   2308   int GetVersion(const OperatorSignature& op_signature) const override {
   2309     const string& input_name = op_signature.op->inputs[0];
   2310     const Array& input_array = op_signature.model->GetArray(input_name);
   2311     // Version 2 supports signed int8 input types.
   2312     if (input_array.data_type == ArrayDataType::kInt8) {
   2313       return 2;
   2314     }
   2315     return 1;
   2316   }
   2317 };
   2318 
   2319 class Select : public SimpleOperator<SelectOperator> {
   2320  public:
   2321   explicit Select() : SimpleOperator("SELECT", OperatorType::kSelect) {}
   2322   int GetVersion(const OperatorSignature& op_signature) const override {
   2323     const string& input_name = op_signature.op->inputs[0];
   2324     const Array& input_array = op_signature.model->GetArray(input_name);
   2325     // Version 2 supports signed int8 input types.
   2326     if (input_array.data_type == ArrayDataType::kInt8) {
   2327       return 2;
   2328     }
   2329     return 1;
   2330   }
   2331 };
   2332 
   2333 namespace {
   2334 // Build a vector containing all the known operators.
   2335 std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
   2336     bool enable_select_tf_ops = false) {
   2337   std::vector<std::unique_ptr<BaseOperator>> ops;
   2338   using tensorflow::MakeUnique;
   2339   // Builtin Operators.
   2340   ops.push_back(
   2341       MakeUnique<Add>(::tflite::BuiltinOperator_ADD, OperatorType::kAdd));
   2342   ops.push_back(
   2343       MakeUnique<AddN>(::tflite::BuiltinOperator_ADD_N, OperatorType::kAddN));
   2344   ops.push_back(
   2345       MakeUnique<Div>(::tflite::BuiltinOperator_DIV, OperatorType::kDiv));
   2346   ops.push_back(
   2347       MakeUnique<Sub>(::tflite::BuiltinOperator_SUB, OperatorType::kSub));
   2348   ops.push_back(MakeUnique<AveragePool>(
   2349       ::tflite::BuiltinOperator_AVERAGE_POOL_2D, OperatorType::kAveragePool));
   2350   ops.push_back(
   2351       MakeUnique<SpaceToBatchND>(::tflite::BuiltinOperator_SPACE_TO_BATCH_ND,
   2352                                  OperatorType::kSpaceToBatchND));
   2353   ops.push_back(
   2354       MakeUnique<BatchToSpaceND>(::tflite::BuiltinOperator_BATCH_TO_SPACE_ND,
   2355                                  OperatorType::kBatchToSpaceND));
   2356   ops.push_back(MakeUnique<Concatenation>(
   2357       ::tflite::BuiltinOperator_CONCATENATION, OperatorType::kConcatenation));
   2358   ops.push_back(MakeUnique<Convolution>(::tflite::BuiltinOperator_CONV_2D,
   2359                                         OperatorType::kConv));
   2360   ops.push_back(MakeUnique<DepthwiseConvolution>(
   2361       ::tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
   2362       OperatorType::kDepthwiseConv));
   2363   ops.push_back(MakeUnique<Dequantize>(::tflite::BuiltinOperator_DEQUANTIZE,
   2364                                        OperatorType::kDequantize));
   2365   ops.push_back(
   2366       MakeUnique<FullyConnected>(::tflite::BuiltinOperator_FULLY_CONNECTED,
   2367                                  OperatorType::kFullyConnected));
   2368   ops.push_back(MakeUnique<Gather>(::tflite::BuiltinOperator_GATHER,
   2369                                    OperatorType::kGather));
   2370   ops.push_back(MakeUnique<GatherNd>(::tflite::BuiltinOperator_GATHER_ND,
   2371                                      OperatorType::kGatherNd));
   2372   ops.push_back(
   2373       MakeUnique<L2Normalization>(::tflite::BuiltinOperator_L2_NORMALIZATION,
   2374                                   OperatorType::kL2Normalization));
   2375   ops.push_back(MakeUnique<L2Pool>(::tflite::BuiltinOperator_L2_POOL_2D,
   2376                                    OperatorType::kL2Pool));
   2377   ops.push_back(MakeUnique<LocalResponseNormalization>(
   2378       ::tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION,
   2379       OperatorType::kLocalResponseNormalization));
   2380   ops.push_back(MakeUnique<MaxPool>(::tflite::BuiltinOperator_MAX_POOL_2D,
   2381                                     OperatorType::kMaxPool));
   2382   ops.push_back(
   2383       MakeUnique<Mul>(::tflite::BuiltinOperator_MUL, OperatorType::kMul));
   2384 
   2385   ops.push_back(
   2386       MakeUnique<Pad>(::tflite::BuiltinOperator_PAD, OperatorType::kPad));
   2387   ops.push_back(
   2388       MakeUnique<PadV2>(::tflite::BuiltinOperator_PADV2, OperatorType::kPadV2));
   2389   ops.push_back(MakeUnique<Reshape>(::tflite::BuiltinOperator_RESHAPE,
   2390                                     OperatorType::kReshape));
   2391   ops.push_back(MakeUnique<Softmax>(::tflite::BuiltinOperator_SOFTMAX,
   2392                                     OperatorType::kSoftmax));
   2393   ops.push_back(MakeUnique<SpaceToDepth>(
   2394       ::tflite::BuiltinOperator_SPACE_TO_DEPTH, OperatorType::kSpaceToDepth));
   2395   ops.push_back(
   2396       MakeUnique<Svdf>(::tflite::BuiltinOperator_SVDF, OperatorType::kSvdf));
   2397   ops.push_back(MakeUnique<Transpose>(::tflite::BuiltinOperator_TRANSPOSE,
   2398                                       OperatorType::kTranspose));
   2399   ops.push_back(
   2400       MakeUnique<Mean>(::tflite::BuiltinOperator_MEAN, OperatorType::kMean));
   2401   ops.push_back(
   2402       MakeUnique<Sum>(::tflite::BuiltinOperator_SUM, OperatorType::kSum));
   2403   ops.push_back(MakeUnique<ReduceProd>(::tflite::BuiltinOperator_REDUCE_PROD,
   2404                                        OperatorType::kReduceProd));
   2405   ops.push_back(MakeUnique<ReduceMax>(::tflite::BuiltinOperator_REDUCE_MAX,
   2406                                       OperatorType::kReduceMax));
   2407   ops.push_back(MakeUnique<ReduceMin>(::tflite::BuiltinOperator_REDUCE_MIN,
   2408                                       OperatorType::kReduceMin));
   2409   ops.push_back(MakeUnique<ReduceAny>(::tflite::BuiltinOperator_REDUCE_ANY,
   2410                                       OperatorType::kAny));
   2411   ops.push_back(
   2412       MakeUnique<ResizeBilinear>(::tflite::BuiltinOperator_RESIZE_BILINEAR,
   2413                                  OperatorType::kResizeBilinear));
   2414   ops.push_back(MakeUnique<ResizeNearestNeighbor>(
   2415       ::tflite::BuiltinOperator_RESIZE_NEAREST_NEIGHBOR,
   2416       OperatorType::kResizeNearestNeighbor));
   2417   ops.push_back(MakeUnique<Squeeze>(::tflite::BuiltinOperator_SQUEEZE,
   2418                                     OperatorType::kSqueeze));
   2419   ops.push_back(
   2420       MakeUnique<Split>(::tflite::BuiltinOperator_SPLIT, OperatorType::kSplit));
   2421   ops.push_back(MakeUnique<SplitV>(::tflite::BuiltinOperator_SPLIT_V,
   2422                                    OperatorType::kSplitV));
   2423   ops.push_back(MakeUnique<StridedSlice>(
   2424       ::tflite::BuiltinOperator_STRIDED_SLICE, OperatorType::kStridedSlice));
   2425   ops.push_back(MakeUnique<TopK_V2>(::tflite::BuiltinOperator_TOPK_V2,
   2426                                     OperatorType::kTopK_V2));
   2427   ops.push_back(MakeUnique<Lstm>(::tflite::BuiltinOperator_LSTM,
   2428                                  OperatorType::kLstmCell));
   2429   ops.push_back(
   2430       MakeUnique<Cast>(::tflite::BuiltinOperator_CAST, OperatorType::kCast));
   2431   ops.push_back(MakeUnique<ArgMax>(::tflite::BuiltinOperator_ARG_MAX,
   2432                                    OperatorType::kArgMax));
   2433   ops.push_back(MakeUnique<ArgMin>(::tflite::BuiltinOperator_ARG_MIN,
   2434                                    OperatorType::kArgMin));
   2435   ops.push_back(
   2436       MakeUnique<Tile>(::tflite::BuiltinOperator_TILE, OperatorType::kTile));
   2437   ops.push_back(MakeUnique<ExpandDims>(::tflite::BuiltinOperator_EXPAND_DIMS,
   2438                                        OperatorType::kExpandDims));
   2439   ops.push_back(MakeUnique<TransposeConv>(
   2440       ::tflite::BuiltinOperator_TRANSPOSE_CONV, OperatorType::kTransposeConv));
   2441   ops.push_back(MakeUnique<SparseToDense>(
   2442       ::tflite::BuiltinOperator_SPARSE_TO_DENSE, OperatorType::kSparseToDense));
   2443   ops.push_back(
   2444       MakeUnique<Shape>(::tflite::BuiltinOperator_SHAPE, OperatorType::kShape));
   2445   ops.push_back(MakeUnique<FakeQuant>(::tflite::BuiltinOperator_FAKE_QUANT,
   2446                                       OperatorType::kFakeQuant));
   2447   ops.push_back(
   2448       MakeUnique<Pack>(::tflite::BuiltinOperator_PACK, OperatorType::kPack));
   2449   ops.emplace_back(MakeUnique<UnidirectionalSequenceLstm>(
   2450       ::tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
   2451       OperatorType::kUnidirectionalSequenceLstm));
   2452   ops.emplace_back(MakeUnique<BidirectionalSequenceLstm>(
   2453       ::tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM,
   2454       OperatorType::kBidirectionalSequenceLstm));
   2455   ops.emplace_back(MakeUnique<BidirectionalSequenceRnn>(
   2456       ::tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN,
   2457       OperatorType::kBidirectionalSequenceRnn));
   2458   ops.push_back(MakeUnique<OneHot>(::tflite::BuiltinOperator_ONE_HOT,
   2459                                    OperatorType::kOneHot));
   2460   ops.push_back(MakeUnique<Unpack>(::tflite::BuiltinOperator_UNPACK,
   2461                                    OperatorType::kUnpack));
   2462   ops.push_back(MakeUnique<LeakyRelu>(::tflite::BuiltinOperator_LEAKY_RELU,
   2463                                       OperatorType::kLeakyRelu));
   2464   ops.push_back(MakeUnique<SquaredDifference>(
   2465       ::tflite::BuiltinOperator_SQUARED_DIFFERENCE,
   2466       OperatorType::kSquaredDifference));
   2467   ops.push_back(MakeUnique<MirrorPad>(::tflite::BuiltinOperator_MIRROR_PAD,
   2468                                       OperatorType::kMirrorPad));
   2469   ops.push_back(MakeUnique<Unique>(::tflite::BuiltinOperator_UNIQUE,
   2470                                    OperatorType::kUnique));
   2471   ops.push_back(MakeUnique<UnidirectionalSequenceRnn>(
   2472       ::tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN,
   2473       OperatorType::kUnidirectionalSequenceRnn));
   2474   ops.push_back(
   2475       MakeUnique<Where>(::tflite::BuiltinOperator_WHERE, OperatorType::kWhere));
   2476   ops.push_back(
   2477       MakeUnique<ReverseSequence>(::tflite::BuiltinOperator_REVERSE_SEQUENCE,
   2478                                   OperatorType::kReverseSequence));
   2479 
   2480   // Custom Operators.
   2481   ops.push_back(
   2482       MakeUnique<DepthToSpace>("DEPTH_TO_SPACE", OperatorType::kDepthToSpace));
   2483   ops.push_back(MakeUnique<CTCBeamSearchDecoder>(
   2484       "CTC_BEAM_SEARCH_DECODER", OperatorType::kCTCBeamSearchDecoder));
   2485   ops.push_back(MakeUnique<TensorFlowUnsupported>("TENSORFLOW_UNSUPPORTED",
   2486                                                   OperatorType::kUnsupported,
   2487                                                   enable_select_tf_ops));
   2488 
   2489   // SimpleOperator was designed to export CUSTOM TF Lite ops, but has since
   2490   // been modified to also export builtins. As TOCO evolved we added warnings
   2491   // when custom ops are exported but SimpleOperator bypasses thoses. To
   2492   // prevent user confusion we are settling on using SimpleOperator only for
   2493   // builtins.
   2494   ops.push_back(
   2495       MakeUnique<SimpleOperator<FloorOperator>>("FLOOR", OperatorType::kFloor));
   2496   ops.push_back(
   2497       MakeUnique<SimpleOperator<CeilOperator>>("CEIL", OperatorType::kCeil));
   2498   ops.push_back(
   2499       MakeUnique<SimpleOperator<EluOperator>>("ELU", OperatorType::kElu));
   2500   ops.push_back(
   2501       MakeUnique<SimpleOperator<ReluOperator>>("RELU", OperatorType::kRelu));
   2502   ops.push_back(MakeUnique<SimpleOperator<Relu1Operator>>(
   2503       "RELU_N1_TO_1", OperatorType::kRelu1));
   2504   ops.push_back(MakeUnique<Relu6>());
   2505   ops.push_back(
   2506       MakeUnique<SimpleOperator<PReluOperator>>("PRELU", OperatorType::kPRelu));
   2507   ops.push_back(MakeUnique<Logistic>());
   2508   ops.push_back(MakeUnique<Tanh>());
   2509   ops.push_back(
   2510       MakeUnique<SimpleOperator<ExpOperator>>("EXP", OperatorType::kExp));
   2511   ops.push_back(
   2512       MakeUnique<SimpleOperator<CosOperator>>("COS", OperatorType::kCos));
   2513   ops.push_back(MakeUnique<LogSoftmax>());
   2514   ops.push_back(MakeUnique<Maximum>());  //  Element-wise Maximum
   2515   ops.push_back(MakeUnique<Minimum>());  //  Element-wise Minimum
   2516   ops.push_back(MakeUnique<Greater>());
   2517   ops.push_back(MakeUnique<GreaterEqual>());
   2518   ops.push_back(MakeUnique<Less>());
   2519   ops.push_back(MakeUnique<LessEqual>());
   2520   ops.push_back(MakeUnique<Equal>());
   2521   ops.push_back(MakeUnique<NotEqual>());
   2522   ops.push_back(
   2523       MakeUnique<SimpleOperator<NegOperator>>("NEG", OperatorType::kNeg));
   2524   ops.push_back(MakeUnique<Select>());
   2525   ops.push_back(MakeUnique<Slice>());
   2526   ops.push_back(
   2527       MakeUnique<SimpleOperator<PowOperator>>("POW", OperatorType::kPow));
   2528   ops.push_back(MakeUnique<SimpleOperator<LogicalOrOperator>>(
   2529       "LOGICAL_OR", OperatorType::kLogicalOr));
   2530   ops.emplace_back(new SimpleOperator<LogicalAndOperator>(
   2531       "LOGICAL_AND", OperatorType::kLogicalAnd));
   2532   ops.emplace_back(new SimpleOperator<LogicalNotOperator>(
   2533       "LOGICAL_NOT", OperatorType::kLogicalNot));
   2534   ops.emplace_back(new SimpleOperator<FloorDivOperator>(
   2535       "FLOOR_DIV", OperatorType::kFloorDiv));
   2536   ops.emplace_back(new SimpleOperator<FloorModOperator>(
   2537       "FLOOR_MOD", OperatorType::kFloorMod));
   2538   ops.emplace_back(
   2539       new SimpleOperator<RangeOperator>("RANGE", OperatorType::kRange));
   2540   // Element-wise operator
   2541   ops.push_back(
   2542       MakeUnique<SimpleOperator<SinOperator>>("SIN", OperatorType::kSin));
   2543   ops.push_back(
   2544       MakeUnique<SimpleOperator<LogOperator>>("LOG", OperatorType::kLog));
   2545   ops.push_back(MakeUnique<SimpleOperator<TensorFlowSqrtOperator>>(
   2546       "SQRT", OperatorType::kSqrt));
   2547   ops.push_back(MakeUnique<SimpleOperator<TensorFlowRsqrtOperator>>(
   2548       "RSQRT", OperatorType::kRsqrt));
   2549   ops.push_back(MakeUnique<SimpleOperator<TensorFlowSquareOperator>>(
   2550       "SQUARE", OperatorType::kSquare));
   2551   ops.push_back(MakeUnique<SimpleOperator<TensorFlowZerosLikeOperator>>(
   2552       "ZEROS_LIKE", OperatorType::kZerosLike));
   2553   ops.push_back(
   2554       MakeUnique<SimpleOperator<AbsOperator>>("ABS", OperatorType::kAbs));
   2555   ops.push_back(
   2556       MakeUnique<SimpleOperator<FillOperator>>("FILL", OperatorType::kFill));
   2557   ops.push_back(MakeUnique<SimpleOperator<ReverseV2Operator>>(
   2558       "REVERSE_V2", OperatorType::kReverseV2));
   2559   ops.push_back(MakeUnique<SimpleOperator<TensorFlowRankOperator>>(
   2560       "RANK", OperatorType::kRank));
   2561   return ops;
   2562 }
   2563 }  // namespace
   2564 
   2565 std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap(
   2566     bool enable_select_tf_ops) {
   2567   std::map<OperatorType, std::unique_ptr<BaseOperator>> result;
   2568 
   2569   std::vector<std::unique_ptr<BaseOperator>> ops =
   2570       BuildOperatorList(enable_select_tf_ops);
   2571   for (auto& op : ops) {
   2572     result[op->type()] = std::move(op);
   2573   }
   2574 
   2575   return result;
   2576 }
   2577 
   2578 std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap(
   2579     bool enable_select_tf_ops) {
   2580   std::map<string, std::unique_ptr<BaseOperator>> result;
   2581 
   2582   std::vector<std::unique_ptr<BaseOperator>> ops =
   2583       BuildOperatorList(enable_select_tf_ops);
   2584   for (auto& op : ops) {
   2585     result[op->name()] = std::move(op);
   2586   }
   2587 
   2588   return result;
   2589 }
   2590 
   2591 bool ShouldExportAsFlexOp(bool enable_select_tf_ops,
   2592                           const string& tensorflow_op_name) {
   2593   // If Flex ops aren't allow at all, simply return false.
   2594   if (!enable_select_tf_ops) {
   2595     return false;
   2596   }
   2597   // Check if we can find the `OpDef` for the TensorFlow op. If we can find
   2598   // it and it has been whitelisted, export the op as an Flex op. Otherwise,
   2599   // export it as a regular custom op.
   2600   const tensorflow::OpDef* op_def = nullptr;
   2601   if (!tensorflow::OpRegistry::Global()
   2602            ->LookUpOpDef(tensorflow_op_name, &op_def)
   2603            .ok()) {
   2604     return false;
   2605   }
   2606 
   2607   if (!IsWhitelistedFlexOp(tensorflow_op_name)) {
   2608     LOG(WARNING) << "Op " << tensorflow_op_name
   2609                  << " is a valid TensorFlow op but has not been whitelisted for"
   2610                     " the TensorFlow Lite flex op set.";
   2611     return false;
   2612   }
   2613 
   2614   return true;
   2615 }
   2616 
   2617 }  // namespace tflite
   2618 
   2619 }  // namespace toco
   2620