Home | History | Annotate | Download | only in graph
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/graph/algorithm.h"
     17 
     18 #include <algorithm>
     19 #include <deque>
     20 #include <vector>
     21 
     22 #include "tensorflow/core/platform/logging.h"
     23 
     24 namespace tensorflow {
     25 namespace {
     26 template <typename T>
     27 void DFSFromHelper(const Graph& g, gtl::ArraySlice<T> start,
     28                    const std::function<void(T)>& enter,
     29                    const std::function<void(T)>& leave,
     30                    const NodeComparator& stable_comparator,
     31                    const EdgeFilter& edge_filter) {
     32   // Stack of work to do.
     33   struct Work {
     34     T node;
     35     bool leave;  // Are we entering or leaving n?
     36   };
     37   std::vector<Work> stack(start.size());
     38   for (int i = 0; i < start.size(); ++i) {
     39     stack[i] = Work{start[i], false};
     40   }
     41 
     42   std::vector<bool> visited(g.num_node_ids(), false);
     43   while (!stack.empty()) {
     44     Work w = stack.back();
     45     stack.pop_back();
     46 
     47     T n = w.node;
     48     if (w.leave) {
     49       leave(n);
     50       continue;
     51     }
     52 
     53     if (visited[n->id()]) continue;
     54     visited[n->id()] = true;
     55     if (enter) enter(n);
     56 
     57     // Arrange to call leave(n) when all done with descendants.
     58     if (leave) stack.push_back(Work{n, true});
     59 
     60     auto add_work = [&visited, &stack](Node* out) {
     61       if (!visited[out->id()]) {
     62         // Note; we must not mark as visited until we actually process it.
     63         stack.push_back(Work{out, false});
     64       }
     65     };
     66 
     67     if (stable_comparator) {
     68       std::vector<Node*> nodes_sorted;
     69       for (const Edge* out_edge : n->out_edges()) {
     70         if (!edge_filter || edge_filter(*out_edge)) {
     71           nodes_sorted.emplace_back(out_edge->dst());
     72         }
     73       }
     74       std::sort(nodes_sorted.begin(), nodes_sorted.end(), stable_comparator);
     75       for (Node* out : nodes_sorted) {
     76         add_work(out);
     77       }
     78     } else {
     79       for (const Edge* out_edge : n->out_edges()) {
     80         if (!edge_filter || edge_filter(*out_edge)) {
     81           add_work(out_edge->dst());
     82         }
     83       }
     84     }
     85   }
     86 }
     87 }  // namespace
     88 
     89 void DFS(const Graph& g, const std::function<void(Node*)>& enter,
     90          const std::function<void(Node*)>& leave,
     91          const NodeComparator& stable_comparator,
     92          const EdgeFilter& edge_filter) {
     93   DFSFromHelper(g, {g.source_node()}, enter, leave, stable_comparator,
     94                 edge_filter);
     95 }
     96 
     97 void DFSFrom(const Graph& g, gtl::ArraySlice<Node*> start,
     98              const std::function<void(Node*)>& enter,
     99              const std::function<void(Node*)>& leave,
    100              const NodeComparator& stable_comparator,
    101              const EdgeFilter& edge_filter) {
    102   DFSFromHelper(g, start, enter, leave, stable_comparator, edge_filter);
    103 }
    104 
    105 void DFSFrom(const Graph& g, gtl::ArraySlice<const Node*> start,
    106              const std::function<void(const Node*)>& enter,
    107              const std::function<void(const Node*)>& leave,
    108              const NodeComparator& stable_comparator,
    109              const EdgeFilter& edge_filter) {
    110   DFSFromHelper(g, start, enter, leave, stable_comparator, edge_filter);
    111 }
    112 
    113 void ReverseDFS(const Graph& g, const std::function<void(Node*)>& enter,
    114                 const std::function<void(Node*)>& leave,
    115                 const NodeComparator& stable_comparator) {
    116   ReverseDFSFrom(g, {g.sink_node()}, enter, leave, stable_comparator);
    117 }
    118 
    119 namespace {
    120 
    121 template <typename T>
    122 void ReverseDFSFromHelper(const Graph& g, gtl::ArraySlice<T> start,
    123                           const std::function<void(T)>& enter,
    124                           const std::function<void(T)>& leave,
    125                           const NodeComparator& stable_comparator) {
    126   // Stack of work to do.
    127   struct Work {
    128     T node;
    129     bool leave;  // Are we entering or leaving n?
    130   };
    131   std::vector<Work> stack(start.size());
    132   for (int i = 0; i < start.size(); ++i) {
    133     stack[i] = Work{start[i], false};
    134   }
    135 
    136   std::vector<bool> visited(g.num_node_ids(), false);
    137   while (!stack.empty()) {
    138     Work w = stack.back();
    139     stack.pop_back();
    140 
    141     T n = w.node;
    142     if (w.leave) {
    143       leave(n);
    144       continue;
    145     }
    146 
    147     if (visited[n->id()]) continue;
    148     visited[n->id()] = true;
    149     if (enter) enter(n);
    150 
    151     // Arrange to call leave(n) when all done with descendants.
    152     if (leave) stack.push_back(Work{n, true});
    153 
    154     auto add_work = [&visited, &stack](T out) {
    155       if (!visited[out->id()]) {
    156         // Note; we must not mark as visited until we actually process it.
    157         stack.push_back(Work{out, false});
    158       }
    159     };
    160 
    161     if (stable_comparator) {
    162       std::vector<T> nodes_sorted;
    163       for (const Edge* in_edge : n->in_edges()) {
    164         nodes_sorted.emplace_back(in_edge->src());
    165       }
    166       std::sort(nodes_sorted.begin(), nodes_sorted.end(), stable_comparator);
    167       for (T in : nodes_sorted) {
    168         add_work(in);
    169       }
    170     } else {
    171       for (const Edge* in_edge : n->in_edges()) {
    172         add_work(in_edge->src());
    173       }
    174     }
    175   }
    176 }
    177 
    178 }  // namespace
    179 
    180 void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<const Node*> start,
    181                     const std::function<void(const Node*)>& enter,
    182                     const std::function<void(const Node*)>& leave,
    183                     const NodeComparator& stable_comparator) {
    184   ReverseDFSFromHelper(g, start, enter, leave, stable_comparator);
    185 }
    186 
    187 void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<Node*> start,
    188                     const std::function<void(Node*)>& enter,
    189                     const std::function<void(Node*)>& leave,
    190                     const NodeComparator& stable_comparator) {
    191   ReverseDFSFromHelper(g, start, enter, leave, stable_comparator);
    192 }
    193 
    194 void GetPostOrder(const Graph& g, std::vector<Node*>* order,
    195                   const NodeComparator& stable_comparator,
    196                   const EdgeFilter& edge_filter) {
    197   order->clear();
    198   DFS(g, nullptr, [order](Node* n) { order->push_back(n); }, stable_comparator,
    199       edge_filter);
    200 }
    201 
    202 void GetReversePostOrder(const Graph& g, std::vector<Node*>* order,
    203                          const NodeComparator& stable_comparator,
    204                          const EdgeFilter& edge_filter) {
    205   GetPostOrder(g, order, stable_comparator, edge_filter);
    206   std::reverse(order->begin(), order->end());
    207 }
    208 
    209 bool PruneForReverseReachability(Graph* g,
    210                                  std::unordered_set<const Node*> visited) {
    211   // Compute set of nodes that we need to traverse in order to reach
    212   // the nodes in "nodes" by performing a breadth-first search from those
    213   // nodes, and accumulating the visited nodes.
    214   std::deque<const Node*> queue;
    215   for (const Node* n : visited) {
    216     VLOG(2) << "Reverse reach init: " << n->name();
    217     queue.push_back(n);
    218   }
    219   while (!queue.empty()) {
    220     const Node* n = queue.front();
    221     queue.pop_front();
    222     for (const Node* in : n->in_nodes()) {
    223       if (visited.insert(in).second) {
    224         queue.push_back(in);
    225         VLOG(2) << "Reverse reach : " << n->name() << " from " << in->name();
    226       }
    227     }
    228   }
    229 
    230   // Make a pass over the graph to remove nodes not in "visited"
    231   std::vector<Node*> all_nodes;
    232   all_nodes.reserve(g->num_nodes());
    233   for (Node* n : g->nodes()) {
    234     all_nodes.push_back(n);
    235   }
    236 
    237   bool any_removed = false;
    238   for (Node* n : all_nodes) {
    239     if (visited.count(n) == 0 && !n->IsSource() && !n->IsSink()) {
    240       g->RemoveNode(n);
    241       any_removed = true;
    242     }
    243   }
    244 
    245   return any_removed;
    246 }
    247 
    248 bool FixupSourceAndSinkEdges(Graph* g) {
    249   // Connect all nodes with no incoming edges to source.
    250   // Connect all nodes with no outgoing edges to sink.
    251   bool changed = false;
    252   for (Node* n : g->nodes()) {
    253     if (!n->IsSource() && n->in_edges().empty()) {
    254       g->AddControlEdge(g->source_node(), n,
    255                         true /* skip test for duplicates */);
    256       changed = true;
    257     }
    258     if (!n->IsSink() && n->out_edges().empty()) {
    259       g->AddControlEdge(n, g->sink_node(), true /* skip test for duplicates */);
    260       changed = true;
    261     }
    262   }
    263   return changed;
    264 }
    265 
    266 }  // namespace tensorflow
    267