1 /* Copyright 2017 The TensorFlow Authors All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Usage: capture_tpu_profile --service_addr="localhost:8466" --logdir=/tmp/log 17 // 18 // Initiates a TPU profiling on the TPUProfiler service at service_addr, 19 // receives and dumps the profile data to a tensorboard log directory. 20 21 #include "grpc++/grpc++.h" 22 23 #include <cstdio> 24 #include <ctime> 25 #include <vector> 26 27 #include "tensorflow/contrib/tpu/profiler/dump_tpu_profile.h" 28 #include "tensorflow/contrib/tpu/profiler/tpu_profiler.grpc.pb.h" 29 #include "tensorflow/contrib/tpu/profiler/version.h" 30 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" 31 #include "tensorflow/core/lib/core/errors.h" 32 #include "tensorflow/core/lib/strings/str_util.h" 33 #include "tensorflow/core/platform/init_main.h" 34 #include "tensorflow/core/util/command_line_flags.h" 35 36 namespace tensorflow { 37 namespace tpu { 38 namespace { 39 40 using ::tensorflow::TPUProfiler; 41 42 constexpr uint64 kMaxEvents = 1000000; 43 44 string GetCurrentTimeStampAsString() { 45 char s[128]; 46 std::time_t t = std::time(nullptr); 47 CHECK_NE(std::strftime(s, sizeof(s), "%F_%T", std::localtime(&t)), 0); 48 return s; 49 } 50 51 Status ValidateHostPortPair(const string& host_port) { 52 uint32 port; 53 std::vector<string> parts = str_util::Split(host_port, ':'); 54 // Must be host:port, port must be a number, host must not contain a '/', 55 // host also must not be empty. 56 if (parts.size() != 2 || !strings::safe_strtou32(parts[1], &port) || 57 parts[0].find("/") != string::npos || parts[0].empty()) { 58 return errors::InvalidArgument("Could not interpret \"", host_port, 59 "\" as a host-port pair."); 60 } 61 return Status::OK(); 62 } 63 64 ProfileResponse Profile(const string& service_addr, int duration_ms, 65 const ProfileOptions& opts) { 66 ProfileRequest request; 67 request.set_duration_ms(duration_ms); 68 request.set_max_events(kMaxEvents); 69 request.add_tools("input_pipeline"); 70 request.add_tools("overview_page"); 71 *request.mutable_opts() = opts; 72 std::cout << "Limiting the number of trace events to " << kMaxEvents 73 << std::endl; 74 ::grpc::ClientContext context; 75 ::grpc::ChannelArguments channel_args; 76 // TODO(ioeric): use `SetMaxReceiveMessageSize` instead once it's available. 77 // TODO(qiuminxu): use `NewHostPortGrpcChannel` instead once their 78 // `ValidateHostPortPair` checks for empty host string case. 79 channel_args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, 80 std::numeric_limits<int32>::max()); 81 std::unique_ptr<TPUProfiler::Stub> stub = 82 TPUProfiler::NewStub(::grpc::CreateCustomChannel( 83 "dns:///" + service_addr, ::grpc::InsecureChannelCredentials(), 84 channel_args)); 85 ProfileResponse response; 86 TF_QCHECK_OK(FromGrpcStatus(stub->Profile(&context, request, &response))); 87 return response; 88 } 89 90 } // namespace 91 } // namespace tpu 92 } // namespace tensorflow 93 94 int main(int argc, char** argv) { 95 tensorflow::string FLAGS_service_addr; 96 tensorflow::string FLAGS_logdir; 97 int FLAGS_duration_ms = 2000; 98 int FLAGS_num_tracing_attempts = 3; 99 bool FLAGS_include_dataset_ops = true; 100 std::vector<tensorflow::Flag> flag_list = { 101 tensorflow::Flag("service_addr", &FLAGS_service_addr, 102 "Address of TPU profiler service e.g. localhost:8466"), 103 tensorflow::Flag("logdir", &FLAGS_logdir, 104 "Path of TensorBoard log directory e.g. /tmp/tb_log, " 105 "gs://tb_bucket"), 106 tensorflow::Flag("duration_ms", &FLAGS_duration_ms, 107 "Duration of tracing in ms. Default is 2000ms."), 108 tensorflow::Flag("num_tracing_attempts", &FLAGS_num_tracing_attempts, 109 "Automatically retry N times when no trace event " 110 "is collected. Default is 3."), 111 tensorflow::Flag("include_dataset_ops", &FLAGS_include_dataset_ops, 112 "Set to false to profile longer TPU device traces."), 113 }; 114 115 std::cout << "Welcome to the Cloud TPU Profiler v" << TPU_PROFILER_VERSION 116 << std::endl; 117 118 tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list); 119 bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); 120 if (!parse_ok || FLAGS_service_addr.empty() || FLAGS_logdir.empty()) { 121 std::cout << usage.c_str() << std::endl; 122 return 2; 123 } 124 tensorflow::Status status = 125 tensorflow::tpu::ValidateHostPortPair(FLAGS_service_addr); 126 if (!status.ok()) { 127 std::cout << status.error_message() << std::endl; 128 std::cout << usage.c_str() << std::endl; 129 return 2; 130 } 131 tensorflow::port::InitMain(argv[0], &argc, &argv); 132 133 // Sets the minimum duration_ms and tracing attempts to one. 134 int duration_ms = std::max(FLAGS_duration_ms, 1); 135 int remaining_attempts = std::max(FLAGS_num_tracing_attempts, 1); 136 tensorflow::ProfileOptions opts; 137 opts.set_include_dataset_ops(FLAGS_include_dataset_ops); 138 tensorflow::ProfileResponse response; 139 140 while (true) { 141 std::cout << "Starting to profile TPU traces for " << duration_ms << " ms. " 142 << "Remaining attempt(s): " << remaining_attempts-- << std::endl; 143 response = tensorflow::tpu::Profile(FLAGS_service_addr, duration_ms, opts); 144 if (remaining_attempts <= 0 || !response.encoded_trace().empty()) break; 145 std::cout << "No trace event is collected. Automatically retrying." 146 << std::endl 147 << std::endl; 148 } 149 150 if (response.encoded_trace().empty()) { 151 std::cout << "No trace event is collected after " 152 << FLAGS_num_tracing_attempts << " attempt(s). " 153 << "Perhaps, you want to try again (with more attempts?)." 154 << std::endl 155 << "Tip: increase number of attempts with --num_tracing_attempts." 156 << std::endl; 157 // Don't dump profile data if no trace is collected. 158 return 0; 159 } 160 161 // Use the current timestamp as the run name. 162 tensorflow::string run = tensorflow::tpu::GetCurrentTimeStampAsString(); 163 TF_CHECK_OK(tensorflow::tpu::WriteTensorboardTPUProfile( 164 FLAGS_logdir, run, response, &std::cout)); 165 // Print this at the end so that it's not buried in irrelevant LOG messages. 166 std::cout 167 << "NOTE: using the trace duration " << duration_ms << "ms." << std::endl 168 << "Set an appropriate duration (with --duration_ms) if you " 169 "don't see a full step in your trace or the captured trace is too " 170 "large." 171 << std::endl; 172 } 173