Home | History | Annotate | Download | only in lite
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 // Deserialization infrastructure for tflite. Provides functionality
     16 // to go from a serialized tflite model in flatbuffer format to an
     17 // interpreter.
     18 //
     19 // using namespace tflite;
     20 // StderrReporter error_reporter;
     21 // auto model = FlatBufferModel::BuildFromFile("interesting_model.tflite",
     22 //                                             &error_reporter);
     23 // MyOpResolver resolver;  // You need to subclass OpResolver to provide
     24 //                         // implementations.
     25 // InterpreterBuilder builder(*model, resolver);
     26 // std::unique_ptr<Interpreter> interpreter;
     27 // if(builder(&interpreter) == kTfLiteOk) {
     28 //   .. run model inference with interpreter
     29 // }
     30 //
     31 // OpResolver must be defined to provide your kernel implementations to the
     32 // interpreter. This is environment specific and may consist of just the builtin
     33 // ops, or some custom operators you defined to extend tflite.
     34 #ifndef TENSORFLOW_CONTRIB_LITE_MODEL_H_
     35 #define TENSORFLOW_CONTRIB_LITE_MODEL_H_
     36 
     37 #include <memory>
     38 #include "tensorflow/contrib/lite/error_reporter.h"
     39 #include "tensorflow/contrib/lite/interpreter.h"
     40 #include "tensorflow/contrib/lite/schema/schema_generated.h"
     41 
     42 namespace tflite {
     43 
     44 // An RAII object that represents a read-only tflite model, copied from disk,
     45 // or mmapped. This uses flatbuffers as the serialization format.
     46 class FlatBufferModel {
     47  public:
     48   // Builds a model based on a file. Returns a nullptr in case of failure.
     49   static std::unique_ptr<FlatBufferModel> BuildFromFile(
     50       const char* filename,
     51       ErrorReporter* error_reporter = DefaultErrorReporter());
     52 
     53   // Builds a model based on a pre-loaded flatbuffer. The caller retains
     54   // ownership of the buffer and should keep it alive until the returned object
     55   // is destroyed. Returns a nullptr in case of failure.
     56   static std::unique_ptr<FlatBufferModel> BuildFromBuffer(
     57       const char* buffer, size_t buffer_size,
     58       ErrorReporter* error_reporter = DefaultErrorReporter());
     59 
     60   // Builds a model directly from a flatbuffer pointer. The caller retains
     61   // ownership of the buffer and should keep it alive until the returned object
     62   // is destroyed. Returns a nullptr in case of failure.
     63   static std::unique_ptr<FlatBufferModel> BuildFromModel(
     64       const tflite::Model* model_spec,
     65       ErrorReporter* error_reporter = DefaultErrorReporter());
     66 
     67   // Releases memory or unmaps mmaped meory.
     68   ~FlatBufferModel();
     69 
     70   // Copying or assignment is disallowed to simplify ownership semantics.
     71   FlatBufferModel(const FlatBufferModel&) = delete;
     72   FlatBufferModel& operator=(const FlatBufferModel&) = delete;
     73 
     74   bool initialized() const { return model_ != nullptr; }
     75   const tflite::Model* operator->() const { return model_; }
     76   const tflite::Model* GetModel() const { return model_; }
     77   ErrorReporter* error_reporter() const { return error_reporter_; }
     78   const Allocation* allocation() const { return allocation_; }
     79 
     80   // Returns true if the model identifier is correct (otherwise false and
     81   // reports an error).
     82   bool CheckModelIdentifier() const;
     83 
     84  private:
     85   // Loads a model from `filename`. If `mmap_file` is true then use mmap,
     86   // otherwise make a copy of the model in a buffer.
     87   //
     88   // Note, if `error_reporter` is null, then a DefaultErrorReporter() will be
     89   // used.
     90   explicit FlatBufferModel(
     91       const char* filename, bool mmap_file = true,
     92       ErrorReporter* error_reporter = DefaultErrorReporter(),
     93       bool use_nnapi = false);
     94 
     95   // Loads a model from `ptr` and `num_bytes` of the model file. The `ptr` has
     96   // to remain alive and unchanged until the end of this flatbuffermodel's
     97   // lifetime.
     98   //
     99   // Note, if `error_reporter` is null, then a DefaultErrorReporter() will be
    100   // used.
    101   FlatBufferModel(const char* ptr, size_t num_bytes,
    102                   ErrorReporter* error_reporter = DefaultErrorReporter());
    103 
    104   // Loads a model from Model flatbuffer. The `model` has to remain alive and
    105   // unchanged until the end of this flatbuffermodel's lifetime.
    106   FlatBufferModel(const Model* model, ErrorReporter* error_reporter);
    107 
    108   // Flatbuffer traverser pointer. (Model* is a pointer that is within the
    109   // allocated memory of the data allocated by allocation's internals.
    110   const tflite::Model* model_ = nullptr;
    111   ErrorReporter* error_reporter_;
    112   Allocation* allocation_ = nullptr;
    113 };
    114 
    115 // Abstract interface that returns TfLiteRegistrations given op codes or custom
    116 // op names. This is the mechanism that ops being referenced in the flatbuffer
    117 // model are mapped to executable function pointers (TfLiteRegistrations).
    118 class OpResolver {
    119  public:
    120   // Finds the op registration for a builtin operator by enum code.
    121   virtual TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const = 0;
    122   // Finds the op registration of a custom operator by op name.
    123   virtual TfLiteRegistration* FindOp(const char* op) const = 0;
    124   virtual ~OpResolver() {}
    125 };
    126 
    127 // Build an interpreter capable of interpreting `model`.
    128 //
    129 // model: a scoped model whose lifetime must be at least as long as
    130 //   the interpreter. In principle multiple interpreters can be made from
    131 //   a single model.
    132 // op_resolver: An instance that implements the Resolver interface which maps
    133 //   custom op names and builtin op codes to op registrations.
    134 // reportError: a functor that is called to report errors that handles
    135 //   printf var arg semantics. The lifetime of the reportError object must
    136 //   be greater than or equal to the Interpreter created by operator().
    137 //
    138 // Returns a kTfLiteOk when successful and sets interpreter to a valid
    139 // Interpreter. Note: the user must ensure the model lifetime is at least as
    140 // long as interpreter's lifetime.
    141 class InterpreterBuilder {
    142  public:
    143   InterpreterBuilder(const FlatBufferModel& model,
    144                      const OpResolver& op_resolver);
    145   // Builds an interpreter given only the raw flatbuffer Model object (instead
    146   // of a FlatBufferModel). Mostly used for testing.
    147   // If `error_reporter` is null, then DefaultErrorReporter() is used.
    148   InterpreterBuilder(const ::tflite::Model* model,
    149                      const OpResolver& op_resolver,
    150                      ErrorReporter* error_reporter = DefaultErrorReporter());
    151   InterpreterBuilder(const InterpreterBuilder&) = delete;
    152   InterpreterBuilder& operator=(const InterpreterBuilder&) = delete;
    153   TfLiteStatus operator()(std::unique_ptr<Interpreter>* interpreter);
    154 
    155  private:
    156   TfLiteStatus BuildLocalIndexToRegistrationMapping();
    157   TfLiteStatus ParseNodes(
    158       const flatbuffers::Vector<flatbuffers::Offset<Operator>>* operators,
    159       Interpreter* interpreter);
    160   TfLiteStatus ParseTensors(
    161       const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers,
    162       const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors,
    163       Interpreter* interpreter);
    164 
    165   const ::tflite::Model* model_;
    166   const OpResolver& op_resolver_;
    167   ErrorReporter* error_reporter_;
    168 
    169   std::vector<TfLiteRegistration*> flatbuffer_op_index_to_registration_;
    170   std::vector<BuiltinOperator> flatbuffer_op_index_to_registration_types_;
    171   const Allocation* allocation_ = nullptr;
    172 };
    173 
    174 }  // namespace tflite
    175 
    176 #endif  // TENSORFLOW_CONTRIB_LITE_MODEL_H_
    177