1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/debug/debug_io_utils.h" 17 18 #include <stddef.h> 19 #include <string.h> 20 #include <cmath> 21 #include <limits> 22 #include <utility> 23 #include <vector> 24 25 #ifndef PLATFORM_WINDOWS 26 #include "grpc++/create_channel.h" 27 #else 28 // winsock2.h is used in grpc, so Ws2_32.lib is needed 29 #pragma comment(lib, "Ws2_32.lib") 30 #endif // #ifndef PLATFORM_WINDOWS 31 32 #include "tensorflow/core/debug/debug_callback_registry.h" 33 #include "tensorflow/core/debug/debugger_event_metadata.pb.h" 34 #include "tensorflow/core/framework/graph.pb.h" 35 #include "tensorflow/core/framework/summary.pb.h" 36 #include "tensorflow/core/framework/tensor_shape.pb.h" 37 #include "tensorflow/core/lib/core/bits.h" 38 #include "tensorflow/core/lib/hash/hash.h" 39 #include "tensorflow/core/lib/io/path.h" 40 #include "tensorflow/core/lib/strings/str_util.h" 41 #include "tensorflow/core/lib/strings/stringprintf.h" 42 #include "tensorflow/core/platform/protobuf.h" 43 #include "tensorflow/core/util/event.pb.h" 44 45 #define GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR \ 46 return errors::Unimplemented( \ 47 kGrpcURLScheme, " debug URL scheme is not implemented on Windows yet.") 48 49 namespace tensorflow { 50 51 namespace { 52 53 // Creates an Event proto representing a chunk of a Tensor. This method only 54 // populates the field of the Event proto that represent the envelope 55 // informaion (e.g., timestmap, device_name, num_chunks, chunk_index, dtype, 56 // shape). It does not set the value.tensor field, which should be set by the 57 // caller separately. 58 Event PrepareChunkEventProto(const DebugNodeKey& debug_node_key, 59 const uint64 wall_time_us, const size_t num_chunks, 60 const size_t chunk_index, 61 const DataType& tensor_dtype, 62 const TensorShapeProto& tensor_shape) { 63 Event event; 64 event.set_wall_time(static_cast<double>(wall_time_us)); 65 Summary::Value* value = event.mutable_summary()->add_value(); 66 67 // Create the debug node_name in the Summary proto. 68 // For example, if tensor_name = "foo/node_a:0", and the debug_op is 69 // "DebugIdentity", the debug node_name in the Summary proto will be 70 // "foo/node_a:0:DebugIdentity". 71 value->set_node_name(debug_node_key.debug_node_name); 72 73 // Tag by the node name. This allows TensorBoard to quickly fetch data 74 // per op. 75 value->set_tag(debug_node_key.node_name); 76 77 // Store data within debugger metadata to be stored for each event. 78 third_party::tensorflow::core::debug::DebuggerEventMetadata metadata; 79 metadata.set_device(debug_node_key.device_name); 80 metadata.set_output_slot(debug_node_key.output_slot); 81 metadata.set_num_chunks(num_chunks); 82 metadata.set_chunk_index(chunk_index); 83 84 // Encode the data in JSON. 85 string json_output; 86 tensorflow::protobuf::util::JsonPrintOptions json_options; 87 json_options.always_print_primitive_fields = true; 88 auto status = tensorflow::protobuf::util::MessageToJsonString( 89 metadata, &json_output, json_options); 90 if (status.ok()) { 91 // Store summary metadata. Set the plugin to use this data as "debugger". 92 SummaryMetadata::PluginData* plugin_data = 93 value->mutable_metadata()->mutable_plugin_data(); 94 plugin_data->set_plugin_name(DebugIO::kDebuggerPluginName); 95 plugin_data->set_content(json_output); 96 } else { 97 LOG(WARNING) << "Failed to convert DebuggerEventMetadata proto to JSON. " 98 << "The debug_node_name is " << debug_node_key.debug_node_name 99 << "."; 100 } 101 102 value->mutable_tensor()->set_dtype(tensor_dtype); 103 *value->mutable_tensor()->mutable_tensor_shape() = tensor_shape; 104 105 return event; 106 } 107 108 // Translates the length of a string to number of bytes when the string is 109 // encoded as bytes in protobuf. Note that this makes a conservative estimate 110 // (i.e., an estimate that is usually too large, but never too small under the 111 // gRPC message size limit) of the Varint-encoded length, to workaround the lack 112 // of a portable length function. 113 const size_t StringValMaxBytesInProto(const string& str) { 114 #if defined(PLATFORM_GOOGLE) 115 return str.size() + DebugGrpcIO::kGrpcMaxVarintLengthSize; 116 #else 117 return str.size(); 118 #endif 119 } 120 121 // Breaks a string Tensor (represented as a TensorProto) as a vector of Event 122 // protos. 123 Status WrapStringTensorAsEvents(const DebugNodeKey& debug_node_key, 124 const uint64 wall_time_us, 125 const size_t chunk_size_limit, 126 TensorProto* tensor_proto, 127 std::vector<Event>* events) { 128 const protobuf::RepeatedPtrField<string>& strs = tensor_proto->string_val(); 129 const size_t num_strs = strs.size(); 130 const size_t chunk_size_ub = chunk_size_limit > 0 131 ? chunk_size_limit 132 : std::numeric_limits<size_t>::max(); 133 134 // E.g., if cutoffs is {j, k, l}, the chunks will have index ranges: 135 // [0:a), [a:b), [c:<end>]. 136 std::vector<size_t> cutoffs; 137 size_t chunk_size = 0; 138 for (size_t i = 0; i < num_strs; ++i) { 139 // Take into account the extra bytes in proto buffer. 140 if (StringValMaxBytesInProto(strs[i]) > chunk_size_ub) { 141 return errors::FailedPrecondition( 142 "string value at index ", i, " from debug node ", 143 debug_node_key.debug_node_name, 144 " does not fit gRPC message size limit (", chunk_size_ub, ")"); 145 } 146 if (chunk_size + StringValMaxBytesInProto(strs[i]) > chunk_size_ub) { 147 cutoffs.push_back(i); 148 chunk_size = 0; 149 } 150 chunk_size += StringValMaxBytesInProto(strs[i]); 151 } 152 cutoffs.push_back(num_strs); 153 const size_t num_chunks = cutoffs.size(); 154 155 for (size_t i = 0; i < num_chunks; ++i) { 156 Event event = PrepareChunkEventProto(debug_node_key, wall_time_us, 157 num_chunks, i, tensor_proto->dtype(), 158 tensor_proto->tensor_shape()); 159 Summary::Value* value = event.mutable_summary()->mutable_value(0); 160 161 if (cutoffs.size() == 1) { 162 value->mutable_tensor()->mutable_string_val()->Swap( 163 tensor_proto->mutable_string_val()); 164 } else { 165 const size_t begin = (i == 0) ? 0 : cutoffs[i - 1]; 166 const size_t end = cutoffs[i]; 167 for (size_t j = begin; j < end; ++j) { 168 value->mutable_tensor()->add_string_val(strs[j]); 169 } 170 } 171 172 events->push_back(std::move(event)); 173 } 174 175 return Status::OK(); 176 } 177 178 // Encapsulates the tensor value inside a vector of Event protos. Large tensors 179 // are broken up to multiple protos to fit the chunk_size_limit. In each Event 180 // proto the field summary.tensor carries the content of the tensor. 181 // If chunk_size_limit <= 0, the tensor will not be broken into chunks, i.e., a 182 // length-1 vector will be returned, regardless of the size of the tensor. 183 Status WrapTensorAsEvents(const DebugNodeKey& debug_node_key, 184 const Tensor& tensor, const uint64 wall_time_us, 185 const size_t chunk_size_limit, 186 std::vector<Event>* events) { 187 TensorProto tensor_proto; 188 if (tensor.dtype() == DT_STRING) { 189 // Treat DT_STRING specially, so that tensor_util.MakeNdarray in Python can 190 // convert the TensorProto to string-type numpy array. MakeNdarray does not 191 // work with strings encoded by AsProtoTensorContent() in tensor_content. 192 tensor.AsProtoField(&tensor_proto); 193 194 TF_RETURN_IF_ERROR(WrapStringTensorAsEvents( 195 debug_node_key, wall_time_us, chunk_size_limit, &tensor_proto, events)); 196 } else { 197 tensor.AsProtoTensorContent(&tensor_proto); 198 199 const size_t total_length = tensor_proto.tensor_content().size(); 200 const size_t chunk_size_ub = 201 chunk_size_limit > 0 ? chunk_size_limit : total_length; 202 const size_t num_chunks = 203 (total_length == 0) 204 ? 1 205 : (total_length + chunk_size_ub - 1) / chunk_size_ub; 206 for (size_t i = 0; i < num_chunks; ++i) { 207 const size_t pos = i * chunk_size_ub; 208 const size_t len = 209 (i == num_chunks - 1) ? (total_length - pos) : chunk_size_ub; 210 Event event = PrepareChunkEventProto(debug_node_key, wall_time_us, 211 num_chunks, i, tensor_proto.dtype(), 212 tensor_proto.tensor_shape()); 213 event.mutable_summary() 214 ->mutable_value(0) 215 ->mutable_tensor() 216 ->set_tensor_content(tensor_proto.tensor_content().substr(pos, len)); 217 events->push_back(std::move(event)); 218 } 219 } 220 221 return Status::OK(); 222 } 223 224 // Appends an underscore and a timestamp to a file path. If the path already 225 // exists on the file system, append a hyphen and a 1-up index. Consecutive 226 // values of the index will be tried until the first unused one is found. 227 // TOCTOU race condition is not of concern here due to the fact that tfdbg 228 // sets parallel_iterations attribute of all while_loops to 1 to prevent 229 // the same node from between executed multiple times concurrently. 230 string AppendTimestampToFilePath(const string& in, const uint64 timestamp) { 231 string out = strings::StrCat(in, "_", timestamp); 232 233 uint64 i = 1; 234 while (Env::Default()->FileExists(out).ok()) { 235 out = strings::StrCat(in, "_", timestamp, "-", i); 236 ++i; 237 } 238 return out; 239 } 240 241 #ifndef PLATFORM_WINDOWS 242 // Publishes encoded GraphDef through a gRPC debugger stream, in chunks, 243 // conforming to the gRPC message size limit. 244 Status PublishEncodedGraphDefInChunks(const string& encoded_graph_def, 245 const string& device_name, 246 const int64 wall_time, 247 const string& debug_url) { 248 const uint64 hash = ::tensorflow::Hash64(encoded_graph_def); 249 const size_t total_length = encoded_graph_def.size(); 250 const size_t num_chunks = 251 static_cast<size_t>(std::ceil(static_cast<float>(total_length) / 252 DebugGrpcIO::kGrpcMessageSizeLimitBytes)); 253 for (size_t i = 0; i < num_chunks; ++i) { 254 const size_t pos = i * DebugGrpcIO::kGrpcMessageSizeLimitBytes; 255 const size_t len = (i == num_chunks - 1) 256 ? (total_length - pos) 257 : DebugGrpcIO::kGrpcMessageSizeLimitBytes; 258 Event event; 259 event.set_wall_time(static_cast<double>(wall_time)); 260 // Prefix the chunk with 261 // <hash64>,<device_name>,<wall_time>|<index>|<num_chunks>|. 262 // TODO(cais): Use DebuggerEventMetadata to store device_name, num_chunks 263 // and chunk_index, instead. 264 event.set_graph_def(strings::StrCat(hash, ",", device_name, ",", wall_time, 265 "|", i, "|", num_chunks, "|", 266 encoded_graph_def.substr(pos, len))); 267 const Status s = DebugGrpcIO::SendEventProtoThroughGrpcStream( 268 event, debug_url, num_chunks - 1 == i); 269 if (!s.ok()) { 270 return errors::FailedPrecondition( 271 "Failed to send chunk ", i, " of ", num_chunks, 272 " of encoded GraphDef of size ", encoded_graph_def.size(), " bytes, ", 273 "due to: ", s.error_message()); 274 } 275 } 276 return Status::OK(); 277 } 278 #endif // #ifndef PLATFORM_WINDOWS 279 280 } // namespace 281 282 const char* const DebugIO::kDebuggerPluginName = "debugger"; 283 284 const char* const DebugIO::kCoreMetadataTag = "core_metadata_"; 285 286 const char* const DebugIO::kGraphTag = "graph_"; 287 288 const char* const DebugIO::kHashTag = "hash"; 289 290 Status ReadEventFromFile(const string& dump_file_path, Event* event) { 291 Env* env(Env::Default()); 292 293 string content; 294 uint64 file_size = 0; 295 296 Status s = env->GetFileSize(dump_file_path, &file_size); 297 if (!s.ok()) { 298 return s; 299 } 300 301 content.resize(file_size); 302 303 std::unique_ptr<RandomAccessFile> file; 304 s = env->NewRandomAccessFile(dump_file_path, &file); 305 if (!s.ok()) { 306 return s; 307 } 308 309 StringPiece result; 310 s = file->Read(0, file_size, &result, &(content)[0]); 311 if (!s.ok()) { 312 return s; 313 } 314 315 event->ParseFromString(content); 316 return Status::OK(); 317 } 318 319 const char* const DebugIO::kFileURLScheme = "file://"; 320 const char* const DebugIO::kGrpcURLScheme = "grpc://"; 321 const char* const DebugIO::kMemoryURLScheme = "memcbk://"; 322 323 // Publishes debug metadata to a set of debug URLs. 324 Status DebugIO::PublishDebugMetadata( 325 const int64 global_step, const int64 session_run_index, 326 const int64 executor_step_index, const std::vector<string>& input_names, 327 const std::vector<string>& output_names, 328 const std::vector<string>& target_nodes, 329 const std::unordered_set<string>& debug_urls) { 330 std::ostringstream oss; 331 332 // Construct a JSON string to carry the metadata. 333 oss << "{"; 334 oss << "\"global_step\":" << global_step << ","; 335 oss << "\"session_run_index\":" << session_run_index << ","; 336 oss << "\"executor_step_index\":" << executor_step_index << ","; 337 oss << "\"input_names\":["; 338 for (size_t i = 0; i < input_names.size(); ++i) { 339 oss << "\"" << input_names[i] << "\""; 340 if (i < input_names.size() - 1) { 341 oss << ","; 342 } 343 } 344 oss << "],"; 345 oss << "\"output_names\":["; 346 for (size_t i = 0; i < output_names.size(); ++i) { 347 oss << "\"" << output_names[i] << "\""; 348 if (i < output_names.size() - 1) { 349 oss << ","; 350 } 351 } 352 oss << "],"; 353 oss << "\"target_nodes\":["; 354 for (size_t i = 0; i < target_nodes.size(); ++i) { 355 oss << "\"" << target_nodes[i] << "\""; 356 if (i < target_nodes.size() - 1) { 357 oss << ","; 358 } 359 } 360 oss << "]"; 361 oss << "}"; 362 363 const string json_metadata = oss.str(); 364 Event event; 365 event.set_wall_time(static_cast<double>(Env::Default()->NowMicros())); 366 LogMessage* log_message = event.mutable_log_message(); 367 log_message->set_message(json_metadata); 368 369 Status status; 370 for (const string& url : debug_urls) { 371 if (str_util::Lowercase(url).find(kGrpcURLScheme) == 0) { 372 #ifndef PLATFORM_WINDOWS 373 Event grpc_event; 374 375 // Determine the path (if any) in the grpc:// URL, and add it as a field 376 // of the JSON string. 377 const string address = url.substr(strlen(DebugIO::kFileURLScheme)); 378 const string path = address.find("/") == string::npos 379 ? "" 380 : address.substr(address.find("/")); 381 grpc_event.set_wall_time(event.wall_time()); 382 LogMessage* log_message_grpc = grpc_event.mutable_log_message(); 383 log_message_grpc->set_message( 384 strings::StrCat(json_metadata.substr(0, json_metadata.size() - 1), 385 ",\"grpc_path\":\"", path, "\"}")); 386 387 status.Update( 388 DebugGrpcIO::SendEventProtoThroughGrpcStream(grpc_event, url, true)); 389 #else 390 GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR; 391 #endif 392 } else if (str_util::Lowercase(url).find(kFileURLScheme) == 0) { 393 const string dump_root_dir = url.substr(strlen(kFileURLScheme)); 394 const string core_metadata_path = AppendTimestampToFilePath( 395 io::JoinPath( 396 dump_root_dir, 397 strings::StrCat(DebugNodeKey::kMetadataFilePrefix, 398 DebugIO::kCoreMetadataTag, "sessionrun", 399 strings::Printf("%.14lld", session_run_index))), 400 Env::Default()->NowMicros()); 401 status.Update(DebugFileIO::DumpEventProtoToFile( 402 event, io::Dirname(core_metadata_path).ToString(), 403 io::Basename(core_metadata_path).ToString())); 404 } 405 } 406 407 return status; 408 } 409 410 Status DebugIO::PublishDebugTensor(const DebugNodeKey& debug_node_key, 411 const Tensor& tensor, 412 const uint64 wall_time_us, 413 const gtl::ArraySlice<string>& debug_urls, 414 const bool gated_grpc) { 415 int32 num_failed_urls = 0; 416 std::vector<Status> fail_statuses; 417 for (const string& url : debug_urls) { 418 if (str_util::Lowercase(url).find(kFileURLScheme) == 0) { 419 const string dump_root_dir = url.substr(strlen(kFileURLScheme)); 420 421 Status s = DebugFileIO::DumpTensorToDir( 422 debug_node_key, tensor, wall_time_us, dump_root_dir, nullptr); 423 if (!s.ok()) { 424 num_failed_urls++; 425 fail_statuses.push_back(s); 426 } 427 } else if (str_util::Lowercase(url).find(kGrpcURLScheme) == 0) { 428 #ifndef PLATFORM_WINDOWS 429 Status s = DebugGrpcIO::SendTensorThroughGrpcStream( 430 debug_node_key, tensor, wall_time_us, url, gated_grpc); 431 432 if (!s.ok()) { 433 num_failed_urls++; 434 fail_statuses.push_back(s); 435 } 436 #else 437 GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR; 438 #endif 439 } else if (str_util::Lowercase(url).find(kMemoryURLScheme) == 0) { 440 const string dump_root_dir = url.substr(strlen(kMemoryURLScheme)); 441 auto* callback_registry = DebugCallbackRegistry::singleton(); 442 auto* callback = callback_registry->GetCallback(dump_root_dir); 443 CHECK(callback) << "No callback registered for: " << dump_root_dir; 444 (*callback)(debug_node_key, tensor); 445 } else { 446 return Status(error::UNAVAILABLE, 447 strings::StrCat("Invalid debug target URL: ", url)); 448 } 449 } 450 451 if (num_failed_urls == 0) { 452 return Status::OK(); 453 } else { 454 string error_message = strings::StrCat( 455 "Publishing to ", num_failed_urls, " of ", debug_urls.size(), 456 " debug target URLs failed, due to the following errors:"); 457 for (Status& status : fail_statuses) { 458 error_message = 459 strings::StrCat(error_message, " ", status.error_message(), ";"); 460 } 461 462 return Status(error::INTERNAL, error_message); 463 } 464 } 465 466 Status DebugIO::PublishDebugTensor(const DebugNodeKey& debug_node_key, 467 const Tensor& tensor, 468 const uint64 wall_time_us, 469 const gtl::ArraySlice<string>& debug_urls) { 470 return PublishDebugTensor(debug_node_key, tensor, wall_time_us, debug_urls, 471 false); 472 } 473 474 Status DebugIO::PublishGraph(const Graph& graph, const string& device_name, 475 const std::unordered_set<string>& debug_urls) { 476 GraphDef graph_def; 477 graph.ToGraphDef(&graph_def); 478 479 string buf; 480 graph_def.SerializeToString(&buf); 481 482 const int64 now_micros = Env::Default()->NowMicros(); 483 Event event; 484 event.set_wall_time(static_cast<double>(now_micros)); 485 event.set_graph_def(buf); 486 487 Status status = Status::OK(); 488 for (const string& debug_url : debug_urls) { 489 if (debug_url.find(kFileURLScheme) == 0) { 490 const string dump_root_dir = 491 io::JoinPath(debug_url.substr(strlen(kFileURLScheme)), 492 DebugNodeKey::DeviceNameToDevicePath(device_name)); 493 const uint64 graph_hash = ::tensorflow::Hash64(buf); 494 const string file_name = 495 strings::StrCat(DebugNodeKey::kMetadataFilePrefix, DebugIO::kGraphTag, 496 DebugIO::kHashTag, graph_hash, "_", now_micros); 497 498 status.Update( 499 DebugFileIO::DumpEventProtoToFile(event, dump_root_dir, file_name)); 500 } else if (debug_url.find(kGrpcURLScheme) == 0) { 501 #ifndef PLATFORM_WINDOWS 502 status.Update(PublishEncodedGraphDefInChunks(buf, device_name, now_micros, 503 debug_url)); 504 #else 505 GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR; 506 #endif 507 } 508 } 509 510 return status; 511 } 512 513 bool DebugIO::IsCopyNodeGateOpen( 514 const std::vector<DebugWatchAndURLSpec>& specs) { 515 #ifndef PLATFORM_WINDOWS 516 for (const DebugWatchAndURLSpec& spec : specs) { 517 if (!spec.gated_grpc || spec.url.compare(0, strlen(DebugIO::kGrpcURLScheme), 518 DebugIO::kGrpcURLScheme)) { 519 return true; 520 } else { 521 if (DebugGrpcIO::IsReadGateOpen(spec.url, spec.watch_key)) { 522 return true; 523 } 524 } 525 } 526 return false; 527 #else 528 return true; 529 #endif 530 } 531 532 bool DebugIO::IsDebugNodeGateOpen(const string& watch_key, 533 const std::vector<string>& debug_urls) { 534 #ifndef PLATFORM_WINDOWS 535 for (const string& debug_url : debug_urls) { 536 if (debug_url.compare(0, strlen(DebugIO::kGrpcURLScheme), 537 DebugIO::kGrpcURLScheme)) { 538 return true; 539 } else { 540 if (DebugGrpcIO::IsReadGateOpen(debug_url, watch_key)) { 541 return true; 542 } 543 } 544 } 545 return false; 546 #else 547 return true; 548 #endif 549 } 550 551 bool DebugIO::IsDebugURLGateOpen(const string& watch_key, 552 const string& debug_url) { 553 #ifndef PLATFORM_WINDOWS 554 if (debug_url.find(kGrpcURLScheme) != 0) { 555 return true; 556 } else { 557 return DebugGrpcIO::IsReadGateOpen(debug_url, watch_key); 558 } 559 #else 560 return true; 561 #endif 562 } 563 564 Status DebugIO::CloseDebugURL(const string& debug_url) { 565 if (debug_url.find(DebugIO::kGrpcURLScheme) == 0) { 566 #ifndef PLATFORM_WINDOWS 567 return DebugGrpcIO::CloseGrpcStream(debug_url); 568 #else 569 GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR; 570 #endif 571 } else { 572 // No-op for non-gRPC URLs. 573 return Status::OK(); 574 } 575 } 576 577 Status DebugFileIO::DumpTensorToDir(const DebugNodeKey& debug_node_key, 578 const Tensor& tensor, 579 const uint64 wall_time_us, 580 const string& dump_root_dir, 581 string* dump_file_path) { 582 const string file_path = 583 GetDumpFilePath(dump_root_dir, debug_node_key, wall_time_us); 584 585 if (dump_file_path != nullptr) { 586 *dump_file_path = file_path; 587 } 588 589 return DumpTensorToEventFile(debug_node_key, tensor, wall_time_us, file_path); 590 } 591 592 string DebugFileIO::GetDumpFilePath(const string& dump_root_dir, 593 const DebugNodeKey& debug_node_key, 594 const uint64 wall_time_us) { 595 return AppendTimestampToFilePath( 596 io::JoinPath(dump_root_dir, debug_node_key.device_path, 597 strings::StrCat(debug_node_key.node_name, "_", 598 debug_node_key.output_slot, "_", 599 debug_node_key.debug_op)), 600 wall_time_us); 601 } 602 603 Status DebugFileIO::DumpEventProtoToFile(const Event& event_proto, 604 const string& dir_name, 605 const string& file_name) { 606 Env* env(Env::Default()); 607 608 Status s = RecursiveCreateDir(env, dir_name); 609 if (!s.ok()) { 610 return Status(error::FAILED_PRECONDITION, 611 strings::StrCat("Failed to create directory ", dir_name, 612 ", due to: ", s.error_message())); 613 } 614 615 const string file_path = io::JoinPath(dir_name, file_name); 616 617 string event_str; 618 event_proto.SerializeToString(&event_str); 619 620 std::unique_ptr<WritableFile> f = nullptr; 621 TF_CHECK_OK(env->NewWritableFile(file_path, &f)); 622 f->Append(event_str).IgnoreError(); 623 TF_CHECK_OK(f->Close()); 624 625 return Status::OK(); 626 } 627 628 Status DebugFileIO::DumpTensorToEventFile(const DebugNodeKey& debug_node_key, 629 const Tensor& tensor, 630 const uint64 wall_time_us, 631 const string& file_path) { 632 std::vector<Event> events; 633 TF_RETURN_IF_ERROR( 634 WrapTensorAsEvents(debug_node_key, tensor, wall_time_us, 0, &events)); 635 return DumpEventProtoToFile(events[0], io::Dirname(file_path).ToString(), 636 io::Basename(file_path).ToString()); 637 } 638 639 Status DebugFileIO::RecursiveCreateDir(Env* env, const string& dir) { 640 if (env->FileExists(dir).ok() && env->IsDirectory(dir).ok()) { 641 // The path already exists as a directory. Return OK right away. 642 return Status::OK(); 643 } 644 645 string parent_dir = io::Dirname(dir).ToString(); 646 if (!env->FileExists(parent_dir).ok()) { 647 // The parent path does not exist yet, create it first. 648 Status s = RecursiveCreateDir(env, parent_dir); // Recursive call 649 if (!s.ok()) { 650 return Status( 651 error::FAILED_PRECONDITION, 652 strings::StrCat("Failed to create directory ", parent_dir)); 653 } 654 } else if (env->FileExists(parent_dir).ok() && 655 !env->IsDirectory(parent_dir).ok()) { 656 // The path exists, but it is a file. 657 return Status(error::FAILED_PRECONDITION, 658 strings::StrCat("Failed to create directory ", parent_dir, 659 " because the path exists as a file ")); 660 } 661 662 env->CreateDir(dir).IgnoreError(); 663 // Guard against potential race in creating directories by doing a check 664 // after the CreateDir call. 665 if (env->FileExists(dir).ok() && env->IsDirectory(dir).ok()) { 666 return Status::OK(); 667 } else { 668 return Status(error::ABORTED, 669 strings::StrCat("Failed to create directory ", parent_dir)); 670 } 671 } 672 673 #ifndef PLATFORM_WINDOWS 674 DebugGrpcChannel::DebugGrpcChannel(const string& server_stream_addr) 675 : server_stream_addr_(server_stream_addr), 676 url_(strings::StrCat(DebugIO::kGrpcURLScheme, server_stream_addr)) {} 677 678 Status DebugGrpcChannel::Connect(const int64 timeout_micros) { 679 ::grpc::ChannelArguments args; 680 args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, std::numeric_limits<int32>::max()); 681 // Avoid problems where default reconnect backoff is too long (e.g., 20 s). 682 args.SetInt("grpc.testing.fixed_reconnect_backoff_ms", 1000); 683 channel_ = ::grpc::CreateCustomChannel( 684 server_stream_addr_, ::grpc::InsecureChannelCredentials(), args); 685 if (!channel_->WaitForConnected( 686 gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), 687 gpr_time_from_micros(timeout_micros, GPR_TIMESPAN)))) { 688 return errors::FailedPrecondition( 689 "Failed to connect to gRPC channel at ", server_stream_addr_, 690 " within a timeout of ", timeout_micros / 1e6, " s."); 691 } 692 stub_ = EventListener::NewStub(channel_); 693 reader_writer_ = stub_->SendEvents(&ctx_); 694 695 return Status::OK(); 696 } 697 698 bool DebugGrpcChannel::WriteEvent(const Event& event) { 699 mutex_lock l(mu_); 700 return reader_writer_->Write(event); 701 } 702 703 bool DebugGrpcChannel::ReadEventReply(EventReply* event_reply) { 704 mutex_lock l(mu_); 705 return reader_writer_->Read(event_reply); 706 } 707 708 void DebugGrpcChannel::ReceiveAndProcessEventReplies(const size_t max_replies) { 709 EventReply event_reply; 710 size_t num_replies = 0; 711 while ((max_replies == 0 || ++num_replies <= max_replies) && 712 ReadEventReply(&event_reply)) { 713 for (const EventReply::DebugOpStateChange& debug_op_state_change : 714 event_reply.debug_op_state_changes()) { 715 string watch_key = strings::StrCat(debug_op_state_change.node_name(), ":", 716 debug_op_state_change.output_slot(), 717 ":", debug_op_state_change.debug_op()); 718 DebugGrpcIO::SetDebugNodeKeyGrpcState(url_, watch_key, 719 debug_op_state_change.state()); 720 } 721 } 722 } 723 724 Status DebugGrpcChannel::ReceiveServerRepliesAndClose() { 725 reader_writer_->WritesDone(); 726 // Read all EventReply messages (if any) from the server. 727 ReceiveAndProcessEventReplies(0); 728 729 if (reader_writer_->Finish().ok()) { 730 return Status::OK(); 731 } else { 732 return Status(error::FAILED_PRECONDITION, 733 "Failed to close debug GRPC stream."); 734 } 735 } 736 737 mutex DebugGrpcIO::streams_mu(LINKER_INITIALIZED); 738 739 int64 DebugGrpcIO::channel_connection_timeout_micros = 900 * 1000 * 1000; 740 // TODO(cais): Make this configurable? 741 742 const size_t DebugGrpcIO::kGrpcMessageSizeLimitBytes = 4000 * 1024; 743 744 const size_t DebugGrpcIO::kGrpcMaxVarintLengthSize = 6; 745 746 std::unordered_map<string, std::unique_ptr<DebugGrpcChannel>>* 747 DebugGrpcIO::GetStreamChannels() { 748 static std::unordered_map<string, std::unique_ptr<DebugGrpcChannel>>* 749 stream_channels = 750 new std::unordered_map<string, std::unique_ptr<DebugGrpcChannel>>(); 751 return stream_channels; 752 } 753 754 Status DebugGrpcIO::SendTensorThroughGrpcStream( 755 const DebugNodeKey& debug_node_key, const Tensor& tensor, 756 const uint64 wall_time_us, const string& grpc_stream_url, 757 const bool gated) { 758 if (gated && 759 !IsReadGateOpen(grpc_stream_url, debug_node_key.debug_node_name)) { 760 return Status::OK(); 761 } else { 762 std::vector<Event> events; 763 TF_RETURN_IF_ERROR(WrapTensorAsEvents(debug_node_key, tensor, wall_time_us, 764 kGrpcMessageSizeLimitBytes, &events)); 765 for (const Event& event : events) { 766 TF_RETURN_IF_ERROR( 767 SendEventProtoThroughGrpcStream(event, grpc_stream_url)); 768 } 769 if (IsWriteGateOpen(grpc_stream_url, debug_node_key.debug_node_name)) { 770 DebugGrpcChannel* debug_grpc_channel = nullptr; 771 TF_RETURN_IF_ERROR( 772 GetOrCreateDebugGrpcChannel(grpc_stream_url, &debug_grpc_channel)); 773 debug_grpc_channel->ReceiveAndProcessEventReplies(1); 774 // TODO(cais): Support new tensor value carried in the EventReply for 775 // overriding the value of the tensor being published. 776 } 777 return Status::OK(); 778 } 779 } 780 781 Status DebugGrpcIO::ReceiveEventReplyProtoThroughGrpcStream( 782 EventReply* event_reply, const string& grpc_stream_url) { 783 DebugGrpcChannel* debug_grpc_channel = nullptr; 784 TF_RETURN_IF_ERROR( 785 GetOrCreateDebugGrpcChannel(grpc_stream_url, &debug_grpc_channel)); 786 if (debug_grpc_channel->ReadEventReply(event_reply)) { 787 return Status::OK(); 788 } else { 789 return errors::Cancelled(strings::StrCat( 790 "Reading EventReply from stream URL ", grpc_stream_url, " failed.")); 791 } 792 } 793 794 Status DebugGrpcIO::GetOrCreateDebugGrpcChannel( 795 const string& grpc_stream_url, DebugGrpcChannel** debug_grpc_channel) { 796 const string addr_with_path = 797 grpc_stream_url.find(DebugIO::kGrpcURLScheme) == 0 798 ? grpc_stream_url.substr(strlen(DebugIO::kGrpcURLScheme)) 799 : grpc_stream_url; 800 const string server_stream_addr = 801 addr_with_path.substr(0, addr_with_path.find('/')); 802 { 803 mutex_lock l(streams_mu); 804 std::unordered_map<string, std::unique_ptr<DebugGrpcChannel>>* 805 stream_channels = GetStreamChannels(); 806 if (stream_channels->find(grpc_stream_url) == stream_channels->end()) { 807 std::unique_ptr<DebugGrpcChannel> channel( 808 new DebugGrpcChannel(server_stream_addr)); 809 TF_RETURN_IF_ERROR(channel->Connect(channel_connection_timeout_micros)); 810 stream_channels->insert( 811 std::make_pair(grpc_stream_url, std::move(channel))); 812 } 813 *debug_grpc_channel = (*stream_channels)[grpc_stream_url].get(); 814 } 815 return Status::OK(); 816 } 817 818 Status DebugGrpcIO::SendEventProtoThroughGrpcStream( 819 const Event& event_proto, const string& grpc_stream_url, 820 const bool receive_reply) { 821 DebugGrpcChannel* debug_grpc_channel; 822 TF_RETURN_IF_ERROR( 823 GetOrCreateDebugGrpcChannel(grpc_stream_url, &debug_grpc_channel)); 824 825 bool write_ok = debug_grpc_channel->WriteEvent(event_proto); 826 if (!write_ok) { 827 return errors::Cancelled(strings::StrCat("Write event to stream URL ", 828 grpc_stream_url, " failed.")); 829 } 830 831 if (receive_reply) { 832 debug_grpc_channel->ReceiveAndProcessEventReplies(1); 833 } 834 835 return Status::OK(); 836 } 837 838 bool DebugGrpcIO::IsReadGateOpen(const string& grpc_debug_url, 839 const string& watch_key) { 840 const DebugNodeName2State* enabled_node_to_state = 841 GetEnabledDebugOpStatesAtUrl(grpc_debug_url); 842 return enabled_node_to_state->find(watch_key) != enabled_node_to_state->end(); 843 } 844 845 bool DebugGrpcIO::IsWriteGateOpen(const string& grpc_debug_url, 846 const string& watch_key) { 847 const DebugNodeName2State* enabled_node_to_state = 848 GetEnabledDebugOpStatesAtUrl(grpc_debug_url); 849 auto it = enabled_node_to_state->find(watch_key); 850 if (it == enabled_node_to_state->end()) { 851 return false; 852 } else { 853 return it->second == EventReply::DebugOpStateChange::READ_WRITE; 854 } 855 } 856 857 Status DebugGrpcIO::CloseGrpcStream(const string& grpc_stream_url) { 858 mutex_lock l(streams_mu); 859 860 std::unordered_map<string, std::unique_ptr<DebugGrpcChannel>>* 861 stream_channels = GetStreamChannels(); 862 if (stream_channels->find(grpc_stream_url) != stream_channels->end()) { 863 // Stream of the specified address exists. Close it and remove it from 864 // record. 865 Status s = 866 (*stream_channels)[grpc_stream_url]->ReceiveServerRepliesAndClose(); 867 (*stream_channels).erase(grpc_stream_url); 868 return s; 869 } else { 870 // Stream of the specified address does not exist. No action. 871 return Status::OK(); 872 } 873 } 874 875 std::unordered_map<string, DebugGrpcIO::DebugNodeName2State>* 876 DebugGrpcIO::GetEnabledDebugOpStates() { 877 static std::unordered_map<string, DebugNodeName2State>* 878 enabled_debug_op_states = 879 new std::unordered_map<string, DebugNodeName2State>(); 880 return enabled_debug_op_states; 881 } 882 883 DebugGrpcIO::DebugNodeName2State* DebugGrpcIO::GetEnabledDebugOpStatesAtUrl( 884 const string& grpc_debug_url) { 885 static mutex* debug_ops_state_mu = new mutex(); 886 std::unordered_map<string, DebugNodeName2State>* states = 887 GetEnabledDebugOpStates(); 888 889 mutex_lock l(*debug_ops_state_mu); 890 if (states->find(grpc_debug_url) == states->end()) { 891 DebugNodeName2State url_enabled_debug_op_states; 892 (*states)[grpc_debug_url] = url_enabled_debug_op_states; 893 } 894 return &(*states)[grpc_debug_url]; 895 } 896 897 void DebugGrpcIO::SetDebugNodeKeyGrpcState( 898 const string& grpc_debug_url, const string& watch_key, 899 const EventReply::DebugOpStateChange::State new_state) { 900 DebugNodeName2State* states = GetEnabledDebugOpStatesAtUrl(grpc_debug_url); 901 if (new_state == EventReply::DebugOpStateChange::DISABLED) { 902 if (states->find(watch_key) == states->end()) { 903 LOG(ERROR) << "Attempt to disable a watch key that is not currently " 904 << "enabled at " << grpc_debug_url << ": " << watch_key; 905 } else { 906 states->erase(watch_key); 907 } 908 } else if (new_state != EventReply::DebugOpStateChange::STATE_UNSPECIFIED) { 909 (*states)[watch_key] = new_state; 910 } 911 } 912 913 void DebugGrpcIO::ClearEnabledWatchKeys() { 914 GetEnabledDebugOpStates()->clear(); 915 } 916 917 #endif // #ifndef PLATFORM_WINDOWS 918 919 } // namespace tensorflow 920