1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for debugger functionalities in tf.Session with grpc:// URLs. 16 17 This test file focuses on the grpc:// debugging of local (non-distributed) 18 tf.Sessions. 19 """ 20 from __future__ import absolute_import 21 from __future__ import division 22 from __future__ import print_function 23 24 import os 25 import shutil 26 27 import numpy as np 28 from six.moves import xrange # pylint: disable=redefined-builtin 29 30 from tensorflow.core.protobuf import config_pb2 31 from tensorflow.core.protobuf import rewriter_config_pb2 32 from tensorflow.python.client import session 33 from tensorflow.python.debug.lib import debug_data 34 from tensorflow.python.debug.lib import debug_utils 35 from tensorflow.python.debug.lib import grpc_debug_test_server 36 from tensorflow.python.debug.lib import session_debug_testlib 37 from tensorflow.python.debug.wrappers import framework 38 from tensorflow.python.debug.wrappers import grpc_wrapper 39 from tensorflow.python.debug.wrappers import hooks 40 from tensorflow.python.framework import constant_op 41 from tensorflow.python.framework import dtypes 42 from tensorflow.python.framework import ops 43 from tensorflow.python.framework import test_util 44 from tensorflow.python.ops import array_ops 45 from tensorflow.python.ops import math_ops 46 from tensorflow.python.ops import state_ops 47 from tensorflow.python.ops import variables 48 from tensorflow.python.platform import googletest 49 from tensorflow.python.platform import test 50 from tensorflow.python.platform import tf_logging 51 from tensorflow.python.training import monitored_session 52 53 54 def no_rewrite_session_config(): 55 rewriter_config = rewriter_config_pb2.RewriterConfig( 56 disable_model_pruning=True, 57 arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF, 58 dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF) 59 graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config) 60 return config_pb2.ConfigProto(graph_options=graph_options) 61 62 63 class GrpcDebugServerTest(test_util.TensorFlowTestCase): 64 65 def testRepeatedRunServerRaisesException(self): 66 (_, _, _, server_thread, 67 server) = grpc_debug_test_server.start_server_on_separate_thread( 68 poll_server=True) 69 # The server is started asynchronously. It needs to be polled till its state 70 # has become started. 71 72 with self.assertRaisesRegexp( 73 ValueError, "Server has already started running"): 74 server.run_server() 75 76 server.stop_server().wait() 77 server_thread.join() 78 79 def testRepeatedStopServerRaisesException(self): 80 (_, _, _, server_thread, 81 server) = grpc_debug_test_server.start_server_on_separate_thread( 82 poll_server=True) 83 server.stop_server().wait() 84 server_thread.join() 85 86 with self.assertRaisesRegexp(ValueError, "Server has already stopped"): 87 server.stop_server().wait() 88 89 def testRunServerAfterStopRaisesException(self): 90 (_, _, _, server_thread, 91 server) = grpc_debug_test_server.start_server_on_separate_thread( 92 poll_server=True) 93 server.stop_server().wait() 94 server_thread.join() 95 96 with self.assertRaisesRegexp(ValueError, "Server has already stopped"): 97 server.run_server() 98 99 def testStartServerWithoutBlocking(self): 100 (_, _, _, server_thread, 101 server) = grpc_debug_test_server.start_server_on_separate_thread( 102 poll_server=True, blocking=False) 103 # The thread that starts the server shouldn't block, so we should be able to 104 # join it before stopping the server. 105 server_thread.join() 106 server.stop_server().wait() 107 108 109 class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase): 110 111 @classmethod 112 def setUpClass(cls): 113 session_debug_testlib.SessionDebugTestBase.setUpClass() 114 (cls._server_port, cls._debug_server_url, cls._server_dump_dir, 115 cls._server_thread, 116 cls._server) = grpc_debug_test_server.start_server_on_separate_thread() 117 118 @classmethod 119 def tearDownClass(cls): 120 # Stop the test server and join the thread. 121 cls._server.stop_server().wait() 122 cls._server_thread.join() 123 124 session_debug_testlib.SessionDebugTestBase.tearDownClass() 125 126 def setUp(self): 127 # Override the dump root as the test server's dump directory. 128 self._dump_root = self._server_dump_dir 129 130 def tearDown(self): 131 if os.path.isdir(self._server_dump_dir): 132 shutil.rmtree(self._server_dump_dir) 133 session_debug_testlib.SessionDebugTestBase.tearDown(self) 134 135 def _debug_urls(self, run_number=None): 136 return ["grpc://localhost:%d" % self._server_port] 137 138 def _debug_dump_dir(self, run_number=None): 139 if run_number is None: 140 return self._dump_root 141 else: 142 return os.path.join(self._dump_root, "run_%d" % run_number) 143 144 def testConstructGrpcDebugWrapperSessionWithInvalidTypeRaisesException(self): 145 sess = session.Session(config=no_rewrite_session_config()) 146 with self.assertRaisesRegexp( 147 TypeError, "Expected type str or list in grpc_debug_server_addresses"): 148 grpc_wrapper.GrpcDebugWrapperSession(sess, 1337) 149 150 def testConstructGrpcDebugWrapperSessionWithInvalidTypeRaisesException2(self): 151 sess = session.Session(config=no_rewrite_session_config()) 152 with self.assertRaisesRegexp( 153 TypeError, "Expected type str in list grpc_debug_server_addresses"): 154 grpc_wrapper.GrpcDebugWrapperSession(sess, ["localhost:1337", 1338]) 155 156 def testUseInvalidWatchFnTypeWithGrpcDebugWrapperSessionRaisesException(self): 157 sess = session.Session(config=no_rewrite_session_config()) 158 with self.assertRaises(TypeError): 159 grpc_wrapper.GrpcDebugWrapperSession( 160 sess, "localhost:%d" % self._server_port, watch_fn="foo") 161 162 def testGrpcDebugWrapperSessionWithoutWatchFnWorks(self): 163 u = variables.Variable(2.1, name="u") 164 v = variables.Variable(20.0, name="v") 165 w = math_ops.multiply(u, v, name="w") 166 167 sess = session.Session(config=no_rewrite_session_config()) 168 sess.run(u.initializer) 169 sess.run(v.initializer) 170 171 sess = grpc_wrapper.GrpcDebugWrapperSession( 172 sess, "localhost:%d" % self._server_port) 173 w_result = sess.run(w) 174 self.assertAllClose(42.0, w_result) 175 176 dump = debug_data.DebugDumpDir(self._dump_root) 177 self.assertEqual(5, dump.size) 178 self.assertAllClose([2.1], dump.get_tensors("u", 0, "DebugIdentity")) 179 self.assertAllClose([2.1], dump.get_tensors("u/read", 0, "DebugIdentity")) 180 self.assertAllClose([20.0], dump.get_tensors("v", 0, "DebugIdentity")) 181 self.assertAllClose([20.0], dump.get_tensors("v/read", 0, "DebugIdentity")) 182 self.assertAllClose([42.0], dump.get_tensors("w", 0, "DebugIdentity")) 183 184 def testGrpcDebugWrapperSessionWithWatchFnWorks(self): 185 def watch_fn(feeds, fetch_keys): 186 del feeds, fetch_keys 187 return ["DebugIdentity", "DebugNumericSummary"], r".*/read", None 188 189 u = variables.Variable(2.1, name="u") 190 v = variables.Variable(20.0, name="v") 191 w = math_ops.multiply(u, v, name="w") 192 193 sess = session.Session(config=no_rewrite_session_config()) 194 sess.run(u.initializer) 195 sess.run(v.initializer) 196 197 sess = grpc_wrapper.GrpcDebugWrapperSession( 198 sess, "localhost:%d" % self._server_port, watch_fn=watch_fn) 199 w_result = sess.run(w) 200 self.assertAllClose(42.0, w_result) 201 202 dump = debug_data.DebugDumpDir(self._dump_root) 203 self.assertEqual(4, dump.size) 204 self.assertAllClose([2.1], dump.get_tensors("u/read", 0, "DebugIdentity")) 205 self.assertEqual( 206 14, len(dump.get_tensors("u/read", 0, "DebugNumericSummary")[0])) 207 self.assertAllClose([20.0], dump.get_tensors("v/read", 0, "DebugIdentity")) 208 self.assertEqual( 209 14, len(dump.get_tensors("v/read", 0, "DebugNumericSummary")[0])) 210 211 def testGrpcDebugHookWithStatelessWatchFnWorks(self): 212 # Perform some set up. Specifically, construct a simple TensorFlow graph and 213 # create a watch function for certain ops. 214 def watch_fn(feeds, fetch_keys): 215 del feeds, fetch_keys 216 return framework.WatchOptions( 217 debug_ops=["DebugIdentity", "DebugNumericSummary"], 218 node_name_regex_whitelist=r".*/read", 219 op_type_regex_whitelist=None, 220 tolerate_debug_op_creation_failures=True) 221 222 u = variables.Variable(2.1, name="u") 223 v = variables.Variable(20.0, name="v") 224 w = math_ops.multiply(u, v, name="w") 225 226 sess = session.Session(config=no_rewrite_session_config()) 227 sess.run(u.initializer) 228 sess.run(v.initializer) 229 230 # Create a hook. One could use this hook with say a tflearn Estimator. 231 # However, we use a HookedSession in this test to avoid depending on the 232 # internal implementation of Estimators. 233 grpc_debug_hook = hooks.GrpcDebugHook( 234 ["localhost:%d" % self._server_port], watch_fn=watch_fn) 235 sess = monitored_session._HookedSession(sess, [grpc_debug_hook]) 236 237 # Run the hooked session. This should stream tensor data to the GRPC 238 # endpoints. 239 w_result = sess.run(w) 240 241 # Verify that the hook monitored the correct tensors. 242 self.assertAllClose(42.0, w_result) 243 dump = debug_data.DebugDumpDir(self._dump_root) 244 self.assertEqual(4, dump.size) 245 self.assertAllClose([2.1], dump.get_tensors("u/read", 0, "DebugIdentity")) 246 self.assertEqual( 247 14, len(dump.get_tensors("u/read", 0, "DebugNumericSummary")[0])) 248 self.assertAllClose([20.0], dump.get_tensors("v/read", 0, "DebugIdentity")) 249 self.assertEqual( 250 14, len(dump.get_tensors("v/read", 0, "DebugNumericSummary")[0])) 251 252 def testTensorBoardDebugHookWorks(self): 253 u = variables.Variable(2.1, name="u") 254 v = variables.Variable(20.0, name="v") 255 w = math_ops.multiply(u, v, name="w") 256 257 sess = session.Session(config=no_rewrite_session_config()) 258 sess.run(u.initializer) 259 sess.run(v.initializer) 260 261 grpc_debug_hook = hooks.TensorBoardDebugHook( 262 ["localhost:%d" % self._server_port]) 263 sess = monitored_session._HookedSession(sess, [grpc_debug_hook]) 264 265 # Activate watch point on a tensor before calling sess.run(). 266 self._server.request_watch("u/read", 0, "DebugIdentity") 267 self.assertAllClose(42.0, sess.run(w)) 268 269 # self.assertAllClose(42.0, sess.run(w)) 270 dump = debug_data.DebugDumpDir(self._dump_root) 271 self.assertAllClose([2.1], dump.get_tensors("u/read", 0, "DebugIdentity")) 272 273 # Check that the server has received the stack trace. 274 self.assertTrue(self._server.query_op_traceback("u")) 275 self.assertTrue(self._server.query_op_traceback("u/read")) 276 self.assertTrue(self._server.query_op_traceback("v")) 277 self.assertTrue(self._server.query_op_traceback("v/read")) 278 self.assertTrue(self._server.query_op_traceback("w")) 279 280 # Check that the server has received the python file content. 281 # Query an arbitrary line to make sure that is the case. 282 with open(__file__, "rt") as this_source_file: 283 first_line = this_source_file.readline().strip() 284 self.assertEqual( 285 first_line, self._server.query_source_file_line(__file__, 1)) 286 287 self._server.clear_data() 288 # Call sess.run() again, and verify that this time the traceback and source 289 # code is not sent, because the graph version is not newer. 290 self.assertAllClose(42.0, sess.run(w)) 291 with self.assertRaises(ValueError): 292 self._server.query_op_traceback("delta_1") 293 with self.assertRaises(ValueError): 294 self._server.query_source_file_line(__file__, 1) 295 296 def testTensorBoardDebugHookDisablingTracebackSourceCodeSendingWorks(self): 297 u = variables.Variable(2.1, name="u") 298 v = variables.Variable(20.0, name="v") 299 w = math_ops.multiply(u, v, name="w") 300 301 sess = session.Session(config=no_rewrite_session_config()) 302 sess.run(variables.global_variables_initializer()) 303 304 grpc_debug_hook = hooks.TensorBoardDebugHook( 305 ["localhost:%d" % self._server_port], 306 send_traceback_and_source_code=False) 307 sess = monitored_session._HookedSession(sess, [grpc_debug_hook]) 308 309 # Activate watch point on a tensor before calling sess.run(). 310 self._server.request_watch("u/read", 0, "DebugIdentity") 311 self.assertAllClose(42.0, sess.run(w)) 312 313 # Check that the server has _not_ received any tracebacks, as a result of 314 # the disabling above. 315 with self.assertRaisesRegexp( 316 ValueError, r"Op .*u/read.* does not exist"): 317 self.assertTrue(self._server.query_op_traceback("u/read")) 318 with self.assertRaisesRegexp( 319 ValueError, r".* has not received any source file"): 320 self._server.query_source_file_line(__file__, 1) 321 322 def testConstructGrpcDebugHookWithOrWithouGrpcInUrlWorks(self): 323 hooks.GrpcDebugHook(["grpc://foo:42424"]) 324 hooks.GrpcDebugHook(["foo:42424"]) 325 326 327 class LargeGraphAndLargeTensorsDebugTest(test_util.TensorFlowTestCase): 328 329 @classmethod 330 def setUpClass(cls): 331 (cls.debug_server_port, cls.debug_server_url, _, cls.debug_server_thread, 332 cls.debug_server 333 ) = grpc_debug_test_server.start_server_on_separate_thread( 334 dump_to_filesystem=False) 335 tf_logging.info("debug server url: %s", cls.debug_server_url) 336 337 @classmethod 338 def tearDownClass(cls): 339 cls.debug_server.stop_server().wait() 340 cls.debug_server_thread.join() 341 342 def tearDown(self): 343 ops.reset_default_graph() 344 self.debug_server.clear_data() 345 346 def testSendingLargeGraphDefsWorks(self): 347 with self.test_session( 348 use_gpu=True, config=no_rewrite_session_config()) as sess: 349 u = variables.Variable(42.0, name="original_u") 350 for _ in xrange(50 * 1000): 351 u = array_ops.identity(u) 352 sess.run(variables.global_variables_initializer()) 353 354 def watch_fn(fetches, feeds): 355 del fetches, feeds 356 return framework.WatchOptions( 357 debug_ops=["DebugIdentity"], 358 node_name_regex_whitelist=r"original_u") 359 sess = grpc_wrapper.GrpcDebugWrapperSession( 360 sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn) 361 self.assertAllClose(42.0, sess.run(u)) 362 363 self.assertAllClose( 364 [42.0], 365 self.debug_server.debug_tensor_values["original_u:0:DebugIdentity"]) 366 self.assertEqual(2 if test.is_gpu_available() else 1, 367 len(self.debug_server.partition_graph_defs)) 368 max_graph_def_size = max([ 369 len(graph_def.SerializeToString()) 370 for graph_def in self.debug_server.partition_graph_defs]) 371 self.assertGreater(max_graph_def_size, 4 * 1024 * 1024) 372 373 def testSendingLargeFloatTensorWorks(self): 374 with self.test_session( 375 use_gpu=True, config=no_rewrite_session_config()) as sess: 376 u_init_val_array = list(xrange(1200 * 1024)) 377 # Size: 4 * 1200 * 1024 = 4800k > 4M 378 379 u_init = constant_op.constant( 380 u_init_val_array, dtype=dtypes.float32, name="u_init") 381 u = variables.Variable(u_init, name="u") 382 383 def watch_fn(fetches, feeds): 384 del fetches, feeds # Unused by this watch_fn. 385 return framework.WatchOptions( 386 debug_ops=["DebugIdentity"], 387 node_name_regex_whitelist=r"u_init") 388 sess = grpc_wrapper.GrpcDebugWrapperSession( 389 sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn) 390 sess.run(u.initializer) 391 392 self.assertAllEqual( 393 u_init_val_array, 394 self.debug_server.debug_tensor_values["u_init:0:DebugIdentity"][0]) 395 396 def testSendingStringTensorWithAlmostTooLargeStringsWorks(self): 397 with self.test_session( 398 use_gpu=True, config=no_rewrite_session_config()) as sess: 399 u_init_val = [ 400 b"", b"spam", b"A" * 2500 * 1024, b"B" * 2500 * 1024, b"egg", b""] 401 u_init = constant_op.constant( 402 u_init_val, dtype=dtypes.string, name="u_init") 403 u = variables.Variable(u_init, name="u") 404 405 def watch_fn(fetches, feeds): 406 del fetches, feeds 407 return framework.WatchOptions( 408 debug_ops=["DebugIdentity"], 409 node_name_regex_whitelist=r"u_init") 410 sess = grpc_wrapper.GrpcDebugWrapperSession( 411 sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn) 412 sess.run(u.initializer) 413 414 self.assertAllEqual( 415 u_init_val, 416 self.debug_server.debug_tensor_values["u_init:0:DebugIdentity"][0]) 417 418 def testSendingLargeStringTensorWorks(self): 419 with self.test_session( 420 use_gpu=True, config=no_rewrite_session_config()) as sess: 421 strs_total_size_threshold = 5000 * 1024 422 cum_size = 0 423 u_init_val_array = [] 424 while cum_size < strs_total_size_threshold: 425 strlen = np.random.randint(200) 426 u_init_val_array.append(b"A" * strlen) 427 cum_size += strlen 428 429 u_init = constant_op.constant( 430 u_init_val_array, dtype=dtypes.string, name="u_init") 431 u = variables.Variable(u_init, name="u") 432 433 def watch_fn(fetches, feeds): 434 del fetches, feeds 435 return framework.WatchOptions( 436 debug_ops=["DebugIdentity"], 437 node_name_regex_whitelist=r"u_init") 438 sess = grpc_wrapper.GrpcDebugWrapperSession( 439 sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn) 440 sess.run(u.initializer) 441 442 self.assertAllEqual( 443 u_init_val_array, 444 self.debug_server.debug_tensor_values["u_init:0:DebugIdentity"][0]) 445 446 def testSendingEmptyFloatTensorWorks(self): 447 with self.test_session( 448 use_gpu=True, config=no_rewrite_session_config()) as sess: 449 u_init = constant_op.constant( 450 [], dtype=dtypes.float32, shape=[0], name="u_init") 451 u = variables.Variable(u_init, name="u") 452 453 def watch_fn(fetches, feeds): 454 del fetches, feeds 455 return framework.WatchOptions( 456 debug_ops=["DebugIdentity"], 457 node_name_regex_whitelist=r"u_init") 458 sess = grpc_wrapper.GrpcDebugWrapperSession( 459 sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn) 460 sess.run(u.initializer) 461 462 u_init_value = self.debug_server.debug_tensor_values[ 463 "u_init:0:DebugIdentity"][0] 464 self.assertEqual(np.float32, u_init_value.dtype) 465 self.assertEqual(0, len(u_init_value)) 466 467 def testSendingEmptyStringTensorWorks(self): 468 with self.test_session( 469 use_gpu=True, config=no_rewrite_session_config()) as sess: 470 u_init = constant_op.constant( 471 [], dtype=dtypes.string, shape=[0], name="u_init") 472 u = variables.Variable(u_init, name="u") 473 474 def watch_fn(fetches, feeds): 475 del fetches, feeds 476 return framework.WatchOptions( 477 debug_ops=["DebugIdentity"], 478 node_name_regex_whitelist=r"u_init") 479 sess = grpc_wrapper.GrpcDebugWrapperSession( 480 sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn) 481 sess.run(u.initializer) 482 483 u_init_value = self.debug_server.debug_tensor_values[ 484 "u_init:0:DebugIdentity"][0] 485 self.assertEqual(np.object, u_init_value.dtype) 486 self.assertEqual(0, len(u_init_value)) 487 488 489 class SessionDebugConcurrentTest( 490 session_debug_testlib.DebugConcurrentRunCallsTest): 491 492 @classmethod 493 def setUpClass(cls): 494 session_debug_testlib.SessionDebugTestBase.setUpClass() 495 (cls._server_port, cls._debug_server_url, cls._server_dump_dir, 496 cls._server_thread, 497 cls._server) = grpc_debug_test_server.start_server_on_separate_thread() 498 499 @classmethod 500 def tearDownClass(cls): 501 # Stop the test server and join the thread. 502 cls._server.stop_server().wait() 503 cls._server_thread.join() 504 session_debug_testlib.SessionDebugTestBase.tearDownClass() 505 506 def setUp(self): 507 self._num_concurrent_runs = 3 508 self._dump_roots = [] 509 for i in range(self._num_concurrent_runs): 510 self._dump_roots.append( 511 os.path.join(self._server_dump_dir, "thread%d" % i)) 512 513 def tearDown(self): 514 ops.reset_default_graph() 515 if os.path.isdir(self._server_dump_dir): 516 shutil.rmtree(self._server_dump_dir) 517 518 def _get_concurrent_debug_urls(self): 519 urls = [] 520 for i in range(self._num_concurrent_runs): 521 urls.append(self._debug_server_url + "/thread%d" % i) 522 return urls 523 524 525 class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase): 526 """Test server gating of debug ops.""" 527 528 @classmethod 529 def setUpClass(cls): 530 (cls._server_port_1, cls._debug_server_url_1, _, cls._server_thread_1, 531 cls._server_1) = grpc_debug_test_server.start_server_on_separate_thread( 532 dump_to_filesystem=False) 533 (cls._server_port_2, cls._debug_server_url_2, _, cls._server_thread_2, 534 cls._server_2) = grpc_debug_test_server.start_server_on_separate_thread( 535 dump_to_filesystem=False) 536 cls._servers_and_threads = [(cls._server_1, cls._server_thread_1), 537 (cls._server_2, cls._server_thread_2)] 538 539 @classmethod 540 def tearDownClass(cls): 541 for server, thread in cls._servers_and_threads: 542 server.stop_server().wait() 543 thread.join() 544 545 def tearDown(self): 546 ops.reset_default_graph() 547 self._server_1.clear_data() 548 self._server_2.clear_data() 549 550 def testToggleEnableTwoDebugWatchesNoCrosstalkBetweenDebugNodes(self): 551 with session.Session(config=no_rewrite_session_config()) as sess: 552 v_1 = variables.Variable(50.0, name="v_1") 553 v_2 = variables.Variable(-50.0, name="v_1") 554 delta_1 = constant_op.constant(5.0, name="delta_1") 555 delta_2 = constant_op.constant(-5.0, name="delta_2") 556 inc_v_1 = state_ops.assign_add(v_1, delta_1, name="inc_v_1") 557 inc_v_2 = state_ops.assign_add(v_2, delta_2, name="inc_v_2") 558 559 sess.run([v_1.initializer, v_2.initializer]) 560 561 run_metadata = config_pb2.RunMetadata() 562 run_options = config_pb2.RunOptions(output_partition_graphs=True) 563 debug_utils.watch_graph( 564 run_options, 565 sess.graph, 566 debug_ops=["DebugIdentity(gated_grpc=true)", 567 "DebugNumericSummary(gated_grpc=true)"], 568 debug_urls=[self._debug_server_url_1]) 569 570 for i in xrange(4): 571 self._server_1.clear_data() 572 573 if i % 2 == 0: 574 self._server_1.request_watch("delta_1", 0, "DebugIdentity") 575 self._server_1.request_watch("delta_2", 0, "DebugIdentity") 576 self._server_1.request_unwatch("delta_1", 0, "DebugNumericSummary") 577 self._server_1.request_unwatch("delta_2", 0, "DebugNumericSummary") 578 else: 579 self._server_1.request_unwatch("delta_1", 0, "DebugIdentity") 580 self._server_1.request_unwatch("delta_2", 0, "DebugIdentity") 581 self._server_1.request_watch("delta_1", 0, "DebugNumericSummary") 582 self._server_1.request_watch("delta_2", 0, "DebugNumericSummary") 583 584 sess.run([inc_v_1, inc_v_2], 585 options=run_options, run_metadata=run_metadata) 586 587 # Watched debug tensors are: 588 # Run 0: delta_[1,2]:0:DebugIdentity 589 # Run 1: delta_[1,2]:0:DebugNumericSummary 590 # Run 2: delta_[1,2]:0:DebugIdentity 591 # Run 3: delta_[1,2]:0:DebugNumericSummary 592 self.assertEqual(2, len(self._server_1.debug_tensor_values)) 593 if i % 2 == 0: 594 self.assertAllClose( 595 [5.0], 596 self._server_1.debug_tensor_values["delta_1:0:DebugIdentity"]) 597 self.assertAllClose( 598 [-5.0], 599 self._server_1.debug_tensor_values["delta_2:0:DebugIdentity"]) 600 else: 601 self.assertAllClose( 602 [[1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 5.0, 5.0, 5.0, 603 0.0, 1.0, 0.0]], 604 self._server_1.debug_tensor_values[ 605 "delta_1:0:DebugNumericSummary"]) 606 self.assertAllClose( 607 [[1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, -5.0, -5.0, -5.0, 608 0.0, 1.0, 0.0]], 609 self._server_1.debug_tensor_values[ 610 "delta_2:0:DebugNumericSummary"]) 611 612 def testToggleWatchesOnCoreMetadata(self): 613 (_, debug_server_url, _, server_thread, 614 server) = grpc_debug_test_server.start_server_on_separate_thread( 615 dump_to_filesystem=False, 616 toggle_watch_on_core_metadata=[("toggled_1", 0, "DebugIdentity"), 617 ("toggled_2", 0, "DebugIdentity")]) 618 self._servers_and_threads.append((server, server_thread)) 619 620 with session.Session(config=no_rewrite_session_config()) as sess: 621 v_1 = variables.Variable(50.0, name="v_1") 622 v_2 = variables.Variable(-50.0, name="v_1") 623 # These two nodes have names that match those in the 624 # toggle_watch_on_core_metadata argument used when calling 625 # start_server_on_separate_thread(). 626 toggled_1 = constant_op.constant(5.0, name="toggled_1") 627 toggled_2 = constant_op.constant(-5.0, name="toggled_2") 628 inc_v_1 = state_ops.assign_add(v_1, toggled_1, name="inc_v_1") 629 inc_v_2 = state_ops.assign_add(v_2, toggled_2, name="inc_v_2") 630 631 sess.run([v_1.initializer, v_2.initializer]) 632 633 run_metadata = config_pb2.RunMetadata() 634 run_options = config_pb2.RunOptions(output_partition_graphs=True) 635 debug_utils.watch_graph( 636 run_options, 637 sess.graph, 638 debug_ops=["DebugIdentity(gated_grpc=true)"], 639 debug_urls=[debug_server_url]) 640 641 for i in xrange(4): 642 server.clear_data() 643 644 sess.run([inc_v_1, inc_v_2], 645 options=run_options, run_metadata=run_metadata) 646 647 if i % 2 == 0: 648 self.assertEqual(2, len(server.debug_tensor_values)) 649 self.assertAllClose( 650 [5.0], 651 server.debug_tensor_values["toggled_1:0:DebugIdentity"]) 652 self.assertAllClose( 653 [-5.0], 654 server.debug_tensor_values["toggled_2:0:DebugIdentity"]) 655 else: 656 self.assertEqual(0, len(server.debug_tensor_values)) 657 658 def testToggleEnableTwoDebugWatchesNoCrosstalkBetweenServers(self): 659 with session.Session(config=no_rewrite_session_config()) as sess: 660 v = variables.Variable(50.0, name="v") 661 delta = constant_op.constant(5.0, name="delta") 662 inc_v = state_ops.assign_add(v, delta, name="inc_v") 663 664 sess.run(v.initializer) 665 666 run_metadata = config_pb2.RunMetadata() 667 run_options = config_pb2.RunOptions(output_partition_graphs=True) 668 debug_utils.watch_graph( 669 run_options, 670 sess.graph, 671 debug_ops=["DebugIdentity(gated_grpc=true)"], 672 debug_urls=[self._debug_server_url_1, self._debug_server_url_2]) 673 674 for i in xrange(4): 675 self._server_1.clear_data() 676 self._server_2.clear_data() 677 678 if i % 2 == 0: 679 self._server_1.request_watch("delta", 0, "DebugIdentity") 680 self._server_2.request_watch("v", 0, "DebugIdentity") 681 else: 682 self._server_1.request_unwatch("delta", 0, "DebugIdentity") 683 self._server_2.request_unwatch("v", 0, "DebugIdentity") 684 685 sess.run(inc_v, options=run_options, run_metadata=run_metadata) 686 687 if i % 2 == 0: 688 self.assertEqual(1, len(self._server_1.debug_tensor_values)) 689 self.assertEqual(1, len(self._server_2.debug_tensor_values)) 690 self.assertAllClose( 691 [5.0], 692 self._server_1.debug_tensor_values["delta:0:DebugIdentity"]) 693 self.assertAllClose( 694 [50 + 5.0 * i], 695 self._server_2.debug_tensor_values["v:0:DebugIdentity"]) 696 else: 697 self.assertEqual(0, len(self._server_1.debug_tensor_values)) 698 self.assertEqual(0, len(self._server_2.debug_tensor_values)) 699 700 def testToggleBreakpointsWorks(self): 701 with session.Session(config=no_rewrite_session_config()) as sess: 702 v_1 = variables.Variable(50.0, name="v_1") 703 v_2 = variables.Variable(-50.0, name="v_2") 704 delta_1 = constant_op.constant(5.0, name="delta_1") 705 delta_2 = constant_op.constant(-5.0, name="delta_2") 706 inc_v_1 = state_ops.assign_add(v_1, delta_1, name="inc_v_1") 707 inc_v_2 = state_ops.assign_add(v_2, delta_2, name="inc_v_2") 708 709 sess.run([v_1.initializer, v_2.initializer]) 710 711 run_metadata = config_pb2.RunMetadata() 712 run_options = config_pb2.RunOptions(output_partition_graphs=True) 713 debug_utils.watch_graph( 714 run_options, 715 sess.graph, 716 debug_ops=["DebugIdentity(gated_grpc=true)"], 717 debug_urls=[self._debug_server_url_1]) 718 719 for i in xrange(4): 720 self._server_1.clear_data() 721 722 if i in (0, 2): 723 # Enable breakpoint at delta_[1,2]:0:DebugIdentity in runs 0 and 2. 724 self._server_1.request_watch( 725 "delta_1", 0, "DebugIdentity", breakpoint=True) 726 self._server_1.request_watch( 727 "delta_2", 0, "DebugIdentity", breakpoint=True) 728 else: 729 # Disable the breakpoint in runs 1 and 3. 730 self._server_1.request_unwatch("delta_1", 0, "DebugIdentity") 731 self._server_1.request_unwatch("delta_2", 0, "DebugIdentity") 732 733 output = sess.run([inc_v_1, inc_v_2], 734 options=run_options, run_metadata=run_metadata) 735 self.assertAllClose([50.0 + 5.0 * (i + 1), -50 - 5.0 * (i + 1)], output) 736 737 if i in (0, 2): 738 # During runs 0 and 2, the server should have received the published 739 # debug tensor delta:0:DebugIdentity. The breakpoint should have been 740 # unblocked by EventReply reponses from the server. 741 self.assertAllClose( 742 [5.0], 743 self._server_1.debug_tensor_values["delta_1:0:DebugIdentity"]) 744 self.assertAllClose( 745 [-5.0], 746 self._server_1.debug_tensor_values["delta_2:0:DebugIdentity"]) 747 # After the runs, the server should have properly registered the 748 # breakpoints due to the request_unwatch calls. 749 self.assertSetEqual({("delta_1", 0, "DebugIdentity"), 750 ("delta_2", 0, "DebugIdentity")}, 751 self._server_1.breakpoints) 752 else: 753 # After the end of runs 1 and 3, the server has received the requests 754 # to disable the breakpoint at delta:0:DebugIdentity. 755 self.assertSetEqual(set(), self._server_1.breakpoints) 756 757 def testTensorBoardDebuggerWrapperToggleBreakpointsWorks(self): 758 with session.Session(config=no_rewrite_session_config()) as sess: 759 v_1 = variables.Variable(50.0, name="v_1") 760 v_2 = variables.Variable(-50.0, name="v_2") 761 delta_1 = constant_op.constant(5.0, name="delta_1") 762 delta_2 = constant_op.constant(-5.0, name="delta_2") 763 inc_v_1 = state_ops.assign_add(v_1, delta_1, name="inc_v_1") 764 inc_v_2 = state_ops.assign_add(v_2, delta_2, name="inc_v_2") 765 766 sess.run([v_1.initializer, v_2.initializer]) 767 768 # The TensorBoardDebugWrapperSession should add a DebugIdentity debug op 769 # with attribute gated_grpc=True for every tensor in the graph. 770 sess = grpc_wrapper.TensorBoardDebugWrapperSession( 771 sess, self._debug_server_url_1) 772 773 for i in xrange(4): 774 self._server_1.clear_data() 775 776 if i in (0, 2): 777 # Enable breakpoint at delta_[1,2]:0:DebugIdentity in runs 0 and 2. 778 self._server_1.request_watch( 779 "delta_1", 0, "DebugIdentity", breakpoint=True) 780 self._server_1.request_watch( 781 "delta_2", 0, "DebugIdentity", breakpoint=True) 782 else: 783 # Disable the breakpoint in runs 1 and 3. 784 self._server_1.request_unwatch("delta_1", 0, "DebugIdentity") 785 self._server_1.request_unwatch("delta_2", 0, "DebugIdentity") 786 787 output = sess.run([inc_v_1, inc_v_2]) 788 self.assertAllClose([50.0 + 5.0 * (i + 1), -50 - 5.0 * (i + 1)], output) 789 790 if i in (0, 2): 791 # During runs 0 and 2, the server should have received the published 792 # debug tensor delta:0:DebugIdentity. The breakpoint should have been 793 # unblocked by EventReply reponses from the server. 794 self.assertAllClose( 795 [5.0], 796 self._server_1.debug_tensor_values["delta_1:0:DebugIdentity"]) 797 self.assertAllClose( 798 [-5.0], 799 self._server_1.debug_tensor_values["delta_2:0:DebugIdentity"]) 800 # After the runs, the server should have properly registered the 801 # breakpoints. 802 else: 803 # After the end of runs 1 and 3, the server has received the requests 804 # to disable the breakpoint at delta:0:DebugIdentity. 805 self.assertSetEqual(set(), self._server_1.breakpoints) 806 807 if i == 0: 808 # Check that the server has received the stack trace. 809 self.assertTrue(self._server_1.query_op_traceback("delta_1")) 810 self.assertTrue(self._server_1.query_op_traceback("delta_2")) 811 self.assertTrue(self._server_1.query_op_traceback("inc_v_1")) 812 self.assertTrue(self._server_1.query_op_traceback("inc_v_2")) 813 # Check that the server has received the python file content. 814 # Query an arbitrary line to make sure that is the case. 815 with open(__file__, "rt") as this_source_file: 816 first_line = this_source_file.readline().strip() 817 self.assertEqual( 818 first_line, self._server_1.query_source_file_line(__file__, 1)) 819 else: 820 # In later Session.run() calls, the traceback shouldn't have been sent 821 # because it is already sent in the 1st call. So calling 822 # query_op_traceback() should lead to an exception, because the test 823 # debug server clears the data at the beginning of every iteration. 824 with self.assertRaises(ValueError): 825 self._server_1.query_op_traceback("delta_1") 826 with self.assertRaises(ValueError): 827 self._server_1.query_source_file_line(__file__, 1) 828 829 def testTensorBoardDebuggerWrapperDisablingTracebackSourceSendingWorks(self): 830 with session.Session(config=no_rewrite_session_config()) as sess: 831 v_1 = variables.Variable(50.0, name="v_1") 832 v_2 = variables.Variable(-50.0, name="v_2") 833 delta_1 = constant_op.constant(5.0, name="delta_1") 834 delta_2 = constant_op.constant(-5.0, name="delta_2") 835 inc_v_1 = state_ops.assign_add(v_1, delta_1, name="inc_v_1") 836 inc_v_2 = state_ops.assign_add(v_2, delta_2, name="inc_v_2") 837 838 sess.run(variables.global_variables_initializer()) 839 840 # Disable the sending of traceback and source code. 841 sess = grpc_wrapper.TensorBoardDebugWrapperSession( 842 sess, self._debug_server_url_1, send_traceback_and_source_code=False) 843 844 for i in xrange(4): 845 self._server_1.clear_data() 846 847 if i == 0: 848 self._server_1.request_watch( 849 "delta_1", 0, "DebugIdentity", breakpoint=True) 850 851 output = sess.run([inc_v_1, inc_v_2]) 852 self.assertAllClose([50.0 + 5.0 * (i + 1), -50 - 5.0 * (i + 1)], output) 853 854 # No op traceback or source code should have been received by the debug 855 # server due to the disabling above. 856 with self.assertRaisesRegexp( 857 ValueError, r"Op .*delta_1.* does not exist"): 858 self.assertTrue(self._server_1.query_op_traceback("delta_1")) 859 with self.assertRaisesRegexp( 860 ValueError, r".* has not received any source file"): 861 self._server_1.query_source_file_line(__file__, 1) 862 863 def testGetGrpcDebugWatchesReturnsCorrectAnswer(self): 864 with session.Session() as sess: 865 v = variables.Variable(50.0, name="v") 866 delta = constant_op.constant(5.0, name="delta") 867 inc_v = state_ops.assign_add(v, delta, name="inc_v") 868 869 sess.run(v.initializer) 870 871 # Before any debugged runs, the server should be aware of no debug 872 # watches. 873 self.assertEqual([], self._server_1.gated_grpc_debug_watches()) 874 875 run_metadata = config_pb2.RunMetadata() 876 run_options = config_pb2.RunOptions(output_partition_graphs=True) 877 debug_utils.add_debug_tensor_watch( 878 run_options, "delta", output_slot=0, 879 debug_ops=["DebugNumericSummary(gated_grpc=true)"], 880 debug_urls=[self._debug_server_url_1]) 881 debug_utils.add_debug_tensor_watch( 882 run_options, "v", output_slot=0, 883 debug_ops=["DebugIdentity"], 884 debug_urls=[self._debug_server_url_1]) 885 sess.run(inc_v, options=run_options, run_metadata=run_metadata) 886 887 # After the first run, the server should have noted the debug watches 888 # for which gated_grpc == True, but not the ones with gated_grpc == False. 889 self.assertEqual(1, len(self._server_1.gated_grpc_debug_watches())) 890 debug_watch = self._server_1.gated_grpc_debug_watches()[0] 891 self.assertEqual("delta", debug_watch.node_name) 892 self.assertEqual(0, debug_watch.output_slot) 893 self.assertEqual("DebugNumericSummary", debug_watch.debug_op) 894 895 896 class DelayedDebugServerTest(test_util.TensorFlowTestCase): 897 898 def testDebuggedSessionRunWorksWithDelayedDebugServerStartup(self): 899 """Test debugged Session.run() tolerates delayed debug server startup.""" 900 ops.reset_default_graph() 901 902 # Start a debug server asynchronously, with a certain amount of delay. 903 (debug_server_port, _, _, server_thread, 904 debug_server) = grpc_debug_test_server.start_server_on_separate_thread( 905 server_start_delay_sec=2.0, dump_to_filesystem=False) 906 907 with self.test_session() as sess: 908 a_init = constant_op.constant(42.0, name="a_init") 909 a = variables.Variable(a_init, name="a") 910 911 def watch_fn(fetches, feeds): 912 del fetches, feeds 913 return framework.WatchOptions(debug_ops=["DebugIdentity"]) 914 915 sess = grpc_wrapper.GrpcDebugWrapperSession( 916 sess, "localhost:%d" % debug_server_port, watch_fn=watch_fn) 917 sess.run(a.initializer) 918 self.assertAllClose( 919 [42.0], debug_server.debug_tensor_values["a_init:0:DebugIdentity"]) 920 921 debug_server.stop_server().wait() 922 server_thread.join() 923 924 925 if __name__ == "__main__": 926 googletest.main() 927