1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """GRPC debug server for testing.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import collections 21 import errno 22 import functools 23 import hashlib 24 import json 25 import os 26 import re 27 import shutil 28 import tempfile 29 import threading 30 import time 31 32 import portpicker 33 34 from tensorflow.core.debug import debug_service_pb2 35 from tensorflow.core.protobuf import config_pb2 36 from tensorflow.core.util import event_pb2 37 from tensorflow.python.client import session 38 from tensorflow.python.debug.lib import debug_data 39 from tensorflow.python.debug.lib import debug_utils 40 from tensorflow.python.debug.lib import grpc_debug_server 41 from tensorflow.python.framework import constant_op 42 from tensorflow.python.framework import errors 43 from tensorflow.python.ops import variables 44 from tensorflow.python.util import compat 45 46 47 def _get_dump_file_path(dump_root, device_name, debug_node_name): 48 """Get the file path of the dump file for a debug node. 49 50 Args: 51 dump_root: (str) Root dump directory. 52 device_name: (str) Name of the device that the debug node resides on. 53 debug_node_name: (str) Name of the debug node, e.g., 54 cross_entropy/Log:0:DebugIdentity. 55 56 Returns: 57 (str) Full path of the dump file. 58 """ 59 60 dump_root = os.path.join( 61 dump_root, debug_data.device_name_to_device_path(device_name)) 62 if "/" in debug_node_name: 63 dump_dir = os.path.join(dump_root, os.path.dirname(debug_node_name)) 64 dump_file_name = re.sub(":", "_", os.path.basename(debug_node_name)) 65 else: 66 dump_dir = dump_root 67 dump_file_name = re.sub(":", "_", debug_node_name) 68 69 now_microsec = int(round(time.time() * 1000 * 1000)) 70 dump_file_name += "_%d" % now_microsec 71 72 return os.path.join(dump_dir, dump_file_name) 73 74 75 class EventListenerTestStreamHandler( 76 grpc_debug_server.EventListenerBaseStreamHandler): 77 """Implementation of EventListenerBaseStreamHandler that dumps to file.""" 78 79 def __init__(self, dump_dir, event_listener_servicer): 80 super(EventListenerTestStreamHandler, self).__init__() 81 self._dump_dir = dump_dir 82 self._event_listener_servicer = event_listener_servicer 83 if self._dump_dir: 84 self._try_makedirs(self._dump_dir) 85 86 self._grpc_path = None 87 self._cached_graph_defs = [] 88 self._cached_graph_def_device_names = [] 89 self._cached_graph_def_wall_times = [] 90 91 def on_core_metadata_event(self, event): 92 self._event_listener_servicer.toggle_watch() 93 94 core_metadata = json.loads(event.log_message.message) 95 96 if not self._grpc_path: 97 grpc_path = core_metadata["grpc_path"] 98 if grpc_path: 99 if grpc_path.startswith("/"): 100 grpc_path = grpc_path[1:] 101 if self._dump_dir: 102 self._dump_dir = os.path.join(self._dump_dir, grpc_path) 103 104 # Write cached graph defs to filesystem. 105 for graph_def, device_name, wall_time in zip( 106 self._cached_graph_defs, 107 self._cached_graph_def_device_names, 108 self._cached_graph_def_wall_times): 109 self._write_graph_def(graph_def, device_name, wall_time) 110 111 if self._dump_dir: 112 self._write_core_metadata_event(event) 113 else: 114 self._event_listener_servicer.core_metadata_json_strings.append( 115 event.log_message.message) 116 117 def on_graph_def(self, graph_def, device_name, wall_time): 118 """Implementation of the tensor value-carrying Event proto callback. 119 120 Args: 121 graph_def: A GraphDef object. 122 device_name: Name of the device on which the graph was created. 123 wall_time: An epoch timestamp (in microseconds) for the graph. 124 """ 125 if self._dump_dir: 126 if self._grpc_path: 127 self._write_graph_def(graph_def, device_name, wall_time) 128 else: 129 self._cached_graph_defs.append(graph_def) 130 self._cached_graph_def_device_names.append(device_name) 131 self._cached_graph_def_wall_times.append(wall_time) 132 else: 133 self._event_listener_servicer.partition_graph_defs.append(graph_def) 134 135 def on_value_event(self, event): 136 """Implementation of the tensor value-carrying Event proto callback. 137 138 Writes the Event proto to the file system for testing. The path written to 139 follows the same pattern as the file:// debug URLs of tfdbg, i.e., the 140 name scope of the op becomes the directory structure under the dump root 141 directory. 142 143 Args: 144 event: The Event proto carrying a tensor value. 145 146 Returns: 147 If the debug node belongs to the set of currently activated breakpoints, 148 a `EventReply` proto will be returned. 149 """ 150 if self._dump_dir: 151 self._write_value_event(event) 152 else: 153 value = event.summary.value[0] 154 tensor_value = debug_data.load_tensor_from_event(event) 155 self._event_listener_servicer.debug_tensor_values[value.node_name].append( 156 tensor_value) 157 158 items = event.summary.value[0].node_name.split(":") 159 node_name = items[0] 160 output_slot = int(items[1]) 161 debug_op = items[2] 162 if ((node_name, output_slot, debug_op) in 163 self._event_listener_servicer.breakpoints): 164 return debug_service_pb2.EventReply() 165 166 def _try_makedirs(self, dir_path): 167 if not os.path.isdir(dir_path): 168 try: 169 os.makedirs(dir_path) 170 except OSError as error: 171 if error.errno != errno.EEXIST: 172 raise 173 174 def _write_core_metadata_event(self, event): 175 core_metadata_path = os.path.join( 176 self._dump_dir, 177 debug_data.METADATA_FILE_PREFIX + debug_data.CORE_METADATA_TAG + 178 "_%d" % event.wall_time) 179 self._try_makedirs(self._dump_dir) 180 with open(core_metadata_path, "wb") as f: 181 f.write(event.SerializeToString()) 182 183 def _write_graph_def(self, graph_def, device_name, wall_time): 184 encoded_graph_def = graph_def.SerializeToString() 185 graph_hash = int(hashlib.md5(encoded_graph_def).hexdigest(), 16) 186 event = event_pb2.Event(graph_def=encoded_graph_def, wall_time=wall_time) 187 graph_file_path = os.path.join( 188 self._dump_dir, 189 debug_data.device_name_to_device_path(device_name), 190 debug_data.METADATA_FILE_PREFIX + debug_data.GRAPH_FILE_TAG + 191 debug_data.HASH_TAG + "%d_%d" % (graph_hash, wall_time)) 192 self._try_makedirs(os.path.dirname(graph_file_path)) 193 with open(graph_file_path, "wb") as f: 194 f.write(event.SerializeToString()) 195 196 def _write_value_event(self, event): 197 value = event.summary.value[0] 198 199 # Obtain the device name from the metadata. 200 summary_metadata = event.summary.value[0].metadata 201 if not summary_metadata.plugin_data: 202 raise ValueError("The value lacks plugin data.") 203 try: 204 content = json.loads(compat.as_text(summary_metadata.plugin_data.content)) 205 except ValueError as err: 206 raise ValueError("Could not parse content into JSON: %r, %r" % (content, 207 err)) 208 device_name = content["device"] 209 210 dump_full_path = _get_dump_file_path( 211 self._dump_dir, device_name, value.node_name) 212 self._try_makedirs(os.path.dirname(dump_full_path)) 213 with open(dump_full_path, "wb") as f: 214 f.write(event.SerializeToString()) 215 216 217 class EventListenerTestServicer(grpc_debug_server.EventListenerBaseServicer): 218 """An implementation of EventListenerBaseServicer for testing.""" 219 220 def __init__(self, server_port, dump_dir, toggle_watch_on_core_metadata=None): 221 """Constructor of EventListenerTestServicer. 222 223 Args: 224 server_port: (int) The server port number. 225 dump_dir: (str) The root directory to which the data files will be 226 dumped. If empty or None, the received debug data will not be dumped 227 to the file system: they will be stored in memory instead. 228 toggle_watch_on_core_metadata: A list of 229 (node_name, output_slot, debug_op) tuples to toggle the 230 watchpoint status during the on_core_metadata calls (optional). 231 """ 232 self.core_metadata_json_strings = [] 233 self.partition_graph_defs = [] 234 self.debug_tensor_values = collections.defaultdict(list) 235 self._initialize_toggle_watch_state(toggle_watch_on_core_metadata) 236 237 grpc_debug_server.EventListenerBaseServicer.__init__( 238 self, server_port, 239 functools.partial(EventListenerTestStreamHandler, dump_dir, self)) 240 241 # Members for storing the graph ops traceback and source files. 242 self._call_types = [] 243 self._call_keys = [] 244 self._origin_stacks = [] 245 self._origin_id_to_strings = [] 246 self._graph_tracebacks = [] 247 self._graph_versions = [] 248 self._source_files = None 249 250 def _initialize_toggle_watch_state(self, toggle_watches): 251 self._toggle_watches = toggle_watches 252 self._toggle_watch_state = dict() 253 if self._toggle_watches: 254 for watch_key in self._toggle_watches: 255 self._toggle_watch_state[watch_key] = False 256 257 def toggle_watch(self): 258 for watch_key in self._toggle_watch_state: 259 node_name, output_slot, debug_op = watch_key 260 if self._toggle_watch_state[watch_key]: 261 self.request_unwatch(node_name, output_slot, debug_op) 262 else: 263 self.request_watch(node_name, output_slot, debug_op) 264 self._toggle_watch_state[watch_key] = ( 265 not self._toggle_watch_state[watch_key]) 266 267 def clear_data(self): 268 self.core_metadata_json_strings = [] 269 self.partition_graph_defs = [] 270 self.debug_tensor_values = collections.defaultdict(list) 271 self._call_types = [] 272 self._call_keys = [] 273 self._origin_stacks = [] 274 self._origin_id_to_strings = [] 275 self._graph_tracebacks = [] 276 self._graph_versions = [] 277 self._source_files = None 278 279 def SendTracebacks(self, request, context): 280 self._call_types.append(request.call_type) 281 self._call_keys.append(request.call_key) 282 self._origin_stacks.append(request.origin_stack) 283 self._origin_id_to_strings.append(request.origin_id_to_string) 284 self._graph_tracebacks.append(request.graph_traceback) 285 self._graph_versions.append(request.graph_version) 286 return debug_service_pb2.EventReply() 287 288 def SendSourceFiles(self, request, context): 289 self._source_files = request 290 return debug_service_pb2.EventReply() 291 292 def query_op_traceback(self, op_name): 293 """Query the traceback of an op. 294 295 Args: 296 op_name: Name of the op to query. 297 298 Returns: 299 The traceback of the op, as a list of 3-tuples: 300 (filename, lineno, function_name) 301 302 Raises: 303 ValueError: If the op cannot be found in the tracebacks received by the 304 server so far. 305 """ 306 for op_log_proto in self._graph_tracebacks: 307 for log_entry in op_log_proto.log_entries: 308 if log_entry.name == op_name: 309 return self._code_def_to_traceback(log_entry.code_def, 310 op_log_proto.id_to_string) 311 raise ValueError( 312 "Op '%s' does not exist in the tracebacks received by the debug " 313 "server." % op_name) 314 315 def query_origin_stack(self): 316 """Query the stack of the origin of the execution call. 317 318 Returns: 319 A `list` of all tracebacks. Each item corresponds to an execution call, 320 i.e., a `SendTracebacks` request. Each item is a `list` of 3-tuples: 321 (filename, lineno, function_name). 322 """ 323 ret = [] 324 for stack, id_to_string in zip( 325 self._origin_stacks, self._origin_id_to_strings): 326 ret.append(self._code_def_to_traceback(stack, id_to_string)) 327 return ret 328 329 def query_call_types(self): 330 return self._call_types 331 332 def query_call_keys(self): 333 return self._call_keys 334 335 def query_graph_versions(self): 336 return self._graph_versions 337 338 def query_source_file_line(self, file_path, lineno): 339 """Query the content of a given line in a source file. 340 341 Args: 342 file_path: Path to the source file. 343 lineno: Line number as an `int`. 344 345 Returns: 346 Content of the line as a string. 347 348 Raises: 349 ValueError: If no source file is found at the given file_path. 350 """ 351 if not self._source_files: 352 raise ValueError( 353 "This debug server has not received any source file contents yet.") 354 for source_file_proto in self._source_files.source_files: 355 if source_file_proto.file_path == file_path: 356 return source_file_proto.lines[lineno - 1] 357 raise ValueError( 358 "Source file at path %s has not been received by the debug server", 359 file_path) 360 361 def _code_def_to_traceback(self, code_def, id_to_string): 362 return [(id_to_string[trace.file_id], 363 trace.lineno, 364 id_to_string[trace.function_id]) for trace in code_def.traces] 365 366 367 def start_server_on_separate_thread(dump_to_filesystem=True, 368 server_start_delay_sec=0.0, 369 poll_server=False, 370 blocking=True, 371 toggle_watch_on_core_metadata=None): 372 """Create a test gRPC debug server and run on a separate thread. 373 374 Args: 375 dump_to_filesystem: (bool) whether the debug server will dump debug data 376 to the filesystem. 377 server_start_delay_sec: (float) amount of time (in sec) to delay the server 378 start up for. 379 poll_server: (bool) whether the server will be polled till success on 380 startup. 381 blocking: (bool) whether the server should be started in a blocking mode. 382 toggle_watch_on_core_metadata: A list of 383 (node_name, output_slot, debug_op) tuples to toggle the 384 watchpoint status during the on_core_metadata calls (optional). 385 386 Returns: 387 server_port: (int) Port on which the server runs. 388 debug_server_url: (str) grpc:// URL to the server. 389 server_dump_dir: (str) The debug server's dump directory. 390 server_thread: The server Thread object. 391 server: The `EventListenerTestServicer` object. 392 393 Raises: 394 ValueError: If polling the server process for ready state is not successful 395 within maximum polling count. 396 """ 397 server_port = portpicker.pick_unused_port() 398 debug_server_url = "grpc://localhost:%d" % server_port 399 400 server_dump_dir = tempfile.mkdtemp() if dump_to_filesystem else None 401 server = EventListenerTestServicer( 402 server_port=server_port, 403 dump_dir=server_dump_dir, 404 toggle_watch_on_core_metadata=toggle_watch_on_core_metadata) 405 406 def delay_then_run_server(): 407 time.sleep(server_start_delay_sec) 408 server.run_server(blocking=blocking) 409 410 server_thread = threading.Thread(target=delay_then_run_server) 411 server_thread.start() 412 413 if poll_server: 414 if not _poll_server_till_success( 415 50, 416 0.2, 417 debug_server_url, 418 server_dump_dir, 419 server, 420 gpu_memory_fraction=0.1): 421 raise ValueError( 422 "Failed to start test gRPC debug server at port %d" % server_port) 423 server.clear_data() 424 return server_port, debug_server_url, server_dump_dir, server_thread, server 425 426 427 def _poll_server_till_success(max_attempts, 428 sleep_per_poll_sec, 429 debug_server_url, 430 dump_dir, 431 server, 432 gpu_memory_fraction=1.0): 433 """Poll server until success or exceeding max polling count. 434 435 Args: 436 max_attempts: (int) How many times to poll at maximum 437 sleep_per_poll_sec: (float) How many seconds to sleep for after each 438 unsuccessful poll. 439 debug_server_url: (str) gRPC URL to the debug server. 440 dump_dir: (str) Dump directory to look for files in. If None, will directly 441 check data from the server object. 442 server: The server object. 443 gpu_memory_fraction: (float) Fraction of GPU memory to be 444 allocated for the Session used in server polling. 445 446 Returns: 447 (bool) Whether the polling succeeded within max_polls attempts. 448 """ 449 poll_count = 0 450 451 config = config_pb2.ConfigProto(gpu_options=config_pb2.GPUOptions( 452 per_process_gpu_memory_fraction=gpu_memory_fraction)) 453 with session.Session(config=config) as sess: 454 for poll_count in range(max_attempts): 455 server.clear_data() 456 print("Polling: poll_count = %d" % poll_count) 457 458 x_init_name = "x_init_%d" % poll_count 459 x_init = constant_op.constant([42.0], shape=[1], name=x_init_name) 460 x = variables.Variable(x_init, name=x_init_name) 461 462 run_options = config_pb2.RunOptions() 463 debug_utils.add_debug_tensor_watch( 464 run_options, x_init_name, 0, debug_urls=[debug_server_url]) 465 try: 466 sess.run(x.initializer, options=run_options) 467 except errors.FailedPreconditionError: 468 pass 469 470 if dump_dir: 471 if os.path.isdir( 472 dump_dir) and debug_data.DebugDumpDir(dump_dir).size > 0: 473 shutil.rmtree(dump_dir) 474 print("Poll succeeded.") 475 return True 476 else: 477 print("Poll failed. Sleeping for %f s" % sleep_per_poll_sec) 478 time.sleep(sleep_per_poll_sec) 479 else: 480 if server.debug_tensor_values: 481 print("Poll succeeded.") 482 return True 483 else: 484 print("Poll failed. Sleeping for %f s" % sleep_per_poll_sec) 485 time.sleep(sleep_per_poll_sec) 486 487 return False 488