Home | History | Annotate | Download | only in toco
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #ifndef TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_H_
     16 #define TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_H_
     17 
     18 #include <functional>
     19 #include <initializer_list>
     20 #include <memory>
     21 #include <string>
     22 #include <unordered_map>
     23 #include <vector>
     24 
     25 #include "tensorflow/contrib/lite/toco/model_flags.pb.h"
     26 #include "tensorflow/contrib/lite/toco/runtime/types.h"
     27 #include "tensorflow/contrib/lite/toco/toco_types.h"
     28 #include "tensorflow/core/platform/logging.h"
     29 
     30 namespace toco {
     31 
     32 enum class OperatorType {
     33   kNone,
     34   // General-purpose neural network operators.
     35   kAdd,
     36   kAddN,
     37   kAveragePool,
     38   kBatchMatMul,
     39   kBatchNormalization,
     40   kConv,
     41   kConcatenation,
     42   kDepthwiseConv,
     43   kDepthToSpace,
     44   kSpaceToDepth,
     45   kDequantize,
     46   kDiv,
     47   kExp,
     48   kExpandDims,
     49   kFill,
     50   kFloorDiv,
     51   kFloorMod,
     52   kFullyConnected,
     53   kL2Normalization,
     54   kL2Pool,
     55   kLstmCell,
     56   kLocalResponseNormalization,
     57   kLogistic,
     58   kMaxPool,
     59   kFakeQuant,
     60   kMul,
     61   kRange,
     62   kRank,
     63   kRelu,
     64   kRelu1,
     65   kRelu6,
     66   kSoftmax,
     67   kLogSoftmax,
     68   kSub,
     69   kTanh,
     70   kTransposeConv,
     71   kCast,
     72   kFloor,
     73   kGather,
     74   kResizeBilinear,
     75   kSpaceToBatchND,
     76   kStack,
     77   kBatchToSpaceND,
     78   kPad,
     79   kStridedSlice,
     80   kSlice,
     81   kSqueeze,
     82   kMean,
     83   kArgMax,
     84   // The SVDF Op is a decomposition of a densely connected Op into
     85   // low rank filters. For details:
     86   // https://research.google.com/pubs/pub43813.html
     87   kSvdf,
     88   // Special operators used for importing TensorFlow nodes.
     89   // The general intent is to have some graph transformation either
     90   // drop them or rewrite them as general-purpose operators.
     91   kTensorFlowAll,
     92   kTensorFlowAssert,
     93   kTensorFlowConcat,
     94   kTensorFlowConcatV2,
     95   kTensorFlowGreater,
     96   kTensorFlowGreaterEqual,
     97   kTensorFlowIdentity,
     98   kTensorFlowLess,
     99   kTensorFlowLessEqual,
    100   kTensorFlowMax,
    101   kTensorFlowMaximum,
    102   kTensorFlowMin,
    103   kTensorFlowMinimum,
    104   kTensorFlowMatMul,
    105   kTensorFlowMerge,
    106   kNeg,
    107   kTensorFlowReshape,
    108   kTensorFlowRsqrt,
    109   kTensorFlowShape,
    110   kTensorFlowSplit,
    111   kTensorFlowSqrt,
    112   kTensorFlowSquare,
    113   kTensorFlowSum,
    114   kTensorFlowSwitch,
    115   kTensorFlowTile,
    116   kTranspose,
    117   kTopK_V2,
    118   // An unsupported TF operation. It's only needed to be able to represent TF
    119   // graph internally and is expected to be dropped by graph transformations.
    120   kTensorFlowUnsupported,
    121   // Finally, TensorFlow uses different conventions for axes ordering,
    122   // see AxesOrder, and this cannot always be resolved at the time of importing
    123   // nodes, as TensorFlow parameters may be constant-expression subgraphs
    124   // instead of being given as plain constant arrays. So we need to insert
    125   // special nodes in the graph to shuffle axes.
    126   kReorderAxes,
    127 };
    128 
    129 // Helper to deal with TensorFlow arrays using a different ordering of
    130 // dimensions
    131 // ("axes") than our own.
    132 // TODO(benoitjacob): Ultimately, we shouldn't have any "ordering" of axes,
    133 // we should have associative arrays mapping symbolic axes identifiers (like
    134 // "output_depth") to dimensions. We would then not need this anymore.
    135 enum class AxesOrder {
    136   kOneAxis,  // one-dimensional array, one unique axis.
    137   kCR,       // column-major matrix storage order. Our standard.
    138   kRC,       // row-major matrix storage order. TensorFlow default.
    139   kOHWI,     // Our standard for conv weights
    140   kHWIO,     // TensorFlow conv weights
    141   k1HWO,     // Our standard for DepthwiseConv weights
    142   kHWIM,     // TensorFlow DepthwiseConv weights
    143   kNHWC,     // TensorFlow activations
    144 };
    145 
    146 // The type of the scalars in an array.
    147 // Note that that does not by itself tell whether the values in the array are
    148 // real (are literally interpreted as real numbers) or quantized (only acquire
    149 // a meaning as real numbers in conjunction with QuantizationParams).
    150 //
    151 // In practice though:
    152 //   float values are always real
    153 //   uint8 values are always quantized
    154 //   int32 values are either real or quantized (depending on whether
    155 //   QuantizationParams are present).
    156 //   other types are unused at the moment.
    157 //
    158 // kNone means that we don't know the data type yet, or that we don't care
    159 // because we'll be dropping the array anyway (e.g. some exotic array types
    160 // may be involved only in debug-only subgraphs that we may not be interested
    161 // in actually supporting).
    162 enum class ArrayDataType {
    163   kNone,  // 0
    164   kBool,
    165   kFloat,
    166   kInt8,
    167   kUint8,
    168   kInt16,  // 5
    169   kUint16,
    170   kInt32,
    171   kUint32,
    172   kInt64,
    173   kUint64,  // 10
    174   kString
    175 };
    176 
    177 // Compile-time logic to map ArrayDataType to the corresponding C++ scalar type
    178 template <ArrayDataType A>
    179 struct DataTypeImpl {};
    180 template <>
    181 struct DataTypeImpl<ArrayDataType::kNone> {
    182   typedef int Type;
    183 };
    184 template <>
    185 struct DataTypeImpl<ArrayDataType::kBool> {
    186   typedef bool Type;
    187 };
    188 template <>
    189 struct DataTypeImpl<ArrayDataType::kFloat> {
    190   typedef float Type;
    191 };
    192 template <>
    193 struct DataTypeImpl<ArrayDataType::kInt8> {
    194   typedef int8 Type;
    195 };
    196 template <>
    197 struct DataTypeImpl<ArrayDataType::kUint8> {
    198   typedef uint8 Type;
    199 };
    200 template <>
    201 struct DataTypeImpl<ArrayDataType::kInt16> {
    202   typedef int16 Type;
    203 };
    204 template <>
    205 struct DataTypeImpl<ArrayDataType::kUint16> {
    206   typedef uint16 Type;
    207 };
    208 template <>
    209 struct DataTypeImpl<ArrayDataType::kInt32> {
    210   typedef int32 Type;
    211 };
    212 template <>
    213 struct DataTypeImpl<ArrayDataType::kUint32> {
    214   typedef uint32 Type;
    215 };
    216 template <>
    217 struct DataTypeImpl<ArrayDataType::kInt64> {
    218   typedef int64 Type;
    219 };
    220 template <>
    221 struct DataTypeImpl<ArrayDataType::kUint64> {
    222   typedef uint64 Type;
    223 };
    224 template <>
    225 struct DataTypeImpl<ArrayDataType::kString> {
    226   typedef string Type;
    227 };
    228 
    229 template <ArrayDataType A>
    230 using DataType = typename DataTypeImpl<A>::Type;
    231 
    232 // Base class for type-specific buffer types.
    233 struct GenericBuffer {
    234   // Non-default-constructible: only ArrayDataType-specific subclass
    235   // objects may be constructed.
    236   GenericBuffer() = delete;
    237   // Non-copyable-or-movable: we should only store pointers-to-Buffer
    238   // in containers, not Operators themselves, so there should be no
    239   // copy or move.
    240   GenericBuffer(const GenericBuffer&) = delete;
    241   GenericBuffer(const GenericBuffer&&) = delete;
    242 
    243   // We need a virtual destructor so we can store pointers-to-Buffer
    244   // in containers and have the containers call the right subclass destructor.
    245   virtual ~GenericBuffer() {}
    246 
    247   const ArrayDataType type;
    248 
    249  protected:
    250   // Constructor used by subclasses for specific ArrayDataType's.
    251   explicit GenericBuffer(ArrayDataType t) : type(t) {}
    252 };
    253 
    254 // Type-specific buffer, containing type-specific storage.
    255 template <ArrayDataType A>
    256 struct Buffer : GenericBuffer {
    257   Buffer() : GenericBuffer(A) {}
    258 
    259   std::vector<DataType<A>> data;
    260 };
    261 
    262 // Base class for all operator classes.
    263 struct Operator {
    264   // Non-default-constructible: only OperatorType-specific subclass
    265   // objects may be constructed.
    266   Operator() = delete;
    267   // Non-copyable-or-movable: we should only store pointers-to-Operator
    268   // in containers, not Operators themselves, so there should be no
    269   // copy or move.
    270   Operator(const Operator&) = delete;
    271   Operator(const Operator&&) = delete;
    272 
    273   // We need a virtual destructor so we can store pointers-to-Operator
    274   // in containers and have the containers call the right subclass destructor.
    275   virtual ~Operator() {}
    276 
    277   // The specific type of operator. Corresponds 1:1 to subclasses.
    278   const OperatorType type;
    279 
    280   // The activation function that may be fused into this operator,
    281   // or None if no activation function is fused.
    282   FusedActivationFunctionType fused_activation_function;
    283 
    284   // Input arrays: either activation arrays or constant array parameters.
    285   // We refer to them by their name, not by their address; the mapping of
    286   // names to addresses is given by the Model, which owns both Operator's and
    287   // Array's. Thus, an Operator on its own doesn't contain much information,
    288   // it is meant to be used in conjunction with the Model that owns it.
    289   std::vector<string> inputs;
    290 
    291   // Output activation arrays. Same comments as for inputs apply here too.
    292   std::vector<string> outputs;
    293 
    294   // If true, the array has more outputs than are listed in the 'outputs'
    295   // member. These need to be resolved by some graph transformation.
    296   // This flag is only here to indicate that an operator should not be
    297   // discarded as unused, even if from its 'outputs' member alone it
    298   // looks unused.
    299   bool unresolved_outputs = false;
    300 
    301  protected:
    302   // Constructor used by subclasses for specific OperatorType's.
    303   explicit Operator(OperatorType t)
    304       : type(t),
    305         fused_activation_function(FusedActivationFunctionType::kNone) {}
    306 };
    307 
    308 // Padding types for Conv-like operators. This is how padding is typically
    309 // specified in model files. But for inference, we will need to resolve this
    310 // to a FixedPadding, see below.
    311 enum class PaddingType { kNone, kSame, kValid };
    312 
    313 // Padding as resolved for a specific layer shape, as needed for inference.
    314 // For a given layer shape, a given padding type will resolve to a choice of
    315 // a number of padding rows and columns, which we call the padding height and
    316 // width respectively.
    317 struct FixedPadding {
    318   int width = 0;
    319   int height = 0;
    320 };
    321 
    322 // "Universal" padding struct containing both a generic PaddingType (as
    323 // represented in a model file), and a FixedPadding (as needed for inference).
    324 // The latter is resolved during the PropagateFixedSizes pass.
    325 struct Padding {
    326   FixedPadding& GetOrCreateFixedPadding() {
    327     if (!fixed) {
    328       FixedPadding* ptr = new FixedPadding;
    329       fixed = std::unique_ptr<FixedPadding>(ptr);
    330     }
    331     return *fixed;
    332   }
    333 
    334   Padding() : type(PaddingType::kNone) {}
    335   PaddingType type;
    336   std::unique_ptr<FixedPadding> fixed;
    337 };
    338 
    339 // "Convolutional" layer, as represented in model files.
    340 //
    341 // Inputs:
    342 //   inputs[0]: required: the input activations array
    343 //   inputs[1]: required: the Conv weights
    344 //   inputs[2]: optional: the bias vector, specifying the biases for each output
    345 //   channel.
    346 //
    347 // Outputs:
    348 //   outputs[0]: required: the output activations array
    349 //   outputs[1]: optional: the intermediate array of im2col-replicated input
    350 //                         activations. Present when targeting implementations
    351 //                         of Conv layers as Im2col+GEMM.
    352 //
    353 // TensorFlow equivalent: Conv2D
    354 struct ConvOperator : Operator {
    355   ConvOperator() : Operator(OperatorType::kConv) {}
    356   Padding padding;
    357   int stride_width = 0;
    358   int stride_height = 0;
    359   // A dilation_rate of 0 is invalid and this field is an optional attribute.
    360   // Thus initializing it to 1 to allow default conv behavior when the
    361   // attribute is not present.
    362   int dilation_rate = 1;
    363 };
    364 
    365 // Depthwise-separable convolution operator.
    366 //
    367 // Inputs:
    368 //   inputs[0]: required: the input activations array
    369 //   inputs[1]: required: the DepthwiseConv weights
    370 //   inputs[2]: optional: the bias vector, specifying the biases for each output
    371 //   channel.
    372 //
    373 // TensorFlow equivalent: DepthwiseConv2dNative
    374 struct DepthwiseConvOperator : Operator {
    375   DepthwiseConvOperator() : Operator(OperatorType::kDepthwiseConv) {}
    376   Padding padding;
    377   int stride_height = 0;
    378   int stride_width = 0;
    379   int depth_multiplier = 0;
    380 };
    381 
    382 // Depth-to-space transform operator.
    383 //
    384 // Inputs:
    385 //   inputs[0]: required: the input activations array
    386 //
    387 // TensorFlow equivalent: DepthToSpace
    388 struct DepthToSpaceOperator : Operator {
    389   DepthToSpaceOperator() : Operator(OperatorType::kDepthToSpace) {}
    390   int block_size = 0;
    391 };
    392 
    393 // Space-to-depth transform operator.
    394 //
    395 // Inputs:
    396 //   inputs[0]: required: the input activations array
    397 //
    398 // TensorFlow equivalent: SpaceToDepth
    399 struct SpaceToDepthOperator : Operator {
    400   SpaceToDepthOperator() : Operator(OperatorType::kSpaceToDepth) {}
    401   int block_size = 0;
    402 };
    403 
    404 // Fully-connected operator.
    405 //
    406 // Inputs:
    407 //   inputs[0]: required: the input activations array
    408 //   inputs[1]: required: the FullyConnected weights
    409 //   inputs[2]: optional: the bias vector, specifying the biases for each output
    410 //   channel.
    411 //
    412 // TensorFlow equivalent: a pair consisting of a Reshape node reshaping the
    413 // input activations as a matrix, followed by a MatMul node.
    414 struct FullyConnectedOperator : Operator {
    415   FullyConnectedOperator() : Operator(OperatorType::kFullyConnected) {}
    416 };
    417 
    418 // Dequantization operator, converting a quantized array of integers with
    419 // quantization parameters specifying how these integers correspond to real
    420 // numbers
    421 // (see QuantizationParams) to an output activations array of floating-point
    422 // values.
    423 //
    424 // In floating-point image models, there is typically a Dequantization operator
    425 // at the very beginning, converting the input image RGB data, consisting of
    426 // uint8 integer values, to floating-point input activations. That is where
    427 // image model parameters such as "mean_value" and "std_value" are typically
    428 // handled.
    429 //
    430 // This is the only operator type that converts from quantized to
    431 // floating-point,
    432 // and there is at the moment no operator type at all to convert from
    433 // floating-point
    434 // to quantized. Every other operator does either float->float or
    435 // quantized->quantized.
    436 //
    437 // Inputs:
    438 //   inputs[0]: required: the input quantized activations array
    439 //
    440 // TensorFlow equivalent: Dequantize
    441 struct DequantizeOperator : Operator {
    442   DequantizeOperator() : Operator(OperatorType::kDequantize) {}
    443 };
    444 
    445 // Batch-normalization operator.
    446 //
    447 // We only support batch-normalization using pre-learned moments, so this is
    448 // just
    449 // computing (input - mean) * multiplier + offset. As such, this can be
    450 // expressed as a combination of Add and Mul nodes, and indeed this is how
    451 // we break it down during tooling for the purpose of fusing it into
    452 // other operators.
    453 //
    454 // Inputs:
    455 //   inputs[0]: required: the input activations array
    456 //   inputs[1]: required: the learned mean array
    457 //   inputs[2]: required: the learned multiplier array
    458 //   inputs[3]: required: the learned offset array
    459 //
    460 // TensorFlow equivalent: a combination of Add and Mul nodes
    461 struct BatchNormalizationOperator : Operator {
    462   BatchNormalizationOperator()
    463       : Operator(OperatorType::kBatchNormalization),
    464         global_normalization(false) {}
    465   bool global_normalization;
    466 };
    467 
    468 // L2-normalization operator.
    469 //
    470 // Inputs:
    471 //   inputs[0]: required: the input activations array
    472 //
    473 // TensorFlow equivalent: none. In TensorFlow, L2 normalization is implemented
    474 // by a sub-graph of operators implementing L2-normalization
    475 // from lower-level arithmetic nodes; during tooling, we identify such
    476 // sub-graphs
    477 // and replace them by L2NormalizationOperator's. See IdentifyL2Normalization.
    478 struct L2NormalizationOperator : Operator {
    479   L2NormalizationOperator() : Operator(OperatorType::kL2Normalization) {}
    480 };
    481 
    482 // LSTM Cell operator.
    483 //
    484 // Inputs:
    485 //   inputs[0]: required: the input data array
    486 //   inputs[1]: required: the previous output activations array
    487 //   inputs[2]: required: the learned weights array
    488 //   inputs[3]: required: the learned biases array
    489 //   inputs[4]: required: the previous output state
    490 //   outputs[0]: required: the output activations array
    491 //   outputs[1]: required: the new state array
    492 //
    493 // TensorFlow equivalent: none. In TensorFlow, an LSTM is implemented
    494 // with a sub-graph of lower-level arithmetic nodes; during tooling, we identify
    495 // such sub-graphs and replace them with LstmCells. See IdentifyLstmCell().
    496 struct LstmCellOperator : Operator {
    497   enum Inputs {
    498     DATA_INPUT = 0,
    499     PREV_ACTIV_INPUT = 1,
    500     WEIGHTS_INPUT = 2,
    501     BIASES_INPUT = 3,
    502     PREV_STATE_INPUT = 4,
    503     NUM_INPUTS = 5
    504   };
    505   enum Outputs {
    506     ACTIV_OUTPUT = 0,
    507     STATE_OUTPUT = 1,
    508     CONCAT_TEMP = 2,
    509     ACTIV_TEMP = 3,
    510     NUM_OUTPUTS = 4
    511   };
    512   LstmCellOperator() : Operator(OperatorType::kLstmCell) {}
    513 };
    514 
    515 // Element-wise multiplication operator.
    516 //
    517 // Inputs:
    518 //   inputs[0]: required: the left-hand side array
    519 //   inputs[1]: required: the right-hand side array
    520 //
    521 // TensorFlow equivalent: Mul
    522 struct MulOperator : Operator {
    523   MulOperator() : Operator(OperatorType::kMul) {}
    524 };
    525 
    526 // Element-wise Relu operator:
    527 //   x -> max(0, x)
    528 //
    529 // Inputs:
    530 //   inputs[0]: required: the input array
    531 //
    532 // TensorFlow equivalent: Relu
    533 struct ReluOperator : Operator {
    534   ReluOperator() : Operator(OperatorType::kRelu) {}
    535 };
    536 
    537 // Element-wise Relu1 operator:
    538 //   x -> min(max(x, -1), 1)
    539 //
    540 // Inputs:
    541 //   inputs[0]: required: the input array
    542 //
    543 // TensorFlow equivalent: none. We can construct the operator with Minimum
    544 // and Maximum operations
    545 struct Relu1Operator : Operator {
    546   Relu1Operator() : Operator(OperatorType::kRelu1) {}
    547 };
    548 
    549 // Element-wise Relu6 operator:
    550 //   x -> max(0, min(6, x))
    551 //
    552 // Inputs:
    553 //   inputs[0]: required: the input array
    554 //
    555 // TensorFlow equivalent: Relu6
    556 struct Relu6Operator : Operator {
    557   Relu6Operator() : Operator(OperatorType::kRelu6) {}
    558 };
    559 
    560 // Element-wise Logistic operator:
    561 //   x -> Logistic(x) = 1 / (1 + exp(-x))
    562 //
    563 // Inputs:
    564 //   inputs[0]: required: the input array
    565 //
    566 // TensorFlow equivalent: Sigmoid
    567 struct LogisticOperator : Operator {
    568   LogisticOperator() : Operator(OperatorType::kLogistic) {}
    569 };
    570 
    571 // Element-wise Tanh operator:
    572 //   x -> Tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
    573 //
    574 // Inputs:
    575 //   inputs[0]: required: the input array
    576 //
    577 // TensorFlow equivalent: Tanh
    578 struct TanhOperator : Operator {
    579   TanhOperator() : Operator(OperatorType::kTanh) {}
    580 };
    581 
    582 // Element-wise addition operator.
    583 //
    584 // Inputs:
    585 //   inputs[0]: required: the left-hand side array
    586 //   inputs[1]: required: the right-hand side array
    587 //
    588 // TensorFlow equivalent: Add
    589 struct AddOperator : Operator {
    590   AddOperator() : Operator(OperatorType::kAdd) {}
    591 };
    592 
    593 // Element-wise addition operator for N inputs.
    594 //
    595 // Inputs:
    596 //   inputs[i]: The i-th array to add together to form the output.
    597 //
    598 // TensorFlow equivalent: AddN
    599 struct AddNOperator : Operator {
    600   AddNOperator() : Operator(OperatorType::kAddN) {}
    601 };
    602 
    603 // Concatenation operator: concatenates its inputs
    604 // along the axis.
    605 //
    606 // Inputs: this operator accepts any number >= 1 of inputs.
    607 //   inputs[i]: the i-th array to concatenate.
    608 //
    609 // TensorFlow equivalent: Concat.
    610 struct ConcatenationOperator : Operator {
    611   ConcatenationOperator() : Operator(OperatorType::kConcatenation) {}
    612   int axis = 0;
    613 };
    614 
    615 // Reordering dimensions. Used only during tooling to transform graphs from
    616 // the TensorFlow format.
    617 //
    618 // Inputs:
    619 //   inputs[0]: required: the input array
    620 //
    621 // TensorFlow equivalent: none. This is only useful to convert between formats.
    622 struct ReorderAxesOperator : Operator {
    623   ReorderAxesOperator() : Operator(OperatorType::kReorderAxes) {}
    624   AxesOrder input_axes_order;
    625   AxesOrder output_axes_order;
    626 };
    627 
    628 // Average-pooling operator.
    629 //
    630 // Inputs:
    631 //   inputs[0]: required: the input array
    632 //
    633 // TensorFlow equivalent: AveragePool
    634 struct AveragePoolOperator : Operator {
    635   AveragePoolOperator() : Operator(OperatorType::kAveragePool) {}
    636   Padding padding;
    637   int stride_height = 0;
    638   int stride_width = 0;
    639   int kheight = 0;
    640   int kwidth = 0;
    641 };
    642 
    643 // Local response normalization operator.
    644 //
    645 // Inputs:
    646 //   inputs[0]: required: the input array
    647 //
    648 // TensorFlow equivalent: LRN
    649 struct LocalResponseNormalizationOperator : Operator {
    650   LocalResponseNormalizationOperator()
    651       : Operator(OperatorType::kLocalResponseNormalization) {}
    652 
    653   int range = 0;
    654   float bias = 0.f;
    655   float alpha = 0.f;
    656   float beta = 0.f;
    657 };
    658 
    659 // Max-pooling operator.
    660 //
    661 // Inputs:
    662 //   inputs[0]: required: the input array
    663 //
    664 // TensorFlow equivalent: MaxPool
    665 struct MaxPoolOperator : Operator {
    666   MaxPoolOperator() : Operator(OperatorType::kMaxPool) {}
    667   Padding padding;
    668   int stride_height = 0;
    669   int stride_width = 0;
    670   int kheight = 0;
    671   int kwidth = 0;
    672 };
    673 
    674 // L2-pooling operator.
    675 //
    676 // Inputs:
    677 //   inputs[0]: required: the input array
    678 //
    679 // TensorFlow equivalent: none. Can be shimmed by squaring+avgpool+sqrt.
    680 struct L2PoolOperator : Operator {
    681   L2PoolOperator() : Operator(OperatorType::kL2Pool) {}
    682   Padding padding;
    683   int stride_height = 0;
    684   int stride_width = 0;
    685   int kheight = 0;
    686   int kwidth = 0;
    687 };
    688 
    689 // The expected [min, max] range of values in a given array.
    690 // Used for quantization only.
    691 // This information typically comes from special nodes found in quantized
    692 // models,
    693 // see FakeQuantOperator, and is used during quantization to resolve
    694 // actual quantization parameters (see QuantizationParams).
    695 struct MinMax {
    696   double min = 0.;
    697   double max = 0.;
    698 };
    699 
    700 inline bool operator==(const MinMax& m1, const MinMax& m2) {
    701   return m1.min == m2.min && m1.max == m2.max;
    702 }
    703 
    704 // Fake-quantization operator. This does two things:
    705 //   - Annotate its input and output arrays with MinMax information,
    706 //   - Arithmetic-wise, this operator rounds incoming activation values
    707 //     to the nearest representable value on the scale of 256
    708 //     values from the min to the max value dictated by its MinMax info.
    709 //
    710 // Inputs:
    711 //   inputs[0]: required: the input array
    712 //   inputs[1]: optional: the 'min' value, if it has not yet been resolved
    713 //              to a constant.
    714 //   inputs[2]: optional: the 'max' value, if it has not yet been resolved
    715 //              to a constant.
    716 //
    717 // TensorFlow equivalent: FakeQuantWithMinMaxVars, FakeQuantWithMinMaxArgs.
    718 struct FakeQuantOperator : Operator {
    719   FakeQuantOperator() : Operator(OperatorType::kFakeQuant) {}
    720   std::unique_ptr<MinMax> minmax;
    721 };
    722 
    723 // Element-wise division operator.
    724 //
    725 // Inputs:
    726 //   inputs[0]: required: the left-hand side array
    727 //   inputs[1]: required: the right-hand side array
    728 //
    729 // TensorFlow equivalent: Div
    730 struct DivOperator : Operator {
    731   DivOperator() : Operator(OperatorType::kDiv) {}
    732 };
    733 
    734 // Element-wise identity (x->x) operator.
    735 //
    736 // Inputs:
    737 //   inputs[0]: required: the input array
    738 //
    739 // TensorFlow equivalent: Identity
    740 struct TensorFlowIdentityOperator : Operator {
    741   TensorFlowIdentityOperator() : Operator(OperatorType::kTensorFlowIdentity) {}
    742 };
    743 
    744 // Batch matrix multiplication operator. This comes from the (deprecated)
    745 // tf.batch_matmul or a tf.matmul that has rank 3. dims(0) is the batch count
    746 // and it can be trivially unrolled into a series of matmuls on each element.
    747 //
    748 // Inputs:
    749 //   inputs[0]: required: the left-hand side matrix
    750 //   inputs[1]: required: the right-hand side matrix
    751 //
    752 // TensorFlow equivalent: MatMul
    753 struct BatchMatMulOperator : Operator {
    754   BatchMatMulOperator() : Operator(OperatorType::kBatchMatMul) {}
    755 };
    756 
    757 // General matrix multiplication operator. We don't want to support general
    758 // matrix multiplication at inference time, so we resolve it during tooling
    759 // to more specific operator types, namely, FullyConnected.
    760 //
    761 // Inputs:
    762 //   inputs[0]: required: the left-hand side matrix
    763 //   inputs[1]: required: the right-hand side matrix
    764 //
    765 // TensorFlow equivalent: MatMul
    766 struct TensorFlowMatMulOperator : Operator {
    767   TensorFlowMatMulOperator() : Operator(OperatorType::kTensorFlowMatMul) {}
    768 };
    769 
    770 // Padding operator. Pads a tensor with zeros.
    771 //
    772 // Inputs:
    773 //   inputs[0]: required: the input array
    774 //   inputs[1]: required: the padding array
    775 //
    776 // This operation pads a `input` with zeros according to the `paddings` you
    777 // specify. `paddings` is an integer tensor with shape `[Dn, 2]`, where n is the
    778 // rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates
    779 // how many zeros to add before the contents of `input` in that dimension, and
    780 // `paddings[D, 1]` indicates how many zeros to add after the contents of
    781 // `input` in that dimension.
    782 //
    783 // TensorFlow equivalent: Pad
    784 struct PadOperator : Operator {
    785   PadOperator() : Operator(OperatorType::kPad) {}
    786 
    787   std::vector<int> left_padding;
    788   std::vector<int> right_padding;
    789 };
    790 
    791 // Strided slice operator.
    792 //
    793 // Inputs:
    794 //   inputs[0]: required: the input array
    795 //   inputs[1]: required: the begin array
    796 //   inputs[2]: required: the end array
    797 //   inputs[3]: optional: the strides array
    798 //
    799 // TensorFlow equivalent: StridedSlice
    800 struct StridedSliceOperator : Operator {
    801   StridedSliceOperator() : Operator(OperatorType::kStridedSlice) {}
    802 
    803   std::vector<int> start_indices;
    804   std::vector<int> stop_indices;
    805   std::vector<int> strides;
    806 
    807   int begin_mask;
    808   int ellipsis_mask;
    809   int end_mask;
    810   int new_axis_mask;
    811   int shrink_axis_mask;
    812 };
    813 
    814 // Reshaping operator, reshaping its input array to a two-dimensional shape
    815 // (a "matrix"). This is used in the TensorFlow format, in conjunction with
    816 // MatMul nodes, to implement fully-connected layers.
    817 //
    818 // Inputs:
    819 //   inputs[0]: required: the input array
    820 //
    821 // TensorFlow equivalent: Reshape --- except that we only support a special case
    822 // here, where the output shape is a matrix (2D) shape.
    823 struct TensorFlowReshapeOperator : Operator {
    824   TensorFlowReshapeOperator() : Operator(OperatorType::kTensorFlowReshape) {}
    825   std::vector<int> shape;
    826 };
    827 
    828 // Removes dimensions of size 1 from the shape of a tensor.
    829 // https://www.tensorflow.org/api_docs/python/tf/squeeze
    830 //
    831 // Inputs:
    832 //   inputs[0]: required: the input array
    833 //
    834 // TensorFlow equivalent: Squeeze
    835 struct SqueezeOperator : Operator {
    836   SqueezeOperator() : Operator(OperatorType::kSqueeze) {}
    837 
    838   std::vector<int> squeeze_dims;
    839 };
    840 
    841 // Inputs:
    842 //   inputs[0]: required: the input activations array
    843 //   inputs[1]: required: the Conv weights
    844 //   channel.
    845 //
    846 // Outputs:
    847 //   outputs[0]: required: the output activations array
    848 //
    849 // TensorFlow equivalent: Conv2DBackpropInput
    850 struct TransposeConvOperator : Operator {
    851   TransposeConvOperator() : Operator(OperatorType::kTransposeConv) {}
    852   Padding padding;
    853   int stride_width = 0;
    854   int stride_height = 0;
    855 };
    856 
    857 // Given a tensor input, this operation calculates element-wise exponential
    858 // (y = e^x).
    859 //
    860 // Inputs:
    861 //   inputs[0]: required: input tensor
    862 //
    863 // TensorFlow equivalent: Exp
    864 struct ExpOperator : Operator {
    865   ExpOperator() : Operator(OperatorType::kExp) {}
    866 };
    867 
    868 // Given a tensor input, this operation inserts a dimension of 1 at the
    869 // dimension index axis of input's shape. The dimension index axis starts at
    870 // zero; if you specify a negative number for axis it is counted backward from
    871 // the end.
    872 //
    873 // Inputs:
    874 //   inputs[0]: required: input tensor
    875 //   inputs[1]: required: 0-D (scalar). Specifies the dimension index at which
    876 //   to expand the shape of input
    877 //
    878 // TensorFlow equivalent: ExpandDims
    879 struct ExpandDimsOperator : Operator {
    880   ExpandDimsOperator() : Operator(OperatorType::kExpandDims) {}
    881 };
    882 
    883 // Ceates a tensor of shape dims and fills it with the given scalar value.
    884 // Output type will be the same as the given scalar value.
    885 //
    886 // Inputs:
    887 //   inputs[0]: required: 1-D (int32) - the shape of the output tensor
    888 //   inputs[1]: required: 0-D (scalar) - value to fill the tensor with
    889 //
    890 // TensorFlow equivalent: Fill
    891 struct FillOperator : Operator {
    892   FillOperator() : Operator(OperatorType::kFill) {}
    893 };
    894 
    895 // Element-wise floor division operator.
    896 //
    897 // Inputs:
    898 //   inputs[0]: required: the left-hand side array
    899 //   inputs[1]: required: the right-hand side array
    900 //
    901 // TensorFlow equivalent: FloorDiv
    902 struct FloorDivOperator : Operator {
    903   FloorDivOperator() : Operator(OperatorType::kFloorDiv) {}
    904 };
    905 
    906 // Element-wise floor mod operator.
    907 //
    908 // Inputs:
    909 //   inputs[0]: required: the left-hand side array
    910 //   inputs[1]: required: the right-hand side array
    911 //
    912 // TensorFlow equivalent: FloorMod
    913 struct FloorModOperator : Operator {
    914   FloorModOperator() : Operator(OperatorType::kFloorMod) {}
    915 };
    916 
    917 // Creates a sequence of numbers that begins at start and extends by increments
    918 // of delta up to but not including limit.
    919 //
    920 // The dtype of the resulting tensor is inferred from the inputs unless it is
    921 // provided explicitly.
    922 //
    923 // Inputs:
    924 //   inputs[0]: required: the start
    925 //   inputs[1]: required: the limit
    926 //   inputs[2]: required: the delta
    927 //
    928 // TensorFlow equivalent: Range
    929 struct RangeOperator : Operator {
    930   RangeOperator() : Operator(OperatorType::kRange) {}
    931   ArrayDataType dtype = ArrayDataType::kNone;
    932 };
    933 
    934 // Rank operator. Extracts the rank of the tensor.
    935 //
    936 // Inputs:
    937 //   inputs[0]: required: the input array
    938 //
    939 // This operation outputs a 0-D integer tensor representing the rank of
    940 // the input.
    941 //
    942 // TensorFlow equivalent: Rank.  We currently assume that the output is int32
    943 // and not int64.  The output type could be stored herein.
    944 struct RankOperator : Operator {
    945   RankOperator() : Operator(OperatorType::kRank) {}
    946 };
    947 
    948 // Element-wise negation (-x) operator.
    949 //
    950 // Inputs:
    951 //   inputs[0]: required: the input array
    952 //
    953 // TensorFlow equivalent: Neg
    954 struct NegOperator : Operator {
    955   NegOperator() : Operator(OperatorType::kNeg) {}
    956 };
    957 
    958 // Element-wise reciprocal-square-root (x^-0.5) operator.
    959 //
    960 // Inputs:
    961 //   inputs[0]: required: the input array
    962 //
    963 // TensorFlow equivalent: Rsqrt
    964 struct TensorFlowRsqrtOperator : Operator {
    965   TensorFlowRsqrtOperator() : Operator(OperatorType::kTensorFlowRsqrt) {}
    966 };
    967 
    968 // Stacks a list of rank-R tensors into one rank-(R+1) tensor.
    969 //
    970 // Packs the list of tensors in values into a tensor with rank one higher than
    971 // each tensor in values, by packing them along the axis dimension. Given a list
    972 // of length N of tensors of shape (A, B, C);.
    973 //
    974 // Inputs: this operator accepts any number >= 1 of inputs.
    975 //   inputs[i]: the i-th array to merge.
    976 //
    977 // TensorFlow equivalent: Stack or Pack
    978 struct StackOperator : Operator {
    979   StackOperator() : Operator(OperatorType::kStack) {}
    980   int axis = 0;
    981 };
    982 
    983 // Shape operator. Extracts the shape of the tensor.
    984 //
    985 // Inputs:
    986 //   inputs[0]: required: the input array
    987 //
    988 // This operation outputs a 1-D integer tensor representing the shape of
    989 // the input.
    990 //
    991 // TensorFlow equivalent: Shape.  We currently assume that the output is int32
    992 // and not int64.  The output type could be stored herein.
    993 struct TensorFlowShapeOperator : Operator {
    994   TensorFlowShapeOperator() : Operator(OperatorType::kTensorFlowShape) {}
    995 };
    996 
    997 // Element-wise square-root (x^0.5) operator.
    998 //
    999 // Inputs:
   1000 //   inputs[0]: required: the input array
   1001 //
   1002 // TensorFlow equivalent: Sqrt
   1003 struct TensorFlowSqrtOperator : Operator {
   1004   TensorFlowSqrtOperator() : Operator(OperatorType::kTensorFlowSqrt) {}
   1005 };
   1006 
   1007 // Element-wise square (x*x) operator.
   1008 //
   1009 // Inputs:
   1010 //   inputs[0]: required: the input array
   1011 //
   1012 // TensorFlow equivalent: Square
   1013 struct TensorFlowSquareOperator : Operator {
   1014   TensorFlowSquareOperator() : Operator(OperatorType::kTensorFlowSquare) {}
   1015 };
   1016 
   1017 // Transposes a tensor.
   1018 //
   1019 // By default, this operation performs a regular matrix transpose on 2-D input
   1020 // tensors.
   1021 //
   1022 // Inputs:
   1023 //   inputs[0]: required: the input array
   1024 //
   1025 // TensorFlow equivalent: Transpose
   1026 struct TransposeOperator : Operator {
   1027   TransposeOperator() : Operator(OperatorType::kTranspose) {}
   1028   std::vector<int> perm;
   1029 };
   1030 
   1031 // Element-wise subtraction operator.
   1032 //
   1033 // Inputs:
   1034 //   inputs[0]: required: the left-hand side array
   1035 //   inputs[1]: required: the right-hand side array
   1036 //
   1037 // TensorFlow equivalent: Sub
   1038 struct SubOperator : Operator {
   1039   SubOperator() : Operator(OperatorType::kSub) {}
   1040 };
   1041 
   1042 // Global sum reduction: computes the sum of all of entries in the input array.
   1043 // Thus the output is "0-dimensional": it consists of a single scalar value.
   1044 //
   1045 // Inputs:
   1046 //   inputs[0]: required: the input array
   1047 //
   1048 // TensorFlow equivalent: Sum --- except that we only support the special case
   1049 // of global reduction across all dimensions.
   1050 struct TensorFlowSumOperator : Operator {
   1051   TensorFlowSumOperator() : Operator(OperatorType::kTensorFlowSum) {}
   1052   bool keep_dims = false;
   1053 };
   1054 
   1055 // TensorFlow Tile equivalent. Refer to TensorFlow documentation for details.
   1056 // Not fully supported, just a placeholder to handle TensorFlow graphs and
   1057 // support graph transformations to other operator types by matching sub-graphs.
   1058 struct TensorFlowTileOperator : Operator {
   1059   TensorFlowTileOperator() : Operator(OperatorType::kTensorFlowTile) {}
   1060 };
   1061 
   1062 // TensorFlow Slice equivalent. Refer to TensorFlow documentation for details.
   1063 struct SliceOperator : Operator {
   1064   SliceOperator() : Operator(OperatorType::kSlice) {}
   1065 
   1066   std::vector<int> begin;
   1067   std::vector<int> size;
   1068 };
   1069 
   1070 // TensorFlow Split equivalent. Refer to TensorFlow documentation for details.
   1071 // Not fully supported, just a placeholder to handle TensorFlow graphs and
   1072 // support graph transformations to other operator types by matching sub-graphs.
   1073 struct TensorFlowSplitOperator : Operator {
   1074   TensorFlowSplitOperator() : Operator(OperatorType::kTensorFlowSplit) {}
   1075   int num_split = 0;
   1076 };
   1077 
   1078 // TensorFlow Concat equivalent. Refer to TensorFlow documentation for details.
   1079 // Not fully supported, just a placeholder to handle TensorFlow graphs and
   1080 // support graph transformations to other operator types by matching sub-graphs.
   1081 // Concretely, once the concat dim becomes known, if it is the depth
   1082 // dimension then we can change this op into a DepthConcatenation op.
   1083 // Otherwise, we hope for some other graph transformation to drop this node.
   1084 struct TensorFlowConcatOperator : Operator {
   1085   TensorFlowConcatOperator() : Operator(OperatorType::kTensorFlowConcat) {}
   1086 };
   1087 
   1088 // TensorFlow ConcatV2 equivalent. Refer to TensorFlow documentation for
   1089 // details.
   1090 // Not fully supported, just a placeholder to handle TensorFlow graphs and
   1091 // support graph transformations to other operator types by matching sub-graphs.
   1092 // Concretely, once the concat dim becomes known, if it is the depth
   1093 // dimension then we can change this op into a DepthConcatenation op.
   1094 // Otherwise, we hope for some other graph transformation to drop this node.
   1095 struct TensorFlowConcatV2Operator : Operator {
   1096   TensorFlowConcatV2Operator() : Operator(OperatorType::kTensorFlowConcatV2) {}
   1097 };
   1098 
   1099 // TensorFlow Merge equivalent. Refer to TensorFlow documentation for details.
   1100 //
   1101 // Inputs: this operator accepts any number >= 1 of inputs.
   1102 //   inputs[i]: the i-th array to merge.
   1103 //
   1104 // It is expected that graph transformations will drop all but exactly one
   1105 // of the inputs, at which point the Merge node will be equivalent to an
   1106 // Identity node forwarding the remaining input.
   1107 //
   1108 // Note: We do not currently support runtime control flow: we only support
   1109 // control flow that can be resolved at tooling time (independently of input
   1110 // activations).
   1111 struct TensorFlowMergeOperator : Operator {
   1112   TensorFlowMergeOperator() : Operator(OperatorType::kTensorFlowMerge) {}
   1113 };
   1114 
   1115 // TensorFlow Switch equivalent. Refer to TensorFlow documentation for details.
   1116 //
   1117 // Inputs:
   1118 //   inputs[0]: required: the input array
   1119 //   inputs[1]: required: the boolean predicate, given as an array of size 1
   1120 //     and of type kBool, will determine which output gets selected.
   1121 //
   1122 // Outputs: a TensorFlow Switch node always has exactly two outputs. Depending
   1123 // on the boolean value that the input predicate resolves to (see note below),
   1124 // one or the other of the outputs will be 'selected': the input array will be
   1125 // forwarded to the 'selected output' as if by a Identity node, while the other
   1126 // output will be discarded, and any graph edge connecting that discarded output
   1127 // will be dropped. The rule for selecting outputs is as follows:
   1128 //   outputs[0] will be selected if the input predicate resolves to 'true'.
   1129 //   outputs[1] will be selected if the input predicate resolves to 'false'.
   1130 //
   1131 // Note: We do not currently support runtime control flow: we only support
   1132 // control flow that can be resolved at tooling time (independently of input
   1133 // activations).
   1134 struct TensorFlowSwitchOperator : Operator {
   1135   TensorFlowSwitchOperator() : Operator(OperatorType::kTensorFlowSwitch) {}
   1136 };
   1137 
   1138 // TensorFlow All equivalent. Refer to TensorFlow documentation for details.
   1139 // Not fully supported, just a placeholder to handle TensorFlow graphs and
   1140 // support graph transformations to other operator types by matching sub-graphs.
   1141 // Typically, this is only used as an input to an Assert node, so can be
   1142 // removed as an unused node as we drop Assert nodes.
   1143 struct TensorFlowAllOperator : Operator {
   1144   TensorFlowAllOperator() : Operator(OperatorType::kTensorFlowAll) {}
   1145 };
   1146 
   1147 // TensorFlow Assert equivalent. Refer to TensorFlow documentation for details.
   1148 // Not fully supported, just a placeholder to handle TensorFlow graphs and
   1149 // support graph transformations to other operator types by matching sub-graphs.
   1150 // Typically, we just drop Assert nodes.
   1151 struct TensorFlowAssertOperator : Operator {
   1152   TensorFlowAssertOperator() : Operator(OperatorType::kTensorFlowAssert) {}
   1153 };
   1154 
   1155 // TensorFlow Less equivalent. Refer to TensorFlow documentation for details.
   1156 // Not fully supported, just a placeholder to handle TensorFlow graphs and
   1157 // support graph transformations to other operator types by matching sub-graphs.
   1158 // Typically, this is only used as an input to an Assert node, so can be
   1159 // removed as an unused node as we drop Assert nodes.
   1160 struct TensorFlowLessOperator : Operator {
   1161   TensorFlowLessOperator() : Operator(OperatorType::kTensorFlowLess) {}
   1162 };
   1163 
   1164 // TensorFlow LessEqual equivalent. Refer to TensorFlow documentation for
   1165 // details.
   1166 // Not fully supported, just a placeholder to handle TensorFlow graphs and
   1167 // support graph transformations to other operator types by matching sub-graphs.
   1168 // Typically, this is only used as an input to an Assert node, so can be
   1169 // removed as an unused node as we drop Assert nodes.
   1170 struct TensorFlowLessEqualOperator : Operator {
   1171   TensorFlowLessEqualOperator()
   1172       : Operator(OperatorType::kTensorFlowLessEqual) {}
   1173 };
   1174 
   1175 // TensorFlow Less equivalent. Refer to TensorFlow documentation for details.
   1176 // Not fully supported, just a placeholder to handle TensorFlow graphs and
   1177 // support graph transformations to other operator types by matching sub-graphs.
   1178 // Typically, this is only used as an input to an Assert node, so can be
   1179 // removed as an unused node as we drop Assert nodes.
   1180 struct TensorFlowGreaterOperator : Operator {
   1181   TensorFlowGreaterOperator() : Operator(OperatorType::kTensorFlowGreater) {}
   1182 };
   1183 
   1184 // TensorFlow GreaterEqual equivalent. Refer to TensorFlow documentation for
   1185 // details.
   1186 // Not fully supported, just a placeholder to handle TensorFlow graphs and
   1187 // support graph transformations to other operator types by matching sub-graphs.
   1188 // Typically, this is only used as an input to an Assert node, so can be
   1189 // removed as an unused node as we drop Assert nodes.
   1190 struct TensorFlowGreaterEqualOperator : Operator {
   1191   TensorFlowGreaterEqualOperator()
   1192       : Operator(OperatorType::kTensorFlowGreaterEqual) {}
   1193 };
   1194 
   1195 // Global max reduction: computes the max of all of entries in the input array.
   1196 // Thus the output is "0-dimensional": it consists of a single scalar value.
   1197 //
   1198 // Inputs:
   1199 //   inputs[0]: required: the input array
   1200 //
   1201 // TensorFlow equivalent: Max --- except that we only support the special case
   1202 // of global reduction across all dimensions.
   1203 struct TensorFlowMaxOperator : Operator {
   1204   TensorFlowMaxOperator() : Operator(OperatorType::kTensorFlowMax) {}
   1205   bool keep_dims = false;
   1206 };
   1207 
   1208 // Global min reduction: computes the min of all of entries in the input array.
   1209 // Thus the output is "0-dimensional": it consists of a single scalar value.
   1210 //
   1211 // Inputs:
   1212 //   inputs[0]: required: the input array
   1213 //
   1214 // TensorFlow equivalent: Min --- except that we only support the special case
   1215 // of global reduction across all dimensions.
   1216 struct TensorFlowMinOperator : Operator {
   1217   TensorFlowMinOperator() : Operator(OperatorType::kTensorFlowMin) {}
   1218   bool keep_dims = false;
   1219 };
   1220 
   1221 // Element-wise maximum operator. Currently it only supports scalar as
   1222 // the second operand.
   1223 //
   1224 // Inputs:
   1225 //   inputs[0]: required: the left-hand side array
   1226 //   inputs[1]: required: the right-hand side array
   1227 //
   1228 // TensorFlow equivalent: Maximum
   1229 struct TensorFlowMaximumOperator : Operator {
   1230   TensorFlowMaximumOperator() : Operator(OperatorType::kTensorFlowMaximum) {}
   1231 };
   1232 
   1233 // Element-wise minimum operator. Currently it only supports scalar as
   1234 // the second operand.
   1235 //
   1236 // Inputs:
   1237 //   inputs[0]: required: the left-hand side array
   1238 //   inputs[1]: required: the right-hand side array
   1239 //
   1240 // TensorFlow equivalent: Minimum
   1241 struct TensorFlowMinimumOperator : Operator {
   1242   TensorFlowMinimumOperator() : Operator(OperatorType::kTensorFlowMinimum) {}
   1243 };
   1244 
   1245 // General TF operation, unsupported by tf.mini. Expected to be dropped by
   1246 // graph transformations.
   1247 struct TensorFlowUnsupportedOperator : Operator {
   1248   TensorFlowUnsupportedOperator()
   1249       : Operator(OperatorType::kTensorFlowUnsupported) {}
   1250 
   1251   // The original TF operation type. Used for diagnostic purposes.
   1252   string tensorflow_op;
   1253   // A serialized tensorflow::NodeDef string.
   1254   string tensorflow_node_def;
   1255   // A boolean indicating if the unsupported op should be treated as quantized.
   1256   bool quantized = false;
   1257   // Output data types
   1258   std::vector<ArrayDataType> output_data_types;
   1259 };
   1260 
   1261 // Softmax activation function.
   1262 //
   1263 // Inputs:
   1264 //   inputs[0]: required: the input array
   1265 //
   1266 // TensorFlow equivalent: Softmax
   1267 struct SoftmaxOperator : Operator {
   1268   SoftmaxOperator() : Operator(OperatorType::kSoftmax) {}
   1269   float beta = 0.f;
   1270 };
   1271 
   1272 // LogSoftmax activation function.
   1273 //
   1274 // Inputs:
   1275 //   inputs[0]: required: the logits input array
   1276 //
   1277 // TensorFlow equivalent: LogSoftmax
   1278 struct LogSoftmaxOperator : Operator {
   1279   LogSoftmaxOperator() : Operator(OperatorType::kLogSoftmax) {}
   1280 };
   1281 
   1282 // Cast operator.
   1283 //
   1284 // Inputs:
   1285 //   inputs[0]: required: the input array
   1286 //
   1287 // TensorFlow equivalent: Cast
   1288 struct CastOperator : Operator {
   1289   CastOperator() : Operator(OperatorType::kCast) {}
   1290   ArrayDataType src_data_type = ArrayDataType::kNone;
   1291   ArrayDataType dst_data_type = ArrayDataType::kNone;
   1292 };
   1293 
   1294 // Floor operator.
   1295 //
   1296 // Inputs:
   1297 //   inputs[0]: required: the input array
   1298 //
   1299 // TensorFlow equivalent: Floor
   1300 struct FloorOperator : Operator {
   1301   FloorOperator() : Operator(OperatorType::kFloor) {}
   1302 };
   1303 
   1304 // Gather operator. It gathers slices from params according to indices.
   1305 // Only 1-D indices are supported at the moment.
   1306 //
   1307 // Inputs:
   1308 //   inputs[0]: required: the params array
   1309 //   inputs[1]: required: the indices to gather
   1310 //
   1311 // TensorFlow equivalent: Gather
   1312 struct GatherOperator : Operator {
   1313   GatherOperator() : Operator(OperatorType::kGather) {}
   1314   int axis = 0;
   1315   int input_rank = 0;
   1316 };
   1317 
   1318 // ArgMax operator. It returns the index of the maximum value along axis.
   1319 //
   1320 // Inputs:
   1321 //   inputs[0]: required: the input tensor
   1322 //
   1323 // TensorFlow equivalent: ArgMax
   1324 struct ArgMaxOperator : Operator {
   1325   ArgMaxOperator() : Operator(OperatorType::kArgMax) {}
   1326   ArrayDataType output_data_type = ArrayDataType::kInt64;
   1327 };
   1328 
   1329 // ResizeBilinear operator. It resizes input images with bilinear interpolation.
   1330 // It does not support align_corners at the moment.
   1331 //
   1332 // Inputs:
   1333 //   inputs[0]: required: the input array
   1334 //   inputs[1]: required: the new image size
   1335 //
   1336 // TensorFlow equivalent: ResizeBilinear
   1337 struct ResizeBilinearOperator : Operator {
   1338   ResizeBilinearOperator() : Operator(OperatorType::kResizeBilinear) {}
   1339 
   1340   bool align_corners = false;
   1341 };
   1342 
   1343 // SpaceToBatchND operator. It divides spatial dimensions into a grid of
   1344 // blocks and interleaves these blocks with the batch dimension. Currently,
   1345 // only 2-d blocks are supported.
   1346 //
   1347 // Inputs:
   1348 //   inputs[0]: required: the input array
   1349 //   inputs[1]: required: the block shape
   1350 //   inputs[2]: required: the paddings
   1351 //
   1352 // TensorFlow equivalent: SpaceToBatchND
   1353 struct SpaceToBatchNDOperator : Operator {
   1354   SpaceToBatchNDOperator() : Operator(OperatorType::kSpaceToBatchND) {}
   1355 
   1356   std::vector<int> block_shape;
   1357   std::vector<int> before_paddings;
   1358   std::vector<int> after_paddings;
   1359 };
   1360 
   1361 // BatchToSpaceND operator. Rearranges data from batch into blocks of
   1362 // spatial data. Currently, only 2-d blocks are supported. Cropping is not
   1363 // supported, either, and the crops array should be all zero.
   1364 //
   1365 // Inputs:
   1366 //   inputs[0]: required: the input array
   1367 //   inputs[1]: required: the block shape
   1368 //   inputs[2]: required: the crops
   1369 //
   1370 // TensorFlow equivalent: BatchToSpaceND
   1371 struct BatchToSpaceNDOperator : Operator {
   1372   BatchToSpaceNDOperator() : Operator(OperatorType::kBatchToSpaceND) {}
   1373 
   1374   std::vector<int> block_shape;
   1375   std::vector<int> before_crops;
   1376   std::vector<int> after_crops;
   1377 };
   1378 
   1379 // Mean operator.
   1380 //
   1381 // Inputs:
   1382 //   inputs[0]: required: the input array
   1383 //
   1384 // TensorFlow equivalent: Mean
   1385 struct MeanOperator : Operator {
   1386   MeanOperator() : Operator(OperatorType::kMean) {}
   1387 
   1388   std::vector<int> axis;
   1389   bool keep_dims = false;
   1390 };
   1391 
   1392 // Svdf operator:
   1393 //
   1394 // Inputs:
   1395 //   inputs[0]: required: the input array
   1396 //   inputs[1]: required: weights_feature
   1397 //   inputs[2]: required: weights_time
   1398 //   inputs[3]: optional: bias
   1399 struct SvdfOperator : Operator {
   1400   SvdfOperator() : Operator(OperatorType::kSvdf) {}
   1401   int rank;
   1402 };
   1403 
   1404 // TopKV2 operator.
   1405 //
   1406 // Inputs:
   1407 //    input tensor and top_k scalar.
   1408 struct TopKV2Operator : Operator {
   1409   TopKV2Operator() : Operator(OperatorType::kTopK_V2) {}
   1410 };
   1411 
   1412 // Alloc's are used for transient arrays only. An Alloc specifies which interval
   1413 // of the "transient_data" workspace buffer passed to inference functions, is to
   1414 // be used for the transient array at hand. The 'start' and 'end' values are
   1415 // offsets from the start of the workspace buffer, expressed in bytes.
   1416 struct Alloc {
   1417   int start = 0;
   1418   int end = 0;
   1419 };
   1420 
   1421 inline bool operator<(const Alloc& a, const Alloc& b) {
   1422   return a.start < b.start;
   1423 }
   1424 
   1425 // Quantization parameters, determining the mapping of quantized values
   1426 // to real values (i.e. determining how quantized values are mathematically
   1427 // interpreted).
   1428 //
   1429 // The correspondence is as follows:
   1430 //
   1431 //   real_value = scale * (quantized_value - zero_point);
   1432 //
   1433 // In other words, zero_point designates which quantized value corresponds to
   1434 // the real 0 value, and scale designates the difference between the real values
   1435 // corresponding to consecutive quantized values differing by 1.
   1436 struct QuantizationParams {
   1437   int32 zero_point = 0;
   1438   double scale = 0.;
   1439 };
   1440 
   1441 class Shape {
   1442  public:
   1443   // For Shape, we stick to half-way encapsulation for now:
   1444   // we hide the raw dims_ member, but expose it raw by accessors
   1445   // because from some brainstorming, it's not at all easy to
   1446   // anticipate which flavor of more hermetic encapsulation would
   1447   // actually buy us future-proof-ness without being needlessly
   1448   // cumbersome.
   1449   Shape() {}
   1450   Shape(std::initializer_list<int> dim_list) : dims_(dim_list) {}
   1451 
   1452   void ReplaceDims(std::initializer_list<int> dim_list) {
   1453     dims_ = std::vector<int>(dim_list);
   1454   }
   1455 
   1456   const std::vector<int>& dims() const { return dims_; }
   1457   std::vector<int>* mutable_dims() { return &dims_; }
   1458   const int dimensions_count() const { return dims_.size(); }
   1459 
   1460   // We still have that one convenience accessor to avoid
   1461   // the awkward double bracket issue:  shape.dims()[i].
   1462   int dims(int i) const { return dims_[i]; }
   1463 
   1464   bool operator==(const Shape& comp) const {
   1465     return (this->dims_ == comp.dims());
   1466   }
   1467 
   1468   bool operator!=(const Shape& comp) const { return !((*this) == comp); }
   1469 
   1470  private:
   1471   std::vector<int> dims_;
   1472 };
   1473 
   1474 // Array represents an array (either a constant parameter array or an
   1475 // activations array) in a Model.
   1476 struct Array {
   1477   template <ArrayDataType A>
   1478   const Buffer<A>& GetBuffer() const {
   1479     DCHECK(buffer);
   1480     DCHECK(buffer->type == A);
   1481     return *static_cast<const Buffer<A>*>(buffer.get());
   1482   }
   1483   template <ArrayDataType A>
   1484   Buffer<A>& GetMutableBuffer() {
   1485     if (!buffer) {
   1486       Buffer<A>* ptr = new Buffer<A>;
   1487       buffer = std::unique_ptr<GenericBuffer>(ptr);
   1488     }
   1489     DCHECK(buffer);
   1490     DCHECK(buffer->type == A);
   1491     return *static_cast<Buffer<A>*>(buffer.get());
   1492   }
   1493   Alloc& GetOrCreateAlloc() {
   1494     if (!alloc) {
   1495       alloc = std::unique_ptr<Alloc>(new Alloc);
   1496     }
   1497     return *alloc;
   1498   }
   1499   MinMax& GetOrCreateMinMax() {
   1500     if (!minmax) {
   1501       minmax = std::unique_ptr<MinMax>(new MinMax);
   1502     }
   1503     return *minmax;
   1504   }
   1505   MinMax& GetMinMax() const {
   1506     DCHECK(minmax);
   1507     return *minmax;
   1508   }
   1509   QuantizationParams& GetOrCreateQuantizationParams() {
   1510     if (!quantization_params) {
   1511       quantization_params =
   1512           std::unique_ptr<QuantizationParams>(new QuantizationParams);
   1513     }
   1514     return *quantization_params;
   1515   }
   1516   QuantizationParams& GetQuantizationParams() const {
   1517     DCHECK(quantization_params);
   1518     return *quantization_params;
   1519   }
   1520 
   1521   // The data type of the actual elements of this array, that is:
   1522   //  - If there is a buffer (see 'buffer' member), it must be of the same
   1523   //    type.
   1524   //  - If there is no buffer, meaning that this is a runtime (i.e. activations)
   1525   //    array, then this specifies the type of elements that there will be
   1526   //    at runtime.
   1527   //
   1528   // Note that this only specifies the storage type of elements; this does
   1529   // not specify whether these are to be treated as 'real' or 'quantized'
   1530   // values.
   1531   // That is decided by whether the 'quantization_params' member is null.
   1532   ArrayDataType data_type = ArrayDataType::kNone;
   1533   // The final value that data_type should have at the end of graph
   1534   // transformations
   1535   ArrayDataType final_data_type = ArrayDataType::kNone;
   1536   // The dimensions of this array --- this specifies both sizes and strides
   1537   // (the storage layout).
   1538   //
   1539   // Issues with shape handling that remain include:
   1540   //   - No way to distinguish between 0-dimensional dims and missing dims.
   1541   //   - No way to describe dims that may be runtime-variable.
   1542   //   - Addressing of dims by integer index differs in different graph formats
   1543   //     (TensorFlow vs. other frameworks vs. what we have informally grown
   1544   //     within toco).
   1545   //     This is currently quite messy; see ReorderAxesOperator which is how we
   1546   //     bridge some of these discrepancies at the moment. This is overdue for
   1547   //     a redesign; I'm thinking that it would be nice to have more flexible
   1548   //     dims that allow mapping 1:1, cleanly, dims as they are in various
   1549   //     formats,
   1550   //     then explicitly convert between different conventions.
   1551 
   1552   // Proto-style accessors
   1553   bool has_shape() const { return array_shape != nullptr; }
   1554   const Shape& shape() const {
   1555     CHECK(has_shape());
   1556     return *array_shape;
   1557   }
   1558   Shape* mutable_shape() {
   1559     if (!array_shape) {
   1560       array_shape.reset(new Shape);
   1561     }
   1562     return array_shape.get();
   1563   }
   1564   void copy_shape(const Shape& src_shape) { *mutable_shape() = src_shape; }
   1565   void clear_shape() { array_shape = nullptr; }
   1566 
   1567   // The constant buffer backing this array. This is non-null if and only if
   1568   // this is a constant parameter array. Conversely, this is null for
   1569   // activations arrays.
   1570   //
   1571   // Note that this buffer is pure storage. In the case of quantized values,
   1572   // it only stores the quantized values, it does not know by itself about the
   1573   // quantization parameters necessary to interprete these values, that is
   1574   // in the separate 'quantization_params' field. In fact, this 'buffer' field
   1575   // does no even know whether values are quantized. It only has a data_type,
   1576   // which must equal the 'data_type' member here, and which only describes
   1577   // the storage type of element, does not tell whether they are quantized i.e.
   1578   // whether they are to be interpreted with quantization_params.
   1579   std::unique_ptr<GenericBuffer> buffer;
   1580   // Only for activation arrays (i.e. when 'buffer' is null).
   1581   // Only for code generation.
   1582   //
   1583   // Describes the allocation of this array within the workspace buffer
   1584   // allocated
   1585   // for all transient arrays.
   1586   std::unique_ptr<Alloc> alloc;
   1587   // Describes the [min, max] range of values
   1588   // to be assumed when determining quantization_params.
   1589   //
   1590   // Only used for quantization. In fact, only used for determining
   1591   // quantization_params.
   1592   //
   1593   // Used for both constant arrays (those having a 'buffer') and non-constant
   1594   // arrays (activations). Indeed, it is important to use the same min-max range
   1595   // as was used during training, even if that min-max range is slightly wrong
   1596   // w.r.t. actual buffer elements. Doing otherwise would defeat the point of
   1597   // re-training for quantization.
   1598   std::unique_ptr<MinMax> minmax;
   1599   // Quantization parameters. The non-null-ness of this pointer is what
   1600   // defines whether this array is quantized or not.
   1601   //
   1602   // If this is non-null, then these quantization parameters are to be used
   1603   // to assign a meaning as real numbers to the elements of this array.
   1604   std::unique_ptr<QuantizationParams> quantization_params;
   1605 
   1606  private:
   1607   std::unique_ptr<Shape> array_shape;
   1608 };
   1609 
   1610 // Our Model struct, represents an entire model (our "top-level" struct).
   1611 // Owns everything.
   1612 class Model {
   1613  public:
   1614   using ArrayMap = std::unordered_map<string, std::unique_ptr<Array>>;
   1615 
   1616   bool HasArray(const string& name) const { return arrays.count(name) > 0; }
   1617   Array& GetArray(const string& name) const {
   1618     DCHECK(HasArray(name)) << "Array not found: " << name;
   1619     return *arrays.at(name);
   1620   }
   1621   Array& GetOrCreateArray(const string& name) {
   1622     // Make sure name is not used by an optional array
   1623     DCHECK(!optional_arrays.count(name));
   1624     if (!HasArray(name)) {
   1625       Array* ptr = new Array;
   1626       arrays[name] = std::unique_ptr<Array>(ptr);
   1627     }
   1628     Array& result = GetArray(name);
   1629     return result;
   1630   }
   1631   void CreateOptionalArray(const string& name) {
   1632     DCHECK(!arrays.count(name) && !optional_arrays.count(name));
   1633     optional_arrays.insert(name);
   1634   }
   1635   bool IsOptionalArray(const string& name) const {
   1636     return optional_arrays.count(name);
   1637   }
   1638 
   1639   // Note that this invalidates all array iterators.
   1640   void EraseArray(const string& name) { arrays.erase(name); }
   1641   void EraseArrays(std::function<bool(const string&)> discardable) {
   1642     for (auto it = arrays.begin(); it != arrays.end();) {
   1643       if (discardable(it->first)) {
   1644         it = arrays.erase(it);
   1645       } else {
   1646         ++it;
   1647       }
   1648     }
   1649   }
   1650   const ArrayMap& GetArrayMap() const { return arrays; }
   1651 
   1652   // Optional arrays are used for optional tensors,
   1653   // these tensors do not have data, but with reserved names as op inputs.
   1654   std::set<string> optional_arrays;
   1655 
   1656   // The list of operators. Notice how it's a list of unique_ptr's, implying
   1657   // that the Model is what owns Operator's and keeps them alive.
   1658   std::vector<std::unique_ptr<Operator>> operators;
   1659 
   1660   // Generic flags, a place where we combine information passed to us via
   1661   // command-line parameters (e.g. --input_width=N) with information that
   1662   // we may or may not find in the input model file.
   1663   ModelFlags flags;
   1664   // For code-generation only: required size of the transient_data buffer
   1665   std::size_t transient_data_size = 0;
   1666   // For code-generation only: required alignment of the transient_data buffer
   1667   std::size_t transient_data_alignment = 0;
   1668 
   1669  private:
   1670   // The associative array mapping names to Array's.
   1671   // Notice how it's a container of unique_ptr's, implying
   1672   // that the Model is what owns Array's and keeps them alive.
   1673   // The Operator's refer to these Array's by their name strings, not by their
   1674   // addresses. See Operator::inputs, Operator::outputs.
   1675   std::unordered_map<string, std::unique_ptr<Array>> arrays;
   1676 };
   1677 }  // namespace toco
   1678 
   1679 #endif  // TENSORFLOW_CONTRIB_LITE_TOCO_MODEL_H_
   1680