Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/lite/kernels/register_ref.h"
     17 #include "tensorflow/lite/util.h"
     18 
     19 namespace tflite {
     20 namespace ops {
     21 
     22 namespace custom {
     23 
     24 TfLiteRegistration* Register_AUDIO_SPECTROGRAM();
     25 TfLiteRegistration* Register_MFCC();
     26 TfLiteRegistration* Register_DETECTION_POSTPROCESS();
     27 
     28 }  // namespace custom
     29 
     30 namespace builtin {
     31 
     32 // TODO(yunluli): Some of the registries, e.g. Tanh(), could only invoke
     33 // optimized kernels. Add a _REF() variant for them.
     34 TfLiteRegistration* Register_ABS();
     35 TfLiteRegistration* Register_RELU();
     36 TfLiteRegistration* Register_RELU_N1_TO_1();
     37 TfLiteRegistration* Register_RELU6();
     38 TfLiteRegistration* Register_TANH_REF();
     39 TfLiteRegistration* Register_LOGISTIC_REF();
     40 TfLiteRegistration* Register_AVERAGE_POOL_REF();
     41 TfLiteRegistration* Register_MAX_POOL_REF();
     42 TfLiteRegistration* Register_L2_POOL_REF();
     43 TfLiteRegistration* Register_CONVOLUTION_REF();
     44 TfLiteRegistration* Register_DEPTHWISE_CONVOLUTION_REF();
     45 TfLiteRegistration* Register_SVDF();
     46 TfLiteRegistration* Register_RNN();
     47 TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_RNN();
     48 TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_RNN();
     49 TfLiteRegistration* Register_EMBEDDING_LOOKUP();
     50 TfLiteRegistration* Register_EMBEDDING_LOOKUP_SPARSE();
     51 TfLiteRegistration* Register_FULLY_CONNECTED_REF();
     52 TfLiteRegistration* Register_LSH_PROJECTION();
     53 TfLiteRegistration* Register_HASHTABLE_LOOKUP();
     54 TfLiteRegistration* Register_SOFTMAX();
     55 TfLiteRegistration* Register_CONCATENATION_REF();
     56 TfLiteRegistration* Register_ADD_REF();
     57 TfLiteRegistration* Register_SPACE_TO_BATCH_ND_REF();
     58 TfLiteRegistration* Register_DIV_REF();
     59 TfLiteRegistration* Register_SUB_REF();
     60 TfLiteRegistration* Register_BATCH_TO_SPACE_ND_REF();
     61 TfLiteRegistration* Register_MUL_REF();
     62 TfLiteRegistration* Register_L2NORM_REF();
     63 TfLiteRegistration* Register_LOCAL_RESPONSE_NORM_REF();
     64 TfLiteRegistration* Register_LSTM();
     65 TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_LSTM();
     66 TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_LSTM();
     67 TfLiteRegistration* Register_PAD_REF();
     68 TfLiteRegistration* Register_PADV2_REF();
     69 TfLiteRegistration* Register_RESHAPE();
     70 TfLiteRegistration* Register_RESIZE_BILINEAR_REF();
     71 TfLiteRegistration* Register_RESIZE_NEAREST_NEIGHBOR_REF();
     72 TfLiteRegistration* Register_SKIP_GRAM();
     73 TfLiteRegistration* Register_SPACE_TO_DEPTH_REF();
     74 TfLiteRegistration* Register_GATHER();
     75 TfLiteRegistration* Register_TRANSPOSE_REF();
     76 TfLiteRegistration* Register_MEAN_REF();
     77 TfLiteRegistration* Register_SPLIT();
     78 TfLiteRegistration* Register_SPLIT_V();
     79 TfLiteRegistration* Register_SQUEEZE();
     80 TfLiteRegistration* Register_STRIDED_SLICE_REF();
     81 TfLiteRegistration* Register_EXP();
     82 TfLiteRegistration* Register_TOPK_V2();
     83 TfLiteRegistration* Register_LOG();
     84 TfLiteRegistration* Register_LOG_SOFTMAX_REF();
     85 TfLiteRegistration* Register_CAST();
     86 TfLiteRegistration* Register_DEQUANTIZE();
     87 TfLiteRegistration* Register_PRELU();
     88 TfLiteRegistration* Register_MAXIMUM();
     89 TfLiteRegistration* Register_MINIMUM();
     90 TfLiteRegistration* Register_ARG_MAX();
     91 TfLiteRegistration* Register_ARG_MIN();
     92 TfLiteRegistration* Register_GREATER();
     93 TfLiteRegistration* Register_GREATER_EQUAL();
     94 TfLiteRegistration* Register_LESS();
     95 TfLiteRegistration* Register_LESS_EQUAL();
     96 TfLiteRegistration* Register_FLOOR_REF();
     97 TfLiteRegistration* Register_TILE();
     98 TfLiteRegistration* Register_NEG();
     99 TfLiteRegistration* Register_SUM();
    100 TfLiteRegistration* Register_REDUCE_PROD();
    101 TfLiteRegistration* Register_REDUCE_MAX();
    102 TfLiteRegistration* Register_REDUCE_MIN();
    103 TfLiteRegistration* Register_REDUCE_ANY();
    104 TfLiteRegistration* Register_SELECT();
    105 TfLiteRegistration* Register_SLICE_REF();
    106 TfLiteRegistration* Register_SIN();
    107 TfLiteRegistration* Register_TRANSPOSECONV_REF();
    108 TfLiteRegistration* Register_EXPAND_DIMS();
    109 TfLiteRegistration* Register_SPARSE_TO_DENSE();
    110 TfLiteRegistration* Register_EQUAL();
    111 TfLiteRegistration* Register_NOT_EQUAL();
    112 TfLiteRegistration* Register_SQRT();
    113 TfLiteRegistration* Register_RSQRT();
    114 TfLiteRegistration* Register_SHAPE();
    115 TfLiteRegistration* Register_POW();
    116 TfLiteRegistration* Register_FAKE_QUANT();
    117 TfLiteRegistration* Register_PACK();
    118 TfLiteRegistration* Register_ONE_HOT();
    119 TfLiteRegistration* Register_LOGICAL_OR();
    120 TfLiteRegistration* Register_LOGICAL_AND();
    121 TfLiteRegistration* Register_LOGICAL_NOT();
    122 TfLiteRegistration* Register_UNPACK();
    123 TfLiteRegistration* Register_FLOOR_DIV();
    124 TfLiteRegistration* Register_SQUARE();
    125 TfLiteRegistration* Register_ZEROS_LIKE();
    126 TfLiteRegistration* Register_FLOOR_MOD();
    127 TfLiteRegistration* Register_RANGE();
    128 TfLiteRegistration* Register_LEAKY_RELU();
    129 TfLiteRegistration* Register_SQUARED_DIFFERENCE();
    130 TfLiteRegistration* Register_FILL();
    131 TfLiteRegistration* Register_MIRROR_PAD();
    132 
    133 namespace {
    134 
    135 TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) {
    136   context->ReportError(
    137       context,
    138       "Regular TensorFlow ops are not supported by this interpreter. Make sure "
    139       "you invoke the Flex delegate before inference.");
    140   return kTfLiteError;
    141 }
    142 
    143 }  // namespace
    144 
    145 const TfLiteRegistration* BuiltinRefOpResolver::FindOp(
    146     tflite::BuiltinOperator op, int version) const {
    147   return MutableOpResolver::FindOp(op, version);
    148 }
    149 
    150 const TfLiteRegistration* BuiltinRefOpResolver::FindOp(const char* op,
    151                                                        int version) const {
    152   // Return the NULL Op for all ops whose name start with "Flex", allowing
    153   // the interpreter to delegate their execution.
    154   if (IsFlexOp(op)) {
    155     static TfLiteRegistration null_op{
    156         nullptr, nullptr, &UnsupportedTensorFlowOp,
    157         nullptr, nullptr, BuiltinOperator_CUSTOM,
    158         "Flex",  1};
    159     return &null_op;
    160   }
    161   return MutableOpResolver::FindOp(op, version);
    162 }
    163 
    164 BuiltinRefOpResolver::BuiltinRefOpResolver() {
    165   AddBuiltin(BuiltinOperator_ABS, Register_ABS());
    166   AddBuiltin(BuiltinOperator_RELU, Register_RELU());
    167   AddBuiltin(BuiltinOperator_RELU_N1_TO_1, Register_RELU_N1_TO_1());
    168   AddBuiltin(BuiltinOperator_RELU6, Register_RELU6());
    169   AddBuiltin(BuiltinOperator_TANH, Register_TANH_REF());
    170   AddBuiltin(BuiltinOperator_LOGISTIC, Register_LOGISTIC_REF());
    171   AddBuiltin(BuiltinOperator_AVERAGE_POOL_2D, Register_AVERAGE_POOL_REF());
    172   AddBuiltin(BuiltinOperator_MAX_POOL_2D, Register_MAX_POOL_REF());
    173   AddBuiltin(BuiltinOperator_L2_POOL_2D, Register_L2_POOL_REF());
    174   AddBuiltin(BuiltinOperator_CONV_2D, Register_CONVOLUTION_REF());
    175   AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D,
    176              Register_DEPTHWISE_CONVOLUTION_REF(),
    177              /* min_version */ 1,
    178              /* max_version */ 2);
    179   AddBuiltin(BuiltinOperator_SVDF, Register_SVDF());
    180   AddBuiltin(BuiltinOperator_RNN, Register_RNN());
    181   AddBuiltin(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN,
    182              Register_BIDIRECTIONAL_SEQUENCE_RNN());
    183   AddBuiltin(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN,
    184              Register_UNIDIRECTIONAL_SEQUENCE_RNN());
    185   AddBuiltin(BuiltinOperator_EMBEDDING_LOOKUP, Register_EMBEDDING_LOOKUP());
    186   AddBuiltin(BuiltinOperator_EMBEDDING_LOOKUP_SPARSE,
    187              Register_EMBEDDING_LOOKUP_SPARSE());
    188   AddBuiltin(BuiltinOperator_FULLY_CONNECTED, Register_FULLY_CONNECTED_REF(),
    189              /* min_version */ 1,
    190              /* max_version */ 2);
    191   AddBuiltin(BuiltinOperator_LSH_PROJECTION, Register_LSH_PROJECTION());
    192   AddBuiltin(BuiltinOperator_HASHTABLE_LOOKUP, Register_HASHTABLE_LOOKUP());
    193   AddBuiltin(BuiltinOperator_SOFTMAX, Register_SOFTMAX());
    194   AddBuiltin(BuiltinOperator_CONCATENATION, Register_CONCATENATION_REF());
    195   AddBuiltin(BuiltinOperator_ADD, Register_ADD_REF());
    196   AddBuiltin(BuiltinOperator_SPACE_TO_BATCH_ND,
    197              Register_SPACE_TO_BATCH_ND_REF());
    198   AddBuiltin(BuiltinOperator_BATCH_TO_SPACE_ND,
    199              Register_BATCH_TO_SPACE_ND_REF());
    200   AddBuiltin(BuiltinOperator_MUL, Register_MUL_REF());
    201   AddBuiltin(BuiltinOperator_L2_NORMALIZATION, Register_L2NORM_REF());
    202   AddBuiltin(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION,
    203              Register_LOCAL_RESPONSE_NORM_REF());
    204   AddBuiltin(BuiltinOperator_LSTM, Register_LSTM(), /* min_version */ 1,
    205              /* max_version */ 2);
    206   AddBuiltin(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM,
    207              Register_BIDIRECTIONAL_SEQUENCE_LSTM(), /* min_version */ 1,
    208              /* max_version */ 2);
    209   AddBuiltin(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
    210              Register_UNIDIRECTIONAL_SEQUENCE_LSTM());
    211   AddBuiltin(BuiltinOperator_PAD, Register_PAD_REF());
    212   AddBuiltin(BuiltinOperator_PADV2, Register_PADV2_REF());
    213   AddBuiltin(BuiltinOperator_RESHAPE, Register_RESHAPE());
    214   AddBuiltin(BuiltinOperator_RESIZE_BILINEAR, Register_RESIZE_BILINEAR_REF());
    215   AddBuiltin(BuiltinOperator_RESIZE_NEAREST_NEIGHBOR,
    216              Register_RESIZE_NEAREST_NEIGHBOR_REF());
    217   AddBuiltin(BuiltinOperator_SKIP_GRAM, Register_SKIP_GRAM());
    218   AddBuiltin(BuiltinOperator_SPACE_TO_DEPTH, Register_SPACE_TO_DEPTH_REF());
    219   AddBuiltin(BuiltinOperator_GATHER, Register_GATHER());
    220   AddBuiltin(BuiltinOperator_TRANSPOSE, Register_TRANSPOSE_REF());
    221   AddBuiltin(BuiltinOperator_MEAN, Register_MEAN_REF());
    222   AddBuiltin(BuiltinOperator_DIV, Register_DIV_REF());
    223   AddBuiltin(BuiltinOperator_SUB, Register_SUB_REF());
    224   AddBuiltin(BuiltinOperator_SPLIT, Register_SPLIT());
    225   AddBuiltin(BuiltinOperator_SPLIT_V, Register_SPLIT_V());
    226   AddBuiltin(BuiltinOperator_SQUEEZE, Register_SQUEEZE());
    227   AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE_REF());
    228   AddBuiltin(BuiltinOperator_EXP, Register_EXP());
    229   AddBuiltin(BuiltinOperator_TOPK_V2, Register_TOPK_V2());
    230   AddBuiltin(BuiltinOperator_LOG, Register_LOG());
    231   AddBuiltin(BuiltinOperator_LOG_SOFTMAX, Register_LOG_SOFTMAX_REF());
    232   AddBuiltin(BuiltinOperator_CAST, Register_CAST());
    233   AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE(),
    234              /* min_version */ 1,
    235              /* max_version */ 2);
    236   AddBuiltin(BuiltinOperator_PRELU, Register_PRELU());
    237   AddBuiltin(BuiltinOperator_MAXIMUM, Register_MAXIMUM());
    238   AddBuiltin(BuiltinOperator_MINIMUM, Register_MINIMUM());
    239   AddBuiltin(BuiltinOperator_ARG_MAX, Register_ARG_MAX());
    240   AddBuiltin(BuiltinOperator_ARG_MIN, Register_ARG_MIN());
    241   AddBuiltin(BuiltinOperator_GREATER, Register_GREATER());
    242   AddBuiltin(BuiltinOperator_GREATER_EQUAL, Register_GREATER_EQUAL());
    243   AddBuiltin(BuiltinOperator_LESS, Register_LESS());
    244   AddBuiltin(BuiltinOperator_LESS_EQUAL, Register_LESS_EQUAL());
    245   AddBuiltin(BuiltinOperator_FLOOR, Register_FLOOR_REF());
    246   AddBuiltin(BuiltinOperator_NEG, Register_NEG());
    247   AddBuiltin(BuiltinOperator_SELECT, Register_SELECT());
    248   AddBuiltin(BuiltinOperator_SLICE, Register_SLICE_REF());
    249   AddBuiltin(BuiltinOperator_SIN, Register_SIN());
    250   AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSECONV_REF());
    251   AddBuiltin(BuiltinOperator_TILE, Register_TILE());
    252   AddBuiltin(BuiltinOperator_SUM, Register_SUM());
    253   AddBuiltin(BuiltinOperator_REDUCE_PROD, Register_REDUCE_PROD());
    254   AddBuiltin(BuiltinOperator_REDUCE_MAX, Register_REDUCE_MAX());
    255   AddBuiltin(BuiltinOperator_REDUCE_MIN, Register_REDUCE_MIN());
    256   AddBuiltin(BuiltinOperator_REDUCE_ANY, Register_REDUCE_ANY());
    257   AddBuiltin(BuiltinOperator_EXPAND_DIMS, Register_EXPAND_DIMS());
    258   AddBuiltin(BuiltinOperator_SPARSE_TO_DENSE, Register_SPARSE_TO_DENSE());
    259   AddBuiltin(BuiltinOperator_EQUAL, Register_EQUAL());
    260   AddBuiltin(BuiltinOperator_NOT_EQUAL, Register_NOT_EQUAL());
    261   AddBuiltin(BuiltinOperator_SQRT, Register_SQRT());
    262   AddBuiltin(BuiltinOperator_RSQRT, Register_RSQRT());
    263   AddBuiltin(BuiltinOperator_SHAPE, Register_SHAPE());
    264   AddBuiltin(BuiltinOperator_POW, Register_POW());
    265   AddBuiltin(BuiltinOperator_FAKE_QUANT, Register_FAKE_QUANT(), 1, 2);
    266   AddBuiltin(BuiltinOperator_PACK, Register_PACK());
    267   AddBuiltin(BuiltinOperator_ONE_HOT, Register_ONE_HOT());
    268   AddBuiltin(BuiltinOperator_LOGICAL_OR, Register_LOGICAL_OR());
    269   AddBuiltin(BuiltinOperator_LOGICAL_AND, Register_LOGICAL_AND());
    270   AddBuiltin(BuiltinOperator_LOGICAL_NOT, Register_LOGICAL_NOT());
    271   AddBuiltin(BuiltinOperator_UNPACK, Register_UNPACK());
    272   AddBuiltin(BuiltinOperator_FLOOR_DIV, Register_FLOOR_DIV());
    273   AddBuiltin(BuiltinOperator_SQUARE, Register_SQUARE());
    274   AddBuiltin(BuiltinOperator_ZEROS_LIKE, Register_ZEROS_LIKE());
    275   AddBuiltin(BuiltinOperator_FLOOR_MOD, Register_FLOOR_MOD());
    276   AddBuiltin(BuiltinOperator_RANGE, Register_RANGE());
    277   AddBuiltin(BuiltinOperator_LEAKY_RELU, Register_LEAKY_RELU());
    278   AddBuiltin(BuiltinOperator_SQUARED_DIFFERENCE, Register_SQUARED_DIFFERENCE());
    279   AddBuiltin(BuiltinOperator_FILL, Register_FILL());
    280   AddBuiltin(BuiltinOperator_MIRROR_PAD, Register_MIRROR_PAD());
    281 
    282   // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
    283   // custom ops aren't always included by default.
    284   // AddCustom("Mfcc", tflite::ops::custom::Register_MFCC());
    285   // AddCustom("AudioSpectrogram",
    286   //          tflite::ops::custom::Register_AUDIO_SPECTROGRAM());
    287   AddCustom("TFLite_Detection_PostProcess",
    288             tflite::ops::custom::Register_DETECTION_POSTPROCESS());
    289 }
    290 
    291 }  // namespace builtin
    292 }  // namespace ops
    293 }  // namespace tflite
    294