1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "common/embedding-network.h" 18 19 #include <math.h> 20 21 #include "common/simple-adder.h" 22 #include "util/base/integral_types.h" 23 #include "util/base/logging.h" 24 25 namespace libtextclassifier { 26 namespace nlp_core { 27 28 namespace { 29 30 // Returns true if and only if matrix does not use any quantization. 31 bool CheckNoQuantization(const EmbeddingNetworkParams::Matrix &matrix) { 32 if (matrix.quant_type != QuantizationType::NONE) { 33 TC_LOG(ERROR) << "Unsupported quantization"; 34 TC_DCHECK(false); // Crash in debug mode. 35 return false; 36 } 37 return true; 38 } 39 40 // Initializes a Matrix object with the parameters from the MatrixParams 41 // source_matrix. source_matrix should not use quantization. 42 // 43 // Returns true on success, false on error. 44 bool InitNonQuantizedMatrix(const EmbeddingNetworkParams::Matrix &source_matrix, 45 EmbeddingNetwork::Matrix *mat) { 46 mat->resize(source_matrix.rows); 47 48 // Before we access the weights as floats, we need to check that they are 49 // really floats, i.e., no quantization is used. 50 if (!CheckNoQuantization(source_matrix)) return false; 51 const float *weights = 52 reinterpret_cast<const float *>(source_matrix.elements); 53 for (int r = 0; r < source_matrix.rows; ++r) { 54 (*mat)[r] = EmbeddingNetwork::VectorWrapper(weights, source_matrix.cols); 55 weights += source_matrix.cols; 56 } 57 return true; 58 } 59 60 // Initializes a VectorWrapper object with the parameters from the MatrixParams 61 // source_matrix. source_matrix should have exactly one column and should not 62 // use quantization. 63 // 64 // Returns true on success, false on error. 65 bool InitNonQuantizedVector(const EmbeddingNetworkParams::Matrix &source_matrix, 66 EmbeddingNetwork::VectorWrapper *vector) { 67 if (source_matrix.cols != 1) { 68 TC_LOG(ERROR) << "wrong #cols " << source_matrix.cols; 69 return false; 70 } 71 if (!CheckNoQuantization(source_matrix)) { 72 TC_LOG(ERROR) << "unsupported quantization"; 73 return false; 74 } 75 // Before we access the weights as floats, we need to check that they are 76 // really floats, i.e., no quantization is used. 77 if (!CheckNoQuantization(source_matrix)) return false; 78 const float *weights = 79 reinterpret_cast<const float *>(source_matrix.elements); 80 *vector = EmbeddingNetwork::VectorWrapper(weights, source_matrix.rows); 81 return true; 82 } 83 84 // Computes y = weights * Relu(x) + b where Relu is optionally applied. 85 template <typename ScaleAdderClass> 86 bool SparseReluProductPlusBias(bool apply_relu, 87 const EmbeddingNetwork::Matrix &weights, 88 const EmbeddingNetwork::VectorWrapper &b, 89 const VectorSpan<float> &x, 90 EmbeddingNetwork::Vector *y) { 91 // Check that dimensions match. 92 if ((x.size() != weights.size()) || weights.empty()) { 93 TC_LOG(ERROR) << x.size() << " != " << weights.size(); 94 return false; 95 } 96 if (weights[0].size() != b.size()) { 97 TC_LOG(ERROR) << weights[0].size() << " != " << b.size(); 98 return false; 99 } 100 101 y->assign(b.data(), b.data() + b.size()); 102 ScaleAdderClass adder(y->data(), y->size()); 103 104 const int x_size = x.size(); 105 for (int i = 0; i < x_size; ++i) { 106 const float &scale = x[i]; 107 if (apply_relu) { 108 if (scale > 0) { 109 adder.LazyScaleAdd(weights[i].data(), scale); 110 } 111 } else { 112 adder.LazyScaleAdd(weights[i].data(), scale); 113 } 114 } 115 return true; 116 } 117 } // namespace 118 119 bool EmbeddingNetwork::ConcatEmbeddings( 120 const std::vector<FeatureVector> &feature_vectors, Vector *concat) const { 121 concat->resize(concat_layer_size_); 122 123 // Invariant 1: feature_vectors contains exactly one element for each 124 // embedding space. That element is itself a FeatureVector, which may be 125 // empty, but it should be there. 126 if (feature_vectors.size() != embedding_matrices_.size()) { 127 TC_LOG(ERROR) << feature_vectors.size() 128 << " != " << embedding_matrices_.size(); 129 return false; 130 } 131 132 // "es_index" stands for "embedding space index". 133 for (int es_index = 0; es_index < feature_vectors.size(); ++es_index) { 134 // Access is safe by es_index loop bounds and Invariant 1. 135 EmbeddingMatrix *const embedding_matrix = 136 embedding_matrices_[es_index].get(); 137 if (embedding_matrix == nullptr) { 138 // Should not happen, hence our terse log error message. 139 TC_LOG(ERROR) << es_index; 140 return false; 141 } 142 143 // Access is safe due to es_index loop bounds. 144 const FeatureVector &feature_vector = feature_vectors[es_index]; 145 146 // Access is safe by es_index loop bounds, Invariant 1, and Invariant 2. 147 const int concat_offset = concat_offset_[es_index]; 148 149 if (!GetEmbeddingInternal(feature_vector, embedding_matrix, concat_offset, 150 concat->data(), concat->size())) { 151 TC_LOG(ERROR) << es_index; 152 return false; 153 } 154 } 155 return true; 156 } 157 158 bool EmbeddingNetwork::GetEmbedding(const FeatureVector &feature_vector, 159 int es_index, float *embedding) const { 160 EmbeddingMatrix *const embedding_matrix = embedding_matrices_[es_index].get(); 161 if (embedding_matrix == nullptr) { 162 // Should not happen, hence our terse log error message. 163 TC_LOG(ERROR) << es_index; 164 return false; 165 } 166 return GetEmbeddingInternal(feature_vector, embedding_matrix, 0, embedding, 167 embedding_matrices_[es_index]->dim()); 168 } 169 170 bool EmbeddingNetwork::GetEmbeddingInternal( 171 const FeatureVector &feature_vector, 172 EmbeddingMatrix *const embedding_matrix, const int concat_offset, 173 float *concat, int concat_size) const { 174 const int embedding_dim = embedding_matrix->dim(); 175 const bool is_quantized = 176 embedding_matrix->quant_type() != QuantizationType::NONE; 177 const int num_features = feature_vector.size(); 178 for (int fi = 0; fi < num_features; ++fi) { 179 // Both accesses below are safe due to loop bounds for fi. 180 const FeatureType *feature_type = feature_vector.type(fi); 181 const FeatureValue feature_value = feature_vector.value(fi); 182 const int feature_offset = 183 concat_offset + feature_type->base() * embedding_dim; 184 185 // Code below updates max(0, embedding_dim) elements from concat, starting 186 // with index feature_offset. Check below ensures these updates are safe. 187 if ((feature_offset < 0) || 188 (feature_offset + embedding_dim > concat_size)) { 189 TC_LOG(ERROR) << fi << ": " << feature_offset << " " << embedding_dim 190 << " " << concat_size; 191 return false; 192 } 193 194 // Pointer to float / uint8 weights for relevant embedding. 195 const void *embedding_data; 196 197 // Multiplier for each embedding weight. 198 float multiplier; 199 200 if (feature_type->is_continuous()) { 201 // Continuous features (encoded as FloatFeatureValue). 202 FloatFeatureValue float_feature_value(feature_value); 203 const int id = float_feature_value.id; 204 embedding_matrix->get_embedding(id, &embedding_data, &multiplier); 205 multiplier *= float_feature_value.weight; 206 } else { 207 // Discrete features: every present feature has implicit value 1.0. 208 // Hence, after we grab the multiplier below, we don't multiply it by 209 // any weight. 210 embedding_matrix->get_embedding(feature_value, &embedding_data, 211 &multiplier); 212 } 213 214 // Weighted embeddings will be added starting from this address. 215 float *concat_ptr = concat + feature_offset; 216 217 if (is_quantized) { 218 const uint8 *quant_weights = 219 reinterpret_cast<const uint8 *>(embedding_data); 220 for (int i = 0; i < embedding_dim; ++i, ++quant_weights, ++concat_ptr) { 221 // 128 is bias for UINT8 quantization, only one we currently support. 222 *concat_ptr += (static_cast<int>(*quant_weights) - 128) * multiplier; 223 } 224 } else { 225 const float *weights = reinterpret_cast<const float *>(embedding_data); 226 for (int i = 0; i < embedding_dim; ++i, ++weights, ++concat_ptr) { 227 *concat_ptr += *weights * multiplier; 228 } 229 } 230 } 231 return true; 232 } 233 234 bool EmbeddingNetwork::ComputeLogits(const VectorSpan<float> &input, 235 Vector *scores) const { 236 return EmbeddingNetwork::ComputeLogitsInternal(input, scores); 237 } 238 239 bool EmbeddingNetwork::ComputeLogits(const Vector &input, 240 Vector *scores) const { 241 return EmbeddingNetwork::ComputeLogitsInternal(input, scores); 242 } 243 244 bool EmbeddingNetwork::ComputeLogitsInternal(const VectorSpan<float> &input, 245 Vector *scores) const { 246 return FinishComputeFinalScoresInternal<SimpleAdder>(input, scores); 247 } 248 249 template <typename ScaleAdderClass> 250 bool EmbeddingNetwork::FinishComputeFinalScoresInternal( 251 const VectorSpan<float> &input, Vector *scores) const { 252 // This vector serves as an alternating storage for activations of the 253 // different layers. We can't use just one vector here because all of the 254 // activations of the previous layer are needed for computation of 255 // activations of the next one. 256 std::vector<Vector> h_storage(2); 257 258 // Compute pre-logits activations. 259 VectorSpan<float> h_in(input); 260 Vector *h_out; 261 for (int i = 0; i < hidden_weights_.size(); ++i) { 262 const bool apply_relu = i > 0; 263 h_out = &(h_storage[i % 2]); 264 h_out->resize(hidden_bias_[i].size()); 265 if (!SparseReluProductPlusBias<ScaleAdderClass>( 266 apply_relu, hidden_weights_[i], hidden_bias_[i], h_in, h_out)) { 267 return false; 268 } 269 h_in = VectorSpan<float>(*h_out); 270 } 271 272 // Compute logit scores. 273 if (!SparseReluProductPlusBias<ScaleAdderClass>( 274 true, softmax_weights_, softmax_bias_, h_in, scores)) { 275 return false; 276 } 277 278 return true; 279 } 280 281 bool EmbeddingNetwork::ComputeFinalScores( 282 const std::vector<FeatureVector> &features, Vector *scores) const { 283 return ComputeFinalScores(features, {}, scores); 284 } 285 286 bool EmbeddingNetwork::ComputeFinalScores( 287 const std::vector<FeatureVector> &features, 288 const std::vector<float> extra_inputs, Vector *scores) const { 289 // If we haven't successfully initialized, return without doing anything. 290 if (!is_valid()) return false; 291 292 Vector concat; 293 if (!ConcatEmbeddings(features, &concat)) return false; 294 295 if (!extra_inputs.empty()) { 296 concat.reserve(concat.size() + extra_inputs.size()); 297 for (int i = 0; i < extra_inputs.size(); i++) { 298 concat.push_back(extra_inputs[i]); 299 } 300 } 301 302 scores->resize(softmax_bias_.size()); 303 return ComputeLogits(concat, scores); 304 } 305 306 EmbeddingNetwork::EmbeddingNetwork(const EmbeddingNetworkParams *model) { 307 // We'll set valid_ to true only if construction is successful. If we detect 308 // an error along the way, we log an informative message and return early, but 309 // we do not crash. 310 valid_ = false; 311 312 // Fill embedding_matrices_, concat_offset_, and concat_layer_size_. 313 const int num_embedding_spaces = model->GetNumEmbeddingSpaces(); 314 int offset_sum = 0; 315 for (int i = 0; i < num_embedding_spaces; ++i) { 316 concat_offset_.push_back(offset_sum); 317 const EmbeddingNetworkParams::Matrix matrix = model->GetEmbeddingMatrix(i); 318 if (matrix.quant_type != QuantizationType::UINT8) { 319 TC_LOG(ERROR) << "Unsupported quantization for embedding #" << i << ": " 320 << static_cast<int>(matrix.quant_type); 321 return; 322 } 323 324 // There is no way to accomodate an empty embedding matrix. E.g., there is 325 // no way for get_embedding to return something that can be read safely. 326 // Hence, we catch that error here and return early. 327 if (matrix.rows == 0) { 328 TC_LOG(ERROR) << "Empty embedding matrix #" << i; 329 return; 330 } 331 embedding_matrices_.emplace_back(new EmbeddingMatrix(matrix)); 332 const int embedding_dim = embedding_matrices_.back()->dim(); 333 offset_sum += embedding_dim * model->GetNumFeaturesInEmbeddingSpace(i); 334 } 335 concat_layer_size_ = offset_sum; 336 337 // Invariant 2 (trivial by the code above). 338 TC_DCHECK_EQ(concat_offset_.size(), embedding_matrices_.size()); 339 340 const int num_hidden_layers = model->GetNumHiddenLayers(); 341 if (num_hidden_layers < 1) { 342 TC_LOG(ERROR) << "Wrong number of hidden layers: " << num_hidden_layers; 343 return; 344 } 345 hidden_weights_.resize(num_hidden_layers); 346 hidden_bias_.resize(num_hidden_layers); 347 348 for (int i = 0; i < num_hidden_layers; ++i) { 349 const EmbeddingNetworkParams::Matrix matrix = 350 model->GetHiddenLayerMatrix(i); 351 const EmbeddingNetworkParams::Matrix bias = model->GetHiddenLayerBias(i); 352 if (!InitNonQuantizedMatrix(matrix, &hidden_weights_[i]) || 353 !InitNonQuantizedVector(bias, &hidden_bias_[i])) { 354 TC_LOG(ERROR) << "Bad hidden layer #" << i; 355 return; 356 } 357 } 358 359 if (!model->HasSoftmaxLayer()) { 360 TC_LOG(ERROR) << "Missing softmax layer"; 361 return; 362 } 363 const EmbeddingNetworkParams::Matrix softmax = model->GetSoftmaxMatrix(); 364 const EmbeddingNetworkParams::Matrix softmax_bias = model->GetSoftmaxBias(); 365 if (!InitNonQuantizedMatrix(softmax, &softmax_weights_) || 366 !InitNonQuantizedVector(softmax_bias, &softmax_bias_)) { 367 TC_LOG(ERROR) << "Bad softmax layer"; 368 return; 369 } 370 371 // Everything looks good. 372 valid_ = true; 373 } 374 375 int EmbeddingNetwork::EmbeddingSize(int es_index) const { 376 return embedding_matrices_[es_index]->dim(); 377 } 378 379 } // namespace nlp_core 380 } // namespace libtextclassifier 381