Home | History | Annotate | Download | only in utils
      1 /*
      2  * Copyright (C) 2018 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "utils/tflite-model-executor.h"
     18 
     19 #include "utils/base/logging.h"
     20 #include "tensorflow/lite/kernels/register.h"
     21 
     22 // Forward declaration of custom TensorFlow Lite ops for registration.
     23 namespace tflite {
     24 namespace ops {
     25 namespace builtin {
     26 TfLiteRegistration* Register_ADD();
     27 TfLiteRegistration* Register_CONCATENATION();
     28 TfLiteRegistration* Register_CONV_2D();
     29 TfLiteRegistration* Register_FULLY_CONNECTED();
     30 TfLiteRegistration* Register_L2_NORMALIZATION();
     31 TfLiteRegistration* Register_MUL();
     32 TfLiteRegistration* Register_RESHAPE();
     33 TfLiteRegistration* Register_SOFTMAX();
     34 TfLiteRegistration* Register_GATHER();
     35 TfLiteRegistration* Register_TRANSPOSE();
     36 TfLiteRegistration* Register_SUB();
     37 TfLiteRegistration* Register_DIV();
     38 TfLiteRegistration* Register_STRIDED_SLICE();
     39 TfLiteRegistration* Register_EXP();
     40 TfLiteRegistration* Register_TOPK_V2();
     41 TfLiteRegistration* Register_SPLIT();
     42 TfLiteRegistration* Register_CAST();
     43 TfLiteRegistration* Register_MAXIMUM();
     44 TfLiteRegistration* Register_MINIMUM();
     45 TfLiteRegistration* Register_NEG();
     46 TfLiteRegistration* Register_SLICE();
     47 TfLiteRegistration* Register_LOG();
     48 TfLiteRegistration* Register_SUM();
     49 TfLiteRegistration* Register_PACK();
     50 TfLiteRegistration* Register_DEQUANTIZE();
     51 TfLiteRegistration* Register_MEAN();
     52 }  // namespace builtin
     53 }  // namespace ops
     54 }  // namespace tflite
     55 
     56 #ifdef TC3_WITH_ACTIONS_OPS
     57 #include "utils/tflite/dist_diversification.h"
     58 #include "utils/tflite/text_encoder.h"
     59 #include "utils/tflite/token_encoder.h"
     60 
     61 void RegisterSelectedOps(tflite::MutableOpResolver* resolver) {
     62   resolver->AddBuiltin(tflite::BuiltinOperator_ADD,
     63                        tflite::ops::builtin::Register_ADD(),
     64                        /*min_version=*/1,
     65                        /*max_version=*/2);
     66   resolver->AddBuiltin(tflite::BuiltinOperator_CONCATENATION,
     67                        tflite::ops::builtin::Register_CONCATENATION(),
     68                        /*min_version=*/1,
     69                        /*max_version=*/2);
     70   resolver->AddBuiltin(tflite::BuiltinOperator_CONV_2D,
     71                        tflite::ops::builtin::Register_CONV_2D(),
     72                        /*min_version=*/1,
     73                        /*max_version=*/3);
     74   resolver->AddBuiltin(tflite::BuiltinOperator_FULLY_CONNECTED,
     75                        tflite::ops::builtin::Register_FULLY_CONNECTED(),
     76                        /*min_version=*/1,
     77                        /*max_version=*/4);
     78   resolver->AddBuiltin(tflite::BuiltinOperator_L2_NORMALIZATION,
     79                        tflite::ops::builtin::Register_L2_NORMALIZATION(),
     80                        /*min_version=*/1,
     81                        /*max_version=*/2);
     82   resolver->AddBuiltin(tflite::BuiltinOperator_MUL,
     83                        tflite::ops::builtin::Register_MUL());
     84   resolver->AddBuiltin(tflite::BuiltinOperator_RESHAPE,
     85                        tflite::ops::builtin::Register_RESHAPE());
     86   resolver->AddBuiltin(tflite::BuiltinOperator_SOFTMAX,
     87                        tflite::ops::builtin::Register_SOFTMAX(),
     88                        /*min_version=*/1,
     89                        /*max_version=*/2);
     90   resolver->AddBuiltin(tflite::BuiltinOperator_GATHER,
     91                        tflite::ops::builtin::Register_GATHER(),
     92                        /*min_version=*/1,
     93                        /*max_version=*/2);
     94   resolver->AddBuiltin(tflite::BuiltinOperator_TRANSPOSE,
     95                        tflite::ops::builtin::Register_TRANSPOSE(),
     96                        /*min_version=*/1,
     97                        /*max_version=*/2);
     98   resolver->AddBuiltin(tflite::BuiltinOperator_SUB,
     99                        tflite::ops::builtin::Register_SUB(),
    100                        /*min_version=*/1,
    101                        /*max_version=*/2);
    102   resolver->AddBuiltin(tflite::BuiltinOperator_DIV,
    103                        tflite::ops::builtin::Register_DIV());
    104   resolver->AddBuiltin(tflite::BuiltinOperator_STRIDED_SLICE,
    105                        tflite::ops::builtin::Register_STRIDED_SLICE(),
    106                        /*min_version=*/1,
    107                        /*max_version=*/2);
    108   resolver->AddBuiltin(tflite::BuiltinOperator_EXP,
    109                        tflite::ops::builtin::Register_EXP());
    110   resolver->AddBuiltin(tflite::BuiltinOperator_TOPK_V2,
    111                        tflite::ops::builtin::Register_TOPK_V2(),
    112                        /*min_version=*/1,
    113                        /*max_version=*/2);
    114   resolver->AddBuiltin(tflite::BuiltinOperator_SPLIT,
    115                        tflite::ops::builtin::Register_SPLIT(),
    116                        /*min_version=*/1,
    117                        /*max_version=*/3);
    118   resolver->AddBuiltin(tflite::BuiltinOperator_CAST,
    119                        tflite::ops::builtin::Register_CAST());
    120   resolver->AddBuiltin(tflite::BuiltinOperator_MAXIMUM,
    121                        tflite::ops::builtin::Register_MAXIMUM(),
    122                        /*min_version=*/1,
    123                        /*max_version=*/2);
    124   resolver->AddBuiltin(tflite::BuiltinOperator_MINIMUM,
    125                        tflite::ops::builtin::Register_MINIMUM(),
    126                        /*min_version=*/1,
    127                        /*max_version=*/2);
    128   resolver->AddBuiltin(tflite::BuiltinOperator_NEG,
    129                        tflite::ops::builtin::Register_NEG());
    130   resolver->AddBuiltin(tflite::BuiltinOperator_SLICE,
    131                        tflite::ops::builtin::Register_SLICE(),
    132                        /*min_version=*/1,
    133                        /*max_version=*/2);
    134   resolver->AddBuiltin(tflite::BuiltinOperator_LOG,
    135                        tflite::ops::builtin::Register_LOG());
    136   resolver->AddBuiltin(tflite::BuiltinOperator_SUM,
    137                        tflite::ops::builtin::Register_SUM());
    138   resolver->AddBuiltin(tflite::BuiltinOperator_PACK,
    139                        tflite::ops::builtin::Register_PACK(),
    140                        /*min_version=*/1,
    141                        /*max_version=*/2);
    142   resolver->AddBuiltin(tflite::BuiltinOperator_DEQUANTIZE,
    143                        tflite::ops::builtin::Register_DEQUANTIZE(),
    144                        /*min_version=*/1,
    145                        /*max_version=*/2);
    146   resolver->AddBuiltin(tflite::BuiltinOperator_MEAN,
    147                        tflite::ops::builtin::Register_MEAN());
    148 }
    149 #else
    150 void RegisterSelectedOps(tflite::MutableOpResolver* resolver) {
    151   resolver->AddBuiltin(tflite::BuiltinOperator_FULLY_CONNECTED,
    152                        tflite::ops::builtin::Register_FULLY_CONNECTED());
    153 }
    154 #endif  // TC3_WITH_ACTIONS_OPS
    155 
    156 namespace libtextclassifier3 {
    157 
    158 inline std::unique_ptr<tflite::OpResolver> BuildOpResolver() {
    159 #ifdef TC3_USE_SELECTIVE_REGISTRATION
    160   std::unique_ptr<tflite::MutableOpResolver> resolver(
    161       new tflite::MutableOpResolver);
    162   RegisterSelectedOps(resolver.get());
    163 #else
    164   std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver(
    165       new tflite::ops::builtin::BuiltinOpResolver);
    166 #endif
    167 #ifdef TC3_WITH_ACTIONS_OPS
    168   resolver->AddCustom("DistanceDiversification",
    169                       tflite::ops::custom::Register_DISTANCE_DIVERSIFICATION());
    170   resolver->AddCustom("TextEncoder",
    171                       tflite::ops::custom::Register_TEXT_ENCODER());
    172   resolver->AddCustom("TokenEncoder",
    173                       tflite::ops::custom::Register_TOKEN_ENCODER());
    174 #endif  // TC3_WITH_ACTIONS_OPS
    175   return std::unique_ptr<tflite::OpResolver>(std::move(resolver));
    176 }
    177 
    178 std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromModelSpec(
    179     const tflite::Model* model_spec) {
    180   std::unique_ptr<const tflite::FlatBufferModel> model(
    181       tflite::FlatBufferModel::BuildFromModel(model_spec));
    182   if (!model || !model->initialized()) {
    183     TC3_LOG(ERROR) << "Could not build TFLite model from a model spec.";
    184     return nullptr;
    185   }
    186   return model;
    187 }
    188 
    189 std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromBuffer(
    190     const flatbuffers::Vector<uint8_t>* model_spec_buffer) {
    191   const tflite::Model* model =
    192       flatbuffers::GetRoot<tflite::Model>(model_spec_buffer->data());
    193   flatbuffers::Verifier verifier(model_spec_buffer->data(),
    194                                  model_spec_buffer->Length());
    195   if (!model->Verify(verifier)) {
    196     return nullptr;
    197   }
    198   return TfLiteModelFromModelSpec(model);
    199 }
    200 
    201 TfLiteModelExecutor::TfLiteModelExecutor(
    202     std::unique_ptr<const tflite::FlatBufferModel> model)
    203     : model_(std::move(model)), resolver_(BuildOpResolver()) {}
    204 
    205 std::unique_ptr<tflite::Interpreter> TfLiteModelExecutor::CreateInterpreter()
    206     const {
    207   std::unique_ptr<tflite::Interpreter> interpreter;
    208   tflite::InterpreterBuilder(*model_, *resolver_)(&interpreter);
    209   return interpreter;
    210 }
    211 
    212 template <>
    213 void TfLiteModelExecutor::SetInput(const int input_index,
    214                                    const std::vector<std::string>& input_data,
    215                                    tflite::Interpreter* interpreter) const {
    216   tflite::DynamicBuffer buf;
    217   for (const std::string& s : input_data) {
    218     buf.AddString(s.data(), s.length());
    219   }
    220   buf.WriteToTensorAsVector(
    221       interpreter->tensor(interpreter->inputs()[input_index]));
    222 }
    223 
    224 template <>
    225 std::vector<tflite::StringRef> TfLiteModelExecutor::Output(
    226     const int output_index, const tflite::Interpreter* interpreter) const {
    227   const TfLiteTensor* output_tensor =
    228       interpreter->tensor(interpreter->outputs()[output_index]);
    229   const int num_strings = tflite::GetStringCount(output_tensor);
    230   std::vector<tflite::StringRef> output(num_strings);
    231   for (int i = 0; i < num_strings; i++) {
    232     output[i] = tflite::GetString(output_tensor, i);
    233   }
    234   return output;
    235 }
    236 
    237 template <>
    238 std::vector<std::string> TfLiteModelExecutor::Output(
    239     const int output_index, const tflite::Interpreter* interpreter) const {
    240   std::vector<std::string> output;
    241   for (const tflite::StringRef& s :
    242        Output<tflite::StringRef>(output_index, interpreter)) {
    243     output.push_back(std::string(s.str, s.len));
    244   }
    245   return output;
    246 }
    247 
    248 }  // namespace libtextclassifier3
    249