1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef LIBTEXTCLASSIFIER_COMMON_MEMORY_IMAGE_EMBEDDING_NETWORK_PARAMS_FROM_IMAGE_H_ 18 #define LIBTEXTCLASSIFIER_COMMON_MEMORY_IMAGE_EMBEDDING_NETWORK_PARAMS_FROM_IMAGE_H_ 19 20 #include "common/embedding-network-package.pb.h" 21 #include "common/embedding-network-params.h" 22 #include "common/embedding-network.pb.h" 23 #include "common/memory_image/memory-image-reader.h" 24 #include "util/base/integral_types.h" 25 26 namespace libtextclassifier { 27 namespace nlp_core { 28 29 // EmbeddingNetworkParams backed by a memory image. 30 // 31 // In this context, a memory image is like an EmbeddingNetworkProto, but with 32 // all repeated weights (>99% of the size) directly usable (with no parsing 33 // required). 34 class EmbeddingNetworkParamsFromImage : public EmbeddingNetworkParams { 35 public: 36 // Constructs an EmbeddingNetworkParamsFromImage, using the memory image that 37 // starts at address start and contains num_bytes bytes. 38 EmbeddingNetworkParamsFromImage(const void *start, uint64 num_bytes) 39 : memory_reader_(start, num_bytes), 40 trimmed_proto_(memory_reader_.trimmed_proto()) { 41 embeddings_blob_offset_ = 0; 42 43 hidden_blob_offset_ = embeddings_blob_offset_ + embeddings_size(); 44 if (trimmed_proto_.embeddings_size() && 45 trimmed_proto_.embeddings(0).is_quantized()) { 46 // Adjust for quantization: each quantized matrix takes two blobs (instead 47 // of one): one for the quantized values and one for the scales. 48 hidden_blob_offset_ += embeddings_size(); 49 } 50 51 hidden_bias_blob_offset_ = hidden_blob_offset_ + hidden_size(); 52 softmax_blob_offset_ = hidden_bias_blob_offset_ + hidden_bias_size(); 53 softmax_bias_blob_offset_ = softmax_blob_offset_ + softmax_size(); 54 } 55 56 ~EmbeddingNetworkParamsFromImage() override {} 57 58 const TaskSpec *GetTaskSpec() override { 59 auto extension_id = task_spec_in_embedding_network_proto; 60 if (trimmed_proto_.HasExtension(extension_id)) { 61 return &(trimmed_proto_.GetExtension(extension_id)); 62 } else { 63 return nullptr; 64 } 65 } 66 67 protected: 68 int embeddings_size() const override { 69 return trimmed_proto_.embeddings_size(); 70 } 71 72 int embeddings_num_rows(int i) const override { 73 TC_DCHECK(InRange(i, embeddings_size())); 74 return trimmed_proto_.embeddings(i).rows(); 75 } 76 77 int embeddings_num_cols(int i) const override { 78 TC_DCHECK(InRange(i, embeddings_size())); 79 return trimmed_proto_.embeddings(i).cols(); 80 } 81 82 const void *embeddings_weights(int i) const override { 83 TC_DCHECK(InRange(i, embeddings_size())); 84 const int blob_index = trimmed_proto_.embeddings(i).is_quantized() 85 ? (embeddings_blob_offset_ + 2 * i) 86 : (embeddings_blob_offset_ + i); 87 DataBlobView data_blob_view = memory_reader_.data_blob_view(blob_index); 88 return data_blob_view.data(); 89 } 90 91 QuantizationType embeddings_quant_type(int i) const override { 92 TC_DCHECK(InRange(i, embeddings_size())); 93 if (trimmed_proto_.embeddings(i).is_quantized()) { 94 return QuantizationType::UINT8; 95 } else { 96 return QuantizationType::NONE; 97 } 98 } 99 100 const float16 *embeddings_quant_scales(int i) const override { 101 TC_DCHECK(InRange(i, embeddings_size())); 102 if (trimmed_proto_.embeddings(i).is_quantized()) { 103 // Each embedding matrix has two atttached data blobs (hence the "2 * i"): 104 // one blob with the quantized values and (immediately after it, hence the 105 // "+ 1") one blob with the scales. 106 int blob_index = embeddings_blob_offset_ + 2 * i + 1; 107 DataBlobView data_blob_view = memory_reader_.data_blob_view(blob_index); 108 return reinterpret_cast<const float16 *>(data_blob_view.data()); 109 } else { 110 return nullptr; 111 } 112 } 113 114 int hidden_size() const override { return trimmed_proto_.hidden_size(); } 115 116 int hidden_num_rows(int i) const override { 117 TC_DCHECK(InRange(i, hidden_size())); 118 return trimmed_proto_.hidden(i).rows(); 119 } 120 121 int hidden_num_cols(int i) const override { 122 TC_DCHECK(InRange(i, hidden_size())); 123 return trimmed_proto_.hidden(i).cols(); 124 } 125 126 const void *hidden_weights(int i) const override { 127 TC_DCHECK(InRange(i, hidden_size())); 128 DataBlobView data_blob_view = 129 memory_reader_.data_blob_view(hidden_blob_offset_ + i); 130 return data_blob_view.data(); 131 } 132 133 int hidden_bias_size() const override { 134 return trimmed_proto_.hidden_bias_size(); 135 } 136 137 int hidden_bias_num_rows(int i) const override { 138 TC_DCHECK(InRange(i, hidden_bias_size())); 139 return trimmed_proto_.hidden_bias(i).rows(); 140 } 141 142 int hidden_bias_num_cols(int i) const override { 143 TC_DCHECK(InRange(i, hidden_bias_size())); 144 return trimmed_proto_.hidden_bias(i).cols(); 145 } 146 147 const void *hidden_bias_weights(int i) const override { 148 TC_DCHECK(InRange(i, hidden_bias_size())); 149 DataBlobView data_blob_view = 150 memory_reader_.data_blob_view(hidden_bias_blob_offset_ + i); 151 return data_blob_view.data(); 152 } 153 154 int softmax_size() const override { 155 return trimmed_proto_.has_softmax() ? 1 : 0; 156 } 157 158 int softmax_num_rows(int i) const override { 159 TC_DCHECK(InRange(i, softmax_size())); 160 return trimmed_proto_.softmax().rows(); 161 } 162 163 int softmax_num_cols(int i) const override { 164 TC_DCHECK(InRange(i, softmax_size())); 165 return trimmed_proto_.softmax().cols(); 166 } 167 168 const void *softmax_weights(int i) const override { 169 TC_DCHECK(InRange(i, softmax_size())); 170 DataBlobView data_blob_view = 171 memory_reader_.data_blob_view(softmax_blob_offset_ + i); 172 return data_blob_view.data(); 173 } 174 175 int softmax_bias_size() const override { 176 return trimmed_proto_.has_softmax_bias() ? 1 : 0; 177 } 178 179 int softmax_bias_num_rows(int i) const override { 180 TC_DCHECK(InRange(i, softmax_bias_size())); 181 return trimmed_proto_.softmax_bias().rows(); 182 } 183 184 int softmax_bias_num_cols(int i) const override { 185 TC_DCHECK(InRange(i, softmax_bias_size())); 186 return trimmed_proto_.softmax_bias().cols(); 187 } 188 189 const void *softmax_bias_weights(int i) const override { 190 TC_DCHECK(InRange(i, softmax_bias_size())); 191 DataBlobView data_blob_view = 192 memory_reader_.data_blob_view(softmax_bias_blob_offset_ + i); 193 return data_blob_view.data(); 194 } 195 196 int embedding_num_features_size() const override { 197 return trimmed_proto_.embedding_num_features_size(); 198 } 199 200 int embedding_num_features(int i) const override { 201 TC_DCHECK(InRange(i, embedding_num_features_size())); 202 return trimmed_proto_.embedding_num_features(i); 203 } 204 205 private: 206 MemoryImageReader<EmbeddingNetworkProto> memory_reader_; 207 208 const EmbeddingNetworkProto &trimmed_proto_; 209 210 // 0-based offsets in the list of data blobs for the different MatrixParams 211 // fields. E.g., the 1st hidden MatrixParams has its weights stored in the 212 // data blob number hidden_blob_offset_, the 2nd one in hidden_blob_offset_ + 213 // 1, and so on. 214 int embeddings_blob_offset_; 215 int hidden_blob_offset_; 216 int hidden_bias_blob_offset_; 217 int softmax_blob_offset_; 218 int softmax_bias_blob_offset_; 219 }; 220 221 } // namespace nlp_core 222 } // namespace libtextclassifier 223 224 #endif // LIBTEXTCLASSIFIER_COMMON_MEMORY_IMAGE_EMBEDDING_NETWORK_PARAMS_FROM_IMAGE_H_ 225