Home | History | Annotate | Download | only in profiler
      1 /* Copyright 2016 The TensorFlow Authors All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <stdio.h>
     17 #include <stdlib.h>
     18 #include <memory>
     19 #include <set>
     20 #include <string>
     21 #include <utility>
     22 #include <vector>
     23 
     24 #include "linenoise.h"
     25 #include "tensorflow/c/c_api.h"
     26 #include "tensorflow/c/checkpoint_reader.h"
     27 #include "tensorflow/core/framework/graph.pb.h"
     28 #include "tensorflow/core/framework/types.h"
     29 #include "tensorflow/core/lib/core/errors.h"
     30 #include "tensorflow/core/lib/strings/str_util.h"
     31 #include "tensorflow/core/platform/env.h"
     32 #include "tensorflow/core/platform/init_main.h"
     33 #include "tensorflow/core/platform/protobuf.h"
     34 #include "tensorflow/core/profiler/internal/advisor/tfprof_advisor.h"
     35 #include "tensorflow/core/profiler/internal/tfprof_stats.h"
     36 #include "tensorflow/core/profiler/internal/tfprof_utils.h"
     37 #include "tensorflow/core/profiler/tfprof_log.pb.h"
     38 #include "tensorflow/core/profiler/tfprof_options.h"
     39 #include "tensorflow/core/protobuf/config.pb.h"
     40 #include "tensorflow/core/util/command_line_flags.h"
     41 
     42 namespace tensorflow {
     43 namespace tfprof {
     44 void completion(const char* buf, linenoiseCompletions* lc) {
     45   string buf_str = buf;
     46   if (buf_str.find(" ") == buf_str.npos) {
     47     for (const char* opt : kCmds) {
     48       if (string(opt).find(buf_str) == 0) {
     49         linenoiseAddCompletion(lc, opt);
     50       }
     51     }
     52     return;
     53   }
     54 
     55   string prefix;
     56   int last_dash = buf_str.find_last_of(' ');
     57   if (last_dash != string::npos) {
     58     prefix = buf_str.substr(0, last_dash + 1);
     59     buf_str = buf_str.substr(last_dash + 1, kint32max);
     60   }
     61   for (const char* opt : kOptions) {
     62     if (string(opt).find(buf_str) == 0) {
     63       linenoiseAddCompletion(lc, (prefix + opt).c_str());
     64     }
     65   }
     66 }
     67 
     68 int Run(int argc, char** argv) {
     69   string FLAGS_profile_path = "";
     70   string FLAGS_graph_path = "";
     71   string FLAGS_run_meta_path = "";
     72   string FLAGS_op_log_path = "";
     73   string FLAGS_checkpoint_path = "";
     74   int32 FLAGS_max_depth = 10;
     75   int64 FLAGS_min_bytes = 0;
     76   int64 FLAGS_min_peak_bytes = 0;
     77   int64 FLAGS_min_residual_bytes = 0;
     78   int64 FLAGS_min_output_bytes = 0;
     79   int64 FLAGS_min_micros = 0;
     80   int64 FLAGS_min_accelerator_micros = 0;
     81   int64 FLAGS_min_cpu_micros = 0;
     82   int64 FLAGS_min_params = 0;
     83   int64 FLAGS_min_float_ops = 0;
     84   int64 FLAGS_min_occurrence = 0;
     85   int64 FLAGS_step = -1;
     86   string FLAGS_order_by = "name";
     87   string FLAGS_account_type_regexes = ".*";
     88   string FLAGS_start_name_regexes = ".*";
     89   string FLAGS_trim_name_regexes = "";
     90   string FLAGS_show_name_regexes = ".*";
     91   string FLAGS_hide_name_regexes;
     92   bool FLAGS_account_displayed_op_only = false;
     93   string FLAGS_select = "micros";
     94   string FLAGS_output = "";
     95   for (int i = 0; i < argc; i++) {
     96     fprintf(stderr, "%s\n", argv[i]);
     97   }
     98 
     99   std::vector<Flag> flag_list = {
    100       Flag("profile_path", &FLAGS_profile_path, "Profile binary file name."),
    101       Flag("graph_path", &FLAGS_graph_path, "GraphDef proto text file name"),
    102       Flag("run_meta_path", &FLAGS_run_meta_path,
    103            "Comma-separated list of RunMetadata proto binary "
    104            "files. Each file is given step number 0,1,2,etc"),
    105       Flag("op_log_path", &FLAGS_op_log_path,
    106            "tensorflow::tfprof::OpLogProto proto binary file name"),
    107       Flag("checkpoint_path", &FLAGS_checkpoint_path,
    108            "TensorFlow Checkpoint file name"),
    109       Flag("max_depth", &FLAGS_max_depth, "max depth"),
    110       Flag("min_bytes", &FLAGS_min_bytes, "min_bytes"),
    111       Flag("min_peak_bytes", &FLAGS_min_peak_bytes, "min_peak_bytes"),
    112       Flag("min_residual_bytes", &FLAGS_min_residual_bytes,
    113            "min_residual_bytes"),
    114       Flag("min_output_bytes", &FLAGS_min_output_bytes, "min_output_bytes"),
    115       Flag("min_micros", &FLAGS_min_micros, "min micros"),
    116       Flag("min_accelerator_micros", &FLAGS_min_accelerator_micros,
    117            "min acclerator_micros"),
    118       Flag("min_cpu_micros", &FLAGS_min_cpu_micros, "min_cpu_micros"),
    119       Flag("min_params", &FLAGS_min_params, "min params"),
    120       Flag("min_float_ops", &FLAGS_min_float_ops, "min float ops"),
    121       Flag("min_occurrence", &FLAGS_min_occurrence, "min occurrence"),
    122       Flag("step", &FLAGS_step,
    123            "The stats of which step to use. By default average"),
    124       Flag("order_by", &FLAGS_order_by, "order by"),
    125       Flag("account_type_regexes", &FLAGS_start_name_regexes,
    126            "start name regexes"),
    127       Flag("trim_name_regexes", &FLAGS_trim_name_regexes, "trim name regexes"),
    128       Flag("show_name_regexes", &FLAGS_show_name_regexes, "show name regexes"),
    129       Flag("hide_name_regexes", &FLAGS_hide_name_regexes, "hide name regexes"),
    130       Flag("account_displayed_op_only", &FLAGS_account_displayed_op_only,
    131            "account displayed op only"),
    132       Flag("select", &FLAGS_select, "select"),
    133       Flag("output", &FLAGS_output, "output"),
    134   };
    135   string usage = Flags::Usage(argv[0], flag_list);
    136   bool parse_ok = Flags::Parse(&argc, argv, flag_list);
    137   if (!parse_ok) {
    138     printf("%s", usage.c_str());
    139     return (2);
    140   }
    141   port::InitMain(argv[0], &argc, &argv);
    142 
    143   if (!FLAGS_profile_path.empty() &&
    144       (!FLAGS_graph_path.empty() || !FLAGS_run_meta_path.empty())) {
    145     fprintf(stderr,
    146             "--profile_path is set, do not set --graph_path or "
    147             "--run_meta_path\n");
    148     return 1;
    149   }
    150 
    151   std::vector<string> account_type_regexes =
    152       str_util::Split(FLAGS_account_type_regexes, ',', str_util::SkipEmpty());
    153   std::vector<string> start_name_regexes =
    154       str_util::Split(FLAGS_start_name_regexes, ',', str_util::SkipEmpty());
    155   std::vector<string> trim_name_regexes =
    156       str_util::Split(FLAGS_trim_name_regexes, ',', str_util::SkipEmpty());
    157   std::vector<string> show_name_regexes =
    158       str_util::Split(FLAGS_show_name_regexes, ',', str_util::SkipEmpty());
    159   std::vector<string> hide_name_regexes =
    160       str_util::Split(FLAGS_hide_name_regexes, ',', str_util::SkipEmpty());
    161   std::vector<string> select =
    162       str_util::Split(FLAGS_select, ',', str_util::SkipEmpty());
    163 
    164   string output_type;
    165   std::map<string, string> output_options;
    166   Status s = ParseOutput(FLAGS_output, &output_type, &output_options);
    167   CHECK(s.ok()) << s.ToString();
    168 
    169   string cmd = "";
    170   if (argc == 1 && FLAGS_graph_path.empty() && FLAGS_profile_path.empty() &&
    171       FLAGS_run_meta_path.empty()) {
    172     PrintHelp();
    173     return 0;
    174   } else if (argc > 1) {
    175     if (string(argv[1]) == kCmds[6]) {
    176       PrintHelp();
    177       return 0;
    178     }
    179     if (string(argv[1]) == kCmds[0] || string(argv[1]) == kCmds[1] ||
    180         string(argv[1]) == kCmds[2] || string(argv[1]) == kCmds[3] ||
    181         string(argv[1]) == kCmds[4]) {
    182       cmd = argv[1];
    183     }
    184   }
    185 
    186   printf("Reading Files...\n");
    187   std::unique_ptr<checkpoint::CheckpointReader> ckpt_reader;
    188   TF_Status* status = TF_NewStatus();
    189   if (!FLAGS_checkpoint_path.empty()) {
    190     ckpt_reader.reset(
    191         new checkpoint::CheckpointReader(FLAGS_checkpoint_path, status));
    192     if (TF_GetCode(status) != TF_OK) {
    193       fprintf(stderr, "%s\n", TF_Message(status));
    194       TF_DeleteStatus(status);
    195       return 1;
    196     }
    197     TF_DeleteStatus(status);
    198   }
    199 
    200   std::unique_ptr<TFStats> tf_stat;
    201   if (!FLAGS_profile_path.empty()) {
    202     tf_stat.reset(new TFStats(FLAGS_profile_path, std::move(ckpt_reader)));
    203   } else {
    204     printf(
    205         "Try to use a single --profile_path instead of "
    206         "graph_path,op_log_path,run_meta_path\n");
    207     std::unique_ptr<GraphDef> graph(new GraphDef());
    208     if (!FLAGS_graph_path.empty()) {
    209       s = ReadProtoFile(Env::Default(), FLAGS_graph_path, graph.get(), false);
    210       if (!s.ok()) {
    211         fprintf(stderr, "Failed to read graph_path: %s\n",
    212                 s.ToString().c_str());
    213         return 1;
    214       }
    215     }
    216 
    217     std::unique_ptr<OpLogProto> op_log(new OpLogProto());
    218     if (!FLAGS_op_log_path.empty()) {
    219       string op_log_str;
    220       s = ReadFileToString(Env::Default(), FLAGS_op_log_path, &op_log_str);
    221       if (!s.ok()) {
    222         fprintf(stderr, "Failed to read op_log_path: %s\n",
    223                 s.ToString().c_str());
    224         return 1;
    225       }
    226       if (!ParseProtoUnlimited(op_log.get(), op_log_str)) {
    227         fprintf(stderr, "Failed to parse op_log_path\n");
    228         return 1;
    229       }
    230     }
    231     tf_stat.reset(new TFStats(std::move(graph), nullptr, std::move(op_log),
    232                               std::move(ckpt_reader)));
    233 
    234     std::vector<string> run_meta_files =
    235         str_util::Split(FLAGS_run_meta_path, ',', str_util::SkipEmpty());
    236     for (int i = 0; i < run_meta_files.size(); ++i) {
    237       std::unique_ptr<RunMetadata> run_meta(new RunMetadata());
    238       s = ReadProtoFile(Env::Default(), run_meta_files[i], run_meta.get(),
    239                         true);
    240       if (!s.ok()) {
    241         fprintf(stderr, "Failed to read run_meta_path %s. Status: %s\n",
    242                 run_meta_files[i].c_str(), s.ToString().c_str());
    243         return 1;
    244       }
    245       tf_stat->AddRunMeta(i, std::move(run_meta));
    246       fprintf(stdout, "run graph coverage: %.2f\n", tf_stat->run_coverage());
    247     }
    248   }
    249 
    250   if (cmd == kCmds[4]) {
    251     tf_stat->BuildAllViews();
    252     Advisor(tf_stat.get()).Advise(Advisor::DefaultOptions());
    253     return 0;
    254   }
    255 
    256   Options opts(
    257       FLAGS_max_depth, FLAGS_min_bytes, FLAGS_min_peak_bytes,
    258       FLAGS_min_residual_bytes, FLAGS_min_output_bytes, FLAGS_min_micros,
    259       FLAGS_min_accelerator_micros, FLAGS_min_cpu_micros, FLAGS_min_params,
    260       FLAGS_min_float_ops, FLAGS_min_occurrence, FLAGS_step, FLAGS_order_by,
    261       account_type_regexes, start_name_regexes, trim_name_regexes,
    262       show_name_regexes, hide_name_regexes, FLAGS_account_displayed_op_only,
    263       select, output_type, output_options);
    264 
    265   if (cmd == kCmds[2] || cmd == kCmds[3]) {
    266     tf_stat->BuildView(cmd);
    267     tf_stat->ShowMultiGraphNode(cmd, opts);
    268     return 0;
    269   } else if (cmd == kCmds[0] || cmd == kCmds[1]) {
    270     tf_stat->BuildView(cmd);
    271     tf_stat->ShowGraphNode(cmd, opts);
    272     return 0;
    273   }
    274 
    275   linenoiseSetCompletionCallback(completion);
    276   linenoiseHistoryLoad(".tfprof_history.txt");
    277 
    278   bool looped = false;
    279   while (true) {
    280     char* line = linenoise("tfprof> ");
    281     if (line == nullptr) {
    282       if (!looped) {
    283         fprintf(stderr,
    284                 "Cannot start interative shell, "
    285                 "use 'bazel-bin' instead of 'bazel run'.\n");
    286       }
    287       break;
    288     }
    289     looped = true;
    290     string line_s = line;
    291     free(line);
    292 
    293     if (line_s.empty()) {
    294       printf("%s", opts.ToString().c_str());
    295       continue;
    296     }
    297     linenoiseHistoryAdd(line_s.c_str());
    298     linenoiseHistorySave(".tfprof_history.txt");
    299 
    300     Options new_opts = opts;
    301     Status s = ParseCmdLine(line_s, &cmd, &new_opts);
    302     if (!s.ok()) {
    303       fprintf(stderr, "E: %s\n", s.ToString().c_str());
    304       continue;
    305     }
    306     if (cmd == kCmds[5]) {
    307       opts = new_opts;
    308     } else if (cmd == kCmds[6]) {
    309       PrintHelp();
    310     } else if (cmd == kCmds[2] || cmd == kCmds[3]) {
    311       tf_stat->BuildView(cmd);
    312       tf_stat->ShowMultiGraphNode(cmd, new_opts);
    313     } else if (cmd == kCmds[0] || cmd == kCmds[1]) {
    314       tf_stat->BuildView(cmd);
    315       tf_stat->ShowGraphNode(cmd, new_opts);
    316     } else if (cmd == kCmds[4]) {
    317       tf_stat->BuildAllViews();
    318       Advisor(tf_stat.get()).Advise(Advisor::DefaultOptions());
    319     }
    320   }
    321   return 0;
    322 }
    323 }  // namespace tfprof
    324 }  // namespace tensorflow
    325 
    326 int main(int argc, char** argv) { return tensorflow::tfprof::Run(argc, argv); }
    327