1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/grappler/optimizers/static_schedule.h" 17 #include <deque> 18 #include "tensorflow/core/framework/attr_value.pb.h" 19 #include "tensorflow/core/grappler/costs/graph_properties.h" 20 #include "tensorflow/core/grappler/costs/op_level_cost_estimator.h" 21 #include "tensorflow/core/grappler/costs/virtual_placer.h" 22 #include "tensorflow/core/grappler/op_types.h" 23 #include "tensorflow/core/grappler/utils.h" 24 #include "tensorflow/core/lib/core/errors.h" 25 #include "tensorflow/core/lib/strings/strcat.h" 26 27 namespace tensorflow { 28 namespace grappler { 29 30 static Costs::NanoSeconds PredictExecutionTime( 31 const GraphProperties& properties, const OpLevelCostEstimator& estimator, 32 const VirtualPlacer& placer, const NodeDef& node) { 33 OpContext op_context; 34 op_context.op_info.set_op(node.op()); 35 *op_context.op_info.mutable_attr() = node.attr(); 36 37 std::vector<OpInfo::TensorProperties> inputs = 38 properties.GetInputProperties(node.name()); 39 for (auto& input : inputs) { 40 op_context.op_info.add_inputs()->Swap(&input); 41 } 42 43 DeviceProperties device = placer.get_device(node); 44 op_context.op_info.mutable_device()->Swap(&device); 45 46 Costs::NanoSeconds estimate = 47 estimator.PredictCosts(op_context).execution_time; 48 49 // Make sure our estimates are at least one nanosecond per node. 50 return std::max(estimate, Costs::NanoSeconds(1)); 51 } 52 53 Status EstimateEarliestExecutionTimes( 54 const GrapplerItem& item, const Cluster* cluster, 55 std::unordered_map<const NodeDef*, Costs::NanoSeconds>* completion_times) { 56 std::unordered_map<string, const NodeDef*> name_map; 57 std::unordered_map<const NodeDef*, int> pending_inputs; 58 std::deque<const NodeDef*> ready_nodes; 59 for (const NodeDef& node : item.graph.node()) { 60 name_map[node.name()] = &node; 61 if (node.input_size() == 0) { 62 ready_nodes.push_back(&node); 63 (*completion_times)[&node] = 0; 64 } else if (IsMerge(node)) { 65 // Merge nodes are processed as soon as one of the input becomes 66 // available. 67 pending_inputs[&node] = 1; 68 } else { 69 pending_inputs[&node] = node.input_size(); 70 } 71 } 72 73 std::unordered_map<const NodeDef*, std::vector<const NodeDef*>> fanouts; 74 for (const NodeDef& node : item.graph.node()) { 75 for (const string& input : node.input()) { 76 string node_name = NodeName(input); 77 auto it = name_map.find(node_name); 78 if (it == name_map.end()) { 79 return errors::InvalidArgument( 80 strings::StrCat("Unknown input node ", input)); 81 } 82 const NodeDef* fanin = it->second; 83 fanouts[fanin].push_back(&node); 84 } 85 } 86 name_map.clear(); 87 88 GraphProperties properties(item); 89 TF_RETURN_IF_ERROR(properties.InferStatically(true)); 90 OpLevelCostEstimator estimator; 91 VirtualPlacer placer(cluster); 92 93 while (!ready_nodes.empty()) { 94 const NodeDef* node = ready_nodes.front(); 95 ready_nodes.pop_front(); 96 97 Costs::NanoSeconds execution_time = 98 PredictExecutionTime(properties, estimator, placer, *node); 99 Costs::NanoSeconds completion_time = 100 execution_time + (*completion_times)[node]; 101 (*completion_times)[node] = completion_time; 102 103 for (const NodeDef* fanout : fanouts[node]) { 104 int pending = pending_inputs[fanout]; 105 if (pending == 0) { 106 // Already processed. Avoid going through loops more than once. 107 continue; 108 } else if (pending == 1) { 109 ready_nodes.push_back(fanout); 110 } 111 pending_inputs[fanout]--; 112 113 Costs::NanoSeconds ready_time = 114 std::max(completion_time, (*completion_times)[fanout]); 115 (*completion_times)[fanout] = ready_time; 116 } 117 } 118 119 return Status::OK(); 120 } 121 122 Status EstimateRequiredTimes( 123 const GrapplerItem& item, const Cluster* cluster, 124 const std::unordered_map<const NodeDef*, Costs::NanoSeconds>& 125 execution_times, 126 std::unordered_map<const NodeDef*, Costs::NanoSeconds>* required_times) { 127 std::unordered_map<string, const NodeDef*> name_map; 128 for (const NodeDef& node : item.graph.node()) { 129 name_map[node.name()] = &node; 130 (*required_times)[&node] = Costs::NanoSeconds::max(); 131 } 132 133 std::unordered_map<const NodeDef*, int> pending_fanouts; 134 for (const NodeDef& node : item.graph.node()) { 135 for (const string& input : node.input()) { 136 string node_name = NodeName(input); 137 auto it = name_map.find(node_name); 138 if (it == name_map.end()) { 139 return errors::InvalidArgument( 140 strings::StrCat("Unknown input node ", input)); 141 } 142 const NodeDef* fanin = it->second; 143 pending_fanouts[fanin] += 1; 144 } 145 } 146 std::deque<const NodeDef*> ready_nodes; 147 for (const NodeDef& node : item.graph.node()) { 148 if (pending_fanouts[&node] == 0) { 149 auto it = execution_times.find(&node); 150 if (it != execution_times.end()) { 151 (*required_times)[&node] = it->second; 152 } 153 ready_nodes.push_back(&node); 154 } 155 } 156 GraphProperties properties(item); 157 TF_RETURN_IF_ERROR(properties.InferStatically(true)); 158 OpLevelCostEstimator estimator; 159 VirtualPlacer placer(cluster); 160 161 while (!ready_nodes.empty()) { 162 const NodeDef* node = ready_nodes.front(); 163 ready_nodes.pop_front(); 164 165 Costs::NanoSeconds execution_time = 166 PredictExecutionTime(properties, estimator, placer, *node); 167 Costs::NanoSeconds required_time = (*required_times)[node] - execution_time; 168 169 for (const string& fanin_name : node->input()) { 170 const NodeDef* fanin = name_map[NodeName(fanin_name)]; 171 (*required_times)[fanin] = 172 std::min((*required_times)[fanin], required_time); 173 174 int pending = pending_fanouts[fanin]; 175 if (pending == 0) { 176 // Already processed. Avoid going through loops more than once. 177 continue; 178 } else if (pending == 1) { 179 ready_nodes.push_back(fanin); 180 } 181 pending_fanouts[fanin]--; 182 } 183 } 184 185 return Status::OK(); 186 } 187 188 } // end namespace grappler 189 } // end namespace tensorflow 190