Home | History | Annotate | Download | only in annotator
      1 /*
      2  * Copyright (C) 2018 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 // Contains classes that can execute different models/parts of a model.
     18 
     19 #ifndef LIBTEXTCLASSIFIER_ANNOTATOR_MODEL_EXECUTOR_H_
     20 #define LIBTEXTCLASSIFIER_ANNOTATOR_MODEL_EXECUTOR_H_
     21 
     22 #include <memory>
     23 
     24 #include "annotator/types.h"
     25 #include "utils/base/logging.h"
     26 #include "utils/tensor-view.h"
     27 #include "utils/tflite-model-executor.h"
     28 
     29 namespace libtextclassifier3 {
     30 
     31 // Executor for the text selection prediction and classification models.
     32 class ModelExecutor : public TfLiteModelExecutor {
     33  public:
     34   static std::unique_ptr<ModelExecutor> FromModelSpec(
     35       const tflite::Model* model_spec) {
     36     auto model = TfLiteModelFromModelSpec(model_spec);
     37     if (!model) {
     38       return nullptr;
     39     }
     40     return std::unique_ptr<ModelExecutor>(new ModelExecutor(std::move(model)));
     41   }
     42 
     43   static std::unique_ptr<ModelExecutor> FromBuffer(
     44       const flatbuffers::Vector<uint8_t>* model_spec_buffer) {
     45     auto model = TfLiteModelFromBuffer(model_spec_buffer);
     46     if (!model) {
     47       return nullptr;
     48     }
     49     return std::unique_ptr<ModelExecutor>(new ModelExecutor(std::move(model)));
     50   }
     51 
     52   TensorView<float> ComputeLogits(const TensorView<float>& features,
     53                                   tflite::Interpreter* interpreter) const;
     54 
     55  protected:
     56   explicit ModelExecutor(std::unique_ptr<const tflite::FlatBufferModel> model)
     57       : TfLiteModelExecutor(std::move(model)) {}
     58 
     59   static const int kInputIndexFeatures = 0;
     60   static const int kOutputIndexLogits = 0;
     61 };
     62 
     63 // Executor for embedding sparse features into a dense vector.
     64 class EmbeddingExecutor {
     65  public:
     66   virtual ~EmbeddingExecutor() {}
     67 
     68   // Embeds the sparse_features into a dense embedding and adds (+) it
     69   // element-wise to the dest vector.
     70   virtual bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
     71                             int dest_size) const = 0;
     72 
     73   // Returns true when the model is ready to be used, false otherwise.
     74   virtual bool IsReady() const { return true; }
     75 };
     76 
     77 class TFLiteEmbeddingExecutor : public EmbeddingExecutor {
     78  public:
     79   static std::unique_ptr<TFLiteEmbeddingExecutor> FromBuffer(
     80       const flatbuffers::Vector<uint8_t>* model_spec_buffer, int embedding_size,
     81       int quantization_bits,
     82       const Model_::EmbeddingPruningMask* embedding_pruning_mask = nullptr);
     83 
     84   // Embeds the sparse_features into a dense embedding and adds (+) it
     85   // element-wise to the dest vector.
     86   bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
     87                     int dest_size) const;
     88 
     89   // Auxiliary function for computing prefixes used in implementation of
     90   // efficient mask indexing data structure.
     91   void ComputePrefixCounts();
     92 
     93   // Function implementing mask indexing based on efficient data structure
     94   int PruneBucketId(int bucket_id) const;
     95 
     96  protected:
     97   explicit TFLiteEmbeddingExecutor(
     98       std::unique_ptr<TfLiteModelExecutor> executor, int quantization_bits,
     99       int num_buckets, int bytes_per_embedding, int output_embedding_size,
    100       const TfLiteTensor* scales, const TfLiteTensor* embeddings,
    101       std::unique_ptr<tflite::Interpreter> interpreter,
    102       const Model_::EmbeddingPruningMask* embedding_pruning_mask = nullptr);
    103 
    104   std::unique_ptr<TfLiteModelExecutor> executor_;
    105 
    106   int quantization_bits_;
    107   int num_buckets_ = -1;
    108   int bytes_per_embedding_ = -1;
    109   int output_embedding_size_ = -1;
    110   const TfLiteTensor* scales_ = nullptr;
    111   const TfLiteTensor* embeddings_ = nullptr;
    112 
    113   // NOTE: This interpreter is used in a read-only way (as a storage for the
    114   // model params), thus is still thread-safe.
    115   std::unique_ptr<tflite::Interpreter> interpreter_;
    116 
    117   std::vector<uint64> pruning_mask_;
    118   std::vector<uint16> prefix_counts_;
    119   int full_num_buckets_ = -1;
    120 
    121   // Index of row of embedding table corresponding to all pruned buckets.
    122   int pruned_row_bucket_id_ = -1;
    123 };
    124 
    125 }  // namespace libtextclassifier3
    126 
    127 #endif  // LIBTEXTCLASSIFIER_ANNOTATOR_MODEL_EXECUTOR_H_
    128