Home | History | Annotate | Download | only in graph
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/graph/algorithm.h"
     17 
     18 #include <algorithm>
     19 #include <deque>
     20 #include <vector>
     21 
     22 #include "tensorflow/core/platform/logging.h"
     23 
     24 namespace tensorflow {
     25 
     26 void DFS(const Graph& g, const std::function<void(Node*)>& enter,
     27          const std::function<void(Node*)>& leave,
     28          const NodeComparator& stable_comparator) {
     29   // Stack of work to do.
     30   struct Work {
     31     Node* node;
     32     bool leave;  // Are we entering or leaving n?
     33   };
     34   std::vector<Work> stack;
     35   stack.push_back(Work{g.source_node(), false});
     36 
     37   std::vector<bool> visited(g.num_node_ids(), false);
     38   while (!stack.empty()) {
     39     Work w = stack.back();
     40     stack.pop_back();
     41 
     42     Node* n = w.node;
     43     if (w.leave) {
     44       leave(n);
     45       continue;
     46     }
     47 
     48     if (visited[n->id()]) continue;
     49     visited[n->id()] = true;
     50     if (enter) enter(n);
     51 
     52     // Arrange to call leave(n) when all done with descendants.
     53     if (leave) stack.push_back(Work{n, true});
     54 
     55     gtl::iterator_range<NeighborIter> nodes = n->out_nodes();
     56     auto add_work = [&visited, &stack](Node* out) {
     57       if (!visited[out->id()]) {
     58         // Note; we must not mark as visited until we actually process it.
     59         stack.push_back(Work{out, false});
     60       }
     61     };
     62 
     63     if (stable_comparator) {
     64       std::vector<Node*> nodes_sorted;
     65       for (Node* out : nodes) {
     66         nodes_sorted.emplace_back(out);
     67       }
     68       std::sort(nodes_sorted.begin(), nodes_sorted.end(), stable_comparator);
     69       for (Node* out : nodes_sorted) {
     70         add_work(out);
     71       }
     72     } else {
     73       for (Node* out : nodes) {
     74         add_work(out);
     75       }
     76     }
     77   }
     78 }
     79 
     80 void ReverseDFS(const Graph& g, const std::function<void(Node*)>& enter,
     81                 const std::function<void(Node*)>& leave,
     82                 const NodeComparator& stable_comparator) {
     83   ReverseDFSFrom(g, {g.sink_node()}, enter, leave, stable_comparator);
     84 }
     85 
     86 namespace {
     87 
     88 template <typename T>
     89 void ReverseDFSFromHelper(const Graph& g, gtl::ArraySlice<T> start,
     90                           const std::function<void(T)>& enter,
     91                           const std::function<void(T)>& leave,
     92                           const NodeComparator& stable_comparator) {
     93   // Stack of work to do.
     94   struct Work {
     95     T node;
     96     bool leave;  // Are we entering or leaving n?
     97   };
     98   std::vector<Work> stack(start.size());
     99   for (int i = 0; i < start.size(); ++i) {
    100     stack[i] = Work{start[i], false};
    101   }
    102 
    103   std::vector<bool> visited(g.num_node_ids(), false);
    104   while (!stack.empty()) {
    105     Work w = stack.back();
    106     stack.pop_back();
    107 
    108     T n = w.node;
    109     if (w.leave) {
    110       leave(n);
    111       continue;
    112     }
    113 
    114     if (visited[n->id()]) continue;
    115     visited[n->id()] = true;
    116     if (enter) enter(n);
    117 
    118     // Arrange to call leave(n) when all done with descendants.
    119     if (leave) stack.push_back(Work{n, true});
    120 
    121     gtl::iterator_range<NeighborIter> nodes = n->in_nodes();
    122 
    123     auto add_work = [&visited, &stack](T out) {
    124       if (!visited[out->id()]) {
    125         // Note; we must not mark as visited until we actually process it.
    126         stack.push_back(Work{out, false});
    127       }
    128     };
    129 
    130     if (stable_comparator) {
    131       std::vector<T> nodes_sorted;
    132       for (T in : nodes) {
    133         nodes_sorted.emplace_back(in);
    134       }
    135       std::sort(nodes_sorted.begin(), nodes_sorted.end(), stable_comparator);
    136       for (T in : nodes_sorted) {
    137         add_work(in);
    138       }
    139     } else {
    140       for (T in : nodes) {
    141         add_work(in);
    142       }
    143     }
    144   }
    145 }
    146 
    147 }  // namespace
    148 
    149 void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<const Node*> start,
    150                     const std::function<void(const Node*)>& enter,
    151                     const std::function<void(const Node*)>& leave,
    152                     const NodeComparator& stable_comparator) {
    153   ReverseDFSFromHelper(g, start, enter, leave, stable_comparator);
    154 }
    155 
    156 void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<Node*> start,
    157                     const std::function<void(Node*)>& enter,
    158                     const std::function<void(Node*)>& leave,
    159                     const NodeComparator& stable_comparator) {
    160   ReverseDFSFromHelper(g, start, enter, leave, stable_comparator);
    161 }
    162 
    163 void GetPostOrder(const Graph& g, std::vector<Node*>* order,
    164                   const NodeComparator& stable_comparator) {
    165   order->clear();
    166   DFS(g, nullptr, [order](Node* n) { order->push_back(n); }, stable_comparator);
    167 }
    168 
    169 void GetReversePostOrder(const Graph& g, std::vector<Node*>* order,
    170                          const NodeComparator& stable_comparator) {
    171   GetPostOrder(g, order, stable_comparator);
    172   std::reverse(order->begin(), order->end());
    173 }
    174 
    175 bool PruneForReverseReachability(Graph* g,
    176                                  std::unordered_set<const Node*> visited) {
    177   // Compute set of nodes that we need to traverse in order to reach
    178   // the nodes in "nodes" by performing a breadth-first search from those
    179   // nodes, and accumulating the visited nodes.
    180   std::deque<const Node*> queue;
    181   for (const Node* n : visited) {
    182     VLOG(2) << "Reverse reach init: " << n->name();
    183     queue.push_back(n);
    184   }
    185   while (!queue.empty()) {
    186     const Node* n = queue.front();
    187     queue.pop_front();
    188     for (const Node* in : n->in_nodes()) {
    189       if (visited.insert(in).second) {
    190         queue.push_back(in);
    191         VLOG(2) << "Reverse reach : " << n->name() << " from " << in->name();
    192       }
    193     }
    194   }
    195 
    196   // Make a pass over the graph to remove nodes not in "visited"
    197   std::vector<Node*> all_nodes;
    198   all_nodes.reserve(g->num_nodes());
    199   for (Node* n : g->nodes()) {
    200     all_nodes.push_back(n);
    201   }
    202 
    203   bool any_removed = false;
    204   for (Node* n : all_nodes) {
    205     if (visited.count(n) == 0 && !n->IsSource() && !n->IsSink()) {
    206       g->RemoveNode(n);
    207       any_removed = true;
    208     }
    209   }
    210 
    211   return any_removed;
    212 }
    213 
    214 bool FixupSourceAndSinkEdges(Graph* g) {
    215   // Connect all nodes with no incoming edges to source.
    216   // Connect all nodes with no outgoing edges to sink.
    217   bool changed = false;
    218   for (Node* n : g->nodes()) {
    219     if (!n->IsSource() && n->in_edges().empty()) {
    220       g->AddControlEdge(g->source_node(), n);
    221       changed = true;
    222     }
    223     if (!n->IsSink() && n->out_edges().empty()) {
    224       g->AddControlEdge(n, g->sink_node());
    225       changed = true;
    226     }
    227   }
    228   return changed;
    229 }
    230 
    231 }  // namespace tensorflow
    232