1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/grappler/utils/scc.h" 17 #include <stack> 18 #include <unordered_map> 19 #include <unordered_set> 20 #include <vector> 21 #include "tensorflow/core/framework/node_def.pb.h" 22 #include "tensorflow/core/grappler/op_types.h" 23 #include "tensorflow/core/grappler/utils.h" 24 25 namespace tensorflow { 26 namespace grappler { 27 28 // Data structure used to store data for Tarjan's Strongly Connected 29 // Components algorithm. 30 struct SCCNodeData { 31 SCCNodeData() 32 : node(nullptr), 33 index(-1), 34 lowlink(-1), 35 onstack(false), 36 caller(nullptr), 37 caller_loop_location(-1) {} 38 void ResetStack(int new_index, SCCNodeData* new_caller) { 39 index = new_index; 40 lowlink = new_index; 41 onstack = true; 42 caller = new_caller; 43 caller_loop_location = 0; 44 } 45 const NodeDef* node; 46 int index; 47 int lowlink; 48 bool onstack; 49 std::vector<SCCNodeData*> children; 50 // StrongConnect "call stack" storage. 51 SCCNodeData* caller; // Node calling StrongConnect 52 int caller_loop_location; // Index in parent StrongConnect for loop 53 }; 54 55 // Core DFS step of Tarjan's Strongly Connected Component algorithm 56 // (implemented using iteration instead of recursion). 57 void StrongConnect(SCCNodeData* v, std::stack<SCCNodeData*>* stack, int* index, 58 std::unordered_map<const NodeDef*, int>* components, 59 int* scc_index) { 60 // Iterative version of Tarjan's StrongConnect function. 61 // The "call stack" state is composed of a SCCNodeData's caller and 62 // caller_loop_location properties. 63 v->ResetStack(*index /* index */, nullptr /* caller */); 64 ++*index; 65 stack->push(v); 66 67 // No one put v on a StrongConnect call stack, reset caller values. 68 v->caller = nullptr; 69 v->caller_loop_location = 0; 70 71 SCCNodeData* last = v; 72 while (true) { 73 if (last->caller_loop_location < last->children.size()) { 74 // Recursive equivalent: Looping over the children of v (possibly 75 // continuing at v->caller_loop_location after having finished a 76 // recursive call. 77 SCCNodeData* w = last->children[last->caller_loop_location]; 78 ++(last->caller_loop_location); // For loop iterator increment 79 if (w->index == -1) { 80 w->ResetStack(*index /* index */, last /* caller */); 81 ++*index; 82 stack->push(w); 83 last = w; 84 } else if (w->onstack == true) { 85 last->lowlink = std::min(last->lowlink, w->index); 86 } 87 } else { 88 // At the end of v's children 89 if (last->lowlink == last->index) { 90 // v is the root of a strongly connected component 91 SCCNodeData* top; 92 while (true) { 93 top = stack->top(); 94 stack->pop(); 95 top->onstack = false; 96 (*components)[top->node] = *scc_index; 97 if (top == last) { 98 break; 99 } 100 } 101 ++*scc_index; 102 } 103 104 // Go up the recursive call stack 105 SCCNodeData* next_last = last->caller; 106 if (next_last == nullptr) { 107 // All nodes have been seen; finished. 108 break; 109 } else { 110 next_last->lowlink = std::min(next_last->lowlink, last->lowlink); 111 last = next_last; 112 } 113 } 114 } 115 } 116 117 // This is an implementation of Tarjan's Strongly Connected Components 118 // DFS algorithm. Most of the hard work is done in the function 119 // StrongConnect, which is an iterative reimplementation of the 120 // recursive version described here: 121 // https://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm 122 // 123 // The edges for the purpose of this algorithm are directed from input 124 // to op (the reverse of the declarations of the NodeDef, which 125 // contain in-edges) 126 void StronglyConnectedComponents( 127 const GraphDef& graph, std::unordered_map<const NodeDef*, int>* components, 128 int* num_components) { 129 std::stack<SCCNodeData*> stack; 130 std::unordered_map<string, SCCNodeData*> name_to_data; 131 std::vector<SCCNodeData> node_data_container; 132 node_data_container.reserve(graph.node_size()); 133 std::unordered_map<const NodeDef*, SCCNodeData*> node_to_data; 134 135 for (const NodeDef& node : graph.node()) { 136 SCCNodeData node_data; 137 node_data.node = &node; 138 node_data_container.push_back(node_data); 139 name_to_data[node.name()] = &(*node_data_container.rbegin()); 140 node_to_data[&node] = &(*node_data_container.rbegin()); 141 } 142 143 // Create a list of top-level parents (add them to object queue) 144 // Also create a mapping from nodes to their children. 145 for (const NodeDef& node : graph.node()) { 146 for (const string& input : node.input()) { 147 name_to_data[NodeName(input)]->children.push_back(node_to_data[&node]); 148 } 149 } 150 151 components->clear(); 152 *num_components = 0; 153 int index = 0; 154 for (auto& v : node_data_container) { 155 if (v.index == -1) { 156 // Node has not yet been visited. Start a DFS at v. 157 StrongConnect(&v, &stack, &index, components, num_components); 158 } 159 } 160 161 std::vector<int> counts_per_component(*num_components, 0); 162 for (auto& component : *components) { 163 DCHECK(component.second >= 0); 164 DCHECK(component.second < *num_components); 165 counts_per_component[component.second]++; 166 } 167 bool has_single_element_component = false; 168 for (auto& component : *components) { 169 if (counts_per_component[component.second] == 1) { 170 component.second = -1; 171 (*num_components)--; 172 has_single_element_component = true; 173 } 174 } 175 if (has_single_element_component) { 176 (*num_components) += 1; 177 } 178 } 179 180 int IdentifyLoops(const GraphDef& graph, 181 std::unordered_map<const NodeDef*, std::vector<int>>* loops) { 182 int num_components = 0; 183 std::unordered_map<const NodeDef*, int> components; 184 StronglyConnectedComponents(graph, &components, &num_components); 185 if (num_components <= 1) { 186 if (!components.empty() && components.begin()->second == -1) { 187 return 0; 188 } 189 } 190 191 std::unordered_map<int, std::vector<const NodeDef*>> component_ids; 192 for (const auto it : components) { 193 int id = it.second; 194 if (id < 0) { 195 continue; 196 } 197 component_ids[id].push_back(it.first); 198 } 199 200 int loop_id = 0; 201 for (const auto& component : component_ids) { 202 const std::vector<const NodeDef*>& component_nodes = component.second; 203 std::vector<std::pair<NodeDef*, string>> next_iter_nodes; 204 GraphDef subgraph; 205 206 for (const auto& component_node : component_nodes) { 207 NodeDef* node = subgraph.add_node(); 208 *node = *component_node; 209 if (IsNextIteration(*node)) { 210 CHECK_EQ(1, node->input_size()); 211 next_iter_nodes.emplace_back(node, node->input(0)); 212 } 213 } 214 if (next_iter_nodes.size() == 1) { 215 for (const auto& component_node : component_nodes) { 216 (*loops)[component_node].push_back(loop_id); 217 } 218 ++loop_id; 219 } else { 220 for (int i = 0; i < next_iter_nodes.size(); ++i) { 221 for (int j = 0; j < next_iter_nodes.size(); ++j) { 222 next_iter_nodes[j].first->clear_input(); 223 if (i == j) { 224 *next_iter_nodes[j].first->add_input() = next_iter_nodes[j].second; 225 } 226 } 227 int num_components = 0; 228 std::unordered_map<const NodeDef*, int> components; 229 StronglyConnectedComponents(subgraph, &components, &num_components); 230 CHECK_EQ(1, num_components); 231 for (const auto it : components) { 232 int id = it.second; 233 if (id < 0) { 234 continue; 235 } 236 (*loops)[it.first].push_back(loop_id); 237 } 238 ++loop_id; 239 } 240 } 241 } 242 243 return loop_id; 244 } 245 246 } // namespace grappler 247 } // namespace tensorflow 248