1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Unit tests for source_remote.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import os 22 import traceback 23 24 from tensorflow.core.debug import debug_service_pb2 25 from tensorflow.python.client import session 26 from tensorflow.python.debug.lib import grpc_debug_test_server 27 from tensorflow.python.debug.lib import source_remote 28 from tensorflow.python.debug.lib import source_utils 29 from tensorflow.python.framework import ops 30 from tensorflow.python.framework import test_util 31 from tensorflow.python.ops import math_ops 32 # Import resource_variable_ops for the variables-to-tensor implicit conversion. 33 from tensorflow.python.ops import resource_variable_ops # pylint: disable=unused-import 34 from tensorflow.python.ops import variables 35 from tensorflow.python.platform import googletest 36 from tensorflow.python.util import tf_inspect 37 38 39 def line_number_above(): 40 return tf_inspect.stack()[1][2] - 1 41 42 43 class SendTracebacksTest(test_util.TensorFlowTestCase): 44 45 @classmethod 46 def setUpClass(cls): 47 test_util.TensorFlowTestCase.setUpClass() 48 (cls._server_port, cls._debug_server_url, cls._server_dump_dir, 49 cls._server_thread, 50 cls._server) = grpc_debug_test_server.start_server_on_separate_thread() 51 cls._server_address = "localhost:%d" % cls._server_port 52 (cls._server_port_2, cls._debug_server_url_2, cls._server_dump_dir_2, 53 cls._server_thread_2, 54 cls._server_2) = grpc_debug_test_server.start_server_on_separate_thread() 55 cls._server_address_2 = "localhost:%d" % cls._server_port_2 56 cls._curr_file_path = os.path.normpath(os.path.abspath(__file__)) 57 58 @classmethod 59 def tearDownClass(cls): 60 # Stop the test server and join the thread. 61 cls._server.stop_server().wait() 62 cls._server_thread.join() 63 cls._server_2.stop_server().wait() 64 cls._server_thread_2.join() 65 test_util.TensorFlowTestCase.tearDownClass() 66 67 def tearDown(self): 68 ops.reset_default_graph() 69 self._server.clear_data() 70 self._server_2.clear_data() 71 super(SendTracebacksTest, self).tearDown() 72 73 def _findFirstTraceInsideTensorFlowPyLibrary(self, op): 74 """Find the first trace of an op that belongs to the TF Python library.""" 75 for trace in op.traceback: 76 if source_utils.guess_is_tensorflow_py_library(trace[0]): 77 return trace 78 79 def testSendGraphTracebacksToSingleDebugServer(self): 80 this_func_name = "testSendGraphTracebacksToSingleDebugServer" 81 with session.Session() as sess: 82 a = variables.Variable(21.0, name="a") 83 a_lineno = line_number_above() 84 b = variables.Variable(2.0, name="b") 85 b_lineno = line_number_above() 86 math_ops.add(a, b, name="x") 87 x_lineno = line_number_above() 88 89 send_stack = traceback.extract_stack() 90 send_lineno = line_number_above() 91 source_remote.send_graph_tracebacks( 92 self._server_address, "dummy_run_key", send_stack, sess.graph) 93 94 tb = self._server.query_op_traceback("a") 95 self.assertIn((self._curr_file_path, a_lineno, this_func_name), tb) 96 tb = self._server.query_op_traceback("b") 97 self.assertIn((self._curr_file_path, b_lineno, this_func_name), tb) 98 tb = self._server.query_op_traceback("x") 99 self.assertIn((self._curr_file_path, x_lineno, this_func_name), tb) 100 101 self.assertIn( 102 (self._curr_file_path, send_lineno, this_func_name), 103 self._server.query_origin_stack()[-1]) 104 105 self.assertEqual( 106 " a = variables.Variable(21.0, name=\"a\")", 107 self._server.query_source_file_line(__file__, a_lineno)) 108 # Files in the TensorFlow code base shouldn not have been sent. 109 tf_trace_file_path = self._findFirstTraceInsideTensorFlowPyLibrary(a.op) 110 with self.assertRaises(ValueError): 111 self._server.query_source_file_line(tf_trace_file_path, 0) 112 self.assertEqual([debug_service_pb2.CallTraceback.GRAPH_EXECUTION], 113 self._server.query_call_types()) 114 self.assertEqual(["dummy_run_key"], self._server.query_call_keys()) 115 self.assertEqual( 116 [sess.graph.version], self._server.query_graph_versions()) 117 118 def testSendGraphTracebacksToTwoDebugServers(self): 119 this_func_name = "testSendGraphTracebacksToTwoDebugServers" 120 with session.Session() as sess: 121 a = variables.Variable(21.0, name="two/a") 122 a_lineno = line_number_above() 123 b = variables.Variable(2.0, name="two/b") 124 b_lineno = line_number_above() 125 x = math_ops.add(a, b, name="two/x") 126 x_lineno = line_number_above() 127 128 send_traceback = traceback.extract_stack() 129 send_lineno = line_number_above() 130 source_remote.send_graph_tracebacks( 131 [self._server_address, self._server_address_2], 132 "dummy_run_key", send_traceback, sess.graph) 133 134 servers = [self._server, self._server_2] 135 for server in servers: 136 tb = server.query_op_traceback("two/a") 137 self.assertIn((self._curr_file_path, a_lineno, this_func_name), tb) 138 tb = server.query_op_traceback("two/b") 139 self.assertIn((self._curr_file_path, b_lineno, this_func_name), tb) 140 tb = server.query_op_traceback("two/x") 141 self.assertIn((self._curr_file_path, x_lineno, this_func_name), tb) 142 143 self.assertIn( 144 (self._curr_file_path, send_lineno, this_func_name), 145 server.query_origin_stack()[-1]) 146 147 self.assertEqual( 148 " x = math_ops.add(a, b, name=\"two/x\")", 149 server.query_source_file_line(__file__, x_lineno)) 150 tf_trace_file_path = self._findFirstTraceInsideTensorFlowPyLibrary(x.op) 151 with self.assertRaises(ValueError): 152 server.query_source_file_line(tf_trace_file_path, 0) 153 self.assertEqual([debug_service_pb2.CallTraceback.GRAPH_EXECUTION], 154 server.query_call_types()) 155 self.assertEqual(["dummy_run_key"], server.query_call_keys()) 156 self.assertEqual([sess.graph.version], server.query_graph_versions()) 157 158 def testSendEagerTracebacksToSingleDebugServer(self): 159 this_func_name = "testSendEagerTracebacksToSingleDebugServer" 160 send_traceback = traceback.extract_stack() 161 send_lineno = line_number_above() 162 source_remote.send_eager_tracebacks(self._server_address, send_traceback) 163 164 self.assertEqual([debug_service_pb2.CallTraceback.EAGER_EXECUTION], 165 self._server.query_call_types()) 166 self.assertIn((self._curr_file_path, send_lineno, this_func_name), 167 self._server.query_origin_stack()[-1]) 168 169 170 if __name__ == "__main__": 171 googletest.main() 172