1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/graph/graph_constructor.h" 17 18 #include <algorithm> 19 #include <set> 20 #include <string> 21 #include <unordered_map> 22 #include <vector> 23 24 #include "tensorflow/core/common_runtime/shape_refiner.h" 25 #include "tensorflow/core/framework/function.h" 26 #include "tensorflow/core/framework/function.pb.h" 27 #include "tensorflow/core/framework/graph.pb.h" 28 #include "tensorflow/core/framework/node_def.pb.h" 29 #include "tensorflow/core/framework/node_def_util.h" 30 #include "tensorflow/core/framework/tensor_shape.pb.h" 31 #include "tensorflow/core/framework/types.h" 32 #include "tensorflow/core/framework/versions.h" 33 #include "tensorflow/core/framework/versions.pb.h" 34 #include "tensorflow/core/graph/algorithm.h" 35 #include "tensorflow/core/graph/graph.h" 36 #include "tensorflow/core/graph/tensor_id.h" 37 #include "tensorflow/core/lib/core/errors.h" 38 #include "tensorflow/core/lib/gtl/flatmap.h" 39 #include "tensorflow/core/lib/gtl/flatset.h" 40 #include "tensorflow/core/lib/gtl/inlined_vector.h" 41 #include "tensorflow/core/lib/strings/scanner.h" 42 #include "tensorflow/core/lib/strings/str_util.h" 43 #include "tensorflow/core/platform/logging.h" 44 #include "tensorflow/core/public/version.h" 45 46 namespace tensorflow { 47 48 namespace { 49 inline bool IsMerge(const NodeDef& node_def) { 50 return node_def.op() == "Merge" || node_def.op() == "RefMerge"; 51 } 52 53 inline bool IsNextIteration(const NodeDef& node_def) { 54 return node_def.op() == "NextIteration" || 55 node_def.op() == "RefNextIteration"; 56 } 57 58 bool IsValidNodeName(StringPiece s, bool allow_internal_ops) { 59 using ::tensorflow::strings::Scanner; 60 return Scanner(s) 61 .One(allow_internal_ops ? Scanner::LETTER_DIGIT_DOT_UNDERSCORE 62 : Scanner::LETTER_DIGIT_DOT) 63 .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE) 64 .Eos() 65 .GetResult(); 66 } 67 68 class GraphConstructor { 69 public: 70 struct Options { 71 Options(const GraphConstructorOptions& in) // NOLINT(runtime/explicit) 72 : allow_internal_ops(in.allow_internal_ops), 73 expect_device_spec(in.expect_device_spec), 74 importing(false), 75 validate_colocation_constraints(false) {} 76 Options(const ImportGraphDefOptions& in) // NOLINT(runtime/explicit) 77 : allow_internal_ops(false), 78 expect_device_spec(false), 79 prefix(in.prefix.empty() || str_util::EndsWith(in.prefix, "/") 80 ? in.prefix 81 : in.prefix + "/"), 82 uniquify_names(in.uniquify_names), 83 uniquify_prefix(in.uniquify_prefix), 84 input_map(in.input_map.begin(), in.input_map.end()), 85 skip_mapped_nodes(in.skip_mapped_nodes), 86 control_dependencies(in.control_dependencies), 87 return_tensors(in.return_tensors.begin(), in.return_tensors.end()), 88 return_nodes(in.return_nodes), 89 importing(true), 90 validate_colocation_constraints(in.validate_colocation_constraints), 91 validate_shape(in.validate_shape), 92 default_device(in.default_device) {} 93 94 bool allow_internal_ops; 95 bool expect_device_spec; 96 97 string prefix; 98 bool uniquify_names; 99 bool uniquify_prefix; 100 std::map<TensorId, TensorId> input_map; 101 bool skip_mapped_nodes; 102 std::vector<string> control_dependencies; 103 std::vector<TensorId> return_tensors; 104 std::vector<string> return_nodes; 105 106 // TODO(ashankar): This bool exists to separate out functionality required 107 // to make ImportGraphDef a close equivalent of Python's import_graph_def 108 // without affecting the behavior of ConvertGraphDefToGraph at the time 109 // ImportGraphDef was added. 110 // 111 // That said, the functionality here (shape and op validation) seems 112 // applicable to ConvertGraphDefToGraph as well, so make an attempt to 113 // remove this. 114 bool importing; 115 bool validate_colocation_constraints; 116 bool validate_shape = true; 117 118 string default_device; 119 }; 120 121 typedef gtl::ArraySlice<const NodeDef*> NodeDefSlice; 122 123 // versions and library may be nullptr 124 static Status Construct( 125 const Options& opts, NodeDefSlice node_defs, const VersionDef* versions, 126 const FunctionDefLibrary* library, Graph* g, ShapeRefiner* refiner, 127 std::vector<std::pair<Node*, int>>* return_tensors, 128 std::vector<Node*>* return_nodes, 129 std::vector<SafeTensorId>* missing_unused_input_map_keys) { 130 if (versions) { 131 TF_RETURN_IF_ERROR(CheckVersions(*versions, TF_GRAPH_DEF_VERSION, 132 TF_GRAPH_DEF_VERSION_MIN_PRODUCER, 133 "GraphDef", "graph")); 134 } 135 GraphConstructor c(opts, node_defs, versions, library, g, refiner, 136 return_tensors, return_nodes, 137 missing_unused_input_map_keys); 138 const Status s = c.TryImport(); 139 if (!s.ok()) c.Undo(); 140 return s; 141 } 142 143 private: 144 GraphConstructor(const Options& opts, NodeDefSlice node_defs, 145 const VersionDef* versions, 146 const FunctionDefLibrary* library, Graph* g, 147 ShapeRefiner* refiner, 148 std::vector<std::pair<Node*, int>>* return_tensors, 149 std::vector<Node*>* return_nodes, 150 std::vector<SafeTensorId>* missing_unused_input_map_keys) 151 : opts_(opts), 152 node_defs_(node_defs), 153 versions_(versions), 154 library_(library), 155 g_(g), 156 original_versions_(g->versions()), 157 prefix_(opts.prefix), 158 refiner_(refiner), 159 return_tensors_(return_tensors), 160 return_nodes_(return_nodes), 161 missing_unused_input_map_keys_(missing_unused_input_map_keys) {} 162 163 Status TryImport() { 164 TF_RETURN_IF_ERROR(EnsureNoNameCollisions()); 165 TF_RETURN_IF_ERROR(ValidateInputMapAndControlDependencies()); 166 TF_RETURN_IF_ERROR(BuildNodeIndex()); 167 TF_RETURN_IF_ERROR(InitFromEdges()); 168 TF_RETURN_IF_ERROR(Convert()); 169 TF_RETURN_IF_ERROR(AddBackEdges()); 170 TF_RETURN_IF_ERROR(UpdateVersionDef()); 171 TF_RETURN_IF_ERROR(PopulateReturnTensors()); 172 TF_RETURN_IF_ERROR(PopulateReturnNodes()); 173 TF_RETURN_IF_ERROR(PopulateMissingUnusedInputMapKeys()); 174 UpdateUniquifiedColocationNames(); 175 FixupSourceAndSinkEdges(g_); 176 return Status::OK(); 177 } 178 179 Status EnsureNoNameCollisions(); 180 Status ValidateInputMapAndControlDependencies(); 181 Status BuildNodeIndex(); 182 Status InitFromEdges(); 183 Status Convert(); 184 Status AddBackEdges(); 185 Status UpdateVersionDef(); 186 Status PopulateReturnTensors(); 187 Status PopulateReturnNodes(); 188 Status PopulateMissingUnusedInputMapKeys(); 189 190 void Undo(); 191 192 Status IsNodeFullyMapped(const NodeDef& node_def, bool* is_node_mapped); 193 Status ValidateColocationConstraints(const NodeDef& node_def); 194 Status MakeNode(const NodeDef& node_def, Node** node); 195 Status MakeEdge(Node* src, int output_index, Node* dst, int input_index); 196 Status ValidateShape(Node* node); 197 Status ModifyNodeDefForImport(NodeDef* node_def); 198 // Modifies node_def's inputs according to opts_.input_map. 199 // input_already_exists is a pre-initialized vector of length 200 // node_def->input_size(). This function will mark inputs that are remapped to 201 // true. 202 void RemapNodeDefInputs(NodeDef* node_def, 203 std::vector<bool>* input_already_exists); 204 // input_already_exists is a pre-initialized vector of length 205 // node_def->input_size(). This function will add and mark control inputs as 206 // true. 207 void AddControlDependencies(NodeDef* node_def, 208 std::vector<bool>* input_already_exists); 209 void AddPrefixToNodeDef(const std::vector<bool>& input_already_exists, 210 NodeDef* node_def); 211 212 // Modifies `node_def` if its name isn't unique, or if any of its inputs' 213 // names have been uniquified. This must be called in topological order on all 214 // nodes. 215 void UniquifyNames(const std::vector<bool>& input_already_exists, 216 NodeDef* node_def); 217 218 // Updates any constructed nodes' colocation group names if the name has been 219 // updated by UniquifyNames. This is called after all the nodes have been 220 // constructed so all the names have been uniquified if necessary. 221 void UpdateUniquifiedColocationNames(); 222 223 // Returns true if `name` already exists in `g_` (either as a node name or 224 // prefix). 225 bool NameExistsInGraph(StringPiece name); 226 227 // Returns true if `name` already exists in the GraphDef being imported 228 // (either as a node name or prefix). 229 bool NameExistsInGraphDef(StringPiece name); 230 231 // Returns a unique version of `original_name`, or `original_name` if it's 232 // already unique in the graph. 233 string FindUniqueName(StringPiece original_name); 234 235 // Decrement pending count for users of `processed` and add the ones that now 236 // have all of their pending inputs satisfied to `ready_`. 237 void UpdatePendingCountAndReady(int processed); 238 239 // From constructor 240 const Options opts_; 241 const NodeDefSlice node_defs_; 242 const VersionDef* versions_; 243 const FunctionDefLibrary* library_; 244 Graph* g_; 245 const VersionDef original_versions_; 246 247 // A copy of opts_.prefix, possibly uniquified. 248 string prefix_; 249 250 ShapeRefiner* refiner_; 251 252 // May be null. Not owned. 253 std::vector<std::pair<Node*, int>>* return_tensors_; 254 255 // May be null. Not owned. 256 std::vector<Node*>* return_nodes_; 257 258 // May be null. Not owned. 259 std::vector<SafeTensorId>* missing_unused_input_map_keys_; 260 261 // Intermediate datastructure used to populate 262 // `missing_unused_input_map_keys_`. 263 std::set<TensorId> used_input_map_keys_; 264 265 // Mapping from node name to the index within node_defs_. 266 struct NodeInfo { 267 explicit NodeInfo(int i) : gdef_index(i), node(nullptr) {} 268 // std::unordered_map<> requires that we have a default constructor. 269 NodeInfo() : NodeInfo(-1) {} 270 int gdef_index; 271 Node* node; // nullptr until the NodeDef is converted to a Node. 272 }; 273 gtl::FlatMap<StringPiece, NodeInfo, StringPieceHasher> gdef_nodes_; 274 275 // Prefixes already used in the GraphDef being imported. 276 gtl::FlatSet<StringPiece, StringPieceHasher> gdef_prefixes_; 277 278 // Mapping from node name to the existing node in g_. 279 gtl::FlatMap<StringPiece, Node*, StringPieceHasher> existing_nodes_; 280 281 // Prefixes already used in the graph. 282 gtl::FlatSet<StringPiece, StringPieceHasher> existing_prefixes_; 283 284 // Imported node names that have been uniquified. The key is the original 285 // name, the value is the new unique name. 286 gtl::FlatMap<string, string> uniquified_names_; 287 288 // Index of NodeDefs in node_defs_ with all inputs already converted. We use a 289 // (sorted) set so nodes are created in the order defined in the GraphDef. 290 std::set<int> ready_; 291 292 // Mapping between index within node_defs_ and the number of inputs that 293 // still need to be converted. 294 std::vector<int> pending_count_; 295 296 // Mapping between index within node_defs_ and the index within node_defs_ of 297 // all nodes it outputs to. 298 std::vector<gtl::InlinedVector<int, 4>> outputs_; 299 300 // Used in the conversion from node_defs_ to g_ to represent the ith input 301 // of a node. 302 struct InputInfo { 303 explicit InputInfo(const string& node_name, Node* n, int i) 304 : name(node_name), node(n), index(i) {} 305 // Use string instead of StringPiece so we don't have to manage lifetime 306 string name; 307 Node* node; 308 int index; 309 }; 310 311 // Used in the conversion from node_defs_ to g_ to represent an edge from 312 // the node named 'name' to node 'n'. 313 struct EdgeInfo { 314 explicit EdgeInfo(const string& name, int i1, Node* n, int i2) 315 : src_name(name), src_index(i1), dst_node(n), dst_index(i2) {} 316 // Use string instead of StringPiece so we don't have to manage lifetime 317 string src_name; 318 int src_index; 319 Node* dst_node; 320 int dst_index; 321 }; 322 std::vector<EdgeInfo> back_edges_; 323 }; 324 325 void GraphConstructor::UpdatePendingCountAndReady(int processed) { 326 // We didn't consider NextIteration->Merge edges when computing 327 // pending_counts_ so we should not have to consider it here either. 328 bool is_next_iteration = IsNextIteration(*node_defs_[processed]); 329 for (size_t i = 0; i < outputs_[processed].size(); ++i) { 330 const int output = outputs_[processed][i]; 331 bool is_next_iteration_to_merge_edge = 332 is_next_iteration && IsMerge(*node_defs_[output]); 333 if (!is_next_iteration_to_merge_edge) { 334 int* current_pending_count = &pending_count_[output]; 335 CHECK_GT(*current_pending_count, 0); 336 (*current_pending_count)--; 337 if (*current_pending_count == 0) { 338 ready_.insert(output); 339 } 340 } 341 } 342 } 343 344 // This could be expensive but we don't expect to call it often, if at all (only 345 // if there are multiple nodes in g_ with the same name) 346 bool NodeNameInValues(const std::map<TensorId, TensorId>& input_map, 347 const StringPiece& node_name) { 348 for (auto iter = input_map.begin(); iter != input_map.end(); ++iter) { 349 if (iter->second.first == node_name) return true; 350 } 351 return false; 352 } 353 354 bool NodeNameInValues(const std::vector<string>& control_dependencies, 355 const StringPiece& node_name) { 356 return std::find(control_dependencies.begin(), control_dependencies.end(), 357 node_name) != control_dependencies.end(); 358 } 359 360 // Adds any prefixes of `node_name` (not including the full name itself) to 361 // `prefixes`. 362 void AddPrefixes(StringPiece node_name, 363 gtl::FlatSet<StringPiece, StringPieceHasher>* prefixes) { 364 size_t idx = -1; 365 while ((idx = node_name.find('/', idx + 1)) != StringPiece::npos) { 366 prefixes->insert(node_name.substr(0, idx)); 367 } 368 } 369 370 Status GraphConstructor::EnsureNoNameCollisions() { 371 existing_nodes_.reserve(g_->num_nodes()); 372 // Populate existing_nodes_ and existing_prefixes_. 373 for (Node* n : g_->nodes()) { 374 bool already_exists = !existing_nodes_.insert({n->name(), n}).second; 375 if (already_exists) { 376 if (NodeNameInValues(opts_.input_map, n->name())) { 377 return errors::InvalidArgument( 378 "cannot resolve input_map because multiple nodes exist with name '", 379 n->name(), "'"); 380 } 381 if (NodeNameInValues(opts_.control_dependencies, n->name())) { 382 return errors::InvalidArgument( 383 "cannot resolve control_dependencies because multiple nodes exist " 384 "with name '", 385 n->name(), "'"); 386 } 387 } 388 AddPrefixes(n->name(), &existing_prefixes_); 389 } 390 if (prefix_.empty() && opts_.importing && !opts_.uniquify_names) { 391 for (const NodeDef* n : node_defs_) { 392 const string& name = n->name(); 393 if (NameExistsInGraph(name)) { 394 return errors::InvalidArgument("Node name '", name, 395 "' already exists in the Graph"); 396 } 397 } 398 } else if (!prefix_.empty()) { 399 StringPiece prefix_no_slash(prefix_); 400 prefix_no_slash.remove_suffix(1); 401 if (!IsValidNodeName(prefix_no_slash, false)) { 402 return errors::InvalidArgument("Imported node name prefix '", prefix_, 403 "' would lead to invalid node names"); 404 } 405 if (NameExistsInGraph(prefix_no_slash) && opts_.uniquify_prefix) { 406 prefix_ = strings::StrCat(FindUniqueName(prefix_no_slash), "/"); 407 } 408 } 409 return Status::OK(); 410 } 411 412 Status GraphConstructor::ValidateInputMapAndControlDependencies() { 413 for (const auto& mapping : opts_.input_map) { 414 TensorId src = mapping.first; 415 TensorId dst = mapping.second; 416 if (existing_nodes_.count(dst.first) == 0) { 417 return errors::InvalidArgument( 418 "node '", dst.first, "' in input_map does not exist in graph ", 419 "(input_map entry: ", src.ToString(), "->", dst.ToString(), ")"); 420 } 421 if ((src.second == Graph::kControlSlot) != 422 (dst.second == Graph::kControlSlot)) { 423 return errors::InvalidArgument("input_map entry ", src.ToString(), "->", 424 dst.ToString(), " between ", 425 "control edge and non-control edge"); 426 } 427 } 428 for (const string& node : opts_.control_dependencies) { 429 if (existing_nodes_.count(node) == 0) { 430 return errors::InvalidArgument( 431 "node '", node, 432 "' in control_dependencies does not exist in " 433 "graph"); 434 } 435 } 436 return Status::OK(); 437 } 438 439 Status GraphConstructor::BuildNodeIndex() { 440 // Validate the node names and add them to gdef_nodes_ and gdef_prefixes_. 441 for (int n = 0; n < node_defs_.size(); ++n) { 442 const NodeDef& node_def = *node_defs_[n]; 443 if (!IsValidNodeName(node_def.name(), opts_.allow_internal_ops)) { 444 return errors::InvalidArgument( 445 "Node '", node_def.name(), 446 "': Node name contains invalid characters"); 447 } 448 if (!gdef_nodes_ 449 .insert(std::make_pair(StringPiece(node_def.name()), NodeInfo(n))) 450 .second) { 451 return errors::InvalidArgument("Node '", node_def.name(), 452 "' is not unique"); 453 } 454 // Validate the operation's type. 455 if (node_def.op().empty()) { 456 return errors::InvalidArgument("Node '", node_def.name(), 457 "' does not specify an operation"); 458 } 459 if (opts_.expect_device_spec && node_def.device().empty()) { 460 return errors::InvalidArgument("Node '", node_def.name(), 461 "' is missing a device specification"); 462 } 463 // Validate control edges at end 464 bool in_control_dependence = false; 465 for (int i = 0; i < node_def.input_size(); ++i) { 466 StringPiece input_name = node_def.input(i); 467 if (!input_name.empty() && str_util::StartsWith(input_name, "^")) { 468 in_control_dependence = true; 469 } else if (in_control_dependence) { 470 return errors::InvalidArgument( 471 "Node '", node_def.name(), 472 "': Control dependencies must come after regular dependencies"); 473 } 474 } 475 // Update gdef_prefixes_. 476 AddPrefixes(node_def.name(), &gdef_prefixes_); 477 } 478 return Status::OK(); 479 } 480 481 std::unordered_set<string> GetNextIterationNodes( 482 const GraphConstructor::NodeDefSlice& node_defs) { 483 std::unordered_set<string> next_iteration_nodes; 484 485 for (int n = 0; n < node_defs.size(); ++n) { 486 const NodeDef& node_def = *node_defs[n]; 487 if (IsNextIteration(node_def)) { 488 next_iteration_nodes.insert(node_def.name()); 489 } 490 } 491 492 return next_iteration_nodes; 493 } 494 495 Status GraphConstructor::InitFromEdges() { 496 const int num_nodes = node_defs_.size(); 497 pending_count_.reserve(num_nodes); 498 outputs_.resize(num_nodes); 499 std::unordered_set<string> next_iteration_nodes_ = 500 GetNextIterationNodes(node_defs_); 501 502 // Parse the inputs for each node. 503 for (int n = 0; n < num_nodes; ++n) { 504 const NodeDef& node_def = *node_defs_[n]; 505 int pending_count = node_def.input_size(); 506 if (IsMerge(node_def)) { 507 // Cycles in the graph are only allowed for while loops. A while loop is 508 // identified by an edge from a NextIteration node to a Merge node. For 509 // such Merge nodes, only wait for one non-control input before 510 // considering the node ready to process in Convert(). 511 int32 num_control_edges = 0; 512 bool has_loop_back_edge = false; 513 for (int i = 0; i < node_def.input_size(); ++i) { 514 StringPiece input_name(node_def.input(i)); 515 if (str_util::StartsWith(input_name, "^")) { 516 num_control_edges++; 517 } else { 518 TensorId id(ParseTensorName(input_name)); 519 if (next_iteration_nodes_.find(string(id.first)) != 520 next_iteration_nodes_.end()) { 521 has_loop_back_edge = true; 522 } 523 } 524 } 525 if (has_loop_back_edge) { 526 pending_count = num_control_edges + 1; 527 } 528 } 529 for (int i = 0; i < node_def.input_size(); ++i) { 530 StringPiece input_name = node_def.input(i); 531 TensorId id(ParseTensorName(input_name)); 532 if (opts_.input_map.count(id) == 0) { 533 // If an input is not mapped, then the input should appear in the graph 534 // being imported. 535 auto iter = gdef_nodes_.find(id.first); 536 if (iter == gdef_nodes_.end()) { 537 return errors::InvalidArgument("Node '", node_def.name(), 538 "': Unknown input node '", 539 node_def.input(i), "'"); 540 } 541 outputs_[iter->second.gdef_index].push_back(n); 542 } else { 543 // This input is mapped to an existing edge. Therefore this input is 544 // as good as being already processed. 545 --pending_count; 546 DCHECK_GE(pending_count, 0); 547 } 548 } 549 if (pending_count == 0) { 550 ready_.insert(n); 551 } 552 pending_count_.push_back(pending_count); 553 } 554 return Status::OK(); 555 } 556 557 Status GraphConstructor::ValidateColocationConstraints( 558 const NodeDef& node_def) { 559 if (!opts_.validate_colocation_constraints || !opts_.importing) 560 return Status::OK(); 561 const auto iter = node_def.attr().find(kColocationAttrName); 562 if (iter == node_def.attr().end()) return Status::OK(); 563 for (const string& c : iter->second.list().s()) { 564 StringPiece s(c); 565 if (str_util::ConsumePrefix(&s, kColocationGroupPrefix) && 566 gdef_nodes_.find(s) == gdef_nodes_.end()) { 567 return errors::InvalidArgument( 568 "Node '", node_def.name(), 569 "' expects to be colocated with unknown node '", s, "'"); 570 } 571 } 572 return Status::OK(); 573 } 574 575 Status GraphConstructor::MakeNode(const NodeDef& node_def, Node** node) { 576 // Add the node to the graph. 577 Status status; 578 *node = g_->AddNode(node_def, &status); 579 if (!status.ok()) return status; 580 if (opts_.expect_device_spec) { 581 (*node)->set_assigned_device_name(node_def.device()); 582 } 583 return Status::OK(); 584 } 585 586 Status GraphConstructor::ValidateShape(Node* node) { 587 if (!opts_.importing || !opts_.validate_shape) return Status::OK(); 588 TF_RETURN_IF_ERROR(refiner_->AddNode(node)); 589 // For nodes with the _output_shapes attribute, override the shape. 590 std::vector<TensorShapeProto> shape_attrs; 591 const char* kAttrName = "_output_shapes"; 592 if (!GetNodeAttr(node->attrs(), kAttrName, &shape_attrs).ok()) { 593 // No _output_shapes attribute, the AddNode call above was sufficient. 594 return Status::OK(); 595 } 596 auto* ic = refiner_->GetContext(node); 597 DCHECK(ic != nullptr) 598 << "ShapeRefiner::AddNode() should have created the InferenceContext"; 599 if (shape_attrs.size() < node->num_outputs()) { 600 return errors::InvalidArgument( 601 "Node '", node->name(), "' has ", node->num_outputs(), 602 " outputs but the ", kAttrName, " attribute specifies shapes for ", 603 shape_attrs.size(), " outputs"); 604 } 605 // NOTE(skyewm): we don't raise an error here because some users depend on 606 // this behavior, even though it's unsafe. 607 // TODO(b/74619486): raise an error. 608 if (shape_attrs.size() > node->num_outputs()) { 609 LOG(WARNING) << "Node '" << node->name() << "' has " << node->num_outputs() 610 << " outputs but the " << kAttrName 611 << " attribute specifies shapes for " << shape_attrs.size() 612 << " outputs. Output shapes may be inaccurate."; 613 } 614 for (int i = 0; i < node->num_outputs(); ++i) { 615 const TensorShapeProto& p = shape_attrs[i]; 616 shape_inference::ShapeHandle h; 617 Status s = ic->MakeShapeFromShapeProto(p, &h); 618 if (!s.ok()) { 619 return errors::InvalidArgument("Node '", node->name(), " has an invalid ", 620 kAttrName, " attribute (shape #", i, 621 " error:'", s.error_message(), "'"); 622 } 623 s = refiner_->SetShape(node, i, h); 624 if (!s.ok()) { 625 // If the output shape is incompatible with what is inferred 626 // by the graph for a very specific whitelist of ops, then we 627 // ignore this output shape. This can happen if there is a 628 // bug in the shape function for some operation, and the 629 // serialized graph def has the incorrect shape set when 630 // running on a newer binary with the fixed shape function. 631 // This is an escape hatch that allows us to correct shape 632 // functions that are not critical to correct execution but 633 // would cause graphs to fail if imported after correcting. 634 // 635 const string& op = node->type_string(); 636 const std::vector<string> whitelist = { 637 // To be removed after 2017/03/08. 638 "RandomShuffleQueue", 639 "PaddingFIFOQueue", 640 "FIFOQueue", 641 "PriorityQueue", 642 "QueueSize", 643 "Stack", 644 "Barrier", 645 "BarrierReadySize", 646 "BarrierIncompleteSize", 647 "HashTable", 648 "MutableHashTable", 649 "MutableHashTableOfTensors", 650 "Mutex", 651 "CuckooTable", 652 "IndexTable", 653 "WholeFileReader", 654 "TextLineReader", 655 "FixedLengthRecordReader", 656 "TFRecordReader", 657 "IdentityReader", 658 "RefSwitch", 659 "RefEnter", 660 "RefNextIteration", 661 "RefMerge", 662 "RefIdentity", 663 "LMDBReader", 664 // To be removed after 2017/04/24. 665 "ConditionalAccumulator", 666 "SparseConditionalAccumulator", 667 "Table", 668 }; 669 if (std::find(whitelist.begin(), whitelist.end(), op) == 670 whitelist.end()) { 671 return errors::InvalidArgument( 672 "Node '", node->name(), "' has an ", kAttrName, 673 " attribute inconsistent with the GraphDef for output #", i, ": ", 674 s.error_message()); 675 } 676 } 677 } 678 node->ClearAttr(kAttrName); 679 return Status::OK(); 680 } 681 682 Status GraphConstructor::ModifyNodeDefForImport(NodeDef* node_def) { 683 const OpDef* op_def; 684 TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def->op(), &op_def)); 685 AddDefaultsToNodeDef(*op_def, node_def); 686 TF_RETURN_IF_ERROR(ValidateNodeDef(*node_def, *op_def)); 687 if (versions_) { 688 TF_RETURN_IF_ERROR(CheckOpDeprecation(*op_def, versions_->producer())); 689 } 690 return Status::OK(); 691 } 692 693 void RemoveInputs(const std::vector<int>& inputs_to_remove, NodeDef* node_def, 694 std::vector<bool>* input_already_exists) { 695 // Remove 'inputs_to_remove' from 'node_def' 696 NodeDef copy; 697 copy.mutable_input()->Reserve(node_def->input_size() - 698 inputs_to_remove.size()); 699 for (int i = 0, j = 0; i < node_def->input_size(); ++i) { 700 if (j < inputs_to_remove.size() && i == inputs_to_remove[j]) { 701 ++j; 702 } else { 703 copy.add_input()->swap(*node_def->mutable_input(i)); 704 } 705 } 706 node_def->mutable_input()->Swap(copy.mutable_input()); 707 // Remove 'inputs_to_remove' from 'input_already_exists' 708 for (int idx : inputs_to_remove) { 709 input_already_exists->erase(input_already_exists->begin() + idx); 710 } 711 DCHECK_EQ(input_already_exists->size(), node_def->input_size()); 712 } 713 714 void GraphConstructor::RemapNodeDefInputs( 715 NodeDef* node_def, std::vector<bool>* input_already_exists) { 716 DCHECK_EQ(input_already_exists->size(), node_def->input_size()); 717 std::set<TensorId> control_inputs; 718 std::vector<int> inputs_to_remove; 719 720 for (int i = 0; i < node_def->input_size(); ++i) { 721 auto iter = opts_.input_map.find(ParseTensorName(node_def->input(i))); 722 if (iter == opts_.input_map.end()) continue; 723 used_input_map_keys_.insert(iter->first); 724 725 TensorId new_input = iter->second; 726 if (new_input.second == Graph::kControlSlot) { 727 // Check if we've already remapped a different input to new_input, and if 728 // so remove this input. 729 if (control_inputs.count(new_input) > 0) { 730 inputs_to_remove.push_back(i); 731 continue; 732 } 733 control_inputs.insert(new_input); 734 } 735 node_def->set_input(i, new_input.ToString()); 736 (*input_already_exists)[i] = true; 737 } 738 if (!inputs_to_remove.empty()) { 739 RemoveInputs(inputs_to_remove, node_def, input_already_exists); 740 } 741 } 742 743 void GraphConstructor::AddControlDependencies( 744 NodeDef* node_def, std::vector<bool>* input_already_exists) { 745 // To avoid adding redundant control dependencies to every imported node, skip 746 // nodes that will inherit the dependencies from another imported node. 747 bool inherits_deps = false; 748 for (int i = 0; i < node_def->input_size(); ++i) { 749 // Assume we won't inherit dependencies from remapped inputs that already 750 // exist in the graph. Even if we're wrong, we'll only add redundant 751 // dependencies. 752 if ((*input_already_exists)[i]) continue; 753 754 // If this input is a backedge, assume we won't inherit the dependencies. 755 // TODO(skyewm): we have many redundant ParseTensorName calls. It could be 756 // worth optimizing these. 757 TensorId id(ParseTensorName(node_def->input(i))); 758 auto iter = gdef_nodes_.find(id.first); 759 DCHECK(iter != gdef_nodes_.end()) << id.first; 760 if (iter->second.node == nullptr) { 761 // Input hasn't been created yet, indicating it's a backedge. 762 continue; 763 } 764 inherits_deps = true; 765 } 766 if (inherits_deps) return; 767 768 // node_def either has no inputs or all remapped inputs, add the control 769 // dependencies 770 for (const string& control_dep : opts_.control_dependencies) { 771 string input = TensorId(control_dep, Graph::kControlSlot).ToString(); 772 bool found = false; 773 for (int i = node_def->input_size() - 1; i >= 0; --i) { 774 const string& node_input = node_def->input(i); 775 if (node_input[0] != '^') { 776 // Control inputs are at the end. Break when we reach the non-control 777 // inputs. 778 break; 779 } 780 if (node_input == input) { 781 // Control dependency already exists 782 found = true; 783 break; 784 } 785 } 786 if (found) { 787 continue; 788 } 789 node_def->add_input(input); 790 input_already_exists->push_back(true); 791 } 792 } 793 794 void GraphConstructor::AddPrefixToNodeDef( 795 const std::vector<bool>& input_already_exists, NodeDef* node_def) { 796 if (prefix_.empty()) return; 797 node_def->set_name(strings::StrCat(prefix_, node_def->name())); 798 // Update names of input nodes 799 for (int i = 0; i < node_def->input_size(); ++i) { 800 // Skip remapped inputs (which already exist in g_ and are not being 801 // imported). 802 if (input_already_exists[i]) continue; 803 StringPiece input(node_def->input(i)); 804 if (str_util::ConsumePrefix(&input, "^")) { 805 node_def->set_input(i, strings::StrCat("^", prefix_, input)); 806 } else { 807 node_def->set_input(i, strings::StrCat(prefix_, input)); 808 } 809 } 810 // Update names of colocation groups 811 if (node_def->attr().find(kColocationAttrName) != node_def->attr().end()) { 812 auto* list = 813 node_def->mutable_attr()->at(kColocationAttrName).mutable_list(); 814 for (int i = 0; i < list->s_size(); ++i) { 815 StringPiece v(list->s(i)); 816 if (str_util::ConsumePrefix(&v, kColocationGroupPrefix)) { 817 list->set_s(i, strings::StrCat(kColocationGroupPrefix, prefix_, v)); 818 } 819 } 820 } 821 } 822 823 void GraphConstructor::UniquifyNames( 824 const std::vector<bool>& input_already_exists, NodeDef* node_def) { 825 if (NameExistsInGraph(node_def->name())) { 826 string old_name = node_def->name(); 827 node_def->set_name(FindUniqueName(node_def->name())); 828 uniquified_names_[old_name] = node_def->name(); 829 // Note that we don't have to update gdef_nodes_ or gdef_prefixes_ with 830 // `name` because we guarantee the original NodeDef names are unique, 831 // meaning we won't generate this name again. 832 } 833 for (int i = 0; i < node_def->input_size(); ++i) { 834 // Skip remapped inputs (which already exist in g_ and are not being 835 // imported). 836 if (input_already_exists[i]) continue; 837 TensorId id = ParseTensorName(node_def->input(i)); 838 // We require that UniquifyNames() is called on all NodeDefs in topological 839 // order. This guarantees that node_def's inputs will already be uniquified 840 // if necessary. 841 auto iter = uniquified_names_.find(string(id.first)); 842 if (iter == uniquified_names_.end()) continue; 843 id.first = iter->second; 844 node_def->set_input(i, id.ToString()); 845 } 846 } 847 848 void GraphConstructor::UpdateUniquifiedColocationNames() { 849 for (const auto& pair : gdef_nodes_) { 850 Node* node = pair.second.node; 851 if (node == nullptr) continue; 852 std::vector<string> coloc_values; 853 Status status = 854 GetNodeAttr(node->attrs(), kColocationAttrName, &coloc_values); 855 if (!status.ok()) continue; 856 bool updated = false; 857 for (int i = 0; i < coloc_values.size(); ++i) { 858 StringPiece val(coloc_values[i]); 859 if (str_util::ConsumePrefix(&val, kColocationGroupPrefix)) { 860 auto name_pair = uniquified_names_.find(string(val)); 861 if (name_pair == uniquified_names_.end()) continue; 862 updated = true; 863 coloc_values[i] = 864 strings::StrCat(kColocationGroupPrefix, name_pair->second); 865 } 866 } 867 if (updated) { 868 node->AddAttr(kColocationAttrName, coloc_values); 869 } 870 } 871 } 872 873 bool GraphConstructor::NameExistsInGraph(StringPiece name) { 874 if (existing_nodes_.find(name) != existing_nodes_.end()) return true; 875 if (existing_prefixes_.find(name) != existing_prefixes_.end()) return true; 876 return false; 877 } 878 879 bool GraphConstructor::NameExistsInGraphDef(StringPiece name) { 880 if (gdef_nodes_.find(name) != gdef_nodes_.end()) return true; 881 if (gdef_prefixes_.find(name) != gdef_prefixes_.end()) return true; 882 return false; 883 } 884 885 string GraphConstructor::FindUniqueName(StringPiece original_name) { 886 string name(original_name); 887 int count = 0; 888 // Check that any generated names don't collide with imported NodeDefs (as 889 // well as nodes in g_). 890 while (NameExistsInGraph(name) || (count > 0 && NameExistsInGraphDef(name))) { 891 name = strings::StrCat(original_name, "_", ++count); 892 } 893 return name; 894 } 895 896 Status GraphConstructor::IsNodeFullyMapped(const NodeDef& node_def, 897 bool* is_node_mapped) { 898 const OpDef* op_def; 899 TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def.op(), &op_def)); 900 for (int i = 0; i < op_def->output_arg_size(); ++i) { 901 if (opts_.input_map.find({node_def.name(), i}) == opts_.input_map.end()) { 902 *is_node_mapped = false; 903 return Status::OK(); 904 } 905 } 906 *is_node_mapped = true; 907 return Status::OK(); 908 } 909 910 Status GraphConstructor::Convert() { 911 // Import functions before adding nodes, since imported nodes may refer to 912 // functions 913 if (library_) { 914 TF_RETURN_IF_ERROR(g_->AddFunctionLibrary(*library_)); 915 } 916 917 std::vector<InputInfo> inputs; 918 int processed = 0; 919 920 std::vector<bool> input_already_exists; 921 922 // Process the NodeDefs in topological order. 923 // (InitFromEdges() sets this up by filling in ready_ with nodes that have no 924 // inputs, pending_counts_ with the number of inputs for each node and 925 // outputs_ with the outputs of each node). 926 while (!ready_.empty()) { 927 int o = *ready_.begin(); 928 ready_.erase(ready_.begin()); 929 ++processed; 930 inputs.clear(); 931 bool has_data_back_edge = false; 932 933 const NodeDef& original_node_def = *node_defs_[o]; 934 NodeDef imported_node_def; 935 const NodeDef* node_def; 936 937 // input_already_exists[i] is true iff the i-th input of the node we're 938 // importing refers to a preexisting node in g_ (i.e. input[i] existed prior 939 // to importing node_defs_). Conversely, input_already_exists[i] is false 940 // iff the input refers to a node in node_defs_. 941 input_already_exists.clear(); 942 input_already_exists.resize(original_node_def.input_size(), false); 943 944 if (opts_.importing) { 945 if (opts_.skip_mapped_nodes) { 946 bool is_node_mapped = false; 947 TF_RETURN_IF_ERROR( 948 IsNodeFullyMapped(original_node_def, &is_node_mapped)); 949 if (is_node_mapped) { 950 // Skip this node after updating pending_count_ for outputs 951 UpdatePendingCountAndReady(o); 952 continue; 953 } 954 } 955 956 // TODO(ashankar): The line below means an additional copy of the 957 // NodeDef, which can be expensive if the NodeDef contains large tensors 958 // in it. Might make sense to change the API for ImportGraphDef to take 959 // a mutable GraphDef* and avoid the copying. 960 imported_node_def = original_node_def; 961 if (!opts_.input_map.empty()) { 962 // Note that input_already_exists can shrink here 963 RemapNodeDefInputs(&imported_node_def, &input_already_exists); 964 } 965 if (!opts_.control_dependencies.empty()) { 966 // Note that input_already_exists can grow here 967 AddControlDependencies(&imported_node_def, &input_already_exists); 968 } 969 if (!opts_.default_device.empty() && imported_node_def.device().empty()) { 970 imported_node_def.set_device(opts_.default_device); 971 } 972 973 node_def = &imported_node_def; 974 } else { 975 node_def = &original_node_def; 976 } 977 978 DCHECK_EQ(node_def->input_size(), input_already_exists.size()); 979 TF_RETURN_IF_ERROR(ValidateColocationConstraints(*node_def)); 980 for (int i = 0; i < node_def->input_size(); ++i) { 981 TensorId id(ParseTensorName(node_def->input(i))); 982 Node* src_node; 983 int src_index; 984 985 if (!input_already_exists[i]) { 986 // Locate input in newly-imported nodes 987 auto iter = gdef_nodes_.find(id.first); 988 DCHECK(iter != gdef_nodes_.end()) << id.first; 989 src_node = iter->second.node; 990 src_index = id.second; 991 if (src_node == nullptr) has_data_back_edge = true; 992 } else { 993 // Input refers to preexistng node in graph 994 auto iter = existing_nodes_.find(id.first); 995 DCHECK(iter != existing_nodes_.end()) << id.first; 996 src_node = iter->second; 997 src_index = id.second; 998 } 999 1000 if (src_node != nullptr && src_index >= src_node->num_outputs()) { 1001 return errors::InvalidArgument( 1002 "Node '", node_def->name(), "': Connecting to invalid output ", 1003 id.second, " of source node ", id.first, " which has ", 1004 src_node->num_outputs(), " outputs"); 1005 } 1006 1007 inputs.emplace_back(string(id.first), src_node, src_index); 1008 } 1009 1010 if (has_data_back_edge && !IsMerge(*node_def)) { 1011 return errors::InvalidArgument( 1012 "Node '", node_def->name(), 1013 "' had a back edge, but only Merge nodes can have back edges."); 1014 } 1015 1016 Node* node; 1017 if (opts_.importing) { 1018 if (!prefix_.empty()) { 1019 AddPrefixToNodeDef(input_already_exists, &imported_node_def); 1020 } 1021 // Note: no need to uniquify names if the prefix already guarantees 1022 // uniqueness 1023 if (opts_.uniquify_names && (prefix_.empty() || !opts_.uniquify_prefix)) { 1024 UniquifyNames(input_already_exists, &imported_node_def); 1025 } 1026 TF_RETURN_IF_ERROR(ModifyNodeDefForImport(&imported_node_def)); 1027 } 1028 TF_RETURN_IF_ERROR(MakeNode(*node_def, &node)); 1029 // Use original_node_def so name StringPiece remains valid 1030 gdef_nodes_[original_node_def.name()].node = node; 1031 1032 // Add edges from inputs to *node to the graph. 1033 for (size_t i = 0; i < inputs.size(); ++i) { 1034 if (inputs[i].node == nullptr) { 1035 // Record this back edge, which will be added after all nodes 1036 // are created. 1037 back_edges_.emplace_back(inputs[i].name, inputs[i].index, node, i); 1038 } else if (inputs[i].index == Graph::kControlSlot) { 1039 g_->AddControlEdge(inputs[i].node, node); 1040 } else { 1041 TF_RETURN_IF_ERROR(MakeEdge(inputs[i].node, inputs[i].index, node, i)); 1042 } 1043 } 1044 1045 TF_RETURN_IF_ERROR(ValidateShape(node)); 1046 1047 // Update pending_count_ for outputs. 1048 UpdatePendingCountAndReady(o); 1049 } 1050 1051 if (processed < node_defs_.size()) { 1052 LOG(WARNING) << "IN " << __func__ << " " << (node_defs_.size() - processed) 1053 << " NODES IN A CYCLE"; 1054 for (int64 i = 0; i < node_defs_.size(); i++) { 1055 if (pending_count_[i] != 0) { 1056 LOG(WARNING) << "PENDING: " << SummarizeNodeDef(*node_defs_[i]) 1057 << " WITH PENDING COUNT = " << pending_count_[i]; 1058 } 1059 } 1060 return errors::InvalidArgument(node_defs_.size() - processed, 1061 " nodes in a cycle"); 1062 } 1063 1064 return Status::OK(); 1065 } 1066 1067 Status GraphConstructor::AddBackEdges() { 1068 // Add the back edges after all nodes are created. 1069 for (auto e : back_edges_) { 1070 Node* src_node = gdef_nodes_[e.src_name].node; 1071 if (e.src_index == Graph::kControlSlot) { 1072 g_->AddControlEdge(src_node, e.dst_node); 1073 } else { 1074 TF_RETURN_IF_ERROR( 1075 MakeEdge(src_node, e.src_index, e.dst_node, e.dst_index)); 1076 } 1077 1078 VLOG(2) << "Add back edge: " << src_node->name() << " -> " 1079 << e.dst_node->name(); 1080 } 1081 return Status::OK(); 1082 } 1083 1084 Status GraphConstructor::UpdateVersionDef() { 1085 if (versions_ == nullptr) return Status::OK(); 1086 1087 if (!opts_.importing) { 1088 g_->set_versions(*versions_); 1089 return Status::OK(); 1090 } 1091 VersionDef versions = g_->versions(); 1092 versions.set_producer(std::min(versions.producer(), versions_->producer())); 1093 versions.set_min_consumer( 1094 std::max(versions.min_consumer(), versions_->min_consumer())); 1095 if (versions_->bad_consumers_size() > 0) { 1096 std::set<int> bad(versions.bad_consumers().begin(), 1097 versions.bad_consumers().end()); 1098 bad.insert(versions_->bad_consumers().begin(), 1099 versions_->bad_consumers().end()); 1100 versions.clear_bad_consumers(); 1101 for (int v : bad) { 1102 versions.add_bad_consumers(v); 1103 } 1104 } 1105 g_->set_versions(versions); 1106 return Status::OK(); 1107 } 1108 1109 Status GraphConstructor::PopulateReturnTensors() { 1110 if (opts_.return_tensors.empty()) return Status::OK(); 1111 for (const TensorId& id : opts_.return_tensors) { 1112 auto iter = opts_.input_map.find(id); 1113 if (iter == opts_.input_map.end()) { 1114 // Locate id in imported nodes 1115 auto iter = gdef_nodes_.find(id.first); 1116 if (iter == gdef_nodes_.end()) { 1117 return errors::InvalidArgument("Requested return tensor '", 1118 id.ToString(), 1119 "' not found in graph def"); 1120 } 1121 int num_outputs = iter->second.node->num_outputs(); 1122 if ((id.second < 0 || id.second >= num_outputs) && 1123 id.second != Graph::kControlSlot) { 1124 return errors::InvalidArgument("Invalid return output ", id.second, 1125 " of node '", id.first, "', which has ", 1126 num_outputs, " output(s)"); 1127 } 1128 return_tensors_->push_back({iter->second.node, id.second}); 1129 } else { 1130 // id was remapped to existing node 1131 TensorId remapped_id = iter->second; 1132 DCHECK_GT(existing_nodes_.count(remapped_id.first), 0); 1133 Node* node = existing_nodes_[remapped_id.first]; 1134 return_tensors_->push_back({node, remapped_id.second}); 1135 } 1136 } 1137 return Status::OK(); 1138 } 1139 1140 Status GraphConstructor::PopulateReturnNodes() { 1141 if (opts_.return_nodes.empty()) return Status::OK(); 1142 for (StringPiece name : opts_.return_nodes) { 1143 auto iter = gdef_nodes_.find(name); 1144 if (iter == gdef_nodes_.end()) { 1145 return errors::InvalidArgument("Requested return node '", name, 1146 "' not found in graph def"); 1147 } 1148 return_nodes_->push_back(iter->second.node); 1149 } 1150 return Status::OK(); 1151 } 1152 1153 Status GraphConstructor::PopulateMissingUnusedInputMapKeys() { 1154 if (missing_unused_input_map_keys_ == nullptr) return Status::OK(); 1155 for (const auto& input_map_pair : opts_.input_map) { 1156 TensorId key = input_map_pair.first; 1157 if (used_input_map_keys_.count(key) > 0) continue; 1158 1159 auto pair = gdef_nodes_.find(key.first); 1160 if (pair == gdef_nodes_.end()) { 1161 // key's node doesn't exist in GraphDef 1162 missing_unused_input_map_keys_->push_back(key); 1163 continue; 1164 } 1165 1166 // Check that key's index is in bounds. Get the number of outputs from the 1167 // NodeDef, rather than the imported Node, since the Node may not exist if 1168 // opts_.skip_mapped_nodes is true. 1169 const NodeDef* node_def = node_defs_[pair->second.gdef_index]; 1170 const OpDef* op_def; 1171 TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def->op(), &op_def)); 1172 int num_outputs; 1173 TF_RETURN_IF_ERROR(NumOutputsForNode(*node_def, *op_def, &num_outputs)); 1174 if (key.second >= num_outputs) { 1175 // key's index out of bounds 1176 missing_unused_input_map_keys_->push_back(key); 1177 } 1178 } 1179 return Status::OK(); 1180 } 1181 1182 void GraphConstructor::Undo() { 1183 for (const auto& iter : gdef_nodes_) { 1184 if (iter.second.node != nullptr) { 1185 g_->RemoveNode(iter.second.node); 1186 } 1187 } 1188 g_->set_versions(original_versions_); 1189 } 1190 1191 Status GraphConstructor::MakeEdge(Node* src, int output_index, Node* dst, 1192 int input_index) { 1193 DataType src_out = src->output_type(output_index); 1194 DataType dst_in = dst->input_type(input_index); 1195 if (!TypesCompatible(dst_in, src_out)) { 1196 return errors::InvalidArgument( 1197 "Input ", input_index, " of node ", dst->name(), " was passed ", 1198 DataTypeString(src_out), " from ", src->name(), ":", output_index, 1199 " incompatible with expected ", DataTypeString(dst_in), "."); 1200 } 1201 g_->AddEdge(src, output_index, dst, input_index); 1202 return Status::OK(); 1203 } 1204 1205 } // namespace 1206 1207 Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts, 1208 const GraphDef& gdef, Graph* g) { 1209 ShapeRefiner refiner(gdef.versions().producer(), g->op_registry()); 1210 return GraphConstructor::Construct( 1211 opts, gdef.node(), &gdef.versions(), &gdef.library(), g, &refiner, 1212 /*return_tensors=*/nullptr, /*return_nodes=*/nullptr, 1213 /*missing_unused_input_map_keys=*/nullptr); 1214 } 1215 1216 Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts, 1217 gtl::ArraySlice<NodeDef> nodes, Graph* g) { 1218 ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, g->op_registry()); 1219 // TODO(irving): Copy will go away once NodeInfo exists 1220 std::vector<const NodeDef*> node_defs; 1221 for (const auto& n : nodes) { 1222 node_defs.push_back(&n); 1223 } 1224 return GraphConstructor::Construct(opts, node_defs, nullptr, nullptr, g, 1225 &refiner, /*return_tensors=*/nullptr, 1226 /*return_nodes=*/nullptr, 1227 /*missing_unused_input_map_keys=*/nullptr); 1228 } 1229 1230 Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef, 1231 Graph* g, ShapeRefiner* refiner, 1232 ImportGraphDefResults* results) { 1233 if (!opts.return_tensors.empty()) { 1234 if (results == nullptr) { 1235 return errors::InvalidArgument( 1236 "results argument to ImportGraphDef() must be non-null if " 1237 "opts.return_tensors is non-empty"); 1238 } 1239 } 1240 1241 if (!opts.return_nodes.empty()) { 1242 if (opts.skip_mapped_nodes) { 1243 return errors::InvalidArgument( 1244 "Requesting return_nodes with skip_mapped_nodes set is not currently " 1245 "supported"); 1246 } 1247 if (results == nullptr) { 1248 return errors::InvalidArgument( 1249 "results argument to ImportGraphDef() must be non-null if " 1250 "opts.return_nodes is non-empty"); 1251 } 1252 } 1253 1254 if (results != nullptr) { 1255 if (!results->return_tensors.empty() || !results->return_nodes.empty() || 1256 !results->missing_unused_input_map_keys.empty()) { 1257 return errors::InvalidArgument( 1258 "All fields in results argument to ImportGraphDef() must be empty."); 1259 } 1260 } 1261 1262 ShapeRefiner default_refiner(gdef.versions().producer(), g->op_registry()); 1263 if (refiner == nullptr) { 1264 refiner = &default_refiner; 1265 } else { 1266 // Log a warning if we are importing a GraphDef at an older 1267 // producer version after already having added non-source/sink 1268 // nodes to the graph in the past. 1269 if (gdef.versions().producer() > 0 && 1270 gdef.versions().producer() < refiner->graph_def_version() && 1271 g->num_nodes() > 2) { 1272 LOG(WARNING) << "Importing a graph with a lower producer version " 1273 << gdef.versions().producer() 1274 << " into an existing graph with producer version " 1275 << refiner->graph_def_version() << ". Shape inference will " 1276 << "have run different parts of the graph with different " 1277 << "producer versions."; 1278 } 1279 } 1280 1281 // Set the graph def version of the refiner as the min of the 1282 // current value and the version from the graph we are about to 1283 // import. 1284 // 1285 // Note: to match Run() semantics, we should re-run shape inference 1286 // on the entire graph if the producer version has changed. For now 1287 // we log the warning above. 1288 refiner->set_graph_def_version( 1289 std::min(refiner->graph_def_version(), gdef.versions().producer())); 1290 1291 if (results == nullptr) { 1292 return GraphConstructor::Construct(opts, gdef.node(), &gdef.versions(), 1293 &gdef.library(), g, refiner, nullptr, 1294 nullptr, nullptr); 1295 } else { 1296 return GraphConstructor::Construct( 1297 opts, gdef.node(), &gdef.versions(), &gdef.library(), g, refiner, 1298 &results->return_tensors, &results->return_nodes, 1299 &results->missing_unused_input_map_keys); 1300 } 1301 } 1302 1303 void CopyGraph(const Graph& src, Graph* dest) { 1304 for (Node* n : dest->nodes()) { 1305 CHECK(n->IsSource() || n->IsSink()) << "*dest must be empty"; 1306 } 1307 1308 // Copy GraphDef versions 1309 dest->set_versions(src.versions()); 1310 1311 // Copy the nodes 1312 std::unordered_map<const Node*, Node*> 1313 node_map; // "Node in src" -> "Node in *dest" 1314 node_map[src.source_node()] = dest->source_node(); 1315 node_map[src.sink_node()] = dest->sink_node(); 1316 for (Node* n : src.op_nodes()) { 1317 node_map[n] = dest->CopyNode(n); 1318 } 1319 1320 // Copy the edges 1321 for (const Edge* e : src.edges()) { 1322 Node* src_copy = node_map[e->src()]; 1323 Node* dst_copy = node_map[e->dst()]; 1324 dest->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input()); 1325 } 1326 } 1327 1328 } // namespace tensorflow 1329