Home | History | Annotate | Download | only in grappler
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/python/grappler/cost_analyzer.h"
     17 
     18 #include <iomanip>
     19 #include "tensorflow/core/grappler/costs/utils.h"
     20 #include "tensorflow/core/grappler/grappler_item.h"
     21 #include "tensorflow/core/lib/core/status.h"
     22 
     23 namespace tensorflow {
     24 namespace grappler {
     25 
     26 CostAnalyzer::CostAnalyzer(const GrapplerItem& item, Cluster* cluster,
     27                            const string& suffix)
     28     : item_(&item),
     29       measure_estimator_(cluster, 10, 0),
     30       analytical_estimator_(cluster, false),
     31       suffix_(suffix) {}
     32 
     33 Status CostAnalyzer::GenerateReport(std::ostream& os, bool per_node_report) {
     34   GatherCosts();
     35   PreprocessCosts();
     36   AnalyzeCosts();
     37   PrintAnalysis(os, per_node_report);
     38   return Status::OK();
     39 }
     40 
     41 void CostAnalyzer::PredictCosts(CostEstimator* cost_estimator,
     42                                 CostGraphDef* cost_graph, int64* total_time) {
     43   TF_CHECK_OK(cost_estimator->Initialize(*item_));
     44   Costs costs;
     45   const Status status =
     46       cost_estimator->PredictCosts(item_->graph, cost_graph, &costs);
     47   *total_time = costs.execution_time.count();
     48   if (!status.ok()) {
     49     LOG(ERROR) << "Could not estimate the cost for item " << item_->id << ": "
     50                << status.error_message();
     51     return;
     52   }
     53 }
     54 
     55 void CostAnalyzer::GatherCosts() {
     56   CostGraphDef cost_graph_measured;
     57   PredictCosts(&measure_estimator_, &cost_graph_measured,
     58                &total_time_measured_);
     59   VLOG(1) << "Graph size: " << item_->graph.node_size();
     60   VLOG(1) << "cost_graph_measured size: " << cost_graph_measured.node_size();
     61 
     62   CostGraphDef cost_graph_analytical;
     63   PredictCosts(&analytical_estimator_, &cost_graph_analytical,
     64                &total_time_analytical_);
     65   VLOG(1) << "cost_graph_analytical size: "
     66           << cost_graph_analytical.node_size();
     67 
     68   CostGraphDef cost_graph_analytical_filtered;
     69   CostGraphDef cost_graph_measured_filtered;
     70   std::map<string, const CostGraphDef_Node*> measured_nodes;
     71   for (const auto& node : cost_graph_measured.node()) {
     72     measured_nodes[node.name()] = &node;
     73   }
     74   for (const auto& node : cost_graph_analytical.node()) {
     75     auto it = measured_nodes.find(node.name());
     76     // Filter the nodes that are not the cost nodes returned by
     77     // MeasuringCostEstimator.
     78     if (it == measured_nodes.end()) {
     79       continue;
     80     }
     81     auto added_node_analytical = cost_graph_analytical_filtered.add_node();
     82     auto added_node_measured = cost_graph_measured_filtered.add_node();
     83     *added_node_analytical = node;
     84     *added_node_measured = *(it->second);
     85   }
     86   VLOG(1) << "cost_graph_analytical_filtered size: "
     87           << cost_graph_analytical_filtered.node_size();
     88 
     89   // TODO(yaozhang): add a test to make sure that op_perf_analytical_ and
     90   // op_perf_ cover the same set of nodes.
     91   op_perf_analytical_ = CostGraphToOpPerformanceData(
     92       cost_graph_analytical_filtered, item_->graph);
     93   op_perf_ =
     94       CostGraphToOpPerformanceData(cost_graph_measured_filtered, item_->graph);
     95 }
     96 
     97 void CostAnalyzer::PreprocessCosts() {
     98   for (int i = 0; i < op_perf_.op_performance_size(); i++) {
     99     OpPerformance* perf = op_perf_.mutable_op_performance(i);
    100     const OpPerformance& analytical = op_perf_analytical_.op_performance(i);
    101     perf->set_compute_time(analytical.compute_time());
    102     perf->set_memory_time(analytical.memory_time());
    103     double measured_cost = perf->compute_cost();
    104 
    105     double analytical_compute_cost = analytical.compute_time();
    106     if (analytical_compute_cost == 0) {
    107       // Negative infinity indidates unavailable data.
    108       perf->set_compute_efficiency(-INFINITY);
    109     } else {
    110       perf->set_compute_efficiency(analytical_compute_cost / measured_cost);
    111     }
    112 
    113     double analytical_memory_cost = analytical.memory_time();
    114     if (analytical_memory_cost == 0) {
    115       // Negative infinity indidates unavailable data.
    116       perf->set_memory_efficiency(-INFINITY);
    117     } else {
    118       perf->set_memory_efficiency(analytical_memory_cost / measured_cost);
    119     }
    120   }
    121 }
    122 
    123 
    124 void CostAnalyzer::SortOpsByTime(std::map<string, OpPerfSummary> ops) {
    125   for (const auto& op : ops) {
    126     ops_.push_back(op.second);
    127   }
    128   struct CompareByTime {
    129     bool operator()(const OpPerfSummary& a, const OpPerfSummary& b) const {
    130       return a.time > b.time;
    131     }
    132   };
    133   std::stable_sort(ops_.begin(), ops_.end(), CompareByTime());
    134 }
    135 
    136 void CostAnalyzer::AnalyzeCosts() {
    137   std::map<string, OpPerfSummary> ops;
    138   for (const auto& op_perf : op_perf_.op_performance()) {
    139     string op_name = op_perf.op().op();
    140     ops[op_name].count++;
    141     ops[op_name].time += op_perf.compute_cost();
    142     ops[op_name].compute_time += op_perf.compute_time();
    143     ops[op_name].memory_time += op_perf.memory_time();
    144     ops[op_name].time_upper += op_perf.compute_time() + op_perf.memory_time();
    145     ops[op_name].time_lower +=
    146         std::max(op_perf.compute_time(), op_perf.memory_time());
    147     ops[op_name].name = op_name;
    148   }
    149   SortOpsByTime(ops);
    150 
    151   total_time_measured_serialized_ = 0;
    152   total_time_analytical_upper_ = 0;
    153   total_time_analytical_lower_ = 0;
    154   for (const auto& op : ops_) {
    155     total_time_measured_serialized_ += op.time;
    156     total_time_analytical_upper_ += op.time_upper;
    157     total_time_analytical_lower_ += op.time_lower;
    158   }
    159 }
    160 
    161 void CostAnalyzer::PrintAnalysis(std::ostream& os, bool per_node_report) const {
    162   os << std::endl;
    163   os << std::left << std::setw(50)
    164      << "Total time measured in ns (serialized): " << std::right
    165      << std::setw(20) << total_time_measured_serialized_ << std::endl;
    166   os << std::left << std::setw(50)
    167      << "Total time measured in ns (actual): " << std::right << std::setw(20)
    168      << total_time_measured_ << std::endl;
    169   os << std::left << std::setw(50)
    170      << "Total time analytical in ns (upper bound): " << std::right
    171      << std::setw(20) << total_time_analytical_upper_ << std::endl;
    172   os << std::left << std::setw(50)
    173      << "Total time analytical in ns (lower bound): " << std::right
    174      << std::setw(20) << total_time_analytical_lower_ << std::endl;
    175   double efficiency_upper = static_cast<double>(total_time_analytical_upper_) /
    176                             static_cast<double>(total_time_measured_);
    177   os << std::left << std::setw(50)
    178      << "Overall efficiency (analytical upper/actual): " << std::right
    179      << std::setw(20) << efficiency_upper << std::endl;
    180   double efficiency_lower = static_cast<double>(total_time_analytical_lower_) /
    181                             static_cast<double>(total_time_measured_);
    182   os << std::left << std::setw(50)
    183      << "Overall efficiency (analytical lower/actual): " << std::right
    184      << std::setw(20) << efficiency_lower << std::endl;
    185   os << std::endl;
    186 
    187   int width = 35;
    188   int width_narrow = 15;
    189   int width_wide = 20;
    190   os << std::setw(width + 1) << "Op,";
    191   os << std::setw(width_narrow + 1) << "Count,";
    192   os << std::setw(width_wide + 1) << "Measured time (ns),";
    193   os << std::setw(width_narrow + 2) << "Time percent,";
    194   os << std::setw(width_narrow + 2) << "Acc percent,";
    195   os << std::setw(width_wide + 1) << "Analytical upper,";
    196   os << std::setw(width_wide + 1) << "Analytical lower,";
    197   os << std::setw(width_narrow + 2) << "Overall eff";
    198   os << std::setw(width_narrow + 2) << "Compute eff";
    199   os << std::setw(width_narrow + 2) << "Memory eff" << std::endl;
    200   float acc_percent = 0;
    201   for (const auto& op : ops_) {
    202     double percent = static_cast<double>(op.time) /
    203                      static_cast<double>(total_time_measured_serialized_);
    204     double eff =
    205         static_cast<double>(op.time_upper) / static_cast<double>(op.time);
    206     double compute_eff =
    207         static_cast<double>(op.compute_time) / static_cast<double>(op.time);
    208     double memory_eff =
    209         static_cast<double>(op.memory_time) / static_cast<double>(op.time);
    210     os << std::setw(width) << op.name << ",";
    211     os << std::setw(width_narrow) << op.count << ",";
    212     os << std::setw(width_wide) << op.time << ",";
    213     os << std::setw(width_narrow) << std::setprecision(2) << percent * 100
    214        << "%,";
    215     acc_percent += percent;
    216     os << std::setw(width_narrow) << std::setprecision(2) << acc_percent * 100
    217        << "%,";
    218     os << std::setw(width_wide) << op.time_upper << ",";
    219     os << std::setw(width_wide) << op.time_lower << ",";
    220     os << std::setw(width_narrow) << std::setprecision(2) << eff * 100 << "%,";
    221     os << std::setw(width_narrow) << std::setprecision(2) << compute_eff * 100
    222        << "%,";
    223     os << std::setw(width_narrow) << std::setprecision(2) << memory_eff * 100
    224        << "%,";
    225     os << std::endl;
    226   }
    227   os << std::endl;
    228 
    229   if (per_node_report) {
    230     os << "Below is the per-node report:" << std::endl;
    231     os << op_perf_.DebugString();
    232   }
    233 }
    234 
    235 }  // end namespace grappler
    236 }  // end namespace tensorflow
    237