Home | History | Annotate | Download | only in cli
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Shared functions and classes for tfdbg command-line interface."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import math
     21 
     22 import numpy as np
     23 import six
     24 
     25 from tensorflow.python.debug.cli import command_parser
     26 from tensorflow.python.debug.cli import debugger_cli_common
     27 from tensorflow.python.debug.cli import tensor_format
     28 from tensorflow.python.debug.lib import common
     29 from tensorflow.python.framework import ops
     30 from tensorflow.python.ops import variables
     31 from tensorflow.python.platform import gfile
     32 
     33 RL = debugger_cli_common.RichLine
     34 
     35 # Default threshold number of elements above which ellipses will be used
     36 # when printing the value of the tensor.
     37 DEFAULT_NDARRAY_DISPLAY_THRESHOLD = 2000
     38 
     39 COLOR_BLACK = "black"
     40 COLOR_BLUE = "blue"
     41 COLOR_CYAN = "cyan"
     42 COLOR_GRAY = "gray"
     43 COLOR_GREEN = "green"
     44 COLOR_MAGENTA = "magenta"
     45 COLOR_RED = "red"
     46 COLOR_WHITE = "white"
     47 COLOR_YELLOW = "yellow"
     48 
     49 TIME_UNIT_US = "us"
     50 TIME_UNIT_MS = "ms"
     51 TIME_UNIT_S = "s"
     52 TIME_UNITS = [TIME_UNIT_US, TIME_UNIT_MS, TIME_UNIT_S]
     53 
     54 
     55 def bytes_to_readable_str(num_bytes, include_b=False):
     56   """Generate a human-readable string representing number of bytes.
     57 
     58   The units B, kB, MB and GB are used.
     59 
     60   Args:
     61     num_bytes: (`int` or None) Number of bytes.
     62     include_b: (`bool`) Include the letter B at the end of the unit.
     63 
     64   Returns:
     65     (`str`) A string representing the number of bytes in a human-readable way,
     66       including a unit at the end.
     67   """
     68 
     69   if num_bytes is None:
     70     return str(num_bytes)
     71   if num_bytes < 1024:
     72     result = "%d" % num_bytes
     73   elif num_bytes < 1048576:
     74     result = "%.2fk" % (num_bytes / 1024.0)
     75   elif num_bytes < 1073741824:
     76     result = "%.2fM" % (num_bytes / 1048576.0)
     77   else:
     78     result = "%.2fG" % (num_bytes / 1073741824.0)
     79 
     80   if include_b:
     81     result += "B"
     82   return result
     83 
     84 
     85 def time_to_readable_str(value_us, force_time_unit=None):
     86   """Convert time value to human-readable string.
     87 
     88   Args:
     89     value_us: time value in microseconds.
     90     force_time_unit: force the output to use the specified time unit. Must be
     91       in TIME_UNITS.
     92 
     93   Returns:
     94     Human-readable string representation of the time value.
     95 
     96   Raises:
     97     ValueError: if force_time_unit value is not in TIME_UNITS.
     98   """
     99   if not value_us:
    100     return "0"
    101   if force_time_unit:
    102     if force_time_unit not in TIME_UNITS:
    103       raise ValueError("Invalid time unit: %s" % force_time_unit)
    104     order = TIME_UNITS.index(force_time_unit)
    105     time_unit = force_time_unit
    106     return "{:.10g}{}".format(value_us / math.pow(10.0, 3*order), time_unit)
    107   else:
    108     order = min(len(TIME_UNITS) - 1, int(math.log(value_us, 10) / 3))
    109     time_unit = TIME_UNITS[order]
    110     return "{:.3g}{}".format(value_us / math.pow(10.0, 3*order), time_unit)
    111 
    112 
    113 def parse_ranges_highlight(ranges_string):
    114   """Process ranges highlight string.
    115 
    116   Args:
    117     ranges_string: (str) A string representing a numerical range of a list of
    118       numerical ranges. See the help info of the -r flag of the print_tensor
    119       command for more details.
    120 
    121   Returns:
    122     An instance of tensor_format.HighlightOptions, if range_string is a valid
    123       representation of a range or a list of ranges.
    124   """
    125 
    126   ranges = None
    127 
    128   def ranges_filter(x):
    129     r = np.zeros(x.shape, dtype=bool)
    130     for range_start, range_end in ranges:
    131       r = np.logical_or(r, np.logical_and(x >= range_start, x <= range_end))
    132 
    133     return r
    134 
    135   if ranges_string:
    136     ranges = command_parser.parse_ranges(ranges_string)
    137     return tensor_format.HighlightOptions(
    138         ranges_filter, description=ranges_string)
    139   else:
    140     return None
    141 
    142 
    143 def numpy_printoptions_from_screen_info(screen_info):
    144   if screen_info and "cols" in screen_info:
    145     return {"linewidth": screen_info["cols"]}
    146   else:
    147     return {}
    148 
    149 
    150 def format_tensor(tensor,
    151                   tensor_name,
    152                   np_printoptions,
    153                   print_all=False,
    154                   tensor_slicing=None,
    155                   highlight_options=None,
    156                   include_numeric_summary=False,
    157                   write_path=None):
    158   """Generate formatted str to represent a tensor or its slices.
    159 
    160   Args:
    161     tensor: (numpy ndarray) The tensor value.
    162     tensor_name: (str) Name of the tensor, e.g., the tensor's debug watch key.
    163     np_printoptions: (dict) Numpy tensor formatting options.
    164     print_all: (bool) Whether the tensor is to be displayed in its entirety,
    165       instead of printing ellipses, even if its number of elements exceeds
    166       the default numpy display threshold.
    167       (Note: Even if this is set to true, the screen output can still be cut
    168        off by the UI frontend if it consist of more lines than the frontend
    169        can handle.)
    170     tensor_slicing: (str or None) Slicing of the tensor, e.g., "[:, 1]". If
    171       None, no slicing will be performed on the tensor.
    172     highlight_options: (tensor_format.HighlightOptions) options to highlight
    173       elements of the tensor. See the doc of tensor_format.format_tensor()
    174       for more details.
    175     include_numeric_summary: Whether a text summary of the numeric values (if
    176       applicable) will be included.
    177     write_path: A path to save the tensor value (after any slicing) to
    178       (optional). `numpy.save()` is used to save the value.
    179 
    180   Returns:
    181     An instance of `debugger_cli_common.RichTextLines` representing the
    182     (potentially sliced) tensor.
    183   """
    184 
    185   if tensor_slicing:
    186     # Validate the indexing.
    187     value = command_parser.evaluate_tensor_slice(tensor, tensor_slicing)
    188     sliced_name = tensor_name + tensor_slicing
    189   else:
    190     value = tensor
    191     sliced_name = tensor_name
    192 
    193   auxiliary_message = None
    194   if write_path:
    195     with gfile.Open(write_path, "wb") as output_file:
    196       np.save(output_file, value)
    197     line = debugger_cli_common.RichLine("Saved value to: ")
    198     line += debugger_cli_common.RichLine(write_path, font_attr="bold")
    199     line += " (%sB)" % bytes_to_readable_str(gfile.Stat(write_path).length)
    200     auxiliary_message = debugger_cli_common.rich_text_lines_from_rich_line_list(
    201         [line, debugger_cli_common.RichLine("")])
    202 
    203   if print_all:
    204     np_printoptions["threshold"] = value.size
    205   else:
    206     np_printoptions["threshold"] = DEFAULT_NDARRAY_DISPLAY_THRESHOLD
    207 
    208   return tensor_format.format_tensor(
    209       value,
    210       sliced_name,
    211       include_metadata=True,
    212       include_numeric_summary=include_numeric_summary,
    213       auxiliary_message=auxiliary_message,
    214       np_printoptions=np_printoptions,
    215       highlight_options=highlight_options)
    216 
    217 
    218 def error(msg):
    219   """Generate a RichTextLines output for error.
    220 
    221   Args:
    222     msg: (str) The error message.
    223 
    224   Returns:
    225     (debugger_cli_common.RichTextLines) A representation of the error message
    226       for screen output.
    227   """
    228 
    229   return debugger_cli_common.rich_text_lines_from_rich_line_list([
    230       RL("ERROR: " + msg, COLOR_RED)])
    231 
    232 
    233 def _recommend_command(command, description, indent=2, create_link=False):
    234   """Generate a RichTextLines object that describes a recommended command.
    235 
    236   Args:
    237     command: (str) The command to recommend.
    238     description: (str) A description of what the command does.
    239     indent: (int) How many spaces to indent in the beginning.
    240     create_link: (bool) Whether a command link is to be applied to the command
    241       string.
    242 
    243   Returns:
    244     (RichTextLines) Formatted text (with font attributes) for recommending the
    245       command.
    246   """
    247 
    248   indent_str = " " * indent
    249 
    250   if create_link:
    251     font_attr = [debugger_cli_common.MenuItem("", command), "bold"]
    252   else:
    253     font_attr = "bold"
    254 
    255   lines = [RL(indent_str) + RL(command, font_attr) + ":",
    256            indent_str + "  " + description]
    257 
    258   return debugger_cli_common.rich_text_lines_from_rich_line_list(lines)
    259 
    260 
    261 def get_tfdbg_logo():
    262   """Make an ASCII representation of the tfdbg logo."""
    263 
    264   lines = [
    265       "",
    266       "TTTTTT FFFF DDD  BBBB   GGG ",
    267       "  TT   F    D  D B   B G    ",
    268       "  TT   FFF  D  D BBBB  G  GG",
    269       "  TT   F    D  D B   B G   G",
    270       "  TT   F    DDD  BBBB   GGG ",
    271       "",
    272   ]
    273   return debugger_cli_common.RichTextLines(lines)
    274 
    275 
    276 _HORIZONTAL_BAR = "======================================"
    277 
    278 
    279 def get_run_start_intro(run_call_count,
    280                         fetches,
    281                         feed_dict,
    282                         tensor_filters,
    283                         is_callable_runner=False):
    284   """Generate formatted intro for run-start UI.
    285 
    286   Args:
    287     run_call_count: (int) Run call counter.
    288     fetches: Fetches of the `Session.run()` call. See doc of `Session.run()`
    289       for more details.
    290     feed_dict: Feeds to the `Session.run()` call. See doc of `Session.run()`
    291       for more details.
    292     tensor_filters: (dict) A dict from tensor-filter name to tensor-filter
    293       callable.
    294     is_callable_runner: (bool) whether a runner returned by
    295         Session.make_callable is being run.
    296 
    297   Returns:
    298     (RichTextLines) Formatted intro message about the `Session.run()` call.
    299   """
    300 
    301   fetch_lines = common.get_flattened_names(fetches)
    302 
    303   if not feed_dict:
    304     feed_dict_lines = [debugger_cli_common.RichLine("  (Empty)")]
    305   else:
    306     feed_dict_lines = []
    307     for feed_key in feed_dict:
    308       feed_key_name = common.get_graph_element_name(feed_key)
    309       feed_dict_line = debugger_cli_common.RichLine("  ")
    310       feed_dict_line += debugger_cli_common.RichLine(
    311           feed_key_name,
    312           debugger_cli_common.MenuItem(None, "pf '%s'" % feed_key_name))
    313       # Surround the name string with quotes, because feed_key_name may contain
    314       # spaces in some cases, e.g., SparseTensors.
    315       feed_dict_lines.append(feed_dict_line)
    316   feed_dict_lines = debugger_cli_common.rich_text_lines_from_rich_line_list(
    317       feed_dict_lines)
    318 
    319   out = debugger_cli_common.RichTextLines(_HORIZONTAL_BAR)
    320   if is_callable_runner:
    321     out.append("Running a runner returned by Session.make_callable()")
    322   else:
    323     out.append("Session.run() call #%d:" % run_call_count)
    324     out.append("")
    325     out.append("Fetch(es):")
    326     out.extend(debugger_cli_common.RichTextLines(
    327         ["  " + line for line in fetch_lines]))
    328     out.append("")
    329     out.append("Feed dict:")
    330     out.extend(feed_dict_lines)
    331   out.append(_HORIZONTAL_BAR)
    332   out.append("")
    333   out.append("Select one of the following commands to proceed ---->")
    334 
    335   out.extend(
    336       _recommend_command(
    337           "run",
    338           "Execute the run() call with debug tensor-watching",
    339           create_link=True))
    340   out.extend(
    341       _recommend_command(
    342           "run -n",
    343           "Execute the run() call without debug tensor-watching",
    344           create_link=True))
    345   out.extend(
    346       _recommend_command(
    347           "run -t <T>",
    348           "Execute run() calls (T - 1) times without debugging, then "
    349           "execute run() once more with debugging and drop back to the CLI"))
    350   out.extend(
    351       _recommend_command(
    352           "run -f <filter_name>",
    353           "Keep executing run() calls until a dumped tensor passes a given, "
    354           "registered filter (conditional breakpoint mode)"))
    355 
    356   more_lines = ["    Registered filter(s):"]
    357   if tensor_filters:
    358     filter_names = []
    359     for filter_name in tensor_filters:
    360       filter_names.append(filter_name)
    361       command_menu_node = debugger_cli_common.MenuItem(
    362           "", "run -f %s" % filter_name)
    363       more_lines.append(RL("        * ") + RL(filter_name, command_menu_node))
    364   else:
    365     more_lines.append("        (None)")
    366 
    367   out.extend(
    368       debugger_cli_common.rich_text_lines_from_rich_line_list(more_lines))
    369 
    370   out.extend(
    371       _recommend_command(
    372           "invoke_stepper",
    373           "Use the node-stepper interface, which allows you to interactively "
    374           "step through nodes involved in the graph run() call and "
    375           "inspect/modify their values", create_link=True))
    376 
    377   out.append("")
    378 
    379   out.append_rich_line(RL("For more details, see ") +
    380                        RL("help.", debugger_cli_common.MenuItem("", "help")) +
    381                        ".")
    382   out.append("")
    383 
    384   # Make main menu for the run-start intro.
    385   menu = debugger_cli_common.Menu()
    386   menu.append(debugger_cli_common.MenuItem("run", "run"))
    387   menu.append(debugger_cli_common.MenuItem(
    388       "invoke_stepper", "invoke_stepper"))
    389   menu.append(debugger_cli_common.MenuItem("exit", "exit"))
    390   out.annotations[debugger_cli_common.MAIN_MENU_KEY] = menu
    391 
    392   return out
    393 
    394 
    395 def get_run_short_description(run_call_count,
    396                               fetches,
    397                               feed_dict,
    398                               is_callable_runner=False):
    399   """Get a short description of the run() call.
    400 
    401   Args:
    402     run_call_count: (int) Run call counter.
    403     fetches: Fetches of the `Session.run()` call. See doc of `Session.run()`
    404       for more details.
    405     feed_dict: Feeds to the `Session.run()` call. See doc of `Session.run()`
    406       for more details.
    407     is_callable_runner: (bool) whether a runner returned by
    408         Session.make_callable is being run.
    409 
    410   Returns:
    411     (str) A short description of the run() call, including information about
    412       the fetche(s) and feed(s).
    413   """
    414   if is_callable_runner:
    415     return "runner from make_callable()"
    416 
    417   description = "run #%d: " % run_call_count
    418 
    419   if isinstance(fetches, (ops.Tensor, ops.Operation, variables.Variable)):
    420     description += "1 fetch (%s); " % common.get_graph_element_name(fetches)
    421   else:
    422     # Could be (nested) list, tuple, dict or namedtuple.
    423     num_fetches = len(common.get_flattened_names(fetches))
    424     if num_fetches > 1:
    425       description += "%d fetches; " % num_fetches
    426     else:
    427       description += "%d fetch; " % num_fetches
    428 
    429   if not feed_dict:
    430     description += "0 feeds"
    431   else:
    432     if len(feed_dict) == 1:
    433       for key in feed_dict:
    434         description += "1 feed (%s)" % (
    435             key if isinstance(key, six.string_types) or not hasattr(key, "name")
    436             else key.name)
    437     else:
    438       description += "%d feeds" % len(feed_dict)
    439 
    440   return description
    441 
    442 
    443 def get_error_intro(tf_error):
    444   """Generate formatted intro for TensorFlow run-time error.
    445 
    446   Args:
    447     tf_error: (errors.OpError) TensorFlow run-time error object.
    448 
    449   Returns:
    450     (RichTextLines) Formatted intro message about the run-time OpError, with
    451       sample commands for debugging.
    452   """
    453 
    454   op_name = tf_error.op.name
    455 
    456   intro_lines = [
    457       "--------------------------------------",
    458       RL("!!! An error occurred during the run !!!", "blink"),
    459       "",
    460       "You may use the following commands to debug:",
    461   ]
    462 
    463   out = debugger_cli_common.rich_text_lines_from_rich_line_list(intro_lines)
    464 
    465   out.extend(
    466       _recommend_command("ni -a -d -t %s" % op_name,
    467                          "Inspect information about the failing op.",
    468                          create_link=True))
    469   out.extend(
    470       _recommend_command("li -r %s" % op_name,
    471                          "List inputs to the failing op, recursively.",
    472                          create_link=True))
    473 
    474   out.extend(
    475       _recommend_command(
    476           "lt",
    477           "List all tensors dumped during the failing run() call.",
    478           create_link=True))
    479 
    480   more_lines = [
    481       "",
    482       "Op name:    " + op_name,
    483       "Error type: " + str(type(tf_error)),
    484       "",
    485       "Details:",
    486       str(tf_error),
    487       "",
    488       "WARNING: Using client GraphDef due to the error, instead of "
    489       "executor GraphDefs.",
    490       "--------------------------------------",
    491       "",
    492   ]
    493 
    494   out.extend(debugger_cli_common.RichTextLines(more_lines))
    495 
    496   return out
    497