1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/graph/algorithm.h" 17 18 #include <algorithm> 19 #include <deque> 20 #include <vector> 21 22 #include "tensorflow/core/platform/logging.h" 23 24 namespace tensorflow { 25 26 void DFS(const Graph& g, const std::function<void(Node*)>& enter, 27 const std::function<void(Node*)>& leave, 28 const NodeComparator& stable_comparator) { 29 // Stack of work to do. 30 struct Work { 31 Node* node; 32 bool leave; // Are we entering or leaving n? 33 }; 34 std::vector<Work> stack; 35 stack.push_back(Work{g.source_node(), false}); 36 37 std::vector<bool> visited(g.num_node_ids(), false); 38 while (!stack.empty()) { 39 Work w = stack.back(); 40 stack.pop_back(); 41 42 Node* n = w.node; 43 if (w.leave) { 44 leave(n); 45 continue; 46 } 47 48 if (visited[n->id()]) continue; 49 visited[n->id()] = true; 50 if (enter) enter(n); 51 52 // Arrange to call leave(n) when all done with descendants. 53 if (leave) stack.push_back(Work{n, true}); 54 55 gtl::iterator_range<NeighborIter> nodes = n->out_nodes(); 56 auto add_work = [&visited, &stack](Node* out) { 57 if (!visited[out->id()]) { 58 // Note; we must not mark as visited until we actually process it. 59 stack.push_back(Work{out, false}); 60 } 61 }; 62 63 if (stable_comparator) { 64 std::vector<Node*> nodes_sorted; 65 for (Node* out : nodes) { 66 nodes_sorted.emplace_back(out); 67 } 68 std::sort(nodes_sorted.begin(), nodes_sorted.end(), stable_comparator); 69 for (Node* out : nodes_sorted) { 70 add_work(out); 71 } 72 } else { 73 for (Node* out : nodes) { 74 add_work(out); 75 } 76 } 77 } 78 } 79 80 void ReverseDFS(const Graph& g, const std::function<void(Node*)>& enter, 81 const std::function<void(Node*)>& leave, 82 const NodeComparator& stable_comparator) { 83 ReverseDFSFrom(g, {g.sink_node()}, enter, leave, stable_comparator); 84 } 85 86 namespace { 87 88 template <typename T> 89 void ReverseDFSFromHelper(const Graph& g, gtl::ArraySlice<T> start, 90 const std::function<void(T)>& enter, 91 const std::function<void(T)>& leave, 92 const NodeComparator& stable_comparator) { 93 // Stack of work to do. 94 struct Work { 95 T node; 96 bool leave; // Are we entering or leaving n? 97 }; 98 std::vector<Work> stack(start.size()); 99 for (int i = 0; i < start.size(); ++i) { 100 stack[i] = Work{start[i], false}; 101 } 102 103 std::vector<bool> visited(g.num_node_ids(), false); 104 while (!stack.empty()) { 105 Work w = stack.back(); 106 stack.pop_back(); 107 108 T n = w.node; 109 if (w.leave) { 110 leave(n); 111 continue; 112 } 113 114 if (visited[n->id()]) continue; 115 visited[n->id()] = true; 116 if (enter) enter(n); 117 118 // Arrange to call leave(n) when all done with descendants. 119 if (leave) stack.push_back(Work{n, true}); 120 121 gtl::iterator_range<NeighborIter> nodes = n->in_nodes(); 122 123 auto add_work = [&visited, &stack](T out) { 124 if (!visited[out->id()]) { 125 // Note; we must not mark as visited until we actually process it. 126 stack.push_back(Work{out, false}); 127 } 128 }; 129 130 if (stable_comparator) { 131 std::vector<T> nodes_sorted; 132 for (T in : nodes) { 133 nodes_sorted.emplace_back(in); 134 } 135 std::sort(nodes_sorted.begin(), nodes_sorted.end(), stable_comparator); 136 for (T in : nodes_sorted) { 137 add_work(in); 138 } 139 } else { 140 for (T in : nodes) { 141 add_work(in); 142 } 143 } 144 } 145 } 146 147 } // namespace 148 149 void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<const Node*> start, 150 const std::function<void(const Node*)>& enter, 151 const std::function<void(const Node*)>& leave, 152 const NodeComparator& stable_comparator) { 153 ReverseDFSFromHelper(g, start, enter, leave, stable_comparator); 154 } 155 156 void ReverseDFSFrom(const Graph& g, gtl::ArraySlice<Node*> start, 157 const std::function<void(Node*)>& enter, 158 const std::function<void(Node*)>& leave, 159 const NodeComparator& stable_comparator) { 160 ReverseDFSFromHelper(g, start, enter, leave, stable_comparator); 161 } 162 163 void GetPostOrder(const Graph& g, std::vector<Node*>* order, 164 const NodeComparator& stable_comparator) { 165 order->clear(); 166 DFS(g, nullptr, [order](Node* n) { order->push_back(n); }, stable_comparator); 167 } 168 169 void GetReversePostOrder(const Graph& g, std::vector<Node*>* order, 170 const NodeComparator& stable_comparator) { 171 GetPostOrder(g, order, stable_comparator); 172 std::reverse(order->begin(), order->end()); 173 } 174 175 bool PruneForReverseReachability(Graph* g, 176 std::unordered_set<const Node*> visited) { 177 // Compute set of nodes that we need to traverse in order to reach 178 // the nodes in "nodes" by performing a breadth-first search from those 179 // nodes, and accumulating the visited nodes. 180 std::deque<const Node*> queue; 181 for (const Node* n : visited) { 182 VLOG(2) << "Reverse reach init: " << n->name(); 183 queue.push_back(n); 184 } 185 while (!queue.empty()) { 186 const Node* n = queue.front(); 187 queue.pop_front(); 188 for (const Node* in : n->in_nodes()) { 189 if (visited.insert(in).second) { 190 queue.push_back(in); 191 VLOG(2) << "Reverse reach : " << n->name() << " from " << in->name(); 192 } 193 } 194 } 195 196 // Make a pass over the graph to remove nodes not in "visited" 197 std::vector<Node*> all_nodes; 198 all_nodes.reserve(g->num_nodes()); 199 for (Node* n : g->nodes()) { 200 all_nodes.push_back(n); 201 } 202 203 bool any_removed = false; 204 for (Node* n : all_nodes) { 205 if (visited.count(n) == 0 && !n->IsSource() && !n->IsSink()) { 206 g->RemoveNode(n); 207 any_removed = true; 208 } 209 } 210 211 return any_removed; 212 } 213 214 bool FixupSourceAndSinkEdges(Graph* g) { 215 // Connect all nodes with no incoming edges to source. 216 // Connect all nodes with no outgoing edges to sink. 217 bool changed = false; 218 for (Node* n : g->nodes()) { 219 if (!n->IsSource() && n->in_edges().empty()) { 220 g->AddControlEdge(g->source_node(), n); 221 changed = true; 222 } 223 if (!n->IsSink() && n->out_edges().empty()) { 224 g->AddControlEdge(n, g->sink_node()); 225 changed = true; 226 } 227 } 228 return changed; 229 } 230 231 } // namespace tensorflow 232