1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "lang_id/common/embedding-network.h" 18 19 #include "lang_id/common/lite_base/integral-types.h" 20 #include "lang_id/common/lite_base/logging.h" 21 22 namespace libtextclassifier3 { 23 namespace mobile { 24 namespace { 25 26 void CheckNoQuantization(const EmbeddingNetworkParams::Matrix &matrix) { 27 SAFTM_CHECK_EQ(static_cast<int>(QuantizationType::NONE), 28 static_cast<int>(matrix.quant_type)) 29 << "Quantization not allowed here"; 30 } 31 32 int GetMatrixRowSizeInBytes(const EmbeddingNetworkParams::Matrix &matrix) { 33 int cols = matrix.cols; 34 QuantizationType quant_type = matrix.quant_type; 35 switch (quant_type) { 36 case QuantizationType::NONE: 37 return cols * sizeof(float); 38 case QuantizationType::UINT8: 39 return cols * sizeof(uint8); 40 case QuantizationType::UINT4: 41 SAFTM_DCHECK_EQ(cols % 2, 0) << "UINT4 with odd #cols = " << cols; 42 return cols / 2; 43 case QuantizationType::FLOAT16: 44 return cols * sizeof(float16); 45 default: 46 SAFTM_LOG(FATAL) << "Unknown quant type: " 47 << static_cast<int>(quant_type); 48 } 49 } 50 51 // Computes y = weights * Relu(x) + b where Relu is optionally applied. 52 // 53 // weights and b are the weight matrix, respectively the bias vector of a neural 54 // network layer. 55 // 56 // Note: in the research literature, usually Relu (the activation function) is 57 // the last part of a neural layer. From that perspective, this function 58 // computes the Relu part of the previous layer (if any) and next the first half 59 // (the computation of the state) for the current layer. 60 // 61 // Note: weights is expected to be the transposed version of the real weight 62 // matrix. Hence, instead of computing a linear combination of the columns of 63 // weights, we compute a linear combination of its rows; but we are mindful that 64 // these rows are the columns of the original matrix, hence the name 65 // weights_col_i in the code. 66 void SparseReluProductPlusBias(bool apply_relu, 67 const EmbeddingNetworkParams::Matrix &weights, 68 const EmbeddingNetworkParams::Matrix &b, 69 const std::vector<float> &x, 70 std::vector<float> *y) { 71 // Initialize y to b. b is a column matrix (i.e., nb.cols == 1); we already 72 // CHECK-ed that the EmbeddingNetwork constructor. 73 const float *b_start = reinterpret_cast<const float *>(b.elements); 74 SAFTM_DCHECK_EQ(b.cols, 1); 75 y->assign(b_start, b_start + b.rows); 76 77 float *const y_data = y->data(); 78 const int y_size = y->size(); 79 SAFTM_CHECK_EQ(weights.cols, y_size); 80 const int x_size = x.size(); 81 SAFTM_CHECK_EQ(weights.rows, x_size); 82 83 // NOTE: the code below reads x_size * y_size elements from weights; these 84 // reads are safe as long as weights.elements contains weights.rows * 85 // weights.cols elements (where the element size depends on the quantization 86 // type). That requirement is checked by the params provider, e.g., by 87 // EmbeddingNetworkParamsFromFlatbuffer. 88 89 // There is some code duplication between the two main cases of the switch 90 // below: the idea was to "lift" the switch outside the loops, to reduce the 91 // number of tests at runtime. 92 switch (weights.quant_type) { 93 case QuantizationType::NONE: { 94 // We compute a linear combination of the rows from |weights|, using 95 // elements of x (optionally, Relu(x)) as scaling factors (the i-th row 96 // gets multiplied by x[i] before being added with the other rows). Note: 97 // elements of |weights| are stored in row-major order: first the elements 98 // of row #0, next the elements of row #1, etc. In the comments below, we 99 // write "weights[i][j]" to refer to the j-th element from the i-th row of 100 // weights. 101 const float *weight_ptr = 102 reinterpret_cast<const float *>(weights.elements); 103 for (int i = 0; i < x_size; ++i) { 104 // Invariant 1: weight_ptr points to the beginning of the i-th row from 105 // weights (i.e., weights[i][0]). 106 const float scale = x[i]; 107 if (!apply_relu || (scale > 0)) { 108 for (int j = 0; j < y_size; ++j, ++weight_ptr) { 109 // Invariant 2: weight_ptr points to weights[i][j]. 110 y_data[j] += (*weight_ptr) * scale; 111 } 112 } else { 113 // We don't update y_data, but we still have to move weight_ptr to the 114 // next row (to satisfy Invariant 1). We do this by adding y_size == 115 // weights.cols() (see earlier CHECK_EQ). 116 weight_ptr += y_size; 117 } 118 } 119 break; 120 } 121 case QuantizationType::FLOAT16: { 122 // See comments for the QuantizationType::NONE case: the code is almost 123 // identical, except for float16 (instead of float) and the Float16To32 124 // conversion. We could unify these two cases using a template, but since 125 // this is a critical loop, don't want to risk that e.g., inlining of the 126 // conversion function doesn't happen. 127 const float16 *weight_ptr = 128 reinterpret_cast<const float16 *>(weights.elements); 129 for (int i = 0; i < x_size; ++i) { 130 const float scale = x[i]; 131 if (!apply_relu || (scale > 0)) { 132 for (int j = 0; j < y_size; ++j, ++weight_ptr) { 133 y_data[j] += Float16To32(*weight_ptr) * scale; 134 } 135 } else { 136 weight_ptr += y_size; 137 } 138 } 139 break; 140 } 141 default: 142 SAFTM_LOG(FATAL) << "Unsupported weights quantization type: " 143 << static_cast<int>(weights.quant_type); 144 } 145 } 146 } // namespace 147 148 void EmbeddingNetwork::ConcatEmbeddings( 149 const std::vector<FeatureVector> &feature_vectors, 150 std::vector<float> *concat) const { 151 concat->resize(concat_layer_size_); 152 153 // "es_index" stands for "embedding space index". 154 for (int es_index = 0; es_index < feature_vectors.size(); ++es_index) { 155 const int concat_offset = concat_offset_[es_index]; 156 157 const EmbeddingNetworkParams::Matrix &embedding_matrix = 158 embedding_matrices_[es_index]; 159 const int embedding_dim = embedding_matrix.cols; 160 const int embedding_row_size_in_bytes = 161 embedding_row_size_in_bytes_[es_index]; 162 163 const FeatureVector &feature_vector = feature_vectors[es_index]; 164 const int num_features = feature_vector.size(); 165 for (int fi = 0; fi < num_features; ++fi) { 166 const FeatureType *feature_type = feature_vector.type(fi); 167 int feature_offset = concat_offset + feature_type->base() * embedding_dim; 168 SAFTM_CHECK_LE(feature_offset + embedding_dim, concat->size()); 169 170 // Weighted embeddings will be added starting from this address. 171 float *concat_ptr = concat->data() + feature_offset; 172 173 // Multiplier for each embedding weight. Includes feature weight (for 174 // continuous features) and quantization scale (for quantized embeddings). 175 float multiplier; 176 int feature_id; 177 const FeatureValue feature_value = feature_vector.value(fi); 178 if (feature_type->is_continuous()) { 179 // Continuous features (encoded as FloatFeatureValue). 180 FloatFeatureValue float_feature_value(feature_value); 181 feature_id = float_feature_value.id; 182 multiplier = float_feature_value.weight; 183 } else { 184 // Discrete features: every present feature has implicit value 1.0. 185 feature_id = feature_value; 186 multiplier = 1.0; 187 } 188 189 SAFTM_CHECK_GE(feature_id, 0); 190 SAFTM_CHECK_LT(feature_id, embedding_matrix.rows); 191 192 // Pointer to float / uint8 weights for relevant embedding. 193 const void *embedding_data = 194 (reinterpret_cast<const char *>(embedding_matrix.elements) + 195 feature_id * embedding_row_size_in_bytes); 196 197 switch (embedding_matrix.quant_type) { 198 case QuantizationType::NONE: { 199 const float *weights = 200 reinterpret_cast<const float *>(embedding_data); 201 for (int i = 0; i < embedding_dim; ++i, ++weights, ++concat_ptr) { 202 *concat_ptr += *weights * multiplier; 203 } 204 break; 205 } 206 case QuantizationType::UINT8: { 207 multiplier *= Float16To32(embedding_matrix.quant_scales[feature_id]); 208 const uint8 *quant_weights = 209 reinterpret_cast<const uint8 *>(embedding_data); 210 for (int i = 0; i < embedding_dim; 211 ++i, ++quant_weights, ++concat_ptr) { 212 // 128 is bias for UINT8 quantization. 213 *concat_ptr += 214 (static_cast<int>(*quant_weights) - 128) * multiplier; 215 } 216 break; 217 } 218 case QuantizationType::UINT4: { 219 multiplier *= Float16To32(embedding_matrix.quant_scales[feature_id]); 220 const uint8 *quant_weights = 221 reinterpret_cast<const uint8 *>(embedding_data); 222 for (int i = 0; i < embedding_dim / 2; ++i, ++quant_weights) { 223 const uint8 qq = *quant_weights; 224 concat_ptr[0] += 225 (static_cast<int>((qq & 0xF0) | 0x08) - 128) * multiplier; 226 concat_ptr[1] += 227 (static_cast<int>(((qq & 0x0F) << 4) | 0x08) - 128) * 228 multiplier; 229 concat_ptr += 2; 230 } 231 break; 232 } 233 default: 234 // We already checked (in GetMatrixRowSizeInBytes) that each embedding 235 // matrix has a known quantization type. Hence, DLOG is enough here. 236 SAFTM_DLOG(ERROR) << "Unknown embeddings quantization type " 237 << static_cast<int>(embedding_matrix.quant_type); 238 break; 239 } 240 } 241 } 242 } 243 244 void EmbeddingNetwork::ComputeFinalScores( 245 const std::vector<FeatureVector> &features, 246 std::vector<float> *scores) const { 247 ComputeFinalScores(features, {}, scores); 248 } 249 250 void EmbeddingNetwork::ComputeFinalScores( 251 const std::vector<FeatureVector> &features, 252 const std::vector<float> &extra_inputs, std::vector<float> *scores) const { 253 // Construct the input layer for our feed-forward neural network (FFNN). 254 std::vector<float> input; 255 ConcatEmbeddings(features, &input); 256 if (!extra_inputs.empty()) { 257 input.reserve(input.size() + extra_inputs.size()); 258 for (int i = 0; i < extra_inputs.size(); i++) { 259 input.push_back(extra_inputs[i]); 260 } 261 } 262 263 // Propagate input through all layers of our FFNN. 264 265 // Alternating storage for activations of the different layers. We can't use 266 // a single vector because all activations of the previous layer are required 267 // when computing the activations of the next one. 268 std::vector<float> storage[2]; 269 const std::vector<float> *v_in = &input; 270 const int num_layers = layer_weights_.size(); 271 for (int i = 0; i < num_layers; ++i) { 272 std::vector<float> *v_out = nullptr; 273 if (i == num_layers - 1) { 274 // Final layer: write results directly into |scores|. 275 v_out = scores; 276 } else { 277 // Hidden layer: write results into the alternating storage. The i % 2 278 // trick ensures the alternation. 279 v_out = &(storage[i % 2]); 280 } 281 const bool apply_relu = i > 0; 282 SparseReluProductPlusBias( 283 apply_relu, layer_weights_[i], layer_bias_[i], *v_in, v_out); 284 v_in = v_out; 285 } 286 } 287 288 EmbeddingNetwork::EmbeddingNetwork(const EmbeddingNetworkParams *model) 289 : model_(model) { 290 int offset_sum = 0; 291 for (int i = 0; i < model_->embedding_num_features_size(); ++i) { 292 concat_offset_.push_back(offset_sum); 293 EmbeddingNetworkParams::Matrix matrix = model_->GetEmbeddingMatrix(i); 294 offset_sum += matrix.cols * model_->embedding_num_features(i); 295 296 // NOTE: each Matrix is a small struct that doesn't own the actual matrix 297 // weights. Hence, the push_back below is fast. 298 embedding_matrices_.push_back(matrix); 299 embedding_row_size_in_bytes_.push_back(GetMatrixRowSizeInBytes(matrix)); 300 } 301 concat_layer_size_ = offset_sum; 302 303 SAFTM_CHECK_EQ(model_->hidden_size(), model_->hidden_bias_size()); 304 for (int i = 0; i < model_->hidden_size(); ++i) { 305 layer_weights_.push_back(model_->GetHiddenLayerMatrix(i)); 306 307 EmbeddingNetworkParams::Matrix bias = model_->GetHiddenLayerBias(i); 308 SAFTM_CHECK_EQ(1, bias.cols); 309 CheckNoQuantization(bias); 310 layer_bias_.push_back(bias); 311 } 312 313 SAFTM_CHECK(model_->HasSoftmax()); 314 layer_weights_.push_back(model_->GetSoftmaxMatrix()); 315 316 EmbeddingNetworkParams::Matrix softmax_bias = model_->GetSoftmaxBias(); 317 SAFTM_CHECK_EQ(1, softmax_bias.cols); 318 CheckNoQuantization(softmax_bias); 319 layer_bias_.push_back(softmax_bias); 320 } 321 322 } // namespace mobile 323 } // namespace nlp_saft 324