1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Unit tests for source_utils.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import os 22 import shutil 23 import tempfile 24 25 import numpy as np 26 27 from tensorflow.core.protobuf import config_pb2 28 from tensorflow.python.client import session 29 from tensorflow.python.debug.lib import debug_data 30 from tensorflow.python.debug.lib import debug_utils 31 from tensorflow.python.debug.lib import source_utils 32 from tensorflow.python.framework import constant_op 33 from tensorflow.python.framework import ops 34 from tensorflow.python.framework import test_util 35 from tensorflow.python.ops import control_flow_ops 36 from tensorflow.python.ops import math_ops 37 # Import resource_variable_ops for the variables-to-tensor implicit conversion. 38 from tensorflow.python.ops import resource_variable_ops # pylint: disable=unused-import 39 from tensorflow.python.ops import variables 40 from tensorflow.python.platform import googletest 41 from tensorflow.python.util import tf_inspect 42 43 44 def line_number_above(): 45 return tf_inspect.stack()[1][2] - 1 46 47 48 class GuessIsTensorFlowLibraryTest(test_util.TensorFlowTestCase): 49 50 def setUp(self): 51 self.curr_file_path = os.path.normpath(os.path.abspath(__file__)) 52 53 def tearDown(self): 54 ops.reset_default_graph() 55 56 def testGuessedBaseDirIsProbablyCorrect(self): 57 self.assertEqual("tensorflow", 58 os.path.basename(source_utils._TENSORFLOW_BASEDIR)) 59 60 def testUnitTestFileReturnsFalse(self): 61 self.assertFalse( 62 source_utils.guess_is_tensorflow_py_library(self.curr_file_path)) 63 64 def testSourceUtilModuleReturnsTrue(self): 65 self.assertTrue( 66 source_utils.guess_is_tensorflow_py_library(source_utils.__file__)) 67 68 def testFileInPythonKernelsPathReturnsTrue(self): 69 x = constant_op.constant(42.0, name="x") 70 self.assertTrue( 71 source_utils.guess_is_tensorflow_py_library(x.op.traceback[-1][0])) 72 73 def testNonPythonFileRaisesException(self): 74 with self.assertRaisesRegexp(ValueError, r"is not a Python source file"): 75 source_utils.guess_is_tensorflow_py_library( 76 os.path.join(os.path.dirname(self.curr_file_path), "foo.cc")) 77 78 79 class SourceHelperTest(test_util.TensorFlowTestCase): 80 81 def createAndRunGraphHelper(self): 82 """Create and run a TensorFlow Graph to generate debug dumps. 83 84 This is intentionally done in separate method, to make it easier to test 85 the stack-top mode of source annotation. 86 """ 87 88 self.dump_root = self.get_temp_dir() 89 self.curr_file_path = os.path.abspath( 90 tf_inspect.getfile(tf_inspect.currentframe())) 91 92 # Run a simple TF graph to generate some debug dumps that can be used in 93 # source annotation. 94 with session.Session() as sess: 95 self.u_init = constant_op.constant( 96 np.array([[5.0, 3.0], [-1.0, 0.0]]), shape=[2, 2], name="u_init") 97 self.u_init_line_number = line_number_above() 98 99 self.u = variables.Variable(self.u_init, name="u") 100 self.u_line_number = line_number_above() 101 102 self.v_init = constant_op.constant( 103 np.array([[2.0], [-1.0]]), shape=[2, 1], name="v_init") 104 self.v_init_line_number = line_number_above() 105 106 self.v = variables.Variable(self.v_init, name="v") 107 self.v_line_number = line_number_above() 108 109 self.w = math_ops.matmul(self.u, self.v, name="w") 110 self.w_line_number = line_number_above() 111 112 sess.run(self.u.initializer) 113 sess.run(self.v.initializer) 114 115 run_options = config_pb2.RunOptions(output_partition_graphs=True) 116 debug_utils.watch_graph( 117 run_options, sess.graph, debug_urls=["file://%s" % self.dump_root]) 118 run_metadata = config_pb2.RunMetadata() 119 sess.run(self.w, options=run_options, run_metadata=run_metadata) 120 121 self.dump = debug_data.DebugDumpDir( 122 self.dump_root, partition_graphs=run_metadata.partition_graphs) 123 self.dump.set_python_graph(sess.graph) 124 125 def setUp(self): 126 self.createAndRunGraphHelper() 127 self.helper_line_number = line_number_above() 128 129 def tearDown(self): 130 if os.path.isdir(self.dump_root): 131 shutil.rmtree(self.dump_root) 132 ops.reset_default_graph() 133 134 def testAnnotateWholeValidSourceFileGivesCorrectResult(self): 135 source_annotation = source_utils.annotate_source(self.dump, 136 self.curr_file_path) 137 138 self.assertIn(self.u_init.op.name, 139 source_annotation[self.u_init_line_number]) 140 self.assertIn(self.u.op.name, source_annotation[self.u_line_number]) 141 self.assertIn(self.v_init.op.name, 142 source_annotation[self.v_init_line_number]) 143 self.assertIn(self.v.op.name, source_annotation[self.v_line_number]) 144 self.assertIn(self.w.op.name, source_annotation[self.w_line_number]) 145 146 # In the non-stack-top (default) mode, the helper line should be annotated 147 # with all the ops as well. 148 self.assertIn(self.u_init.op.name, 149 source_annotation[self.helper_line_number]) 150 self.assertIn(self.u.op.name, source_annotation[self.helper_line_number]) 151 self.assertIn(self.v_init.op.name, 152 source_annotation[self.helper_line_number]) 153 self.assertIn(self.v.op.name, source_annotation[self.helper_line_number]) 154 self.assertIn(self.w.op.name, source_annotation[self.helper_line_number]) 155 156 def testAnnotateWithStackTopGivesCorrectResult(self): 157 source_annotation = source_utils.annotate_source( 158 self.dump, self.curr_file_path, file_stack_top=True) 159 160 self.assertIn(self.u_init.op.name, 161 source_annotation[self.u_init_line_number]) 162 self.assertIn(self.u.op.name, source_annotation[self.u_line_number]) 163 self.assertIn(self.v_init.op.name, 164 source_annotation[self.v_init_line_number]) 165 self.assertIn(self.v.op.name, source_annotation[self.v_line_number]) 166 self.assertIn(self.w.op.name, source_annotation[self.w_line_number]) 167 168 # In the stack-top mode, the helper line should not have been annotated. 169 self.assertNotIn(self.helper_line_number, source_annotation) 170 171 def testAnnotateSubsetOfLinesGivesCorrectResult(self): 172 source_annotation = source_utils.annotate_source( 173 self.dump, 174 self.curr_file_path, 175 min_line=self.u_line_number, 176 max_line=self.u_line_number + 1) 177 178 self.assertIn(self.u.op.name, source_annotation[self.u_line_number]) 179 self.assertNotIn(self.v_line_number, source_annotation) 180 181 def testAnnotateDumpedTensorsGivesCorrectResult(self): 182 source_annotation = source_utils.annotate_source( 183 self.dump, self.curr_file_path, do_dumped_tensors=True) 184 185 # Note: Constant Tensors u_init and v_init may not get dumped due to 186 # constant-folding. 187 self.assertIn(self.u.name, source_annotation[self.u_line_number]) 188 self.assertIn(self.v.name, source_annotation[self.v_line_number]) 189 self.assertIn(self.w.name, source_annotation[self.w_line_number]) 190 191 self.assertNotIn(self.u.op.name, source_annotation[self.u_line_number]) 192 self.assertNotIn(self.v.op.name, source_annotation[self.v_line_number]) 193 self.assertNotIn(self.w.op.name, source_annotation[self.w_line_number]) 194 195 self.assertIn(self.u.name, source_annotation[self.helper_line_number]) 196 self.assertIn(self.v.name, source_annotation[self.helper_line_number]) 197 self.assertIn(self.w.name, source_annotation[self.helper_line_number]) 198 199 def testCallingAnnotateSourceWithoutPythonGraphRaisesException(self): 200 self.dump.set_python_graph(None) 201 with self.assertRaises(ValueError): 202 source_utils.annotate_source(self.dump, self.curr_file_path) 203 204 def testCallingAnnotateSourceOnUnrelatedSourceFileDoesNotError(self): 205 # Create an unrelated source file. 206 unrelated_source_path = tempfile.mktemp() 207 with open(unrelated_source_path, "wt") as source_file: 208 source_file.write("print('hello, world')\n") 209 210 self.assertEqual({}, 211 source_utils.annotate_source(self.dump, 212 unrelated_source_path)) 213 214 # Clean up unrelated source file. 215 os.remove(unrelated_source_path) 216 217 218 class ListSourceAgainstDumpTest(test_util.TensorFlowTestCase): 219 220 def createAndRunGraphWithWhileLoop(self): 221 """Create and run a TensorFlow Graph with a while loop to generate dumps.""" 222 223 self.dump_root = self.get_temp_dir() 224 self.curr_file_path = os.path.abspath( 225 tf_inspect.getfile(tf_inspect.currentframe())) 226 227 # Run a simple TF graph to generate some debug dumps that can be used in 228 # source annotation. 229 with session.Session() as sess: 230 loop_body = lambda i: math_ops.add(i, 2) 231 self.traceback_first_line = line_number_above() 232 233 loop_cond = lambda i: math_ops.less(i, 16) 234 235 i = constant_op.constant(10, name="i") 236 loop = control_flow_ops.while_loop(loop_cond, loop_body, [i]) 237 238 run_options = config_pb2.RunOptions(output_partition_graphs=True) 239 debug_utils.watch_graph( 240 run_options, sess.graph, debug_urls=["file://%s" % self.dump_root]) 241 run_metadata = config_pb2.RunMetadata() 242 sess.run(loop, options=run_options, run_metadata=run_metadata) 243 244 self.dump = debug_data.DebugDumpDir( 245 self.dump_root, partition_graphs=run_metadata.partition_graphs) 246 self.dump.set_python_graph(sess.graph) 247 248 def setUp(self): 249 self.createAndRunGraphWithWhileLoop() 250 251 def tearDown(self): 252 if os.path.isdir(self.dump_root): 253 shutil.rmtree(self.dump_root) 254 ops.reset_default_graph() 255 256 def testGenerateSourceList(self): 257 source_list = source_utils.list_source_files_against_dump(self.dump) 258 259 # Assert that the file paths are sorted and unique. 260 file_paths = [item[0] for item in source_list] 261 self.assertEqual(sorted(file_paths), file_paths) 262 self.assertEqual(len(set(file_paths)), len(file_paths)) 263 264 # Assert that each item of source_list has length 6. 265 for item in source_list: 266 self.assertTrue(isinstance(item, tuple)) 267 self.assertEqual(6, len(item)) 268 269 # The while loop body should have executed 3 times. The following table 270 # lists the tensors and how many times each of them is dumped. 271 # Tensor name # of times dumped: 272 # i:0 1 273 # while/Enter:0 1 274 # while/Merge:0 4 275 # while/Merge:1 4 276 # while/Less/y:0 4 277 # while/Less:0 4 278 # while/LoopCond:0 4 279 # while/Switch:0 1 280 # while/Swtich:1 3 281 # while/Identity:0 3 282 # while/Add/y:0 3 283 # while/Add:0 3 284 # while/NextIteration:0 3 285 # while/Exit:0 1 286 # ---------------------------- 287 # (Total) 39 288 # 289 # The total number of nodes is 12. 290 # The total number of tensors is 14 (2 of the nodes have 2 outputs: 291 # while/Merge, while/Switch). 292 293 _, is_tf_py_library, num_nodes, num_tensors, num_dumps, first_line = ( 294 source_list[file_paths.index(self.curr_file_path)]) 295 self.assertFalse(is_tf_py_library) 296 self.assertEqual(12, num_nodes) 297 self.assertEqual(14, num_tensors) 298 self.assertEqual(39, num_dumps) 299 self.assertEqual(self.traceback_first_line, first_line) 300 301 def testGenerateSourceListWithNodeNameFilter(self): 302 source_list = source_utils.list_source_files_against_dump( 303 self.dump, node_name_regex_whitelist=r"while/Add.*") 304 305 # Assert that the file paths are sorted. 306 file_paths = [item[0] for item in source_list] 307 self.assertEqual(sorted(file_paths), file_paths) 308 self.assertEqual(len(set(file_paths)), len(file_paths)) 309 310 # Assert that each item of source_list has length 4. 311 for item in source_list: 312 self.assertTrue(isinstance(item, tuple)) 313 self.assertEqual(6, len(item)) 314 315 # Due to the node-name filtering the result should only contain 2 nodes 316 # and 2 tensors. The total number of dumped tensors should be 6: 317 # while/Add/y:0 3 318 # while/Add:0 3 319 _, is_tf_py_library, num_nodes, num_tensors, num_dumps, _ = ( 320 source_list[file_paths.index(self.curr_file_path)]) 321 self.assertFalse(is_tf_py_library) 322 self.assertEqual(2, num_nodes) 323 self.assertEqual(2, num_tensors) 324 self.assertEqual(6, num_dumps) 325 326 def testGenerateSourceListWithPathRegexFilter(self): 327 curr_file_basename = os.path.basename(self.curr_file_path) 328 source_list = source_utils.list_source_files_against_dump( 329 self.dump, 330 path_regex_whitelist=( 331 ".*" + curr_file_basename.replace(".", "\\.") + "$")) 332 333 self.assertEqual(1, len(source_list)) 334 (file_path, is_tf_py_library, num_nodes, num_tensors, num_dumps, 335 first_line) = source_list[0] 336 self.assertEqual(self.curr_file_path, file_path) 337 self.assertFalse(is_tf_py_library) 338 self.assertEqual(12, num_nodes) 339 self.assertEqual(14, num_tensors) 340 self.assertEqual(39, num_dumps) 341 self.assertEqual(self.traceback_first_line, first_line) 342 343 344 if __name__ == "__main__": 345 googletest.main() 346