Home | History | Annotate | Download | only in lib
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Communicating tracebacks and source code with debug server."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import socket
     22 
     23 import grpc
     24 
     25 from tensorflow.core.debug import debug_service_pb2
     26 from tensorflow.core.protobuf import debug_pb2
     27 from tensorflow.python.debug.lib import common
     28 from tensorflow.python.debug.lib import debug_service_pb2_grpc
     29 from tensorflow.python.debug.lib import source_utils
     30 from tensorflow.python.platform import gfile
     31 from tensorflow.python.profiler import tfprof_logger
     32 
     33 
     34 def _load_debugged_source_file(file_path, source_file_proto):
     35   file_stat = gfile.Stat(file_path)
     36   source_file_proto.host = socket.gethostname()
     37   source_file_proto.file_path = file_path
     38   source_file_proto.last_modified = file_stat.mtime_nsec
     39   source_file_proto.bytes = file_stat.length
     40   try:
     41     with gfile.Open(file_path, "r") as f:
     42       source_file_proto.lines.extend(f.read().splitlines())
     43   except IOError:
     44     pass
     45 
     46 
     47 def _string_to_id(string, string_to_id):
     48   if string not in string_to_id:
     49     string_to_id[string] = len(string_to_id)
     50   return string_to_id[string]
     51 
     52 
     53 def _format_origin_stack(origin_stack, call_traceback_proto):
     54   """Format a traceback stack for a `CallTraceback` proto.
     55 
     56   Args:
     57     origin_stack: The stack list as returned by `traceback.extract_stack()`.
     58     call_traceback_proto: A `CallTraceback` proto whose fields are to be
     59       populated.
     60   """
     61   string_to_id = dict()
     62   string_to_id[None] = 0
     63   for frame in origin_stack:
     64     file_path, lineno, func_name, line_text = frame
     65     call_traceback_proto.origin_stack.traces.add(
     66         file_id=_string_to_id(file_path, string_to_id),
     67         lineno=lineno,
     68         function_id=_string_to_id(func_name, string_to_id),
     69         line_id=_string_to_id(line_text, string_to_id))
     70 
     71   id_to_string = call_traceback_proto.origin_id_to_string
     72   for key, value in string_to_id.items():
     73     id_to_string[value] = key if key is not None else ""
     74 
     75 
     76 def _source_file_paths_outside_tensorflow_py_library(code_defs, id_to_string):
     77   """Extract source file paths outside TensorFlow Python library.
     78 
     79   Args:
     80     code_defs: An iterable of `CodeDef` protos, i.e., an iterable of stack
     81       traces.
     82     id_to_string: A proto map from integer ids to strings.
     83 
     84   Returns:
     85     An iterable of source file paths outside the TensorFlow Python library.
     86   """
     87   file_ids = set()
     88   for code_def in code_defs:
     89     for trace in code_def.traces:
     90       file_ids.add(trace.file_id)
     91   non_tf_files = (id_to_string[file_id] for file_id in file_ids)
     92   non_tf_files = (
     93       f for f in non_tf_files
     94       if not source_utils.guess_is_tensorflow_py_library(f) and gfile.Exists(f))
     95   return non_tf_files
     96 
     97 
     98 def _send_call_tracebacks(destinations,
     99                           origin_stack,
    100                           is_eager_execution=False,
    101                           call_key=None,
    102                           graph=None,
    103                           send_source=True):
    104   """Send the tracebacks of a TensorFlow execution call.
    105 
    106   To gRPC debug server(s). This applies to graph execution (`tf.Session.run()`)
    107   calls and eager execution calls.
    108 
    109   If `send_source`, also sends the underlying source files outside the
    110   TensorFlow library.
    111 
    112   Args:
    113     destinations: gRPC destination addresses, a `str` or a `list` of `str`s,
    114       e.g., "localhost:4242". If a `list`, gRPC requests containing the same
    115       `CallTraceback` proto payload will be sent to all the destinations.
    116     origin_stack: The traceback stack for the origin of the execution call. For
    117       graph execution, this is the traceback of the `tf.Session.run()`
    118       invocation. For eager execution, this is the traceback of the Python
    119       line that executes the eager opertion.
    120     is_eager_execution: (`bool`) whether an eager execution call (i.e., not a
    121       `tf.Session.run` or derived methods) is being sent.
    122     call_key: The key of the execution call, as a string. For graph execution,
    123       this is a string describing the feeds, fetches (and targets) names of the
    124       `tf.Session.run` call. For eager execution, this is ignored.
    125     graph: A Python `tf.Graph` object (i.e., *not* a `tf.GraphDef`), which
    126       contains op tracebacks, if applicable.
    127     send_source: Whether the source files involved in the op tracebacks but
    128       outside the TensorFlow library are to be sent.
    129   """
    130   if not isinstance(destinations, list):
    131     destinations = [destinations]
    132   # Strip grpc:// prefix, if any is present.
    133   destinations = [
    134       dest[len(common.GRPC_URL_PREFIX):]
    135       if dest.startswith(common.GRPC_URL_PREFIX) else dest
    136       for dest in destinations]
    137 
    138   call_type = (debug_service_pb2.CallTraceback.EAGER_EXECUTION
    139                if is_eager_execution
    140                else debug_service_pb2.CallTraceback.GRAPH_EXECUTION)
    141   graph_traceback = tfprof_logger.merge_default_with_oplog(
    142       graph, add_trainable_var=False) if graph else None
    143   call_traceback = debug_service_pb2.CallTraceback(
    144       call_type=call_type, call_key=call_key, graph_traceback=graph_traceback,
    145       graph_version=graph.version if graph else None)
    146 
    147   _format_origin_stack(origin_stack, call_traceback)
    148 
    149   if send_source:
    150     source_file_paths = set()
    151     source_file_paths.update(_source_file_paths_outside_tensorflow_py_library(
    152         (log_entry.code_def for log_entry
    153          in call_traceback.graph_traceback.log_entries),
    154         call_traceback.graph_traceback.id_to_string))
    155     source_file_paths.update(_source_file_paths_outside_tensorflow_py_library(
    156         [call_traceback.origin_stack], call_traceback.origin_id_to_string))
    157 
    158     debugged_source_files = debug_pb2.DebuggedSourceFiles()
    159     for file_path in source_file_paths:
    160       _load_debugged_source_file(
    161           file_path, debugged_source_files.source_files.add())
    162 
    163   for destination in destinations:
    164     channel = grpc.insecure_channel(destination)
    165     stub = debug_service_pb2_grpc.EventListenerStub(channel)
    166     stub.SendTracebacks(call_traceback)
    167     if send_source:
    168       stub.SendSourceFiles(debugged_source_files)
    169 
    170 
    171 def send_graph_tracebacks(destinations,
    172                           run_key,
    173                           origin_stack,
    174                           graph,
    175                           send_source=True):
    176   """Send the tracebacks of a graph execution call to debug server(s).
    177 
    178   Args:
    179     destinations: gRPC destination addresses, a `str` or a `list` of `str`s,
    180       e.g., "localhost:4242". If a `list`, gRPC requests containing the same
    181       `CallTraceback` proto payload will be sent to all the destinations.
    182     run_key: A string describing the feeds, fetches (and targets) names of the
    183       `tf.Session.run` call.
    184     origin_stack: The traceback of the `tf.Session.run()` invocation.
    185     graph: A Python `tf.Graph` object (i.e., *not* a `tf.GraphDef`), which
    186       contains op tracebacks.
    187     send_source: Whether the source files involved in the op tracebacks but
    188       outside the TensorFlow library are to be sent.
    189   """
    190   _send_call_tracebacks(
    191       destinations, origin_stack, is_eager_execution=False, call_key=run_key,
    192       graph=graph, send_source=send_source)
    193 
    194 
    195 def send_eager_tracebacks(destinations,
    196                           origin_stack,
    197                           send_source=True):
    198   """Send the tracebacks of an eager execution call to debug server(s).
    199 
    200   Args:
    201     destinations: gRPC destination addresses, a `str` or a `list` of `str`s,
    202       e.g., "localhost:4242". If a `list`, gRPC requests containing the same
    203     origin_stack: The traceback of the eager operation invocation.
    204     send_source: Whether the source files involved in the op tracebacks but
    205       outside the TensorFlow library are to be sent.
    206   """
    207   _send_call_tracebacks(
    208       destinations, origin_stack, is_eager_execution=True,
    209       send_source=send_source)
    210