1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/lite/delegates/flex/kernel.h" 16 17 #include "flatbuffers/flexbuffers.h" // TF:flatbuffers 18 #include "tensorflow/core/common_runtime/eager/context.h" 19 #include "tensorflow/core/common_runtime/eager/execute.h" 20 #include "tensorflow/core/common_runtime/eager/tensor_handle.h" 21 #include "tensorflow/core/framework/node_def.pb.h" 22 #include "tensorflow/core/framework/node_def_util.h" 23 #include "tensorflow/core/lib/core/errors.h" 24 #include "tensorflow/lite/builtin_ops.h" 25 #include "tensorflow/lite/c/c_api_internal.h" 26 #include "tensorflow/lite/context_util.h" 27 #include "tensorflow/lite/delegates/flex/delegate_data.h" 28 #include "tensorflow/lite/delegates/flex/util.h" 29 #include "tensorflow/lite/kernels/kernel_util.h" 30 #include "tensorflow/lite/profiling/profiler.h" 31 #include "tensorflow/lite/string.h" 32 33 // Note: this is part of TF Lite's Flex delegation code which is to be 34 // completed soon. 35 36 // This is the TF Lite op that is created by the flex delegate to handle 37 // execution of a supported subgraph. The usual flow is that the delegate 38 // informs the interpreter of supported nodes in a graph, and each supported 39 // subgraph is replaced with one instance of this kernel. 40 // 41 // The kernel is initialized with TfLiteDelegateParams from which we retrieve 42 // the global EagerContext and BufferMap, as well as a list of inputs and 43 // outputs to the subgraph. Those are used to build the OpData, with a list of 44 // TensorFlow Ops that should be executed in order (which we call an OpNode). 45 // 46 // For each node included in the subgraph, we query the interpreter and 47 // retrieve the associated NodeDef, which is then used to configure the 48 // corresponding TensorFlow/Eager Op. 49 50 namespace tflite { 51 namespace flex { 52 namespace kernel { 53 54 struct OpNode; 55 56 // Represents the origin of a given tensor as a reference to the output 57 // of an upstream node. 58 struct TensorSource { 59 OpNode* node; 60 int node_output_index; 61 }; 62 63 // A list of inputs of a given node of the TensorFlow/Eager graph. 64 class OpInputs { 65 public: 66 explicit OpInputs(const TfLiteIntArray* indexes) { 67 for (int index : TfLiteIntArrayView(indexes)) { 68 inputs_.push_back(index); 69 } 70 forwardable_.resize(inputs_.size()); 71 } 72 ~OpInputs() {} 73 74 int Size() const { return inputs_.size(); } 75 76 int TfLiteIndex(int i) const { return inputs_[i]; } 77 78 // Given a map relating tensors to the node that originates them, populate a 79 // list of sources for the tensors in this class. 80 void InitializeTensorSources( 81 const std::map<int, TensorSource>& tflite_tensor_sources) { 82 sources_.clear(); 83 for (int i : inputs_) { 84 auto it = tflite_tensor_sources.find(i); 85 if (it == tflite_tensor_sources.end()) { 86 sources_.push_back({nullptr, 0}); 87 } else { 88 sources_.push_back(it->second); 89 } 90 } 91 } 92 93 void SetForwardable(int i, bool v) { forwardable_[i] = v; } 94 95 bool IsForwardable(int i) const { return forwardable_[i]; } 96 97 TensorSource GetTensorSource(int i) const { return sources_[i]; } 98 99 private: 100 std::vector<int> inputs_; 101 std::vector<TensorSource> sources_; 102 103 // List of tensors that can be used by TF in its forwarding optimization. 104 // Doing so allows an input tensor to be modified and used as the output 105 // tensor. The delegate takes care of not holding any references to tensors 106 // in this list while Eager is executing the corresponding op. 107 std::vector<int> forwardable_; 108 }; 109 110 // A list of outputs of a given node of the TensorFlow/Eager graph, along with 111 // the actual outputs of the EagerOperation. 112 class OpOutputs { 113 public: 114 explicit OpOutputs(const TfLiteIntArray* indexes) { 115 for (int index : TfLiteIntArrayView(indexes)) { 116 outputs_.push_back(index); 117 } 118 vector_.resize(outputs_.size()); 119 } 120 ~OpOutputs() { ResetTensorHandles(); } 121 122 // Stores information about which of the tensors in this class are also 123 // outputs of the sugbraph. 124 void InitializeGraphOutputs(const std::set<int>& subgraph_outputs) { 125 subgraph_outputs_.clear(); 126 for (int i : outputs_) { 127 subgraph_outputs_.push_back(subgraph_outputs.count(i) > 0); 128 } 129 } 130 131 // Returns true if the tensor given by index 'i' is an output of the entire 132 // subgraph. 133 bool IsSubgraphOutput(int i) const { return subgraph_outputs_[i]; } 134 135 // Returns a handle to a given tensor and, optionally, remove it from the 136 // internal vector. 137 tensorflow::TensorHandle* GetHandle(int i, bool remove) { 138 auto* handle = vector_[i]; 139 if (!remove) { 140 handle->Ref(); 141 } else { 142 // Don't increase the ref-count. Instead, simply take it out of the 143 // vector. 144 vector_[i] = nullptr; 145 } 146 return handle; 147 } 148 149 int Size() const { return outputs_.size(); } 150 151 int TfLiteIndex(int i) const { return outputs_[i]; } 152 153 // Carefully unreference all the handles in the eager output vector. 154 void ResetTensorHandles() { 155 for (int i = 0; i < vector_.size(); ++i) { 156 if (vector_[i]) { 157 vector_[i]->Unref(); 158 vector_[i] = nullptr; 159 } 160 } 161 } 162 163 tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2>* 164 GetTensorHandles() { 165 return &vector_; 166 } 167 168 private: 169 std::vector<int> outputs_; 170 std::vector<bool> subgraph_outputs_; 171 tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> vector_; 172 }; 173 174 // A single node within the larger 'op'. Note that this kernel executes many 175 // TensorFlow ops within a single TF Lite op. 176 class OpNode { 177 public: 178 OpNode(const TfLiteIntArray* inputs, const TfLiteIntArray* outputs) 179 : inputs_(inputs), outputs_(outputs) {} 180 ~OpNode() { 181 if (op_) ClearEagerInputs(); 182 } 183 184 const string& name() const { return name_; } 185 void set_name(const string& name) { name_ = name; } 186 187 int index() const { return index_; } 188 void set_index(int index) { index_ = index; } 189 190 const tensorflow::NodeDef& nodedef() const { return nodedef_; } 191 192 const OpInputs& inputs() const { return inputs_; } 193 OpInputs* mutable_inputs() { return &inputs_; } 194 195 const OpOutputs& outputs() const { return outputs_; } 196 OpOutputs* mutable_outputs() { return &outputs_; } 197 198 int NumInputs() const { return inputs_.Size(); } 199 int NumOutputs() const { return outputs_.Size(); } 200 201 tensorflow::EagerOperation* op() { return op_.get(); } 202 203 tensorflow::Status InitializeNodeDef(const void* custom_initial_data, 204 int custom_initial_data_size) { 205 if (!custom_initial_data) { 206 return tensorflow::errors::Internal( 207 "Cannot convert empty data into a valid NodeDef"); 208 } 209 // The flexbuffer contains a vector where the first elements is the 210 // op name and the second is a serialized NodeDef. 211 const flexbuffers::Vector& v = 212 flexbuffers::GetRoot( 213 reinterpret_cast<const uint8_t*>(custom_initial_data), 214 custom_initial_data_size) 215 .AsVector(); 216 217 name_ = v[0].AsString().str(); 218 if (!nodedef_.ParseFromString(v[1].AsString().str())) { 219 nodedef_.Clear(); 220 return tensorflow::errors::Internal( 221 "Failed to parse data into a valid NodeDef"); 222 } 223 224 // Fill NodeDef with defaults if it's a valid op. 225 const tensorflow::OpRegistrationData* op_reg_data; 226 TF_RETURN_IF_ERROR( 227 tensorflow::OpRegistry::Global()->LookUp(nodedef_.op(), &op_reg_data)); 228 AddDefaultsToNodeDef(op_reg_data->op_def, &nodedef_); 229 230 return tensorflow::Status::OK(); 231 } 232 233 // Build thew new EagerOperation. In case of error, the returned 'op' is 234 // guaranteed to be 'nullptr'. 235 tensorflow::Status BuildEagerOp(tensorflow::EagerContext* eager_context) { 236 op_.reset(); 237 238 const tensorflow::AttrTypeMap* attr_types; 239 bool is_function = false; 240 TF_RETURN_WITH_CONTEXT_IF_ERROR( 241 tensorflow::AttrTypeMapForOp(name_.c_str(), &attr_types, &is_function), 242 " (while processing attributes of '", name_, "')"); 243 if (is_function) { 244 return tensorflow::errors::NotFound( 245 "Operation '", name_, 246 "' is not registered. (while processing attributes of '", name_, 247 "')"); 248 } 249 250 op_.reset(new tensorflow::EagerOperation(eager_context, name_.c_str(), 251 /*is_function=*/false, 252 attr_types)); 253 254 op_->MutableAttrs()->NumInputs(inputs_.Size()); 255 for (const auto& attr : nodedef_.attr()) { 256 op_->MutableAttrs()->Set(attr.first, attr.second); 257 } 258 259 // Precalculating a cache key saves about 10% of inference time for very 260 // small models. 261 tensorflow::Device* device = op_->Device(); 262 op_->MutableAttrs()->CacheKey(device == nullptr ? "unspecified" 263 : device->name()); 264 265 return tensorflow::Status::OK(); 266 } 267 268 void ClearEagerInputs() { 269 for (tensorflow::TensorHandle* h : *op_->MutableInputs()) { 270 if (h) h->Unref(); 271 } 272 op_->MutableInputs()->clear(); 273 } 274 275 tensorflow::Status BuildEagerInputs(const BufferMap* buffer_map) { 276 for (int i = 0; i < inputs_.Size(); ++i) { 277 int input_index = inputs_.TfLiteIndex(i); 278 TensorSource s = inputs_.GetTensorSource(i); 279 if (!s.node) { 280 // This input is not produced by this Eager subgraph (it could be a TF 281 // Lite native buffer, or could be produced by a separater subgraph). We 282 // need to fetch it from the delegate's buffer_map. 283 if (!buffer_map->HasTensor(input_index)) { 284 return tensorflow::errors::Internal( 285 "Cannot read from invalid tensor index ", input_index); 286 } 287 auto* handle = new tensorflow::TensorHandle( 288 buffer_map->GetTensor(input_index), nullptr, nullptr, nullptr); 289 op_->MutableInputs()->push_back(handle); 290 } else { 291 // If this is a forwardable tensor, we will remove it from the previous 292 // op's list, giving TF the opportunity to reuse its buffer. 293 bool unref_handle = inputs_.IsForwardable(i); 294 auto* handle = 295 s.node->outputs_.GetHandle(s.node_output_index, unref_handle); 296 op_->MutableInputs()->push_back(handle); 297 } 298 } 299 return tensorflow::Status::OK(); 300 } 301 302 tensorflow::Status PersistEagerOutputs(BufferMap* buffer_map) { 303 auto* handles = outputs_.GetTensorHandles(); 304 for (int i = 0; i < outputs_.Size(); ++i) { 305 if (outputs_.IsSubgraphOutput(i)) { 306 const tensorflow::Tensor* tensor = nullptr; 307 TF_RETURN_IF_ERROR(handles->at(i)->Tensor(&tensor)); 308 buffer_map->SetFromTensorFlow(outputs_.TfLiteIndex(i), *tensor); 309 } 310 } 311 return tensorflow::Status::OK(); 312 } 313 314 private: 315 OpNode(const OpNode&) = delete; 316 OpNode& operator=(const OpNode&) = delete; 317 318 // The name of the TensorFlow op to execute. 319 string name_; 320 // Index of this node into TF Lite's operator list. 321 int index_; 322 // The corresponding NodeDef, containing the attributes for the op. 323 tensorflow::NodeDef nodedef_; 324 // List of inputs, as TF Lite tensor indices. 325 OpInputs inputs_; 326 // List of outputs, as TF Lite tensor indices. 327 OpOutputs outputs_; 328 329 std::unique_ptr<tensorflow::EagerOperation> op_; 330 }; 331 332 // Executes the TensorFlow op given by 'op_name', with the attributes specified 333 // in 'nodedef'. Inputs and outputs are given as indices into the 'buffer_map'. 334 tensorflow::Status ExecuteFlexOp(TfLiteContext* context, BufferMap* buffer_map, 335 OpNode* node_data) { 336 TF_RETURN_WITH_CONTEXT_IF_ERROR(node_data->BuildEagerInputs(buffer_map), 337 " (while executing '", node_data->name(), 338 "' via Eager)"); 339 340 node_data->mutable_outputs()->ResetTensorHandles(); 341 int num_retvals = node_data->NumOutputs(); 342 TF_RETURN_WITH_CONTEXT_IF_ERROR( 343 EagerExecute(node_data->op(), 344 node_data->mutable_outputs()->GetTensorHandles(), 345 &num_retvals), 346 " (while executing '", node_data->name(), "' via Eager)"); 347 348 if (num_retvals != node_data->NumOutputs()) { 349 return tensorflow::errors::Internal( 350 "Unexpected number of outputs from EagerExecute"); 351 } 352 353 TF_RETURN_IF_ERROR(node_data->PersistEagerOutputs(buffer_map)); 354 355 node_data->ClearEagerInputs(); 356 357 return tensorflow::Status::OK(); 358 } 359 360 // The larger 'op', which contains all the nodes in a supported subgraph. 361 struct OpData { 362 tensorflow::EagerContext* eager_context; 363 BufferMap* buffer_map; 364 std::vector<std::unique_ptr<OpNode>> nodes; 365 std::vector<int> subgraph_inputs; 366 std::vector<int> subgraph_outputs; 367 }; 368 369 void* Init(TfLiteContext* context, const char* buffer, size_t length) { 370 auto* op_data = new OpData; 371 372 const TfLiteDelegateParams* params = 373 reinterpret_cast<const TfLiteDelegateParams*>(buffer); 374 CHECK(params); 375 CHECK(params->delegate); 376 CHECK(params->delegate->data_); 377 op_data->eager_context = 378 reinterpret_cast<DelegateData*>(params->delegate->data_) 379 ->GetEagerContext(); 380 op_data->buffer_map = reinterpret_cast<DelegateData*>(params->delegate->data_) 381 ->GetBufferMap(context); 382 383 CHECK(params->output_tensors); 384 std::set<int> output_set; 385 for (auto tensor_index : TfLiteIntArrayView(params->output_tensors)) { 386 op_data->subgraph_outputs.push_back(tensor_index); 387 output_set.insert(tensor_index); 388 } 389 390 CHECK(params->input_tensors); 391 for (auto tensor_index : TfLiteIntArrayView(params->input_tensors)) { 392 op_data->subgraph_inputs.push_back(tensor_index); 393 } 394 395 op_data->nodes.reserve(params->nodes_to_replace->size); 396 397 CHECK(params->nodes_to_replace); 398 tensorflow::Status status; 399 for (auto node_index : TfLiteIntArrayView(params->nodes_to_replace)) { 400 TfLiteNode* node; 401 TfLiteRegistration* reg; 402 context->GetNodeAndRegistration(context, node_index, &node, ®); 403 404 op_data->nodes.emplace_back(new OpNode(node->inputs, node->outputs)); 405 OpNode& node_data = *op_data->nodes.back(); 406 407 node_data.set_index(node_index); 408 node_data.set_name(""); 409 410 status = node_data.InitializeNodeDef(node->custom_initial_data, 411 node->custom_initial_data_size); 412 if (!status.ok()) break; 413 status = node_data.BuildEagerOp(op_data->eager_context); 414 if (!status.ok()) break; 415 } 416 417 if (ConvertStatus(context, status) != kTfLiteOk) { 418 // We can't return an error from this function but ConvertStatus will 419 // report them and we will stop processing in Prepare() if anything went 420 // wrong. 421 return op_data; 422 } 423 424 // Given a TfLite tensor index, return the OpNode that produces it, 425 // along with it index into that OpNodes list of outputs. 426 std::map<int, TensorSource> tflite_tensor_sources; 427 428 // Find out how each tensor is produced. This does not account for 429 // tensors that are not produce by eager ops. 430 for (auto& node_data : op_data->nodes) { 431 node_data->mutable_outputs()->InitializeGraphOutputs(output_set); 432 for (int i = 0; i < node_data->outputs().Size(); ++i) { 433 int output_index = node_data->outputs().TfLiteIndex(i); 434 tflite_tensor_sources[output_index] = TensorSource{node_data.get(), i}; 435 } 436 } 437 438 // For each node, resolve the inputs, so we can keep pointers to the nodes 439 // that produces them. 440 for (auto& node_data : op_data->nodes) { 441 node_data->mutable_inputs()->InitializeTensorSources(tflite_tensor_sources); 442 } 443 444 return op_data; 445 } 446 447 void Free(TfLiteContext* context, void* buffer) { 448 delete reinterpret_cast<OpData*>(buffer); 449 } 450 451 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { 452 const auto* op_data = reinterpret_cast<OpData*>(node->user_data); 453 TF_LITE_ENSURE_MSG( 454 context, op_data->eager_context != nullptr, 455 "Failed to initialize eager context. This often happens when a CPU " 456 "device has not been registered, presumably because some symbols from " 457 "tensorflow/core:core_cpu_impl were not linked into the binary."); 458 459 // We will keep track of the number of references to each tensor in the 460 // graph, so we can make them "forwardable" if there is only one reference. 461 std::map<int, int> tensor_ref_count; 462 463 // Whenever we find a constant tensor, insert it in the buffer map. 464 BufferMap* buffer_map = op_data->buffer_map; 465 for (auto tensor_index : op_data->subgraph_inputs) { 466 TfLiteTensor* tensor = &context->tensors[tensor_index]; 467 if (IsConstantTensor(tensor)) { 468 if (!buffer_map->HasTensor(tensor_index)) { 469 buffer_map->SetFromTfLite(tensor_index, tensor); 470 } 471 } 472 473 // Input tensors should never be forwarded so we increment their ref counts 474 // twice: once for this graph and another for the possibility of them being 475 // used by another subgraph, or being an output of the full graph. 476 tensor_ref_count[tensor_index] += 2; 477 } 478 479 // All output tensors are allocated by TensorFlow/Eager, so we 480 // mark them as kTfLiteDynamic. 481 for (auto tensor_index : op_data->subgraph_outputs) { 482 SetTensorToDynamic(&context->tensors[tensor_index]); 483 ++tensor_ref_count[tensor_index]; 484 } 485 486 for (const auto& node_data : op_data->nodes) { 487 if (node_data->nodedef().op().empty()) { 488 context->ReportError(context, "Invalid NodeDef in Flex op '%s'", 489 node_data->name().c_str()); 490 return kTfLiteError; 491 } 492 TF_LITE_ENSURE(context, node_data->op()); 493 494 for (int i = 0; i < node_data->inputs().Size(); ++i) { 495 ++tensor_ref_count[node_data->inputs().TfLiteIndex(i)]; 496 } 497 } 498 499 // All tensors that are referenced exactly once are marked as "forwardable", 500 // meaning that we will allow TensorFlow to reuse its buffer as the output of 501 // an op. 502 for (auto& node_data : op_data->nodes) { 503 for (int i = 0; i < node_data->inputs().Size(); ++i) { 504 bool f = (tensor_ref_count[node_data->inputs().TfLiteIndex(i)] == 1); 505 node_data->mutable_inputs()->SetForwardable(i, f); 506 } 507 } 508 509 return kTfLiteOk; 510 } 511 512 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { 513 auto* op_data = reinterpret_cast<OpData*>(node->user_data); 514 BufferMap* buffer_map = op_data->buffer_map; 515 516 // Insert a tensor in the buffer map for all inputs that are not constant. 517 // Constants were handled in Prepare() already. 518 for (auto tensor_index : op_data->subgraph_inputs) { 519 TfLiteTensor* tensor = &context->tensors[tensor_index]; 520 if (!IsConstantTensor(tensor)) { 521 // If this tensor is part of an earlier TF subgraph we should not add it 522 // to the BufferMap again, because TF already knows about it and its 523 // contents are kept automatically up-to-date. 524 if (!buffer_map->IsTensorFlowTensor(tensor_index)) { 525 buffer_map->SetFromTfLite(tensor_index, tensor); 526 } 527 } 528 } 529 530 // Execute the TensorFlow Ops sequentially. 531 for (auto& node_data : op_data->nodes) { 532 SCOPED_TAGGED_OPERATOR_PROFILE( 533 reinterpret_cast<profiling::Profiler*>(context->profiler), 534 node_data->name().c_str(), node_data->index()); 535 536 auto status = ExecuteFlexOp(context, buffer_map, node_data.get()); 537 TF_LITE_ENSURE_OK(context, ConvertStatus(context, status)); 538 } 539 540 for (auto tensor_index : op_data->subgraph_outputs) { 541 if (!buffer_map->HasTensor(tensor_index)) { 542 context->ReportError(context, "Cannot write to invalid tensor index %d", 543 tensor_index); 544 return kTfLiteError; 545 } 546 547 TfLiteTensor* tensor = &context->tensors[tensor_index]; 548 TF_LITE_ENSURE_OK( 549 context, 550 CopyShapeAndType(context, buffer_map->GetTensor(tensor_index), tensor)); 551 tensor->buffer_handle = tensor_index; 552 tensor->data_is_stale = true; 553 } 554 555 return kTfLiteOk; 556 } 557 558 } // namespace kernel 559 560 TfLiteRegistration GetKernel() { 561 TfLiteRegistration registration{&kernel::Init, &kernel::Free, 562 &kernel::Prepare, &kernel::Eval, 563 nullptr, kTfLiteBuiltinDelegate}; 564 return registration; 565 } 566 567 } // namespace flex 568 } // namespace tflite 569