Home | History | Annotate | Download | only in lib
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Classes and functions that help to inspect Python source w.r.t. TF graphs."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import collections
     22 import os
     23 import re
     24 
     25 import numpy as np
     26 
     27 from tensorflow.python.debug.lib import profiling
     28 
     29 
     30 _TENSORFLOW_BASEDIR = os.path.dirname(
     31     os.path.dirname(os.path.dirname(os.path.dirname(
     32         os.path.normpath(os.path.abspath(__file__))))))
     33 
     34 UNCOMPILED_SOURCE_SUFFIXES = (".py")
     35 COMPILED_SOURCE_SUFFIXES = (".pyc", ".pyo")
     36 
     37 
     38 def _norm_abs_path(file_path):
     39   return os.path.normpath(os.path.abspath(file_path))
     40 
     41 
     42 def is_extension_uncompiled_python_source(file_path):
     43   _, extension = os.path.splitext(file_path)
     44   return extension.lower() in UNCOMPILED_SOURCE_SUFFIXES
     45 
     46 
     47 def is_extension_compiled_python_source(file_path):
     48   _, extension = os.path.splitext(file_path)
     49   return extension.lower() in COMPILED_SOURCE_SUFFIXES
     50 
     51 
     52 def _convert_watch_key_to_tensor_name(watch_key):
     53   return watch_key[:watch_key.rfind(":")]
     54 
     55 
     56 def guess_is_tensorflow_py_library(py_file_path):
     57   """Guess whether a Python source file is a part of the tensorflow library.
     58 
     59   Special cases:
     60     1) Returns False for unit-test files in the library (*_test.py),
     61     2) Returns False for files under python/debug/examples.
     62 
     63   Args:
     64     py_file_path: full path of the Python source file in question.
     65 
     66   Returns:
     67     (`bool`) Whether the file is a part of the tensorflow library.
     68 
     69   Raises:
     70     ValueError: if the extension name of py_file_path does not indicate a Python
     71       source file (compiled or uncomplied).
     72   """
     73   if (not is_extension_uncompiled_python_source(py_file_path) and
     74       not is_extension_compiled_python_source(py_file_path)):
     75     raise ValueError(
     76         "Input file path (%s) is not a Python source file." % py_file_path)
     77   py_file_path = _norm_abs_path(py_file_path)
     78 
     79   return (py_file_path.startswith(_TENSORFLOW_BASEDIR) and
     80           not py_file_path.endswith("_test.py") and
     81           not os.path.dirname(py_file_path).endswith(
     82               os.path.normpath("python/debug/examples")))
     83 
     84 
     85 def load_source(source_file_path):
     86   with open(source_file_path, "rU") as f:
     87     source_text = f.read()
     88   source_lines = source_text.split("\n")
     89   line_num_width = int(np.ceil(np.log10(len(source_lines)))) + 3
     90   return source_lines, line_num_width
     91 
     92 
     93 def annotate_source(dump,
     94                     source_file_path,
     95                     do_dumped_tensors=False,
     96                     file_stack_top=False,
     97                     min_line=None,
     98                     max_line=None):
     99   """Annotate a Python source file with a list of ops created at each line.
    100 
    101   (The annotation doesn't change the source file itself.)
    102 
    103   Args:
    104     dump: (`DebugDumpDir`) A `DebugDumpDir` object of which the Python graph
    105       has been loaded.
    106     source_file_path: (`str`) Path to the source file being annotated.
    107     do_dumped_tensors: (`str`) Whether dumped Tensors, instead of ops are to be
    108       used to annotate the source file.
    109     file_stack_top: (`bool`) Whether only the top stack trace in the
    110       specified source file is to be annotated.
    111     min_line: (`None` or `int`) The 1-based line to start annotate the source
    112       file from (inclusive).
    113     max_line: (`None` or `int`) The 1-based line number to end the annotation
    114       at (exclusive).
    115 
    116   Returns:
    117     A `dict` mapping 1-based line number to a list of op name(s) created at
    118       that line, or tensor names if `do_dumped_tensors` is True.
    119 
    120   Raises:
    121     ValueError: If the dump object does not have a Python graph set.
    122   """
    123 
    124   py_graph = dump.python_graph
    125   if not py_graph:
    126     raise ValueError("Cannot perform source annotation due to a lack of set "
    127                      "Python graph in the dump object")
    128 
    129   source_file_path = _norm_abs_path(source_file_path)
    130 
    131   line_to_op_names = {}
    132   for op in py_graph.get_operations():
    133     for file_path, line_number, _, _ in reversed(dump.node_traceback(op.name)):
    134       if (min_line is not None and line_number < min_line or
    135           max_line is not None and line_number >= max_line):
    136         continue
    137 
    138       if _norm_abs_path(file_path) != source_file_path:
    139         continue
    140 
    141       if do_dumped_tensors:
    142         watch_keys = dump.debug_watch_keys(op.name)
    143         # Convert watch keys to unique Tensor names.
    144         items_to_append = list(
    145             set(map(_convert_watch_key_to_tensor_name, watch_keys)))
    146       else:
    147         items_to_append = [op.name]
    148 
    149       if line_number in line_to_op_names:
    150         line_to_op_names[line_number].extend(items_to_append)
    151       else:
    152         line_to_op_names[line_number] = items_to_append
    153 
    154       if file_stack_top:
    155         break
    156 
    157   return line_to_op_names
    158 
    159 
    160 def list_source_files_against_dump(dump,
    161                                    path_regex_whitelist=None,
    162                                    node_name_regex_whitelist=None):
    163   """Generate a list of source files with information regarding ops and tensors.
    164 
    165   Args:
    166     dump: (`DebugDumpDir`) A `DebugDumpDir` object of which the Python graph
    167       has been loaded.
    168     path_regex_whitelist: A regular-expression filter for source file path.
    169     node_name_regex_whitelist: A regular-expression filter for node names.
    170 
    171   Returns:
    172     A list of tuples regarding the Python source files involved in constructing
    173     the ops and tensors contained in `dump`. Each tuple is:
    174       (source_file_path, is_tf_library, num_nodes, num_tensors, num_dumps,
    175        first_line)
    176 
    177       is_tf_library: (`bool`) A guess of whether the file belongs to the
    178         TensorFlow Python library.
    179       num_nodes: How many nodes were created by lines of this source file.
    180         These include nodes with dumps and those without.
    181       num_tensors: How many Tensors were created by lines of this source file.
    182         These include Tensors with dumps and those without.
    183       num_dumps: How many debug Tensor dumps were from nodes (and Tensors)
    184         that were created by this source file.
    185       first_line: The first line number (1-based) that created any nodes or
    186         Tensors in this source file.
    187 
    188     The list is sorted by ascending order of source_file_path.
    189 
    190   Raises:
    191     ValueError: If the dump object does not have a Python graph set.
    192   """
    193 
    194   py_graph = dump.python_graph
    195   if not py_graph:
    196     raise ValueError("Cannot generate source list due to a lack of set "
    197                      "Python graph in the dump object")
    198 
    199   path_to_node_names = collections.defaultdict(set)
    200   path_to_tensor_names = collections.defaultdict(set)
    201   path_to_first_line = {}
    202   tensor_name_to_num_dumps = {}
    203 
    204   path_regex = (re.compile(path_regex_whitelist)
    205                 if path_regex_whitelist else None)
    206   node_name_regex = (re.compile(node_name_regex_whitelist)
    207                      if node_name_regex_whitelist else None)
    208 
    209   to_skip_file_paths = set()
    210   for op in py_graph.get_operations():
    211     if node_name_regex and not node_name_regex.match(op.name):
    212       continue
    213 
    214     for file_path, line_number, _, _ in dump.node_traceback(op.name):
    215       file_path = _norm_abs_path(file_path)
    216       if (file_path in to_skip_file_paths or
    217           path_regex and not path_regex.match(file_path) or
    218           not os.path.isfile(file_path)):
    219         to_skip_file_paths.add(file_path)
    220         continue
    221 
    222       path_to_node_names[file_path].add(op.name)
    223       if file_path in path_to_first_line:
    224         if path_to_first_line[file_path] > line_number:
    225           path_to_first_line[file_path] = line_number
    226       else:
    227         path_to_first_line[file_path] = line_number
    228 
    229       for output_tensor in op.outputs:
    230         tensor_name = output_tensor.name
    231         path_to_tensor_names[file_path].add(tensor_name)
    232 
    233       watch_keys = dump.debug_watch_keys(op.name)
    234       for watch_key in watch_keys:
    235         node_name, output_slot, debug_op = watch_key.split(":")
    236         tensor_name = "%s:%s" % (node_name, output_slot)
    237         if tensor_name not in tensor_name_to_num_dumps:
    238           tensor_name_to_num_dumps[tensor_name] = len(
    239               dump.get_tensors(node_name, int(output_slot), debug_op))
    240 
    241   path_to_num_dumps = {}
    242   for path in path_to_tensor_names:
    243     path_to_num_dumps[path] = sum(
    244         tensor_name_to_num_dumps.get(tensor_name, 0)
    245         for tensor_name in path_to_tensor_names[path])
    246 
    247   output = []
    248   for file_path in path_to_node_names:
    249     output.append((
    250         file_path,
    251         guess_is_tensorflow_py_library(file_path),
    252         len(path_to_node_names.get(file_path, {})),
    253         len(path_to_tensor_names.get(file_path, {})),
    254         path_to_num_dumps.get(file_path, 0),
    255         path_to_first_line[file_path]))
    256 
    257   return sorted(output, key=lambda x: x[0])
    258 
    259 
    260 def annotate_source_against_profile(profile_data,
    261                                     source_file_path,
    262                                     node_name_filter=None,
    263                                     op_type_filter=None,
    264                                     min_line=None,
    265                                     max_line=None):
    266   """Annotate a Python source file with profiling information at each line.
    267 
    268   (The annotation doesn't change the source file itself.)
    269 
    270   Args:
    271     profile_data: (`list` of `ProfileDatum`) A list of `ProfileDatum`.
    272     source_file_path: (`str`) Path to the source file being annotated.
    273     node_name_filter: Regular expression to filter by node name.
    274     op_type_filter: Regular expression to filter by op type.
    275     min_line: (`None` or `int`) The 1-based line to start annotate the source
    276       file from (inclusive).
    277     max_line: (`None` or `int`) The 1-based line number to end the annotation
    278       at (exclusive).
    279 
    280   Returns:
    281     A `dict` mapping 1-based line number to a the namedtuple
    282       `profiling.LineOrFuncProfileSummary`.
    283   """
    284 
    285   source_file_path = _norm_abs_path(source_file_path)
    286 
    287   node_name_regex = re.compile(node_name_filter) if node_name_filter else None
    288   op_type_regex = re.compile(op_type_filter) if op_type_filter else None
    289 
    290   line_to_profile_summary = {}
    291   for profile_datum in profile_data:
    292     if not profile_datum.file_path:
    293       continue
    294 
    295     if _norm_abs_path(profile_datum.file_path) != source_file_path:
    296       continue
    297 
    298     if (min_line is not None and profile_datum.line_number < min_line or
    299         max_line is not None and profile_datum.line_number >= max_line):
    300       continue
    301 
    302     if (node_name_regex and
    303         not node_name_regex.match(profile_datum.node_exec_stats.node_name)):
    304       continue
    305 
    306     if op_type_regex and not op_type_regex.match(profile_datum.op_type):
    307       continue
    308 
    309     if profile_datum.line_number not in line_to_profile_summary:
    310       line_to_profile_summary[profile_datum.line_number] = (
    311           profiling.AggregateProfile(profile_datum))
    312     else:
    313       line_to_profile_summary[profile_datum.line_number].add(profile_datum)
    314 
    315   return line_to_profile_summary
    316