1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/lite/toco/tflite/operator.h" 16 17 #include "tensorflow/core/framework/attr_value.pb.h" 18 #include "tensorflow/core/framework/node_def.pb.h" 19 #include "tensorflow/core/framework/op.h" 20 #include "tensorflow/core/framework/op_def.pb.h" 21 #include "tensorflow/core/util/ptr_util.h" 22 23 // TODO(ycling): Consider refactoring to extract the LSTM definition out of 24 // graph_transformation module. 25 #include "tensorflow/lite/schema/schema_generated.h" 26 #include "tensorflow/lite/toco/graph_transformations/lstm_utils.h" 27 #include "tensorflow/lite/toco/model.h" 28 #include "tensorflow/lite/toco/tflite/builtin_operator.h" 29 #include "tensorflow/lite/toco/tflite/custom_operator.h" 30 #include "tensorflow/lite/toco/tflite/simple_operator.h" 31 #include "tensorflow/lite/toco/tflite/types.h" 32 #include "tensorflow/lite/toco/tflite/whitelisted_flex_ops.h" 33 34 namespace toco { 35 36 namespace tflite { 37 38 class AveragePool 39 : public BuiltinOperator<AveragePoolOperator, ::tflite::Pool2DOptions, 40 ::tflite::BuiltinOptions_Pool2DOptions> { 41 public: 42 using BuiltinOperator::BuiltinOperator; 43 44 flatbuffers::Offset<TfLiteOptions> WriteOptions( 45 const TocoOperator& op, 46 flatbuffers::FlatBufferBuilder* builder) const override { 47 auto padding = Padding::Serialize(op.padding.type); 48 auto activation_function = 49 ActivationFunction::Serialize(op.fused_activation_function); 50 return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width, 51 op.stride_height, op.kwidth, 52 op.kheight, activation_function); 53 } 54 55 void ReadOptions(const TfLiteOptions& options, 56 TocoOperator* op) const override { 57 op->padding.type = Padding::Deserialize(options.padding()); 58 op->stride_width = options.stride_w(); 59 op->stride_height = options.stride_h(); 60 op->kwidth = options.filter_width(); 61 op->kheight = options.filter_height(); 62 op->fused_activation_function = 63 ActivationFunction::Deserialize(options.fused_activation_function()); 64 } 65 66 int GetVersion(const OperatorSignature& op_signature) const override { 67 const string& input_name = op_signature.op->inputs[0]; 68 const Array& input_array = op_signature.model->GetArray(input_name); 69 if (input_array.data_type == ArrayDataType::kInt8) { 70 return 2; 71 } 72 return 1; 73 } 74 }; 75 76 class Convolution 77 : public BuiltinOperator<ConvOperator, ::tflite::Conv2DOptions, 78 ::tflite::BuiltinOptions_Conv2DOptions> { 79 public: 80 using BuiltinOperator::BuiltinOperator; 81 82 flatbuffers::Offset<TfLiteOptions> WriteOptions( 83 const TocoOperator& op, 84 flatbuffers::FlatBufferBuilder* builder) const override { 85 auto padding = Padding::Serialize(op.padding.type); 86 auto activation_function = 87 ActivationFunction::Serialize(op.fused_activation_function); 88 return ::tflite::CreateConv2DOptions(*builder, padding, op.stride_width, 89 op.stride_height, activation_function, 90 op.dilation_width_factor, 91 op.dilation_height_factor); 92 } 93 94 void ReadOptions(const TfLiteOptions& options, 95 TocoOperator* op) const override { 96 op->padding.type = Padding::Deserialize(options.padding()); 97 op->stride_width = options.stride_w(); 98 op->stride_height = options.stride_h(); 99 op->dilation_width_factor = options.dilation_w_factor(); 100 op->dilation_height_factor = options.dilation_h_factor(); 101 op->fused_activation_function = 102 ActivationFunction::Deserialize(options.fused_activation_function()); 103 } 104 105 int GetVersion(const OperatorSignature& op_signature) const override { 106 const string& input_name = op_signature.op->inputs[0]; 107 const string& filter_name = op_signature.op->inputs[1]; 108 const string& output_name = op_signature.op->outputs[0]; 109 const Array& input_array = op_signature.model->GetArray(input_name); 110 const Array& filter_array = op_signature.model->GetArray(filter_name); 111 const Array& output_array = op_signature.model->GetArray(output_name); 112 // If the op has signed int8 inputs and outputs, its version 3. 113 if (input_array.data_type == ArrayDataType::kInt8 && 114 filter_array.data_type == ArrayDataType::kInt8 && 115 output_array.data_type == ArrayDataType::kInt8) { 116 return 3; 117 } 118 // If the op is a signed int8 hybrid operation, we need to return 119 // version 2. 120 if (input_array.data_type == ArrayDataType::kFloat && 121 filter_array.data_type == ArrayDataType::kInt8 && 122 output_array.data_type == ArrayDataType::kFloat) { 123 return 2; 124 } 125 return 1; 126 } 127 }; 128 129 class DepthwiseConvolution 130 : public BuiltinOperator<DepthwiseConvOperator, 131 ::tflite::DepthwiseConv2DOptions, 132 ::tflite::BuiltinOptions_DepthwiseConv2DOptions> { 133 public: 134 using BuiltinOperator::BuiltinOperator; 135 136 flatbuffers::Offset<TfLiteOptions> WriteOptions( 137 const TocoOperator& op, 138 flatbuffers::FlatBufferBuilder* builder) const override { 139 auto padding = Padding::Serialize(op.padding.type); 140 auto activation_function = 141 ActivationFunction::Serialize(op.fused_activation_function); 142 return ::tflite::CreateDepthwiseConv2DOptions( 143 *builder, padding, op.stride_width, op.stride_height, 144 op.depth_multiplier, activation_function, op.dilation_width_factor, 145 op.dilation_height_factor); 146 } 147 148 void ReadOptions(const TfLiteOptions& options, 149 TocoOperator* op) const override { 150 op->padding.type = Padding::Deserialize(options.padding()); 151 op->stride_width = options.stride_w(); 152 op->stride_height = options.stride_h(); 153 op->depth_multiplier = options.depth_multiplier(); 154 op->fused_activation_function = 155 ActivationFunction::Deserialize(options.fused_activation_function()); 156 op->dilation_width_factor = options.dilation_w_factor(); 157 op->dilation_height_factor = options.dilation_h_factor(); 158 } 159 160 int GetVersion(const OperatorSignature& op_signature) const override { 161 const auto& conv_op = 162 static_cast<const DepthwiseConvOperator&>(*op_signature.op); 163 const string& input_name = op_signature.op->inputs[0]; 164 const string& filter_name = op_signature.op->inputs[1]; 165 const string& output_name = op_signature.op->outputs[0]; 166 const Array& input_array = op_signature.model->GetArray(input_name); 167 const Array& filter_array = op_signature.model->GetArray(filter_name); 168 const Array& output_array = op_signature.model->GetArray(output_name); 169 // If the op has signed int8 inputs and outputs, its version 3. 170 if (input_array.data_type == ArrayDataType::kInt8 && 171 filter_array.data_type == ArrayDataType::kInt8 && 172 output_array.data_type == ArrayDataType::kInt8) { 173 return 3; 174 } 175 if (conv_op.dilation_width_factor != 1 || 176 conv_op.dilation_height_factor != 1) { 177 return 2; 178 } 179 return 1; 180 } 181 }; 182 183 class Add : public BuiltinOperator<AddOperator, ::tflite::AddOptions, 184 ::tflite::BuiltinOptions_AddOptions> { 185 public: 186 using BuiltinOperator::BuiltinOperator; 187 188 flatbuffers::Offset<TfLiteOptions> WriteOptions( 189 const TocoOperator& op, 190 flatbuffers::FlatBufferBuilder* builder) const override { 191 auto activation_function = 192 ActivationFunction::Serialize(op.fused_activation_function); 193 return ::tflite::CreateAddOptions(*builder, activation_function); 194 } 195 196 void ReadOptions(const TfLiteOptions& options, 197 TocoOperator* op) const override { 198 op->fused_activation_function = 199 ActivationFunction::Deserialize(options.fused_activation_function()); 200 } 201 202 int GetVersion(const OperatorSignature& op_signature) const override { 203 const string& input_name = op_signature.op->inputs[0]; 204 const Array& input_array = op_signature.model->GetArray(input_name); 205 // Version 2 supports signed int8 input types. 206 if (input_array.data_type == ArrayDataType::kInt8) { 207 return 2; 208 } 209 return 1; 210 } 211 }; 212 213 class AddN : public BuiltinOperator<AddNOperator, ::tflite::AddNOptions, 214 ::tflite::BuiltinOptions_AddNOptions> { 215 public: 216 using BuiltinOperator::BuiltinOperator; 217 218 flatbuffers::Offset<TfLiteOptions> WriteOptions( 219 const TocoOperator& op, 220 flatbuffers::FlatBufferBuilder* builder) const override { 221 return ::tflite::CreateAddNOptions(*builder); 222 } 223 224 void ReadOptions(const TfLiteOptions& options, 225 TocoOperator* op) const override {} 226 227 int GetVersion(const OperatorSignature& op_signature) const override { 228 return 1; 229 } 230 }; 231 232 class SpaceToBatchND 233 : public BuiltinOperator<SpaceToBatchNDOperator, 234 ::tflite::SpaceToBatchNDOptions, 235 ::tflite::BuiltinOptions_SpaceToBatchNDOptions> { 236 public: 237 using BuiltinOperator::BuiltinOperator; 238 239 flatbuffers::Offset<TfLiteOptions> WriteOptions( 240 const TocoOperator& op, 241 flatbuffers::FlatBufferBuilder* builder) const override { 242 return ::tflite::CreateSpaceToBatchNDOptions(*builder); 243 } 244 245 void ReadOptions(const TfLiteOptions& options, 246 TocoOperator* op) const override {} 247 248 int GetVersion(const OperatorSignature& op_signature) const override { 249 const string& input_name = op_signature.op->inputs[0]; 250 const Array& input_array = op_signature.model->GetArray(input_name); 251 // If the op take int8 input, it is version 2. 252 if (input_array.data_type == ArrayDataType::kInt8) { 253 return 2; 254 } 255 return 1; 256 } 257 }; 258 259 class Sub : public BuiltinOperator<SubOperator, ::tflite::SubOptions, 260 ::tflite::BuiltinOptions_SubOptions> { 261 public: 262 using BuiltinOperator::BuiltinOperator; 263 264 flatbuffers::Offset<TfLiteOptions> WriteOptions( 265 const TocoOperator& op, 266 flatbuffers::FlatBufferBuilder* builder) const override { 267 auto activation_function = 268 ActivationFunction::Serialize(op.fused_activation_function); 269 return ::tflite::CreateSubOptions(*builder, activation_function); 270 } 271 272 void ReadOptions(const TfLiteOptions& options, 273 TocoOperator* op) const override { 274 op->fused_activation_function = 275 ActivationFunction::Deserialize(options.fused_activation_function()); 276 } 277 278 int GetVersion(const OperatorSignature& op_signature) const override { 279 const string& input_name = op_signature.op->inputs[0]; 280 const Array& input_array = op_signature.model->GetArray(input_name); 281 // If the op take int8 input, it is version 2. 282 if (input_array.data_type == ArrayDataType::kInt8) { 283 return 2; 284 } 285 return 1; 286 } 287 }; 288 289 class Div : public BuiltinOperator<DivOperator, ::tflite::DivOptions, 290 ::tflite::BuiltinOptions_DivOptions> { 291 public: 292 using BuiltinOperator::BuiltinOperator; 293 294 flatbuffers::Offset<TfLiteOptions> WriteOptions( 295 const TocoOperator& op, 296 flatbuffers::FlatBufferBuilder* builder) const override { 297 auto activation_function = 298 ActivationFunction::Serialize(op.fused_activation_function); 299 return ::tflite::CreateDivOptions(*builder, activation_function); 300 } 301 302 void ReadOptions(const TfLiteOptions& options, 303 TocoOperator* op) const override { 304 op->fused_activation_function = 305 ActivationFunction::Deserialize(options.fused_activation_function()); 306 } 307 308 int GetVersion(const OperatorSignature& op_signature) const override { 309 return 1; 310 } 311 }; 312 313 class BatchToSpaceND 314 : public BuiltinOperator<BatchToSpaceNDOperator, 315 ::tflite::BatchToSpaceNDOptions, 316 ::tflite::BuiltinOptions_BatchToSpaceNDOptions> { 317 public: 318 using BuiltinOperator::BuiltinOperator; 319 320 flatbuffers::Offset<TfLiteOptions> WriteOptions( 321 const TocoOperator& op, 322 flatbuffers::FlatBufferBuilder* builder) const override { 323 return ::tflite::CreateBatchToSpaceNDOptions(*builder); 324 } 325 326 void ReadOptions(const TfLiteOptions& options, 327 TocoOperator* op) const override {} 328 329 int GetVersion(const OperatorSignature& op_signature) const override { 330 const string& input_name = op_signature.op->inputs[0]; 331 const Array& input_array = op_signature.model->GetArray(input_name); 332 // If the op take int8 input, it is version 2. 333 if (input_array.data_type == ArrayDataType::kInt8) { 334 return 2; 335 } 336 return 1; 337 } 338 }; 339 340 class Cast : public BuiltinOperator<CastOperator, ::tflite::CastOptions, 341 ::tflite::BuiltinOptions_CastOptions> { 342 public: 343 using BuiltinOperator::BuiltinOperator; 344 flatbuffers::Offset<TfLiteOptions> WriteOptions( 345 const TocoOperator& op, 346 flatbuffers::FlatBufferBuilder* builder) const override { 347 return ::tflite::CreateCastOptions(*builder, 348 DataType::Serialize(op.src_data_type), 349 DataType::Serialize(op.dst_data_type)); 350 } 351 352 void ReadOptions(const TfLiteOptions& options, 353 TocoOperator* op) const override { 354 op->src_data_type = DataType::Deserialize(options.in_data_type()); 355 op->dst_data_type = DataType::Deserialize(options.out_data_type()); 356 } 357 358 int GetVersion(const OperatorSignature& op_signature) const override { 359 return 1; 360 } 361 }; 362 363 class Concatenation 364 : public BuiltinOperator<ConcatenationOperator, 365 ::tflite::ConcatenationOptions, 366 ::tflite::BuiltinOptions_ConcatenationOptions> { 367 public: 368 using BuiltinOperator::BuiltinOperator; 369 flatbuffers::Offset<TfLiteOptions> WriteOptions( 370 const TocoOperator& op, 371 flatbuffers::FlatBufferBuilder* builder) const override { 372 return ::tflite::CreateConcatenationOptions(*builder, op.axis); 373 } 374 375 void ReadOptions(const TfLiteOptions& options, 376 TocoOperator* op) const override { 377 op->axis = options.axis(); 378 } 379 380 int GetVersion(const OperatorSignature& op_signature) const override { 381 const string& input_name = op_signature.op->inputs[0]; 382 const Array& input_array = op_signature.model->GetArray(input_name); 383 // If the op take int8 input, it is version 2. 384 if (input_array.data_type == ArrayDataType::kInt8) { 385 return 2; 386 } 387 return 1; 388 } 389 }; 390 391 class DepthToSpace : public CustomOperator<DepthToSpaceOperator> { 392 public: 393 using CustomOperator::CustomOperator; 394 void WriteOptions(const TocoOperator& op, 395 flexbuffers::Builder* fbb) const override { 396 fbb->Int("block_size", op.block_size); 397 } 398 void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override { 399 op->block_size = m["block_size"].AsInt64(); 400 } 401 402 int GetVersion(const OperatorSignature& op_signature) const override { 403 return 1; 404 } 405 }; 406 407 class FakeQuant 408 : public BuiltinOperator<FakeQuantOperator, ::tflite::FakeQuantOptions, 409 ::tflite::BuiltinOptions_FakeQuantOptions> { 410 public: 411 using BuiltinOperator::BuiltinOperator; 412 flatbuffers::Offset<TfLiteOptions> WriteOptions( 413 const TocoOperator& op, 414 flatbuffers::FlatBufferBuilder* builder) const override { 415 return ::tflite::CreateFakeQuantOptions( 416 *builder, op.minmax->min, op.minmax->max, op.num_bits, op.narrow_range); 417 } 418 void ReadOptions(const TfLiteOptions& options, 419 TocoOperator* op) const override { 420 auto* minmax = new MinMax; 421 minmax->min = options.min(); 422 minmax->max = options.max(); 423 op->minmax.reset(minmax); 424 op->num_bits = options.num_bits(); 425 op->narrow_range = options.narrow_range(); 426 } 427 int GetVersion(const OperatorSignature& op_signature) const override { 428 const auto& fq_op = static_cast<const FakeQuantOperator&>(*op_signature.op); 429 return fq_op.narrow_range ? 2 : 1; 430 } 431 }; 432 433 class FullyConnected 434 : public BuiltinOperator<FullyConnectedOperator, 435 ::tflite::FullyConnectedOptions, 436 ::tflite::BuiltinOptions_FullyConnectedOptions> { 437 public: 438 using BuiltinOperator::BuiltinOperator; 439 flatbuffers::Offset<TfLiteOptions> WriteOptions( 440 const TocoOperator& op, 441 flatbuffers::FlatBufferBuilder* builder) const override { 442 auto activation_function = 443 ActivationFunction::Serialize(op.fused_activation_function); 444 ::tflite::FullyConnectedOptionsWeightsFormat tflite_weights_format; 445 switch (op.weights_format) { 446 case FullyConnectedWeightsFormat::kDefault: 447 tflite_weights_format = 448 ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT; 449 break; 450 case FullyConnectedWeightsFormat::kShuffled4x16Int8: 451 tflite_weights_format = 452 ::tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8; 453 break; 454 default: 455 LOG(ERROR) << "Unhandled FC weights format"; 456 tflite_weights_format = 457 ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT; 458 } 459 return ::tflite::CreateFullyConnectedOptions(*builder, activation_function, 460 tflite_weights_format); 461 } 462 463 void ReadOptions(const TfLiteOptions& options, 464 TocoOperator* op) const override { 465 op->fused_activation_function = 466 ActivationFunction::Deserialize(options.fused_activation_function()); 467 switch (options.weights_format()) { 468 case ::tflite::FullyConnectedOptionsWeightsFormat_DEFAULT: 469 op->weights_format = FullyConnectedWeightsFormat::kDefault; 470 break; 471 case ::tflite::FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8: 472 op->weights_format = FullyConnectedWeightsFormat::kShuffled4x16Int8; 473 break; 474 default: 475 LOG(ERROR) << "Unhandled FC weights format"; 476 op->weights_format = FullyConnectedWeightsFormat::kDefault; 477 } 478 } 479 480 // +-----------------+--------------------+--------------------------+ 481 // | | Weight::Default | Weight::Shuffled4x16Int8 | 482 // +-----------------+--------------------+--------------------------+ 483 // | Float | 1 | 2 | 484 // | Quantized Uint8 | 1 | 2 | 485 // | Hybrid | 3 | 3 | 486 // | Quantized Int8 | 4 | 4 | 487 // +-----------------+--------------------+--------------------------+ 488 int GetVersion(const OperatorSignature& op_signature) const override { 489 const auto& fc_op = 490 static_cast<const FullyConnectedOperator&>(*op_signature.op); 491 const string& input_name = op_signature.op->inputs[0]; 492 const string& weights_name = op_signature.op->inputs[1]; 493 const string& output_name = op_signature.op->outputs[0]; 494 const Array& input_array = op_signature.model->GetArray(input_name); 495 const Array& weights_array = op_signature.model->GetArray(weights_name); 496 const Array& output_array = op_signature.model->GetArray(output_name); 497 // Int8 fully fixed point kernel is at version 4. 498 if (input_array.data_type == ArrayDataType::kInt8 && 499 weights_array.data_type == ArrayDataType::kInt8 && 500 output_array.data_type == ArrayDataType::kInt8) { 501 return 4; 502 } 503 // If the op is a signed int8 hybrid operation, we need to return 504 // version 3. 505 if (input_array.data_type == ArrayDataType::kFloat && 506 weights_array.data_type == ArrayDataType::kInt8 && 507 output_array.data_type == ArrayDataType::kFloat) { 508 return 3; 509 } 510 // For float and uint8 fixed point kernels, if the weight is 511 // Shuffled4x16Int8, is is version 2. 512 if (fc_op.weights_format == 513 FullyConnectedWeightsFormat::kShuffled4x16Int8) { 514 return 2; 515 } 516 517 // Otherwise (weight is default), the version is 1. 518 return 1; 519 } 520 }; 521 522 class Gather : public BuiltinOperator<GatherOperator, ::tflite::GatherOptions, 523 ::tflite::BuiltinOptions_GatherOptions> { 524 public: 525 using BuiltinOperator::BuiltinOperator; 526 flatbuffers::Offset<TfLiteOptions> WriteOptions( 527 const TocoOperator& op, 528 flatbuffers::FlatBufferBuilder* builder) const override { 529 int axis = op.axis ? op.axis.value() : 0; 530 return ::tflite::CreateGatherOptions(*builder, axis); 531 } 532 533 void ReadOptions(const TfLiteOptions& options, 534 TocoOperator* op) const override { 535 op->axis = {options.axis()}; 536 } 537 538 int GetVersion(const OperatorSignature& op_signature) const override { 539 const string& input_name = op_signature.op->inputs[0]; 540 const Array& input_array = op_signature.model->GetArray(input_name); 541 // If the op take int8 input, it is version 2. 542 if (input_array.data_type == ArrayDataType::kInt8) { 543 return 2; 544 } 545 return 1; 546 } 547 }; 548 549 class GatherNd 550 : public BuiltinOperator<GatherNdOperator, ::tflite::GatherNdOptions, 551 ::tflite::BuiltinOptions_GatherNdOptions> { 552 public: 553 using BuiltinOperator::BuiltinOperator; 554 555 flatbuffers::Offset<TfLiteOptions> WriteOptions( 556 const TocoOperator& op, 557 flatbuffers::FlatBufferBuilder* builder) const override { 558 return ::tflite::CreateGatherNdOptions(*builder); 559 } 560 561 void ReadOptions(const TfLiteOptions& options, 562 TocoOperator* op) const override {} 563 564 int GetVersion(const OperatorSignature& op_signature) const override { 565 return 1; 566 } 567 }; 568 569 class Svdf : public BuiltinOperator<SvdfOperator, ::tflite::SVDFOptions, 570 ::tflite::BuiltinOptions_SVDFOptions> { 571 public: 572 using BuiltinOperator::BuiltinOperator; 573 flatbuffers::Offset<TfLiteOptions> WriteOptions( 574 const TocoOperator& op, 575 flatbuffers::FlatBufferBuilder* builder) const override { 576 auto activation_function = 577 ActivationFunction::Serialize(op.fused_activation_function); 578 return ::tflite::CreateSVDFOptions(*builder, op.rank, activation_function); 579 } 580 581 void ReadOptions(const TfLiteOptions& options, 582 TocoOperator* op) const override { 583 op->fused_activation_function = 584 ActivationFunction::Deserialize(options.fused_activation_function()); 585 op->rank = options.rank(); 586 } 587 588 int GetVersion(const OperatorSignature& op_signature) const override { 589 const string& input_name = op_signature.op->inputs[0]; 590 const string& weights_feature_name = op_signature.op->inputs[1]; 591 const string& output_name = op_signature.op->outputs[0]; 592 const Array& input_array = op_signature.model->GetArray(input_name); 593 const Array& weights_feature_array = 594 op_signature.model->GetArray(weights_feature_name); 595 const Array& output_array = op_signature.model->GetArray(output_name); 596 // If the op is a signed int8 hybrid operation, we need to return 597 // version 2. 598 if (input_array.data_type == ArrayDataType::kFloat && 599 weights_feature_array.data_type == ArrayDataType::kInt8 && 600 output_array.data_type == ArrayDataType::kFloat) { 601 return 2; 602 } 603 return 1; 604 } 605 }; 606 607 class L2Normalization 608 : public BuiltinOperator<L2NormalizationOperator, ::tflite::L2NormOptions, 609 ::tflite::BuiltinOptions_L2NormOptions> { 610 public: 611 using BuiltinOperator::BuiltinOperator; 612 flatbuffers::Offset<TfLiteOptions> WriteOptions( 613 const TocoOperator& op, 614 flatbuffers::FlatBufferBuilder* builder) const override { 615 auto activation_function = 616 ActivationFunction::Serialize(op.fused_activation_function); 617 return ::tflite::CreateL2NormOptions(*builder, activation_function); 618 } 619 620 void ReadOptions(const TfLiteOptions& options, 621 TocoOperator* op) const override { 622 op->fused_activation_function = 623 ActivationFunction::Deserialize(options.fused_activation_function()); 624 } 625 626 int GetVersion(const OperatorSignature& op_signature) const override { 627 const string& output_name = op_signature.op->outputs[0]; 628 const Array& output_array = op_signature.model->GetArray(output_name); 629 // Version 2 supports signed int8 input types. 630 if (output_array.data_type == ArrayDataType::kInt8) { 631 return 2; 632 } 633 return 1; 634 } 635 }; 636 637 class L2Pool : public BuiltinOperator<L2PoolOperator, ::tflite::Pool2DOptions, 638 ::tflite::BuiltinOptions_Pool2DOptions> { 639 public: 640 using BuiltinOperator::BuiltinOperator; 641 flatbuffers::Offset<TfLiteOptions> WriteOptions( 642 const TocoOperator& op, 643 flatbuffers::FlatBufferBuilder* builder) const override { 644 auto padding = Padding::Serialize(op.padding.type); 645 auto activation_function = 646 ActivationFunction::Serialize(op.fused_activation_function); 647 return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width, 648 op.stride_height, op.kwidth, 649 op.kheight, activation_function); 650 } 651 652 void ReadOptions(const TfLiteOptions& options, 653 TocoOperator* op) const override { 654 op->padding.type = Padding::Deserialize(options.padding()); 655 op->stride_width = options.stride_w(); 656 op->stride_height = options.stride_h(); 657 op->kwidth = options.filter_width(); 658 op->kheight = options.filter_height(); 659 op->fused_activation_function = 660 ActivationFunction::Deserialize(options.fused_activation_function()); 661 } 662 663 int GetVersion(const OperatorSignature& op_signature) const override { 664 return 1; 665 } 666 }; 667 668 class LocalResponseNormalization 669 : public BuiltinOperator< 670 LocalResponseNormalizationOperator, 671 ::tflite::LocalResponseNormalizationOptions, 672 ::tflite::BuiltinOptions_LocalResponseNormalizationOptions> { 673 public: 674 using BuiltinOperator::BuiltinOperator; 675 flatbuffers::Offset<TfLiteOptions> WriteOptions( 676 const TocoOperator& op, 677 flatbuffers::FlatBufferBuilder* builder) const override { 678 return ::tflite::CreateLocalResponseNormalizationOptions( 679 *builder, op.range, op.bias, op.alpha, op.beta); 680 } 681 682 void ReadOptions(const TfLiteOptions& options, 683 TocoOperator* op) const override { 684 op->range = options.radius(); 685 op->bias = options.bias(); 686 op->alpha = options.alpha(); 687 op->beta = options.beta(); 688 } 689 690 int GetVersion(const OperatorSignature& op_signature) const override { 691 return 1; 692 } 693 }; 694 695 class MaxPool : public BuiltinOperator<MaxPoolOperator, ::tflite::Pool2DOptions, 696 ::tflite::BuiltinOptions_Pool2DOptions> { 697 public: 698 using BuiltinOperator::BuiltinOperator; 699 flatbuffers::Offset<TfLiteOptions> WriteOptions( 700 const TocoOperator& op, 701 flatbuffers::FlatBufferBuilder* builder) const override { 702 auto padding = Padding::Serialize(op.padding.type); 703 auto activation_function = 704 ActivationFunction::Serialize(op.fused_activation_function); 705 return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width, 706 op.stride_height, op.kwidth, 707 op.kheight, activation_function); 708 } 709 710 void ReadOptions(const TfLiteOptions& options, 711 TocoOperator* op) const override { 712 op->padding.type = Padding::Deserialize(options.padding()); 713 op->stride_width = options.stride_w(); 714 op->stride_height = options.stride_h(); 715 op->kwidth = options.filter_width(); 716 op->kheight = options.filter_height(); 717 op->fused_activation_function = 718 ActivationFunction::Deserialize(options.fused_activation_function()); 719 } 720 721 int GetVersion(const OperatorSignature& op_signature) const override { 722 const string& input_name = op_signature.op->inputs[0]; 723 const Array& input_array = op_signature.model->GetArray(input_name); 724 if (input_array.data_type == ArrayDataType::kInt8) { 725 return 2; 726 } 727 return 1; 728 } 729 }; 730 731 class Maximum : public SimpleOperator<TensorFlowMaximumOperator> { 732 public: 733 explicit Maximum() : SimpleOperator("MAXIMUM", OperatorType::kMaximum) {} 734 int GetVersion(const OperatorSignature& op_signature) const override { 735 const string& input_name = op_signature.op->inputs[0]; 736 const Array& input_array = op_signature.model->GetArray(input_name); 737 // Version 2 supports signed int8 input types. 738 if (input_array.data_type == ArrayDataType::kInt8) { 739 return 2; 740 } 741 return 1; 742 } 743 }; 744 745 class Minimum : public SimpleOperator<TensorFlowMinimumOperator> { 746 public: 747 explicit Minimum() : SimpleOperator("MINIMUM", OperatorType::kMinimum) {} 748 int GetVersion(const OperatorSignature& op_signature) const override { 749 const string& input_name = op_signature.op->inputs[0]; 750 const Array& input_array = op_signature.model->GetArray(input_name); 751 // Version 2 supports signed int8 input types. 752 if (input_array.data_type == ArrayDataType::kInt8) { 753 return 2; 754 } 755 return 1; 756 } 757 }; 758 759 class Mul : public BuiltinOperator<MulOperator, ::tflite::MulOptions, 760 ::tflite::BuiltinOptions_MulOptions> { 761 public: 762 using BuiltinOperator::BuiltinOperator; 763 764 flatbuffers::Offset<TfLiteOptions> WriteOptions( 765 const TocoOperator& op, 766 flatbuffers::FlatBufferBuilder* builder) const override { 767 auto activation_function = 768 ActivationFunction::Serialize(op.fused_activation_function); 769 return ::tflite::CreateMulOptions(*builder, activation_function); 770 } 771 772 void ReadOptions(const TfLiteOptions& options, 773 TocoOperator* op) const override { 774 op->fused_activation_function = 775 ActivationFunction::Deserialize(options.fused_activation_function()); 776 } 777 778 int GetVersion(const OperatorSignature& op_signature) const override { 779 const string& input_name = op_signature.op->inputs[0]; 780 const Array& input_array = op_signature.model->GetArray(input_name); 781 // Version 2 supports signed int8 input types. 782 if (input_array.data_type == ArrayDataType::kInt8) { 783 return 2; 784 } 785 return 1; 786 } 787 }; 788 789 class Pad : public BuiltinOperator<PadOperator, ::tflite::PadOptions, 790 ::tflite::BuiltinOptions_PadOptions> { 791 public: 792 using BuiltinOperator::BuiltinOperator; 793 794 flatbuffers::Offset<TfLiteOptions> WriteOptions( 795 const TocoOperator& op, 796 flatbuffers::FlatBufferBuilder* builder) const override { 797 return ::tflite::CreatePadOptions(*builder); 798 } 799 800 void ReadOptions(const TfLiteOptions& options, 801 TocoOperator* op) const override {} 802 803 int GetVersion(const OperatorSignature& op_signature) const override { 804 const string& input_name = op_signature.op->inputs[0]; 805 const Array& input_array = op_signature.model->GetArray(input_name); 806 // If the op take int8 input, it is version 2. 807 if (input_array.data_type == ArrayDataType::kInt8) { 808 return 2; 809 } 810 return 1; 811 } 812 }; 813 814 class Tile 815 : public BuiltinOperator<TensorFlowTileOperator, ::tflite::TileOptions, 816 ::tflite::BuiltinOptions_TileOptions> { 817 using BuiltinOperator::BuiltinOperator; 818 819 flatbuffers::Offset<TfLiteOptions> WriteOptions( 820 const TocoOperator& op, 821 flatbuffers::FlatBufferBuilder* builder) const override { 822 return ::tflite::CreateTileOptions(*builder); 823 } 824 825 void ReadOptions(const TfLiteOptions& options, 826 TocoOperator* op) const override {} 827 int GetVersion(const OperatorSignature& op_signature) const override { 828 return 1; 829 } 830 }; 831 832 class PadV2 : public BuiltinOperator<PadV2Operator, ::tflite::PadV2Options, 833 ::tflite::BuiltinOptions_PadV2Options> { 834 public: 835 using BuiltinOperator::BuiltinOperator; 836 837 flatbuffers::Offset<TfLiteOptions> WriteOptions( 838 const TocoOperator& op, 839 flatbuffers::FlatBufferBuilder* builder) const override { 840 return ::tflite::CreatePadV2Options(*builder); 841 } 842 843 void ReadOptions(const TfLiteOptions& options, 844 TocoOperator* op) const override {} 845 846 int GetVersion(const OperatorSignature& op_signature) const override { 847 const string& input_name = op_signature.op->inputs[0]; 848 const Array& input_array = op_signature.model->GetArray(input_name); 849 // If the op take int8 input, it is version 2. 850 if (input_array.data_type == ArrayDataType::kInt8) { 851 return 2; 852 } 853 return 1; 854 } 855 }; 856 857 class Reshape 858 : public BuiltinOperator<TensorFlowReshapeOperator, 859 ::tflite::ReshapeOptions, 860 ::tflite::BuiltinOptions_ReshapeOptions> { 861 public: 862 using BuiltinOperator::BuiltinOperator; 863 864 flatbuffers::Offset<TfLiteOptions> WriteOptions( 865 const TocoOperator& op, 866 flatbuffers::FlatBufferBuilder* builder) const override { 867 return ::tflite::CreateReshapeOptions(*builder, 868 builder->CreateVector(op.shape)); 869 } 870 871 void ReadOptions(const TfLiteOptions& options, 872 TocoOperator* op) const override { 873 op->shape.insert(op->shape.end(), options.new_shape()->begin(), 874 options.new_shape()->end()); 875 } 876 877 int GetVersion(const OperatorSignature& op_signature) const override { 878 return 1; 879 } 880 }; 881 882 class Softmax 883 : public BuiltinOperator<SoftmaxOperator, ::tflite::SoftmaxOptions, 884 ::tflite::BuiltinOptions_SoftmaxOptions> { 885 public: 886 using BuiltinOperator::BuiltinOperator; 887 flatbuffers::Offset<TfLiteOptions> WriteOptions( 888 const TocoOperator& op, 889 flatbuffers::FlatBufferBuilder* builder) const override { 890 return ::tflite::CreateSoftmaxOptions(*builder, op.beta); 891 } 892 893 void ReadOptions(const TfLiteOptions& options, 894 TocoOperator* op) const override { 895 op->beta = options.beta(); 896 } 897 898 int GetVersion(const OperatorSignature& op_signature) const override { 899 const string& input_name = op_signature.op->inputs[0]; 900 const Array& input_array = op_signature.model->GetArray(input_name); 901 if (input_array.data_type == ArrayDataType::kInt8) { 902 return 2; 903 } 904 return 1; 905 } 906 }; 907 908 class SpaceToDepth 909 : public BuiltinOperator<SpaceToDepthOperator, 910 ::tflite::SpaceToDepthOptions, 911 ::tflite::BuiltinOptions_SpaceToDepthOptions> { 912 public: 913 using BuiltinOperator::BuiltinOperator; 914 flatbuffers::Offset<TfLiteOptions> WriteOptions( 915 const TocoOperator& op, 916 flatbuffers::FlatBufferBuilder* builder) const override { 917 return ::tflite::CreateSpaceToDepthOptions(*builder, op.block_size); 918 } 919 920 void ReadOptions(const TfLiteOptions& options, 921 TocoOperator* op) const override { 922 op->block_size = options.block_size(); 923 } 924 925 int GetVersion(const OperatorSignature& op_signature) const override { 926 const string& input_name = op_signature.op->inputs[0]; 927 const Array& input_array = op_signature.model->GetArray(input_name); 928 // If the op take int8 input, it is version 2. 929 if (input_array.data_type == ArrayDataType::kInt8) { 930 return 2; 931 } 932 return 1; 933 } 934 }; 935 936 class Transpose 937 : public BuiltinOperator<TransposeOperator, ::tflite::TransposeOptions, 938 ::tflite::BuiltinOptions_TransposeOptions> { 939 public: 940 using BuiltinOperator::BuiltinOperator; 941 flatbuffers::Offset<TfLiteOptions> WriteOptions( 942 const TocoOperator& op, 943 flatbuffers::FlatBufferBuilder* builder) const override { 944 return ::tflite::CreateTransposeOptions(*builder); 945 } 946 947 void ReadOptions(const TfLiteOptions& options, 948 TocoOperator* op) const override {} 949 950 int GetVersion(const OperatorSignature& op_signature) const override { 951 const string& input_name = op_signature.op->inputs[0]; 952 const Array& input_array = op_signature.model->GetArray(input_name); 953 // If the op take int8 input, it is version 2. 954 if (input_array.data_type == ArrayDataType::kInt8) { 955 return 2; 956 } 957 return 1; 958 } 959 }; 960 961 class Lstm : public BuiltinOperator<LstmCellOperator, ::tflite::LSTMOptions, 962 ::tflite::BuiltinOptions_LSTMOptions> { 963 public: 964 using BuiltinOperator::BuiltinOperator; 965 flatbuffers::Offset<TfLiteOptions> WriteOptions( 966 const TocoOperator& op, 967 flatbuffers::FlatBufferBuilder* builder) const override { 968 ::tflite::LSTMKernelType kernel_type = ::tflite::LSTMKernelType_FULL; 969 switch (op.kernel_type) { 970 case LstmCellOperator::KERNEL_BASIC: 971 kernel_type = ::tflite::LSTMKernelType_BASIC; 972 break; 973 case LstmCellOperator::KERNEL_FULL: 974 kernel_type = ::tflite::LSTMKernelType_FULL; 975 break; 976 default: 977 return -1; 978 } 979 980 // Current toco converter only supports tanh, no clip. 981 return ::tflite::CreateLSTMOptions(*builder, /*fused_activation_function=*/ 982 ::tflite::ActivationFunctionType_TANH, 983 /*cell_clip=*/0.0, 984 /*proj_clip=*/0.0, kernel_type); 985 } 986 987 void ReadOptions(const TfLiteOptions& options, 988 TocoOperator* op) const override { 989 // Only support tanh activation, so check that tflite type is tanh. 990 CHECK(options.fused_activation_function() == 991 ::tflite::ActivationFunctionType_TANH); 992 993 switch (options.kernel_type()) { 994 case ::tflite::LSTMKernelType_BASIC: 995 op->kernel_type = LstmCellOperator::KERNEL_BASIC; 996 break; 997 case ::tflite::LSTMKernelType_FULL: 998 op->kernel_type = LstmCellOperator::KERNEL_FULL; 999 break; 1000 } 1001 } 1002 1003 int GetVersion(const OperatorSignature& op_signature) const override { 1004 const auto& lstm_op = 1005 static_cast<const LstmCellOperator&>(*op_signature.op); 1006 switch (lstm_op.kernel_type) { 1007 case LstmCellOperator::KERNEL_FULL: { 1008 // If the input tensor is float and a weight is int8, this is a version 1009 // 3 hybrid operation. 1010 const string& input_name = op_signature.op->inputs[0]; 1011 const string& weights_name = op_signature.op->inputs[2]; 1012 const string& output_name = op_signature.op->outputs[0]; 1013 const Array& input_array = op_signature.model->GetArray(input_name); 1014 const Array& weights_array = op_signature.model->GetArray(weights_name); 1015 const Array& output_array = op_signature.model->GetArray(output_name); 1016 if (input_array.data_type == ArrayDataType::kFloat && 1017 weights_array.data_type == ArrayDataType::kInt8 && 1018 output_array.data_type == ArrayDataType::kFloat) { 1019 return 3; 1020 } 1021 return 1; 1022 } 1023 case LstmCellOperator::KERNEL_BASIC: 1024 // KERNEL_BASIC was added in version 2. 1025 return 2; 1026 } 1027 } 1028 1029 std::vector<bool> GetMutatingInputVariables( 1030 const Operator& op) const override { 1031 const auto& lstm_op = static_cast<const LstmCellOperator&>(op); 1032 1033 std::vector<bool> mutating_input_variables(op.inputs.size(), false); 1034 switch (lstm_op.kernel_type) { 1035 case LstmCellOperator::KERNEL_FULL: { 1036 mutating_input_variables[kInputActivationStateTensor] = true; 1037 mutating_input_variables[kInputCellStateTensor] = true; 1038 break; 1039 } 1040 case LstmCellOperator::KERNEL_BASIC: { 1041 mutating_input_variables[LstmCellOperator::PREV_ACTIV_INPUT] = true; 1042 mutating_input_variables[LstmCellOperator::PREV_STATE_INPUT] = true; 1043 break; 1044 } 1045 } 1046 return mutating_input_variables; 1047 } 1048 }; 1049 1050 class UnidirectionalSequenceLstm 1051 : public BuiltinOperator< 1052 UnidirectionalSequenceLstmOperator, 1053 ::tflite::UnidirectionalSequenceLSTMOptions, 1054 ::tflite::BuiltinOptions_UnidirectionalSequenceLSTMOptions> { 1055 public: 1056 using BuiltinOperator::BuiltinOperator; 1057 flatbuffers::Offset<TfLiteOptions> WriteOptions( 1058 const TocoOperator& op, 1059 flatbuffers::FlatBufferBuilder* builder) const override { 1060 // Current toco converter only supports tanh, no clip. 1061 return ::tflite::CreateUnidirectionalSequenceLSTMOptions( 1062 *builder, /*fused_activation_function=*/ 1063 ::tflite::ActivationFunctionType_TANH, 1064 /*cell_clip=*/0.0, 1065 /*proj_clip=*/0.0, 1066 /*time_major=*/true); 1067 } 1068 1069 void ReadOptions(const TfLiteOptions& options, 1070 TocoOperator* op) const override { 1071 // Only support tanh activation, so check that tflite type is tanh. 1072 DCHECK(options.fused_activation_function() == 1073 ::tflite::ActivationFunctionType_TANH); 1074 } 1075 1076 int GetVersion(const OperatorSignature& op_signature) const override { 1077 // If the input tensor is float and a weight is int8, this is a version 1078 // 2 hybrid operation. 1079 const string& input_name = op_signature.op->inputs[0]; 1080 const string& weights_name = op_signature.op->inputs[2]; 1081 const string& output_name = op_signature.op->outputs[0]; 1082 const Array& input_array = op_signature.model->GetArray(input_name); 1083 const Array& weights_array = op_signature.model->GetArray(weights_name); 1084 const Array& output_array = op_signature.model->GetArray(output_name); 1085 if (input_array.data_type == ArrayDataType::kFloat && 1086 weights_array.data_type == ArrayDataType::kInt8 && 1087 output_array.data_type == ArrayDataType::kFloat) { 1088 return 2; 1089 } 1090 return 1; 1091 } 1092 1093 std::vector<bool> GetMutatingInputVariables( 1094 const Operator& op) const override { 1095 std::vector<bool> mutating_input_variables(op.inputs.size(), false); 1096 mutating_input_variables[kInputActivationStateTensor] = true; 1097 mutating_input_variables[kInputCellStateTensor] = true; 1098 return mutating_input_variables; 1099 } 1100 }; 1101 1102 class BidirectionalSequenceLstm 1103 : public BuiltinOperator< 1104 BidirectionalSequenceLstmOperator, 1105 ::tflite::BidirectionalSequenceLSTMOptions, 1106 ::tflite::BuiltinOptions_BidirectionalSequenceLSTMOptions> { 1107 public: 1108 using BuiltinOperator::BuiltinOperator; 1109 flatbuffers::Offset<TfLiteOptions> WriteOptions( 1110 const TocoOperator& op, 1111 flatbuffers::FlatBufferBuilder* builder) const override { 1112 // Current toco converter only supports tanh, no clip. 1113 return ::tflite::CreateBidirectionalSequenceLSTMOptions( 1114 *builder, /*fused_activation_function=*/ 1115 ::tflite::ActivationFunctionType_TANH, 1116 /*cell_clip=*/0.0, 1117 /*proj_clip=*/0.0, 1118 /*merge_outputs=*/op.merge_outputs, 1119 /*time_major=*/true); 1120 } 1121 1122 void ReadOptions(const TfLiteOptions& options, 1123 TocoOperator* op) const override { 1124 // Only support tanh activation, so check that tflite type is tanh. 1125 DCHECK(options.fused_activation_function() == 1126 ::tflite::ActivationFunctionType_TANH); 1127 op->merge_outputs = options.merge_outputs(); 1128 } 1129 1130 int GetVersion(const OperatorSignature& op_signature) const override { 1131 return 1; 1132 } 1133 1134 std::vector<bool> GetMutatingInputVariables( 1135 const Operator& op) const override { 1136 std::vector<bool> mutating_input_variables(op.inputs.size(), false); 1137 // Forward input activation state. 1138 mutating_input_variables[35] = true; 1139 // Forward input cell state. 1140 mutating_input_variables[36] = true; 1141 // Backward input activation state. 1142 mutating_input_variables[37] = true; 1143 // Backward input cell state. 1144 mutating_input_variables[38] = true; 1145 return mutating_input_variables; 1146 } 1147 }; 1148 1149 class BidirectionalSequenceRnn 1150 : public BuiltinOperator< 1151 BidirectionalSequenceRnnOperator, 1152 ::tflite::BidirectionalSequenceRNNOptions, 1153 ::tflite::BuiltinOptions_BidirectionalSequenceRNNOptions> { 1154 public: 1155 using BuiltinOperator::BuiltinOperator; 1156 flatbuffers::Offset<TfLiteOptions> WriteOptions( 1157 const TocoOperator& op, 1158 flatbuffers::FlatBufferBuilder* builder) const override { 1159 // Current toco converter only supports tanh, no clip. 1160 return ::tflite::CreateBidirectionalSequenceRNNOptions( 1161 *builder, /*time_major=*/true, 1162 /*fused_activation_function=*/ 1163 ::tflite::ActivationFunctionType_TANH, 1164 /*merge_outputs=*/op.merge_outputs); 1165 } 1166 1167 void ReadOptions(const TfLiteOptions& options, 1168 TocoOperator* op) const override { 1169 // Only support tanh activation, so check that tflite type is tanh. 1170 DCHECK(options.fused_activation_function() == 1171 ::tflite::ActivationFunctionType_TANH); 1172 op->merge_outputs = options.merge_outputs(); 1173 } 1174 1175 int GetVersion(const OperatorSignature& op_signature) const override { 1176 return 1; 1177 } 1178 1179 std::vector<bool> GetMutatingInputVariables( 1180 const Operator& op) const override { 1181 std::vector<bool> mutating_input_variables(op.inputs.size(), false); 1182 // Forward hidden state. 1183 mutating_input_variables[4] = true; 1184 // Backward hidden state. 1185 mutating_input_variables[8] = true; 1186 return mutating_input_variables; 1187 } 1188 }; 1189 1190 class Mean : public BuiltinOperator<MeanOperator, ::tflite::ReducerOptions, 1191 ::tflite::BuiltinOptions_ReducerOptions> { 1192 public: 1193 using BuiltinOperator::BuiltinOperator; 1194 flatbuffers::Offset<TfLiteOptions> WriteOptions( 1195 const TocoOperator& op, 1196 flatbuffers::FlatBufferBuilder* builder) const override { 1197 return ::tflite::CreateReducerOptions(*builder, op.keep_dims); 1198 } 1199 1200 void ReadOptions(const TfLiteOptions& options, 1201 TocoOperator* op) const override { 1202 op->keep_dims = options.keep_dims(); 1203 } 1204 1205 int GetVersion(const OperatorSignature& op_signature) const override { 1206 return 1; 1207 } 1208 }; 1209 1210 class Sum 1211 : public BuiltinOperator<TensorFlowSumOperator, ::tflite::ReducerOptions, 1212 ::tflite::BuiltinOptions_ReducerOptions> { 1213 public: 1214 using BuiltinOperator::BuiltinOperator; 1215 flatbuffers::Offset<TfLiteOptions> WriteOptions( 1216 const TocoOperator& op, 1217 flatbuffers::FlatBufferBuilder* builder) const override { 1218 return ::tflite::CreateReducerOptions(*builder, op.keep_dims); 1219 } 1220 1221 void ReadOptions(const TfLiteOptions& options, 1222 TocoOperator* op) const override { 1223 op->keep_dims = options.keep_dims(); 1224 } 1225 1226 int GetVersion(const OperatorSignature& op_signature) const override { 1227 return 1; 1228 } 1229 }; 1230 1231 class ReduceMax 1232 : public BuiltinOperator<TensorFlowMaxOperator, ::tflite::ReducerOptions, 1233 ::tflite::BuiltinOptions_ReducerOptions> { 1234 public: 1235 using BuiltinOperator::BuiltinOperator; 1236 flatbuffers::Offset<TfLiteOptions> WriteOptions( 1237 const TocoOperator& op, 1238 flatbuffers::FlatBufferBuilder* builder) const override { 1239 return ::tflite::CreateReducerOptions(*builder, op.keep_dims); 1240 } 1241 1242 void ReadOptions(const TfLiteOptions& options, 1243 TocoOperator* op) const override { 1244 op->keep_dims = options.keep_dims(); 1245 } 1246 1247 int GetVersion(const OperatorSignature& op_signature) const override { 1248 const string& input_name = op_signature.op->inputs[0]; 1249 const Array& input_array = op_signature.model->GetArray(input_name); 1250 // If the op take int8 input, it is version 2. 1251 if (input_array.data_type == ArrayDataType::kInt8) { 1252 return 2; 1253 } 1254 return 1; 1255 } 1256 }; 1257 1258 class ReduceMin 1259 : public BuiltinOperator<TensorFlowMinOperator, ::tflite::ReducerOptions, 1260 ::tflite::BuiltinOptions_ReducerOptions> { 1261 public: 1262 using BuiltinOperator::BuiltinOperator; 1263 flatbuffers::Offset<TfLiteOptions> WriteOptions( 1264 const TocoOperator& op, 1265 flatbuffers::FlatBufferBuilder* builder) const override { 1266 return ::tflite::CreateReducerOptions(*builder, op.keep_dims); 1267 } 1268 1269 void ReadOptions(const TfLiteOptions& options, 1270 TocoOperator* op) const override { 1271 op->keep_dims = options.keep_dims(); 1272 } 1273 1274 int GetVersion(const OperatorSignature& op_signature) const override { 1275 const string& input_name = op_signature.op->inputs[0]; 1276 const Array& input_array = op_signature.model->GetArray(input_name); 1277 // If the op take int8 input, it is version 2. 1278 if (input_array.data_type == ArrayDataType::kInt8) { 1279 return 2; 1280 } 1281 return 1; 1282 } 1283 }; 1284 1285 class ReduceProd 1286 : public BuiltinOperator<TensorFlowProdOperator, ::tflite::ReducerOptions, 1287 ::tflite::BuiltinOptions_ReducerOptions> { 1288 public: 1289 using BuiltinOperator::BuiltinOperator; 1290 flatbuffers::Offset<TfLiteOptions> WriteOptions( 1291 const TocoOperator& op, 1292 flatbuffers::FlatBufferBuilder* builder) const override { 1293 return ::tflite::CreateReducerOptions(*builder, op.keep_dims); 1294 } 1295 1296 void ReadOptions(const TfLiteOptions& options, 1297 TocoOperator* op) const override { 1298 op->keep_dims = options.keep_dims(); 1299 } 1300 1301 int GetVersion(const OperatorSignature& op_signature) const override { 1302 return 1; 1303 } 1304 }; 1305 1306 class ReduceAny 1307 : public BuiltinOperator<TensorFlowAnyOperator, ::tflite::ReducerOptions, 1308 ::tflite::BuiltinOptions_ReducerOptions> { 1309 public: 1310 using BuiltinOperator::BuiltinOperator; 1311 flatbuffers::Offset<TfLiteOptions> WriteOptions( 1312 const TocoOperator& op, 1313 flatbuffers::FlatBufferBuilder* builder) const override { 1314 return ::tflite::CreateReducerOptions(*builder, op.keep_dims); 1315 } 1316 1317 void ReadOptions(const TfLiteOptions& options, 1318 TocoOperator* op) const override { 1319 op->keep_dims = options.keep_dims(); 1320 } 1321 1322 int GetVersion(const OperatorSignature& op_signature) const override { 1323 return 1; 1324 } 1325 }; 1326 1327 class Relu6 : public SimpleOperator<Relu6Operator> { 1328 public: 1329 explicit Relu6() : SimpleOperator("RELU6", OperatorType::kRelu6) {} 1330 int GetVersion(const OperatorSignature& op_signature) const override { 1331 const string& input_name = op_signature.op->inputs[0]; 1332 const Array& input_array = op_signature.model->GetArray(input_name); 1333 // Version 2 supports signed int8 input types. 1334 if (input_array.data_type == ArrayDataType::kInt8) { 1335 return 2; 1336 } 1337 return 1; 1338 } 1339 }; 1340 1341 class ResizeBilinear 1342 : public BuiltinOperator<ResizeBilinearOperator, 1343 ::tflite::ResizeBilinearOptions, 1344 ::tflite::BuiltinOptions_ResizeBilinearOptions> { 1345 public: 1346 using BuiltinOperator::BuiltinOperator; 1347 flatbuffers::Offset<TfLiteOptions> WriteOptions( 1348 const TocoOperator& op, 1349 flatbuffers::FlatBufferBuilder* builder) const override { 1350 return ::tflite::CreateResizeBilinearOptions(*builder, op.align_corners); 1351 } 1352 1353 void ReadOptions(const TfLiteOptions& options, 1354 TocoOperator* op) const override { 1355 op->align_corners = options.align_corners(); 1356 } 1357 1358 int GetVersion(const OperatorSignature& op_signature) const override { 1359 const string& input_name = op_signature.op->inputs[0]; 1360 const Array& input_array = op_signature.model->GetArray(input_name); 1361 // If the op takes int8 input, it is version 2. 1362 if (input_array.data_type == ArrayDataType::kInt8) { 1363 return 2; 1364 } 1365 return 1; 1366 } 1367 }; 1368 1369 class ResizeNearestNeighbor 1370 : public BuiltinOperator< 1371 ResizeNearestNeighborOperator, ::tflite::ResizeNearestNeighborOptions, 1372 ::tflite::BuiltinOptions_ResizeNearestNeighborOptions> { 1373 public: 1374 using BuiltinOperator::BuiltinOperator; 1375 flatbuffers::Offset<TfLiteOptions> WriteOptions( 1376 const TocoOperator& op, 1377 flatbuffers::FlatBufferBuilder* builder) const override { 1378 return ::tflite::CreateResizeNearestNeighborOptions(*builder, 1379 op.align_corners); 1380 } 1381 1382 void ReadOptions(const TfLiteOptions& options, 1383 TocoOperator* op) const override { 1384 op->align_corners = options.align_corners(); 1385 } 1386 1387 int GetVersion(const OperatorSignature& op_signature) const override { 1388 const string& input_name = op_signature.op->inputs[0]; 1389 const Array& input_array = op_signature.model->GetArray(input_name); 1390 // Version 2 supports signed int8 input types. 1391 if (input_array.data_type == ArrayDataType::kInt8) { 1392 return 2; 1393 } 1394 return 1; 1395 } 1396 }; 1397 1398 class Squeeze 1399 : public BuiltinOperator<SqueezeOperator, ::tflite::SqueezeOptions, 1400 ::tflite::BuiltinOptions_SqueezeOptions> { 1401 public: 1402 using BuiltinOperator::BuiltinOperator; 1403 1404 flatbuffers::Offset<TfLiteOptions> WriteOptions( 1405 const TocoOperator& op, 1406 flatbuffers::FlatBufferBuilder* builder) const override { 1407 auto squeeze_dims = builder->CreateVector(op.squeeze_dims); 1408 return ::tflite::CreateSqueezeOptions(*builder, squeeze_dims); 1409 } 1410 1411 void ReadOptions(const TfLiteOptions& options, 1412 TocoOperator* op) const override { 1413 op->squeeze_dims.insert(op->squeeze_dims.end(), 1414 options.squeeze_dims()->begin(), 1415 options.squeeze_dims()->end()); 1416 } 1417 1418 int GetVersion(const OperatorSignature& op_signature) const override { 1419 return 1; 1420 } 1421 }; 1422 1423 class Split 1424 : public BuiltinOperator<TensorFlowSplitOperator, ::tflite::SplitOptions, 1425 ::tflite::BuiltinOptions_SplitOptions> { 1426 public: 1427 using BuiltinOperator::BuiltinOperator; 1428 1429 flatbuffers::Offset<TfLiteOptions> WriteOptions( 1430 const TocoOperator& op, 1431 flatbuffers::FlatBufferBuilder* builder) const override { 1432 return ::tflite::CreateSplitOptions(*builder, op.num_split); 1433 } 1434 1435 void ReadOptions(const TfLiteOptions& options, 1436 TocoOperator* op) const override { 1437 op->num_split = options.num_splits(); 1438 } 1439 1440 int GetVersion(const OperatorSignature& op_signature) const override { 1441 const string& input_name = op_signature.op->inputs[0]; 1442 const Array& input_array = op_signature.model->GetArray(input_name); 1443 // If the op take int8 input, it is version 2, for int32 it's version 3. 1444 if (input_array.data_type == ArrayDataType::kInt8) { 1445 return 2; 1446 } else if (input_array.data_type == ArrayDataType::kInt32) { 1447 return 3; 1448 } 1449 return 1; 1450 } 1451 }; 1452 1453 class SplitV 1454 : public BuiltinOperator<TensorFlowSplitVOperator, ::tflite::SplitVOptions, 1455 ::tflite::BuiltinOptions_SplitVOptions> { 1456 public: 1457 using BuiltinOperator::BuiltinOperator; 1458 1459 flatbuffers::Offset<TfLiteOptions> WriteOptions( 1460 const TocoOperator& op, 1461 flatbuffers::FlatBufferBuilder* builder) const override { 1462 return ::tflite::CreateSplitVOptions(*builder, op.num_split); 1463 } 1464 1465 void ReadOptions(const TfLiteOptions& options, 1466 TocoOperator* op) const override { 1467 op->num_split = options.num_splits(); 1468 } 1469 1470 int GetVersion(const OperatorSignature& op_signature) const override { 1471 return 1; 1472 } 1473 }; 1474 1475 class StridedSlice 1476 : public BuiltinOperator<StridedSliceOperator, 1477 ::tflite::StridedSliceOptions, 1478 ::tflite::BuiltinOptions_StridedSliceOptions> { 1479 public: 1480 using BuiltinOperator::BuiltinOperator; 1481 flatbuffers::Offset<TfLiteOptions> WriteOptions( 1482 const TocoOperator& op, 1483 flatbuffers::FlatBufferBuilder* builder) const override { 1484 return ::tflite::CreateStridedSliceOptions( 1485 *builder, op.begin_mask, op.end_mask, op.ellipsis_mask, 1486 op.new_axis_mask, op.shrink_axis_mask); 1487 } 1488 1489 void ReadOptions(const TfLiteOptions& options, 1490 TocoOperator* op) const override { 1491 op->begin_mask = options.begin_mask(); 1492 op->end_mask = options.end_mask(); 1493 op->ellipsis_mask = options.ellipsis_mask(); 1494 op->new_axis_mask = options.new_axis_mask(); 1495 op->shrink_axis_mask = options.shrink_axis_mask(); 1496 } 1497 1498 int GetVersion(const OperatorSignature& op_signature) const override { 1499 const string& input_name = op_signature.op->inputs[0]; 1500 const Array& input_array = op_signature.model->GetArray(input_name); 1501 // If the op take int8 input, it is version 2. 1502 if (input_array.data_type == ArrayDataType::kInt8) { 1503 return 2; 1504 } 1505 return 1; 1506 } 1507 }; 1508 1509 class TopK_V2 : public BuiltinOperator<TopKV2Operator, ::tflite::TopKV2Options, 1510 ::tflite::BuiltinOptions_TopKV2Options> { 1511 public: 1512 using BuiltinOperator::BuiltinOperator; 1513 flatbuffers::Offset<TfLiteOptions> WriteOptions( 1514 const TocoOperator& op, 1515 flatbuffers::FlatBufferBuilder* builder) const override { 1516 return ::tflite::CreateTopKV2Options(*builder); 1517 } 1518 1519 void ReadOptions(const TfLiteOptions& options, 1520 TocoOperator* op) const override {} 1521 1522 int GetVersion(const OperatorSignature& op_signature) const override { 1523 const string& input_name = op_signature.op->inputs[0]; 1524 const Array& input_array = op_signature.model->GetArray(input_name); 1525 if (input_array.data_type == ArrayDataType::kInt8) { 1526 return 2; 1527 } 1528 return 1; 1529 } 1530 }; 1531 1532 class ArgMax : public BuiltinOperator<ArgMaxOperator, ::tflite::ArgMaxOptions, 1533 ::tflite::BuiltinOptions_ArgMaxOptions> { 1534 public: 1535 using BuiltinOperator::BuiltinOperator; 1536 flatbuffers::Offset<TfLiteOptions> WriteOptions( 1537 const TocoOperator& op, 1538 flatbuffers::FlatBufferBuilder* builder) const override { 1539 return ::tflite::CreateArgMaxOptions( 1540 *builder, DataType::Serialize(op.output_data_type)); 1541 } 1542 1543 void ReadOptions(const TfLiteOptions& options, 1544 TocoOperator* op) const override { 1545 op->output_data_type = DataType::Deserialize(options.output_type()); 1546 } 1547 1548 int GetVersion(const OperatorSignature& op_signature) const override { 1549 const string& input_name = op_signature.op->inputs[0]; 1550 const Array& input_array = op_signature.model->GetArray(input_name); 1551 if (input_array.data_type == ArrayDataType::kInt8) { 1552 return 2; 1553 } 1554 1555 return 1; 1556 } 1557 }; 1558 1559 class ArgMin : public BuiltinOperator<ArgMinOperator, ::tflite::ArgMinOptions, 1560 ::tflite::BuiltinOptions_ArgMinOptions> { 1561 public: 1562 using BuiltinOperator::BuiltinOperator; 1563 flatbuffers::Offset<TfLiteOptions> WriteOptions( 1564 const TocoOperator& op, 1565 flatbuffers::FlatBufferBuilder* builder) const override { 1566 return ::tflite::CreateArgMinOptions( 1567 *builder, DataType::Serialize(op.output_data_type)); 1568 } 1569 1570 void ReadOptions(const TfLiteOptions& options, 1571 TocoOperator* op) const override { 1572 op->output_data_type = DataType::Deserialize(options.output_type()); 1573 } 1574 1575 int GetVersion(const OperatorSignature& op_signature) const override { 1576 const string& input_name = op_signature.op->inputs[0]; 1577 const Array& input_array = op_signature.model->GetArray(input_name); 1578 if (input_array.data_type == ArrayDataType::kInt8) { 1579 return 2; 1580 } 1581 1582 return 1; 1583 } 1584 }; 1585 1586 class TransposeConv 1587 : public BuiltinOperator<TransposeConvOperator, 1588 ::tflite::TransposeConvOptions, 1589 ::tflite::BuiltinOptions_TransposeConvOptions> { 1590 public: 1591 using BuiltinOperator::BuiltinOperator; 1592 1593 flatbuffers::Offset<TfLiteOptions> WriteOptions( 1594 const TocoOperator& op, 1595 flatbuffers::FlatBufferBuilder* builder) const override { 1596 auto padding = Padding::Serialize(op.padding.type); 1597 return ::tflite::CreateTransposeConvOptions( 1598 *builder, padding, op.stride_width, op.stride_height); 1599 } 1600 1601 void ReadOptions(const TfLiteOptions& options, 1602 TocoOperator* op) const override { 1603 op->padding.type = Padding::Deserialize(options.padding()); 1604 op->stride_width = options.stride_w(); 1605 op->stride_height = options.stride_h(); 1606 } 1607 1608 int GetVersion(const OperatorSignature& op_signature) const override { 1609 return 1; 1610 } 1611 }; 1612 1613 class SparseToDense 1614 : public BuiltinOperator<SparseToDenseOperator, 1615 ::tflite::SparseToDenseOptions, 1616 ::tflite::BuiltinOptions_SparseToDenseOptions> { 1617 public: 1618 using BuiltinOperator::BuiltinOperator; 1619 1620 flatbuffers::Offset<TfLiteOptions> WriteOptions( 1621 const TocoOperator& op, 1622 flatbuffers::FlatBufferBuilder* builder) const override { 1623 return ::tflite::CreateSparseToDenseOptions(*builder, op.validate_indices); 1624 } 1625 1626 void ReadOptions(const TfLiteOptions& options, 1627 TocoOperator* op) const override { 1628 op->validate_indices = options.validate_indices(); 1629 } 1630 1631 int GetVersion(const OperatorSignature& op_signature) const override { 1632 return 1; 1633 } 1634 }; 1635 1636 class ExpandDims 1637 : public BuiltinOperator<ExpandDimsOperator, ::tflite::ExpandDimsOptions, 1638 ::tflite::BuiltinOptions_ExpandDimsOptions> { 1639 public: 1640 using BuiltinOperator::BuiltinOperator; 1641 1642 flatbuffers::Offset<TfLiteOptions> WriteOptions( 1643 const TocoOperator& op, 1644 flatbuffers::FlatBufferBuilder* builder) const override { 1645 return ::tflite::CreateExpandDimsOptions(*builder); 1646 } 1647 1648 void ReadOptions(const TfLiteOptions& options, 1649 TocoOperator* op) const override {} 1650 1651 int GetVersion(const OperatorSignature& op_signature) const override { 1652 return 1; 1653 } 1654 }; 1655 1656 class Pack : public BuiltinOperator<PackOperator, ::tflite::PackOptions, 1657 ::tflite::BuiltinOptions_PackOptions> { 1658 public: 1659 using BuiltinOperator::BuiltinOperator; 1660 1661 flatbuffers::Offset<TfLiteOptions> WriteOptions( 1662 const TocoOperator& op, 1663 flatbuffers::FlatBufferBuilder* builder) const override { 1664 return ::tflite::CreatePackOptions(*builder, op.values_count, op.axis); 1665 } 1666 1667 void ReadOptions(const TfLiteOptions& options, 1668 TocoOperator* op) const override { 1669 op->values_count = options.values_count(); 1670 op->axis = options.axis(); 1671 } 1672 1673 int GetVersion(const OperatorSignature& op_signature) const override { 1674 const string& input_name = op_signature.op->inputs[0]; 1675 const Array& input_array = op_signature.model->GetArray(input_name); 1676 // If the op take int8 input, it is version 2. 1677 if (input_array.data_type == ArrayDataType::kInt8) { 1678 return 2; 1679 } 1680 return 1; 1681 } 1682 }; 1683 1684 class Shape 1685 : public BuiltinOperator<TensorFlowShapeOperator, ::tflite::ShapeOptions, 1686 ::tflite::BuiltinOptions_ShapeOptions> { 1687 public: 1688 using BuiltinOperator::BuiltinOperator; 1689 flatbuffers::Offset<TfLiteOptions> WriteOptions( 1690 const TocoOperator& op, 1691 flatbuffers::FlatBufferBuilder* builder) const override { 1692 return ::tflite::CreateShapeOptions( 1693 *builder, DataType::Serialize(op.output_data_type)); 1694 } 1695 1696 void ReadOptions(const TfLiteOptions& options, 1697 TocoOperator* op) const override { 1698 op->output_data_type = DataType::Deserialize(options.out_type()); 1699 } 1700 1701 int GetVersion(const OperatorSignature& op_signature) const override { 1702 return 1; 1703 } 1704 }; 1705 1706 class Slice : public SimpleOperator<SliceOperator> { 1707 public: 1708 explicit Slice() : SimpleOperator("SLICE", OperatorType::kSlice) {} 1709 int GetVersion(const OperatorSignature& op_signature) const override { 1710 const string& input_name = op_signature.op->inputs[0]; 1711 const Array& input_array = op_signature.model->GetArray(input_name); 1712 // Version 2 supports signed int8 input types. 1713 if (input_array.data_type == ArrayDataType::kInt8) { 1714 return 2; 1715 } 1716 return 1; 1717 } 1718 }; 1719 1720 class Tanh : public SimpleOperator<TanhOperator> { 1721 public: 1722 explicit Tanh() : SimpleOperator("TANH", OperatorType::kTanh) {} 1723 int GetVersion(const OperatorSignature& op_signature) const override { 1724 const string& input_name = op_signature.op->inputs[0]; 1725 const Array& input_array = op_signature.model->GetArray(input_name); 1726 // Version 2 supports signed int8 input types. 1727 if (input_array.data_type == ArrayDataType::kInt8) { 1728 return 2; 1729 } 1730 return 1; 1731 } 1732 }; 1733 1734 class OneHot : public BuiltinOperator<OneHotOperator, ::tflite::OneHotOptions, 1735 ::tflite::BuiltinOptions_OneHotOptions> { 1736 public: 1737 using BuiltinOperator::BuiltinOperator; 1738 flatbuffers::Offset<TfLiteOptions> WriteOptions( 1739 const TocoOperator& op, 1740 flatbuffers::FlatBufferBuilder* builder) const override { 1741 return ::tflite::CreateOneHotOptions(*builder, op.axis); 1742 } 1743 void ReadOptions(const TfLiteOptions& options, 1744 TocoOperator* op) const override { 1745 op->axis = options.axis(); 1746 } 1747 1748 int GetVersion(const OperatorSignature& op_signature) const override { 1749 return 1; 1750 } 1751 }; 1752 1753 class CTCBeamSearchDecoder 1754 : public CustomOperator<CTCBeamSearchDecoderOperator> { 1755 public: 1756 using CustomOperator::CustomOperator; 1757 1758 void WriteOptions(const TocoOperator& op, 1759 flexbuffers::Builder* fbb) const override { 1760 fbb->Int("beam_width", op.beam_width); 1761 fbb->Int("top_paths", op.top_paths); 1762 fbb->Bool("merge_repeated", op.merge_repeated); 1763 } 1764 1765 void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override { 1766 op->beam_width = m["beam_width"].AsInt32(); 1767 op->top_paths = m["top_paths"].AsInt32(); 1768 op->merge_repeated = m["merge_repeated"].AsBool(); 1769 } 1770 1771 int GetVersion(const OperatorSignature& op_signature) const override { 1772 return 1; 1773 } 1774 }; 1775 1776 class Unpack : public BuiltinOperator<UnpackOperator, ::tflite::UnpackOptions, 1777 ::tflite::BuiltinOptions_UnpackOptions> { 1778 public: 1779 using BuiltinOperator::BuiltinOperator; 1780 flatbuffers::Offset<TfLiteOptions> WriteOptions( 1781 const TocoOperator& op, 1782 flatbuffers::FlatBufferBuilder* builder) const override { 1783 return ::tflite::CreateUnpackOptions(*builder, op.num, op.axis); 1784 } 1785 void ReadOptions(const TfLiteOptions& options, 1786 TocoOperator* op) const override { 1787 op->num = options.num(); 1788 op->axis = options.axis(); 1789 } 1790 1791 int GetVersion(const OperatorSignature& op_signature) const override { 1792 return 1; 1793 } 1794 }; 1795 1796 class LeakyRelu 1797 : public BuiltinOperator<LeakyReluOperator, ::tflite::LeakyReluOptions, 1798 ::tflite::BuiltinOptions_LeakyReluOptions> { 1799 public: 1800 using BuiltinOperator::BuiltinOperator; 1801 flatbuffers::Offset<TfLiteOptions> WriteOptions( 1802 const TocoOperator& op, 1803 flatbuffers::FlatBufferBuilder* builder) const override { 1804 return ::tflite::CreateLeakyReluOptions(*builder, op.alpha); 1805 } 1806 void ReadOptions(const TfLiteOptions& options, 1807 TocoOperator* op) const override { 1808 op->alpha = options.alpha(); 1809 } 1810 1811 int GetVersion(const OperatorSignature& op_signature) const override { 1812 return 1; 1813 } 1814 }; 1815 1816 class Logistic : public SimpleOperator<LogisticOperator> { 1817 public: 1818 explicit Logistic() : SimpleOperator("LOGISTIC", OperatorType::kLogistic) {} 1819 int GetVersion(const OperatorSignature& op_signature) const override { 1820 const string& input_name = op_signature.op->inputs[0]; 1821 const Array& input_array = op_signature.model->GetArray(input_name); 1822 // Version 2 supports signed int8 input types. 1823 if (input_array.data_type == ArrayDataType::kInt8) { 1824 return 2; 1825 } 1826 return 1; 1827 } 1828 }; 1829 1830 class LogSoftmax : public SimpleOperator<LogSoftmaxOperator> { 1831 public: 1832 explicit LogSoftmax() 1833 : SimpleOperator("LOG_SOFTMAX", OperatorType::kLogSoftmax) {} 1834 int GetVersion(const OperatorSignature& op_signature) const override { 1835 const string& input_name = op_signature.op->inputs[0]; 1836 const Array& input_array = op_signature.model->GetArray(input_name); 1837 // Version 2 supports signed int8 input types. 1838 if (input_array.data_type == ArrayDataType::kInt8) { 1839 return 2; 1840 } 1841 return 1; 1842 } 1843 }; 1844 1845 class SquaredDifference 1846 : public BuiltinOperator< 1847 SquaredDifferenceOperator, ::tflite::SquaredDifferenceOptions, 1848 ::tflite::BuiltinOptions_SquaredDifferenceOptions> { 1849 public: 1850 using BuiltinOperator::BuiltinOperator; 1851 1852 flatbuffers::Offset<TfLiteOptions> WriteOptions( 1853 const TocoOperator& op, 1854 flatbuffers::FlatBufferBuilder* builder) const override { 1855 return ::tflite::CreateSquaredDifferenceOptions(*builder); 1856 } 1857 1858 void ReadOptions(const TfLiteOptions& options, 1859 TocoOperator* op) const override {} 1860 1861 int GetVersion(const OperatorSignature& op_signature) const override { 1862 return 1; 1863 } 1864 }; 1865 1866 class MirrorPad 1867 : public BuiltinOperator<MirrorPadOperator, ::tflite::MirrorPadOptions, 1868 ::tflite::BuiltinOptions_MirrorPadOptions> { 1869 public: 1870 using BuiltinOperator::BuiltinOperator; 1871 flatbuffers::Offset<TfLiteOptions> WriteOptions( 1872 const TocoOperator& op, 1873 flatbuffers::FlatBufferBuilder* builder) const override { 1874 return ::tflite::CreateMirrorPadOptions( 1875 *builder, op.mode == MirrorPadMode::kReflect 1876 ? ::tflite::MirrorPadMode::MirrorPadMode_REFLECT 1877 : ::tflite::MirrorPadMode::MirrorPadMode_SYMMETRIC); 1878 } 1879 void ReadOptions(const TfLiteOptions& options, 1880 TocoOperator* op) const override { 1881 op->mode = options.mode() == ::tflite::MirrorPadMode::MirrorPadMode_REFLECT 1882 ? MirrorPadMode::kReflect 1883 : MirrorPadMode::kSymmetric; 1884 } 1885 1886 int GetVersion(const OperatorSignature& op) const override { return 1; } 1887 }; 1888 1889 class Unique : public BuiltinOperator<UniqueOperator, ::tflite::UniqueOptions, 1890 ::tflite::BuiltinOptions_UniqueOptions> { 1891 public: 1892 using BuiltinOperator::BuiltinOperator; 1893 flatbuffers::Offset<TfLiteOptions> WriteOptions( 1894 const TocoOperator& op, 1895 flatbuffers::FlatBufferBuilder* builder) const override { 1896 const UniqueOperator& unique_op = static_cast<const UniqueOperator&>(op); 1897 return ::tflite::CreateUniqueOptions( 1898 *builder, unique_op.idx_out_type == toco::ArrayDataType::kInt64 1899 ? ::tflite::TensorType::TensorType_INT64 1900 : ::tflite::TensorType_INT32); 1901 } 1902 void ReadOptions(const TfLiteOptions& options, 1903 TocoOperator* op) const override { 1904 UniqueOperator* unique_op = static_cast<UniqueOperator*>(op); 1905 unique_op->idx_out_type = 1906 options.idx_out_type() == ::tflite::TensorType_INT64 1907 ? toco::ArrayDataType::kInt64 1908 : toco::ArrayDataType::kInt32; 1909 } 1910 1911 int GetVersion(const OperatorSignature& op_signature) const override { 1912 return 1; 1913 } 1914 }; 1915 1916 class UnidirectionalSequenceRnn 1917 : public BuiltinOperator<UnidirectionalSequenceRnnOperator, 1918 ::tflite::SequenceRNNOptions, 1919 ::tflite::BuiltinOptions_SequenceRNNOptions> { 1920 public: 1921 using BuiltinOperator::BuiltinOperator; 1922 flatbuffers::Offset<TfLiteOptions> WriteOptions( 1923 const TocoOperator& op, 1924 flatbuffers::FlatBufferBuilder* builder) const override { 1925 return ::tflite::CreateSequenceRNNOptions( 1926 *builder, /*time_major=*/true, 1927 /*fused_activation_function=*/ 1928 ::tflite::ActivationFunctionType_TANH); 1929 } 1930 void ReadOptions(const TfLiteOptions& options, 1931 TocoOperator* op) const override { 1932 // Only support tanh activation, so check that tflite type is tanh. 1933 DCHECK(options.fused_activation_function() == 1934 ::tflite::ActivationFunctionType_TANH); 1935 } 1936 1937 int GetVersion(const OperatorSignature& op_signature) const override { 1938 return 1; 1939 } 1940 1941 std::vector<bool> GetMutatingInputVariables( 1942 const Operator& op) const override { 1943 std::vector<bool> mutating_input_variables(op.inputs.size(), false); 1944 mutating_input_variables[4] = true; 1945 return mutating_input_variables; 1946 } 1947 }; 1948 1949 class Where : public BuiltinOperator<WhereOperator, ::tflite::WhereOptions, 1950 ::tflite::BuiltinOptions_WhereOptions> { 1951 public: 1952 using BuiltinOperator::BuiltinOperator; 1953 1954 flatbuffers::Offset<TfLiteOptions> WriteOptions( 1955 const TocoOperator& op, 1956 flatbuffers::FlatBufferBuilder* builder) const override { 1957 return ::tflite::CreateWhereOptions(*builder); 1958 } 1959 1960 void ReadOptions(const TfLiteOptions& options, 1961 TocoOperator* op) const override {} 1962 1963 int GetVersion(const OperatorSignature& op_signature) const override { 1964 return 1; 1965 } 1966 }; 1967 1968 std::unique_ptr<flexbuffers::Builder> WriteFlexOpOptions( 1969 const string& tensorflow_node_def) { 1970 auto fbb = absl::make_unique<flexbuffers::Builder>(); 1971 1972 ::tensorflow::NodeDef node_def; 1973 if (!node_def.ParseFromString(tensorflow_node_def)) { 1974 LOG(ERROR) << "Failed to parse TensorFlow NodeDef"; 1975 return {}; 1976 } 1977 1978 fbb->Vector([&]() { 1979 fbb->String(node_def.op()); 1980 fbb->String(tensorflow_node_def); 1981 }); 1982 fbb->Finish(); 1983 LOG(INFO) << "Writing flex op: " << node_def.op(); 1984 return std::unique_ptr<flexbuffers::Builder>(fbb.release()); 1985 } 1986 1987 class TensorFlowUnsupported : public BaseOperator { 1988 public: 1989 TensorFlowUnsupported(const string& name, OperatorType type, 1990 bool enable_select_tf_ops) 1991 : BaseOperator(name, type), enable_select_tf_ops_(enable_select_tf_ops) {} 1992 1993 Options Serialize(const Operator& op, 1994 flatbuffers::FlatBufferBuilder* builder) const override { 1995 auto fbb = 1996 WriteOptions(static_cast<const TensorFlowUnsupportedOperator&>(op)); 1997 if (fbb) { 1998 return Options::Custom(builder->CreateVector(fbb->GetBuffer())); 1999 } else { 2000 return Options::Custom(0); 2001 } 2002 } 2003 2004 std::unique_ptr<Operator> Deserialize( 2005 const BuiltinOptions* builtin_options, 2006 const CustomOptions* custom_options) const override { 2007 // Deserializing Flex ops doesn't work now. 2008 // TODO(ycling): Revisit and decide if we should fix the flow for importing 2009 // TFLite models with Flex ops. 2010 auto op = absl::make_unique<TensorFlowUnsupportedOperator>(); 2011 if (custom_options) { 2012 auto flexbuffer_map = 2013 flexbuffers::GetRoot(custom_options->data(), custom_options->size()) 2014 .AsMap(); 2015 ReadOptions(flexbuffer_map, op.get()); 2016 } 2017 return std::unique_ptr<Operator>(op.release()); 2018 } 2019 2020 std::unique_ptr<flexbuffers::Builder> WriteOptions( 2021 const TensorFlowUnsupportedOperator& op) const { 2022 if (enable_select_tf_ops_) { 2023 return WriteFlexOpOptions(op.tensorflow_node_def); 2024 } 2025 auto fbb = absl::make_unique<flexbuffers::Builder>(); 2026 2027 ::tensorflow::NodeDef node_def; 2028 if (!node_def.ParseFromString(op.tensorflow_node_def)) { 2029 LOG(ERROR) << "Failed to parse TensorFlow NodeDef"; 2030 return std::unique_ptr<flexbuffers::Builder>(); 2031 } 2032 2033 if (ShouldExportAsFlexOp(enable_select_tf_ops_, node_def.op())) { 2034 fbb->Vector([&]() { 2035 fbb->String(node_def.op()); 2036 fbb->String(op.tensorflow_node_def); 2037 }); 2038 fbb->Finish(); 2039 LOG(INFO) << "Writing flex op: " << node_def.op(); 2040 return std::unique_ptr<flexbuffers::Builder>(fbb.release()); 2041 } 2042 2043 bool has_valid_attr = false; 2044 size_t map_start = fbb->StartMap(); 2045 for (const auto& pair : node_def.attr()) { 2046 const char* key = pair.first.c_str(); 2047 const auto& attr = pair.second; 2048 switch (attr.value_case()) { 2049 case ::tensorflow::AttrValue::kS: 2050 fbb->String(key, attr.s()); 2051 has_valid_attr = true; 2052 break; 2053 case ::tensorflow::AttrValue::kI: 2054 fbb->Int(key, attr.i()); 2055 has_valid_attr = true; 2056 break; 2057 case ::tensorflow::AttrValue::kF: 2058 fbb->Float(key, attr.f()); 2059 has_valid_attr = true; 2060 break; 2061 case ::tensorflow::AttrValue::kB: 2062 fbb->Bool(key, attr.b()); 2063 has_valid_attr = true; 2064 break; 2065 case tensorflow::AttrValue::kList: 2066 if (attr.list().s_size() > 0) { 2067 auto start = fbb->StartVector(key); 2068 for (const string& v : attr.list().s()) { 2069 fbb->Add(v); 2070 } 2071 fbb->EndVector(start, /*typed=*/true, /*fixed=*/false); 2072 has_valid_attr = true; 2073 } else if (attr.list().i_size() > 0) { 2074 auto start = fbb->StartVector(key); 2075 for (const int64_t v : attr.list().i()) { 2076 fbb->Add(v); 2077 } 2078 fbb->EndVector(start, /*typed=*/true, /*fixed=*/false); 2079 has_valid_attr = true; 2080 } else if (attr.list().f_size() > 0) { 2081 auto start = fbb->StartVector(key); 2082 for (const float v : attr.list().f()) { 2083 fbb->Add(v); 2084 } 2085 fbb->EndVector(start, /*typed=*/true, /*fixed=*/false); 2086 has_valid_attr = true; 2087 } else { 2088 LOG(WARNING) 2089 << "Ignoring unsupported type in list attribute with key '" 2090 << key << "'"; 2091 } 2092 break; 2093 default: 2094 LOG(WARNING) << "Ignoring unsupported attribute type with key '" 2095 << key << "'"; 2096 break; 2097 } 2098 } 2099 if (!has_valid_attr) { 2100 return std::unique_ptr<flexbuffers::Builder>(); 2101 } 2102 fbb->EndMap(map_start); 2103 fbb->Finish(); 2104 return std::unique_ptr<flexbuffers::Builder>(fbb.release()); 2105 } 2106 2107 void ReadOptions(const flexbuffers::Map& m, 2108 TensorFlowUnsupportedOperator* op) const { 2109 ::tensorflow::NodeDef node_def; 2110 auto attr = node_def.mutable_attr(); 2111 2112 const auto& keys = m.Keys(); 2113 for (size_t i = 0; i < keys.size(); ++i) { 2114 const auto key = keys[i].AsKey(); 2115 const auto& value = m[key]; 2116 // TODO(wvo): hack to make this code compile with 2 different API 2117 // versions. 2118 // Please remove once OS/internal versions are in sync. 2119 // See hardcoded values in the switch below. 2120 switch (value.GetType()) { 2121 case 5: // flexbuffers::FBT_STRING: 2122 (*attr)[key].set_s(value.AsString().c_str()); 2123 break; 2124 case 1: // flexbuffers::FBT_INT: 2125 (*attr)[key].set_i(value.AsInt64()); 2126 break; 2127 case 3: // flexbuffers::FBT_FLOAT: 2128 (*attr)[key].set_f(value.AsFloat()); 2129 break; 2130 case 26: // flexbuffers::FBT_BOOL: 2131 (*attr)[key].set_b(value.AsBool()); 2132 if (string(key) == "_output_quantized") { 2133 op->quantized = value.AsBool(); 2134 } 2135 if (string(key) == "_support_output_type_float_in_quantized_op") { 2136 op->support_output_type_float_in_quantized_op = value.AsBool(); 2137 } 2138 break; 2139 case 11: { // flexbuffers::FBT_VECTOR_INT: { 2140 auto* list = (*attr)[key].mutable_list(); 2141 const auto& vector = value.AsTypedVector(); 2142 for (size_t i = 0; i < vector.size(); i++) { 2143 list->add_i(vector[i].AsInt64()); 2144 } 2145 break; 2146 } 2147 case 13: { // flexbuffers::FBT_VECTOR_FLOAT: { 2148 auto* list = (*attr)[key].mutable_list(); 2149 const auto& vector = value.AsTypedVector(); 2150 for (size_t i = 0; i < vector.size(); i++) { 2151 list->add_f(vector[i].AsFloat()); 2152 } 2153 break; 2154 } 2155 case 15: { // flexbuffers::FBT_VECTOR_STRING: { 2156 auto* list = (*attr)[key].mutable_list(); 2157 const auto& vector = value.AsTypedVector(); 2158 for (size_t i = 0; i < vector.size(); i++) { 2159 list->add_s(vector[i].AsString().str()); 2160 } 2161 break; 2162 } 2163 default: 2164 LOG(WARNING) << "Ignoring unsupported attribute type with key '" 2165 << key << "'"; 2166 break; 2167 } 2168 } 2169 node_def.SerializeToString(&op->tensorflow_node_def); 2170 } 2171 2172 int GetVersion(const OperatorSignature& op_signature) const override { 2173 // TODO(ycling): Design and implement a way to plumb the version of 2174 // custom ops. 2175 return 1; 2176 } 2177 2178 private: 2179 const bool enable_select_tf_ops_; 2180 }; 2181 2182 class Dequantize 2183 : public BuiltinOperator<DequantizeOperator, ::tflite::DequantizeOptions, 2184 ::tflite::BuiltinOptions_DequantizeOptions> { 2185 public: 2186 using BuiltinOperator::BuiltinOperator; 2187 2188 flatbuffers::Offset<TfLiteOptions> WriteOptions( 2189 const TocoOperator& op, 2190 flatbuffers::FlatBufferBuilder* builder) const override { 2191 return ::tflite::CreateDequantizeOptions(*builder); 2192 } 2193 2194 void ReadOptions(const TfLiteOptions& options, 2195 TocoOperator* op) const override {} 2196 2197 int GetVersion(const OperatorSignature& op_signature) const override { 2198 const string& input_name = op_signature.op->inputs[0]; 2199 const Array& input_array = op_signature.model->GetArray(input_name); 2200 // Version 2 supports signed int8 input types. 2201 if (input_array.data_type == ArrayDataType::kInt8) { 2202 return 2; 2203 } 2204 return 1; 2205 } 2206 }; 2207 2208 class ReverseSequence 2209 : public BuiltinOperator<ReverseSequenceOperator, 2210 ::tflite::ReverseSequenceOptions, 2211 ::tflite::BuiltinOptions_ReverseSequenceOptions> { 2212 public: 2213 using BuiltinOperator::BuiltinOperator; 2214 2215 flatbuffers::Offset<TfLiteOptions> WriteOptions( 2216 const TocoOperator& op, 2217 flatbuffers::FlatBufferBuilder* builder) const override { 2218 return ::tflite::CreateReverseSequenceOptions(*builder, op.seq_dim, 2219 op.batch_dim); 2220 } 2221 2222 void ReadOptions(const TfLiteOptions& options, 2223 TocoOperator* op) const override { 2224 op->seq_dim = options.seq_dim(); 2225 op->batch_dim = options.batch_dim(); 2226 } 2227 2228 int GetVersion(const OperatorSignature& op_signature) const override { 2229 return 1; 2230 } 2231 }; 2232 2233 class Equal : public SimpleOperator<TensorFlowEqualOperator> { 2234 public: 2235 explicit Equal() : SimpleOperator("EQUAL", OperatorType::kEqual) {} 2236 int GetVersion(const OperatorSignature& op_signature) const override { 2237 const string& input_name = op_signature.op->inputs[0]; 2238 const Array& input_array = op_signature.model->GetArray(input_name); 2239 // Version 2 supports signed int8 input types. 2240 if (input_array.data_type == ArrayDataType::kInt8) { 2241 return 2; 2242 } 2243 return 1; 2244 } 2245 }; 2246 2247 class NotEqual : public SimpleOperator<TensorFlowNotEqualOperator> { 2248 public: 2249 explicit NotEqual() : SimpleOperator("NOT_EQUAL", OperatorType::kNotEqual) {} 2250 int GetVersion(const OperatorSignature& op_signature) const override { 2251 const string& input_name = op_signature.op->inputs[0]; 2252 const Array& input_array = op_signature.model->GetArray(input_name); 2253 // Version 2 supports signed int8 input types. 2254 if (input_array.data_type == ArrayDataType::kInt8) { 2255 return 2; 2256 } 2257 return 1; 2258 } 2259 }; 2260 2261 class Greater : public SimpleOperator<TensorFlowGreaterOperator> { 2262 public: 2263 explicit Greater() : SimpleOperator("GREATER", OperatorType::kGreater) {} 2264 int GetVersion(const OperatorSignature& op_signature) const override { 2265 const string& input_name = op_signature.op->inputs[0]; 2266 const Array& input_array = op_signature.model->GetArray(input_name); 2267 // Version 2 supports signed int8 input types. 2268 if (input_array.data_type == ArrayDataType::kInt8) { 2269 return 2; 2270 } 2271 return 1; 2272 } 2273 }; 2274 2275 class GreaterEqual : public SimpleOperator<TensorFlowGreaterEqualOperator> { 2276 public: 2277 explicit GreaterEqual() 2278 : SimpleOperator("GREATER_EQUAL", OperatorType::kGreaterEqual) {} 2279 int GetVersion(const OperatorSignature& op_signature) const override { 2280 const string& input_name = op_signature.op->inputs[0]; 2281 const Array& input_array = op_signature.model->GetArray(input_name); 2282 // Version 2 supports signed int8 input types. 2283 if (input_array.data_type == ArrayDataType::kInt8) { 2284 return 2; 2285 } 2286 return 1; 2287 } 2288 }; 2289 2290 class Less : public SimpleOperator<TensorFlowLessOperator> { 2291 public: 2292 explicit Less() : SimpleOperator("LESS", OperatorType::kLess) {} 2293 int GetVersion(const OperatorSignature& op_signature) const override { 2294 const string& input_name = op_signature.op->inputs[0]; 2295 const Array& input_array = op_signature.model->GetArray(input_name); 2296 // Version 2 supports signed int8 input types. 2297 if (input_array.data_type == ArrayDataType::kInt8) { 2298 return 2; 2299 } 2300 return 1; 2301 } 2302 }; 2303 2304 class LessEqual : public SimpleOperator<TensorFlowLessEqualOperator> { 2305 public: 2306 explicit LessEqual() 2307 : SimpleOperator("LESS_EQUAL", OperatorType::kLessEqual) {} 2308 int GetVersion(const OperatorSignature& op_signature) const override { 2309 const string& input_name = op_signature.op->inputs[0]; 2310 const Array& input_array = op_signature.model->GetArray(input_name); 2311 // Version 2 supports signed int8 input types. 2312 if (input_array.data_type == ArrayDataType::kInt8) { 2313 return 2; 2314 } 2315 return 1; 2316 } 2317 }; 2318 2319 class Select : public SimpleOperator<SelectOperator> { 2320 public: 2321 explicit Select() : SimpleOperator("SELECT", OperatorType::kSelect) {} 2322 int GetVersion(const OperatorSignature& op_signature) const override { 2323 const string& input_name = op_signature.op->inputs[0]; 2324 const Array& input_array = op_signature.model->GetArray(input_name); 2325 // Version 2 supports signed int8 input types. 2326 if (input_array.data_type == ArrayDataType::kInt8) { 2327 return 2; 2328 } 2329 return 1; 2330 } 2331 }; 2332 2333 namespace { 2334 // Build a vector containing all the known operators. 2335 std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList( 2336 bool enable_select_tf_ops = false) { 2337 std::vector<std::unique_ptr<BaseOperator>> ops; 2338 using tensorflow::MakeUnique; 2339 // Builtin Operators. 2340 ops.push_back( 2341 MakeUnique<Add>(::tflite::BuiltinOperator_ADD, OperatorType::kAdd)); 2342 ops.push_back( 2343 MakeUnique<AddN>(::tflite::BuiltinOperator_ADD_N, OperatorType::kAddN)); 2344 ops.push_back( 2345 MakeUnique<Div>(::tflite::BuiltinOperator_DIV, OperatorType::kDiv)); 2346 ops.push_back( 2347 MakeUnique<Sub>(::tflite::BuiltinOperator_SUB, OperatorType::kSub)); 2348 ops.push_back(MakeUnique<AveragePool>( 2349 ::tflite::BuiltinOperator_AVERAGE_POOL_2D, OperatorType::kAveragePool)); 2350 ops.push_back( 2351 MakeUnique<SpaceToBatchND>(::tflite::BuiltinOperator_SPACE_TO_BATCH_ND, 2352 OperatorType::kSpaceToBatchND)); 2353 ops.push_back( 2354 MakeUnique<BatchToSpaceND>(::tflite::BuiltinOperator_BATCH_TO_SPACE_ND, 2355 OperatorType::kBatchToSpaceND)); 2356 ops.push_back(MakeUnique<Concatenation>( 2357 ::tflite::BuiltinOperator_CONCATENATION, OperatorType::kConcatenation)); 2358 ops.push_back(MakeUnique<Convolution>(::tflite::BuiltinOperator_CONV_2D, 2359 OperatorType::kConv)); 2360 ops.push_back(MakeUnique<DepthwiseConvolution>( 2361 ::tflite::BuiltinOperator_DEPTHWISE_CONV_2D, 2362 OperatorType::kDepthwiseConv)); 2363 ops.push_back(MakeUnique<Dequantize>(::tflite::BuiltinOperator_DEQUANTIZE, 2364 OperatorType::kDequantize)); 2365 ops.push_back( 2366 MakeUnique<FullyConnected>(::tflite::BuiltinOperator_FULLY_CONNECTED, 2367 OperatorType::kFullyConnected)); 2368 ops.push_back(MakeUnique<Gather>(::tflite::BuiltinOperator_GATHER, 2369 OperatorType::kGather)); 2370 ops.push_back(MakeUnique<GatherNd>(::tflite::BuiltinOperator_GATHER_ND, 2371 OperatorType::kGatherNd)); 2372 ops.push_back( 2373 MakeUnique<L2Normalization>(::tflite::BuiltinOperator_L2_NORMALIZATION, 2374 OperatorType::kL2Normalization)); 2375 ops.push_back(MakeUnique<L2Pool>(::tflite::BuiltinOperator_L2_POOL_2D, 2376 OperatorType::kL2Pool)); 2377 ops.push_back(MakeUnique<LocalResponseNormalization>( 2378 ::tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION, 2379 OperatorType::kLocalResponseNormalization)); 2380 ops.push_back(MakeUnique<MaxPool>(::tflite::BuiltinOperator_MAX_POOL_2D, 2381 OperatorType::kMaxPool)); 2382 ops.push_back( 2383 MakeUnique<Mul>(::tflite::BuiltinOperator_MUL, OperatorType::kMul)); 2384 2385 ops.push_back( 2386 MakeUnique<Pad>(::tflite::BuiltinOperator_PAD, OperatorType::kPad)); 2387 ops.push_back( 2388 MakeUnique<PadV2>(::tflite::BuiltinOperator_PADV2, OperatorType::kPadV2)); 2389 ops.push_back(MakeUnique<Reshape>(::tflite::BuiltinOperator_RESHAPE, 2390 OperatorType::kReshape)); 2391 ops.push_back(MakeUnique<Softmax>(::tflite::BuiltinOperator_SOFTMAX, 2392 OperatorType::kSoftmax)); 2393 ops.push_back(MakeUnique<SpaceToDepth>( 2394 ::tflite::BuiltinOperator_SPACE_TO_DEPTH, OperatorType::kSpaceToDepth)); 2395 ops.push_back( 2396 MakeUnique<Svdf>(::tflite::BuiltinOperator_SVDF, OperatorType::kSvdf)); 2397 ops.push_back(MakeUnique<Transpose>(::tflite::BuiltinOperator_TRANSPOSE, 2398 OperatorType::kTranspose)); 2399 ops.push_back( 2400 MakeUnique<Mean>(::tflite::BuiltinOperator_MEAN, OperatorType::kMean)); 2401 ops.push_back( 2402 MakeUnique<Sum>(::tflite::BuiltinOperator_SUM, OperatorType::kSum)); 2403 ops.push_back(MakeUnique<ReduceProd>(::tflite::BuiltinOperator_REDUCE_PROD, 2404 OperatorType::kReduceProd)); 2405 ops.push_back(MakeUnique<ReduceMax>(::tflite::BuiltinOperator_REDUCE_MAX, 2406 OperatorType::kReduceMax)); 2407 ops.push_back(MakeUnique<ReduceMin>(::tflite::BuiltinOperator_REDUCE_MIN, 2408 OperatorType::kReduceMin)); 2409 ops.push_back(MakeUnique<ReduceAny>(::tflite::BuiltinOperator_REDUCE_ANY, 2410 OperatorType::kAny)); 2411 ops.push_back( 2412 MakeUnique<ResizeBilinear>(::tflite::BuiltinOperator_RESIZE_BILINEAR, 2413 OperatorType::kResizeBilinear)); 2414 ops.push_back(MakeUnique<ResizeNearestNeighbor>( 2415 ::tflite::BuiltinOperator_RESIZE_NEAREST_NEIGHBOR, 2416 OperatorType::kResizeNearestNeighbor)); 2417 ops.push_back(MakeUnique<Squeeze>(::tflite::BuiltinOperator_SQUEEZE, 2418 OperatorType::kSqueeze)); 2419 ops.push_back( 2420 MakeUnique<Split>(::tflite::BuiltinOperator_SPLIT, OperatorType::kSplit)); 2421 ops.push_back(MakeUnique<SplitV>(::tflite::BuiltinOperator_SPLIT_V, 2422 OperatorType::kSplitV)); 2423 ops.push_back(MakeUnique<StridedSlice>( 2424 ::tflite::BuiltinOperator_STRIDED_SLICE, OperatorType::kStridedSlice)); 2425 ops.push_back(MakeUnique<TopK_V2>(::tflite::BuiltinOperator_TOPK_V2, 2426 OperatorType::kTopK_V2)); 2427 ops.push_back(MakeUnique<Lstm>(::tflite::BuiltinOperator_LSTM, 2428 OperatorType::kLstmCell)); 2429 ops.push_back( 2430 MakeUnique<Cast>(::tflite::BuiltinOperator_CAST, OperatorType::kCast)); 2431 ops.push_back(MakeUnique<ArgMax>(::tflite::BuiltinOperator_ARG_MAX, 2432 OperatorType::kArgMax)); 2433 ops.push_back(MakeUnique<ArgMin>(::tflite::BuiltinOperator_ARG_MIN, 2434 OperatorType::kArgMin)); 2435 ops.push_back( 2436 MakeUnique<Tile>(::tflite::BuiltinOperator_TILE, OperatorType::kTile)); 2437 ops.push_back(MakeUnique<ExpandDims>(::tflite::BuiltinOperator_EXPAND_DIMS, 2438 OperatorType::kExpandDims)); 2439 ops.push_back(MakeUnique<TransposeConv>( 2440 ::tflite::BuiltinOperator_TRANSPOSE_CONV, OperatorType::kTransposeConv)); 2441 ops.push_back(MakeUnique<SparseToDense>( 2442 ::tflite::BuiltinOperator_SPARSE_TO_DENSE, OperatorType::kSparseToDense)); 2443 ops.push_back( 2444 MakeUnique<Shape>(::tflite::BuiltinOperator_SHAPE, OperatorType::kShape)); 2445 ops.push_back(MakeUnique<FakeQuant>(::tflite::BuiltinOperator_FAKE_QUANT, 2446 OperatorType::kFakeQuant)); 2447 ops.push_back( 2448 MakeUnique<Pack>(::tflite::BuiltinOperator_PACK, OperatorType::kPack)); 2449 ops.emplace_back(MakeUnique<UnidirectionalSequenceLstm>( 2450 ::tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, 2451 OperatorType::kUnidirectionalSequenceLstm)); 2452 ops.emplace_back(MakeUnique<BidirectionalSequenceLstm>( 2453 ::tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM, 2454 OperatorType::kBidirectionalSequenceLstm)); 2455 ops.emplace_back(MakeUnique<BidirectionalSequenceRnn>( 2456 ::tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN, 2457 OperatorType::kBidirectionalSequenceRnn)); 2458 ops.push_back(MakeUnique<OneHot>(::tflite::BuiltinOperator_ONE_HOT, 2459 OperatorType::kOneHot)); 2460 ops.push_back(MakeUnique<Unpack>(::tflite::BuiltinOperator_UNPACK, 2461 OperatorType::kUnpack)); 2462 ops.push_back(MakeUnique<LeakyRelu>(::tflite::BuiltinOperator_LEAKY_RELU, 2463 OperatorType::kLeakyRelu)); 2464 ops.push_back(MakeUnique<SquaredDifference>( 2465 ::tflite::BuiltinOperator_SQUARED_DIFFERENCE, 2466 OperatorType::kSquaredDifference)); 2467 ops.push_back(MakeUnique<MirrorPad>(::tflite::BuiltinOperator_MIRROR_PAD, 2468 OperatorType::kMirrorPad)); 2469 ops.push_back(MakeUnique<Unique>(::tflite::BuiltinOperator_UNIQUE, 2470 OperatorType::kUnique)); 2471 ops.push_back(MakeUnique<UnidirectionalSequenceRnn>( 2472 ::tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN, 2473 OperatorType::kUnidirectionalSequenceRnn)); 2474 ops.push_back( 2475 MakeUnique<Where>(::tflite::BuiltinOperator_WHERE, OperatorType::kWhere)); 2476 ops.push_back( 2477 MakeUnique<ReverseSequence>(::tflite::BuiltinOperator_REVERSE_SEQUENCE, 2478 OperatorType::kReverseSequence)); 2479 2480 // Custom Operators. 2481 ops.push_back( 2482 MakeUnique<DepthToSpace>("DEPTH_TO_SPACE", OperatorType::kDepthToSpace)); 2483 ops.push_back(MakeUnique<CTCBeamSearchDecoder>( 2484 "CTC_BEAM_SEARCH_DECODER", OperatorType::kCTCBeamSearchDecoder)); 2485 ops.push_back(MakeUnique<TensorFlowUnsupported>("TENSORFLOW_UNSUPPORTED", 2486 OperatorType::kUnsupported, 2487 enable_select_tf_ops)); 2488 2489 // SimpleOperator was designed to export CUSTOM TF Lite ops, but has since 2490 // been modified to also export builtins. As TOCO evolved we added warnings 2491 // when custom ops are exported but SimpleOperator bypasses thoses. To 2492 // prevent user confusion we are settling on using SimpleOperator only for 2493 // builtins. 2494 ops.push_back( 2495 MakeUnique<SimpleOperator<FloorOperator>>("FLOOR", OperatorType::kFloor)); 2496 ops.push_back( 2497 MakeUnique<SimpleOperator<CeilOperator>>("CEIL", OperatorType::kCeil)); 2498 ops.push_back( 2499 MakeUnique<SimpleOperator<EluOperator>>("ELU", OperatorType::kElu)); 2500 ops.push_back( 2501 MakeUnique<SimpleOperator<ReluOperator>>("RELU", OperatorType::kRelu)); 2502 ops.push_back(MakeUnique<SimpleOperator<Relu1Operator>>( 2503 "RELU_N1_TO_1", OperatorType::kRelu1)); 2504 ops.push_back(MakeUnique<Relu6>()); 2505 ops.push_back( 2506 MakeUnique<SimpleOperator<PReluOperator>>("PRELU", OperatorType::kPRelu)); 2507 ops.push_back(MakeUnique<Logistic>()); 2508 ops.push_back(MakeUnique<Tanh>()); 2509 ops.push_back( 2510 MakeUnique<SimpleOperator<ExpOperator>>("EXP", OperatorType::kExp)); 2511 ops.push_back( 2512 MakeUnique<SimpleOperator<CosOperator>>("COS", OperatorType::kCos)); 2513 ops.push_back(MakeUnique<LogSoftmax>()); 2514 ops.push_back(MakeUnique<Maximum>()); // Element-wise Maximum 2515 ops.push_back(MakeUnique<Minimum>()); // Element-wise Minimum 2516 ops.push_back(MakeUnique<Greater>()); 2517 ops.push_back(MakeUnique<GreaterEqual>()); 2518 ops.push_back(MakeUnique<Less>()); 2519 ops.push_back(MakeUnique<LessEqual>()); 2520 ops.push_back(MakeUnique<Equal>()); 2521 ops.push_back(MakeUnique<NotEqual>()); 2522 ops.push_back( 2523 MakeUnique<SimpleOperator<NegOperator>>("NEG", OperatorType::kNeg)); 2524 ops.push_back(MakeUnique<Select>()); 2525 ops.push_back(MakeUnique<Slice>()); 2526 ops.push_back( 2527 MakeUnique<SimpleOperator<PowOperator>>("POW", OperatorType::kPow)); 2528 ops.push_back(MakeUnique<SimpleOperator<LogicalOrOperator>>( 2529 "LOGICAL_OR", OperatorType::kLogicalOr)); 2530 ops.emplace_back(new SimpleOperator<LogicalAndOperator>( 2531 "LOGICAL_AND", OperatorType::kLogicalAnd)); 2532 ops.emplace_back(new SimpleOperator<LogicalNotOperator>( 2533 "LOGICAL_NOT", OperatorType::kLogicalNot)); 2534 ops.emplace_back(new SimpleOperator<FloorDivOperator>( 2535 "FLOOR_DIV", OperatorType::kFloorDiv)); 2536 ops.emplace_back(new SimpleOperator<FloorModOperator>( 2537 "FLOOR_MOD", OperatorType::kFloorMod)); 2538 ops.emplace_back( 2539 new SimpleOperator<RangeOperator>("RANGE", OperatorType::kRange)); 2540 // Element-wise operator 2541 ops.push_back( 2542 MakeUnique<SimpleOperator<SinOperator>>("SIN", OperatorType::kSin)); 2543 ops.push_back( 2544 MakeUnique<SimpleOperator<LogOperator>>("LOG", OperatorType::kLog)); 2545 ops.push_back(MakeUnique<SimpleOperator<TensorFlowSqrtOperator>>( 2546 "SQRT", OperatorType::kSqrt)); 2547 ops.push_back(MakeUnique<SimpleOperator<TensorFlowRsqrtOperator>>( 2548 "RSQRT", OperatorType::kRsqrt)); 2549 ops.push_back(MakeUnique<SimpleOperator<TensorFlowSquareOperator>>( 2550 "SQUARE", OperatorType::kSquare)); 2551 ops.push_back(MakeUnique<SimpleOperator<TensorFlowZerosLikeOperator>>( 2552 "ZEROS_LIKE", OperatorType::kZerosLike)); 2553 ops.push_back( 2554 MakeUnique<SimpleOperator<AbsOperator>>("ABS", OperatorType::kAbs)); 2555 ops.push_back( 2556 MakeUnique<SimpleOperator<FillOperator>>("FILL", OperatorType::kFill)); 2557 ops.push_back(MakeUnique<SimpleOperator<ReverseV2Operator>>( 2558 "REVERSE_V2", OperatorType::kReverseV2)); 2559 ops.push_back(MakeUnique<SimpleOperator<TensorFlowRankOperator>>( 2560 "RANK", OperatorType::kRank)); 2561 return ops; 2562 } 2563 } // namespace 2564 2565 std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap( 2566 bool enable_select_tf_ops) { 2567 std::map<OperatorType, std::unique_ptr<BaseOperator>> result; 2568 2569 std::vector<std::unique_ptr<BaseOperator>> ops = 2570 BuildOperatorList(enable_select_tf_ops); 2571 for (auto& op : ops) { 2572 result[op->type()] = std::move(op); 2573 } 2574 2575 return result; 2576 } 2577 2578 std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap( 2579 bool enable_select_tf_ops) { 2580 std::map<string, std::unique_ptr<BaseOperator>> result; 2581 2582 std::vector<std::unique_ptr<BaseOperator>> ops = 2583 BuildOperatorList(enable_select_tf_ops); 2584 for (auto& op : ops) { 2585 result[op->name()] = std::move(op); 2586 } 2587 2588 return result; 2589 } 2590 2591 bool ShouldExportAsFlexOp(bool enable_select_tf_ops, 2592 const string& tensorflow_op_name) { 2593 // If Flex ops aren't allow at all, simply return false. 2594 if (!enable_select_tf_ops) { 2595 return false; 2596 } 2597 // Check if we can find the `OpDef` for the TensorFlow op. If we can find 2598 // it and it has been whitelisted, export the op as an Flex op. Otherwise, 2599 // export it as a regular custom op. 2600 const tensorflow::OpDef* op_def = nullptr; 2601 if (!tensorflow::OpRegistry::Global() 2602 ->LookUpOpDef(tensorflow_op_name, &op_def) 2603 .ok()) { 2604 return false; 2605 } 2606 2607 if (!IsWhitelistedFlexOp(tensorflow_op_name)) { 2608 LOG(WARNING) << "Op " << tensorflow_op_name 2609 << " is a valid TensorFlow op but has not been whitelisted for" 2610 " the TensorFlow Lite flex op set."; 2611 return false; 2612 } 2613 2614 return true; 2615 } 2616 2617 } // namespace tflite 2618 2619 } // namespace toco 2620