Home | History | Annotate | Download | only in graph
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/graph/graph_constructor.h"
     17 
     18 #include <algorithm>
     19 #include <set>
     20 #include <string>
     21 #include <unordered_map>
     22 #include <vector>
     23 
     24 #include "tensorflow/core/common_runtime/shape_refiner.h"
     25 #include "tensorflow/core/framework/function.h"
     26 #include "tensorflow/core/framework/function.pb.h"
     27 #include "tensorflow/core/framework/graph.pb.h"
     28 #include "tensorflow/core/framework/node_def.pb.h"
     29 #include "tensorflow/core/framework/node_def_util.h"
     30 #include "tensorflow/core/framework/tensor_shape.pb.h"
     31 #include "tensorflow/core/framework/types.h"
     32 #include "tensorflow/core/framework/versions.h"
     33 #include "tensorflow/core/framework/versions.pb.h"
     34 #include "tensorflow/core/graph/algorithm.h"
     35 #include "tensorflow/core/graph/graph.h"
     36 #include "tensorflow/core/graph/tensor_id.h"
     37 #include "tensorflow/core/lib/core/errors.h"
     38 #include "tensorflow/core/lib/gtl/inlined_vector.h"
     39 #include "tensorflow/core/lib/strings/scanner.h"
     40 #include "tensorflow/core/platform/logging.h"
     41 #include "tensorflow/core/public/version.h"
     42 
     43 namespace tensorflow {
     44 
     45 namespace {
     46 inline bool IsMerge(const NodeDef& node_def) {
     47   return node_def.op() == "Merge" || node_def.op() == "RefMerge";
     48 }
     49 
     50 inline bool IsNextIteration(const NodeDef& node_def) {
     51   return node_def.op() == "NextIteration" ||
     52          node_def.op() == "RefNextIteration";
     53 }
     54 
     55 bool IsValidNodeName(StringPiece s, bool allow_internal_ops) {
     56   using ::tensorflow::strings::Scanner;
     57   return Scanner(s)
     58       .One(allow_internal_ops ? Scanner::LETTER_DIGIT_DOT_UNDERSCORE
     59                               : Scanner::LETTER_DIGIT_DOT)
     60       .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE)
     61       .Eos()
     62       .GetResult();
     63 }
     64 
     65 class GraphConstructor {
     66  public:
     67   struct Options {
     68     Options(const GraphConstructorOptions& in)  // NOLINT(runtime/explicit)
     69         : allow_internal_ops(in.allow_internal_ops),
     70           expect_device_spec(in.expect_device_spec),
     71           importing(false),
     72           validate_colocation_constraints(false) {}
     73     Options(const ImportGraphDefOptions& in)  // NOLINT(runtime/explicit)
     74         : allow_internal_ops(false),
     75           expect_device_spec(false),
     76           prefix(in.prefix.empty() || StringPiece(in.prefix).ends_with("/")
     77                      ? in.prefix
     78                      : in.prefix + "/"),
     79           uniquify_names(in.uniquify_names),
     80           uniquify_prefix(in.uniquify_prefix),
     81           input_map(in.input_map),
     82           skip_mapped_nodes(in.skip_mapped_nodes),
     83           control_dependencies(in.control_dependencies),
     84           return_tensors(in.return_tensors),
     85           return_nodes(in.return_nodes),
     86           importing(true),
     87           validate_colocation_constraints(in.validate_colocation_constraints),
     88           validate_shape(in.validate_shape) {}
     89 
     90     bool allow_internal_ops;
     91     bool expect_device_spec;
     92 
     93     string prefix;
     94     bool uniquify_names;
     95     bool uniquify_prefix;
     96     std::map<TensorId, TensorId> input_map;
     97     bool skip_mapped_nodes;
     98     std::vector<string> control_dependencies;
     99     std::vector<TensorId> return_tensors;
    100     std::vector<string> return_nodes;
    101 
    102     // TODO(ashankar): This bool exists to separate out functionality required
    103     // to make ImportGraphDef a close equivalent of Python's import_graph_def
    104     // without affecting the behavior of ConvertGraphDefToGraph at the time
    105     // ImportGraphDef was added.
    106     //
    107     // That said, the functionality here (shape and op validation) seems
    108     // applicable to ConvertGraphDefToGraph as well, so make an attempt to
    109     // remove this.
    110     bool importing;
    111     bool validate_colocation_constraints;
    112     bool validate_shape = true;
    113   };
    114 
    115   typedef gtl::ArraySlice<const NodeDef*> NodeDefSlice;
    116 
    117   // versions and library may be nullptr
    118   static Status Construct(
    119       const Options& opts, NodeDefSlice node_defs, const VersionDef* versions,
    120       const FunctionDefLibrary* library, Graph* g, ShapeRefiner* refiner,
    121       std::vector<std::pair<Node*, int>>* return_tensors,
    122       std::vector<Node*>* return_nodes,
    123       std::vector<TensorId>* missing_unused_input_map_keys) {
    124     if (versions) {
    125       TF_RETURN_IF_ERROR(CheckVersions(*versions, TF_GRAPH_DEF_VERSION,
    126                                        TF_GRAPH_DEF_VERSION_MIN_PRODUCER,
    127                                        "GraphDef", "graph"));
    128     }
    129     GraphConstructor c(opts, node_defs, versions, library, g, refiner,
    130                        return_tensors, return_nodes,
    131                        missing_unused_input_map_keys);
    132     const Status s = c.TryImport();
    133     if (!s.ok()) c.Undo();
    134     return s;
    135   }
    136 
    137  private:
    138   GraphConstructor(const Options& opts, NodeDefSlice node_defs,
    139                    const VersionDef* versions,
    140                    const FunctionDefLibrary* library, Graph* g,
    141                    ShapeRefiner* refiner,
    142                    std::vector<std::pair<Node*, int>>* return_tensors,
    143                    std::vector<Node*>* return_nodes,
    144                    std::vector<TensorId>* missing_unused_input_map_keys)
    145       : opts_(opts),
    146         node_defs_(node_defs),
    147         versions_(versions),
    148         library_(library),
    149         g_(g),
    150         original_versions_(g->versions()),
    151         prefix_(opts.prefix),
    152         refiner_(refiner),
    153         return_tensors_(return_tensors),
    154         return_nodes_(return_nodes),
    155         missing_unused_input_map_keys_(missing_unused_input_map_keys) {}
    156 
    157   Status TryImport() {
    158     TF_RETURN_IF_ERROR(EnsureNoNameCollisions());
    159     TF_RETURN_IF_ERROR(ValidateInputMapAndControlDependencies());
    160     TF_RETURN_IF_ERROR(BuildNodeIndex());
    161     TF_RETURN_IF_ERROR(InitFromEdges());
    162     TF_RETURN_IF_ERROR(Convert());
    163     TF_RETURN_IF_ERROR(AddBackEdges());
    164     TF_RETURN_IF_ERROR(UpdateVersionDef());
    165     TF_RETURN_IF_ERROR(PopulateReturnTensors());
    166     TF_RETURN_IF_ERROR(PopulateReturnNodes());
    167     TF_RETURN_IF_ERROR(PopulateMissingUnusedInputMapKeys());
    168     UpdateUniquifiedColocationNames();
    169     FixupSourceAndSinkEdges(g_);
    170     return Status::OK();
    171   }
    172 
    173   Status EnsureNoNameCollisions();
    174   Status ValidateInputMapAndControlDependencies();
    175   Status BuildNodeIndex();
    176   Status InitFromEdges();
    177   Status Convert();
    178   Status AddBackEdges();
    179   Status UpdateVersionDef();
    180   Status PopulateReturnTensors();
    181   Status PopulateReturnNodes();
    182   Status PopulateMissingUnusedInputMapKeys();
    183 
    184   void Undo();
    185 
    186   Status IsNodeFullyMapped(const NodeDef& node_def, bool* is_node_mapped);
    187   Status ValidateColocationConstraints(const NodeDef& node_def);
    188   Status MakeNode(const NodeDef& node_def, Node** node);
    189   Status MakeEdge(Node* src, int output_index, Node* dst, int input_index);
    190   Status ValidateShape(Node* node);
    191   Status ModifyNodeDefForImport(NodeDef* node_def);
    192   // Modifies node_def's inputs according to opts_.input_map.
    193   // input_already_exists is a pre-initialized vector of length
    194   // node_def->input_size(). This function will mark inputs that are remapped to
    195   // true.
    196   void RemapNodeDefInputs(NodeDef* node_def,
    197                           std::vector<bool>* input_already_exists);
    198   // input_already_exists is a pre-initialized vector of length
    199   // node_def->input_size(). This function will add and mark control inputs as
    200   // true.
    201   void AddControlDependencies(NodeDef* node_def,
    202                               std::vector<bool>* input_already_exists);
    203   void AddPrefixToNodeDef(const std::vector<bool>& input_already_exists,
    204                           NodeDef* node_def);
    205 
    206   // Modifies `node_def` if its name isn't unique, or if any of its inputs'
    207   // names have been uniquified. This must be called in topological order on all
    208   // nodes.
    209   void UniquifyNames(const std::vector<bool>& input_already_exists,
    210                      NodeDef* node_def);
    211 
    212   // Updates any constructed nodes' colocation group names if the name has been
    213   // updated by UniquifyNames. This is called after all the nodes have been
    214   // constructed so all the names have been uniquified if necessary.
    215   void UpdateUniquifiedColocationNames();
    216 
    217   // Returns true if `name` already exists in `g_` (either as a node name or
    218   // prefix).
    219   bool NameExistsInGraph(StringPiece name);
    220 
    221   // Returns true if `name` already exists in the GraphDef being imported
    222   // (either as a node name or prefix).
    223   bool NameExistsInGraphDef(StringPiece name);
    224 
    225   // Returns a unique version of `original_name`, or `original_name` if it's
    226   // already unique in the graph.
    227   string FindUniqueName(StringPiece original_name);
    228 
    229   // From constructor
    230   const Options opts_;
    231   const NodeDefSlice node_defs_;
    232   const VersionDef* versions_;
    233   const FunctionDefLibrary* library_;
    234   Graph* g_;
    235   const VersionDef original_versions_;
    236 
    237   // A copy of opts_.prefix, possibly uniquified.
    238   string prefix_;
    239 
    240   ShapeRefiner* refiner_;
    241 
    242   // May be null. Not owned.
    243   std::vector<std::pair<Node*, int>>* return_tensors_;
    244 
    245   // May be null. Not owned.
    246   std::vector<Node*>* return_nodes_;
    247 
    248   // May be null. Not owned.
    249   std::vector<TensorId>* missing_unused_input_map_keys_;
    250 
    251   // Intermediate datastructure used to populate
    252   // `missing_unused_input_map_keys_`.
    253   std::set<TensorId> used_input_map_keys_;
    254 
    255   // Mapping from node name to the index within node_defs_.
    256   struct NodeInfo {
    257     explicit NodeInfo(int i) : gdef_index(i), node(nullptr) {}
    258     // std::unordered_map<> requires that we have a default constructor.
    259     NodeInfo() : NodeInfo(-1) {}
    260     int gdef_index;
    261     Node* node;  // nullptr until the NodeDef is converted to a Node.
    262   };
    263   // TODO(vrv): Profile this data structure to see if we should use an
    264   // alternative implementation of std::unordered_map.
    265   std::unordered_map<StringPiece, NodeInfo, StringPieceHasher> gdef_nodes_;
    266 
    267   // Prefixes already used in the GraphDef being imported.
    268   std::unordered_set<StringPiece, StringPieceHasher> gdef_prefixes_;
    269 
    270   // Mapping from node name to the existing node in g_.
    271   std::unordered_map<StringPiece, Node*, StringPieceHasher> existing_nodes_;
    272 
    273   // Prefixes already used in the graph.
    274   std::unordered_set<StringPiece, StringPieceHasher> existing_prefixes_;
    275 
    276   // Imported node names that have been uniquified. The key is the original
    277   // name, the value is the new unique name.
    278   std::unordered_map<string, string> uniquified_names_;
    279 
    280   // Index of NodeDefs in node_defs_ with all inputs already converted.
    281   std::vector<int> ready_;
    282 
    283   // Mapping between index within node_defs_ and the number of inputs that
    284   // still need to be converted.
    285   std::vector<int> pending_count_;
    286 
    287   // Mapping between index within node_defs_ and the index within node_defs_ of
    288   // all nodes it outputs to.
    289   std::vector<gtl::InlinedVector<int, 4>> outputs_;
    290 
    291   // Used in the conversion from node_defs_ to g_ to represent the ith input
    292   // of a node.
    293   struct InputInfo {
    294     explicit InputInfo(const string& node_name, Node* n, int i)
    295         : name(node_name), node(n), index(i) {}
    296     // Use string instead of StringPiece so we don't have to manage lifetime
    297     string name;
    298     Node* node;
    299     int index;
    300   };
    301 
    302   // Used in the conversion from node_defs_ to g_ to represent an edge from
    303   // the node named 'name' to node 'n'.
    304   struct EdgeInfo {
    305     explicit EdgeInfo(const string& name, int i1, Node* n, int i2)
    306         : src_name(name), src_index(i1), dst_node(n), dst_index(i2) {}
    307     // Use string instead of StringPiece so we don't have to manage lifetime
    308     string src_name;
    309     int src_index;
    310     Node* dst_node;
    311     int dst_index;
    312   };
    313   std::vector<EdgeInfo> back_edges_;
    314 };
    315 
    316 // This could be expensive but we don't expect to call it often, if at all (only
    317 // if there are multiple nodes in g_ with the same name)
    318 bool NodeNameInValues(const std::map<TensorId, TensorId>& input_map,
    319                       const StringPiece& node_name) {
    320   for (auto iter = input_map.begin(); iter != input_map.end(); ++iter) {
    321     if (iter->second.first == node_name) return true;
    322   }
    323   return false;
    324 }
    325 
    326 bool NodeNameInValues(const std::vector<string>& control_dependencies,
    327                       const StringPiece& node_name) {
    328   return std::find(control_dependencies.begin(), control_dependencies.end(),
    329                    node_name) != control_dependencies.end();
    330 }
    331 
    332 // Adds any prefixes of `node_name` (not including the full name itself) to
    333 // `prefixes`.
    334 void AddPrefixes(StringPiece node_name,
    335                  std::unordered_set<StringPiece, StringPieceHasher>* prefixes) {
    336   size_t idx = -1;
    337   while ((idx = node_name.find('/', idx + 1)) != StringPiece::npos) {
    338     prefixes->insert(node_name.substr(0, idx));
    339   }
    340 }
    341 
    342 Status GraphConstructor::EnsureNoNameCollisions() {
    343   existing_nodes_.reserve(g_->num_nodes());
    344   // Populate existing_nodes_ and existing_prefixes_.
    345   for (Node* n : g_->nodes()) {
    346     bool already_exists = !existing_nodes_.insert({n->name(), n}).second;
    347     if (already_exists) {
    348       if (NodeNameInValues(opts_.input_map, n->name())) {
    349         return errors::InvalidArgument(
    350             "cannot resolve input_map because multiple nodes exist with name '",
    351             n->name(), "'");
    352       }
    353       if (NodeNameInValues(opts_.control_dependencies, n->name())) {
    354         return errors::InvalidArgument(
    355             "cannot resolve control_dependencies because multiple nodes exist "
    356             "with name '",
    357             n->name(), "'");
    358       }
    359     }
    360     AddPrefixes(n->name(), &existing_prefixes_);
    361   }
    362   if (prefix_.empty() && opts_.importing && !opts_.uniquify_names) {
    363     for (const NodeDef* n : node_defs_) {
    364       const string& name = n->name();
    365       if (NameExistsInGraph(name)) {
    366         return errors::InvalidArgument("Node name '", name,
    367                                        "' already exists in the Graph");
    368       }
    369     }
    370   } else if (!prefix_.empty()) {
    371     StringPiece prefix_no_slash(prefix_);
    372     prefix_no_slash.remove_suffix(1);
    373     if (!IsValidNodeName(prefix_no_slash, false)) {
    374       return errors::InvalidArgument("Imported node name prefix '", prefix_,
    375                                      "' would lead to invalid node names");
    376     }
    377     if (NameExistsInGraph(prefix_no_slash) && opts_.uniquify_prefix) {
    378       prefix_ = strings::StrCat(FindUniqueName(prefix_no_slash), "/");
    379     }
    380   }
    381   return Status::OK();
    382 }
    383 
    384 Status GraphConstructor::ValidateInputMapAndControlDependencies() {
    385   for (const auto& mapping : opts_.input_map) {
    386     TensorId src = mapping.first;
    387     TensorId dst = mapping.second;
    388     if (existing_nodes_.count(dst.first) == 0) {
    389       return errors::InvalidArgument(
    390           "node '", dst.first, "' in input_map does not exist in graph ",
    391           "(input_map entry: ", src.ToString(), "->", dst.ToString(), ")");
    392     }
    393     if ((src.second == Graph::kControlSlot) !=
    394         (dst.second == Graph::kControlSlot)) {
    395       return errors::InvalidArgument("input_map entry ", src.ToString(), "->",
    396                                      dst.ToString(), " between ",
    397                                      "control edge and non-control edge");
    398     }
    399   }
    400   for (const string& node : opts_.control_dependencies) {
    401     if (existing_nodes_.count(node) == 0) {
    402       return errors::InvalidArgument(
    403           "node '", node,
    404           "' in control_dependencies does not exist in "
    405           "graph");
    406     }
    407   }
    408   return Status::OK();
    409 }
    410 
    411 Status GraphConstructor::BuildNodeIndex() {
    412   // Validate the node names and add them to gdef_nodes_ and gdef_prefixes_.
    413   for (int n = 0; n < node_defs_.size(); ++n) {
    414     const NodeDef& node_def = *node_defs_[n];
    415     if (!IsValidNodeName(node_def.name(), opts_.allow_internal_ops)) {
    416       return errors::InvalidArgument(
    417           "Node '", node_def.name(),
    418           "': Node name contains invalid characters");
    419     }
    420     if (!gdef_nodes_
    421              .insert(std::make_pair(StringPiece(node_def.name()), NodeInfo(n)))
    422              .second) {
    423       return errors::InvalidArgument("Node '", node_def.name(),
    424                                      "' is not unique");
    425     }
    426     // Validate the operation's type.
    427     if (node_def.op().empty()) {
    428       return errors::InvalidArgument("Node '", node_def.name(),
    429                                      "' does not specify an operation");
    430     }
    431     if (opts_.expect_device_spec && node_def.device().empty()) {
    432       return errors::InvalidArgument("Node '", node_def.name(),
    433                                      "' is missing a device specification");
    434     }
    435     // Validate control edges at end
    436     bool in_control_dependence = false;
    437     for (int i = 0; i < node_def.input_size(); ++i) {
    438       StringPiece input_name = node_def.input(i);
    439       if (!input_name.empty() && input_name.starts_with("^")) {
    440         in_control_dependence = true;
    441       } else if (in_control_dependence) {
    442         return errors::InvalidArgument(
    443             "Node '", node_def.name(),
    444             "': Control dependencies must come after regular dependencies");
    445       }
    446     }
    447     // Update gdef_prefixes_.
    448     AddPrefixes(node_def.name(), &gdef_prefixes_);
    449   }
    450   return Status::OK();
    451 }
    452 
    453 std::unordered_set<string> GetNextIterationNodes(
    454     const GraphConstructor::NodeDefSlice& node_defs) {
    455   std::unordered_set<string> next_iteration_nodes;
    456 
    457   for (int n = 0; n < node_defs.size(); ++n) {
    458     const NodeDef& node_def = *node_defs[n];
    459     if (IsNextIteration(node_def)) {
    460       next_iteration_nodes.insert(node_def.name());
    461     }
    462   }
    463 
    464   return next_iteration_nodes;
    465 }
    466 
    467 Status GraphConstructor::InitFromEdges() {
    468   const int num_nodes = node_defs_.size();
    469   pending_count_.reserve(num_nodes);
    470   outputs_.resize(num_nodes);
    471   std::unordered_set<string> next_iteration_nodes_ =
    472       GetNextIterationNodes(node_defs_);
    473 
    474   // Parse the inputs for each node.
    475   for (int n = 0; n < num_nodes; ++n) {
    476     const NodeDef& node_def = *node_defs_[n];
    477     int pending_count = node_def.input_size();
    478     if (IsMerge(node_def)) {
    479       // Cycles in the graph are only allowed for while loops. A while loop is
    480       // identified by an edge from a NextIteration node to a Merge node. For
    481       // such Merge nodes, only wait for one non-control input before
    482       // considering the node ready to process in Convert().
    483       int32 num_control_edges = 0;
    484       bool has_loop_back_edge = false;
    485       for (int i = 0; i < node_def.input_size(); ++i) {
    486         StringPiece input_name(node_def.input(i));
    487         if (input_name.starts_with("^")) {
    488           num_control_edges++;
    489         } else {
    490           TensorId id(ParseTensorName(input_name));
    491           if (next_iteration_nodes_.find(id.first.ToString()) !=
    492               next_iteration_nodes_.end()) {
    493             has_loop_back_edge = true;
    494           }
    495         }
    496       }
    497       if (has_loop_back_edge) {
    498         pending_count = num_control_edges + 1;
    499       }
    500     }
    501     for (int i = 0; i < node_def.input_size(); ++i) {
    502       StringPiece input_name = node_def.input(i);
    503       TensorId id(ParseTensorName(input_name));
    504       if (opts_.input_map.count(id) == 0) {
    505         // If an input is not mapped, then the input should appear in the graph
    506         // being imported.
    507         auto iter = gdef_nodes_.find(id.first);
    508         if (iter == gdef_nodes_.end()) {
    509           return errors::InvalidArgument("Node '", node_def.name(),
    510                                          "': Unknown input node '",
    511                                          node_def.input(i), "'");
    512         }
    513         outputs_[iter->second.gdef_index].push_back(n);
    514       } else {
    515         // This input is mapped to an existing edge. Therefore this input is
    516         // as good as being already processed.
    517         --pending_count;
    518         DCHECK_GE(pending_count, 0);
    519       }
    520     }
    521     if (pending_count == 0) {
    522       ready_.push_back(n);
    523     }
    524     pending_count_.push_back(pending_count);
    525   }
    526   return Status::OK();
    527 }
    528 
    529 Status GraphConstructor::ValidateColocationConstraints(
    530     const NodeDef& node_def) {
    531   if (!opts_.validate_colocation_constraints || !opts_.importing)
    532     return Status::OK();
    533   const auto iter = node_def.attr().find(kColocationAttrName);
    534   if (iter == node_def.attr().end()) return Status::OK();
    535   for (const string& c : iter->second.list().s()) {
    536     StringPiece s(c);
    537     if (s.Consume(kColocationGroupPrefix) &&
    538         gdef_nodes_.find(s) == gdef_nodes_.end()) {
    539       return errors::InvalidArgument(
    540           "Node '", node_def.name(),
    541           "' expects to be colocated with unknown node '", s, "'");
    542     }
    543   }
    544   return Status::OK();
    545 }
    546 
    547 Status GraphConstructor::MakeNode(const NodeDef& node_def, Node** node) {
    548   // Add the node to the graph.
    549   Status status;
    550   *node = g_->AddNode(node_def, &status);
    551   if (!status.ok()) return status;
    552   if (opts_.expect_device_spec) {
    553     (*node)->set_assigned_device_name(node_def.device());
    554   }
    555   return Status::OK();
    556 }
    557 
    558 Status GraphConstructor::ValidateShape(Node* node) {
    559   if (!opts_.importing || !opts_.validate_shape) return Status::OK();
    560   TF_RETURN_IF_ERROR(refiner_->AddNode(node));
    561   // For nodes with the _output_shapes attribute, override the shape.
    562   std::vector<TensorShapeProto> shape_attrs;
    563   const char* kAttrName = "_output_shapes";
    564   if (!GetNodeAttr(node->attrs(), kAttrName, &shape_attrs).ok()) {
    565     // No _output_shapes attribute, the AddNode call above was sufficient.
    566     return Status::OK();
    567   }
    568   auto* ic = refiner_->GetContext(node);
    569   DCHECK(ic != nullptr)
    570       << "ShapeRefiner::AddNode() should have created the InferenceContext";
    571   if (shape_attrs.size() != node->num_outputs()) {
    572     return errors::InvalidArgument(
    573         "Node '", node->name(), "' has ", node->num_outputs(),
    574         " outputs but the ", kAttrName, " attribute specifies shapes for ",
    575         shape_attrs.size(), " outputs");
    576   }
    577   for (int i = 0; i < shape_attrs.size(); ++i) {
    578     const TensorShapeProto& p = shape_attrs[i];
    579     shape_inference::ShapeHandle h;
    580     Status s = ic->MakeShapeFromShapeProto(p, &h);
    581     if (!s.ok()) {
    582       return errors::InvalidArgument("Node '", node->name(), " has an invalid ",
    583                                      kAttrName, " attribute (shape #", i,
    584                                      " error:'", s.error_message(), "'");
    585     }
    586     s = refiner_->SetShape(node, i, h);
    587     if (!s.ok()) {
    588       // If the output shape is incompatible with what is inferred
    589       // by the graph for a very specific whitelist of ops, then we
    590       // ignore this output shape.  This can happen if there is a
    591       // bug in the shape function for some operation, and the
    592       // serialized graph def has the incorrect shape set when
    593       // running on a newer binary with the fixed shape function.
    594       // This is an escape hatch that allows us to correct shape
    595       // functions that are not critical to correct execution but
    596       // would cause graphs to fail if imported after correcting.
    597       //
    598       const string& op = node->type_string();
    599       const std::vector<string> whitelist = {
    600           // To be removed after 2017/03/08.
    601           "RandomShuffleQueue",
    602           "PaddingFIFOQueue",
    603           "FIFOQueue",
    604           "PriorityQueue",
    605           "QueueSize",
    606           "Stack",
    607           "Barrier",
    608           "BarrierReadySize",
    609           "BarrierIncompleteSize",
    610           "HashTable",
    611           "MutableHashTable",
    612           "MutableHashTableOfTensors",
    613           "Mutex",
    614           "CuckooTable",
    615           "IndexTable",
    616           "WholeFileReader",
    617           "TextLineReader",
    618           "FixedLengthRecordReader",
    619           "TFRecordReader",
    620           "IdentityReader",
    621           "RefSwitch",
    622           "RefEnter",
    623           "RefNextIteration",
    624           "RefMerge",
    625           "RefIdentity",
    626           "LMDBReader",
    627           // To be removed after 2017/04/24.
    628           "ConditionalAccumulator",
    629           "SparseConditionalAccumulator",
    630           "Table",
    631       };
    632       if (std::find(whitelist.begin(), whitelist.end(), op) ==
    633           whitelist.end()) {
    634         return errors::InvalidArgument(
    635             "Node '", node->name(), "' has an ", kAttrName,
    636             " attribute inconsistent with the GraphDef for output #", i, ": ",
    637             s.error_message());
    638       }
    639     }
    640   }
    641   node->ClearAttr(kAttrName);
    642   return Status::OK();
    643 }
    644 
    645 Status GraphConstructor::ModifyNodeDefForImport(NodeDef* node_def) {
    646   const OpDef* op_def;
    647   TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def->op(), &op_def));
    648   AddDefaultsToNodeDef(*op_def, node_def);
    649   TF_RETURN_IF_ERROR(ValidateNodeDef(*node_def, *op_def));
    650   if (versions_) {
    651     TF_RETURN_IF_ERROR(CheckOpDeprecation(*op_def, versions_->producer()));
    652   }
    653   return Status::OK();
    654 }
    655 
    656 void RemoveInputs(const std::vector<int>& inputs_to_remove, NodeDef* node_def,
    657                   std::vector<bool>* input_already_exists) {
    658   // Remove 'inputs_to_remove' from 'node_def'
    659   // TODO(skyewm): is there a better way to do this?
    660   std::vector<string> inputs;
    661   inputs.reserve(node_def->input_size());
    662   for (int i = 0; i < node_def->input_size(); ++i) {
    663     inputs.push_back(node_def->input(i));
    664   }
    665   node_def->clear_input();
    666   for (int i = 0, j = 0; i < inputs.size(); ++i) {
    667     if (j < inputs_to_remove.size() && i == inputs_to_remove[j]) {
    668       ++j;
    669     } else {
    670       node_def->add_input(inputs[i]);
    671     }
    672   }
    673   // Remove 'inputs_to_remove' from 'input_already_exists'
    674   for (int idx : inputs_to_remove) {
    675     input_already_exists->erase(input_already_exists->begin() + idx);
    676   }
    677   DCHECK_EQ(input_already_exists->size(), node_def->input_size());
    678 }
    679 
    680 void GraphConstructor::RemapNodeDefInputs(
    681     NodeDef* node_def, std::vector<bool>* input_already_exists) {
    682   DCHECK_EQ(input_already_exists->size(), node_def->input_size());
    683   std::set<TensorId> control_inputs;
    684   std::vector<int> inputs_to_remove;
    685 
    686   for (int i = 0; i < node_def->input_size(); ++i) {
    687     auto iter = opts_.input_map.find(ParseTensorName(node_def->input(i)));
    688     if (iter == opts_.input_map.end()) continue;
    689     used_input_map_keys_.insert(iter->first);
    690 
    691     TensorId new_input = iter->second;
    692     if (new_input.second == Graph::kControlSlot) {
    693       // Check if we've already remapped a different input to new_input, and if
    694       // so remove this input.
    695       if (control_inputs.count(new_input) > 0) {
    696         inputs_to_remove.push_back(i);
    697         continue;
    698       }
    699       control_inputs.insert(new_input);
    700     }
    701     node_def->set_input(i, new_input.ToString());
    702     (*input_already_exists)[i] = true;
    703   }
    704   if (!inputs_to_remove.empty()) {
    705     RemoveInputs(inputs_to_remove, node_def, input_already_exists);
    706   }
    707 }
    708 
    709 void GraphConstructor::AddControlDependencies(
    710     NodeDef* node_def, std::vector<bool>* input_already_exists) {
    711   // To avoid adding redundant control dependencies to every imported node, skip
    712   // nodes that will inherit the dependencies from another imported node.
    713   bool inherits_deps = false;
    714   for (int i = 0; i < node_def->input_size(); ++i) {
    715     // Assume we won't inherit dependencies from remapped inputs that already
    716     // exist in the graph. Even if we're wrong, we'll only add redundant
    717     // dependencies.
    718     if ((*input_already_exists)[i]) continue;
    719 
    720     // If this input is a backedge, assume we won't inherit the dependencies.
    721     // TODO(skyewm): we have many redundant ParseTensorName calls. It could be
    722     // worth optimizing these.
    723     TensorId id(ParseTensorName(node_def->input(i)));
    724     auto iter = gdef_nodes_.find(id.first);
    725     DCHECK(iter != gdef_nodes_.end()) << id.first;
    726     if (iter->second.node == nullptr) {
    727       // Input hasn't been created yet, indicating it's a backedge.
    728       continue;
    729     }
    730     inherits_deps = true;
    731   }
    732   if (inherits_deps) return;
    733 
    734   // node_def either has no inputs or all remapped inputs, add the control
    735   // dependencies
    736   for (const string& control_dep : opts_.control_dependencies) {
    737     string input = TensorId(control_dep, Graph::kControlSlot).ToString();
    738     const protobuf::RepeatedPtrField<string>& inputs = node_def->input();
    739     if (std::find(inputs.begin(), inputs.end(), input) != inputs.end()) {
    740       // Control dependency already exists
    741       continue;
    742     }
    743     node_def->add_input(input);
    744     input_already_exists->push_back(true);
    745   }
    746 }
    747 
    748 void GraphConstructor::AddPrefixToNodeDef(
    749     const std::vector<bool>& input_already_exists, NodeDef* node_def) {
    750   if (prefix_.empty()) return;
    751   node_def->set_name(strings::StrCat(prefix_, node_def->name()));
    752   // Update names of input nodes
    753   for (int i = 0; i < node_def->input_size(); ++i) {
    754     StringPiece input(node_def->input(i));
    755     // Skip remapped inputs (which already exist in g_ and are not being
    756     // imported).
    757     if (input_already_exists[i]) continue;
    758     if (input.Consume("^")) {
    759       node_def->set_input(i, strings::StrCat("^", prefix_, input));
    760     } else {
    761       node_def->set_input(i, strings::StrCat(prefix_, input));
    762     }
    763   }
    764   // Update names of colocation groups
    765   if (node_def->attr().find(kColocationAttrName) != node_def->attr().end()) {
    766     auto* list =
    767         node_def->mutable_attr()->at(kColocationAttrName).mutable_list();
    768     for (int i = 0; i < list->s_size(); ++i) {
    769       StringPiece v(list->s(i));
    770       if (v.Consume(kColocationGroupPrefix)) {
    771         list->set_s(i, strings::StrCat(kColocationGroupPrefix, prefix_, v));
    772       }
    773     }
    774   }
    775 }
    776 
    777 void GraphConstructor::UniquifyNames(
    778     const std::vector<bool>& input_already_exists, NodeDef* node_def) {
    779   if (NameExistsInGraph(node_def->name())) {
    780     string old_name = node_def->name();
    781     node_def->set_name(FindUniqueName(node_def->name()));
    782     uniquified_names_[old_name] = node_def->name();
    783     // Note that we don't have to update gdef_nodes_ or gdef_prefixes_ with
    784     // `name` because we guarantee the original NodeDef names are unique,
    785     // meaning we won't generate this name again.
    786   }
    787   for (int i = 0; i < node_def->input_size(); ++i) {
    788     // Skip remapped inputs (which already exist in g_ and are not being
    789     // imported).
    790     if (input_already_exists[i]) continue;
    791     TensorId id = ParseTensorName(node_def->input(i));
    792     // We require that UniquifyNames() is called on all NodeDefs in topological
    793     // order. This guarantees that node_def's inputs will already be uniquified
    794     // if necessary.
    795     auto iter = uniquified_names_.find(id.first.ToString());
    796     if (iter == uniquified_names_.end()) continue;
    797     id.first = iter->second;
    798     node_def->set_input(i, id.ToString());
    799   }
    800 }
    801 
    802 void GraphConstructor::UpdateUniquifiedColocationNames() {
    803   for (const auto& pair : gdef_nodes_) {
    804     Node* node = pair.second.node;
    805     if (node == nullptr) continue;
    806     std::vector<string> coloc_values;
    807     Status status =
    808         GetNodeAttr(node->attrs(), kColocationAttrName, &coloc_values);
    809     if (!status.ok()) continue;
    810     bool updated = false;
    811     for (int i = 0; i < coloc_values.size(); ++i) {
    812       StringPiece val(coloc_values[i]);
    813       if (val.Consume(kColocationGroupPrefix)) {
    814         const auto& name_pair = uniquified_names_.find(val.ToString());
    815         if (name_pair == uniquified_names_.end()) continue;
    816         updated = true;
    817         coloc_values[i] =
    818             strings::StrCat(kColocationGroupPrefix, name_pair->second);
    819       }
    820     }
    821     if (updated) {
    822       node->AddAttr(kColocationAttrName, coloc_values);
    823     }
    824   }
    825 }
    826 
    827 bool GraphConstructor::NameExistsInGraph(StringPiece name) {
    828   if (existing_nodes_.find(name) != existing_nodes_.end()) return true;
    829   if (existing_prefixes_.find(name) != existing_prefixes_.end()) return true;
    830   return false;
    831 }
    832 
    833 bool GraphConstructor::NameExistsInGraphDef(StringPiece name) {
    834   if (gdef_nodes_.find(name) != gdef_nodes_.end()) return true;
    835   if (gdef_prefixes_.find(name) != gdef_prefixes_.end()) return true;
    836   return false;
    837 }
    838 
    839 string GraphConstructor::FindUniqueName(StringPiece original_name) {
    840   string name = original_name.ToString();
    841   int count = 0;
    842   // Check that any generated names don't collide with imported NodeDefs (as
    843   // well as nodes in g_).
    844   while (NameExistsInGraph(name) || (count > 0 && NameExistsInGraphDef(name))) {
    845     name = strings::StrCat(original_name, "_", ++count);
    846   }
    847   return name;
    848 }
    849 
    850 Status GraphConstructor::IsNodeFullyMapped(const NodeDef& node_def,
    851                                            bool* is_node_mapped) {
    852   const OpDef* op_def;
    853   TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def.op(), &op_def));
    854   for (int i = 0; i < op_def->output_arg_size(); ++i) {
    855     if (opts_.input_map.find({node_def.name(), i}) == opts_.input_map.end()) {
    856       *is_node_mapped = false;
    857       return Status::OK();
    858     }
    859   }
    860   *is_node_mapped = true;
    861   return Status::OK();
    862 }
    863 
    864 namespace {
    865 
    866 void UpdatePendingCountAndReady(
    867     const std::vector<gtl::InlinedVector<int, 4>>& outputs, int o,
    868     std::vector<int>* pending_count, std::vector<int>* ready) {
    869   for (size_t i = 0; i < outputs[o].size(); ++i) {
    870     const int output = outputs[o][i];
    871     (*pending_count)[output]--;
    872     if ((*pending_count)[output] == 0) {
    873       ready->push_back(output);
    874     }
    875   }
    876 }
    877 
    878 }  // anonymous namespace
    879 
    880 Status GraphConstructor::Convert() {
    881   // Import functions before adding nodes, since imported nodes may refer to
    882   // functions
    883   if (library_) {
    884     TF_RETURN_IF_ERROR(g_->AddFunctionLibrary(*library_));
    885   }
    886 
    887   std::vector<InputInfo> inputs;
    888   int processed = 0;
    889 
    890   std::vector<bool> input_already_exists;
    891 
    892   // Process the NodeDefs in topological order.
    893   // (InitFromEdges() sets this up by filling in ready_ with nodes that have no
    894   // inputs, pending_counts_ with the number of inputs for each node and
    895   // outputs_ with the outputs of each node).
    896   while (!ready_.empty()) {
    897     int o = ready_.back();
    898     ready_.pop_back();
    899     ++processed;
    900     inputs.clear();
    901     bool has_data_back_edge = false;
    902 
    903     const NodeDef& original_node_def = *node_defs_[o];
    904     NodeDef imported_node_def;
    905     const NodeDef* node_def;
    906 
    907     // input_already_exists[i] is true iff the i-th input of the node we're
    908     // importing refers to a preexisting node in g_ (i.e. input[i] existed prior
    909     // to importing node_defs_).  Conversely, input_already_exists[i] is false
    910     // iff the input refers to a node in node_defs_.
    911     input_already_exists.clear();
    912     input_already_exists.resize(original_node_def.input_size(), false);
    913 
    914     if (opts_.importing) {
    915       if (opts_.skip_mapped_nodes) {
    916         bool is_node_mapped = false;
    917         TF_RETURN_IF_ERROR(
    918             IsNodeFullyMapped(original_node_def, &is_node_mapped));
    919         if (is_node_mapped) {
    920           // Skip this node after updating pending_count_ for outputs
    921           UpdatePendingCountAndReady(outputs_, o, &pending_count_, &ready_);
    922           continue;
    923         }
    924       }
    925 
    926       // TODO(ashankar): The line below means an additional copy of the NodeDef,
    927       // which can be expensive if the NodeDef contains large tensors in it.
    928       // Might make sense to change the API for ImportGraphDef to take a mutable
    929       // GraphDef* and avoid the copying.
    930       imported_node_def = original_node_def;
    931       if (!opts_.input_map.empty()) {
    932         // Note that input_already_exists can shrink here
    933         RemapNodeDefInputs(&imported_node_def, &input_already_exists);
    934       }
    935       if (!opts_.control_dependencies.empty()) {
    936         // Note that input_already_exists can grow here
    937         AddControlDependencies(&imported_node_def, &input_already_exists);
    938       }
    939       node_def = &imported_node_def;
    940     } else {
    941       node_def = &original_node_def;
    942     }
    943 
    944     DCHECK_EQ(node_def->input_size(), input_already_exists.size());
    945     TF_RETURN_IF_ERROR(ValidateColocationConstraints(*node_def));
    946     for (int i = 0; i < node_def->input_size(); ++i) {
    947       TensorId id(ParseTensorName(node_def->input(i)));
    948       Node* src_node;
    949       int src_index;
    950 
    951       if (!input_already_exists[i]) {
    952         // Locate input in newly-imported nodes
    953         auto iter = gdef_nodes_.find(id.first);
    954         DCHECK(iter != gdef_nodes_.end()) << id.first;
    955         src_node = iter->second.node;
    956         src_index = id.second;
    957         if (src_node == nullptr) has_data_back_edge = true;
    958       } else {
    959         // Input refers to preexistng node in graph
    960         auto iter = existing_nodes_.find(id.first);
    961         DCHECK(iter != existing_nodes_.end()) << id.first;
    962         src_node = iter->second;
    963         src_index = id.second;
    964       }
    965 
    966       if (src_node != nullptr && src_index >= src_node->num_outputs()) {
    967         return errors::InvalidArgument(
    968             "Node '", node_def->name(), "': Connecting to invalid output ",
    969             id.second, " of source node ", id.first, " which has ",
    970             src_node->num_outputs(), " outputs");
    971       }
    972 
    973       inputs.push_back(InputInfo(id.first.ToString(), src_node, src_index));
    974     }
    975 
    976     if (has_data_back_edge && !IsMerge(*node_def)) {
    977       return errors::InvalidArgument(
    978           "Node '", node_def->name(),
    979           "' had a back edge, but only Merge nodes can have back edges.");
    980     }
    981 
    982     Node* node;
    983     if (opts_.importing) {
    984       if (!prefix_.empty()) {
    985         AddPrefixToNodeDef(input_already_exists, &imported_node_def);
    986       }
    987       // Note: no need to uniquify names if the prefix already guarantees
    988       // uniqueness
    989       if (opts_.uniquify_names && (prefix_.empty() || !opts_.uniquify_prefix)) {
    990         UniquifyNames(input_already_exists, &imported_node_def);
    991       }
    992       TF_RETURN_IF_ERROR(ModifyNodeDefForImport(&imported_node_def));
    993     }
    994     TF_RETURN_IF_ERROR(MakeNode(*node_def, &node));
    995     // Use original_node_def so name StringPiece remains valid
    996     gdef_nodes_[original_node_def.name()].node = node;
    997 
    998     // Add edges from inputs to *node to the graph.
    999     for (size_t i = 0; i < inputs.size(); ++i) {
   1000       if (inputs[i].node == nullptr) {
   1001         // Record this back edge, which will be added after all nodes
   1002         // are created.
   1003         back_edges_.push_back(
   1004             EdgeInfo(inputs[i].name, inputs[i].index, node, i));
   1005       } else if (inputs[i].index == Graph::kControlSlot) {
   1006         g_->AddControlEdge(inputs[i].node, node);
   1007       } else {
   1008         TF_RETURN_IF_ERROR(MakeEdge(inputs[i].node, inputs[i].index, node, i));
   1009       }
   1010     }
   1011 
   1012     // Function shape inference is supported on an opt-in basis per
   1013     // ShapeRefiner.
   1014     if (refiner_->function_shape_inference_supported() ||
   1015         g_->flib_def().Find(node_def->name()) == nullptr) {
   1016       TF_RETURN_IF_ERROR(ValidateShape(node));
   1017     }
   1018 
   1019     // Update pending_count_ for outputs.
   1020     UpdatePendingCountAndReady(outputs_, o, &pending_count_, &ready_);
   1021   }
   1022 
   1023   if (processed < node_defs_.size()) {
   1024     return errors::InvalidArgument(node_defs_.size() - processed,
   1025                                    " nodes in a cycle");
   1026   }
   1027 
   1028   return Status::OK();
   1029 }
   1030 
   1031 Status GraphConstructor::AddBackEdges() {
   1032   // Add the back edges after all nodes are created.
   1033   for (auto e : back_edges_) {
   1034     Node* src_node = gdef_nodes_[e.src_name].node;
   1035     if (e.src_index == Graph::kControlSlot) {
   1036       g_->AddControlEdge(src_node, e.dst_node);
   1037     } else {
   1038       TF_RETURN_IF_ERROR(
   1039           MakeEdge(src_node, e.src_index, e.dst_node, e.dst_index));
   1040     }
   1041 
   1042     VLOG(2) << "Add back edge: " << src_node->name() << " -> "
   1043             << e.dst_node->name();
   1044   }
   1045   return Status::OK();
   1046 }
   1047 
   1048 Status GraphConstructor::UpdateVersionDef() {
   1049   if (versions_ == nullptr) return Status::OK();
   1050 
   1051   if (!opts_.importing) {
   1052     g_->set_versions(*versions_);
   1053     return Status::OK();
   1054   }
   1055   VersionDef versions = g_->versions();
   1056   versions.set_producer(std::min(versions.producer(), versions_->producer()));
   1057   versions.set_min_consumer(
   1058       std::max(versions.min_consumer(), versions_->min_consumer()));
   1059   if (versions_->bad_consumers_size() > 0) {
   1060     std::set<int> bad(versions.bad_consumers().begin(),
   1061                       versions.bad_consumers().end());
   1062     bad.insert(versions_->bad_consumers().begin(),
   1063                versions_->bad_consumers().end());
   1064     versions.clear_bad_consumers();
   1065     for (int v : bad) {
   1066       versions.add_bad_consumers(v);
   1067     }
   1068   }
   1069   g_->set_versions(versions);
   1070   return Status::OK();
   1071 }
   1072 
   1073 Status GraphConstructor::PopulateReturnTensors() {
   1074   if (opts_.return_tensors.empty()) return Status::OK();
   1075   for (const TensorId& id : opts_.return_tensors) {
   1076     auto iter = opts_.input_map.find(id);
   1077     if (iter == opts_.input_map.end()) {
   1078       // Locate id in imported nodes
   1079       auto iter = gdef_nodes_.find(id.first);
   1080       if (iter == gdef_nodes_.end()) {
   1081         return errors::InvalidArgument("Requested return tensor '",
   1082                                        id.ToString(),
   1083                                        "' not found in graph def");
   1084       }
   1085       int num_outputs = iter->second.node->num_outputs();
   1086       if ((id.second < 0 || id.second >= num_outputs) &&
   1087           id.second != Graph::kControlSlot) {
   1088         return errors::InvalidArgument("Invalid return output ", id.second,
   1089                                        " of node '", id.first, "', which has ",
   1090                                        num_outputs, " output(s)");
   1091       }
   1092       return_tensors_->push_back({iter->second.node, id.second});
   1093     } else {
   1094       // id was remapped to existing node
   1095       TensorId remapped_id = iter->second;
   1096       DCHECK_GT(existing_nodes_.count(remapped_id.first), 0);
   1097       Node* node = existing_nodes_[remapped_id.first];
   1098       return_tensors_->push_back({node, remapped_id.second});
   1099     }
   1100   }
   1101   return Status::OK();
   1102 }
   1103 
   1104 Status GraphConstructor::PopulateReturnNodes() {
   1105   if (opts_.return_nodes.empty()) return Status::OK();
   1106   for (StringPiece name : opts_.return_nodes) {
   1107     auto iter = gdef_nodes_.find(name);
   1108     if (iter == gdef_nodes_.end()) {
   1109       return errors::InvalidArgument("Requested return node '", name,
   1110                                      "' not found in graph def");
   1111     }
   1112     return_nodes_->push_back(iter->second.node);
   1113   }
   1114   return Status::OK();
   1115 }
   1116 
   1117 Status GraphConstructor::PopulateMissingUnusedInputMapKeys() {
   1118   if (missing_unused_input_map_keys_ == nullptr) return Status::OK();
   1119   for (const auto& input_map_pair : opts_.input_map) {
   1120     TensorId key = input_map_pair.first;
   1121     if (used_input_map_keys_.count(key) > 0) continue;
   1122 
   1123     auto pair = gdef_nodes_.find(key.first);
   1124     if (pair == gdef_nodes_.end()) {
   1125       // key's node doesn't exist in GraphDef
   1126       missing_unused_input_map_keys_->push_back(key);
   1127       continue;
   1128     }
   1129 
   1130     // Check that key's index is in bounds. Get the number of outputs from the
   1131     // NodeDef, rather than the imported Node, since the Node may not exist if
   1132     // opts_.skip_mapped_nodes is true.
   1133     const NodeDef* node_def = node_defs_[pair->second.gdef_index];
   1134     const OpDef* op_def;
   1135     TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def->op(), &op_def));
   1136     if (key.second >= op_def->output_arg_size()) {
   1137       // key's index out of bounds
   1138       missing_unused_input_map_keys_->push_back(key);
   1139     }
   1140   }
   1141   return Status::OK();
   1142 }
   1143 
   1144 void GraphConstructor::Undo() {
   1145   for (const auto& iter : gdef_nodes_) {
   1146     if (iter.second.node != nullptr) {
   1147       g_->RemoveNode(iter.second.node);
   1148     }
   1149   }
   1150   g_->set_versions(original_versions_);
   1151 }
   1152 
   1153 Status GraphConstructor::MakeEdge(Node* src, int output_index, Node* dst,
   1154                                   int input_index) {
   1155   DataType src_out = src->output_type(output_index);
   1156   DataType dst_in = dst->input_type(input_index);
   1157   if (!TypesCompatible(dst_in, src_out)) {
   1158     return errors::InvalidArgument(
   1159         "Input ", input_index, " of node ", dst->name(), " was passed ",
   1160         DataTypeString(src_out), " from ", src->name(), ":", output_index,
   1161         " incompatible with expected ", DataTypeString(dst_in), ".");
   1162   }
   1163   g_->AddEdge(src, output_index, dst, input_index);
   1164   return Status::OK();
   1165 }
   1166 
   1167 }  // namespace
   1168 
   1169 Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
   1170                               const GraphDef& gdef, Graph* g) {
   1171   ShapeRefiner refiner(gdef.versions().producer(), g->op_registry());
   1172   return GraphConstructor::Construct(
   1173       opts, gdef.node(), &gdef.versions(), &gdef.library(), g, &refiner,
   1174       /*return_tensors=*/nullptr, /*return_nodes=*/nullptr,
   1175       /*missing_unused_input_map_keys=*/nullptr);
   1176 }
   1177 
   1178 Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts,
   1179                               gtl::ArraySlice<NodeDef> nodes, Graph* g) {
   1180   ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, g->op_registry());
   1181   // TODO(irving): Copy will go away once NodeInfo exists
   1182   std::vector<const NodeDef*> node_defs;
   1183   for (const auto& n : nodes) {
   1184     node_defs.push_back(&n);
   1185   }
   1186   return GraphConstructor::Construct(opts, node_defs, nullptr, nullptr, g,
   1187                                      &refiner, /*return_tensors=*/nullptr,
   1188                                      /*return_nodes=*/nullptr,
   1189                                      /*missing_unused_input_map_keys=*/nullptr);
   1190 }
   1191 
   1192 Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef,
   1193                       Graph* g, ShapeRefiner* refiner,
   1194                       ImportGraphDefResults* results) {
   1195   if (!opts.return_tensors.empty()) {
   1196     if (results == nullptr) {
   1197       return errors::InvalidArgument(
   1198           "results argument to ImportGraphDef() must be non-null if "
   1199           "opts.return_tensors is non-empty");
   1200     }
   1201   }
   1202 
   1203   if (!opts.return_nodes.empty()) {
   1204     if (opts.skip_mapped_nodes) {
   1205       return errors::InvalidArgument(
   1206           "Requesting return_nodes with skip_mapped_nodes set is not currently "
   1207           "supported");
   1208     }
   1209     if (results == nullptr) {
   1210       return errors::InvalidArgument(
   1211           "results argument to ImportGraphDef() must be non-null if "
   1212           "opts.return_nodes is non-empty");
   1213     }
   1214   }
   1215 
   1216   if (results != nullptr) {
   1217     if (!results->return_tensors.empty() || !results->return_nodes.empty() ||
   1218         !results->missing_unused_input_map_keys.empty()) {
   1219       return errors::InvalidArgument(
   1220           "All fields in results argument to ImportGraphDef() must be empty.");
   1221     }
   1222   }
   1223 
   1224   ShapeRefiner default_refiner(gdef.versions().producer(), g->op_registry());
   1225   if (refiner == nullptr) {
   1226     refiner = &default_refiner;
   1227   } else {
   1228     // Log a warning if we are importing a GraphDef at an older
   1229     // producer version after already having added non-source/sink
   1230     // nodes to the graph in the past.
   1231     if (gdef.versions().producer() > 0 &&
   1232         gdef.versions().producer() < refiner->graph_def_version() &&
   1233         g->num_nodes() > 2) {
   1234       LOG(WARNING) << "Importing a graph with a lower producer version "
   1235                    << gdef.versions().producer()
   1236                    << " into an existing graph with producer version "
   1237                    << refiner->graph_def_version() << ". Shape inference will "
   1238                    << "have run different parts of the graph with different "
   1239                    << "producer versions.";
   1240     }
   1241   }
   1242 
   1243   // Set the graph def version of the refiner as the min of the
   1244   // current value and the version from the graph we are about to
   1245   // import.
   1246   //
   1247   // Note: to match Run() semantics, we should re-run shape inference
   1248   // on the entire graph if the producer version has changed.  For now
   1249   // we log the warning above.
   1250   refiner->set_graph_def_version(
   1251       std::min(refiner->graph_def_version(), gdef.versions().producer()));
   1252 
   1253   if (results == nullptr) {
   1254     return GraphConstructor::Construct(opts, gdef.node(), &gdef.versions(),
   1255                                        &gdef.library(), g, refiner, nullptr,
   1256                                        nullptr, nullptr);
   1257   } else {
   1258     return GraphConstructor::Construct(
   1259         opts, gdef.node(), &gdef.versions(), &gdef.library(), g, refiner,
   1260         &results->return_tensors, &results->return_nodes,
   1261         &results->missing_unused_input_map_keys);
   1262   }
   1263 }
   1264 
   1265 void CopyGraph(const Graph& src, Graph* dest) {
   1266   for (Node* n : dest->nodes()) {
   1267     CHECK(n->IsSource() || n->IsSink()) << "*dest must be empty";
   1268   }
   1269 
   1270   // Copy GraphDef versions
   1271   dest->set_versions(src.versions());
   1272 
   1273   // Copy the nodes
   1274   std::unordered_map<Node*, Node*>
   1275       node_map;  // "Node in src" -> "Node in *dest"
   1276   node_map[src.source_node()] = dest->source_node();
   1277   node_map[src.sink_node()] = dest->sink_node();
   1278   for (Node* n : src.op_nodes()) {
   1279     node_map[n] = dest->CopyNode(n);
   1280   }
   1281 
   1282   // Copy the edges
   1283   for (const Edge* e : src.edges()) {
   1284     Node* src_copy = node_map[e->src()];
   1285     Node* dst_copy = node_map[e->dst()];
   1286     dest->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input());
   1287   }
   1288 }
   1289 
   1290 }  // namespace tensorflow
   1291