Home | History | Annotate | Download | only in lib
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for debugger functionalities in tf.Session with grpc:// URLs.
     16 
     17 This test file focuses on the grpc:// debugging of local (non-distributed)
     18 tf.Sessions.
     19 """
     20 from __future__ import absolute_import
     21 from __future__ import division
     22 from __future__ import print_function
     23 
     24 import os
     25 import shutil
     26 
     27 import numpy as np
     28 from six.moves import xrange  # pylint: disable=redefined-builtin
     29 
     30 from tensorflow.core.protobuf import config_pb2
     31 from tensorflow.core.protobuf import rewriter_config_pb2
     32 from tensorflow.python.client import session
     33 from tensorflow.python.debug.lib import debug_data
     34 from tensorflow.python.debug.lib import debug_utils
     35 from tensorflow.python.debug.lib import grpc_debug_test_server
     36 from tensorflow.python.debug.lib import session_debug_testlib
     37 from tensorflow.python.debug.wrappers import framework
     38 from tensorflow.python.debug.wrappers import grpc_wrapper
     39 from tensorflow.python.debug.wrappers import hooks
     40 from tensorflow.python.framework import constant_op
     41 from tensorflow.python.framework import dtypes
     42 from tensorflow.python.framework import ops
     43 from tensorflow.python.framework import test_util
     44 from tensorflow.python.ops import array_ops
     45 from tensorflow.python.ops import math_ops
     46 from tensorflow.python.ops import state_ops
     47 from tensorflow.python.ops import variables
     48 from tensorflow.python.platform import googletest
     49 from tensorflow.python.platform import test
     50 from tensorflow.python.platform import tf_logging
     51 from tensorflow.python.training import monitored_session
     52 
     53 
     54 def no_rewrite_session_config():
     55   rewriter_config = rewriter_config_pb2.RewriterConfig(
     56       disable_model_pruning=True,
     57       arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
     58       dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF)
     59   graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
     60   return config_pb2.ConfigProto(graph_options=graph_options)
     61 
     62 
     63 class GrpcDebugServerTest(test_util.TensorFlowTestCase):
     64 
     65   def testRepeatedRunServerRaisesException(self):
     66     (_, _, _, server_thread,
     67      server) = grpc_debug_test_server.start_server_on_separate_thread(
     68          poll_server=True)
     69     # The server is started asynchronously. It needs to be polled till its state
     70     # has become started.
     71 
     72     with self.assertRaisesRegexp(
     73         ValueError, "Server has already started running"):
     74       server.run_server()
     75 
     76     server.stop_server().wait()
     77     server_thread.join()
     78 
     79   def testRepeatedStopServerRaisesException(self):
     80     (_, _, _, server_thread,
     81      server) = grpc_debug_test_server.start_server_on_separate_thread(
     82          poll_server=True)
     83     server.stop_server().wait()
     84     server_thread.join()
     85 
     86     with self.assertRaisesRegexp(ValueError, "Server has already stopped"):
     87       server.stop_server().wait()
     88 
     89   def testRunServerAfterStopRaisesException(self):
     90     (_, _, _, server_thread,
     91      server) = grpc_debug_test_server.start_server_on_separate_thread(
     92          poll_server=True)
     93     server.stop_server().wait()
     94     server_thread.join()
     95 
     96     with self.assertRaisesRegexp(ValueError, "Server has already stopped"):
     97       server.run_server()
     98 
     99   def testStartServerWithoutBlocking(self):
    100     (_, _, _, server_thread,
    101      server) = grpc_debug_test_server.start_server_on_separate_thread(
    102          poll_server=True, blocking=False)
    103     # The thread that starts the server shouldn't block, so we should be able to
    104     # join it before stopping the server.
    105     server_thread.join()
    106     server.stop_server().wait()
    107 
    108 
    109 class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase):
    110 
    111   @classmethod
    112   def setUpClass(cls):
    113     session_debug_testlib.SessionDebugTestBase.setUpClass()
    114     (cls._server_port, cls._debug_server_url, cls._server_dump_dir,
    115      cls._server_thread,
    116      cls._server) = grpc_debug_test_server.start_server_on_separate_thread()
    117 
    118   @classmethod
    119   def tearDownClass(cls):
    120     # Stop the test server and join the thread.
    121     cls._server.stop_server().wait()
    122     cls._server_thread.join()
    123 
    124     session_debug_testlib.SessionDebugTestBase.tearDownClass()
    125 
    126   def setUp(self):
    127     # Override the dump root as the test server's dump directory.
    128     self._dump_root = self._server_dump_dir
    129 
    130   def tearDown(self):
    131     if os.path.isdir(self._server_dump_dir):
    132       shutil.rmtree(self._server_dump_dir)
    133     session_debug_testlib.SessionDebugTestBase.tearDown(self)
    134 
    135   def _debug_urls(self, run_number=None):
    136     return ["grpc://localhost:%d" % self._server_port]
    137 
    138   def _debug_dump_dir(self, run_number=None):
    139     if run_number is None:
    140       return self._dump_root
    141     else:
    142       return os.path.join(self._dump_root, "run_%d" % run_number)
    143 
    144   def testConstructGrpcDebugWrapperSessionWithInvalidTypeRaisesException(self):
    145     sess = session.Session(config=no_rewrite_session_config())
    146     with self.assertRaisesRegexp(
    147         TypeError, "Expected type str or list in grpc_debug_server_addresses"):
    148       grpc_wrapper.GrpcDebugWrapperSession(sess, 1337)
    149 
    150   def testConstructGrpcDebugWrapperSessionWithInvalidTypeRaisesException2(self):
    151     sess = session.Session(config=no_rewrite_session_config())
    152     with self.assertRaisesRegexp(
    153         TypeError, "Expected type str in list grpc_debug_server_addresses"):
    154       grpc_wrapper.GrpcDebugWrapperSession(sess, ["localhost:1337", 1338])
    155 
    156   def testUseInvalidWatchFnTypeWithGrpcDebugWrapperSessionRaisesException(self):
    157     sess = session.Session(config=no_rewrite_session_config())
    158     with self.assertRaises(TypeError):
    159       grpc_wrapper.GrpcDebugWrapperSession(
    160           sess, "localhost:%d" % self._server_port, watch_fn="foo")
    161 
    162   def testGrpcDebugWrapperSessionWithoutWatchFnWorks(self):
    163     u = variables.Variable(2.1, name="u")
    164     v = variables.Variable(20.0, name="v")
    165     w = math_ops.multiply(u, v, name="w")
    166 
    167     sess = session.Session(config=no_rewrite_session_config())
    168     sess.run(u.initializer)
    169     sess.run(v.initializer)
    170 
    171     sess = grpc_wrapper.GrpcDebugWrapperSession(
    172         sess, "localhost:%d" % self._server_port)
    173     w_result = sess.run(w)
    174     self.assertAllClose(42.0, w_result)
    175 
    176     dump = debug_data.DebugDumpDir(self._dump_root)
    177     self.assertEqual(5, dump.size)
    178     self.assertAllClose([2.1], dump.get_tensors("u", 0, "DebugIdentity"))
    179     self.assertAllClose([2.1], dump.get_tensors("u/read", 0, "DebugIdentity"))
    180     self.assertAllClose([20.0], dump.get_tensors("v", 0, "DebugIdentity"))
    181     self.assertAllClose([20.0], dump.get_tensors("v/read", 0, "DebugIdentity"))
    182     self.assertAllClose([42.0], dump.get_tensors("w", 0, "DebugIdentity"))
    183 
    184   def testGrpcDebugWrapperSessionWithWatchFnWorks(self):
    185     def watch_fn(feeds, fetch_keys):
    186       del feeds, fetch_keys
    187       return ["DebugIdentity", "DebugNumericSummary"], r".*/read", None
    188 
    189     u = variables.Variable(2.1, name="u")
    190     v = variables.Variable(20.0, name="v")
    191     w = math_ops.multiply(u, v, name="w")
    192 
    193     sess = session.Session(config=no_rewrite_session_config())
    194     sess.run(u.initializer)
    195     sess.run(v.initializer)
    196 
    197     sess = grpc_wrapper.GrpcDebugWrapperSession(
    198         sess, "localhost:%d" % self._server_port, watch_fn=watch_fn)
    199     w_result = sess.run(w)
    200     self.assertAllClose(42.0, w_result)
    201 
    202     dump = debug_data.DebugDumpDir(self._dump_root)
    203     self.assertEqual(4, dump.size)
    204     self.assertAllClose([2.1], dump.get_tensors("u/read", 0, "DebugIdentity"))
    205     self.assertEqual(
    206         14, len(dump.get_tensors("u/read", 0, "DebugNumericSummary")[0]))
    207     self.assertAllClose([20.0], dump.get_tensors("v/read", 0, "DebugIdentity"))
    208     self.assertEqual(
    209         14, len(dump.get_tensors("v/read", 0, "DebugNumericSummary")[0]))
    210 
    211   def testGrpcDebugHookWithStatelessWatchFnWorks(self):
    212     # Perform some set up. Specifically, construct a simple TensorFlow graph and
    213     # create a watch function for certain ops.
    214     def watch_fn(feeds, fetch_keys):
    215       del feeds, fetch_keys
    216       return framework.WatchOptions(
    217           debug_ops=["DebugIdentity", "DebugNumericSummary"],
    218           node_name_regex_whitelist=r".*/read",
    219           op_type_regex_whitelist=None,
    220           tolerate_debug_op_creation_failures=True)
    221 
    222     u = variables.Variable(2.1, name="u")
    223     v = variables.Variable(20.0, name="v")
    224     w = math_ops.multiply(u, v, name="w")
    225 
    226     sess = session.Session(config=no_rewrite_session_config())
    227     sess.run(u.initializer)
    228     sess.run(v.initializer)
    229 
    230     # Create a hook. One could use this hook with say a tflearn Estimator.
    231     # However, we use a HookedSession in this test to avoid depending on the
    232     # internal implementation of Estimators.
    233     grpc_debug_hook = hooks.GrpcDebugHook(
    234         ["localhost:%d" % self._server_port], watch_fn=watch_fn)
    235     sess = monitored_session._HookedSession(sess, [grpc_debug_hook])
    236 
    237     # Run the hooked session. This should stream tensor data to the GRPC
    238     # endpoints.
    239     w_result = sess.run(w)
    240 
    241     # Verify that the hook monitored the correct tensors.
    242     self.assertAllClose(42.0, w_result)
    243     dump = debug_data.DebugDumpDir(self._dump_root)
    244     self.assertEqual(4, dump.size)
    245     self.assertAllClose([2.1], dump.get_tensors("u/read", 0, "DebugIdentity"))
    246     self.assertEqual(
    247         14, len(dump.get_tensors("u/read", 0, "DebugNumericSummary")[0]))
    248     self.assertAllClose([20.0], dump.get_tensors("v/read", 0, "DebugIdentity"))
    249     self.assertEqual(
    250         14, len(dump.get_tensors("v/read", 0, "DebugNumericSummary")[0]))
    251 
    252   def testTensorBoardDebugHookWorks(self):
    253     u = variables.Variable(2.1, name="u")
    254     v = variables.Variable(20.0, name="v")
    255     w = math_ops.multiply(u, v, name="w")
    256 
    257     sess = session.Session(config=no_rewrite_session_config())
    258     sess.run(u.initializer)
    259     sess.run(v.initializer)
    260 
    261     grpc_debug_hook = hooks.TensorBoardDebugHook(
    262         ["localhost:%d" % self._server_port])
    263     sess = monitored_session._HookedSession(sess, [grpc_debug_hook])
    264 
    265     # Activate watch point on a tensor before calling sess.run().
    266     self._server.request_watch("u/read", 0, "DebugIdentity")
    267     self.assertAllClose(42.0, sess.run(w))
    268 
    269     # self.assertAllClose(42.0, sess.run(w))
    270     dump = debug_data.DebugDumpDir(self._dump_root)
    271     self.assertAllClose([2.1], dump.get_tensors("u/read", 0, "DebugIdentity"))
    272 
    273     # Check that the server has received the stack trace.
    274     self.assertTrue(self._server.query_op_traceback("u"))
    275     self.assertTrue(self._server.query_op_traceback("u/read"))
    276     self.assertTrue(self._server.query_op_traceback("v"))
    277     self.assertTrue(self._server.query_op_traceback("v/read"))
    278     self.assertTrue(self._server.query_op_traceback("w"))
    279 
    280     # Check that the server has received the python file content.
    281     # Query an arbitrary line to make sure that is the case.
    282     with open(__file__, "rt") as this_source_file:
    283       first_line = this_source_file.readline().strip()
    284       self.assertEqual(
    285           first_line, self._server.query_source_file_line(__file__, 1))
    286 
    287     self._server.clear_data()
    288     # Call sess.run() again, and verify that this time the traceback and source
    289     # code is not sent, because the graph version is not newer.
    290     self.assertAllClose(42.0, sess.run(w))
    291     with self.assertRaises(ValueError):
    292       self._server.query_op_traceback("delta_1")
    293     with self.assertRaises(ValueError):
    294       self._server.query_source_file_line(__file__, 1)
    295 
    296   def testTensorBoardDebugHookDisablingTracebackSourceCodeSendingWorks(self):
    297     u = variables.Variable(2.1, name="u")
    298     v = variables.Variable(20.0, name="v")
    299     w = math_ops.multiply(u, v, name="w")
    300 
    301     sess = session.Session(config=no_rewrite_session_config())
    302     sess.run(variables.global_variables_initializer())
    303 
    304     grpc_debug_hook = hooks.TensorBoardDebugHook(
    305         ["localhost:%d" % self._server_port],
    306         send_traceback_and_source_code=False)
    307     sess = monitored_session._HookedSession(sess, [grpc_debug_hook])
    308 
    309     # Activate watch point on a tensor before calling sess.run().
    310     self._server.request_watch("u/read", 0, "DebugIdentity")
    311     self.assertAllClose(42.0, sess.run(w))
    312 
    313     # Check that the server has _not_ received any tracebacks, as a result of
    314     # the disabling above.
    315     with self.assertRaisesRegexp(
    316         ValueError, r"Op .*u/read.* does not exist"):
    317       self.assertTrue(self._server.query_op_traceback("u/read"))
    318     with self.assertRaisesRegexp(
    319         ValueError, r".* has not received any source file"):
    320       self._server.query_source_file_line(__file__, 1)
    321 
    322   def testConstructGrpcDebugHookWithOrWithouGrpcInUrlWorks(self):
    323     hooks.GrpcDebugHook(["grpc://foo:42424"])
    324     hooks.GrpcDebugHook(["foo:42424"])
    325 
    326 
    327 class LargeGraphAndLargeTensorsDebugTest(test_util.TensorFlowTestCase):
    328 
    329   @classmethod
    330   def setUpClass(cls):
    331     (cls.debug_server_port, cls.debug_server_url, _, cls.debug_server_thread,
    332      cls.debug_server
    333     ) = grpc_debug_test_server.start_server_on_separate_thread(
    334         dump_to_filesystem=False)
    335     tf_logging.info("debug server url: %s", cls.debug_server_url)
    336 
    337   @classmethod
    338   def tearDownClass(cls):
    339     cls.debug_server.stop_server().wait()
    340     cls.debug_server_thread.join()
    341 
    342   def tearDown(self):
    343     ops.reset_default_graph()
    344     self.debug_server.clear_data()
    345 
    346   def testSendingLargeGraphDefsWorks(self):
    347     with self.test_session(
    348         use_gpu=True, config=no_rewrite_session_config()) as sess:
    349       u = variables.Variable(42.0, name="original_u")
    350       for _ in xrange(50 * 1000):
    351         u = array_ops.identity(u)
    352       sess.run(variables.global_variables_initializer())
    353 
    354       def watch_fn(fetches, feeds):
    355         del fetches, feeds
    356         return framework.WatchOptions(
    357             debug_ops=["DebugIdentity"],
    358             node_name_regex_whitelist=r"original_u")
    359       sess = grpc_wrapper.GrpcDebugWrapperSession(
    360           sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
    361       self.assertAllClose(42.0, sess.run(u))
    362 
    363       self.assertAllClose(
    364           [42.0],
    365           self.debug_server.debug_tensor_values["original_u:0:DebugIdentity"])
    366       self.assertEqual(2 if test.is_gpu_available() else 1,
    367                        len(self.debug_server.partition_graph_defs))
    368       max_graph_def_size = max([
    369           len(graph_def.SerializeToString())
    370           for graph_def in self.debug_server.partition_graph_defs])
    371       self.assertGreater(max_graph_def_size, 4 * 1024 * 1024)
    372 
    373   def testSendingLargeFloatTensorWorks(self):
    374     with self.test_session(
    375         use_gpu=True, config=no_rewrite_session_config()) as sess:
    376       u_init_val_array = list(xrange(1200 * 1024))
    377       # Size: 4 * 1200 * 1024 = 4800k > 4M
    378 
    379       u_init = constant_op.constant(
    380           u_init_val_array, dtype=dtypes.float32, name="u_init")
    381       u = variables.Variable(u_init, name="u")
    382 
    383       def watch_fn(fetches, feeds):
    384         del fetches, feeds  # Unused by this watch_fn.
    385         return framework.WatchOptions(
    386             debug_ops=["DebugIdentity"],
    387             node_name_regex_whitelist=r"u_init")
    388       sess = grpc_wrapper.GrpcDebugWrapperSession(
    389           sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
    390       sess.run(u.initializer)
    391 
    392       self.assertAllEqual(
    393           u_init_val_array,
    394           self.debug_server.debug_tensor_values["u_init:0:DebugIdentity"][0])
    395 
    396   def testSendingStringTensorWithAlmostTooLargeStringsWorks(self):
    397     with self.test_session(
    398         use_gpu=True, config=no_rewrite_session_config()) as sess:
    399       u_init_val = [
    400           b"", b"spam", b"A" * 2500 * 1024, b"B" * 2500 * 1024, b"egg", b""]
    401       u_init = constant_op.constant(
    402           u_init_val, dtype=dtypes.string, name="u_init")
    403       u = variables.Variable(u_init, name="u")
    404 
    405       def watch_fn(fetches, feeds):
    406         del fetches, feeds
    407         return framework.WatchOptions(
    408             debug_ops=["DebugIdentity"],
    409             node_name_regex_whitelist=r"u_init")
    410       sess = grpc_wrapper.GrpcDebugWrapperSession(
    411           sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
    412       sess.run(u.initializer)
    413 
    414       self.assertAllEqual(
    415           u_init_val,
    416           self.debug_server.debug_tensor_values["u_init:0:DebugIdentity"][0])
    417 
    418   def testSendingLargeStringTensorWorks(self):
    419     with self.test_session(
    420         use_gpu=True, config=no_rewrite_session_config()) as sess:
    421       strs_total_size_threshold = 5000 * 1024
    422       cum_size = 0
    423       u_init_val_array = []
    424       while cum_size < strs_total_size_threshold:
    425         strlen = np.random.randint(200)
    426         u_init_val_array.append(b"A" * strlen)
    427         cum_size += strlen
    428 
    429       u_init = constant_op.constant(
    430           u_init_val_array, dtype=dtypes.string, name="u_init")
    431       u = variables.Variable(u_init, name="u")
    432 
    433       def watch_fn(fetches, feeds):
    434         del fetches, feeds
    435         return framework.WatchOptions(
    436             debug_ops=["DebugIdentity"],
    437             node_name_regex_whitelist=r"u_init")
    438       sess = grpc_wrapper.GrpcDebugWrapperSession(
    439           sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
    440       sess.run(u.initializer)
    441 
    442       self.assertAllEqual(
    443           u_init_val_array,
    444           self.debug_server.debug_tensor_values["u_init:0:DebugIdentity"][0])
    445 
    446   def testSendingEmptyFloatTensorWorks(self):
    447     with self.test_session(
    448         use_gpu=True, config=no_rewrite_session_config()) as sess:
    449       u_init = constant_op.constant(
    450           [], dtype=dtypes.float32, shape=[0], name="u_init")
    451       u = variables.Variable(u_init, name="u")
    452 
    453       def watch_fn(fetches, feeds):
    454         del fetches, feeds
    455         return framework.WatchOptions(
    456             debug_ops=["DebugIdentity"],
    457             node_name_regex_whitelist=r"u_init")
    458       sess = grpc_wrapper.GrpcDebugWrapperSession(
    459           sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
    460       sess.run(u.initializer)
    461 
    462       u_init_value = self.debug_server.debug_tensor_values[
    463           "u_init:0:DebugIdentity"][0]
    464       self.assertEqual(np.float32, u_init_value.dtype)
    465       self.assertEqual(0, len(u_init_value))
    466 
    467   def testSendingEmptyStringTensorWorks(self):
    468     with self.test_session(
    469         use_gpu=True, config=no_rewrite_session_config()) as sess:
    470       u_init = constant_op.constant(
    471           [], dtype=dtypes.string, shape=[0], name="u_init")
    472       u = variables.Variable(u_init, name="u")
    473 
    474       def watch_fn(fetches, feeds):
    475         del fetches, feeds
    476         return framework.WatchOptions(
    477             debug_ops=["DebugIdentity"],
    478             node_name_regex_whitelist=r"u_init")
    479       sess = grpc_wrapper.GrpcDebugWrapperSession(
    480           sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
    481       sess.run(u.initializer)
    482 
    483       u_init_value = self.debug_server.debug_tensor_values[
    484           "u_init:0:DebugIdentity"][0]
    485       self.assertEqual(np.object, u_init_value.dtype)
    486       self.assertEqual(0, len(u_init_value))
    487 
    488 
    489 class SessionDebugConcurrentTest(
    490     session_debug_testlib.DebugConcurrentRunCallsTest):
    491 
    492   @classmethod
    493   def setUpClass(cls):
    494     session_debug_testlib.SessionDebugTestBase.setUpClass()
    495     (cls._server_port, cls._debug_server_url, cls._server_dump_dir,
    496      cls._server_thread,
    497      cls._server) = grpc_debug_test_server.start_server_on_separate_thread()
    498 
    499   @classmethod
    500   def tearDownClass(cls):
    501     # Stop the test server and join the thread.
    502     cls._server.stop_server().wait()
    503     cls._server_thread.join()
    504     session_debug_testlib.SessionDebugTestBase.tearDownClass()
    505 
    506   def setUp(self):
    507     self._num_concurrent_runs = 3
    508     self._dump_roots = []
    509     for i in range(self._num_concurrent_runs):
    510       self._dump_roots.append(
    511           os.path.join(self._server_dump_dir, "thread%d" % i))
    512 
    513   def tearDown(self):
    514     ops.reset_default_graph()
    515     if os.path.isdir(self._server_dump_dir):
    516       shutil.rmtree(self._server_dump_dir)
    517 
    518   def _get_concurrent_debug_urls(self):
    519     urls = []
    520     for i in range(self._num_concurrent_runs):
    521       urls.append(self._debug_server_url + "/thread%d" % i)
    522     return urls
    523 
    524 
    525 class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase):
    526   """Test server gating of debug ops."""
    527 
    528   @classmethod
    529   def setUpClass(cls):
    530     (cls._server_port_1, cls._debug_server_url_1, _, cls._server_thread_1,
    531      cls._server_1) = grpc_debug_test_server.start_server_on_separate_thread(
    532          dump_to_filesystem=False)
    533     (cls._server_port_2, cls._debug_server_url_2, _, cls._server_thread_2,
    534      cls._server_2) = grpc_debug_test_server.start_server_on_separate_thread(
    535          dump_to_filesystem=False)
    536     cls._servers_and_threads = [(cls._server_1, cls._server_thread_1),
    537                                 (cls._server_2, cls._server_thread_2)]
    538 
    539   @classmethod
    540   def tearDownClass(cls):
    541     for server, thread in cls._servers_and_threads:
    542       server.stop_server().wait()
    543       thread.join()
    544 
    545   def tearDown(self):
    546     ops.reset_default_graph()
    547     self._server_1.clear_data()
    548     self._server_2.clear_data()
    549 
    550   def testToggleEnableTwoDebugWatchesNoCrosstalkBetweenDebugNodes(self):
    551     with session.Session(config=no_rewrite_session_config()) as sess:
    552       v_1 = variables.Variable(50.0, name="v_1")
    553       v_2 = variables.Variable(-50.0, name="v_1")
    554       delta_1 = constant_op.constant(5.0, name="delta_1")
    555       delta_2 = constant_op.constant(-5.0, name="delta_2")
    556       inc_v_1 = state_ops.assign_add(v_1, delta_1, name="inc_v_1")
    557       inc_v_2 = state_ops.assign_add(v_2, delta_2, name="inc_v_2")
    558 
    559       sess.run([v_1.initializer, v_2.initializer])
    560 
    561       run_metadata = config_pb2.RunMetadata()
    562       run_options = config_pb2.RunOptions(output_partition_graphs=True)
    563       debug_utils.watch_graph(
    564           run_options,
    565           sess.graph,
    566           debug_ops=["DebugIdentity(gated_grpc=true)",
    567                      "DebugNumericSummary(gated_grpc=true)"],
    568           debug_urls=[self._debug_server_url_1])
    569 
    570       for i in xrange(4):
    571         self._server_1.clear_data()
    572 
    573         if i % 2 == 0:
    574           self._server_1.request_watch("delta_1", 0, "DebugIdentity")
    575           self._server_1.request_watch("delta_2", 0, "DebugIdentity")
    576           self._server_1.request_unwatch("delta_1", 0, "DebugNumericSummary")
    577           self._server_1.request_unwatch("delta_2", 0, "DebugNumericSummary")
    578         else:
    579           self._server_1.request_unwatch("delta_1", 0, "DebugIdentity")
    580           self._server_1.request_unwatch("delta_2", 0, "DebugIdentity")
    581           self._server_1.request_watch("delta_1", 0, "DebugNumericSummary")
    582           self._server_1.request_watch("delta_2", 0, "DebugNumericSummary")
    583 
    584         sess.run([inc_v_1, inc_v_2],
    585                  options=run_options, run_metadata=run_metadata)
    586 
    587         # Watched debug tensors are:
    588         #   Run 0: delta_[1,2]:0:DebugIdentity
    589         #   Run 1: delta_[1,2]:0:DebugNumericSummary
    590         #   Run 2: delta_[1,2]:0:DebugIdentity
    591         #   Run 3: delta_[1,2]:0:DebugNumericSummary
    592         self.assertEqual(2, len(self._server_1.debug_tensor_values))
    593         if i % 2 == 0:
    594           self.assertAllClose(
    595               [5.0],
    596               self._server_1.debug_tensor_values["delta_1:0:DebugIdentity"])
    597           self.assertAllClose(
    598               [-5.0],
    599               self._server_1.debug_tensor_values["delta_2:0:DebugIdentity"])
    600         else:
    601           self.assertAllClose(
    602               [[1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 5.0, 5.0, 5.0,
    603                 0.0, 1.0, 0.0]],
    604               self._server_1.debug_tensor_values[
    605                   "delta_1:0:DebugNumericSummary"])
    606           self.assertAllClose(
    607               [[1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, -5.0, -5.0, -5.0,
    608                 0.0, 1.0, 0.0]],
    609               self._server_1.debug_tensor_values[
    610                   "delta_2:0:DebugNumericSummary"])
    611 
    612   def testToggleWatchesOnCoreMetadata(self):
    613     (_, debug_server_url, _, server_thread,
    614      server) = grpc_debug_test_server.start_server_on_separate_thread(
    615          dump_to_filesystem=False,
    616          toggle_watch_on_core_metadata=[("toggled_1", 0, "DebugIdentity"),
    617                                         ("toggled_2", 0, "DebugIdentity")])
    618     self._servers_and_threads.append((server, server_thread))
    619 
    620     with session.Session(config=no_rewrite_session_config()) as sess:
    621       v_1 = variables.Variable(50.0, name="v_1")
    622       v_2 = variables.Variable(-50.0, name="v_1")
    623       # These two nodes have names that match those in the
    624       # toggle_watch_on_core_metadata argument used when calling
    625       # start_server_on_separate_thread().
    626       toggled_1 = constant_op.constant(5.0, name="toggled_1")
    627       toggled_2 = constant_op.constant(-5.0, name="toggled_2")
    628       inc_v_1 = state_ops.assign_add(v_1, toggled_1, name="inc_v_1")
    629       inc_v_2 = state_ops.assign_add(v_2, toggled_2, name="inc_v_2")
    630 
    631       sess.run([v_1.initializer, v_2.initializer])
    632 
    633       run_metadata = config_pb2.RunMetadata()
    634       run_options = config_pb2.RunOptions(output_partition_graphs=True)
    635       debug_utils.watch_graph(
    636           run_options,
    637           sess.graph,
    638           debug_ops=["DebugIdentity(gated_grpc=true)"],
    639           debug_urls=[debug_server_url])
    640 
    641       for i in xrange(4):
    642         server.clear_data()
    643 
    644         sess.run([inc_v_1, inc_v_2],
    645                  options=run_options, run_metadata=run_metadata)
    646 
    647         if i % 2 == 0:
    648           self.assertEqual(2, len(server.debug_tensor_values))
    649           self.assertAllClose(
    650               [5.0],
    651               server.debug_tensor_values["toggled_1:0:DebugIdentity"])
    652           self.assertAllClose(
    653               [-5.0],
    654               server.debug_tensor_values["toggled_2:0:DebugIdentity"])
    655         else:
    656           self.assertEqual(0, len(server.debug_tensor_values))
    657 
    658   def testToggleEnableTwoDebugWatchesNoCrosstalkBetweenServers(self):
    659     with session.Session(config=no_rewrite_session_config()) as sess:
    660       v = variables.Variable(50.0, name="v")
    661       delta = constant_op.constant(5.0, name="delta")
    662       inc_v = state_ops.assign_add(v, delta, name="inc_v")
    663 
    664       sess.run(v.initializer)
    665 
    666       run_metadata = config_pb2.RunMetadata()
    667       run_options = config_pb2.RunOptions(output_partition_graphs=True)
    668       debug_utils.watch_graph(
    669           run_options,
    670           sess.graph,
    671           debug_ops=["DebugIdentity(gated_grpc=true)"],
    672           debug_urls=[self._debug_server_url_1, self._debug_server_url_2])
    673 
    674       for i in xrange(4):
    675         self._server_1.clear_data()
    676         self._server_2.clear_data()
    677 
    678         if i % 2 == 0:
    679           self._server_1.request_watch("delta", 0, "DebugIdentity")
    680           self._server_2.request_watch("v", 0, "DebugIdentity")
    681         else:
    682           self._server_1.request_unwatch("delta", 0, "DebugIdentity")
    683           self._server_2.request_unwatch("v", 0, "DebugIdentity")
    684 
    685         sess.run(inc_v, options=run_options, run_metadata=run_metadata)
    686 
    687         if i % 2 == 0:
    688           self.assertEqual(1, len(self._server_1.debug_tensor_values))
    689           self.assertEqual(1, len(self._server_2.debug_tensor_values))
    690           self.assertAllClose(
    691               [5.0],
    692               self._server_1.debug_tensor_values["delta:0:DebugIdentity"])
    693           self.assertAllClose(
    694               [50 + 5.0 * i],
    695               self._server_2.debug_tensor_values["v:0:DebugIdentity"])
    696         else:
    697           self.assertEqual(0, len(self._server_1.debug_tensor_values))
    698           self.assertEqual(0, len(self._server_2.debug_tensor_values))
    699 
    700   def testToggleBreakpointsWorks(self):
    701     with session.Session(config=no_rewrite_session_config()) as sess:
    702       v_1 = variables.Variable(50.0, name="v_1")
    703       v_2 = variables.Variable(-50.0, name="v_2")
    704       delta_1 = constant_op.constant(5.0, name="delta_1")
    705       delta_2 = constant_op.constant(-5.0, name="delta_2")
    706       inc_v_1 = state_ops.assign_add(v_1, delta_1, name="inc_v_1")
    707       inc_v_2 = state_ops.assign_add(v_2, delta_2, name="inc_v_2")
    708 
    709       sess.run([v_1.initializer, v_2.initializer])
    710 
    711       run_metadata = config_pb2.RunMetadata()
    712       run_options = config_pb2.RunOptions(output_partition_graphs=True)
    713       debug_utils.watch_graph(
    714           run_options,
    715           sess.graph,
    716           debug_ops=["DebugIdentity(gated_grpc=true)"],
    717           debug_urls=[self._debug_server_url_1])
    718 
    719       for i in xrange(4):
    720         self._server_1.clear_data()
    721 
    722         if i in (0, 2):
    723           # Enable breakpoint at delta_[1,2]:0:DebugIdentity in runs 0 and 2.
    724           self._server_1.request_watch(
    725               "delta_1", 0, "DebugIdentity", breakpoint=True)
    726           self._server_1.request_watch(
    727               "delta_2", 0, "DebugIdentity", breakpoint=True)
    728         else:
    729           # Disable the breakpoint in runs 1 and 3.
    730           self._server_1.request_unwatch("delta_1", 0, "DebugIdentity")
    731           self._server_1.request_unwatch("delta_2", 0, "DebugIdentity")
    732 
    733         output = sess.run([inc_v_1, inc_v_2],
    734                           options=run_options, run_metadata=run_metadata)
    735         self.assertAllClose([50.0 + 5.0 * (i + 1), -50 - 5.0 * (i + 1)], output)
    736 
    737         if i in (0, 2):
    738           # During runs 0 and 2, the server should have received the published
    739           # debug tensor delta:0:DebugIdentity. The breakpoint should have been
    740           # unblocked by EventReply reponses from the server.
    741           self.assertAllClose(
    742               [5.0],
    743               self._server_1.debug_tensor_values["delta_1:0:DebugIdentity"])
    744           self.assertAllClose(
    745               [-5.0],
    746               self._server_1.debug_tensor_values["delta_2:0:DebugIdentity"])
    747           # After the runs, the server should have properly registered the
    748           # breakpoints due to the request_unwatch calls.
    749           self.assertSetEqual({("delta_1", 0, "DebugIdentity"),
    750                                ("delta_2", 0, "DebugIdentity")},
    751                               self._server_1.breakpoints)
    752         else:
    753           # After the end of runs 1 and 3, the server has received the requests
    754           # to disable the breakpoint at delta:0:DebugIdentity.
    755           self.assertSetEqual(set(), self._server_1.breakpoints)
    756 
    757   def testTensorBoardDebuggerWrapperToggleBreakpointsWorks(self):
    758     with session.Session(config=no_rewrite_session_config()) as sess:
    759       v_1 = variables.Variable(50.0, name="v_1")
    760       v_2 = variables.Variable(-50.0, name="v_2")
    761       delta_1 = constant_op.constant(5.0, name="delta_1")
    762       delta_2 = constant_op.constant(-5.0, name="delta_2")
    763       inc_v_1 = state_ops.assign_add(v_1, delta_1, name="inc_v_1")
    764       inc_v_2 = state_ops.assign_add(v_2, delta_2, name="inc_v_2")
    765 
    766       sess.run([v_1.initializer, v_2.initializer])
    767 
    768       # The TensorBoardDebugWrapperSession should add a DebugIdentity debug op
    769       # with attribute gated_grpc=True for every tensor in the graph.
    770       sess = grpc_wrapper.TensorBoardDebugWrapperSession(
    771           sess, self._debug_server_url_1)
    772 
    773       for i in xrange(4):
    774         self._server_1.clear_data()
    775 
    776         if i in (0, 2):
    777           # Enable breakpoint at delta_[1,2]:0:DebugIdentity in runs 0 and 2.
    778           self._server_1.request_watch(
    779               "delta_1", 0, "DebugIdentity", breakpoint=True)
    780           self._server_1.request_watch(
    781               "delta_2", 0, "DebugIdentity", breakpoint=True)
    782         else:
    783           # Disable the breakpoint in runs 1 and 3.
    784           self._server_1.request_unwatch("delta_1", 0, "DebugIdentity")
    785           self._server_1.request_unwatch("delta_2", 0, "DebugIdentity")
    786 
    787         output = sess.run([inc_v_1, inc_v_2])
    788         self.assertAllClose([50.0 + 5.0 * (i + 1), -50 - 5.0 * (i + 1)], output)
    789 
    790         if i in (0, 2):
    791           # During runs 0 and 2, the server should have received the published
    792           # debug tensor delta:0:DebugIdentity. The breakpoint should have been
    793           # unblocked by EventReply reponses from the server.
    794           self.assertAllClose(
    795               [5.0],
    796               self._server_1.debug_tensor_values["delta_1:0:DebugIdentity"])
    797           self.assertAllClose(
    798               [-5.0],
    799               self._server_1.debug_tensor_values["delta_2:0:DebugIdentity"])
    800           # After the runs, the server should have properly registered the
    801           # breakpoints.
    802         else:
    803           # After the end of runs 1 and 3, the server has received the requests
    804           # to disable the breakpoint at delta:0:DebugIdentity.
    805           self.assertSetEqual(set(), self._server_1.breakpoints)
    806 
    807         if i == 0:
    808           # Check that the server has received the stack trace.
    809           self.assertTrue(self._server_1.query_op_traceback("delta_1"))
    810           self.assertTrue(self._server_1.query_op_traceback("delta_2"))
    811           self.assertTrue(self._server_1.query_op_traceback("inc_v_1"))
    812           self.assertTrue(self._server_1.query_op_traceback("inc_v_2"))
    813           # Check that the server has received the python file content.
    814           # Query an arbitrary line to make sure that is the case.
    815           with open(__file__, "rt") as this_source_file:
    816             first_line = this_source_file.readline().strip()
    817           self.assertEqual(
    818               first_line, self._server_1.query_source_file_line(__file__, 1))
    819         else:
    820           # In later Session.run() calls, the traceback shouldn't have been sent
    821           # because it is already sent in the 1st call. So calling
    822           # query_op_traceback() should lead to an exception, because the test
    823           # debug server clears the data at the beginning of every iteration.
    824           with self.assertRaises(ValueError):
    825             self._server_1.query_op_traceback("delta_1")
    826           with self.assertRaises(ValueError):
    827             self._server_1.query_source_file_line(__file__, 1)
    828 
    829   def testTensorBoardDebuggerWrapperDisablingTracebackSourceSendingWorks(self):
    830     with session.Session(config=no_rewrite_session_config()) as sess:
    831       v_1 = variables.Variable(50.0, name="v_1")
    832       v_2 = variables.Variable(-50.0, name="v_2")
    833       delta_1 = constant_op.constant(5.0, name="delta_1")
    834       delta_2 = constant_op.constant(-5.0, name="delta_2")
    835       inc_v_1 = state_ops.assign_add(v_1, delta_1, name="inc_v_1")
    836       inc_v_2 = state_ops.assign_add(v_2, delta_2, name="inc_v_2")
    837 
    838       sess.run(variables.global_variables_initializer())
    839 
    840       # Disable the sending of traceback and source code.
    841       sess = grpc_wrapper.TensorBoardDebugWrapperSession(
    842           sess, self._debug_server_url_1, send_traceback_and_source_code=False)
    843 
    844       for i in xrange(4):
    845         self._server_1.clear_data()
    846 
    847         if i == 0:
    848           self._server_1.request_watch(
    849               "delta_1", 0, "DebugIdentity", breakpoint=True)
    850 
    851         output = sess.run([inc_v_1, inc_v_2])
    852         self.assertAllClose([50.0 + 5.0 * (i + 1), -50 - 5.0 * (i + 1)], output)
    853 
    854         # No op traceback or source code should have been received by the debug
    855         # server due to the disabling above.
    856         with self.assertRaisesRegexp(
    857             ValueError, r"Op .*delta_1.* does not exist"):
    858           self.assertTrue(self._server_1.query_op_traceback("delta_1"))
    859         with self.assertRaisesRegexp(
    860             ValueError, r".* has not received any source file"):
    861           self._server_1.query_source_file_line(__file__, 1)
    862 
    863   def testGetGrpcDebugWatchesReturnsCorrectAnswer(self):
    864     with session.Session() as sess:
    865       v = variables.Variable(50.0, name="v")
    866       delta = constant_op.constant(5.0, name="delta")
    867       inc_v = state_ops.assign_add(v, delta, name="inc_v")
    868 
    869       sess.run(v.initializer)
    870 
    871       # Before any debugged runs, the server should be aware of no debug
    872       # watches.
    873       self.assertEqual([], self._server_1.gated_grpc_debug_watches())
    874 
    875       run_metadata = config_pb2.RunMetadata()
    876       run_options = config_pb2.RunOptions(output_partition_graphs=True)
    877       debug_utils.add_debug_tensor_watch(
    878           run_options, "delta", output_slot=0,
    879           debug_ops=["DebugNumericSummary(gated_grpc=true)"],
    880           debug_urls=[self._debug_server_url_1])
    881       debug_utils.add_debug_tensor_watch(
    882           run_options, "v", output_slot=0,
    883           debug_ops=["DebugIdentity"],
    884           debug_urls=[self._debug_server_url_1])
    885       sess.run(inc_v, options=run_options, run_metadata=run_metadata)
    886 
    887       # After the first run, the server should have noted the debug watches
    888       # for which gated_grpc == True, but not the ones with gated_grpc == False.
    889       self.assertEqual(1, len(self._server_1.gated_grpc_debug_watches()))
    890       debug_watch = self._server_1.gated_grpc_debug_watches()[0]
    891       self.assertEqual("delta", debug_watch.node_name)
    892       self.assertEqual(0, debug_watch.output_slot)
    893       self.assertEqual("DebugNumericSummary", debug_watch.debug_op)
    894 
    895 
    896 class DelayedDebugServerTest(test_util.TensorFlowTestCase):
    897 
    898   def testDebuggedSessionRunWorksWithDelayedDebugServerStartup(self):
    899     """Test debugged Session.run() tolerates delayed debug server startup."""
    900     ops.reset_default_graph()
    901 
    902     # Start a debug server asynchronously, with a certain amount of delay.
    903     (debug_server_port, _, _, server_thread,
    904      debug_server) = grpc_debug_test_server.start_server_on_separate_thread(
    905          server_start_delay_sec=2.0, dump_to_filesystem=False)
    906 
    907     with self.test_session() as sess:
    908       a_init = constant_op.constant(42.0, name="a_init")
    909       a = variables.Variable(a_init, name="a")
    910 
    911       def watch_fn(fetches, feeds):
    912         del fetches, feeds
    913         return framework.WatchOptions(debug_ops=["DebugIdentity"])
    914 
    915       sess = grpc_wrapper.GrpcDebugWrapperSession(
    916           sess, "localhost:%d" % debug_server_port, watch_fn=watch_fn)
    917       sess.run(a.initializer)
    918       self.assertAllClose(
    919           [42.0], debug_server.debug_tensor_values["a_init:0:DebugIdentity"])
    920 
    921     debug_server.stop_server().wait()
    922     server_thread.join()
    923 
    924 
    925 if __name__ == "__main__":
    926   googletest.main()
    927