Home | History | Annotate | Download | only in lib
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """GRPC debug server for testing."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import collections
     21 import errno
     22 import functools
     23 import hashlib
     24 import json
     25 import os
     26 import re
     27 import shutil
     28 import tempfile
     29 import threading
     30 import time
     31 
     32 import portpicker
     33 
     34 from tensorflow.core.debug import debug_service_pb2
     35 from tensorflow.core.protobuf import config_pb2
     36 from tensorflow.core.util import event_pb2
     37 from tensorflow.python.client import session
     38 from tensorflow.python.debug.lib import debug_data
     39 from tensorflow.python.debug.lib import debug_utils
     40 from tensorflow.python.debug.lib import grpc_debug_server
     41 from tensorflow.python.framework import constant_op
     42 from tensorflow.python.framework import errors
     43 from tensorflow.python.ops import variables
     44 from tensorflow.python.util import compat
     45 
     46 
     47 def _get_dump_file_path(dump_root, device_name, debug_node_name):
     48   """Get the file path of the dump file for a debug node.
     49 
     50   Args:
     51     dump_root: (str) Root dump directory.
     52     device_name: (str) Name of the device that the debug node resides on.
     53     debug_node_name: (str) Name of the debug node, e.g.,
     54       cross_entropy/Log:0:DebugIdentity.
     55 
     56   Returns:
     57     (str) Full path of the dump file.
     58   """
     59 
     60   dump_root = os.path.join(
     61       dump_root, debug_data.device_name_to_device_path(device_name))
     62   if "/" in debug_node_name:
     63     dump_dir = os.path.join(dump_root, os.path.dirname(debug_node_name))
     64     dump_file_name = re.sub(":", "_", os.path.basename(debug_node_name))
     65   else:
     66     dump_dir = dump_root
     67     dump_file_name = re.sub(":", "_", debug_node_name)
     68 
     69   now_microsec = int(round(time.time() * 1000 * 1000))
     70   dump_file_name += "_%d" % now_microsec
     71 
     72   return os.path.join(dump_dir, dump_file_name)
     73 
     74 
     75 class EventListenerTestStreamHandler(
     76     grpc_debug_server.EventListenerBaseStreamHandler):
     77   """Implementation of EventListenerBaseStreamHandler that dumps to file."""
     78 
     79   def __init__(self, dump_dir, event_listener_servicer):
     80     super(EventListenerTestStreamHandler, self).__init__()
     81     self._dump_dir = dump_dir
     82     self._event_listener_servicer = event_listener_servicer
     83     if self._dump_dir:
     84       self._try_makedirs(self._dump_dir)
     85 
     86     self._grpc_path = None
     87     self._cached_graph_defs = []
     88     self._cached_graph_def_device_names = []
     89     self._cached_graph_def_wall_times = []
     90 
     91   def on_core_metadata_event(self, event):
     92     self._event_listener_servicer.toggle_watch()
     93 
     94     core_metadata = json.loads(event.log_message.message)
     95 
     96     if not self._grpc_path:
     97       grpc_path = core_metadata["grpc_path"]
     98       if grpc_path:
     99         if grpc_path.startswith("/"):
    100           grpc_path = grpc_path[1:]
    101       if self._dump_dir:
    102         self._dump_dir = os.path.join(self._dump_dir, grpc_path)
    103 
    104         # Write cached graph defs to filesystem.
    105         for graph_def, device_name, wall_time in zip(
    106             self._cached_graph_defs,
    107             self._cached_graph_def_device_names,
    108             self._cached_graph_def_wall_times):
    109           self._write_graph_def(graph_def, device_name, wall_time)
    110 
    111     if self._dump_dir:
    112       self._write_core_metadata_event(event)
    113     else:
    114       self._event_listener_servicer.core_metadata_json_strings.append(
    115           event.log_message.message)
    116 
    117   def on_graph_def(self, graph_def, device_name, wall_time):
    118     """Implementation of the tensor value-carrying Event proto callback.
    119 
    120     Args:
    121       graph_def: A GraphDef object.
    122       device_name: Name of the device on which the graph was created.
    123       wall_time: An epoch timestamp (in microseconds) for the graph.
    124     """
    125     if self._dump_dir:
    126       if self._grpc_path:
    127         self._write_graph_def(graph_def, device_name, wall_time)
    128       else:
    129         self._cached_graph_defs.append(graph_def)
    130         self._cached_graph_def_device_names.append(device_name)
    131         self._cached_graph_def_wall_times.append(wall_time)
    132     else:
    133       self._event_listener_servicer.partition_graph_defs.append(graph_def)
    134 
    135   def on_value_event(self, event):
    136     """Implementation of the tensor value-carrying Event proto callback.
    137 
    138     Writes the Event proto to the file system for testing. The path written to
    139     follows the same pattern as the file:// debug URLs of tfdbg, i.e., the
    140     name scope of the op becomes the directory structure under the dump root
    141     directory.
    142 
    143     Args:
    144       event: The Event proto carrying a tensor value.
    145 
    146     Returns:
    147       If the debug node belongs to the set of currently activated breakpoints,
    148       a `EventReply` proto will be returned.
    149     """
    150     if self._dump_dir:
    151       self._write_value_event(event)
    152     else:
    153       value = event.summary.value[0]
    154       tensor_value = debug_data.load_tensor_from_event(event)
    155       self._event_listener_servicer.debug_tensor_values[value.node_name].append(
    156           tensor_value)
    157 
    158       items = event.summary.value[0].node_name.split(":")
    159       node_name = items[0]
    160       output_slot = int(items[1])
    161       debug_op = items[2]
    162       if ((node_name, output_slot, debug_op) in
    163           self._event_listener_servicer.breakpoints):
    164         return debug_service_pb2.EventReply()
    165 
    166   def _try_makedirs(self, dir_path):
    167     if not os.path.isdir(dir_path):
    168       try:
    169         os.makedirs(dir_path)
    170       except OSError as error:
    171         if error.errno != errno.EEXIST:
    172           raise
    173 
    174   def _write_core_metadata_event(self, event):
    175     core_metadata_path = os.path.join(
    176         self._dump_dir,
    177         debug_data.METADATA_FILE_PREFIX + debug_data.CORE_METADATA_TAG +
    178         "_%d" % event.wall_time)
    179     self._try_makedirs(self._dump_dir)
    180     with open(core_metadata_path, "wb") as f:
    181       f.write(event.SerializeToString())
    182 
    183   def _write_graph_def(self, graph_def, device_name, wall_time):
    184     encoded_graph_def = graph_def.SerializeToString()
    185     graph_hash = int(hashlib.md5(encoded_graph_def).hexdigest(), 16)
    186     event = event_pb2.Event(graph_def=encoded_graph_def, wall_time=wall_time)
    187     graph_file_path = os.path.join(
    188         self._dump_dir,
    189         debug_data.device_name_to_device_path(device_name),
    190         debug_data.METADATA_FILE_PREFIX + debug_data.GRAPH_FILE_TAG +
    191         debug_data.HASH_TAG + "%d_%d" % (graph_hash, wall_time))
    192     self._try_makedirs(os.path.dirname(graph_file_path))
    193     with open(graph_file_path, "wb") as f:
    194       f.write(event.SerializeToString())
    195 
    196   def _write_value_event(self, event):
    197     value = event.summary.value[0]
    198 
    199     # Obtain the device name from the metadata.
    200     summary_metadata = event.summary.value[0].metadata
    201     if not summary_metadata.plugin_data:
    202       raise ValueError("The value lacks plugin data.")
    203     try:
    204       content = json.loads(compat.as_text(summary_metadata.plugin_data.content))
    205     except ValueError as err:
    206       raise ValueError("Could not parse content into JSON: %r, %r" % (content,
    207                                                                       err))
    208     device_name = content["device"]
    209 
    210     dump_full_path = _get_dump_file_path(
    211         self._dump_dir, device_name, value.node_name)
    212     self._try_makedirs(os.path.dirname(dump_full_path))
    213     with open(dump_full_path, "wb") as f:
    214       f.write(event.SerializeToString())
    215 
    216 
    217 class EventListenerTestServicer(grpc_debug_server.EventListenerBaseServicer):
    218   """An implementation of EventListenerBaseServicer for testing."""
    219 
    220   def __init__(self, server_port, dump_dir, toggle_watch_on_core_metadata=None):
    221     """Constructor of EventListenerTestServicer.
    222 
    223     Args:
    224       server_port: (int) The server port number.
    225       dump_dir: (str) The root directory to which the data files will be
    226         dumped. If empty or None, the received debug data will not be dumped
    227         to the file system: they will be stored in memory instead.
    228       toggle_watch_on_core_metadata: A list of
    229         (node_name, output_slot, debug_op) tuples to toggle the
    230         watchpoint status during the on_core_metadata calls (optional).
    231     """
    232     self.core_metadata_json_strings = []
    233     self.partition_graph_defs = []
    234     self.debug_tensor_values = collections.defaultdict(list)
    235     self._initialize_toggle_watch_state(toggle_watch_on_core_metadata)
    236 
    237     grpc_debug_server.EventListenerBaseServicer.__init__(
    238         self, server_port,
    239         functools.partial(EventListenerTestStreamHandler, dump_dir, self))
    240 
    241     # Members for storing the graph ops traceback and source files.
    242     self._call_types = []
    243     self._call_keys = []
    244     self._origin_stacks = []
    245     self._origin_id_to_strings = []
    246     self._graph_tracebacks = []
    247     self._graph_versions = []
    248     self._source_files = []
    249 
    250   def _initialize_toggle_watch_state(self, toggle_watches):
    251     self._toggle_watches = toggle_watches
    252     self._toggle_watch_state = dict()
    253     if self._toggle_watches:
    254       for watch_key in self._toggle_watches:
    255         self._toggle_watch_state[watch_key] = False
    256 
    257   def toggle_watch(self):
    258     for watch_key in self._toggle_watch_state:
    259       node_name, output_slot, debug_op = watch_key
    260       if self._toggle_watch_state[watch_key]:
    261         self.request_unwatch(node_name, output_slot, debug_op)
    262       else:
    263         self.request_watch(node_name, output_slot, debug_op)
    264       self._toggle_watch_state[watch_key] = (
    265           not self._toggle_watch_state[watch_key])
    266 
    267   def clear_data(self):
    268     self.core_metadata_json_strings = []
    269     self.partition_graph_defs = []
    270     self.debug_tensor_values = collections.defaultdict(list)
    271     self._call_types = []
    272     self._call_keys = []
    273     self._origin_stacks = []
    274     self._origin_id_to_strings = []
    275     self._graph_tracebacks = []
    276     self._graph_versions = []
    277     self._source_files = []
    278 
    279   def SendTracebacks(self, request, context):
    280     self._call_types.append(request.call_type)
    281     self._call_keys.append(request.call_key)
    282     self._origin_stacks.append(request.origin_stack)
    283     self._origin_id_to_strings.append(request.origin_id_to_string)
    284     self._graph_tracebacks.append(request.graph_traceback)
    285     self._graph_versions.append(request.graph_version)
    286     return debug_service_pb2.EventReply()
    287 
    288   def SendSourceFiles(self, request, context):
    289     self._source_files.append(request)
    290     return debug_service_pb2.EventReply()
    291 
    292   def query_op_traceback(self, op_name):
    293     """Query the traceback of an op.
    294 
    295     Args:
    296       op_name: Name of the op to query.
    297 
    298     Returns:
    299       The traceback of the op, as a list of 3-tuples:
    300         (filename, lineno, function_name)
    301 
    302     Raises:
    303       ValueError: If the op cannot be found in the tracebacks received by the
    304         server so far.
    305     """
    306     for op_log_proto in self._graph_tracebacks:
    307       for log_entry in op_log_proto.log_entries:
    308         if log_entry.name == op_name:
    309           return self._code_def_to_traceback(log_entry.code_def,
    310                                              op_log_proto.id_to_string)
    311     raise ValueError(
    312         "Op '%s' does not exist in the tracebacks received by the debug "
    313         "server." % op_name)
    314 
    315   def query_origin_stack(self):
    316     """Query the stack of the origin of the execution call.
    317 
    318     Returns:
    319       A `list` of all tracebacks. Each item corresponds to an execution call,
    320         i.e., a `SendTracebacks` request. Each item is a `list` of 3-tuples:
    321         (filename, lineno, function_name).
    322     """
    323     ret = []
    324     for stack, id_to_string in zip(
    325         self._origin_stacks, self._origin_id_to_strings):
    326       ret.append(self._code_def_to_traceback(stack, id_to_string))
    327     return ret
    328 
    329   def query_call_types(self):
    330     return self._call_types
    331 
    332   def query_call_keys(self):
    333     return self._call_keys
    334 
    335   def query_graph_versions(self):
    336     return self._graph_versions
    337 
    338   def query_source_file_line(self, file_path, lineno):
    339     """Query the content of a given line in a source file.
    340 
    341     Args:
    342       file_path: Path to the source file.
    343       lineno: Line number as an `int`.
    344 
    345     Returns:
    346       Content of the line as a string.
    347 
    348     Raises:
    349       ValueError: If no source file is found at the given file_path.
    350     """
    351     if not self._source_files:
    352       raise ValueError(
    353           "This debug server has not received any source file contents yet.")
    354     for source_files in self._source_files:
    355       for source_file_proto in source_files.source_files:
    356         if source_file_proto.file_path == file_path:
    357           return source_file_proto.lines[lineno - 1]
    358     raise ValueError(
    359         "Source file at path %s has not been received by the debug server",
    360         file_path)
    361 
    362   def _code_def_to_traceback(self, code_def, id_to_string):
    363     return [(id_to_string[trace.file_id],
    364              trace.lineno,
    365              id_to_string[trace.function_id]) for trace in code_def.traces]
    366 
    367 
    368 def start_server_on_separate_thread(dump_to_filesystem=True,
    369                                     server_start_delay_sec=0.0,
    370                                     poll_server=False,
    371                                     blocking=True,
    372                                     toggle_watch_on_core_metadata=None):
    373   """Create a test gRPC debug server and run on a separate thread.
    374 
    375   Args:
    376     dump_to_filesystem: (bool) whether the debug server will dump debug data
    377       to the filesystem.
    378     server_start_delay_sec: (float) amount of time (in sec) to delay the server
    379       start up for.
    380     poll_server: (bool) whether the server will be polled till success on
    381       startup.
    382     blocking: (bool) whether the server should be started in a blocking mode.
    383     toggle_watch_on_core_metadata: A list of
    384         (node_name, output_slot, debug_op) tuples to toggle the
    385         watchpoint status during the on_core_metadata calls (optional).
    386 
    387   Returns:
    388     server_port: (int) Port on which the server runs.
    389     debug_server_url: (str) grpc:// URL to the server.
    390     server_dump_dir: (str) The debug server's dump directory.
    391     server_thread: The server Thread object.
    392     server: The `EventListenerTestServicer` object.
    393 
    394   Raises:
    395     ValueError: If polling the server process for ready state is not successful
    396       within maximum polling count.
    397   """
    398   server_port = portpicker.pick_unused_port()
    399   debug_server_url = "grpc://localhost:%d" % server_port
    400 
    401   server_dump_dir = tempfile.mkdtemp() if dump_to_filesystem else None
    402   server = EventListenerTestServicer(
    403       server_port=server_port,
    404       dump_dir=server_dump_dir,
    405       toggle_watch_on_core_metadata=toggle_watch_on_core_metadata)
    406 
    407   def delay_then_run_server():
    408     time.sleep(server_start_delay_sec)
    409     server.run_server(blocking=blocking)
    410 
    411   server_thread = threading.Thread(target=delay_then_run_server)
    412   server_thread.start()
    413 
    414   if poll_server:
    415     if not _poll_server_till_success(
    416         50,
    417         0.2,
    418         debug_server_url,
    419         server_dump_dir,
    420         server,
    421         gpu_memory_fraction=0.1):
    422       raise ValueError(
    423           "Failed to start test gRPC debug server at port %d" % server_port)
    424     server.clear_data()
    425   return server_port, debug_server_url, server_dump_dir, server_thread, server
    426 
    427 
    428 def _poll_server_till_success(max_attempts,
    429                               sleep_per_poll_sec,
    430                               debug_server_url,
    431                               dump_dir,
    432                               server,
    433                               gpu_memory_fraction=1.0):
    434   """Poll server until success or exceeding max polling count.
    435 
    436   Args:
    437     max_attempts: (int) How many times to poll at maximum
    438     sleep_per_poll_sec: (float) How many seconds to sleep for after each
    439       unsuccessful poll.
    440     debug_server_url: (str) gRPC URL to the debug server.
    441     dump_dir: (str) Dump directory to look for files in. If None, will directly
    442       check data from the server object.
    443     server: The server object.
    444     gpu_memory_fraction: (float) Fraction of GPU memory to be
    445       allocated for the Session used in server polling.
    446 
    447   Returns:
    448     (bool) Whether the polling succeeded within max_polls attempts.
    449   """
    450   poll_count = 0
    451 
    452   config = config_pb2.ConfigProto(gpu_options=config_pb2.GPUOptions(
    453       per_process_gpu_memory_fraction=gpu_memory_fraction))
    454   with session.Session(config=config) as sess:
    455     for poll_count in range(max_attempts):
    456       server.clear_data()
    457       print("Polling: poll_count = %d" % poll_count)
    458 
    459       x_init_name = "x_init_%d" % poll_count
    460       x_init = constant_op.constant([42.0], shape=[1], name=x_init_name)
    461       x = variables.Variable(x_init, name=x_init_name)
    462 
    463       run_options = config_pb2.RunOptions()
    464       debug_utils.add_debug_tensor_watch(
    465           run_options, x_init_name, 0, debug_urls=[debug_server_url])
    466       try:
    467         sess.run(x.initializer, options=run_options)
    468       except errors.FailedPreconditionError:
    469         pass
    470 
    471       if dump_dir:
    472         if os.path.isdir(
    473             dump_dir) and debug_data.DebugDumpDir(dump_dir).size > 0:
    474           shutil.rmtree(dump_dir)
    475           print("Poll succeeded.")
    476           return True
    477         else:
    478           print("Poll failed. Sleeping for %f s" % sleep_per_poll_sec)
    479           time.sleep(sleep_per_poll_sec)
    480       else:
    481         if server.debug_tensor_values:
    482           print("Poll succeeded.")
    483           return True
    484         else:
    485           print("Poll failed. Sleeping for %f s" % sleep_per_poll_sec)
    486           time.sleep(sleep_per_poll_sec)
    487 
    488     return False
    489