1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/graph/graph_constructor.h" 17 18 #include <algorithm> 19 #include <set> 20 #include <string> 21 #include <unordered_map> 22 #include <vector> 23 24 #include "tensorflow/core/common_runtime/shape_refiner.h" 25 #include "tensorflow/core/framework/function.h" 26 #include "tensorflow/core/framework/function.pb.h" 27 #include "tensorflow/core/framework/graph.pb.h" 28 #include "tensorflow/core/framework/node_def.pb.h" 29 #include "tensorflow/core/framework/node_def_util.h" 30 #include "tensorflow/core/framework/tensor_shape.pb.h" 31 #include "tensorflow/core/framework/types.h" 32 #include "tensorflow/core/framework/versions.h" 33 #include "tensorflow/core/framework/versions.pb.h" 34 #include "tensorflow/core/graph/algorithm.h" 35 #include "tensorflow/core/graph/graph.h" 36 #include "tensorflow/core/graph/tensor_id.h" 37 #include "tensorflow/core/lib/core/errors.h" 38 #include "tensorflow/core/lib/gtl/inlined_vector.h" 39 #include "tensorflow/core/lib/strings/scanner.h" 40 #include "tensorflow/core/platform/logging.h" 41 #include "tensorflow/core/public/version.h" 42 43 namespace tensorflow { 44 45 namespace { 46 inline bool IsMerge(const NodeDef& node_def) { 47 return node_def.op() == "Merge" || node_def.op() == "RefMerge"; 48 } 49 50 inline bool IsNextIteration(const NodeDef& node_def) { 51 return node_def.op() == "NextIteration" || 52 node_def.op() == "RefNextIteration"; 53 } 54 55 bool IsValidNodeName(StringPiece s, bool allow_internal_ops) { 56 using ::tensorflow::strings::Scanner; 57 return Scanner(s) 58 .One(allow_internal_ops ? Scanner::LETTER_DIGIT_DOT_UNDERSCORE 59 : Scanner::LETTER_DIGIT_DOT) 60 .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE) 61 .Eos() 62 .GetResult(); 63 } 64 65 class GraphConstructor { 66 public: 67 struct Options { 68 Options(const GraphConstructorOptions& in) // NOLINT(runtime/explicit) 69 : allow_internal_ops(in.allow_internal_ops), 70 expect_device_spec(in.expect_device_spec), 71 importing(false), 72 validate_colocation_constraints(false) {} 73 Options(const ImportGraphDefOptions& in) // NOLINT(runtime/explicit) 74 : allow_internal_ops(false), 75 expect_device_spec(false), 76 prefix(in.prefix.empty() || StringPiece(in.prefix).ends_with("/") 77 ? in.prefix 78 : in.prefix + "/"), 79 uniquify_names(in.uniquify_names), 80 uniquify_prefix(in.uniquify_prefix), 81 input_map(in.input_map), 82 skip_mapped_nodes(in.skip_mapped_nodes), 83 control_dependencies(in.control_dependencies), 84 return_tensors(in.return_tensors), 85 return_nodes(in.return_nodes), 86 importing(true), 87 validate_colocation_constraints(in.validate_colocation_constraints), 88 validate_shape(in.validate_shape) {} 89 90 bool allow_internal_ops; 91 bool expect_device_spec; 92 93 string prefix; 94 bool uniquify_names; 95 bool uniquify_prefix; 96 std::map<TensorId, TensorId> input_map; 97 bool skip_mapped_nodes; 98 std::vector<string> control_dependencies; 99 std::vector<TensorId> return_tensors; 100 std::vector<string> return_nodes; 101 102 // TODO(ashankar): This bool exists to separate out functionality required 103 // to make ImportGraphDef a close equivalent of Python's import_graph_def 104 // without affecting the behavior of ConvertGraphDefToGraph at the time 105 // ImportGraphDef was added. 106 // 107 // That said, the functionality here (shape and op validation) seems 108 // applicable to ConvertGraphDefToGraph as well, so make an attempt to 109 // remove this. 110 bool importing; 111 bool validate_colocation_constraints; 112 bool validate_shape = true; 113 }; 114 115 typedef gtl::ArraySlice<const NodeDef*> NodeDefSlice; 116 117 // versions and library may be nullptr 118 static Status Construct( 119 const Options& opts, NodeDefSlice node_defs, const VersionDef* versions, 120 const FunctionDefLibrary* library, Graph* g, ShapeRefiner* refiner, 121 std::vector<std::pair<Node*, int>>* return_tensors, 122 std::vector<Node*>* return_nodes, 123 std::vector<TensorId>* missing_unused_input_map_keys) { 124 if (versions) { 125 TF_RETURN_IF_ERROR(CheckVersions(*versions, TF_GRAPH_DEF_VERSION, 126 TF_GRAPH_DEF_VERSION_MIN_PRODUCER, 127 "GraphDef", "graph")); 128 } 129 GraphConstructor c(opts, node_defs, versions, library, g, refiner, 130 return_tensors, return_nodes, 131 missing_unused_input_map_keys); 132 const Status s = c.TryImport(); 133 if (!s.ok()) c.Undo(); 134 return s; 135 } 136 137 private: 138 GraphConstructor(const Options& opts, NodeDefSlice node_defs, 139 const VersionDef* versions, 140 const FunctionDefLibrary* library, Graph* g, 141 ShapeRefiner* refiner, 142 std::vector<std::pair<Node*, int>>* return_tensors, 143 std::vector<Node*>* return_nodes, 144 std::vector<TensorId>* missing_unused_input_map_keys) 145 : opts_(opts), 146 node_defs_(node_defs), 147 versions_(versions), 148 library_(library), 149 g_(g), 150 original_versions_(g->versions()), 151 prefix_(opts.prefix), 152 refiner_(refiner), 153 return_tensors_(return_tensors), 154 return_nodes_(return_nodes), 155 missing_unused_input_map_keys_(missing_unused_input_map_keys) {} 156 157 Status TryImport() { 158 TF_RETURN_IF_ERROR(EnsureNoNameCollisions()); 159 TF_RETURN_IF_ERROR(ValidateInputMapAndControlDependencies()); 160 TF_RETURN_IF_ERROR(BuildNodeIndex()); 161 TF_RETURN_IF_ERROR(InitFromEdges()); 162 TF_RETURN_IF_ERROR(Convert()); 163 TF_RETURN_IF_ERROR(AddBackEdges()); 164 TF_RETURN_IF_ERROR(UpdateVersionDef()); 165 TF_RETURN_IF_ERROR(PopulateReturnTensors()); 166 TF_RETURN_IF_ERROR(PopulateReturnNodes()); 167 TF_RETURN_IF_ERROR(PopulateMissingUnusedInputMapKeys()); 168 UpdateUniquifiedColocationNames(); 169 FixupSourceAndSinkEdges(g_); 170 return Status::OK(); 171 } 172 173 Status EnsureNoNameCollisions(); 174 Status ValidateInputMapAndControlDependencies(); 175 Status BuildNodeIndex(); 176 Status InitFromEdges(); 177 Status Convert(); 178 Status AddBackEdges(); 179 Status UpdateVersionDef(); 180 Status PopulateReturnTensors(); 181 Status PopulateReturnNodes(); 182 Status PopulateMissingUnusedInputMapKeys(); 183 184 void Undo(); 185 186 Status IsNodeFullyMapped(const NodeDef& node_def, bool* is_node_mapped); 187 Status ValidateColocationConstraints(const NodeDef& node_def); 188 Status MakeNode(const NodeDef& node_def, Node** node); 189 Status MakeEdge(Node* src, int output_index, Node* dst, int input_index); 190 Status ValidateShape(Node* node); 191 Status ModifyNodeDefForImport(NodeDef* node_def); 192 // Modifies node_def's inputs according to opts_.input_map. 193 // input_already_exists is a pre-initialized vector of length 194 // node_def->input_size(). This function will mark inputs that are remapped to 195 // true. 196 void RemapNodeDefInputs(NodeDef* node_def, 197 std::vector<bool>* input_already_exists); 198 // input_already_exists is a pre-initialized vector of length 199 // node_def->input_size(). This function will add and mark control inputs as 200 // true. 201 void AddControlDependencies(NodeDef* node_def, 202 std::vector<bool>* input_already_exists); 203 void AddPrefixToNodeDef(const std::vector<bool>& input_already_exists, 204 NodeDef* node_def); 205 206 // Modifies `node_def` if its name isn't unique, or if any of its inputs' 207 // names have been uniquified. This must be called in topological order on all 208 // nodes. 209 void UniquifyNames(const std::vector<bool>& input_already_exists, 210 NodeDef* node_def); 211 212 // Updates any constructed nodes' colocation group names if the name has been 213 // updated by UniquifyNames. This is called after all the nodes have been 214 // constructed so all the names have been uniquified if necessary. 215 void UpdateUniquifiedColocationNames(); 216 217 // Returns true if `name` already exists in `g_` (either as a node name or 218 // prefix). 219 bool NameExistsInGraph(StringPiece name); 220 221 // Returns true if `name` already exists in the GraphDef being imported 222 // (either as a node name or prefix). 223 bool NameExistsInGraphDef(StringPiece name); 224 225 // Returns a unique version of `original_name`, or `original_name` if it's 226 // already unique in the graph. 227 string FindUniqueName(StringPiece original_name); 228 229 // From constructor 230 const Options opts_; 231 const NodeDefSlice node_defs_; 232 const VersionDef* versions_; 233 const FunctionDefLibrary* library_; 234 Graph* g_; 235 const VersionDef original_versions_; 236 237 // A copy of opts_.prefix, possibly uniquified. 238 string prefix_; 239 240 ShapeRefiner* refiner_; 241 242 // May be null. Not owned. 243 std::vector<std::pair<Node*, int>>* return_tensors_; 244 245 // May be null. Not owned. 246 std::vector<Node*>* return_nodes_; 247 248 // May be null. Not owned. 249 std::vector<TensorId>* missing_unused_input_map_keys_; 250 251 // Intermediate datastructure used to populate 252 // `missing_unused_input_map_keys_`. 253 std::set<TensorId> used_input_map_keys_; 254 255 // Mapping from node name to the index within node_defs_. 256 struct NodeInfo { 257 explicit NodeInfo(int i) : gdef_index(i), node(nullptr) {} 258 // std::unordered_map<> requires that we have a default constructor. 259 NodeInfo() : NodeInfo(-1) {} 260 int gdef_index; 261 Node* node; // nullptr until the NodeDef is converted to a Node. 262 }; 263 // TODO(vrv): Profile this data structure to see if we should use an 264 // alternative implementation of std::unordered_map. 265 std::unordered_map<StringPiece, NodeInfo, StringPieceHasher> gdef_nodes_; 266 267 // Prefixes already used in the GraphDef being imported. 268 std::unordered_set<StringPiece, StringPieceHasher> gdef_prefixes_; 269 270 // Mapping from node name to the existing node in g_. 271 std::unordered_map<StringPiece, Node*, StringPieceHasher> existing_nodes_; 272 273 // Prefixes already used in the graph. 274 std::unordered_set<StringPiece, StringPieceHasher> existing_prefixes_; 275 276 // Imported node names that have been uniquified. The key is the original 277 // name, the value is the new unique name. 278 std::unordered_map<string, string> uniquified_names_; 279 280 // Index of NodeDefs in node_defs_ with all inputs already converted. 281 std::vector<int> ready_; 282 283 // Mapping between index within node_defs_ and the number of inputs that 284 // still need to be converted. 285 std::vector<int> pending_count_; 286 287 // Mapping between index within node_defs_ and the index within node_defs_ of 288 // all nodes it outputs to. 289 std::vector<gtl::InlinedVector<int, 4>> outputs_; 290 291 // Used in the conversion from node_defs_ to g_ to represent the ith input 292 // of a node. 293 struct InputInfo { 294 explicit InputInfo(const string& node_name, Node* n, int i) 295 : name(node_name), node(n), index(i) {} 296 // Use string instead of StringPiece so we don't have to manage lifetime 297 string name; 298 Node* node; 299 int index; 300 }; 301 302 // Used in the conversion from node_defs_ to g_ to represent an edge from 303 // the node named 'name' to node 'n'. 304 struct EdgeInfo { 305 explicit EdgeInfo(const string& name, int i1, Node* n, int i2) 306 : src_name(name), src_index(i1), dst_node(n), dst_index(i2) {} 307 // Use string instead of StringPiece so we don't have to manage lifetime 308 string src_name; 309 int src_index; 310 Node* dst_node; 311 int dst_index; 312 }; 313 std::vector<EdgeInfo> back_edges_; 314 }; 315 316 // This could be expensive but we don't expect to call it often, if at all (only 317 // if there are multiple nodes in g_ with the same name) 318 bool NodeNameInValues(const std::map<TensorId, TensorId>& input_map, 319 const StringPiece& node_name) { 320 for (auto iter = input_map.begin(); iter != input_map.end(); ++iter) { 321 if (iter->second.first == node_name) return true; 322 } 323 return false; 324 } 325 326 bool NodeNameInValues(const std::vector<string>& control_dependencies, 327 const StringPiece& node_name) { 328 return std::find(control_dependencies.begin(), control_dependencies.end(), 329 node_name) != control_dependencies.end(); 330 } 331 332 // Adds any prefixes of `node_name` (not including the full name itself) to 333 // `prefixes`. 334 void AddPrefixes(StringPiece node_name, 335 std::unordered_set<StringPiece, StringPieceHasher>* prefixes) { 336 size_t idx = -1; 337 while ((idx = node_name.find('/', idx + 1)) != StringPiece::npos) { 338 prefixes->insert(node_name.substr(0, idx)); 339 } 340 } 341 342 Status GraphConstructor::EnsureNoNameCollisions() { 343 existing_nodes_.reserve(g_->num_nodes()); 344 // Populate existing_nodes_ and existing_prefixes_. 345 for (Node* n : g_->nodes()) { 346 bool already_exists = !existing_nodes_.insert({n->name(), n}).second; 347 if (already_exists) { 348 if (NodeNameInValues(opts_.input_map, n->name())) { 349 return errors::InvalidArgument( 350 "cannot resolve input_map because multiple nodes exist with name '", 351 n->name(), "'"); 352 } 353 if (NodeNameInValues(opts_.control_dependencies, n->name())) { 354 return errors::InvalidArgument( 355 "cannot resolve control_dependencies because multiple nodes exist " 356 "with name '", 357 n->name(), "'"); 358 } 359 } 360 AddPrefixes(n->name(), &existing_prefixes_); 361 } 362 if (prefix_.empty() && opts_.importing && !opts_.uniquify_names) { 363 for (const NodeDef* n : node_defs_) { 364 const string& name = n->name(); 365 if (NameExistsInGraph(name)) { 366 return errors::InvalidArgument("Node name '", name, 367 "' already exists in the Graph"); 368 } 369 } 370 } else if (!prefix_.empty()) { 371 StringPiece prefix_no_slash(prefix_); 372 prefix_no_slash.remove_suffix(1); 373 if (!IsValidNodeName(prefix_no_slash, false)) { 374 return errors::InvalidArgument("Imported node name prefix '", prefix_, 375 "' would lead to invalid node names"); 376 } 377 if (NameExistsInGraph(prefix_no_slash) && opts_.uniquify_prefix) { 378 prefix_ = strings::StrCat(FindUniqueName(prefix_no_slash), "/"); 379 } 380 } 381 return Status::OK(); 382 } 383 384 Status GraphConstructor::ValidateInputMapAndControlDependencies() { 385 for (const auto& mapping : opts_.input_map) { 386 TensorId src = mapping.first; 387 TensorId dst = mapping.second; 388 if (existing_nodes_.count(dst.first) == 0) { 389 return errors::InvalidArgument( 390 "node '", dst.first, "' in input_map does not exist in graph ", 391 "(input_map entry: ", src.ToString(), "->", dst.ToString(), ")"); 392 } 393 if ((src.second == Graph::kControlSlot) != 394 (dst.second == Graph::kControlSlot)) { 395 return errors::InvalidArgument("input_map entry ", src.ToString(), "->", 396 dst.ToString(), " between ", 397 "control edge and non-control edge"); 398 } 399 } 400 for (const string& node : opts_.control_dependencies) { 401 if (existing_nodes_.count(node) == 0) { 402 return errors::InvalidArgument( 403 "node '", node, 404 "' in control_dependencies does not exist in " 405 "graph"); 406 } 407 } 408 return Status::OK(); 409 } 410 411 Status GraphConstructor::BuildNodeIndex() { 412 // Validate the node names and add them to gdef_nodes_ and gdef_prefixes_. 413 for (int n = 0; n < node_defs_.size(); ++n) { 414 const NodeDef& node_def = *node_defs_[n]; 415 if (!IsValidNodeName(node_def.name(), opts_.allow_internal_ops)) { 416 return errors::InvalidArgument( 417 "Node '", node_def.name(), 418 "': Node name contains invalid characters"); 419 } 420 if (!gdef_nodes_ 421 .insert(std::make_pair(StringPiece(node_def.name()), NodeInfo(n))) 422 .second) { 423 return errors::InvalidArgument("Node '", node_def.name(), 424 "' is not unique"); 425 } 426 // Validate the operation's type. 427 if (node_def.op().empty()) { 428 return errors::InvalidArgument("Node '", node_def.name(), 429 "' does not specify an operation"); 430 } 431 if (opts_.expect_device_spec && node_def.device().empty()) { 432 return errors::InvalidArgument("Node '", node_def.name(), 433 "' is missing a device specification"); 434 } 435 // Validate control edges at end 436 bool in_control_dependence = false; 437 for (int i = 0; i < node_def.input_size(); ++i) { 438 StringPiece input_name = node_def.input(i); 439 if (!input_name.empty() && input_name.starts_with("^")) { 440 in_control_dependence = true; 441 } else if (in_control_dependence) { 442 return errors::InvalidArgument( 443 "Node '", node_def.name(), 444 "': Control dependencies must come after regular dependencies"); 445 } 446 } 447 // Update gdef_prefixes_. 448 AddPrefixes(node_def.name(), &gdef_prefixes_); 449 } 450 return Status::OK(); 451 } 452 453 std::unordered_set<string> GetNextIterationNodes( 454 const GraphConstructor::NodeDefSlice& node_defs) { 455 std::unordered_set<string> next_iteration_nodes; 456 457 for (int n = 0; n < node_defs.size(); ++n) { 458 const NodeDef& node_def = *node_defs[n]; 459 if (IsNextIteration(node_def)) { 460 next_iteration_nodes.insert(node_def.name()); 461 } 462 } 463 464 return next_iteration_nodes; 465 } 466 467 Status GraphConstructor::InitFromEdges() { 468 const int num_nodes = node_defs_.size(); 469 pending_count_.reserve(num_nodes); 470 outputs_.resize(num_nodes); 471 std::unordered_set<string> next_iteration_nodes_ = 472 GetNextIterationNodes(node_defs_); 473 474 // Parse the inputs for each node. 475 for (int n = 0; n < num_nodes; ++n) { 476 const NodeDef& node_def = *node_defs_[n]; 477 int pending_count = node_def.input_size(); 478 if (IsMerge(node_def)) { 479 // Cycles in the graph are only allowed for while loops. A while loop is 480 // identified by an edge from a NextIteration node to a Merge node. For 481 // such Merge nodes, only wait for one non-control input before 482 // considering the node ready to process in Convert(). 483 int32 num_control_edges = 0; 484 bool has_loop_back_edge = false; 485 for (int i = 0; i < node_def.input_size(); ++i) { 486 StringPiece input_name(node_def.input(i)); 487 if (input_name.starts_with("^")) { 488 num_control_edges++; 489 } else { 490 TensorId id(ParseTensorName(input_name)); 491 if (next_iteration_nodes_.find(id.first.ToString()) != 492 next_iteration_nodes_.end()) { 493 has_loop_back_edge = true; 494 } 495 } 496 } 497 if (has_loop_back_edge) { 498 pending_count = num_control_edges + 1; 499 } 500 } 501 for (int i = 0; i < node_def.input_size(); ++i) { 502 StringPiece input_name = node_def.input(i); 503 TensorId id(ParseTensorName(input_name)); 504 if (opts_.input_map.count(id) == 0) { 505 // If an input is not mapped, then the input should appear in the graph 506 // being imported. 507 auto iter = gdef_nodes_.find(id.first); 508 if (iter == gdef_nodes_.end()) { 509 return errors::InvalidArgument("Node '", node_def.name(), 510 "': Unknown input node '", 511 node_def.input(i), "'"); 512 } 513 outputs_[iter->second.gdef_index].push_back(n); 514 } else { 515 // This input is mapped to an existing edge. Therefore this input is 516 // as good as being already processed. 517 --pending_count; 518 DCHECK_GE(pending_count, 0); 519 } 520 } 521 if (pending_count == 0) { 522 ready_.push_back(n); 523 } 524 pending_count_.push_back(pending_count); 525 } 526 return Status::OK(); 527 } 528 529 Status GraphConstructor::ValidateColocationConstraints( 530 const NodeDef& node_def) { 531 if (!opts_.validate_colocation_constraints || !opts_.importing) 532 return Status::OK(); 533 const auto iter = node_def.attr().find(kColocationAttrName); 534 if (iter == node_def.attr().end()) return Status::OK(); 535 for (const string& c : iter->second.list().s()) { 536 StringPiece s(c); 537 if (s.Consume(kColocationGroupPrefix) && 538 gdef_nodes_.find(s) == gdef_nodes_.end()) { 539 return errors::InvalidArgument( 540 "Node '", node_def.name(), 541 "' expects to be colocated with unknown node '", s, "'"); 542 } 543 } 544 return Status::OK(); 545 } 546 547 Status GraphConstructor::MakeNode(const NodeDef& node_def, Node** node) { 548 // Add the node to the graph. 549 Status status; 550 *node = g_->AddNode(node_def, &status); 551 if (!status.ok()) return status; 552 if (opts_.expect_device_spec) { 553 (*node)->set_assigned_device_name(node_def.device()); 554 } 555 return Status::OK(); 556 } 557 558 Status GraphConstructor::ValidateShape(Node* node) { 559 if (!opts_.importing || !opts_.validate_shape) return Status::OK(); 560 TF_RETURN_IF_ERROR(refiner_->AddNode(node)); 561 // For nodes with the _output_shapes attribute, override the shape. 562 std::vector<TensorShapeProto> shape_attrs; 563 const char* kAttrName = "_output_shapes"; 564 if (!GetNodeAttr(node->attrs(), kAttrName, &shape_attrs).ok()) { 565 // No _output_shapes attribute, the AddNode call above was sufficient. 566 return Status::OK(); 567 } 568 auto* ic = refiner_->GetContext(node); 569 DCHECK(ic != nullptr) 570 << "ShapeRefiner::AddNode() should have created the InferenceContext"; 571 if (shape_attrs.size() != node->num_outputs()) { 572 return errors::InvalidArgument( 573 "Node '", node->name(), "' has ", node->num_outputs(), 574 " outputs but the ", kAttrName, " attribute specifies shapes for ", 575 shape_attrs.size(), " outputs"); 576 } 577 for (int i = 0; i < shape_attrs.size(); ++i) { 578 const TensorShapeProto& p = shape_attrs[i]; 579 shape_inference::ShapeHandle h; 580 Status s = ic->MakeShapeFromShapeProto(p, &h); 581 if (!s.ok()) { 582 return errors::InvalidArgument("Node '", node->name(), " has an invalid ", 583 kAttrName, " attribute (shape #", i, 584 " error:'", s.error_message(), "'"); 585 } 586 s = refiner_->SetShape(node, i, h); 587 if (!s.ok()) { 588 // If the output shape is incompatible with what is inferred 589 // by the graph for a very specific whitelist of ops, then we 590 // ignore this output shape. This can happen if there is a 591 // bug in the shape function for some operation, and the 592 // serialized graph def has the incorrect shape set when 593 // running on a newer binary with the fixed shape function. 594 // This is an escape hatch that allows us to correct shape 595 // functions that are not critical to correct execution but 596 // would cause graphs to fail if imported after correcting. 597 // 598 const string& op = node->type_string(); 599 const std::vector<string> whitelist = { 600 // To be removed after 2017/03/08. 601 "RandomShuffleQueue", 602 "PaddingFIFOQueue", 603 "FIFOQueue", 604 "PriorityQueue", 605 "QueueSize", 606 "Stack", 607 "Barrier", 608 "BarrierReadySize", 609 "BarrierIncompleteSize", 610 "HashTable", 611 "MutableHashTable", 612 "MutableHashTableOfTensors", 613 "Mutex", 614 "CuckooTable", 615 "IndexTable", 616 "WholeFileReader", 617 "TextLineReader", 618 "FixedLengthRecordReader", 619 "TFRecordReader", 620 "IdentityReader", 621 "RefSwitch", 622 "RefEnter", 623 "RefNextIteration", 624 "RefMerge", 625 "RefIdentity", 626 "LMDBReader", 627 // To be removed after 2017/04/24. 628 "ConditionalAccumulator", 629 "SparseConditionalAccumulator", 630 "Table", 631 }; 632 if (std::find(whitelist.begin(), whitelist.end(), op) == 633 whitelist.end()) { 634 return errors::InvalidArgument( 635 "Node '", node->name(), "' has an ", kAttrName, 636 " attribute inconsistent with the GraphDef for output #", i, ": ", 637 s.error_message()); 638 } 639 } 640 } 641 node->ClearAttr(kAttrName); 642 return Status::OK(); 643 } 644 645 Status GraphConstructor::ModifyNodeDefForImport(NodeDef* node_def) { 646 const OpDef* op_def; 647 TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def->op(), &op_def)); 648 AddDefaultsToNodeDef(*op_def, node_def); 649 TF_RETURN_IF_ERROR(ValidateNodeDef(*node_def, *op_def)); 650 if (versions_) { 651 TF_RETURN_IF_ERROR(CheckOpDeprecation(*op_def, versions_->producer())); 652 } 653 return Status::OK(); 654 } 655 656 void RemoveInputs(const std::vector<int>& inputs_to_remove, NodeDef* node_def, 657 std::vector<bool>* input_already_exists) { 658 // Remove 'inputs_to_remove' from 'node_def' 659 // TODO(skyewm): is there a better way to do this? 660 std::vector<string> inputs; 661 inputs.reserve(node_def->input_size()); 662 for (int i = 0; i < node_def->input_size(); ++i) { 663 inputs.push_back(node_def->input(i)); 664 } 665 node_def->clear_input(); 666 for (int i = 0, j = 0; i < inputs.size(); ++i) { 667 if (j < inputs_to_remove.size() && i == inputs_to_remove[j]) { 668 ++j; 669 } else { 670 node_def->add_input(inputs[i]); 671 } 672 } 673 // Remove 'inputs_to_remove' from 'input_already_exists' 674 for (int idx : inputs_to_remove) { 675 input_already_exists->erase(input_already_exists->begin() + idx); 676 } 677 DCHECK_EQ(input_already_exists->size(), node_def->input_size()); 678 } 679 680 void GraphConstructor::RemapNodeDefInputs( 681 NodeDef* node_def, std::vector<bool>* input_already_exists) { 682 DCHECK_EQ(input_already_exists->size(), node_def->input_size()); 683 std::set<TensorId> control_inputs; 684 std::vector<int> inputs_to_remove; 685 686 for (int i = 0; i < node_def->input_size(); ++i) { 687 auto iter = opts_.input_map.find(ParseTensorName(node_def->input(i))); 688 if (iter == opts_.input_map.end()) continue; 689 used_input_map_keys_.insert(iter->first); 690 691 TensorId new_input = iter->second; 692 if (new_input.second == Graph::kControlSlot) { 693 // Check if we've already remapped a different input to new_input, and if 694 // so remove this input. 695 if (control_inputs.count(new_input) > 0) { 696 inputs_to_remove.push_back(i); 697 continue; 698 } 699 control_inputs.insert(new_input); 700 } 701 node_def->set_input(i, new_input.ToString()); 702 (*input_already_exists)[i] = true; 703 } 704 if (!inputs_to_remove.empty()) { 705 RemoveInputs(inputs_to_remove, node_def, input_already_exists); 706 } 707 } 708 709 void GraphConstructor::AddControlDependencies( 710 NodeDef* node_def, std::vector<bool>* input_already_exists) { 711 // To avoid adding redundant control dependencies to every imported node, skip 712 // nodes that will inherit the dependencies from another imported node. 713 bool inherits_deps = false; 714 for (int i = 0; i < node_def->input_size(); ++i) { 715 // Assume we won't inherit dependencies from remapped inputs that already 716 // exist in the graph. Even if we're wrong, we'll only add redundant 717 // dependencies. 718 if ((*input_already_exists)[i]) continue; 719 720 // If this input is a backedge, assume we won't inherit the dependencies. 721 // TODO(skyewm): we have many redundant ParseTensorName calls. It could be 722 // worth optimizing these. 723 TensorId id(ParseTensorName(node_def->input(i))); 724 auto iter = gdef_nodes_.find(id.first); 725 DCHECK(iter != gdef_nodes_.end()) << id.first; 726 if (iter->second.node == nullptr) { 727 // Input hasn't been created yet, indicating it's a backedge. 728 continue; 729 } 730 inherits_deps = true; 731 } 732 if (inherits_deps) return; 733 734 // node_def either has no inputs or all remapped inputs, add the control 735 // dependencies 736 for (const string& control_dep : opts_.control_dependencies) { 737 string input = TensorId(control_dep, Graph::kControlSlot).ToString(); 738 const protobuf::RepeatedPtrField<string>& inputs = node_def->input(); 739 if (std::find(inputs.begin(), inputs.end(), input) != inputs.end()) { 740 // Control dependency already exists 741 continue; 742 } 743 node_def->add_input(input); 744 input_already_exists->push_back(true); 745 } 746 } 747 748 void GraphConstructor::AddPrefixToNodeDef( 749 const std::vector<bool>& input_already_exists, NodeDef* node_def) { 750 if (prefix_.empty()) return; 751 node_def->set_name(strings::StrCat(prefix_, node_def->name())); 752 // Update names of input nodes 753 for (int i = 0; i < node_def->input_size(); ++i) { 754 StringPiece input(node_def->input(i)); 755 // Skip remapped inputs (which already exist in g_ and are not being 756 // imported). 757 if (input_already_exists[i]) continue; 758 if (input.Consume("^")) { 759 node_def->set_input(i, strings::StrCat("^", prefix_, input)); 760 } else { 761 node_def->set_input(i, strings::StrCat(prefix_, input)); 762 } 763 } 764 // Update names of colocation groups 765 if (node_def->attr().find(kColocationAttrName) != node_def->attr().end()) { 766 auto* list = 767 node_def->mutable_attr()->at(kColocationAttrName).mutable_list(); 768 for (int i = 0; i < list->s_size(); ++i) { 769 StringPiece v(list->s(i)); 770 if (v.Consume(kColocationGroupPrefix)) { 771 list->set_s(i, strings::StrCat(kColocationGroupPrefix, prefix_, v)); 772 } 773 } 774 } 775 } 776 777 void GraphConstructor::UniquifyNames( 778 const std::vector<bool>& input_already_exists, NodeDef* node_def) { 779 if (NameExistsInGraph(node_def->name())) { 780 string old_name = node_def->name(); 781 node_def->set_name(FindUniqueName(node_def->name())); 782 uniquified_names_[old_name] = node_def->name(); 783 // Note that we don't have to update gdef_nodes_ or gdef_prefixes_ with 784 // `name` because we guarantee the original NodeDef names are unique, 785 // meaning we won't generate this name again. 786 } 787 for (int i = 0; i < node_def->input_size(); ++i) { 788 // Skip remapped inputs (which already exist in g_ and are not being 789 // imported). 790 if (input_already_exists[i]) continue; 791 TensorId id = ParseTensorName(node_def->input(i)); 792 // We require that UniquifyNames() is called on all NodeDefs in topological 793 // order. This guarantees that node_def's inputs will already be uniquified 794 // if necessary. 795 auto iter = uniquified_names_.find(id.first.ToString()); 796 if (iter == uniquified_names_.end()) continue; 797 id.first = iter->second; 798 node_def->set_input(i, id.ToString()); 799 } 800 } 801 802 void GraphConstructor::UpdateUniquifiedColocationNames() { 803 for (const auto& pair : gdef_nodes_) { 804 Node* node = pair.second.node; 805 if (node == nullptr) continue; 806 std::vector<string> coloc_values; 807 Status status = 808 GetNodeAttr(node->attrs(), kColocationAttrName, &coloc_values); 809 if (!status.ok()) continue; 810 bool updated = false; 811 for (int i = 0; i < coloc_values.size(); ++i) { 812 StringPiece val(coloc_values[i]); 813 if (val.Consume(kColocationGroupPrefix)) { 814 const auto& name_pair = uniquified_names_.find(val.ToString()); 815 if (name_pair == uniquified_names_.end()) continue; 816 updated = true; 817 coloc_values[i] = 818 strings::StrCat(kColocationGroupPrefix, name_pair->second); 819 } 820 } 821 if (updated) { 822 node->AddAttr(kColocationAttrName, coloc_values); 823 } 824 } 825 } 826 827 bool GraphConstructor::NameExistsInGraph(StringPiece name) { 828 if (existing_nodes_.find(name) != existing_nodes_.end()) return true; 829 if (existing_prefixes_.find(name) != existing_prefixes_.end()) return true; 830 return false; 831 } 832 833 bool GraphConstructor::NameExistsInGraphDef(StringPiece name) { 834 if (gdef_nodes_.find(name) != gdef_nodes_.end()) return true; 835 if (gdef_prefixes_.find(name) != gdef_prefixes_.end()) return true; 836 return false; 837 } 838 839 string GraphConstructor::FindUniqueName(StringPiece original_name) { 840 string name = original_name.ToString(); 841 int count = 0; 842 // Check that any generated names don't collide with imported NodeDefs (as 843 // well as nodes in g_). 844 while (NameExistsInGraph(name) || (count > 0 && NameExistsInGraphDef(name))) { 845 name = strings::StrCat(original_name, "_", ++count); 846 } 847 return name; 848 } 849 850 Status GraphConstructor::IsNodeFullyMapped(const NodeDef& node_def, 851 bool* is_node_mapped) { 852 const OpDef* op_def; 853 TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def.op(), &op_def)); 854 for (int i = 0; i < op_def->output_arg_size(); ++i) { 855 if (opts_.input_map.find({node_def.name(), i}) == opts_.input_map.end()) { 856 *is_node_mapped = false; 857 return Status::OK(); 858 } 859 } 860 *is_node_mapped = true; 861 return Status::OK(); 862 } 863 864 namespace { 865 866 void UpdatePendingCountAndReady( 867 const std::vector<gtl::InlinedVector<int, 4>>& outputs, int o, 868 std::vector<int>* pending_count, std::vector<int>* ready) { 869 for (size_t i = 0; i < outputs[o].size(); ++i) { 870 const int output = outputs[o][i]; 871 (*pending_count)[output]--; 872 if ((*pending_count)[output] == 0) { 873 ready->push_back(output); 874 } 875 } 876 } 877 878 } // anonymous namespace 879 880 Status GraphConstructor::Convert() { 881 // Import functions before adding nodes, since imported nodes may refer to 882 // functions 883 if (library_) { 884 TF_RETURN_IF_ERROR(g_->AddFunctionLibrary(*library_)); 885 } 886 887 std::vector<InputInfo> inputs; 888 int processed = 0; 889 890 std::vector<bool> input_already_exists; 891 892 // Process the NodeDefs in topological order. 893 // (InitFromEdges() sets this up by filling in ready_ with nodes that have no 894 // inputs, pending_counts_ with the number of inputs for each node and 895 // outputs_ with the outputs of each node). 896 while (!ready_.empty()) { 897 int o = ready_.back(); 898 ready_.pop_back(); 899 ++processed; 900 inputs.clear(); 901 bool has_data_back_edge = false; 902 903 const NodeDef& original_node_def = *node_defs_[o]; 904 NodeDef imported_node_def; 905 const NodeDef* node_def; 906 907 // input_already_exists[i] is true iff the i-th input of the node we're 908 // importing refers to a preexisting node in g_ (i.e. input[i] existed prior 909 // to importing node_defs_). Conversely, input_already_exists[i] is false 910 // iff the input refers to a node in node_defs_. 911 input_already_exists.clear(); 912 input_already_exists.resize(original_node_def.input_size(), false); 913 914 if (opts_.importing) { 915 if (opts_.skip_mapped_nodes) { 916 bool is_node_mapped = false; 917 TF_RETURN_IF_ERROR( 918 IsNodeFullyMapped(original_node_def, &is_node_mapped)); 919 if (is_node_mapped) { 920 // Skip this node after updating pending_count_ for outputs 921 UpdatePendingCountAndReady(outputs_, o, &pending_count_, &ready_); 922 continue; 923 } 924 } 925 926 // TODO(ashankar): The line below means an additional copy of the NodeDef, 927 // which can be expensive if the NodeDef contains large tensors in it. 928 // Might make sense to change the API for ImportGraphDef to take a mutable 929 // GraphDef* and avoid the copying. 930 imported_node_def = original_node_def; 931 if (!opts_.input_map.empty()) { 932 // Note that input_already_exists can shrink here 933 RemapNodeDefInputs(&imported_node_def, &input_already_exists); 934 } 935 if (!opts_.control_dependencies.empty()) { 936 // Note that input_already_exists can grow here 937 AddControlDependencies(&imported_node_def, &input_already_exists); 938 } 939 node_def = &imported_node_def; 940 } else { 941 node_def = &original_node_def; 942 } 943 944 DCHECK_EQ(node_def->input_size(), input_already_exists.size()); 945 TF_RETURN_IF_ERROR(ValidateColocationConstraints(*node_def)); 946 for (int i = 0; i < node_def->input_size(); ++i) { 947 TensorId id(ParseTensorName(node_def->input(i))); 948 Node* src_node; 949 int src_index; 950 951 if (!input_already_exists[i]) { 952 // Locate input in newly-imported nodes 953 auto iter = gdef_nodes_.find(id.first); 954 DCHECK(iter != gdef_nodes_.end()) << id.first; 955 src_node = iter->second.node; 956 src_index = id.second; 957 if (src_node == nullptr) has_data_back_edge = true; 958 } else { 959 // Input refers to preexistng node in graph 960 auto iter = existing_nodes_.find(id.first); 961 DCHECK(iter != existing_nodes_.end()) << id.first; 962 src_node = iter->second; 963 src_index = id.second; 964 } 965 966 if (src_node != nullptr && src_index >= src_node->num_outputs()) { 967 return errors::InvalidArgument( 968 "Node '", node_def->name(), "': Connecting to invalid output ", 969 id.second, " of source node ", id.first, " which has ", 970 src_node->num_outputs(), " outputs"); 971 } 972 973 inputs.push_back(InputInfo(id.first.ToString(), src_node, src_index)); 974 } 975 976 if (has_data_back_edge && !IsMerge(*node_def)) { 977 return errors::InvalidArgument( 978 "Node '", node_def->name(), 979 "' had a back edge, but only Merge nodes can have back edges."); 980 } 981 982 Node* node; 983 if (opts_.importing) { 984 if (!prefix_.empty()) { 985 AddPrefixToNodeDef(input_already_exists, &imported_node_def); 986 } 987 // Note: no need to uniquify names if the prefix already guarantees 988 // uniqueness 989 if (opts_.uniquify_names && (prefix_.empty() || !opts_.uniquify_prefix)) { 990 UniquifyNames(input_already_exists, &imported_node_def); 991 } 992 TF_RETURN_IF_ERROR(ModifyNodeDefForImport(&imported_node_def)); 993 } 994 TF_RETURN_IF_ERROR(MakeNode(*node_def, &node)); 995 // Use original_node_def so name StringPiece remains valid 996 gdef_nodes_[original_node_def.name()].node = node; 997 998 // Add edges from inputs to *node to the graph. 999 for (size_t i = 0; i < inputs.size(); ++i) { 1000 if (inputs[i].node == nullptr) { 1001 // Record this back edge, which will be added after all nodes 1002 // are created. 1003 back_edges_.push_back( 1004 EdgeInfo(inputs[i].name, inputs[i].index, node, i)); 1005 } else if (inputs[i].index == Graph::kControlSlot) { 1006 g_->AddControlEdge(inputs[i].node, node); 1007 } else { 1008 TF_RETURN_IF_ERROR(MakeEdge(inputs[i].node, inputs[i].index, node, i)); 1009 } 1010 } 1011 1012 // Function shape inference is supported on an opt-in basis per 1013 // ShapeRefiner. 1014 if (refiner_->function_shape_inference_supported() || 1015 g_->flib_def().Find(node_def->name()) == nullptr) { 1016 TF_RETURN_IF_ERROR(ValidateShape(node)); 1017 } 1018 1019 // Update pending_count_ for outputs. 1020 UpdatePendingCountAndReady(outputs_, o, &pending_count_, &ready_); 1021 } 1022 1023 if (processed < node_defs_.size()) { 1024 return errors::InvalidArgument(node_defs_.size() - processed, 1025 " nodes in a cycle"); 1026 } 1027 1028 return Status::OK(); 1029 } 1030 1031 Status GraphConstructor::AddBackEdges() { 1032 // Add the back edges after all nodes are created. 1033 for (auto e : back_edges_) { 1034 Node* src_node = gdef_nodes_[e.src_name].node; 1035 if (e.src_index == Graph::kControlSlot) { 1036 g_->AddControlEdge(src_node, e.dst_node); 1037 } else { 1038 TF_RETURN_IF_ERROR( 1039 MakeEdge(src_node, e.src_index, e.dst_node, e.dst_index)); 1040 } 1041 1042 VLOG(2) << "Add back edge: " << src_node->name() << " -> " 1043 << e.dst_node->name(); 1044 } 1045 return Status::OK(); 1046 } 1047 1048 Status GraphConstructor::UpdateVersionDef() { 1049 if (versions_ == nullptr) return Status::OK(); 1050 1051 if (!opts_.importing) { 1052 g_->set_versions(*versions_); 1053 return Status::OK(); 1054 } 1055 VersionDef versions = g_->versions(); 1056 versions.set_producer(std::min(versions.producer(), versions_->producer())); 1057 versions.set_min_consumer( 1058 std::max(versions.min_consumer(), versions_->min_consumer())); 1059 if (versions_->bad_consumers_size() > 0) { 1060 std::set<int> bad(versions.bad_consumers().begin(), 1061 versions.bad_consumers().end()); 1062 bad.insert(versions_->bad_consumers().begin(), 1063 versions_->bad_consumers().end()); 1064 versions.clear_bad_consumers(); 1065 for (int v : bad) { 1066 versions.add_bad_consumers(v); 1067 } 1068 } 1069 g_->set_versions(versions); 1070 return Status::OK(); 1071 } 1072 1073 Status GraphConstructor::PopulateReturnTensors() { 1074 if (opts_.return_tensors.empty()) return Status::OK(); 1075 for (const TensorId& id : opts_.return_tensors) { 1076 auto iter = opts_.input_map.find(id); 1077 if (iter == opts_.input_map.end()) { 1078 // Locate id in imported nodes 1079 auto iter = gdef_nodes_.find(id.first); 1080 if (iter == gdef_nodes_.end()) { 1081 return errors::InvalidArgument("Requested return tensor '", 1082 id.ToString(), 1083 "' not found in graph def"); 1084 } 1085 int num_outputs = iter->second.node->num_outputs(); 1086 if ((id.second < 0 || id.second >= num_outputs) && 1087 id.second != Graph::kControlSlot) { 1088 return errors::InvalidArgument("Invalid return output ", id.second, 1089 " of node '", id.first, "', which has ", 1090 num_outputs, " output(s)"); 1091 } 1092 return_tensors_->push_back({iter->second.node, id.second}); 1093 } else { 1094 // id was remapped to existing node 1095 TensorId remapped_id = iter->second; 1096 DCHECK_GT(existing_nodes_.count(remapped_id.first), 0); 1097 Node* node = existing_nodes_[remapped_id.first]; 1098 return_tensors_->push_back({node, remapped_id.second}); 1099 } 1100 } 1101 return Status::OK(); 1102 } 1103 1104 Status GraphConstructor::PopulateReturnNodes() { 1105 if (opts_.return_nodes.empty()) return Status::OK(); 1106 for (StringPiece name : opts_.return_nodes) { 1107 auto iter = gdef_nodes_.find(name); 1108 if (iter == gdef_nodes_.end()) { 1109 return errors::InvalidArgument("Requested return node '", name, 1110 "' not found in graph def"); 1111 } 1112 return_nodes_->push_back(iter->second.node); 1113 } 1114 return Status::OK(); 1115 } 1116 1117 Status GraphConstructor::PopulateMissingUnusedInputMapKeys() { 1118 if (missing_unused_input_map_keys_ == nullptr) return Status::OK(); 1119 for (const auto& input_map_pair : opts_.input_map) { 1120 TensorId key = input_map_pair.first; 1121 if (used_input_map_keys_.count(key) > 0) continue; 1122 1123 auto pair = gdef_nodes_.find(key.first); 1124 if (pair == gdef_nodes_.end()) { 1125 // key's node doesn't exist in GraphDef 1126 missing_unused_input_map_keys_->push_back(key); 1127 continue; 1128 } 1129 1130 // Check that key's index is in bounds. Get the number of outputs from the 1131 // NodeDef, rather than the imported Node, since the Node may not exist if 1132 // opts_.skip_mapped_nodes is true. 1133 const NodeDef* node_def = node_defs_[pair->second.gdef_index]; 1134 const OpDef* op_def; 1135 TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def->op(), &op_def)); 1136 if (key.second >= op_def->output_arg_size()) { 1137 // key's index out of bounds 1138 missing_unused_input_map_keys_->push_back(key); 1139 } 1140 } 1141 return Status::OK(); 1142 } 1143 1144 void GraphConstructor::Undo() { 1145 for (const auto& iter : gdef_nodes_) { 1146 if (iter.second.node != nullptr) { 1147 g_->RemoveNode(iter.second.node); 1148 } 1149 } 1150 g_->set_versions(original_versions_); 1151 } 1152 1153 Status GraphConstructor::MakeEdge(Node* src, int output_index, Node* dst, 1154 int input_index) { 1155 DataType src_out = src->output_type(output_index); 1156 DataType dst_in = dst->input_type(input_index); 1157 if (!TypesCompatible(dst_in, src_out)) { 1158 return errors::InvalidArgument( 1159 "Input ", input_index, " of node ", dst->name(), " was passed ", 1160 DataTypeString(src_out), " from ", src->name(), ":", output_index, 1161 " incompatible with expected ", DataTypeString(dst_in), "."); 1162 } 1163 g_->AddEdge(src, output_index, dst, input_index); 1164 return Status::OK(); 1165 } 1166 1167 } // namespace 1168 1169 Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts, 1170 const GraphDef& gdef, Graph* g) { 1171 ShapeRefiner refiner(gdef.versions().producer(), g->op_registry()); 1172 return GraphConstructor::Construct( 1173 opts, gdef.node(), &gdef.versions(), &gdef.library(), g, &refiner, 1174 /*return_tensors=*/nullptr, /*return_nodes=*/nullptr, 1175 /*missing_unused_input_map_keys=*/nullptr); 1176 } 1177 1178 Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts, 1179 gtl::ArraySlice<NodeDef> nodes, Graph* g) { 1180 ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, g->op_registry()); 1181 // TODO(irving): Copy will go away once NodeInfo exists 1182 std::vector<const NodeDef*> node_defs; 1183 for (const auto& n : nodes) { 1184 node_defs.push_back(&n); 1185 } 1186 return GraphConstructor::Construct(opts, node_defs, nullptr, nullptr, g, 1187 &refiner, /*return_tensors=*/nullptr, 1188 /*return_nodes=*/nullptr, 1189 /*missing_unused_input_map_keys=*/nullptr); 1190 } 1191 1192 Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef, 1193 Graph* g, ShapeRefiner* refiner, 1194 ImportGraphDefResults* results) { 1195 if (!opts.return_tensors.empty()) { 1196 if (results == nullptr) { 1197 return errors::InvalidArgument( 1198 "results argument to ImportGraphDef() must be non-null if " 1199 "opts.return_tensors is non-empty"); 1200 } 1201 } 1202 1203 if (!opts.return_nodes.empty()) { 1204 if (opts.skip_mapped_nodes) { 1205 return errors::InvalidArgument( 1206 "Requesting return_nodes with skip_mapped_nodes set is not currently " 1207 "supported"); 1208 } 1209 if (results == nullptr) { 1210 return errors::InvalidArgument( 1211 "results argument to ImportGraphDef() must be non-null if " 1212 "opts.return_nodes is non-empty"); 1213 } 1214 } 1215 1216 if (results != nullptr) { 1217 if (!results->return_tensors.empty() || !results->return_nodes.empty() || 1218 !results->missing_unused_input_map_keys.empty()) { 1219 return errors::InvalidArgument( 1220 "All fields in results argument to ImportGraphDef() must be empty."); 1221 } 1222 } 1223 1224 ShapeRefiner default_refiner(gdef.versions().producer(), g->op_registry()); 1225 if (refiner == nullptr) { 1226 refiner = &default_refiner; 1227 } else { 1228 // Log a warning if we are importing a GraphDef at an older 1229 // producer version after already having added non-source/sink 1230 // nodes to the graph in the past. 1231 if (gdef.versions().producer() > 0 && 1232 gdef.versions().producer() < refiner->graph_def_version() && 1233 g->num_nodes() > 2) { 1234 LOG(WARNING) << "Importing a graph with a lower producer version " 1235 << gdef.versions().producer() 1236 << " into an existing graph with producer version " 1237 << refiner->graph_def_version() << ". Shape inference will " 1238 << "have run different parts of the graph with different " 1239 << "producer versions."; 1240 } 1241 } 1242 1243 // Set the graph def version of the refiner as the min of the 1244 // current value and the version from the graph we are about to 1245 // import. 1246 // 1247 // Note: to match Run() semantics, we should re-run shape inference 1248 // on the entire graph if the producer version has changed. For now 1249 // we log the warning above. 1250 refiner->set_graph_def_version( 1251 std::min(refiner->graph_def_version(), gdef.versions().producer())); 1252 1253 if (results == nullptr) { 1254 return GraphConstructor::Construct(opts, gdef.node(), &gdef.versions(), 1255 &gdef.library(), g, refiner, nullptr, 1256 nullptr, nullptr); 1257 } else { 1258 return GraphConstructor::Construct( 1259 opts, gdef.node(), &gdef.versions(), &gdef.library(), g, refiner, 1260 &results->return_tensors, &results->return_nodes, 1261 &results->missing_unused_input_map_keys); 1262 } 1263 } 1264 1265 void CopyGraph(const Graph& src, Graph* dest) { 1266 for (Node* n : dest->nodes()) { 1267 CHECK(n->IsSource() || n->IsSink()) << "*dest must be empty"; 1268 } 1269 1270 // Copy GraphDef versions 1271 dest->set_versions(src.versions()); 1272 1273 // Copy the nodes 1274 std::unordered_map<Node*, Node*> 1275 node_map; // "Node in src" -> "Node in *dest" 1276 node_map[src.source_node()] = dest->source_node(); 1277 node_map[src.sink_node()] = dest->sink_node(); 1278 for (Node* n : src.op_nodes()) { 1279 node_map[n] = dest->CopyNode(n); 1280 } 1281 1282 // Copy the edges 1283 for (const Edge* e : src.edges()) { 1284 Node* src_copy = node_map[e->src()]; 1285 Node* dst_copy = node_map[e->dst()]; 1286 dest->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input()); 1287 } 1288 } 1289 1290 } // namespace tensorflow 1291