Home | History | Annotate | Download | only in profiler
      1 /* Copyright 2017 The TensorFlow Authors All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // Usage: capture_tpu_profile --service_addr="localhost:8466" --logdir=/tmp/log
     17 //
     18 // Initiates a TPU profiling on the TPUProfiler service at service_addr,
     19 // receives and dumps the profile data to a tensorboard log directory.
     20 
     21 #include "grpc++/grpc++.h"
     22 
     23 #include <cstdio>
     24 #include <ctime>
     25 #include <vector>
     26 
     27 #include "tensorflow/contrib/tpu/profiler/dump_tpu_profile.h"
     28 #include "tensorflow/contrib/tpu/profiler/tpu_profiler.grpc.pb.h"
     29 #include "tensorflow/contrib/tpu/profiler/version.h"
     30 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
     31 #include "tensorflow/core/lib/core/errors.h"
     32 #include "tensorflow/core/lib/strings/str_util.h"
     33 #include "tensorflow/core/platform/init_main.h"
     34 #include "tensorflow/core/util/command_line_flags.h"
     35 
     36 namespace tensorflow {
     37 namespace tpu {
     38 namespace {
     39 
     40 using ::tensorflow::TPUProfiler;
     41 
     42 constexpr uint64 kMaxEvents = 1000000;
     43 
     44 string GetCurrentTimeStampAsString() {
     45   char s[128];
     46   std::time_t t = std::time(nullptr);
     47   CHECK_NE(std::strftime(s, sizeof(s), "%F_%T", std::localtime(&t)), 0);
     48   return s;
     49 }
     50 
     51 Status ValidateHostPortPair(const string& host_port) {
     52   uint32 port;
     53   std::vector<string> parts = str_util::Split(host_port, ':');
     54   // Must be host:port, port must be a number, host must not contain a '/',
     55   // host also must not be empty.
     56   if (parts.size() != 2 || !strings::safe_strtou32(parts[1], &port) ||
     57       parts[0].find("/") != string::npos || parts[0].empty()) {
     58     return errors::InvalidArgument("Could not interpret \"", host_port,
     59                                    "\" as a host-port pair.");
     60   }
     61   return Status::OK();
     62 }
     63 
     64 ProfileResponse Profile(const string& service_addr, int duration_ms,
     65                         const ProfileOptions& opts) {
     66   ProfileRequest request;
     67   request.set_duration_ms(duration_ms);
     68   request.set_max_events(kMaxEvents);
     69   request.add_tools("input_pipeline");
     70   request.add_tools("overview_page");
     71   *request.mutable_opts() = opts;
     72   std::cout << "Limiting the number of trace events to " << kMaxEvents
     73             << std::endl;
     74   ::grpc::ClientContext context;
     75   ::grpc::ChannelArguments channel_args;
     76   // TODO(ioeric): use `SetMaxReceiveMessageSize` instead once it's available.
     77   // TODO(qiuminxu): use `NewHostPortGrpcChannel` instead once their
     78   // `ValidateHostPortPair` checks for empty host string case.
     79   channel_args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH,
     80                       std::numeric_limits<int32>::max());
     81   std::unique_ptr<TPUProfiler::Stub> stub =
     82       TPUProfiler::NewStub(::grpc::CreateCustomChannel(
     83           "dns:///" + service_addr, ::grpc::InsecureChannelCredentials(),
     84           channel_args));
     85   ProfileResponse response;
     86   TF_QCHECK_OK(FromGrpcStatus(stub->Profile(&context, request, &response)));
     87   return response;
     88 }
     89 
     90 }  // namespace
     91 }  // namespace tpu
     92 }  // namespace tensorflow
     93 
     94 int main(int argc, char** argv) {
     95   tensorflow::string FLAGS_service_addr;
     96   tensorflow::string FLAGS_logdir;
     97   int FLAGS_duration_ms = 2000;
     98   int FLAGS_num_tracing_attempts = 3;
     99   bool FLAGS_include_dataset_ops = true;
    100   std::vector<tensorflow::Flag> flag_list = {
    101       tensorflow::Flag("service_addr", &FLAGS_service_addr,
    102                        "Address of TPU profiler service e.g. localhost:8466"),
    103       tensorflow::Flag("logdir", &FLAGS_logdir,
    104                        "Path of TensorBoard log directory e.g. /tmp/tb_log, "
    105                        "gs://tb_bucket"),
    106       tensorflow::Flag("duration_ms", &FLAGS_duration_ms,
    107                        "Duration of tracing in ms. Default is 2000ms."),
    108       tensorflow::Flag("num_tracing_attempts", &FLAGS_num_tracing_attempts,
    109                        "Automatically retry N times when no trace event "
    110                        "is collected. Default is 3."),
    111       tensorflow::Flag("include_dataset_ops", &FLAGS_include_dataset_ops,
    112                        "Set to false to profile longer TPU device traces."),
    113   };
    114 
    115   std::cout << "Welcome to the Cloud TPU Profiler v" << TPU_PROFILER_VERSION
    116             << std::endl;
    117 
    118   tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
    119   bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
    120   if (!parse_ok || FLAGS_service_addr.empty() || FLAGS_logdir.empty()) {
    121     std::cout << usage.c_str() << std::endl;
    122     return 2;
    123   }
    124   tensorflow::Status status =
    125       tensorflow::tpu::ValidateHostPortPair(FLAGS_service_addr);
    126   if (!status.ok()) {
    127     std::cout << status.error_message() << std::endl;
    128     std::cout << usage.c_str() << std::endl;
    129     return 2;
    130   }
    131   tensorflow::port::InitMain(argv[0], &argc, &argv);
    132 
    133   // Sets the minimum duration_ms and tracing attempts to one.
    134   int duration_ms = std::max(FLAGS_duration_ms, 1);
    135   int remaining_attempts = std::max(FLAGS_num_tracing_attempts, 1);
    136   tensorflow::ProfileOptions opts;
    137   opts.set_include_dataset_ops(FLAGS_include_dataset_ops);
    138   tensorflow::ProfileResponse response;
    139 
    140   while (true) {
    141     std::cout << "Starting to profile TPU traces for " << duration_ms << " ms. "
    142               << "Remaining attempt(s): " << remaining_attempts-- << std::endl;
    143     response = tensorflow::tpu::Profile(FLAGS_service_addr, duration_ms, opts);
    144     if (remaining_attempts <= 0 || !response.encoded_trace().empty()) break;
    145     std::cout << "No trace event is collected. Automatically retrying."
    146               << std::endl
    147               << std::endl;
    148   }
    149 
    150   if (response.encoded_trace().empty()) {
    151     std::cout << "No trace event is collected after "
    152               << FLAGS_num_tracing_attempts << " attempt(s). "
    153               << "Perhaps, you want to try again (with more attempts?)."
    154               << std::endl
    155               << "Tip: increase number of attempts with --num_tracing_attempts."
    156               << std::endl;
    157     // Don't dump profile data if no trace is collected.
    158     return 0;
    159   }
    160 
    161   // Use the current timestamp as the run name.
    162   tensorflow::string run = tensorflow::tpu::GetCurrentTimeStampAsString();
    163   TF_CHECK_OK(tensorflow::tpu::WriteTensorboardTPUProfile(
    164       FLAGS_logdir, run, response, &std::cout));
    165   // Print this at the end so that it's not buried in irrelevant LOG messages.
    166   std::cout
    167       << "NOTE: using the trace duration " << duration_ms << "ms." << std::endl
    168       << "Set an appropriate duration (with --duration_ms) if you "
    169          "don't see a full step in your trace or the captured trace is too "
    170          "large."
    171       << std::endl;
    172 }
    173