1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/python/grappler/cost_analyzer.h" 17 18 #include <iomanip> 19 #include "tensorflow/core/grappler/costs/utils.h" 20 #include "tensorflow/core/grappler/grappler_item.h" 21 #include "tensorflow/core/lib/core/status.h" 22 23 namespace tensorflow { 24 namespace grappler { 25 26 CostAnalyzer::CostAnalyzer(const GrapplerItem& item, Cluster* cluster, 27 const string& suffix) 28 : item_(&item), 29 measure_estimator_(cluster, 10, 0), 30 analytical_estimator_(cluster, false), 31 suffix_(suffix) {} 32 33 Status CostAnalyzer::GenerateReport(std::ostream& os, bool per_node_report) { 34 GatherCosts(); 35 PreprocessCosts(); 36 AnalyzeCosts(); 37 PrintAnalysis(os, per_node_report); 38 return Status::OK(); 39 } 40 41 void CostAnalyzer::PredictCosts(CostEstimator* cost_estimator, 42 CostGraphDef* cost_graph, int64* total_time) { 43 TF_CHECK_OK(cost_estimator->Initialize(*item_)); 44 Costs costs; 45 const Status status = 46 cost_estimator->PredictCosts(item_->graph, cost_graph, &costs); 47 *total_time = costs.execution_time.count(); 48 if (!status.ok()) { 49 LOG(ERROR) << "Could not estimate the cost for item " << item_->id << ": " 50 << status.error_message(); 51 return; 52 } 53 } 54 55 void CostAnalyzer::GatherCosts() { 56 CostGraphDef cost_graph_measured; 57 PredictCosts(&measure_estimator_, &cost_graph_measured, 58 &total_time_measured_); 59 VLOG(1) << "Graph size: " << item_->graph.node_size(); 60 VLOG(1) << "cost_graph_measured size: " << cost_graph_measured.node_size(); 61 62 CostGraphDef cost_graph_analytical; 63 PredictCosts(&analytical_estimator_, &cost_graph_analytical, 64 &total_time_analytical_); 65 VLOG(1) << "cost_graph_analytical size: " 66 << cost_graph_analytical.node_size(); 67 68 CostGraphDef cost_graph_analytical_filtered; 69 CostGraphDef cost_graph_measured_filtered; 70 std::map<string, const CostGraphDef_Node*> measured_nodes; 71 for (const auto& node : cost_graph_measured.node()) { 72 measured_nodes[node.name()] = &node; 73 } 74 for (const auto& node : cost_graph_analytical.node()) { 75 auto it = measured_nodes.find(node.name()); 76 // Filter the nodes that are not the cost nodes returned by 77 // MeasuringCostEstimator. 78 if (it == measured_nodes.end()) { 79 continue; 80 } 81 auto added_node_analytical = cost_graph_analytical_filtered.add_node(); 82 auto added_node_measured = cost_graph_measured_filtered.add_node(); 83 *added_node_analytical = node; 84 *added_node_measured = *(it->second); 85 } 86 VLOG(1) << "cost_graph_analytical_filtered size: " 87 << cost_graph_analytical_filtered.node_size(); 88 89 // TODO(yaozhang): add a test to make sure that op_perf_analytical_ and 90 // op_perf_ cover the same set of nodes. 91 op_perf_analytical_ = CostGraphToOpPerformanceData( 92 cost_graph_analytical_filtered, item_->graph); 93 op_perf_ = 94 CostGraphToOpPerformanceData(cost_graph_measured_filtered, item_->graph); 95 } 96 97 void CostAnalyzer::PreprocessCosts() { 98 for (int i = 0; i < op_perf_.op_performance_size(); i++) { 99 OpPerformance* perf = op_perf_.mutable_op_performance(i); 100 const OpPerformance& analytical = op_perf_analytical_.op_performance(i); 101 perf->set_compute_time(analytical.compute_time()); 102 perf->set_memory_time(analytical.memory_time()); 103 double measured_cost = perf->compute_cost(); 104 105 double analytical_compute_cost = analytical.compute_time(); 106 if (analytical_compute_cost == 0) { 107 // Negative infinity indidates unavailable data. 108 perf->set_compute_efficiency(-INFINITY); 109 } else { 110 perf->set_compute_efficiency(analytical_compute_cost / measured_cost); 111 } 112 113 double analytical_memory_cost = analytical.memory_time(); 114 if (analytical_memory_cost == 0) { 115 // Negative infinity indidates unavailable data. 116 perf->set_memory_efficiency(-INFINITY); 117 } else { 118 perf->set_memory_efficiency(analytical_memory_cost / measured_cost); 119 } 120 } 121 } 122 123 124 void CostAnalyzer::SortOpsByTime(std::map<string, OpPerfSummary> ops) { 125 for (const auto& op : ops) { 126 ops_.push_back(op.second); 127 } 128 struct CompareByTime { 129 bool operator()(const OpPerfSummary& a, const OpPerfSummary& b) const { 130 return a.time > b.time; 131 } 132 }; 133 std::stable_sort(ops_.begin(), ops_.end(), CompareByTime()); 134 } 135 136 void CostAnalyzer::AnalyzeCosts() { 137 std::map<string, OpPerfSummary> ops; 138 for (const auto& op_perf : op_perf_.op_performance()) { 139 string op_name = op_perf.op().op(); 140 ops[op_name].count++; 141 ops[op_name].time += op_perf.compute_cost(); 142 ops[op_name].compute_time += op_perf.compute_time(); 143 ops[op_name].memory_time += op_perf.memory_time(); 144 ops[op_name].time_upper += op_perf.compute_time() + op_perf.memory_time(); 145 ops[op_name].time_lower += 146 std::max(op_perf.compute_time(), op_perf.memory_time()); 147 ops[op_name].name = op_name; 148 } 149 SortOpsByTime(ops); 150 151 total_time_measured_serialized_ = 0; 152 total_time_analytical_upper_ = 0; 153 total_time_analytical_lower_ = 0; 154 for (const auto& op : ops_) { 155 total_time_measured_serialized_ += op.time; 156 total_time_analytical_upper_ += op.time_upper; 157 total_time_analytical_lower_ += op.time_lower; 158 } 159 } 160 161 void CostAnalyzer::PrintAnalysis(std::ostream& os, bool per_node_report) const { 162 os << std::endl; 163 os << std::left << std::setw(50) 164 << "Total time measured in ns (serialized): " << std::right 165 << std::setw(20) << total_time_measured_serialized_ << std::endl; 166 os << std::left << std::setw(50) 167 << "Total time measured in ns (actual): " << std::right << std::setw(20) 168 << total_time_measured_ << std::endl; 169 os << std::left << std::setw(50) 170 << "Total time analytical in ns (upper bound): " << std::right 171 << std::setw(20) << total_time_analytical_upper_ << std::endl; 172 os << std::left << std::setw(50) 173 << "Total time analytical in ns (lower bound): " << std::right 174 << std::setw(20) << total_time_analytical_lower_ << std::endl; 175 double efficiency_upper = static_cast<double>(total_time_analytical_upper_) / 176 static_cast<double>(total_time_measured_); 177 os << std::left << std::setw(50) 178 << "Overall efficiency (analytical upper/actual): " << std::right 179 << std::setw(20) << efficiency_upper << std::endl; 180 double efficiency_lower = static_cast<double>(total_time_analytical_lower_) / 181 static_cast<double>(total_time_measured_); 182 os << std::left << std::setw(50) 183 << "Overall efficiency (analytical lower/actual): " << std::right 184 << std::setw(20) << efficiency_lower << std::endl; 185 os << std::endl; 186 187 int width = 35; 188 int width_narrow = 15; 189 int width_wide = 20; 190 os << std::setw(width + 1) << "Op,"; 191 os << std::setw(width_narrow + 1) << "Count,"; 192 os << std::setw(width_wide + 1) << "Measured time (ns),"; 193 os << std::setw(width_narrow + 2) << "Time percent,"; 194 os << std::setw(width_narrow + 2) << "Acc percent,"; 195 os << std::setw(width_wide + 1) << "Analytical upper,"; 196 os << std::setw(width_wide + 1) << "Analytical lower,"; 197 os << std::setw(width_narrow + 2) << "Overall eff"; 198 os << std::setw(width_narrow + 2) << "Compute eff"; 199 os << std::setw(width_narrow + 2) << "Memory eff" << std::endl; 200 float acc_percent = 0; 201 for (const auto& op : ops_) { 202 double percent = static_cast<double>(op.time) / 203 static_cast<double>(total_time_measured_serialized_); 204 double eff = 205 static_cast<double>(op.time_upper) / static_cast<double>(op.time); 206 double compute_eff = 207 static_cast<double>(op.compute_time) / static_cast<double>(op.time); 208 double memory_eff = 209 static_cast<double>(op.memory_time) / static_cast<double>(op.time); 210 os << std::setw(width) << op.name << ","; 211 os << std::setw(width_narrow) << op.count << ","; 212 os << std::setw(width_wide) << op.time << ","; 213 os << std::setw(width_narrow) << std::setprecision(2) << percent * 100 214 << "%,"; 215 acc_percent += percent; 216 os << std::setw(width_narrow) << std::setprecision(2) << acc_percent * 100 217 << "%,"; 218 os << std::setw(width_wide) << op.time_upper << ","; 219 os << std::setw(width_wide) << op.time_lower << ","; 220 os << std::setw(width_narrow) << std::setprecision(2) << eff * 100 << "%,"; 221 os << std::setw(width_narrow) << std::setprecision(2) << compute_eff * 100 222 << "%,"; 223 os << std::setw(width_narrow) << std::setprecision(2) << memory_eff * 100 224 << "%,"; 225 os << std::endl; 226 } 227 os << std::endl; 228 229 if (per_node_report) { 230 os << "Below is the per-node report:" << std::endl; 231 os << op_perf_.DebugString(); 232 } 233 } 234 235 } // end namespace grappler 236 } // end namespace tensorflow 237