Home | History | Annotate | Download | only in tensorflow_graph_matching
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.h"
     16 
     17 #include <string>
     18 #include <unordered_map>
     19 #include <vector>
     20 
     21 #include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h"
     22 #include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.h"
     23 #include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.h"
     24 #include "tensorflow/contrib/lite/toco/tooling_util.h"
     25 #include "tensorflow/core/framework/attr_value.pb.h"
     26 #include "tensorflow/core/framework/function.pb.h"
     27 #include "tensorflow/core/framework/graph.pb.h"
     28 #include "tensorflow/core/framework/node_def.pb.h"
     29 
     30 namespace toco {
     31 
     32 using tensorflow::GraphDef;
     33 using tensorflow::NodeDef;
     34 
     35 void AddNodeToGraph(const NodeDef& node,
     36                     const std::vector<string>& cluster_names, GraphDef* graph) {
     37   NodeDef* new_node = graph->add_node();
     38   new_node->set_op(node.op());
     39   new_node->set_name(node.name());
     40   new_node->set_device(node.device());
     41   // If the inputs are coming from a node which belongs to another cluster, then
     42   // those inputs are renamed to the source cluster name. Otherwise the original
     43   // input name is used.
     44   for (const string& node_input : node.input()) {
     45     bool input_from_cluster = false;
     46     for (const string& cluster_name : cluster_names) {
     47       if (StrContains(node_input, cluster_name) &&
     48           !StrContains(node.name(), cluster_name)) {
     49         new_node->add_input(cluster_name);
     50         input_from_cluster = true;
     51         break;
     52       }
     53     }
     54     if (!input_from_cluster) {
     55       new_node->add_input(node_input);
     56     }
     57   }
     58   for (const auto& attr : node.attr()) {
     59     (*new_node->mutable_attr())[attr.first] = attr.second;
     60   }
     61 }
     62 
     63 bool FindCluster(const ClusterFactoryInterface& cluster_factory,
     64                  const GraphDef& graph_def,
     65                  std::unordered_map<string, bool>* is_node_in_cluster,
     66                  std::vector<std::unique_ptr<Cluster>>* clusters) {
     67   for (const NodeDef& node : graph_def.node()) {
     68     // If the node is not assigned to any cluster, then we check if it belong to
     69     // the cluster_factory.
     70     bool node_in_cluster = (*is_node_in_cluster)[node.name()];
     71     if (!node_in_cluster) {
     72       std::unique_ptr<Cluster> cluster =
     73           cluster_factory.CreateCluster(node, graph_def);
     74       if (cluster) {
     75         // Label all the nodes in is_node_in_cluster which are in this cluster
     76         // as belonged to this cluster.
     77         for (const NodeDef* cluster_node : cluster->GetNodes()) {
     78           (*is_node_in_cluster)[cluster_node->name()] = true;
     79         }
     80         clusters->push_back(std::move(cluster));
     81       }
     82     }
     83   }
     84   return (!clusters->empty());
     85 }
     86 
     87 std::unique_ptr<GraphDef> MaybeResolveClusters(
     88     const GraphDef& graph_def,
     89     const std::vector<ClusterFactoryInterface*>& cluster_factories) {
     90   std::unique_ptr<GraphDef> pruned_graph(new GraphDef);
     91   // The structure to keep track of which cluster each node is assigned to, and
     92   // to initialize them to all un-assigned,
     93   std::unordered_map<string, bool> is_node_in_cluster;
     94   for (const NodeDef& node : graph_def.node()) {
     95     is_node_in_cluster[node.name()] = false;
     96   }
     97 
     98   std::vector<string> cluster_names;
     99   std::vector<std::unique_ptr<Cluster>> all_clusters;
    100   // Find the clusters for all available cluster factories.
    101   for (const ClusterFactoryInterface* cluster_factory : cluster_factories) {
    102     std::vector<std::unique_ptr<Cluster>> clusters;
    103     if (FindCluster(*cluster_factory, graph_def, &is_node_in_cluster,
    104                     &clusters)) {
    105       for (auto itr = clusters.begin(); itr != clusters.end(); ++itr) {
    106         cluster_names.push_back((*itr)->GetName());
    107         (*itr)->CreateNodes();
    108         all_clusters.push_back(std::move(*itr));
    109       }
    110     }
    111   }
    112 
    113   for (const std::unique_ptr<Cluster>& cluster : all_clusters) {
    114     for (const std::unique_ptr<tensorflow::NodeDef>& src_node :
    115          cluster->GetNewNodes()) {
    116       // Add it to the output GraphDef.
    117       AddNodeToGraph(*src_node, cluster_names, pruned_graph.get());
    118     }
    119   }
    120 
    121   // Add any node which is not part of a cluster.
    122   for (const NodeDef& node : graph_def.node()) {
    123     bool node_in_cluster = is_node_in_cluster[node.name()];
    124     if (!node_in_cluster) {
    125       AddNodeToGraph(node, cluster_names, pruned_graph.get());
    126     }
    127   }
    128 
    129   if (pruned_graph->node_size() == 0) {
    130     return nullptr;
    131   } else {
    132     return pruned_graph;
    133   }
    134 }
    135 
    136 std::unique_ptr<GraphDef> MaybeReplaceCompositeSubgraph(
    137     const GraphDef& tf_graph) {
    138   SvdfClusterFactory svdf_cluster_factory;
    139 
    140   std::vector<ClusterFactoryInterface*> cluster_factories;
    141   cluster_factories.push_back(&svdf_cluster_factory);
    142 
    143   std::unique_ptr<GraphDef> pruned_graph =
    144       MaybeResolveClusters(tf_graph, cluster_factories);
    145 
    146   // Copy function definitions
    147   *(pruned_graph->mutable_library()) = tf_graph.library();
    148   return pruned_graph;
    149 }
    150 
    151 }  // end namespace toco
    152