1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for debugger functionalities in tf.Session with grpc:// URLs. 16 17 This test focus on grpc:// debugging of distributed (gRPC) sessions. 18 """ 19 from __future__ import absolute_import 20 from __future__ import division 21 from __future__ import print_function 22 23 import json 24 import subprocess 25 import sys 26 import time 27 28 import portpicker 29 from six.moves import xrange # pylint: disable=redefined-builtin 30 31 from tensorflow.core.protobuf import config_pb2 32 from tensorflow.python.client import session 33 from tensorflow.python.debug.lib import debug_utils 34 from tensorflow.python.debug.lib import grpc_debug_test_server 35 from tensorflow.python.debug.wrappers import framework 36 from tensorflow.python.debug.wrappers import grpc_wrapper 37 from tensorflow.python.framework import ops 38 from tensorflow.python.framework import test_util 39 from tensorflow.python.ops import math_ops 40 from tensorflow.python.ops import state_ops 41 from tensorflow.python.ops import variables 42 from tensorflow.python.platform import googletest 43 from tensorflow.python.platform import test 44 from tensorflow.python.platform import tf_logging 45 46 47 class DistributedSessionDebugTest(test_util.TensorFlowTestCase): 48 """Test the debugging of distributed sessions.""" 49 50 PER_PROC_GPU_MEMORY_FRACTION = 0.1 51 POLLING_INTERVAL_SEC = 0.025 52 53 @classmethod 54 def setUpClass(cls): 55 gpu_memory_fraction_opt = ( 56 "--gpu_memory_fraction=%f" % cls.PER_PROC_GPU_MEMORY_FRACTION) 57 58 worker_port = portpicker.pick_unused_port() 59 cluster_spec = "worker|localhost:%d" % worker_port 60 tf_logging.info("cluster_spec: %s", cluster_spec) 61 62 server_bin = test.test_src_dir_path( 63 "tools/dist_test/server/grpc_tensorflow_server") 64 65 cls.server_target = "grpc://localhost:%d" % worker_port 66 67 cls.server_procs = {} 68 cls.server_procs["worker"] = subprocess.Popen( 69 [ 70 server_bin, 71 "--cluster_spec=%s" % cluster_spec, 72 "--job_name=worker", 73 "--task_id=0", 74 gpu_memory_fraction_opt, 75 ], 76 stdout=sys.stdout, 77 stderr=sys.stderr) 78 79 # Start debug server in-process, on separate thread. 80 (cls.debug_server_port, cls.debug_server_url, _, cls.debug_server_thread, 81 cls.debug_server 82 ) = grpc_debug_test_server.start_server_on_separate_thread( 83 dump_to_filesystem=False) 84 tf_logging.info("debug server url: %s", cls.debug_server_url) 85 86 cls.session_config = config_pb2.ConfigProto( 87 gpu_options=config_pb2.GPUOptions( 88 per_process_gpu_memory_fraction=cls.PER_PROC_GPU_MEMORY_FRACTION)) 89 90 @classmethod 91 def tearDownClass(cls): 92 for key in cls.server_procs: 93 cls.server_procs[key].terminate() 94 cls.debug_server.stop_server().wait() 95 cls.debug_server_thread.join() 96 97 def setUp(self): 98 pass 99 100 def tearDown(self): 101 self.debug_server.clear_data() 102 103 def _pollingAssertDebugTensorValuesAllClose(self, expected_values, 104 debug_tensor_name): 105 """Poll debug_server till tensor appears and matches expected values.""" 106 while (debug_tensor_name not in self.debug_server.debug_tensor_values or 107 len(self.debug_server.debug_tensor_values) < len(expected_values)): 108 time.sleep(self.POLLING_INTERVAL_SEC) 109 self.assertAllClose( 110 expected_values, 111 self.debug_server.debug_tensor_values[debug_tensor_name]) 112 113 def _createGraph(self): 114 """Create graph for testing. 115 116 Returns: 117 Python Graph object. 118 """ 119 with ops.Graph().as_default() as graph: 120 with ops.device("/job:worker/task:0/cpu:0"): 121 self.a = variables.Variable(10.0, name="a") 122 self.b = variables.Variable(100.0, name="b") 123 self.inc_a = state_ops.assign_add(self.a, 2.0, name="inc_a") 124 self.dec_b = state_ops.assign_add(self.b, -5.0, name="dec_b") 125 self.p = math_ops.multiply(self.inc_a, self.dec_b, name="p") 126 self.q = math_ops.negative(self.p, name="q") 127 return graph 128 129 def testDistributedRunWithGatedGrpcCommunicatesWithDebugServerCorrectly(self): 130 graph = self._createGraph() 131 with session.Session( 132 config=self.session_config, graph=graph, 133 target=self.server_target) as sess: 134 sess.run(self.a.initializer) 135 sess.run(self.b.initializer) 136 137 run_options = config_pb2.RunOptions() 138 debug_utils.watch_graph( 139 run_options, 140 sess.graph, 141 node_name_regex_whitelist=r"a", 142 debug_ops=["DebugIdentity"], 143 debug_urls=[self.debug_server_url]) 144 145 # Test gated_grpc for an op located on the worker, i.e., on the same 146 # host as where MasterSession is. 147 # TODO(cais): gRPC gating of debug ops does not work on partition graphs 148 # not located on MasterSession hosts (e.g., parameter servers) yet. Make 149 # it work. 150 debug_utils.watch_graph( 151 run_options, 152 sess.graph, 153 node_name_regex_whitelist=r"p", 154 debug_ops=["DebugIdentity(gated_grpc=True)"], 155 debug_urls=[self.debug_server_url]) 156 157 for i in xrange(4): 158 if i % 2 == 0: 159 self.debug_server.request_watch("p", 0, "DebugIdentity") 160 else: 161 self.debug_server.request_unwatch("p", 0, "DebugIdentity") 162 163 expected_p = (10.0 + 2.0 * (i + 1)) * (100.0 - 5.0 * (i + 1)) 164 self.assertAllClose(-expected_p, sess.run(self.q, options=run_options)) 165 166 self.assertEqual(1, len(self.debug_server.core_metadata_json_strings)) 167 core_metadata = json.loads( 168 self.debug_server.core_metadata_json_strings[0]) 169 self.assertEqual([], core_metadata["input_names"]) 170 self.assertEqual(["q:0"], core_metadata["output_names"]) 171 self.assertEqual(i, core_metadata["executor_step_index"]) 172 173 if i == 0: 174 self.assertEqual(1, len(self.debug_server.partition_graph_defs)) 175 176 # Tensor "a" is from a PS. It may take longer to arrive due to the fact 177 # that the stream connection between the PS and the debug server is 178 # persistent and not torn down at the end of each Session.run() 179 self._pollingAssertDebugTensorValuesAllClose([10.0 + 2.0 * i], 180 "a:0:DebugIdentity") 181 182 # Due to the gRPC gating of the debug op for "p", the debug tensor 183 # should be available on odd-indexed runs. 184 if i % 2 == 0: 185 self.assertAllClose( 186 [expected_p], 187 self.debug_server.debug_tensor_values["p:0:DebugIdentity"]) 188 else: 189 self.assertNotIn("p:0:DebugIdentity", 190 self.debug_server.debug_tensor_values) 191 192 self.assertNotIn("b:0:DebugIdentity", 193 self.debug_server.debug_tensor_values) 194 self.debug_server.clear_data() 195 196 def testDistributedRunWithGrpcDebugWrapperWorks(self): 197 graph = self._createGraph() 198 with session.Session( 199 config=self.session_config, graph=graph, 200 target=self.server_target) as sess: 201 sess.run(self.a.initializer) 202 sess.run(self.b.initializer) 203 204 def watch_fn(feeds, fetch_keys): 205 del feeds, fetch_keys 206 return framework.WatchOptions( 207 debug_ops=["DebugIdentity"], 208 node_name_regex_whitelist=r"p") 209 sess = grpc_wrapper.GrpcDebugWrapperSession( 210 sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn) 211 212 for i in xrange(4): 213 expected_p = (10.0 + 2.0 * (i + 1)) * (100.0 - 5.0 * (i + 1)) 214 self.assertAllClose(-expected_p, sess.run(self.q)) 215 216 if i == 0: 217 self.assertEqual(1, len(self.debug_server.partition_graph_defs)) 218 219 self.assertAllClose( 220 [expected_p], 221 self.debug_server.debug_tensor_values["p:0:DebugIdentity"]) 222 self.assertNotIn("b:0:DebugIdentity", 223 self.debug_server.debug_tensor_values) 224 self.debug_server.clear_data() 225 226 227 if __name__ == "__main__": 228 googletest.main() 229