1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "utils/tflite-model-executor.h" 18 19 #include "utils/base/logging.h" 20 #include "tensorflow/lite/kernels/register.h" 21 22 // Forward declaration of custom TensorFlow Lite ops for registration. 23 namespace tflite { 24 namespace ops { 25 namespace builtin { 26 TfLiteRegistration* Register_ADD(); 27 TfLiteRegistration* Register_CONCATENATION(); 28 TfLiteRegistration* Register_CONV_2D(); 29 TfLiteRegistration* Register_FULLY_CONNECTED(); 30 TfLiteRegistration* Register_L2_NORMALIZATION(); 31 TfLiteRegistration* Register_MUL(); 32 TfLiteRegistration* Register_RESHAPE(); 33 TfLiteRegistration* Register_SOFTMAX(); 34 TfLiteRegistration* Register_GATHER(); 35 TfLiteRegistration* Register_TRANSPOSE(); 36 TfLiteRegistration* Register_SUB(); 37 TfLiteRegistration* Register_DIV(); 38 TfLiteRegistration* Register_STRIDED_SLICE(); 39 TfLiteRegistration* Register_EXP(); 40 TfLiteRegistration* Register_TOPK_V2(); 41 TfLiteRegistration* Register_SPLIT(); 42 TfLiteRegistration* Register_CAST(); 43 TfLiteRegistration* Register_MAXIMUM(); 44 TfLiteRegistration* Register_MINIMUM(); 45 TfLiteRegistration* Register_NEG(); 46 TfLiteRegistration* Register_SLICE(); 47 TfLiteRegistration* Register_LOG(); 48 TfLiteRegistration* Register_SUM(); 49 TfLiteRegistration* Register_PACK(); 50 TfLiteRegistration* Register_DEQUANTIZE(); 51 TfLiteRegistration* Register_MEAN(); 52 } // namespace builtin 53 } // namespace ops 54 } // namespace tflite 55 56 #ifdef TC3_WITH_ACTIONS_OPS 57 #include "utils/tflite/dist_diversification.h" 58 #include "utils/tflite/text_encoder.h" 59 #include "utils/tflite/token_encoder.h" 60 61 void RegisterSelectedOps(tflite::MutableOpResolver* resolver) { 62 resolver->AddBuiltin(tflite::BuiltinOperator_ADD, 63 tflite::ops::builtin::Register_ADD(), 64 /*min_version=*/1, 65 /*max_version=*/2); 66 resolver->AddBuiltin(tflite::BuiltinOperator_CONCATENATION, 67 tflite::ops::builtin::Register_CONCATENATION(), 68 /*min_version=*/1, 69 /*max_version=*/2); 70 resolver->AddBuiltin(tflite::BuiltinOperator_CONV_2D, 71 tflite::ops::builtin::Register_CONV_2D(), 72 /*min_version=*/1, 73 /*max_version=*/3); 74 resolver->AddBuiltin(tflite::BuiltinOperator_FULLY_CONNECTED, 75 tflite::ops::builtin::Register_FULLY_CONNECTED(), 76 /*min_version=*/1, 77 /*max_version=*/4); 78 resolver->AddBuiltin(tflite::BuiltinOperator_L2_NORMALIZATION, 79 tflite::ops::builtin::Register_L2_NORMALIZATION(), 80 /*min_version=*/1, 81 /*max_version=*/2); 82 resolver->AddBuiltin(tflite::BuiltinOperator_MUL, 83 tflite::ops::builtin::Register_MUL()); 84 resolver->AddBuiltin(tflite::BuiltinOperator_RESHAPE, 85 tflite::ops::builtin::Register_RESHAPE()); 86 resolver->AddBuiltin(tflite::BuiltinOperator_SOFTMAX, 87 tflite::ops::builtin::Register_SOFTMAX(), 88 /*min_version=*/1, 89 /*max_version=*/2); 90 resolver->AddBuiltin(tflite::BuiltinOperator_GATHER, 91 tflite::ops::builtin::Register_GATHER(), 92 /*min_version=*/1, 93 /*max_version=*/2); 94 resolver->AddBuiltin(tflite::BuiltinOperator_TRANSPOSE, 95 tflite::ops::builtin::Register_TRANSPOSE(), 96 /*min_version=*/1, 97 /*max_version=*/2); 98 resolver->AddBuiltin(tflite::BuiltinOperator_SUB, 99 tflite::ops::builtin::Register_SUB(), 100 /*min_version=*/1, 101 /*max_version=*/2); 102 resolver->AddBuiltin(tflite::BuiltinOperator_DIV, 103 tflite::ops::builtin::Register_DIV()); 104 resolver->AddBuiltin(tflite::BuiltinOperator_STRIDED_SLICE, 105 tflite::ops::builtin::Register_STRIDED_SLICE(), 106 /*min_version=*/1, 107 /*max_version=*/2); 108 resolver->AddBuiltin(tflite::BuiltinOperator_EXP, 109 tflite::ops::builtin::Register_EXP()); 110 resolver->AddBuiltin(tflite::BuiltinOperator_TOPK_V2, 111 tflite::ops::builtin::Register_TOPK_V2(), 112 /*min_version=*/1, 113 /*max_version=*/2); 114 resolver->AddBuiltin(tflite::BuiltinOperator_SPLIT, 115 tflite::ops::builtin::Register_SPLIT(), 116 /*min_version=*/1, 117 /*max_version=*/3); 118 resolver->AddBuiltin(tflite::BuiltinOperator_CAST, 119 tflite::ops::builtin::Register_CAST()); 120 resolver->AddBuiltin(tflite::BuiltinOperator_MAXIMUM, 121 tflite::ops::builtin::Register_MAXIMUM(), 122 /*min_version=*/1, 123 /*max_version=*/2); 124 resolver->AddBuiltin(tflite::BuiltinOperator_MINIMUM, 125 tflite::ops::builtin::Register_MINIMUM(), 126 /*min_version=*/1, 127 /*max_version=*/2); 128 resolver->AddBuiltin(tflite::BuiltinOperator_NEG, 129 tflite::ops::builtin::Register_NEG()); 130 resolver->AddBuiltin(tflite::BuiltinOperator_SLICE, 131 tflite::ops::builtin::Register_SLICE(), 132 /*min_version=*/1, 133 /*max_version=*/2); 134 resolver->AddBuiltin(tflite::BuiltinOperator_LOG, 135 tflite::ops::builtin::Register_LOG()); 136 resolver->AddBuiltin(tflite::BuiltinOperator_SUM, 137 tflite::ops::builtin::Register_SUM()); 138 resolver->AddBuiltin(tflite::BuiltinOperator_PACK, 139 tflite::ops::builtin::Register_PACK(), 140 /*min_version=*/1, 141 /*max_version=*/2); 142 resolver->AddBuiltin(tflite::BuiltinOperator_DEQUANTIZE, 143 tflite::ops::builtin::Register_DEQUANTIZE(), 144 /*min_version=*/1, 145 /*max_version=*/2); 146 resolver->AddBuiltin(tflite::BuiltinOperator_MEAN, 147 tflite::ops::builtin::Register_MEAN()); 148 } 149 #else 150 void RegisterSelectedOps(tflite::MutableOpResolver* resolver) { 151 resolver->AddBuiltin(tflite::BuiltinOperator_FULLY_CONNECTED, 152 tflite::ops::builtin::Register_FULLY_CONNECTED()); 153 } 154 #endif // TC3_WITH_ACTIONS_OPS 155 156 namespace libtextclassifier3 { 157 158 inline std::unique_ptr<tflite::OpResolver> BuildOpResolver() { 159 #ifdef TC3_USE_SELECTIVE_REGISTRATION 160 std::unique_ptr<tflite::MutableOpResolver> resolver( 161 new tflite::MutableOpResolver); 162 RegisterSelectedOps(resolver.get()); 163 #else 164 std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver( 165 new tflite::ops::builtin::BuiltinOpResolver); 166 #endif 167 #ifdef TC3_WITH_ACTIONS_OPS 168 resolver->AddCustom("DistanceDiversification", 169 tflite::ops::custom::Register_DISTANCE_DIVERSIFICATION()); 170 resolver->AddCustom("TextEncoder", 171 tflite::ops::custom::Register_TEXT_ENCODER()); 172 resolver->AddCustom("TokenEncoder", 173 tflite::ops::custom::Register_TOKEN_ENCODER()); 174 #endif // TC3_WITH_ACTIONS_OPS 175 return std::unique_ptr<tflite::OpResolver>(std::move(resolver)); 176 } 177 178 std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromModelSpec( 179 const tflite::Model* model_spec) { 180 std::unique_ptr<const tflite::FlatBufferModel> model( 181 tflite::FlatBufferModel::BuildFromModel(model_spec)); 182 if (!model || !model->initialized()) { 183 TC3_LOG(ERROR) << "Could not build TFLite model from a model spec."; 184 return nullptr; 185 } 186 return model; 187 } 188 189 std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromBuffer( 190 const flatbuffers::Vector<uint8_t>* model_spec_buffer) { 191 const tflite::Model* model = 192 flatbuffers::GetRoot<tflite::Model>(model_spec_buffer->data()); 193 flatbuffers::Verifier verifier(model_spec_buffer->data(), 194 model_spec_buffer->Length()); 195 if (!model->Verify(verifier)) { 196 return nullptr; 197 } 198 return TfLiteModelFromModelSpec(model); 199 } 200 201 TfLiteModelExecutor::TfLiteModelExecutor( 202 std::unique_ptr<const tflite::FlatBufferModel> model) 203 : model_(std::move(model)), resolver_(BuildOpResolver()) {} 204 205 std::unique_ptr<tflite::Interpreter> TfLiteModelExecutor::CreateInterpreter() 206 const { 207 std::unique_ptr<tflite::Interpreter> interpreter; 208 tflite::InterpreterBuilder(*model_, *resolver_)(&interpreter); 209 return interpreter; 210 } 211 212 template <> 213 void TfLiteModelExecutor::SetInput(const int input_index, 214 const std::vector<std::string>& input_data, 215 tflite::Interpreter* interpreter) const { 216 tflite::DynamicBuffer buf; 217 for (const std::string& s : input_data) { 218 buf.AddString(s.data(), s.length()); 219 } 220 buf.WriteToTensorAsVector( 221 interpreter->tensor(interpreter->inputs()[input_index])); 222 } 223 224 template <> 225 std::vector<tflite::StringRef> TfLiteModelExecutor::Output( 226 const int output_index, const tflite::Interpreter* interpreter) const { 227 const TfLiteTensor* output_tensor = 228 interpreter->tensor(interpreter->outputs()[output_index]); 229 const int num_strings = tflite::GetStringCount(output_tensor); 230 std::vector<tflite::StringRef> output(num_strings); 231 for (int i = 0; i < num_strings; i++) { 232 output[i] = tflite::GetString(output_tensor, i); 233 } 234 return output; 235 } 236 237 template <> 238 std::vector<std::string> TfLiteModelExecutor::Output( 239 const int output_index, const tflite::Interpreter* interpreter) const { 240 std::vector<std::string> output; 241 for (const tflite::StringRef& s : 242 Output<tflite::StringRef>(output_index, interpreter)) { 243 output.push_back(std::string(s.str, s.len)); 244 } 245 return output; 246 } 247 248 } // namespace libtextclassifier3 249