Home | History | Annotate | Download | only in debug
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_DEBUG_IO_UTILS_H_
     17 #define TENSORFLOW_DEBUG_IO_UTILS_H_
     18 
     19 #include <cstddef>
     20 #include <functional>
     21 #include <memory>
     22 #include <string>
     23 #include <unordered_map>
     24 #include <unordered_set>
     25 #include <vector>
     26 
     27 #include "tensorflow/core/debug/debug_node_key.h"
     28 #include "tensorflow/core/framework/tensor.h"
     29 #include "tensorflow/core/graph/graph.h"
     30 #include "tensorflow/core/lib/core/status.h"
     31 #include "tensorflow/core/lib/gtl/array_slice.h"
     32 #include "tensorflow/core/platform/env.h"
     33 #include "tensorflow/core/util/event.pb.h"
     34 
     35 namespace tensorflow {
     36 
     37 Status ReadEventFromFile(const string& dump_file_path, Event* event);
     38 
     39 struct DebugWatchAndURLSpec {
     40   DebugWatchAndURLSpec(const string& watch_key, const string& url,
     41                        const bool gated_grpc)
     42       : watch_key(watch_key), url(url), gated_grpc(gated_grpc) {}
     43 
     44   const string watch_key;
     45   const string url;
     46   const bool gated_grpc;
     47 };
     48 
     49 // TODO(cais): Put static functions and members in a namespace, not a class.
     50 class DebugIO {
     51  public:
     52   static const char* const kDebuggerPluginName;
     53 
     54   static const char* const kCoreMetadataTag;
     55   static const char* const kGraphTag;
     56   static const char* const kHashTag;
     57 
     58   static const char* const kFileURLScheme;
     59   static const char* const kGrpcURLScheme;
     60   static const char* const kMemoryURLScheme;
     61 
     62   static Status PublishDebugMetadata(
     63       const int64 global_step, const int64 session_run_index,
     64       const int64 executor_step_index, const std::vector<string>& input_names,
     65       const std::vector<string>& output_names,
     66       const std::vector<string>& target_nodes,
     67       const std::unordered_set<string>& debug_urls);
     68 
     69   // Publishes a tensor to a debug target URL.
     70   //
     71   // Args:
     72   //   debug_node_key: A DebugNodeKey identifying the debug node.
     73   //   tensor: The Tensor object being published.
     74   //   wall_time_us: Time stamp for the Tensor. Unit: microseconds (us).
     75   //   debug_urls: An array of debug target URLs, e.g.,
     76   //     "file:///foo/tfdbg_dump", "grpc://localhost:11011"
     77   //   gated_grpc: Whether this call is subject to gRPC gating.
     78   static Status PublishDebugTensor(const DebugNodeKey& debug_node_key,
     79                                    const Tensor& tensor,
     80                                    const uint64 wall_time_us,
     81                                    const gtl::ArraySlice<string>& debug_urls,
     82                                    const bool gated_grpc);
     83 
     84   // Convenience overload of the method above for no gated_grpc by default.
     85   static Status PublishDebugTensor(const DebugNodeKey& debug_node_key,
     86                                    const Tensor& tensor,
     87                                    const uint64 wall_time_us,
     88                                    const gtl::ArraySlice<string>& debug_urls);
     89 
     90   // Publishes a graph to a set of debug URLs.
     91   //
     92   // Args:
     93   //   graph: The graph to be published.
     94   //   debug_urls: The set of debug URLs to publish the graph to.
     95   static Status PublishGraph(const Graph& graph, const string& device_name,
     96                              const std::unordered_set<string>& debug_urls);
     97 
     98   // Determines whether a copy node needs to perform deep-copy of input tensor.
     99   //
    100   // The input arguments contain sufficient information about the attached
    101   // downstream debug ops for this method to determine whether all the said
    102   // ops are disabled given the current status of the gRPC gating.
    103   //
    104   // Args:
    105   //   specs: A vector of DebugWatchAndURLSpec carrying information about the
    106   //     debug ops attached to the Copy node, their debug URLs and whether
    107   //     they have the attribute value gated_grpc == True.
    108   //
    109   // Returns:
    110   //   Whether any of the attached downstream debug ops is enabled given the
    111   //   current status of the gRPC gating.
    112   static bool IsCopyNodeGateOpen(
    113       const std::vector<DebugWatchAndURLSpec>& specs);
    114 
    115   // Determines whether a debug node needs to proceed given the current gRPC
    116   // gating status.
    117   //
    118   // Args:
    119   //   watch_key: debug tensor watch key, in the format of
    120   //     tensor_name:debug_op, e.g., "Weights:0:DebugIdentity".
    121   //   debug_urls: the debug URLs of the debug node.
    122   //
    123   // Returns:
    124   //   Whether this debug op should proceed.
    125   static bool IsDebugNodeGateOpen(const string& watch_key,
    126                                   const std::vector<string>& debug_urls);
    127 
    128   // Determines whether debug information should be sent through a grpc://
    129   // debug URL given the current gRPC gating status.
    130   //
    131   // Args:
    132   //   watch_key: debug tensor watch key, in the format of
    133   //     tensor_name:debug_op, e.g., "Weights:0:DebugIdentity".
    134   //   debug_url: the debug URL, e.g., "grpc://localhost:3333",
    135   //     "file:///tmp/tfdbg_1".
    136   //
    137   // Returns:
    138   //   Whether the sending of debug data to the debug_url should
    139   //     proceed.
    140   static bool IsDebugURLGateOpen(const string& watch_key,
    141                                  const string& debug_url);
    142 
    143   static Status CloseDebugURL(const string& debug_url);
    144 };
    145 
    146 // Helper class for debug ops.
    147 class DebugFileIO {
    148  public:
    149   // Encapsulates the Tensor in an Event protobuf and write it to a directory.
    150   // The actual path of the dump file will be a contactenation of
    151   // dump_root_dir, tensor_name, along with the wall_time.
    152   //
    153   // For example:
    154   //   let dump_root_dir = "/tmp/tfdbg_dump",
    155   //       node_name = "foo/bar",
    156   //       output_slot = 0,
    157   //       debug_op = DebugIdentity,
    158   //       and wall_time_us = 1467891234512345,
    159   // the dump file will be generated at path:
    160   //   /tmp/tfdbg_dump/foo/bar_0_DebugIdentity_1467891234512345.
    161   //
    162   // Args:
    163   //   debug_node_key: A DebugNodeKey identifying the debug node.
    164   //   wall_time_us: Wall time at which the Tensor is generated during graph
    165   //     execution. Unit: microseconds (us).
    166   //   dump_root_dir: Root directory for dumping the tensor.
    167   //   dump_file_path: The actual dump file path (passed as reference).
    168   static Status DumpTensorToDir(const DebugNodeKey& debug_node_key,
    169                                 const Tensor& tensor, const uint64 wall_time_us,
    170                                 const string& dump_root_dir,
    171                                 string* dump_file_path);
    172 
    173   // Get the full path to the dump file.
    174   //
    175   // Args:
    176   //   dump_root_dir: The dump root directory, e.g., /tmp/tfdbg_dump
    177   //   node_name: Name of the node from which the dumped tensor is generated,
    178   //     e.g., foo/bar/node_a
    179   //   output_slot: Output slot index of the said node, e.g., 0.
    180   //   debug_op: Name of the debug op, e.g., DebugIdentity.
    181   //   wall_time_us: Time stamp of the dumped tensor, in microseconds (us).
    182   static string GetDumpFilePath(const string& dump_root_dir,
    183                                 const DebugNodeKey& debug_node_key,
    184                                 const uint64 wall_time_us);
    185 
    186   // Dumps an Event proto to a file.
    187   //
    188   // Args:
    189   //   event_prot: The Event proto to be dumped.
    190   //   dir_name: Directory path.
    191   //   file_name: Base file name.
    192   static Status DumpEventProtoToFile(const Event& event_proto,
    193                                      const string& dir_name,
    194                                      const string& file_name);
    195 
    196  private:
    197   // Encapsulates the Tensor in an Event protobuf and write it to file.
    198   static Status DumpTensorToEventFile(const DebugNodeKey& debug_node_key,
    199                                       const Tensor& tensor,
    200                                       const uint64 wall_time_us,
    201                                       const string& file_path);
    202 
    203   // Implemented ad hoc here for now.
    204   // TODO(cais): Replace with shared implementation once http://b/30497715 is
    205   // fixed.
    206   static Status RecursiveCreateDir(Env* env, const string& dir);
    207 };
    208 
    209 }  // namespace tensorflow
    210 
    211 namespace std {
    212 
    213 template <>
    214 struct hash<::tensorflow::DebugNodeKey> {
    215   size_t operator()(const ::tensorflow::DebugNodeKey& k) const {
    216     return ::tensorflow::Hash64(
    217         ::tensorflow::strings::StrCat(k.device_name, ":", k.node_name, ":",
    218                                       k.output_slot, ":", k.debug_op, ":"));
    219   }
    220 };
    221 
    222 }  // namespace std
    223 
    224 // TODO(cais): Support grpc:// debug URLs in open source once Python grpc
    225 //   genrule becomes available. See b/23796275.
    226 #ifndef PLATFORM_WINDOWS
    227 #include "tensorflow/core/debug/debug_service.grpc.pb.h"
    228 
    229 namespace tensorflow {
    230 
    231 class DebugGrpcChannel {
    232  public:
    233   // Constructor of DebugGrpcChannel.
    234   //
    235   // Args:
    236   //   server_stream_addr: Address (host name and port) of the debug stream
    237   //     server implementing the EventListener service (see
    238   //     debug_service.proto). E.g., "127.0.0.1:12345".
    239   DebugGrpcChannel(const string& server_stream_addr);
    240 
    241   virtual ~DebugGrpcChannel() {}
    242 
    243   // Attempt to establish connection with server.
    244   //
    245   // Args:
    246   //   timeout_micros: Timeout (in microseconds) for the attempt to establish
    247   //     the connection.
    248   //
    249   // Returns:
    250   //   OK Status iff connection is successfully established before timeout,
    251   //   otherwise return an error Status.
    252   Status Connect(const int64 timeout_micros);
    253 
    254   // Write an Event proto to the debug gRPC stream.
    255   //
    256   // Thread-safety: Safe with respect to other calls to the same method and
    257   //   calls to ReadEventReply() and Close().
    258   //
    259   // Args:
    260   //   event: The event proto to be written to the stream.
    261   //
    262   // Returns:
    263   //   True iff the write is successful.
    264   bool WriteEvent(const Event& event);
    265 
    266   // Read an EventReply proto from the debug gRPC stream.
    267   //
    268   // This method blocks and waits for an EventReply from the server.
    269   // Thread-safety: Safe with respect to other calls to the same method and
    270   //   calls to WriteEvent() and Close().
    271   //
    272   // Args:
    273   //   event_reply: the to-be-modified EventReply proto passed as reference.
    274   //
    275   // Returns:
    276   //   True iff the read is successful.
    277   bool ReadEventReply(EventReply* event_reply);
    278 
    279   // Receive and process EventReply protos from the gRPC debug server.
    280   //
    281   // The processing includes setting debug watch key states using the
    282   // DebugOpStateChange fields of the EventReply.
    283   //
    284   // Args:
    285   //   max_replies: Maximum number of replies to receive. Will receive all
    286   //     remaining replies iff max_replies == 0.
    287   void ReceiveAndProcessEventReplies(size_t max_replies);
    288 
    289   // Receive EventReplies from server (if any) and close the stream and the
    290   // channel.
    291   Status ReceiveServerRepliesAndClose();
    292 
    293  private:
    294   string server_stream_addr_;
    295   string url_;
    296   ::grpc::ClientContext ctx_;
    297   std::shared_ptr<::grpc::Channel> channel_;
    298   std::unique_ptr<EventListener::Stub> stub_;
    299   std::unique_ptr<::grpc::ClientReaderWriterInterface<Event, EventReply>>
    300       reader_writer_;
    301 
    302   mutex mu_;
    303 };
    304 
    305 class DebugGrpcIO {
    306  public:
    307   static const size_t kGrpcMessageSizeLimitBytes;
    308   static const size_t kGrpcMaxVarintLengthSize;
    309 
    310   // Sends a tensor through a debug gRPC stream.
    311   static Status SendTensorThroughGrpcStream(const DebugNodeKey& debug_node_key,
    312                                             const Tensor& tensor,
    313                                             const uint64 wall_time_us,
    314                                             const string& grpc_stream_url,
    315                                             const bool gated);
    316 
    317   // Sends an Event proto through a debug gRPC stream.
    318   // Thread-safety: Safe with respect to other calls to the same method and
    319   // calls to CloseGrpcStream().
    320   //
    321   // Args:
    322   //   event_proto: The Event proto to be sent.
    323   //   grpc_stream_url: The grpc:// URL of the stream to use, e.g.,
    324   //     "grpc://localhost:11011", "localhost:22022".
    325   //   receive_reply: Whether an EventReply proto will be read after event_proto
    326   //     is sent and before the function returns.
    327   //
    328   // Returns:
    329   //   The Status of the operation.
    330   static Status SendEventProtoThroughGrpcStream(
    331       const Event& event_proto, const string& grpc_stream_url,
    332       const bool receive_reply = false);
    333 
    334   // Receive an EventReply proto through a debug gRPC stream.
    335   static Status ReceiveEventReplyProtoThroughGrpcStream(
    336       EventReply* event_reply, const string& grpc_stream_url);
    337 
    338   // Check whether a debug watch key is read-activated at a given gRPC URL.
    339   static bool IsReadGateOpen(const string& grpc_debug_url,
    340                              const string& watch_key);
    341 
    342   // Check whether a debug watch key is write-activated (i.e., read- and
    343   // write-activated) at a given gRPC URL.
    344   static bool IsWriteGateOpen(const string& grpc_debug_url,
    345                               const string& watch_key);
    346 
    347   // Closes a gRPC stream to the given address, if it exists.
    348   // Thread-safety: Safe with respect to other calls to the same method and
    349   // calls to SendTensorThroughGrpcStream().
    350   static Status CloseGrpcStream(const string& grpc_stream_url);
    351 
    352   // Set the gRPC state of a debug node key.
    353   // TODO(cais): Include device information in watch_key.
    354   static void SetDebugNodeKeyGrpcState(
    355       const string& grpc_debug_url, const string& watch_key,
    356       const EventReply::DebugOpStateChange::State new_state);
    357 
    358  private:
    359   using DebugNodeName2State =
    360       std::unordered_map<string, EventReply::DebugOpStateChange::State>;
    361 
    362   // Returns a global map from grpc debug URLs to the corresponding
    363   // DebugGrpcChannels.
    364   static std::unordered_map<string, std::unique_ptr<DebugGrpcChannel>>*
    365   GetStreamChannels();
    366 
    367   // Get a DebugGrpcChannel object at a given URL, creating one if necessary.
    368   //
    369   // Args:
    370   //   grpc_stream_url: grpc:// URL of the stream, e.g., "grpc://localhost:6064"
    371   //   debug_grpc_channel: A pointer to the DebugGrpcChannel object, passed as a
    372   //     a pointer to the pointer. The DebugGrpcChannel object is owned
    373   //     statically elsewhere, not by the caller of this function.
    374   //
    375   // Returns:
    376   //   Status of this operation.
    377   static Status GetOrCreateDebugGrpcChannel(
    378       const string& grpc_stream_url, DebugGrpcChannel** debug_grpc_channel);
    379 
    380   // Returns a map from debug URL to a map from debug op name to enabled state.
    381   static std::unordered_map<string, DebugNodeName2State>*
    382   GetEnabledDebugOpStates();
    383 
    384   // Returns a map from debug op names to enabled state, for a given debug URL.
    385   static DebugNodeName2State* GetEnabledDebugOpStatesAtUrl(
    386       const string& grpc_debug_url);
    387 
    388   // Clear enabled debug op state from all debug URLs (if any).
    389   static void ClearEnabledWatchKeys();
    390 
    391   static mutex streams_mu;
    392   static int64 channel_connection_timeout_micros;
    393 
    394   friend class GrpcDebugTest;
    395   friend class DebugNumericSummaryOpTest;
    396 };
    397 
    398 }  // namespace tensorflow
    399 #endif  // #ifndef(PLATFORM_WINDOWS)
    400 
    401 #endif  // TENSORFLOW_DEBUG_IO_UTILS_H_
    402