1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <memory> 17 #include <vector> 18 19 #include "tensorflow/core/framework/attr_value.pb.h" 20 #include "tensorflow/core/framework/function.h" 21 #include "tensorflow/core/framework/op.h" 22 #include "tensorflow/core/framework/op_def.pb.h" 23 #include "tensorflow/core/framework/types.h" 24 #include "tensorflow/core/grappler/utils.h" 25 #include "tensorflow/core/lib/strings/numbers.h" 26 #include "tensorflow/core/lib/strings/scanner.h" 27 #include "tensorflow/core/lib/strings/strcat.h" 28 #include "tensorflow/core/platform/notification.h" 29 30 namespace tensorflow { 31 namespace grappler { 32 33 NodeMap::NodeMap(GraphDef* graph) { 34 CHECK(graph != nullptr); 35 for (int i = 0; i < graph->node_size(); i++) { 36 NodeDef* node = graph->mutable_node(i); 37 const string& node_name = node->name(); 38 auto rslt = nodes_.emplace(node_name, node); 39 // Check that the graph doesn't contain multiple nodes with the same name. 40 if (!rslt.second) { 41 LOG(WARNING) << "Duplicated node in the graph: " << node_name; 42 } 43 for (const auto& input : node->input()) { 44 outputs_[NodeName(input)].insert(nodes_[node_name]); 45 } 46 } 47 } 48 49 void NodeMap::RemoveNode(const string& name) { 50 nodes_.erase(NodeName(name)); 51 outputs_.erase(NodeName(name)); 52 } 53 54 NodeDef* NodeMap::GetNode(const string& name) const { 55 const string node_name = NodeName(name); 56 auto it = nodes_.find(node_name); 57 if (it == nodes_.end()) { 58 return nullptr; 59 } 60 return it->second; 61 } 62 63 bool NodeMap::NodeExists(const string& name) const { 64 const string node_name = NodeName(name); 65 return nodes_.find(node_name) != nodes_.end(); 66 } 67 68 const std::set<NodeDef*>& NodeMap::GetOutputs(const string& node_name) const { 69 auto it = outputs_.find(node_name); 70 if (it == outputs_.end()) { 71 return empty_set_; 72 } 73 return it->second; 74 } 75 76 void NodeMap::AddNode(const string& node_name, NodeDef* node) { 77 auto ret = nodes_.emplace(node_name, CHECK_NOTNULL(node)); 78 CHECK(ret.second) << "Pair (" << node_name << "," << node 79 << ") is not inserted because the same key already exists."; 80 } 81 82 void NodeMap::AddOutput(const string& node_name, const string& output_name) { 83 auto output_node = nodes_[NodeName(output_name)]; 84 CHECK(output_node) << "Output node " << output_name 85 << " is missing in NodeMap."; 86 outputs_[node_name].insert(output_node); 87 } 88 89 void NodeMap::RemoveOutput(const string& node_name, const string& output_name) { 90 outputs_[node_name].erase(nodes_[NodeName(output_name)]); 91 } 92 93 void NodeMap::UpdateInput(const string& node_name, const string& old_input_name, 94 const string& new_input_name) { 95 RemoveOutput(NodeName(old_input_name), node_name); 96 AddOutput(NodeName(new_input_name), node_name); 97 } 98 99 void NodeMap::RemoveInputs(const string& node_name) { 100 auto node = nodes_[node_name]; 101 for (const auto& input : node->input()) { 102 RemoveOutput(NodeName(input), node->name()); 103 } 104 } 105 106 void NodeMap::RemoveOutputs(const string& node_name) { 107 outputs_.erase(node_name); 108 } 109 110 void NodeMap::UpdateOutput(const string& node_name, 111 const string& old_output_name, 112 const string& new_output_name) { 113 std::set<NodeDef*>& outputs = outputs_[node_name]; 114 outputs.erase(nodes_[NodeName(old_output_name)]); 115 outputs.insert(nodes_[NodeName(new_output_name)]); 116 } 117 118 bool IsSameInput(const string& name1, const string& name2) { 119 if (name1 == name2) { 120 return true; 121 } 122 int position1; 123 string node1 = ParseNodeName(name1, &position1); 124 int position2; 125 string node2 = ParseNodeName(name2, &position2); 126 return (position1 == position2) && (node1 == node2); 127 } 128 129 string ParseNodeName(const string& name, int* position) { 130 // Strip the prefix '^' (if any), and strip the trailing ":{digits} (if any) 131 // to get a node name. 132 strings::Scanner scan(name); 133 scan.ZeroOrOneLiteral("^") 134 .RestartCapture() 135 .One(strings::Scanner::LETTER_DIGIT_DOT_UNDERSCORE) 136 .Any(strings::Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE); 137 StringPiece capture; 138 StringPiece remaining; 139 if (scan.Peek(':') != ':' || !scan.GetResult(&remaining, &capture)) { 140 *position = 0; 141 return ""; 142 } else { 143 if (name[0] == '^') { 144 *position = -1; 145 } else if (remaining.empty()) { 146 *position = 0; 147 } else { 148 // Skip the first ':' character. 149 CHECK(strings::safe_strto32(remaining.substr(1), position)); 150 } 151 return capture.ToString(); 152 } 153 } 154 155 bool IsControlInput(const string& name) { 156 return !name.empty() && name[0] == '^'; 157 } 158 159 string NodeName(const string& name) { 160 int position; 161 return ParseNodeName(name, &position); 162 } 163 164 int NodePosition(const string& name) { 165 int position; 166 ParseNodeName(name, &position); 167 return position; 168 } 169 170 string AddPrefixToNodeName(const string& name, const string& prefix, 171 const string& delimiter) { 172 if (!name.empty()) { 173 if (name[0] == '^') { 174 return strings::StrCat("^", prefix, delimiter, name.substr(1)); 175 } 176 } 177 return strings::StrCat(prefix, delimiter, name); 178 } 179 180 string AddPrefixToNodeName(const string& name, const string& prefix) { 181 return AddPrefixToNodeName(name, prefix, "/"); 182 } 183 184 bool ExecuteWithTimeout(std::function<void()> fn, const int64 timeout_in_ms, 185 thread::ThreadPool* const thread_pool) { 186 if (timeout_in_ms <= 0) { 187 fn(); 188 return true; 189 } 190 auto done = std::make_shared<Notification>(); 191 thread_pool->Schedule([done, fn]() { 192 fn(); 193 done->Notify(); 194 }); 195 const bool notified = 196 WaitForNotificationWithTimeout(done.get(), timeout_in_ms * 1000); 197 return notified; 198 } 199 200 string AsControlDependency(const NodeDef& node) { 201 return strings::StrCat("^", node.name()); 202 } 203 204 string AsControlDependency(const string& node_name) { 205 CHECK(!node_name.empty()); 206 return (!node_name.empty() && node_name[0] == '^') 207 ? node_name 208 : strings::StrCat("^", node_name); 209 } 210 211 int NumOutputs(const NodeDef& node, GraphDef* graph) { 212 int num_outputs = 0; 213 const OpDef* op_def = nullptr; 214 auto status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def); 215 if (status.ok()) { 216 for (const auto& output : op_def->output_arg()) { 217 if (!output.type_list_attr().empty()) { 218 num_outputs += 219 node.attr().at(output.type_list_attr()).list().type_size(); 220 } else if (!output.number_attr().empty()) { 221 num_outputs += node.attr().at(output.number_attr()).i(); 222 } else { 223 num_outputs++; 224 } 225 } 226 } else { 227 FunctionLibraryDefinition fdef(OpRegistry::Global(), graph->library()); 228 auto status = fdef.LookUpOpDef(node.op(), &op_def); 229 if (status.ok()) { 230 num_outputs = op_def->output_arg_size(); 231 } 232 } 233 return num_outputs; 234 } 235 236 int NumNonControlInputs(const NodeDef& node) { 237 int num_inputs = node.input_size(); 238 for (const string& input : node.input()) { 239 if (IsControlInput(input)) { 240 --num_inputs; 241 } 242 } 243 return num_inputs; 244 } 245 246 int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map) { 247 int num_outputs = 0; 248 for (const NodeDef* output : node_map.GetOutputs(node.name())) { 249 for (const string& node_as_input : output->input()) { 250 if (IsControlInput(node_as_input)) { 251 break; 252 } 253 if (NodeName(node_as_input) == node.name()) { 254 ++num_outputs; 255 } 256 } 257 } 258 return num_outputs; 259 } 260 261 // Returns the data type in attribute `attr_name` of `node`. If that attribute 262 // doesn't exist, returns DT_INVALID. 263 DataType GetDataTypeFromAttr(const NodeDef& node, const string& attr_name) { 264 if (!node.attr().count(attr_name)) { 265 return DT_INVALID; 266 } 267 const auto& attr = node.attr().at(attr_name); 268 if (attr.value_case() != AttrValue::kType) { 269 return DT_INVALID; 270 } 271 return attr.type(); 272 } 273 274 NodeDef* GetTailOfChain(const NodeDef& source, const NodeMap& node_map, 275 bool follow_control_input, 276 const std::function<bool(const NodeDef&)>& pred_fn) { 277 const NodeDef* current = &source; 278 const NodeDef* next = current; 279 while (next == &source || (next != nullptr && pred_fn(*next))) { 280 current = next; 281 if (current->input_size() == 0 || 282 (!follow_control_input && IsControlInput(current->input(0)))) { 283 break; 284 } 285 next = node_map.GetNode(current->input(0)); 286 if (next == nullptr) { 287 LOG(ERROR) << "Node not found: " << current->input(0); 288 } 289 } 290 return const_cast<NodeDef*>(current); 291 } 292 293 // Every permutation is a product of one or more cycles. Iterate over the cycles 294 // in the permutation, and convert each of those into a product of 295 // transpositions (swaps): https://en.wikipedia.org/wiki/Cyclic_permutation 296 void PermuteNodesInPlace(GraphDef* graph, std::vector<int>* permutation, 297 bool invert_permutation) { 298 CHECK_EQ(graph->node_size(), permutation->size()); 299 std::vector<int> inv_perm(permutation->size(), 0); 300 if (invert_permutation) { 301 for (size_t n = 0; n < permutation->size(); ++n) { 302 inv_perm[(*permutation)[n]] = n; 303 } 304 permutation->swap(inv_perm); 305 } 306 for (std::size_t n = 0; n + 1 < permutation->size(); ++n) { 307 while (n != (*permutation)[n]) { 308 std::size_t r = (*permutation)[n]; 309 graph->mutable_node()->SwapElements(n, r); 310 std::swap((*permutation)[n], (*permutation)[r]); 311 } 312 } 313 } 314 315 void DedupControlInputs(NodeDef* node) { 316 std::unordered_set<string> inputs; 317 int pos = 0; 318 while (pos < node->input_size()) { 319 const string& input = node->input(pos); 320 if (!inputs.insert(NodeName(input)).second && IsControlInput(input)) { 321 node->mutable_input()->SwapElements(pos, node->input_size() - 1); 322 node->mutable_input()->RemoveLast(); 323 } else { 324 ++pos; 325 } 326 } 327 } 328 329 namespace { 330 template <typename T> 331 inline void STLSortAndRemoveDuplicates(T* v) { 332 std::sort(v->begin(), v->end()); 333 v->erase(std::unique(v->begin(), v->end()), v->end()); 334 } 335 } // namespace 336 337 Status SimpleGraphView::Initialize(const GraphDef& graph, bool dedup_inputs, 338 bool dedup_outputs) { 339 const int num_nodes = graph.node_size(); 340 inputs_.clear(); 341 inputs_.resize(num_nodes); 342 outputs_.clear(); 343 outputs_.resize(num_nodes); 344 name_to_index_.clear(); 345 name_to_index_.reserve(num_nodes); 346 index_to_name_.clear(); 347 index_to_name_.reserve(num_nodes); 348 349 // Build map from name to index and vice versa. 350 for (int node_idx = 0; node_idx < num_nodes; ++node_idx) { 351 const NodeDef& node = graph.node(node_idx); 352 name_to_index_.emplace(node.name(), node_idx); 353 index_to_name_.push_back(node.name()); 354 } 355 356 // Build forward and reverse adjacency lists. 357 for (int node_idx = 0; node_idx < num_nodes; ++node_idx) { 358 const NodeDef& node = graph.node(node_idx); 359 inputs_[node_idx].reserve(node.input_size()); 360 for (const string& input : node.input()) { 361 auto it = name_to_index_.find(NodeName(input)); 362 if (it == name_to_index_.end()) { 363 return errors::InvalidArgument("Non-existent input ", input, 364 " for node ", node.name()); 365 } 366 const int input_idx = it->second; 367 inputs_[node_idx].push_back(input_idx); 368 outputs_[input_idx].push_back(node_idx); 369 } 370 if (dedup_inputs) { 371 // Dedup the input list while it's still hot in cache. 372 STLSortAndRemoveDuplicates(&inputs_[node_idx]); 373 } 374 } 375 376 // Dedup outputs. 377 if (dedup_outputs) { 378 for (int node_idx = 0; node_idx < num_nodes; ++node_idx) { 379 STLSortAndRemoveDuplicates(&outputs_[node_idx]); 380 } 381 } 382 return Status::OK(); 383 } 384 385 string SimpleGraphView::PrintToString() const { 386 string str; 387 for (int i = 0; i < num_nodes(); ++i) { 388 strings::StrAppend(&str, "Node ", i, "'", node_name(i), "'\n", "Inputs: ["); 389 for (int input : inputs(i)) { 390 strings::StrAppend(&str, input, " '", node_name(input), "', "); 391 } 392 strings::StrAppend(&str, "]\n", "Outputs: ["); 393 for (int j = 0; j < outputs(i).size(); ++j) { 394 const int output = outputs(i)[j]; 395 if (j > 0) { 396 strings::StrAppend(&str, ", "); 397 } 398 strings::StrAppend(&str, output, " '", node_name(output), "'"); 399 } 400 strings::StrAppend(&str, "]\n"); 401 } 402 return str; 403 } 404 405 } // end namespace grappler 406 } // end namespace tensorflow 407