1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "model-executor.h" 18 19 #include "quantization.h" 20 #include "util/base/logging.h" 21 22 namespace libtextclassifier2 { 23 namespace internal { 24 bool FromModelSpec(const tflite::Model* model_spec, 25 std::unique_ptr<const tflite::FlatBufferModel>* model) { 26 *model = tflite::FlatBufferModel::BuildFromModel(model_spec); 27 if (!(*model) || !(*model)->initialized()) { 28 TC_LOG(ERROR) << "Could not build TFLite model from a model spec. "; 29 return false; 30 } 31 return true; 32 } 33 } // namespace internal 34 35 std::unique_ptr<tflite::Interpreter> ModelExecutor::CreateInterpreter() const { 36 std::unique_ptr<tflite::Interpreter> interpreter; 37 tflite::InterpreterBuilder(*model_, builtins_)(&interpreter); 38 return interpreter; 39 } 40 41 std::unique_ptr<TFLiteEmbeddingExecutor> TFLiteEmbeddingExecutor::Instance( 42 const flatbuffers::Vector<uint8_t>* model_spec_buffer, int embedding_size, 43 int quantization_bits) { 44 const tflite::Model* model_spec = 45 flatbuffers::GetRoot<tflite::Model>(model_spec_buffer->data()); 46 flatbuffers::Verifier verifier(model_spec_buffer->data(), 47 model_spec_buffer->Length()); 48 std::unique_ptr<const tflite::FlatBufferModel> model; 49 if (!model_spec->Verify(verifier) || 50 !internal::FromModelSpec(model_spec, &model)) { 51 TC_LOG(ERROR) << "Could not load TFLite model."; 52 return nullptr; 53 } 54 55 std::unique_ptr<tflite::Interpreter> interpreter; 56 tflite::ops::builtin::BuiltinOpResolver builtins; 57 tflite::InterpreterBuilder(*model, builtins)(&interpreter); 58 if (!interpreter) { 59 TC_LOG(ERROR) << "Could not build TFLite interpreter for embeddings."; 60 return nullptr; 61 } 62 63 if (interpreter->tensors_size() != 2) { 64 return nullptr; 65 } 66 const TfLiteTensor* embeddings = interpreter->tensor(0); 67 if (embeddings->dims->size != 2) { 68 return nullptr; 69 } 70 int num_buckets = embeddings->dims->data[0]; 71 const TfLiteTensor* scales = interpreter->tensor(1); 72 if (scales->dims->size != 2 || scales->dims->data[0] != num_buckets || 73 scales->dims->data[1] != 1) { 74 return nullptr; 75 } 76 int bytes_per_embedding = embeddings->dims->data[1]; 77 if (!CheckQuantizationParams(bytes_per_embedding, quantization_bits, 78 embedding_size)) { 79 TC_LOG(ERROR) << "Mismatch in quantization parameters."; 80 return nullptr; 81 } 82 83 return std::unique_ptr<TFLiteEmbeddingExecutor>(new TFLiteEmbeddingExecutor( 84 std::move(model), quantization_bits, num_buckets, bytes_per_embedding, 85 embedding_size, scales, embeddings, std::move(interpreter))); 86 } 87 88 TFLiteEmbeddingExecutor::TFLiteEmbeddingExecutor( 89 std::unique_ptr<const tflite::FlatBufferModel> model, int quantization_bits, 90 int num_buckets, int bytes_per_embedding, int output_embedding_size, 91 const TfLiteTensor* scales, const TfLiteTensor* embeddings, 92 std::unique_ptr<tflite::Interpreter> interpreter) 93 : model_(std::move(model)), 94 quantization_bits_(quantization_bits), 95 num_buckets_(num_buckets), 96 bytes_per_embedding_(bytes_per_embedding), 97 output_embedding_size_(output_embedding_size), 98 scales_(scales), 99 embeddings_(embeddings), 100 interpreter_(std::move(interpreter)) {} 101 102 bool TFLiteEmbeddingExecutor::AddEmbedding( 103 const TensorView<int>& sparse_features, float* dest, int dest_size) const { 104 if (dest_size != output_embedding_size_) { 105 TC_LOG(ERROR) << "Mismatching dest_size and output_embedding_size: " 106 << dest_size << " " << output_embedding_size_; 107 return false; 108 } 109 const int num_sparse_features = sparse_features.size(); 110 for (int i = 0; i < num_sparse_features; ++i) { 111 const int bucket_id = sparse_features.data()[i]; 112 if (bucket_id >= num_buckets_) { 113 return false; 114 } 115 116 if (!DequantizeAdd(scales_->data.f, embeddings_->data.uint8, 117 bytes_per_embedding_, num_sparse_features, 118 quantization_bits_, bucket_id, dest, dest_size)) { 119 return false; 120 } 121 } 122 return true; 123 } 124 125 TensorView<float> ComputeLogitsHelper(const int input_index_features, 126 const int output_index_logits, 127 const TensorView<float>& features, 128 tflite::Interpreter* interpreter) { 129 if (!interpreter) { 130 return TensorView<float>::Invalid(); 131 } 132 interpreter->ResizeInputTensor(input_index_features, features.shape()); 133 if (interpreter->AllocateTensors() != kTfLiteOk) { 134 TC_VLOG(1) << "Allocation failed."; 135 return TensorView<float>::Invalid(); 136 } 137 138 TfLiteTensor* features_tensor = 139 interpreter->tensor(interpreter->inputs()[input_index_features]); 140 int size = 1; 141 for (int i = 0; i < features_tensor->dims->size; ++i) { 142 size *= features_tensor->dims->data[i]; 143 } 144 features.copy_to(features_tensor->data.f, size); 145 146 if (interpreter->Invoke() != kTfLiteOk) { 147 TC_VLOG(1) << "Interpreter failed."; 148 return TensorView<float>::Invalid(); 149 } 150 151 TfLiteTensor* logits_tensor = 152 interpreter->tensor(interpreter->outputs()[output_index_logits]); 153 154 std::vector<int> output_shape(logits_tensor->dims->size); 155 for (int i = 0; i < logits_tensor->dims->size; ++i) { 156 output_shape[i] = logits_tensor->dims->data[i]; 157 } 158 159 return TensorView<float>(logits_tensor->data.f, output_shape); 160 } 161 162 } // namespace libtextclassifier2 163