Home | History | Annotate | Download | only in utils
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/grappler/utils/scc.h"
     17 #include <stack>
     18 #include <unordered_map>
     19 #include <unordered_set>
     20 #include <vector>
     21 #include "tensorflow/core/framework/node_def.pb.h"
     22 #include "tensorflow/core/grappler/op_types.h"
     23 #include "tensorflow/core/grappler/utils.h"
     24 
     25 namespace tensorflow {
     26 namespace grappler {
     27 
     28 // Data structure used to store data for Tarjan's Strongly Connected
     29 // Components algorithm.
     30 struct SCCNodeData {
     31   SCCNodeData()
     32       : node(nullptr),
     33         index(-1),
     34         lowlink(-1),
     35         onstack(false),
     36         caller(nullptr),
     37         caller_loop_location(-1) {}
     38   void ResetStack(int new_index, SCCNodeData* new_caller) {
     39     index = new_index;
     40     lowlink = new_index;
     41     onstack = true;
     42     caller = new_caller;
     43     caller_loop_location = 0;
     44   }
     45   const NodeDef* node;
     46   int index;
     47   int lowlink;
     48   bool onstack;
     49   std::vector<SCCNodeData*> children;
     50   // StrongConnect "call stack" storage.
     51   SCCNodeData* caller;       // Node calling StrongConnect
     52   int caller_loop_location;  // Index in parent StrongConnect for loop
     53 };
     54 
     55 // Core DFS step of Tarjan's Strongly Connected Component algorithm
     56 // (implemented using iteration instead of recursion).
     57 void StrongConnect(SCCNodeData* v, std::stack<SCCNodeData*>* stack, int* index,
     58                    std::unordered_map<const NodeDef*, int>* components,
     59                    int* scc_index) {
     60   // Iterative version of Tarjan's StrongConnect function.
     61   // The "call stack" state is composed of a SCCNodeData's caller and
     62   // caller_loop_location properties.
     63   v->ResetStack(*index /* index */, nullptr /* caller */);
     64   ++*index;
     65   stack->push(v);
     66 
     67   // No one put v on a StrongConnect call stack, reset caller values.
     68   v->caller = nullptr;
     69   v->caller_loop_location = 0;
     70 
     71   SCCNodeData* last = v;
     72   while (true) {
     73     if (last->caller_loop_location < last->children.size()) {
     74       // Recursive equivalent: Looping over the children of v (possibly
     75       // continuing at v->caller_loop_location after having finished a
     76       // recursive call.
     77       SCCNodeData* w = last->children[last->caller_loop_location];
     78       ++(last->caller_loop_location);  // For loop iterator increment
     79       if (w->index == -1) {
     80         w->ResetStack(*index /* index */, last /* caller */);
     81         ++*index;
     82         stack->push(w);
     83         last = w;
     84       } else if (w->onstack == true) {
     85         last->lowlink = std::min(last->lowlink, w->index);
     86       }
     87     } else {
     88       // At the end of v's children
     89       if (last->lowlink == last->index) {
     90         // v is the root of a strongly connected component
     91         SCCNodeData* top;
     92         while (true) {
     93           top = stack->top();
     94           stack->pop();
     95           top->onstack = false;
     96           (*components)[top->node] = *scc_index;
     97           if (top == last) {
     98             break;
     99           }
    100         }
    101         ++*scc_index;
    102       }
    103 
    104       // Go up the recursive call stack
    105       SCCNodeData* next_last = last->caller;
    106       if (next_last == nullptr) {
    107         // All nodes have been seen; finished.
    108         break;
    109       } else {
    110         next_last->lowlink = std::min(next_last->lowlink, last->lowlink);
    111         last = next_last;
    112       }
    113     }
    114   }
    115 }
    116 
    117 // This is an implementation of Tarjan's Strongly Connected Components
    118 // DFS algorithm.  Most of the hard work is done in the function
    119 // StrongConnect, which is an iterative reimplementation of the
    120 // recursive version described here:
    121 //   https://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm
    122 //
    123 // The edges for the purpose of this algorithm are directed from input
    124 // to op (the reverse of the declarations of the NodeDef, which
    125 // contain in-edges)
    126 void StronglyConnectedComponents(
    127     const GraphDef& graph, std::unordered_map<const NodeDef*, int>* components,
    128     int* num_components) {
    129   std::stack<SCCNodeData*> stack;
    130   std::unordered_map<string, SCCNodeData*> name_to_data;
    131   std::vector<SCCNodeData> node_data_container;
    132   node_data_container.reserve(graph.node_size());
    133   std::unordered_map<const NodeDef*, SCCNodeData*> node_to_data;
    134 
    135   for (const NodeDef& node : graph.node()) {
    136     SCCNodeData node_data;
    137     node_data.node = &node;
    138     node_data_container.push_back(node_data);
    139     name_to_data[node.name()] = &(*node_data_container.rbegin());
    140     node_to_data[&node] = &(*node_data_container.rbegin());
    141   }
    142 
    143   // Create a list of top-level parents (add them to object queue)
    144   // Also create a mapping from nodes to their children.
    145   for (const NodeDef& node : graph.node()) {
    146     for (const string& input : node.input()) {
    147       name_to_data[NodeName(input)]->children.push_back(node_to_data[&node]);
    148     }
    149   }
    150 
    151   components->clear();
    152   *num_components = 0;
    153   int index = 0;
    154   for (auto& v : node_data_container) {
    155     if (v.index == -1) {
    156       // Node has not yet been visited.  Start a DFS at v.
    157       StrongConnect(&v, &stack, &index, components, num_components);
    158     }
    159   }
    160 
    161   std::vector<int> counts_per_component(*num_components, 0);
    162   for (auto& component : *components) {
    163     DCHECK(component.second >= 0);
    164     DCHECK(component.second < *num_components);
    165     counts_per_component[component.second]++;
    166   }
    167   bool has_single_element_component = false;
    168   for (auto& component : *components) {
    169     if (counts_per_component[component.second] == 1) {
    170       component.second = -1;
    171       (*num_components)--;
    172       has_single_element_component = true;
    173     }
    174   }
    175   if (has_single_element_component) {
    176     (*num_components) += 1;
    177   }
    178 }
    179 
    180 int IdentifyLoops(const GraphDef& graph,
    181                   std::unordered_map<const NodeDef*, std::vector<int>>* loops) {
    182   int num_components = 0;
    183   std::unordered_map<const NodeDef*, int> components;
    184   StronglyConnectedComponents(graph, &components, &num_components);
    185   if (num_components <= 1) {
    186     if (!components.empty() && components.begin()->second == -1) {
    187       return 0;
    188     }
    189   }
    190 
    191   std::unordered_map<int, std::vector<const NodeDef*>> component_ids;
    192   for (const auto it : components) {
    193     int id = it.second;
    194     if (id < 0) {
    195       continue;
    196     }
    197     component_ids[id].push_back(it.first);
    198   }
    199 
    200   int loop_id = 0;
    201   for (const auto& component : component_ids) {
    202     const std::vector<const NodeDef*>& component_nodes = component.second;
    203     std::vector<std::pair<NodeDef*, string>> next_iter_nodes;
    204     GraphDef subgraph;
    205 
    206     for (const auto& component_node : component_nodes) {
    207       NodeDef* node = subgraph.add_node();
    208       *node = *component_node;
    209       if (IsNextIteration(*node)) {
    210         CHECK_EQ(1, node->input_size());
    211         next_iter_nodes.emplace_back(node, node->input(0));
    212       }
    213     }
    214     if (next_iter_nodes.size() == 1) {
    215       for (const auto& component_node : component_nodes) {
    216         (*loops)[component_node].push_back(loop_id);
    217       }
    218       ++loop_id;
    219     } else {
    220       for (int i = 0; i < next_iter_nodes.size(); ++i) {
    221         for (int j = 0; j < next_iter_nodes.size(); ++j) {
    222           next_iter_nodes[j].first->clear_input();
    223           if (i == j) {
    224             *next_iter_nodes[j].first->add_input() = next_iter_nodes[j].second;
    225           }
    226         }
    227         int num_components = 0;
    228         std::unordered_map<const NodeDef*, int> components;
    229         StronglyConnectedComponents(subgraph, &components, &num_components);
    230         CHECK_EQ(1, num_components);
    231         for (const auto it : components) {
    232           int id = it.second;
    233           if (id < 0) {
    234             continue;
    235           }
    236           (*loops)[it.first].push_back(loop_id);
    237         }
    238         ++loop_id;
    239       }
    240     }
    241   }
    242 
    243   return loop_id;
    244 }
    245 
    246 }  // namespace grappler
    247 }  // namespace tensorflow
    248