Home | History | Annotate | Download | only in debug
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/debug/debug_io_utils.h"
     17 
     18 #include <stddef.h>
     19 #include <string.h>
     20 #include <cmath>
     21 #include <limits>
     22 #include <utility>
     23 #include <vector>
     24 
     25 #ifndef PLATFORM_WINDOWS
     26 #include "grpc++/create_channel.h"
     27 #else
     28 // winsock2.h is used in grpc, so Ws2_32.lib is needed
     29 #pragma comment(lib, "Ws2_32.lib")
     30 #endif  // #ifndef PLATFORM_WINDOWS
     31 
     32 #include "tensorflow/core/debug/debug_callback_registry.h"
     33 #include "tensorflow/core/debug/debugger_event_metadata.pb.h"
     34 #include "tensorflow/core/framework/graph.pb.h"
     35 #include "tensorflow/core/framework/summary.pb.h"
     36 #include "tensorflow/core/framework/tensor_shape.pb.h"
     37 #include "tensorflow/core/lib/core/bits.h"
     38 #include "tensorflow/core/lib/hash/hash.h"
     39 #include "tensorflow/core/lib/io/path.h"
     40 #include "tensorflow/core/lib/strings/str_util.h"
     41 #include "tensorflow/core/lib/strings/stringprintf.h"
     42 #include "tensorflow/core/platform/protobuf.h"
     43 #include "tensorflow/core/util/event.pb.h"
     44 
     45 #define GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR \
     46   return errors::Unimplemented(              \
     47       kGrpcURLScheme, " debug URL scheme is not implemented on Windows yet.")
     48 
     49 namespace tensorflow {
     50 
     51 namespace {
     52 
     53 // Creates an Event proto representing a chunk of a Tensor. This method only
     54 // populates the field of the Event proto that represent the envelope
     55 // informaion (e.g., timestmap, device_name, num_chunks, chunk_index, dtype,
     56 // shape). It does not set the value.tensor field, which should be set by the
     57 // caller separately.
     58 Event PrepareChunkEventProto(const DebugNodeKey& debug_node_key,
     59                              const uint64 wall_time_us, const size_t num_chunks,
     60                              const size_t chunk_index,
     61                              const DataType& tensor_dtype,
     62                              const TensorShapeProto& tensor_shape) {
     63   Event event;
     64   event.set_wall_time(static_cast<double>(wall_time_us));
     65   Summary::Value* value = event.mutable_summary()->add_value();
     66 
     67   // Create the debug node_name in the Summary proto.
     68   // For example, if tensor_name = "foo/node_a:0", and the debug_op is
     69   // "DebugIdentity", the debug node_name in the Summary proto will be
     70   // "foo/node_a:0:DebugIdentity".
     71   value->set_node_name(debug_node_key.debug_node_name);
     72 
     73   // Tag by the node name. This allows TensorBoard to quickly fetch data
     74   // per op.
     75   value->set_tag(debug_node_key.node_name);
     76 
     77   // Store data within debugger metadata to be stored for each event.
     78   third_party::tensorflow::core::debug::DebuggerEventMetadata metadata;
     79   metadata.set_device(debug_node_key.device_name);
     80   metadata.set_output_slot(debug_node_key.output_slot);
     81   metadata.set_num_chunks(num_chunks);
     82   metadata.set_chunk_index(chunk_index);
     83 
     84   // Encode the data in JSON.
     85   string json_output;
     86   tensorflow::protobuf::util::JsonPrintOptions json_options;
     87   json_options.always_print_primitive_fields = true;
     88   auto status = tensorflow::protobuf::util::MessageToJsonString(
     89       metadata, &json_output, json_options);
     90   if (status.ok()) {
     91     // Store summary metadata. Set the plugin to use this data as "debugger".
     92     SummaryMetadata::PluginData* plugin_data =
     93         value->mutable_metadata()->mutable_plugin_data();
     94     plugin_data->set_plugin_name(DebugIO::kDebuggerPluginName);
     95     plugin_data->set_content(json_output);
     96   } else {
     97     LOG(WARNING) << "Failed to convert DebuggerEventMetadata proto to JSON. "
     98                  << "The debug_node_name is " << debug_node_key.debug_node_name
     99                  << ".";
    100   }
    101 
    102   value->mutable_tensor()->set_dtype(tensor_dtype);
    103   *value->mutable_tensor()->mutable_tensor_shape() = tensor_shape;
    104 
    105   return event;
    106 }
    107 
    108 // Translates the length of a string to number of bytes when the string is
    109 // encoded as bytes in protobuf. Note that this makes a conservative estimate
    110 // (i.e., an estimate that is usually too large, but never too small under the
    111 // gRPC message size limit) of the Varint-encoded length, to workaround the lack
    112 // of a portable length function.
    113 const size_t StringValMaxBytesInProto(const string& str) {
    114 #if defined(PLATFORM_GOOGLE)
    115   return str.size() + DebugGrpcIO::kGrpcMaxVarintLengthSize;
    116 #else
    117   return str.size();
    118 #endif
    119 }
    120 
    121 // Breaks a string Tensor (represented as a TensorProto) as a vector of Event
    122 // protos.
    123 Status WrapStringTensorAsEvents(const DebugNodeKey& debug_node_key,
    124                                 const uint64 wall_time_us,
    125                                 const size_t chunk_size_limit,
    126                                 TensorProto* tensor_proto,
    127                                 std::vector<Event>* events) {
    128   const protobuf::RepeatedPtrField<string>& strs = tensor_proto->string_val();
    129   const size_t num_strs = strs.size();
    130   const size_t chunk_size_ub = chunk_size_limit > 0
    131                                    ? chunk_size_limit
    132                                    : std::numeric_limits<size_t>::max();
    133 
    134   // E.g., if cutoffs is {j, k, l}, the chunks will have index ranges:
    135   //   [0:a), [a:b), [c:<end>].
    136   std::vector<size_t> cutoffs;
    137   size_t chunk_size = 0;
    138   for (size_t i = 0; i < num_strs; ++i) {
    139     // Take into account the extra bytes in proto buffer.
    140     if (StringValMaxBytesInProto(strs[i]) > chunk_size_ub) {
    141       return errors::FailedPrecondition(
    142           "string value at index ", i, " from debug node ",
    143           debug_node_key.debug_node_name,
    144           " does not fit gRPC message size limit (", chunk_size_ub, ")");
    145     }
    146     if (chunk_size + StringValMaxBytesInProto(strs[i]) > chunk_size_ub) {
    147       cutoffs.push_back(i);
    148       chunk_size = 0;
    149     }
    150     chunk_size += StringValMaxBytesInProto(strs[i]);
    151   }
    152   cutoffs.push_back(num_strs);
    153   const size_t num_chunks = cutoffs.size();
    154 
    155   for (size_t i = 0; i < num_chunks; ++i) {
    156     Event event = PrepareChunkEventProto(debug_node_key, wall_time_us,
    157                                          num_chunks, i, tensor_proto->dtype(),
    158                                          tensor_proto->tensor_shape());
    159     Summary::Value* value = event.mutable_summary()->mutable_value(0);
    160 
    161     if (cutoffs.size() == 1) {
    162       value->mutable_tensor()->mutable_string_val()->Swap(
    163           tensor_proto->mutable_string_val());
    164     } else {
    165       const size_t begin = (i == 0) ? 0 : cutoffs[i - 1];
    166       const size_t end = cutoffs[i];
    167       for (size_t j = begin; j < end; ++j) {
    168         value->mutable_tensor()->add_string_val(strs[j]);
    169       }
    170     }
    171 
    172     events->push_back(std::move(event));
    173   }
    174 
    175   return Status::OK();
    176 }
    177 
    178 // Encapsulates the tensor value inside a vector of Event protos. Large tensors
    179 // are broken up to multiple protos to fit the chunk_size_limit. In each Event
    180 // proto the field summary.tensor carries the content of the tensor.
    181 // If chunk_size_limit <= 0, the tensor will not be broken into chunks, i.e., a
    182 // length-1 vector will be returned, regardless of the size of the tensor.
    183 Status WrapTensorAsEvents(const DebugNodeKey& debug_node_key,
    184                           const Tensor& tensor, const uint64 wall_time_us,
    185                           const size_t chunk_size_limit,
    186                           std::vector<Event>* events) {
    187   TensorProto tensor_proto;
    188   if (tensor.dtype() == DT_STRING) {
    189     // Treat DT_STRING specially, so that tensor_util.MakeNdarray in Python can
    190     // convert the TensorProto to string-type numpy array. MakeNdarray does not
    191     // work with strings encoded by AsProtoTensorContent() in tensor_content.
    192     tensor.AsProtoField(&tensor_proto);
    193 
    194     TF_RETURN_IF_ERROR(WrapStringTensorAsEvents(
    195         debug_node_key, wall_time_us, chunk_size_limit, &tensor_proto, events));
    196   } else {
    197     tensor.AsProtoTensorContent(&tensor_proto);
    198 
    199     const size_t total_length = tensor_proto.tensor_content().size();
    200     const size_t chunk_size_ub =
    201         chunk_size_limit > 0 ? chunk_size_limit : total_length;
    202     const size_t num_chunks =
    203         (total_length == 0)
    204             ? 1
    205             : (total_length + chunk_size_ub - 1) / chunk_size_ub;
    206     for (size_t i = 0; i < num_chunks; ++i) {
    207       const size_t pos = i * chunk_size_ub;
    208       const size_t len =
    209           (i == num_chunks - 1) ? (total_length - pos) : chunk_size_ub;
    210       Event event = PrepareChunkEventProto(debug_node_key, wall_time_us,
    211                                            num_chunks, i, tensor_proto.dtype(),
    212                                            tensor_proto.tensor_shape());
    213       event.mutable_summary()
    214           ->mutable_value(0)
    215           ->mutable_tensor()
    216           ->set_tensor_content(tensor_proto.tensor_content().substr(pos, len));
    217       events->push_back(std::move(event));
    218     }
    219   }
    220 
    221   return Status::OK();
    222 }
    223 
    224 // Appends an underscore and a timestamp to a file path. If the path already
    225 // exists on the file system, append a hyphen and a 1-up index. Consecutive
    226 // values of the index will be tried until the first unused one is found.
    227 // TOCTOU race condition is not of concern here due to the fact that tfdbg
    228 // sets parallel_iterations attribute of all while_loops to 1 to prevent
    229 // the same node from between executed multiple times concurrently.
    230 string AppendTimestampToFilePath(const string& in, const uint64 timestamp) {
    231   string out = strings::StrCat(in, "_", timestamp);
    232 
    233   uint64 i = 1;
    234   while (Env::Default()->FileExists(out).ok()) {
    235     out = strings::StrCat(in, "_", timestamp, "-", i);
    236     ++i;
    237   }
    238   return out;
    239 }
    240 
    241 #ifndef PLATFORM_WINDOWS
    242 // Publishes encoded GraphDef through a gRPC debugger stream, in chunks,
    243 // conforming to the gRPC message size limit.
    244 Status PublishEncodedGraphDefInChunks(const string& encoded_graph_def,
    245                                       const string& device_name,
    246                                       const int64 wall_time,
    247                                       const string& debug_url) {
    248   const uint64 hash = ::tensorflow::Hash64(encoded_graph_def);
    249   const size_t total_length = encoded_graph_def.size();
    250   const size_t num_chunks =
    251       static_cast<size_t>(std::ceil(static_cast<float>(total_length) /
    252                                     DebugGrpcIO::kGrpcMessageSizeLimitBytes));
    253   for (size_t i = 0; i < num_chunks; ++i) {
    254     const size_t pos = i * DebugGrpcIO::kGrpcMessageSizeLimitBytes;
    255     const size_t len = (i == num_chunks - 1)
    256                            ? (total_length - pos)
    257                            : DebugGrpcIO::kGrpcMessageSizeLimitBytes;
    258     Event event;
    259     event.set_wall_time(static_cast<double>(wall_time));
    260     // Prefix the chunk with
    261     //   <hash64>,<device_name>,<wall_time>|<index>|<num_chunks>|.
    262     // TODO(cais): Use DebuggerEventMetadata to store device_name, num_chunks
    263     // and chunk_index, instead.
    264     event.set_graph_def(strings::StrCat(hash, ",", device_name, ",", wall_time,
    265                                         "|", i, "|", num_chunks, "|",
    266                                         encoded_graph_def.substr(pos, len)));
    267     const Status s = DebugGrpcIO::SendEventProtoThroughGrpcStream(
    268         event, debug_url, num_chunks - 1 == i);
    269     if (!s.ok()) {
    270       return errors::FailedPrecondition(
    271           "Failed to send chunk ", i, " of ", num_chunks,
    272           " of encoded GraphDef of size ", encoded_graph_def.size(), " bytes, ",
    273           "due to: ", s.error_message());
    274     }
    275   }
    276   return Status::OK();
    277 }
    278 #endif  // #ifndef PLATFORM_WINDOWS
    279 
    280 }  // namespace
    281 
    282 const char* const DebugIO::kDebuggerPluginName = "debugger";
    283 
    284 const char* const DebugIO::kCoreMetadataTag = "core_metadata_";
    285 
    286 const char* const DebugIO::kGraphTag = "graph_";
    287 
    288 const char* const DebugIO::kHashTag = "hash";
    289 
    290 Status ReadEventFromFile(const string& dump_file_path, Event* event) {
    291   Env* env(Env::Default());
    292 
    293   string content;
    294   uint64 file_size = 0;
    295 
    296   Status s = env->GetFileSize(dump_file_path, &file_size);
    297   if (!s.ok()) {
    298     return s;
    299   }
    300 
    301   content.resize(file_size);
    302 
    303   std::unique_ptr<RandomAccessFile> file;
    304   s = env->NewRandomAccessFile(dump_file_path, &file);
    305   if (!s.ok()) {
    306     return s;
    307   }
    308 
    309   StringPiece result;
    310   s = file->Read(0, file_size, &result, &(content)[0]);
    311   if (!s.ok()) {
    312     return s;
    313   }
    314 
    315   event->ParseFromString(content);
    316   return Status::OK();
    317 }
    318 
    319 const char* const DebugIO::kFileURLScheme = "file://";
    320 const char* const DebugIO::kGrpcURLScheme = "grpc://";
    321 const char* const DebugIO::kMemoryURLScheme = "memcbk://";
    322 
    323 // Publishes debug metadata to a set of debug URLs.
    324 Status DebugIO::PublishDebugMetadata(
    325     const int64 global_step, const int64 session_run_index,
    326     const int64 executor_step_index, const std::vector<string>& input_names,
    327     const std::vector<string>& output_names,
    328     const std::vector<string>& target_nodes,
    329     const std::unordered_set<string>& debug_urls) {
    330   std::ostringstream oss;
    331 
    332   // Construct a JSON string to carry the metadata.
    333   oss << "{";
    334   oss << "\"global_step\":" << global_step << ",";
    335   oss << "\"session_run_index\":" << session_run_index << ",";
    336   oss << "\"executor_step_index\":" << executor_step_index << ",";
    337   oss << "\"input_names\":[";
    338   for (size_t i = 0; i < input_names.size(); ++i) {
    339     oss << "\"" << input_names[i] << "\"";
    340     if (i < input_names.size() - 1) {
    341       oss << ",";
    342     }
    343   }
    344   oss << "],";
    345   oss << "\"output_names\":[";
    346   for (size_t i = 0; i < output_names.size(); ++i) {
    347     oss << "\"" << output_names[i] << "\"";
    348     if (i < output_names.size() - 1) {
    349       oss << ",";
    350     }
    351   }
    352   oss << "],";
    353   oss << "\"target_nodes\":[";
    354   for (size_t i = 0; i < target_nodes.size(); ++i) {
    355     oss << "\"" << target_nodes[i] << "\"";
    356     if (i < target_nodes.size() - 1) {
    357       oss << ",";
    358     }
    359   }
    360   oss << "]";
    361   oss << "}";
    362 
    363   const string json_metadata = oss.str();
    364   Event event;
    365   event.set_wall_time(static_cast<double>(Env::Default()->NowMicros()));
    366   LogMessage* log_message = event.mutable_log_message();
    367   log_message->set_message(json_metadata);
    368 
    369   Status status;
    370   for (const string& url : debug_urls) {
    371     if (str_util::Lowercase(url).find(kGrpcURLScheme) == 0) {
    372 #ifndef PLATFORM_WINDOWS
    373       Event grpc_event;
    374 
    375       // Determine the path (if any) in the grpc:// URL, and add it as a field
    376       // of the JSON string.
    377       const string address = url.substr(strlen(DebugIO::kFileURLScheme));
    378       const string path = address.find("/") == string::npos
    379                               ? ""
    380                               : address.substr(address.find("/"));
    381       grpc_event.set_wall_time(event.wall_time());
    382       LogMessage* log_message_grpc = grpc_event.mutable_log_message();
    383       log_message_grpc->set_message(
    384           strings::StrCat(json_metadata.substr(0, json_metadata.size() - 1),
    385                           ",\"grpc_path\":\"", path, "\"}"));
    386 
    387       status.Update(
    388           DebugGrpcIO::SendEventProtoThroughGrpcStream(grpc_event, url, true));
    389 #else
    390       GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR;
    391 #endif
    392     } else if (str_util::Lowercase(url).find(kFileURLScheme) == 0) {
    393       const string dump_root_dir = url.substr(strlen(kFileURLScheme));
    394       const string core_metadata_path = AppendTimestampToFilePath(
    395           io::JoinPath(
    396               dump_root_dir,
    397               strings::StrCat(DebugNodeKey::kMetadataFilePrefix,
    398                               DebugIO::kCoreMetadataTag, "sessionrun",
    399                               strings::Printf("%.14lld", session_run_index))),
    400           Env::Default()->NowMicros());
    401       status.Update(DebugFileIO::DumpEventProtoToFile(
    402           event, io::Dirname(core_metadata_path).ToString(),
    403           io::Basename(core_metadata_path).ToString()));
    404     }
    405   }
    406 
    407   return status;
    408 }
    409 
    410 Status DebugIO::PublishDebugTensor(const DebugNodeKey& debug_node_key,
    411                                    const Tensor& tensor,
    412                                    const uint64 wall_time_us,
    413                                    const gtl::ArraySlice<string>& debug_urls,
    414                                    const bool gated_grpc) {
    415   int32 num_failed_urls = 0;
    416   std::vector<Status> fail_statuses;
    417   for (const string& url : debug_urls) {
    418     if (str_util::Lowercase(url).find(kFileURLScheme) == 0) {
    419       const string dump_root_dir = url.substr(strlen(kFileURLScheme));
    420 
    421       Status s = DebugFileIO::DumpTensorToDir(
    422           debug_node_key, tensor, wall_time_us, dump_root_dir, nullptr);
    423       if (!s.ok()) {
    424         num_failed_urls++;
    425         fail_statuses.push_back(s);
    426       }
    427     } else if (str_util::Lowercase(url).find(kGrpcURLScheme) == 0) {
    428 #ifndef PLATFORM_WINDOWS
    429       Status s = DebugGrpcIO::SendTensorThroughGrpcStream(
    430           debug_node_key, tensor, wall_time_us, url, gated_grpc);
    431 
    432       if (!s.ok()) {
    433         num_failed_urls++;
    434         fail_statuses.push_back(s);
    435       }
    436 #else
    437       GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR;
    438 #endif
    439     } else if (str_util::Lowercase(url).find(kMemoryURLScheme) == 0) {
    440       const string dump_root_dir = url.substr(strlen(kMemoryURLScheme));
    441       auto* callback_registry = DebugCallbackRegistry::singleton();
    442       auto* callback = callback_registry->GetCallback(dump_root_dir);
    443       CHECK(callback) << "No callback registered for: " << dump_root_dir;
    444       (*callback)(debug_node_key, tensor);
    445     } else {
    446       return Status(error::UNAVAILABLE,
    447                     strings::StrCat("Invalid debug target URL: ", url));
    448     }
    449   }
    450 
    451   if (num_failed_urls == 0) {
    452     return Status::OK();
    453   } else {
    454     string error_message = strings::StrCat(
    455         "Publishing to ", num_failed_urls, " of ", debug_urls.size(),
    456         " debug target URLs failed, due to the following errors:");
    457     for (Status& status : fail_statuses) {
    458       error_message =
    459           strings::StrCat(error_message, " ", status.error_message(), ";");
    460     }
    461 
    462     return Status(error::INTERNAL, error_message);
    463   }
    464 }
    465 
    466 Status DebugIO::PublishDebugTensor(const DebugNodeKey& debug_node_key,
    467                                    const Tensor& tensor,
    468                                    const uint64 wall_time_us,
    469                                    const gtl::ArraySlice<string>& debug_urls) {
    470   return PublishDebugTensor(debug_node_key, tensor, wall_time_us, debug_urls,
    471                             false);
    472 }
    473 
    474 Status DebugIO::PublishGraph(const Graph& graph, const string& device_name,
    475                              const std::unordered_set<string>& debug_urls) {
    476   GraphDef graph_def;
    477   graph.ToGraphDef(&graph_def);
    478 
    479   string buf;
    480   graph_def.SerializeToString(&buf);
    481 
    482   const int64 now_micros = Env::Default()->NowMicros();
    483   Event event;
    484   event.set_wall_time(static_cast<double>(now_micros));
    485   event.set_graph_def(buf);
    486 
    487   Status status = Status::OK();
    488   for (const string& debug_url : debug_urls) {
    489     if (debug_url.find(kFileURLScheme) == 0) {
    490       const string dump_root_dir =
    491           io::JoinPath(debug_url.substr(strlen(kFileURLScheme)),
    492                        DebugNodeKey::DeviceNameToDevicePath(device_name));
    493       const uint64 graph_hash = ::tensorflow::Hash64(buf);
    494       const string file_name =
    495           strings::StrCat(DebugNodeKey::kMetadataFilePrefix, DebugIO::kGraphTag,
    496                           DebugIO::kHashTag, graph_hash, "_", now_micros);
    497 
    498       status.Update(
    499           DebugFileIO::DumpEventProtoToFile(event, dump_root_dir, file_name));
    500     } else if (debug_url.find(kGrpcURLScheme) == 0) {
    501 #ifndef PLATFORM_WINDOWS
    502       status.Update(PublishEncodedGraphDefInChunks(buf, device_name, now_micros,
    503                                                    debug_url));
    504 #else
    505       GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR;
    506 #endif
    507     }
    508   }
    509 
    510   return status;
    511 }
    512 
    513 bool DebugIO::IsCopyNodeGateOpen(
    514     const std::vector<DebugWatchAndURLSpec>& specs) {
    515 #ifndef PLATFORM_WINDOWS
    516   for (const DebugWatchAndURLSpec& spec : specs) {
    517     if (!spec.gated_grpc || spec.url.compare(0, strlen(DebugIO::kGrpcURLScheme),
    518                                              DebugIO::kGrpcURLScheme)) {
    519       return true;
    520     } else {
    521       if (DebugGrpcIO::IsReadGateOpen(spec.url, spec.watch_key)) {
    522         return true;
    523       }
    524     }
    525   }
    526   return false;
    527 #else
    528   return true;
    529 #endif
    530 }
    531 
    532 bool DebugIO::IsDebugNodeGateOpen(const string& watch_key,
    533                                   const std::vector<string>& debug_urls) {
    534 #ifndef PLATFORM_WINDOWS
    535   for (const string& debug_url : debug_urls) {
    536     if (debug_url.compare(0, strlen(DebugIO::kGrpcURLScheme),
    537                           DebugIO::kGrpcURLScheme)) {
    538       return true;
    539     } else {
    540       if (DebugGrpcIO::IsReadGateOpen(debug_url, watch_key)) {
    541         return true;
    542       }
    543     }
    544   }
    545   return false;
    546 #else
    547   return true;
    548 #endif
    549 }
    550 
    551 bool DebugIO::IsDebugURLGateOpen(const string& watch_key,
    552                                  const string& debug_url) {
    553 #ifndef PLATFORM_WINDOWS
    554   if (debug_url.find(kGrpcURLScheme) != 0) {
    555     return true;
    556   } else {
    557     return DebugGrpcIO::IsReadGateOpen(debug_url, watch_key);
    558   }
    559 #else
    560   return true;
    561 #endif
    562 }
    563 
    564 Status DebugIO::CloseDebugURL(const string& debug_url) {
    565   if (debug_url.find(DebugIO::kGrpcURLScheme) == 0) {
    566 #ifndef PLATFORM_WINDOWS
    567     return DebugGrpcIO::CloseGrpcStream(debug_url);
    568 #else
    569     GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR;
    570 #endif
    571   } else {
    572     // No-op for non-gRPC URLs.
    573     return Status::OK();
    574   }
    575 }
    576 
    577 Status DebugFileIO::DumpTensorToDir(const DebugNodeKey& debug_node_key,
    578                                     const Tensor& tensor,
    579                                     const uint64 wall_time_us,
    580                                     const string& dump_root_dir,
    581                                     string* dump_file_path) {
    582   const string file_path =
    583       GetDumpFilePath(dump_root_dir, debug_node_key, wall_time_us);
    584 
    585   if (dump_file_path != nullptr) {
    586     *dump_file_path = file_path;
    587   }
    588 
    589   return DumpTensorToEventFile(debug_node_key, tensor, wall_time_us, file_path);
    590 }
    591 
    592 string DebugFileIO::GetDumpFilePath(const string& dump_root_dir,
    593                                     const DebugNodeKey& debug_node_key,
    594                                     const uint64 wall_time_us) {
    595   return AppendTimestampToFilePath(
    596       io::JoinPath(dump_root_dir, debug_node_key.device_path,
    597                    strings::StrCat(debug_node_key.node_name, "_",
    598                                    debug_node_key.output_slot, "_",
    599                                    debug_node_key.debug_op)),
    600       wall_time_us);
    601 }
    602 
    603 Status DebugFileIO::DumpEventProtoToFile(const Event& event_proto,
    604                                          const string& dir_name,
    605                                          const string& file_name) {
    606   Env* env(Env::Default());
    607 
    608   Status s = RecursiveCreateDir(env, dir_name);
    609   if (!s.ok()) {
    610     return Status(error::FAILED_PRECONDITION,
    611                   strings::StrCat("Failed to create directory  ", dir_name,
    612                                   ", due to: ", s.error_message()));
    613   }
    614 
    615   const string file_path = io::JoinPath(dir_name, file_name);
    616 
    617   string event_str;
    618   event_proto.SerializeToString(&event_str);
    619 
    620   std::unique_ptr<WritableFile> f = nullptr;
    621   TF_CHECK_OK(env->NewWritableFile(file_path, &f));
    622   f->Append(event_str).IgnoreError();
    623   TF_CHECK_OK(f->Close());
    624 
    625   return Status::OK();
    626 }
    627 
    628 Status DebugFileIO::DumpTensorToEventFile(const DebugNodeKey& debug_node_key,
    629                                           const Tensor& tensor,
    630                                           const uint64 wall_time_us,
    631                                           const string& file_path) {
    632   std::vector<Event> events;
    633   TF_RETURN_IF_ERROR(
    634       WrapTensorAsEvents(debug_node_key, tensor, wall_time_us, 0, &events));
    635   return DumpEventProtoToFile(events[0], io::Dirname(file_path).ToString(),
    636                               io::Basename(file_path).ToString());
    637 }
    638 
    639 Status DebugFileIO::RecursiveCreateDir(Env* env, const string& dir) {
    640   if (env->FileExists(dir).ok() && env->IsDirectory(dir).ok()) {
    641     // The path already exists as a directory. Return OK right away.
    642     return Status::OK();
    643   }
    644 
    645   string parent_dir = io::Dirname(dir).ToString();
    646   if (!env->FileExists(parent_dir).ok()) {
    647     // The parent path does not exist yet, create it first.
    648     Status s = RecursiveCreateDir(env, parent_dir);  // Recursive call
    649     if (!s.ok()) {
    650       return Status(
    651           error::FAILED_PRECONDITION,
    652           strings::StrCat("Failed to create directory  ", parent_dir));
    653     }
    654   } else if (env->FileExists(parent_dir).ok() &&
    655              !env->IsDirectory(parent_dir).ok()) {
    656     // The path exists, but it is a file.
    657     return Status(error::FAILED_PRECONDITION,
    658                   strings::StrCat("Failed to create directory  ", parent_dir,
    659                                   " because the path exists as a file "));
    660   }
    661 
    662   env->CreateDir(dir).IgnoreError();
    663   // Guard against potential race in creating directories by doing a check
    664   // after the CreateDir call.
    665   if (env->FileExists(dir).ok() && env->IsDirectory(dir).ok()) {
    666     return Status::OK();
    667   } else {
    668     return Status(error::ABORTED,
    669                   strings::StrCat("Failed to create directory  ", parent_dir));
    670   }
    671 }
    672 
    673 #ifndef PLATFORM_WINDOWS
    674 DebugGrpcChannel::DebugGrpcChannel(const string& server_stream_addr)
    675     : server_stream_addr_(server_stream_addr),
    676       url_(strings::StrCat(DebugIO::kGrpcURLScheme, server_stream_addr)) {}
    677 
    678 Status DebugGrpcChannel::Connect(const int64 timeout_micros) {
    679   ::grpc::ChannelArguments args;
    680   args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, std::numeric_limits<int32>::max());
    681   // Avoid problems where default reconnect backoff is too long (e.g., 20 s).
    682   args.SetInt("grpc.testing.fixed_reconnect_backoff_ms", 1000);
    683   channel_ = ::grpc::CreateCustomChannel(
    684       server_stream_addr_, ::grpc::InsecureChannelCredentials(), args);
    685   if (!channel_->WaitForConnected(
    686           gpr_time_add(gpr_now(GPR_CLOCK_REALTIME),
    687                        gpr_time_from_micros(timeout_micros, GPR_TIMESPAN)))) {
    688     return errors::FailedPrecondition(
    689         "Failed to connect to gRPC channel at ", server_stream_addr_,
    690         " within a timeout of ", timeout_micros / 1e6, " s.");
    691   }
    692   stub_ = EventListener::NewStub(channel_);
    693   reader_writer_ = stub_->SendEvents(&ctx_);
    694 
    695   return Status::OK();
    696 }
    697 
    698 bool DebugGrpcChannel::WriteEvent(const Event& event) {
    699   mutex_lock l(mu_);
    700   return reader_writer_->Write(event);
    701 }
    702 
    703 bool DebugGrpcChannel::ReadEventReply(EventReply* event_reply) {
    704   mutex_lock l(mu_);
    705   return reader_writer_->Read(event_reply);
    706 }
    707 
    708 void DebugGrpcChannel::ReceiveAndProcessEventReplies(const size_t max_replies) {
    709   EventReply event_reply;
    710   size_t num_replies = 0;
    711   while ((max_replies == 0 || ++num_replies <= max_replies) &&
    712          ReadEventReply(&event_reply)) {
    713     for (const EventReply::DebugOpStateChange& debug_op_state_change :
    714          event_reply.debug_op_state_changes()) {
    715       string watch_key = strings::StrCat(debug_op_state_change.node_name(), ":",
    716                                          debug_op_state_change.output_slot(),
    717                                          ":", debug_op_state_change.debug_op());
    718       DebugGrpcIO::SetDebugNodeKeyGrpcState(url_, watch_key,
    719                                             debug_op_state_change.state());
    720     }
    721   }
    722 }
    723 
    724 Status DebugGrpcChannel::ReceiveServerRepliesAndClose() {
    725   reader_writer_->WritesDone();
    726   // Read all EventReply messages (if any) from the server.
    727   ReceiveAndProcessEventReplies(0);
    728 
    729   if (reader_writer_->Finish().ok()) {
    730     return Status::OK();
    731   } else {
    732     return Status(error::FAILED_PRECONDITION,
    733                   "Failed to close debug GRPC stream.");
    734   }
    735 }
    736 
    737 mutex DebugGrpcIO::streams_mu(LINKER_INITIALIZED);
    738 
    739 int64 DebugGrpcIO::channel_connection_timeout_micros = 900 * 1000 * 1000;
    740 // TODO(cais): Make this configurable?
    741 
    742 const size_t DebugGrpcIO::kGrpcMessageSizeLimitBytes = 4000 * 1024;
    743 
    744 const size_t DebugGrpcIO::kGrpcMaxVarintLengthSize = 6;
    745 
    746 std::unordered_map<string, std::unique_ptr<DebugGrpcChannel>>*
    747 DebugGrpcIO::GetStreamChannels() {
    748   static std::unordered_map<string, std::unique_ptr<DebugGrpcChannel>>*
    749       stream_channels =
    750           new std::unordered_map<string, std::unique_ptr<DebugGrpcChannel>>();
    751   return stream_channels;
    752 }
    753 
    754 Status DebugGrpcIO::SendTensorThroughGrpcStream(
    755     const DebugNodeKey& debug_node_key, const Tensor& tensor,
    756     const uint64 wall_time_us, const string& grpc_stream_url,
    757     const bool gated) {
    758   if (gated &&
    759       !IsReadGateOpen(grpc_stream_url, debug_node_key.debug_node_name)) {
    760     return Status::OK();
    761   } else {
    762     std::vector<Event> events;
    763     TF_RETURN_IF_ERROR(WrapTensorAsEvents(debug_node_key, tensor, wall_time_us,
    764                                           kGrpcMessageSizeLimitBytes, &events));
    765     for (const Event& event : events) {
    766       TF_RETURN_IF_ERROR(
    767           SendEventProtoThroughGrpcStream(event, grpc_stream_url));
    768     }
    769     if (IsWriteGateOpen(grpc_stream_url, debug_node_key.debug_node_name)) {
    770       DebugGrpcChannel* debug_grpc_channel = nullptr;
    771       TF_RETURN_IF_ERROR(
    772           GetOrCreateDebugGrpcChannel(grpc_stream_url, &debug_grpc_channel));
    773       debug_grpc_channel->ReceiveAndProcessEventReplies(1);
    774       // TODO(cais): Support new tensor value carried in the EventReply for
    775       // overriding the value of the tensor being published.
    776     }
    777     return Status::OK();
    778   }
    779 }
    780 
    781 Status DebugGrpcIO::ReceiveEventReplyProtoThroughGrpcStream(
    782     EventReply* event_reply, const string& grpc_stream_url) {
    783   DebugGrpcChannel* debug_grpc_channel = nullptr;
    784   TF_RETURN_IF_ERROR(
    785       GetOrCreateDebugGrpcChannel(grpc_stream_url, &debug_grpc_channel));
    786   if (debug_grpc_channel->ReadEventReply(event_reply)) {
    787     return Status::OK();
    788   } else {
    789     return errors::Cancelled(strings::StrCat(
    790         "Reading EventReply from stream URL ", grpc_stream_url, " failed."));
    791   }
    792 }
    793 
    794 Status DebugGrpcIO::GetOrCreateDebugGrpcChannel(
    795     const string& grpc_stream_url, DebugGrpcChannel** debug_grpc_channel) {
    796   const string addr_with_path =
    797       grpc_stream_url.find(DebugIO::kGrpcURLScheme) == 0
    798           ? grpc_stream_url.substr(strlen(DebugIO::kGrpcURLScheme))
    799           : grpc_stream_url;
    800   const string server_stream_addr =
    801       addr_with_path.substr(0, addr_with_path.find('/'));
    802   {
    803     mutex_lock l(streams_mu);
    804     std::unordered_map<string, std::unique_ptr<DebugGrpcChannel>>*
    805         stream_channels = GetStreamChannels();
    806     if (stream_channels->find(grpc_stream_url) == stream_channels->end()) {
    807       std::unique_ptr<DebugGrpcChannel> channel(
    808           new DebugGrpcChannel(server_stream_addr));
    809       TF_RETURN_IF_ERROR(channel->Connect(channel_connection_timeout_micros));
    810       stream_channels->insert(
    811           std::make_pair(grpc_stream_url, std::move(channel)));
    812     }
    813     *debug_grpc_channel = (*stream_channels)[grpc_stream_url].get();
    814   }
    815   return Status::OK();
    816 }
    817 
    818 Status DebugGrpcIO::SendEventProtoThroughGrpcStream(
    819     const Event& event_proto, const string& grpc_stream_url,
    820     const bool receive_reply) {
    821   DebugGrpcChannel* debug_grpc_channel;
    822   TF_RETURN_IF_ERROR(
    823       GetOrCreateDebugGrpcChannel(grpc_stream_url, &debug_grpc_channel));
    824 
    825   bool write_ok = debug_grpc_channel->WriteEvent(event_proto);
    826   if (!write_ok) {
    827     return errors::Cancelled(strings::StrCat("Write event to stream URL ",
    828                                              grpc_stream_url, " failed."));
    829   }
    830 
    831   if (receive_reply) {
    832     debug_grpc_channel->ReceiveAndProcessEventReplies(1);
    833   }
    834 
    835   return Status::OK();
    836 }
    837 
    838 bool DebugGrpcIO::IsReadGateOpen(const string& grpc_debug_url,
    839                                  const string& watch_key) {
    840   const DebugNodeName2State* enabled_node_to_state =
    841       GetEnabledDebugOpStatesAtUrl(grpc_debug_url);
    842   return enabled_node_to_state->find(watch_key) != enabled_node_to_state->end();
    843 }
    844 
    845 bool DebugGrpcIO::IsWriteGateOpen(const string& grpc_debug_url,
    846                                   const string& watch_key) {
    847   const DebugNodeName2State* enabled_node_to_state =
    848       GetEnabledDebugOpStatesAtUrl(grpc_debug_url);
    849   auto it = enabled_node_to_state->find(watch_key);
    850   if (it == enabled_node_to_state->end()) {
    851     return false;
    852   } else {
    853     return it->second == EventReply::DebugOpStateChange::READ_WRITE;
    854   }
    855 }
    856 
    857 Status DebugGrpcIO::CloseGrpcStream(const string& grpc_stream_url) {
    858   mutex_lock l(streams_mu);
    859 
    860   std::unordered_map<string, std::unique_ptr<DebugGrpcChannel>>*
    861       stream_channels = GetStreamChannels();
    862   if (stream_channels->find(grpc_stream_url) != stream_channels->end()) {
    863     // Stream of the specified address exists. Close it and remove it from
    864     // record.
    865     Status s =
    866         (*stream_channels)[grpc_stream_url]->ReceiveServerRepliesAndClose();
    867     (*stream_channels).erase(grpc_stream_url);
    868     return s;
    869   } else {
    870     // Stream of the specified address does not exist. No action.
    871     return Status::OK();
    872   }
    873 }
    874 
    875 std::unordered_map<string, DebugGrpcIO::DebugNodeName2State>*
    876 DebugGrpcIO::GetEnabledDebugOpStates() {
    877   static std::unordered_map<string, DebugNodeName2State>*
    878       enabled_debug_op_states =
    879           new std::unordered_map<string, DebugNodeName2State>();
    880   return enabled_debug_op_states;
    881 }
    882 
    883 DebugGrpcIO::DebugNodeName2State* DebugGrpcIO::GetEnabledDebugOpStatesAtUrl(
    884     const string& grpc_debug_url) {
    885   static mutex* debug_ops_state_mu = new mutex();
    886   std::unordered_map<string, DebugNodeName2State>* states =
    887       GetEnabledDebugOpStates();
    888 
    889   mutex_lock l(*debug_ops_state_mu);
    890   if (states->find(grpc_debug_url) == states->end()) {
    891     DebugNodeName2State url_enabled_debug_op_states;
    892     (*states)[grpc_debug_url] = url_enabled_debug_op_states;
    893   }
    894   return &(*states)[grpc_debug_url];
    895 }
    896 
    897 void DebugGrpcIO::SetDebugNodeKeyGrpcState(
    898     const string& grpc_debug_url, const string& watch_key,
    899     const EventReply::DebugOpStateChange::State new_state) {
    900   DebugNodeName2State* states = GetEnabledDebugOpStatesAtUrl(grpc_debug_url);
    901   if (new_state == EventReply::DebugOpStateChange::DISABLED) {
    902     if (states->find(watch_key) == states->end()) {
    903       LOG(ERROR) << "Attempt to disable a watch key that is not currently "
    904                  << "enabled at " << grpc_debug_url << ": " << watch_key;
    905     } else {
    906       states->erase(watch_key);
    907     }
    908   } else if (new_state != EventReply::DebugOpStateChange::STATE_UNSPECIFIED) {
    909     (*states)[watch_key] = new_state;
    910   }
    911 }
    912 
    913 void DebugGrpcIO::ClearEnabledWatchKeys() {
    914   GetEnabledDebugOpStates()->clear();
    915 }
    916 
    917 #endif  // #ifndef PLATFORM_WINDOWS
    918 
    919 }  // namespace tensorflow
    920