Home | History | Annotate | Download | only in common
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "common/embedding-network.h"
     18 
     19 #include <math.h>
     20 
     21 #include "common/simple-adder.h"
     22 #include "util/base/integral_types.h"
     23 #include "util/base/logging.h"
     24 
     25 namespace libtextclassifier {
     26 namespace nlp_core {
     27 
     28 namespace {
     29 
     30 // Returns true if and only if matrix does not use any quantization.
     31 bool CheckNoQuantization(const EmbeddingNetworkParams::Matrix &matrix) {
     32   if (matrix.quant_type != QuantizationType::NONE) {
     33     TC_LOG(ERROR) << "Unsupported quantization";
     34     TC_DCHECK(false);  // Crash in debug mode.
     35     return false;
     36   }
     37   return true;
     38 }
     39 
     40 // Initializes a Matrix object with the parameters from the MatrixParams
     41 // source_matrix.  source_matrix should not use quantization.
     42 //
     43 // Returns true on success, false on error.
     44 bool InitNonQuantizedMatrix(const EmbeddingNetworkParams::Matrix &source_matrix,
     45                             EmbeddingNetwork::Matrix *mat) {
     46   mat->resize(source_matrix.rows);
     47 
     48   // Before we access the weights as floats, we need to check that they are
     49   // really floats, i.e., no quantization is used.
     50   if (!CheckNoQuantization(source_matrix)) return false;
     51   const float *weights =
     52       reinterpret_cast<const float *>(source_matrix.elements);
     53   for (int r = 0; r < source_matrix.rows; ++r) {
     54     (*mat)[r] = EmbeddingNetwork::VectorWrapper(weights, source_matrix.cols);
     55     weights += source_matrix.cols;
     56   }
     57   return true;
     58 }
     59 
     60 // Initializes a VectorWrapper object with the parameters from the MatrixParams
     61 // source_matrix.  source_matrix should have exactly one column and should not
     62 // use quantization.
     63 //
     64 // Returns true on success, false on error.
     65 bool InitNonQuantizedVector(const EmbeddingNetworkParams::Matrix &source_matrix,
     66                             EmbeddingNetwork::VectorWrapper *vector) {
     67   if (source_matrix.cols != 1) {
     68     TC_LOG(ERROR) << "wrong #cols " << source_matrix.cols;
     69     return false;
     70   }
     71   if (!CheckNoQuantization(source_matrix)) {
     72     TC_LOG(ERROR) << "unsupported quantization";
     73     return false;
     74   }
     75   // Before we access the weights as floats, we need to check that they are
     76   // really floats, i.e., no quantization is used.
     77   if (!CheckNoQuantization(source_matrix)) return false;
     78   const float *weights =
     79       reinterpret_cast<const float *>(source_matrix.elements);
     80   *vector = EmbeddingNetwork::VectorWrapper(weights, source_matrix.rows);
     81   return true;
     82 }
     83 
     84 // Computes y = weights * Relu(x) + b where Relu is optionally applied.
     85 template <typename ScaleAdderClass>
     86 bool SparseReluProductPlusBias(bool apply_relu,
     87                                const EmbeddingNetwork::Matrix &weights,
     88                                const EmbeddingNetwork::VectorWrapper &b,
     89                                const VectorSpan<float> &x,
     90                                EmbeddingNetwork::Vector *y) {
     91   // Check that dimensions match.
     92   if ((x.size() != weights.size()) || weights.empty()) {
     93     TC_LOG(ERROR) << x.size() << " != " << weights.size();
     94     return false;
     95   }
     96   if (weights[0].size() != b.size()) {
     97     TC_LOG(ERROR) << weights[0].size() << " != " << b.size();
     98     return false;
     99   }
    100 
    101   y->assign(b.data(), b.data() + b.size());
    102   ScaleAdderClass adder(y->data(), y->size());
    103 
    104   const int x_size = x.size();
    105   for (int i = 0; i < x_size; ++i) {
    106     const float &scale = x[i];
    107     if (apply_relu) {
    108       if (scale > 0) {
    109         adder.LazyScaleAdd(weights[i].data(), scale);
    110       }
    111     } else {
    112       adder.LazyScaleAdd(weights[i].data(), scale);
    113     }
    114   }
    115   return true;
    116 }
    117 }  // namespace
    118 
    119 bool EmbeddingNetwork::ConcatEmbeddings(
    120     const std::vector<FeatureVector> &feature_vectors, Vector *concat) const {
    121   concat->resize(concat_layer_size_);
    122 
    123   // Invariant 1: feature_vectors contains exactly one element for each
    124   // embedding space.  That element is itself a FeatureVector, which may be
    125   // empty, but it should be there.
    126   if (feature_vectors.size() != embedding_matrices_.size()) {
    127     TC_LOG(ERROR) << feature_vectors.size()
    128                   << " != " << embedding_matrices_.size();
    129     return false;
    130   }
    131 
    132   // "es_index" stands for "embedding space index".
    133   for (int es_index = 0; es_index < feature_vectors.size(); ++es_index) {
    134     // Access is safe by es_index loop bounds and Invariant 1.
    135     EmbeddingMatrix *const embedding_matrix =
    136         embedding_matrices_[es_index].get();
    137     if (embedding_matrix == nullptr) {
    138       // Should not happen, hence our terse log error message.
    139       TC_LOG(ERROR) << es_index;
    140       return false;
    141     }
    142 
    143     // Access is safe due to es_index loop bounds.
    144     const FeatureVector &feature_vector = feature_vectors[es_index];
    145 
    146     // Access is safe by es_index loop bounds, Invariant 1, and Invariant 2.
    147     const int concat_offset = concat_offset_[es_index];
    148 
    149     if (!GetEmbeddingInternal(feature_vector, embedding_matrix, concat_offset,
    150                               concat->data(), concat->size())) {
    151       TC_LOG(ERROR) << es_index;
    152       return false;
    153     }
    154   }
    155   return true;
    156 }
    157 
    158 bool EmbeddingNetwork::GetEmbedding(const FeatureVector &feature_vector,
    159                                     int es_index, float *embedding) const {
    160   EmbeddingMatrix *const embedding_matrix = embedding_matrices_[es_index].get();
    161   if (embedding_matrix == nullptr) {
    162     // Should not happen, hence our terse log error message.
    163     TC_LOG(ERROR) << es_index;
    164     return false;
    165   }
    166   return GetEmbeddingInternal(feature_vector, embedding_matrix, 0, embedding,
    167                               embedding_matrices_[es_index]->dim());
    168 }
    169 
    170 bool EmbeddingNetwork::GetEmbeddingInternal(
    171     const FeatureVector &feature_vector,
    172     EmbeddingMatrix *const embedding_matrix, const int concat_offset,
    173     float *concat, int concat_size) const {
    174   const int embedding_dim = embedding_matrix->dim();
    175   const bool is_quantized =
    176       embedding_matrix->quant_type() != QuantizationType::NONE;
    177   const int num_features = feature_vector.size();
    178   for (int fi = 0; fi < num_features; ++fi) {
    179     // Both accesses below are safe due to loop bounds for fi.
    180     const FeatureType *feature_type = feature_vector.type(fi);
    181     const FeatureValue feature_value = feature_vector.value(fi);
    182     const int feature_offset =
    183         concat_offset + feature_type->base() * embedding_dim;
    184 
    185     // Code below updates max(0, embedding_dim) elements from concat, starting
    186     // with index feature_offset.  Check below ensures these updates are safe.
    187     if ((feature_offset < 0) ||
    188         (feature_offset + embedding_dim > concat_size)) {
    189       TC_LOG(ERROR) << fi << ": " << feature_offset << " " << embedding_dim
    190                     << " " << concat_size;
    191       return false;
    192     }
    193 
    194     // Pointer to float / uint8 weights for relevant embedding.
    195     const void *embedding_data;
    196 
    197     // Multiplier for each embedding weight.
    198     float multiplier;
    199 
    200     if (feature_type->is_continuous()) {
    201       // Continuous features (encoded as FloatFeatureValue).
    202       FloatFeatureValue float_feature_value(feature_value);
    203       const int id = float_feature_value.id;
    204       embedding_matrix->get_embedding(id, &embedding_data, &multiplier);
    205       multiplier *= float_feature_value.weight;
    206     } else {
    207       // Discrete features: every present feature has implicit value 1.0.
    208       // Hence, after we grab the multiplier below, we don't multiply it by
    209       // any weight.
    210       embedding_matrix->get_embedding(feature_value, &embedding_data,
    211                                       &multiplier);
    212     }
    213 
    214     // Weighted embeddings will be added starting from this address.
    215     float *concat_ptr = concat + feature_offset;
    216 
    217     if (is_quantized) {
    218       const uint8 *quant_weights =
    219           reinterpret_cast<const uint8 *>(embedding_data);
    220       for (int i = 0; i < embedding_dim; ++i, ++quant_weights, ++concat_ptr) {
    221         // 128 is bias for UINT8 quantization, only one we currently support.
    222         *concat_ptr += (static_cast<int>(*quant_weights) - 128) * multiplier;
    223       }
    224     } else {
    225       const float *weights = reinterpret_cast<const float *>(embedding_data);
    226       for (int i = 0; i < embedding_dim; ++i, ++weights, ++concat_ptr) {
    227         *concat_ptr += *weights * multiplier;
    228       }
    229     }
    230   }
    231   return true;
    232 }
    233 
    234 bool EmbeddingNetwork::ComputeLogits(const VectorSpan<float> &input,
    235                                      Vector *scores) const {
    236   return EmbeddingNetwork::ComputeLogitsInternal(input, scores);
    237 }
    238 
    239 bool EmbeddingNetwork::ComputeLogits(const Vector &input,
    240                                      Vector *scores) const {
    241   return EmbeddingNetwork::ComputeLogitsInternal(input, scores);
    242 }
    243 
    244 bool EmbeddingNetwork::ComputeLogitsInternal(const VectorSpan<float> &input,
    245                                              Vector *scores) const {
    246   return FinishComputeFinalScoresInternal<SimpleAdder>(input, scores);
    247 }
    248 
    249 template <typename ScaleAdderClass>
    250 bool EmbeddingNetwork::FinishComputeFinalScoresInternal(
    251     const VectorSpan<float> &input, Vector *scores) const {
    252   // This vector serves as an alternating storage for activations of the
    253   // different layers. We can't use just one vector here because all of the
    254   // activations of  the previous layer are needed for computation of
    255   // activations of the next one.
    256   std::vector<Vector> h_storage(2);
    257 
    258   // Compute pre-logits activations.
    259   VectorSpan<float> h_in(input);
    260   Vector *h_out;
    261   for (int i = 0; i < hidden_weights_.size(); ++i) {
    262     const bool apply_relu = i > 0;
    263     h_out = &(h_storage[i % 2]);
    264     h_out->resize(hidden_bias_[i].size());
    265     if (!SparseReluProductPlusBias<ScaleAdderClass>(
    266             apply_relu, hidden_weights_[i], hidden_bias_[i], h_in, h_out)) {
    267       return false;
    268     }
    269     h_in = VectorSpan<float>(*h_out);
    270   }
    271 
    272   // Compute logit scores.
    273   if (!SparseReluProductPlusBias<ScaleAdderClass>(
    274           true, softmax_weights_, softmax_bias_, h_in, scores)) {
    275     return false;
    276   }
    277 
    278   return true;
    279 }
    280 
    281 bool EmbeddingNetwork::ComputeFinalScores(
    282     const std::vector<FeatureVector> &features, Vector *scores) const {
    283   return ComputeFinalScores(features, {}, scores);
    284 }
    285 
    286 bool EmbeddingNetwork::ComputeFinalScores(
    287     const std::vector<FeatureVector> &features,
    288     const std::vector<float> extra_inputs, Vector *scores) const {
    289   // If we haven't successfully initialized, return without doing anything.
    290   if (!is_valid()) return false;
    291 
    292   Vector concat;
    293   if (!ConcatEmbeddings(features, &concat)) return false;
    294 
    295   if (!extra_inputs.empty()) {
    296     concat.reserve(concat.size() + extra_inputs.size());
    297     for (int i = 0; i < extra_inputs.size(); i++) {
    298       concat.push_back(extra_inputs[i]);
    299     }
    300   }
    301 
    302   scores->resize(softmax_bias_.size());
    303   return ComputeLogits(concat, scores);
    304 }
    305 
    306 EmbeddingNetwork::EmbeddingNetwork(const EmbeddingNetworkParams *model) {
    307   // We'll set valid_ to true only if construction is successful.  If we detect
    308   // an error along the way, we log an informative message and return early, but
    309   // we do not crash.
    310   valid_ = false;
    311 
    312   // Fill embedding_matrices_, concat_offset_, and concat_layer_size_.
    313   const int num_embedding_spaces = model->GetNumEmbeddingSpaces();
    314   int offset_sum = 0;
    315   for (int i = 0; i < num_embedding_spaces; ++i) {
    316     concat_offset_.push_back(offset_sum);
    317     const EmbeddingNetworkParams::Matrix matrix = model->GetEmbeddingMatrix(i);
    318     if (matrix.quant_type != QuantizationType::UINT8) {
    319       TC_LOG(ERROR) << "Unsupported quantization for embedding #" << i << ": "
    320                     << static_cast<int>(matrix.quant_type);
    321       return;
    322     }
    323 
    324     // There is no way to accomodate an empty embedding matrix.  E.g., there is
    325     // no way for get_embedding to return something that can be read safely.
    326     // Hence, we catch that error here and return early.
    327     if (matrix.rows == 0) {
    328       TC_LOG(ERROR) << "Empty embedding matrix #" << i;
    329       return;
    330     }
    331     embedding_matrices_.emplace_back(new EmbeddingMatrix(matrix));
    332     const int embedding_dim = embedding_matrices_.back()->dim();
    333     offset_sum += embedding_dim * model->GetNumFeaturesInEmbeddingSpace(i);
    334   }
    335   concat_layer_size_ = offset_sum;
    336 
    337   // Invariant 2 (trivial by the code above).
    338   TC_DCHECK_EQ(concat_offset_.size(), embedding_matrices_.size());
    339 
    340   const int num_hidden_layers = model->GetNumHiddenLayers();
    341   if (num_hidden_layers < 1) {
    342     TC_LOG(ERROR) << "Wrong number of hidden layers: " << num_hidden_layers;
    343     return;
    344   }
    345   hidden_weights_.resize(num_hidden_layers);
    346   hidden_bias_.resize(num_hidden_layers);
    347 
    348   for (int i = 0; i < num_hidden_layers; ++i) {
    349     const EmbeddingNetworkParams::Matrix matrix =
    350         model->GetHiddenLayerMatrix(i);
    351     const EmbeddingNetworkParams::Matrix bias = model->GetHiddenLayerBias(i);
    352     if (!InitNonQuantizedMatrix(matrix, &hidden_weights_[i]) ||
    353         !InitNonQuantizedVector(bias, &hidden_bias_[i])) {
    354       TC_LOG(ERROR) << "Bad hidden layer #" << i;
    355       return;
    356     }
    357   }
    358 
    359   if (!model->HasSoftmaxLayer()) {
    360     TC_LOG(ERROR) << "Missing softmax layer";
    361     return;
    362   }
    363   const EmbeddingNetworkParams::Matrix softmax = model->GetSoftmaxMatrix();
    364   const EmbeddingNetworkParams::Matrix softmax_bias = model->GetSoftmaxBias();
    365   if (!InitNonQuantizedMatrix(softmax, &softmax_weights_) ||
    366       !InitNonQuantizedVector(softmax_bias, &softmax_bias_)) {
    367     TC_LOG(ERROR) << "Bad softmax layer";
    368     return;
    369   }
    370 
    371   // Everything looks good.
    372   valid_ = true;
    373 }
    374 
    375 int EmbeddingNetwork::EmbeddingSize(int es_index) const {
    376   return embedding_matrices_[es_index]->dim();
    377 }
    378 
    379 }  // namespace nlp_core
    380 }  // namespace libtextclassifier
    381