Home | History | Annotate | Download | only in lib
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """GRPC debug server for testing."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import collections
     21 import errno
     22 import functools
     23 import hashlib
     24 import json
     25 import os
     26 import re
     27 import shutil
     28 import tempfile
     29 import threading
     30 import time
     31 
     32 import portpicker
     33 
     34 from tensorflow.core.debug import debug_service_pb2
     35 from tensorflow.core.protobuf import config_pb2
     36 from tensorflow.core.util import event_pb2
     37 from tensorflow.python.client import session
     38 from tensorflow.python.debug.lib import debug_data
     39 from tensorflow.python.debug.lib import debug_utils
     40 from tensorflow.python.debug.lib import grpc_debug_server
     41 from tensorflow.python.framework import constant_op
     42 from tensorflow.python.framework import errors
     43 from tensorflow.python.ops import variables
     44 from tensorflow.python.util import compat
     45 
     46 
     47 def _get_dump_file_path(dump_root, device_name, debug_node_name):
     48   """Get the file path of the dump file for a debug node.
     49 
     50   Args:
     51     dump_root: (str) Root dump directory.
     52     device_name: (str) Name of the device that the debug node resides on.
     53     debug_node_name: (str) Name of the debug node, e.g.,
     54       cross_entropy/Log:0:DebugIdentity.
     55 
     56   Returns:
     57     (str) Full path of the dump file.
     58   """
     59 
     60   dump_root = os.path.join(
     61       dump_root, debug_data.device_name_to_device_path(device_name))
     62   if "/" in debug_node_name:
     63     dump_dir = os.path.join(dump_root, os.path.dirname(debug_node_name))
     64     dump_file_name = re.sub(":", "_", os.path.basename(debug_node_name))
     65   else:
     66     dump_dir = dump_root
     67     dump_file_name = re.sub(":", "_", debug_node_name)
     68 
     69   now_microsec = int(round(time.time() * 1000 * 1000))
     70   dump_file_name += "_%d" % now_microsec
     71 
     72   return os.path.join(dump_dir, dump_file_name)
     73 
     74 
     75 class EventListenerTestStreamHandler(
     76     grpc_debug_server.EventListenerBaseStreamHandler):
     77   """Implementation of EventListenerBaseStreamHandler that dumps to file."""
     78 
     79   def __init__(self, dump_dir, event_listener_servicer):
     80     super(EventListenerTestStreamHandler, self).__init__()
     81     self._dump_dir = dump_dir
     82     self._event_listener_servicer = event_listener_servicer
     83     if self._dump_dir:
     84       self._try_makedirs(self._dump_dir)
     85 
     86     self._grpc_path = None
     87     self._cached_graph_defs = []
     88     self._cached_graph_def_device_names = []
     89     self._cached_graph_def_wall_times = []
     90 
     91   def on_core_metadata_event(self, event):
     92     self._event_listener_servicer.toggle_watch()
     93 
     94     core_metadata = json.loads(event.log_message.message)
     95 
     96     if not self._grpc_path:
     97       grpc_path = core_metadata["grpc_path"]
     98       if grpc_path:
     99         if grpc_path.startswith("/"):
    100           grpc_path = grpc_path[1:]
    101       if self._dump_dir:
    102         self._dump_dir = os.path.join(self._dump_dir, grpc_path)
    103 
    104         # Write cached graph defs to filesystem.
    105         for graph_def, device_name, wall_time in zip(
    106             self._cached_graph_defs,
    107             self._cached_graph_def_device_names,
    108             self._cached_graph_def_wall_times):
    109           self._write_graph_def(graph_def, device_name, wall_time)
    110 
    111     if self._dump_dir:
    112       self._write_core_metadata_event(event)
    113     else:
    114       self._event_listener_servicer.core_metadata_json_strings.append(
    115           event.log_message.message)
    116 
    117   def on_graph_def(self, graph_def, device_name, wall_time):
    118     """Implementation of the tensor value-carrying Event proto callback.
    119 
    120     Args:
    121       graph_def: A GraphDef object.
    122       device_name: Name of the device on which the graph was created.
    123       wall_time: An epoch timestamp (in microseconds) for the graph.
    124     """
    125     if self._dump_dir:
    126       if self._grpc_path:
    127         self._write_graph_def(graph_def, device_name, wall_time)
    128       else:
    129         self._cached_graph_defs.append(graph_def)
    130         self._cached_graph_def_device_names.append(device_name)
    131         self._cached_graph_def_wall_times.append(wall_time)
    132     else:
    133       self._event_listener_servicer.partition_graph_defs.append(graph_def)
    134 
    135   def on_value_event(self, event):
    136     """Implementation of the tensor value-carrying Event proto callback.
    137 
    138     Writes the Event proto to the file system for testing. The path written to
    139     follows the same pattern as the file:// debug URLs of tfdbg, i.e., the
    140     name scope of the op becomes the directory structure under the dump root
    141     directory.
    142 
    143     Args:
    144       event: The Event proto carrying a tensor value.
    145 
    146     Returns:
    147       If the debug node belongs to the set of currently activated breakpoints,
    148       a `EventReply` proto will be returned.
    149     """
    150     if self._dump_dir:
    151       self._write_value_event(event)
    152     else:
    153       value = event.summary.value[0]
    154       tensor_value = debug_data.load_tensor_from_event(event)
    155       self._event_listener_servicer.debug_tensor_values[value.node_name].append(
    156           tensor_value)
    157 
    158       items = event.summary.value[0].node_name.split(":")
    159       node_name = items[0]
    160       output_slot = int(items[1])
    161       debug_op = items[2]
    162       if ((node_name, output_slot, debug_op) in
    163           self._event_listener_servicer.breakpoints):
    164         return debug_service_pb2.EventReply()
    165 
    166   def _try_makedirs(self, dir_path):
    167     if not os.path.isdir(dir_path):
    168       try:
    169         os.makedirs(dir_path)
    170       except OSError as error:
    171         if error.errno != errno.EEXIST:
    172           raise
    173 
    174   def _write_core_metadata_event(self, event):
    175     core_metadata_path = os.path.join(
    176         self._dump_dir,
    177         debug_data.METADATA_FILE_PREFIX + debug_data.CORE_METADATA_TAG +
    178         "_%d" % event.wall_time)
    179     self._try_makedirs(self._dump_dir)
    180     with open(core_metadata_path, "wb") as f:
    181       f.write(event.SerializeToString())
    182 
    183   def _write_graph_def(self, graph_def, device_name, wall_time):
    184     encoded_graph_def = graph_def.SerializeToString()
    185     graph_hash = int(hashlib.md5(encoded_graph_def).hexdigest(), 16)
    186     event = event_pb2.Event(graph_def=encoded_graph_def, wall_time=wall_time)
    187     graph_file_path = os.path.join(
    188         self._dump_dir,
    189         debug_data.device_name_to_device_path(device_name),
    190         debug_data.METADATA_FILE_PREFIX + debug_data.GRAPH_FILE_TAG +
    191         debug_data.HASH_TAG + "%d_%d" % (graph_hash, wall_time))
    192     self._try_makedirs(os.path.dirname(graph_file_path))
    193     with open(graph_file_path, "wb") as f:
    194       f.write(event.SerializeToString())
    195 
    196   def _write_value_event(self, event):
    197     value = event.summary.value[0]
    198 
    199     # Obtain the device name from the metadata.
    200     summary_metadata = event.summary.value[0].metadata
    201     if not summary_metadata.plugin_data:
    202       raise ValueError("The value lacks plugin data.")
    203     try:
    204       content = json.loads(compat.as_text(summary_metadata.plugin_data.content))
    205     except ValueError as err:
    206       raise ValueError("Could not parse content into JSON: %r, %r" % (content,
    207                                                                       err))
    208     device_name = content["device"]
    209 
    210     dump_full_path = _get_dump_file_path(
    211         self._dump_dir, device_name, value.node_name)
    212     self._try_makedirs(os.path.dirname(dump_full_path))
    213     with open(dump_full_path, "wb") as f:
    214       f.write(event.SerializeToString())
    215 
    216 
    217 class EventListenerTestServicer(grpc_debug_server.EventListenerBaseServicer):
    218   """An implementation of EventListenerBaseServicer for testing."""
    219 
    220   def __init__(self, server_port, dump_dir, toggle_watch_on_core_metadata=None):
    221     """Constructor of EventListenerTestServicer.
    222 
    223     Args:
    224       server_port: (int) The server port number.
    225       dump_dir: (str) The root directory to which the data files will be
    226         dumped. If empty or None, the received debug data will not be dumped
    227         to the file system: they will be stored in memory instead.
    228       toggle_watch_on_core_metadata: A list of
    229         (node_name, output_slot, debug_op) tuples to toggle the
    230         watchpoint status during the on_core_metadata calls (optional).
    231     """
    232     self.core_metadata_json_strings = []
    233     self.partition_graph_defs = []
    234     self.debug_tensor_values = collections.defaultdict(list)
    235     self._initialize_toggle_watch_state(toggle_watch_on_core_metadata)
    236 
    237     grpc_debug_server.EventListenerBaseServicer.__init__(
    238         self, server_port,
    239         functools.partial(EventListenerTestStreamHandler, dump_dir, self))
    240 
    241     # Members for storing the graph ops traceback and source files.
    242     self._call_types = []
    243     self._call_keys = []
    244     self._origin_stacks = []
    245     self._origin_id_to_strings = []
    246     self._graph_tracebacks = []
    247     self._graph_versions = []
    248     self._source_files = None
    249 
    250   def _initialize_toggle_watch_state(self, toggle_watches):
    251     self._toggle_watches = toggle_watches
    252     self._toggle_watch_state = dict()
    253     if self._toggle_watches:
    254       for watch_key in self._toggle_watches:
    255         self._toggle_watch_state[watch_key] = False
    256 
    257   def toggle_watch(self):
    258     for watch_key in self._toggle_watch_state:
    259       node_name, output_slot, debug_op = watch_key
    260       if self._toggle_watch_state[watch_key]:
    261         self.request_unwatch(node_name, output_slot, debug_op)
    262       else:
    263         self.request_watch(node_name, output_slot, debug_op)
    264       self._toggle_watch_state[watch_key] = (
    265           not self._toggle_watch_state[watch_key])
    266 
    267   def clear_data(self):
    268     self.core_metadata_json_strings = []
    269     self.partition_graph_defs = []
    270     self.debug_tensor_values = collections.defaultdict(list)
    271     self._call_types = []
    272     self._call_keys = []
    273     self._origin_stacks = []
    274     self._origin_id_to_strings = []
    275     self._graph_tracebacks = []
    276     self._graph_versions = []
    277     self._source_files = None
    278 
    279   def SendTracebacks(self, request, context):
    280     self._call_types.append(request.call_type)
    281     self._call_keys.append(request.call_key)
    282     self._origin_stacks.append(request.origin_stack)
    283     self._origin_id_to_strings.append(request.origin_id_to_string)
    284     self._graph_tracebacks.append(request.graph_traceback)
    285     self._graph_versions.append(request.graph_version)
    286     return debug_service_pb2.EventReply()
    287 
    288   def SendSourceFiles(self, request, context):
    289     self._source_files = request
    290     return debug_service_pb2.EventReply()
    291 
    292   def query_op_traceback(self, op_name):
    293     """Query the traceback of an op.
    294 
    295     Args:
    296       op_name: Name of the op to query.
    297 
    298     Returns:
    299       The traceback of the op, as a list of 3-tuples:
    300         (filename, lineno, function_name)
    301 
    302     Raises:
    303       ValueError: If the op cannot be found in the tracebacks received by the
    304         server so far.
    305     """
    306     for op_log_proto in self._graph_tracebacks:
    307       for log_entry in op_log_proto.log_entries:
    308         if log_entry.name == op_name:
    309           return self._code_def_to_traceback(log_entry.code_def,
    310                                              op_log_proto.id_to_string)
    311     raise ValueError(
    312         "Op '%s' does not exist in the tracebacks received by the debug "
    313         "server." % op_name)
    314 
    315   def query_origin_stack(self):
    316     """Query the stack of the origin of the execution call.
    317 
    318     Returns:
    319       A `list` of all tracebacks. Each item corresponds to an execution call,
    320         i.e., a `SendTracebacks` request. Each item is a `list` of 3-tuples:
    321         (filename, lineno, function_name).
    322     """
    323     ret = []
    324     for stack, id_to_string in zip(
    325         self._origin_stacks, self._origin_id_to_strings):
    326       ret.append(self._code_def_to_traceback(stack, id_to_string))
    327     return ret
    328 
    329   def query_call_types(self):
    330     return self._call_types
    331 
    332   def query_call_keys(self):
    333     return self._call_keys
    334 
    335   def query_graph_versions(self):
    336     return self._graph_versions
    337 
    338   def query_source_file_line(self, file_path, lineno):
    339     """Query the content of a given line in a source file.
    340 
    341     Args:
    342       file_path: Path to the source file.
    343       lineno: Line number as an `int`.
    344 
    345     Returns:
    346       Content of the line as a string.
    347 
    348     Raises:
    349       ValueError: If no source file is found at the given file_path.
    350     """
    351     if not self._source_files:
    352       raise ValueError(
    353           "This debug server has not received any source file contents yet.")
    354     for source_file_proto in self._source_files.source_files:
    355       if source_file_proto.file_path == file_path:
    356         return source_file_proto.lines[lineno - 1]
    357     raise ValueError(
    358         "Source file at path %s has not been received by the debug server",
    359         file_path)
    360 
    361   def _code_def_to_traceback(self, code_def, id_to_string):
    362     return [(id_to_string[trace.file_id],
    363              trace.lineno,
    364              id_to_string[trace.function_id]) for trace in code_def.traces]
    365 
    366 
    367 def start_server_on_separate_thread(dump_to_filesystem=True,
    368                                     server_start_delay_sec=0.0,
    369                                     poll_server=False,
    370                                     blocking=True,
    371                                     toggle_watch_on_core_metadata=None):
    372   """Create a test gRPC debug server and run on a separate thread.
    373 
    374   Args:
    375     dump_to_filesystem: (bool) whether the debug server will dump debug data
    376       to the filesystem.
    377     server_start_delay_sec: (float) amount of time (in sec) to delay the server
    378       start up for.
    379     poll_server: (bool) whether the server will be polled till success on
    380       startup.
    381     blocking: (bool) whether the server should be started in a blocking mode.
    382     toggle_watch_on_core_metadata: A list of
    383         (node_name, output_slot, debug_op) tuples to toggle the
    384         watchpoint status during the on_core_metadata calls (optional).
    385 
    386   Returns:
    387     server_port: (int) Port on which the server runs.
    388     debug_server_url: (str) grpc:// URL to the server.
    389     server_dump_dir: (str) The debug server's dump directory.
    390     server_thread: The server Thread object.
    391     server: The `EventListenerTestServicer` object.
    392 
    393   Raises:
    394     ValueError: If polling the server process for ready state is not successful
    395       within maximum polling count.
    396   """
    397   server_port = portpicker.pick_unused_port()
    398   debug_server_url = "grpc://localhost:%d" % server_port
    399 
    400   server_dump_dir = tempfile.mkdtemp() if dump_to_filesystem else None
    401   server = EventListenerTestServicer(
    402       server_port=server_port,
    403       dump_dir=server_dump_dir,
    404       toggle_watch_on_core_metadata=toggle_watch_on_core_metadata)
    405 
    406   def delay_then_run_server():
    407     time.sleep(server_start_delay_sec)
    408     server.run_server(blocking=blocking)
    409 
    410   server_thread = threading.Thread(target=delay_then_run_server)
    411   server_thread.start()
    412 
    413   if poll_server:
    414     if not _poll_server_till_success(
    415         50,
    416         0.2,
    417         debug_server_url,
    418         server_dump_dir,
    419         server,
    420         gpu_memory_fraction=0.1):
    421       raise ValueError(
    422           "Failed to start test gRPC debug server at port %d" % server_port)
    423     server.clear_data()
    424   return server_port, debug_server_url, server_dump_dir, server_thread, server
    425 
    426 
    427 def _poll_server_till_success(max_attempts,
    428                               sleep_per_poll_sec,
    429                               debug_server_url,
    430                               dump_dir,
    431                               server,
    432                               gpu_memory_fraction=1.0):
    433   """Poll server until success or exceeding max polling count.
    434 
    435   Args:
    436     max_attempts: (int) How many times to poll at maximum
    437     sleep_per_poll_sec: (float) How many seconds to sleep for after each
    438       unsuccessful poll.
    439     debug_server_url: (str) gRPC URL to the debug server.
    440     dump_dir: (str) Dump directory to look for files in. If None, will directly
    441       check data from the server object.
    442     server: The server object.
    443     gpu_memory_fraction: (float) Fraction of GPU memory to be
    444       allocated for the Session used in server polling.
    445 
    446   Returns:
    447     (bool) Whether the polling succeeded within max_polls attempts.
    448   """
    449   poll_count = 0
    450 
    451   config = config_pb2.ConfigProto(gpu_options=config_pb2.GPUOptions(
    452       per_process_gpu_memory_fraction=gpu_memory_fraction))
    453   with session.Session(config=config) as sess:
    454     for poll_count in range(max_attempts):
    455       server.clear_data()
    456       print("Polling: poll_count = %d" % poll_count)
    457 
    458       x_init_name = "x_init_%d" % poll_count
    459       x_init = constant_op.constant([42.0], shape=[1], name=x_init_name)
    460       x = variables.Variable(x_init, name=x_init_name)
    461 
    462       run_options = config_pb2.RunOptions()
    463       debug_utils.add_debug_tensor_watch(
    464           run_options, x_init_name, 0, debug_urls=[debug_server_url])
    465       try:
    466         sess.run(x.initializer, options=run_options)
    467       except errors.FailedPreconditionError:
    468         pass
    469 
    470       if dump_dir:
    471         if os.path.isdir(
    472             dump_dir) and debug_data.DebugDumpDir(dump_dir).size > 0:
    473           shutil.rmtree(dump_dir)
    474           print("Poll succeeded.")
    475           return True
    476         else:
    477           print("Poll failed. Sleeping for %f s" % sleep_per_poll_sec)
    478           time.sleep(sleep_per_poll_sec)
    479       else:
    480         if server.debug_tensor_values:
    481           print("Poll succeeded.")
    482           return True
    483         else:
    484           print("Poll failed. Sleeping for %f s" % sleep_per_poll_sec)
    485           time.sleep(sleep_per_poll_sec)
    486 
    487     return False
    488