Home | History | Annotate | Download | only in lib
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Unit tests for source_utils."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import os
     22 import shutil
     23 import tempfile
     24 
     25 import numpy as np
     26 
     27 from tensorflow.core.protobuf import config_pb2
     28 from tensorflow.python.client import session
     29 from tensorflow.python.debug.lib import debug_data
     30 from tensorflow.python.debug.lib import debug_utils
     31 from tensorflow.python.debug.lib import source_utils
     32 from tensorflow.python.framework import constant_op
     33 from tensorflow.python.framework import ops
     34 from tensorflow.python.framework import test_util
     35 from tensorflow.python.ops import control_flow_ops
     36 from tensorflow.python.ops import math_ops
     37 # Import resource_variable_ops for the variables-to-tensor implicit conversion.
     38 from tensorflow.python.ops import resource_variable_ops  # pylint: disable=unused-import
     39 from tensorflow.python.ops import variables
     40 from tensorflow.python.platform import googletest
     41 from tensorflow.python.util import tf_inspect
     42 
     43 
     44 def line_number_above():
     45   return tf_inspect.stack()[1][2] - 1
     46 
     47 
     48 class GuessIsTensorFlowLibraryTest(test_util.TensorFlowTestCase):
     49 
     50   def setUp(self):
     51     self.curr_file_path = os.path.normpath(os.path.abspath(__file__))
     52 
     53   def tearDown(self):
     54     ops.reset_default_graph()
     55 
     56   def testGuessedBaseDirIsProbablyCorrect(self):
     57     self.assertEqual("tensorflow",
     58                      os.path.basename(source_utils._TENSORFLOW_BASEDIR))
     59 
     60   def testUnitTestFileReturnsFalse(self):
     61     self.assertFalse(
     62         source_utils.guess_is_tensorflow_py_library(self.curr_file_path))
     63 
     64   def testSourceUtilModuleReturnsTrue(self):
     65     self.assertTrue(
     66         source_utils.guess_is_tensorflow_py_library(source_utils.__file__))
     67 
     68   def testFileInPythonKernelsPathReturnsTrue(self):
     69     x = constant_op.constant(42.0, name="x")
     70     self.assertTrue(
     71         source_utils.guess_is_tensorflow_py_library(x.op.traceback[-1][0]))
     72 
     73   def testNonPythonFileRaisesException(self):
     74     with self.assertRaisesRegexp(ValueError, r"is not a Python source file"):
     75       source_utils.guess_is_tensorflow_py_library(
     76           os.path.join(os.path.dirname(self.curr_file_path), "foo.cc"))
     77 
     78 
     79 class SourceHelperTest(test_util.TensorFlowTestCase):
     80 
     81   def createAndRunGraphHelper(self):
     82     """Create and run a TensorFlow Graph to generate debug dumps.
     83 
     84     This is intentionally done in separate method, to make it easier to test
     85     the stack-top mode of source annotation.
     86     """
     87 
     88     self.dump_root = self.get_temp_dir()
     89     self.curr_file_path = os.path.abspath(
     90         tf_inspect.getfile(tf_inspect.currentframe()))
     91 
     92     # Run a simple TF graph to generate some debug dumps that can be used in
     93     # source annotation.
     94     with session.Session() as sess:
     95       self.u_init = constant_op.constant(
     96           np.array([[5.0, 3.0], [-1.0, 0.0]]), shape=[2, 2], name="u_init")
     97       self.u_init_line_number = line_number_above()
     98 
     99       self.u = variables.Variable(self.u_init, name="u")
    100       self.u_line_number = line_number_above()
    101 
    102       self.v_init = constant_op.constant(
    103           np.array([[2.0], [-1.0]]), shape=[2, 1], name="v_init")
    104       self.v_init_line_number = line_number_above()
    105 
    106       self.v = variables.Variable(self.v_init, name="v")
    107       self.v_line_number = line_number_above()
    108 
    109       self.w = math_ops.matmul(self.u, self.v, name="w")
    110       self.w_line_number = line_number_above()
    111 
    112       sess.run(self.u.initializer)
    113       sess.run(self.v.initializer)
    114 
    115       run_options = config_pb2.RunOptions(output_partition_graphs=True)
    116       debug_utils.watch_graph(
    117           run_options, sess.graph, debug_urls=["file://%s" % self.dump_root])
    118       run_metadata = config_pb2.RunMetadata()
    119       sess.run(self.w, options=run_options, run_metadata=run_metadata)
    120 
    121       self.dump = debug_data.DebugDumpDir(
    122           self.dump_root, partition_graphs=run_metadata.partition_graphs)
    123       self.dump.set_python_graph(sess.graph)
    124 
    125   def setUp(self):
    126     self.createAndRunGraphHelper()
    127     self.helper_line_number = line_number_above()
    128 
    129   def tearDown(self):
    130     if os.path.isdir(self.dump_root):
    131       shutil.rmtree(self.dump_root)
    132     ops.reset_default_graph()
    133 
    134   def testAnnotateWholeValidSourceFileGivesCorrectResult(self):
    135     source_annotation = source_utils.annotate_source(self.dump,
    136                                                      self.curr_file_path)
    137 
    138     self.assertIn(self.u_init.op.name,
    139                   source_annotation[self.u_init_line_number])
    140     self.assertIn(self.u.op.name, source_annotation[self.u_line_number])
    141     self.assertIn(self.v_init.op.name,
    142                   source_annotation[self.v_init_line_number])
    143     self.assertIn(self.v.op.name, source_annotation[self.v_line_number])
    144     self.assertIn(self.w.op.name, source_annotation[self.w_line_number])
    145 
    146     # In the non-stack-top (default) mode, the helper line should be annotated
    147     # with all the ops as well.
    148     self.assertIn(self.u_init.op.name,
    149                   source_annotation[self.helper_line_number])
    150     self.assertIn(self.u.op.name, source_annotation[self.helper_line_number])
    151     self.assertIn(self.v_init.op.name,
    152                   source_annotation[self.helper_line_number])
    153     self.assertIn(self.v.op.name, source_annotation[self.helper_line_number])
    154     self.assertIn(self.w.op.name, source_annotation[self.helper_line_number])
    155 
    156   def testAnnotateWithStackTopGivesCorrectResult(self):
    157     source_annotation = source_utils.annotate_source(
    158         self.dump, self.curr_file_path, file_stack_top=True)
    159 
    160     self.assertIn(self.u_init.op.name,
    161                   source_annotation[self.u_init_line_number])
    162     self.assertIn(self.u.op.name, source_annotation[self.u_line_number])
    163     self.assertIn(self.v_init.op.name,
    164                   source_annotation[self.v_init_line_number])
    165     self.assertIn(self.v.op.name, source_annotation[self.v_line_number])
    166     self.assertIn(self.w.op.name, source_annotation[self.w_line_number])
    167 
    168     # In the stack-top mode, the helper line should not have been annotated.
    169     self.assertNotIn(self.helper_line_number, source_annotation)
    170 
    171   def testAnnotateSubsetOfLinesGivesCorrectResult(self):
    172     source_annotation = source_utils.annotate_source(
    173         self.dump,
    174         self.curr_file_path,
    175         min_line=self.u_line_number,
    176         max_line=self.u_line_number + 1)
    177 
    178     self.assertIn(self.u.op.name, source_annotation[self.u_line_number])
    179     self.assertNotIn(self.v_line_number, source_annotation)
    180 
    181   def testAnnotateDumpedTensorsGivesCorrectResult(self):
    182     source_annotation = source_utils.annotate_source(
    183         self.dump, self.curr_file_path, do_dumped_tensors=True)
    184 
    185     # Note: Constant Tensors u_init and v_init may not get dumped due to
    186     #   constant-folding.
    187     self.assertIn(self.u.name, source_annotation[self.u_line_number])
    188     self.assertIn(self.v.name, source_annotation[self.v_line_number])
    189     self.assertIn(self.w.name, source_annotation[self.w_line_number])
    190 
    191     self.assertNotIn(self.u.op.name, source_annotation[self.u_line_number])
    192     self.assertNotIn(self.v.op.name, source_annotation[self.v_line_number])
    193     self.assertNotIn(self.w.op.name, source_annotation[self.w_line_number])
    194 
    195     self.assertIn(self.u.name, source_annotation[self.helper_line_number])
    196     self.assertIn(self.v.name, source_annotation[self.helper_line_number])
    197     self.assertIn(self.w.name, source_annotation[self.helper_line_number])
    198 
    199   def testCallingAnnotateSourceWithoutPythonGraphRaisesException(self):
    200     self.dump.set_python_graph(None)
    201     with self.assertRaises(ValueError):
    202       source_utils.annotate_source(self.dump, self.curr_file_path)
    203 
    204   def testCallingAnnotateSourceOnUnrelatedSourceFileDoesNotError(self):
    205     # Create an unrelated source file.
    206     unrelated_source_path = tempfile.mktemp()
    207     with open(unrelated_source_path, "wt") as source_file:
    208       source_file.write("print('hello, world')\n")
    209 
    210     self.assertEqual({},
    211                      source_utils.annotate_source(self.dump,
    212                                                   unrelated_source_path))
    213 
    214     # Clean up unrelated source file.
    215     os.remove(unrelated_source_path)
    216 
    217 
    218 class ListSourceAgainstDumpTest(test_util.TensorFlowTestCase):
    219 
    220   def createAndRunGraphWithWhileLoop(self):
    221     """Create and run a TensorFlow Graph with a while loop to generate dumps."""
    222 
    223     self.dump_root = self.get_temp_dir()
    224     self.curr_file_path = os.path.abspath(
    225         tf_inspect.getfile(tf_inspect.currentframe()))
    226 
    227     # Run a simple TF graph to generate some debug dumps that can be used in
    228     # source annotation.
    229     with session.Session() as sess:
    230       loop_body = lambda i: math_ops.add(i, 2)
    231       self.traceback_first_line = line_number_above()
    232 
    233       loop_cond = lambda i: math_ops.less(i, 16)
    234 
    235       i = constant_op.constant(10, name="i")
    236       loop = control_flow_ops.while_loop(loop_cond, loop_body, [i])
    237 
    238       run_options = config_pb2.RunOptions(output_partition_graphs=True)
    239       debug_utils.watch_graph(
    240           run_options, sess.graph, debug_urls=["file://%s" % self.dump_root])
    241       run_metadata = config_pb2.RunMetadata()
    242       sess.run(loop, options=run_options, run_metadata=run_metadata)
    243 
    244       self.dump = debug_data.DebugDumpDir(
    245           self.dump_root, partition_graphs=run_metadata.partition_graphs)
    246       self.dump.set_python_graph(sess.graph)
    247 
    248   def setUp(self):
    249     self.createAndRunGraphWithWhileLoop()
    250 
    251   def tearDown(self):
    252     if os.path.isdir(self.dump_root):
    253       shutil.rmtree(self.dump_root)
    254     ops.reset_default_graph()
    255 
    256   def testGenerateSourceList(self):
    257     source_list = source_utils.list_source_files_against_dump(self.dump)
    258 
    259     # Assert that the file paths are sorted and unique.
    260     file_paths = [item[0] for item in source_list]
    261     self.assertEqual(sorted(file_paths), file_paths)
    262     self.assertEqual(len(set(file_paths)), len(file_paths))
    263 
    264     # Assert that each item of source_list has length 6.
    265     for item in source_list:
    266       self.assertTrue(isinstance(item, tuple))
    267       self.assertEqual(6, len(item))
    268 
    269     # The while loop body should have executed 3 times. The following table
    270     # lists the tensors and how many times each of them is dumped.
    271     #   Tensor name            # of times dumped:
    272     #   i:0                    1
    273     #   while/Enter:0          1
    274     #   while/Merge:0          4
    275     #   while/Merge:1          4
    276     #   while/Less/y:0         4
    277     #   while/Less:0           4
    278     #   while/LoopCond:0       4
    279     #   while/Switch:0         1
    280     #   while/Swtich:1         3
    281     #   while/Identity:0       3
    282     #   while/Add/y:0          3
    283     #   while/Add:0            3
    284     #   while/NextIteration:0  3
    285     #   while/Exit:0           1
    286     # ----------------------------
    287     #   (Total)                39
    288     #
    289     # The total number of nodes is 12.
    290     # The total number of tensors is 14 (2 of the nodes have 2 outputs:
    291     #   while/Merge, while/Switch).
    292 
    293     _, is_tf_py_library, num_nodes, num_tensors, num_dumps, first_line = (
    294         source_list[file_paths.index(self.curr_file_path)])
    295     self.assertFalse(is_tf_py_library)
    296     self.assertEqual(12, num_nodes)
    297     self.assertEqual(14, num_tensors)
    298     self.assertEqual(39, num_dumps)
    299     self.assertEqual(self.traceback_first_line, first_line)
    300 
    301   def testGenerateSourceListWithNodeNameFilter(self):
    302     source_list = source_utils.list_source_files_against_dump(
    303         self.dump, node_name_regex_whitelist=r"while/Add.*")
    304 
    305     # Assert that the file paths are sorted.
    306     file_paths = [item[0] for item in source_list]
    307     self.assertEqual(sorted(file_paths), file_paths)
    308     self.assertEqual(len(set(file_paths)), len(file_paths))
    309 
    310     # Assert that each item of source_list has length 4.
    311     for item in source_list:
    312       self.assertTrue(isinstance(item, tuple))
    313       self.assertEqual(6, len(item))
    314 
    315     # Due to the node-name filtering the result should only contain 2 nodes
    316     # and 2 tensors. The total number of dumped tensors should be 6:
    317     #   while/Add/y:0          3
    318     #   while/Add:0            3
    319     _, is_tf_py_library, num_nodes, num_tensors, num_dumps, _ = (
    320         source_list[file_paths.index(self.curr_file_path)])
    321     self.assertFalse(is_tf_py_library)
    322     self.assertEqual(2, num_nodes)
    323     self.assertEqual(2, num_tensors)
    324     self.assertEqual(6, num_dumps)
    325 
    326   def testGenerateSourceListWithPathRegexFilter(self):
    327     curr_file_basename = os.path.basename(self.curr_file_path)
    328     source_list = source_utils.list_source_files_against_dump(
    329         self.dump,
    330         path_regex_whitelist=(
    331             ".*" + curr_file_basename.replace(".", "\\.") + "$"))
    332 
    333     self.assertEqual(1, len(source_list))
    334     (file_path, is_tf_py_library, num_nodes, num_tensors, num_dumps,
    335      first_line) = source_list[0]
    336     self.assertEqual(self.curr_file_path, file_path)
    337     self.assertFalse(is_tf_py_library)
    338     self.assertEqual(12, num_nodes)
    339     self.assertEqual(14, num_tensors)
    340     self.assertEqual(39, num_dumps)
    341     self.assertEqual(self.traceback_first_line, first_line)
    342 
    343 
    344 if __name__ == "__main__":
    345   googletest.main()
    346