Home | History | Annotate | Download | only in common
      1 /*
      2  * Copyright (C) 2018 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "lang_id/common/embedding-network.h"
     18 
     19 #include "lang_id/common/lite_base/integral-types.h"
     20 #include "lang_id/common/lite_base/logging.h"
     21 
     22 namespace libtextclassifier3 {
     23 namespace mobile {
     24 namespace {
     25 
     26 void CheckNoQuantization(const EmbeddingNetworkParams::Matrix &matrix) {
     27   SAFTM_CHECK_EQ(static_cast<int>(QuantizationType::NONE),
     28                  static_cast<int>(matrix.quant_type))
     29       << "Quantization not allowed here";
     30 }
     31 
     32 int GetMatrixRowSizeInBytes(const EmbeddingNetworkParams::Matrix &matrix) {
     33   int cols = matrix.cols;
     34   QuantizationType quant_type = matrix.quant_type;
     35   switch (quant_type) {
     36     case QuantizationType::NONE:
     37       return cols * sizeof(float);
     38     case QuantizationType::UINT8:
     39       return cols * sizeof(uint8);
     40     case QuantizationType::UINT4:
     41       SAFTM_DCHECK_EQ(cols % 2, 0) << "UINT4 with odd #cols = " << cols;
     42       return cols / 2;
     43     case QuantizationType::FLOAT16:
     44       return cols * sizeof(float16);
     45     default:
     46       SAFTM_LOG(FATAL) << "Unknown quant type: "
     47                        << static_cast<int>(quant_type);
     48   }
     49 }
     50 
     51 // Computes y = weights * Relu(x) + b where Relu is optionally applied.
     52 //
     53 // weights and b are the weight matrix, respectively the bias vector of a neural
     54 // network layer.
     55 //
     56 // Note: in the research literature, usually Relu (the activation function) is
     57 // the last part of a neural layer.  From that perspective, this function
     58 // computes the Relu part of the previous layer (if any) and next the first half
     59 // (the computation of the state) for the current layer.
     60 //
     61 // Note: weights is expected to be the transposed version of the real weight
     62 // matrix.  Hence, instead of computing a linear combination of the columns of
     63 // weights, we compute a linear combination of its rows; but we are mindful that
     64 // these rows are the columns of the original matrix, hence the name
     65 // weights_col_i in the code.
     66 void SparseReluProductPlusBias(bool apply_relu,
     67                                const EmbeddingNetworkParams::Matrix &weights,
     68                                const EmbeddingNetworkParams::Matrix &b,
     69                                const std::vector<float> &x,
     70                                std::vector<float> *y) {
     71   // Initialize y to b.  b is a column matrix (i.e., nb.cols == 1); we already
     72   // CHECK-ed that the EmbeddingNetwork constructor.
     73   const float *b_start = reinterpret_cast<const float *>(b.elements);
     74   SAFTM_DCHECK_EQ(b.cols, 1);
     75   y->assign(b_start, b_start + b.rows);
     76 
     77   float *const y_data = y->data();
     78   const int y_size = y->size();
     79   SAFTM_CHECK_EQ(weights.cols, y_size);
     80   const int x_size = x.size();
     81   SAFTM_CHECK_EQ(weights.rows, x_size);
     82 
     83   // NOTE: the code below reads x_size * y_size elements from weights; these
     84   // reads are safe as long as weights.elements contains weights.rows *
     85   // weights.cols elements (where the element size depends on the quantization
     86   // type).  That requirement is checked by the params provider, e.g., by
     87   // EmbeddingNetworkParamsFromFlatbuffer.
     88 
     89   // There is some code duplication between the two main cases of the switch
     90   // below: the idea was to "lift" the switch outside the loops, to reduce the
     91   // number of tests at runtime.
     92   switch (weights.quant_type) {
     93     case QuantizationType::NONE: {
     94       // We compute a linear combination of the rows from |weights|, using
     95       // elements of x (optionally, Relu(x)) as scaling factors (the i-th row
     96       // gets multiplied by x[i] before being added with the other rows).  Note:
     97       // elements of |weights| are stored in row-major order: first the elements
     98       // of row #0, next the elements of row #1, etc.  In the comments below, we
     99       // write "weights[i][j]" to refer to the j-th element from the i-th row of
    100       // weights.
    101       const float *weight_ptr =
    102           reinterpret_cast<const float *>(weights.elements);
    103       for (int i = 0; i < x_size; ++i) {
    104         // Invariant 1: weight_ptr points to the beginning of the i-th row from
    105         // weights (i.e., weights[i][0]).
    106         const float scale = x[i];
    107         if (!apply_relu || (scale > 0)) {
    108           for (int j = 0; j < y_size; ++j, ++weight_ptr) {
    109             // Invariant 2: weight_ptr points to weights[i][j].
    110             y_data[j] += (*weight_ptr) * scale;
    111           }
    112         } else {
    113           // We don't update y_data, but we still have to move weight_ptr to the
    114           // next row (to satisfy Invariant 1).  We do this by adding y_size ==
    115           // weights.cols() (see earlier CHECK_EQ).
    116           weight_ptr += y_size;
    117         }
    118       }
    119       break;
    120     }
    121     case QuantizationType::FLOAT16: {
    122       // See comments for the QuantizationType::NONE case: the code is almost
    123       // identical, except for float16 (instead of float) and the Float16To32
    124       // conversion.  We could unify these two cases using a template, but since
    125       // this is a critical loop, don't want to risk that e.g., inlining of the
    126       // conversion function doesn't happen.
    127       const float16 *weight_ptr =
    128           reinterpret_cast<const float16 *>(weights.elements);
    129       for (int i = 0; i < x_size; ++i) {
    130         const float scale = x[i];
    131         if (!apply_relu || (scale > 0)) {
    132           for (int j = 0; j < y_size; ++j, ++weight_ptr) {
    133             y_data[j] += Float16To32(*weight_ptr) * scale;
    134           }
    135         } else {
    136           weight_ptr += y_size;
    137         }
    138       }
    139       break;
    140     }
    141     default:
    142       SAFTM_LOG(FATAL) << "Unsupported weights quantization type: "
    143                        << static_cast<int>(weights.quant_type);
    144   }
    145 }
    146 }  // namespace
    147 
    148 void EmbeddingNetwork::ConcatEmbeddings(
    149     const std::vector<FeatureVector> &feature_vectors,
    150     std::vector<float> *concat) const {
    151   concat->resize(concat_layer_size_);
    152 
    153   // "es_index" stands for "embedding space index".
    154   for (int es_index = 0; es_index < feature_vectors.size(); ++es_index) {
    155     const int concat_offset = concat_offset_[es_index];
    156 
    157     const EmbeddingNetworkParams::Matrix &embedding_matrix =
    158         embedding_matrices_[es_index];
    159     const int embedding_dim = embedding_matrix.cols;
    160     const int embedding_row_size_in_bytes =
    161         embedding_row_size_in_bytes_[es_index];
    162 
    163     const FeatureVector &feature_vector = feature_vectors[es_index];
    164     const int num_features = feature_vector.size();
    165     for (int fi = 0; fi < num_features; ++fi) {
    166       const FeatureType *feature_type = feature_vector.type(fi);
    167       int feature_offset = concat_offset + feature_type->base() * embedding_dim;
    168       SAFTM_CHECK_LE(feature_offset + embedding_dim, concat->size());
    169 
    170       // Weighted embeddings will be added starting from this address.
    171       float *concat_ptr = concat->data() + feature_offset;
    172 
    173       // Multiplier for each embedding weight.  Includes feature weight (for
    174       // continuous features) and quantization scale (for quantized embeddings).
    175       float multiplier;
    176       int feature_id;
    177       const FeatureValue feature_value = feature_vector.value(fi);
    178       if (feature_type->is_continuous()) {
    179         // Continuous features (encoded as FloatFeatureValue).
    180         FloatFeatureValue float_feature_value(feature_value);
    181         feature_id = float_feature_value.id;
    182         multiplier = float_feature_value.weight;
    183       } else {
    184         // Discrete features: every present feature has implicit value 1.0.
    185         feature_id = feature_value;
    186         multiplier = 1.0;
    187       }
    188 
    189       SAFTM_CHECK_GE(feature_id, 0);
    190       SAFTM_CHECK_LT(feature_id, embedding_matrix.rows);
    191 
    192       // Pointer to float / uint8 weights for relevant embedding.
    193       const void *embedding_data =
    194           (reinterpret_cast<const char *>(embedding_matrix.elements) +
    195            feature_id * embedding_row_size_in_bytes);
    196 
    197       switch (embedding_matrix.quant_type) {
    198         case QuantizationType::NONE: {
    199           const float *weights =
    200               reinterpret_cast<const float *>(embedding_data);
    201           for (int i = 0; i < embedding_dim; ++i, ++weights, ++concat_ptr) {
    202             *concat_ptr += *weights * multiplier;
    203           }
    204           break;
    205         }
    206         case QuantizationType::UINT8: {
    207           multiplier *= Float16To32(embedding_matrix.quant_scales[feature_id]);
    208           const uint8 *quant_weights =
    209               reinterpret_cast<const uint8 *>(embedding_data);
    210           for (int i = 0; i < embedding_dim;
    211                ++i, ++quant_weights, ++concat_ptr) {
    212             // 128 is bias for UINT8 quantization.
    213             *concat_ptr +=
    214                 (static_cast<int>(*quant_weights) - 128) * multiplier;
    215           }
    216           break;
    217         }
    218         case QuantizationType::UINT4: {
    219           multiplier *= Float16To32(embedding_matrix.quant_scales[feature_id]);
    220           const uint8 *quant_weights =
    221               reinterpret_cast<const uint8 *>(embedding_data);
    222           for (int i = 0; i < embedding_dim / 2; ++i, ++quant_weights) {
    223             const uint8 qq = *quant_weights;
    224             concat_ptr[0] +=
    225                 (static_cast<int>((qq & 0xF0) | 0x08) - 128) * multiplier;
    226             concat_ptr[1] +=
    227                 (static_cast<int>(((qq & 0x0F) << 4) | 0x08) - 128) *
    228                 multiplier;
    229             concat_ptr += 2;
    230           }
    231           break;
    232         }
    233         default:
    234           // We already checked (in GetMatrixRowSizeInBytes) that each embedding
    235           // matrix has a known quantization type.  Hence, DLOG is enough here.
    236           SAFTM_DLOG(ERROR) << "Unknown embeddings quantization type "
    237                             << static_cast<int>(embedding_matrix.quant_type);
    238           break;
    239       }
    240     }
    241   }
    242 }
    243 
    244 void EmbeddingNetwork::ComputeFinalScores(
    245     const std::vector<FeatureVector> &features,
    246     std::vector<float> *scores) const {
    247   ComputeFinalScores(features, {}, scores);
    248 }
    249 
    250 void EmbeddingNetwork::ComputeFinalScores(
    251     const std::vector<FeatureVector> &features,
    252     const std::vector<float> &extra_inputs, std::vector<float> *scores) const {
    253   // Construct the input layer for our feed-forward neural network (FFNN).
    254   std::vector<float> input;
    255   ConcatEmbeddings(features, &input);
    256   if (!extra_inputs.empty()) {
    257     input.reserve(input.size() + extra_inputs.size());
    258     for (int i = 0; i < extra_inputs.size(); i++) {
    259       input.push_back(extra_inputs[i]);
    260     }
    261   }
    262 
    263   // Propagate input through all layers of our FFNN.
    264 
    265   // Alternating storage for activations of the different layers.  We can't use
    266   // a single vector because all activations of the previous layer are required
    267   // when computing the activations of the next one.
    268   std::vector<float> storage[2];
    269   const std::vector<float> *v_in = &input;
    270   const int num_layers = layer_weights_.size();
    271   for (int i = 0; i < num_layers; ++i) {
    272     std::vector<float> *v_out = nullptr;
    273     if (i == num_layers - 1) {
    274       // Final layer: write results directly into |scores|.
    275       v_out = scores;
    276     } else {
    277       // Hidden layer: write results into the alternating storage.  The i % 2
    278       // trick ensures the alternation.
    279       v_out = &(storage[i % 2]);
    280     }
    281     const bool apply_relu = i > 0;
    282     SparseReluProductPlusBias(
    283         apply_relu, layer_weights_[i], layer_bias_[i], *v_in, v_out);
    284     v_in = v_out;
    285   }
    286 }
    287 
    288 EmbeddingNetwork::EmbeddingNetwork(const EmbeddingNetworkParams *model)
    289     : model_(model) {
    290   int offset_sum = 0;
    291   for (int i = 0; i < model_->embedding_num_features_size(); ++i) {
    292     concat_offset_.push_back(offset_sum);
    293     EmbeddingNetworkParams::Matrix matrix = model_->GetEmbeddingMatrix(i);
    294     offset_sum += matrix.cols * model_->embedding_num_features(i);
    295 
    296     // NOTE: each Matrix is a small struct that doesn't own the actual matrix
    297     // weights.  Hence, the push_back below is fast.
    298     embedding_matrices_.push_back(matrix);
    299     embedding_row_size_in_bytes_.push_back(GetMatrixRowSizeInBytes(matrix));
    300   }
    301   concat_layer_size_ = offset_sum;
    302 
    303   SAFTM_CHECK_EQ(model_->hidden_size(), model_->hidden_bias_size());
    304   for (int i = 0; i < model_->hidden_size(); ++i) {
    305     layer_weights_.push_back(model_->GetHiddenLayerMatrix(i));
    306 
    307     EmbeddingNetworkParams::Matrix bias = model_->GetHiddenLayerBias(i);
    308     SAFTM_CHECK_EQ(1, bias.cols);
    309     CheckNoQuantization(bias);
    310     layer_bias_.push_back(bias);
    311   }
    312 
    313   SAFTM_CHECK(model_->HasSoftmax());
    314   layer_weights_.push_back(model_->GetSoftmaxMatrix());
    315 
    316   EmbeddingNetworkParams::Matrix softmax_bias = model_->GetSoftmaxBias();
    317   SAFTM_CHECK_EQ(1, softmax_bias.cols);
    318   CheckNoQuantization(softmax_bias);
    319   layer_bias_.push_back(softmax_bias);
    320 }
    321 
    322 }  // namespace mobile
    323 }  // namespace nlp_saft
    324