1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 // Deserialization infrastructure for tflite. Provides functionality 16 // to go from a serialized tflite model in flatbuffer format to an 17 // interpreter. 18 // 19 // using namespace tflite; 20 // StderrReporter error_reporter; 21 // auto model = FlatBufferModel::BuildFromFile("interesting_model.tflite", 22 // &error_reporter); 23 // MyOpResolver resolver; // You need to subclass OpResolver to provide 24 // // implementations. 25 // InterpreterBuilder builder(*model, resolver); 26 // std::unique_ptr<Interpreter> interpreter; 27 // if(builder(&interpreter) == kTfLiteOk) { 28 // .. run model inference with interpreter 29 // } 30 // 31 // OpResolver must be defined to provide your kernel implementations to the 32 // interpreter. This is environment specific and may consist of just the builtin 33 // ops, or some custom operators you defined to extend tflite. 34 #ifndef TENSORFLOW_CONTRIB_LITE_MODEL_H_ 35 #define TENSORFLOW_CONTRIB_LITE_MODEL_H_ 36 37 #include <memory> 38 #include "tensorflow/contrib/lite/error_reporter.h" 39 #include "tensorflow/contrib/lite/interpreter.h" 40 #include "tensorflow/contrib/lite/schema/schema_generated.h" 41 42 namespace tflite { 43 44 // An RAII object that represents a read-only tflite model, copied from disk, 45 // or mmapped. This uses flatbuffers as the serialization format. 46 class FlatBufferModel { 47 public: 48 // Builds a model based on a file. Returns a nullptr in case of failure. 49 static std::unique_ptr<FlatBufferModel> BuildFromFile( 50 const char* filename, 51 ErrorReporter* error_reporter = DefaultErrorReporter()); 52 53 // Builds a model based on a pre-loaded flatbuffer. The caller retains 54 // ownership of the buffer and should keep it alive until the returned object 55 // is destroyed. Returns a nullptr in case of failure. 56 static std::unique_ptr<FlatBufferModel> BuildFromBuffer( 57 const char* buffer, size_t buffer_size, 58 ErrorReporter* error_reporter = DefaultErrorReporter()); 59 60 // Builds a model directly from a flatbuffer pointer. The caller retains 61 // ownership of the buffer and should keep it alive until the returned object 62 // is destroyed. Returns a nullptr in case of failure. 63 static std::unique_ptr<FlatBufferModel> BuildFromModel( 64 const tflite::Model* model_spec, 65 ErrorReporter* error_reporter = DefaultErrorReporter()); 66 67 // Releases memory or unmaps mmaped meory. 68 ~FlatBufferModel(); 69 70 // Copying or assignment is disallowed to simplify ownership semantics. 71 FlatBufferModel(const FlatBufferModel&) = delete; 72 FlatBufferModel& operator=(const FlatBufferModel&) = delete; 73 74 bool initialized() const { return model_ != nullptr; } 75 const tflite::Model* operator->() const { return model_; } 76 const tflite::Model* GetModel() const { return model_; } 77 ErrorReporter* error_reporter() const { return error_reporter_; } 78 const Allocation* allocation() const { return allocation_; } 79 80 // Returns true if the model identifier is correct (otherwise false and 81 // reports an error). 82 bool CheckModelIdentifier() const; 83 84 private: 85 // Loads a model from `filename`. If `mmap_file` is true then use mmap, 86 // otherwise make a copy of the model in a buffer. 87 // 88 // Note, if `error_reporter` is null, then a DefaultErrorReporter() will be 89 // used. 90 explicit FlatBufferModel( 91 const char* filename, bool mmap_file = true, 92 ErrorReporter* error_reporter = DefaultErrorReporter(), 93 bool use_nnapi = false); 94 95 // Loads a model from `ptr` and `num_bytes` of the model file. The `ptr` has 96 // to remain alive and unchanged until the end of this flatbuffermodel's 97 // lifetime. 98 // 99 // Note, if `error_reporter` is null, then a DefaultErrorReporter() will be 100 // used. 101 FlatBufferModel(const char* ptr, size_t num_bytes, 102 ErrorReporter* error_reporter = DefaultErrorReporter()); 103 104 // Loads a model from Model flatbuffer. The `model` has to remain alive and 105 // unchanged until the end of this flatbuffermodel's lifetime. 106 FlatBufferModel(const Model* model, ErrorReporter* error_reporter); 107 108 // Flatbuffer traverser pointer. (Model* is a pointer that is within the 109 // allocated memory of the data allocated by allocation's internals. 110 const tflite::Model* model_ = nullptr; 111 ErrorReporter* error_reporter_; 112 Allocation* allocation_ = nullptr; 113 }; 114 115 // Abstract interface that returns TfLiteRegistrations given op codes or custom 116 // op names. This is the mechanism that ops being referenced in the flatbuffer 117 // model are mapped to executable function pointers (TfLiteRegistrations). 118 class OpResolver { 119 public: 120 // Finds the op registration for a builtin operator by enum code. 121 virtual TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const = 0; 122 // Finds the op registration of a custom operator by op name. 123 virtual TfLiteRegistration* FindOp(const char* op) const = 0; 124 virtual ~OpResolver() {} 125 }; 126 127 // Build an interpreter capable of interpreting `model`. 128 // 129 // model: a scoped model whose lifetime must be at least as long as 130 // the interpreter. In principle multiple interpreters can be made from 131 // a single model. 132 // op_resolver: An instance that implements the Resolver interface which maps 133 // custom op names and builtin op codes to op registrations. 134 // reportError: a functor that is called to report errors that handles 135 // printf var arg semantics. The lifetime of the reportError object must 136 // be greater than or equal to the Interpreter created by operator(). 137 // 138 // Returns a kTfLiteOk when successful and sets interpreter to a valid 139 // Interpreter. Note: the user must ensure the model lifetime is at least as 140 // long as interpreter's lifetime. 141 class InterpreterBuilder { 142 public: 143 InterpreterBuilder(const FlatBufferModel& model, 144 const OpResolver& op_resolver); 145 // Builds an interpreter given only the raw flatbuffer Model object (instead 146 // of a FlatBufferModel). Mostly used for testing. 147 // If `error_reporter` is null, then DefaultErrorReporter() is used. 148 InterpreterBuilder(const ::tflite::Model* model, 149 const OpResolver& op_resolver, 150 ErrorReporter* error_reporter = DefaultErrorReporter()); 151 InterpreterBuilder(const InterpreterBuilder&) = delete; 152 InterpreterBuilder& operator=(const InterpreterBuilder&) = delete; 153 TfLiteStatus operator()(std::unique_ptr<Interpreter>* interpreter); 154 155 private: 156 TfLiteStatus BuildLocalIndexToRegistrationMapping(); 157 TfLiteStatus ParseNodes( 158 const flatbuffers::Vector<flatbuffers::Offset<Operator>>* operators, 159 Interpreter* interpreter); 160 TfLiteStatus ParseTensors( 161 const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers, 162 const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors, 163 Interpreter* interpreter); 164 165 const ::tflite::Model* model_; 166 const OpResolver& op_resolver_; 167 ErrorReporter* error_reporter_; 168 169 std::vector<TfLiteRegistration*> flatbuffer_op_index_to_registration_; 170 std::vector<BuiltinOperator> flatbuffer_op_index_to_registration_types_; 171 const Allocation* allocation_ = nullptr; 172 }; 173 174 } // namespace tflite 175 176 #endif // TENSORFLOW_CONTRIB_LITE_MODEL_H_ 177