Home | History | Annotate | Download | only in flex
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include "tensorflow/lite/delegates/flex/kernel.h"
     16 
     17 #include "flatbuffers/flexbuffers.h"  // TF:flatbuffers
     18 #include "tensorflow/core/common_runtime/eager/context.h"
     19 #include "tensorflow/core/common_runtime/eager/execute.h"
     20 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
     21 #include "tensorflow/core/framework/node_def.pb.h"
     22 #include "tensorflow/core/framework/node_def_util.h"
     23 #include "tensorflow/core/lib/core/errors.h"
     24 #include "tensorflow/lite/builtin_ops.h"
     25 #include "tensorflow/lite/c/c_api_internal.h"
     26 #include "tensorflow/lite/context_util.h"
     27 #include "tensorflow/lite/delegates/flex/delegate_data.h"
     28 #include "tensorflow/lite/delegates/flex/util.h"
     29 #include "tensorflow/lite/kernels/kernel_util.h"
     30 #include "tensorflow/lite/profiling/profiler.h"
     31 #include "tensorflow/lite/string.h"
     32 
     33 // Note: this is part of TF Lite's Flex delegation code which is to be
     34 // completed soon.
     35 
     36 // This is the TF Lite op that is created by the flex delegate to handle
     37 // execution of a supported subgraph. The usual flow is that the delegate
     38 // informs the interpreter of supported nodes in a graph, and each supported
     39 // subgraph is replaced with one instance of this kernel.
     40 //
     41 // The kernel is initialized with TfLiteDelegateParams from which we retrieve
     42 // the global EagerContext and BufferMap, as well as a list of inputs and
     43 // outputs to the subgraph. Those are used to build the OpData, with a list of
     44 // TensorFlow Ops that should be executed in order (which we call an OpNode).
     45 //
     46 // For each node included in the subgraph, we query the interpreter and
     47 // retrieve the associated NodeDef, which is then used to configure the
     48 // corresponding TensorFlow/Eager Op.
     49 
     50 namespace tflite {
     51 namespace flex {
     52 namespace kernel {
     53 
     54 struct OpNode;
     55 
     56 // Represents the origin of a given tensor as a reference to the output
     57 // of an upstream node.
     58 struct TensorSource {
     59   OpNode* node;
     60   int node_output_index;
     61 };
     62 
     63 // A list of inputs of a given node of the TensorFlow/Eager graph.
     64 class OpInputs {
     65  public:
     66   explicit OpInputs(const TfLiteIntArray* indexes) {
     67     for (int index : TfLiteIntArrayView(indexes)) {
     68       inputs_.push_back(index);
     69     }
     70     forwardable_.resize(inputs_.size());
     71   }
     72   ~OpInputs() {}
     73 
     74   int Size() const { return inputs_.size(); }
     75 
     76   int TfLiteIndex(int i) const { return inputs_[i]; }
     77 
     78   // Given a map relating tensors to the node that originates them, populate a
     79   // list of sources for the tensors in this class.
     80   void InitializeTensorSources(
     81       const std::map<int, TensorSource>& tflite_tensor_sources) {
     82     sources_.clear();
     83     for (int i : inputs_) {
     84       auto it = tflite_tensor_sources.find(i);
     85       if (it == tflite_tensor_sources.end()) {
     86         sources_.push_back({nullptr, 0});
     87       } else {
     88         sources_.push_back(it->second);
     89       }
     90     }
     91   }
     92 
     93   void SetForwardable(int i, bool v) { forwardable_[i] = v; }
     94 
     95   bool IsForwardable(int i) const { return forwardable_[i]; }
     96 
     97   TensorSource GetTensorSource(int i) const { return sources_[i]; }
     98 
     99  private:
    100   std::vector<int> inputs_;
    101   std::vector<TensorSource> sources_;
    102 
    103   // List of tensors that can be used by TF in its forwarding optimization.
    104   // Doing so allows an input tensor to be modified and used as the output
    105   // tensor. The delegate takes care of not holding any references to tensors
    106   // in this list while Eager is executing the corresponding op.
    107   std::vector<int> forwardable_;
    108 };
    109 
    110 // A list of outputs of a given node of the TensorFlow/Eager graph, along with
    111 // the actual outputs of the EagerOperation.
    112 class OpOutputs {
    113  public:
    114   explicit OpOutputs(const TfLiteIntArray* indexes) {
    115     for (int index : TfLiteIntArrayView(indexes)) {
    116       outputs_.push_back(index);
    117     }
    118     vector_.resize(outputs_.size());
    119   }
    120   ~OpOutputs() { ResetTensorHandles(); }
    121 
    122   // Stores information about which of the tensors in this class are also
    123   // outputs of the sugbraph.
    124   void InitializeGraphOutputs(const std::set<int>& subgraph_outputs) {
    125     subgraph_outputs_.clear();
    126     for (int i : outputs_) {
    127       subgraph_outputs_.push_back(subgraph_outputs.count(i) > 0);
    128     }
    129   }
    130 
    131   // Returns true if the tensor given by index 'i' is an output of the entire
    132   // subgraph.
    133   bool IsSubgraphOutput(int i) const { return subgraph_outputs_[i]; }
    134 
    135   // Returns a handle to a given tensor and, optionally, remove it from the
    136   // internal vector.
    137   tensorflow::TensorHandle* GetHandle(int i, bool remove) {
    138     auto* handle = vector_[i];
    139     if (!remove) {
    140       handle->Ref();
    141     } else {
    142       // Don't increase the ref-count. Instead, simply take it out of the
    143       // vector.
    144       vector_[i] = nullptr;
    145     }
    146     return handle;
    147   }
    148 
    149   int Size() const { return outputs_.size(); }
    150 
    151   int TfLiteIndex(int i) const { return outputs_[i]; }
    152 
    153   // Carefully unreference all the handles in the eager output vector.
    154   void ResetTensorHandles() {
    155     for (int i = 0; i < vector_.size(); ++i) {
    156       if (vector_[i]) {
    157         vector_[i]->Unref();
    158         vector_[i] = nullptr;
    159       }
    160     }
    161   }
    162 
    163   tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2>*
    164   GetTensorHandles() {
    165     return &vector_;
    166   }
    167 
    168  private:
    169   std::vector<int> outputs_;
    170   std::vector<bool> subgraph_outputs_;
    171   tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> vector_;
    172 };
    173 
    174 // A single node within the larger 'op'. Note that this kernel executes many
    175 // TensorFlow ops within a single TF Lite op.
    176 class OpNode {
    177  public:
    178   OpNode(const TfLiteIntArray* inputs, const TfLiteIntArray* outputs)
    179       : inputs_(inputs), outputs_(outputs) {}
    180   ~OpNode() {
    181     if (op_) ClearEagerInputs();
    182   }
    183 
    184   const string& name() const { return name_; }
    185   void set_name(const string& name) { name_ = name; }
    186 
    187   int index() const { return index_; }
    188   void set_index(int index) { index_ = index; }
    189 
    190   const tensorflow::NodeDef& nodedef() const { return nodedef_; }
    191 
    192   const OpInputs& inputs() const { return inputs_; }
    193   OpInputs* mutable_inputs() { return &inputs_; }
    194 
    195   const OpOutputs& outputs() const { return outputs_; }
    196   OpOutputs* mutable_outputs() { return &outputs_; }
    197 
    198   int NumInputs() const { return inputs_.Size(); }
    199   int NumOutputs() const { return outputs_.Size(); }
    200 
    201   tensorflow::EagerOperation* op() { return op_.get(); }
    202 
    203   tensorflow::Status InitializeNodeDef(const void* custom_initial_data,
    204                                        int custom_initial_data_size) {
    205     if (!custom_initial_data) {
    206       return tensorflow::errors::Internal(
    207           "Cannot convert empty data into a valid NodeDef");
    208     }
    209     // The flexbuffer contains a vector where the first elements is the
    210     // op name and the second is a serialized NodeDef.
    211     const flexbuffers::Vector& v =
    212         flexbuffers::GetRoot(
    213             reinterpret_cast<const uint8_t*>(custom_initial_data),
    214             custom_initial_data_size)
    215             .AsVector();
    216 
    217     name_ = v[0].AsString().str();
    218     if (!nodedef_.ParseFromString(v[1].AsString().str())) {
    219       nodedef_.Clear();
    220       return tensorflow::errors::Internal(
    221           "Failed to parse data into a valid NodeDef");
    222     }
    223 
    224     // Fill NodeDef with defaults if it's a valid op.
    225     const tensorflow::OpRegistrationData* op_reg_data;
    226     TF_RETURN_IF_ERROR(
    227         tensorflow::OpRegistry::Global()->LookUp(nodedef_.op(), &op_reg_data));
    228     AddDefaultsToNodeDef(op_reg_data->op_def, &nodedef_);
    229 
    230     return tensorflow::Status::OK();
    231   }
    232 
    233   // Build thew new EagerOperation. In case of error, the returned 'op' is
    234   // guaranteed to be 'nullptr'.
    235   tensorflow::Status BuildEagerOp(tensorflow::EagerContext* eager_context) {
    236     op_.reset();
    237 
    238     const tensorflow::AttrTypeMap* attr_types;
    239     bool is_function = false;
    240     TF_RETURN_WITH_CONTEXT_IF_ERROR(
    241         tensorflow::AttrTypeMapForOp(name_.c_str(), &attr_types, &is_function),
    242         " (while processing attributes of '", name_, "')");
    243     if (is_function) {
    244       return tensorflow::errors::NotFound(
    245           "Operation '", name_,
    246           "' is not registered.  (while processing attributes of '", name_,
    247           "')");
    248     }
    249 
    250     op_.reset(new tensorflow::EagerOperation(eager_context, name_.c_str(),
    251                                              /*is_function=*/false,
    252                                              attr_types));
    253 
    254     op_->MutableAttrs()->NumInputs(inputs_.Size());
    255     for (const auto& attr : nodedef_.attr()) {
    256       op_->MutableAttrs()->Set(attr.first, attr.second);
    257     }
    258 
    259     // Precalculating a cache key saves about 10% of inference time for very
    260     // small models.
    261     tensorflow::Device* device = op_->Device();
    262     op_->MutableAttrs()->CacheKey(device == nullptr ? "unspecified"
    263                                                     : device->name());
    264 
    265     return tensorflow::Status::OK();
    266   }
    267 
    268   void ClearEagerInputs() {
    269     for (tensorflow::TensorHandle* h : *op_->MutableInputs()) {
    270       if (h) h->Unref();
    271     }
    272     op_->MutableInputs()->clear();
    273   }
    274 
    275   tensorflow::Status BuildEagerInputs(const BufferMap* buffer_map) {
    276     for (int i = 0; i < inputs_.Size(); ++i) {
    277       int input_index = inputs_.TfLiteIndex(i);
    278       TensorSource s = inputs_.GetTensorSource(i);
    279       if (!s.node) {
    280         // This input is not produced by this Eager subgraph (it could be a TF
    281         // Lite native buffer, or could be produced by a separater subgraph). We
    282         // need to fetch it from the delegate's buffer_map.
    283         if (!buffer_map->HasTensor(input_index)) {
    284           return tensorflow::errors::Internal(
    285               "Cannot read from invalid tensor index ", input_index);
    286         }
    287         auto* handle = new tensorflow::TensorHandle(
    288             buffer_map->GetTensor(input_index), nullptr, nullptr, nullptr);
    289         op_->MutableInputs()->push_back(handle);
    290       } else {
    291         // If this is a forwardable tensor, we will remove it from the previous
    292         // op's list, giving TF the opportunity to reuse its buffer.
    293         bool unref_handle = inputs_.IsForwardable(i);
    294         auto* handle =
    295             s.node->outputs_.GetHandle(s.node_output_index, unref_handle);
    296         op_->MutableInputs()->push_back(handle);
    297       }
    298     }
    299     return tensorflow::Status::OK();
    300   }
    301 
    302   tensorflow::Status PersistEagerOutputs(BufferMap* buffer_map) {
    303     auto* handles = outputs_.GetTensorHandles();
    304     for (int i = 0; i < outputs_.Size(); ++i) {
    305       if (outputs_.IsSubgraphOutput(i)) {
    306         const tensorflow::Tensor* tensor = nullptr;
    307         TF_RETURN_IF_ERROR(handles->at(i)->Tensor(&tensor));
    308         buffer_map->SetFromTensorFlow(outputs_.TfLiteIndex(i), *tensor);
    309       }
    310     }
    311     return tensorflow::Status::OK();
    312   }
    313 
    314  private:
    315   OpNode(const OpNode&) = delete;
    316   OpNode& operator=(const OpNode&) = delete;
    317 
    318   // The name of the TensorFlow op to execute.
    319   string name_;
    320   // Index of this node into TF Lite's operator list.
    321   int index_;
    322   // The corresponding NodeDef, containing the attributes for the op.
    323   tensorflow::NodeDef nodedef_;
    324   // List of inputs, as TF Lite tensor indices.
    325   OpInputs inputs_;
    326   // List of outputs, as TF Lite tensor indices.
    327   OpOutputs outputs_;
    328 
    329   std::unique_ptr<tensorflow::EagerOperation> op_;
    330 };
    331 
    332 // Executes the TensorFlow op given by 'op_name', with the attributes specified
    333 // in 'nodedef'. Inputs and outputs are given as indices into the 'buffer_map'.
    334 tensorflow::Status ExecuteFlexOp(TfLiteContext* context, BufferMap* buffer_map,
    335                                  OpNode* node_data) {
    336   TF_RETURN_WITH_CONTEXT_IF_ERROR(node_data->BuildEagerInputs(buffer_map),
    337                                   " (while executing '", node_data->name(),
    338                                   "' via Eager)");
    339 
    340   node_data->mutable_outputs()->ResetTensorHandles();
    341   int num_retvals = node_data->NumOutputs();
    342   TF_RETURN_WITH_CONTEXT_IF_ERROR(
    343       EagerExecute(node_data->op(),
    344                    node_data->mutable_outputs()->GetTensorHandles(),
    345                    &num_retvals),
    346       " (while executing '", node_data->name(), "' via Eager)");
    347 
    348   if (num_retvals != node_data->NumOutputs()) {
    349     return tensorflow::errors::Internal(
    350         "Unexpected number of outputs from EagerExecute");
    351   }
    352 
    353   TF_RETURN_IF_ERROR(node_data->PersistEagerOutputs(buffer_map));
    354 
    355   node_data->ClearEagerInputs();
    356 
    357   return tensorflow::Status::OK();
    358 }
    359 
    360 // The larger 'op', which contains all the nodes in a supported subgraph.
    361 struct OpData {
    362   tensorflow::EagerContext* eager_context;
    363   BufferMap* buffer_map;
    364   std::vector<std::unique_ptr<OpNode>> nodes;
    365   std::vector<int> subgraph_inputs;
    366   std::vector<int> subgraph_outputs;
    367 };
    368 
    369 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
    370   auto* op_data = new OpData;
    371 
    372   const TfLiteDelegateParams* params =
    373       reinterpret_cast<const TfLiteDelegateParams*>(buffer);
    374   CHECK(params);
    375   CHECK(params->delegate);
    376   CHECK(params->delegate->data_);
    377   op_data->eager_context =
    378       reinterpret_cast<DelegateData*>(params->delegate->data_)
    379           ->GetEagerContext();
    380   op_data->buffer_map = reinterpret_cast<DelegateData*>(params->delegate->data_)
    381                             ->GetBufferMap(context);
    382 
    383   CHECK(params->output_tensors);
    384   std::set<int> output_set;
    385   for (auto tensor_index : TfLiteIntArrayView(params->output_tensors)) {
    386     op_data->subgraph_outputs.push_back(tensor_index);
    387     output_set.insert(tensor_index);
    388   }
    389 
    390   CHECK(params->input_tensors);
    391   for (auto tensor_index : TfLiteIntArrayView(params->input_tensors)) {
    392     op_data->subgraph_inputs.push_back(tensor_index);
    393   }
    394 
    395   op_data->nodes.reserve(params->nodes_to_replace->size);
    396 
    397   CHECK(params->nodes_to_replace);
    398   tensorflow::Status status;
    399   for (auto node_index : TfLiteIntArrayView(params->nodes_to_replace)) {
    400     TfLiteNode* node;
    401     TfLiteRegistration* reg;
    402     context->GetNodeAndRegistration(context, node_index, &node, &reg);
    403 
    404     op_data->nodes.emplace_back(new OpNode(node->inputs, node->outputs));
    405     OpNode& node_data = *op_data->nodes.back();
    406 
    407     node_data.set_index(node_index);
    408     node_data.set_name("");
    409 
    410     status = node_data.InitializeNodeDef(node->custom_initial_data,
    411                                          node->custom_initial_data_size);
    412     if (!status.ok()) break;
    413     status = node_data.BuildEagerOp(op_data->eager_context);
    414     if (!status.ok()) break;
    415   }
    416 
    417   if (ConvertStatus(context, status) != kTfLiteOk) {
    418     // We can't return an error from this function but ConvertStatus will
    419     // report them and we will stop processing in Prepare() if anything went
    420     // wrong.
    421     return op_data;
    422   }
    423 
    424   // Given a TfLite tensor index, return the OpNode that produces it,
    425   // along with it index into that OpNodes list of outputs.
    426   std::map<int, TensorSource> tflite_tensor_sources;
    427 
    428   // Find out how each tensor is produced. This does not account for
    429   // tensors that are not produce by eager ops.
    430   for (auto& node_data : op_data->nodes) {
    431     node_data->mutable_outputs()->InitializeGraphOutputs(output_set);
    432     for (int i = 0; i < node_data->outputs().Size(); ++i) {
    433       int output_index = node_data->outputs().TfLiteIndex(i);
    434       tflite_tensor_sources[output_index] = TensorSource{node_data.get(), i};
    435     }
    436   }
    437 
    438   // For each node, resolve the inputs, so we can keep pointers to the nodes
    439   // that produces them.
    440   for (auto& node_data : op_data->nodes) {
    441     node_data->mutable_inputs()->InitializeTensorSources(tflite_tensor_sources);
    442   }
    443 
    444   return op_data;
    445 }
    446 
    447 void Free(TfLiteContext* context, void* buffer) {
    448   delete reinterpret_cast<OpData*>(buffer);
    449 }
    450 
    451 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
    452   const auto* op_data = reinterpret_cast<OpData*>(node->user_data);
    453   TF_LITE_ENSURE_MSG(
    454       context, op_data->eager_context != nullptr,
    455       "Failed to initialize eager context. This often happens when a CPU "
    456       "device has not been registered, presumably because some symbols from "
    457       "tensorflow/core:core_cpu_impl were not linked into the binary.");
    458 
    459   // We will keep track of the number of references to each tensor in the
    460   // graph, so we can make them "forwardable" if there is only one reference.
    461   std::map<int, int> tensor_ref_count;
    462 
    463   // Whenever we find a constant tensor, insert it in the buffer map.
    464   BufferMap* buffer_map = op_data->buffer_map;
    465   for (auto tensor_index : op_data->subgraph_inputs) {
    466     TfLiteTensor* tensor = &context->tensors[tensor_index];
    467     if (IsConstantTensor(tensor)) {
    468       if (!buffer_map->HasTensor(tensor_index)) {
    469         buffer_map->SetFromTfLite(tensor_index, tensor);
    470       }
    471     }
    472 
    473     // Input tensors should never be forwarded so we increment their ref counts
    474     // twice: once for this graph and another for the possibility of them being
    475     // used by another subgraph, or being an output of the full graph.
    476     tensor_ref_count[tensor_index] += 2;
    477   }
    478 
    479   // All output tensors are allocated by TensorFlow/Eager, so we
    480   // mark them as kTfLiteDynamic.
    481   for (auto tensor_index : op_data->subgraph_outputs) {
    482     SetTensorToDynamic(&context->tensors[tensor_index]);
    483     ++tensor_ref_count[tensor_index];
    484   }
    485 
    486   for (const auto& node_data : op_data->nodes) {
    487     if (node_data->nodedef().op().empty()) {
    488       context->ReportError(context, "Invalid NodeDef in Flex op '%s'",
    489                            node_data->name().c_str());
    490       return kTfLiteError;
    491     }
    492     TF_LITE_ENSURE(context, node_data->op());
    493 
    494     for (int i = 0; i < node_data->inputs().Size(); ++i) {
    495       ++tensor_ref_count[node_data->inputs().TfLiteIndex(i)];
    496     }
    497   }
    498 
    499   // All tensors that are referenced exactly once are marked as "forwardable",
    500   // meaning that we will allow TensorFlow to reuse its buffer as the output of
    501   // an op.
    502   for (auto& node_data : op_data->nodes) {
    503     for (int i = 0; i < node_data->inputs().Size(); ++i) {
    504       bool f = (tensor_ref_count[node_data->inputs().TfLiteIndex(i)] == 1);
    505       node_data->mutable_inputs()->SetForwardable(i, f);
    506     }
    507   }
    508 
    509   return kTfLiteOk;
    510 }
    511 
    512 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
    513   auto* op_data = reinterpret_cast<OpData*>(node->user_data);
    514   BufferMap* buffer_map = op_data->buffer_map;
    515 
    516   // Insert a tensor in the buffer map for all inputs that are not constant.
    517   // Constants were handled in Prepare() already.
    518   for (auto tensor_index : op_data->subgraph_inputs) {
    519     TfLiteTensor* tensor = &context->tensors[tensor_index];
    520     if (!IsConstantTensor(tensor)) {
    521       // If this tensor is part of an earlier TF subgraph we should not add it
    522       // to the BufferMap again, because TF already knows about it and its
    523       // contents are kept automatically up-to-date.
    524       if (!buffer_map->IsTensorFlowTensor(tensor_index)) {
    525         buffer_map->SetFromTfLite(tensor_index, tensor);
    526       }
    527     }
    528   }
    529 
    530   // Execute the TensorFlow Ops sequentially.
    531   for (auto& node_data : op_data->nodes) {
    532     SCOPED_TAGGED_OPERATOR_PROFILE(
    533         reinterpret_cast<profiling::Profiler*>(context->profiler),
    534         node_data->name().c_str(), node_data->index());
    535 
    536     auto status = ExecuteFlexOp(context, buffer_map, node_data.get());
    537     TF_LITE_ENSURE_OK(context, ConvertStatus(context, status));
    538   }
    539 
    540   for (auto tensor_index : op_data->subgraph_outputs) {
    541     if (!buffer_map->HasTensor(tensor_index)) {
    542       context->ReportError(context, "Cannot write to invalid tensor index %d",
    543                            tensor_index);
    544       return kTfLiteError;
    545     }
    546 
    547     TfLiteTensor* tensor = &context->tensors[tensor_index];
    548     TF_LITE_ENSURE_OK(
    549         context,
    550         CopyShapeAndType(context, buffer_map->GetTensor(tensor_index), tensor));
    551     tensor->buffer_handle = tensor_index;
    552     tensor->data_is_stale = true;
    553   }
    554 
    555   return kTfLiteOk;
    556 }
    557 
    558 }  // namespace kernel
    559 
    560 TfLiteRegistration GetKernel() {
    561   TfLiteRegistration registration{&kernel::Init,    &kernel::Free,
    562                                   &kernel::Prepare, &kernel::Eval,
    563                                   nullptr,          kTfLiteBuiltinDelegate};
    564   return registration;
    565 }
    566 
    567 }  // namespace flex
    568 }  // namespace tflite
    569