1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_DEBUG_IO_UTILS_H_ 17 #define TENSORFLOW_DEBUG_IO_UTILS_H_ 18 19 #include <cstddef> 20 #include <functional> 21 #include <memory> 22 #include <string> 23 #include <unordered_map> 24 #include <unordered_set> 25 #include <vector> 26 27 #include "tensorflow/core/debug/debug_node_key.h" 28 #include "tensorflow/core/framework/tensor.h" 29 #include "tensorflow/core/graph/graph.h" 30 #include "tensorflow/core/lib/core/status.h" 31 #include "tensorflow/core/lib/gtl/array_slice.h" 32 #include "tensorflow/core/platform/env.h" 33 #include "tensorflow/core/util/event.pb.h" 34 35 namespace tensorflow { 36 37 Status ReadEventFromFile(const string& dump_file_path, Event* event); 38 39 struct DebugWatchAndURLSpec { 40 DebugWatchAndURLSpec(const string& watch_key, const string& url, 41 const bool gated_grpc) 42 : watch_key(watch_key), url(url), gated_grpc(gated_grpc) {} 43 44 const string watch_key; 45 const string url; 46 const bool gated_grpc; 47 }; 48 49 // TODO(cais): Put static functions and members in a namespace, not a class. 50 class DebugIO { 51 public: 52 static const char* const kDebuggerPluginName; 53 54 static const char* const kCoreMetadataTag; 55 static const char* const kGraphTag; 56 static const char* const kHashTag; 57 58 static const char* const kFileURLScheme; 59 static const char* const kGrpcURLScheme; 60 static const char* const kMemoryURLScheme; 61 62 static Status PublishDebugMetadata( 63 const int64 global_step, const int64 session_run_index, 64 const int64 executor_step_index, const std::vector<string>& input_names, 65 const std::vector<string>& output_names, 66 const std::vector<string>& target_nodes, 67 const std::unordered_set<string>& debug_urls); 68 69 // Publishes a tensor to a debug target URL. 70 // 71 // Args: 72 // debug_node_key: A DebugNodeKey identifying the debug node. 73 // tensor: The Tensor object being published. 74 // wall_time_us: Time stamp for the Tensor. Unit: microseconds (us). 75 // debug_urls: An array of debug target URLs, e.g., 76 // "file:///foo/tfdbg_dump", "grpc://localhost:11011" 77 // gated_grpc: Whether this call is subject to gRPC gating. 78 static Status PublishDebugTensor(const DebugNodeKey& debug_node_key, 79 const Tensor& tensor, 80 const uint64 wall_time_us, 81 const gtl::ArraySlice<string>& debug_urls, 82 const bool gated_grpc); 83 84 // Convenience overload of the method above for no gated_grpc by default. 85 static Status PublishDebugTensor(const DebugNodeKey& debug_node_key, 86 const Tensor& tensor, 87 const uint64 wall_time_us, 88 const gtl::ArraySlice<string>& debug_urls); 89 90 // Publishes a graph to a set of debug URLs. 91 // 92 // Args: 93 // graph: The graph to be published. 94 // debug_urls: The set of debug URLs to publish the graph to. 95 static Status PublishGraph(const Graph& graph, const string& device_name, 96 const std::unordered_set<string>& debug_urls); 97 98 // Determines whether a copy node needs to perform deep-copy of input tensor. 99 // 100 // The input arguments contain sufficient information about the attached 101 // downstream debug ops for this method to determine whether all the said 102 // ops are disabled given the current status of the gRPC gating. 103 // 104 // Args: 105 // specs: A vector of DebugWatchAndURLSpec carrying information about the 106 // debug ops attached to the Copy node, their debug URLs and whether 107 // they have the attribute value gated_grpc == True. 108 // 109 // Returns: 110 // Whether any of the attached downstream debug ops is enabled given the 111 // current status of the gRPC gating. 112 static bool IsCopyNodeGateOpen( 113 const std::vector<DebugWatchAndURLSpec>& specs); 114 115 // Determines whether a debug node needs to proceed given the current gRPC 116 // gating status. 117 // 118 // Args: 119 // watch_key: debug tensor watch key, in the format of 120 // tensor_name:debug_op, e.g., "Weights:0:DebugIdentity". 121 // debug_urls: the debug URLs of the debug node. 122 // 123 // Returns: 124 // Whether this debug op should proceed. 125 static bool IsDebugNodeGateOpen(const string& watch_key, 126 const std::vector<string>& debug_urls); 127 128 // Determines whether debug information should be sent through a grpc:// 129 // debug URL given the current gRPC gating status. 130 // 131 // Args: 132 // watch_key: debug tensor watch key, in the format of 133 // tensor_name:debug_op, e.g., "Weights:0:DebugIdentity". 134 // debug_url: the debug URL, e.g., "grpc://localhost:3333", 135 // "file:///tmp/tfdbg_1". 136 // 137 // Returns: 138 // Whether the sending of debug data to the debug_url should 139 // proceed. 140 static bool IsDebugURLGateOpen(const string& watch_key, 141 const string& debug_url); 142 143 static Status CloseDebugURL(const string& debug_url); 144 }; 145 146 // Helper class for debug ops. 147 class DebugFileIO { 148 public: 149 // Encapsulates the Tensor in an Event protobuf and write it to a directory. 150 // The actual path of the dump file will be a contactenation of 151 // dump_root_dir, tensor_name, along with the wall_time. 152 // 153 // For example: 154 // let dump_root_dir = "/tmp/tfdbg_dump", 155 // node_name = "foo/bar", 156 // output_slot = 0, 157 // debug_op = DebugIdentity, 158 // and wall_time_us = 1467891234512345, 159 // the dump file will be generated at path: 160 // /tmp/tfdbg_dump/foo/bar_0_DebugIdentity_1467891234512345. 161 // 162 // Args: 163 // debug_node_key: A DebugNodeKey identifying the debug node. 164 // wall_time_us: Wall time at which the Tensor is generated during graph 165 // execution. Unit: microseconds (us). 166 // dump_root_dir: Root directory for dumping the tensor. 167 // dump_file_path: The actual dump file path (passed as reference). 168 static Status DumpTensorToDir(const DebugNodeKey& debug_node_key, 169 const Tensor& tensor, const uint64 wall_time_us, 170 const string& dump_root_dir, 171 string* dump_file_path); 172 173 // Get the full path to the dump file. 174 // 175 // Args: 176 // dump_root_dir: The dump root directory, e.g., /tmp/tfdbg_dump 177 // node_name: Name of the node from which the dumped tensor is generated, 178 // e.g., foo/bar/node_a 179 // output_slot: Output slot index of the said node, e.g., 0. 180 // debug_op: Name of the debug op, e.g., DebugIdentity. 181 // wall_time_us: Time stamp of the dumped tensor, in microseconds (us). 182 static string GetDumpFilePath(const string& dump_root_dir, 183 const DebugNodeKey& debug_node_key, 184 const uint64 wall_time_us); 185 186 // Dumps an Event proto to a file. 187 // 188 // Args: 189 // event_prot: The Event proto to be dumped. 190 // dir_name: Directory path. 191 // file_name: Base file name. 192 static Status DumpEventProtoToFile(const Event& event_proto, 193 const string& dir_name, 194 const string& file_name); 195 196 private: 197 // Encapsulates the Tensor in an Event protobuf and write it to file. 198 static Status DumpTensorToEventFile(const DebugNodeKey& debug_node_key, 199 const Tensor& tensor, 200 const uint64 wall_time_us, 201 const string& file_path); 202 203 // Implemented ad hoc here for now. 204 // TODO(cais): Replace with shared implementation once http://b/30497715 is 205 // fixed. 206 static Status RecursiveCreateDir(Env* env, const string& dir); 207 }; 208 209 } // namespace tensorflow 210 211 namespace std { 212 213 template <> 214 struct hash<::tensorflow::DebugNodeKey> { 215 size_t operator()(const ::tensorflow::DebugNodeKey& k) const { 216 return ::tensorflow::Hash64( 217 ::tensorflow::strings::StrCat(k.device_name, ":", k.node_name, ":", 218 k.output_slot, ":", k.debug_op, ":")); 219 } 220 }; 221 222 } // namespace std 223 224 // TODO(cais): Support grpc:// debug URLs in open source once Python grpc 225 // genrule becomes available. See b/23796275. 226 #ifndef PLATFORM_WINDOWS 227 #include "tensorflow/core/debug/debug_service.grpc.pb.h" 228 229 namespace tensorflow { 230 231 class DebugGrpcChannel { 232 public: 233 // Constructor of DebugGrpcChannel. 234 // 235 // Args: 236 // server_stream_addr: Address (host name and port) of the debug stream 237 // server implementing the EventListener service (see 238 // debug_service.proto). E.g., "127.0.0.1:12345". 239 DebugGrpcChannel(const string& server_stream_addr); 240 241 virtual ~DebugGrpcChannel() {} 242 243 // Attempt to establish connection with server. 244 // 245 // Args: 246 // timeout_micros: Timeout (in microseconds) for the attempt to establish 247 // the connection. 248 // 249 // Returns: 250 // OK Status iff connection is successfully established before timeout, 251 // otherwise return an error Status. 252 Status Connect(const int64 timeout_micros); 253 254 // Write an Event proto to the debug gRPC stream. 255 // 256 // Thread-safety: Safe with respect to other calls to the same method and 257 // calls to ReadEventReply() and Close(). 258 // 259 // Args: 260 // event: The event proto to be written to the stream. 261 // 262 // Returns: 263 // True iff the write is successful. 264 bool WriteEvent(const Event& event); 265 266 // Read an EventReply proto from the debug gRPC stream. 267 // 268 // This method blocks and waits for an EventReply from the server. 269 // Thread-safety: Safe with respect to other calls to the same method and 270 // calls to WriteEvent() and Close(). 271 // 272 // Args: 273 // event_reply: the to-be-modified EventReply proto passed as reference. 274 // 275 // Returns: 276 // True iff the read is successful. 277 bool ReadEventReply(EventReply* event_reply); 278 279 // Receive and process EventReply protos from the gRPC debug server. 280 // 281 // The processing includes setting debug watch key states using the 282 // DebugOpStateChange fields of the EventReply. 283 // 284 // Args: 285 // max_replies: Maximum number of replies to receive. Will receive all 286 // remaining replies iff max_replies == 0. 287 void ReceiveAndProcessEventReplies(size_t max_replies); 288 289 // Receive EventReplies from server (if any) and close the stream and the 290 // channel. 291 Status ReceiveServerRepliesAndClose(); 292 293 private: 294 string server_stream_addr_; 295 string url_; 296 ::grpc::ClientContext ctx_; 297 std::shared_ptr<::grpc::Channel> channel_; 298 std::unique_ptr<EventListener::Stub> stub_; 299 std::unique_ptr<::grpc::ClientReaderWriterInterface<Event, EventReply>> 300 reader_writer_; 301 302 mutex mu_; 303 }; 304 305 class DebugGrpcIO { 306 public: 307 static const size_t kGrpcMessageSizeLimitBytes; 308 static const size_t kGrpcMaxVarintLengthSize; 309 310 // Sends a tensor through a debug gRPC stream. 311 static Status SendTensorThroughGrpcStream(const DebugNodeKey& debug_node_key, 312 const Tensor& tensor, 313 const uint64 wall_time_us, 314 const string& grpc_stream_url, 315 const bool gated); 316 317 // Sends an Event proto through a debug gRPC stream. 318 // Thread-safety: Safe with respect to other calls to the same method and 319 // calls to CloseGrpcStream(). 320 // 321 // Args: 322 // event_proto: The Event proto to be sent. 323 // grpc_stream_url: The grpc:// URL of the stream to use, e.g., 324 // "grpc://localhost:11011", "localhost:22022". 325 // receive_reply: Whether an EventReply proto will be read after event_proto 326 // is sent and before the function returns. 327 // 328 // Returns: 329 // The Status of the operation. 330 static Status SendEventProtoThroughGrpcStream( 331 const Event& event_proto, const string& grpc_stream_url, 332 const bool receive_reply = false); 333 334 // Receive an EventReply proto through a debug gRPC stream. 335 static Status ReceiveEventReplyProtoThroughGrpcStream( 336 EventReply* event_reply, const string& grpc_stream_url); 337 338 // Check whether a debug watch key is read-activated at a given gRPC URL. 339 static bool IsReadGateOpen(const string& grpc_debug_url, 340 const string& watch_key); 341 342 // Check whether a debug watch key is write-activated (i.e., read- and 343 // write-activated) at a given gRPC URL. 344 static bool IsWriteGateOpen(const string& grpc_debug_url, 345 const string& watch_key); 346 347 // Closes a gRPC stream to the given address, if it exists. 348 // Thread-safety: Safe with respect to other calls to the same method and 349 // calls to SendTensorThroughGrpcStream(). 350 static Status CloseGrpcStream(const string& grpc_stream_url); 351 352 // Set the gRPC state of a debug node key. 353 // TODO(cais): Include device information in watch_key. 354 static void SetDebugNodeKeyGrpcState( 355 const string& grpc_debug_url, const string& watch_key, 356 const EventReply::DebugOpStateChange::State new_state); 357 358 private: 359 using DebugNodeName2State = 360 std::unordered_map<string, EventReply::DebugOpStateChange::State>; 361 362 // Returns a global map from grpc debug URLs to the corresponding 363 // DebugGrpcChannels. 364 static std::unordered_map<string, std::unique_ptr<DebugGrpcChannel>>* 365 GetStreamChannels(); 366 367 // Get a DebugGrpcChannel object at a given URL, creating one if necessary. 368 // 369 // Args: 370 // grpc_stream_url: grpc:// URL of the stream, e.g., "grpc://localhost:6064" 371 // debug_grpc_channel: A pointer to the DebugGrpcChannel object, passed as a 372 // a pointer to the pointer. The DebugGrpcChannel object is owned 373 // statically elsewhere, not by the caller of this function. 374 // 375 // Returns: 376 // Status of this operation. 377 static Status GetOrCreateDebugGrpcChannel( 378 const string& grpc_stream_url, DebugGrpcChannel** debug_grpc_channel); 379 380 // Returns a map from debug URL to a map from debug op name to enabled state. 381 static std::unordered_map<string, DebugNodeName2State>* 382 GetEnabledDebugOpStates(); 383 384 // Returns a map from debug op names to enabled state, for a given debug URL. 385 static DebugNodeName2State* GetEnabledDebugOpStatesAtUrl( 386 const string& grpc_debug_url); 387 388 // Clear enabled debug op state from all debug URLs (if any). 389 static void ClearEnabledWatchKeys(); 390 391 static mutex streams_mu; 392 static int64 channel_connection_timeout_micros; 393 394 friend class GrpcDebugTest; 395 friend class DebugNumericSummaryOpTest; 396 }; 397 398 } // namespace tensorflow 399 #endif // #ifndef(PLATFORM_WINDOWS) 400 401 #endif // TENSORFLOW_DEBUG_IO_UTILS_H_ 402