Home | History | Annotate | Download | only in graph
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // A Graph describes a set of computations that are to be
     17 // performed, as well as the dependencies between those
     18 // computations. The basic model is a DAG (directed acyclic graph) with
     19 // * internal nodes representing computational operations to be performed;
     20 // * edges represent dependencies, indicating the target may only be
     21 //   executed once the source has completed; and
     22 // * predefined "source" (start) and "sink" (finish) nodes -- the source
     23 //   should be the only node that doesn't depend on anything, and the sink
     24 //   should be the only node that nothing depends on.
     25 //
     26 // Note: Node ids are intended to be relatively dense in the
     27 // 0..max_id range, but there may be gaps since ids won't be reused.
     28 //
     29 // Note: Some dependencies between operations are due to one operation
     30 // consuming the output of another. In fact operations can produce
     31 // multiple outputs and consume multiple inputs, and some
     32 // optimizations will care about which specific outputs are connected
     33 // to which specific inputs.  We therefore represent data dependency
     34 // between output O of layer A and input I of layer B using
     35 // "input index" and "output index" labels per edge.
     36 
     37 #ifndef TENSORFLOW_GRAPH_GRAPH_H_
     38 #define TENSORFLOW_GRAPH_GRAPH_H_
     39 
     40 #include <functional>
     41 #include <string>
     42 #include <vector>
     43 #include "tensorflow/core/framework/function.h"
     44 #include "tensorflow/core/framework/op.h"
     45 #include "tensorflow/core/framework/types.h"
     46 #include "tensorflow/core/graph/edgeset.h"
     47 #include "tensorflow/core/lib/core/arena.h"
     48 #include "tensorflow/core/lib/core/refcount.h"
     49 #include "tensorflow/core/lib/core/status.h"
     50 #include "tensorflow/core/lib/gtl/iterator_range.h"
     51 #include "tensorflow/core/platform/logging.h"
     52 #include "tensorflow/core/platform/macros.h"
     53 #include "tensorflow/core/platform/types.h"
     54 
     55 namespace tensorflow {
     56 
     57 class Edge;
     58 class EdgeSetTest;
     59 class Graph;
     60 class GraphDef;
     61 class Node;
     62 class VersionDef;
     63 class WhileContext;
     64 
     65 class NeighborIter;    // Declared below
     66 class NodeIter;        // Declared below
     67 class NodeProperties;  // Defined in .cc
     68 
     69 class Node {
     70  public:
     71   string DebugString() const;
     72   int id() const { return id_; }
     73   int cost_id() const { return cost_id_; }
     74   const string& name() const;
     75   const string& type_string() const;
     76 
     77   // def() provides the NodeDef the user supplied, but the specifics
     78   // of this Node may have changed due to placement, optimization, etc.
     79   // In particular:
     80   // * def().name() will match name();
     81   // * def().op() will match type_string() and op_def().name();
     82   // * def().input() is not reliable, use "in_edges()" below instead;
     83   // * def().device() is the "user's requested device" and may not match
     84   //   the actual assigned device, see assigned_device_name() below;
     85   // * def().attr() is authoritative.
     86   // TODO(irving): Replace with NodeInfo.
     87   const NodeDef& def() const;
     88   const OpDef& op_def() const;
     89 
     90   // input and output types
     91   int32 num_inputs() const;
     92   DataType input_type(int32 i) const;
     93   const DataTypeVector& input_types() const;
     94 
     95   int32 num_outputs() const;
     96   DataType output_type(int32 o) const;
     97   const DataTypeVector& output_types() const;
     98 
     99   // The device requested by the user.  For the actual assigned device,
    100   // use assigned_device_name() below.
    101   const string& requested_device() const;
    102 
    103   // This changes the user requested device but not necessarily the device that
    104   // on which the operation will run.
    105   void set_requested_device(const string& device);
    106 
    107   // This gives the device the runtime has assigned this node to.  If
    108   // you want the device the user requested, use def().device() instead.
    109   // TODO(josh11b): Validate that the assigned_device, if not empty:
    110   // fully specifies a device, and satisfies def().device().
    111   // TODO(josh11b): Move assigned_device_name outside of Node into a
    112   // NodeId->DeviceName map.
    113   const string& assigned_device_name() const;
    114   void set_assigned_device_name(const string& device_name);
    115   bool has_assigned_device_name() const {
    116     return assigned_device_name_index_ > 0;
    117   }
    118   int assigned_device_name_index() const { return assigned_device_name_index_; }
    119   void set_assigned_device_name_index(int index);
    120 
    121   // Read only access to attributes
    122   AttrSlice attrs() const;
    123 
    124   // Inputs requested by the NodeDef.  For the actual inputs, use in_edges.
    125   const protobuf::RepeatedPtrField<string>& requested_inputs() const;
    126 
    127   // Get the neighboring nodes via edges either in or out of this node.
    128   gtl::iterator_range<NeighborIter> in_nodes() const;
    129   gtl::iterator_range<NeighborIter> out_nodes() const;
    130   const EdgeSet& in_edges() const { return in_edges_; }
    131   const EdgeSet& out_edges() const { return out_edges_; }
    132 
    133   // Node type helpers.
    134   bool IsSource() const { return id() == 0; }
    135   bool IsSink() const { return id() == 1; }
    136   // Anything other than the special Source & Sink nodes.
    137   bool IsOp() const { return id() > 1; }
    138 
    139   // Node class helpers
    140   bool IsSwitch() const { return class_ == NC_SWITCH; }
    141   bool IsMerge() const { return class_ == NC_MERGE; }
    142   bool IsEnter() const { return class_ == NC_ENTER; }
    143   bool IsExit() const { return class_ == NC_EXIT; }
    144   bool IsNextIteration() const { return class_ == NC_NEXT_ITERATION; }
    145   bool IsLoopCond() const { return class_ == NC_LOOP_COND; }
    146   bool IsControlTrigger() const { return class_ == NC_CONTROL_TRIGGER; }
    147   bool IsSend() const { return class_ == NC_SEND || class_ == NC_HOST_SEND; }
    148   bool IsRecv() const { return class_ == NC_RECV || class_ == NC_HOST_RECV; }
    149   bool IsConstant() const { return class_ == NC_CONSTANT; }
    150   bool IsVariable() const { return class_ == NC_VARIABLE; }
    151   bool IsIdentity() const { return class_ == NC_IDENTITY; }
    152   bool IsGetSessionHandle() const { return class_ == NC_GET_SESSION_HANDLE; }
    153   bool IsGetSessionTensor() const { return class_ == NC_GET_SESSION_TENSOR; }
    154   bool IsDeleteSessionTensor() const {
    155     return class_ == NC_DELETE_SESSION_TENSOR;
    156   }
    157   bool IsControlFlow() const {
    158     return (class_ != NC_OTHER) &&  // Fast path
    159            (IsSwitch() || IsMerge() || IsEnter() || IsExit() ||
    160             IsNextIteration());
    161   }
    162   bool IsHostSend() const { return class_ == NC_HOST_SEND; }
    163   bool IsHostRecv() const { return class_ == NC_HOST_RECV; }
    164 
    165   bool IsMetadata() const { return class_ == NC_METADATA; }
    166 
    167   template <typename T>
    168   void AddAttr(const string& name, const T& val) {
    169     SetAttrValue(val, AddAttrHelper(name));
    170   }
    171 
    172   void ClearAttr(const string& name);
    173 
    174   // Returns into '*e' the edge connecting to the 'idx' input of this Node.
    175   Status input_edge(int idx, const Edge** e) const;
    176 
    177   // Returns into '*edges' the input data edges of this Node, indexed by input
    178   // number. Does not return control edges.
    179   Status input_edges(std::vector<const Edge*>* edges) const;
    180 
    181   // Returns into '*n' the node that has an output connected to the
    182   // 'idx' input of this Node.
    183   Status input_node(int idx, const Node** n) const;
    184   Status input_node(int idx, Node** n) const;
    185 
    186   WhileContext* while_ctx() const { return while_ctx_; }
    187   void set_while_ctx(WhileContext* while_ctx) {
    188     DCHECK(IsExit());
    189     DCHECK(while_ctx_ == nullptr);
    190     while_ctx_ = while_ctx;
    191   }
    192 
    193  private:
    194   friend class Graph;
    195   Node();
    196 
    197   NodeProperties* properties() const { return props_.get(); }
    198 
    199   void Initialize(int id, int cost_id, std::shared_ptr<NodeProperties> props);
    200 
    201   // Releases memory from props_, in addition to restoring *this to its
    202   // uninitialized state.
    203   void Clear();
    204 
    205   // Make a copy of the Node's props_ if props_ is shared with
    206   // other nodes. This must be called before mutating properties,
    207   // e.g. in AddAttr.
    208   void MaybeCopyOnWrite();
    209 
    210   AttrValue* AddAttrHelper(const string& name);
    211 
    212   // A set of mutually exclusive classes for different kinds of nodes,
    213   // class_ is initialized in the Node::Initialize routine based on the
    214   // node's type_string().
    215   enum NodeClass {
    216     NC_UNINITIALIZED,
    217     NC_SWITCH,
    218     NC_MERGE,
    219     NC_ENTER,
    220     NC_EXIT,
    221     NC_NEXT_ITERATION,
    222     NC_LOOP_COND,
    223     NC_CONTROL_TRIGGER,
    224     NC_SEND,
    225     NC_HOST_SEND,
    226     NC_RECV,
    227     NC_HOST_RECV,
    228     NC_CONSTANT,
    229     NC_VARIABLE,
    230     NC_IDENTITY,
    231     NC_GET_SESSION_HANDLE,
    232     NC_GET_SESSION_TENSOR,
    233     NC_DELETE_SESSION_TENSOR,
    234     NC_METADATA,
    235     NC_OTHER  // Not a special kind of node
    236   };
    237 
    238   static const std::unordered_map<string, NodeClass>& kNodeClassTable;
    239 
    240   static NodeClass GetNodeClassForOp(const string& ts);
    241 
    242   int id_;       // -1 until Initialize() is called
    243   int cost_id_;  // -1 if there is no corresponding cost accounting node
    244   NodeClass class_;
    245 
    246   EdgeSet in_edges_;
    247   EdgeSet out_edges_;
    248 
    249   // NOTE(skyewm): inheriting from core::RefCounted may have a slight
    250   // performance benefit over using shared_ptr, at the cost of manual ref
    251   // counting
    252   std::shared_ptr<NodeProperties> props_;
    253 
    254   // Index within Graph::device_names_ of the name of device assigned
    255   // to perform this computation.
    256   int assigned_device_name_index_;
    257 
    258   // A back-pointer to the Graph that owns this node.  Currently, this exists
    259   // solely to allow Node::[set_]assigned_device_name() to work. However, if all
    260   // callers of Node::[set_]assigned_device_name() are modified to use the
    261   // equivalent methods defined directly on Graph, then we can remove this
    262   // field and reclaim that memory.
    263   Graph* graph_;
    264 
    265   // Set if this is an exit node of a while loop with an associated
    266   // WhileContext. Otherwise null. (This is only set for exit nodes because
    267   // they're the first nodes of a loop encountered while creating the gradient
    268   // graph. Exit nodes that are part of while loop gradient graphs will not have
    269   // this set.)
    270   WhileContext* while_ctx_;
    271 
    272   TF_DISALLOW_COPY_AND_ASSIGN(Node);
    273 };
    274 
    275 // Represents an input of a node, i.e., the `index`-th input to `node`.
    276 struct InputTensor {
    277   const Node* node;
    278   int index;
    279 
    280   InputTensor(const Node* n, int i) : node(n), index(i) {}
    281   InputTensor() : node(nullptr), index(0) {}
    282 };
    283 
    284 // Represents an output of a node, i.e., the `index`-th output of `node`. Note
    285 // that a single `OutputTensor` can correspond to multiple `Edge`s if the output
    286 // is consumed by multiple destination nodes.
    287 struct OutputTensor {
    288   const Node* node;
    289   int index;
    290 
    291   OutputTensor(const Node* n, int i) : node(n), index(i) {}
    292   OutputTensor() : node(nullptr), index(0) {}
    293 };
    294 
    295 class Edge {
    296  public:
    297   Node* src() const { return src_; }
    298   Node* dst() const { return dst_; }
    299   int id() const { return id_; }
    300 
    301   // Return the index of the source output that produces the data
    302   // carried by this edge.  The special value kControlSlot is used
    303   // for control dependencies.
    304   int src_output() const { return src_output_; }
    305 
    306   // Return the index of the destination input that consumes the data
    307   // carried by this edge.  The special value kControlSlot is used
    308   // for control dependencies.
    309   int dst_input() const { return dst_input_; }
    310 
    311   // Return true iff this is an edge that indicates a control-flow
    312   // (as opposed to a data-flow) dependency.
    313   bool IsControlEdge() const;
    314 
    315   string DebugString() const;
    316 
    317  private:
    318   Edge() {}
    319 
    320   friend class EdgeSetTest;
    321   friend class Graph;
    322   Node* src_;
    323   Node* dst_;
    324   int id_;
    325   int src_output_;
    326   int dst_input_;
    327 };
    328 
    329 // Allows for iteration of the edges of a Graph, by iterating the underlying
    330 // Graph.edges_ vector while skipping over null entries.
    331 class GraphEdgesIterable {
    332  private:
    333   const std::vector<Edge*>& edges_;
    334 
    335  public:
    336   explicit GraphEdgesIterable(const std::vector<Edge*>& edges)
    337       : edges_(edges) {}
    338 
    339   typedef Edge* value_type;
    340 
    341   class const_iterator {
    342    private:
    343     // The underlying iterator.
    344     std::vector<value_type>::const_iterator iter_;
    345 
    346     // The end of the underlying iterator.
    347     std::vector<value_type>::const_iterator end_;
    348 
    349     // Advances iter_ until it reaches a non-null item, or reaches the end.
    350     void apply_filter() {
    351       while (iter_ != end_ && *iter_ == nullptr) {
    352         ++iter_;
    353       }
    354     }
    355 
    356    public:
    357     const_iterator(std::vector<value_type>::const_iterator iter,
    358                    std::vector<value_type>::const_iterator end)
    359         : iter_(iter), end_(end) {
    360       apply_filter();
    361     }
    362 
    363     bool operator==(const const_iterator& other) const {
    364       return iter_ == other.iter_;
    365     }
    366 
    367     bool operator!=(const const_iterator& other) const {
    368       return iter_ != other.iter_;
    369     }
    370 
    371     // This is the prefix increment operator (++x), which is the operator
    372     // used by C++ range iteration (for (x : y) ...).  We intentionally do not
    373     // provide a postfix increment operator.
    374     const_iterator& operator++() {
    375       ++iter_;
    376       apply_filter();
    377       return *this;
    378     }
    379 
    380     value_type operator*() { return *iter_; }
    381   };
    382 
    383   const_iterator begin() {
    384     return const_iterator(edges_.begin(), edges_.end());
    385   }
    386   const_iterator end() { return const_iterator(edges_.end(), edges_.end()); }
    387 };
    388 
    389 // Thread compatible but not thread safe.
    390 class Graph {
    391  public:
    392   // Constructs a graph with a single SOURCE (always id kSourceId) and a
    393   // single SINK (always id kSinkId) node, and an edge from SOURCE->SINK.
    394   //
    395   // The graph can hold ops found in registry. `registry`s lifetime must be at
    396   // least that of the constructed graph's.
    397   explicit Graph(const OpRegistryInterface* registry);
    398 
    399   // Constructs a graph with a single SOURCE (always id kSourceId) and a
    400   // single SINK (always id kSinkId) node, and an edge from SOURCE->SINK.
    401   //
    402   // The graph can hold ops found in `flib_def`. Unlike the constructor taking
    403   // an OpRegistryInterface, this constructor copies the function definitions in
    404   // `flib_def` so its lifetime may be shorter than that of the graph's. The
    405   // OpRegistryInterface backing `flib_def` must still have the lifetime of the
    406   // graph though.
    407   explicit Graph(const FunctionLibraryDefinition& flib_def);
    408 
    409   ~Graph();
    410 
    411   static const int kControlSlot;
    412 
    413   // The GraphDef version range of this graph (see graph.proto).
    414   const VersionDef& versions() const;
    415   void set_versions(const VersionDef& versions);
    416 
    417   // Adds a new node to this graph, and returns it. Infers the Op and
    418   // input/output types for the node. *this owns the returned instance.
    419   // Returns nullptr and sets *status on error.
    420   Node* AddNode(const NodeDef& node_def, Status* status);
    421 
    422   // Copies *node, which may belong to another graph, to a new node,
    423   // which is returned.  Does not copy any edges.  *this owns the
    424   // returned instance.
    425   Node* CopyNode(Node* node);
    426 
    427   // Removes a node from this graph, including all edges from or to it.
    428   // *node should not be accessed after calling this function.
    429   // REQUIRES: node->IsOp()
    430   void RemoveNode(Node* node);
    431 
    432   // Adds an edge that connects the xth output of `source` to the yth input of
    433   // `dest` and returns it. Does not update dest's NodeDef.
    434   const Edge* AddEdge(Node* source, int x, Node* dest, int y);
    435 
    436   // Adds a control edge (no data flows along this edge) that connects `source`
    437   // to `dest`. If `dest`s NodeDef is missing the corresponding control input,
    438   // adds the control input.
    439   //
    440   // If such a control edge already exists and `allow_duplicates` is false, no
    441   // edge is added and the function returns nullptr. Otherwise the edge is
    442   // unconditionally created and returned. The NodeDef is not updated if
    443   // `allow_duplicates` is true.
    444   // TODO(skyewm): // TODO(skyewm): allow_duplicates is needed only by
    445   // graph_partition.cc. Figure out if we can do away with it.
    446   const Edge* AddControlEdge(Node* source, Node* dest,
    447                              bool allow_duplicates = false);
    448 
    449   // Removes edge from the graph. Does not update the destination node's
    450   // NodeDef.
    451   // REQUIRES: The edge must exist.
    452   void RemoveEdge(const Edge* edge);
    453 
    454   // Removes control edge `edge` from the graph. Note that this also updates
    455   // the corresponding NodeDef to reflect the change.
    456   // REQUIRES: The control edge must exist.
    457   void RemoveControlEdge(const Edge* e);
    458   // Updates the input to a node.  The existing edge to `dst` is removed and an
    459   // edge from `new_src` to `dst` is created. The NodeDef associated with `dst`
    460   // is also updated.
    461   Status UpdateEdge(Node* new_src, int new_src_index, Node* dst, int dst_index);
    462 
    463   // Adds the function and gradient definitions in `fdef_lib` to this graph's op
    464   // registry. Ignores duplicate functions, and returns a bad status if an
    465   // imported function differs from an existing function or op with the same
    466   // name.
    467   Status AddFunctionLibrary(const FunctionDefLibrary& fdef_lib);
    468 
    469   // The number of live nodes in the graph.
    470   //
    471   // Because nodes can be removed from the graph, num_nodes() is often
    472   // smaller than num_node_ids(). If one needs to create an array of
    473   // nodes indexed by node ids, num_node_ids() should be used as the
    474   // array's size.
    475   int num_nodes() const { return num_nodes_; }
    476 
    477   // The number of live nodes in the graph, excluding the Source and Sink nodes.
    478   int num_op_nodes() const {
    479     DCHECK_GE(num_nodes_, 2);
    480     return num_nodes_ - 2;
    481   }
    482 
    483   // The number of live edges in the graph.
    484   //
    485   // Because edges can be removed from the graph, num_edges() is often
    486   // smaller than num_edge_ids(). If one needs to create an array of
    487   // edges indexed by edge ids, num_edge_ids() should be used as the
    488   // array's size.
    489   int num_edges() const { return num_edges_; }
    490 
    491   // Serialize the nodes starting at `from_node_id` to a GraphDef.
    492   void ToGraphDefSubRange(GraphDef* graph_def, int from_node_id) const;
    493 
    494   // Serialize to a GraphDef.
    495   void ToGraphDef(GraphDef* graph_def) const;
    496 
    497   // This version can be called from debugger to inspect the graph content.
    498   // Use the previous version outside debug context for efficiency reasons.
    499   //
    500   // Note: We do not expose a DebugString() API, since GraphDef.DebugString() is
    501   // not defined in some TensorFlow builds.
    502   GraphDef ToGraphDefDebug() const;
    503 
    504   // Generate new node name with the specified prefix that is unique
    505   // across this graph.
    506   string NewName(StringPiece prefix);
    507 
    508   // Access to the list of all nodes.  Example usage:
    509   //   for (Node* node : graph.nodes()) { ... }
    510   gtl::iterator_range<NodeIter> nodes() const;
    511 
    512   // Access to the list of all nodes, excluding the Source and Sink nodes.
    513   gtl::iterator_range<NodeIter> op_nodes() const;
    514 
    515   // Returns one more than the maximum id assigned to any node.
    516   int num_node_ids() const { return nodes_.size(); }
    517 
    518   // Returns the node associated with an id, or nullptr if no node
    519   // with that id (the node with that id was removed and the id has
    520   // not yet been re-used). *this owns the returned instance.
    521   // REQUIRES: 0 <= id < num_node_ids().
    522   Node* FindNodeId(int id) const { return nodes_[id]; }
    523 
    524   // Returns one more than the maximum id assigned to any edge.
    525   int num_edge_ids() const { return edges_.size(); }
    526 
    527   // Returns the Edge associated with an id, or nullptr if no edge
    528   // with that id (the node with that id was removed and the id has
    529   // not yet been re-used). *this owns the returned instance.
    530   // REQUIRES: 0 <= id < num_node_ids().
    531   const Edge* FindEdgeId(int id) const { return edges_[id]; }
    532 
    533   // Access to the set of all edges.  Example usage:
    534   //   for (const Edge* e : graph.edges()) { ... }
    535   GraphEdgesIterable edges() const { return GraphEdgesIterable(edges_); }
    536 
    537   // The pre-defined nodes.
    538   enum { kSourceId = 0, kSinkId = 1 };
    539   Node* source_node() const { return FindNodeId(kSourceId); }
    540   Node* sink_node() const { return FindNodeId(kSinkId); }
    541 
    542   const OpRegistryInterface* op_registry() const { return &ops_; }
    543   const FunctionLibraryDefinition& flib_def() const { return ops_; }
    544 
    545   void CheckDeviceNameIndex(int index) {
    546     DCHECK_GE(index, 0);
    547     DCHECK_LT(index, static_cast<int>(device_names_.size()));
    548   }
    549 
    550   int InternDeviceName(const string& device_name);
    551 
    552   const string& get_assigned_device_name(const Node& node) const {
    553     return device_names_[node.assigned_device_name_index()];
    554   }
    555 
    556   void set_assigned_device_name_index(Node* node, int device_name_index) {
    557     CheckDeviceNameIndex(device_name_index);
    558     node->assigned_device_name_index_ = device_name_index;
    559   }
    560 
    561   void set_assigned_device_name(Node* node, const string& device_name) {
    562     node->assigned_device_name_index_ = InternDeviceName(device_name);
    563   }
    564 
    565   // Returns OK if `node` is non-null and belongs to this graph
    566   Status IsValidNode(const Node* node) const;
    567 
    568   // Returns OK if IsValidNode(`node`) and `idx` is less than
    569   // node->num_outputs()
    570   Status IsValidOutputTensor(const Node* node, int idx) const;
    571 
    572   // Returns OK if IsValidNode(`node`) and `idx` is less than
    573   // node->num_inputs()
    574   Status IsValidInputTensor(const Node* node, int idx) const;
    575 
    576   // Create and return a new WhileContext owned by this graph. This is called
    577   // when a new while loop is created. `frame_name` must be unique among
    578   // WhileContexts in this graph.
    579   Status AddWhileContext(StringPiece frame_name, std::vector<Node*> enter_nodes,
    580                          std::vector<Node*> exit_nodes,
    581                          OutputTensor cond_output,
    582                          std::vector<OutputTensor> body_inputs,
    583                          std::vector<OutputTensor> body_outputs,
    584                          WhileContext** result);
    585 
    586   // TODO(josh11b): uint64 hash() const;
    587 
    588  private:
    589   // If cost_node is non-null, then cost accounting (in CostModel)
    590   // will be associated with that node rather than the new one being
    591   // created.
    592   //
    593   // Ownership of the returned Node is not transferred to caller.
    594   Node* AllocateNode(std::shared_ptr<NodeProperties> props,
    595                      const Node* cost_node);
    596   void ReleaseNode(Node* node);
    597 
    598   // Registry of all known ops, including functions.
    599   FunctionLibraryDefinition ops_;
    600 
    601   // GraphDef versions
    602   const std::unique_ptr<VersionDef> versions_;
    603 
    604   // Allocator which will give us good locality.
    605   core::Arena arena_;
    606 
    607   // Map from node ids to allocated nodes.  nodes_[id] may be nullptr if
    608   // the node with that id was removed from the graph.
    609   std::vector<Node*> nodes_;
    610 
    611   // Number of nodes alive.
    612   int64 num_nodes_ = 0;
    613 
    614   // Map from edge ids to allocated edges.  edges_[id] may be nullptr if
    615   // the edge with that id was removed from the graph.
    616   std::vector<Edge*> edges_;
    617 
    618   // The number of entries in edges_ that are not nullptr.
    619   int num_edges_ = 0;
    620 
    621   // Allocated but free nodes and edges.
    622   std::vector<Node*> free_nodes_;
    623   std::vector<Edge*> free_edges_;
    624 
    625   // For generating unique names.
    626   int name_counter_ = 0;
    627 
    628   // In most graphs, the number of unique values used for the
    629   // Node::assigned_device_name() property is quite small.  If the graph is
    630   // large, then this duplication of values can consume a significant amount of
    631   // memory.  Instead, we represent the same information using an interning
    632   // table, which consists of a vector of unique strings (device_names_), as
    633   // well a map (device_names_map_) from unique strings to indices within the
    634   // unique string table.
    635   //
    636   // The InternDeviceName() method handles adding a new entry into the table,
    637   // or locating the index of an existing entry.
    638   //
    639   // The fact that Node::assigned_device_name() is implemented using an
    640   // interning table is intentionally public.  This allows algorithms that
    641   // frequently access this field to do so efficiently, especially for the case
    642   // where the assigned_device_name of one Node is copied directly from that
    643   // of another Node.
    644 
    645   // A table of the unique assigned device names.  Indices do NOT correspond
    646   // to node IDs.  Index 0 is always the empty string.
    647   std::vector<string> device_names_;
    648 
    649   // Maps unique device names to indices within device_names_[i].
    650   std::unordered_map<string, int> device_names_map_;
    651 
    652   // All the while contexts owned by this graph, keyed by frame name,
    653   // corresponding to all the while loops contained in this graph (including
    654   // nested loops). The stored contexts are usually accessed via
    655   // AddWhileContext() or Node::while_ctx(), but this manages the lifetime.
    656   std::map<string, WhileContext> while_ctxs_;
    657 
    658   // Searches through edges_ for the Edge whose destination node and index
    659   // matches dst. An edge with destination `dst` must exist in the graph.
    660   const Edge* FindEdge(const Node* dst, int index);
    661 
    662   TF_DISALLOW_COPY_AND_ASSIGN(Graph);
    663 };
    664 
    665 // TODO(josh11b): We may want to support keeping an index on various
    666 // node/edge attributes in a graph, particularly node names.
    667 
    668 // Helper routines
    669 
    670 inline bool IsSource(const Node* node) { return node->IsSource(); }
    671 inline bool IsSink(const Node* node) { return node->IsSink(); }
    672 inline bool IsSwitch(const Node* node) { return node->IsSwitch(); }
    673 inline bool IsMerge(const Node* node) { return node->IsMerge(); }
    674 inline bool IsEnter(const Node* node) { return node->IsEnter(); }
    675 inline bool IsExit(const Node* node) { return node->IsExit(); }
    676 inline bool IsNextIteration(const Node* n) { return n->IsNextIteration(); }
    677 inline bool IsLoopCond(const Node* node) { return node->IsLoopCond(); }
    678 inline bool IsControlTrigger(const Node* n) { return n->IsControlTrigger(); }
    679 inline bool IsSend(const Node* node) { return node->IsSend(); }
    680 inline bool IsRecv(const Node* node) { return node->IsRecv(); }
    681 inline bool IsHostSend(const Node* node) { return node->IsHostSend(); }
    682 inline bool IsHostRecv(const Node* node) { return node->IsHostRecv(); }
    683 
    684 // True for Nodes that mediate the transfer of values between processes.
    685 inline bool IsTransferNode(const Node* n) { return IsSend(n) || IsRecv(n); }
    686 
    687 inline bool IsConstant(const Node* node) { return node->IsConstant(); }
    688 inline bool IsVariable(const Node* node) { return node->IsVariable(); }
    689 inline bool IsIdentity(const Node* node) { return node->IsIdentity(); }
    690 
    691 // Returns true iff 'n' is a control flow node.
    692 inline bool IsControlFlow(const Node* n) { return n->IsControlFlow(); }
    693 
    694 // Returns true if the node only depends on its input's metadata
    695 // (shape).  Specifically, returns true for "Size", "Shape" and "Rank" ops.
    696 inline bool IsMetadata(const Node* n) { return n->IsMetadata(); }
    697 
    698 inline bool IsHostMemoryPreserving(const Node* node) {
    699   return IsIdentity(node) || IsControlFlow(node);
    700 }
    701 
    702 // Iterator for stepping through the nodes of a graph.
    703 class NodeIter {
    704  public:
    705   NodeIter(const Graph* graph, int id);
    706   bool operator==(const NodeIter& rhs);
    707   bool operator!=(const NodeIter& rhs);
    708   void operator++();
    709   Node* operator*();
    710   Node* operator->();
    711 
    712  private:
    713   // Invariant: id_ == graph_->num_node_ids() || graph_->FindId(id_) != nullptr
    714   const Graph* graph_;
    715   int id_;
    716 };
    717 
    718 // Iterator for stepping through the neighbors of a node.
    719 class NeighborIter {
    720  public:
    721   NeighborIter(EdgeSet::const_iterator iter, bool incoming);
    722   bool operator==(const NeighborIter& rhs);
    723   bool operator!=(const NeighborIter& rhs);
    724   void operator++();
    725   Node* operator*();
    726   Node* operator->();
    727 
    728  private:
    729   EdgeSet::const_iterator iter_;
    730   bool incoming_;
    731 };
    732 
    733 // IMPLEMENTATION DETAILS, PLEASE IGNORE
    734 
    735 inline NodeIter::NodeIter(const Graph* graph, int id)
    736     : graph_(graph), id_(id) {}
    737 
    738 inline bool NodeIter::operator==(const NodeIter& rhs) {
    739   DCHECK(graph_ == rhs.graph_);
    740   return id_ == rhs.id_;
    741 }
    742 
    743 inline bool NodeIter::operator!=(const NodeIter& rhs) {
    744   return !(*this == rhs);
    745 }
    746 
    747 inline void NodeIter::operator++() {
    748   while (1) {
    749     DCHECK_LE(id_, graph_->num_node_ids());
    750     ++id_;
    751     if (id_ >= graph_->num_node_ids() || graph_->FindNodeId(id_) != nullptr) {
    752       return;
    753     }
    754   }
    755 }
    756 
    757 inline Node* NodeIter::operator*() { return graph_->FindNodeId(id_); }
    758 
    759 inline Node* NodeIter::operator->() { return graph_->FindNodeId(id_); }
    760 
    761 inline NeighborIter::NeighborIter(EdgeSet::const_iterator iter, bool incoming)
    762     : iter_(iter), incoming_(incoming) {}
    763 
    764 inline bool NeighborIter::operator==(const NeighborIter& rhs) {
    765   return iter_ == rhs.iter_ && incoming_ == rhs.incoming_;
    766 }
    767 
    768 inline bool NeighborIter::operator!=(const NeighborIter& rhs) {
    769   return !(*this == rhs);
    770 }
    771 
    772 inline void NeighborIter::operator++() { ++iter_; }
    773 
    774 inline Node* NeighborIter::operator*() {
    775   const Edge* e = *iter_;
    776   return incoming_ ? e->src() : e->dst();
    777 }
    778 
    779 inline Node* NeighborIter::operator->() {
    780   const Edge* e = *iter_;
    781   return incoming_ ? e->src() : e->dst();
    782 }
    783 
    784 inline bool Edge::IsControlEdge() const {
    785   // Note that if either src_output_ or dst_input_ is kControlSlot,
    786   // so is the other one (AddEdge checks this).
    787   return src_output_ == Graph::kControlSlot;
    788 }
    789 
    790 inline gtl::iterator_range<NodeIter> Graph::nodes() const {
    791   // Note that NodeId 0 is always valid since we don't let the source
    792   // node be removed from the graph.
    793   return gtl::make_range(NodeIter(this, 0), NodeIter(this, num_node_ids()));
    794 }
    795 
    796 inline gtl::iterator_range<NodeIter> Graph::op_nodes() const {
    797   // Note that NodeId 0 is always valid since we don't let the source
    798   // node be removed from the graph.
    799   //
    800   // The current implementation of Graph maintains the invariant that the
    801   // first two nodes are the source and sink nodes, and all other nodes are op
    802   // nodes. This method (op_nodes()) relies on this invariant.
    803   NodeIter begin(this, 0);
    804   NodeIter end(this, num_node_ids());
    805   if (begin != end) {
    806     ++begin;
    807   }
    808   if (begin != end) {
    809     ++begin;
    810   }
    811   return gtl::make_range(begin, end);
    812 }
    813 
    814 inline void Node::set_assigned_device_name_index(int index) {
    815   graph_->CheckDeviceNameIndex(index);
    816   assigned_device_name_index_ = index;
    817 }
    818 
    819 inline void Node::set_assigned_device_name(const string& device_name) {
    820   graph_->set_assigned_device_name(this, device_name);
    821 }
    822 
    823 inline const string& Node::assigned_device_name() const {
    824   return graph_->get_assigned_device_name(*this);
    825 }
    826 
    827 }  // namespace tensorflow
    828 
    829 #endif  // TENSORFLOW_GRAPH_GRAPH_H_
    830