1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/contrib/lite/toco/tflite/operator.h" 16 17 #include "tensorflow/contrib/lite/toco/tflite/builtin_operator.h" 18 #include "tensorflow/contrib/lite/toco/tflite/custom_operator.h" 19 #include "tensorflow/contrib/lite/toco/tflite/simple_operator.h" 20 #include "tensorflow/contrib/lite/toco/tflite/types.h" 21 22 #include "tensorflow/core/framework/attr_value.pb.h" 23 #include "tensorflow/core/framework/node_def.pb.h" 24 25 namespace toco { 26 27 namespace tflite { 28 29 class AveragePool 30 : public BuiltinOperator<AveragePoolOperator, ::tflite::Pool2DOptions, 31 ::tflite::BuiltinOptions_Pool2DOptions> { 32 public: 33 using BuiltinOperator::BuiltinOperator; 34 35 flatbuffers::Offset<TfLiteOptions> WriteOptions( 36 const TocoOperator& op, 37 flatbuffers::FlatBufferBuilder* builder) const override { 38 auto padding = Padding::Serialize(op.padding.type); 39 auto activation_function = 40 ActivationFunction::Serialize(op.fused_activation_function); 41 return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width, 42 op.stride_height, op.kwidth, 43 op.kheight, activation_function); 44 } 45 46 void ReadOptions(const TfLiteOptions& options, 47 TocoOperator* op) const override { 48 op->padding.type = Padding::Deserialize(options.padding()); 49 op->stride_width = options.stride_w(); 50 op->stride_height = options.stride_h(); 51 op->kwidth = options.filter_width(); 52 op->kheight = options.filter_height(); 53 op->fused_activation_function = 54 ActivationFunction::Deserialize(options.fused_activation_function()); 55 } 56 }; 57 58 class Convolution 59 : public BuiltinOperator<ConvOperator, ::tflite::Conv2DOptions, 60 ::tflite::BuiltinOptions_Conv2DOptions> { 61 public: 62 using BuiltinOperator::BuiltinOperator; 63 64 flatbuffers::Offset<TfLiteOptions> WriteOptions( 65 const TocoOperator& op, 66 flatbuffers::FlatBufferBuilder* builder) const override { 67 auto padding = Padding::Serialize(op.padding.type); 68 auto activation_function = 69 ActivationFunction::Serialize(op.fused_activation_function); 70 return ::tflite::CreateConv2DOptions(*builder, padding, op.stride_width, 71 op.stride_height, activation_function); 72 } 73 74 void ReadOptions(const TfLiteOptions& options, 75 TocoOperator* op) const override { 76 op->padding.type = Padding::Deserialize(options.padding()); 77 op->stride_width = options.stride_w(); 78 op->stride_height = options.stride_h(); 79 op->fused_activation_function = 80 ActivationFunction::Deserialize(options.fused_activation_function()); 81 } 82 }; 83 84 class DepthwiseConvolution 85 : public BuiltinOperator<DepthwiseConvOperator, 86 ::tflite::DepthwiseConv2DOptions, 87 ::tflite::BuiltinOptions_DepthwiseConv2DOptions> { 88 public: 89 using BuiltinOperator::BuiltinOperator; 90 91 flatbuffers::Offset<TfLiteOptions> WriteOptions( 92 const TocoOperator& op, 93 flatbuffers::FlatBufferBuilder* builder) const override { 94 auto padding = Padding::Serialize(op.padding.type); 95 auto activation_function = 96 ActivationFunction::Serialize(op.fused_activation_function); 97 return ::tflite::CreateDepthwiseConv2DOptions( 98 *builder, padding, op.stride_width, op.stride_height, 99 op.depth_multiplier, activation_function); 100 } 101 102 void ReadOptions(const TfLiteOptions& options, 103 TocoOperator* op) const override { 104 op->padding.type = Padding::Deserialize(options.padding()); 105 op->stride_width = options.stride_w(); 106 op->stride_height = options.stride_h(); 107 op->depth_multiplier = options.depth_multiplier(); 108 op->fused_activation_function = 109 ActivationFunction::Deserialize(options.fused_activation_function()); 110 } 111 }; 112 113 class Add : public BuiltinOperator<AddOperator, ::tflite::AddOptions, 114 ::tflite::BuiltinOptions_AddOptions> { 115 public: 116 using BuiltinOperator::BuiltinOperator; 117 118 flatbuffers::Offset<TfLiteOptions> WriteOptions( 119 const TocoOperator& op, 120 flatbuffers::FlatBufferBuilder* builder) const override { 121 auto activation_function = 122 ActivationFunction::Serialize(op.fused_activation_function); 123 return ::tflite::CreateAddOptions(*builder, activation_function); 124 } 125 126 void ReadOptions(const TfLiteOptions& options, 127 TocoOperator* op) const override { 128 op->fused_activation_function = 129 ActivationFunction::Deserialize(options.fused_activation_function()); 130 } 131 }; 132 133 class SpaceToBatchND 134 : public BuiltinOperator<SpaceToBatchNDOperator, 135 ::tflite::SpaceToBatchNDOptions, 136 ::tflite::BuiltinOptions_SpaceToBatchNDOptions> { 137 public: 138 using BuiltinOperator::BuiltinOperator; 139 140 flatbuffers::Offset<TfLiteOptions> WriteOptions( 141 const TocoOperator& op, 142 flatbuffers::FlatBufferBuilder* builder) const override { 143 return ::tflite::CreateSpaceToBatchNDOptions(*builder); 144 } 145 146 void ReadOptions(const TfLiteOptions& options, 147 TocoOperator* op) const override {} 148 }; 149 150 class Sub : public BuiltinOperator<SubOperator, ::tflite::SubOptions, 151 ::tflite::BuiltinOptions_SubOptions> { 152 public: 153 using BuiltinOperator::BuiltinOperator; 154 155 flatbuffers::Offset<TfLiteOptions> WriteOptions( 156 const TocoOperator& op, 157 flatbuffers::FlatBufferBuilder* builder) const override { 158 auto activation_function = 159 ActivationFunction::Serialize(op.fused_activation_function); 160 return ::tflite::CreateSubOptions(*builder, activation_function); 161 } 162 163 void ReadOptions(const TfLiteOptions& options, 164 TocoOperator* op) const override { 165 op->fused_activation_function = 166 ActivationFunction::Deserialize(options.fused_activation_function()); 167 } 168 }; 169 170 class Div : public BuiltinOperator<DivOperator, ::tflite::DivOptions, 171 ::tflite::BuiltinOptions_DivOptions> { 172 public: 173 using BuiltinOperator::BuiltinOperator; 174 175 flatbuffers::Offset<TfLiteOptions> WriteOptions( 176 const TocoOperator& op, 177 flatbuffers::FlatBufferBuilder* builder) const override { 178 auto activation_function = 179 ActivationFunction::Serialize(op.fused_activation_function); 180 return ::tflite::CreateDivOptions(*builder, activation_function); 181 } 182 183 void ReadOptions(const TfLiteOptions& options, 184 TocoOperator* op) const override { 185 op->fused_activation_function = 186 ActivationFunction::Deserialize(options.fused_activation_function()); 187 } 188 }; 189 190 class BatchToSpaceND 191 : public BuiltinOperator<BatchToSpaceNDOperator, 192 ::tflite::BatchToSpaceNDOptions, 193 ::tflite::BuiltinOptions_BatchToSpaceNDOptions> { 194 public: 195 using BuiltinOperator::BuiltinOperator; 196 197 flatbuffers::Offset<TfLiteOptions> WriteOptions( 198 const TocoOperator& op, 199 flatbuffers::FlatBufferBuilder* builder) const override { 200 return ::tflite::CreateBatchToSpaceNDOptions(*builder); 201 } 202 203 void ReadOptions(const TfLiteOptions& options, 204 TocoOperator* op) const override {} 205 }; 206 207 class Cast : public CustomOperator<CastOperator> { 208 public: 209 using CustomOperator::CustomOperator; 210 void WriteOptions(const TocoOperator& op, 211 flexbuffers::Builder* fbb) const override { 212 fbb->Int("src_data_type", DataType::Serialize(op.src_data_type)); 213 fbb->Int("dst_data_type", DataType::Serialize(op.dst_data_type)); 214 } 215 void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override { 216 op->src_data_type = DataType::Deserialize(m["src_data_type"].AsInt64()); 217 op->dst_data_type = DataType::Deserialize(m["dst_data_type"].AsInt64()); 218 } 219 }; 220 221 class Concatenation 222 : public BuiltinOperator<ConcatenationOperator, 223 ::tflite::ConcatenationOptions, 224 ::tflite::BuiltinOptions_ConcatenationOptions> { 225 public: 226 using BuiltinOperator::BuiltinOperator; 227 flatbuffers::Offset<TfLiteOptions> WriteOptions( 228 const TocoOperator& op, 229 flatbuffers::FlatBufferBuilder* builder) const override { 230 return ::tflite::CreateConcatenationOptions(*builder, op.axis); 231 } 232 233 void ReadOptions(const TfLiteOptions& options, 234 TocoOperator* op) const override { 235 op->axis = options.axis(); 236 } 237 }; 238 239 class DepthToSpace : public CustomOperator<DepthToSpaceOperator> { 240 public: 241 using CustomOperator::CustomOperator; 242 void WriteOptions(const TocoOperator& op, 243 flexbuffers::Builder* fbb) const override { 244 fbb->Int("block_size", op.block_size); 245 } 246 void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override { 247 op->block_size = m["block_size"].AsInt64(); 248 } 249 }; 250 251 class FakeQuant : public CustomOperator<FakeQuantOperator> { 252 public: 253 using CustomOperator::CustomOperator; 254 void WriteOptions(const TocoOperator& op, 255 flexbuffers::Builder* fbb) const override { 256 fbb->Float("min", op.minmax->min); 257 fbb->Float("max", op.minmax->max); 258 } 259 void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override { 260 auto* minmax = new MinMax; 261 minmax->min = m["min"].AsFloat(); 262 minmax->max = m["max"].AsFloat(); 263 op->minmax.reset(minmax); 264 } 265 }; 266 267 class FullyConnected 268 : public BuiltinOperator<FullyConnectedOperator, 269 ::tflite::FullyConnectedOptions, 270 ::tflite::BuiltinOptions_FullyConnectedOptions> { 271 public: 272 using BuiltinOperator::BuiltinOperator; 273 flatbuffers::Offset<TfLiteOptions> WriteOptions( 274 const TocoOperator& op, 275 flatbuffers::FlatBufferBuilder* builder) const override { 276 auto activation_function = 277 ActivationFunction::Serialize(op.fused_activation_function); 278 return ::tflite::CreateFullyConnectedOptions(*builder, activation_function); 279 } 280 281 void ReadOptions(const TfLiteOptions& options, 282 TocoOperator* op) const override { 283 op->fused_activation_function = 284 ActivationFunction::Deserialize(options.fused_activation_function()); 285 } 286 }; 287 288 class Gather : public BuiltinOperator<GatherOperator, ::tflite::GatherOptions, 289 ::tflite::BuiltinOptions_GatherOptions> { 290 public: 291 using BuiltinOperator::BuiltinOperator; 292 flatbuffers::Offset<TfLiteOptions> WriteOptions( 293 const TocoOperator& op, 294 flatbuffers::FlatBufferBuilder* builder) const override { 295 return ::tflite::CreateGatherOptions(*builder, op.axis); 296 } 297 298 void ReadOptions(const TfLiteOptions& options, 299 TocoOperator* op) const override { 300 op->axis = options.axis(); 301 } 302 }; 303 304 class Svdf : public BuiltinOperator<SvdfOperator, ::tflite::SVDFOptions, 305 ::tflite::BuiltinOptions_SVDFOptions> { 306 public: 307 using BuiltinOperator::BuiltinOperator; 308 flatbuffers::Offset<TfLiteOptions> WriteOptions( 309 const TocoOperator& op, 310 flatbuffers::FlatBufferBuilder* builder) const override { 311 auto activation_function = 312 ActivationFunction::Serialize(op.fused_activation_function); 313 return ::tflite::CreateSVDFOptions(*builder, op.rank, activation_function); 314 } 315 316 void ReadOptions(const TfLiteOptions& options, 317 TocoOperator* op) const override { 318 op->fused_activation_function = 319 ActivationFunction::Deserialize(options.fused_activation_function()); 320 op->rank = options.rank(); 321 } 322 }; 323 324 class L2Normalization 325 : public BuiltinOperator<L2NormalizationOperator, ::tflite::L2NormOptions, 326 ::tflite::BuiltinOptions_L2NormOptions> { 327 public: 328 using BuiltinOperator::BuiltinOperator; 329 flatbuffers::Offset<TfLiteOptions> WriteOptions( 330 const TocoOperator& op, 331 flatbuffers::FlatBufferBuilder* builder) const override { 332 auto activation_function = 333 ActivationFunction::Serialize(op.fused_activation_function); 334 return ::tflite::CreateL2NormOptions(*builder, activation_function); 335 } 336 337 void ReadOptions(const TfLiteOptions& options, 338 TocoOperator* op) const override { 339 op->fused_activation_function = 340 ActivationFunction::Deserialize(options.fused_activation_function()); 341 } 342 }; 343 344 class L2Pool : public BuiltinOperator<L2PoolOperator, ::tflite::Pool2DOptions, 345 ::tflite::BuiltinOptions_Pool2DOptions> { 346 public: 347 using BuiltinOperator::BuiltinOperator; 348 flatbuffers::Offset<TfLiteOptions> WriteOptions( 349 const TocoOperator& op, 350 flatbuffers::FlatBufferBuilder* builder) const override { 351 auto padding = Padding::Serialize(op.padding.type); 352 auto activation_function = 353 ActivationFunction::Serialize(op.fused_activation_function); 354 return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width, 355 op.stride_height, op.kwidth, 356 op.kheight, activation_function); 357 } 358 359 void ReadOptions(const TfLiteOptions& options, 360 TocoOperator* op) const override { 361 op->padding.type = Padding::Deserialize(options.padding()); 362 op->stride_width = options.stride_w(); 363 op->stride_height = options.stride_h(); 364 op->kwidth = options.filter_width(); 365 op->kheight = options.filter_height(); 366 op->fused_activation_function = 367 ActivationFunction::Deserialize(options.fused_activation_function()); 368 } 369 }; 370 371 class LocalResponseNormalization 372 : public BuiltinOperator< 373 LocalResponseNormalizationOperator, 374 ::tflite::LocalResponseNormalizationOptions, 375 ::tflite::BuiltinOptions_LocalResponseNormalizationOptions> { 376 public: 377 using BuiltinOperator::BuiltinOperator; 378 flatbuffers::Offset<TfLiteOptions> WriteOptions( 379 const TocoOperator& op, 380 flatbuffers::FlatBufferBuilder* builder) const override { 381 return ::tflite::CreateLocalResponseNormalizationOptions( 382 *builder, op.range, op.bias, op.alpha, op.beta); 383 } 384 385 void ReadOptions(const TfLiteOptions& options, 386 TocoOperator* op) const override { 387 op->range = options.radius(); 388 op->bias = options.bias(); 389 op->alpha = options.alpha(); 390 op->beta = options.beta(); 391 } 392 }; 393 394 class MaxPool : public BuiltinOperator<MaxPoolOperator, ::tflite::Pool2DOptions, 395 ::tflite::BuiltinOptions_Pool2DOptions> { 396 public: 397 using BuiltinOperator::BuiltinOperator; 398 flatbuffers::Offset<TfLiteOptions> WriteOptions( 399 const TocoOperator& op, 400 flatbuffers::FlatBufferBuilder* builder) const override { 401 auto padding = Padding::Serialize(op.padding.type); 402 auto activation_function = 403 ActivationFunction::Serialize(op.fused_activation_function); 404 return ::tflite::CreatePool2DOptions(*builder, padding, op.stride_width, 405 op.stride_height, op.kwidth, 406 op.kheight, activation_function); 407 } 408 409 void ReadOptions(const TfLiteOptions& options, 410 TocoOperator* op) const override { 411 op->padding.type = Padding::Deserialize(options.padding()); 412 op->stride_width = options.stride_w(); 413 op->stride_height = options.stride_h(); 414 op->kwidth = options.filter_width(); 415 op->kheight = options.filter_height(); 416 op->fused_activation_function = 417 ActivationFunction::Deserialize(options.fused_activation_function()); 418 } 419 }; 420 421 class Mul : public BuiltinOperator<MulOperator, ::tflite::MulOptions, 422 ::tflite::BuiltinOptions_MulOptions> { 423 public: 424 using BuiltinOperator::BuiltinOperator; 425 426 flatbuffers::Offset<TfLiteOptions> WriteOptions( 427 const TocoOperator& op, 428 flatbuffers::FlatBufferBuilder* builder) const override { 429 auto activation_function = 430 ActivationFunction::Serialize(op.fused_activation_function); 431 return ::tflite::CreateMulOptions(*builder, activation_function); 432 } 433 434 void ReadOptions(const TfLiteOptions& options, 435 TocoOperator* op) const override { 436 op->fused_activation_function = 437 ActivationFunction::Deserialize(options.fused_activation_function()); 438 } 439 }; 440 441 class Pad : public BuiltinOperator<PadOperator, ::tflite::PadOptions, 442 ::tflite::BuiltinOptions_PadOptions> { 443 public: 444 using BuiltinOperator::BuiltinOperator; 445 446 flatbuffers::Offset<TfLiteOptions> WriteOptions( 447 const TocoOperator& op, 448 flatbuffers::FlatBufferBuilder* builder) const override { 449 return ::tflite::CreatePadOptions(*builder); 450 } 451 452 void ReadOptions(const TfLiteOptions& options, 453 TocoOperator* op) const override {} 454 }; 455 456 class Reshape 457 : public BuiltinOperator<TensorFlowReshapeOperator, 458 ::tflite::ReshapeOptions, 459 ::tflite::BuiltinOptions_ReshapeOptions> { 460 public: 461 using BuiltinOperator::BuiltinOperator; 462 463 flatbuffers::Offset<TfLiteOptions> WriteOptions( 464 const TocoOperator& op, 465 flatbuffers::FlatBufferBuilder* builder) const override { 466 return ::tflite::CreateReshapeOptions(*builder, 467 builder->CreateVector(op.shape)); 468 } 469 470 void ReadOptions(const TfLiteOptions& options, 471 TocoOperator* op) const override { 472 op->shape.insert(op->shape.end(), options.new_shape()->begin(), 473 options.new_shape()->end()); 474 } 475 }; 476 477 class Softmax 478 : public BuiltinOperator<SoftmaxOperator, ::tflite::SoftmaxOptions, 479 ::tflite::BuiltinOptions_SoftmaxOptions> { 480 public: 481 using BuiltinOperator::BuiltinOperator; 482 flatbuffers::Offset<TfLiteOptions> WriteOptions( 483 const TocoOperator& op, 484 flatbuffers::FlatBufferBuilder* builder) const override { 485 return ::tflite::CreateSoftmaxOptions(*builder, op.beta); 486 } 487 488 void ReadOptions(const TfLiteOptions& options, 489 TocoOperator* op) const override { 490 op->beta = options.beta(); 491 } 492 }; 493 494 class SpaceToDepth 495 : public BuiltinOperator<SpaceToDepthOperator, 496 ::tflite::SpaceToDepthOptions, 497 ::tflite::BuiltinOptions_SpaceToDepthOptions> { 498 public: 499 using BuiltinOperator::BuiltinOperator; 500 flatbuffers::Offset<TfLiteOptions> WriteOptions( 501 const TocoOperator& op, 502 flatbuffers::FlatBufferBuilder* builder) const override { 503 return ::tflite::CreateSpaceToDepthOptions(*builder, op.block_size); 504 } 505 506 void ReadOptions(const TfLiteOptions& options, 507 TocoOperator* op) const override { 508 op->block_size = options.block_size(); 509 } 510 }; 511 512 class Transpose 513 : public BuiltinOperator<TransposeOperator, ::tflite::TransposeOptions, 514 ::tflite::BuiltinOptions_TransposeOptions> { 515 public: 516 using BuiltinOperator::BuiltinOperator; 517 flatbuffers::Offset<TfLiteOptions> WriteOptions( 518 const TocoOperator& op, 519 flatbuffers::FlatBufferBuilder* builder) const override { 520 return ::tflite::CreateTransposeOptions(*builder); 521 } 522 523 void ReadOptions(const TfLiteOptions& options, 524 TocoOperator* op) const override {} 525 }; 526 527 class Lstm : public BuiltinOperator<LstmCellOperator, ::tflite::LSTMOptions, 528 ::tflite::BuiltinOptions_LSTMOptions> { 529 public: 530 using BuiltinOperator::BuiltinOperator; 531 flatbuffers::Offset<TfLiteOptions> WriteOptions( 532 const TocoOperator& op, 533 flatbuffers::FlatBufferBuilder* builder) const override { 534 // Current toco converter only supports tanh, no clip. 535 return ::tflite::CreateLSTMOptions(*builder, /*fused_activation_function=*/ 536 ::tflite::ActivationFunctionType_TANH, 537 /*cell_clip=*/0.0, 538 /*proj_clip=*/0.0); 539 } 540 541 void ReadOptions(const TfLiteOptions& options, 542 TocoOperator* op) const override { 543 // Only support tanh activation, so check that tflite type is tanh. 544 CHECK(options.fused_activation_function() == 545 ::tflite::ActivationFunctionType_TANH); 546 } 547 }; 548 549 class Mean : public BuiltinOperator<MeanOperator, ::tflite::MeanOptions, 550 ::tflite::BuiltinOptions_MeanOptions> { 551 public: 552 using BuiltinOperator::BuiltinOperator; 553 flatbuffers::Offset<TfLiteOptions> WriteOptions( 554 const TocoOperator& op, 555 flatbuffers::FlatBufferBuilder* builder) const override { 556 return ::tflite::CreateMeanOptions(*builder, op.keep_dims); 557 } 558 559 void ReadOptions(const TfLiteOptions& options, 560 TocoOperator* op) const override { 561 op->keep_dims = options.keep_dims(); 562 } 563 }; 564 565 class ResizeBilinear 566 : public BuiltinOperator<ResizeBilinearOperator, 567 ::tflite::ResizeBilinearOptions, 568 ::tflite::BuiltinOptions_ResizeBilinearOptions> { 569 public: 570 using BuiltinOperator::BuiltinOperator; 571 flatbuffers::Offset<TfLiteOptions> WriteOptions( 572 const TocoOperator& op, 573 flatbuffers::FlatBufferBuilder* builder) const override { 574 return ::tflite::CreateResizeBilinearOptions(*builder, op.align_corners); 575 } 576 577 void ReadOptions(const TfLiteOptions& options, 578 TocoOperator* op) const override { 579 op->align_corners = options.align_corners(); 580 } 581 }; 582 583 class Squeeze 584 : public BuiltinOperator<SqueezeOperator, ::tflite::SqueezeOptions, 585 ::tflite::BuiltinOptions_SqueezeOptions> { 586 public: 587 using BuiltinOperator::BuiltinOperator; 588 589 flatbuffers::Offset<TfLiteOptions> WriteOptions( 590 const TocoOperator& op, 591 flatbuffers::FlatBufferBuilder* builder) const override { 592 auto squeeze_dims = builder->CreateVector(op.squeeze_dims); 593 return ::tflite::CreateSqueezeOptions(*builder, squeeze_dims); 594 } 595 596 void ReadOptions(const TfLiteOptions& options, 597 TocoOperator* op) const override { 598 op->squeeze_dims.insert(op->squeeze_dims.end(), 599 options.squeeze_dims()->begin(), 600 options.squeeze_dims()->end()); 601 } 602 }; 603 604 class Split 605 : public BuiltinOperator<TensorFlowSplitOperator, ::tflite::SplitOptions, 606 ::tflite::BuiltinOptions_SplitOptions> { 607 public: 608 using BuiltinOperator::BuiltinOperator; 609 610 flatbuffers::Offset<TfLiteOptions> WriteOptions( 611 const TocoOperator& op, 612 flatbuffers::FlatBufferBuilder* builder) const override { 613 return ::tflite::CreateSplitOptions(*builder, op.num_split); 614 } 615 616 void ReadOptions(const TfLiteOptions& options, 617 TocoOperator* op) const override { 618 op->num_split = options.num_splits(); 619 } 620 }; 621 622 class StridedSlice 623 : public BuiltinOperator<StridedSliceOperator, 624 ::tflite::StridedSliceOptions, 625 ::tflite::BuiltinOptions_StridedSliceOptions> { 626 public: 627 using BuiltinOperator::BuiltinOperator; 628 flatbuffers::Offset<TfLiteOptions> WriteOptions( 629 const TocoOperator& op, 630 flatbuffers::FlatBufferBuilder* builder) const override { 631 return ::tflite::CreateStridedSliceOptions( 632 *builder, op.begin_mask, op.end_mask, op.ellipsis_mask, 633 op.new_axis_mask, op.shrink_axis_mask); 634 } 635 636 void ReadOptions(const TfLiteOptions& options, 637 TocoOperator* op) const override { 638 op->begin_mask = options.begin_mask(); 639 op->end_mask = options.end_mask(); 640 op->ellipsis_mask = options.ellipsis_mask(); 641 op->new_axis_mask = options.new_axis_mask(); 642 op->shrink_axis_mask = options.shrink_axis_mask(); 643 } 644 }; 645 646 class TopK_V2 : public BuiltinOperator<TopKV2Operator, ::tflite::TopKV2Options, 647 ::tflite::BuiltinOptions_TopKV2Options> { 648 public: 649 using BuiltinOperator::BuiltinOperator; 650 flatbuffers::Offset<TfLiteOptions> WriteOptions( 651 const TocoOperator& op, 652 flatbuffers::FlatBufferBuilder* builder) const override { 653 return ::tflite::CreateTopKV2Options(*builder); 654 } 655 656 void ReadOptions(const TfLiteOptions& options, 657 TocoOperator* op) const override {} 658 }; 659 660 class TensorFlowUnsupported : public BaseOperator { 661 public: 662 using BaseOperator::BaseOperator; 663 664 Options Serialize(const Operator& op, 665 flatbuffers::FlatBufferBuilder* builder) const override { 666 auto fbb = 667 WriteOptions(static_cast<const TensorFlowUnsupportedOperator&>(op)); 668 if (fbb) { 669 return Options::Custom(builder->CreateVector(fbb->GetBuffer())); 670 } else { 671 return Options::Custom(0); 672 } 673 } 674 675 std::unique_ptr<Operator> Deserialize( 676 const BuiltinOptions* builtin_options, 677 const CustomOptions* custom_options) const override { 678 auto op = absl::make_unique<TensorFlowUnsupportedOperator>(); 679 if (custom_options) { 680 auto flexbuffer_map = 681 flexbuffers::GetRoot(custom_options->data(), custom_options->size()) 682 .AsMap(); 683 ReadOptions(flexbuffer_map, op.get()); 684 } 685 return std::unique_ptr<Operator>(op.release()); 686 } 687 688 std::unique_ptr<flexbuffers::Builder> WriteOptions( 689 const TensorFlowUnsupportedOperator& op) const { 690 auto fbb = absl::make_unique<flexbuffers::Builder>(); 691 692 ::tensorflow::NodeDef node_def; 693 if (!node_def.ParseFromString(op.tensorflow_node_def)) { 694 LOG(ERROR) << "Failed to parse TensorFlow NodeDef"; 695 return std::unique_ptr<flexbuffers::Builder>(); 696 } 697 698 bool has_valid_attr = false; 699 size_t map_start = fbb->StartMap(); 700 for (const auto& pair : node_def.attr()) { 701 const char* key = pair.first.c_str(); 702 const auto& attr = pair.second; 703 switch (attr.value_case()) { 704 case ::tensorflow::AttrValue::kS: 705 fbb->String(key, attr.s()); 706 has_valid_attr = true; 707 break; 708 case ::tensorflow::AttrValue::kI: 709 fbb->Int(key, attr.i()); 710 has_valid_attr = true; 711 break; 712 case ::tensorflow::AttrValue::kF: 713 fbb->Float(key, attr.f()); 714 has_valid_attr = true; 715 break; 716 case ::tensorflow::AttrValue::kB: 717 fbb->Bool(key, attr.b()); 718 has_valid_attr = true; 719 break; 720 default: 721 LOG(WARNING) << "Ignoring unsupported attribute type with key '" 722 << key << "'"; 723 break; 724 } 725 } 726 if (!has_valid_attr) { 727 return std::unique_ptr<flexbuffers::Builder>(); 728 } 729 fbb->EndMap(map_start); 730 fbb->Finish(); 731 return std::unique_ptr<flexbuffers::Builder>(fbb.release()); 732 } 733 734 void ReadOptions(const flexbuffers::Map& m, 735 TensorFlowUnsupportedOperator* op) const { 736 ::tensorflow::NodeDef node_def; 737 auto attr = node_def.mutable_attr(); 738 739 const auto& keys = m.Keys(); 740 for (size_t i = 0; i < keys.size(); ++i) { 741 const auto key = keys[i].AsKey(); 742 const auto& value = m[key]; 743 switch (value.GetType()) { 744 case flexbuffers::TYPE_STRING: 745 (*attr)[key].set_s(value.AsString().c_str()); 746 break; 747 case flexbuffers::TYPE_INT: 748 (*attr)[key].set_i(value.AsInt64()); 749 break; 750 case flexbuffers::TYPE_FLOAT: 751 (*attr)[key].set_f(value.AsFloat()); 752 break; 753 case flexbuffers::TYPE_BOOL: 754 (*attr)[key].set_b(value.AsBool()); 755 break; 756 default: 757 LOG(WARNING) << "Ignoring unsupported attribute type with key '" 758 << key << "'"; 759 break; 760 } 761 } 762 node_def.SerializeToString(&op->tensorflow_node_def); 763 } 764 }; 765 766 namespace { 767 // Build a vector containing all the known operators. 768 std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() { 769 std::vector<std::unique_ptr<BaseOperator>> ops; 770 771 // Builtin Operators. 772 ops.emplace_back(new Add(::tflite::BuiltinOperator_ADD, OperatorType::kAdd)); 773 ops.emplace_back(new Div(::tflite::BuiltinOperator_DIV, OperatorType::kDiv)); 774 ops.emplace_back(new Sub(::tflite::BuiltinOperator_SUB, OperatorType::kSub)); 775 ops.emplace_back(new AveragePool(::tflite::BuiltinOperator_AVERAGE_POOL_2D, 776 OperatorType::kAveragePool)); 777 ops.emplace_back( 778 new SpaceToBatchND(::tflite::BuiltinOperator_SPACE_TO_BATCH_ND, 779 OperatorType::kSpaceToBatchND)); 780 ops.emplace_back( 781 new BatchToSpaceND(::tflite::BuiltinOperator_BATCH_TO_SPACE_ND, 782 OperatorType::kBatchToSpaceND)); 783 ops.emplace_back(new Concatenation(::tflite::BuiltinOperator_CONCATENATION, 784 OperatorType::kConcatenation)); 785 ops.emplace_back( 786 new Convolution(::tflite::BuiltinOperator_CONV_2D, OperatorType::kConv)); 787 ops.emplace_back( 788 new DepthwiseConvolution(::tflite::BuiltinOperator_DEPTHWISE_CONV_2D, 789 OperatorType::kDepthwiseConv)); 790 ops.emplace_back(new FullyConnected(::tflite::BuiltinOperator_FULLY_CONNECTED, 791 OperatorType::kFullyConnected)); 792 ops.emplace_back( 793 new Gather(::tflite::BuiltinOperator_GATHER, OperatorType::kGather)); 794 ops.emplace_back( 795 new L2Normalization(::tflite::BuiltinOperator_L2_NORMALIZATION, 796 OperatorType::kL2Normalization)); 797 ops.emplace_back( 798 new L2Pool(::tflite::BuiltinOperator_L2_POOL_2D, OperatorType::kL2Pool)); 799 ops.emplace_back(new LocalResponseNormalization( 800 ::tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION, 801 OperatorType::kLocalResponseNormalization)); 802 ops.emplace_back(new MaxPool(::tflite::BuiltinOperator_MAX_POOL_2D, 803 OperatorType::kMaxPool)); 804 ops.emplace_back(new Mul(::tflite::BuiltinOperator_MUL, OperatorType::kMul)); 805 ops.emplace_back(new Pad(::tflite::BuiltinOperator_PAD, OperatorType::kPad)); 806 ops.emplace_back(new Reshape(::tflite::BuiltinOperator_RESHAPE, 807 OperatorType::kTensorFlowReshape)); 808 ops.emplace_back( 809 new Softmax(::tflite::BuiltinOperator_SOFTMAX, OperatorType::kSoftmax)); 810 ops.emplace_back(new SpaceToDepth(::tflite::BuiltinOperator_SPACE_TO_DEPTH, 811 OperatorType::kSpaceToDepth)); 812 ops.emplace_back( 813 new Svdf(::tflite::BuiltinOperator_SVDF, OperatorType::kSvdf)); 814 ops.emplace_back(new Transpose(::tflite::BuiltinOperator_TRANSPOSE, 815 OperatorType::kTranspose)); 816 ops.emplace_back( 817 new Mean(::tflite::BuiltinOperator_MEAN, OperatorType::kMean)); 818 ops.emplace_back(new ResizeBilinear(::tflite::BuiltinOperator_RESIZE_BILINEAR, 819 OperatorType::kResizeBilinear)); 820 ops.emplace_back( 821 new Squeeze(::tflite::BuiltinOperator_SQUEEZE, OperatorType::kSqueeze)); 822 ops.emplace_back(new Split(::tflite::BuiltinOperator_SPLIT, 823 OperatorType::kTensorFlowSplit)); 824 ops.emplace_back(new StridedSlice(::tflite::BuiltinOperator_STRIDED_SLICE, 825 OperatorType::kStridedSlice)); 826 ops.emplace_back( 827 new TopK_V2(::tflite::BuiltinOperator_TOPK_V2, OperatorType::kTopK_V2)); 828 ops.emplace_back( 829 new Lstm(::tflite::BuiltinOperator_LSTM, OperatorType::kLstmCell)); 830 831 // Custom Operators. 832 ops.emplace_back(new Cast("CAST", OperatorType::kCast)); 833 ops.emplace_back( 834 new DepthToSpace("DEPTH_TO_SPACE", OperatorType::kDepthToSpace)); 835 ops.emplace_back(new FakeQuant("FAKE_QUANT", OperatorType::kFakeQuant)); 836 ops.emplace_back(new TensorFlowUnsupported( 837 "TENSORFLOW_UNSUPPORTED", OperatorType::kTensorFlowUnsupported)); 838 839 // There operators are supported by Toco, but not by TF Lite, and has no 840 // attributes. 841 ops.emplace_back( 842 new SimpleOperator<AddNOperator>("ADDN", OperatorType::kAddN)); 843 ops.emplace_back(new SimpleOperator<NegOperator>("NEG", OperatorType::kNeg)); 844 ops.emplace_back(new SimpleOperator<TensorFlowRsqrtOperator>( 845 "RSQRT", OperatorType::kTensorFlowRsqrt)); 846 // Simple Operators. 847 ops.emplace_back(new SimpleOperator<DequantizeOperator>( 848 "DEQUANTIZE", OperatorType::kDequantize)); 849 ops.emplace_back( 850 new SimpleOperator<FloorOperator>("FLOOR", OperatorType::kFloor)); 851 ops.emplace_back( 852 new SimpleOperator<ReluOperator>("RELU", OperatorType::kRelu)); 853 ops.emplace_back( 854 new SimpleOperator<Relu1Operator>("RELU_N1_TO_1", OperatorType::kRelu1)); 855 ops.emplace_back( 856 new SimpleOperator<Relu6Operator>("RELU6", OperatorType::kRelu6)); 857 ops.emplace_back(new SimpleOperator<LogisticOperator>( 858 "LOGISTIC", OperatorType::kLogistic)); 859 ops.emplace_back( 860 new SimpleOperator<TanhOperator>("TANH", OperatorType::kTanh)); 861 ops.emplace_back(new SimpleOperator<ExpOperator>("EXP", OperatorType::kExp)); 862 863 return ops; 864 } 865 } // namespace 866 867 std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap() { 868 std::map<OperatorType, std::unique_ptr<BaseOperator>> result; 869 870 std::vector<std::unique_ptr<BaseOperator>> ops = BuildOperatorList(); 871 for (auto& op : ops) { 872 result[op->type()] = std::move(op); 873 } 874 875 return result; 876 } 877 878 std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap() { 879 std::map<string, std::unique_ptr<BaseOperator>> result; 880 881 std::vector<std::unique_ptr<BaseOperator>> ops = BuildOperatorList(); 882 for (auto& op : ops) { 883 result[op->name()] = std::move(op); 884 } 885 886 return result; 887 } 888 889 } // namespace tflite 890 891 } // namespace toco 892