Home | History | Annotate | Download | only in common
      1 /*
      2  * Copyright (C) 2018 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_NETWORK_H_
     18 #define NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_NETWORK_H_
     19 
     20 #include <vector>
     21 
     22 #include "lang_id/common/embedding-network-params.h"
     23 #include "lang_id/common/fel/feature-extractor.h"
     24 
     25 namespace libtextclassifier3 {
     26 namespace mobile {
     27 
     28 // Classifier using a hand-coded feed-forward neural network.
     29 //
     30 // No gradient computation, just inference.
     31 //
     32 // Based on the more general nlp_saft::EmbeddingNetwork (without ::mobile).
     33 //
     34 // Classification works as follows:
     35 //
     36 // Discrete features -> Embeddings -> Concatenation -> Hidden+ -> Softmax
     37 //
     38 // In words: given some discrete features, this class extracts the embeddings
     39 // for these features, concatenates them, passes them through one or more hidden
     40 // layers (each layer uses Relu) and next through a softmax layer that computes
     41 // an unnormalized score for each possible class.  Note: there is always a
     42 // softmax layer at the end.
     43 class EmbeddingNetwork {
     44  public:
     45   // Constructs an embedding network using the parameters from model.
     46   //
     47   // Note: model should stay alive for at least the lifetime of this
     48   // EmbeddingNetwork object.
     49   explicit EmbeddingNetwork(const EmbeddingNetworkParams *model);
     50 
     51   virtual ~EmbeddingNetwork() {}
     52 
     53   // Runs forward computation to fill scores with unnormalized output unit
     54   // scores. This is useful for making predictions.
     55   void ComputeFinalScores(const std::vector<FeatureVector> &features,
     56                           std::vector<float> *scores) const;
     57 
     58   // Same as above, but allows specification of extra extra neural network
     59   // inputs that will be appended to the embedding vector build from features.
     60   void ComputeFinalScores(const std::vector<FeatureVector> &features,
     61                           const std::vector<float> &extra_inputs,
     62                           std::vector<float> *scores) const;
     63 
     64  private:
     65   // Constructs the concatenated input embedding vector in place in output
     66   // vector concat.
     67   void ConcatEmbeddings(const std::vector<FeatureVector> &features,
     68                         std::vector<float> *concat) const;
     69 
     70   // Pointer to the model object passed to the constructor.  Not owned.
     71   const EmbeddingNetworkParams *model_;
     72 
     73   // Network parameters.
     74 
     75   // One weight matrix for each embedding.
     76   std::vector<EmbeddingNetworkParams::Matrix> embedding_matrices_;
     77 
     78   // embedding_row_size_in_bytes_[i] is the size (in bytes) of a row from
     79   // embedding_matrices_[i].  We precompute this in order to quickly find the
     80   // beginning of the k-th row from an embedding matrix (which is stored in
     81   // row-major order).
     82   std::vector<int> embedding_row_size_in_bytes_;
     83 
     84   // concat_offset_[i] is the input layer offset for i-th embedding space.
     85   std::vector<int> concat_offset_;
     86 
     87   // Size of the input ("concatenation") layer.
     88   int concat_layer_size_ = 0;
     89 
     90   // One weight matrix and one vector of bias weights for each layer of neurons.
     91   // Last layer is the softmax layer, the previous ones are the hidden layers.
     92   std::vector<EmbeddingNetworkParams::Matrix> layer_weights_;
     93   std::vector<EmbeddingNetworkParams::Matrix> layer_bias_;
     94 };
     95 
     96 }  // namespace mobile
     97 }  // namespace nlp_saft
     98 
     99 #endif  // NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_NETWORK_H_
    100