1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "annotator/model-executor.h" 18 19 #include "annotator/quantization.h" 20 #include "utils/base/logging.h" 21 22 namespace libtextclassifier3 { 23 24 TensorView<float> ModelExecutor::ComputeLogits( 25 const TensorView<float>& features, tflite::Interpreter* interpreter) const { 26 if (!interpreter) { 27 return TensorView<float>::Invalid(); 28 } 29 interpreter->ResizeInputTensor(kInputIndexFeatures, features.shape()); 30 if (interpreter->AllocateTensors() != kTfLiteOk) { 31 TC3_VLOG(1) << "Allocation failed."; 32 return TensorView<float>::Invalid(); 33 } 34 35 SetInput<float>(kInputIndexFeatures, features, interpreter); 36 37 if (interpreter->Invoke() != kTfLiteOk) { 38 TC3_VLOG(1) << "Interpreter failed."; 39 return TensorView<float>::Invalid(); 40 } 41 42 return OutputView<float>(kOutputIndexLogits, interpreter); 43 } 44 45 std::unique_ptr<TFLiteEmbeddingExecutor> TFLiteEmbeddingExecutor::FromBuffer( 46 const flatbuffers::Vector<uint8_t>* model_spec_buffer, int embedding_size, 47 int quantization_bits, 48 const Model_::EmbeddingPruningMask* embedding_pruning_mask) { 49 std::unique_ptr<TfLiteModelExecutor> executor = 50 TfLiteModelExecutor::FromBuffer(model_spec_buffer); 51 if (!executor) { 52 TC3_LOG(ERROR) << "Could not load TFLite model for embeddings."; 53 return nullptr; 54 } 55 56 std::unique_ptr<tflite::Interpreter> interpreter = 57 executor->CreateInterpreter(); 58 if (!interpreter) { 59 TC3_LOG(ERROR) << "Could not build TFLite interpreter for embeddings."; 60 return nullptr; 61 } 62 63 if (interpreter->tensors_size() != 2) { 64 return nullptr; 65 } 66 const TfLiteTensor* embeddings = interpreter->tensor(0); 67 if (embeddings->dims->size != 2) { 68 return nullptr; 69 } 70 int num_buckets = embeddings->dims->data[0]; 71 const TfLiteTensor* scales = interpreter->tensor(1); 72 if (scales->dims->size != 2 || scales->dims->data[0] != num_buckets || 73 scales->dims->data[1] != 1) { 74 return nullptr; 75 } 76 int bytes_per_embedding = embeddings->dims->data[1]; 77 if (!CheckQuantizationParams(bytes_per_embedding, quantization_bits, 78 embedding_size)) { 79 TC3_LOG(ERROR) << "Mismatch in quantization parameters."; 80 return nullptr; 81 } 82 83 return std::unique_ptr<TFLiteEmbeddingExecutor>(new TFLiteEmbeddingExecutor( 84 std::move(executor), quantization_bits, num_buckets, bytes_per_embedding, 85 embedding_size, scales, embeddings, std::move(interpreter), 86 embedding_pruning_mask)); 87 } 88 89 TFLiteEmbeddingExecutor::TFLiteEmbeddingExecutor( 90 std::unique_ptr<TfLiteModelExecutor> executor, int quantization_bits, 91 int num_buckets, int bytes_per_embedding, int output_embedding_size, 92 const TfLiteTensor* scales, const TfLiteTensor* embeddings, 93 std::unique_ptr<tflite::Interpreter> interpreter, 94 const Model_::EmbeddingPruningMask* embedding_pruning_mask) 95 : executor_(std::move(executor)), 96 quantization_bits_(quantization_bits), 97 num_buckets_(num_buckets), 98 bytes_per_embedding_(bytes_per_embedding), 99 output_embedding_size_(output_embedding_size), 100 scales_(scales), 101 embeddings_(embeddings), 102 interpreter_(std::move(interpreter)) { 103 if ((embedding_pruning_mask != nullptr) && 104 (embedding_pruning_mask->enabled())) { 105 for (int i = 0; i < embedding_pruning_mask->pruning_mask()->size(); i++) { 106 pruning_mask_.push_back((*(embedding_pruning_mask->pruning_mask()))[i]); 107 } 108 ComputePrefixCounts(); 109 full_num_buckets_ = embedding_pruning_mask->full_num_buckets(); 110 pruned_row_bucket_id_ = embedding_pruning_mask->pruned_row_bucket_id(); 111 } else { 112 full_num_buckets_ = num_buckets; 113 } 114 } 115 116 void TFLiteEmbeddingExecutor::ComputePrefixCounts() { 117 // Pre-compute the prefix sums. 118 // For each i in {0, 1,...,pruning_mask_.size()-1}, we compute number of 1s 119 // in binary representations of the uint64 values in pruning_mask_ before 120 // index i. We set pruned_row_bucket_id_ to the total number of 1s 121 // in binary representations of all values in pruning_mask_. 122 int count = 0; 123 for (const uint64 mask : pruning_mask_) { 124 prefix_counts_.push_back(count); 125 count += __builtin_popcountll(mask); 126 } 127 } 128 129 int TFLiteEmbeddingExecutor::PruneBucketId(int bucket_id) const { 130 // Implements auxiliary data structure for computing the pruned index of a 131 // given bucket_id. 132 // If bucket_id is present in pruning_mask_, we compute floor(bucket_id/64), 133 // look it up in the auxiliary array prefix_counts_, and add to it the number 134 // of 1s before before bucket_id % 64 in the 64-bit sequence 135 // pruning_mask_[floor(bucket_id/64)]. 136 // If bucket_id is absent from pruning_mask_, we return pruned_row_bucket_id_. 137 const int bucket_id_major = bucket_id >> 6; 138 const int bucket_id_minor = bucket_id & 63; 139 uint64_t one = 1; 140 if (!(pruning_mask_[bucket_id_major] & (one << bucket_id_minor))) 141 return pruned_row_bucket_id_; 142 const uint64 zero = 0; 143 uint64 minor_mask; 144 if (bucket_id_minor == 0) 145 minor_mask = zero; 146 else 147 minor_mask = ((~zero) >> (64 - bucket_id_minor)); 148 return prefix_counts_[bucket_id_major] + 149 __builtin_popcountll(pruning_mask_[bucket_id_major] & minor_mask); 150 } 151 152 bool TFLiteEmbeddingExecutor::AddEmbedding( 153 const TensorView<int>& sparse_features, float* dest, int dest_size) const { 154 if (dest_size != output_embedding_size_) { 155 TC3_LOG(ERROR) << "Mismatching dest_size and output_embedding_size: " 156 << dest_size << " " << output_embedding_size_; 157 return false; 158 } 159 const int num_sparse_features = sparse_features.size(); 160 for (int i = 0; i < num_sparse_features; ++i) { 161 const int bucket_id = sparse_features.data()[i]; 162 int full_num_buckets; 163 if (!pruning_mask_.empty()) { 164 full_num_buckets = full_num_buckets_; 165 } else { 166 full_num_buckets = num_buckets_; 167 } 168 if (bucket_id >= full_num_buckets) { 169 return false; 170 } 171 int final_bucket_id; 172 if (!pruning_mask_.empty()) { 173 final_bucket_id = PruneBucketId(bucket_id); 174 } else { 175 final_bucket_id = bucket_id; 176 } 177 if (!DequantizeAdd(scales_->data.f, embeddings_->data.uint8, 178 bytes_per_embedding_, num_sparse_features, 179 quantization_bits_, final_bucket_id, dest, dest_size)) { 180 return false; 181 } 182 } 183 return true; 184 } 185 186 } // namespace libtextclassifier3 187