Home | History | Annotate | Download | only in internal
      1 /* Copyright 2016 The TensorFlow Authors All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/profiler/internal/tfprof_code.h"
     17 
     18 #include <stdio.h>
     19 #include <utility>
     20 
     21 #include "tensorflow/c/c_api.h"
     22 #include "tensorflow/core/framework/tensor.h"
     23 #include "tensorflow/core/lib/io/path.h"
     24 #include "tensorflow/core/lib/io/zlib_compression_options.h"
     25 #include "tensorflow/core/lib/io/zlib_outputbuffer.h"
     26 #include "tensorflow/core/lib/strings/str_util.h"
     27 #include "tensorflow/core/lib/strings/strcat.h"
     28 #include "tensorflow/core/lib/strings/stringprintf.h"
     29 #include "tensorflow/core/platform/regexp.h"
     30 #include "tensorflow/core/profiler/internal/tfprof_constants.h"
     31 
     32 namespace tensorflow {
     33 namespace tfprof {
     34 namespace {
     35 
     36 const char* const kGradientSuffix = " (gradient)";
     37 
     38 // Convert to Trace proto into a short readable string.
     39 string GetTraceString(const CallStack::Trace& trace) {
     40   string ntrace = io::Basename(trace.file()).ToString();
     41   ntrace += strings::StrCat(":", trace.lineno());
     42   if (trace.function().length() < 20) {
     43     ntrace += ":" + trace.function();
     44   } else {
     45     ntrace += ":" + trace.function().substr(0, 17) + "...";
     46   }
     47   return ntrace;
     48 }
     49 
     50 bool IsGradNode(const string& name, string* forward_name) {
     51   // Given a forward operation with name op, its gradient op has the following
     52   // name: ...gradients/op_grad/...
     53   // TODO(xpan): This is hacky.
     54   auto grad_prefix = name.find("gradients/");
     55   auto grad_suffix = name.find("_grad/");
     56   if (grad_prefix == name.npos || grad_suffix == name.npos) {
     57     return false;
     58   }
     59   auto start = grad_prefix + string("gradients/").length();
     60   auto len = grad_suffix - start;
     61   if (len <= 0) {
     62     return false;
     63   }
     64   *forward_name = name.substr(start, len);
     65   return true;
     66 }
     67 
     68 // StringTable maps each string to an id.
     69 class StringTable {
     70  public:
     71   StringTable() {
     72     // Pprof requires first entry in string_table to be ''.
     73     string_id_[""] = 0;
     74     all_strings_.push_back("");
     75   }
     76 
     77   // Returns the index of a string. If not found, inserts the string and
     78   // return the inserted index.
     79   uint64 GetIndex(const string& str) {
     80     auto idx = string_id_.find(str);
     81     if (idx != string_id_.end()) {
     82       return idx->second;
     83     }
     84     all_strings_.push_back(str);
     85     return string_id_.insert(std::pair<string, int64>(str, string_id_.size()))
     86         .first->second;
     87   }
     88 
     89   const std::vector<string>& strings() const { return all_strings_; }
     90 
     91  private:
     92   std::map<string, uint64> string_id_;
     93   std::vector<string> all_strings_;
     94 };
     95 
     96 // FunctionTable maps each function to an id.
     97 class FunctionTable {
     98  public:
     99   explicit FunctionTable(StringTable* string_table)
    100       : string_table_(string_table) {}
    101 
    102   // Returns the index of a function. If not found, adds a function proto
    103   // and returns the function index.
    104   uint64 GetIndex(const string& file_path, const string& func_name,
    105                   uint64 func_start_line) {
    106     auto key = std::tuple<string, string, uint64>(file_path, func_name,
    107                                                   func_start_line);
    108     auto idx = function_table_.find(key);
    109     if (idx != function_table_.end()) {
    110       return idx->second.id();
    111     }
    112     pprof::Function* func_pb = &function_table_[key];
    113     // function index should start from 1.
    114     func_pb->set_id(function_table_.size());
    115 
    116     string file_base = io::Basename(file_path).ToString();
    117     file_base = file_base.substr(0, file_base.find_last_of("."));
    118     func_pb->set_name(
    119         string_table_->GetIndex(strings::StrCat(file_base, ":", func_name)));
    120     func_pb->set_filename(string_table_->GetIndex(file_path));
    121     func_pb->set_start_line(func_start_line);
    122     return func_pb->id();
    123   }
    124 
    125   const std::map<std::tuple<string, string, uint64>, pprof::Function>&
    126   functions() const {
    127     return function_table_;
    128   }
    129 
    130  private:
    131   StringTable* string_table_;
    132   std::map<std::tuple<string, string, uint64>, pprof::Function> function_table_;
    133 };
    134 
    135 // LocationTable maps each function call to an id.
    136 class LocationTable {
    137  public:
    138   explicit LocationTable(FunctionTable* function_table)
    139       : function_table_(function_table) {}
    140 
    141   // Returns the index of a function call localtion. If not found, adds a
    142   // location proto and returns the location index.
    143   uint64 GetIndex(const string& file_path, uint64 line_number,
    144                   const string& called_function_name,
    145                   const string& called_file_path,
    146                   uint64 called_func_start_line) {
    147     auto key = std::tuple<string, string, uint64>(
    148         file_path, called_function_name, line_number);
    149 
    150     auto idx = location_table_.find(key);
    151     if (idx != location_table_.end()) {
    152       return idx->second.id();
    153     }
    154     pprof::Location* location_pb = &location_table_[key];
    155     location_pb->set_id(location_table_.size());
    156     pprof::Line* line_pb = location_pb->add_line();
    157     line_pb->set_function_id(function_table_->GetIndex(
    158         called_file_path, called_function_name, called_func_start_line));
    159     line_pb->set_line(line_number);
    160     return location_pb->id();
    161   }
    162 
    163   const std::map<std::tuple<string, string, uint64>, pprof::Location>&
    164   locations() const {
    165     return location_table_;
    166   }
    167 
    168  private:
    169   FunctionTable* function_table_;
    170   std::map<std::tuple<string, string, uint64>, pprof::Location> location_table_;
    171 };
    172 
    173 // Samples stores samples of all calls. A sample is a single call trace,
    174 // that is, the call path from top caller to the leaf callee.
    175 class Samples {
    176  public:
    177   explicit Samples(StringTable* string_table, const Options* opts)
    178       : string_table_(string_table), opts_(opts) {}
    179 
    180   // 'node' is the leaf of the displayed trace. It includes all graph nodes
    181   // created by it. 'location_ids' contains
    182   // the call stack, from callee to caller.
    183   // This method adds the statistics of graph nodes created by the python
    184   // call.
    185   void Add(const CodeNode* node, const std::vector<uint64>& location_ids) {
    186     // displayed leaf might not be true leaf. Retrive the true leaves for
    187     // stats.
    188     std::vector<const CodeNode*> all_leaf = FetchAllLeaf(node);
    189     CHECK(!all_leaf.empty()) << node->name();
    190 
    191     for (const CodeNode* cn : all_leaf) {
    192       for (auto gn_it : cn->node->graph_nodes()) {
    193         const TFGraphNode* gn = gn_it.second;
    194         string name = gn->name();
    195         // Generate a new trace name, in case the name is taken.
    196         while (sample_table_.find(name) != sample_table_.end()) {
    197           name += '@';
    198         }
    199         pprof::Sample* sample_pb = &sample_table_[name];
    200         for (uint64 id : location_ids) {
    201           sample_pb->mutable_location_id()->Add(id);
    202         }
    203         pprof::Label* label_pb = sample_pb->mutable_label()->Add();
    204         label_pb->set_key(string_table_->GetIndex("graph node:"));
    205         label_pb->set_str(string_table_->GetIndex(gn->name()));
    206 
    207         sample_pb->mutable_value()->Add(1);
    208         string type = *opts_->select.begin();
    209         if (type == kShown[1]) {
    210           sample_pb->mutable_value()->Add(gn->exec_micros(node->node->step()));
    211         } else if (type == kShown[9]) {
    212           sample_pb->mutable_value()->Add(
    213               gn->accelerator_exec_micros(node->node->step()));
    214         } else if (type == kShown[10]) {
    215           sample_pb->mutable_value()->Add(
    216               gn->cpu_exec_micros(node->node->step()));
    217         } else if (type == kShown[0]) {
    218           sample_pb->mutable_value()->Add(
    219               gn->requested_bytes(node->node->step()));
    220         } else if (type == kShown[11]) {
    221           sample_pb->mutable_value()->Add(gn->peak_bytes(node->node->step()));
    222         } else if (type == kShown[12]) {
    223           sample_pb->mutable_value()->Add(
    224               gn->residual_bytes(node->node->step()));
    225         } else if (type == kShown[13]) {
    226           sample_pb->mutable_value()->Add(gn->output_bytes(node->node->step()));
    227         } else if (type == kShown[2]) {
    228           sample_pb->mutable_value()->Add(gn->parameters());
    229         } else if (type == kShown[3]) {
    230           sample_pb->mutable_value()->Add(gn->float_ops(node->node->step()));
    231         } else {
    232           fprintf(stderr, "pprof doesn't support -select=%s\n", type.c_str());
    233         }
    234       }
    235     }
    236   }
    237 
    238   const std::map<string, pprof::Sample>& samples() const {
    239     return sample_table_;
    240   }
    241 
    242  private:
    243   std::vector<const CodeNode*> FetchAllLeaf(const CodeNode* root) {
    244     if (root->children.empty()) {
    245       return {root};
    246     }
    247     std::vector<const CodeNode*> ret;
    248     for (auto& n : root->children) {
    249       std::vector<const CodeNode*> nodes = FetchAllLeaf(n);
    250       ret.insert(ret.end(), nodes.begin(), nodes.end());
    251     }
    252     return ret;
    253   }
    254 
    255   StringTable* string_table_;
    256   const Options* opts_;
    257   std::map<string, pprof::Sample> sample_table_;
    258 };
    259 
    260 class PprofProfileImpl : public PprofProfile {
    261  public:
    262   explicit PprofProfileImpl(const Options* opts)
    263       : opts_(opts),
    264         func_table_(new FunctionTable(&string_table_)),
    265         loc_table_(new LocationTable(func_table_.get())),
    266         samples_(new Samples(&string_table_, opts)) {}
    267 
    268   uint64 AddLocation(const CodeNode* callee, const CodeNode* caller) override {
    269     const string& file_path = caller->file();
    270     uint64 lineno = caller->lineno();
    271     const string& callee_file_path = callee->file();
    272     const string& callee_function = callee->function();
    273     uint64 callee_func_start_line = callee->func_start_line();
    274 
    275     return loc_table_->GetIndex(file_path, lineno, callee_function,
    276                                 callee_file_path, callee_func_start_line);
    277   }
    278 
    279   void AddSample(const CodeNode* leaf, std::vector<uint64>* call_ids) override {
    280     std::vector<uint64> reversed_call_ids;
    281     std::reverse_copy(call_ids->begin(), call_ids->end(),
    282                       std::back_inserter(reversed_call_ids));
    283     samples_->Add(leaf, reversed_call_ids);
    284   }
    285 
    286   Status WritePprofProfile(const string& filename) override {
    287     pprof::Profile profile_pb;
    288     Build(&profile_pb);
    289 
    290     std::unique_ptr<WritableFile> file;
    291     Status s = Env::Default()->NewWritableFile(filename, &file);
    292     if (!s.ok()) return s;
    293 
    294     int32 buf_size = 1024 * 1024;
    295     io::ZlibOutputBuffer* zlib_output_buffer = new io::ZlibOutputBuffer(
    296         file.get(), buf_size, buf_size, io::ZlibCompressionOptions::GZIP());
    297     s = zlib_output_buffer->Init();
    298     if (!s.ok()) return s;
    299     s = zlib_output_buffer->Append(profile_pb.SerializeAsString());
    300     if (!s.ok()) return s;
    301     s = zlib_output_buffer->Close();
    302     if (!s.ok()) return s;
    303     fprintf(stdout, "\nRun pprof -png --nodecount=100 --sample_index=1 <%s>\n",
    304             filename.c_str());
    305     return s;
    306   }
    307 
    308  private:
    309   void Build(pprof::Profile* profile_pb) {
    310     string sample_type_description = "count";
    311     auto sample_type = profile_pb->mutable_sample_type()->Add();
    312     sample_type->set_type(string_table_.GetIndex(sample_type_description));
    313     sample_type->set_unit(string_table_.GetIndex("count"));
    314 
    315     string type = *opts_->select.begin();
    316     sample_type_description = type;
    317     sample_type = profile_pb->mutable_sample_type()->Add();
    318     sample_type->set_type(string_table_.GetIndex(sample_type_description));
    319     if (type == kShown[1] || type == kShown[9] || type == kShown[10]) {
    320       sample_type->set_unit(string_table_.GetIndex("microseconds"));
    321       if (type == kShown[1]) {
    322         profile_pb->mutable_comment()->Add(string_table_.GetIndex(
    323             "Sum of accelerator execution time and cpu execution time."));
    324       } else if (type == kShown[9]) {
    325         profile_pb->mutable_comment()->Add(
    326             string_table_.GetIndex("Accelerator execution time."));
    327       } else if (type == kShown[10]) {
    328         profile_pb->mutable_comment()->Add(
    329             string_table_.GetIndex("CPU execution time."));
    330       }
    331     } else if (type == kShown[0]) {
    332       sample_type->set_unit(string_table_.GetIndex("bytes"));
    333       profile_pb->mutable_comment()->Add(
    334           string_table_.GetIndex("Sum of operation total memory requests, "
    335                                  "excluding deallocations."));
    336     } else if (type == kShown[11]) {
    337       sample_type->set_unit(string_table_.GetIndex("bytes"));
    338       profile_pb->mutable_comment()->Add(
    339           string_table_.GetIndex("Sum of operation peak memory usage."));
    340     } else if (type == kShown[12]) {
    341       sample_type->set_unit(string_table_.GetIndex("bytes"));
    342       profile_pb->mutable_comment()->Add(string_table_.GetIndex(
    343           "Sum of operation allocated memory after finish."));
    344     } else if (type == kShown[13]) {
    345       sample_type->set_unit(string_table_.GetIndex("bytes"));
    346       profile_pb->mutable_comment()->Add(
    347           string_table_.GetIndex("Sum of operation output size."));
    348     } else if (type == kShown[2]) {
    349       sample_type->set_unit(string_table_.GetIndex("count"));
    350       profile_pb->mutable_comment()->Add(
    351           string_table_.GetIndex("Model parameters."));
    352     } else if (type == kShown[3]) {
    353       sample_type->set_unit(string_table_.GetIndex("count"));
    354       profile_pb->mutable_comment()->Add(string_table_.GetIndex(
    355           "Model float operations (Only available if defined)."));
    356     } else {
    357       fprintf(stderr, "pprof doesn't support selecting: %s\n", type.c_str());
    358     }
    359 
    360     for (const string& str : string_table_.strings()) {
    361       *profile_pb->mutable_string_table()->Add() = str;
    362     }
    363     for (const auto& sample_it : samples_->samples()) {
    364       // TODO(xpan): Consider swap.
    365       profile_pb->mutable_sample()->Add()->MergeFrom(sample_it.second);
    366     }
    367     for (const auto& function_it : func_table_->functions()) {
    368       profile_pb->mutable_function()->Add()->MergeFrom(function_it.second);
    369     }
    370     for (const auto& location_it : loc_table_->locations()) {
    371       profile_pb->mutable_location()->Add()->MergeFrom(location_it.second);
    372     }
    373   }
    374 
    375   const Options* opts_;
    376   StringTable string_table_;
    377   std::unique_ptr<FunctionTable> func_table_;
    378   std::unique_ptr<LocationTable> loc_table_;
    379   std::unique_ptr<Samples> samples_;
    380 };
    381 }  // namespace
    382 
    383 void TFCode::AddNode(TFGraphNode* node) {
    384   if (!node->call_stack() || node->call_stack()->traces().empty()) {
    385     return;
    386   }
    387   // We infer the forward operation name from gradient op name. So, we can
    388   // map gradient op traces to forward op traces.
    389   // E.g. gradient node of 'inp_1/Conv2D' would be 'gradients/inp_1/Conv2D_grad.
    390   string forward_name;
    391   if (IsGradNode(node->name(), &forward_name)) {
    392     auto grad_nodes_it = grad_nodes_.find(forward_name);
    393     if (grad_nodes_it != grad_nodes_.end()) {
    394       grad_nodes_it->second.push_back(node);
    395     } else {
    396       grad_nodes_.insert(
    397           std::pair<string, std::vector<TFGraphNode*>>(forward_name, {node}));
    398     }
    399     return;
    400   } else {
    401     forward_nodes_[node->name()] = node;
    402   }
    403 
    404   if (!root_) {
    405     graph_root_.reset(new TFMultiGraphNode(kTFProfRoot));
    406     root_.reset(new CodeNode(graph_root_.get(), nullptr, ""));
    407   }
    408 
    409   CodeNode* pre_code_node = root_.get();
    410   // TODO(xpan): Consider to release CodeDef after TFCode is built. It
    411   // takes a lot of memory.
    412   std::set<string> traces;
    413   for (int i = 0; i < node->call_stack()->traces().size(); ++i) {
    414     // Unlike op name, which is globally unique, trace name is only unique
    415     // w.r.t. it's parent.
    416     const string& trace = GetTraceString(node->call_stack()->traces().at(i));
    417     traces.insert(trace);
    418     pre_code_node = pre_code_node->AddChildren(
    419         trace, &node->call_stack()->traces().at(i), "");
    420     if (i == node->call_stack()->traces().size() - 1) {
    421       pre_code_node->node->AddGraphNode(node);
    422     }
    423   }
    424 }
    425 
    426 void TFCode::Build() {
    427   int64 unaccounted_nodes = 0;
    428   for (auto it : grad_nodes_) {
    429     const string& forward_name = it.first;
    430     auto forward_it = forward_nodes_.find(forward_name);
    431     if (forward_it == forward_nodes_.end()) {
    432       unaccounted_nodes += 1;
    433       continue;
    434     }
    435     TFGraphNode* fn = forward_it->second;
    436     CodeNode* leaf = nullptr;
    437     CodeNode* pre_code_node = root_.get();
    438     for (int i = 0; i < fn->call_stack()->traces().size(); ++i) {
    439       const string& trace =
    440           GetTraceString(fn->call_stack()->traces().at(i)) + kGradientSuffix;
    441       pre_code_node = pre_code_node->AddChildren(
    442           trace, &fn->call_stack()->traces().at(i), kGradientSuffix);
    443       if (i == fn->call_stack()->traces().size() - 1) {
    444         leaf = pre_code_node;
    445       }
    446     }
    447     for (TFGraphNode* gn : it.second) {
    448       leaf->node->AddGraphNode(gn);
    449     }
    450   }
    451   if (unaccounted_nodes > 0) {
    452     fprintf(stderr, "%lld gradient nodes not accounted\n", unaccounted_nodes);
    453   }
    454 }
    455 
    456 const ShowMultiNode* TFCode::ShowInternal(const Options& opts,
    457                                           Timeline* timeline) {
    458   root_->ResetTotalStats();
    459   if (opts.output_type == kOutput[3]) {
    460     if (opts.select.size() != 1) {
    461       fprintf(stderr, "Can only select 1 attribute for pprof output.\n");
    462       return root_.get();
    463     }
    464     string select = *opts.select.begin();
    465     if (select != kShown[0] && select != kShown[1] && select != kShown[2] &&
    466         select != kShown[3] && select != kShown[9] && select != kShown[10] &&
    467         select != kShown[11] && select != kShown[12] && select != kShown[13]) {
    468       fprintf(stderr, "pprof doesn't support -select=%s\n", select.c_str());
    469       return root_.get();
    470     }
    471   }
    472   if (opts.account_displayed_op_only) {
    473     fprintf(stderr, "Note: code view ignores account_displayed_op_only\n");
    474   }
    475 
    476   std::vector<CodeNode*> roots = Account(root_->children, opts);
    477   root_->show_children.clear();
    478   for (CodeNode* n : roots) {
    479     root_->AggregateTotalStats(n);
    480   }
    481 
    482   if (opts.start_name_regexes.size() != 1 ||
    483       opts.start_name_regexes[0] != ".*") {
    484     roots = SearchRoot(roots, opts.start_name_regexes);
    485   }
    486 
    487   root_->show_children.assign(roots.begin(), roots.end());
    488 
    489   CodeNode* root = PrintScope({root_.get()}, opts, 1, 0)[0];
    490 
    491   root->formatted_str = FormatLegend(opts) + root->formatted_str;
    492 
    493   if (opts.output_type == kOutput[3]) {
    494     std::vector<uint64> call_ids;
    495     pprof_profile_.reset(new PprofProfileImpl(&opts));
    496     Format(root, root->show_children, opts, &root->formatted_str,
    497            root->mutable_proto(), &call_ids);
    498     Status s = pprof_profile_->WritePprofProfile(
    499         opts.output_options.at(kPprofOpts[0]));
    500     if (!s.ok()) {
    501       fprintf(stderr, "%s\n", s.ToString().c_str());
    502     }
    503   } else {
    504     Format(root, root->show_children, opts, &root->formatted_str,
    505            root->mutable_proto(), nullptr);
    506     if (timeline) {
    507       timeline->GenerateCodeTimeline(root);
    508     }
    509   }
    510   return root;
    511 }
    512 
    513 void TFCode::Format(const CodeNode* root, const std::vector<CodeNode*>& nodes,
    514                     const Options& opts, string* display_str,
    515                     MultiGraphNodeProto* proto, std::vector<uint64>* call_ids) {
    516   if (nodes.empty() && root->has_trace() && opts.output_type == kOutput[3]) {
    517     pprof_profile_->AddSample(root, call_ids);
    518   }
    519 
    520   for (CodeNode* node : nodes) {
    521     if (root->has_trace() && opts.output_type == kOutput[3]) {
    522       uint64 loc_id = pprof_profile_->AddLocation(node, root);
    523       call_ids->push_back(loc_id);
    524     }
    525     display_str->append(node->formatted_str);
    526     MultiGraphNodeProto* child = proto->add_children();
    527     child->MergeFrom(node->proto());
    528     Format(node, node->show_children, opts, display_str, child, call_ids);
    529     if (root->has_trace() && opts.output_type == kOutput[3]) {
    530       call_ids->pop_back();
    531     }
    532   }
    533 }
    534 
    535 std::vector<CodeNode*> TFCode::SearchRoot(std::vector<CodeNode*> roots,
    536                                           const std::vector<string>& regexes) {
    537   std::vector<CodeNode*> res;
    538   if (roots.empty()) {
    539     return res;
    540   }
    541   for (CodeNode* root : roots) {
    542     bool match_start_node = false;
    543     for (const string& regex : regexes) {
    544       if (RE2::FullMatch(root->name(), regex)) {
    545         res.push_back(root);
    546         match_start_node = true;
    547         break;
    548       }
    549     }
    550     if (match_start_node) {
    551       // Found a start node at this branch, no need to continue.
    552       continue;
    553     }
    554     std::vector<CodeNode*> nroots = SearchRoot(root->show_children, regexes);
    555     res.insert(res.end(), nroots.begin(), nroots.end());
    556   }
    557   return res;
    558 }
    559 
    560 std::vector<CodeNode*> TFCode::PrintScope(const std::vector<CodeNode*> roots,
    561                                           const Options& opts, int depth,
    562                                           int last_ident) {
    563   std::vector<CodeNode*> show_nodes;
    564 
    565   for (CodeNode* node : roots) {
    566     if (ShouldTrim(node, opts.trim_name_regexes) || depth > opts.max_depth) {
    567       continue;
    568     }
    569     int ident = last_ident;
    570     bool show = ShouldShow(node, opts, depth);
    571     if (show) ident += 2;
    572 
    573     std::vector<CodeNode*> show_cnodes =
    574         PrintScope(node->show_children, opts, depth + 1, ident);
    575     if (show) {
    576       node->show_children.clear();
    577 
    578       show_cnodes = SortNodes(show_cnodes, opts);
    579       for (CodeNode* sc : show_cnodes) {
    580         node->show_children.push_back(sc);
    581       }
    582 
    583       node->formatted_str = FormatNode(node, opts, last_ident);
    584 
    585       if (opts.select.find(kShown[4]) != opts.select.end()) {
    586         fprintf(stderr, "code view has no tensor value to show\n");
    587       }
    588       show_nodes.push_back(node);
    589     } else {
    590       show_nodes.insert(show_nodes.end(), show_cnodes.begin(),
    591                         show_cnodes.end());
    592     }
    593   }
    594   return show_nodes;
    595 }
    596 
    597 std::vector<CodeNode*> TFCode::Account(const std::vector<CodeNode*>& roots,
    598                                        const Options& opts) {
    599   std::vector<CodeNode*> act_nodes;
    600 
    601   for (CodeNode* node : roots) {
    602     node->ResetTotalStats();
    603     std::vector<CodeNode*> act_cnodes = Account(node->children, opts);
    604     node->account = ReAccount(node, opts);
    605     if (node->account || !act_cnodes.empty()) {
    606       node->show_children.clear();
    607       node->ResetTotalStats();
    608       node->AddSelfToTotalStats();
    609       for (CodeNode* c : act_cnodes) {
    610         node->AggregateTotalStats(c);
    611         node->show_children.push_back(c);
    612       }
    613       act_nodes.push_back(node);
    614     }
    615   }
    616   return act_nodes;
    617 }
    618 
    619 string TFCode::FormatNodeMemory(CodeNode* node, int64 bytes,
    620                                 int64 total_bytes) const {
    621   string memory = FormatMemory(total_bytes);
    622   if (node->account) {
    623     memory = FormatMemory(bytes) + "/" + memory;
    624   } else {
    625     memory = "--/" + memory;
    626   }
    627   return memory;
    628 }
    629 
    630 string TFCode::FormatNode(CodeNode* node, const Options& opts,
    631                           int64 indent) const {
    632   std::vector<string> attrs;
    633   if (opts.select.find(kShown[0]) != opts.select.end()) {
    634     attrs.push_back(FormatNodeMemory(node, node->proto().requested_bytes(),
    635                                      node->proto().total_requested_bytes()));
    636   }
    637   if (opts.select.find(kShown[11]) != opts.select.end()) {
    638     attrs.push_back(FormatNodeMemory(node, node->proto().peak_bytes(),
    639                                      node->proto().total_peak_bytes()));
    640   }
    641   if (opts.select.find(kShown[12]) != opts.select.end()) {
    642     attrs.push_back(FormatNodeMemory(node, node->proto().residual_bytes(),
    643                                      node->proto().total_residual_bytes()));
    644   }
    645   if (opts.select.find(kShown[13]) != opts.select.end()) {
    646     attrs.push_back(FormatNodeMemory(node, node->proto().output_bytes(),
    647                                      node->proto().total_output_bytes()));
    648   }
    649 
    650   std::vector<string> time_attrs = FormatTimes(node, opts);
    651   attrs.insert(attrs.end(), time_attrs.begin(), time_attrs.end());
    652 
    653   if (opts.select.find(kShown[2]) != opts.select.end()) {
    654     string params = FormatNumber(node->proto().total_parameters()) + " params";
    655     if (node->account) {
    656       params = FormatNumber(node->proto().parameters()) + "/" + params;
    657     } else {
    658       params = "--/" + params;
    659     }
    660     attrs.push_back(params);
    661   }
    662 
    663   if (opts.select.find(kShown[3]) != opts.select.end()) {
    664     string fops = FormatNumber(node->proto().total_float_ops()) + " flops";
    665     if (node->account) {
    666       fops = FormatNumber(node->proto().float_ops()) + "/" + fops;
    667     } else {
    668       fops = "--/" + fops;
    669     }
    670     attrs.push_back(fops);
    671   }
    672 
    673   if (opts.select.find(kShown[5]) != opts.select.end() &&
    674       !node->node->devices().empty()) {
    675     attrs.push_back(str_util::Join(node->node->devices(), "|"));
    676   }
    677   if (opts.select.find(kShown[6]) != opts.select.end()) {
    678     std::set<string> op_types = node->node->op_types();
    679     attrs.push_back(str_util::Join(op_types, "|"));
    680   }
    681   if (opts.select.find(kShown[7]) != opts.select.end()) {
    682     // TODO(xpan): Make op count available in code view?
    683     attrs.push_back(strings::Printf("%s N/A in code view", kShown[7]));
    684   }
    685   if (opts.select.find(kShown[8]) != opts.select.end()) {
    686     attrs.push_back(strings::Printf("%s N/A in code view", kShown[8]));
    687   }
    688 
    689   return strings::Printf("%s%s (%s)\n", string(indent, ' ').c_str(),
    690                          node->name().c_str(),
    691                          str_util::Join(attrs, ", ").c_str());
    692 }
    693 }  // namespace tfprof
    694 }  // namespace tensorflow
    695