Home | History | Annotate | Download | only in utils

Lines Matching defs:tflite

17 #include "utils/tflite-model-executor.h"
23 namespace tflite {
54 } // namespace tflite
57 #include "utils/tflite/dist_diversification.h"
58 #include "utils/tflite/text_encoder.h"
59 #include "utils/tflite/token_encoder.h"
61 void RegisterSelectedOps(tflite::MutableOpResolver* resolver) {
62 resolver->AddBuiltin(tflite::BuiltinOperator_ADD,
63 tflite::ops::builtin::Register_ADD(),
66 resolver->AddBuiltin(tflite::BuiltinOperator_CONCATENATION,
67 tflite::ops::builtin::Register_CONCATENATION(),
70 resolver->AddBuiltin(tflite::BuiltinOperator_CONV_2D,
71 tflite::ops::builtin::Register_CONV_2D(),
74 resolver->AddBuiltin(tflite::BuiltinOperator_FULLY_CONNECTED,
75 tflite::ops::builtin::Register_FULLY_CONNECTED(),
78 resolver->AddBuiltin(tflite::BuiltinOperator_L2_NORMALIZATION,
79 tflite::ops::builtin::Register_L2_NORMALIZATION(),
82 resolver->AddBuiltin(tflite::BuiltinOperator_MUL,
83 tflite::ops::builtin::Register_MUL());
84 resolver->AddBuiltin(tflite::BuiltinOperator_RESHAPE,
85 tflite::ops::builtin::Register_RESHAPE());
86 resolver->AddBuiltin(tflite::BuiltinOperator_SOFTMAX,
87 tflite::ops::builtin::Register_SOFTMAX(),
90 resolver->AddBuiltin(tflite::BuiltinOperator_GATHER,
91 tflite::ops::builtin::Register_GATHER(),
94 resolver->AddBuiltin(tflite::BuiltinOperator_TRANSPOSE,
95 tflite::ops::builtin::Register_TRANSPOSE(),
98 resolver->AddBuiltin(tflite::BuiltinOperator_SUB,
99 tflite::ops::builtin::Register_SUB(),
102 resolver->AddBuiltin(tflite::BuiltinOperator_DIV,
103 tflite::ops::builtin::Register_DIV());
104 resolver->AddBuiltin(tflite::BuiltinOperator_STRIDED_SLICE,
105 tflite::ops::builtin::Register_STRIDED_SLICE(),
108 resolver->AddBuiltin(tflite::BuiltinOperator_EXP,
109 tflite::ops::builtin::Register_EXP());
110 resolver->AddBuiltin(tflite::BuiltinOperator_TOPK_V2,
111 tflite::ops::builtin::Register_TOPK_V2(),
114 resolver->AddBuiltin(tflite::BuiltinOperator_SPLIT,
115 tflite::ops::builtin::Register_SPLIT(),
118 resolver->AddBuiltin(tflite::BuiltinOperator_CAST,
119 tflite::ops::builtin::Register_CAST());
120 resolver->AddBuiltin(tflite::BuiltinOperator_MAXIMUM,
121 tflite::ops::builtin::Register_MAXIMUM(),
124 resolver->AddBuiltin(tflite::BuiltinOperator_MINIMUM,
125 tflite::ops::builtin::Register_MINIMUM(),
128 resolver->AddBuiltin(tflite::BuiltinOperator_NEG,
129 tflite::ops::builtin::Register_NEG());
130 resolver->AddBuiltin(tflite::BuiltinOperator_SLICE,
131 tflite::ops::builtin::Register_SLICE(),
134 resolver->AddBuiltin(tflite::BuiltinOperator_LOG,
135 tflite::ops::builtin::Register_LOG());
136 resolver->AddBuiltin(tflite::BuiltinOperator_SUM,
137 tflite::ops::builtin::Register_SUM());
138 resolver->AddBuiltin(tflite::BuiltinOperator_PACK,
139 tflite::ops::builtin::Register_PACK(),
142 resolver->AddBuiltin(tflite::BuiltinOperator_DEQUANTIZE,
143 tflite::ops::builtin::Register_DEQUANTIZE(),
146 resolver->AddBuiltin(tflite::BuiltinOperator_MEAN,
147 tflite::ops::builtin::Register_MEAN());
150 void RegisterSelectedOps(tflite::MutableOpResolver* resolver) {
151 resolver->AddBuiltin(tflite::BuiltinOperator_FULLY_CONNECTED,
152 tflite::ops::builtin::Register_FULLY_CONNECTED());
158 inline std::unique_ptr<tflite::OpResolver> BuildOpResolver() {
160 std::unique_ptr<tflite::MutableOpResolver> resolver(
161 new tflite::MutableOpResolver);
164 std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver(
165 new tflite::ops::builtin::BuiltinOpResolver);
169 tflite::ops::custom::Register_DISTANCE_DIVERSIFICATION());
171 tflite::ops::custom::Register_TEXT_ENCODER());
173 tflite::ops::custom::Register_TOKEN_ENCODER());
175 return std::unique_ptr<tflite::OpResolver>(std::move(resolver));
178 std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromModelSpec(
179 const tflite::Model* model_spec) {
180 std::unique_ptr<const tflite::FlatBufferModel> model(
181 tflite::FlatBufferModel::BuildFromModel(model_spec));
183 TC3_LOG(ERROR) << "Could not build TFLite model from a model spec.";
189 std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromBuffer(
191 const tflite::Model* model =
192 flatbuffers::GetRoot<tflite::Model>(model_spec_buffer->data());
202 std::unique_ptr<const tflite::FlatBufferModel> model)
205 std::unique_ptr<tflite::Interpreter> TfLiteModelExecutor::CreateInterpreter()
207 std::unique_ptr<tflite::Interpreter> interpreter;
208 tflite::InterpreterBuilder(*model_, *resolver_)(&interpreter);
215 tflite::Interpreter* interpreter) const {
216 tflite::DynamicBuffer buf;
225 std::vector<tflite::StringRef> TfLiteModelExecutor::Output(
226 const int output_index, const tflite::Interpreter* interpreter) const {
229 const int num_strings = tflite::GetStringCount(output_tensor);
230 std::vector<tflite::StringRef> output(num_strings);
232 output[i] = tflite::GetString(output_tensor, i);
239 const int output_index, const tflite::Interpreter* interpreter) const {
241 for (const tflite::StringRef& s :
242 Output<tflite::StringRef>(output_index, interpreter)) {