Home | History | Annotate | Download | only in common_runtime
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/common_runtime/executor.h"
     17 
     18 #include <atomic>
     19 #include <deque>
     20 #include <memory>
     21 #include <string>
     22 #include <unordered_map>
     23 #include <vector>
     24 
     25 #include "tensorflow/core/common_runtime/costmodel_manager.h"
     26 #include "tensorflow/core/common_runtime/executor_factory.h"
     27 #include "tensorflow/core/common_runtime/pending_counts.h"
     28 #include "tensorflow/core/common_runtime/step_stats_collector.h"
     29 #include "tensorflow/core/framework/allocation_description.pb.h"
     30 #include "tensorflow/core/framework/allocator.h"
     31 #include "tensorflow/core/framework/cancellation.h"
     32 #include "tensorflow/core/framework/collective.h"
     33 #include "tensorflow/core/framework/control_flow.h"
     34 #include "tensorflow/core/framework/device_attributes.pb.h"
     35 #include "tensorflow/core/framework/graph.pb.h"
     36 #include "tensorflow/core/framework/log_memory.h"
     37 #include "tensorflow/core/framework/node_def_util.h"
     38 #include "tensorflow/core/framework/op_kernel.h"
     39 #include "tensorflow/core/framework/op_segment.h"
     40 #include "tensorflow/core/framework/step_stats.pb.h"
     41 #include "tensorflow/core/framework/tensor.h"
     42 #include "tensorflow/core/framework/tensor_reference.h"
     43 #include "tensorflow/core/framework/types.h"
     44 #include "tensorflow/core/framework/types.pb.h"
     45 #include "tensorflow/core/graph/edgeset.h"
     46 #include "tensorflow/core/lib/core/errors.h"
     47 #include "tensorflow/core/lib/core/notification.h"
     48 #include "tensorflow/core/lib/core/stringpiece.h"
     49 #include "tensorflow/core/lib/core/threadpool.h"
     50 #include "tensorflow/core/lib/gtl/flatmap.h"
     51 #include "tensorflow/core/lib/gtl/flatset.h"
     52 #include "tensorflow/core/lib/gtl/inlined_vector.h"
     53 #include "tensorflow/core/lib/gtl/manual_constructor.h"
     54 #include "tensorflow/core/lib/gtl/stl_util.h"
     55 #include "tensorflow/core/lib/hash/hash.h"
     56 #include "tensorflow/core/lib/strings/str_util.h"
     57 #include "tensorflow/core/lib/strings/stringprintf.h"
     58 #include "tensorflow/core/platform/context.h"
     59 #include "tensorflow/core/platform/env.h"
     60 #include "tensorflow/core/platform/logging.h"
     61 #include "tensorflow/core/platform/macros.h"
     62 #include "tensorflow/core/platform/mutex.h"
     63 #include "tensorflow/core/platform/profile_utils/cpu_utils.h"
     64 #include "tensorflow/core/platform/thread_annotations.h"
     65 #include "tensorflow/core/platform/tracing.h"
     66 #include "tensorflow/core/platform/types.h"
     67 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
     68 
     69 namespace tensorflow {
     70 namespace {
     71 
     72 // 1-D, 0 element tensor.
     73 static const Tensor* const kEmptyTensor = new Tensor;
     74 
     75 bool IsInitializationOp(const Node* node) {
     76   return node->op_def().allows_uninitialized_input();
     77 }
     78 
     79 // Helper routines for collecting step stats.
     80 namespace nodestats {
     81 inline int64 NowInNsec() { return Env::Default()->NowNanos(); }
     82 
     83 void SetScheduled(NodeExecStatsInterface* stats, int64 micros) {
     84   if (!stats) return;
     85   stats->SetScheduled(micros * EnvTime::kMicrosToNanos);
     86 }
     87 
     88 void SetAllStart(NodeExecStatsInterface* stats) {
     89   if (!stats) return;
     90   stats->RecordExecutorStarted();
     91 }
     92 
     93 void SetOpStart(NodeExecStatsInterface* stats) {
     94   if (!stats) return;
     95   stats->RecordComputeStarted();
     96 }
     97 
     98 void SetOpEnd(NodeExecStatsInterface* stats) {
     99   if (!stats) return;
    100   stats->RecordComputeEnded();
    101 }
    102 
    103 void SetAllEnd(NodeExecStatsInterface* stats) {
    104   if (!stats) return;
    105   stats->RecordExecutorEnded();
    106 }
    107 
    108 void SetOutput(NodeExecStatsInterface* stats, int slot, const Tensor* v) {
    109   if (!stats) return;
    110   stats->SetOutput(slot, v);
    111 }
    112 
    113 void SetMemory(NodeExecStatsInterface* stats, OpKernelContext* ctx) {
    114   if (!stats) return;
    115   stats->SetMemory(ctx);
    116 }
    117 
    118 void SetReferencedTensors(NodeExecStatsInterface* stats,
    119                           const TensorReferenceVector& tensors) {
    120   if (!stats) return;
    121   stats->SetReferencedTensors(tensors);
    122 }
    123 
    124 }  // namespace nodestats
    125 
    126 class ExecutorImpl;
    127 class GraphView;
    128 
    129 struct EdgeInfo {
    130   int dst_id;
    131   int output_slot : 31;
    132   // true if this is the last info for output_slot in the EdgeInfo list.
    133   bool is_last : 1;
    134   int input_slot;
    135 };
    136 
    137 // Time the execution of kernels (in CPU cycles).  Used to dynamically identify
    138 // inexpensive kernels which can be dispatched inline.
    139 struct KernelTimer {
    140   uint64 start_cycles = profile_utils::CpuUtils::GetCurrentClockCycle();
    141 
    142   uint64 ElapsedCycles() {
    143     return profile_utils::CpuUtils::GetCurrentClockCycle() - start_cycles;
    144   }
    145 };
    146 
    147 struct NodeItem {
    148   NodeItem() {}
    149 
    150   // A graph node.
    151   const Node* node = nullptr;
    152 
    153   // The kernel for this node.
    154   OpKernel* kernel = nullptr;
    155 
    156   bool kernel_is_async : 1;      // True iff kernel->AsAsync() != nullptr
    157   bool is_merge : 1;             // True iff IsMerge(node)
    158   bool is_enter : 1;             // True iff IsEnter(node)
    159   bool is_constant_enter : 1;    // True iff IsEnter(node) and
    160                                  // node->GetAttr("is_constant") == true.
    161   bool is_exit : 1;              // True iff IsExit(node)
    162   bool is_control_trigger : 1;   // True iff IsControlTrigger(node)
    163   bool is_sink : 1;              // True iff IsSink(node)
    164   // True iff IsEnter(node) || IsExit(node) || IsNextIteration(node)
    165   bool is_enter_exit_or_next_iter : 1;
    166 
    167   // Cached values of node->num_inputs() and node->num_outputs(), to
    168   // avoid levels of indirection.
    169   int num_inputs;
    170   int num_outputs;
    171 
    172   // ExecutorImpl::tensors_[input_start] is the 1st positional input
    173   // for this node.
    174   int input_start = 0;
    175 
    176   // Number of output edges.
    177   size_t num_output_edges;
    178 
    179   PendingCounts::Handle pending_id;
    180 
    181   const EdgeInfo* output_edge_list() const { return output_edge_base(); }
    182 
    183   // ith output edge.
    184   const EdgeInfo& output_edge(int i) const {
    185     DCHECK_GE(i, 0);
    186     DCHECK_LT(i, num_output_edges);
    187     return output_edge_base()[i];
    188   }
    189 
    190   DataType input_type(int i) const {
    191     DCHECK_LT(i, num_inputs);
    192     return static_cast<DataType>(input_type_base()[i]);
    193   }
    194   DataType output_type(int i) const {
    195     DCHECK_LT(i, num_outputs);
    196     return static_cast<DataType>(output_type_base()[i]);
    197   }
    198 
    199   // Return array of per-output allocator attributes.
    200   const AllocatorAttributes* output_attrs() const { return output_attr_base(); }
    201 
    202   // Return array of expected input index from which each output should
    203   // be forwarded:
    204   // kNeverForward (-2) for DO NOT FORWARD (must allocate).
    205   // kNoReservation (-1) for no expected forwarding.
    206   // 0... for forward from that input.
    207   const int* forward_from() const { return forward_from_base(); }
    208 
    209  private:
    210   friend class GraphView;
    211 
    212   // Variable length section starts immediately after *this
    213   // (uint8 is enough for DataType).
    214   //   EdgeInfo            out_edges[num_out_edges];
    215   //   AllocatorAttributes output_attr[num_outputs];
    216   //   int                 forward_from[num_outputs];
    217   //   uint8               input_type[num_inputs];
    218   //   uint8               output_type[num_outputs];
    219 
    220   // Return pointer to variable length section.
    221   char* var() const {
    222     return const_cast<char*>(reinterpret_cast<const char*>(this) +
    223                              sizeof(NodeItem));
    224   }
    225 
    226   EdgeInfo* output_edge_base() const {
    227     return reinterpret_cast<EdgeInfo*>(var());
    228   }
    229   AllocatorAttributes* output_attr_base() const {
    230     return reinterpret_cast<AllocatorAttributes*>(var() + sizeof(EdgeInfo) *
    231                                                               num_output_edges);
    232   }
    233   int* forward_from_base() const {
    234     return reinterpret_cast<int*>(var() + sizeof(EdgeInfo) * num_output_edges +
    235                                   sizeof(AllocatorAttributes) * num_outputs);
    236   }
    237   uint8* input_type_base() const {
    238     return reinterpret_cast<uint8*>(
    239         var() + sizeof(EdgeInfo) * num_output_edges +
    240         sizeof(AllocatorAttributes) * num_outputs + sizeof(int) * num_outputs);
    241   }
    242   uint8* output_type_base() const {
    243     return reinterpret_cast<uint8*>(
    244         var() + sizeof(EdgeInfo) * num_output_edges +
    245         sizeof(AllocatorAttributes) * num_outputs + sizeof(int) * num_outputs +
    246         sizeof(uint8) * num_inputs);
    247   }
    248 
    249   TF_DISALLOW_COPY_AND_ASSIGN(NodeItem);
    250 };
    251 
    252 typedef gtl::InlinedVector<TensorValue, 4> TensorValueVec;
    253 typedef gtl::InlinedVector<DeviceContext*, 4> DeviceContextVec;
    254 typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec;
    255 
    256 // Immutable view of a Graph organized for efficient execution.
    257 class GraphView {
    258  public:
    259   GraphView() : space_(nullptr) {}
    260   ~GraphView();
    261 
    262   void Initialize(const Graph* g);
    263   Status SetAllocAttrs(const Graph* g, const Device* device);
    264   void SetScopedAllocatorAttrs(const std::vector<const Node*>& sa_nodes);
    265 
    266   NodeItem* node(size_t id) const {
    267     DCHECK_GE(id, 0);
    268     DCHECK_LT(id, num_nodes_);
    269     uint32 offset = node_offsets_[id];
    270     return ((offset == kuint32max)
    271                 ? nullptr
    272                 : reinterpret_cast<NodeItem*>(space_ + node_offsets_[id]));
    273   }
    274 
    275  private:
    276   char* InitializeNode(char* ptr, const Node* n);
    277   size_t NodeItemBytes(const Node* n);
    278 
    279   int32 num_nodes_ = 0;
    280   uint32* node_offsets_ = nullptr;  // array of size "graph_.num_node_ids()"
    281   // node_offsets_[id] holds the byte offset for node w/ "id" in space_
    282 
    283   char* space_;  // NodeItem objects are allocated here
    284 
    285   TF_DISALLOW_COPY_AND_ASSIGN(GraphView);
    286 };
    287 
    288 class ExecutorImpl : public Executor {
    289  public:
    290   ExecutorImpl(const LocalExecutorParams& p, std::unique_ptr<const Graph> g)
    291       : params_(p), graph_(std::move(g)), gview_() {
    292     CHECK(p.create_kernel != nullptr);
    293     CHECK(p.delete_kernel != nullptr);
    294   }
    295 
    296   ~ExecutorImpl() override {
    297     for (int i = 0; i < graph_->num_node_ids(); i++) {
    298       NodeItem* item = gview_.node(i);
    299       if (item != nullptr) {
    300         params_.delete_kernel(item->kernel);
    301       }
    302     }
    303     for (auto fiter : frame_info_) {
    304       delete fiter.second;
    305     }
    306   }
    307 
    308   Status Initialize();
    309 
    310   // Process all Nodes in the current graph, attempting to infer the
    311   // memory allocation attributes to be used wherever they may allocate
    312   // a tensor buffer.
    313   Status SetAllocAttrs();
    314 
    315   void RunAsync(const Args& args, DoneCallback done) override;
    316 
    317  private:
    318   friend class ExecutorState;
    319 
    320   struct ControlFlowInfo {
    321     gtl::FlatSet<string> unique_frame_names;
    322     std::vector<string> frame_names;
    323   };
    324 
    325   struct FrameInfo {
    326     FrameInfo()
    327         : input_count(0),
    328           total_inputs(0),
    329           pending_counts(nullptr),
    330           nodes(nullptr) {}
    331 
    332     // The total number of inputs to a frame.
    333     int input_count;
    334 
    335     // The total number of input tensors of a frame.
    336     // == sum(nodes[*].num_inputs()) where nodes are the nodes in the frame.
    337     int total_inputs;
    338 
    339     // Used to determine the next place to allocate space in the
    340     // pending_counts data structure we'll eventually construct
    341     PendingCounts::Layout pending_counts_layout;
    342 
    343     // Each frame has its own PendingCounts only for the nodes in the frame.
    344     PendingCounts* pending_counts;  // Owned
    345 
    346     // The nodes in a frame. Used only for debugging.
    347     std::vector<const Node*>* nodes;  // Owned
    348 
    349     ~FrameInfo() {
    350       delete pending_counts;
    351       delete nodes;
    352     }
    353   };
    354 
    355   static Status BuildControlFlowInfo(const Graph* graph,
    356                                      ControlFlowInfo* cf_info);
    357   void InitializePending(const Graph* graph, const ControlFlowInfo& cf_info);
    358 
    359   FrameInfo* EnsureFrameInfo(const string& fname) {
    360     auto slot = &frame_info_[fname];
    361     if (*slot == nullptr) {
    362       *slot = new FrameInfo;
    363     }
    364     return *slot;
    365   }
    366 
    367   // Owned.
    368   LocalExecutorParams params_;
    369   std::unique_ptr<const Graph> graph_;
    370   GraphView gview_;
    371 
    372   // A cached value of params_
    373   bool device_record_tensor_accesses_ = false;
    374 
    375   // Root nodes (with no in edges) that should form the initial ready queue
    376   std::vector<const Node*> root_nodes_;
    377 
    378   // Mapping from frame name to static information about the frame.
    379   // TODO(yuanbyu): We could cache it along with the graph so to avoid
    380   // the overhead of constructing it for each executor instance.
    381   gtl::FlatMap<string, FrameInfo*> frame_info_;
    382 
    383   TF_DISALLOW_COPY_AND_ASSIGN(ExecutorImpl);
    384 };
    385 
    386 // Infer memory allocation attributes of a node n's output,
    387 // based on its use node dst.  Note that dst might not be directly
    388 // connected to n by a single edge, but might be a downstream
    389 // consumer of n's output by reference.  *attr is updated with any
    390 // necessary attributes.
    391 Status InferAllocAttr(const Node* n, const Node* dst,
    392                       const DeviceNameUtils::ParsedName& local_dev_name,
    393                       AllocatorAttributes* attr);
    394 
    395 GraphView::~GraphView() {
    396   static_assert(std::is_trivially_destructible<AllocatorAttributes>::value,
    397                 "Update code if AllocatorAttributes gains a destructor");
    398   static_assert(std::is_trivially_destructible<EdgeInfo>::value,
    399                 "Update code if EdgeInfo gains a destructor");
    400   for (int i = 0; i < num_nodes_; i++) {
    401     NodeItem* n = node(i);
    402     if (n != nullptr) {
    403       n->NodeItem::~NodeItem();
    404       // Memory for "n" itself is held in space_ & gets cleaned up below
    405     }
    406   }
    407   delete[] node_offsets_;
    408   delete[] space_;
    409 }
    410 
    411 size_t GraphView::NodeItemBytes(const Node* n) {
    412   const size_t num_output_edges = n->out_edges().size();
    413   const int num_inputs = n->num_inputs();
    414   const int num_outputs = n->num_outputs();
    415 
    416   // Compute number of bytes needed for NodeItem and variable length data.
    417   // We do not subtract sizeof(var) since num_inputs/num_outputs might
    418   // both be zero.
    419   const size_t raw_bytes =
    420       sizeof(NodeItem)                             // Fixed
    421       + num_output_edges * sizeof(EdgeInfo)        // output_edges[...]
    422       + num_outputs * sizeof(AllocatorAttributes)  // output_attr[...]
    423       + num_outputs * sizeof(int)                  // forward_from[num_outputs]
    424       + num_inputs * sizeof(uint8)                 // input_type[num_inputs]
    425       + num_outputs * sizeof(uint8);               // output_type[num_outputs]
    426   static constexpr size_t kItemAlignment = sizeof(NodeItem*);
    427   static_assert(kItemAlignment % alignof(NodeItem) == 0,
    428                 "NodeItem must be aligned with kItemAlignment");
    429   static_assert(kItemAlignment % alignof(EdgeInfo) == 0,
    430                 "EdgeInfo must be aligned with kItemAlignment");
    431   static_assert(kItemAlignment % alignof(AllocatorAttributes) == 0,
    432                 "AllocatorAttributes must be aligned with kItemAlignment");
    433   static_assert(sizeof(NodeItem) % alignof(EdgeInfo) == 0,
    434                 "NodeItem must be aligned with EdgeInfo");
    435   static_assert(sizeof(NodeItem) % alignof(AllocatorAttributes) == 0,
    436                 "NodeItem must be aligned with AllocatorAttributes");
    437   static_assert(sizeof(EdgeInfo) % alignof(AllocatorAttributes) == 0,
    438                 "EdgeInfo must be aligned with AllocatorAttributes");
    439   const size_t bytes =
    440       ((raw_bytes + kItemAlignment - 1) / kItemAlignment) * kItemAlignment;
    441   return bytes;
    442 }
    443 
    444 char* GraphView::InitializeNode(char* ptr, const Node* n) {
    445   const int id = n->id();
    446   CHECK(node_offsets_[id] == kuint32max);  // Initial value in constructor
    447 
    448   const size_t bytes = NodeItemBytes(n);
    449   constexpr size_t kItemAlignment = sizeof(NodeItem*);
    450   CHECK_EQ(reinterpret_cast<uintptr_t>(ptr) % kItemAlignment, 0);
    451   NodeItem* item = reinterpret_cast<NodeItem*>(ptr);
    452 
    453   // We store a 32-bit offset relative to the beginning of space_, so that we
    454   // only need an array of 32-bit values to map from node id to the NodeItem*,
    455   // (versus 64 bits on most machines if we just stored an array of NodeItem*
    456   // pointers). Casting to int64 is needed on 32bit CPU to avoid comparing
    457   // values as "int" vs "size_t" in CHECK_LE.
    458   CHECK_LE(static_cast<int64>(ptr - space_), kuint32max);
    459   const uint32 offset = static_cast<uint32>(ptr - space_);
    460   node_offsets_[id] = offset;
    461   ptr += bytes;
    462 
    463   const size_t num_output_edges = n->out_edges().size();
    464   const int num_inputs = n->num_inputs();
    465   const int num_outputs = n->num_outputs();
    466 
    467   new (item) NodeItem();
    468   item->num_inputs = num_inputs;
    469   item->num_outputs = num_outputs;
    470   item->num_output_edges = num_output_edges;
    471 
    472   // Fill output edges.
    473   // Keep track of the last EdgeInfo in the EdgeInfo array that references
    474   // a given output slot.  For all but the last, we need to do a copy of the
    475   // Tensor when propagating results downstream in the graph, but for the
    476   // last one, we can just do a move of the Tensor object to propagate it.
    477   gtl::InlinedVector<EdgeInfo*, 4> last_indices(num_outputs, nullptr);
    478   EdgeInfo* dst_edge = item->output_edge_base();
    479   for (auto e : n->out_edges()) {
    480     dst_edge->dst_id = e->dst()->id();
    481     CHECK_LE(e->src_output(), 0x3FFFFFFF);  // Must fit in 31 bits
    482     dst_edge->output_slot = e->src_output();
    483     dst_edge->is_last = false;
    484     const int output_slot = dst_edge->output_slot;
    485     if (output_slot >= 0) {
    486       last_indices[output_slot] = dst_edge;
    487     }
    488     dst_edge->input_slot = e->dst_input();
    489     dst_edge++;
    490   }
    491   for (EdgeInfo* edge_info : last_indices) {
    492     if (edge_info != nullptr) {
    493       edge_info->is_last = true;
    494     }
    495   }
    496 
    497   AllocatorAttributes* output_attrs = item->output_attr_base();
    498   for (int i = 0; i < num_outputs; i++) {
    499     new (&output_attrs[i]) AllocatorAttributes();
    500   }
    501 
    502   DCHECK_LT(DataType_MAX, 255);  // Must fit in uint8
    503   uint8* input_types = item->input_type_base();
    504   for (int i = 0; i < num_inputs; i++) {
    505     input_types[i] = static_cast<uint8>(n->input_type(i));
    506     DCHECK_EQ(item->input_type(i), n->input_type(i));
    507   }
    508 
    509   // Check ScopedAllocatorAttrs and forward_from.  Also assign output_types.
    510   {
    511     std::vector<int> forward_input;
    512     Status fwd_status =
    513         GetNodeAttr(n->attrs(), "_forward_input", &forward_input);
    514     std::vector<int> scoped_allocator_attrs;
    515     Status sa_status =
    516         GetNodeAttr(n->attrs(), "_scoped_allocator", &scoped_allocator_attrs);
    517 
    518     int* forward_from = item->forward_from_base();
    519     uint8* output_types = item->output_type_base();
    520     for (int i = 0; i < num_outputs; ++i) {
    521       output_types[i] = static_cast<uint8>(n->output_type(i));
    522       DCHECK_EQ(item->output_type(i), n->output_type(i));
    523 
    524       forward_from[i] = OpKernelContext::Params::kNoReservation;
    525       if (sa_status.ok()) {
    526         for (int j = 0; j < scoped_allocator_attrs.size(); j += 2) {
    527           if (scoped_allocator_attrs[j] == i) {
    528             // This output slot must be explicitly allocated from a
    529             // ScopedAllocator.
    530             forward_from[i] = OpKernelContext::Params::kNeverForward;
    531             DCHECK_EQ(output_attrs[i].scope_id, 0);
    532             output_attrs[i].scope_id = scoped_allocator_attrs[j + 1];
    533           }
    534         }
    535       }
    536       if (fwd_status.ok() &&
    537           forward_from[i] == OpKernelContext::Params::kNoReservation) {
    538         DCHECK_EQ(forward_input.size() % 2, 0);
    539         for (int j = 0; j < forward_input.size(); j += 2) {
    540           if (forward_input[j + 1] == i) {
    541             DCHECK_EQ(forward_from[i], OpKernelContext::Params::kNoReservation);
    542             forward_from[i] = forward_input[j];
    543             break;
    544           }
    545         }
    546       }
    547     }
    548   }
    549 
    550   return ptr;
    551 }
    552 
    553 void GraphView::Initialize(const Graph* g) {
    554   CHECK(node_offsets_ == nullptr);
    555   const int num_nodes = g->num_node_ids();
    556   num_nodes_ = num_nodes;
    557   size_t total_bytes = 0;
    558   for (const Node* n : g->nodes()) {
    559     total_bytes += NodeItemBytes(n);
    560   }
    561 
    562   node_offsets_ = new uint32[num_nodes];
    563   for (int i = 0; i < num_nodes; i++) {
    564     node_offsets_[i] = kuint32max;
    565   }
    566 
    567   space_ = new char[total_bytes];  // NodeItem objects are allocated here
    568   char* ptr = space_;
    569   for (const Node* n : g->nodes()) {
    570     ptr = InitializeNode(ptr, n);
    571   }
    572   CHECK_EQ(ptr, space_ + total_bytes);
    573 }
    574 
    575 void GetMaxPendingCounts(const Node* n, size_t* max_pending,
    576                          size_t* max_dead_count) {
    577   const size_t num_in_edges = n->in_edges().size();
    578   size_t initial_count;
    579   if (IsMerge(n)) {
    580     // merge waits all control inputs so we initialize the pending
    581     // count to be the number of control edges.
    582     int32 num_control_edges = 0;
    583     for (const Edge* edge : n->in_edges()) {
    584       if (edge->IsControlEdge()) {
    585         num_control_edges++;
    586       }
    587     }
    588     // Use bit 0 to indicate if we are waiting for a ready live data input.
    589     initial_count = 1 + (num_control_edges << 1);
    590   } else {
    591     initial_count = num_in_edges;
    592   }
    593 
    594   *max_pending = initial_count;
    595   *max_dead_count = num_in_edges;
    596 }
    597 
    598 Status ExecutorImpl::Initialize() {
    599   gview_.Initialize(graph_.get());
    600 
    601   // Build the information about frames in this subgraph.
    602   ControlFlowInfo cf_info;
    603   TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph_.get(), &cf_info));
    604 
    605   // Cache this value so we make this virtual function call once, rather
    606   // that O(# steps * # nodes per step) times.
    607   device_record_tensor_accesses_ =
    608       params_.device->RequiresRecordingAccessedTensors();
    609 
    610   for (auto& it : cf_info.unique_frame_names) {
    611     EnsureFrameInfo(it)->nodes = new std::vector<const Node*>;
    612   }
    613 
    614   // Preprocess every node in the graph to create an instance of op
    615   // kernel for each node.
    616   for (const Node* n : graph_->nodes()) {
    617     const int id = n->id();
    618     const string& frame_name = cf_info.frame_names[id];
    619     FrameInfo* frame_info = EnsureFrameInfo(frame_name);
    620 
    621     // See if this node is a root node, and if so, add to root_nodes_.
    622     if (n->in_edges().empty()) {
    623       root_nodes_.push_back(n);
    624     }
    625 
    626     NodeItem* item = gview_.node(id);
    627     item->node = n;
    628 
    629     item->input_start = frame_info->total_inputs;
    630     frame_info->total_inputs += n->num_inputs();
    631 
    632     Status s = params_.create_kernel(n->def(), &item->kernel);
    633     if (!s.ok()) {
    634       item->kernel = nullptr;
    635       s = AttachDef(s, *n);
    636       LOG(ERROR) << "Executor failed to create kernel. " << s;
    637       return s;
    638     }
    639     CHECK(item->kernel);
    640     item->kernel_is_async = (item->kernel->AsAsync() != nullptr);
    641     item->is_merge = IsMerge(n);
    642     item->is_enter = IsEnter(n);
    643     if (item->is_enter) {
    644       bool is_constant_enter;
    645       TF_RETURN_IF_ERROR(
    646           GetNodeAttr(n->attrs(), "is_constant", &is_constant_enter));
    647       item->is_constant_enter = is_constant_enter;
    648     } else {
    649       item->is_constant_enter = false;
    650     }
    651     item->is_exit = IsExit(n);
    652     item->is_control_trigger = IsControlTrigger(n);
    653     item->is_sink = IsSink(n);
    654     item->is_enter_exit_or_next_iter =
    655         (IsEnter(n) || IsExit(n) || IsNextIteration(n));
    656 
    657     // Compute the maximum values we'll store for this node in the
    658     // pending counts data structure, and allocate a handle in
    659     // that frame's pending counts data structure that has enough
    660     // space to store these maximal count values.
    661     size_t max_pending, max_dead;
    662     GetMaxPendingCounts(n, &max_pending, &max_dead);
    663     item->pending_id =
    664         frame_info->pending_counts_layout.CreateHandle(max_pending, max_dead);
    665 
    666     // Initialize static information about the frames in the graph.
    667     frame_info->nodes->push_back(n);
    668     if (IsEnter(n)) {
    669       string enter_name;
    670       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "frame_name", &enter_name));
    671       EnsureFrameInfo(enter_name)->input_count++;
    672     }
    673   }
    674 
    675   // Initialize PendingCounts only after item->pending_id is initialized for
    676   // all nodes.
    677   InitializePending(graph_.get(), cf_info);
    678 
    679   return gview_.SetAllocAttrs(graph_.get(), params_.device);
    680 }
    681 
    682 // If a Node has been marked to use a ScopedAllocator x for output i, then
    683 // sc_attr will contain the subsequence (i, x) at an even offset.  This function
    684 // extracts and transfers that ScopedAllocator id to alloc_attr.  For now, we
    685 // only allow one ScopedAllocator use per Node.
    686 bool ExtractScopedAllocatorAttr(const std::vector<int>& sc_attr,
    687                                 int output_index,
    688                                 AllocatorAttributes* alloc_attr) {
    689   DCHECK_LE(2, sc_attr.size());
    690   for (int i = 0; i < sc_attr.size(); i += 2) {
    691     if (sc_attr[i] == output_index) {
    692       CHECK_EQ(alloc_attr->scope_id, 0);
    693       alloc_attr->scope_id = sc_attr[i + 1];
    694       return true;
    695     }
    696   }
    697   return false;
    698 }
    699 
    700 void GraphView::SetScopedAllocatorAttrs(
    701     const std::vector<const Node*>& sa_nodes) {
    702   for (const Node* sa : sa_nodes) {
    703     NodeItem* sa_item = node(sa->id());
    704     AllocatorAttributes* sa_attrs = sa_item->output_attr_base();
    705     // Control edges out of the ScopedAllocator should be use instances, but may
    706     // include a few other nodes.
    707     for (const auto& e : sa->out_edges()) {
    708       if (!e->IsControlEdge()) {
    709         continue;
    710       }
    711       Node* use_node = e->dst();
    712       NodeItem* item = node(use_node->id());
    713       AllocatorAttributes* use_attrs = item->output_attr_base();
    714       std::vector<int> scoped_allocator_attrs;
    715       Status s = GetNodeAttr(use_node->attrs(), "_scoped_allocator",
    716                              &scoped_allocator_attrs);
    717       if (!s.ok()) {
    718         VLOG(2) << "Failed to find expected ScopedAllocator attr on "
    719                 << use_node->name();
    720         continue;
    721       }
    722       // There can be more than one output using ScopedAllocation, but this
    723       // analysis assumes they use the same ScopedAllocator.
    724       for (const auto& e : use_node->out_edges()) {
    725         if (!e->IsControlEdge()) {
    726           AllocatorAttributes attr;
    727           if (ExtractScopedAllocatorAttr(scoped_allocator_attrs,
    728                                          e->src_output(), &attr)) {
    729             // Set the scope_id on this use instance node.
    730             (use_attrs + e->src_output())->Merge(attr);
    731             // Propagate the other attributes of this node back to the SA node.
    732             attr = *(use_attrs + e->src_output());
    733             attr.scope_id = 0;
    734             sa_attrs->Merge(attr);
    735           }
    736         }
    737       }
    738     }
    739   }
    740 }
    741 
    742 Status GraphView::SetAllocAttrs(const Graph* g, const Device* device) {
    743   Status s;
    744   DeviceNameUtils::ParsedName local_dev_name = device->parsed_name();
    745 
    746   std::vector<const Node*> scoped_allocator_instances;
    747   for (const Node* n : g->nodes()) {
    748     NodeItem* item = node(n->id());
    749     AllocatorAttributes* attrs = item->output_attr_base();
    750     if (IsScopedAllocator(n)) {
    751       scoped_allocator_instances.push_back(n);
    752     }
    753 
    754     // Examine the out edges of each node looking for special use
    755     // cases that may affect memory allocation attributes.
    756     for (const auto& e : n->out_edges()) {
    757       if (!e->IsControlEdge()) {
    758         AllocatorAttributes attr;
    759         s = InferAllocAttr(n, e->dst(), local_dev_name, &attr);
    760         if (!s.ok()) return s;
    761         if (attr.value != 0 || attr.scope_id != 0) {
    762           attrs[e->src_output()].Merge(attr);
    763         }
    764       }
    765     }
    766 
    767     for (int out = 0; out < n->num_outputs(); out++) {
    768       const OpKernel* op_kernel = item->kernel;
    769       DCHECK_LT(out, op_kernel->output_memory_types().size());
    770       bool on_host = op_kernel->output_memory_types()[out] == HOST_MEMORY;
    771       if (on_host) {
    772         AllocatorAttributes h;
    773         h.set_on_host(on_host);
    774         attrs[out].Merge(h);
    775       }
    776     }
    777   }
    778   SetScopedAllocatorAttrs(scoped_allocator_instances);
    779   return s;
    780 }
    781 
    782 Status InferAllocAttr(const Node* n, const Node* dst,
    783                       const DeviceNameUtils::ParsedName& local_dev_name,
    784                       AllocatorAttributes* attr) {
    785   Status s;
    786   // Note that it's possible for *n to be a Recv and *dst to be a Send,
    787   // so these two cases are not mutually exclusive.
    788   if (IsRecv(n)) {
    789     string src_name;
    790     s = GetNodeAttr(n->attrs(), "send_device", &src_name);
    791     if (!s.ok()) return s;
    792     DeviceNameUtils::ParsedName parsed_src_name;
    793     if (!DeviceNameUtils::ParseFullName(src_name, &parsed_src_name)) {
    794       s = errors::Internal("Bad send_device attr '", src_name, "' in node ",
    795                            n->name());
    796       return s;
    797     }
    798     if (!DeviceNameUtils::IsSameAddressSpace(parsed_src_name, local_dev_name)) {
    799       // Value is going to be the sink of an RPC.
    800       attr->set_nic_compatible(true);
    801       VLOG(2) << "node " << n->name() << " is the sink of an RPC in";
    802     } else if ((local_dev_name.type == "CPU" || n->IsHostRecv()) &&
    803                parsed_src_name.type != "CPU") {
    804       // Value is going to be the sink of a local DMA from GPU to CPU (or
    805       // other types of accelerators).
    806       attr->set_gpu_compatible(true);
    807       VLOG(2) << "node " << n->name() << " is the sink of a gpu->cpu copy";
    808     } else {
    809       VLOG(2) << "default alloc case local type " << local_dev_name.type
    810               << " remote type " << parsed_src_name.type;
    811     }
    812   }
    813   if (IsSend(dst)) {
    814     string dst_name;
    815     s = GetNodeAttr(dst->attrs(), "recv_device", &dst_name);
    816     if (!s.ok()) return s;
    817     DeviceNameUtils::ParsedName parsed_dst_name;
    818     if (!DeviceNameUtils::ParseFullName(dst_name, &parsed_dst_name)) {
    819       s = errors::Internal("Bad recv_device attr '", dst_name, "' in node ",
    820                            n->name());
    821       return s;
    822     }
    823     if (!DeviceNameUtils::IsSameAddressSpace(parsed_dst_name, local_dev_name)) {
    824       // Value is going to be the source of an RPC.
    825       attr->set_nic_compatible(true);
    826       VLOG(2) << "node " << n->name() << " is the source of an RPC out";
    827     } else if ((local_dev_name.type == "CPU" || dst->IsHostSend()) &&
    828                parsed_dst_name.type != "CPU") {
    829       // Value is going to be the source of a local DMA from CPU to GPU (or
    830       // other types of accelerators).
    831       // Note that this does not cover the case where the allocation of the
    832       // output tensor is not generated by the src: n.
    833       attr->set_gpu_compatible(true);
    834       VLOG(2) << "node " << n->name() << " is the source of a cpu->gpu copy";
    835     } else {
    836       VLOG(2) << "default alloc case local type " << local_dev_name.type
    837               << " remote type " << parsed_dst_name.type;
    838     }
    839   }
    840   if (n->IsCollective()) {
    841     // We'll make the sweeping assumption that any collective op is going
    842     // to be involved in network i/o.
    843     attr->set_nic_compatible(true);
    844   }
    845   return s;
    846 }
    847 
    848 // The state associated with one invocation of ExecutorImpl::Run.
    849 // ExecutorState dispatches nodes when they become ready and keeps
    850 // track of how many predecessors of a node have not done (pending_).
    851 class ExecutorState {
    852  public:
    853   ExecutorState(const Executor::Args& args, ExecutorImpl* impl);
    854   ~ExecutorState();
    855 
    856   void RunAsync(Executor::DoneCallback done);
    857 
    858  private:
    859   // Either a tensor pointer (pass-by-reference) or a tensor (pass-by-value).
    860   // TODO(yuanbyu): A better way to do "has_value"?
    861   struct Entry {
    862     Entry() {}
    863     Entry(const Entry& other)
    864         : ref(other.ref),
    865           ref_mu(other.ref_mu),
    866           has_value(other.has_value),
    867           val_field_is_set(other.val_field_is_set),
    868           alloc_attr(other.alloc_attr),
    869           device_context(other.device_context) {
    870       if (val_field_is_set) {
    871         val.Init(*other.val);
    872       }
    873     }
    874     ~Entry() {
    875       if (val_field_is_set) val.Destroy();
    876     }
    877 
    878     Entry& operator=(const Entry& other) {
    879       if (val_field_is_set) {
    880         val.Destroy();
    881       }
    882       ref = other.ref;
    883       ref_mu = other.ref_mu;
    884       has_value = other.has_value;
    885       val_field_is_set = other.val_field_is_set;
    886       alloc_attr = other.alloc_attr;
    887       device_context = other.device_context;
    888       if (val_field_is_set) {
    889         val.Init(*other.val);
    890       }
    891       return *this;
    892     }
    893 
    894     Entry& operator=(Entry&& other) {
    895       if (val_field_is_set) {
    896         val.Destroy();
    897       }
    898       ref = other.ref;
    899       ref_mu = other.ref_mu;
    900       has_value = other.has_value;
    901       val_field_is_set = other.val_field_is_set;
    902       alloc_attr = other.alloc_attr;
    903       device_context = other.device_context;
    904       if (val_field_is_set) {
    905         val.Init(std::move(*other.val));
    906       }
    907       return *this;
    908     }
    909 
    910     // Clears the <val> field.
    911     void ClearVal() {
    912       if (val_field_is_set) {
    913         val.Destroy();
    914         val_field_is_set = false;
    915         has_value = false;
    916       }
    917     }
    918 
    919     // A tensor value, if val_field_is_set.
    920     ManualConstructor<Tensor> val;
    921 
    922     Tensor* ref = nullptr;    // A tensor reference.
    923     mutex* ref_mu = nullptr;  // mutex for *ref if ref is not nullptr.
    924 
    925     // Whether the value exists, either in <val> or <ref>.
    926     bool has_value = false;
    927 
    928     bool val_field_is_set = false;
    929 
    930     // The attributes of the allocator that creates the tensor.
    931     AllocatorAttributes alloc_attr;
    932 
    933     // Every entry carries an optional DeviceContext containing
    934     // Device-specific information about how the Tensor was produced.
    935     DeviceContext* device_context = nullptr;
    936   };
    937 
    938   // Contains a value for [node->id()] for the device context assigned by the
    939   // device at the beginning of a step.
    940   DeviceContextMap device_context_map_;
    941 
    942   struct TaggedNode;
    943   typedef gtl::InlinedVector<TaggedNode, 8> TaggedNodeSeq;
    944   typedef gtl::InlinedVector<Entry, 4> EntryVector;
    945 
    946   struct IterationState {
    947     explicit IterationState(const PendingCounts* pending_counts,
    948                             int total_input_tensors)
    949         : input_tensors(new Entry[total_input_tensors]),
    950           outstanding_ops(0),
    951           outstanding_frame_count(0),
    952           counts_(*pending_counts) {  // Initialize with copy of *pending_counts
    953     }
    954 
    955     // The state of an iteration.
    956 
    957     // One copy per iteration. For iteration k, i-th node's j-th input is in
    958     // input_tensors[k][impl_->nodes[i].input_start + j]. An entry is either
    959     // a tensor pointer (pass-by-reference) or a tensor (pass-by-value).
    960     //
    961     // NOTE: No need to protect input_tensors[i] by any locks because it
    962     // is resized once. Each element of tensors_ is written once by the
    963     // source node of an edge and is cleared by the destination of the same
    964     // edge. The latter node is never run concurrently with the former node.
    965     Entry* input_tensors;
    966 
    967     // The number of outstanding ops for each iteration.
    968     size_t outstanding_ops;
    969 
    970     // The number of outstanding frames for each iteration.
    971     int outstanding_frame_count;
    972     int pending(PendingCounts::Handle h) { return counts_.pending(h); }
    973     int decrement_pending(PendingCounts::Handle h, int v) {
    974       return counts_.decrement_pending(h, v);
    975     }
    976     // Mark a merge node as live
    977     // REQUIRES: Node corresponding to "h" is a merge node
    978     void mark_live(PendingCounts::Handle h) { counts_.mark_live(h); }
    979     // Mark a node to show that processing has started.
    980     void mark_started(PendingCounts::Handle h) { counts_.mark_started(h); }
    981     // Mark a node to show that processing has completed.
    982     void mark_completed(PendingCounts::Handle h) { counts_.mark_completed(h); }
    983     PendingCounts::NodeState node_state(PendingCounts::Handle h) {
    984       return counts_.node_state(h);
    985     }
    986 
    987     int dead_count(PendingCounts::Handle h) { return counts_.dead_count(h); }
    988     void increment_dead_count(PendingCounts::Handle h) {
    989       counts_.increment_dead_count(h);
    990     }
    991     void adjust_for_activation(PendingCounts::Handle h, bool increment_dead,
    992                                int* pending_result, int* dead_result) {
    993       counts_.adjust_for_activation(h, increment_dead, pending_result,
    994                                     dead_result);
    995     }
    996 
    997     ~IterationState() { delete[] input_tensors; }
    998 
    999    private:
   1000     PendingCounts counts_;
   1001   };
   1002 
   1003   struct FrameState {
   1004     explicit FrameState(const ExecutorImpl* impl, int parallel_iters)
   1005         : executor(impl),
   1006           max_parallel_iterations(parallel_iters),
   1007           num_outstanding_iterations(1) {}
   1008 
   1009     // A new frame is created for each loop. Execution starts at iteration 0.
   1010     // When a value at iteration 0 passes through a NextIteration node,
   1011     // iteration 1 is created and starts running. Note that iteration 0 may
   1012     // still be running so multiple iterations may run in parallel. The
   1013     // frame maintains the state of iterations in several data structures
   1014     // such as pending_count and input_tensors. When iteration 0 completes,
   1015     // we garbage collect the state of iteration 0.
   1016     //
   1017     // A frame instance is considered "done" and can be garbage collected
   1018     // if all its inputs have entered and all its iterations are "done".
   1019     //
   1020     // A frame manages the live iterations of an iterative computation.
   1021     // Iteration i is considered "done" when there are no outstanding ops,
   1022     // frames at iteration i are done, all recvs for this iteration are
   1023     // completed, and iteration i-1 is done. For iteration 0, we instead
   1024     // wait for there to be no more pending inputs of the frame.
   1025     //
   1026     // Frames and iterations are garbage collected once they are done.
   1027     // The state we need to keep around is highly dependent on the
   1028     // parallelism enabled by the scheduler. We may want to have the
   1029     // scheduler dynamically control the outstanding number of live
   1030     // parallel frames and iterations. To reduce the state space, the
   1031     // scheduler might want to schedule ops in inner frames first and
   1032     // lower iterations first.
   1033     //
   1034     // This frame state is mostly initialized lazily on demand so we
   1035     // don't introduce unnecessary overhead.
   1036 
   1037     // The executor the frame is in.
   1038     const ExecutorImpl* executor = nullptr;
   1039 
   1040     // The name of this frame, which is the concatenation of its parent
   1041     // frame name, the iteration of the parent frame when this frame was
   1042     // created, and the value of the attr 'frame_name'.
   1043     string frame_name;
   1044 
   1045     // The unique id for this frame. Generated by fingerprinting
   1046     // frame_name.
   1047     uint64 frame_id;
   1048 
   1049     // The iteration id of its parent frame when this frame is created.
   1050     // -1 if there is no parent frame. The frame_name/parent_iter pair
   1051     // uniquely identifies this FrameState.
   1052     int64 parent_iter = -1;
   1053 
   1054     // The FrameState of its parent frame.
   1055     FrameState* parent_frame = nullptr;
   1056 
   1057     // The maximum allowed number of parallel iterations.
   1058     const int max_parallel_iterations;
   1059 
   1060     // The number of inputs this frame is still waiting.
   1061     int num_pending_inputs = 0;
   1062 
   1063     // The highest iteration number we have reached so far in this frame.
   1064     int64 iteration_count GUARDED_BY(mu) = 0;
   1065 
   1066     // The number of outstanding iterations.
   1067     int num_outstanding_iterations GUARDED_BY(mu) = 1;
   1068 
   1069     // The active iteration states of this frame.
   1070     gtl::InlinedVector<IterationState*, 12> iterations;
   1071 
   1072     // The NextIteration nodes to enter a new iteration. If the number of
   1073     // outstanding iterations reaches the limit, we will defer the start of
   1074     // the next iteration until the number of outstanding iterations falls
   1075     // below the limit.
   1076     std::vector<std::pair<const Node*, Entry>> next_iter_roots GUARDED_BY(mu);
   1077 
   1078     // The values of the loop invariants for this loop. They are added into
   1079     // this list as they "enter" the frame. When a loop invariant enters,
   1080     // we make it available to all active iterations. When the frame starts
   1081     // a new iteration, we make all the current loop invariants available
   1082     // to the new iteration.
   1083     std::vector<std::pair<const Node*, Entry>> inv_values GUARDED_BY(mu);
   1084 
   1085     // The list of dead exit nodes for the current highest iteration. We
   1086     // will only "execute" the dead exits of the final iteration.
   1087     std::vector<const Node*> dead_exits GUARDED_BY(mu);
   1088 
   1089     // Static information specific to this frame.
   1090     PendingCounts* pending_counts = nullptr;
   1091     int total_input_tensors = 0;
   1092     std::vector<const Node*>* nodes = nullptr;
   1093 
   1094     // Lock ordering: ExecutorState.mu_ < mu;
   1095     // during structured traversal: parent_frame->mu < mu.
   1096     mutex mu;
   1097 
   1098     void InitializeFrameInfo(const string& enter_name) {
   1099       auto it_frame_info = executor->frame_info_.find(enter_name);
   1100       DCHECK(it_frame_info != executor->frame_info_.end());
   1101       ExecutorImpl::FrameInfo* finfo = it_frame_info->second;
   1102       pending_counts = finfo->pending_counts;
   1103       total_input_tensors = finfo->total_inputs;
   1104       num_pending_inputs = finfo->input_count;
   1105       nodes = finfo->nodes;
   1106     }
   1107 
   1108     inline IterationState* GetIteration(int64 iter)
   1109         EXCLUSIVE_LOCKS_REQUIRED(mu) {
   1110       size_t index = iter % iterations.size();
   1111       return iterations[index];
   1112     }
   1113 
   1114     inline void SetIteration(int64 iter, IterationState* state)
   1115         EXCLUSIVE_LOCKS_REQUIRED(mu) {
   1116       size_t index = iter % iterations.size();
   1117       DCHECK(state == nullptr || iterations[index] == nullptr);
   1118       iterations[index] = state;
   1119     }
   1120 
   1121     // Decrement the outstanding op count and clean up the iterations in the
   1122     // frame. Return true iff the execution of the frame is done.
   1123     inline bool DecrementOutstandingOps(const GraphView* gview, int64 iter,
   1124                                         TaggedNodeSeq* ready) {
   1125       mutex_lock l(mu);
   1126       return DecrementOutstandingOpsLocked(gview, iter, ready);
   1127     }
   1128 
   1129     // Decrement the outstanding op count and clean up the iterations in the
   1130     // frame. Return true iff the execution of the frame is done.
   1131     inline bool DecrementOutstandingOpsLocked(const GraphView* gview,
   1132                                               int64 iter, TaggedNodeSeq* ready)
   1133         EXCLUSIVE_LOCKS_REQUIRED(mu) {
   1134       IterationState* istate = GetIteration(iter);
   1135       istate->outstanding_ops--;
   1136       if (istate->outstanding_ops != 0) {
   1137         return false;
   1138       } else {
   1139         return CleanupIterations(gview, iter, ready);
   1140       }
   1141     }
   1142 
   1143     // Returns true if the computation in the frame is completed.
   1144     inline bool IsFrameDone() EXCLUSIVE_LOCKS_REQUIRED(mu) {
   1145       return (num_pending_inputs == 0 && num_outstanding_iterations == 0);
   1146     }
   1147 
   1148     // Returns true if the iteration of the frame is completed.
   1149     bool IsIterationDone(int64 iter) EXCLUSIVE_LOCKS_REQUIRED(mu);
   1150 
   1151     // Increments the iteration id. If this is a new iteration, initialize it.
   1152     void IncrementIteration(const GraphView* gview, TaggedNodeSeq* ready)
   1153         EXCLUSIVE_LOCKS_REQUIRED(mu);
   1154 
   1155     // Activate all the deferred NextIteration nodes in a new iteration.
   1156     void ActivateNexts(const GraphView* gview, int64 iter, TaggedNodeSeq* ready)
   1157         EXCLUSIVE_LOCKS_REQUIRED(mu);
   1158 
   1159     // Activate all the current loop invariants in a new iteration.
   1160     void ActivateLoopInvs(const GraphView* gview, int64 iter,
   1161                           TaggedNodeSeq* ready) EXCLUSIVE_LOCKS_REQUIRED(mu);
   1162 
   1163     // Add a new loop invariant and make it available to all active
   1164     // iterations.
   1165     void AddLoopInv(const NodeItem* item, const Entry& value,
   1166                     TaggedNodeSeq* ready) EXCLUSIVE_LOCKS_REQUIRED(mu);
   1167 
   1168     // Activate the successors of a node. Contents of *outputs are left in an
   1169     // indeterminate state after returning from this method.
   1170     void ActivateNodes(const NodeItem* item, const bool is_dead, int64 iter,
   1171                        EntryVector* outputs, TaggedNodeSeq* ready)
   1172         EXCLUSIVE_LOCKS_REQUIRED(mu);
   1173 
   1174     // Cleanup iterations of this frame starting from iteration iter.
   1175     bool CleanupIterations(const GraphView* gview, int64 iter,
   1176                            TaggedNodeSeq* ready) EXCLUSIVE_LOCKS_REQUIRED(mu);
   1177 
   1178     ~FrameState() {
   1179       for (size_t i = 0; i < iterations.size(); ++i) {
   1180         delete iterations[i];
   1181         iterations[i] = nullptr;
   1182       }
   1183     }
   1184   };
   1185 
   1186   // A tagged node: <frame*, iter, node*>.
   1187   struct TaggedNode {
   1188     const Node* node = nullptr;
   1189     FrameState* input_frame = nullptr;
   1190     int64 input_iter = -1;
   1191     bool is_dead = false;
   1192 
   1193     TaggedNode(const Node* t_node, FrameState* in_frame, int64 in_iter,
   1194                bool dead) {
   1195       node = t_node;
   1196       input_frame = in_frame;
   1197       input_iter = in_iter;
   1198       is_dead = dead;
   1199     }
   1200   };
   1201 
   1202   // A drop-in replacement for std::deque<TaggedNode>.  We typically don't
   1203   // have that many nodes in the ready queue, so we just use a vector and
   1204   // don't free up memory from the queue as we consume nodes.
   1205   class TaggedNodeReadyQueue {
   1206    public:
   1207     TaggedNodeReadyQueue() : front_index_(0) {}
   1208 
   1209     void push_back(TaggedNode node) { ready_.push_back(node); }
   1210     TaggedNode front() const {
   1211       DCHECK_LT(front_index_, ready_.size());
   1212       return ready_[front_index_];
   1213     }
   1214     void pop_front() {
   1215       DCHECK_LT(front_index_, ready_.size());
   1216       front_index_++;
   1217       if ((front_index_ == ready_.size()) || (front_index_ > 16384)) {
   1218         if (front_index_ == ready_.size()) {
   1219           ready_.clear();
   1220         } else {
   1221           // Lots of unused entries at beginning of vector: move everything
   1222           // down to start of vector.
   1223           ready_.erase(ready_.begin(), ready_.begin() + front_index_);
   1224         }
   1225         front_index_ = 0;
   1226       }
   1227     }
   1228     bool empty() const { return ready_.empty(); }
   1229     const TaggedNode* begin() const { return ready_.begin() + front_index_; }
   1230     const TaggedNode* end() const { return ready_.end(); }
   1231 
   1232    private:
   1233     gtl::InlinedVector<TaggedNode, 16> ready_;
   1234     int front_index_;
   1235   };
   1236 
   1237   struct AsyncState;
   1238 
   1239   const bool vlog_;  // true if VLOG_IS_ON(1). Used to check vlog cheaply.
   1240 
   1241   // true if LogMemory::IsEnabled(). Used to check memory enabled cheaply.
   1242   const bool log_memory_;
   1243 
   1244   int64 step_id_;
   1245   // Not owned.
   1246   Rendezvous* rendezvous_;
   1247   CollectiveExecutor* collective_executor_ = nullptr;
   1248   SessionState* session_state_;
   1249   string session_handle_;
   1250   TensorStore* tensor_store_;
   1251   // Step-local container.
   1252   ScopedStepContainer* step_container_;
   1253   StepStatsCollectorInterface* const stats_collector_;
   1254   const tracing::EventCollector* const event_collector_;
   1255   Context context_;
   1256 
   1257   // QUESTION: Make it a checkpoint::TensorSliceReaderCacheWrapper
   1258   // instead of a pointer?  (avoids having to delete).
   1259   checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache_;
   1260   CallFrameInterface* call_frame_;
   1261   const ExecutorImpl* impl_;
   1262   CancellationManager* cancellation_manager_;
   1263   Executor::Args::Runner runner_;
   1264   bool sync_on_finish_;
   1265   const bool trace_using_annotations_;
   1266 
   1267   // Owned.
   1268 
   1269   // A flag that is set on error after the frame state has been
   1270   // dumped for diagnostic purposes.
   1271   bool dumped_on_error_ = false;
   1272 
   1273   // The root frame in which the execution of this step is started.
   1274   FrameState* root_frame_;
   1275 
   1276   // Invoked when the execution finishes.
   1277   Executor::DoneCallback done_cb_;
   1278 
   1279   std::atomic_int_fast32_t num_outstanding_ops_;
   1280 
   1281   // Available via OpKernelContext to every OpKernel invocation.
   1282   mutex num_deferred_ops_mu_;
   1283   condition_variable no_deferred_ops_cv_;
   1284   int64 num_deferred_ops_ GUARDED_BY(num_deferred_ops_mu_) = 0;
   1285 
   1286   mutex mu_;
   1287   Status status_ GUARDED_BY(mu_);
   1288 
   1289   // Mapping from frame name to outstanding frames. A new frame is created
   1290   // at some iteration of an active frame. So the unique key for the new
   1291   // child frame is composed of the name of the parent frame, the iteration
   1292   // number at which the parent frame is creating the new frame, and the
   1293   // name of the new frame from nodedef.
   1294   gtl::FlatMap<string, FrameState*> outstanding_frames_ GUARDED_BY(mu_);
   1295 
   1296   // The unique name of a frame.
   1297   inline string MakeFrameName(FrameState* frame, int64 iter_id,
   1298                               const string& name) {
   1299     return strings::StrCat(frame->frame_name, ";", iter_id, ";", name);
   1300   }
   1301 
   1302   // Find an existing or create a new child frame in the frame 'frame' at
   1303   // iteration 'iter'.
   1304   void FindOrCreateChildFrame(FrameState* frame, int64 iter, const Node* node,
   1305                               FrameState** child);
   1306 
   1307   // Delete a frame. Called when the frame is done.
   1308   void DeleteFrame(FrameState* frame, TaggedNodeSeq* ready);
   1309 
   1310   // Cleanup frames and iterations starting from frame/iter. Called when
   1311   // a child frame is done.
   1312   void CleanupFramesIterations(FrameState* frame, int64 iter,
   1313                                TaggedNodeSeq* ready);
   1314 
   1315   // Process a ready node in current thread.
   1316   void Process(TaggedNode node, int64 scheduled_nsec);
   1317 
   1318   // Before invoking item->kernel, fills in its "inputs".
   1319   Status PrepareInputs(const NodeItem& item, Entry* first_input,
   1320                        TensorValueVec* inputs,
   1321                        DeviceContextVec* input_device_contexts,
   1322                        AllocatorAttributeVec* input_alloc_attrs,
   1323                        bool* is_input_dead);
   1324 
   1325   // After item->kernel computation is done, processes its outputs.
   1326   Status ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
   1327                         EntryVector* outputs, NodeExecStatsInterface* stats);
   1328 
   1329   // After processing the outputs, propagates the outputs to their dsts.
   1330   // Contents of *outputs are left in an indeterminate state after
   1331   // returning from this method.
   1332   void PropagateOutputs(const TaggedNode& tagged_node, const NodeItem* item,
   1333                         EntryVector* outputs, TaggedNodeSeq* ready);
   1334 
   1335   // "node" just finishes. Takes ownership of "stats". Returns true if
   1336   // execution has completed.
   1337   bool NodeDone(const Status& s, const Node* node, const TaggedNodeSeq& ready,
   1338                 NodeExecStatsInterface* stats,
   1339                 TaggedNodeReadyQueue* inline_ready);
   1340 
   1341   // Schedule all the expensive nodes in 'ready', and put all the inexpensive
   1342   // nodes in 'ready' into 'inline_ready'.
   1343   void ScheduleReady(const TaggedNodeSeq& ready,
   1344                      TaggedNodeReadyQueue* inline_ready);
   1345 
   1346   // For debugging/logging only.
   1347   inline void MaybeMarkCompleted(FrameState* frame, int64 iter, int64 id);
   1348 
   1349   // Provide debugging output about an outstanding node in the executor.
   1350   void DumpPendingNodeState(const int node_id, const Entry* input_vector,
   1351                             bool show_nodes_with_no_ready_inputs);
   1352   void DumpActiveNodeState(const int node_id, const Entry* input_vector);
   1353 
   1354   // Provide debugging output about an outstanding iteration in the executor.
   1355   void DumpIterationState(const FrameState* frame, IterationState* iteration);
   1356 
   1357   // Provide debugging output of the state of the executor.
   1358   void DumpState();
   1359   const Tensor* GetTensorValueForDump(const Entry& input);
   1360 
   1361   // Clean up when this executor is done.
   1362   void Finish();
   1363   // Schedule Finish() on a separate thread if it needs to wait for deferred
   1364   // async ops to complete; otherwise run it on the current thread.
   1365   void ScheduleFinish();
   1366 
   1367   // A standalone routine for this expression so that we can express
   1368   // that we don't want thread safety analysis on this reference (it's
   1369   // safe to do without the lock because the iterations array never
   1370   // resizes and this particular iteration's array element will not
   1371   // be changed out from under us because the iteration is still alive).
   1372   Entry* GetInputTensors(FrameState* input_frame,
   1373                          int64 input_iter) const NO_THREAD_SAFETY_ANALYSIS {
   1374     return input_frame->GetIteration(input_iter)->input_tensors;
   1375   }
   1376 };
   1377 
   1378 ExecutorState::ExecutorState(const Executor::Args& args, ExecutorImpl* impl)
   1379     : vlog_(VLOG_IS_ON(1)),
   1380       log_memory_(LogMemory::IsEnabled()),
   1381       step_id_(args.step_id),
   1382       rendezvous_(args.rendezvous),
   1383       collective_executor_(args.collective_executor),
   1384       session_state_(args.session_state),
   1385       session_handle_(args.session_handle),
   1386       tensor_store_(args.tensor_store),
   1387       step_container_(args.step_container),
   1388       stats_collector_(args.stats_collector),
   1389       event_collector_(
   1390           tracing::GetEventCollector(tracing::EventCategory::kCompute)),
   1391       context_(ContextKind::kThread),
   1392       slice_reader_cache_(new checkpoint::TensorSliceReaderCacheWrapper),
   1393       call_frame_(args.call_frame),
   1394       impl_(impl),
   1395       cancellation_manager_(args.cancellation_manager),
   1396       runner_(args.runner),
   1397       sync_on_finish_(args.sync_on_finish),
   1398       trace_using_annotations_(impl->params_.device->TraceUsingAnnotations()),
   1399       num_outstanding_ops_(0) {
   1400   // We start the entire execution in iteration 0 of the root frame
   1401   // so let us create the root frame and the state for iteration 0.
   1402   // We assume root_frame_->frame_name.empty().
   1403   root_frame_ = new FrameState(impl_, 1);
   1404   root_frame_->frame_id = 0;  // must be 0
   1405   root_frame_->InitializeFrameInfo(root_frame_->frame_name);
   1406 
   1407   // Initialize iteration 0.
   1408   root_frame_->iterations.resize(root_frame_->max_parallel_iterations);
   1409   root_frame_->iterations[0] = new IterationState(
   1410       root_frame_->pending_counts, root_frame_->total_input_tensors);
   1411 
   1412   outstanding_frames_.insert({root_frame_->frame_name, root_frame_});
   1413 }
   1414 
   1415 ExecutorState::~ExecutorState() {
   1416   for (auto name_frame : outstanding_frames_) {
   1417     delete name_frame.second;
   1418   }
   1419   for (auto it : device_context_map_) {
   1420     it->Unref();
   1421   }
   1422   delete slice_reader_cache_;
   1423 }
   1424 
   1425 Status ExecutorImpl::BuildControlFlowInfo(const Graph* g,
   1426                                           ControlFlowInfo* cf_info) {
   1427   const int num_nodes = g->num_node_ids();
   1428   cf_info->frame_names.resize(num_nodes);
   1429   std::vector<Node*> parent_nodes;
   1430   parent_nodes.resize(num_nodes);
   1431   std::vector<bool> visited;
   1432   visited.resize(num_nodes);
   1433 
   1434   string frame_name;
   1435   std::deque<Node*> ready;
   1436 
   1437   // Initialize with the root nodes.
   1438   for (Node* n : g->nodes()) {
   1439     if (n->in_edges().empty()) {
   1440       visited[n->id()] = true;
   1441       cf_info->unique_frame_names.insert(frame_name);
   1442       ready.push_back(n);
   1443     }
   1444   }
   1445 
   1446   while (!ready.empty()) {
   1447     Node* curr_node = ready.front();
   1448     int curr_id = curr_node->id();
   1449     ready.pop_front();
   1450 
   1451     Node* parent = nullptr;
   1452     if (IsEnter(curr_node)) {
   1453       // Enter a child frame.
   1454       TF_RETURN_IF_ERROR(
   1455           GetNodeAttr(curr_node->attrs(), "frame_name", &frame_name));
   1456       parent = curr_node;
   1457     } else if (IsExit(curr_node)) {
   1458       // Exit to the parent frame.
   1459       parent = parent_nodes[curr_id];
   1460       frame_name = cf_info->frame_names[parent->id()];
   1461       parent = parent_nodes[parent->id()];
   1462     } else {
   1463       parent = parent_nodes[curr_id];
   1464       frame_name = cf_info->frame_names[curr_id];
   1465     }
   1466 
   1467     for (const Edge* out_edge : curr_node->out_edges()) {
   1468       Node* out = out_edge->dst();
   1469       const int out_id = out->id();
   1470 
   1471       // Add to ready queue if not visited.
   1472       bool is_visited = visited[out_id];
   1473       if (!is_visited) {
   1474         ready.push_back(out);
   1475         visited[out_id] = true;
   1476 
   1477         // Process the node 'out'.
   1478         cf_info->frame_names[out_id] = frame_name;
   1479         parent_nodes[out_id] = parent;
   1480         cf_info->unique_frame_names.insert(frame_name);
   1481       }
   1482     }
   1483   }
   1484 
   1485   return Status::OK();
   1486 }
   1487 
   1488 void ExecutorImpl::InitializePending(const Graph* graph,
   1489                                      const ControlFlowInfo& cf_info) {
   1490   for (auto& it : cf_info.unique_frame_names) {
   1491     FrameInfo* finfo = EnsureFrameInfo(it);
   1492     PendingCounts* counts = new PendingCounts(finfo->pending_counts_layout);
   1493     DCHECK_EQ(finfo->pending_counts, nullptr);
   1494     finfo->pending_counts = counts;
   1495   }
   1496   for (const Node* n : graph->nodes()) {
   1497     const int id = n->id();
   1498     const string& name = cf_info.frame_names[id];
   1499     size_t max_pending, max_dead;
   1500     GetMaxPendingCounts(n, &max_pending, &max_dead);
   1501     const NodeItem* item = gview_.node(id);
   1502     PendingCounts* counts = EnsureFrameInfo(name)->pending_counts;
   1503     counts->set_initial_count(item->pending_id, max_pending);
   1504   }
   1505 }
   1506 
   1507 void ExecutorState::RunAsync(Executor::DoneCallback done) {
   1508   const Graph* graph = impl_->graph_.get();
   1509   TaggedNodeSeq ready;
   1510 
   1511   // Ask the device to fill in the device context map.
   1512   Device* device = impl_->params_.device;
   1513   const Status fill_status =
   1514       device->FillContextMap(graph, &device_context_map_);
   1515   if (!fill_status.ok()) {
   1516     delete this;
   1517     done(fill_status);
   1518     return;
   1519   }
   1520 
   1521   // Initialize the ready queue.
   1522   for (const Node* n : impl_->root_nodes_) {
   1523     DCHECK_EQ(n->in_edges().size(), 0);
   1524     ready.push_back(TaggedNode{n, root_frame_, 0, false});
   1525   }
   1526   if (ready.empty()) {
   1527     delete this;
   1528     done(Status::OK());
   1529   } else {
   1530     num_outstanding_ops_ = ready.size();
   1531     root_frame_->iterations[0]->outstanding_ops = ready.size();
   1532     done_cb_ = std::move(done);
   1533     // Schedule to run all the ready ops in thread pool.
   1534     ScheduleReady(ready, nullptr);
   1535   }
   1536 }
   1537 
   1538 // State kept alive for executing an asynchronous node in another
   1539 // thread.  NOTE: We need to make a copy of p.input,
   1540 // p.input_device_contexts, and p.input_alloc_attrs for asynchronous
   1541 // kernels because OpKernelContext methods like input_type(i) needs
   1542 // the param points to valid input type vector. It's not an issue for
   1543 // sync kernels because these vectors are kept on the stack.
   1544 struct ExecutorState::AsyncState {
   1545   AsyncState(const OpKernelContext::Params& p, const TaggedNode& _tagged_node,
   1546              const NodeItem* _item, Entry* _first_input,
   1547              NodeExecStatsInterface* _stats)
   1548       : saved_inputs(*p.inputs),
   1549         saved_input_device_contexts(*p.input_device_contexts),
   1550         saved_input_alloc_attrs(*p.input_alloc_attrs),
   1551         params(p),
   1552         tagged_node(_tagged_node),
   1553         item(_item),
   1554         first_input(_first_input),
   1555         // ParamsButClearingEigenGPUDevice does equivalent of
   1556         //   params.eigen_gpu_device = nullptr;
   1557         ctx(ParamsButClearingEigenGPUDevice(&params), item->num_outputs),
   1558         stats(_stats) {
   1559     params.inputs = &saved_inputs;
   1560     params.input_device_contexts = &saved_input_device_contexts;
   1561     params.input_alloc_attrs = &saved_input_alloc_attrs;
   1562   }
   1563 
   1564   TensorValueVec saved_inputs;
   1565   DeviceContextVec saved_input_device_contexts;
   1566   AllocatorAttributeVec saved_input_alloc_attrs;
   1567   OpKernelContext::Params params;
   1568   TaggedNode tagged_node;
   1569   const NodeItem* item;
   1570   Entry* first_input;
   1571   OpKernelContext ctx;
   1572   NodeExecStatsInterface* stats;
   1573 
   1574  private:
   1575   OpKernelContext::Params* ParamsButClearingEigenGPUDevice(
   1576       OpKernelContext::Params* p) {
   1577     // Ensure OpKernelContext constructor will make a new eigen GPU device if
   1578     // necessary.
   1579     p->eigen_gpu_device = nullptr;  // Force allocation
   1580     return p;
   1581   }
   1582 };
   1583 
   1584 // Returns true if `item` might be traced by the given trace and event
   1585 // collectors. Returns false only if `item` definitely will not be traced.
   1586 bool MightTrace(const NodeItem& item,
   1587                 const tracing::EventCollector* event_collector,
   1588                 bool using_annotations) {
   1589   // Tracing will only be enabled if either `event_collector` is non null,
   1590   // or `trace_collector` is non-null and enabled for this particular kernel.
   1591   // Although `tracing::ScopedActivity`,
   1592   // `tracing::ScopedAnnotation`, and `tracing::ScopedRegion` check subsets of
   1593   // these properties internally in their constructors, the cost of passing the
   1594   // necessary arguments to them can be significant, so we avoid constructing
   1595   // them in the common case (when we know they will not be used).
   1596   if (event_collector != nullptr) {
   1597     return true;
   1598   }
   1599   auto* trace_collector = tracing::GetTraceCollector();
   1600   if (trace_collector) {
   1601     if (using_annotations) {
   1602       return trace_collector->IsEnabledForAnnotations();
   1603     } else {
   1604       return trace_collector->IsEnabledForActivities(
   1605           item.kernel->IsExpensive());
   1606     }
   1607   }
   1608   return false;
   1609 }
   1610 
   1611 void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
   1612   WithContext wc(context_);
   1613   const GraphView& gview = impl_->gview_;
   1614   TaggedNodeSeq ready;
   1615   TaggedNodeReadyQueue inline_ready;
   1616 
   1617   // Parameters passed to OpKernel::Compute.
   1618   TensorValueVec inputs;
   1619   DeviceContextVec input_device_contexts;
   1620   AllocatorAttributeVec input_alloc_attrs;
   1621 
   1622   OpKernelContext::Params params;
   1623   params.step_id = step_id_;
   1624   Device* device = impl_->params_.device;
   1625   params.device = device;
   1626   params.log_memory = log_memory_;
   1627   params.record_tensor_accesses = impl_->device_record_tensor_accesses_;
   1628   params.rendezvous = rendezvous_;
   1629   params.collective_executor = collective_executor_;
   1630   params.session_state = session_state_;
   1631   params.session_handle = session_handle_;
   1632   params.tensor_store = tensor_store_;
   1633   params.cancellation_manager = cancellation_manager_;
   1634   params.call_frame = call_frame_;
   1635   params.function_library = impl_->params_.function_library;
   1636   params.resource_manager = device->resource_manager();
   1637   params.step_container = step_container_;
   1638   params.slice_reader_cache = slice_reader_cache_;
   1639   params.inputs = &inputs;
   1640   params.input_device_contexts = &input_device_contexts;
   1641   params.input_alloc_attrs = &input_alloc_attrs;
   1642   params.runner = &runner_;
   1643   params.stats_collector = stats_collector_;
   1644   params.inc_num_deferred_ops_function = [this]() {
   1645     mutex_lock lock(num_deferred_ops_mu_);
   1646     num_deferred_ops_++;
   1647   };
   1648   params.dec_num_deferred_ops_function = [this]() {
   1649     mutex_lock lock(num_deferred_ops_mu_);
   1650     num_deferred_ops_--;
   1651     if (num_deferred_ops_ == 0) {
   1652       no_deferred_ops_cv_.notify_all();
   1653     }
   1654   };
   1655 
   1656   Status s;
   1657   NodeExecStatsInterface* stats = nullptr;
   1658 
   1659   EntryVector outputs;
   1660   bool completed = false;
   1661   inline_ready.push_back(tagged_node);
   1662   while (!inline_ready.empty()) {
   1663     tagged_node = inline_ready.front();
   1664     inline_ready.pop_front();
   1665     const Node* node = tagged_node.node;
   1666     FrameState* input_frame = tagged_node.input_frame;
   1667     const int64 input_iter = tagged_node.input_iter;
   1668     const int id = node->id();
   1669     const NodeItem& item = *gview.node(id);
   1670 
   1671     // TODO(misard) Replace with a finer-grain enabling flag once we
   1672     // add better optional debugging support.
   1673     if (vlog_ && VLOG_IS_ON(1)) {
   1674       mutex_lock l(input_frame->mu);
   1675       input_frame->GetIteration(input_iter)->mark_started(item.pending_id);
   1676     }
   1677 
   1678     // Set the device_context for this node id, if it exists.
   1679     if (id < device_context_map_.size()) {
   1680       params.op_device_context = device_context_map_[id];
   1681     }
   1682 
   1683     params.track_allocations = false;
   1684     stats = nullptr;
   1685     if (stats_collector_ && !tagged_node.is_dead) {
   1686       stats = stats_collector_->CreateNodeExecStats(node);
   1687       // Track allocations if and only if we are collecting statistics, and
   1688       // `stats` object is expecting allocations to be tracked.
   1689       params.track_allocations = stats ? stats->TrackAllocations() : false;
   1690       nodestats::SetScheduled(stats, scheduled_nsec);
   1691       nodestats::SetAllStart(stats);
   1692     }
   1693 
   1694     if (vlog_) {
   1695       VLOG(1) << "Process node: " << id << " step " << params.step_id << " "
   1696               << SummarizeNode(*node) << (tagged_node.is_dead ? " is dead" : "")
   1697               << " device: " << device->name();
   1698     }
   1699 
   1700     Entry* input_tensors = GetInputTensors(input_frame, input_iter);
   1701     Entry* first_input = input_tensors + item.input_start;
   1702     outputs.clear();
   1703 
   1704     TensorReferenceVector accessed_tensors;
   1705     DeviceContext* device_context = nullptr;
   1706     // Only execute this node if it is not dead or it is a send/recv
   1707     // transfer node. For transfer nodes, we need to propagate the "dead"
   1708     // bit even when the node is dead.
   1709     bool launched_asynchronously = false;
   1710     if (tagged_node.is_dead && !IsTransferNode(node)) {
   1711       outputs.resize(item.num_outputs);
   1712     } else {
   1713       // Prepares inputs.
   1714       bool is_input_dead = false;
   1715       s = PrepareInputs(item, first_input, &inputs, &input_device_contexts,
   1716                         &input_alloc_attrs, &is_input_dead);
   1717       if (!s.ok()) {
   1718         // Clear inputs.
   1719         int num_inputs = item.num_inputs;
   1720         for (int i = 0; i < num_inputs; ++i) {
   1721           (first_input + i)->ClearVal();
   1722         }
   1723         MaybeMarkCompleted(input_frame, input_iter, id);
   1724         // Continue to process the nodes in 'inline_ready'.
   1725         completed = NodeDone(s, item.node, ready, stats, &inline_ready);
   1726         continue;
   1727       }
   1728 
   1729       // Set up compute params.
   1730       OpKernel* op_kernel = item.kernel;
   1731       params.op_kernel = op_kernel;
   1732       params.frame_iter = FrameAndIter(input_frame->frame_id, input_iter);
   1733       params.is_input_dead = is_input_dead;
   1734       params.output_attr_array = item.output_attrs();
   1735       params.forward_from_array = item.forward_from();
   1736 
   1737       if (item.kernel_is_async) {
   1738         // Asynchronous computes.
   1739         AsyncOpKernel* async = item.kernel->AsAsync();
   1740         DCHECK(async != nullptr);
   1741         launched_asynchronously = true;
   1742         AsyncState* state =
   1743             new AsyncState(params, tagged_node, &item, first_input, stats);
   1744 
   1745         auto done = [this, state]() {
   1746           Device* device = impl_->params_.device;
   1747           NodeExecStatsInterface* stats = state->stats;  // Shorthand
   1748           Entry* first_input = state->first_input;       // Shorthand
   1749 
   1750           nodestats::SetOpEnd(stats);
   1751           EntryVector outputs;
   1752           Status s = ProcessOutputs(*state->item, &state->ctx, &outputs, stats);
   1753           nodestats::SetMemory(stats, &state->ctx);
   1754           if (vlog_) {
   1755             VLOG(2) << "Async kernel done: " << state->item->node->id()
   1756                     << " step " << step_id_ << " "
   1757                     << SummarizeNode(*state->item->node)
   1758                     << (state->tagged_node.is_dead ? " is dead" : "")
   1759                     << " device: " << device->name();
   1760           }
   1761 
   1762           // Clears inputs.
   1763           const int num_inputs = state->item->num_inputs;
   1764           for (int i = 0; i < num_inputs; ++i) {
   1765             (first_input + i)->ClearVal();
   1766           }
   1767           FrameState* input_frame = state->tagged_node.input_frame;
   1768           const int64 input_iter = state->tagged_node.input_iter;
   1769           const int id = state->tagged_node.node->id();
   1770           MaybeMarkCompleted(input_frame, input_iter, id);
   1771           TaggedNodeSeq ready;
   1772           if (s.ok()) {
   1773             PropagateOutputs(state->tagged_node, state->item, &outputs, &ready);
   1774           }
   1775           outputs.clear();
   1776           if (s.ok() && impl_->device_record_tensor_accesses_) {
   1777             // Get the list of all tensors accessed during the execution
   1778             TensorReferenceVector accessed;
   1779             state->ctx.retrieve_accessed_tensors(&accessed);
   1780             nodestats::SetReferencedTensors(stats, accessed);
   1781             // callee takes ownership of the vector
   1782             device->ConsumeListOfAccessedTensors(state->ctx.op_device_context(),
   1783                                                  accessed);
   1784           }
   1785           const bool completed =
   1786               NodeDone(s, state->item->node, ready, stats, nullptr);
   1787           delete state;
   1788           if (completed) ScheduleFinish();
   1789         };
   1790         nodestats::SetOpStart(stats);
   1791         device->ComputeAsync(async, &state->ctx, done);
   1792       } else {
   1793         // Synchronous computes.
   1794         OpKernelContext ctx(&params, item.num_outputs);
   1795         nodestats::SetOpStart(stats);
   1796 
   1797         if (TF_PREDICT_FALSE(
   1798                 MightTrace(item, event_collector_, trace_using_annotations_))) {
   1799           const string& op_name = op_kernel->name();
   1800           tracing::ScopedRegion region(tracing::EventCategory::kCompute,
   1801                                        op_name);
   1802           if (trace_using_annotations_) {
   1803             // The OpKernel may create child activities (such as GPU kernel
   1804             // launches), so use a `ScopedAnnotation` to relate these activities
   1805             // in the trace.
   1806             tracing::ScopedAnnotation activity(
   1807                 op_name, strings::StrCat(op_kernel->type_string(),
   1808                                          "#id=", step_id_, "#"));
   1809             device->Compute(op_kernel, &ctx);
   1810           } else {
   1811             // Use the cheaper `ScopedActivity` to trace just the OpKernel
   1812             // execution.
   1813             tracing::ScopedActivity activity(
   1814                 op_name,
   1815                 strings::StrCat(op_kernel->type_string(), "#id=", step_id_,
   1816                                 "#"),
   1817                 item.kernel->IsExpensive());
   1818             device->Compute(op_kernel, &ctx);
   1819           }
   1820         } else {
   1821           // In the common case, avoid creating any tracing objects.
   1822           if (op_kernel->IsExpensive()) {
   1823             KernelTimer timer;
   1824             device->Compute(op_kernel, &ctx);
   1825             op_kernel->UpdateCostEstimate(timer.ElapsedCycles());
   1826           } else {
   1827             device->Compute(op_kernel, &ctx);
   1828           }
   1829         }
   1830 
   1831         nodestats::SetOpEnd(stats);
   1832         s = ProcessOutputs(item, &ctx, &outputs, stats);
   1833         if (s.ok() && impl_->device_record_tensor_accesses_) {
   1834           // Get the list of all tensors accessed during the execution
   1835           ctx.retrieve_accessed_tensors(&accessed_tensors);
   1836           device_context = ctx.op_device_context();
   1837         }
   1838         nodestats::SetMemory(stats, &ctx);
   1839       }
   1840     }
   1841 
   1842     if (!launched_asynchronously) {
   1843       if (vlog_) {
   1844         VLOG(2) << "Synchronous kernel done: " << id << " step "
   1845                 << params.step_id << " " << SummarizeNode(*node)
   1846                 << (tagged_node.is_dead ? " is dead: " : "")
   1847                 << " device: " << device->name();
   1848       }
   1849 
   1850       // Clears inputs.
   1851       const int num_inputs = item.num_inputs;
   1852       for (int i = 0; i < num_inputs; ++i) {
   1853         (first_input + i)->ClearVal();
   1854       }
   1855       MaybeMarkCompleted(input_frame, input_iter, id);
   1856       // Propagates outputs.
   1857       if (s.ok()) {
   1858         PropagateOutputs(tagged_node, &item, &outputs, &ready);
   1859       }
   1860       outputs.clear();
   1861       if (!accessed_tensors.empty()) {
   1862         nodestats::SetReferencedTensors(stats, accessed_tensors);
   1863         // device_context is set above in synchronous computes
   1864         device->ConsumeListOfAccessedTensors(device_context, accessed_tensors);
   1865       }
   1866       if (stats) {
   1867         scheduled_nsec = nodestats::NowInNsec();
   1868       }
   1869       // Postprocess.
   1870       completed = NodeDone(s, item.node, ready, stats, &inline_ready);
   1871     }
   1872   }  // while !inline_ready.empty()
   1873 
   1874   // This thread of computation is done if completed = true.
   1875   if (completed) ScheduleFinish();
   1876 }
   1877 
   1878 Status ExecutorState::PrepareInputs(const NodeItem& item, Entry* first_input,
   1879                                     TensorValueVec* inputs,
   1880                                     DeviceContextVec* input_device_contexts,
   1881                                     AllocatorAttributeVec* input_alloc_attrs,
   1882                                     bool* is_input_dead) {
   1883   const Node* node = item.node;
   1884 
   1885   inputs->clear();
   1886   inputs->resize(item.num_inputs);
   1887   input_device_contexts->clear();
   1888   input_device_contexts->resize(item.num_inputs);
   1889   input_alloc_attrs->clear();
   1890   input_alloc_attrs->resize(item.num_inputs);
   1891 
   1892   *is_input_dead = false;
   1893 
   1894   bool is_merge = item.is_merge;
   1895   for (int i = 0; i < item.num_inputs; ++i) {
   1896     const bool expect_ref = IsRefType(item.input_type(i));
   1897     Entry* entry = first_input + i;
   1898     (*input_device_contexts)[i] = entry->device_context;
   1899     (*input_alloc_attrs)[i] = entry->alloc_attr;
   1900 
   1901     // i-th input.
   1902     TensorValue* inp = &(*inputs)[i];
   1903 
   1904     // Only merge and transfer nodes can have no-value inputs.
   1905     if (!entry->has_value) {
   1906       if (!is_merge) {
   1907         DCHECK(IsTransferNode(node)) << node->name() << " - input " << i;
   1908         DCHECK(!entry->val_field_is_set) << node->name() << " - input " << i;
   1909         entry->has_value = true;
   1910         entry->val_field_is_set = true;
   1911         entry->val.Init(*kEmptyTensor);
   1912         inp->tensor = entry->val.get();
   1913         *is_input_dead = true;
   1914       }
   1915       continue;
   1916     }
   1917     if (entry->ref == nullptr) {
   1918       if (expect_ref) {
   1919         return AttachDef(
   1920             errors::InvalidArgument(i, "-th input expects a ref type"),
   1921             item.kernel->def());
   1922       }
   1923       inp->tensor = entry->val.get();
   1924     } else {
   1925       {
   1926         tf_shared_lock ml(*entry->ref_mu);
   1927         if (!entry->ref->IsInitialized() && !IsInitializationOp(item.node)) {
   1928           return AttachDef(errors::FailedPrecondition(
   1929                                "Attempting to use uninitialized value ",
   1930                                item.kernel->requested_input(i)),
   1931                            item.kernel->def());
   1932         }
   1933       }
   1934       if (expect_ref) {
   1935         inp->mutex_if_ref = entry->ref_mu;
   1936         inp->tensor = entry->ref;
   1937       } else {
   1938         // Automatically deref the tensor ref when the op expects a
   1939         // tensor but is given a ref to a tensor.  Need to deref it
   1940         // under the mutex.
   1941         {
   1942           tf_shared_lock l(*(entry->ref_mu));
   1943           DCHECK(!entry->val_field_is_set);
   1944           entry->val.Init(*entry->ref);
   1945           entry->val_field_is_set = true;
   1946         }
   1947         entry->ref = nullptr;
   1948         entry->ref_mu = nullptr;
   1949 
   1950         inp->tensor = entry->val.get();
   1951         // The dtype of entry->ref could have been changed by another operation
   1952         // that ran after the operation that "produced" it executed, so
   1953         // re-validate that the type of the dereferenced tensor matches the
   1954         // expected input type.
   1955         if (item.input_type(i) != inp->tensor->dtype()) {
   1956           return AttachDef(
   1957               errors::InvalidArgument(
   1958                   i, "-th input expects type ",
   1959                   DataTypeString(item.input_type(i)),
   1960                   " but automatically dereferenced input tensor has type ",
   1961                   DataTypeString(inp->tensor->dtype())),
   1962               item.kernel->def());
   1963         }
   1964       }
   1965     }
   1966   }
   1967   return Status::OK();
   1968 }
   1969 
   1970 Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
   1971                                      EntryVector* outputs,
   1972                                      NodeExecStatsInterface* stats) {
   1973   const Node* node = item.node;
   1974   DCHECK_EQ(0, outputs->size());
   1975   outputs->resize(item.num_outputs);
   1976 
   1977   Status s = ctx->status();
   1978   if (!s.ok()) {
   1979     s = AttachDef(s, item.kernel->def());
   1980     // TODO(misard) Replace with a finer-grain enabling flag once we
   1981     // add better optional debugging support.
   1982     if (vlog_ && VLOG_IS_ON(1)) {
   1983       LOG(WARNING) << this << " Compute status: " << s;
   1984       DumpState();
   1985     }
   1986     if (s.code() == error::RESOURCE_EXHAUSTED) {
   1987       if (stats_collector_) {
   1988         string err = stats_collector_->ReportAllocsOnResourceExhausted(
   1989             s.error_message());
   1990         s = Status(s.code(), strings::StrCat(s.error_message(), err));
   1991       } else {
   1992         s = Status(
   1993             s.code(),
   1994             strings::StrCat(
   1995                 s.error_message(),
   1996                 "\nHint: If you want to see a list of allocated tensors when "
   1997                 "OOM happens, add report_tensor_allocations_upon_oom "
   1998                 "to RunOptions for current allocation info.\n"));
   1999       }
   2000     }
   2001     return s;
   2002   }
   2003 
   2004   // Get the device_context for this node id, if it exists.
   2005   DeviceContext* device_context = nullptr;
   2006   if (node->id() < device_context_map_.size()) {
   2007     device_context = device_context_map_[node->id()];
   2008   }
   2009 
   2010   for (int i = 0; i < item.num_outputs; ++i) {
   2011     const TensorValue val = ctx->release_output(i);
   2012     if (val.tensor == nullptr) {
   2013       // Unless it's a Switch or a Recv, the node must produce a
   2014       // tensor value at i-th output.
   2015       if (!IsSwitch(node) && !IsRecv(node)) {
   2016         s.Update(errors::Internal("Missing ", i, "-th output from ",
   2017                                   FormatNodeForError(*node)));
   2018       }
   2019     } else {
   2020       Entry* out = &((*outputs)[i]);
   2021 
   2022       // Set the device context of the output entry.
   2023       out->device_context = device_context;
   2024 
   2025       // Set the allocator attributes of the output entry.
   2026       out->alloc_attr = ctx->output_alloc_attr(i);
   2027 
   2028       // Sanity check of output tensor types.
   2029       DataType dtype;
   2030       if (val.is_ref()) {
   2031         tf_shared_lock ml(*val.mutex_if_ref);
   2032         dtype = MakeRefType(val->dtype());
   2033       } else {
   2034         dtype = val->dtype();
   2035       }
   2036       if (dtype == item.output_type(i)) {
   2037         if (stats && val.tensor->IsInitialized()) {
   2038           nodestats::SetOutput(stats, i, val.tensor);
   2039         }
   2040         if (val.is_ref()) {
   2041           out->has_value = true;
   2042           out->ref = val.tensor;
   2043           out->ref_mu = val.mutex_if_ref;
   2044           if (log_memory_) {
   2045             Tensor to_log;
   2046             {
   2047               // Dereference the tensor under the lock.
   2048               tf_shared_lock l(*out->ref_mu);
   2049               to_log = *out->ref;
   2050             }
   2051             LogMemory::RecordTensorOutput(ctx->op_kernel().name(),
   2052                                           ctx->step_id(), i, to_log);
   2053           }
   2054         } else {
   2055           // NOTE that std::move is used here, so val.tensor goes to
   2056           // uninitialized state (val.tensor->IsInitialized return false).
   2057           DCHECK(!out->val_field_is_set);
   2058           out->has_value = true;
   2059           out->val_field_is_set = true;
   2060           out->val.Init(std::move(*val.tensor));
   2061           if (log_memory_) {
   2062             LogMemory::RecordTensorOutput(ctx->op_kernel().name(),
   2063                                           ctx->step_id(), i, *out->val);
   2064           }
   2065         }
   2066       } else {
   2067         s.Update(errors::Internal("Output ", i, " of type ",
   2068                                   DataTypeString(dtype),
   2069                                   " does not match declared output type ",
   2070                                   DataTypeString(item.output_type(i)),
   2071                                   " for node ", FormatNodeForError(*node)));
   2072       }
   2073     }
   2074     if (!val.is_ref()) {
   2075       // If OpKernelContext returns outputs via pass-by-value, we
   2076       // don't need this trouble.
   2077       delete val.tensor;
   2078     }
   2079   }
   2080   return s;
   2081 }
   2082 
   2083 void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node,
   2084                                      const NodeItem* item, EntryVector* outputs,
   2085                                      TaggedNodeSeq* ready) {
   2086   auto activity_handle =
   2087       [&]() -> std::unique_ptr<tracing::TraceCollector::Handle> {
   2088     auto* trace_collector = tracing::GetTraceCollector();
   2089     if (TF_PREDICT_FALSE(trace_collector != nullptr &&
   2090                          trace_collector->IsEnabledForActivities(
   2091                              false /* is_expensive */))) {
   2092       const string& op_name = item->kernel->name();
   2093       // Intentionally using ExecutorPropagateOutputs as the first key so that
   2094       // users are aware that it's not the op invocation.
   2095       return trace_collector->CreateActivityHandle(
   2096           "ExecutorPropagateOutputs",
   2097           strings::StrCat(op_name, "#id=", step_id_, "#"),
   2098           false /* is_expensive */);
   2099     } else {
   2100       return nullptr;
   2101     }
   2102   }();
   2103 
   2104   const Node* node = tagged_node.node;
   2105   FrameState* input_frame = tagged_node.input_frame;
   2106   const int64 input_iter = tagged_node.input_iter;
   2107   const bool is_dead = tagged_node.is_dead;
   2108 
   2109   // Propagates outputs along out edges, and puts newly ready nodes
   2110   // into the ready queue.
   2111   ready->clear();
   2112   bool is_frame_done = false;
   2113   FrameState* output_frame = input_frame;
   2114   int64 output_iter = input_iter;
   2115 
   2116   if (!item->is_enter_exit_or_next_iter) {
   2117     // Fast path for nodes types that don't need special handling
   2118     DCHECK_EQ(input_frame, output_frame);
   2119     // Normal path for most nodes
   2120     mutex_lock l(input_frame->mu);
   2121     output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready);
   2122     is_frame_done = input_frame->DecrementOutstandingOpsLocked(
   2123         &impl_->gview_, input_iter, ready);
   2124   } else if (item->is_enter) {
   2125     FindOrCreateChildFrame(input_frame, input_iter, node, &output_frame);
   2126     output_iter = 0;
   2127     {
   2128       const NodeItem* item = impl_->gview_.node(node->id());
   2129       mutex_lock l(output_frame->mu);
   2130       if (item->is_constant_enter) {
   2131         // Propagate to all active iterations if this is a loop invariant.
   2132         output_frame->AddLoopInv(item, (*outputs)[0], ready);
   2133       } else {
   2134         output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready);
   2135       }
   2136       output_frame->num_pending_inputs--;
   2137     }
   2138     is_frame_done =
   2139         input_frame->DecrementOutstandingOps(&impl_->gview_, input_iter, ready);
   2140   } else if (item->is_exit) {
   2141     if (is_dead) {
   2142       mutex_lock l(input_frame->mu);
   2143       // Stop and remember this node if it is a dead exit.
   2144       if (input_iter == input_frame->iteration_count) {
   2145         input_frame->dead_exits.push_back(node);
   2146       }
   2147       is_frame_done = input_frame->DecrementOutstandingOpsLocked(
   2148           &impl_->gview_, input_iter, ready);
   2149     } else {
   2150       output_frame = input_frame->parent_frame;
   2151       output_iter = input_frame->parent_iter;
   2152       {
   2153         mutex_lock l(output_frame->mu);
   2154         output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready);
   2155       }
   2156       is_frame_done = input_frame->DecrementOutstandingOps(&impl_->gview_,
   2157                                                            input_iter, ready);
   2158     }
   2159   } else {
   2160     DCHECK(IsNextIteration(node));
   2161     mutex_lock l(input_frame->mu);
   2162     if (is_dead) {
   2163       // Stop the deadness propagation.
   2164       output_frame = nullptr;
   2165     } else {
   2166       if (input_iter == input_frame->iteration_count &&
   2167           input_frame->num_outstanding_iterations ==
   2168               input_frame->max_parallel_iterations) {
   2169         // Reached the maximum for parallel iterations.
   2170         input_frame->next_iter_roots.push_back({node, (*outputs)[0]});
   2171         output_frame = nullptr;
   2172       } else {
   2173         // If this is a new iteration, start it.
   2174         if (input_iter == input_frame->iteration_count) {
   2175           input_frame->IncrementIteration(&impl_->gview_, ready);
   2176         }
   2177         output_iter = input_iter + 1;
   2178       }
   2179     }
   2180     if (output_frame != nullptr) {
   2181       // This is the case when node is not Enter, Exit, or NextIteration.
   2182       DCHECK(input_frame == output_frame);
   2183       output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready);
   2184     }
   2185     is_frame_done = input_frame->DecrementOutstandingOpsLocked(
   2186         &impl_->gview_, input_iter, ready);
   2187   }
   2188 
   2189   // At this point, this node is completely done. We also know if the
   2190   // completion of this node makes its frame completed.
   2191   if (is_frame_done) {
   2192     FrameState* parent_frame = input_frame->parent_frame;
   2193     const int64 parent_iter = input_frame->parent_iter;
   2194     DeleteFrame(input_frame, ready);
   2195     if (parent_frame != nullptr) {
   2196       // The completion of frame may cause completions in its parent frame.
   2197       // So clean things up recursively.
   2198       CleanupFramesIterations(parent_frame, parent_iter, ready);
   2199     }
   2200   }
   2201 }
   2202 
   2203 bool ExecutorState::NodeDone(const Status& s, const Node* node,
   2204                              const TaggedNodeSeq& ready,
   2205                              NodeExecStatsInterface* stats,
   2206                              TaggedNodeReadyQueue* inline_ready) {
   2207   nodestats::SetAllEnd(stats);
   2208   if (stats) {
   2209     if (stats_collector_) {
   2210       stats->Done(impl_->params_.device->name());
   2211     } else {
   2212       delete stats;
   2213     }
   2214   }
   2215 
   2216   bool abort_run = false;
   2217   if (!s.ok()) {
   2218     // Some error happened. This thread of computation is done.
   2219     mutex_lock l(mu_);
   2220     if (status_.ok()) {
   2221       abort_run = true;
   2222       status_ = s;
   2223     }
   2224   }
   2225   if (abort_run) {
   2226     TRACEPRINTF("StartAbort: %s", s.ToString().c_str());
   2227     if (rendezvous_) {
   2228       rendezvous_->StartAbort(s);
   2229     }
   2230     if (collective_executor_) {
   2231       collective_executor_->StartAbort(s);
   2232     }
   2233     if (cancellation_manager_) {
   2234       cancellation_manager_->StartCancel();
   2235     }
   2236   }
   2237 
   2238   bool completed = false;
   2239   const size_t ready_size = ready.size();
   2240   if (ready_size == 0 || !s.ok()) {
   2241     completed = (num_outstanding_ops_.fetch_sub(1) == 1);
   2242   } else if (ready_size > 1) {
   2243     num_outstanding_ops_.fetch_add(ready_size - 1, std::memory_order_relaxed);
   2244   }
   2245 
   2246   // Schedule the ready nodes in 'ready'.
   2247   if (s.ok()) {
   2248     ScheduleReady(ready, inline_ready);
   2249   }
   2250   return completed;
   2251 }
   2252 
   2253 void ExecutorState::ScheduleReady(const TaggedNodeSeq& ready,
   2254                                   TaggedNodeReadyQueue* inline_ready) {
   2255   if (ready.empty()) return;
   2256 
   2257   int64 scheduled_nsec = 0;
   2258   if (stats_collector_) {
   2259     scheduled_nsec = nodestats::NowInNsec();
   2260   }
   2261 
   2262   if (inline_ready == nullptr) {
   2263     // Schedule to run all the ready ops in thread pool.
   2264     for (auto& tagged_node : ready) {
   2265       runner_([=]() { Process(tagged_node, scheduled_nsec); });
   2266     }
   2267     return;
   2268   }
   2269 
   2270   const GraphView& gview = impl_->gview_;
   2271   const TaggedNode* curr_expensive_node = nullptr;
   2272   for (auto& tagged_node : ready) {
   2273     const NodeItem& item = *gview.node(tagged_node.node->id());
   2274     if (tagged_node.is_dead || !item.kernel->IsExpensive()) {
   2275       // Inline this inexpensive node.
   2276       inline_ready->push_back(tagged_node);
   2277     } else {
   2278       if (curr_expensive_node) {
   2279         // Dispatch to another thread since there is plenty of work to
   2280         // do for this thread.
   2281         runner_(std::bind(&ExecutorState::Process, this, *curr_expensive_node,
   2282                           scheduled_nsec));
   2283       }
   2284       curr_expensive_node = &tagged_node;
   2285     }
   2286   }
   2287   if (curr_expensive_node) {
   2288     if (inline_ready->empty()) {
   2289       // Tail recursion optimization
   2290       inline_ready->push_back(*curr_expensive_node);
   2291     } else {
   2292       // There are inline nodes to run already. We dispatch this expensive
   2293       // node to other thread.
   2294       runner_(std::bind(&ExecutorState::Process, this, *curr_expensive_node,
   2295                         scheduled_nsec));
   2296     }
   2297   }
   2298 }
   2299 
   2300 inline void ExecutorState::MaybeMarkCompleted(FrameState* frame, int64 iter,
   2301                                               int64 node_id) {
   2302   // TODO(misard) Replace with a finer-grain enabling flag once we
   2303   // add better optional debugging support.
   2304   if (vlog_ && VLOG_IS_ON(1)) {
   2305     const NodeItem* item = impl_->gview_.node(node_id);
   2306     mutex_lock l(frame->mu);
   2307     frame->GetIteration(iter)->mark_completed(item->pending_id);
   2308   }
   2309 }
   2310 
   2311 const Tensor* ExecutorState::GetTensorValueForDump(const Entry& input) {
   2312   if (!input.has_value) {
   2313     return kEmptyTensor;
   2314   } else if (input.ref == nullptr) {
   2315     return input.val.get();
   2316   } else {
   2317     return input.ref;
   2318   }
   2319 }
   2320 
   2321 void ExecutorState::DumpPendingNodeState(
   2322     const int node_id, const Entry* input_vector,
   2323     const bool show_nodes_with_no_ready_inputs) {
   2324   const NodeItem& node_item = *impl_->gview_.node(node_id);
   2325   const Node& node = *node_item.node;
   2326   const int input_base = node_item.input_start;
   2327   if (!show_nodes_with_no_ready_inputs) {
   2328     bool has_ready_input = false;
   2329     for (int i = 0; i < node.num_inputs(); ++i) {
   2330       const Entry& input = input_vector[input_base + i];
   2331       const Tensor* tensor = GetTensorValueForDump(input);
   2332       if (tensor->IsInitialized()) {
   2333         has_ready_input = true;
   2334         break;
   2335       }
   2336     }
   2337     if (!has_ready_input) {
   2338       return;
   2339     }
   2340   }
   2341   LOG(WARNING) << "    Pending Node: " << node.DebugString();
   2342   for (int i = 0; i < node.num_inputs(); ++i) {
   2343     const Entry& input = input_vector[input_base + i];
   2344     const Tensor* tensor = GetTensorValueForDump(input);
   2345     if (tensor->IsInitialized()) {
   2346       LOG(WARNING) << "      Input " << i << ": "
   2347                    << strings::StrCat(
   2348                           "Tensor<type: ", DataTypeString(tensor->dtype()),
   2349                           " shape: ", tensor->shape().DebugString(), ">");
   2350     } else {
   2351       LOG(WARNING) << "      Input " << i << ": not present";
   2352     }
   2353   }
   2354 }
   2355 
   2356 void ExecutorState::DumpActiveNodeState(const int node_id,
   2357                                         const Entry* input_vector) {
   2358   const NodeItem& node_item = *impl_->gview_.node(node_id);
   2359   const Node& node = *node_item.node;
   2360   LOG(WARNING) << "    Active Node: " << node.DebugString();
   2361   const int input_base = node_item.input_start;
   2362   for (int i = 0; i < node.num_inputs(); ++i) {
   2363     const Entry& input = input_vector[input_base + i];
   2364     const Tensor* tensor = GetTensorValueForDump(input);
   2365     if (tensor->IsInitialized()) {
   2366       LOG(WARNING) << "      Input " << i << ": "
   2367                    << strings::StrCat(
   2368                           "Tensor<type: ", DataTypeString(tensor->dtype()),
   2369                           " shape: ", tensor->shape().DebugString(), ">");
   2370     } else {
   2371       LOG(WARNING) << "      Input " << i << ": not present";
   2372     }
   2373   }
   2374 }
   2375 
   2376 void ExecutorState::DumpIterationState(const FrameState* frame,
   2377                                        IterationState* iteration) {
   2378   const std::vector<const Node*>* nodes = frame->nodes;
   2379   // Dump any waiting nodes that are holding on to tensors.
   2380   for (const Node* node : *nodes) {
   2381     const int node_id = node->id();
   2382     PendingCounts::Handle pending_id = impl_->gview_.node(node_id)->pending_id;
   2383     if (iteration->node_state(pending_id) == PendingCounts::PENDING_NOTREADY ||
   2384         iteration->node_state(pending_id) == PendingCounts::PENDING_READY) {
   2385       DumpPendingNodeState(node_id, iteration->input_tensors, false);
   2386     }
   2387   }
   2388   // Then the active nodes.
   2389   for (const Node* node : *nodes) {
   2390     const int node_id = node->id();
   2391     PendingCounts::Handle pending_id = impl_->gview_.node(node_id)->pending_id;
   2392     if (iteration->node_state(pending_id) == PendingCounts::STARTED) {
   2393       DumpActiveNodeState(node_id, iteration->input_tensors);
   2394     }
   2395   }
   2396   // Show all input tensors in use.
   2397   const int total_input_tensors = frame->total_input_tensors;
   2398   size_t total_bytes = 0;
   2399   for (int i = 0; i < total_input_tensors; ++i) {
   2400     const Entry& input = iteration->input_tensors[i];
   2401     const Tensor* tensor = GetTensorValueForDump(input);
   2402     if (tensor->IsInitialized()) {
   2403       LOG(WARNING) << "    Input " << i << ": "
   2404                    << strings::StrCat(
   2405                           "Tensor<type: ", DataTypeString(tensor->dtype()),
   2406                           " shape: ", tensor->shape().DebugString(),
   2407                           ", bytes: ", tensor->TotalBytes(), ">");
   2408       total_bytes += tensor->TotalBytes();
   2409     }
   2410   }
   2411   LOG(WARNING) << "    Total bytes " << total_bytes;
   2412 }
   2413 
   2414 void ExecutorState::DumpState() {
   2415   mutex_lock l(mu_);
   2416   if (!dumped_on_error_) {
   2417     LOG(WARNING) << "Dumping state";
   2418     for (auto& frame : outstanding_frames_) {
   2419       LOG(WARNING) << frame.first;
   2420       FrameState* frame_state = frame.second;
   2421       mutex_lock frame_lock(frame_state->mu);
   2422       for (IterationState* iteration : frame_state->iterations) {
   2423         LOG(WARNING) << "  Iteration:";
   2424         DumpIterationState(frame_state, iteration);
   2425       }
   2426     }
   2427     dumped_on_error_ = true;
   2428   }
   2429 }
   2430 
   2431 void ExecutorState::ScheduleFinish() {
   2432   int num_deferred_ops;
   2433   {
   2434     mutex_lock lock(num_deferred_ops_mu_);
   2435     num_deferred_ops = num_deferred_ops_;
   2436   }
   2437   if (num_deferred_ops > 0) {
   2438     // Finish() may be blocked waiting for deferred async ops to complete. The
   2439     // execution of deferred async ops may be waiting for non-enqueued ops of
   2440     // other executors to complete. So running Finish() on the current thread
   2441     // (inter-op threadpool thread) may lead to a deadlock due to threadpool
   2442     // exhaustion. Instead, we run it on a separate thread to unblock the
   2443     // threadpool thread.
   2444     Env::Default()->SchedClosure([this]() { Finish(); });
   2445   } else {
   2446     Finish();
   2447   }
   2448 }
   2449 
   2450 void ExecutorState::Finish() {
   2451   mu_.lock();
   2452   auto status = status_;
   2453   auto done_cb = std::move(done_cb_);
   2454   auto runner = std::move(runner_);
   2455   mu_.unlock();
   2456   CHECK(done_cb != nullptr);
   2457   Device* device = impl_->params_.device;
   2458 
   2459   // There are several potential race conditions below. To name a few:
   2460   // 1. Even if the device's status is OK at the precise moment when
   2461   // num_deferred_ops_ reaches 0, it could go bad before device->RefreshStatus()
   2462   // is called below, caused by work enqueued onto the same device by other
   2463   // concurrent ExecutorState objects.
   2464   // 2. Some implementations of Device::RefreshStatus, such as
   2465   // XlaDevice::RefreshStatus, may be inherently racy because it releases the
   2466   // device mutex after a stream pointer is acquired and before the stream is
   2467   // queried for status.
   2468   // 3. It's the same for some implementations of Device::Sync, such as
   2469   // XlaDevice::Sync.
   2470   //
   2471   // However, these race conditions are acceptable because a stream (and
   2472   // therefore an XlaDevice) can only go from OK to not-OK, never the opposite,
   2473   // which means we will at worst report errors when there isn't any, never the
   2474   // opposite.
   2475 
   2476   // If inc_num_deferred_ops_function has ever been called, ExecutorState must
   2477   // wait for all corresponding dec_num_deferred_ops_function calls to happen
   2478   // regardless of status. This ensures that dec_num_deferred_ops_function can
   2479   // safely use ExecutorState's resources.
   2480   {
   2481     mutex_lock lock(num_deferred_ops_mu_);
   2482     while (num_deferred_ops_ > 0) {
   2483       no_deferred_ops_cv_.wait(lock);
   2484     }
   2485   }
   2486 
   2487   // An early exit for devices don't allow sync on completion. Ops that run on
   2488   // these devices should have used num_deferred_ops correctly to ensure the
   2489   // device has finished all relevant work at this point.
   2490   if (!device->AllowsSyncOnCompletion()) {
   2491     status.Update(device->RefreshStatus());
   2492     if (!status.ok()) {
   2493       // In device async execution mode, it's possible for device execution to
   2494       // lag behind ExecutorState scheduling so much that this is the first
   2495       // place a device execution error surfaces.
   2496       // If so, all ExecutorState::NodeDone calls have already happened with OK
   2497       // status. This is the last defense where StartCancel must be called to
   2498       // abort all computation still running on any device.
   2499       // TODO(b/124523000): Always call Finish in a separate thread, so even if
   2500       // StartCancel blocks the current thread's execution, we won't encounter
   2501       // deadlocks caused by inter-op thread exhaustion.
   2502       if (cancellation_manager_) {
   2503         cancellation_manager_->StartCancel();
   2504       }
   2505     }
   2506     delete this;
   2507     runner([=]() { done_cb(status); });
   2508     return;
   2509   }
   2510 
   2511   if (sync_on_finish_ && status.ok()) {
   2512     // Block until the device has finished all queued operations. For
   2513     // devices like GPUs that continue to execute Ops after their Compute
   2514     // methods have completed, this ensures that control is not returned to
   2515     // the user until the step (and its side-effects) has actually completed.
   2516     device->Sync([=](Status new_status) mutable {
   2517       status.Update(new_status);
   2518       delete this;
   2519       runner([=]() { done_cb(status); });
   2520     });
   2521   } else {
   2522     delete this;
   2523     runner([=]() { done_cb(status); });
   2524   }
   2525 }
   2526 
   2527 void ExecutorState::FindOrCreateChildFrame(FrameState* frame, int64 iter,
   2528                                            const Node* node,
   2529                                            FrameState** child) {
   2530   // Get the child frame name.
   2531   string enter_name;
   2532   Status s = GetNodeAttr(node->attrs(), "frame_name", &enter_name);
   2533   DCHECK(s.ok()) << s;
   2534   const string child_name = MakeFrameName(frame, iter, enter_name);
   2535 
   2536   {
   2537     mutex_lock executor_lock(mu_);
   2538     auto it = outstanding_frames_.find(child_name);
   2539     if (it != outstanding_frames_.end()) {
   2540       *child = it->second;
   2541       return;
   2542     }
   2543   }
   2544 
   2545   // Need to create a new frame instance.
   2546   // Note that this new frame instance is created without any locks.
   2547   if (vlog_) VLOG(2) << "Create frame: " << child_name;
   2548 
   2549   int parallel_iters;
   2550   s = GetNodeAttr(node->attrs(), "parallel_iterations", &parallel_iters);
   2551   DCHECK(s.ok()) << s;
   2552   FrameState* temp = new FrameState(impl_, parallel_iters);
   2553   temp->frame_name = child_name;
   2554   temp->frame_id = Hash64(child_name);
   2555   temp->parent_frame = frame;
   2556   temp->parent_iter = iter;
   2557   temp->InitializeFrameInfo(enter_name);
   2558 
   2559   // 'iterations' is a fixed-length circular buffer.
   2560   temp->iterations.resize(temp->max_parallel_iterations + 1);
   2561   // Initialize iteration 0.
   2562   temp->iterations[0] =
   2563       new IterationState(temp->pending_counts, temp->total_input_tensors);
   2564 
   2565   {
   2566     mutex_lock executor_lock(mu_);
   2567     auto it = outstanding_frames_.find(child_name);
   2568     if (it != outstanding_frames_.end()) {
   2569       *child = it->second;
   2570     } else {
   2571       mutex_lock frame_lock(frame->mu);
   2572       frame->GetIteration(iter)->outstanding_frame_count++;
   2573       outstanding_frames_[child_name] = temp;
   2574       *child = temp;
   2575       temp = nullptr;
   2576     }
   2577   }
   2578   delete temp;  // Not used so delete it.
   2579 }
   2580 
   2581 void ExecutorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) {
   2582   // First, propagate dead_exits (if any) to the parent frame.
   2583   FrameState* parent_frame = frame->parent_frame;
   2584   const int64 parent_iter = frame->parent_iter;
   2585   if (parent_frame != nullptr) {
   2586     mutex_lock parent_frame_lock(parent_frame->mu);
   2587     // Propagate all the dead exits to the parent frame.
   2588     mutex_lock this_frame_lock(frame->mu);
   2589     for (const Node* node : frame->dead_exits) {
   2590       auto parent_iter_state = parent_frame->GetIteration(parent_iter);
   2591       for (const Edge* e : node->out_edges()) {
   2592         const Node* dst_node = e->dst();
   2593 
   2594         const auto dst_pending_id =
   2595             impl_->gview_.node(dst_node->id())->pending_id;
   2596 
   2597         // TODO(yuanbyu): We don't need this if we require the subgraph
   2598         // given to an executor not to contain a sink node.
   2599         if (dst_node->IsSink()) continue;
   2600 
   2601         bool dst_dead = true;
   2602         bool dst_ready = false;
   2603         // We know this is a dead input to dst.
   2604         if (IsMerge(dst_node)) {
   2605           if (e->IsControlEdge()) {
   2606             parent_iter_state->decrement_pending(dst_pending_id, 2);
   2607             int count = parent_iter_state->pending(dst_pending_id);
   2608             int dead_cnt = parent_iter_state->dead_count(dst_pending_id);
   2609             dst_dead = (dead_cnt == dst_node->num_inputs());
   2610             dst_ready = (count == 0) || ((count == 1) && dst_dead);
   2611           } else {
   2612             parent_iter_state->increment_dead_count(dst_pending_id);
   2613             const int dead_cnt = parent_iter_state->dead_count(dst_pending_id);
   2614             dst_dead = (dead_cnt == dst_node->num_inputs());
   2615             dst_ready =
   2616                 (parent_iter_state->pending(dst_pending_id) == 1) && dst_dead;
   2617           }
   2618         } else {
   2619           parent_iter_state->increment_dead_count(dst_pending_id);
   2620           dst_ready =
   2621               (parent_iter_state->decrement_pending(dst_pending_id, 1) == 0);
   2622         }
   2623         if (dst_ready) {
   2624           if (IsControlTrigger(dst_node)) dst_dead = false;
   2625           ready->emplace_back(dst_node, parent_frame, parent_iter, dst_dead);
   2626           parent_iter_state->outstanding_ops++;
   2627         }
   2628       }
   2629     }
   2630   }
   2631 
   2632   // Delete the frame.
   2633   const string& frame_name = frame->frame_name;
   2634   if (vlog_) VLOG(2) << "Delete frame " << frame_name;
   2635   {
   2636     mutex_lock executor_lock(mu_);
   2637     outstanding_frames_.erase(frame_name);
   2638   }
   2639   delete frame;
   2640 }
   2641 
   2642 void ExecutorState::CleanupFramesIterations(FrameState* frame, int64 iter,
   2643                                             TaggedNodeSeq* ready) {
   2644   bool is_frame_done = false;
   2645   {
   2646     mutex_lock frame_lock(frame->mu);
   2647     frame->GetIteration(iter)->outstanding_frame_count--;
   2648     is_frame_done = frame->CleanupIterations(&impl_->gview_, iter, ready);
   2649   }
   2650   if (is_frame_done) {
   2651     FrameState* parent_frame = frame->parent_frame;
   2652     const int64 parent_iter = frame->parent_iter;
   2653     DeleteFrame(frame, ready);
   2654     if (parent_frame != nullptr) {
   2655       // The completion of frame may cause completions in its parent frame.
   2656       // So clean things up recursively.
   2657       CleanupFramesIterations(parent_frame, parent_iter, ready);
   2658     }
   2659   }
   2660 }
   2661 
   2662 void ExecutorState::FrameState::ActivateNodes(const NodeItem* item,
   2663                                               const bool is_dead, int64 iter,
   2664                                               EntryVector* outputs,
   2665                                               TaggedNodeSeq* ready) {
   2666   const GraphView& gview = executor->gview_;
   2667   IterationState* iter_state = GetIteration(iter);
   2668   const size_t num_output_edges = item->num_output_edges;
   2669   const EdgeInfo* edges = item->output_edge_list();
   2670   Entry* input_tensors = iter_state->input_tensors;
   2671   for (size_t out_index = 0; out_index < num_output_edges; out_index++) {
   2672     const EdgeInfo& e = edges[out_index];
   2673     const int dst_id = e.dst_id;
   2674     const NodeItem* dst_item = gview.node(dst_id);
   2675     const PendingCounts::Handle dst_pending_id = dst_item->pending_id;
   2676     const int src_slot = e.output_slot;
   2677 
   2678     // TODO(yuanbyu): We don't need this if we require the subgraph
   2679     // given to an executor not to contain a sink node.
   2680     if (dst_item->is_sink) continue;
   2681 
   2682     bool dst_dead = false;
   2683     bool dst_ready = false;
   2684     // True iff this input for dst is needed. We only set this input for
   2685     // dst if this flag is true. This is needed to make the thread safety
   2686     // analysis happy.
   2687     const bool is_control_edge = (src_slot == Graph::kControlSlot);
   2688     bool dst_need_input = !is_control_edge;
   2689     if (dst_item->is_merge) {
   2690       // A merge node is ready if all control inputs have arrived and either
   2691       // a) a live data input becomes available or b) all data inputs are
   2692       // dead. For Merge, pending's LSB is set iff a live data input has
   2693       // arrived.
   2694       if (is_control_edge) {
   2695         iter_state->decrement_pending(dst_pending_id, 2);
   2696         int count = iter_state->pending(dst_pending_id);
   2697         int dead_cnt = iter_state->dead_count(dst_pending_id);
   2698         dst_dead = (dead_cnt == dst_item->num_inputs);
   2699         dst_ready = (count == 0) || ((count == 1) && dst_dead);
   2700       } else {
   2701         if ((*outputs)[src_slot].has_value) {
   2702           // This is a live data input.
   2703           int count = iter_state->pending(dst_pending_id);
   2704           iter_state->mark_live(dst_pending_id);
   2705           // Only the first live edge sets the input and (potentially)
   2706           // triggers execution. The low bit of count is set if and
   2707           // only if no live input has been used yet (mark_live clears
   2708           // it). The node should be started if and only if this is
   2709           // the first live input and there are no pending control
   2710           // edges, i.e. count == 1.
   2711           dst_ready = (count == 1);
   2712           dst_need_input = ((count & 0x1) == 1);
   2713         } else {
   2714           // This is a dead data input. Note that dst_node is dead if node is
   2715           // a dead enter. We need this to handle properly a while loop on
   2716           // the untaken branch of a conditional.
   2717           // TODO(yuanbyu): This is a bit hacky, but a good solution for
   2718           // now.
   2719           iter_state->increment_dead_count(dst_pending_id);
   2720           const int dead_cnt = iter_state->dead_count(dst_pending_id);
   2721           dst_dead = (dead_cnt == dst_item->num_inputs) || item->is_enter;
   2722           dst_ready = (iter_state->pending(dst_pending_id) == 1) && dst_dead;
   2723           dst_need_input = false;
   2724         }
   2725       }
   2726     } else {
   2727       const bool increment_dead =
   2728           (is_dead || (!is_control_edge && !(*outputs)[src_slot].has_value));
   2729       int pending, dead;
   2730       iter_state->adjust_for_activation(dst_pending_id, increment_dead,
   2731                                         &pending, &dead);
   2732       dst_dead = (dead > 0);
   2733       dst_ready = (pending == 0);
   2734     }
   2735 
   2736     if (dst_need_input) {
   2737       const int dst_slot = e.input_slot;
   2738       const int dst_loc = dst_item->input_start + dst_slot;
   2739       if (e.is_last) {
   2740         input_tensors[dst_loc] = std::move((*outputs)[src_slot]);
   2741       } else {
   2742         input_tensors[dst_loc] = (*outputs)[src_slot];
   2743       }
   2744     }
   2745 
   2746     // Add dst to the ready queue if it's ready
   2747     if (dst_ready) {
   2748       if (dst_item->is_control_trigger) dst_dead = false;
   2749       ready->emplace_back(dst_item->node, this, iter, dst_dead);
   2750       iter_state->outstanding_ops++;
   2751     }
   2752   }
   2753 }
   2754 
   2755 void ExecutorState::FrameState::ActivateNexts(const GraphView* gview,
   2756                                               int64 iter,
   2757                                               TaggedNodeSeq* ready) {
   2758   // Propagate the deferred NextIteration nodes to the new iteration.
   2759   for (auto& node_entry : next_iter_roots) {
   2760     const Node* node = node_entry.first;
   2761     const Entry& entry = node_entry.second;
   2762     const bool is_dead = !entry.has_value;
   2763     const NodeItem* item = gview->node(node->id());
   2764     EntryVector outputs{entry};
   2765     ActivateNodes(item, is_dead, iter, &outputs, ready);
   2766   }
   2767   next_iter_roots.clear();
   2768 }
   2769 
   2770 void ExecutorState::FrameState::ActivateLoopInvs(const GraphView* gview,
   2771                                                  int64 iter,
   2772                                                  TaggedNodeSeq* ready) {
   2773   // Propagate loop invariants to the new iteration.
   2774   for (auto& node_entry : inv_values) {
   2775     const Node* node = node_entry.first;
   2776     const Entry& entry = node_entry.second;
   2777     const bool is_dead = !entry.has_value;
   2778     const NodeItem* item = gview->node(node->id());
   2779     EntryVector outputs{entry};
   2780     ActivateNodes(item, is_dead, iter, &outputs, ready);
   2781   }
   2782 }
   2783 
   2784 void ExecutorState::FrameState::AddLoopInv(const NodeItem* item,
   2785                                            const Entry& entry,
   2786                                            TaggedNodeSeq* ready) {
   2787   // Store this value.
   2788   inv_values.push_back({item->node, entry});
   2789 
   2790   // Make this value available to all iterations.
   2791   const bool is_dead = !entry.has_value;
   2792   for (int i = 0; i <= iteration_count; ++i) {
   2793     EntryVector outputs{entry};
   2794     ActivateNodes(item, is_dead, i, &outputs, ready);
   2795   }
   2796 }
   2797 
   2798 bool ExecutorState::FrameState::IsIterationDone(int64 iter) {
   2799   IterationState* iter_state = GetIteration(iter);
   2800   if (iter_state->outstanding_ops == 0 &&
   2801       iter_state->outstanding_frame_count == 0) {
   2802     if (iter == 0) {
   2803       // The enclosing frame has no pending input.
   2804       return num_pending_inputs == 0;
   2805     } else {
   2806       // The preceding iteration is deleted (and therefore done).
   2807       return (GetIteration(iter - 1) == nullptr);
   2808     }
   2809   }
   2810   return false;
   2811 }
   2812 
   2813 void ExecutorState::FrameState::IncrementIteration(const GraphView* gview,
   2814                                                    TaggedNodeSeq* ready) {
   2815   iteration_count++;
   2816   const int64 next_iter = iteration_count;
   2817 
   2818   // Initialize the next iteration.
   2819   IterationState* iter_state =
   2820       new IterationState(pending_counts, total_input_tensors);
   2821   SetIteration(next_iter, iter_state);
   2822   num_outstanding_iterations++;
   2823   dead_exits.clear();
   2824 
   2825   // Activate the successors of the deferred roots in the new iteration.
   2826   ActivateNexts(gview, next_iter, ready);
   2827 
   2828   // Activate the loop invariants in the new iteration.
   2829   ActivateLoopInvs(gview, next_iter, ready);
   2830 }
   2831 
   2832 bool ExecutorState::FrameState::CleanupIterations(const GraphView* gview,
   2833                                                   int64 iter,
   2834                                                   TaggedNodeSeq* ready) {
   2835   int64 curr_iter = iter;
   2836   while (curr_iter <= iteration_count && IsIterationDone(curr_iter)) {
   2837     // Delete the iteration curr_iter.
   2838     delete GetIteration(curr_iter);
   2839     SetIteration(curr_iter, nullptr);
   2840     --num_outstanding_iterations;
   2841     ++curr_iter;
   2842 
   2843     // When one iteration is completed, we check for deferred iteration,
   2844     // and start it if there is one.
   2845     if (!next_iter_roots.empty()) {
   2846       IncrementIteration(gview, ready);
   2847     }
   2848   }
   2849   return IsFrameDone();
   2850 }
   2851 
   2852 void ExecutorImpl::RunAsync(const Args& args, DoneCallback done) {
   2853   (new ExecutorState(args, this))->RunAsync(std::move(done));
   2854 }
   2855 
   2856 }  // namespace
   2857 
   2858 Status NewLocalExecutor(const LocalExecutorParams& params,
   2859                         std::unique_ptr<const Graph> graph,
   2860                         Executor** executor) {
   2861   ExecutorImpl* impl = new ExecutorImpl(params, std::move(graph));
   2862   const Status s = impl->Initialize();
   2863   if (s.ok()) {
   2864     *executor = impl;
   2865   } else {
   2866     delete impl;
   2867   }
   2868   return s;
   2869 }
   2870 
   2871 Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib,
   2872                              const NodeDef& ndef, int graph_def_version,
   2873                              OpKernel** kernel) {
   2874   const auto device_type = DeviceType(device->attributes().device_type());
   2875   auto allocator = device->GetAllocator(AllocatorAttributes());
   2876   return CreateOpKernel(device_type, device, allocator, flib, ndef,
   2877                         graph_def_version, kernel);
   2878 }
   2879 
   2880 void DeleteNonCachedKernel(OpKernel* kernel) { delete kernel; }
   2881 
   2882 namespace {
   2883 
   2884 class DefaultExecutorRegistrar {
   2885  public:
   2886   DefaultExecutorRegistrar() {
   2887     Factory* factory = new Factory;
   2888     ExecutorFactory::Register("", factory);
   2889     ExecutorFactory::Register("DEFAULT", factory);
   2890   }
   2891 
   2892  private:
   2893   class Factory : public ExecutorFactory {
   2894     Status NewExecutor(const LocalExecutorParams& params,
   2895                        std::unique_ptr<const Graph> graph,
   2896                        std::unique_ptr<Executor>* out_executor) override {
   2897       Executor* ret = nullptr;
   2898       TF_RETURN_IF_ERROR(NewLocalExecutor(params, std::move(graph), &ret));
   2899       out_executor->reset(ret);
   2900       return Status::OK();
   2901     }
   2902   };
   2903 };
   2904 static DefaultExecutorRegistrar registrar;
   2905 
   2906 }  // namespace
   2907 
   2908 }  // namespace tensorflow
   2909