Home | History | Annotate | Download | only in common
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #ifndef LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_H_
     18 #define LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_H_
     19 
     20 #include <memory>
     21 #include <vector>
     22 
     23 #include "common/embedding-network-params.h"
     24 #include "common/feature-extractor.h"
     25 #include "common/vector-span.h"
     26 #include "util/base/integral_types.h"
     27 #include "util/base/logging.h"
     28 #include "util/base/macros.h"
     29 
     30 namespace libtextclassifier {
     31 namespace nlp_core {
     32 
     33 // Classifier using a hand-coded feed-forward neural network.
     34 //
     35 // No gradient computation, just inference.
     36 //
     37 // Classification works as follows:
     38 //
     39 // Discrete features -> Embeddings -> Concatenation -> Hidden+ -> Softmax
     40 //
     41 // In words: given some discrete features, this class extracts the embeddings
     42 // for these features, concatenates them, passes them through one or two hidden
     43 // layers (each layer uses Relu) and next through a softmax layer that computes
     44 // an unnormalized score for each possible class.  Note: there is always a
     45 // softmax layer.
     46 class EmbeddingNetwork {
     47  public:
     48   // Class used to represent an embedding matrix.  Each row is the embedding on
     49   // a vocabulary element.  Number of columns = number of embedding dimensions.
     50   class EmbeddingMatrix {
     51    public:
     52     explicit EmbeddingMatrix(const EmbeddingNetworkParams::Matrix source_matrix)
     53         : rows_(source_matrix.rows),
     54           cols_(source_matrix.cols),
     55           quant_type_(source_matrix.quant_type),
     56           data_(source_matrix.elements),
     57           row_size_in_bytes_(GetRowSizeInBytes(cols_, quant_type_)),
     58           quant_scales_(source_matrix.quant_scales) {}
     59 
     60     // Returns vocabulary size; one embedding for each vocabulary element.
     61     int size() const { return rows_; }
     62 
     63     // Returns number of weights in embedding of each vocabulary element.
     64     int dim() const { return cols_; }
     65 
     66     // Returns quantization type for this embedding matrix.
     67     QuantizationType quant_type() const { return quant_type_; }
     68 
     69     // Gets embedding for k-th vocabulary element: on return, sets *data to
     70     // point to the embedding weights and *scale to the quantization scale (1.0
     71     // if no quantization).
     72     void get_embedding(int k, const void **data, float *scale) const {
     73       if ((k < 0) || (k >= size())) {
     74         TC_LOG(ERROR) << "Index outside [0, " << size() << "): " << k;
     75 
     76         // In debug mode, crash.  In prod, pretend that k is 0.
     77         TC_DCHECK(false);
     78         k = 0;
     79       }
     80       *data = reinterpret_cast<const char *>(data_) + k * row_size_in_bytes_;
     81       if (quant_type_ == QuantizationType::NONE) {
     82         *scale = 1.0;
     83       } else {
     84         *scale = Float16To32(quant_scales_[k]);
     85       }
     86     }
     87 
     88    private:
     89     static int GetRowSizeInBytes(int cols, QuantizationType quant_type) {
     90       switch (quant_type) {
     91         case QuantizationType::NONE:
     92           return cols * sizeof(float);
     93         case QuantizationType::UINT8:
     94           return cols * sizeof(uint8);
     95         default:
     96           TC_LOG(ERROR) << "Unknown quant type: "
     97                         << static_cast<int>(quant_type);
     98           return 0;
     99       }
    100     }
    101 
    102     // Vocabulary size.
    103     const int rows_;
    104 
    105     // Number of elements in each embedding.
    106     const int cols_;
    107 
    108     const QuantizationType quant_type_;
    109 
    110     // Pointer to the embedding weights, in row-major order.  This is a pointer
    111     // to an array of floats / uint8, depending on the quantization type.
    112     // Not owned.
    113     const void *const data_;
    114 
    115     // Number of bytes for one row.  Used to jump to next row in data_.
    116     const int row_size_in_bytes_;
    117 
    118     // Pointer to quantization scales.  nullptr if no quantization.  Otherwise,
    119     // quant_scales_[i] is scale for embedding of i-th vocabulary element.
    120     const float16 *const quant_scales_;
    121 
    122     TC_DISALLOW_COPY_AND_ASSIGN(EmbeddingMatrix);
    123   };
    124 
    125   // An immutable vector that doesn't own the memory that stores the underlying
    126   // floats.  Can be used e.g., as a wrapper around model weights stored in the
    127   // static memory.
    128   class VectorWrapper {
    129    public:
    130     VectorWrapper() : VectorWrapper(nullptr, 0) {}
    131 
    132     // Constructs a vector wrapper around the size consecutive floats that start
    133     // at address data.  Note: the underlying data should be alive for at least
    134     // the lifetime of this VectorWrapper object.  That's trivially true if data
    135     // points to statically allocated data :)
    136     VectorWrapper(const float *data, int size) : data_(data), size_(size) {}
    137 
    138     int size() const { return size_; }
    139 
    140     const float *data() const { return data_; }
    141 
    142    private:
    143     const float *data_;  // Not owned.
    144     int size_;
    145 
    146     // Doesn't own anything, so it can be copied and assigned at will :)
    147   };
    148 
    149   typedef std::vector<VectorWrapper> Matrix;
    150   typedef std::vector<float> Vector;
    151 
    152   // Constructs an embedding network using the parameters from model.
    153   //
    154   // Note: model should stay alive for at least the lifetime of this
    155   // EmbeddingNetwork object.
    156   explicit EmbeddingNetwork(const EmbeddingNetworkParams *model);
    157 
    158   virtual ~EmbeddingNetwork() {}
    159 
    160   // Returns true if this EmbeddingNetwork object has been correctly constructed
    161   // and is ready to use.  Idea: in case of errors, mark this EmbeddingNetwork
    162   // object as invalid, but do not crash.
    163   bool is_valid() const { return valid_; }
    164 
    165   // Runs forward computation to fill scores with unnormalized output unit
    166   // scores. This is useful for making predictions.
    167   //
    168   // Returns true on success, false on error (e.g., if !is_valid()).
    169   bool ComputeFinalScores(const std::vector<FeatureVector> &features,
    170                           Vector *scores) const;
    171 
    172   // Same as above, but allows specification of extra neural network inputs that
    173   // will be appended to the embedding vector build from features.
    174   bool ComputeFinalScores(const std::vector<FeatureVector> &features,
    175                           const std::vector<float> extra_inputs,
    176                           Vector *scores) const;
    177 
    178   // Constructs the concatenated input embedding vector in place in output
    179   // vector concat.  Returns true on success, false on error.
    180   bool ConcatEmbeddings(const std::vector<FeatureVector> &features,
    181                         Vector *concat) const;
    182 
    183   // Sums embeddings for all features from |feature_vector| and adds result
    184   // to values from the array pointed-to by |output|.  Embeddings for continuous
    185   // features are weighted by the feature weight.
    186   //
    187   // NOTE: output should point to an array of EmbeddingSize(es_index) floats.
    188   bool GetEmbedding(const FeatureVector &feature_vector, int es_index,
    189                     float *embedding) const;
    190 
    191   // Runs the feed-forward neural network for |input| and computes logits for
    192   // softmax layer.
    193   bool ComputeLogits(const Vector &input, Vector *scores) const;
    194 
    195   // Same as above but uses a view of the feature vector.
    196   bool ComputeLogits(const VectorSpan<float> &input, Vector *scores) const;
    197 
    198   // Returns the size (the number of columns) of the embedding space es_index.
    199   int EmbeddingSize(int es_index) const;
    200 
    201  protected:
    202   // Builds an embedding for given feature vector, and places it from
    203   // concat_offset to the concat vector.
    204   bool GetEmbeddingInternal(const FeatureVector &feature_vector,
    205                             EmbeddingMatrix *embedding_matrix,
    206                             int concat_offset, float *concat,
    207                             int embedding_size) const;
    208 
    209   // Templated function that computes the logit scores given the concatenated
    210   // input embeddings.
    211   bool ComputeLogitsInternal(const VectorSpan<float> &concat,
    212                              Vector *scores) const;
    213 
    214   // Computes the softmax scores (prior to normalization) from the concatenated
    215   // representation.  Returns true on success, false on error.
    216   template <typename ScaleAdderClass>
    217   bool FinishComputeFinalScoresInternal(const VectorSpan<float> &concat,
    218                                         Vector *scores) const;
    219 
    220   // Set to true on successful construction, false otherwise.
    221   bool valid_ = false;
    222 
    223   // Network parameters.
    224 
    225   // One weight matrix for each embedding space.
    226   std::vector<std::unique_ptr<EmbeddingMatrix>> embedding_matrices_;
    227 
    228   // concat_offset_[i] is the input layer offset for i-th embedding space.
    229   std::vector<int> concat_offset_;
    230 
    231   // Size of the input ("concatenation") layer.
    232   int concat_layer_size_;
    233 
    234   // One weight matrix and one vector of bias weights for each hiden layer.
    235   std::vector<Matrix> hidden_weights_;
    236   std::vector<VectorWrapper> hidden_bias_;
    237 
    238   // Weight matrix and bias vector for the softmax layer.
    239   Matrix softmax_weights_;
    240   VectorWrapper softmax_bias_;
    241 };
    242 
    243 }  // namespace nlp_core
    244 }  // namespace libtextclassifier
    245 
    246 #endif  // LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_H_
    247