Home | History | Annotate | Download | only in utils
      1 /*
      2  * Copyright (C) 2018 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 // Contains classes that can execute different models/parts of a model.
     18 
     19 #ifndef LIBTEXTCLASSIFIER_UTILS_TFLITE_MODEL_EXECUTOR_H_
     20 #define LIBTEXTCLASSIFIER_UTILS_TFLITE_MODEL_EXECUTOR_H_
     21 
     22 #include <memory>
     23 
     24 #include "utils/base/logging.h"
     25 #include "utils/tensor-view.h"
     26 #include "tensorflow/lite/interpreter.h"
     27 #include "tensorflow/lite/kernels/register.h"
     28 #include "tensorflow/lite/model.h"
     29 #include "tensorflow/lite/op_resolver.h"
     30 #include "tensorflow/lite/string_util.h"
     31 
     32 namespace libtextclassifier3 {
     33 
     34 std::unique_ptr<tflite::OpResolver> BuildOpResolver();
     35 std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromModelSpec(
     36     const tflite::Model*);
     37 std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromBuffer(
     38     const flatbuffers::Vector<uint8_t>*);
     39 
     40 // Executor for the text selection prediction and classification models.
     41 class TfLiteModelExecutor {
     42  public:
     43   static std::unique_ptr<TfLiteModelExecutor> FromModelSpec(
     44       const tflite::Model* model_spec) {
     45     auto model = TfLiteModelFromModelSpec(model_spec);
     46     if (!model) {
     47       return nullptr;
     48     }
     49     return std::unique_ptr<TfLiteModelExecutor>(
     50         new TfLiteModelExecutor(std::move(model)));
     51   }
     52 
     53   static std::unique_ptr<TfLiteModelExecutor> FromBuffer(
     54       const flatbuffers::Vector<uint8_t>* model_spec_buffer) {
     55     auto model = TfLiteModelFromBuffer(model_spec_buffer);
     56     if (!model) {
     57       return nullptr;
     58     }
     59     return std::unique_ptr<TfLiteModelExecutor>(
     60         new TfLiteModelExecutor(std::move(model)));
     61   }
     62 
     63   // Creates an Interpreter for the model that serves as a scratch-pad for the
     64   // inference. The Interpreter is NOT thread-safe.
     65   std::unique_ptr<tflite::Interpreter> CreateInterpreter() const;
     66 
     67   template <typename T>
     68   void SetInput(const int input_index, const TensorView<T>& input_data,
     69                 tflite::Interpreter* interpreter) const {
     70     input_data.copy_to(interpreter->typed_input_tensor<T>(input_index),
     71                        input_data.size());
     72   }
     73 
     74   template <typename T>
     75   void SetInput(const int input_index, const std::vector<T>& input_data,
     76                 tflite::Interpreter* interpreter) const {
     77     std::copy(input_data.begin(), input_data.end(),
     78               interpreter->typed_input_tensor<T>(input_index));
     79   }
     80 
     81   template <typename T>
     82   void SetInput(const int input_index, const T input_value,
     83                 tflite::Interpreter* interpreter) const {
     84     TfLiteTensor* input_tensor =
     85         interpreter->tensor(interpreter->inputs()[input_index]);
     86     switch (input_tensor->type) {
     87       case kTfLiteFloat32:
     88         *(input_tensor->data.f) = input_value;
     89         break;
     90       case kTfLiteInt32:
     91         *(input_tensor->data.i32) = input_value;
     92         break;
     93       case kTfLiteUInt8:
     94         *(input_tensor->data.uint8) = input_value;
     95         break;
     96       case kTfLiteInt64:
     97         *(input_tensor->data.i64) = input_value;
     98         break;
     99       case kTfLiteBool:
    100         *(input_tensor->data.b) = input_value;
    101         break;
    102       case kTfLiteInt16:
    103         *(input_tensor->data.i16) = input_value;
    104         break;
    105       case kTfLiteInt8:
    106         *(input_tensor->data.int8) = input_value;
    107         break;
    108       default:
    109         break;
    110     }
    111   }
    112 
    113   template <typename T>
    114   TensorView<T> OutputView(const int output_index,
    115                            const tflite::Interpreter* interpreter) const {
    116     const TfLiteTensor* output_tensor =
    117         interpreter->tensor(interpreter->outputs()[output_index]);
    118     return TensorView<T>(interpreter->typed_output_tensor<T>(output_index),
    119                          std::vector<int>(output_tensor->dims->data,
    120                                           output_tensor->dims->data +
    121                                               output_tensor->dims->size));
    122   }
    123 
    124   template <typename T>
    125   std::vector<T> Output(const int output_index,
    126                         const tflite::Interpreter* interpreter) const {
    127     TensorView<T> output_view = OutputView<T>(output_index, interpreter);
    128     return std::vector<T>(output_view.data(),
    129                           output_view.data() + output_view.size());
    130   }
    131 
    132  protected:
    133   explicit TfLiteModelExecutor(
    134       std::unique_ptr<const tflite::FlatBufferModel> model);
    135 
    136   std::unique_ptr<const tflite::FlatBufferModel> model_;
    137   std::unique_ptr<tflite::OpResolver> resolver_;
    138 };
    139 
    140 template <>
    141 void TfLiteModelExecutor::SetInput(const int input_index,
    142                                    const std::vector<std::string>& input_data,
    143                                    tflite::Interpreter* interpreter) const;
    144 
    145 template <>
    146 std::vector<tflite::StringRef> TfLiteModelExecutor::Output(
    147     const int output_index, const tflite::Interpreter* interpreter) const;
    148 
    149 template <>
    150 std::vector<std::string> TfLiteModelExecutor::Output(
    151     const int output_index, const tflite::Interpreter* interpreter) const;
    152 
    153 }  // namespace libtextclassifier3
    154 
    155 #endif  // LIBTEXTCLASSIFIER_UTILS_TFLITE_MODEL_EXECUTOR_H_
    156