Home | History | Annotate | Download | only in graph
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/graph/graph.h"
     17 
     18 #include <vector>
     19 #include "tensorflow/core/framework/graph.pb.h"
     20 #include "tensorflow/core/framework/node_def.pb.h"
     21 #include "tensorflow/core/framework/node_def_util.h"
     22 #include "tensorflow/core/framework/op_kernel.h"
     23 #include "tensorflow/core/framework/versions.pb.h"
     24 #include "tensorflow/core/graph/while_context.h"
     25 #include "tensorflow/core/lib/core/errors.h"
     26 #include "tensorflow/core/lib/gtl/map_util.h"
     27 #include "tensorflow/core/lib/hash/hash.h"
     28 #include "tensorflow/core/lib/strings/strcat.h"
     29 #include "tensorflow/core/lib/strings/stringprintf.h"
     30 #include "tensorflow/core/platform/logging.h"
     31 #include "tensorflow/core/public/version.h"
     32 
     33 namespace tensorflow {
     34 
     35 const int Graph::kControlSlot = -1;
     36 
     37 struct NodeProperties {
     38  public:
     39   NodeProperties(const OpDef* op_def, const NodeDef& node_def,
     40                  const DataTypeSlice inputs, const DataTypeSlice outputs)
     41       : op_def(op_def),
     42         node_def(node_def),
     43         input_types(inputs.begin(), inputs.end()),
     44         output_types(outputs.begin(), outputs.end()) {}
     45 
     46   const OpDef* op_def;  // not owned
     47   NodeDef node_def;
     48   const DataTypeVector input_types;
     49   const DataTypeVector output_types;
     50 };
     51 
     52 // Node
     53 
     54 #define REF_CLASS(key, value) \
     55   {key, value}, { "Ref" key, value }
     56 
     57 const std::unordered_map<string, Node::NodeClass>& Node::kNodeClassTable =
     58     *new std::unordered_map<string, Node::NodeClass>({
     59         // Keep in same order as NodeClass values
     60         REF_CLASS("Switch", NC_SWITCH),
     61         REF_CLASS("Merge", NC_MERGE),
     62         REF_CLASS("Enter", NC_ENTER),
     63         REF_CLASS("Exit", NC_EXIT),
     64         REF_CLASS("NextIteration", NC_NEXT_ITERATION),
     65         {"LoopCond", NC_LOOP_COND},
     66         {"ControlTrigger", NC_CONTROL_TRIGGER},
     67         {"_Send", NC_SEND},
     68         {"_HostSend", NC_HOST_SEND},
     69         {"_Recv", NC_RECV},
     70         {"_HostRecv", NC_HOST_RECV},
     71         {"Const", NC_CONSTANT},
     72         {"HostConst", NC_CONSTANT},
     73         {"Variable", NC_VARIABLE},
     74         {"VariableV2", NC_VARIABLE},
     75         REF_CLASS("Identity", NC_IDENTITY),
     76         {"GetSessionHandle", NC_GET_SESSION_HANDLE},
     77         {"GetSessionHandleV2", NC_GET_SESSION_HANDLE},
     78         {"GetSessionTensor", NC_GET_SESSION_TENSOR},
     79         {"DeleteSessionTensor", NC_DELETE_SESSION_TENSOR},
     80         {"Size", NC_METADATA},
     81         {"Shape", NC_METADATA},
     82         {"Rank", NC_METADATA},
     83         {"_ScopedAllocator", NC_SCOPED_ALLOCATOR},
     84         {"CollectiveReduce", NC_COLLECTIVE},
     85         {"CollectiveBcastSend", NC_COLLECTIVE},
     86         {"CollectiveBcastRecv", NC_COLLECTIVE},
     87         {"FakeParam", NC_FAKE_PARAM},
     88         {"PartitionedCall", NC_PARTITIONED_CALL},
     89         {"StatefulPartitionedCall", NC_PARTITIONED_CALL},
     90         // Not using the constants defined in FunctionLibraryDefinition for the
     91         // 4 ops below because android inference library does not link
     92         // tf.function related files.
     93         {"_Arg", NC_ARG},
     94         {"_DeviceArg", NC_ARG},
     95         {"_Retval", NC_RETVAL},
     96         {"_DeviceRetval", NC_RETVAL},
     97     });
     98 
     99 #undef REF_CLASS
    100 
    101 Node::NodeClass Node::GetNodeClassForOp(const string& ts) {
    102   auto it = kNodeClassTable.find(ts);
    103   if (it != kNodeClassTable.end()) {
    104     return it->second;
    105   } else {
    106     return NC_OTHER;
    107   }
    108 }
    109 
    110 string Node::DebugString() const {
    111   string ret = strings::StrCat("{name:'", name(), "' id:", id_);
    112   if (IsSource()) {
    113     strings::StrAppend(&ret, " source}");
    114   } else if (IsSink()) {
    115     strings::StrAppend(&ret, " sink}");
    116   } else {
    117     strings::StrAppend(&ret, " op device:");
    118     strings::StrAppend(&ret, "{", assigned_device_name(), "}");
    119     strings::StrAppend(&ret, " def:{", SummarizeNode(*this), "}}");
    120   }
    121   return ret;
    122 }
    123 
    124 Node::Node()
    125     : id_(-1),
    126       cost_id_(-1),
    127       class_(NC_UNINITIALIZED),
    128       props_(nullptr),
    129       assigned_device_name_index_(0),
    130       while_ctx_(nullptr) {}
    131 
    132 void Node::Initialize(int id, int cost_id,
    133                       std::shared_ptr<NodeProperties> props) {
    134   DCHECK_EQ(id_, -1);
    135   DCHECK(in_edges_.empty());
    136   DCHECK(out_edges_.empty());
    137   id_ = id;
    138   cost_id_ = cost_id;
    139 
    140   props_ = std::move(props);
    141   // Initialize the class_ based on the type string
    142   class_ = GetNodeClassForOp(props_->node_def.op());
    143 }
    144 
    145 void Node::Clear() {
    146   in_edges_.clear();
    147   out_edges_.clear();
    148   id_ = -1;
    149   cost_id_ = -1;
    150   class_ = NC_UNINITIALIZED;
    151   props_.reset();
    152   assigned_device_name_index_ = 0;
    153 }
    154 
    155 void Node::UpdateProperties() {
    156   DataTypeVector inputs;
    157   DataTypeVector outputs;
    158   Status status =
    159       InOutTypesForNode(props_->node_def, *(props_->op_def), &inputs, &outputs);
    160   if (!status.ok()) {
    161     LOG(ERROR) << "Failed at updating node: " << status;
    162     return;
    163   }
    164   props_ = std::make_shared<NodeProperties>(props_->op_def, props_->node_def,
    165                                             inputs, outputs);
    166 }
    167 
    168 const string& Node::name() const { return props_->node_def.name(); }
    169 const string& Node::type_string() const { return props_->node_def.op(); }
    170 const NodeDef& Node::def() const { return props_->node_def; }
    171 const OpDef& Node::op_def() const { return *props_->op_def; }
    172 
    173 int32 Node::num_inputs() const { return props_->input_types.size(); }
    174 DataType Node::input_type(int32 i) const { return props_->input_types[i]; }
    175 const DataTypeVector& Node::input_types() const { return props_->input_types; }
    176 
    177 int32 Node::num_outputs() const { return props_->output_types.size(); }
    178 DataType Node::output_type(int32 o) const { return props_->output_types[o]; }
    179 const DataTypeVector& Node::output_types() const {
    180   return props_->output_types;
    181 }
    182 
    183 AttrSlice Node::attrs() const { return AttrSlice(def()); }
    184 
    185 const protobuf::RepeatedPtrField<string>& Node::requested_inputs() const {
    186   return def().input();
    187 }
    188 
    189 const string& Node::requested_device() const { return def().device(); }
    190 
    191 gtl::iterator_range<NeighborIter> Node::out_nodes() const {
    192   return gtl::make_range(NeighborIter(out_edges_.begin(), false),
    193                          NeighborIter(out_edges_.end(), false));
    194 }
    195 
    196 gtl::iterator_range<NeighborIter> Node::in_nodes() const {
    197   return gtl::make_range(NeighborIter(in_edges_.begin(), true),
    198                          NeighborIter(in_edges_.end(), true));
    199 }
    200 
    201 void Node::MaybeCopyOnWrite() {
    202   // NodeProperties may be shared between Nodes. Make a copy if so.
    203   if (!props_.unique()) {
    204     props_ = std::make_shared<NodeProperties>(*props_);
    205   }
    206 }
    207 
    208 AttrValue* Node::AddAttrHelper(const string& name) {
    209   MaybeCopyOnWrite();
    210   return &((*props_->node_def.mutable_attr())[name]);
    211 }
    212 
    213 void Node::ClearAttr(const string& name) {
    214   MaybeCopyOnWrite();
    215   (*props_->node_def.mutable_attr()).erase(name);
    216 }
    217 
    218 void Node::set_name(string name) {
    219   MaybeCopyOnWrite();
    220   props_->node_def.set_name(std::move(name));
    221 }
    222 
    223 void Node::set_requested_device(const string& device) {
    224   MaybeCopyOnWrite();
    225   props_->node_def.set_device(device);
    226 }
    227 
    228 void Node::set_original_node_names(const std::vector<string>& names) {
    229   MaybeCopyOnWrite();
    230   props_->node_def.mutable_experimental_debug_info()
    231       ->clear_original_node_names();
    232   if (!names.empty()) {
    233     *props_->node_def.mutable_experimental_debug_info()
    234          ->mutable_original_node_names() = {names.begin(), names.end()};
    235   }
    236 }
    237 
    238 Status Node::input_edge(int idx, const Edge** e) const {
    239   if (idx < 0 || idx >= num_inputs()) {
    240     return errors::InvalidArgument("Invalid input_edge index: ", idx, ", Node ",
    241                                    name(), " only has ", num_inputs(),
    242                                    " inputs.");
    243   }
    244 
    245   // This does a linear search over the edges.  In the common case,
    246   // the number of elements is small enough that this search isn't
    247   // expensive.  Should it become a bottleneck, one can make an
    248   // optimization where, if the number of edges is small, we use
    249   // linear iteration, and if the number of edges is large, we perform
    250   // an indexing step during construction that keeps an array of Edges
    251   // indexed by pointer.  This would keep the size of each Node small
    252   // in the common case but make this function faster when the number
    253   // of edges is large.
    254   for (const Edge* edge : in_edges()) {
    255     if (edge->dst_input() == idx) {
    256       *e = edge;
    257       return Status::OK();
    258     }
    259   }
    260 
    261   return errors::NotFound("Could not find input edge ", idx, " for ", name());
    262 }
    263 
    264 // Returns a vector of the non-control input edges to a node, indexed by ID.
    265 Status Node::input_edges(std::vector<const Edge*>* input_edges) const {
    266   input_edges->clear();
    267   input_edges->resize(num_inputs(), nullptr);
    268 
    269   for (const Edge* edge : in_edges()) {
    270     if (edge->IsControlEdge()) continue;
    271     if (edge->dst_input() < 0 || edge->dst_input() >= num_inputs()) {
    272       return errors::Internal("Invalid edge input number ", edge->dst_input());
    273     }
    274     if ((*input_edges)[edge->dst_input()] != nullptr) {
    275       return errors::Internal("Duplicate edge input number: ",
    276                               edge->dst_input());
    277     }
    278     (*input_edges)[edge->dst_input()] = edge;
    279   }
    280 
    281   for (int i = 0; i < num_inputs(); ++i) {
    282     if ((*input_edges)[i] == nullptr) {
    283       return errors::InvalidArgument("Missing edge input number: ", i);
    284     }
    285   }
    286   return Status::OK();
    287 }
    288 
    289 Status Node::input_node(int idx, Node** n) const {
    290   const Edge* e;
    291   TF_RETURN_IF_ERROR(input_edge(idx, &e));
    292   if (e == nullptr) {
    293     *n = nullptr;
    294   } else {
    295     *n = e->src();
    296   }
    297   return Status::OK();
    298 }
    299 
    300 Status Node::input_node(int idx, const Node** const_n) const {
    301   Node* n;
    302   TF_RETURN_IF_ERROR(input_node(idx, &n));
    303   *const_n = n;
    304   return Status::OK();
    305 }
    306 
    307 Status Node::input_tensor(int idx, OutputTensor* t) const {
    308   const Edge* e;
    309   TF_RETURN_IF_ERROR(input_edge(idx, &e));
    310   DCHECK(e != nullptr);
    311   *t = OutputTensor(e->src(), e->src_output());
    312   return Status::OK();
    313 }
    314 
    315 // NodeDebugInfo
    316 
    317 NodeDebugInfo::NodeDebugInfo(const Node& n) : NodeDebugInfo(n.def()) {}
    318 NodeDebugInfo::NodeDebugInfo(const NodeDef& ndef) : name(ndef.name()) {
    319   if (ndef.has_experimental_debug_info()) {
    320     const auto& names = ndef.experimental_debug_info().original_node_names();
    321     original_node_names.assign(names.begin(), names.end());
    322   }
    323 }
    324 
    325 // InputTensor
    326 
    327 bool InputTensor::operator==(const InputTensor& other) const {
    328   return node == other.node && index == other.index;
    329 }
    330 
    331 uint64 InputTensor::Hash::operator()(InputTensor const& s) const {
    332   return Hash64Combine(std::hash<const Node*>()(s.node),
    333                        std::hash<int>()(s.index));
    334 }
    335 
    336 // OutputTensor
    337 
    338 bool OutputTensor::operator==(const OutputTensor& other) const {
    339   return node == other.node && index == other.index;
    340 }
    341 
    342 uint64 OutputTensor::Hash::operator()(OutputTensor const& s) const {
    343   return Hash64Combine(std::hash<const Node*>()(s.node),
    344                        std::hash<int>()(s.index));
    345 }
    346 
    347 // Graph
    348 
    349 Graph::Graph(const OpRegistryInterface* ops)
    350     : ops_(ops, FunctionDefLibrary()),
    351       versions_(new VersionDef),
    352       arena_(8 << 10 /* 8kB */) {
    353   versions_->set_producer(TF_GRAPH_DEF_VERSION);
    354   versions_->set_min_consumer(TF_GRAPH_DEF_VERSION_MIN_CONSUMER);
    355 
    356   // Initialize the name interning table for assigned_device_name.
    357   device_names_.push_back("");
    358   DCHECK_EQ(0, InternDeviceName(""));
    359 
    360   // Source and sink have no endpoints, just control edges.
    361   NodeDef def;
    362   def.set_name("_SOURCE");
    363   def.set_op("NoOp");
    364   Status status;
    365   Node* source = AddNode(def, &status);
    366   TF_CHECK_OK(status);
    367   CHECK_EQ(source->id(), kSourceId);
    368 
    369   def.set_name("_SINK");
    370   Node* sink = AddNode(def, &status);
    371   TF_CHECK_OK(status);
    372   CHECK_EQ(sink->id(), kSinkId);
    373 
    374   AddControlEdge(source, sink);
    375 }
    376 
    377 Graph::Graph(const FunctionLibraryDefinition& flib_def)
    378     : Graph(flib_def.default_registry()) {
    379   // Need a new-enough consumer to support the functions we add to the graph.
    380   if (flib_def.ToProto().function_size() > 0 &&
    381       versions_->min_consumer() < 12) {
    382     versions_->set_min_consumer(12);
    383   }
    384   Status s = ops_.AddLibrary(flib_def);
    385   CHECK(s.ok()) << s.error_message();
    386 }
    387 
    388 Graph::~Graph() {
    389   // Manually call the destructors for all the Nodes we constructed using
    390   // placement new.
    391   for (Node* node : nodes_) {
    392     if (node != nullptr) {
    393       node->~Node();
    394     }
    395   }
    396   for (Node* node : free_nodes_) {
    397     node->~Node();
    398   }
    399   // Edges have no destructor, and we arena-allocated them, so no need to
    400   // destroy them.
    401 }
    402 
    403 const VersionDef& Graph::versions() const { return *versions_; }
    404 void Graph::set_versions(const VersionDef& versions) { *versions_ = versions; }
    405 
    406 Node* Graph::AddNode(const NodeDef& node_def, Status* status) {
    407   const OpDef* op_def;
    408   status->Update(ops_.LookUpOpDef(node_def.op(), &op_def));
    409   if (!status->ok()) return nullptr;
    410 
    411   DataTypeVector inputs;
    412   DataTypeVector outputs;
    413   status->Update(InOutTypesForNode(node_def, *op_def, &inputs, &outputs));
    414   if (!status->ok()) {
    415     *status = AttachDef(*status, node_def);
    416     return nullptr;
    417   }
    418 
    419   Node* node = AllocateNode(
    420       std::make_shared<NodeProperties>(op_def, node_def, inputs, outputs),
    421       nullptr);
    422   return node;
    423 }
    424 
    425 Node* Graph::CopyNode(const Node* node) {
    426   DCHECK(!node->IsSource());
    427   DCHECK(!node->IsSink());
    428   Node* copy = AllocateNode(node->props_, node);
    429   copy->set_assigned_device_name(node->assigned_device_name());
    430 
    431   // Since the OpDef of a function may be owned by the Graph that owns 'node',
    432   // relookup the OpDef in the target graph. If it differs, then clone the
    433   // node properties with the updated OpDef.
    434   const OpDef* op_def;
    435   TF_CHECK_OK(ops_.LookUpOpDef(node->type_string(), &op_def));
    436   if (op_def != node->props_->op_def) {
    437     copy->MaybeCopyOnWrite();
    438     copy->props_->op_def = op_def;
    439   }
    440 
    441   return copy;
    442 }
    443 
    444 void Graph::RemoveNode(Node* node) {
    445   TF_DCHECK_OK(IsValidNode(node)) << node->DebugString();
    446   DCHECK(!node->IsSource());
    447   DCHECK(!node->IsSink());
    448 
    449   // Remove any edges involving this node.
    450   while (!node->in_edges_.empty()) {
    451     RemoveEdge(*node->in_edges_.begin());
    452   }
    453   while (!node->out_edges_.empty()) {
    454     RemoveEdge(*node->out_edges_.begin());
    455   }
    456   ReleaseNode(node);
    457 }
    458 
    459 const Edge* Graph::AddEdge(Node* source, int x, Node* dest, int y) {
    460   TF_DCHECK_OK(IsValidNode(source)) << source->DebugString();
    461   TF_DCHECK_OK(IsValidNode(dest)) << dest->DebugString();
    462 
    463   // source/sink must only be linked via control slots, and
    464   // control slots must only be linked to control slots.
    465   if (source == source_node() || dest == sink_node() || x == kControlSlot ||
    466       y == kControlSlot) {
    467     DCHECK_EQ(x, kControlSlot) << source->DebugString();
    468     DCHECK_EQ(y, kControlSlot) << dest->DebugString();
    469   }
    470 
    471   Edge* e = nullptr;
    472   if (free_edges_.empty()) {
    473     e = new (arena_.Alloc(sizeof(Edge))) Edge;  // placement new
    474   } else {
    475     e = free_edges_.back();
    476     free_edges_.pop_back();
    477   }
    478   e->id_ = edges_.size();
    479   e->src_ = source;
    480   e->dst_ = dest;
    481   e->src_output_ = x;
    482   e->dst_input_ = y;
    483   CHECK(source->out_edges_.insert(e).second);
    484   CHECK(dest->in_edges_.insert(e).second);
    485   edges_.push_back(e);
    486   ++num_edges_;
    487   return e;
    488 }
    489 
    490 void Graph::RemoveEdge(const Edge* e) {
    491   TF_DCHECK_OK(IsValidNode(e->src_)) << e->src_->DebugString();
    492   TF_DCHECK_OK(IsValidNode(e->dst_)) << e->dst_->DebugString();
    493   CHECK_EQ(e->src_->out_edges_.erase(e), size_t{1});
    494   CHECK_EQ(e->dst_->in_edges_.erase(e), size_t{1});
    495   CHECK_EQ(e, edges_[e->id_]);
    496   CHECK_GT(num_edges_, 0);
    497 
    498   edges_[e->id_] = nullptr;
    499 
    500   Edge* del = const_cast<Edge*>(e);
    501   del->src_ = nullptr;
    502   del->dst_ = nullptr;
    503   del->id_ = -1;
    504   del->src_output_ = kControlSlot - 1;
    505   del->dst_input_ = kControlSlot - 1;
    506   free_edges_.push_back(del);
    507   --num_edges_;
    508 }
    509 
    510 const Edge* Graph::AddControlEdge(Node* source, Node* dest,
    511                                   bool allow_duplicates) {
    512   if (!allow_duplicates) {
    513     for (const Edge* edge : dest->in_edges()) {
    514       if (edge->IsControlEdge() && edge->src() == source) {
    515         // The requested edge already exists.
    516         return nullptr;
    517       }
    518     }
    519   }
    520   // Modify dest's NodeDef if necessary.
    521   if (!source->IsSource() && !dest->IsSink() && !allow_duplicates) {
    522     // Check if this input is already in dest's NodeDef.
    523     const string new_input = strings::StrCat("^", source->name());
    524     bool input_exists = false;
    525     for (const string& input : dest->props_->node_def.input()) {
    526       if (input == new_input) {
    527         input_exists = true;
    528         break;
    529       }
    530     }
    531     if (!input_exists) {
    532       dest->MaybeCopyOnWrite();
    533       dest->props_->node_def.add_input(new_input);
    534     }
    535   }
    536   return AddEdge(source, kControlSlot, dest, kControlSlot);
    537 }
    538 
    539 void Graph::RemoveControlEdge(const Edge* e) {
    540   if (!e->src_->IsSource() && !e->dst_->IsSink()) {
    541     e->dst_->MaybeCopyOnWrite();
    542     string e_src_name = strings::StrCat("^", e->src_->name());
    543     auto* inputs = e->dst_->props_->node_def.mutable_input();
    544     for (auto it = inputs->begin(); it != inputs->end(); ++it) {
    545       if (*it == e_src_name) {
    546         inputs->erase(it);
    547         break;
    548       }
    549     }
    550   }
    551   RemoveEdge(e);
    552 }
    553 
    554 namespace {
    555 const Edge* FindEdge(const Node* dst, int index) {
    556   for (const Edge* e : dst->in_edges()) {
    557     if (e->dst_input() == index) return e;
    558   }
    559   return nullptr;
    560 }
    561 }  // namespace
    562 
    563 Status Graph::UpdateEdge(Node* new_src, int new_src_index, Node* dst,
    564                          int dst_index) {
    565   TF_RETURN_IF_ERROR(IsValidOutputTensor(new_src, new_src_index));
    566   TF_RETURN_IF_ERROR(IsValidInputTensor(dst, dst_index));
    567   const Edge* e = FindEdge(dst, dst_index);
    568   if (e == nullptr) {
    569     return errors::InvalidArgument("Couldn't find edge to ",
    570                                    FormatNodeForError(*dst));
    571   }
    572   RemoveEdge(e);
    573   AddEdge(new_src, new_src_index, dst, dst_index);
    574   dst->MaybeCopyOnWrite();
    575   (*dst->props_->node_def.mutable_input())[dst_index] =
    576       strings::StrCat(new_src->name(), ":", new_src_index);
    577   return Status::OK();
    578 }
    579 
    580 Status Graph::AddWhileInputHack(Node* new_src, int new_src_index, Node* dst) {
    581   if (dst->type_string() != "While") {
    582     return errors::Internal(
    583         "dst argument to AddWhileEdgeHack should be a While op, got: ",
    584         dst->DebugString());
    585   }
    586   TF_RETURN_IF_ERROR(IsValidOutputTensor(new_src, new_src_index));
    587   // Find the current number of data inputs. We'll add the new edge to the next
    588   // missing data input.
    589   int dst_index = 0;
    590   for (const Edge* edge : dst->in_edges()) {
    591     if (edge->IsControlEdge()) continue;
    592     ++dst_index;
    593   }
    594   TF_RETURN_IF_ERROR(IsValidInputTensor(dst, dst_index));
    595   AddEdge(new_src, new_src_index, dst, dst_index);
    596   dst->MaybeCopyOnWrite();
    597   dst->props_->node_def.add_input(
    598       strings::StrCat(new_src->name(), ":", new_src_index));
    599   return Status::OK();
    600 }
    601 
    602 Status Graph::AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) {
    603   // Need a new-enough consumer to support the functions we add to the graph.
    604   if (fdef_lib.function_size() > 0 && versions_->min_consumer() < 12) {
    605     versions_->set_min_consumer(12);
    606   }
    607   return ops_.AddLibrary(fdef_lib);
    608 }
    609 
    610 namespace {
    611 
    612 void AddInput(NodeDef* dst, StringPiece src_name, int src_slot) {
    613   if (src_slot == Graph::kControlSlot) {
    614     dst->add_input(strings::StrCat("^", src_name));
    615   } else if (src_slot == 0) {
    616     dst->add_input(src_name.data(), src_name.size());
    617   } else {
    618     dst->add_input(strings::StrCat(src_name, ":", src_slot));
    619   }
    620 }
    621 
    622 }  // namespace
    623 
    624 void Graph::ToGraphDef(GraphDef* graph_def) const {
    625   ToGraphDefSubRange(graph_def, 0);
    626 }
    627 
    628 GraphDef Graph::ToGraphDefDebug() const {
    629   GraphDef ret;
    630   ToGraphDef(&ret);
    631   return ret;
    632 }
    633 
    634 void Graph::ToGraphDefSubRange(GraphDef* graph_def, int from_node_id) const {
    635   graph_def->Clear();
    636   *graph_def->mutable_versions() = versions();
    637   *graph_def->mutable_library() = ops_.ToProto();
    638 
    639   graph_def->mutable_node()->Reserve(std::max(1, num_nodes() - from_node_id));
    640 
    641   std::vector<const Edge*>
    642       inputs;  // Construct this outside the loop for speed.
    643   for (auto id = from_node_id; id < num_node_ids(); ++id) {
    644     const Node* node = FindNodeId(id);
    645     if (node == nullptr || !node->IsOp()) continue;
    646     NodeDef* node_def = graph_def->add_node();
    647     *node_def = node->def();
    648 
    649     // Use the node's assigned device, if any, instead of the device requested
    650     // in the NodeDef.
    651     if (!node->assigned_device_name().empty()) {
    652       node_def->set_device(node->assigned_device_name());
    653     }
    654 
    655     // Get the inputs for this Node.  We make sure control inputs are
    656     // after data inputs, as required by GraphDef.
    657     inputs.clear();
    658     inputs.resize(node->num_inputs(), nullptr);
    659     for (const Edge* edge : node->in_edges()) {
    660       if (edge->IsControlEdge()) {
    661         inputs.push_back(edge);
    662       } else {
    663         CHECK(inputs[edge->dst_input()] == nullptr)
    664             << "Edge " << edge->src()->DebugString() << ":"
    665             << edge->dst()->DebugString() << " with dst_input "
    666             << edge->dst_input() << " and had pre-existing input edge "
    667             << inputs[edge->dst_input()]->src()->DebugString() << ":"
    668             << inputs[edge->dst_input()]->dst()->DebugString();
    669 
    670         inputs[edge->dst_input()] = edge;
    671       }
    672     }
    673     // Sort the control inputs for more predictable serialization.
    674     std::sort(inputs.begin() + node->num_inputs(), inputs.end(),
    675               [](const Edge* a, const Edge* b) -> bool {
    676                 return a->src()->name() < b->src()->name();
    677               });
    678     node_def->clear_input();
    679     node_def->mutable_input()->Reserve(inputs.size());
    680 
    681     for (size_t i = 0; i < inputs.size(); ++i) {
    682       const Edge* edge = inputs[i];
    683       if (edge == nullptr) {
    684         if (i < node->requested_inputs().size()) {
    685           node_def->add_input(node->requested_inputs()[i]);
    686         } else {
    687           node_def->add_input("");
    688         }
    689       } else {
    690         const Node* src = edge->src();
    691         if (!src->IsOp()) continue;
    692         AddInput(node_def, src->name(), edge->src_output());
    693       }
    694     }
    695   }
    696 }
    697 
    698 string Graph::NewName(StringPiece prefix) {
    699   return strings::StrCat(prefix, "/_", name_counter_++);
    700 }
    701 
    702 Status Graph::IsValidNode(const Node* node) const {
    703   if (node == nullptr) {
    704     return errors::InvalidArgument("Node is null");
    705   }
    706   const int id = node->id();
    707   if (id < 0) {
    708     return errors::InvalidArgument("node id ", id, " is less than zero");
    709   }
    710   if (static_cast<size_t>(id) >= nodes_.size()) {
    711     return errors::InvalidArgument(
    712         "node id ", id, " is >= than number of nodes in graph ", nodes_.size());
    713   }
    714   if (nodes_[id] != node) {
    715     return errors::InvalidArgument("Node with id ", id,
    716                                    " is different from the passed in node. "
    717                                    "Does it belong to a different graph?");
    718   }
    719   return Status::OK();
    720 }
    721 
    722 Status Graph::IsValidOutputTensor(const Node* node, int idx) const {
    723   TF_RETURN_IF_ERROR(IsValidNode(node));
    724   if (idx >= node->num_outputs() || idx < 0) {
    725     return errors::OutOfRange("Node '", node->name(), "' (type: '",
    726                               node->op_def().name(),
    727                               "', num of outputs: ", node->num_outputs(),
    728                               ") does not have ", "output ", idx);
    729   }
    730   return Status::OK();
    731 }
    732 
    733 Status Graph::IsValidInputTensor(const Node* node, int idx) const {
    734   TF_RETURN_IF_ERROR(IsValidNode(node));
    735   if (idx >= node->num_inputs() || idx < 0) {
    736     return errors::OutOfRange("Node '", node->name(), "' (type: '",
    737                               node->op_def().name(),
    738                               "', num of inputs: ", node->num_inputs(),
    739                               ") does not have ", "input ", idx);
    740   }
    741   return Status::OK();
    742 }
    743 
    744 Node* Graph::AllocateNode(std::shared_ptr<NodeProperties> props,
    745                           const Node* cost_node) {
    746   Node* node = nullptr;
    747   if (free_nodes_.empty()) {
    748     node = new (arena_.Alloc(sizeof(Node))) Node;  // placement new
    749   } else {
    750     node = free_nodes_.back();
    751     free_nodes_.pop_back();
    752   }
    753   node->graph_ = this;
    754   const int id = nodes_.size();
    755   int cost_id = cost_node ? cost_node->cost_id() : id;
    756   node->Initialize(id, cost_id, std::move(props));
    757   nodes_.push_back(node);
    758   ++num_nodes_;
    759   return node;
    760 }
    761 
    762 void Graph::ReleaseNode(Node* node) {
    763   TF_DCHECK_OK(IsValidNode(node)) << node->DebugString();
    764   nodes_[node->id()] = nullptr;
    765   free_nodes_.push_back(node);
    766   --num_nodes_;
    767   node->Clear();
    768 }
    769 
    770 // Ensures that 'device_name' is present in the device name table, and returns
    771 // the index of that device name. The index is stable, and can be used in
    772 // calls to Node::set_assigned_device_name_index().
    773 int Graph::InternDeviceName(const string& device_name) {
    774   // Special case, very common.  Also, this allows us to use a single map
    775   // lookup below, instead of two.  The 'if (index_cell > 0)' test below
    776   // relies on this check.
    777   if (device_name.empty()) {
    778     return 0;
    779   }
    780 
    781   int& index_cell = device_names_map_[device_name];
    782   if (index_cell > 0) {
    783     return index_cell;
    784   }
    785 
    786   const int index = device_names_map_.size();
    787   index_cell = index;
    788   device_names_.push_back(device_name);
    789   return index;
    790 }
    791 
    792 Status Graph::AddWhileContext(StringPiece frame_name,
    793                               std::vector<Node*> enter_nodes,
    794                               std::vector<Node*> exit_nodes,
    795                               OutputTensor cond_output,
    796                               std::vector<OutputTensor> body_inputs,
    797                               std::vector<OutputTensor> body_outputs,
    798                               WhileContext** result) {
    799   auto pair = while_ctxs_.insert(std::pair<string, WhileContext>(
    800       string(frame_name),
    801       WhileContext(frame_name, std::move(enter_nodes), std::move(exit_nodes),
    802                    cond_output, std::move(body_inputs),
    803                    std::move(body_outputs))));
    804   if (!pair.second) {
    805     *result = nullptr;
    806     return errors::InvalidArgument("WhileContext with frame name '", frame_name,
    807                                    "' already exists");
    808   }
    809   *result = &pair.first->second;
    810   return Status::OK();
    811 }
    812 
    813 std::unordered_map<string, Node*> Graph::BuildNodeNameIndex() const {
    814   std::unordered_map<string, Node*> result;
    815   for (Node* n : nodes()) {
    816     result[n->name()] = n;
    817   }
    818   return result;
    819 }
    820 
    821 string Edge::DebugString() const {
    822   return strings::Printf("[id=%d %s:%d -> %s:%d]", id_, src_->name().c_str(),
    823                          src_output_, dst_->name().c_str(), dst_input_);
    824 }
    825 
    826 }  // namespace tensorflow
    827