1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/grappler/grappler_item.h" 17 18 #include <unordered_map> 19 #include <unordered_set> 20 #include <vector> 21 22 #include "tensorflow/core/framework/attr_value.pb.h" 23 #include "tensorflow/core/framework/node_def.pb.h" 24 #include "tensorflow/core/grappler/op_types.h" 25 #include "tensorflow/core/grappler/utils.h" 26 27 namespace tensorflow { 28 namespace grappler { 29 30 GrapplerItem::GrapplerItem(const GrapplerItem& other, GraphDef&& graphDef) { 31 id = other.id; 32 feed = other.feed; 33 fetch = other.fetch; 34 init_ops = other.init_ops; 35 expected_init_time = other.expected_init_time; 36 save_op = other.save_op; 37 restore_op = other.restore_op; 38 save_restore_loc_tensor = other.save_restore_loc_tensor; 39 queue_runners = other.queue_runners; 40 graph.Swap(&graphDef); 41 } 42 43 std::vector<const NodeDef*> GrapplerItem::MainOpsFanin() const { 44 return ComputeTransitiveFanin(graph, fetch); 45 } 46 47 std::vector<const NodeDef*> GrapplerItem::EnqueueOpsFanin() const { 48 std::vector<string> enqueue_ops; 49 for (const auto& queue_runner : queue_runners) { 50 for (const string& enqueue_op : queue_runner.enqueue_op_name()) { 51 enqueue_ops.push_back(enqueue_op); 52 } 53 } 54 return ComputeTransitiveFanin(graph, enqueue_ops); 55 } 56 57 std::vector<const NodeDef*> GrapplerItem::InitOpsFanin() const { 58 return ComputeTransitiveFanin(graph, init_ops); 59 } 60 61 std::vector<const NodeDef*> GrapplerItem::MainVariables() const { 62 std::vector<const NodeDef*> fanin = ComputeTransitiveFanin(graph, init_ops); 63 std::vector<const NodeDef*> vars; 64 for (const NodeDef* node : fanin) { 65 if (IsVariable(*node)) { 66 vars.push_back(node); 67 } 68 } 69 return vars; 70 } 71 72 std::unordered_set<string> GrapplerItem::NodesToPreserve() const { 73 std::unordered_set<string> result; 74 for (const string& f : fetch) { 75 VLOG(1) << "Add fetch " << f; 76 result.insert(NodeName(f)); 77 } 78 for (const auto& f : feed) { 79 VLOG(1) << "Add feed " << f.first; 80 result.insert(NodeName(f.first)); 81 } 82 for (const auto& node : init_ops) { 83 result.insert(NodeName(node)); 84 } 85 if (!save_op.empty()) { 86 result.insert(NodeName(save_op)); 87 } 88 if (!restore_op.empty()) { 89 result.insert(NodeName(restore_op)); 90 } 91 if (!save_restore_loc_tensor.empty()) { 92 result.insert(NodeName(save_restore_loc_tensor)); 93 } 94 95 for (const auto& queue_runner : queue_runners) { 96 for (const string& enqueue_op : queue_runner.enqueue_op_name()) { 97 result.insert(NodeName(enqueue_op)); 98 } 99 if (!queue_runner.close_op_name().empty()) { 100 result.insert(NodeName(queue_runner.close_op_name())); 101 } 102 if (!queue_runner.cancel_op_name().empty()) { 103 result.insert(NodeName(queue_runner.cancel_op_name())); 104 } 105 } 106 return result; 107 } 108 109 std::vector<const NodeDef*> ComputeTransitiveFanin( 110 const GraphDef& graph, const std::vector<string>& terminal_nodes) { 111 bool ill_formed = false; 112 std::vector<const NodeDef*> result = 113 ComputeTransitiveFanin(graph, terminal_nodes, &ill_formed); 114 CHECK(!ill_formed); 115 return result; 116 } 117 118 std::vector<const NodeDef*> ComputeTransitiveFanin( 119 const GraphDef& graph, const std::vector<string>& terminal_nodes, 120 bool* ill_formed) { 121 *ill_formed = false; 122 std::unordered_map<string, const NodeDef*> name_to_node; 123 std::unordered_map<string, const NodeDef*> name_to_send; 124 for (const auto& node : graph.node()) { 125 name_to_node[node.name()] = &node; 126 if (node.op() == "_Send") { 127 const auto& attr = node.attr(); 128 name_to_send[attr.at("tensor_name").s()] = &node; 129 } 130 } 131 132 std::vector<const NodeDef*> queue; 133 for (const string& root : terminal_nodes) { 134 const NodeDef* node = name_to_node[NodeName(root)]; 135 if (!node) { 136 *ill_formed = true; 137 VLOG(2) << "ComputeTransitiveFanin: problem with root node: " << root; 138 return {}; 139 } 140 queue.push_back(node); 141 } 142 143 std::vector<const NodeDef*> result; 144 std::unordered_set<const NodeDef*> visited; 145 146 while (!queue.empty()) { 147 const NodeDef* node = queue.back(); 148 queue.pop_back(); 149 if (!visited.insert(node).second) { 150 // The node has already been visited. 151 continue; 152 } 153 result.push_back(node); 154 for (const string& input : node->input()) { 155 const NodeDef* in = name_to_node[NodeName(input)]; 156 if (!in) { 157 VLOG(2) << "ComputeTransitiveFanin: problem with node: " << input; 158 *ill_formed = true; 159 return {}; 160 } 161 queue.push_back(in); 162 } 163 if (node->op() == "_Recv") { 164 const auto& attr = node->attr(); 165 const NodeDef* send = name_to_send[attr.at("tensor_name").s()]; 166 if (send) { 167 queue.push_back(send); 168 } 169 // Subgraph after partitioning may have either _Send or _Recv, not both. 170 // So, we do not set ill_formed for missing _Send. 171 } 172 } 173 return result; 174 } 175 176 } // end namespace grappler 177 } // end namespace tensorflow 178