Home | History | Annotate | Download | only in internal
      1 /* Copyright 2016 The TensorFlow Authors All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/profiler/internal/tfprof_op.h"
     17 
     18 #include <stdio.h>
     19 #include <utility>
     20 
     21 #include "tensorflow/core/lib/strings/strcat.h"
     22 #include "tensorflow/core/lib/strings/stringprintf.h"
     23 #include "tensorflow/core/platform/regexp.h"
     24 #include "tensorflow/core/profiler/internal/tfprof_constants.h"
     25 #include "tensorflow/core/profiler/internal/tfprof_tensor.h"
     26 
     27 namespace tensorflow {
     28 namespace tfprof {
     29 namespace {
     30 string FormatToalExecTime(const ShowMultiNode* node,
     31                           const ShowMultiNode* root) {
     32   double accu_pct = 0.0;
     33   double pct = 0.0;
     34   if (node->proto().total_exec_micros() > 0) {
     35     accu_pct = 100.0 * node->proto().total_exec_micros() /
     36                root->proto().total_exec_micros();
     37     pct =
     38         100.0 * node->proto().exec_micros() / root->proto().total_exec_micros();
     39   }
     40 
     41   return strings::Printf(
     42       "%30s", strings::Printf("%s (%.2f%%, %.2f%%)",
     43                               FormatTime(node->proto().exec_micros()).c_str(),
     44                               accu_pct, pct)
     45                   .c_str());
     46 }
     47 string FormatCPUExecTime(const ShowMultiNode* node, const ShowMultiNode* root) {
     48   double accu_pct = 0.0;
     49   double pct = 0.0;
     50   if (node->proto().total_cpu_exec_micros() > 0) {
     51     accu_pct = 100.0 * node->proto().total_cpu_exec_micros() /
     52                root->proto().total_cpu_exec_micros();
     53     pct = 100.0 * node->proto().cpu_exec_micros() /
     54           root->proto().total_cpu_exec_micros();
     55   }
     56 
     57   return strings::Printf(
     58       "%30s",
     59       strings::Printf("%s (%.2f%%, %.2f%%)",
     60                       FormatTime(node->proto().cpu_exec_micros()).c_str(),
     61                       accu_pct, pct)
     62           .c_str());
     63 }
     64 string FormatAcceleratorExecTime(const ShowMultiNode* node,
     65                                  const ShowMultiNode* root) {
     66   double accu_pct = 0.0;
     67   double pct = 0.0;
     68   if (node->proto().total_accelerator_exec_micros() > 0) {
     69     accu_pct = 100.0 * node->proto().total_accelerator_exec_micros() /
     70                root->proto().total_accelerator_exec_micros();
     71     pct = 100.0 * node->proto().accelerator_exec_micros() /
     72           root->proto().total_accelerator_exec_micros();
     73   }
     74 
     75   return strings::Printf(
     76       "%30s", strings::Printf(
     77                   "%s (%.2f%%, %.2f%%)",
     78                   FormatTime(node->proto().accelerator_exec_micros()).c_str(),
     79                   accu_pct, pct)
     80                   .c_str());
     81 }
     82 }  // namespace
     83 
     84 void TFOp::AddNode(TFGraphNode* node) {
     85   const string& op = node->op();
     86   if (tfcnodes_map_.find(op) == tfcnodes_map_.end()) {
     87     tfcnodes_map_[op] =
     88         std::unique_ptr<TFMultiGraphNode>(new TFMultiGraphNode(op));
     89   }
     90   TFMultiGraphNode* tfcnode = tfcnodes_map_[op].get();
     91   tfcnode->AddGraphNode(node);
     92 }
     93 
     94 void TFOp::Build() {
     95   for (auto& tn : tfcnodes_map_) {
     96     cnodes_map_[tn.first] =
     97         std::unique_ptr<OpNode>(new OpNode(tn.second.get()));
     98   }
     99 
    100   tfcnodes_map_[kTFProfRoot] =
    101       std::unique_ptr<TFMultiGraphNode>(new TFMultiGraphNode(kTFProfRoot));
    102   root_.reset(new OpNode(tfcnodes_map_[kTFProfRoot].get()));
    103 }
    104 
    105 const ShowMultiNode* TFOp::ShowInternal(const Options& opts,
    106                                         Timeline* timeline) {
    107   root_->ResetTotalStats();
    108   if (opts.output_type == kOutput[3]) {
    109     fprintf(stderr, "Only 'code' view supports pprof output now.\n");
    110     return root_.get();
    111   }
    112   if (opts.output_type == kOutput[1] || opts.output_type == kOutput[2]) {
    113     root_->formatted_str = FormatNode(root_.get(), root_.get(), opts);
    114   }
    115   if (timeline) {
    116     fprintf(stderr,
    117             "op view doesn't support timeline yet. "
    118             "Consider graph/scope/code view.\n");
    119     return root_.get();
    120   }
    121   if (cnodes_map_.empty()) {
    122     return root_.get();
    123   }
    124 
    125   std::vector<OpNode*> nodes;
    126   for (auto& n : cnodes_map_) {
    127     n.second->account = ReAccount(n.second.get(), opts);
    128     n.second->ResetTotalStats();
    129     n.second->AddSelfToTotalStats();
    130     nodes.push_back(n.second.get());
    131   }
    132   nodes = SortNodes(nodes, opts);
    133   // pre keeps track of previous visited node.
    134   OpNode* pre = nullptr;
    135   std::vector<OpNode*> account_nodes;
    136   for (auto it = nodes.rbegin(); it != nodes.rend(); ++it) {
    137     if ((*it)->account) {
    138       if (pre) (*it)->AggregateTotalStats(pre);
    139       account_nodes.push_back(*it);
    140       pre = *it;
    141     }
    142   }
    143   std::reverse(std::begin(account_nodes), std::end(account_nodes));
    144   if (pre) {
    145     root_->AggregateTotalStats(pre);
    146   }
    147 
    148   // Perform the display and optionally redo accounting.
    149   int64 depth = 0;
    150   std::vector<OpNode*> show_nodes;
    151   int64 start = SearchRoot(account_nodes, opts.start_name_regexes);
    152   for (int64 i = start; i < account_nodes.size(); ++i, ++depth) {
    153     OpNode* n = account_nodes[i];
    154     if (ShouldTrim(n, opts.trim_name_regexes) || depth > opts.max_depth) {
    155       break;
    156     }
    157     n->show = ShouldShow(n, opts, depth);
    158     if (n->show) show_nodes.push_back(n);
    159   }
    160 
    161   pre = nullptr;
    162   for (auto it = show_nodes.rbegin(); it != show_nodes.rend(); ++it) {
    163     if (opts.account_displayed_op_only) {
    164       (*it)->ResetTotalStats();
    165       (*it)->AddSelfToTotalStats();
    166       if (pre) (*it)->AggregateTotalStats(pre);
    167     }
    168     pre = *it;
    169   }
    170   if (opts.account_displayed_op_only) {
    171     root_->ResetTotalStats();
    172     if (pre) {
    173       root_->AggregateTotalStats(pre);
    174     }
    175   }
    176   if (opts.output_type == kOutput[1] || opts.output_type == kOutput[2]) {
    177     string display_str = FormatLegend(opts);
    178     for (OpNode* node : show_nodes) {
    179       display_str += FormatNode(node, root_.get(), opts);
    180     }
    181     // In op view, we don't show root (total). But it will still in proto.
    182     // TODO(xpan): Is it the right choice?
    183     root_->formatted_str = display_str;
    184   }
    185   // Populate the chidren field.
    186   auto* pre_pb = root_->mutable_proto();
    187   for (auto& show_node : show_nodes) {
    188     pre_pb->clear_children();
    189     pre_pb->add_children()->Swap(show_node->mutable_proto());
    190     pre_pb = pre_pb->mutable_children(0);
    191   }
    192   return root_.get();
    193 }
    194 
    195 int64 TFOp::SearchRoot(const std::vector<OpNode*> nodes,
    196                        const std::vector<string>& regexes) {
    197   if (regexes.empty() || (regexes.size() == 1 && regexes[0] == ".*")) {
    198     return 0;
    199   }
    200   int64 i = 0;
    201   for (; i < nodes.size(); ++i) {
    202     for (const string& regex : regexes) {
    203       if (RE2::FullMatch(nodes[i]->name(), regex)) {
    204         return i;
    205       }
    206     }
    207   }
    208   return i;
    209 }
    210 
    211 string TFOp::FormatMemoryNode(int64 node_total_bytes, int64 root_total_bytes,
    212                               int64 node_bytes) const {
    213   double accu_pct = 0.0;
    214   double pct = 0.0;
    215   if (node_bytes > 0) {
    216     accu_pct = 100.0 * node_total_bytes / root_total_bytes;
    217     pct = 100.0 * node_bytes / root_total_bytes;
    218   }
    219   return strings::Printf(
    220       "%30s", strings::Printf("%s (%.2f%%, %.2f%%)",
    221                               FormatMemory(node_bytes).c_str(), accu_pct, pct)
    222                   .c_str());
    223 }
    224 
    225 string TFOp::FormatNode(OpNode* node, OpNode* root, const Options& opts) const {
    226   std::vector<string> attrs;
    227 
    228   if (opts.select.find(kShown[0]) != opts.select.end()) {
    229     attrs.push_back(FormatMemoryNode(node->proto().total_requested_bytes(),
    230                                      root->proto().total_requested_bytes(),
    231                                      node->proto().requested_bytes()));
    232   }
    233 
    234   if (opts.select.find(kShown[11]) != opts.select.end()) {
    235     attrs.push_back(FormatMemoryNode(node->proto().total_peak_bytes(),
    236                                      root->proto().total_peak_bytes(),
    237                                      node->proto().peak_bytes()));
    238   }
    239 
    240   if (opts.select.find(kShown[12]) != opts.select.end()) {
    241     attrs.push_back(FormatMemoryNode(node->proto().total_residual_bytes(),
    242                                      root->proto().total_residual_bytes(),
    243                                      node->proto().residual_bytes()));
    244   }
    245   if (opts.select.find(kShown[13]) != opts.select.end()) {
    246     attrs.push_back(FormatMemoryNode(node->proto().total_output_bytes(),
    247                                      root->proto().total_output_bytes(),
    248                                      node->proto().output_bytes()));
    249   }
    250 
    251   if (opts.select.find(kShown[1]) != opts.select.end()) {
    252     attrs.push_back(FormatToalExecTime(node, root));
    253     attrs.push_back(FormatAcceleratorExecTime(node, root));
    254     attrs.push_back(FormatCPUExecTime(node, root));
    255   }
    256   if (opts.select.find(kShown[9]) != opts.select.end() &&
    257       opts.select.find(kShown[1]) == opts.select.end()) {
    258     attrs.push_back(FormatAcceleratorExecTime(node, root));
    259   }
    260   if (opts.select.find(kShown[10]) != opts.select.end() &&
    261       opts.select.find(kShown[1]) == opts.select.end()) {
    262     attrs.push_back(FormatCPUExecTime(node, root));
    263   }
    264   if (opts.select.find(kShown[2]) != opts.select.end()) {
    265     double accu_pct = 0.0;
    266     double pct = 0.0;
    267     if (node->proto().total_parameters() > 0) {
    268       accu_pct = 100.0 * node->proto().total_parameters() /
    269                  root->proto().total_parameters();
    270       pct =
    271           100.0 * node->proto().parameters() / root->proto().total_parameters();
    272     }
    273     attrs.push_back(strings::Printf(
    274         "%30s",
    275         strings::Printf("%s params (%.2f%%, %.2f%%)",
    276                         FormatNumber(node->proto().parameters()).c_str(),
    277                         accu_pct, pct)
    278             .c_str()));
    279   }
    280 
    281   if (opts.select.find(kShown[3]) != opts.select.end()) {
    282     double accu_pct = 0.0;
    283     double pct = 0.0;
    284     if (node->proto().total_float_ops() > 0) {
    285       accu_pct = 100.0 * node->proto().total_float_ops() /
    286                  root->proto().total_float_ops();
    287       pct = 100.0 * node->proto().float_ops() / root->proto().total_float_ops();
    288     }
    289 
    290     attrs.push_back(strings::Printf(
    291         "%30s", strings::Printf("%s float_ops (%.2f%%, %.2f%%)",
    292                                 FormatNumber(node->proto().float_ops()).c_str(),
    293                                 accu_pct, pct)
    294                     .c_str()));
    295   }
    296 
    297   if (opts.select.find(kShown[5]) != opts.select.end()) {
    298     attrs.push_back(str_util::Join(node->node->devices(), "|"));
    299   }
    300 
    301   if (opts.select.find(kShown[6]) != opts.select.end()) {
    302     std::set<string> op_types = node->node->op_types();
    303     attrs.push_back(str_util::Join(op_types, "|"));
    304   }
    305 
    306   if (opts.select.find(kShown[7]) != opts.select.end()) {
    307     int64 total_runs = 0;
    308     for (const auto& gnode : node->proto().graph_nodes()) {
    309       total_runs += gnode.run_count();
    310     }
    311     attrs.push_back(strings::Printf(
    312         "%10s",
    313         strings::Printf("%lld|%d", total_runs, node->proto().graph_nodes_size())
    314             .c_str()));
    315   }
    316 
    317   string node_str = strings::Printf("%-25s%s\n", node->name().c_str(),
    318                                     str_util::Join(attrs, ", ").c_str());
    319 
    320   if (opts.select.find(kShown[8]) != opts.select.end()) {
    321     string input_shape_str = FormatInputShapes(node->proto());
    322     if (!input_shape_str.empty()) {
    323       node_str = strings::Printf("%s\n%s\n\n", node_str.c_str(),
    324                                  input_shape_str.c_str());
    325     }
    326   }
    327   return node_str;
    328 }
    329 }  // namespace tfprof
    330 }  // namespace tensorflow
    331