Home | History | Annotate | Download | only in lib
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Unit tests for source_remote."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import os
     22 import traceback
     23 
     24 from tensorflow.core.debug import debug_service_pb2
     25 from tensorflow.python.client import session
     26 from tensorflow.python.debug.lib import grpc_debug_test_server
     27 from tensorflow.python.debug.lib import source_remote
     28 from tensorflow.python.debug.lib import source_utils
     29 from tensorflow.python.framework import ops
     30 from tensorflow.python.framework import test_util
     31 from tensorflow.python.ops import math_ops
     32 # Import resource_variable_ops for the variables-to-tensor implicit conversion.
     33 from tensorflow.python.ops import resource_variable_ops  # pylint: disable=unused-import
     34 from tensorflow.python.ops import variables
     35 from tensorflow.python.platform import googletest
     36 from tensorflow.python.util import tf_inspect
     37 
     38 
     39 def line_number_above():
     40   return tf_inspect.stack()[1][2] - 1
     41 
     42 
     43 class SendTracebacksTest(test_util.TensorFlowTestCase):
     44 
     45   @classmethod
     46   def setUpClass(cls):
     47     test_util.TensorFlowTestCase.setUpClass()
     48     (cls._server_port, cls._debug_server_url, cls._server_dump_dir,
     49      cls._server_thread,
     50      cls._server) = grpc_debug_test_server.start_server_on_separate_thread()
     51     cls._server_address = "localhost:%d" % cls._server_port
     52     (cls._server_port_2, cls._debug_server_url_2, cls._server_dump_dir_2,
     53      cls._server_thread_2,
     54      cls._server_2) = grpc_debug_test_server.start_server_on_separate_thread()
     55     cls._server_address_2 = "localhost:%d" % cls._server_port_2
     56     cls._curr_file_path = os.path.normpath(os.path.abspath(__file__))
     57 
     58   @classmethod
     59   def tearDownClass(cls):
     60     # Stop the test server and join the thread.
     61     cls._server.stop_server().wait()
     62     cls._server_thread.join()
     63     cls._server_2.stop_server().wait()
     64     cls._server_thread_2.join()
     65     test_util.TensorFlowTestCase.tearDownClass()
     66 
     67   def tearDown(self):
     68     ops.reset_default_graph()
     69     self._server.clear_data()
     70     self._server_2.clear_data()
     71     super(SendTracebacksTest, self).tearDown()
     72 
     73   def _findFirstTraceInsideTensorFlowPyLibrary(self, op):
     74     """Find the first trace of an op that belongs to the TF Python library."""
     75     for trace in op.traceback:
     76       if source_utils.guess_is_tensorflow_py_library(trace[0]):
     77         return trace
     78 
     79   def testSendGraphTracebacksToSingleDebugServer(self):
     80     this_func_name = "testSendGraphTracebacksToSingleDebugServer"
     81     with session.Session() as sess:
     82       a = variables.Variable(21.0, name="a")
     83       a_lineno = line_number_above()
     84       b = variables.Variable(2.0, name="b")
     85       b_lineno = line_number_above()
     86       math_ops.add(a, b, name="x")
     87       x_lineno = line_number_above()
     88 
     89       send_stack = traceback.extract_stack()
     90       send_lineno = line_number_above()
     91       source_remote.send_graph_tracebacks(
     92           self._server_address, "dummy_run_key", send_stack, sess.graph)
     93 
     94       tb = self._server.query_op_traceback("a")
     95       self.assertIn((self._curr_file_path, a_lineno, this_func_name), tb)
     96       tb = self._server.query_op_traceback("b")
     97       self.assertIn((self._curr_file_path, b_lineno, this_func_name), tb)
     98       tb = self._server.query_op_traceback("x")
     99       self.assertIn((self._curr_file_path, x_lineno, this_func_name), tb)
    100 
    101       self.assertIn(
    102           (self._curr_file_path, send_lineno, this_func_name),
    103           self._server.query_origin_stack()[-1])
    104 
    105       self.assertEqual(
    106           "      a = variables.Variable(21.0, name=\"a\")",
    107           self._server.query_source_file_line(__file__, a_lineno))
    108       # Files in the TensorFlow code base shouldn not have been sent.
    109       tf_trace_file_path = self._findFirstTraceInsideTensorFlowPyLibrary(a.op)
    110       with self.assertRaises(ValueError):
    111         self._server.query_source_file_line(tf_trace_file_path, 0)
    112       self.assertEqual([debug_service_pb2.CallTraceback.GRAPH_EXECUTION],
    113                        self._server.query_call_types())
    114       self.assertEqual(["dummy_run_key"], self._server.query_call_keys())
    115       self.assertEqual(
    116           [sess.graph.version], self._server.query_graph_versions())
    117 
    118   def testSendGraphTracebacksToTwoDebugServers(self):
    119     this_func_name = "testSendGraphTracebacksToTwoDebugServers"
    120     with session.Session() as sess:
    121       a = variables.Variable(21.0, name="two/a")
    122       a_lineno = line_number_above()
    123       b = variables.Variable(2.0, name="two/b")
    124       b_lineno = line_number_above()
    125       x = math_ops.add(a, b, name="two/x")
    126       x_lineno = line_number_above()
    127 
    128       send_traceback = traceback.extract_stack()
    129       send_lineno = line_number_above()
    130       source_remote.send_graph_tracebacks(
    131           [self._server_address, self._server_address_2],
    132           "dummy_run_key", send_traceback, sess.graph)
    133 
    134       servers = [self._server, self._server_2]
    135       for server in servers:
    136         tb = server.query_op_traceback("two/a")
    137         self.assertIn((self._curr_file_path, a_lineno, this_func_name), tb)
    138         tb = server.query_op_traceback("two/b")
    139         self.assertIn((self._curr_file_path, b_lineno, this_func_name), tb)
    140         tb = server.query_op_traceback("two/x")
    141         self.assertIn((self._curr_file_path, x_lineno, this_func_name), tb)
    142 
    143         self.assertIn(
    144             (self._curr_file_path, send_lineno, this_func_name),
    145             server.query_origin_stack()[-1])
    146 
    147         self.assertEqual(
    148             "      x = math_ops.add(a, b, name=\"two/x\")",
    149             server.query_source_file_line(__file__, x_lineno))
    150         tf_trace_file_path = self._findFirstTraceInsideTensorFlowPyLibrary(x.op)
    151         with self.assertRaises(ValueError):
    152           server.query_source_file_line(tf_trace_file_path, 0)
    153         self.assertEqual([debug_service_pb2.CallTraceback.GRAPH_EXECUTION],
    154                          server.query_call_types())
    155         self.assertEqual(["dummy_run_key"], server.query_call_keys())
    156         self.assertEqual([sess.graph.version], server.query_graph_versions())
    157 
    158   def testSendEagerTracebacksToSingleDebugServer(self):
    159     this_func_name = "testSendEagerTracebacksToSingleDebugServer"
    160     send_traceback = traceback.extract_stack()
    161     send_lineno = line_number_above()
    162     source_remote.send_eager_tracebacks(self._server_address, send_traceback)
    163 
    164     self.assertEqual([debug_service_pb2.CallTraceback.EAGER_EXECUTION],
    165                      self._server.query_call_types())
    166     self.assertIn((self._curr_file_path, send_lineno, this_func_name),
    167                   self._server.query_origin_stack()[-1])
    168 
    169 
    170 if __name__ == "__main__":
    171   googletest.main()
    172