1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/grappler/optimizers/graph_rewriter.h" 17 #include <unordered_map> 18 #include <unordered_set> 19 #include "tensorflow/core/framework/function.pb.h" 20 #include "tensorflow/core/framework/node_def.pb.h" 21 #include "tensorflow/core/framework/node_def_util.h" 22 #include "tensorflow/core/framework/op.h" 23 #include "tensorflow/core/framework/op_def.pb.h" 24 #include "tensorflow/core/grappler/grappler_item.h" 25 #include "tensorflow/core/grappler/op_types.h" 26 #include "tensorflow/core/grappler/utils.h" 27 28 namespace tensorflow { 29 namespace grappler { 30 31 GraphRewriter::GraphRewriter(const GrapplerItem& item) { 32 OpRegistryInterface* op_registry = OpRegistry::Global(); 33 for (auto& node : item.graph.node()) { 34 NodeInfo* info = new NodeInfo(); 35 info->def = &node; 36 37 const OpRegistrationData* op_reg_data = nullptr; 38 Status s = op_registry->LookUp(node.op(), &op_reg_data); 39 // TODO(bsteiner): make this not a best-effort lookup and evaluation? 40 if (s.ok()) { 41 DataTypeVector inputs; 42 s = InOutTypesForNode(node, op_reg_data->op_def, &inputs, &info->outputs); 43 if (!s.ok()) { 44 info->outputs.clear(); 45 } 46 } 47 48 nodes_[node.name()].reset(info); 49 } 50 51 std::unordered_set<string> function_names; 52 for (const auto& function : item.graph.library().function()) { 53 function_names.insert(function.signature().name()); 54 } 55 56 for (auto& node : item.graph.node()) { 57 RecordConnectivity(node, function_names); 58 } 59 } 60 61 void GraphRewriter::ForwardInputs( 62 const NodeDef& original_node, 63 const std::unordered_set<const NodeDef*>& nodes_to_delete, 64 NodeDef* new_node) { 65 ForwardInputsInternal(original_node, nodes_to_delete, false, new_node); 66 if (!new_node->name().empty()) { 67 optimized_nodes_[new_node->name()] = new_node; 68 } 69 // Reorder inputs such that control inputs come after regular inputs. 70 int pos = 0; 71 for (int i = 0; i < new_node->input_size(); ++i) { 72 if (!IsControlInput(new_node->input(i))) { 73 new_node->mutable_input()->SwapElements(pos, i); 74 ++pos; 75 } 76 } 77 DedupControlInputs(new_node); 78 } 79 80 bool GraphRewriter::DrivesControlDependency(const NodeDef& node) const { 81 return control_dependency_drivers_.find(&node) != 82 control_dependency_drivers_.end(); 83 } 84 85 bool GraphRewriter::FeedsMerge(const NodeDef& node) const { 86 return merge_feeders_.find(&node) != merge_feeders_.end(); 87 } 88 89 bool GraphRewriter::IsDrivenByControlDependency(const NodeDef& node) const { 90 for (const auto& input : node.input()) { 91 CHECK(!input.empty()); 92 if (input[0] == '^') { 93 return true; 94 } 95 } 96 return false; 97 } 98 99 bool GraphRewriter::IsConnectedToFunction(const NodeDef& node) const { 100 return function_neighbors_.find(&node) != function_neighbors_.end(); 101 } 102 103 bool GraphRewriter::IsDrivenByAnotherDevice(const NodeDef& node) const { 104 return cross_device_receivers_.find(&node) != cross_device_receivers_.end(); 105 } 106 107 bool GraphRewriter::ReceivesRefValue(const NodeDef& node) const { 108 return ref_receivers_.find(&node) != ref_receivers_.end(); 109 } 110 111 bool GraphRewriter::IsDrivenBySwitch(const NodeDef& node) const { 112 return switch_receivers_.find(&node) != switch_receivers_.end(); 113 } 114 115 bool GraphRewriter::RemovalIncreasesEdgeCount(const NodeDef& node) const { 116 const int in_degree = node.input_size(); 117 auto itr = nodes_.find(node.name()); 118 if (itr == nodes_.end()) { 119 return true; 120 } 121 const int out_degree = itr->second->out_degree; 122 return in_degree * out_degree > in_degree + out_degree; 123 } 124 125 void GraphRewriter::RecordConnectivity( 126 const NodeDef& node, const std::unordered_set<string>& function_names) { 127 const bool is_function = 128 function_names.find(node.op()) != function_names.end(); 129 130 bool ref_receiver = false; 131 bool switch_receiver = false; 132 for (const auto& input : node.input()) { 133 int position = 0; 134 string input_node_name = ParseNodeName(input, &position); 135 auto itr = nodes_.find(input_node_name); 136 if (itr == nodes_.end()) { 137 continue; 138 } 139 140 NodeInfo* fanin_info = itr->second.get(); 141 const NodeDef* fanin = fanin_info->def; 142 if (IsMerge(node)) { 143 merge_feeders_.insert(fanin); 144 } 145 // Update out_degree of fanin. 146 ++fanin_info->out_degree; 147 if (position < 0) { 148 // This is a control edge 149 control_dependency_drivers_.insert(fanin); 150 } else { 151 // This is a regular edge 152 if (function_names.find(fanin->op()) != function_names.end()) { 153 function_neighbors_.insert(&node); 154 } 155 if (is_function) { 156 function_neighbors_.insert(fanin); 157 } 158 if (IsSwitch(*fanin)) { 159 switch_receiver = true; 160 } 161 if (position < fanin_info->outputs.size() && 162 IsRefType(fanin_info->outputs[position])) { 163 ref_receiver = true; 164 } 165 } 166 if (fanin->device() != node.device()) { 167 cross_device_receivers_.insert(&node); 168 } 169 } 170 171 if (ref_receiver) { 172 ref_receivers_.insert(&node); 173 } 174 if (switch_receiver) { 175 switch_receivers_.insert(&node); 176 } 177 } 178 179 void GraphRewriter::ForwardInputsInternal( 180 const NodeDef& node, 181 const std::unordered_set<const NodeDef*>& nodes_to_delete, 182 bool add_as_control, NodeDef* new_node) { 183 // To speed things up, use the optimized version of the node if 184 // available. 185 auto itr = optimized_nodes_.find(node.name()); 186 if (itr != optimized_nodes_.end()) { 187 for (const string& input : itr->second->input()) { 188 *new_node->add_input() = 189 add_as_control ? AsControlDependency(NodeName(input)) : input; 190 } 191 return; 192 } 193 for (const auto& input : node.input()) { 194 const string input_node_name = NodeName(input); 195 auto itr = nodes_.find(input_node_name); 196 if (itr == nodes_.end()) { 197 // Invalid input, preserve it as is. 198 *new_node->add_input() = 199 add_as_control ? AsControlDependency(NodeName(input)) : input; 200 continue; 201 } 202 const NodeDef* input_node = itr->second->def; 203 if (nodes_to_delete.find(input_node) != nodes_to_delete.end()) { 204 ForwardInputsInternal(*input_node, nodes_to_delete, 205 add_as_control || IsControlInput(input), new_node); 206 } else { 207 *new_node->add_input() = 208 add_as_control ? AsControlDependency(NodeName(input)) : input; 209 } 210 } 211 } 212 213 } // end namespace grappler 214 } // end namespace tensorflow 215