Home | History | Annotate | Download | only in memory_image
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #ifndef LIBTEXTCLASSIFIER_COMMON_MEMORY_IMAGE_EMBEDDING_NETWORK_PARAMS_FROM_IMAGE_H_
     18 #define LIBTEXTCLASSIFIER_COMMON_MEMORY_IMAGE_EMBEDDING_NETWORK_PARAMS_FROM_IMAGE_H_
     19 
     20 #include "common/embedding-network-package.pb.h"
     21 #include "common/embedding-network-params.h"
     22 #include "common/embedding-network.pb.h"
     23 #include "common/memory_image/memory-image-reader.h"
     24 #include "util/base/integral_types.h"
     25 
     26 namespace libtextclassifier {
     27 namespace nlp_core {
     28 
     29 // EmbeddingNetworkParams backed by a memory image.
     30 //
     31 // In this context, a memory image is like an EmbeddingNetworkProto, but with
     32 // all repeated weights (>99% of the size) directly usable (with no parsing
     33 // required).
     34 class EmbeddingNetworkParamsFromImage : public EmbeddingNetworkParams {
     35  public:
     36   // Constructs an EmbeddingNetworkParamsFromImage, using the memory image that
     37   // starts at address start and contains num_bytes bytes.
     38   EmbeddingNetworkParamsFromImage(const void *start, uint64 num_bytes)
     39       : memory_reader_(start, num_bytes),
     40         trimmed_proto_(memory_reader_.trimmed_proto()) {
     41     embeddings_blob_offset_ = 0;
     42 
     43     hidden_blob_offset_ = embeddings_blob_offset_ + embeddings_size();
     44     if (trimmed_proto_.embeddings_size() &&
     45         trimmed_proto_.embeddings(0).is_quantized()) {
     46       // Adjust for quantization: each quantized matrix takes two blobs (instead
     47       // of one): one for the quantized values and one for the scales.
     48       hidden_blob_offset_ += embeddings_size();
     49     }
     50 
     51     hidden_bias_blob_offset_ = hidden_blob_offset_ + hidden_size();
     52     softmax_blob_offset_ = hidden_bias_blob_offset_ + hidden_bias_size();
     53     softmax_bias_blob_offset_ = softmax_blob_offset_ + softmax_size();
     54   }
     55 
     56   ~EmbeddingNetworkParamsFromImage() override {}
     57 
     58   const TaskSpec *GetTaskSpec() override {
     59     auto extension_id = task_spec_in_embedding_network_proto;
     60     if (trimmed_proto_.HasExtension(extension_id)) {
     61       return &(trimmed_proto_.GetExtension(extension_id));
     62     } else {
     63       return nullptr;
     64     }
     65   }
     66 
     67  protected:
     68   int embeddings_size() const override {
     69     return trimmed_proto_.embeddings_size();
     70   }
     71 
     72   int embeddings_num_rows(int i) const override {
     73     TC_DCHECK(InRange(i, embeddings_size()));
     74     return trimmed_proto_.embeddings(i).rows();
     75   }
     76 
     77   int embeddings_num_cols(int i) const override {
     78     TC_DCHECK(InRange(i, embeddings_size()));
     79     return trimmed_proto_.embeddings(i).cols();
     80   }
     81 
     82   const void *embeddings_weights(int i) const override {
     83     TC_DCHECK(InRange(i, embeddings_size()));
     84     const int blob_index = trimmed_proto_.embeddings(i).is_quantized()
     85                                ? (embeddings_blob_offset_ + 2 * i)
     86                                : (embeddings_blob_offset_ + i);
     87     DataBlobView data_blob_view = memory_reader_.data_blob_view(blob_index);
     88     return data_blob_view.data();
     89   }
     90 
     91   QuantizationType embeddings_quant_type(int i) const override {
     92     TC_DCHECK(InRange(i, embeddings_size()));
     93     if (trimmed_proto_.embeddings(i).is_quantized()) {
     94       return QuantizationType::UINT8;
     95     } else {
     96       return QuantizationType::NONE;
     97     }
     98   }
     99 
    100   const float16 *embeddings_quant_scales(int i) const override {
    101     TC_DCHECK(InRange(i, embeddings_size()));
    102     if (trimmed_proto_.embeddings(i).is_quantized()) {
    103       // Each embedding matrix has two atttached data blobs (hence the "2 * i"):
    104       // one blob with the quantized values and (immediately after it, hence the
    105       // "+ 1") one blob with the scales.
    106       int blob_index = embeddings_blob_offset_ + 2 * i + 1;
    107       DataBlobView data_blob_view = memory_reader_.data_blob_view(blob_index);
    108       return reinterpret_cast<const float16 *>(data_blob_view.data());
    109     } else {
    110       return nullptr;
    111     }
    112   }
    113 
    114   int hidden_size() const override { return trimmed_proto_.hidden_size(); }
    115 
    116   int hidden_num_rows(int i) const override {
    117     TC_DCHECK(InRange(i, hidden_size()));
    118     return trimmed_proto_.hidden(i).rows();
    119   }
    120 
    121   int hidden_num_cols(int i) const override {
    122     TC_DCHECK(InRange(i, hidden_size()));
    123     return trimmed_proto_.hidden(i).cols();
    124   }
    125 
    126   const void *hidden_weights(int i) const override {
    127     TC_DCHECK(InRange(i, hidden_size()));
    128     DataBlobView data_blob_view =
    129         memory_reader_.data_blob_view(hidden_blob_offset_ + i);
    130     return data_blob_view.data();
    131   }
    132 
    133   int hidden_bias_size() const override {
    134     return trimmed_proto_.hidden_bias_size();
    135   }
    136 
    137   int hidden_bias_num_rows(int i) const override {
    138     TC_DCHECK(InRange(i, hidden_bias_size()));
    139     return trimmed_proto_.hidden_bias(i).rows();
    140   }
    141 
    142   int hidden_bias_num_cols(int i) const override {
    143     TC_DCHECK(InRange(i, hidden_bias_size()));
    144     return trimmed_proto_.hidden_bias(i).cols();
    145   }
    146 
    147   const void *hidden_bias_weights(int i) const override {
    148     TC_DCHECK(InRange(i, hidden_bias_size()));
    149     DataBlobView data_blob_view =
    150         memory_reader_.data_blob_view(hidden_bias_blob_offset_ + i);
    151     return data_blob_view.data();
    152   }
    153 
    154   int softmax_size() const override {
    155     return trimmed_proto_.has_softmax() ? 1 : 0;
    156   }
    157 
    158   int softmax_num_rows(int i) const override {
    159     TC_DCHECK(InRange(i, softmax_size()));
    160     return trimmed_proto_.softmax().rows();
    161   }
    162 
    163   int softmax_num_cols(int i) const override {
    164     TC_DCHECK(InRange(i, softmax_size()));
    165     return trimmed_proto_.softmax().cols();
    166   }
    167 
    168   const void *softmax_weights(int i) const override {
    169     TC_DCHECK(InRange(i, softmax_size()));
    170     DataBlobView data_blob_view =
    171         memory_reader_.data_blob_view(softmax_blob_offset_ + i);
    172     return data_blob_view.data();
    173   }
    174 
    175   int softmax_bias_size() const override {
    176     return trimmed_proto_.has_softmax_bias() ? 1 : 0;
    177   }
    178 
    179   int softmax_bias_num_rows(int i) const override {
    180     TC_DCHECK(InRange(i, softmax_bias_size()));
    181     return trimmed_proto_.softmax_bias().rows();
    182   }
    183 
    184   int softmax_bias_num_cols(int i) const override {
    185     TC_DCHECK(InRange(i, softmax_bias_size()));
    186     return trimmed_proto_.softmax_bias().cols();
    187   }
    188 
    189   const void *softmax_bias_weights(int i) const override {
    190     TC_DCHECK(InRange(i, softmax_bias_size()));
    191     DataBlobView data_blob_view =
    192         memory_reader_.data_blob_view(softmax_bias_blob_offset_ + i);
    193     return data_blob_view.data();
    194   }
    195 
    196   int embedding_num_features_size() const override {
    197     return trimmed_proto_.embedding_num_features_size();
    198   }
    199 
    200   int embedding_num_features(int i) const override {
    201     TC_DCHECK(InRange(i, embedding_num_features_size()));
    202     return trimmed_proto_.embedding_num_features(i);
    203   }
    204 
    205  private:
    206   MemoryImageReader<EmbeddingNetworkProto> memory_reader_;
    207 
    208   const EmbeddingNetworkProto &trimmed_proto_;
    209 
    210   // 0-based offsets in the list of data blobs for the different MatrixParams
    211   // fields.  E.g., the 1st hidden MatrixParams has its weights stored in the
    212   // data blob number hidden_blob_offset_, the 2nd one in hidden_blob_offset_ +
    213   // 1, and so on.
    214   int embeddings_blob_offset_;
    215   int hidden_blob_offset_;
    216   int hidden_bias_blob_offset_;
    217   int softmax_blob_offset_;
    218   int softmax_bias_blob_offset_;
    219 };
    220 
    221 }  // namespace nlp_core
    222 }  // namespace libtextclassifier
    223 
    224 #endif  // LIBTEXTCLASSIFIER_COMMON_MEMORY_IMAGE_EMBEDDING_NETWORK_PARAMS_FROM_IMAGE_H_
    225