Home | History | Annotate | Download | only in common_runtime
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/common_runtime/executor.h"
     17 
     18 #include <atomic>
     19 #include <deque>
     20 #include <memory>
     21 #include <string>
     22 #include <unordered_map>
     23 #include <vector>
     24 
     25 #include "tensorflow/core/common_runtime/costmodel_manager.h"
     26 #include "tensorflow/core/common_runtime/pending_counts.h"
     27 #include "tensorflow/core/common_runtime/step_stats_collector.h"
     28 #include "tensorflow/core/framework/allocation_description.pb.h"
     29 #include "tensorflow/core/framework/allocator.h"
     30 #include "tensorflow/core/framework/cancellation.h"
     31 #include "tensorflow/core/framework/control_flow.h"
     32 #include "tensorflow/core/framework/device_attributes.pb.h"
     33 #include "tensorflow/core/framework/graph.pb.h"
     34 #include "tensorflow/core/framework/log_memory.h"
     35 #include "tensorflow/core/framework/node_def_util.h"
     36 #include "tensorflow/core/framework/op_kernel.h"
     37 #include "tensorflow/core/framework/op_segment.h"
     38 #include "tensorflow/core/framework/step_stats.pb.h"
     39 #include "tensorflow/core/framework/tensor.h"
     40 #include "tensorflow/core/framework/tensor_reference.h"
     41 #include "tensorflow/core/framework/types.h"
     42 #include "tensorflow/core/framework/types.pb.h"
     43 #include "tensorflow/core/graph/edgeset.h"
     44 #include "tensorflow/core/lib/core/errors.h"
     45 #include "tensorflow/core/lib/core/notification.h"
     46 #include "tensorflow/core/lib/core/stringpiece.h"
     47 #include "tensorflow/core/lib/gtl/flatmap.h"
     48 #include "tensorflow/core/lib/gtl/flatset.h"
     49 #include "tensorflow/core/lib/gtl/inlined_vector.h"
     50 #include "tensorflow/core/lib/gtl/manual_constructor.h"
     51 #include "tensorflow/core/lib/gtl/stl_util.h"
     52 #include "tensorflow/core/lib/hash/hash.h"
     53 #include "tensorflow/core/lib/strings/str_util.h"
     54 #include "tensorflow/core/lib/strings/stringprintf.h"
     55 #include "tensorflow/core/platform/logging.h"
     56 #include "tensorflow/core/platform/macros.h"
     57 #include "tensorflow/core/platform/mutex.h"
     58 #include "tensorflow/core/platform/thread_annotations.h"
     59 #include "tensorflow/core/platform/tracing.h"
     60 #include "tensorflow/core/platform/types.h"
     61 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
     62 
     63 namespace tensorflow {
     64 namespace {
     65 
     66 // 1-D, 0 element tensor.
     67 static const Tensor* const kEmptyTensor = new Tensor;
     68 
     69 bool IsInitializationOp(const Node* node) {
     70   return node->op_def().allows_uninitialized_input();
     71 }
     72 
     73 // Sets the timeline_label field of *node_stats, using data from *node.
     74 // Returns true iff the node is a transfer node.
     75 // TODO(tucker): merge with the DetailText function in session.cc
     76 // in a common location.
     77 bool SetTimelineLabel(const Node* node, NodeExecStatsWrapper* stats) {
     78   bool is_transfer_node = false;
     79   if (!stats) {
     80     return is_transfer_node;
     81   }
     82   string memory;
     83   for (auto& all : stats->stats()->memory()) {
     84     int64 tot = all.total_bytes();
     85     if (tot >= 0.1 * 1048576.0) {
     86       int64 peak = all.peak_bytes();
     87       if (peak > 0) {
     88         memory =
     89             strings::StrCat(memory, "[", all.allocator_name(),
     90                             strings::Printf(" %.1fMB %.1fMB] ", tot / 1048576.0,
     91                                             peak / 1048576.0));
     92       } else {
     93         memory = strings::StrCat(memory, "[", all.allocator_name(),
     94                                  strings::Printf(" %.1fMB] ", tot / 1048576.0));
     95       }
     96     }
     97   }
     98   const AttrSlice attrs = node->attrs();
     99   string text;
    100   if (IsSend(node)) {
    101     string tensor_name;
    102     TF_CHECK_OK(GetNodeAttr(attrs, "tensor_name", &tensor_name));
    103     string recv_device;
    104     TF_CHECK_OK(GetNodeAttr(attrs, "recv_device", &recv_device));
    105     text = strings::StrCat(memory, node->name(), " = ", node->type_string(),
    106                            "(", tensor_name, " @", recv_device);
    107     is_transfer_node = true;
    108   } else if (IsRecv(node)) {
    109     string tensor_name;
    110     TF_CHECK_OK(GetNodeAttr(attrs, "tensor_name", &tensor_name));
    111     string send_device;
    112     TF_CHECK_OK(GetNodeAttr(attrs, "send_device", &send_device));
    113     text = strings::StrCat(memory, node->name(), " = ", node->type_string(),
    114                            "(", tensor_name, " @", send_device);
    115     is_transfer_node = true;
    116   } else {
    117     text =
    118         strings::StrCat(memory, node->name(), " = ", node->type_string(), "(",
    119                         str_util::Join(node->requested_inputs(), ", "), ")");
    120   }
    121   stats->stats()->set_timeline_label(text);
    122   return is_transfer_node;
    123 }
    124 
    125 // Helper routines for collecting step stats.
    126 namespace nodestats {
    127 inline int64 NowInUsec() { return Env::Default()->NowMicros(); }
    128 
    129 void SetScheduled(NodeExecStatsWrapper* stats, int64 t) {
    130   if (!stats) return;
    131   stats->stats()->set_scheduled_micros(t);
    132 }
    133 
    134 void SetAllStart(NodeExecStatsWrapper* stats) {
    135   if (!stats) return;
    136   stats->stats()->set_all_start_micros(NowInUsec());
    137 }
    138 
    139 void SetOpStart(NodeExecStatsWrapper* stats) {
    140   if (!stats) return;
    141   NodeExecStats* nt = stats->stats();
    142   DCHECK_NE(nt->all_start_micros(), 0);
    143   nt->set_op_start_rel_micros(NowInUsec() - nt->all_start_micros());
    144 }
    145 
    146 void SetOpEnd(NodeExecStatsWrapper* stats) {
    147   if (!stats) return;
    148   NodeExecStats* nt = stats->stats();
    149   DCHECK_NE(nt->all_start_micros(), 0);
    150   nt->set_op_end_rel_micros(NowInUsec() - nt->all_start_micros());
    151 }
    152 
    153 void SetAllEnd(NodeExecStatsWrapper* stats) {
    154   if (!stats) return;
    155   NodeExecStats* nt = stats->stats();
    156   DCHECK_NE(nt->all_start_micros(), 0);
    157   nt->set_all_end_rel_micros(NowInUsec() - nt->all_start_micros());
    158 }
    159 
    160 void SetOutput(NodeExecStatsWrapper* stats, int slot, const Tensor* v) {
    161   if (!stats) return;
    162   DCHECK(v);
    163   NodeOutput* no = stats->stats()->add_output();
    164   no->set_slot(slot);
    165   v->FillDescription(no->mutable_tensor_description());
    166 }
    167 
    168 void SetMemory(NodeExecStatsWrapper* stats, OpKernelContext* ctx) {
    169   if (!stats) return;
    170 
    171   for (const auto& allocator_pair : ctx->wrapped_allocators()) {
    172     stats->AddAllocation(allocator_pair.first, allocator_pair.second);
    173   }
    174   auto* ms = stats->stats()->mutable_memory_stats();
    175   ms->set_temp_memory_size(ctx->temp_memory_allocated());
    176   for (const auto& alloc_id : ctx->persistent_alloc_ids()) {
    177     ms->mutable_persistent_tensor_alloc_ids()->Add(alloc_id);
    178   }
    179   ms->set_persistent_memory_size(ctx->persistent_memory_allocated());
    180 }
    181 
    182 void SetReferencedTensors(NodeExecStatsWrapper* stats,
    183                           const TensorReferenceVector& tensors) {
    184   if (!stats) return;
    185   // be careful not to increment the reference count on any tensor
    186   // while recording the information
    187   for (size_t i = 0; i < tensors.size(); ++i) {
    188     AllocationDescription* description =
    189         stats->stats()->add_referenced_tensor();
    190     tensors.at(i).FillDescription(description);
    191   }
    192 }
    193 
    194 }  // namespace nodestats
    195 
    196 class ExecutorImpl;
    197 class GraphView;
    198 
    199 struct EdgeInfo {
    200   int dst_id;
    201   int output_slot : 31;
    202   // true if this is the last info for output_slot in the EdgeInfo list.
    203   bool is_last : 1;
    204   int input_slot;
    205 };
    206 
    207 struct NodeItem {
    208   NodeItem() {}
    209 
    210   // A graph node.
    211   const Node* node = nullptr;
    212 
    213   // The kernel for this node.
    214   OpKernel* kernel = nullptr;
    215 
    216   bool kernel_is_expensive : 1;  // True iff kernel->IsExpensive()
    217   bool kernel_is_async : 1;      // True iff kernel->AsAsync() != nullptr
    218   bool is_merge : 1;             // True iff IsMerge(node)
    219   bool is_enter : 1;             // True iff IsEnter(node)
    220   bool is_exit : 1;              // True iff IsExit(node)
    221   bool is_control_trigger : 1;   // True iff IsControlTrigger(node)
    222   bool is_sink : 1;              // True iff IsSink(node)
    223   // True iff IsEnter(node) || IsExit(node) || IsNextIteration(node)
    224   bool is_enter_exit_or_next_iter : 1;
    225 
    226   // Cached values of node->num_inputs() and node->num_outputs(), to
    227   // avoid levels of indirection.
    228   int num_inputs;
    229   int num_outputs;
    230 
    231   // ExecutorImpl::tensors_[input_start] is the 1st positional input
    232   // for this node.
    233   int input_start = 0;
    234 
    235   // Number of output edges.
    236   size_t num_output_edges;
    237 
    238   PendingCounts::Handle pending_id;
    239 
    240   const EdgeInfo* output_edge_list() const { return output_edge_base(); }
    241 
    242   // ith output edge.
    243   const EdgeInfo& output_edge(int i) const {
    244     DCHECK_GE(i, 0);
    245     DCHECK_LT(i, num_output_edges);
    246     return output_edge_base()[i];
    247   }
    248 
    249   DataType input_type(int i) const {
    250     DCHECK_LT(i, num_inputs);
    251     return static_cast<DataType>(input_type_base()[i]);
    252   }
    253   DataType output_type(int i) const {
    254     DCHECK_LT(i, num_outputs);
    255     return static_cast<DataType>(output_type_base()[i]);
    256   }
    257 
    258   // Return array of per-output allocator attributes.
    259   const AllocatorAttributes* output_attrs() const { return output_attr_base(); }
    260 
    261  private:
    262   friend class GraphView;
    263 
    264   // Variable length section starts immediately after *this
    265   // (uint8 is enough for DataType).
    266   //   EdgeInfo            out_edges[num_out_edges];
    267   //   AllocatorAttributes output_attr[num_outputs];
    268   //   uint8               input_type[num_inputs];
    269   //   uint8               output_type[num_outputs];
    270 
    271   // Return pointer to variable length section.
    272   char* var() const {
    273     return const_cast<char*>(reinterpret_cast<const char*>(this) +
    274                              sizeof(NodeItem));
    275   }
    276 
    277   EdgeInfo* output_edge_base() const {
    278     return reinterpret_cast<EdgeInfo*>(var());
    279   }
    280   AllocatorAttributes* output_attr_base() const {
    281     return reinterpret_cast<AllocatorAttributes*>(var() + sizeof(EdgeInfo) *
    282                                                               num_output_edges);
    283   }
    284   uint8* input_type_base() const {
    285     return reinterpret_cast<uint8*>(var() +
    286                                     sizeof(EdgeInfo) * num_output_edges +
    287                                     sizeof(AllocatorAttributes) * num_outputs);
    288   }
    289   uint8* output_type_base() const {
    290     return reinterpret_cast<uint8*>(
    291         var() + sizeof(EdgeInfo) * num_output_edges +
    292         sizeof(AllocatorAttributes) * num_outputs + sizeof(uint8) * num_inputs);
    293   }
    294 
    295   TF_DISALLOW_COPY_AND_ASSIGN(NodeItem);
    296 };
    297 
    298 typedef gtl::InlinedVector<TensorValue, 4> TensorValueVec;
    299 typedef gtl::InlinedVector<DeviceContext*, 4> DeviceContextVec;
    300 typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec;
    301 
    302 // Immutable view of a Graph organized for efficient execution.
    303 class GraphView {
    304  public:
    305   GraphView() : space_(nullptr) {}
    306   ~GraphView();
    307 
    308   void Initialize(const Graph* g);
    309   Status SetAllocAttrs(const Graph* g, const Device* device);
    310 
    311   NodeItem* node(size_t id) const {
    312     DCHECK_GE(id, 0);
    313     DCHECK_LT(id, num_nodes_);
    314     uint32 offset = node_offsets_[id];
    315     return ((offset == kuint32max)
    316                 ? nullptr
    317                 : reinterpret_cast<NodeItem*>(space_ + node_offsets_[id]));
    318   }
    319 
    320  private:
    321   char* InitializeNode(char* ptr, const Node* n);
    322   size_t NodeItemBytes(const Node* n);
    323 
    324   int32 num_nodes_ = 0;
    325   uint32* node_offsets_ = nullptr;  // array of size "graph_.num_node_ids()"
    326   // node_offsets_[id] holds the byte offset for node w/ "id" in space_
    327 
    328   char* space_;  // NodeItem objects are allocated here
    329 
    330   TF_DISALLOW_COPY_AND_ASSIGN(GraphView);
    331 };
    332 
    333 class ExecutorImpl : public Executor {
    334  public:
    335   ExecutorImpl(const LocalExecutorParams& p, std::unique_ptr<const Graph> g)
    336       : params_(p), graph_(std::move(g)), gview_() {
    337     CHECK(p.create_kernel != nullptr);
    338     CHECK(p.delete_kernel != nullptr);
    339   }
    340 
    341   ~ExecutorImpl() override {
    342     for (int i = 0; i < graph_->num_node_ids(); i++) {
    343       NodeItem* item = gview_.node(i);
    344       if (item != nullptr) {
    345         params_.delete_kernel(item->kernel);
    346       }
    347     }
    348     for (auto fiter : frame_info_) {
    349       delete fiter.second;
    350     }
    351   }
    352 
    353   Status Initialize();
    354 
    355   // Process all Nodes in the current graph, attempting to infer the
    356   // memory allocation attributes to be used wherever they may allocate
    357   // a tensor buffer.
    358   Status SetAllocAttrs();
    359 
    360   void RunAsync(const Args& args, DoneCallback done) override;
    361 
    362  private:
    363   friend class ExecutorState;
    364 
    365   struct ControlFlowInfo {
    366     gtl::FlatSet<string> unique_frame_names;
    367     std::vector<string> frame_names;
    368   };
    369 
    370   struct FrameInfo {
    371     FrameInfo()
    372         : input_count(0),
    373           total_inputs(0),
    374           pending_counts(nullptr),
    375           nodes(nullptr) {}
    376 
    377     // The total number of inputs to a frame.
    378     int input_count;
    379 
    380     // The total number of input tensors of a frame.
    381     // == sum(nodes[*].num_inputs()) where nodes are the nodes in the frame.
    382     int total_inputs;
    383 
    384     // Used to determine the next place to allocate space in the
    385     // pending_counts data structure we'll eventually construct
    386     PendingCounts::Layout pending_counts_layout;
    387 
    388     // Each frame has its own PendingCounts only for the nodes in the frame.
    389     PendingCounts* pending_counts;  // Owned
    390 
    391     // The nodes in a frame. Used only for debugging.
    392     std::vector<const Node*>* nodes;  // Owned
    393 
    394     ~FrameInfo() {
    395       delete pending_counts;
    396       delete nodes;
    397     }
    398   };
    399 
    400   static Status BuildControlFlowInfo(const Graph* graph,
    401                                      ControlFlowInfo* cf_info);
    402   void InitializePending(const Graph* graph, const ControlFlowInfo& cf_info);
    403 
    404   FrameInfo* EnsureFrameInfo(const string& fname) {
    405     auto slot = &frame_info_[fname];
    406     if (*slot == nullptr) {
    407       *slot = new FrameInfo;
    408     }
    409     return *slot;
    410   }
    411 
    412   // Owned.
    413   LocalExecutorParams params_;
    414   std::unique_ptr<const Graph> graph_;
    415   GraphView gview_;
    416 
    417   // A cached value of params_
    418   bool device_record_tensor_accesses_ = false;
    419 
    420   // Root nodes (with no in edges) that should form the initial ready queue
    421   std::vector<const Node*> root_nodes_;
    422 
    423   // Mapping from frame name to static information about the frame.
    424   // TODO(yuanbyu): We could cache it along with the graph so to avoid
    425   // the overhead of constructing it for each executor instance.
    426   gtl::FlatMap<string, FrameInfo*> frame_info_;
    427 
    428   TF_DISALLOW_COPY_AND_ASSIGN(ExecutorImpl);
    429 };
    430 
    431 // Infer memory allocation attributes of a node n's output,
    432 // based on its use node dst.  Note that dst might not be directly
    433 // connected to n by a single edge, but might be a downstream
    434 // consumer of n's output by reference.  *attr is updated with any
    435 // necessary attributes.
    436 Status InferAllocAttr(const Node* n, const Node* dst,
    437                       const DeviceNameUtils::ParsedName& local_dev_name,
    438                       AllocatorAttributes* attr);
    439 
    440 GraphView::~GraphView() {
    441   static_assert(std::is_trivially_destructible<AllocatorAttributes>::value,
    442                 "Update code if AllocatorAttributes gains a destructor");
    443   static_assert(std::is_trivially_destructible<EdgeInfo>::value,
    444                 "Update code if EdgeInfo gains a destructor");
    445   for (int i = 0; i < num_nodes_; i++) {
    446     NodeItem* n = node(i);
    447     if (n != nullptr) {
    448       n->NodeItem::~NodeItem();
    449       // Memory for "n" itself is held in space_ & gets cleaned up below
    450     }
    451   }
    452   delete[] node_offsets_;
    453   delete[] space_;
    454 }
    455 
    456 size_t GraphView::NodeItemBytes(const Node* n) {
    457   const size_t num_output_edges = n->out_edges().size();
    458   const int num_inputs = n->num_inputs();
    459   const int num_outputs = n->num_outputs();
    460 
    461   // Compute number of bytes needed for NodeItem and variable length data.
    462   // We do not subtract sizeof(var) since num_inputs/num_outputs might
    463   // both be zero.
    464   const size_t raw_bytes =
    465       sizeof(NodeItem)                             // Fixed
    466       + num_output_edges * sizeof(EdgeInfo)        // output_edges[...]
    467       + num_outputs * sizeof(AllocatorAttributes)  // output_attr[...]
    468       + num_inputs * sizeof(uint8)                 // input_type[num_inputs]
    469       + num_outputs * sizeof(uint8);               // output_type[num_outputs]
    470   static constexpr size_t kItemAlignment = sizeof(NodeItem*);
    471   static_assert(kItemAlignment % alignof(NodeItem) == 0,
    472                 "NodeItem must be aligned with kItemAlignment");
    473   static_assert(kItemAlignment % alignof(EdgeInfo) == 0,
    474                 "EdgeInfo must be aligned with kItemAlignment");
    475   static_assert(kItemAlignment % alignof(AllocatorAttributes) == 0,
    476                 "AllocatorAttributes must be aligned with kItemAlignment");
    477   static_assert(sizeof(NodeItem) % alignof(EdgeInfo) == 0,
    478                 "NodeItem must be aligned with EdgeInfo");
    479   static_assert(sizeof(NodeItem) % alignof(AllocatorAttributes) == 0,
    480                 "NodeItem must be aligned with AllocatorAttributes");
    481   static_assert(sizeof(EdgeInfo) % alignof(AllocatorAttributes) == 0,
    482                 "EdgeInfo must be aligned with AllocatorAttributes");
    483   const size_t bytes =
    484       ((raw_bytes + kItemAlignment - 1) / kItemAlignment) * kItemAlignment;
    485   return bytes;
    486 }
    487 
    488 char* GraphView::InitializeNode(char* ptr, const Node* n) {
    489   const int id = n->id();
    490   CHECK(node_offsets_[id] == kuint32max);  // Initial value in constructor
    491 
    492   const size_t bytes = NodeItemBytes(n);
    493   constexpr size_t kItemAlignment = sizeof(NodeItem*);
    494   CHECK_EQ(reinterpret_cast<uintptr_t>(ptr) % kItemAlignment, 0);
    495   NodeItem* item = reinterpret_cast<NodeItem*>(ptr);
    496 
    497   // We store a 32-bit offset relative to the beginning of space_, so that we
    498   // only need an array of 32-bit values to map from node id to the NodeItem*,
    499   // (versus 64 bits on most machines if we just stored an array of NodeItem*
    500   // pointers). Casting to int64 is needed on 32bit CPU to avoid comparing
    501   // values as "int" vs "size_t" in CHECK_LE.
    502   CHECK_LE(static_cast<int64>(ptr - space_), kuint32max);
    503   const uint32 offset = static_cast<uint32>(ptr - space_);
    504   node_offsets_[id] = offset;
    505   ptr += bytes;
    506 
    507   const size_t num_output_edges = n->out_edges().size();
    508   const int num_inputs = n->num_inputs();
    509   const int num_outputs = n->num_outputs();
    510 
    511   new (item) NodeItem();
    512   item->num_inputs = num_inputs;
    513   item->num_outputs = num_outputs;
    514   item->num_output_edges = num_output_edges;
    515 
    516   // Fill output edges.
    517   // Keep track of the last EdgeInfo in the EdgeInfo array that references
    518   // a given output slot.  For all but the last, we need to do a copy of the
    519   // Tensor when propagating results downstream in the graph, but for the
    520   // last one, we can just do a move of the Tensor object to propagate it.
    521   gtl::InlinedVector<EdgeInfo*, 4> last_indices(num_outputs, nullptr);
    522   EdgeInfo* dst_edge = item->output_edge_base();
    523   for (auto e : n->out_edges()) {
    524     dst_edge->dst_id = e->dst()->id();
    525     CHECK_LE(e->src_output(), 0x3FFFFFFF);  // Must fit in 31 bits
    526     dst_edge->output_slot = e->src_output();
    527     dst_edge->is_last = false;
    528     const int output_slot = dst_edge->output_slot;
    529     if (output_slot >= 0) {
    530       last_indices[output_slot] = dst_edge;
    531     }
    532     dst_edge->input_slot = e->dst_input();
    533     dst_edge++;
    534   }
    535   for (EdgeInfo* edge_info : last_indices) {
    536     if (edge_info != nullptr) {
    537       edge_info->is_last = true;
    538     }
    539   }
    540 
    541   AllocatorAttributes* output_attrs = item->output_attr_base();
    542   for (int i = 0; i < num_outputs; i++) {
    543     new (&output_attrs[i]) AllocatorAttributes();
    544   }
    545 
    546   DCHECK_LT(DataType_MAX, 255);  // Must fit in uint8
    547   uint8* input_types = item->input_type_base();
    548   for (int i = 0; i < num_inputs; i++) {
    549     input_types[i] = static_cast<uint8>(n->input_type(i));
    550     DCHECK_EQ(item->input_type(i), n->input_type(i));
    551   }
    552 
    553   uint8* output_types = item->output_type_base();
    554   for (int i = 0; i < num_outputs; i++) {
    555     output_types[i] = static_cast<uint8>(n->output_type(i));
    556     DCHECK_EQ(item->output_type(i), n->output_type(i));
    557   }
    558   return ptr;
    559 }
    560 
    561 void GraphView::Initialize(const Graph* g) {
    562   CHECK(node_offsets_ == nullptr);
    563   const int num_nodes = g->num_node_ids();
    564   num_nodes_ = num_nodes;
    565   size_t total_bytes = 0;
    566   for (const Node* n : g->nodes()) {
    567     total_bytes += NodeItemBytes(n);
    568   }
    569 
    570   node_offsets_ = new uint32[num_nodes];
    571   for (int i = 0; i < num_nodes; i++) {
    572     node_offsets_[i] = kuint32max;
    573   }
    574 
    575   space_ = new char[total_bytes];  // NodeItem objects are allocated here
    576   char* ptr = space_;
    577   for (const Node* n : g->nodes()) {
    578     ptr = InitializeNode(ptr, n);
    579   }
    580   CHECK_EQ(ptr, space_ + total_bytes);
    581 }
    582 
    583 void GetMaxPendingCounts(const Node* n, size_t* max_pending,
    584                          size_t* max_dead_count) {
    585   const size_t num_in_edges = n->in_edges().size();
    586   size_t initial_count;
    587   if (IsMerge(n)) {
    588     // merge waits all control inputs so we initialize the pending
    589     // count to be the number of control edges.
    590     int32 num_control_edges = 0;
    591     for (const Edge* edge : n->in_edges()) {
    592       if (edge->IsControlEdge()) {
    593         num_control_edges++;
    594       }
    595     }
    596     // Use bit 0 to indicate if we are waiting for a ready live data input.
    597     initial_count = 1 + (num_control_edges << 1);
    598   } else {
    599     initial_count = num_in_edges;
    600   }
    601 
    602   *max_pending = initial_count;
    603   *max_dead_count = num_in_edges;
    604 }
    605 
    606 Status ExecutorImpl::Initialize() {
    607   gview_.Initialize(graph_.get());
    608 
    609   // Build the information about frames in this subgraph.
    610   ControlFlowInfo cf_info;
    611   TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph_.get(), &cf_info));
    612 
    613   // Cache this value so we make this virtual function call once, rather
    614   // that O(# steps * # nodes per step) times.
    615   device_record_tensor_accesses_ =
    616       params_.device->RequiresRecordingAccessedTensors();
    617 
    618   for (auto& it : cf_info.unique_frame_names) {
    619     EnsureFrameInfo(it)->nodes = new std::vector<const Node*>;
    620   }
    621 
    622   // Preprocess every node in the graph to create an instance of op
    623   // kernel for each node.
    624   for (const Node* n : graph_->nodes()) {
    625     const int id = n->id();
    626     const string& frame_name = cf_info.frame_names[id];
    627     FrameInfo* frame_info = EnsureFrameInfo(frame_name);
    628 
    629     // See if this node is a root node, and if so, add to root_nodes_.
    630     if (n->in_edges().empty()) {
    631       root_nodes_.push_back(n);
    632     }
    633 
    634     NodeItem* item = gview_.node(id);
    635     item->node = n;
    636 
    637     item->input_start = frame_info->total_inputs;
    638     frame_info->total_inputs += n->num_inputs();
    639 
    640     Status s = params_.create_kernel(n->def(), &item->kernel);
    641     if (!s.ok()) {
    642       item->kernel = nullptr;
    643       s = AttachDef(s, *n);
    644       LOG(ERROR) << "Executor failed to create kernel. " << s;
    645       return s;
    646     }
    647     CHECK(item->kernel);
    648     item->kernel_is_expensive = item->kernel->IsExpensive();
    649     item->kernel_is_async = (item->kernel->AsAsync() != nullptr);
    650     item->is_merge = IsMerge(n);
    651     item->is_enter = IsEnter(n);
    652     item->is_exit = IsExit(n);
    653     item->is_control_trigger = IsControlTrigger(n);
    654     item->is_sink = IsSink(n);
    655     item->is_enter_exit_or_next_iter =
    656         (IsEnter(n) || IsExit(n) || IsNextIteration(n));
    657 
    658     // Compute the maximum values we'll store for this node in the
    659     // pending counts data structure, and allocate a handle in
    660     // that frame's pending counts data structure that has enough
    661     // space to store these maximal count values.
    662     size_t max_pending, max_dead;
    663     GetMaxPendingCounts(n, &max_pending, &max_dead);
    664     item->pending_id =
    665         frame_info->pending_counts_layout.CreateHandle(max_pending, max_dead);
    666 
    667     // Initialize static information about the frames in the graph.
    668     frame_info->nodes->push_back(n);
    669     if (IsEnter(n)) {
    670       string enter_name;
    671       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "frame_name", &enter_name));
    672       EnsureFrameInfo(enter_name)->input_count++;
    673     }
    674   }
    675 
    676   // Initialize PendingCounts only after item->pending_id is initialized for
    677   // all nodes.
    678   InitializePending(graph_.get(), cf_info);
    679 
    680   return gview_.SetAllocAttrs(graph_.get(), params_.device);
    681 }
    682 
    683 Status GraphView::SetAllocAttrs(const Graph* g, const Device* device) {
    684   Status s;
    685   DeviceNameUtils::ParsedName local_dev_name = device->parsed_name();
    686 
    687   for (const Node* n : g->nodes()) {
    688     NodeItem* item = node(n->id());
    689     AllocatorAttributes* attrs = item->output_attr_base();
    690 
    691     // Examine the out edges of each node looking for special use
    692     // cases that may affect memory allocation attributes.
    693     for (auto e : n->out_edges()) {
    694       if (!e->IsControlEdge()) {
    695         AllocatorAttributes attr;
    696         s = InferAllocAttr(n, e->dst(), local_dev_name, &attr);
    697         if (!s.ok()) return s;
    698         if (attr.value != 0) {
    699           attrs[e->src_output()].Merge(attr);
    700         }
    701       }
    702     }
    703 
    704     for (int out = 0; out < n->num_outputs(); out++) {
    705       const OpKernel* op_kernel = item->kernel;
    706       DCHECK_LT(out, op_kernel->output_memory_types().size());
    707       bool on_host = op_kernel->output_memory_types()[out] == HOST_MEMORY;
    708       if (on_host) {
    709         AllocatorAttributes h;
    710         h.set_on_host(on_host);
    711         attrs[out].Merge(h);
    712       }
    713     }
    714   }
    715   return s;
    716 }
    717 
    718 Status InferAllocAttr(const Node* n, const Node* dst,
    719                       const DeviceNameUtils::ParsedName& local_dev_name,
    720                       AllocatorAttributes* attr) {
    721   Status s;
    722   // Note that it's possible for *n to be a Recv and *dst to be a Send,
    723   // so these two cases are not mutually exclusive.
    724   if (IsRecv(n)) {
    725     string src_name;
    726     s = GetNodeAttr(n->attrs(), "send_device", &src_name);
    727     if (!s.ok()) return s;
    728     DeviceNameUtils::ParsedName parsed_src_name;
    729     if (!DeviceNameUtils::ParseFullName(src_name, &parsed_src_name)) {
    730       s = errors::Internal("Bad send_device attr '", src_name, "' in node ",
    731                            n->name());
    732       return s;
    733     }
    734     if (!DeviceNameUtils::IsSameAddressSpace(parsed_src_name, local_dev_name)) {
    735       // Value is going to be the sink of an RPC.
    736       attr->set_nic_compatible(true);
    737       VLOG(2) << "node " << n->name() << " is the sink of an RPC in";
    738     } else if ((local_dev_name.type == "CPU" || n->IsHostRecv()) &&
    739                parsed_src_name.type != "CPU") {
    740       // Value is going to be the sink of a local DMA from GPU to CPU (or other
    741       // types of accelerators).
    742       attr->set_gpu_compatible(true);
    743       VLOG(2) << "node " << n->name() << " is the sink of a gpu->cpu copy";
    744     } else {
    745       VLOG(2) << "default alloc case local type " << local_dev_name.type
    746               << " remote type " << parsed_src_name.type;
    747     }
    748   }
    749   if (IsSend(dst)) {
    750     string dst_name;
    751     s = GetNodeAttr(dst->attrs(), "recv_device", &dst_name);
    752     if (!s.ok()) return s;
    753     DeviceNameUtils::ParsedName parsed_dst_name;
    754     if (!DeviceNameUtils::ParseFullName(dst_name, &parsed_dst_name)) {
    755       s = errors::Internal("Bad recv_device attr '", dst_name, "' in node ",
    756                            n->name());
    757       return s;
    758     }
    759     if (!DeviceNameUtils::IsSameAddressSpace(parsed_dst_name, local_dev_name)) {
    760       // Value is going to be the source of an RPC.
    761       attr->set_nic_compatible(true);
    762       VLOG(2) << "node " << n->name() << " is the source of an RPC out";
    763     } else if ((local_dev_name.type == "CPU" || dst->IsHostSend()) &&
    764                parsed_dst_name.type != "CPU") {
    765       // Value is going to be the source of a local DMA from CPU to GPU (or
    766       // other types of accelerators).
    767       // Note that this does not cover the case where the allocation of the
    768       // output tensor is not generated by the src: n.
    769       attr->set_gpu_compatible(true);
    770       VLOG(2) << "node " << n->name() << " is the source of a cpu->gpu copy";
    771     } else {
    772       VLOG(2) << "default alloc case local type " << local_dev_name.type
    773               << " remote type " << parsed_dst_name.type;
    774     }
    775   }
    776   return s;
    777 }
    778 
    779 // The state associated with one invocation of ExecutorImpl::Run.
    780 // ExecutorState dispatches nodes when they become ready and keeps
    781 // track of how many predecessors of a node have not done (pending_).
    782 class ExecutorState {
    783  public:
    784   ExecutorState(const Executor::Args& args, ExecutorImpl* impl);
    785   ~ExecutorState();
    786 
    787   void RunAsync(Executor::DoneCallback done);
    788 
    789  private:
    790   // Either a tensor pointer (pass-by-reference) or a tensor (pass-by-value).
    791   // TODO(yuanbyu): A better way to do "has_value"?
    792   struct Entry {
    793     Entry() {}
    794     Entry(const Entry& other)
    795         : ref(other.ref),
    796           ref_mu(other.ref_mu),
    797           has_value(other.has_value),
    798           val_field_is_set(other.val_field_is_set),
    799           alloc_attr(other.alloc_attr),
    800           device_context(other.device_context) {
    801       if (val_field_is_set) {
    802         val.Init(*other.val);
    803       }
    804     }
    805     ~Entry() {
    806       if (val_field_is_set) val.Destroy();
    807     }
    808 
    809     Entry& operator=(const Entry& other) {
    810       if (val_field_is_set) {
    811         val.Destroy();
    812       }
    813       ref = other.ref;
    814       ref_mu = other.ref_mu;
    815       has_value = other.has_value;
    816       val_field_is_set = other.val_field_is_set;
    817       alloc_attr = other.alloc_attr;
    818       device_context = other.device_context;
    819       if (val_field_is_set) {
    820         val.Init(*other.val);
    821       }
    822       return *this;
    823     }
    824 
    825     Entry& operator=(Entry&& other) {
    826       if (val_field_is_set) {
    827         val.Destroy();
    828       }
    829       ref = other.ref;
    830       ref_mu = other.ref_mu;
    831       has_value = other.has_value;
    832       val_field_is_set = other.val_field_is_set;
    833       alloc_attr = other.alloc_attr;
    834       device_context = other.device_context;
    835       if (val_field_is_set) {
    836         val.Init(std::move(*other.val));
    837       }
    838       return *this;
    839     }
    840 
    841     // Clears the <val> field.
    842     void ClearVal() {
    843       if (val_field_is_set) {
    844         val.Destroy();
    845         val_field_is_set = false;
    846         has_value = false;
    847       }
    848     }
    849 
    850     // A tensor value, if val_field_is_set.
    851     ManualConstructor<Tensor> val;
    852 
    853     Tensor* ref = nullptr;    // A tensor reference.
    854     mutex* ref_mu = nullptr;  // mutex for *ref if ref is not nullptr.
    855 
    856     // Whether the value exists, either in <val> or <ref>.
    857     bool has_value = false;
    858 
    859     bool val_field_is_set = false;
    860 
    861     // The attributes of the allocator that creates the tensor.
    862     AllocatorAttributes alloc_attr;
    863 
    864     // Every entry carries an optional DeviceContext containing
    865     // Device-specific information about how the Tensor was produced.
    866     DeviceContext* device_context = nullptr;
    867   };
    868 
    869   // Contains a value for [node->id()] for the device context assigned by the
    870   // device at the beginning of a step.
    871   DeviceContextMap device_context_map_;
    872 
    873   struct TaggedNode;
    874   typedef gtl::InlinedVector<TaggedNode, 8> TaggedNodeSeq;
    875   typedef gtl::InlinedVector<Entry, 4> EntryVector;
    876 
    877   struct IterationState {
    878     explicit IterationState(const PendingCounts* pending_counts,
    879                             int total_input_tensors)
    880         : input_tensors(new Entry[total_input_tensors]),
    881           outstanding_ops(0),
    882           outstanding_frame_count(0),
    883           counts_(*pending_counts) {  // Initialize with copy of *pending_counts
    884     }
    885 
    886     // The state of an iteration.
    887 
    888     // One copy per iteration. For iteration k, i-th node's j-th input is in
    889     // input_tensors[k][impl_->nodes[i].input_start + j]. An entry is either
    890     // a tensor pointer (pass-by-reference) or a tensor (pass-by-value).
    891     //
    892     // NOTE: No need to protect input_tensors[i] by any locks because it
    893     // is resized once. Each element of tensors_ is written once by the
    894     // source node of an edge and is cleared by the destination of the same
    895     // edge. The latter node is never run concurrently with the former node.
    896     Entry* input_tensors;
    897 
    898     // The number of outstanding ops for each iteration.
    899     size_t outstanding_ops;
    900 
    901     // The number of outstanding frames for each iteration.
    902     int outstanding_frame_count;
    903     int pending(PendingCounts::Handle h) { return counts_.pending(h); }
    904     int decrement_pending(PendingCounts::Handle h, int v) {
    905       return counts_.decrement_pending(h, v);
    906     }
    907     // Mark a merge node as live
    908     // REQUIRES: Node corresponding to "h" is a merge node
    909     void mark_live(PendingCounts::Handle h) { counts_.mark_live(h); }
    910     // Mark a node to show that processing has started.
    911     void mark_started(PendingCounts::Handle h) { counts_.mark_started(h); }
    912     // Mark a node to show that processing has completed.
    913     void mark_completed(PendingCounts::Handle h) { counts_.mark_completed(h); }
    914     PendingCounts::NodeState node_state(PendingCounts::Handle h) {
    915       return counts_.node_state(h);
    916     }
    917 
    918     int dead_count(PendingCounts::Handle h) { return counts_.dead_count(h); }
    919     void increment_dead_count(PendingCounts::Handle h) {
    920       counts_.increment_dead_count(h);
    921     }
    922     void adjust_for_activation(PendingCounts::Handle h, bool increment_dead,
    923                                int* pending_result, int* dead_result) {
    924       counts_.adjust_for_activation(h, increment_dead, pending_result,
    925                                     dead_result);
    926     }
    927 
    928     ~IterationState() { delete[] input_tensors; }
    929 
    930    private:
    931     PendingCounts counts_;
    932   };
    933 
    934   struct FrameState {
    935     explicit FrameState(const ExecutorImpl* impl, int parallel_iters)
    936         : executor(impl),
    937           max_parallel_iterations(parallel_iters),
    938           num_outstanding_iterations(1) {}
    939 
    940     // A new frame is created for each loop. Execution starts at iteration 0.
    941     // When a value at iteration 0 passes through a NextIteration node,
    942     // iteration 1 is created and starts running. Note that iteration 0 may
    943     // still be running so multiple iterations may run in parallel. The
    944     // frame maintains the state of iterations in several data structures
    945     // such as pending_count and input_tensors. When iteration 0 completes,
    946     // we garbage collect the state of iteration 0.
    947     //
    948     // A frame instance is considered "done" and can be garbage collected
    949     // if all its inputs have entered and all its iterations are "done".
    950     //
    951     // A frame manages the live iterations of an iterative computation.
    952     // Iteration i is considered "done" when there are no outstanding ops,
    953     // frames at iteration i are done, all recvs for this iteration are
    954     // completed, and iteration i-1 is done. For iteration 0, we instead
    955     // wait for there to be no more pending inputs of the frame.
    956     //
    957     // Frames and iterations are garbage collected once they are done.
    958     // The state we need to keep around is highly dependent on the
    959     // parallelism enabled by the scheduler. We may want to have the
    960     // scheduler dynamically control the outstanding number of live
    961     // parallel frames and iterations. To reduce the state space, the
    962     // scheduler might want to schedule ops in inner frames first and
    963     // lower iterations first.
    964     //
    965     // This frame state is mostly initialized lazily on demand so we
    966     // don't introduce unnecessary overhead.
    967 
    968     // The executor the frame is in.
    969     const ExecutorImpl* executor = nullptr;
    970 
    971     // The name of this frame, which is the concatenation of its parent
    972     // frame name, the iteration of the parent frame when this frame was
    973     // created, and the value of the attr 'frame_name'.
    974     string frame_name;
    975 
    976     // The unique id for this frame. Generated by fingerprinting
    977     // frame_name.
    978     uint64 frame_id;
    979 
    980     // The iteration id of its parent frame when this frame is created.
    981     // -1 if there is no parent frame. The frame_name/parent_iter pair
    982     // uniquely identifies this FrameState.
    983     int64 parent_iter = -1;
    984 
    985     // The FrameState of its parent frame.
    986     FrameState* parent_frame = nullptr;
    987 
    988     // The maximum allowed number of parallel iterations.
    989     const int max_parallel_iterations;
    990 
    991     // The number of inputs this frame is still waiting.
    992     int num_pending_inputs = 0;
    993 
    994     // The highest iteration number we have reached so far in this frame.
    995     int64 iteration_count GUARDED_BY(mu) = 0;
    996 
    997     // The number of outstanding iterations.
    998     int num_outstanding_iterations GUARDED_BY(mu) = 1;
    999 
   1000     // The active iteration states of this frame.
   1001     gtl::InlinedVector<IterationState*, 12> iterations;
   1002 
   1003     // The NextIteration nodes to enter a new iteration. If the number of
   1004     // outstanding iterations reaches the limit, we will defer the start of
   1005     // the next iteration until the number of outstanding iterations falls
   1006     // below the limit.
   1007     std::vector<std::pair<const Node*, Entry>> next_iter_roots GUARDED_BY(mu);
   1008 
   1009     // The values of the loop invariants for this loop. They are added into
   1010     // this list as they "enter" the frame. When a loop invariant enters,
   1011     // we make it available to all active iterations. When the frame starts
   1012     // a new iteration, we make all the current loop invariants available
   1013     // to the new iteration.
   1014     std::vector<std::pair<const Node*, Entry>> inv_values GUARDED_BY(mu);
   1015 
   1016     // The list of dead exit nodes for the current highest iteration. We
   1017     // will only "execute" the dead exits of the final iteration.
   1018     std::vector<const Node*> dead_exits GUARDED_BY(mu);
   1019 
   1020     // Static information specific to this frame.
   1021     PendingCounts* pending_counts = nullptr;
   1022     int total_input_tensors = 0;
   1023     std::vector<const Node*>* nodes = nullptr;
   1024 
   1025     // Lock ordering: ExecutorState.mu_ < mu.
   1026     mutex mu;
   1027 
   1028     void InitializeFrameInfo(const string& enter_name) {
   1029       auto it_frame_info = executor->frame_info_.find(enter_name);
   1030       DCHECK(it_frame_info != executor->frame_info_.end());
   1031       ExecutorImpl::FrameInfo* finfo = it_frame_info->second;
   1032       pending_counts = finfo->pending_counts;
   1033       total_input_tensors = finfo->total_inputs;
   1034       num_pending_inputs = finfo->input_count;
   1035       nodes = finfo->nodes;
   1036     }
   1037 
   1038     inline IterationState* GetIteration(int64 iter)
   1039         EXCLUSIVE_LOCKS_REQUIRED(mu) {
   1040       size_t index = iter % iterations.size();
   1041       return iterations[index];
   1042     }
   1043 
   1044     inline void SetIteration(int64 iter, IterationState* state)
   1045         EXCLUSIVE_LOCKS_REQUIRED(mu) {
   1046       size_t index = iter % iterations.size();
   1047       DCHECK(state == nullptr || iterations[index] == nullptr);
   1048       iterations[index] = state;
   1049     }
   1050 
   1051     // Decrement the outstanding op count and clean up the iterations in the
   1052     // frame. Return true iff the execution of the frame is done.
   1053     inline bool DecrementOutstandingOps(const GraphView* gview, int64 iter,
   1054                                         TaggedNodeSeq* ready) {
   1055       mutex_lock l(mu);
   1056       return DecrementOutstandingOpsLocked(gview, iter, ready);
   1057     }
   1058 
   1059     // Decrement the outstanding op count and clean up the iterations in the
   1060     // frame. Return true iff the execution of the frame is done.
   1061     inline bool DecrementOutstandingOpsLocked(const GraphView* gview,
   1062                                               int64 iter, TaggedNodeSeq* ready)
   1063         EXCLUSIVE_LOCKS_REQUIRED(mu) {
   1064       IterationState* istate = GetIteration(iter);
   1065       istate->outstanding_ops--;
   1066       if (istate->outstanding_ops != 0) {
   1067         return false;
   1068       } else {
   1069         return CleanupIterations(gview, iter, ready);
   1070       }
   1071     }
   1072 
   1073     // Returns true if the computation in the frame is completed.
   1074     inline bool IsFrameDone() EXCLUSIVE_LOCKS_REQUIRED(mu) {
   1075       return (num_pending_inputs == 0 && num_outstanding_iterations == 0);
   1076     }
   1077 
   1078     // Returns true if the iteration of the frame is completed.
   1079     bool IsIterationDone(int64 iter) EXCLUSIVE_LOCKS_REQUIRED(mu);
   1080 
   1081     // Increments the iteration id. If this is a new iteration, initialize it.
   1082     void IncrementIteration(const GraphView* gview, TaggedNodeSeq* ready)
   1083         EXCLUSIVE_LOCKS_REQUIRED(mu);
   1084 
   1085     // Activate all the deferred NextIteration nodes in a new iteration.
   1086     void ActivateNexts(const GraphView* gview, int64 iter, TaggedNodeSeq* ready)
   1087         EXCLUSIVE_LOCKS_REQUIRED(mu);
   1088 
   1089     // Activate all the current loop invariants in a new iteration.
   1090     void ActivateLoopInvs(const GraphView* gview, int64 iter,
   1091                           TaggedNodeSeq* ready) EXCLUSIVE_LOCKS_REQUIRED(mu);
   1092 
   1093     // Add a new loop invariant and make it available to all active iterations.
   1094     void AddLoopInv(const NodeItem* item, const Entry& value,
   1095                     TaggedNodeSeq* ready) EXCLUSIVE_LOCKS_REQUIRED(mu);
   1096 
   1097     // Activate the successors of a node. Contents of *outputs are left in an
   1098     // indeterminate state after returning from this method.
   1099     void ActivateNodes(const NodeItem* item, const bool is_dead, int64 iter,
   1100                        EntryVector* outputs, TaggedNodeSeq* ready)
   1101         EXCLUSIVE_LOCKS_REQUIRED(mu);
   1102 
   1103     // Cleanup iterations of this frame starting from iteration iter.
   1104     bool CleanupIterations(const GraphView* gview, int64 iter,
   1105                            TaggedNodeSeq* ready) EXCLUSIVE_LOCKS_REQUIRED(mu);
   1106 
   1107     ~FrameState() {
   1108       for (size_t i = 0; i < iterations.size(); ++i) {
   1109         delete iterations[i];
   1110         iterations[i] = nullptr;
   1111       }
   1112     }
   1113   };
   1114 
   1115   // A tagged node: <frame*, iter, node*>.
   1116   struct TaggedNode {
   1117     const Node* node = nullptr;
   1118     FrameState* input_frame = nullptr;
   1119     int64 input_iter = -1;
   1120     bool is_dead = false;
   1121 
   1122     TaggedNode(const Node* t_node, FrameState* in_frame, int64 in_iter,
   1123                bool dead) {
   1124       node = t_node;
   1125       input_frame = in_frame;
   1126       input_iter = in_iter;
   1127       is_dead = dead;
   1128     }
   1129   };
   1130 
   1131   // A drop-in replacement for std::deque<TaggedNode>.  We typically don't
   1132   // have that many nodes in the ready queue, so we just use a vector and
   1133   // don't free up memory from the queue as we consume nodes.
   1134   class TaggedNodeReadyQueue {
   1135    public:
   1136     TaggedNodeReadyQueue() : front_index_(0) {}
   1137 
   1138     void push_back(TaggedNode node) { ready_.push_back(node); }
   1139     TaggedNode front() const {
   1140       DCHECK_LT(front_index_, ready_.size());
   1141       return ready_[front_index_];
   1142     }
   1143     void pop_front() {
   1144       DCHECK_LT(front_index_, ready_.size());
   1145       front_index_++;
   1146       if ((front_index_ == ready_.size()) || (front_index_ > 16384)) {
   1147         if (front_index_ == ready_.size()) {
   1148           ready_.clear();
   1149         } else {
   1150           // Lots of unused entries at beginning of vector: move everything down
   1151           // to start of vector.
   1152           ready_.erase(ready_.begin(), ready_.begin() + front_index_);
   1153         }
   1154         front_index_ = 0;
   1155       }
   1156     }
   1157     bool empty() const { return ready_.empty(); }
   1158     const TaggedNode* begin() const { return ready_.begin() + front_index_; }
   1159     const TaggedNode* end() const { return ready_.end(); }
   1160 
   1161    private:
   1162     gtl::InlinedVector<TaggedNode, 16> ready_;
   1163     int front_index_;
   1164   };
   1165 
   1166   struct AsyncState;
   1167 
   1168   const bool vlog_;  // true if VLOG_IS_ON(1). Used to check vlog cheaply.
   1169 
   1170   // true if LogMemory::IsEnabled(). Used to check memory enabled cheaply.
   1171   const bool log_memory_;
   1172 
   1173   int64 step_id_;
   1174   // Not owned.
   1175   Rendezvous* rendezvous_;
   1176   SessionState* session_state_;
   1177   TensorStore* tensor_store_;
   1178   // Step-local container.
   1179   ScopedStepContainer* step_container_;
   1180   StepStatsCollector* stats_collector_;
   1181   // QUESTION: Make it a checkpoint::TensorSliceReaderCacheWrapper
   1182   // instead of a pointer?  (avoids having to delete).
   1183   checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache_;
   1184   CallFrameInterface* call_frame_;
   1185   const ExecutorImpl* impl_;
   1186   CancellationManager* cancellation_manager_;
   1187   Executor::Args::Runner runner_;
   1188   bool sync_on_finish_;
   1189 
   1190   // Owned.
   1191 
   1192   // A flag that is set on error after the frame state has been
   1193   // dumped for diagnostic purposes.
   1194   bool dumped_on_error_ = false;
   1195 
   1196   // The root frame in which the execution of this step is started.
   1197   FrameState* root_frame_;
   1198 
   1199   // Invoked when the execution finishes.
   1200   Executor::DoneCallback done_cb_;
   1201 
   1202   std::atomic_int_fast32_t num_outstanding_ops_;
   1203 
   1204   mutex mu_;
   1205   Status status_ GUARDED_BY(mu_);
   1206 
   1207   // Mapping from frame name to outstanding frames. A new frame is created
   1208   // at some iteration of an active frame. So the unique key for the new
   1209   // child frame is composed of the name of the parent frame, the iteration
   1210   // number at which the parent frame is creating the new frame, and the
   1211   // name of the new frame from nodedef.
   1212   gtl::FlatMap<string, FrameState*> outstanding_frames_ GUARDED_BY(mu_);
   1213 
   1214   // The unique name of a frame.
   1215   inline string MakeFrameName(FrameState* frame, int64 iter_id,
   1216                               const string& name) {
   1217     return strings::StrCat(frame->frame_name, ";", iter_id, ";", name);
   1218   }
   1219 
   1220   // Find an existing or create a new child frame in the frame 'frame' at
   1221   // iteration 'iter'.
   1222   void FindOrCreateChildFrame(FrameState* frame, int64 iter, const Node* node,
   1223                               FrameState** child);
   1224 
   1225   // Delete a frame. Called when the frame is done.
   1226   void DeleteFrame(FrameState* frame, TaggedNodeSeq* ready);
   1227 
   1228   // Cleanup frames and iterations starting from frame/iter. Called when
   1229   // a child frame is done.
   1230   void CleanupFramesIterations(FrameState* frame, int64 iter,
   1231                                TaggedNodeSeq* ready);
   1232 
   1233   // Process a ready node in current thread.
   1234   void Process(TaggedNode node, int64 scheduled_usec);
   1235 
   1236   // Before invoking item->kernel, fills in its "inputs".
   1237   Status PrepareInputs(const NodeItem& item, Entry* first_input,
   1238                        TensorValueVec* inputs,
   1239                        DeviceContextVec* input_device_contexts,
   1240                        AllocatorAttributeVec* input_alloc_attrs,
   1241                        bool* is_input_dead);
   1242 
   1243   // After item->kernel computation is done, processes its outputs.
   1244   Status ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
   1245                         EntryVector* outputs, NodeExecStatsWrapper* stats);
   1246 
   1247   // After processing the outputs, propagates the outputs to their dsts.
   1248   // Contents of *outputs are left in an indeterminate state after
   1249   // returning from this method.
   1250   void PropagateOutputs(const TaggedNode& tagged_node, const NodeItem* item,
   1251                         EntryVector* outputs, TaggedNodeSeq* ready);
   1252 
   1253   // "node" just finishes. Takes ownership of "stats". Returns true if
   1254   // execution has completed.
   1255   bool NodeDone(const Status& s, const Node* node, const TaggedNodeSeq& ready,
   1256                 NodeExecStatsWrapper* stats,
   1257                 TaggedNodeReadyQueue* inline_ready);
   1258 
   1259   // Schedule all the expensive nodes in 'ready', and put all the inexpensive
   1260   // nodes in 'ready' into 'inline_ready'.
   1261   void ScheduleReady(const TaggedNodeSeq& ready,
   1262                      TaggedNodeReadyQueue* inline_ready);
   1263 
   1264   // For debugging/logging only.
   1265   inline void MaybeMarkCompleted(FrameState* frame, int64 iter, int64 id);
   1266 
   1267   // Provide debugging output about an outstanding node in the executor.
   1268   void DumpPendingNodeState(const int node_id, const Entry* input_vector,
   1269                             bool show_nodes_with_no_ready_inputs);
   1270   void DumpActiveNodeState(const int node_id, const Entry* input_vector);
   1271 
   1272   // Provide debugging output about an outstanding iteration in the executor.
   1273   void DumpIterationState(const FrameState* frame, IterationState* iteration);
   1274 
   1275   // Provide debugging output of the state of the executor.
   1276   void DumpState();
   1277   const Tensor* GetTensorValueForDump(const Entry& input);
   1278 
   1279   // Clean up when this executor is done.
   1280   void Finish();
   1281 
   1282   // A standalone routine for this expression so that we can express
   1283   // that we don't want thread safety analysis on this reference (it's
   1284   // safe to do without the lock because the iterations array never
   1285   // resizes and this particular iteration's array element will not
   1286   // be changed out from under us because the iteration is still alive).
   1287   Entry* GetInputTensors(FrameState* input_frame,
   1288                          int64 input_iter) const NO_THREAD_SAFETY_ANALYSIS {
   1289     return input_frame->GetIteration(input_iter)->input_tensors;
   1290   }
   1291 };
   1292 
   1293 ExecutorState::ExecutorState(const Executor::Args& args, ExecutorImpl* impl)
   1294     : vlog_(VLOG_IS_ON(1)),
   1295       log_memory_(LogMemory::IsEnabled()),
   1296       step_id_(args.step_id),
   1297       rendezvous_(args.rendezvous),
   1298       session_state_(args.session_state),
   1299       tensor_store_(args.tensor_store),
   1300       step_container_(args.step_container),
   1301       stats_collector_(args.stats_collector),
   1302       slice_reader_cache_(new checkpoint::TensorSliceReaderCacheWrapper),
   1303       call_frame_(args.call_frame),
   1304       impl_(impl),
   1305       cancellation_manager_(args.cancellation_manager),
   1306       runner_(args.runner),
   1307       sync_on_finish_(args.sync_on_finish),
   1308       num_outstanding_ops_(0) {
   1309   // We start the entire execution in iteration 0 of the root frame
   1310   // so let us create the root frame and the state for iteration 0.
   1311   // We assume root_frame_->frame_name.empty().
   1312   root_frame_ = new FrameState(impl_, 1);
   1313   root_frame_->frame_id = 0;  // must be 0
   1314   root_frame_->InitializeFrameInfo(root_frame_->frame_name);
   1315 
   1316   // Initialize iteration 0.
   1317   root_frame_->iterations.resize(root_frame_->max_parallel_iterations);
   1318   root_frame_->iterations[0] = new IterationState(
   1319       root_frame_->pending_counts, root_frame_->total_input_tensors);
   1320 
   1321   outstanding_frames_.insert({root_frame_->frame_name, root_frame_});
   1322 }
   1323 
   1324 ExecutorState::~ExecutorState() {
   1325   for (auto name_frame : outstanding_frames_) {
   1326     delete name_frame.second;
   1327   }
   1328   for (auto it : device_context_map_) {
   1329     it->Unref();
   1330   }
   1331   delete slice_reader_cache_;
   1332 }
   1333 
   1334 Status ExecutorImpl::BuildControlFlowInfo(const Graph* g,
   1335                                           ControlFlowInfo* cf_info) {
   1336   const int num_nodes = g->num_node_ids();
   1337   cf_info->frame_names.resize(num_nodes);
   1338   std::vector<Node*> parent_nodes;
   1339   parent_nodes.resize(num_nodes);
   1340   std::vector<bool> visited;
   1341   visited.resize(num_nodes);
   1342 
   1343   string frame_name;
   1344   std::deque<Node*> ready;
   1345 
   1346   // Initialize with the root nodes.
   1347   for (Node* n : g->nodes()) {
   1348     if (n->in_edges().empty()) {
   1349       visited[n->id()] = true;
   1350       cf_info->unique_frame_names.insert(frame_name);
   1351       ready.push_back(n);
   1352     }
   1353   }
   1354 
   1355   while (!ready.empty()) {
   1356     Node* curr_node = ready.front();
   1357     int curr_id = curr_node->id();
   1358     ready.pop_front();
   1359 
   1360     Node* parent = nullptr;
   1361     if (IsEnter(curr_node)) {
   1362       // Enter a child frame.
   1363       TF_RETURN_IF_ERROR(
   1364           GetNodeAttr(curr_node->attrs(), "frame_name", &frame_name));
   1365       parent = curr_node;
   1366     } else if (IsExit(curr_node)) {
   1367       // Exit to the parent frame.
   1368       parent = parent_nodes[curr_id];
   1369       frame_name = cf_info->frame_names[parent->id()];
   1370       parent = parent_nodes[parent->id()];
   1371     } else {
   1372       parent = parent_nodes[curr_id];
   1373       frame_name = cf_info->frame_names[curr_id];
   1374     }
   1375 
   1376     for (const Edge* out_edge : curr_node->out_edges()) {
   1377       Node* out = out_edge->dst();
   1378       const int out_id = out->id();
   1379 
   1380       // Add to ready queue if not visited.
   1381       bool is_visited = visited[out_id];
   1382       if (!is_visited) {
   1383         ready.push_back(out);
   1384         visited[out_id] = true;
   1385 
   1386         // Process the node 'out'.
   1387         cf_info->frame_names[out_id] = frame_name;
   1388         parent_nodes[out_id] = parent;
   1389         cf_info->unique_frame_names.insert(frame_name);
   1390       }
   1391     }
   1392   }
   1393 
   1394   return Status::OK();
   1395 }
   1396 
   1397 void ExecutorImpl::InitializePending(const Graph* graph,
   1398                                      const ControlFlowInfo& cf_info) {
   1399   for (auto& it : cf_info.unique_frame_names) {
   1400     FrameInfo* finfo = EnsureFrameInfo(it);
   1401     PendingCounts* counts = new PendingCounts(finfo->pending_counts_layout);
   1402     DCHECK_EQ(finfo->pending_counts, nullptr);
   1403     finfo->pending_counts = counts;
   1404   }
   1405   for (const Node* n : graph->nodes()) {
   1406     const int id = n->id();
   1407     const string& name = cf_info.frame_names[id];
   1408     size_t max_pending, max_dead;
   1409     GetMaxPendingCounts(n, &max_pending, &max_dead);
   1410     const NodeItem* item = gview_.node(id);
   1411     PendingCounts* counts = EnsureFrameInfo(name)->pending_counts;
   1412     counts->set_initial_count(item->pending_id, max_pending);
   1413   }
   1414 }
   1415 
   1416 void ExecutorState::RunAsync(Executor::DoneCallback done) {
   1417   const Graph* graph = impl_->graph_.get();
   1418   TaggedNodeSeq ready;
   1419 
   1420   // Ask the device to fill in the device context map.
   1421   Device* device = impl_->params_.device;
   1422   const Status fill_status =
   1423       device->FillContextMap(graph, &device_context_map_);
   1424   if (!fill_status.ok()) {
   1425     done(fill_status);
   1426     return;
   1427   }
   1428 
   1429   // Initialize the ready queue.
   1430   for (const Node* n : impl_->root_nodes_) {
   1431     DCHECK_EQ(n->in_edges().size(), 0);
   1432     ready.push_back(TaggedNode{n, root_frame_, 0, false});
   1433   }
   1434   if (ready.empty()) {
   1435     done(Status::OK());
   1436   } else {
   1437     num_outstanding_ops_ = ready.size();
   1438     root_frame_->iterations[0]->outstanding_ops = ready.size();
   1439     done_cb_ = std::move(done);
   1440     // Schedule to run all the ready ops in thread pool.
   1441     ScheduleReady(ready, nullptr);
   1442   }
   1443 }
   1444 
   1445 // State kept alive for executing an asynchronous node in another
   1446 // thread.  NOTE: We need to make a copy of p.input,
   1447 // p.input_device_contexts, and p.input_alloc_attrs for asynchronous
   1448 // kernels because OpKernelContext methods like input_type(i) needs
   1449 // the param points to valid input type vector. It's not an issue for
   1450 // sync kernels because these vectors are kept on the stack.
   1451 struct ExecutorState::AsyncState {
   1452   AsyncState(const OpKernelContext::Params& p, const TaggedNode& _tagged_node,
   1453              const NodeItem* _item, Entry* _first_input,
   1454              NodeExecStatsWrapper* _stats)
   1455       : saved_inputs(*p.inputs),
   1456         saved_input_device_contexts(*p.input_device_contexts),
   1457         saved_input_alloc_attrs(*p.input_alloc_attrs),
   1458         params(p),
   1459         tagged_node(_tagged_node),
   1460         item(_item),
   1461         first_input(_first_input),
   1462         // ParamsButClearingEigenGPUDevice does equivalent of
   1463         //   params.eigen_gpu_device = nullptr;
   1464         ctx(ParamsButClearingEigenGPUDevice(&params), item->num_outputs),
   1465         stats(_stats) {
   1466     params.inputs = &saved_inputs;
   1467     params.input_device_contexts = &saved_input_device_contexts;
   1468     params.input_alloc_attrs = &saved_input_alloc_attrs;
   1469   }
   1470 
   1471   TensorValueVec saved_inputs;
   1472   DeviceContextVec saved_input_device_contexts;
   1473   AllocatorAttributeVec saved_input_alloc_attrs;
   1474   OpKernelContext::Params params;
   1475   TaggedNode tagged_node;
   1476   const NodeItem* item;
   1477   Entry* first_input;
   1478   OpKernelContext ctx;
   1479   NodeExecStatsWrapper* stats;
   1480 
   1481  private:
   1482   OpKernelContext::Params* ParamsButClearingEigenGPUDevice(
   1483       OpKernelContext::Params* p) {
   1484     // Ensure OpKernelContext constructor will make a new eigen GPU device if
   1485     // necessary.
   1486     p->eigen_gpu_device = nullptr;  // Force allocation
   1487     return p;
   1488   }
   1489 };
   1490 
   1491 void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_usec) {
   1492   const GraphView& gview = impl_->gview_;
   1493   TaggedNodeSeq ready;
   1494   TaggedNodeReadyQueue inline_ready;
   1495 
   1496   // Parameters passed to OpKernel::Compute.
   1497   TensorValueVec inputs;
   1498   DeviceContextVec input_device_contexts;
   1499   AllocatorAttributeVec input_alloc_attrs;
   1500 
   1501   OpKernelContext::Params params;
   1502   params.step_id = step_id_;
   1503   Device* device = impl_->params_.device;
   1504   params.device = device;
   1505   params.log_memory = log_memory_;
   1506   params.record_tensor_accesses = impl_->device_record_tensor_accesses_;
   1507   params.rendezvous = rendezvous_;
   1508   params.session_state = session_state_;
   1509   params.tensor_store = tensor_store_;
   1510   params.cancellation_manager = cancellation_manager_;
   1511   params.call_frame = call_frame_;
   1512   params.function_library = impl_->params_.function_library;
   1513   params.resource_manager = device->resource_manager();
   1514   params.step_container = step_container_;
   1515   params.slice_reader_cache = slice_reader_cache_;
   1516   params.inputs = &inputs;
   1517   params.input_device_contexts = &input_device_contexts;
   1518   params.input_alloc_attrs = &input_alloc_attrs;
   1519   params.runner = &runner_;
   1520   params.stats_collector = stats_collector_;
   1521 
   1522   Status s;
   1523   NodeExecStatsWrapper* stats = nullptr;
   1524   EntryVector outputs;
   1525   bool completed = false;
   1526   inline_ready.push_back(tagged_node);
   1527   while (!inline_ready.empty()) {
   1528     tagged_node = inline_ready.front();
   1529     inline_ready.pop_front();
   1530     const Node* node = tagged_node.node;
   1531     FrameState* input_frame = tagged_node.input_frame;
   1532     const int64 input_iter = tagged_node.input_iter;
   1533     const int id = node->id();
   1534     const NodeItem& item = *gview.node(id);
   1535 
   1536     // TODO(misard) Replace with a finer-grain enabling flag once we
   1537     // add better optional debugging support.
   1538     if (vlog_ && VLOG_IS_ON(1)) {
   1539       mutex_lock l(input_frame->mu);
   1540       input_frame->GetIteration(input_iter)->mark_started(item.pending_id);
   1541     }
   1542 
   1543     // Set the device_context for this node id, if it exists.
   1544     if (id < device_context_map_.size()) {
   1545       params.op_device_context = device_context_map_[id];
   1546     }
   1547 
   1548     params.track_allocations = false;
   1549     stats = nullptr;
   1550     if (stats_collector_ && !tagged_node.is_dead) {
   1551       // track allocations if and only if we are collecting statistics
   1552       params.track_allocations = true;
   1553       stats = new NodeExecStatsWrapper;
   1554       stats->stats()->set_node_name(node->name());
   1555       nodestats::SetScheduled(stats, scheduled_usec);
   1556       nodestats::SetAllStart(stats);
   1557     }
   1558 
   1559     if (vlog_) {
   1560       VLOG(1) << "Process node: " << id << " step " << params.step_id << " "
   1561               << SummarizeNode(*node) << " is dead: " << tagged_node.is_dead;
   1562     }
   1563 
   1564     Entry* input_tensors = GetInputTensors(input_frame, input_iter);
   1565     Entry* first_input = input_tensors + item.input_start;
   1566     outputs.clear();
   1567 
   1568     TensorReferenceVector accessed_tensors;
   1569     DeviceContext* device_context = nullptr;
   1570     // Only execute this node if it is not dead or it is a send/recv
   1571     // transfer node. For transfer nodes, we need to propagate the "dead"
   1572     // bit even when the node is dead.
   1573     bool launched_asynchronously = false;
   1574     if (tagged_node.is_dead && !IsTransferNode(node)) {
   1575       outputs.resize(item.num_outputs);
   1576     } else {
   1577       // Prepares inputs.
   1578       bool is_input_dead = false;
   1579       s = PrepareInputs(item, first_input, &inputs, &input_device_contexts,
   1580                         &input_alloc_attrs, &is_input_dead);
   1581       if (!s.ok()) {
   1582         // Clear inputs.
   1583         int num_inputs = item.num_inputs;
   1584         for (int i = 0; i < num_inputs; ++i) {
   1585           (first_input + i)->ClearVal();
   1586         }
   1587         MaybeMarkCompleted(input_frame, input_iter, id);
   1588         // Continue to process the nodes in 'inline_ready'.
   1589         completed = NodeDone(s, item.node, ready, stats, &inline_ready);
   1590         continue;
   1591       }
   1592 
   1593       // Set up compute params.
   1594       OpKernel* op_kernel = item.kernel;
   1595       params.op_kernel = op_kernel;
   1596       params.frame_iter = FrameAndIter(input_frame->frame_id, input_iter);
   1597       params.is_input_dead = is_input_dead;
   1598       params.output_attr_array = item.output_attrs();
   1599 
   1600       if (item.kernel_is_async) {
   1601         // Asynchronous computes.
   1602         AsyncOpKernel* async = item.kernel->AsAsync();
   1603         DCHECK(async != nullptr);
   1604         launched_asynchronously = true;
   1605         AsyncState* state =
   1606             new AsyncState(params, tagged_node, &item, first_input, stats);
   1607 
   1608         auto done = [this, state]() {
   1609           Device* device = impl_->params_.device;
   1610           NodeExecStatsWrapper* stats = state->stats;  // Shorthand
   1611           Entry* first_input = state->first_input;     // Shorthand
   1612 
   1613           nodestats::SetOpEnd(stats);
   1614           EntryVector outputs;
   1615           Status s = ProcessOutputs(*state->item, &state->ctx, &outputs, stats);
   1616           nodestats::SetMemory(stats, &state->ctx);
   1617           if (vlog_) {
   1618             VLOG(2) << "Async kernel done: " << state->item->node->id()
   1619                     << " step " << step_id_ << " "
   1620                     << SummarizeNode(*state->item->node)
   1621                     << " is dead: " << state->tagged_node.is_dead;
   1622           }
   1623 
   1624           // Clears inputs.
   1625           const int num_inputs = state->item->num_inputs;
   1626           for (int i = 0; i < num_inputs; ++i) {
   1627             (first_input + i)->ClearVal();
   1628           }
   1629           FrameState* input_frame = state->tagged_node.input_frame;
   1630           const int64 input_iter = state->tagged_node.input_iter;
   1631           const int id = state->tagged_node.node->id();
   1632           MaybeMarkCompleted(input_frame, input_iter, id);
   1633           TaggedNodeSeq ready;
   1634           if (s.ok()) {
   1635             PropagateOutputs(state->tagged_node, state->item, &outputs, &ready);
   1636           }
   1637           outputs.clear();
   1638           if (s.ok() && impl_->device_record_tensor_accesses_) {
   1639             // Get the list of all tensors accessed during the execution
   1640             TensorReferenceVector accessed;
   1641             state->ctx.retrieve_accessed_tensors(&accessed);
   1642             nodestats::SetReferencedTensors(stats, accessed);
   1643             // callee takes ownership of the vector
   1644             device->ConsumeListOfAccessedTensors(state->ctx.op_device_context(),
   1645                                                  accessed);
   1646           }
   1647           const bool completed =
   1648               NodeDone(s, state->item->node, ready, stats, nullptr);
   1649           delete state;
   1650           if (completed) Finish();
   1651         };
   1652         nodestats::SetOpStart(stats);
   1653         device->ComputeAsync(async, &state->ctx, done);
   1654       } else {
   1655         // Synchronous computes.
   1656         OpKernelContext ctx(&params, item.num_outputs);
   1657         nodestats::SetOpStart(stats);
   1658         device->Compute(CHECK_NOTNULL(op_kernel), &ctx);
   1659         nodestats::SetOpEnd(stats);
   1660         s = ProcessOutputs(item, &ctx, &outputs, stats);
   1661         if (s.ok() && impl_->device_record_tensor_accesses_) {
   1662           // Get the list of all tensors accessed during the execution
   1663           ctx.retrieve_accessed_tensors(&accessed_tensors);
   1664           device_context = ctx.op_device_context();
   1665         }
   1666         nodestats::SetMemory(stats, &ctx);
   1667       }
   1668     }
   1669 
   1670     if (!launched_asynchronously) {
   1671       if (vlog_) {
   1672         VLOG(2) << "Synchronous kernel done: " << id << " step "
   1673                 << params.step_id << " " << SummarizeNode(*node)
   1674                 << " is dead: " << tagged_node.is_dead;
   1675       }
   1676 
   1677       // Clears inputs.
   1678       const int num_inputs = item.num_inputs;
   1679       for (int i = 0; i < num_inputs; ++i) {
   1680         (first_input + i)->ClearVal();
   1681       }
   1682       MaybeMarkCompleted(input_frame, input_iter, id);
   1683       // Propagates outputs.
   1684       if (s.ok()) {
   1685         PropagateOutputs(tagged_node, &item, &outputs, &ready);
   1686       }
   1687       outputs.clear();
   1688       if (!accessed_tensors.empty()) {
   1689         nodestats::SetReferencedTensors(stats, accessed_tensors);
   1690         // device_context is set above in synchronous computes
   1691         device->ConsumeListOfAccessedTensors(device_context, accessed_tensors);
   1692       }
   1693       if (stats) {
   1694         scheduled_usec = nodestats::NowInUsec();
   1695       }
   1696       // Postprocess.
   1697       completed = NodeDone(s, item.node, ready, stats, &inline_ready);
   1698     }
   1699   }  // while !inline_ready.empty()
   1700 
   1701   // This thread of computation is done if completed = true.
   1702   if (completed) Finish();
   1703 }
   1704 
   1705 Status ExecutorState::PrepareInputs(const NodeItem& item, Entry* first_input,
   1706                                     TensorValueVec* inputs,
   1707                                     DeviceContextVec* input_device_contexts,
   1708                                     AllocatorAttributeVec* input_alloc_attrs,
   1709                                     bool* is_input_dead) {
   1710   const Node* node = item.node;
   1711 
   1712   inputs->clear();
   1713   inputs->resize(item.num_inputs);
   1714   input_device_contexts->clear();
   1715   input_device_contexts->resize(item.num_inputs);
   1716   input_alloc_attrs->clear();
   1717   input_alloc_attrs->resize(item.num_inputs);
   1718 
   1719   *is_input_dead = false;
   1720 
   1721   bool is_merge = item.is_merge;
   1722   for (int i = 0; i < item.num_inputs; ++i) {
   1723     const bool expect_ref = IsRefType(item.input_type(i));
   1724     Entry* entry = first_input + i;
   1725     (*input_device_contexts)[i] = entry->device_context;
   1726     (*input_alloc_attrs)[i] = entry->alloc_attr;
   1727 
   1728     // i-th input.
   1729     TensorValue* inp = &(*inputs)[i];
   1730 
   1731     // Only merge and transfer nodes can have no-value inputs.
   1732     if (!entry->has_value) {
   1733       if (!is_merge) {
   1734         DCHECK(IsTransferNode(node)) << node->name() << " - input " << i;
   1735         DCHECK(!entry->val_field_is_set) << node->name() << " - input " << i;
   1736         entry->has_value = true;
   1737         entry->val_field_is_set = true;
   1738         entry->val.Init(*kEmptyTensor);
   1739         inp->tensor = entry->val.get();
   1740         *is_input_dead = true;
   1741       }
   1742       continue;
   1743     }
   1744     if (entry->ref == nullptr) {
   1745       if (expect_ref) {
   1746         return AttachDef(
   1747             errors::InvalidArgument(i, "-th input expects a ref type"),
   1748             item.kernel->def());
   1749       }
   1750       inp->tensor = entry->val.get();
   1751     } else {
   1752       {
   1753         mutex_lock ml(*entry->ref_mu);
   1754         if (!entry->ref->IsInitialized() && !IsInitializationOp(item.node)) {
   1755           return AttachDef(errors::FailedPrecondition(
   1756                                "Attempting to use uninitialized value ",
   1757                                item.kernel->requested_input(i)),
   1758                            item.kernel->def());
   1759         }
   1760       }
   1761       if (expect_ref) {
   1762         inp->mutex_if_ref = entry->ref_mu;
   1763         inp->tensor = entry->ref;
   1764       } else {
   1765         // Automatically deref the tensor ref when the op expects a
   1766         // tensor but is given a ref to a tensor.  Need to deref it
   1767         // under the mutex.
   1768         {
   1769           mutex_lock l(*(entry->ref_mu));
   1770           DCHECK(!entry->val_field_is_set);
   1771           entry->val.Init(*entry->ref);
   1772           entry->val_field_is_set = true;
   1773         }
   1774         entry->ref = nullptr;
   1775         entry->ref_mu = nullptr;
   1776 
   1777         inp->tensor = entry->val.get();
   1778         // The dtype of entry->ref could have been changed by another operation
   1779         // that ran after the operation that "produced" it executed, so
   1780         // re-validate that the type of the dereferenced tensor matches the
   1781         // expected input type.
   1782         if (item.input_type(i) != inp->tensor->dtype()) {
   1783           return AttachDef(
   1784               errors::InvalidArgument(
   1785                   i, "-th input expects type ",
   1786                   DataTypeString(item.input_type(i)),
   1787                   " but automatically dereferenced input tensor has type ",
   1788                   DataTypeString(inp->tensor->dtype())),
   1789               item.kernel->def());
   1790         }
   1791       }
   1792     }
   1793   }
   1794   return Status::OK();
   1795 }
   1796 
   1797 Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
   1798                                      EntryVector* outputs,
   1799                                      NodeExecStatsWrapper* stats) {
   1800   const Node* node = item.node;
   1801   DCHECK_EQ(0, outputs->size());
   1802   outputs->resize(item.num_outputs);
   1803 
   1804   Status s = ctx->status();
   1805   if (!s.ok()) {
   1806     s = AttachDef(s, item.kernel->def());
   1807     // TODO(misard) Replace with a finer-grain enabling flag once we
   1808     // add better optional debugging support.
   1809     if (vlog_ && VLOG_IS_ON(1)) {
   1810       LOG(WARNING) << this << " Compute status: " << s;
   1811       DumpState();
   1812     }
   1813     if (s.code() == error::RESOURCE_EXHAUSTED) {
   1814       if (stats_collector_) {
   1815         string err = stats_collector_->ReportAllocsOnResourceExhausted(
   1816             s.error_message());
   1817         s = Status(s.code(), strings::StrCat(s.error_message(), err));
   1818       } else {
   1819         s = Status(
   1820             s.code(),
   1821             strings::StrCat(
   1822                 s.error_message(),
   1823                 "\nHint: If you want to see a list of allocated tensors when "
   1824                 "OOM happens, add report_tensor_allocations_upon_oom "
   1825                 "to RunOptions for current allocation info.\n"));
   1826       }
   1827     }
   1828     return s;
   1829   }
   1830 
   1831   // Get the device_context for this node id, if it exists.
   1832   DeviceContext* device_context = nullptr;
   1833   if (node->id() < device_context_map_.size()) {
   1834     device_context = device_context_map_[node->id()];
   1835   }
   1836 
   1837   // Experimental: debugger (tfdb) access to intermediate node completion.
   1838   if (item.num_outputs == 0 && impl_->params_.node_outputs_cb != nullptr) {
   1839     // If the node has no output, invoke the callback with output slot set to
   1840     // -1, signifying that this is a no-output node.
   1841     s.Update(impl_->params_.node_outputs_cb(item.node->name(), -1, nullptr,
   1842                                             false, ctx));
   1843   }
   1844 
   1845   for (int i = 0; i < item.num_outputs; ++i) {
   1846     const TensorValue val = ctx->release_output(i);
   1847     if (*ctx->is_output_dead() || val.tensor == nullptr) {
   1848       // Unless it's a Switch or a Recv, the node must produce a
   1849       // tensor value at i-th output.
   1850       if (!IsSwitch(node) && !IsRecv(node)) {
   1851         s.Update(errors::Internal("Missing ", i, "-th output from ",
   1852                                   SummarizeNode(*node)));
   1853       }
   1854     } else {
   1855       Entry* out = &((*outputs)[i]);
   1856 
   1857       // Set the device context of the output entry.
   1858       out->device_context = device_context;
   1859 
   1860       // Set the allocator attributes of the output entry.
   1861       out->alloc_attr = ctx->output_alloc_attr(i);
   1862 
   1863       // Sanity check of output tensor types.
   1864       DataType dtype;
   1865       if (val.is_ref()) {
   1866         mutex_lock ml(*val.mutex_if_ref);
   1867         dtype = MakeRefType(val->dtype());
   1868       } else {
   1869         dtype = val->dtype();
   1870       }
   1871       if (dtype == item.output_type(i)) {
   1872         if (stats && val.tensor->IsInitialized()) {
   1873           nodestats::SetOutput(stats, i, val.tensor);
   1874         }
   1875         if (val.is_ref()) {
   1876           out->has_value = true;
   1877           out->ref = val.tensor;
   1878           out->ref_mu = val.mutex_if_ref;
   1879           if (log_memory_) {
   1880             Tensor to_log;
   1881             {
   1882               // Dereference the tensor under the lock.
   1883               mutex_lock l(*out->ref_mu);
   1884               to_log = *out->ref;
   1885             }
   1886             LogMemory::RecordTensorOutput(ctx->op_kernel().name(),
   1887                                           ctx->step_id(), i, to_log);
   1888           }
   1889 
   1890           // Experimental: debugger (tfdb) access to intermediate node
   1891           // outputs.
   1892           if (impl_->params_.node_outputs_cb != nullptr) {
   1893             s.Update(impl_->params_.node_outputs_cb(item.node->name(), i,
   1894                                                     out->ref, true, ctx));
   1895           }
   1896         } else {
   1897           // NOTE that std::move is used here, so val.tensor goes to
   1898           // uninitialized state (val.tensor->IsInitialized return false).
   1899           DCHECK(!out->val_field_is_set);
   1900           out->has_value = true;
   1901           out->val_field_is_set = true;
   1902           out->val.Init(std::move(*val.tensor));
   1903           if (log_memory_) {
   1904             LogMemory::RecordTensorOutput(ctx->op_kernel().name(),
   1905                                           ctx->step_id(), i, *out->val);
   1906           }
   1907 
   1908           // Experimental: debugger access to intermediate node outputs.
   1909           if (impl_->params_.node_outputs_cb != nullptr) {
   1910             s.Update(impl_->params_.node_outputs_cb(
   1911                 item.node->name(), i, out->val.get(), false, ctx));
   1912           }
   1913         }
   1914       } else {
   1915         s.Update(errors::Internal("Output ", i, " of type ",
   1916                                   DataTypeString(dtype),
   1917                                   " does not match declared output type ",
   1918                                   DataTypeString(item.output_type(i)),
   1919                                   " for node ", SummarizeNode(*node)));
   1920       }
   1921     }
   1922     if (!val.is_ref()) {
   1923       // If OpKernelContext returns outputs via pass-by-value, we
   1924       // don't need this trouble.
   1925       delete val.tensor;
   1926     }
   1927   }
   1928   return s;
   1929 }
   1930 
   1931 void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node,
   1932                                      const NodeItem* item, EntryVector* outputs,
   1933                                      TaggedNodeSeq* ready) {
   1934   const Node* node = tagged_node.node;
   1935   FrameState* input_frame = tagged_node.input_frame;
   1936   const int64 input_iter = tagged_node.input_iter;
   1937   const bool is_dead = tagged_node.is_dead;
   1938 
   1939   // Propagates outputs along out edges, and puts newly ready nodes
   1940   // into the ready queue.
   1941   ready->clear();
   1942   bool is_frame_done = false;
   1943   FrameState* output_frame = input_frame;
   1944   int64 output_iter = input_iter;
   1945 
   1946   if (!item->is_enter_exit_or_next_iter) {
   1947     // Fast path for nodes types that don't need special handling
   1948     DCHECK_EQ(input_frame, output_frame);
   1949     // Normal path for most nodes
   1950     mutex_lock l(input_frame->mu);
   1951     output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready);
   1952     is_frame_done = input_frame->DecrementOutstandingOpsLocked(
   1953         &impl_->gview_, input_iter, ready);
   1954   } else if (item->is_enter) {
   1955     bool is_constant;
   1956     const Status s = GetNodeAttr(node->attrs(), "is_constant", &is_constant);
   1957     DCHECK(s.ok()) << s;
   1958     FindOrCreateChildFrame(input_frame, input_iter, node, &output_frame);
   1959     output_iter = 0;
   1960     {
   1961       const NodeItem* item = impl_->gview_.node(node->id());
   1962       mutex_lock l(output_frame->mu);
   1963       if (is_constant) {
   1964         // Propagate to all active iterations if this is a loop invariant.
   1965         output_frame->AddLoopInv(item, (*outputs)[0], ready);
   1966       } else {
   1967         output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready);
   1968       }
   1969       output_frame->num_pending_inputs--;
   1970     }
   1971     is_frame_done =
   1972         input_frame->DecrementOutstandingOps(&impl_->gview_, input_iter, ready);
   1973   } else if (item->is_exit) {
   1974     if (is_dead) {
   1975       mutex_lock l(input_frame->mu);
   1976       // Stop and remember this node if it is a dead exit.
   1977       if (input_iter == input_frame->iteration_count) {
   1978         input_frame->dead_exits.push_back(node);
   1979       }
   1980       is_frame_done = input_frame->DecrementOutstandingOpsLocked(
   1981           &impl_->gview_, input_iter, ready);
   1982     } else {
   1983       output_frame = input_frame->parent_frame;
   1984       output_iter = input_frame->parent_iter;
   1985       {
   1986         mutex_lock l(output_frame->mu);
   1987         output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready);
   1988       }
   1989       is_frame_done = input_frame->DecrementOutstandingOps(&impl_->gview_,
   1990                                                            input_iter, ready);
   1991     }
   1992   } else {
   1993     DCHECK(IsNextIteration(node));
   1994     mutex_lock l(input_frame->mu);
   1995     if (is_dead) {
   1996       // Stop the deadness propagation.
   1997       output_frame = nullptr;
   1998     } else {
   1999       if (input_iter == input_frame->iteration_count &&
   2000           input_frame->num_outstanding_iterations ==
   2001               input_frame->max_parallel_iterations) {
   2002         // Reached the maximum for parallel iterations.
   2003         input_frame->next_iter_roots.push_back({node, (*outputs)[0]});
   2004         output_frame = nullptr;
   2005       } else {
   2006         // If this is a new iteration, start it.
   2007         if (input_iter == input_frame->iteration_count) {
   2008           input_frame->IncrementIteration(&impl_->gview_, ready);
   2009         }
   2010         output_iter = input_iter + 1;
   2011       }
   2012     }
   2013     if (output_frame != nullptr) {
   2014       // This is the case when node is not Enter, Exit, or NextIteration.
   2015       DCHECK(input_frame == output_frame);
   2016       output_frame->ActivateNodes(item, is_dead, output_iter, outputs, ready);
   2017     }
   2018     is_frame_done = input_frame->DecrementOutstandingOpsLocked(
   2019         &impl_->gview_, input_iter, ready);
   2020   }
   2021 
   2022   // At this point, this node is completely done. We also know if the
   2023   // completion of this node makes its frame completed.
   2024   if (is_frame_done) {
   2025     FrameState* parent_frame = input_frame->parent_frame;
   2026     const int64 parent_iter = input_frame->parent_iter;
   2027     DeleteFrame(input_frame, ready);
   2028     if (parent_frame != nullptr) {
   2029       // The completion of frame may cause completions in its parent frame.
   2030       // So clean things up recursively.
   2031       CleanupFramesIterations(parent_frame, parent_iter, ready);
   2032     }
   2033   }
   2034 }
   2035 
   2036 bool ExecutorState::NodeDone(const Status& s, const Node* node,
   2037                              const TaggedNodeSeq& ready,
   2038                              NodeExecStatsWrapper* stats,
   2039                              TaggedNodeReadyQueue* inline_ready) {
   2040   nodestats::SetAllEnd(stats);
   2041   if (stats_collector_ != nullptr && !SetTimelineLabel(node, stats)) {
   2042     // Only record non-transfer nodes.
   2043     // Transfers 'stats' ownership to 'stats_collector_'.
   2044     stats_collector_->Save(impl_->params_.device->name(), stats);
   2045   } else if (stats) {
   2046     delete stats;
   2047   }
   2048 
   2049   bool abort_run = false;
   2050   if (!s.ok()) {
   2051     // Some error happened. This thread of computation is done.
   2052     mutex_lock l(mu_);
   2053     if (status_.ok()) {
   2054       abort_run = true;
   2055       status_ = s;
   2056     }
   2057   }
   2058   if (abort_run) {
   2059     TRACEPRINTF("StartAbort: %s", s.ToString().c_str());
   2060     if (rendezvous_) {
   2061       rendezvous_->StartAbort(s);
   2062     }
   2063     if (cancellation_manager_) {
   2064       cancellation_manager_->StartCancel();
   2065     }
   2066   }
   2067 
   2068   bool completed = false;
   2069   const size_t ready_size = ready.size();
   2070   if (ready_size == 0 || !s.ok()) {
   2071     completed = (num_outstanding_ops_.fetch_sub(1) == 1);
   2072   } else if (ready_size > 1) {
   2073     num_outstanding_ops_.fetch_add(ready_size - 1, std::memory_order_relaxed);
   2074   }
   2075 
   2076   // Schedule the ready nodes in 'ready'.
   2077   if (s.ok()) {
   2078     ScheduleReady(ready, inline_ready);
   2079   }
   2080   return completed;
   2081 }
   2082 
   2083 void ExecutorState::ScheduleReady(const TaggedNodeSeq& ready,
   2084                                   TaggedNodeReadyQueue* inline_ready) {
   2085   if (ready.empty()) return;
   2086 
   2087   int64 scheduled_usec = 0;
   2088   if (stats_collector_) {
   2089     scheduled_usec = nodestats::NowInUsec();
   2090   }
   2091   if (inline_ready == nullptr) {
   2092     // Schedule to run all the ready ops in thread pool.
   2093     for (auto& tagged_node : ready) {
   2094       runner_([=]() { Process(tagged_node, scheduled_usec); });
   2095     }
   2096     return;
   2097   }
   2098   const GraphView& gview = impl_->gview_;
   2099   const TaggedNode* curr_expensive_node = nullptr;
   2100   for (auto& tagged_node : ready) {
   2101     const NodeItem& item = *gview.node(tagged_node.node->id());
   2102     if (tagged_node.is_dead || !item.kernel_is_expensive) {
   2103       // Inline this inexpensive node.
   2104       inline_ready->push_back(tagged_node);
   2105     } else {
   2106       if (curr_expensive_node) {
   2107         // Dispatch to another thread since there is plenty of work to
   2108         // do for this thread.
   2109         runner_(std::bind(&ExecutorState::Process, this, *curr_expensive_node,
   2110                           scheduled_usec));
   2111       }
   2112       curr_expensive_node = &tagged_node;
   2113     }
   2114   }
   2115   if (curr_expensive_node) {
   2116     if (inline_ready->empty()) {
   2117       // Tail recursion optimization
   2118       inline_ready->push_back(*curr_expensive_node);
   2119     } else {
   2120       // There are inline nodes to run already. We dispatch this expensive
   2121       // node to other thread.
   2122       runner_(std::bind(&ExecutorState::Process, this, *curr_expensive_node,
   2123                         scheduled_usec));
   2124     }
   2125   }
   2126 }
   2127 
   2128 inline void ExecutorState::MaybeMarkCompleted(FrameState* frame, int64 iter,
   2129                                               int64 node_id) {
   2130   // TODO(misard) Replace with a finer-grain enabling flag once we
   2131   // add better optional debugging support.
   2132   if (vlog_ && VLOG_IS_ON(1)) {
   2133     const NodeItem* item = impl_->gview_.node(node_id);
   2134     mutex_lock l(frame->mu);
   2135     frame->GetIteration(iter)->mark_completed(item->pending_id);
   2136   }
   2137 }
   2138 
   2139 const Tensor* ExecutorState::GetTensorValueForDump(const Entry& input) {
   2140   if (!input.has_value) {
   2141     return kEmptyTensor;
   2142   } else if (input.ref == nullptr) {
   2143     return input.val.get();
   2144   } else {
   2145     return input.ref;
   2146   }
   2147 }
   2148 
   2149 void ExecutorState::DumpPendingNodeState(
   2150     const int node_id, const Entry* input_vector,
   2151     const bool show_nodes_with_no_ready_inputs) {
   2152   const NodeItem& node_item = *impl_->gview_.node(node_id);
   2153   const Node& node = *node_item.node;
   2154   const int input_base = node_item.input_start;
   2155   if (!show_nodes_with_no_ready_inputs) {
   2156     bool has_ready_input = false;
   2157     for (int i = 0; i < node.num_inputs(); ++i) {
   2158       const Entry& input = input_vector[input_base + i];
   2159       const Tensor* tensor = GetTensorValueForDump(input);
   2160       if (tensor->IsInitialized()) {
   2161         has_ready_input = true;
   2162         break;
   2163       }
   2164     }
   2165     if (!has_ready_input) {
   2166       return;
   2167     }
   2168   }
   2169   LOG(WARNING) << "    Pending Node: " << node.DebugString();
   2170   for (int i = 0; i < node.num_inputs(); ++i) {
   2171     const Entry& input = input_vector[input_base + i];
   2172     const Tensor* tensor = GetTensorValueForDump(input);
   2173     if (tensor->IsInitialized()) {
   2174       LOG(WARNING) << "      Input " << i << ": "
   2175                    << strings::StrCat(
   2176                           "Tensor<type: ", DataTypeString(tensor->dtype()),
   2177                           " shape: ", tensor->shape().DebugString(), ">");
   2178     } else {
   2179       LOG(WARNING) << "      Input " << i << ": not present";
   2180     }
   2181   }
   2182 }
   2183 
   2184 void ExecutorState::DumpActiveNodeState(const int node_id,
   2185                                         const Entry* input_vector) {
   2186   const NodeItem& node_item = *impl_->gview_.node(node_id);
   2187   const Node& node = *node_item.node;
   2188   LOG(WARNING) << "    Active Node: " << node.DebugString();
   2189   const int input_base = node_item.input_start;
   2190   for (int i = 0; i < node.num_inputs(); ++i) {
   2191     const Entry& input = input_vector[input_base + i];
   2192     const Tensor* tensor = GetTensorValueForDump(input);
   2193     if (tensor->IsInitialized()) {
   2194       LOG(WARNING) << "      Input " << i << ": "
   2195                    << strings::StrCat(
   2196                           "Tensor<type: ", DataTypeString(tensor->dtype()),
   2197                           " shape: ", tensor->shape().DebugString(), ">");
   2198     } else {
   2199       LOG(WARNING) << "      Input " << i << ": not present";
   2200     }
   2201   }
   2202 }
   2203 
   2204 void ExecutorState::DumpIterationState(const FrameState* frame,
   2205                                        IterationState* iteration) {
   2206   const std::vector<const Node*>* nodes = frame->nodes;
   2207   // Dump any waiting nodes that are holding on to tensors.
   2208   for (const Node* node : *nodes) {
   2209     const int node_id = node->id();
   2210     PendingCounts::Handle pending_id = impl_->gview_.node(node_id)->pending_id;
   2211     if (iteration->node_state(pending_id) == PendingCounts::PENDING_NOTREADY ||
   2212         iteration->node_state(pending_id) == PendingCounts::PENDING_READY) {
   2213       DumpPendingNodeState(node_id, iteration->input_tensors, false);
   2214     }
   2215   }
   2216   // Then the active nodes.
   2217   for (const Node* node : *nodes) {
   2218     const int node_id = node->id();
   2219     PendingCounts::Handle pending_id = impl_->gview_.node(node_id)->pending_id;
   2220     if (iteration->node_state(pending_id) == PendingCounts::STARTED) {
   2221       DumpActiveNodeState(node_id, iteration->input_tensors);
   2222     }
   2223   }
   2224   // Show all input tensors in use.
   2225   const int total_input_tensors = frame->total_input_tensors;
   2226   size_t total_bytes = 0;
   2227   for (int i = 0; i < total_input_tensors; ++i) {
   2228     const Entry& input = iteration->input_tensors[i];
   2229     const Tensor* tensor = GetTensorValueForDump(input);
   2230     if (tensor->IsInitialized()) {
   2231       LOG(WARNING) << "    Input " << i << ": "
   2232                    << strings::StrCat(
   2233                           "Tensor<type: ", DataTypeString(tensor->dtype()),
   2234                           " shape: ", tensor->shape().DebugString(),
   2235                           ", bytes: ", tensor->TotalBytes(), ">");
   2236       total_bytes += tensor->TotalBytes();
   2237     }
   2238   }
   2239   LOG(WARNING) << "    Total bytes " << total_bytes;
   2240 }
   2241 
   2242 void ExecutorState::DumpState() {
   2243   mutex_lock l(mu_);
   2244   if (!dumped_on_error_) {
   2245     LOG(WARNING) << "Dumping state";
   2246     for (auto& frame : outstanding_frames_) {
   2247       LOG(WARNING) << frame.first;
   2248       FrameState* frame_state = frame.second;
   2249       mutex_lock frame_lock(frame_state->mu);
   2250       for (IterationState* iteration : frame_state->iterations) {
   2251         LOG(WARNING) << "  Iteration:";
   2252         DumpIterationState(frame_state, iteration);
   2253       }
   2254     }
   2255     dumped_on_error_ = true;
   2256   }
   2257 }
   2258 
   2259 void ExecutorState::Finish() {
   2260   mu_.lock();
   2261   auto status = status_;
   2262   auto done_cb = std::move(done_cb_);
   2263   auto runner = std::move(runner_);
   2264   mu_.unlock();
   2265   if (sync_on_finish_ && status.ok()) {
   2266     // Block until the device has finished all queued operations. For
   2267     // devices like GPUs that continue to execute Ops after their Compute
   2268     // methods have completed, this ensures that control is not returned to
   2269     // the user until the step (and its side-effects) has actually completed.
   2270     status = impl_->params_.device->Sync();
   2271   }
   2272   delete this;
   2273   CHECK(done_cb != nullptr);
   2274   runner([=]() { done_cb(status); });
   2275 }
   2276 
   2277 void ExecutorState::FindOrCreateChildFrame(FrameState* frame, int64 iter,
   2278                                            const Node* node,
   2279                                            FrameState** child) {
   2280   // Get the child frame name.
   2281   string enter_name;
   2282   Status s = GetNodeAttr(node->attrs(), "frame_name", &enter_name);
   2283   DCHECK(s.ok()) << s;
   2284   const string child_name = MakeFrameName(frame, iter, enter_name);
   2285 
   2286   {
   2287     mutex_lock executor_lock(mu_);
   2288     auto it = outstanding_frames_.find(child_name);
   2289     if (it != outstanding_frames_.end()) {
   2290       *child = it->second;
   2291       return;
   2292     }
   2293   }
   2294 
   2295   // Need to create a new frame instance.
   2296   // Note that this new frame instance is created without any locks.
   2297   if (vlog_) VLOG(2) << "Create frame: " << child_name;
   2298 
   2299   int parallel_iters;
   2300   s = GetNodeAttr(node->attrs(), "parallel_iterations", &parallel_iters);
   2301   DCHECK(s.ok()) << s;
   2302   FrameState* temp = new FrameState(impl_, parallel_iters);
   2303   temp->frame_name = child_name;
   2304   temp->frame_id = Hash64(child_name);
   2305   temp->parent_frame = frame;
   2306   temp->parent_iter = iter;
   2307   temp->InitializeFrameInfo(enter_name);
   2308 
   2309   // 'iterations' is a fixed-length circular buffer.
   2310   temp->iterations.resize(temp->max_parallel_iterations + 1);
   2311   // Initialize iteration 0.
   2312   temp->iterations[0] =
   2313       new IterationState(temp->pending_counts, temp->total_input_tensors);
   2314 
   2315   {
   2316     mutex_lock executor_lock(mu_);
   2317     auto it = outstanding_frames_.find(child_name);
   2318     if (it != outstanding_frames_.end()) {
   2319       *child = it->second;
   2320     } else {
   2321       mutex_lock frame_lock(frame->mu);
   2322       frame->GetIteration(iter)->outstanding_frame_count++;
   2323       outstanding_frames_[child_name] = temp;
   2324       *child = temp;
   2325       temp = nullptr;
   2326     }
   2327   }
   2328   delete temp;  // Not used so delete it.
   2329 }
   2330 
   2331 void ExecutorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) {
   2332   // First, propagate dead_exits (if any) to the parent frame.
   2333   FrameState* parent_frame = frame->parent_frame;
   2334   const int64 parent_iter = frame->parent_iter;
   2335   if (parent_frame != nullptr) {
   2336     mutex_lock paranet_frame_lock(parent_frame->mu);
   2337     // Propagate all the dead exits to the parent frame.
   2338     for (const Node* node : frame->dead_exits) {
   2339       auto parent_iter_state = parent_frame->GetIteration(parent_iter);
   2340       for (const Edge* e : node->out_edges()) {
   2341         const Node* dst_node = e->dst();
   2342 
   2343         const auto dst_pending_id =
   2344             impl_->gview_.node(dst_node->id())->pending_id;
   2345 
   2346         // TODO(yuanbyu): We don't need this if we require the subgraph
   2347         // given to an executor not to contain a sink node.
   2348         if (dst_node->IsSink()) continue;
   2349 
   2350         bool dst_dead = true;
   2351         bool dst_ready = false;
   2352         // We know this is a dead input to dst.
   2353         if (IsMerge(dst_node)) {
   2354           if (e->IsControlEdge()) {
   2355             parent_iter_state->decrement_pending(dst_pending_id, 2);
   2356             int count = parent_iter_state->pending(dst_pending_id);
   2357             int dead_cnt = parent_iter_state->dead_count(dst_pending_id);
   2358             dst_dead = (dead_cnt == dst_node->num_inputs());
   2359             dst_ready = (count == 0) || ((count == 1) && dst_dead);
   2360           } else {
   2361             parent_iter_state->increment_dead_count(dst_pending_id);
   2362             const int dead_cnt = parent_iter_state->dead_count(dst_pending_id);
   2363             dst_dead = (dead_cnt == dst_node->num_inputs());
   2364             dst_ready =
   2365                 (parent_iter_state->pending(dst_pending_id) == 1) && dst_dead;
   2366           }
   2367         } else {
   2368           parent_iter_state->increment_dead_count(dst_pending_id);
   2369           dst_ready =
   2370               (parent_iter_state->decrement_pending(dst_pending_id, 1) == 0);
   2371         }
   2372         if (dst_ready) {
   2373           if (IsControlTrigger(dst_node)) dst_dead = false;
   2374           ready->push_back(
   2375               TaggedNode(dst_node, parent_frame, parent_iter, dst_dead));
   2376           parent_iter_state->outstanding_ops++;
   2377         }
   2378       }
   2379     }
   2380   }
   2381 
   2382   // Delete the frame.
   2383   const string& frame_name = frame->frame_name;
   2384   if (vlog_) VLOG(2) << "Delete frame " << frame_name;
   2385   {
   2386     mutex_lock executor_lock(mu_);
   2387     outstanding_frames_.erase(frame_name);
   2388   }
   2389   delete frame;
   2390 }
   2391 
   2392 void ExecutorState::CleanupFramesIterations(FrameState* frame, int64 iter,
   2393                                             TaggedNodeSeq* ready) {
   2394   bool is_frame_done = false;
   2395   {
   2396     mutex_lock frame_lock(frame->mu);
   2397     frame->GetIteration(iter)->outstanding_frame_count--;
   2398     is_frame_done = frame->CleanupIterations(&impl_->gview_, iter, ready);
   2399   }
   2400   if (is_frame_done) {
   2401     FrameState* parent_frame = frame->parent_frame;
   2402     const int64 parent_iter = frame->parent_iter;
   2403     DeleteFrame(frame, ready);
   2404     if (parent_frame != nullptr) {
   2405       // The completion of frame may cause completions in its parent frame.
   2406       // So clean things up recursively.
   2407       CleanupFramesIterations(parent_frame, parent_iter, ready);
   2408     }
   2409   }
   2410 }
   2411 
   2412 void ExecutorState::FrameState::ActivateNodes(const NodeItem* item,
   2413                                               const bool is_dead, int64 iter,
   2414                                               EntryVector* outputs,
   2415                                               TaggedNodeSeq* ready) {
   2416   const GraphView& gview = executor->gview_;
   2417   IterationState* iter_state = GetIteration(iter);
   2418   const size_t num_output_edges = item->num_output_edges;
   2419   const EdgeInfo* edges = item->output_edge_list();
   2420   Entry* input_tensors = iter_state->input_tensors;
   2421   for (size_t out_index = 0; out_index < num_output_edges; out_index++) {
   2422     const EdgeInfo& e = edges[out_index];
   2423     const int dst_id = e.dst_id;
   2424     const NodeItem* dst_item = gview.node(dst_id);
   2425     const PendingCounts::Handle dst_pending_id = dst_item->pending_id;
   2426     const int src_slot = e.output_slot;
   2427 
   2428     // TODO(yuanbyu): We don't need this if we require the subgraph
   2429     // given to an executor not to contain a sink node.
   2430     if (dst_item->is_sink) continue;
   2431 
   2432     bool dst_dead = false;
   2433     bool dst_ready = false;
   2434     // True iff this input for dst is needed. We only set this input for
   2435     // dst if this flag is true. This is needed to make the thread safety
   2436     // analysis happy.
   2437     const bool is_control_edge = (src_slot == Graph::kControlSlot);
   2438     bool dst_need_input = !is_control_edge;
   2439     if (dst_item->is_merge) {
   2440       // A merge node is ready if all control inputs have arrived and either
   2441       // a) a live data input becomes available or b) all data inputs are
   2442       // dead. For Merge, pending's LSB is set iff a live data input has
   2443       // arrived.
   2444       if (is_control_edge) {
   2445         iter_state->decrement_pending(dst_pending_id, 2);
   2446         int count = iter_state->pending(dst_pending_id);
   2447         int dead_cnt = iter_state->dead_count(dst_pending_id);
   2448         dst_dead = (dead_cnt == dst_item->num_inputs);
   2449         dst_ready = (count == 0) || ((count == 1) && dst_dead);
   2450       } else {
   2451         if ((*outputs)[src_slot].has_value) {
   2452           // This is a live data input.
   2453           int count = iter_state->pending(dst_pending_id);
   2454           iter_state->mark_live(dst_pending_id);
   2455           // Only the first live edge sets the input and (potentially)
   2456           // triggers execution. The low bit of count is set if and
   2457           // only if no live input has been used yet (mark_live clears
   2458           // it). The node should be started if and only if this is
   2459           // the first live input and there are no pending control
   2460           // edges, i.e. count == 1.
   2461           dst_ready = (count == 1);
   2462           dst_need_input = ((count & 0x1) == 1);
   2463         } else {
   2464           // This is a dead data input. Note that dst_node is dead if node is
   2465           // a dead enter. We need this to handle properly a while loop on
   2466           // the untaken branch of a conditional.
   2467           // TODO(yuanbyu): This is a bit hacky, but a good solution for
   2468           // now.
   2469           iter_state->increment_dead_count(dst_pending_id);
   2470           const int dead_cnt = iter_state->dead_count(dst_pending_id);
   2471           dst_dead = (dead_cnt == dst_item->num_inputs) || item->is_enter;
   2472           dst_ready = (iter_state->pending(dst_pending_id) == 1) && dst_dead;
   2473           dst_need_input = false;
   2474         }
   2475       }
   2476     } else {
   2477       const bool increment_dead =
   2478           (is_dead || (!is_control_edge && !(*outputs)[src_slot].has_value));
   2479       int pending, dead;
   2480       iter_state->adjust_for_activation(dst_pending_id, increment_dead,
   2481                                         &pending, &dead);
   2482       dst_dead = (dead > 0);
   2483       dst_ready = (pending == 0);
   2484     }
   2485 
   2486     if (dst_need_input) {
   2487       const int dst_slot = e.input_slot;
   2488       const int dst_loc = dst_item->input_start + dst_slot;
   2489       if (e.is_last) {
   2490         input_tensors[dst_loc] = std::move((*outputs)[src_slot]);
   2491       } else {
   2492         input_tensors[dst_loc] = (*outputs)[src_slot];
   2493       }
   2494     }
   2495 
   2496     // Add dst to the ready queue if it's ready
   2497     if (dst_ready) {
   2498       if (dst_item->is_control_trigger) dst_dead = false;
   2499       ready->push_back(TaggedNode(dst_item->node, this, iter, dst_dead));
   2500       iter_state->outstanding_ops++;
   2501     }
   2502   }
   2503 }
   2504 
   2505 void ExecutorState::FrameState::ActivateNexts(const GraphView* gview,
   2506                                               int64 iter,
   2507                                               TaggedNodeSeq* ready) {
   2508   // Propagate the deferred NextIteration nodes to the new iteration.
   2509   for (auto& node_entry : next_iter_roots) {
   2510     const Node* node = node_entry.first;
   2511     const Entry& entry = node_entry.second;
   2512     const bool is_dead = !entry.has_value;
   2513     const NodeItem* item = gview->node(node->id());
   2514     EntryVector outputs{entry};
   2515     ActivateNodes(item, is_dead, iter, &outputs, ready);
   2516   }
   2517   next_iter_roots.clear();
   2518 }
   2519 
   2520 void ExecutorState::FrameState::ActivateLoopInvs(const GraphView* gview,
   2521                                                  int64 iter,
   2522                                                  TaggedNodeSeq* ready) {
   2523   // Propagate loop invariants to the new iteration.
   2524   for (auto& node_entry : inv_values) {
   2525     const Node* node = node_entry.first;
   2526     const Entry& entry = node_entry.second;
   2527     const bool is_dead = !entry.has_value;
   2528     const NodeItem* item = gview->node(node->id());
   2529     EntryVector outputs{entry};
   2530     ActivateNodes(item, is_dead, iter, &outputs, ready);
   2531   }
   2532 }
   2533 
   2534 void ExecutorState::FrameState::AddLoopInv(const NodeItem* item,
   2535                                            const Entry& entry,
   2536                                            TaggedNodeSeq* ready) {
   2537   // Store this value.
   2538   inv_values.push_back({item->node, entry});
   2539 
   2540   // Make this value available to all iterations.
   2541   const bool is_dead = !entry.has_value;
   2542   for (int i = 0; i <= iteration_count; ++i) {
   2543     EntryVector outputs{entry};
   2544     ActivateNodes(item, is_dead, i, &outputs, ready);
   2545   }
   2546 }
   2547 
   2548 bool ExecutorState::FrameState::IsIterationDone(int64 iter) {
   2549   IterationState* iter_state = GetIteration(iter);
   2550   if (iter_state->outstanding_ops == 0 &&
   2551       iter_state->outstanding_frame_count == 0) {
   2552     if (iter == 0) {
   2553       // The enclosing frame has no pending input.
   2554       return num_pending_inputs == 0;
   2555     } else {
   2556       // The preceding iteration is deleted (and therefore done).
   2557       return (GetIteration(iter - 1) == nullptr);
   2558     }
   2559   }
   2560   return false;
   2561 }
   2562 
   2563 void ExecutorState::FrameState::IncrementIteration(const GraphView* gview,
   2564                                                    TaggedNodeSeq* ready) {
   2565   iteration_count++;
   2566   const int64 next_iter = iteration_count;
   2567 
   2568   // Initialize the next iteration.
   2569   IterationState* iter_state =
   2570       new IterationState(pending_counts, total_input_tensors);
   2571   SetIteration(next_iter, iter_state);
   2572   num_outstanding_iterations++;
   2573   dead_exits.clear();
   2574 
   2575   // Activate the successors of the deferred roots in the new iteration.
   2576   ActivateNexts(gview, next_iter, ready);
   2577 
   2578   // Activate the loop invariants in the new iteration.
   2579   ActivateLoopInvs(gview, next_iter, ready);
   2580 }
   2581 
   2582 bool ExecutorState::FrameState::CleanupIterations(const GraphView* gview,
   2583                                                   int64 iter,
   2584                                                   TaggedNodeSeq* ready) {
   2585   int64 curr_iter = iter;
   2586   while (curr_iter <= iteration_count && IsIterationDone(curr_iter)) {
   2587     // Delete the iteration curr_iter.
   2588     delete GetIteration(curr_iter);
   2589     SetIteration(curr_iter, nullptr);
   2590     --num_outstanding_iterations;
   2591     ++curr_iter;
   2592 
   2593     // When one iteration is completed, we check for deferred iteration,
   2594     // and start it if there is one.
   2595     if (!next_iter_roots.empty()) {
   2596       IncrementIteration(gview, ready);
   2597     }
   2598   }
   2599   return IsFrameDone();
   2600 }
   2601 
   2602 void ExecutorImpl::RunAsync(const Args& args, DoneCallback done) {
   2603   (new ExecutorState(args, this))->RunAsync(std::move(done));
   2604 }
   2605 
   2606 }  // end namespace
   2607 
   2608 Status NewLocalExecutor(const LocalExecutorParams& params,
   2609                         std::unique_ptr<const Graph> graph,
   2610                         Executor** executor) {
   2611   ExecutorImpl* impl = new ExecutorImpl(params, std::move(graph));
   2612   const Status s = impl->Initialize();
   2613   if (s.ok()) {
   2614     *executor = impl;
   2615   } else {
   2616     delete impl;
   2617   }
   2618   return s;
   2619 }
   2620 
   2621 Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib,
   2622                              const NodeDef& ndef, int graph_def_version,
   2623                              OpKernel** kernel) {
   2624   const auto device_type = DeviceType(device->attributes().device_type());
   2625   auto allocator = device->GetAllocator(AllocatorAttributes());
   2626   return CreateOpKernel(device_type, device, allocator, flib, ndef,
   2627                         graph_def_version, kernel);
   2628 }
   2629 
   2630 void DeleteNonCachedKernel(OpKernel* kernel) { delete kernel; }
   2631 
   2632 }  // end namespace tensorflow
   2633