1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Unit tests for source_remote.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import os 22 import traceback 23 24 from tensorflow.core.debug import debug_service_pb2 25 from tensorflow.python.client import session 26 from tensorflow.python.debug.lib import grpc_debug_test_server 27 from tensorflow.python.debug.lib import source_remote 28 from tensorflow.python.debug.lib import source_utils 29 from tensorflow.python.framework import ops 30 from tensorflow.python.framework import test_util 31 from tensorflow.python.ops import math_ops 32 # Import resource_variable_ops for the variables-to-tensor implicit conversion. 33 from tensorflow.python.ops import resource_variable_ops # pylint: disable=unused-import 34 from tensorflow.python.ops import variables 35 from tensorflow.python.platform import googletest 36 from tensorflow.python.platform import test 37 from tensorflow.python.util import tf_inspect 38 39 40 def line_number_above(): 41 return tf_inspect.stack()[1][2] - 1 42 43 44 class SendTracebacksTest(test_util.TensorFlowTestCase): 45 46 @classmethod 47 def setUpClass(cls): 48 test_util.TensorFlowTestCase.setUpClass() 49 (cls._server_port, cls._debug_server_url, cls._server_dump_dir, 50 cls._server_thread, 51 cls._server) = grpc_debug_test_server.start_server_on_separate_thread( 52 poll_server=True) 53 cls._server_address = "localhost:%d" % cls._server_port 54 (cls._server_port_2, cls._debug_server_url_2, cls._server_dump_dir_2, 55 cls._server_thread_2, 56 cls._server_2) = grpc_debug_test_server.start_server_on_separate_thread() 57 cls._server_address_2 = "localhost:%d" % cls._server_port_2 58 cls._curr_file_path = os.path.normpath(os.path.abspath(__file__)) 59 60 @classmethod 61 def tearDownClass(cls): 62 # Stop the test server and join the thread. 63 cls._server.stop_server().wait() 64 cls._server_thread.join() 65 cls._server_2.stop_server().wait() 66 cls._server_thread_2.join() 67 test_util.TensorFlowTestCase.tearDownClass() 68 69 def tearDown(self): 70 ops.reset_default_graph() 71 self._server.clear_data() 72 self._server_2.clear_data() 73 super(SendTracebacksTest, self).tearDown() 74 75 def _findFirstTraceInsideTensorFlowPyLibrary(self, op): 76 """Find the first trace of an op that belongs to the TF Python library.""" 77 for trace in op.traceback: 78 if source_utils.guess_is_tensorflow_py_library(trace[0]): 79 return trace 80 81 def testSendGraphTracebacksToSingleDebugServer(self): 82 this_func_name = "testSendGraphTracebacksToSingleDebugServer" 83 with session.Session() as sess: 84 a = variables.Variable(21.0, name="a") 85 a_lineno = line_number_above() 86 b = variables.Variable(2.0, name="b") 87 b_lineno = line_number_above() 88 math_ops.add(a, b, name="x") 89 x_lineno = line_number_above() 90 91 send_stack = traceback.extract_stack() 92 send_lineno = line_number_above() 93 source_remote.send_graph_tracebacks( 94 self._server_address, "dummy_run_key", send_stack, sess.graph) 95 96 tb = self._server.query_op_traceback("a") 97 self.assertIn((self._curr_file_path, a_lineno, this_func_name), tb) 98 tb = self._server.query_op_traceback("b") 99 self.assertIn((self._curr_file_path, b_lineno, this_func_name), tb) 100 tb = self._server.query_op_traceback("x") 101 self.assertIn((self._curr_file_path, x_lineno, this_func_name), tb) 102 103 self.assertIn( 104 (self._curr_file_path, send_lineno, this_func_name), 105 self._server.query_origin_stack()[-1]) 106 107 self.assertEqual( 108 " a = variables.Variable(21.0, name=\"a\")", 109 self._server.query_source_file_line(__file__, a_lineno)) 110 # Files in the TensorFlow code base shouldn not have been sent. 111 tf_trace_file_path = self._findFirstTraceInsideTensorFlowPyLibrary(a.op) 112 with self.assertRaises(ValueError): 113 self._server.query_source_file_line(tf_trace_file_path, 0) 114 self.assertEqual([debug_service_pb2.CallTraceback.GRAPH_EXECUTION], 115 self._server.query_call_types()) 116 self.assertEqual(["dummy_run_key"], self._server.query_call_keys()) 117 self.assertEqual( 118 [sess.graph.version], self._server.query_graph_versions()) 119 120 def testSendGraphTracebacksToTwoDebugServers(self): 121 this_func_name = "testSendGraphTracebacksToTwoDebugServers" 122 with session.Session() as sess: 123 a = variables.Variable(21.0, name="two/a") 124 a_lineno = line_number_above() 125 b = variables.Variable(2.0, name="two/b") 126 b_lineno = line_number_above() 127 x = math_ops.add(a, b, name="two/x") 128 x_lineno = line_number_above() 129 130 send_traceback = traceback.extract_stack() 131 send_lineno = line_number_above() 132 source_remote.send_graph_tracebacks( 133 [self._server_address, self._server_address_2], 134 "dummy_run_key", send_traceback, sess.graph) 135 136 servers = [self._server, self._server_2] 137 for server in servers: 138 tb = server.query_op_traceback("two/a") 139 self.assertIn((self._curr_file_path, a_lineno, this_func_name), tb) 140 tb = server.query_op_traceback("two/b") 141 self.assertIn((self._curr_file_path, b_lineno, this_func_name), tb) 142 tb = server.query_op_traceback("two/x") 143 self.assertIn((self._curr_file_path, x_lineno, this_func_name), tb) 144 145 self.assertIn( 146 (self._curr_file_path, send_lineno, this_func_name), 147 server.query_origin_stack()[-1]) 148 149 self.assertEqual( 150 " x = math_ops.add(a, b, name=\"two/x\")", 151 server.query_source_file_line(__file__, x_lineno)) 152 tf_trace_file_path = self._findFirstTraceInsideTensorFlowPyLibrary(x.op) 153 with self.assertRaises(ValueError): 154 server.query_source_file_line(tf_trace_file_path, 0) 155 self.assertEqual([debug_service_pb2.CallTraceback.GRAPH_EXECUTION], 156 server.query_call_types()) 157 self.assertEqual(["dummy_run_key"], server.query_call_keys()) 158 self.assertEqual([sess.graph.version], server.query_graph_versions()) 159 160 def testSourceFileSizeExceedsGrpcMessageLengthLimit(self): 161 """In case source file size exceeds the grpc message length limit. 162 163 it ought not to have been sent to the server. 164 """ 165 this_func_name = "testSourceFileSizeExceedsGrpcMessageLengthLimit" 166 167 # Patch the method to simulate a very small message length limit. 168 with test.mock.patch.object( 169 source_remote, "grpc_message_length_bytes", return_value=2): 170 with session.Session() as sess: 171 a = variables.Variable(21.0, name="two/a") 172 a_lineno = line_number_above() 173 b = variables.Variable(2.0, name="two/b") 174 b_lineno = line_number_above() 175 x = math_ops.add(a, b, name="two/x") 176 x_lineno = line_number_above() 177 178 send_traceback = traceback.extract_stack() 179 send_lineno = line_number_above() 180 source_remote.send_graph_tracebacks( 181 [self._server_address, self._server_address_2], 182 "dummy_run_key", send_traceback, sess.graph) 183 184 servers = [self._server, self._server_2] 185 for server in servers: 186 # Even though the source file content is not sent, the traceback 187 # should have been sent. 188 tb = server.query_op_traceback("two/a") 189 self.assertIn((self._curr_file_path, a_lineno, this_func_name), tb) 190 tb = server.query_op_traceback("two/b") 191 self.assertIn((self._curr_file_path, b_lineno, this_func_name), tb) 192 tb = server.query_op_traceback("two/x") 193 self.assertIn((self._curr_file_path, x_lineno, this_func_name), tb) 194 195 self.assertIn( 196 (self._curr_file_path, send_lineno, this_func_name), 197 server.query_origin_stack()[-1]) 198 199 tf_trace_file_path = ( 200 self._findFirstTraceInsideTensorFlowPyLibrary(x.op)) 201 # Verify that the source content is not sent to the server. 202 with self.assertRaises(ValueError): 203 self._server.query_source_file_line(tf_trace_file_path, 0) 204 205 def testSendEagerTracebacksToSingleDebugServer(self): 206 this_func_name = "testSendEagerTracebacksToSingleDebugServer" 207 send_traceback = traceback.extract_stack() 208 send_lineno = line_number_above() 209 source_remote.send_eager_tracebacks(self._server_address, send_traceback) 210 211 self.assertEqual([debug_service_pb2.CallTraceback.EAGER_EXECUTION], 212 self._server.query_call_types()) 213 self.assertIn((self._curr_file_path, send_lineno, this_func_name), 214 self._server.query_origin_stack()[-1]) 215 216 217 if __name__ == "__main__": 218 googletest.main() 219