1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/graph/graph.h" 17 18 #include <vector> 19 #include "tensorflow/core/framework/graph.pb.h" 20 #include "tensorflow/core/framework/node_def.pb.h" 21 #include "tensorflow/core/framework/node_def_util.h" 22 #include "tensorflow/core/framework/op_kernel.h" 23 #include "tensorflow/core/framework/versions.pb.h" 24 #include "tensorflow/core/graph/while_context.h" 25 #include "tensorflow/core/lib/core/errors.h" 26 #include "tensorflow/core/lib/gtl/map_util.h" 27 #include "tensorflow/core/lib/hash/hash.h" 28 #include "tensorflow/core/lib/strings/strcat.h" 29 #include "tensorflow/core/lib/strings/stringprintf.h" 30 #include "tensorflow/core/platform/logging.h" 31 #include "tensorflow/core/public/version.h" 32 33 namespace tensorflow { 34 35 const int Graph::kControlSlot = -1; 36 37 struct NodeProperties { 38 public: 39 NodeProperties(const OpDef* op_def, const NodeDef& node_def, 40 const DataTypeSlice inputs, const DataTypeSlice outputs) 41 : op_def(op_def), 42 node_def(node_def), 43 input_types(inputs.begin(), inputs.end()), 44 output_types(outputs.begin(), outputs.end()) {} 45 46 const OpDef* op_def; // not owned 47 NodeDef node_def; 48 const DataTypeVector input_types; 49 const DataTypeVector output_types; 50 }; 51 52 // Node 53 54 #define REF_CLASS(key, value) \ 55 {key, value}, { "Ref" key, value } 56 57 const std::unordered_map<string, Node::NodeClass>& Node::kNodeClassTable = 58 *new std::unordered_map<string, Node::NodeClass>({ 59 // Keep in same order as NodeClass values 60 REF_CLASS("Switch", NC_SWITCH), 61 REF_CLASS("Merge", NC_MERGE), 62 REF_CLASS("Enter", NC_ENTER), 63 REF_CLASS("Exit", NC_EXIT), 64 REF_CLASS("NextIteration", NC_NEXT_ITERATION), 65 {"LoopCond", NC_LOOP_COND}, 66 {"ControlTrigger", NC_CONTROL_TRIGGER}, 67 {"_Send", NC_SEND}, 68 {"_HostSend", NC_HOST_SEND}, 69 {"_Recv", NC_RECV}, 70 {"_HostRecv", NC_HOST_RECV}, 71 {"Const", NC_CONSTANT}, 72 {"HostConst", NC_CONSTANT}, 73 {"Variable", NC_VARIABLE}, 74 {"VariableV2", NC_VARIABLE}, 75 REF_CLASS("Identity", NC_IDENTITY), 76 {"GetSessionHandle", NC_GET_SESSION_HANDLE}, 77 {"GetSessionHandleV2", NC_GET_SESSION_HANDLE}, 78 {"GetSessionTensor", NC_GET_SESSION_TENSOR}, 79 {"DeleteSessionTensor", NC_DELETE_SESSION_TENSOR}, 80 {"Size", NC_METADATA}, 81 {"Shape", NC_METADATA}, 82 {"Rank", NC_METADATA}, 83 {"_ScopedAllocator", NC_SCOPED_ALLOCATOR}, 84 {"CollectiveReduce", NC_COLLECTIVE}, 85 {"CollectiveBcastSend", NC_COLLECTIVE}, 86 {"CollectiveBcastRecv", NC_COLLECTIVE}, 87 {"FakeParam", NC_FAKE_PARAM}, 88 {"PartitionedCall", NC_PARTITIONED_CALL}, 89 {"StatefulPartitionedCall", NC_PARTITIONED_CALL}, 90 // Not using the constants defined in FunctionLibraryDefinition for the 91 // 4 ops below because android inference library does not link 92 // tf.function related files. 93 {"_Arg", NC_ARG}, 94 {"_DeviceArg", NC_ARG}, 95 {"_Retval", NC_RETVAL}, 96 {"_DeviceRetval", NC_RETVAL}, 97 }); 98 99 #undef REF_CLASS 100 101 Node::NodeClass Node::GetNodeClassForOp(const string& ts) { 102 auto it = kNodeClassTable.find(ts); 103 if (it != kNodeClassTable.end()) { 104 return it->second; 105 } else { 106 return NC_OTHER; 107 } 108 } 109 110 string Node::DebugString() const { 111 string ret = strings::StrCat("{name:'", name(), "' id:", id_); 112 if (IsSource()) { 113 strings::StrAppend(&ret, " source}"); 114 } else if (IsSink()) { 115 strings::StrAppend(&ret, " sink}"); 116 } else { 117 strings::StrAppend(&ret, " op device:"); 118 strings::StrAppend(&ret, "{", assigned_device_name(), "}"); 119 strings::StrAppend(&ret, " def:{", SummarizeNode(*this), "}}"); 120 } 121 return ret; 122 } 123 124 Node::Node() 125 : id_(-1), 126 cost_id_(-1), 127 class_(NC_UNINITIALIZED), 128 props_(nullptr), 129 assigned_device_name_index_(0), 130 while_ctx_(nullptr) {} 131 132 void Node::Initialize(int id, int cost_id, 133 std::shared_ptr<NodeProperties> props) { 134 DCHECK_EQ(id_, -1); 135 DCHECK(in_edges_.empty()); 136 DCHECK(out_edges_.empty()); 137 id_ = id; 138 cost_id_ = cost_id; 139 140 props_ = std::move(props); 141 // Initialize the class_ based on the type string 142 class_ = GetNodeClassForOp(props_->node_def.op()); 143 } 144 145 void Node::Clear() { 146 in_edges_.clear(); 147 out_edges_.clear(); 148 id_ = -1; 149 cost_id_ = -1; 150 class_ = NC_UNINITIALIZED; 151 props_.reset(); 152 assigned_device_name_index_ = 0; 153 } 154 155 void Node::UpdateProperties() { 156 DataTypeVector inputs; 157 DataTypeVector outputs; 158 Status status = 159 InOutTypesForNode(props_->node_def, *(props_->op_def), &inputs, &outputs); 160 if (!status.ok()) { 161 LOG(ERROR) << "Failed at updating node: " << status; 162 return; 163 } 164 props_ = std::make_shared<NodeProperties>(props_->op_def, props_->node_def, 165 inputs, outputs); 166 } 167 168 const string& Node::name() const { return props_->node_def.name(); } 169 const string& Node::type_string() const { return props_->node_def.op(); } 170 const NodeDef& Node::def() const { return props_->node_def; } 171 const OpDef& Node::op_def() const { return *props_->op_def; } 172 173 int32 Node::num_inputs() const { return props_->input_types.size(); } 174 DataType Node::input_type(int32 i) const { return props_->input_types[i]; } 175 const DataTypeVector& Node::input_types() const { return props_->input_types; } 176 177 int32 Node::num_outputs() const { return props_->output_types.size(); } 178 DataType Node::output_type(int32 o) const { return props_->output_types[o]; } 179 const DataTypeVector& Node::output_types() const { 180 return props_->output_types; 181 } 182 183 AttrSlice Node::attrs() const { return AttrSlice(def()); } 184 185 const protobuf::RepeatedPtrField<string>& Node::requested_inputs() const { 186 return def().input(); 187 } 188 189 const string& Node::requested_device() const { return def().device(); } 190 191 gtl::iterator_range<NeighborIter> Node::out_nodes() const { 192 return gtl::make_range(NeighborIter(out_edges_.begin(), false), 193 NeighborIter(out_edges_.end(), false)); 194 } 195 196 gtl::iterator_range<NeighborIter> Node::in_nodes() const { 197 return gtl::make_range(NeighborIter(in_edges_.begin(), true), 198 NeighborIter(in_edges_.end(), true)); 199 } 200 201 void Node::MaybeCopyOnWrite() { 202 // NodeProperties may be shared between Nodes. Make a copy if so. 203 if (!props_.unique()) { 204 props_ = std::make_shared<NodeProperties>(*props_); 205 } 206 } 207 208 AttrValue* Node::AddAttrHelper(const string& name) { 209 MaybeCopyOnWrite(); 210 return &((*props_->node_def.mutable_attr())[name]); 211 } 212 213 void Node::ClearAttr(const string& name) { 214 MaybeCopyOnWrite(); 215 (*props_->node_def.mutable_attr()).erase(name); 216 } 217 218 void Node::set_name(string name) { 219 MaybeCopyOnWrite(); 220 props_->node_def.set_name(std::move(name)); 221 } 222 223 void Node::set_requested_device(const string& device) { 224 MaybeCopyOnWrite(); 225 props_->node_def.set_device(device); 226 } 227 228 void Node::set_original_node_names(const std::vector<string>& names) { 229 MaybeCopyOnWrite(); 230 props_->node_def.mutable_experimental_debug_info() 231 ->clear_original_node_names(); 232 if (!names.empty()) { 233 *props_->node_def.mutable_experimental_debug_info() 234 ->mutable_original_node_names() = {names.begin(), names.end()}; 235 } 236 } 237 238 Status Node::input_edge(int idx, const Edge** e) const { 239 if (idx < 0 || idx >= num_inputs()) { 240 return errors::InvalidArgument("Invalid input_edge index: ", idx, ", Node ", 241 name(), " only has ", num_inputs(), 242 " inputs."); 243 } 244 245 // This does a linear search over the edges. In the common case, 246 // the number of elements is small enough that this search isn't 247 // expensive. Should it become a bottleneck, one can make an 248 // optimization where, if the number of edges is small, we use 249 // linear iteration, and if the number of edges is large, we perform 250 // an indexing step during construction that keeps an array of Edges 251 // indexed by pointer. This would keep the size of each Node small 252 // in the common case but make this function faster when the number 253 // of edges is large. 254 for (const Edge* edge : in_edges()) { 255 if (edge->dst_input() == idx) { 256 *e = edge; 257 return Status::OK(); 258 } 259 } 260 261 return errors::NotFound("Could not find input edge ", idx, " for ", name()); 262 } 263 264 // Returns a vector of the non-control input edges to a node, indexed by ID. 265 Status Node::input_edges(std::vector<const Edge*>* input_edges) const { 266 input_edges->clear(); 267 input_edges->resize(num_inputs(), nullptr); 268 269 for (const Edge* edge : in_edges()) { 270 if (edge->IsControlEdge()) continue; 271 if (edge->dst_input() < 0 || edge->dst_input() >= num_inputs()) { 272 return errors::Internal("Invalid edge input number ", edge->dst_input()); 273 } 274 if ((*input_edges)[edge->dst_input()] != nullptr) { 275 return errors::Internal("Duplicate edge input number: ", 276 edge->dst_input()); 277 } 278 (*input_edges)[edge->dst_input()] = edge; 279 } 280 281 for (int i = 0; i < num_inputs(); ++i) { 282 if ((*input_edges)[i] == nullptr) { 283 return errors::InvalidArgument("Missing edge input number: ", i); 284 } 285 } 286 return Status::OK(); 287 } 288 289 Status Node::input_node(int idx, Node** n) const { 290 const Edge* e; 291 TF_RETURN_IF_ERROR(input_edge(idx, &e)); 292 if (e == nullptr) { 293 *n = nullptr; 294 } else { 295 *n = e->src(); 296 } 297 return Status::OK(); 298 } 299 300 Status Node::input_node(int idx, const Node** const_n) const { 301 Node* n; 302 TF_RETURN_IF_ERROR(input_node(idx, &n)); 303 *const_n = n; 304 return Status::OK(); 305 } 306 307 Status Node::input_tensor(int idx, OutputTensor* t) const { 308 const Edge* e; 309 TF_RETURN_IF_ERROR(input_edge(idx, &e)); 310 DCHECK(e != nullptr); 311 *t = OutputTensor(e->src(), e->src_output()); 312 return Status::OK(); 313 } 314 315 // NodeDebugInfo 316 317 NodeDebugInfo::NodeDebugInfo(const Node& n) : NodeDebugInfo(n.def()) {} 318 NodeDebugInfo::NodeDebugInfo(const NodeDef& ndef) : name(ndef.name()) { 319 if (ndef.has_experimental_debug_info()) { 320 const auto& names = ndef.experimental_debug_info().original_node_names(); 321 original_node_names.assign(names.begin(), names.end()); 322 } 323 } 324 325 // InputTensor 326 327 bool InputTensor::operator==(const InputTensor& other) const { 328 return node == other.node && index == other.index; 329 } 330 331 uint64 InputTensor::Hash::operator()(InputTensor const& s) const { 332 return Hash64Combine(std::hash<const Node*>()(s.node), 333 std::hash<int>()(s.index)); 334 } 335 336 // OutputTensor 337 338 bool OutputTensor::operator==(const OutputTensor& other) const { 339 return node == other.node && index == other.index; 340 } 341 342 uint64 OutputTensor::Hash::operator()(OutputTensor const& s) const { 343 return Hash64Combine(std::hash<const Node*>()(s.node), 344 std::hash<int>()(s.index)); 345 } 346 347 // Graph 348 349 Graph::Graph(const OpRegistryInterface* ops) 350 : ops_(ops, FunctionDefLibrary()), 351 versions_(new VersionDef), 352 arena_(8 << 10 /* 8kB */) { 353 versions_->set_producer(TF_GRAPH_DEF_VERSION); 354 versions_->set_min_consumer(TF_GRAPH_DEF_VERSION_MIN_CONSUMER); 355 356 // Initialize the name interning table for assigned_device_name. 357 device_names_.push_back(""); 358 DCHECK_EQ(0, InternDeviceName("")); 359 360 // Source and sink have no endpoints, just control edges. 361 NodeDef def; 362 def.set_name("_SOURCE"); 363 def.set_op("NoOp"); 364 Status status; 365 Node* source = AddNode(def, &status); 366 TF_CHECK_OK(status); 367 CHECK_EQ(source->id(), kSourceId); 368 369 def.set_name("_SINK"); 370 Node* sink = AddNode(def, &status); 371 TF_CHECK_OK(status); 372 CHECK_EQ(sink->id(), kSinkId); 373 374 AddControlEdge(source, sink); 375 } 376 377 Graph::Graph(const FunctionLibraryDefinition& flib_def) 378 : Graph(flib_def.default_registry()) { 379 // Need a new-enough consumer to support the functions we add to the graph. 380 if (flib_def.ToProto().function_size() > 0 && 381 versions_->min_consumer() < 12) { 382 versions_->set_min_consumer(12); 383 } 384 Status s = ops_.AddLibrary(flib_def); 385 CHECK(s.ok()) << s.error_message(); 386 } 387 388 Graph::~Graph() { 389 // Manually call the destructors for all the Nodes we constructed using 390 // placement new. 391 for (Node* node : nodes_) { 392 if (node != nullptr) { 393 node->~Node(); 394 } 395 } 396 for (Node* node : free_nodes_) { 397 node->~Node(); 398 } 399 // Edges have no destructor, and we arena-allocated them, so no need to 400 // destroy them. 401 } 402 403 const VersionDef& Graph::versions() const { return *versions_; } 404 void Graph::set_versions(const VersionDef& versions) { *versions_ = versions; } 405 406 Node* Graph::AddNode(const NodeDef& node_def, Status* status) { 407 const OpDef* op_def; 408 status->Update(ops_.LookUpOpDef(node_def.op(), &op_def)); 409 if (!status->ok()) return nullptr; 410 411 DataTypeVector inputs; 412 DataTypeVector outputs; 413 status->Update(InOutTypesForNode(node_def, *op_def, &inputs, &outputs)); 414 if (!status->ok()) { 415 *status = AttachDef(*status, node_def); 416 return nullptr; 417 } 418 419 Node* node = AllocateNode( 420 std::make_shared<NodeProperties>(op_def, node_def, inputs, outputs), 421 nullptr); 422 return node; 423 } 424 425 Node* Graph::CopyNode(const Node* node) { 426 DCHECK(!node->IsSource()); 427 DCHECK(!node->IsSink()); 428 Node* copy = AllocateNode(node->props_, node); 429 copy->set_assigned_device_name(node->assigned_device_name()); 430 431 // Since the OpDef of a function may be owned by the Graph that owns 'node', 432 // relookup the OpDef in the target graph. If it differs, then clone the 433 // node properties with the updated OpDef. 434 const OpDef* op_def; 435 TF_CHECK_OK(ops_.LookUpOpDef(node->type_string(), &op_def)); 436 if (op_def != node->props_->op_def) { 437 copy->MaybeCopyOnWrite(); 438 copy->props_->op_def = op_def; 439 } 440 441 return copy; 442 } 443 444 void Graph::RemoveNode(Node* node) { 445 TF_DCHECK_OK(IsValidNode(node)) << node->DebugString(); 446 DCHECK(!node->IsSource()); 447 DCHECK(!node->IsSink()); 448 449 // Remove any edges involving this node. 450 while (!node->in_edges_.empty()) { 451 RemoveEdge(*node->in_edges_.begin()); 452 } 453 while (!node->out_edges_.empty()) { 454 RemoveEdge(*node->out_edges_.begin()); 455 } 456 ReleaseNode(node); 457 } 458 459 const Edge* Graph::AddEdge(Node* source, int x, Node* dest, int y) { 460 TF_DCHECK_OK(IsValidNode(source)) << source->DebugString(); 461 TF_DCHECK_OK(IsValidNode(dest)) << dest->DebugString(); 462 463 // source/sink must only be linked via control slots, and 464 // control slots must only be linked to control slots. 465 if (source == source_node() || dest == sink_node() || x == kControlSlot || 466 y == kControlSlot) { 467 DCHECK_EQ(x, kControlSlot) << source->DebugString(); 468 DCHECK_EQ(y, kControlSlot) << dest->DebugString(); 469 } 470 471 Edge* e = nullptr; 472 if (free_edges_.empty()) { 473 e = new (arena_.Alloc(sizeof(Edge))) Edge; // placement new 474 } else { 475 e = free_edges_.back(); 476 free_edges_.pop_back(); 477 } 478 e->id_ = edges_.size(); 479 e->src_ = source; 480 e->dst_ = dest; 481 e->src_output_ = x; 482 e->dst_input_ = y; 483 CHECK(source->out_edges_.insert(e).second); 484 CHECK(dest->in_edges_.insert(e).second); 485 edges_.push_back(e); 486 ++num_edges_; 487 return e; 488 } 489 490 void Graph::RemoveEdge(const Edge* e) { 491 TF_DCHECK_OK(IsValidNode(e->src_)) << e->src_->DebugString(); 492 TF_DCHECK_OK(IsValidNode(e->dst_)) << e->dst_->DebugString(); 493 CHECK_EQ(e->src_->out_edges_.erase(e), size_t{1}); 494 CHECK_EQ(e->dst_->in_edges_.erase(e), size_t{1}); 495 CHECK_EQ(e, edges_[e->id_]); 496 CHECK_GT(num_edges_, 0); 497 498 edges_[e->id_] = nullptr; 499 500 Edge* del = const_cast<Edge*>(e); 501 del->src_ = nullptr; 502 del->dst_ = nullptr; 503 del->id_ = -1; 504 del->src_output_ = kControlSlot - 1; 505 del->dst_input_ = kControlSlot - 1; 506 free_edges_.push_back(del); 507 --num_edges_; 508 } 509 510 const Edge* Graph::AddControlEdge(Node* source, Node* dest, 511 bool allow_duplicates) { 512 if (!allow_duplicates) { 513 for (const Edge* edge : dest->in_edges()) { 514 if (edge->IsControlEdge() && edge->src() == source) { 515 // The requested edge already exists. 516 return nullptr; 517 } 518 } 519 } 520 // Modify dest's NodeDef if necessary. 521 if (!source->IsSource() && !dest->IsSink() && !allow_duplicates) { 522 // Check if this input is already in dest's NodeDef. 523 const string new_input = strings::StrCat("^", source->name()); 524 bool input_exists = false; 525 for (const string& input : dest->props_->node_def.input()) { 526 if (input == new_input) { 527 input_exists = true; 528 break; 529 } 530 } 531 if (!input_exists) { 532 dest->MaybeCopyOnWrite(); 533 dest->props_->node_def.add_input(new_input); 534 } 535 } 536 return AddEdge(source, kControlSlot, dest, kControlSlot); 537 } 538 539 void Graph::RemoveControlEdge(const Edge* e) { 540 if (!e->src_->IsSource() && !e->dst_->IsSink()) { 541 e->dst_->MaybeCopyOnWrite(); 542 string e_src_name = strings::StrCat("^", e->src_->name()); 543 auto* inputs = e->dst_->props_->node_def.mutable_input(); 544 for (auto it = inputs->begin(); it != inputs->end(); ++it) { 545 if (*it == e_src_name) { 546 inputs->erase(it); 547 break; 548 } 549 } 550 } 551 RemoveEdge(e); 552 } 553 554 namespace { 555 const Edge* FindEdge(const Node* dst, int index) { 556 for (const Edge* e : dst->in_edges()) { 557 if (e->dst_input() == index) return e; 558 } 559 return nullptr; 560 } 561 } // namespace 562 563 Status Graph::UpdateEdge(Node* new_src, int new_src_index, Node* dst, 564 int dst_index) { 565 TF_RETURN_IF_ERROR(IsValidOutputTensor(new_src, new_src_index)); 566 TF_RETURN_IF_ERROR(IsValidInputTensor(dst, dst_index)); 567 const Edge* e = FindEdge(dst, dst_index); 568 if (e == nullptr) { 569 return errors::InvalidArgument("Couldn't find edge to ", 570 FormatNodeForError(*dst)); 571 } 572 RemoveEdge(e); 573 AddEdge(new_src, new_src_index, dst, dst_index); 574 dst->MaybeCopyOnWrite(); 575 (*dst->props_->node_def.mutable_input())[dst_index] = 576 strings::StrCat(new_src->name(), ":", new_src_index); 577 return Status::OK(); 578 } 579 580 Status Graph::AddWhileInputHack(Node* new_src, int new_src_index, Node* dst) { 581 if (dst->type_string() != "While") { 582 return errors::Internal( 583 "dst argument to AddWhileEdgeHack should be a While op, got: ", 584 dst->DebugString()); 585 } 586 TF_RETURN_IF_ERROR(IsValidOutputTensor(new_src, new_src_index)); 587 // Find the current number of data inputs. We'll add the new edge to the next 588 // missing data input. 589 int dst_index = 0; 590 for (const Edge* edge : dst->in_edges()) { 591 if (edge->IsControlEdge()) continue; 592 ++dst_index; 593 } 594 TF_RETURN_IF_ERROR(IsValidInputTensor(dst, dst_index)); 595 AddEdge(new_src, new_src_index, dst, dst_index); 596 dst->MaybeCopyOnWrite(); 597 dst->props_->node_def.add_input( 598 strings::StrCat(new_src->name(), ":", new_src_index)); 599 return Status::OK(); 600 } 601 602 Status Graph::AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) { 603 // Need a new-enough consumer to support the functions we add to the graph. 604 if (fdef_lib.function_size() > 0 && versions_->min_consumer() < 12) { 605 versions_->set_min_consumer(12); 606 } 607 return ops_.AddLibrary(fdef_lib); 608 } 609 610 namespace { 611 612 void AddInput(NodeDef* dst, StringPiece src_name, int src_slot) { 613 if (src_slot == Graph::kControlSlot) { 614 dst->add_input(strings::StrCat("^", src_name)); 615 } else if (src_slot == 0) { 616 dst->add_input(src_name.data(), src_name.size()); 617 } else { 618 dst->add_input(strings::StrCat(src_name, ":", src_slot)); 619 } 620 } 621 622 } // namespace 623 624 void Graph::ToGraphDef(GraphDef* graph_def) const { 625 ToGraphDefSubRange(graph_def, 0); 626 } 627 628 GraphDef Graph::ToGraphDefDebug() const { 629 GraphDef ret; 630 ToGraphDef(&ret); 631 return ret; 632 } 633 634 void Graph::ToGraphDefSubRange(GraphDef* graph_def, int from_node_id) const { 635 graph_def->Clear(); 636 *graph_def->mutable_versions() = versions(); 637 *graph_def->mutable_library() = ops_.ToProto(); 638 639 graph_def->mutable_node()->Reserve(std::max(1, num_nodes() - from_node_id)); 640 641 std::vector<const Edge*> 642 inputs; // Construct this outside the loop for speed. 643 for (auto id = from_node_id; id < num_node_ids(); ++id) { 644 const Node* node = FindNodeId(id); 645 if (node == nullptr || !node->IsOp()) continue; 646 NodeDef* node_def = graph_def->add_node(); 647 *node_def = node->def(); 648 649 // Use the node's assigned device, if any, instead of the device requested 650 // in the NodeDef. 651 if (!node->assigned_device_name().empty()) { 652 node_def->set_device(node->assigned_device_name()); 653 } 654 655 // Get the inputs for this Node. We make sure control inputs are 656 // after data inputs, as required by GraphDef. 657 inputs.clear(); 658 inputs.resize(node->num_inputs(), nullptr); 659 for (const Edge* edge : node->in_edges()) { 660 if (edge->IsControlEdge()) { 661 inputs.push_back(edge); 662 } else { 663 CHECK(inputs[edge->dst_input()] == nullptr) 664 << "Edge " << edge->src()->DebugString() << ":" 665 << edge->dst()->DebugString() << " with dst_input " 666 << edge->dst_input() << " and had pre-existing input edge " 667 << inputs[edge->dst_input()]->src()->DebugString() << ":" 668 << inputs[edge->dst_input()]->dst()->DebugString(); 669 670 inputs[edge->dst_input()] = edge; 671 } 672 } 673 // Sort the control inputs for more predictable serialization. 674 std::sort(inputs.begin() + node->num_inputs(), inputs.end(), 675 [](const Edge* a, const Edge* b) -> bool { 676 return a->src()->name() < b->src()->name(); 677 }); 678 node_def->clear_input(); 679 node_def->mutable_input()->Reserve(inputs.size()); 680 681 for (size_t i = 0; i < inputs.size(); ++i) { 682 const Edge* edge = inputs[i]; 683 if (edge == nullptr) { 684 if (i < node->requested_inputs().size()) { 685 node_def->add_input(node->requested_inputs()[i]); 686 } else { 687 node_def->add_input(""); 688 } 689 } else { 690 const Node* src = edge->src(); 691 if (!src->IsOp()) continue; 692 AddInput(node_def, src->name(), edge->src_output()); 693 } 694 } 695 } 696 } 697 698 string Graph::NewName(StringPiece prefix) { 699 return strings::StrCat(prefix, "/_", name_counter_++); 700 } 701 702 Status Graph::IsValidNode(const Node* node) const { 703 if (node == nullptr) { 704 return errors::InvalidArgument("Node is null"); 705 } 706 const int id = node->id(); 707 if (id < 0) { 708 return errors::InvalidArgument("node id ", id, " is less than zero"); 709 } 710 if (static_cast<size_t>(id) >= nodes_.size()) { 711 return errors::InvalidArgument( 712 "node id ", id, " is >= than number of nodes in graph ", nodes_.size()); 713 } 714 if (nodes_[id] != node) { 715 return errors::InvalidArgument("Node with id ", id, 716 " is different from the passed in node. " 717 "Does it belong to a different graph?"); 718 } 719 return Status::OK(); 720 } 721 722 Status Graph::IsValidOutputTensor(const Node* node, int idx) const { 723 TF_RETURN_IF_ERROR(IsValidNode(node)); 724 if (idx >= node->num_outputs() || idx < 0) { 725 return errors::OutOfRange("Node '", node->name(), "' (type: '", 726 node->op_def().name(), 727 "', num of outputs: ", node->num_outputs(), 728 ") does not have ", "output ", idx); 729 } 730 return Status::OK(); 731 } 732 733 Status Graph::IsValidInputTensor(const Node* node, int idx) const { 734 TF_RETURN_IF_ERROR(IsValidNode(node)); 735 if (idx >= node->num_inputs() || idx < 0) { 736 return errors::OutOfRange("Node '", node->name(), "' (type: '", 737 node->op_def().name(), 738 "', num of inputs: ", node->num_inputs(), 739 ") does not have ", "input ", idx); 740 } 741 return Status::OK(); 742 } 743 744 Node* Graph::AllocateNode(std::shared_ptr<NodeProperties> props, 745 const Node* cost_node) { 746 Node* node = nullptr; 747 if (free_nodes_.empty()) { 748 node = new (arena_.Alloc(sizeof(Node))) Node; // placement new 749 } else { 750 node = free_nodes_.back(); 751 free_nodes_.pop_back(); 752 } 753 node->graph_ = this; 754 const int id = nodes_.size(); 755 int cost_id = cost_node ? cost_node->cost_id() : id; 756 node->Initialize(id, cost_id, std::move(props)); 757 nodes_.push_back(node); 758 ++num_nodes_; 759 return node; 760 } 761 762 void Graph::ReleaseNode(Node* node) { 763 TF_DCHECK_OK(IsValidNode(node)) << node->DebugString(); 764 nodes_[node->id()] = nullptr; 765 free_nodes_.push_back(node); 766 --num_nodes_; 767 node->Clear(); 768 } 769 770 // Ensures that 'device_name' is present in the device name table, and returns 771 // the index of that device name. The index is stable, and can be used in 772 // calls to Node::set_assigned_device_name_index(). 773 int Graph::InternDeviceName(const string& device_name) { 774 // Special case, very common. Also, this allows us to use a single map 775 // lookup below, instead of two. The 'if (index_cell > 0)' test below 776 // relies on this check. 777 if (device_name.empty()) { 778 return 0; 779 } 780 781 int& index_cell = device_names_map_[device_name]; 782 if (index_cell > 0) { 783 return index_cell; 784 } 785 786 const int index = device_names_map_.size(); 787 index_cell = index; 788 device_names_.push_back(device_name); 789 return index; 790 } 791 792 Status Graph::AddWhileContext(StringPiece frame_name, 793 std::vector<Node*> enter_nodes, 794 std::vector<Node*> exit_nodes, 795 OutputTensor cond_output, 796 std::vector<OutputTensor> body_inputs, 797 std::vector<OutputTensor> body_outputs, 798 WhileContext** result) { 799 auto pair = while_ctxs_.insert(std::pair<string, WhileContext>( 800 string(frame_name), 801 WhileContext(frame_name, std::move(enter_nodes), std::move(exit_nodes), 802 cond_output, std::move(body_inputs), 803 std::move(body_outputs)))); 804 if (!pair.second) { 805 *result = nullptr; 806 return errors::InvalidArgument("WhileContext with frame name '", frame_name, 807 "' already exists"); 808 } 809 *result = &pair.first->second; 810 return Status::OK(); 811 } 812 813 std::unordered_map<string, Node*> Graph::BuildNodeNameIndex() const { 814 std::unordered_map<string, Node*> result; 815 for (Node* n : nodes()) { 816 result[n->name()] = n; 817 } 818 return result; 819 } 820 821 string Edge::DebugString() const { 822 return strings::Printf("[id=%d %s:%d -> %s:%d]", id_, src_->name().c_str(), 823 src_output_, dst_->name().c_str(), dst_input_); 824 } 825 826 } // namespace tensorflow 827