Home | History | Annotate | Download | only in lib
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """gRPC debug server in Python."""
     16 # pylint: disable=g-bad-import-order
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import collections
     22 import json
     23 import threading
     24 import time
     25 
     26 from concurrent import futures
     27 import grpc
     28 from six.moves import queue
     29 
     30 from tensorflow.core.debug import debug_service_pb2
     31 from tensorflow.core.framework import graph_pb2
     32 from tensorflow.python.debug.lib import debug_graphs
     33 from tensorflow.python.debug.lib import debug_service_pb2_grpc
     34 from tensorflow.python.platform import tf_logging as logging
     35 from tensorflow.python.util import compat
     36 
     37 DebugWatch = collections.namedtuple("DebugWatch",
     38                                     ["node_name", "output_slot", "debug_op"])
     39 
     40 
     41 def _state_change(new_state, node_name, output_slot, debug_op):
     42   state_change = debug_service_pb2.EventReply.DebugOpStateChange()
     43   state_change.state = new_state
     44   state_change.node_name = node_name
     45   state_change.output_slot = output_slot
     46   state_change.debug_op = debug_op
     47   return state_change
     48 
     49 
     50 class EventListenerBaseStreamHandler(object):
     51   """Per-stream handler of EventListener gRPC streams."""
     52 
     53   def __init__(self):
     54     """Constructor of EventListenerBaseStreamHandler."""
     55 
     56   def on_core_metadata_event(self, event):
     57     """Callback for core metadata.
     58 
     59     Args:
     60       event: The Event proto that carries a JSON string in its
     61         `log_message.message` field.
     62 
     63     Returns:
     64       `None` or an `EventReply` proto to be sent back to the client. If `None`,
     65       an `EventReply` proto construct with the default no-arg constructor will
     66       be sent back to the client.
     67     """
     68     raise NotImplementedError(
     69         "on_core_metadata_event() is not implemented in the base servicer "
     70         "class")
     71 
     72   def on_graph_def(self, graph_def, device_name, wall_time):
     73     """Callback for Event proto received through the gRPC stream.
     74 
     75     This Event proto carries a GraphDef, encoded as bytes, in its graph_def
     76     field.
     77 
     78     Args:
     79       graph_def: A GraphDef object.
     80       device_name: Name of the device on which the graph was created.
     81       wall_time: An epoch timestamp (in microseconds) for the graph.
     82 
     83     Returns:
     84       `None` or an `EventReply` proto to be sent back to the client. If `None`,
     85       an `EventReply` proto construct with the default no-arg constructor will
     86       be sent back to the client.
     87     """
     88     raise NotImplementedError(
     89         "on_graph_def() is not implemented in the base servicer class")
     90 
     91   def on_value_event(self, event):
     92     """Callback for Event proto received through the gRPC stream.
     93 
     94     This Event proto carries a Tensor in its summary.value[0] field.
     95 
     96     Args:
     97       event: The Event proto from the stream to be processed.
     98     """
     99     raise NotImplementedError(
    100         "on_value_event() is not implemented in the base servicer class")
    101 
    102 
    103 class EventListenerBaseServicer(debug_service_pb2_grpc.EventListenerServicer):
    104   """Base Python class for gRPC debug server."""
    105 
    106   def __init__(self, server_port, stream_handler_class):
    107     """Constructor.
    108 
    109     Args:
    110       server_port: (int) Port number to bind to.
    111       stream_handler_class: A class of the base class
    112         `EventListenerBaseStreamHandler` that will be used to constructor
    113         stream handler objects during `SendEvents` calls.
    114     """
    115 
    116     self._server_port = server_port
    117     self._stream_handler_class = stream_handler_class
    118 
    119     self._server_lock = threading.Lock()
    120     self._server_started = False
    121     self._stop_requested = False
    122 
    123     self._debug_ops_state_change_queue = queue.Queue()
    124     self._gated_grpc_debug_watches = set()
    125     self._breakpoints = set()
    126 
    127   def SendEvents(self, request_iterator, context):
    128     """Implementation of the SendEvents service method.
    129 
    130     This method receives streams of Event protos from the client, and processes
    131     them in ways specified in the on_event() callback. The stream is
    132     bi-directional, but currently only the client-to-server stream (i.e., the
    133     stream from the debug ops to the server) is used.
    134 
    135     Args:
    136       request_iterator: The incoming stream of Event protos.
    137       context: Server context.
    138 
    139     Raises:
    140       ValueError: If there are more than one core metadata events.
    141 
    142     Yields:
    143       An empty stream of responses.
    144     """
    145     core_metadata_count = 0
    146 
    147     # A map from GraphDef hash to a list of received chunks.
    148     graph_def_chunks = {}
    149     tensor_chunks = {}
    150 
    151     stream_handler = None
    152     for event in request_iterator:
    153       if not stream_handler:
    154         stream_handler = self._stream_handler_class()
    155 
    156       if event.summary and event.summary.value:
    157         # An Event proto carrying a tensor value.
    158         maybe_tensor_event = self._process_tensor_event_in_chunks(
    159             event, tensor_chunks)
    160         if maybe_tensor_event:
    161           event_reply = stream_handler.on_value_event(maybe_tensor_event)
    162           if event_reply is not None:
    163             yield self._process_debug_op_state_changes(event_reply)
    164       else:
    165         # Non-tensor-value Event.
    166         if event.graph_def:
    167           # GraphDef-carrying Event.
    168           maybe_graph_def, maybe_device_name, maybe_wall_time = (
    169               self._process_encoded_graph_def_in_chunks(
    170                   event, graph_def_chunks))
    171           if maybe_graph_def:
    172             reply = stream_handler.on_graph_def(
    173                 maybe_graph_def, maybe_device_name, maybe_wall_time)
    174             yield self._process_debug_op_state_changes(reply)
    175         elif event.log_message.message:
    176           # Core metadata-carrying Event.
    177           core_metadata_count += 1
    178           if core_metadata_count > 1:
    179             raise ValueError(
    180                 "Expected one core metadata event; received multiple")
    181           reply = stream_handler.on_core_metadata_event(event)
    182           yield self._process_debug_op_state_changes(reply)
    183 
    184   def _process_debug_op_state_changes(self, event_reply=None):
    185     """Dequeue and process all the queued debug-op state change protos.
    186 
    187     Include all the debug-op state change protos in a `EventReply` proto.
    188 
    189     Args:
    190       event_reply: An `EventReply` to add the `DebugOpStateChange` protos to,
    191         or `None`.
    192 
    193     Returns:
    194       An `EventReply` proto with the dequeued `DebugOpStateChange` protos (if
    195         any) added.
    196     """
    197     if event_reply is None:
    198       event_reply = debug_service_pb2.EventReply()
    199     while not self._debug_ops_state_change_queue.empty():
    200       state_change = self._debug_ops_state_change_queue.get()
    201       debug_node_key = (state_change.node_name, state_change.output_slot,
    202                         state_change.debug_op)
    203       if (state_change.state ==
    204           debug_service_pb2.EventReply.DebugOpStateChange.READ_WRITE):
    205         logging.info("Adding breakpoint %s:%d:%s", state_change.node_name,
    206                      state_change.output_slot, state_change.debug_op)
    207         self._breakpoints.add(debug_node_key)
    208       elif (state_change.state ==
    209             debug_service_pb2.EventReply.DebugOpStateChange.READ_ONLY):
    210         logging.info("Adding watchpoint %s:%d:%s", state_change.node_name,
    211                      state_change.output_slot, state_change.debug_op)
    212         if debug_node_key in self._breakpoints:
    213           self._breakpoints.discard(debug_node_key)
    214       elif (state_change.state ==
    215             debug_service_pb2.EventReply.DebugOpStateChange.DISABLED):
    216         logging.info("Removing watchpoint or breakpoint: %s:%d:%s",
    217                      state_change.node_name, state_change.output_slot,
    218                      state_change.debug_op)
    219         if debug_node_key in self._breakpoints:
    220           self._breakpoints.discard(debug_node_key)
    221         else:
    222           logging.warn(
    223               "Attempting to remove a non-existent debug node key: %s",
    224               debug_node_key)
    225       new_state_change = event_reply.debug_op_state_changes.add()
    226       new_state_change.CopyFrom(state_change)
    227     return event_reply
    228 
    229   def _process_tensor_event_in_chunks(self, event, tensor_chunks):
    230     """Possibly reassemble event chunks.
    231 
    232     Due to gRPC's message size limit, a large tensor can be encapsulated in
    233     multiple Event proto chunks to be sent through the debugger stream. This
    234     method keeps track of the chunks that have arrived, reassemble all chunks
    235     corresponding to a tensor when they have arrived and return the reassembled
    236     Event proto.
    237 
    238     Args:
    239       event: The single Event proto that has arrived.
    240       tensor_chunks: A dict used to keep track of the Event protos that have
    241         arrived but haven't been reassembled.
    242 
    243     Returns:
    244       If all Event protos corresponding to a tensor have arrived, returns the
    245       reassembled Event proto. Otherwise, return None.
    246     """
    247 
    248     value = event.summary.value[0]
    249     debugger_plugin_metadata = json.loads(
    250         compat.as_text(value.metadata.plugin_data.content))
    251     device_name = debugger_plugin_metadata["device"]
    252     num_chunks = debugger_plugin_metadata["numChunks"]
    253     chunk_index = debugger_plugin_metadata["chunkIndex"]
    254 
    255     if num_chunks <= 1:
    256       return event
    257 
    258     debug_node_name = value.node_name
    259     timestamp = int(event.wall_time)
    260     tensor_key = "%s_%s_%d" % (device_name, debug_node_name, timestamp)
    261 
    262     if tensor_key not in tensor_chunks:
    263       tensor_chunks[tensor_key] = [None] * num_chunks
    264 
    265     chunks = tensor_chunks[tensor_key]
    266     if value.tensor.tensor_content:
    267       chunks[chunk_index] = value.tensor
    268     elif value.tensor.string_val:
    269       chunks[chunk_index] = event
    270 
    271     if None not in chunks:
    272       if value.tensor.tensor_content:
    273         event.summary.value[0].tensor.tensor_content = b"".join(
    274             chunk.tensor_content for chunk in chunks)
    275         del tensor_chunks[tensor_key]
    276         return event
    277       elif value.tensor.string_val:
    278         merged_event = chunks[0]
    279         for chunk in chunks[1:]:
    280           merged_event.summary.value[0].tensor.string_val.extend(
    281               list(chunk.summary.value[0].tensor.string_val))
    282         return merged_event
    283 
    284   def _process_encoded_graph_def_in_chunks(self,
    285                                            event,
    286                                            graph_def_chunks):
    287     """Process an Event proto containing a chunk of encoded GraphDef.
    288 
    289     Args:
    290       event: the Event proto containing the chunk of encoded GraphDef.
    291       graph_def_chunks: A dict mapping keys for GraphDefs (i.e.,
    292       "<graph_def_hash>,<device_name>,<wall_time>") to a list of chunks of
    293       encoded GraphDefs.
    294 
    295     Returns:
    296       If all chunks of the GraphDef have arrived,
    297         return decoded GraphDef proto, device name, wall_time.
    298       Otherwise,
    299         return None, None, None.
    300     """
    301     graph_def = graph_pb2.GraphDef()
    302     index_bar_0 = event.graph_def.find(b"|")
    303     index_bar_1 = event.graph_def.find(b"|", index_bar_0 + 1)
    304     index_bar_2 = event.graph_def.find(b"|", index_bar_1 + 1)
    305     graph_def_hash_device_timestamp = event.graph_def[:index_bar_0]
    306     chunk_index = int(event.graph_def[index_bar_0 + 1 : index_bar_1])
    307     num_chunks = int(event.graph_def[index_bar_1 + 1 : index_bar_2])
    308     if graph_def_hash_device_timestamp not in graph_def_chunks:
    309       graph_def_chunks[graph_def_hash_device_timestamp] = [None] * num_chunks
    310     graph_def_chunks[graph_def_hash_device_timestamp][
    311         chunk_index] = event.graph_def[index_bar_2 + 1:]
    312     if all(graph_def_chunks[graph_def_hash_device_timestamp]):
    313       device_name = graph_def_hash_device_timestamp.split(b",")[1]
    314       wall_time = int(graph_def_hash_device_timestamp.split(b",")[2])
    315       graph_def.ParseFromString(
    316           b"".join(graph_def_chunks[graph_def_hash_device_timestamp]))
    317       del graph_def_chunks[graph_def_hash_device_timestamp]
    318       self._process_graph_def(graph_def)
    319       return graph_def, device_name, wall_time
    320     else:
    321       return None, None, None
    322 
    323   def _process_graph_def(self, graph_def):
    324     for node_def in graph_def.node:
    325       if (debug_graphs.is_debug_node(node_def.name) and
    326           node_def.attr["gated_grpc"].b):
    327         node_name, output_slot, _, debug_op = (
    328             debug_graphs.parse_debug_node_name(node_def.name))
    329         self._gated_grpc_debug_watches.add(
    330             DebugWatch(node_name, output_slot, debug_op))
    331 
    332   def run_server(self, blocking=True):
    333     """Start running the server.
    334 
    335     Args:
    336       blocking: If `True`, block until `stop_server()` is invoked.
    337 
    338     Raises:
    339       ValueError: If server stop has already been requested, or if the server
    340         has already started running.
    341     """
    342     self._server_lock.acquire()
    343     try:
    344       if self._stop_requested:
    345         raise ValueError("Server has already stopped")
    346       if self._server_started:
    347         raise ValueError("Server has already started running")
    348 
    349       self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
    350       debug_service_pb2_grpc.add_EventListenerServicer_to_server(self,
    351                                                                  self.server)
    352       self.server.add_insecure_port("[::]:%d" % self._server_port)
    353       self.server.start()
    354       self._server_started = True
    355     finally:
    356       self._server_lock.release()
    357 
    358     if blocking:
    359       while not self._stop_requested:
    360         time.sleep(1.0)
    361 
    362   def stop_server(self, grace=1.0):
    363     """Request server stopping.
    364 
    365     Once stopped, server cannot be stopped or started again. This method is
    366     non-blocking. Call `wait()` on the returned event to block until the server
    367     has completely stopped.
    368 
    369     Args:
    370       grace: Grace period in seconds to be used when calling `server.stop()`.
    371 
    372     Raises:
    373       ValueError: If server stop has already been requested, or if the server
    374         has not started running yet.
    375 
    376     Returns:
    377       A threading.Event that will be set when the server has completely stopped.
    378     """
    379     self._server_lock.acquire()
    380     try:
    381       if not self._server_started:
    382         raise ValueError("Server has not started running")
    383       if self._stop_requested:
    384         raise ValueError("Server has already stopped")
    385 
    386       self._stop_requested = True
    387       return self.server.stop(grace=grace)
    388     finally:
    389       self._server_lock.release()
    390 
    391   def request_watch(self, node_name, output_slot, debug_op, breakpoint=False):
    392     """Request enabling a debug tensor watchpoint or breakpoint.
    393 
    394     This will let the server send a EventReply to the client side
    395     (i.e., the debugged TensorFlow runtime process) to request adding a watch
    396     key (i.e., <node_name>:<output_slot>:<debug_op>) to the list of enabled
    397     watch keys. The list applies only to debug ops with the attribute
    398     gated_grpc=True.
    399 
    400     To disable the watch, use `request_unwatch()`.
    401 
    402     Args:
    403       node_name: (`str`) name of the node that the to-be-watched tensor belongs
    404         to, e.g., "hidden/Weights".
    405       output_slot: (`int`) output slot index of the tensor to watch.
    406       debug_op: (`str`) name of the debug op to enable. This should not include
    407         any attribute substrings.
    408       breakpoint: (`bool`) Iff `True`, the debug op will block and wait until it
    409         receives an `EventReply` response from the server. The `EventReply`
    410         proto may carry a TensorProto that modifies the value of the debug op's
    411         output tensor.
    412     """
    413     self._debug_ops_state_change_queue.put(
    414         _state_change(
    415             debug_service_pb2.EventReply.DebugOpStateChange.READ_WRITE
    416             if breakpoint
    417             else debug_service_pb2.EventReply.DebugOpStateChange.READ_ONLY,
    418             node_name, output_slot, debug_op))
    419 
    420   def request_unwatch(self, node_name, output_slot, debug_op):
    421     """Request disabling a debug tensor watchpoint or breakpoint.
    422 
    423     This is the opposite of `request_watch()`.
    424 
    425     Args:
    426       node_name: (`str`) name of the node that the to-be-watched tensor belongs
    427         to, e.g., "hidden/Weights".
    428       output_slot: (`int`) output slot index of the tensor to watch.
    429       debug_op: (`str`) name of the debug op to enable. This should not include
    430         any attribute substrings.
    431     """
    432     self._debug_ops_state_change_queue.put(
    433         _state_change(
    434             debug_service_pb2.EventReply.DebugOpStateChange.DISABLED, node_name,
    435             output_slot, debug_op))
    436 
    437   @property
    438   def breakpoints(self):
    439     """Get a set of the currently-activated breakpoints.
    440 
    441     Returns:
    442       A `set` of 3-tuples: (node_name, output_slot, debug_op), e.g.,
    443         {("MatMul", 0, "DebugIdentity")}.
    444     """
    445     return self._breakpoints
    446 
    447   def gated_grpc_debug_watches(self):
    448     """Get the list of debug watches with attribute gated_grpc=True.
    449 
    450     Since the server receives `GraphDef` from the debugged runtime, it can only
    451     return such debug watches that it has received so far.
    452 
    453     Returns:
    454       A `list` of `DebugWatch` `namedtuples` representing the debug watches with
    455       gated_grpc=True. Each `namedtuple` element has the attributes:
    456         `node_name` as a `str`,
    457         `output_slot` as an `int`,
    458         `debug_op` as a `str`.
    459     """
    460     return list(self._gated_grpc_debug_watches)
    461 
    462   def SendTracebacks(self, request, context):
    463     """Base implementation of the handling of SendTracebacks calls.
    464 
    465     The base implementation does nothing with the incoming request.
    466     Override in an implementation of the server if necessary.
    467 
    468     Args:
    469       request: A `CallTraceback` proto, containing information about the
    470         type (e.g., graph vs. eager execution) and source-code traceback of the
    471         call and (any) associated `tf.Graph`s.
    472       context: Server context.
    473 
    474     Returns:
    475       A `EventReply` proto.
    476     """
    477     return debug_service_pb2.EventReply()
    478 
    479   def SendSourceFiles(self, request, context):
    480     """Base implementation of the handling of SendSourceFiles calls.
    481 
    482     The base implementation does nothing with the incoming request.
    483     Override in an implementation of the server if necessary.
    484 
    485     Args:
    486       request: A `DebuggedSourceFiles` proto, containing the path, content, size
    487         and last-modified timestamp of source files.
    488       context: Server context.
    489 
    490     Returns:
    491       A `EventReply` proto.
    492     """
    493     return debug_service_pb2.EventReply()
    494