1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // A Graph describes a set of computations that are to be 17 // performed, as well as the dependencies between those 18 // computations. The basic model is a DAG (directed acyclic graph) with 19 // * internal nodes representing computational operations to be performed; 20 // * edges represent dependencies, indicating the target may only be 21 // executed once the source has completed; and 22 // * predefined "source" (start) and "sink" (finish) nodes -- the source 23 // should be the only node that doesn't depend on anything, and the sink 24 // should be the only node that nothing depends on. 25 // 26 // Note: Node ids are intended to be relatively dense in the 27 // 0..max_id range, but there may be gaps since ids won't be reused. 28 // 29 // Note: Some dependencies between operations are due to one operation 30 // consuming the output of another. In fact operations can produce 31 // multiple outputs and consume multiple inputs, and some 32 // optimizations will care about which specific outputs are connected 33 // to which specific inputs. We therefore represent data dependency 34 // between output O of layer A and input I of layer B using 35 // "input index" and "output index" labels per edge. 36 37 #ifndef TENSORFLOW_GRAPH_GRAPH_H_ 38 #define TENSORFLOW_GRAPH_GRAPH_H_ 39 40 #include <functional> 41 #include <string> 42 #include <vector> 43 #include "tensorflow/core/framework/function.h" 44 #include "tensorflow/core/framework/op.h" 45 #include "tensorflow/core/framework/types.h" 46 #include "tensorflow/core/graph/edgeset.h" 47 #include "tensorflow/core/lib/core/arena.h" 48 #include "tensorflow/core/lib/core/refcount.h" 49 #include "tensorflow/core/lib/core/status.h" 50 #include "tensorflow/core/lib/gtl/iterator_range.h" 51 #include "tensorflow/core/platform/logging.h" 52 #include "tensorflow/core/platform/macros.h" 53 #include "tensorflow/core/platform/types.h" 54 55 namespace tensorflow { 56 57 class Edge; 58 class EdgeSetTest; 59 class Graph; 60 class GraphDef; 61 class Node; 62 class VersionDef; 63 class WhileContext; 64 65 class NeighborIter; // Declared below 66 class NodeIter; // Declared below 67 class NodeProperties; // Defined in .cc 68 69 class Node { 70 public: 71 string DebugString() const; 72 int id() const { return id_; } 73 int cost_id() const { return cost_id_; } 74 const string& name() const; 75 const string& type_string() const; 76 77 // def() provides the NodeDef the user supplied, but the specifics 78 // of this Node may have changed due to placement, optimization, etc. 79 // In particular: 80 // * def().name() will match name(); 81 // * def().op() will match type_string() and op_def().name(); 82 // * def().input() is not reliable, use "in_edges()" below instead; 83 // * def().device() is the "user's requested device" and may not match 84 // the actual assigned device, see assigned_device_name() below; 85 // * def().attr() is authoritative. 86 // TODO(irving): Replace with NodeInfo. 87 const NodeDef& def() const; 88 const OpDef& op_def() const; 89 90 // input and output types 91 int32 num_inputs() const; 92 DataType input_type(int32 i) const; 93 const DataTypeVector& input_types() const; 94 95 int32 num_outputs() const; 96 DataType output_type(int32 o) const; 97 const DataTypeVector& output_types() const; 98 99 // The device requested by the user. For the actual assigned device, 100 // use assigned_device_name() below. 101 const string& requested_device() const; 102 103 // This changes the user requested device but not necessarily the device that 104 // on which the operation will run. 105 void set_requested_device(const string& device); 106 107 // This gives the device the runtime has assigned this node to. If 108 // you want the device the user requested, use def().device() instead. 109 // TODO(josh11b): Validate that the assigned_device, if not empty: 110 // fully specifies a device, and satisfies def().device(). 111 // TODO(josh11b): Move assigned_device_name outside of Node into a 112 // NodeId->DeviceName map. 113 const string& assigned_device_name() const; 114 void set_assigned_device_name(const string& device_name); 115 bool has_assigned_device_name() const { 116 return assigned_device_name_index_ > 0; 117 } 118 int assigned_device_name_index() const { return assigned_device_name_index_; } 119 void set_assigned_device_name_index(int index); 120 121 // Read only access to attributes 122 AttrSlice attrs() const; 123 124 // Inputs requested by the NodeDef. For the actual inputs, use in_edges. 125 const protobuf::RepeatedPtrField<string>& requested_inputs() const; 126 127 // Get the neighboring nodes via edges either in or out of this node. 128 gtl::iterator_range<NeighborIter> in_nodes() const; 129 gtl::iterator_range<NeighborIter> out_nodes() const; 130 const EdgeSet& in_edges() const { return in_edges_; } 131 const EdgeSet& out_edges() const { return out_edges_; } 132 133 // Node type helpers. 134 bool IsSource() const { return id() == 0; } 135 bool IsSink() const { return id() == 1; } 136 // Anything other than the special Source & Sink nodes. 137 bool IsOp() const { return id() > 1; } 138 139 // Node class helpers 140 bool IsSwitch() const { return class_ == NC_SWITCH; } 141 bool IsMerge() const { return class_ == NC_MERGE; } 142 bool IsEnter() const { return class_ == NC_ENTER; } 143 bool IsExit() const { return class_ == NC_EXIT; } 144 bool IsNextIteration() const { return class_ == NC_NEXT_ITERATION; } 145 bool IsLoopCond() const { return class_ == NC_LOOP_COND; } 146 bool IsControlTrigger() const { return class_ == NC_CONTROL_TRIGGER; } 147 bool IsSend() const { return class_ == NC_SEND || class_ == NC_HOST_SEND; } 148 bool IsRecv() const { return class_ == NC_RECV || class_ == NC_HOST_RECV; } 149 bool IsConstant() const { return class_ == NC_CONSTANT; } 150 bool IsVariable() const { return class_ == NC_VARIABLE; } 151 bool IsIdentity() const { return class_ == NC_IDENTITY; } 152 bool IsGetSessionHandle() const { return class_ == NC_GET_SESSION_HANDLE; } 153 bool IsGetSessionTensor() const { return class_ == NC_GET_SESSION_TENSOR; } 154 bool IsDeleteSessionTensor() const { 155 return class_ == NC_DELETE_SESSION_TENSOR; 156 } 157 bool IsControlFlow() const { 158 return (class_ != NC_OTHER) && // Fast path 159 (IsSwitch() || IsMerge() || IsEnter() || IsExit() || 160 IsNextIteration()); 161 } 162 bool IsHostSend() const { return class_ == NC_HOST_SEND; } 163 bool IsHostRecv() const { return class_ == NC_HOST_RECV; } 164 165 bool IsMetadata() const { return class_ == NC_METADATA; } 166 167 template <typename T> 168 void AddAttr(const string& name, const T& val) { 169 SetAttrValue(val, AddAttrHelper(name)); 170 } 171 172 void ClearAttr(const string& name); 173 174 // Returns into '*e' the edge connecting to the 'idx' input of this Node. 175 Status input_edge(int idx, const Edge** e) const; 176 177 // Returns into '*edges' the input data edges of this Node, indexed by input 178 // number. Does not return control edges. 179 Status input_edges(std::vector<const Edge*>* edges) const; 180 181 // Returns into '*n' the node that has an output connected to the 182 // 'idx' input of this Node. 183 Status input_node(int idx, const Node** n) const; 184 Status input_node(int idx, Node** n) const; 185 186 WhileContext* while_ctx() const { return while_ctx_; } 187 void set_while_ctx(WhileContext* while_ctx) { 188 DCHECK(IsExit()); 189 DCHECK(while_ctx_ == nullptr); 190 while_ctx_ = while_ctx; 191 } 192 193 private: 194 friend class Graph; 195 Node(); 196 197 NodeProperties* properties() const { return props_.get(); } 198 199 void Initialize(int id, int cost_id, std::shared_ptr<NodeProperties> props); 200 201 // Releases memory from props_, in addition to restoring *this to its 202 // uninitialized state. 203 void Clear(); 204 205 // Make a copy of the Node's props_ if props_ is shared with 206 // other nodes. This must be called before mutating properties, 207 // e.g. in AddAttr. 208 void MaybeCopyOnWrite(); 209 210 AttrValue* AddAttrHelper(const string& name); 211 212 // A set of mutually exclusive classes for different kinds of nodes, 213 // class_ is initialized in the Node::Initialize routine based on the 214 // node's type_string(). 215 enum NodeClass { 216 NC_UNINITIALIZED, 217 NC_SWITCH, 218 NC_MERGE, 219 NC_ENTER, 220 NC_EXIT, 221 NC_NEXT_ITERATION, 222 NC_LOOP_COND, 223 NC_CONTROL_TRIGGER, 224 NC_SEND, 225 NC_HOST_SEND, 226 NC_RECV, 227 NC_HOST_RECV, 228 NC_CONSTANT, 229 NC_VARIABLE, 230 NC_IDENTITY, 231 NC_GET_SESSION_HANDLE, 232 NC_GET_SESSION_TENSOR, 233 NC_DELETE_SESSION_TENSOR, 234 NC_METADATA, 235 NC_OTHER // Not a special kind of node 236 }; 237 238 static const std::unordered_map<string, NodeClass>& kNodeClassTable; 239 240 static NodeClass GetNodeClassForOp(const string& ts); 241 242 int id_; // -1 until Initialize() is called 243 int cost_id_; // -1 if there is no corresponding cost accounting node 244 NodeClass class_; 245 246 EdgeSet in_edges_; 247 EdgeSet out_edges_; 248 249 // NOTE(skyewm): inheriting from core::RefCounted may have a slight 250 // performance benefit over using shared_ptr, at the cost of manual ref 251 // counting 252 std::shared_ptr<NodeProperties> props_; 253 254 // Index within Graph::device_names_ of the name of device assigned 255 // to perform this computation. 256 int assigned_device_name_index_; 257 258 // A back-pointer to the Graph that owns this node. Currently, this exists 259 // solely to allow Node::[set_]assigned_device_name() to work. However, if all 260 // callers of Node::[set_]assigned_device_name() are modified to use the 261 // equivalent methods defined directly on Graph, then we can remove this 262 // field and reclaim that memory. 263 Graph* graph_; 264 265 // Set if this is an exit node of a while loop with an associated 266 // WhileContext. Otherwise null. (This is only set for exit nodes because 267 // they're the first nodes of a loop encountered while creating the gradient 268 // graph. Exit nodes that are part of while loop gradient graphs will not have 269 // this set.) 270 WhileContext* while_ctx_; 271 272 TF_DISALLOW_COPY_AND_ASSIGN(Node); 273 }; 274 275 // Represents an input of a node, i.e., the `index`-th input to `node`. 276 struct InputTensor { 277 const Node* node; 278 int index; 279 280 InputTensor(const Node* n, int i) : node(n), index(i) {} 281 InputTensor() : node(nullptr), index(0) {} 282 }; 283 284 // Represents an output of a node, i.e., the `index`-th output of `node`. Note 285 // that a single `OutputTensor` can correspond to multiple `Edge`s if the output 286 // is consumed by multiple destination nodes. 287 struct OutputTensor { 288 const Node* node; 289 int index; 290 291 OutputTensor(const Node* n, int i) : node(n), index(i) {} 292 OutputTensor() : node(nullptr), index(0) {} 293 }; 294 295 class Edge { 296 public: 297 Node* src() const { return src_; } 298 Node* dst() const { return dst_; } 299 int id() const { return id_; } 300 301 // Return the index of the source output that produces the data 302 // carried by this edge. The special value kControlSlot is used 303 // for control dependencies. 304 int src_output() const { return src_output_; } 305 306 // Return the index of the destination input that consumes the data 307 // carried by this edge. The special value kControlSlot is used 308 // for control dependencies. 309 int dst_input() const { return dst_input_; } 310 311 // Return true iff this is an edge that indicates a control-flow 312 // (as opposed to a data-flow) dependency. 313 bool IsControlEdge() const; 314 315 string DebugString() const; 316 317 private: 318 Edge() {} 319 320 friend class EdgeSetTest; 321 friend class Graph; 322 Node* src_; 323 Node* dst_; 324 int id_; 325 int src_output_; 326 int dst_input_; 327 }; 328 329 // Allows for iteration of the edges of a Graph, by iterating the underlying 330 // Graph.edges_ vector while skipping over null entries. 331 class GraphEdgesIterable { 332 private: 333 const std::vector<Edge*>& edges_; 334 335 public: 336 explicit GraphEdgesIterable(const std::vector<Edge*>& edges) 337 : edges_(edges) {} 338 339 typedef Edge* value_type; 340 341 class const_iterator { 342 private: 343 // The underlying iterator. 344 std::vector<value_type>::const_iterator iter_; 345 346 // The end of the underlying iterator. 347 std::vector<value_type>::const_iterator end_; 348 349 // Advances iter_ until it reaches a non-null item, or reaches the end. 350 void apply_filter() { 351 while (iter_ != end_ && *iter_ == nullptr) { 352 ++iter_; 353 } 354 } 355 356 public: 357 const_iterator(std::vector<value_type>::const_iterator iter, 358 std::vector<value_type>::const_iterator end) 359 : iter_(iter), end_(end) { 360 apply_filter(); 361 } 362 363 bool operator==(const const_iterator& other) const { 364 return iter_ == other.iter_; 365 } 366 367 bool operator!=(const const_iterator& other) const { 368 return iter_ != other.iter_; 369 } 370 371 // This is the prefix increment operator (++x), which is the operator 372 // used by C++ range iteration (for (x : y) ...). We intentionally do not 373 // provide a postfix increment operator. 374 const_iterator& operator++() { 375 ++iter_; 376 apply_filter(); 377 return *this; 378 } 379 380 value_type operator*() { return *iter_; } 381 }; 382 383 const_iterator begin() { 384 return const_iterator(edges_.begin(), edges_.end()); 385 } 386 const_iterator end() { return const_iterator(edges_.end(), edges_.end()); } 387 }; 388 389 // Thread compatible but not thread safe. 390 class Graph { 391 public: 392 // Constructs a graph with a single SOURCE (always id kSourceId) and a 393 // single SINK (always id kSinkId) node, and an edge from SOURCE->SINK. 394 // 395 // The graph can hold ops found in registry. `registry`s lifetime must be at 396 // least that of the constructed graph's. 397 explicit Graph(const OpRegistryInterface* registry); 398 399 // Constructs a graph with a single SOURCE (always id kSourceId) and a 400 // single SINK (always id kSinkId) node, and an edge from SOURCE->SINK. 401 // 402 // The graph can hold ops found in `flib_def`. Unlike the constructor taking 403 // an OpRegistryInterface, this constructor copies the function definitions in 404 // `flib_def` so its lifetime may be shorter than that of the graph's. The 405 // OpRegistryInterface backing `flib_def` must still have the lifetime of the 406 // graph though. 407 explicit Graph(const FunctionLibraryDefinition& flib_def); 408 409 ~Graph(); 410 411 static const int kControlSlot; 412 413 // The GraphDef version range of this graph (see graph.proto). 414 const VersionDef& versions() const; 415 void set_versions(const VersionDef& versions); 416 417 // Adds a new node to this graph, and returns it. Infers the Op and 418 // input/output types for the node. *this owns the returned instance. 419 // Returns nullptr and sets *status on error. 420 Node* AddNode(const NodeDef& node_def, Status* status); 421 422 // Copies *node, which may belong to another graph, to a new node, 423 // which is returned. Does not copy any edges. *this owns the 424 // returned instance. 425 Node* CopyNode(Node* node); 426 427 // Removes a node from this graph, including all edges from or to it. 428 // *node should not be accessed after calling this function. 429 // REQUIRES: node->IsOp() 430 void RemoveNode(Node* node); 431 432 // Adds an edge that connects the xth output of `source` to the yth input of 433 // `dest` and returns it. Does not update dest's NodeDef. 434 const Edge* AddEdge(Node* source, int x, Node* dest, int y); 435 436 // Adds a control edge (no data flows along this edge) that connects `source` 437 // to `dest`. If `dest`s NodeDef is missing the corresponding control input, 438 // adds the control input. 439 // 440 // If such a control edge already exists and `allow_duplicates` is false, no 441 // edge is added and the function returns nullptr. Otherwise the edge is 442 // unconditionally created and returned. The NodeDef is not updated if 443 // `allow_duplicates` is true. 444 // TODO(skyewm): // TODO(skyewm): allow_duplicates is needed only by 445 // graph_partition.cc. Figure out if we can do away with it. 446 const Edge* AddControlEdge(Node* source, Node* dest, 447 bool allow_duplicates = false); 448 449 // Removes edge from the graph. Does not update the destination node's 450 // NodeDef. 451 // REQUIRES: The edge must exist. 452 void RemoveEdge(const Edge* edge); 453 454 // Removes control edge `edge` from the graph. Note that this also updates 455 // the corresponding NodeDef to reflect the change. 456 // REQUIRES: The control edge must exist. 457 void RemoveControlEdge(const Edge* e); 458 // Updates the input to a node. The existing edge to `dst` is removed and an 459 // edge from `new_src` to `dst` is created. The NodeDef associated with `dst` 460 // is also updated. 461 Status UpdateEdge(Node* new_src, int new_src_index, Node* dst, int dst_index); 462 463 // Adds the function and gradient definitions in `fdef_lib` to this graph's op 464 // registry. Ignores duplicate functions, and returns a bad status if an 465 // imported function differs from an existing function or op with the same 466 // name. 467 Status AddFunctionLibrary(const FunctionDefLibrary& fdef_lib); 468 469 // The number of live nodes in the graph. 470 // 471 // Because nodes can be removed from the graph, num_nodes() is often 472 // smaller than num_node_ids(). If one needs to create an array of 473 // nodes indexed by node ids, num_node_ids() should be used as the 474 // array's size. 475 int num_nodes() const { return num_nodes_; } 476 477 // The number of live nodes in the graph, excluding the Source and Sink nodes. 478 int num_op_nodes() const { 479 DCHECK_GE(num_nodes_, 2); 480 return num_nodes_ - 2; 481 } 482 483 // The number of live edges in the graph. 484 // 485 // Because edges can be removed from the graph, num_edges() is often 486 // smaller than num_edge_ids(). If one needs to create an array of 487 // edges indexed by edge ids, num_edge_ids() should be used as the 488 // array's size. 489 int num_edges() const { return num_edges_; } 490 491 // Serialize the nodes starting at `from_node_id` to a GraphDef. 492 void ToGraphDefSubRange(GraphDef* graph_def, int from_node_id) const; 493 494 // Serialize to a GraphDef. 495 void ToGraphDef(GraphDef* graph_def) const; 496 497 // This version can be called from debugger to inspect the graph content. 498 // Use the previous version outside debug context for efficiency reasons. 499 // 500 // Note: We do not expose a DebugString() API, since GraphDef.DebugString() is 501 // not defined in some TensorFlow builds. 502 GraphDef ToGraphDefDebug() const; 503 504 // Generate new node name with the specified prefix that is unique 505 // across this graph. 506 string NewName(StringPiece prefix); 507 508 // Access to the list of all nodes. Example usage: 509 // for (Node* node : graph.nodes()) { ... } 510 gtl::iterator_range<NodeIter> nodes() const; 511 512 // Access to the list of all nodes, excluding the Source and Sink nodes. 513 gtl::iterator_range<NodeIter> op_nodes() const; 514 515 // Returns one more than the maximum id assigned to any node. 516 int num_node_ids() const { return nodes_.size(); } 517 518 // Returns the node associated with an id, or nullptr if no node 519 // with that id (the node with that id was removed and the id has 520 // not yet been re-used). *this owns the returned instance. 521 // REQUIRES: 0 <= id < num_node_ids(). 522 Node* FindNodeId(int id) const { return nodes_[id]; } 523 524 // Returns one more than the maximum id assigned to any edge. 525 int num_edge_ids() const { return edges_.size(); } 526 527 // Returns the Edge associated with an id, or nullptr if no edge 528 // with that id (the node with that id was removed and the id has 529 // not yet been re-used). *this owns the returned instance. 530 // REQUIRES: 0 <= id < num_node_ids(). 531 const Edge* FindEdgeId(int id) const { return edges_[id]; } 532 533 // Access to the set of all edges. Example usage: 534 // for (const Edge* e : graph.edges()) { ... } 535 GraphEdgesIterable edges() const { return GraphEdgesIterable(edges_); } 536 537 // The pre-defined nodes. 538 enum { kSourceId = 0, kSinkId = 1 }; 539 Node* source_node() const { return FindNodeId(kSourceId); } 540 Node* sink_node() const { return FindNodeId(kSinkId); } 541 542 const OpRegistryInterface* op_registry() const { return &ops_; } 543 const FunctionLibraryDefinition& flib_def() const { return ops_; } 544 545 void CheckDeviceNameIndex(int index) { 546 DCHECK_GE(index, 0); 547 DCHECK_LT(index, static_cast<int>(device_names_.size())); 548 } 549 550 int InternDeviceName(const string& device_name); 551 552 const string& get_assigned_device_name(const Node& node) const { 553 return device_names_[node.assigned_device_name_index()]; 554 } 555 556 void set_assigned_device_name_index(Node* node, int device_name_index) { 557 CheckDeviceNameIndex(device_name_index); 558 node->assigned_device_name_index_ = device_name_index; 559 } 560 561 void set_assigned_device_name(Node* node, const string& device_name) { 562 node->assigned_device_name_index_ = InternDeviceName(device_name); 563 } 564 565 // Returns OK if `node` is non-null and belongs to this graph 566 Status IsValidNode(const Node* node) const; 567 568 // Returns OK if IsValidNode(`node`) and `idx` is less than 569 // node->num_outputs() 570 Status IsValidOutputTensor(const Node* node, int idx) const; 571 572 // Returns OK if IsValidNode(`node`) and `idx` is less than 573 // node->num_inputs() 574 Status IsValidInputTensor(const Node* node, int idx) const; 575 576 // Create and return a new WhileContext owned by this graph. This is called 577 // when a new while loop is created. `frame_name` must be unique among 578 // WhileContexts in this graph. 579 Status AddWhileContext(StringPiece frame_name, std::vector<Node*> enter_nodes, 580 std::vector<Node*> exit_nodes, 581 OutputTensor cond_output, 582 std::vector<OutputTensor> body_inputs, 583 std::vector<OutputTensor> body_outputs, 584 WhileContext** result); 585 586 // TODO(josh11b): uint64 hash() const; 587 588 private: 589 // If cost_node is non-null, then cost accounting (in CostModel) 590 // will be associated with that node rather than the new one being 591 // created. 592 // 593 // Ownership of the returned Node is not transferred to caller. 594 Node* AllocateNode(std::shared_ptr<NodeProperties> props, 595 const Node* cost_node); 596 void ReleaseNode(Node* node); 597 598 // Registry of all known ops, including functions. 599 FunctionLibraryDefinition ops_; 600 601 // GraphDef versions 602 const std::unique_ptr<VersionDef> versions_; 603 604 // Allocator which will give us good locality. 605 core::Arena arena_; 606 607 // Map from node ids to allocated nodes. nodes_[id] may be nullptr if 608 // the node with that id was removed from the graph. 609 std::vector<Node*> nodes_; 610 611 // Number of nodes alive. 612 int64 num_nodes_ = 0; 613 614 // Map from edge ids to allocated edges. edges_[id] may be nullptr if 615 // the edge with that id was removed from the graph. 616 std::vector<Edge*> edges_; 617 618 // The number of entries in edges_ that are not nullptr. 619 int num_edges_ = 0; 620 621 // Allocated but free nodes and edges. 622 std::vector<Node*> free_nodes_; 623 std::vector<Edge*> free_edges_; 624 625 // For generating unique names. 626 int name_counter_ = 0; 627 628 // In most graphs, the number of unique values used for the 629 // Node::assigned_device_name() property is quite small. If the graph is 630 // large, then this duplication of values can consume a significant amount of 631 // memory. Instead, we represent the same information using an interning 632 // table, which consists of a vector of unique strings (device_names_), as 633 // well a map (device_names_map_) from unique strings to indices within the 634 // unique string table. 635 // 636 // The InternDeviceName() method handles adding a new entry into the table, 637 // or locating the index of an existing entry. 638 // 639 // The fact that Node::assigned_device_name() is implemented using an 640 // interning table is intentionally public. This allows algorithms that 641 // frequently access this field to do so efficiently, especially for the case 642 // where the assigned_device_name of one Node is copied directly from that 643 // of another Node. 644 645 // A table of the unique assigned device names. Indices do NOT correspond 646 // to node IDs. Index 0 is always the empty string. 647 std::vector<string> device_names_; 648 649 // Maps unique device names to indices within device_names_[i]. 650 std::unordered_map<string, int> device_names_map_; 651 652 // All the while contexts owned by this graph, keyed by frame name, 653 // corresponding to all the while loops contained in this graph (including 654 // nested loops). The stored contexts are usually accessed via 655 // AddWhileContext() or Node::while_ctx(), but this manages the lifetime. 656 std::map<string, WhileContext> while_ctxs_; 657 658 // Searches through edges_ for the Edge whose destination node and index 659 // matches dst. An edge with destination `dst` must exist in the graph. 660 const Edge* FindEdge(const Node* dst, int index); 661 662 TF_DISALLOW_COPY_AND_ASSIGN(Graph); 663 }; 664 665 // TODO(josh11b): We may want to support keeping an index on various 666 // node/edge attributes in a graph, particularly node names. 667 668 // Helper routines 669 670 inline bool IsSource(const Node* node) { return node->IsSource(); } 671 inline bool IsSink(const Node* node) { return node->IsSink(); } 672 inline bool IsSwitch(const Node* node) { return node->IsSwitch(); } 673 inline bool IsMerge(const Node* node) { return node->IsMerge(); } 674 inline bool IsEnter(const Node* node) { return node->IsEnter(); } 675 inline bool IsExit(const Node* node) { return node->IsExit(); } 676 inline bool IsNextIteration(const Node* n) { return n->IsNextIteration(); } 677 inline bool IsLoopCond(const Node* node) { return node->IsLoopCond(); } 678 inline bool IsControlTrigger(const Node* n) { return n->IsControlTrigger(); } 679 inline bool IsSend(const Node* node) { return node->IsSend(); } 680 inline bool IsRecv(const Node* node) { return node->IsRecv(); } 681 inline bool IsHostSend(const Node* node) { return node->IsHostSend(); } 682 inline bool IsHostRecv(const Node* node) { return node->IsHostRecv(); } 683 684 // True for Nodes that mediate the transfer of values between processes. 685 inline bool IsTransferNode(const Node* n) { return IsSend(n) || IsRecv(n); } 686 687 inline bool IsConstant(const Node* node) { return node->IsConstant(); } 688 inline bool IsVariable(const Node* node) { return node->IsVariable(); } 689 inline bool IsIdentity(const Node* node) { return node->IsIdentity(); } 690 691 // Returns true iff 'n' is a control flow node. 692 inline bool IsControlFlow(const Node* n) { return n->IsControlFlow(); } 693 694 // Returns true if the node only depends on its input's metadata 695 // (shape). Specifically, returns true for "Size", "Shape" and "Rank" ops. 696 inline bool IsMetadata(const Node* n) { return n->IsMetadata(); } 697 698 inline bool IsHostMemoryPreserving(const Node* node) { 699 return IsIdentity(node) || IsControlFlow(node); 700 } 701 702 // Iterator for stepping through the nodes of a graph. 703 class NodeIter { 704 public: 705 NodeIter(const Graph* graph, int id); 706 bool operator==(const NodeIter& rhs); 707 bool operator!=(const NodeIter& rhs); 708 void operator++(); 709 Node* operator*(); 710 Node* operator->(); 711 712 private: 713 // Invariant: id_ == graph_->num_node_ids() || graph_->FindId(id_) != nullptr 714 const Graph* graph_; 715 int id_; 716 }; 717 718 // Iterator for stepping through the neighbors of a node. 719 class NeighborIter { 720 public: 721 NeighborIter(EdgeSet::const_iterator iter, bool incoming); 722 bool operator==(const NeighborIter& rhs); 723 bool operator!=(const NeighborIter& rhs); 724 void operator++(); 725 Node* operator*(); 726 Node* operator->(); 727 728 private: 729 EdgeSet::const_iterator iter_; 730 bool incoming_; 731 }; 732 733 // IMPLEMENTATION DETAILS, PLEASE IGNORE 734 735 inline NodeIter::NodeIter(const Graph* graph, int id) 736 : graph_(graph), id_(id) {} 737 738 inline bool NodeIter::operator==(const NodeIter& rhs) { 739 DCHECK(graph_ == rhs.graph_); 740 return id_ == rhs.id_; 741 } 742 743 inline bool NodeIter::operator!=(const NodeIter& rhs) { 744 return !(*this == rhs); 745 } 746 747 inline void NodeIter::operator++() { 748 while (1) { 749 DCHECK_LE(id_, graph_->num_node_ids()); 750 ++id_; 751 if (id_ >= graph_->num_node_ids() || graph_->FindNodeId(id_) != nullptr) { 752 return; 753 } 754 } 755 } 756 757 inline Node* NodeIter::operator*() { return graph_->FindNodeId(id_); } 758 759 inline Node* NodeIter::operator->() { return graph_->FindNodeId(id_); } 760 761 inline NeighborIter::NeighborIter(EdgeSet::const_iterator iter, bool incoming) 762 : iter_(iter), incoming_(incoming) {} 763 764 inline bool NeighborIter::operator==(const NeighborIter& rhs) { 765 return iter_ == rhs.iter_ && incoming_ == rhs.incoming_; 766 } 767 768 inline bool NeighborIter::operator!=(const NeighborIter& rhs) { 769 return !(*this == rhs); 770 } 771 772 inline void NeighborIter::operator++() { ++iter_; } 773 774 inline Node* NeighborIter::operator*() { 775 const Edge* e = *iter_; 776 return incoming_ ? e->src() : e->dst(); 777 } 778 779 inline Node* NeighborIter::operator->() { 780 const Edge* e = *iter_; 781 return incoming_ ? e->src() : e->dst(); 782 } 783 784 inline bool Edge::IsControlEdge() const { 785 // Note that if either src_output_ or dst_input_ is kControlSlot, 786 // so is the other one (AddEdge checks this). 787 return src_output_ == Graph::kControlSlot; 788 } 789 790 inline gtl::iterator_range<NodeIter> Graph::nodes() const { 791 // Note that NodeId 0 is always valid since we don't let the source 792 // node be removed from the graph. 793 return gtl::make_range(NodeIter(this, 0), NodeIter(this, num_node_ids())); 794 } 795 796 inline gtl::iterator_range<NodeIter> Graph::op_nodes() const { 797 // Note that NodeId 0 is always valid since we don't let the source 798 // node be removed from the graph. 799 // 800 // The current implementation of Graph maintains the invariant that the 801 // first two nodes are the source and sink nodes, and all other nodes are op 802 // nodes. This method (op_nodes()) relies on this invariant. 803 NodeIter begin(this, 0); 804 NodeIter end(this, num_node_ids()); 805 if (begin != end) { 806 ++begin; 807 } 808 if (begin != end) { 809 ++begin; 810 } 811 return gtl::make_range(begin, end); 812 } 813 814 inline void Node::set_assigned_device_name_index(int index) { 815 graph_->CheckDeviceNameIndex(index); 816 assigned_device_name_index_ = index; 817 } 818 819 inline void Node::set_assigned_device_name(const string& device_name) { 820 graph_->set_assigned_device_name(this, device_name); 821 } 822 823 inline const string& Node::assigned_device_name() const { 824 return graph_->get_assigned_device_name(*this); 825 } 826 827 } // namespace tensorflow 828 829 #endif // TENSORFLOW_GRAPH_GRAPH_H_ 830