Home | History | Annotate | Download | only in common
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #ifndef LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_PARAMS_FROM_PROTO_H_
     18 #define LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_PARAMS_FROM_PROTO_H_
     19 
     20 #include <algorithm>
     21 #include <memory>
     22 #include <string>
     23 #include <utility>
     24 #include <vector>
     25 
     26 #include "common/embedding-network-package.pb.h"
     27 #include "common/embedding-network-params.h"
     28 #include "common/embedding-network.pb.h"
     29 #include "common/float16.h"
     30 #include "common/little-endian-data.h"
     31 #include "common/task-context.h"
     32 #include "common/task-spec.pb.h"
     33 #include "util/base/integral_types.h"
     34 #include "util/base/logging.h"
     35 
     36 namespace libtextclassifier {
     37 namespace nlp_core {
     38 
     39 // A wrapper class that owns and exposes an EmbeddingNetworkProto message via
     40 // the EmbeddingNetworkParams interface.
     41 //
     42 // The EmbeddingNetworkParams interface encapsulates the weight matrices of the
     43 // embeddings, hidden and softmax layers as transposed versions of their
     44 // counterparts in the original EmbeddingNetworkProto. The matrices in the proto
     45 // passed to this class' constructor must likewise already have been transposed.
     46 // See embedding-network-params.h for details.
     47 class EmbeddingNetworkParamsFromProto : public EmbeddingNetworkParams {
     48  public:
     49   // Constructor that takes ownership of the provided proto. See class-comment
     50   // for the requirements that certain weight matrices must satisfy.
     51   explicit EmbeddingNetworkParamsFromProto(
     52       std::unique_ptr<EmbeddingNetworkProto> proto)
     53       : proto_(std::move(proto)) {
     54     valid_ = true;
     55 
     56     // Initialize these vectors to have the required number of elements
     57     // regardless of quantization status. This is to support the unlikely case
     58     // where only some embeddings are quantized, along with the fact that
     59     // EmbeddingNetworkParams interface accesses them by index.
     60     embeddings_quant_scales_.resize(proto_->embeddings_size());
     61     embeddings_quant_weights_.resize(proto_->embeddings_size());
     62     for (int i = 0; i < proto_->embeddings_size(); ++i) {
     63       MatrixParams *embedding = proto_->mutable_embeddings()->Mutable(i);
     64       if (!embedding->is_quantized()) {
     65         continue;
     66       }
     67 
     68       bool success = FillVectorFromDataBytesInLittleEndian(
     69           embedding->bytes_for_quantized_values(),
     70           embedding->rows() * embedding->cols(),
     71           &(embeddings_quant_weights_[i]));
     72       if (!success) {
     73         TC_LOG(ERROR) << "Problem decoding quant_weights for embeddings #" << i;
     74         valid_ = false;
     75       }
     76 
     77       // The repeated field bytes_for_quantized_values uses a lot of memory.
     78       // Since it's no longer necessary (and we own the proto), we clear it.
     79       embedding->clear_bytes_for_quantized_values();
     80 
     81       success = FillVectorFromDataBytesInLittleEndian(
     82           embedding->bytes_for_col_scales(),
     83           embedding->rows(),
     84           &(embeddings_quant_scales_[i]));
     85       if (!success) {
     86         TC_LOG(ERROR) << "Problem decoding col_scales for embeddings #" << i;
     87         valid_ = false;
     88       }
     89 
     90       // See comments for clear_bytes_for_quantized_values().
     91       embedding->clear_bytes_for_col_scales();
     92     }
     93   }
     94 
     95   const TaskSpec *GetTaskSpec() override {
     96     if (!proto_) {
     97       return nullptr;
     98     }
     99     auto extension_id = task_spec_in_embedding_network_proto;
    100     if (proto_->HasExtension(extension_id)) {
    101       return &(proto_->GetExtension(extension_id));
    102     } else {
    103       TC_LOG(ERROR) << "Unable to get TaskSpec from EmbeddingNetworkProto";
    104       return nullptr;
    105     }
    106   }
    107 
    108   // Returns true if these params are valid.  False otherwise (e.g., if the
    109   // original proto data was corrupted).
    110   bool is_valid() { return valid_; }
    111 
    112  protected:
    113   int embeddings_size() const override { return proto_->embeddings_size(); }
    114 
    115   int embeddings_num_rows(int i) const override {
    116     TC_DCHECK(InRange(i, embeddings_size()));
    117     return proto_->embeddings(i).rows();
    118   }
    119 
    120   int embeddings_num_cols(int i) const override {
    121     TC_DCHECK(InRange(i, embeddings_size()));
    122     return proto_->embeddings(i).cols();
    123   }
    124 
    125   const void *embeddings_weights(int i) const override {
    126     TC_DCHECK(InRange(i, embeddings_size()));
    127     if (proto_->embeddings(i).is_quantized()) {
    128       return static_cast<const void *>(embeddings_quant_weights_.at(i).data());
    129     } else {
    130       return static_cast<const void *>(proto_->embeddings(i).value().data());
    131     }
    132   }
    133 
    134   QuantizationType embeddings_quant_type(int i) const override {
    135     TC_DCHECK(InRange(i, embeddings_size()));
    136     return proto_->embeddings(i).is_quantized() ? QuantizationType::UINT8
    137                                                 : QuantizationType::NONE;
    138   }
    139 
    140   const float16 *embeddings_quant_scales(int i) const override {
    141     TC_DCHECK(InRange(i, embeddings_size()));
    142     return proto_->embeddings(i).is_quantized()
    143                ? embeddings_quant_scales_.at(i).data()
    144                : nullptr;
    145   }
    146 
    147   int hidden_size() const override { return proto_->hidden_size(); }
    148 
    149   int hidden_num_rows(int i) const override {
    150     TC_DCHECK(InRange(i, hidden_size()));
    151     return proto_->hidden(i).rows();
    152   }
    153 
    154   int hidden_num_cols(int i) const override {
    155     TC_DCHECK(InRange(i, hidden_size()));
    156     return proto_->hidden(i).cols();
    157   }
    158 
    159   const void *hidden_weights(int i) const override {
    160     TC_DCHECK(InRange(i, hidden_size()));
    161     return proto_->hidden(i).value().data();
    162   }
    163 
    164   int hidden_bias_size() const override { return proto_->hidden_bias_size(); }
    165 
    166   int hidden_bias_num_rows(int i) const override {
    167     TC_DCHECK(InRange(i, hidden_bias_size()));
    168     return proto_->hidden_bias(i).rows();
    169   }
    170 
    171   int hidden_bias_num_cols(int i) const override {
    172     TC_DCHECK(InRange(i, hidden_bias_size()));
    173     return proto_->hidden_bias(i).cols();
    174   }
    175 
    176   const void *hidden_bias_weights(int i) const override {
    177     TC_DCHECK(InRange(i, hidden_bias_size()));
    178     return proto_->hidden_bias(i).value().data();
    179   }
    180 
    181   int softmax_size() const override { return proto_->has_softmax() ? 1 : 0; }
    182 
    183   int softmax_num_rows(int i) const override {
    184     TC_DCHECK(InRange(i, softmax_size()));
    185     return proto_->has_softmax() ? proto_->softmax().rows() : 0;
    186   }
    187 
    188   int softmax_num_cols(int i) const override {
    189     TC_DCHECK(InRange(i, softmax_size()));
    190     return proto_->has_softmax() ? proto_->softmax().cols() : 0;
    191   }
    192 
    193   const void *softmax_weights(int i) const override {
    194     TC_DCHECK(InRange(i, softmax_size()));
    195     return proto_->has_softmax() ? proto_->softmax().value().data() : nullptr;
    196   }
    197 
    198   int softmax_bias_size() const override {
    199     return proto_->has_softmax_bias() ? 1 : 0;
    200   }
    201 
    202   int softmax_bias_num_rows(int i) const override {
    203     TC_DCHECK(InRange(i, softmax_bias_size()));
    204     return proto_->has_softmax_bias() ? proto_->softmax_bias().rows() : 0;
    205   }
    206 
    207   int softmax_bias_num_cols(int i) const override {
    208     TC_DCHECK(InRange(i, softmax_bias_size()));
    209     return proto_->has_softmax_bias() ? proto_->softmax_bias().cols() : 0;
    210   }
    211 
    212   const void *softmax_bias_weights(int i) const override {
    213     TC_DCHECK(InRange(i, softmax_bias_size()));
    214     return proto_->has_softmax_bias() ? proto_->softmax_bias().value().data()
    215                                       : nullptr;
    216   }
    217 
    218   int embedding_num_features_size() const override {
    219     return proto_->embedding_num_features_size();
    220   }
    221 
    222   int embedding_num_features(int i) const override {
    223     TC_DCHECK(InRange(i, embedding_num_features_size()));
    224     return proto_->embedding_num_features(i);
    225   }
    226 
    227  private:
    228   std::unique_ptr<EmbeddingNetworkProto> proto_;
    229 
    230   // True if these params are valid.  May be false if the original proto was
    231   // corrupted.  We prefer to set this to false to CHECK-failing.
    232   bool valid_;
    233 
    234   // When the embeddings are quantized, these members are used to store their
    235   // numeric values using the types expected by the rest of the class. Due to
    236   // technical reasons, the proto stores this info using larger types (i.e.,
    237   // more bits).
    238   std::vector<std::vector<float16>> embeddings_quant_scales_;
    239   std::vector<std::vector<uint8>> embeddings_quant_weights_;
    240 };
    241 
    242 }  // namespace nlp_core
    243 }  // namespace libtextclassifier
    244 
    245 #endif  // LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_PARAMS_FROM_PROTO_H_
    246