Home | History | Annotate | Download | only in lib
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Unit tests for source_remote."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import os
     22 import traceback
     23 
     24 from tensorflow.core.debug import debug_service_pb2
     25 from tensorflow.python.client import session
     26 from tensorflow.python.debug.lib import grpc_debug_test_server
     27 from tensorflow.python.debug.lib import source_remote
     28 from tensorflow.python.debug.lib import source_utils
     29 from tensorflow.python.framework import ops
     30 from tensorflow.python.framework import test_util
     31 from tensorflow.python.ops import math_ops
     32 # Import resource_variable_ops for the variables-to-tensor implicit conversion.
     33 from tensorflow.python.ops import resource_variable_ops  # pylint: disable=unused-import
     34 from tensorflow.python.ops import variables
     35 from tensorflow.python.platform import googletest
     36 from tensorflow.python.platform import test
     37 from tensorflow.python.util import tf_inspect
     38 
     39 
     40 def line_number_above():
     41   return tf_inspect.stack()[1][2] - 1
     42 
     43 
     44 class SendTracebacksTest(test_util.TensorFlowTestCase):
     45 
     46   @classmethod
     47   def setUpClass(cls):
     48     test_util.TensorFlowTestCase.setUpClass()
     49     (cls._server_port, cls._debug_server_url, cls._server_dump_dir,
     50      cls._server_thread,
     51      cls._server) = grpc_debug_test_server.start_server_on_separate_thread(
     52          poll_server=True)
     53     cls._server_address = "localhost:%d" % cls._server_port
     54     (cls._server_port_2, cls._debug_server_url_2, cls._server_dump_dir_2,
     55      cls._server_thread_2,
     56      cls._server_2) = grpc_debug_test_server.start_server_on_separate_thread()
     57     cls._server_address_2 = "localhost:%d" % cls._server_port_2
     58     cls._curr_file_path = os.path.normpath(os.path.abspath(__file__))
     59 
     60   @classmethod
     61   def tearDownClass(cls):
     62     # Stop the test server and join the thread.
     63     cls._server.stop_server().wait()
     64     cls._server_thread.join()
     65     cls._server_2.stop_server().wait()
     66     cls._server_thread_2.join()
     67     test_util.TensorFlowTestCase.tearDownClass()
     68 
     69   def tearDown(self):
     70     ops.reset_default_graph()
     71     self._server.clear_data()
     72     self._server_2.clear_data()
     73     super(SendTracebacksTest, self).tearDown()
     74 
     75   def _findFirstTraceInsideTensorFlowPyLibrary(self, op):
     76     """Find the first trace of an op that belongs to the TF Python library."""
     77     for trace in op.traceback:
     78       if source_utils.guess_is_tensorflow_py_library(trace[0]):
     79         return trace
     80 
     81   def testSendGraphTracebacksToSingleDebugServer(self):
     82     this_func_name = "testSendGraphTracebacksToSingleDebugServer"
     83     with session.Session() as sess:
     84       a = variables.Variable(21.0, name="a")
     85       a_lineno = line_number_above()
     86       b = variables.Variable(2.0, name="b")
     87       b_lineno = line_number_above()
     88       math_ops.add(a, b, name="x")
     89       x_lineno = line_number_above()
     90 
     91       send_stack = traceback.extract_stack()
     92       send_lineno = line_number_above()
     93       source_remote.send_graph_tracebacks(
     94           self._server_address, "dummy_run_key", send_stack, sess.graph)
     95 
     96       tb = self._server.query_op_traceback("a")
     97       self.assertIn((self._curr_file_path, a_lineno, this_func_name), tb)
     98       tb = self._server.query_op_traceback("b")
     99       self.assertIn((self._curr_file_path, b_lineno, this_func_name), tb)
    100       tb = self._server.query_op_traceback("x")
    101       self.assertIn((self._curr_file_path, x_lineno, this_func_name), tb)
    102 
    103       self.assertIn(
    104           (self._curr_file_path, send_lineno, this_func_name),
    105           self._server.query_origin_stack()[-1])
    106 
    107       self.assertEqual(
    108           "      a = variables.Variable(21.0, name=\"a\")",
    109           self._server.query_source_file_line(__file__, a_lineno))
    110       # Files in the TensorFlow code base shouldn not have been sent.
    111       tf_trace_file_path = self._findFirstTraceInsideTensorFlowPyLibrary(a.op)
    112       with self.assertRaises(ValueError):
    113         self._server.query_source_file_line(tf_trace_file_path, 0)
    114       self.assertEqual([debug_service_pb2.CallTraceback.GRAPH_EXECUTION],
    115                        self._server.query_call_types())
    116       self.assertEqual(["dummy_run_key"], self._server.query_call_keys())
    117       self.assertEqual(
    118           [sess.graph.version], self._server.query_graph_versions())
    119 
    120   def testSendGraphTracebacksToTwoDebugServers(self):
    121     this_func_name = "testSendGraphTracebacksToTwoDebugServers"
    122     with session.Session() as sess:
    123       a = variables.Variable(21.0, name="two/a")
    124       a_lineno = line_number_above()
    125       b = variables.Variable(2.0, name="two/b")
    126       b_lineno = line_number_above()
    127       x = math_ops.add(a, b, name="two/x")
    128       x_lineno = line_number_above()
    129 
    130       send_traceback = traceback.extract_stack()
    131       send_lineno = line_number_above()
    132       source_remote.send_graph_tracebacks(
    133           [self._server_address, self._server_address_2],
    134           "dummy_run_key", send_traceback, sess.graph)
    135 
    136       servers = [self._server, self._server_2]
    137       for server in servers:
    138         tb = server.query_op_traceback("two/a")
    139         self.assertIn((self._curr_file_path, a_lineno, this_func_name), tb)
    140         tb = server.query_op_traceback("two/b")
    141         self.assertIn((self._curr_file_path, b_lineno, this_func_name), tb)
    142         tb = server.query_op_traceback("two/x")
    143         self.assertIn((self._curr_file_path, x_lineno, this_func_name), tb)
    144 
    145         self.assertIn(
    146             (self._curr_file_path, send_lineno, this_func_name),
    147             server.query_origin_stack()[-1])
    148 
    149         self.assertEqual(
    150             "      x = math_ops.add(a, b, name=\"two/x\")",
    151             server.query_source_file_line(__file__, x_lineno))
    152         tf_trace_file_path = self._findFirstTraceInsideTensorFlowPyLibrary(x.op)
    153         with self.assertRaises(ValueError):
    154           server.query_source_file_line(tf_trace_file_path, 0)
    155         self.assertEqual([debug_service_pb2.CallTraceback.GRAPH_EXECUTION],
    156                          server.query_call_types())
    157         self.assertEqual(["dummy_run_key"], server.query_call_keys())
    158         self.assertEqual([sess.graph.version], server.query_graph_versions())
    159 
    160   def testSourceFileSizeExceedsGrpcMessageLengthLimit(self):
    161     """In case source file size exceeds the grpc message length limit.
    162 
    163     it ought not to have been sent to the server.
    164     """
    165     this_func_name = "testSourceFileSizeExceedsGrpcMessageLengthLimit"
    166 
    167     # Patch the method to simulate a very small message length limit.
    168     with test.mock.patch.object(
    169         source_remote, "grpc_message_length_bytes", return_value=2):
    170       with session.Session() as sess:
    171         a = variables.Variable(21.0, name="two/a")
    172         a_lineno = line_number_above()
    173         b = variables.Variable(2.0, name="two/b")
    174         b_lineno = line_number_above()
    175         x = math_ops.add(a, b, name="two/x")
    176         x_lineno = line_number_above()
    177 
    178         send_traceback = traceback.extract_stack()
    179         send_lineno = line_number_above()
    180         source_remote.send_graph_tracebacks(
    181             [self._server_address, self._server_address_2],
    182             "dummy_run_key", send_traceback, sess.graph)
    183 
    184         servers = [self._server, self._server_2]
    185         for server in servers:
    186           # Even though the source file content is not sent, the traceback
    187           # should have been sent.
    188           tb = server.query_op_traceback("two/a")
    189           self.assertIn((self._curr_file_path, a_lineno, this_func_name), tb)
    190           tb = server.query_op_traceback("two/b")
    191           self.assertIn((self._curr_file_path, b_lineno, this_func_name), tb)
    192           tb = server.query_op_traceback("two/x")
    193           self.assertIn((self._curr_file_path, x_lineno, this_func_name), tb)
    194 
    195           self.assertIn(
    196               (self._curr_file_path, send_lineno, this_func_name),
    197               server.query_origin_stack()[-1])
    198 
    199           tf_trace_file_path = (
    200               self._findFirstTraceInsideTensorFlowPyLibrary(x.op))
    201           # Verify that the source content is not sent to the server.
    202           with self.assertRaises(ValueError):
    203             self._server.query_source_file_line(tf_trace_file_path, 0)
    204 
    205   def testSendEagerTracebacksToSingleDebugServer(self):
    206     this_func_name = "testSendEagerTracebacksToSingleDebugServer"
    207     send_traceback = traceback.extract_stack()
    208     send_lineno = line_number_above()
    209     source_remote.send_eager_tracebacks(self._server_address, send_traceback)
    210 
    211     self.assertEqual([debug_service_pb2.CallTraceback.EAGER_EXECUTION],
    212                      self._server.query_call_types())
    213     self.assertIn((self._curr_file_path, send_lineno, this_func_name),
    214                   self._server.query_origin_stack()[-1])
    215 
    216 
    217 if __name__ == "__main__":
    218   googletest.main()
    219