1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.h" 16 17 #include <string> 18 #include <unordered_map> 19 #include <vector> 20 21 #include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster.h" 22 #include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/cluster_utils.h" 23 #include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_svdf.h" 24 #include "tensorflow/contrib/lite/toco/tooling_util.h" 25 #include "tensorflow/core/framework/attr_value.pb.h" 26 #include "tensorflow/core/framework/function.pb.h" 27 #include "tensorflow/core/framework/graph.pb.h" 28 #include "tensorflow/core/framework/node_def.pb.h" 29 30 namespace toco { 31 32 using tensorflow::GraphDef; 33 using tensorflow::NodeDef; 34 35 void AddNodeToGraph(const NodeDef& node, 36 const std::vector<string>& cluster_names, GraphDef* graph) { 37 NodeDef* new_node = graph->add_node(); 38 new_node->set_op(node.op()); 39 new_node->set_name(node.name()); 40 new_node->set_device(node.device()); 41 // If the inputs are coming from a node which belongs to another cluster, then 42 // those inputs are renamed to the source cluster name. Otherwise the original 43 // input name is used. 44 for (const string& node_input : node.input()) { 45 bool input_from_cluster = false; 46 for (const string& cluster_name : cluster_names) { 47 if (StrContains(node_input, cluster_name) && 48 !StrContains(node.name(), cluster_name)) { 49 new_node->add_input(cluster_name); 50 input_from_cluster = true; 51 break; 52 } 53 } 54 if (!input_from_cluster) { 55 new_node->add_input(node_input); 56 } 57 } 58 for (const auto& attr : node.attr()) { 59 (*new_node->mutable_attr())[attr.first] = attr.second; 60 } 61 } 62 63 bool FindCluster(const ClusterFactoryInterface& cluster_factory, 64 const GraphDef& graph_def, 65 std::unordered_map<string, bool>* is_node_in_cluster, 66 std::vector<std::unique_ptr<Cluster>>* clusters) { 67 for (const NodeDef& node : graph_def.node()) { 68 // If the node is not assigned to any cluster, then we check if it belong to 69 // the cluster_factory. 70 bool node_in_cluster = (*is_node_in_cluster)[node.name()]; 71 if (!node_in_cluster) { 72 std::unique_ptr<Cluster> cluster = 73 cluster_factory.CreateCluster(node, graph_def); 74 if (cluster) { 75 // Label all the nodes in is_node_in_cluster which are in this cluster 76 // as belonged to this cluster. 77 for (const NodeDef* cluster_node : cluster->GetNodes()) { 78 (*is_node_in_cluster)[cluster_node->name()] = true; 79 } 80 clusters->push_back(std::move(cluster)); 81 } 82 } 83 } 84 return (!clusters->empty()); 85 } 86 87 std::unique_ptr<GraphDef> MaybeResolveClusters( 88 const GraphDef& graph_def, 89 const std::vector<ClusterFactoryInterface*>& cluster_factories) { 90 std::unique_ptr<GraphDef> pruned_graph(new GraphDef); 91 // The structure to keep track of which cluster each node is assigned to, and 92 // to initialize them to all un-assigned, 93 std::unordered_map<string, bool> is_node_in_cluster; 94 for (const NodeDef& node : graph_def.node()) { 95 is_node_in_cluster[node.name()] = false; 96 } 97 98 std::vector<string> cluster_names; 99 std::vector<std::unique_ptr<Cluster>> all_clusters; 100 // Find the clusters for all available cluster factories. 101 for (const ClusterFactoryInterface* cluster_factory : cluster_factories) { 102 std::vector<std::unique_ptr<Cluster>> clusters; 103 if (FindCluster(*cluster_factory, graph_def, &is_node_in_cluster, 104 &clusters)) { 105 for (auto itr = clusters.begin(); itr != clusters.end(); ++itr) { 106 cluster_names.push_back((*itr)->GetName()); 107 (*itr)->CreateNodes(); 108 all_clusters.push_back(std::move(*itr)); 109 } 110 } 111 } 112 113 for (const std::unique_ptr<Cluster>& cluster : all_clusters) { 114 for (const std::unique_ptr<tensorflow::NodeDef>& src_node : 115 cluster->GetNewNodes()) { 116 // Add it to the output GraphDef. 117 AddNodeToGraph(*src_node, cluster_names, pruned_graph.get()); 118 } 119 } 120 121 // Add any node which is not part of a cluster. 122 for (const NodeDef& node : graph_def.node()) { 123 bool node_in_cluster = is_node_in_cluster[node.name()]; 124 if (!node_in_cluster) { 125 AddNodeToGraph(node, cluster_names, pruned_graph.get()); 126 } 127 } 128 129 if (pruned_graph->node_size() == 0) { 130 return nullptr; 131 } else { 132 return pruned_graph; 133 } 134 } 135 136 std::unique_ptr<GraphDef> MaybeReplaceCompositeSubgraph( 137 const GraphDef& tf_graph) { 138 SvdfClusterFactory svdf_cluster_factory; 139 140 std::vector<ClusterFactoryInterface*> cluster_factories; 141 cluster_factories.push_back(&svdf_cluster_factory); 142 143 std::unique_ptr<GraphDef> pruned_graph = 144 MaybeResolveClusters(tf_graph, cluster_factories); 145 146 // Copy function definitions 147 *(pruned_graph->mutable_library()) = tf_graph.library(); 148 return pruned_graph; 149 } 150 151 } // end namespace toco 152