1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """GRPC debug server for testing.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import collections 21 import errno 22 import functools 23 import hashlib 24 import json 25 import os 26 import re 27 import shutil 28 import tempfile 29 import threading 30 import time 31 32 import portpicker 33 34 from tensorflow.core.debug import debug_service_pb2 35 from tensorflow.core.protobuf import config_pb2 36 from tensorflow.core.util import event_pb2 37 from tensorflow.python.client import session 38 from tensorflow.python.debug.lib import debug_data 39 from tensorflow.python.debug.lib import debug_utils 40 from tensorflow.python.debug.lib import grpc_debug_server 41 from tensorflow.python.framework import constant_op 42 from tensorflow.python.framework import errors 43 from tensorflow.python.ops import variables 44 from tensorflow.python.util import compat 45 46 47 def _get_dump_file_path(dump_root, device_name, debug_node_name): 48 """Get the file path of the dump file for a debug node. 49 50 Args: 51 dump_root: (str) Root dump directory. 52 device_name: (str) Name of the device that the debug node resides on. 53 debug_node_name: (str) Name of the debug node, e.g., 54 cross_entropy/Log:0:DebugIdentity. 55 56 Returns: 57 (str) Full path of the dump file. 58 """ 59 60 dump_root = os.path.join( 61 dump_root, debug_data.device_name_to_device_path(device_name)) 62 if "/" in debug_node_name: 63 dump_dir = os.path.join(dump_root, os.path.dirname(debug_node_name)) 64 dump_file_name = re.sub(":", "_", os.path.basename(debug_node_name)) 65 else: 66 dump_dir = dump_root 67 dump_file_name = re.sub(":", "_", debug_node_name) 68 69 now_microsec = int(round(time.time() * 1000 * 1000)) 70 dump_file_name += "_%d" % now_microsec 71 72 return os.path.join(dump_dir, dump_file_name) 73 74 75 class EventListenerTestStreamHandler( 76 grpc_debug_server.EventListenerBaseStreamHandler): 77 """Implementation of EventListenerBaseStreamHandler that dumps to file.""" 78 79 def __init__(self, dump_dir, event_listener_servicer): 80 super(EventListenerTestStreamHandler, self).__init__() 81 self._dump_dir = dump_dir 82 self._event_listener_servicer = event_listener_servicer 83 if self._dump_dir: 84 self._try_makedirs(self._dump_dir) 85 86 self._grpc_path = None 87 self._cached_graph_defs = [] 88 self._cached_graph_def_device_names = [] 89 self._cached_graph_def_wall_times = [] 90 91 def on_core_metadata_event(self, event): 92 self._event_listener_servicer.toggle_watch() 93 94 core_metadata = json.loads(event.log_message.message) 95 96 if not self._grpc_path: 97 grpc_path = core_metadata["grpc_path"] 98 if grpc_path: 99 if grpc_path.startswith("/"): 100 grpc_path = grpc_path[1:] 101 if self._dump_dir: 102 self._dump_dir = os.path.join(self._dump_dir, grpc_path) 103 104 # Write cached graph defs to filesystem. 105 for graph_def, device_name, wall_time in zip( 106 self._cached_graph_defs, 107 self._cached_graph_def_device_names, 108 self._cached_graph_def_wall_times): 109 self._write_graph_def(graph_def, device_name, wall_time) 110 111 if self._dump_dir: 112 self._write_core_metadata_event(event) 113 else: 114 self._event_listener_servicer.core_metadata_json_strings.append( 115 event.log_message.message) 116 117 def on_graph_def(self, graph_def, device_name, wall_time): 118 """Implementation of the tensor value-carrying Event proto callback. 119 120 Args: 121 graph_def: A GraphDef object. 122 device_name: Name of the device on which the graph was created. 123 wall_time: An epoch timestamp (in microseconds) for the graph. 124 """ 125 if self._dump_dir: 126 if self._grpc_path: 127 self._write_graph_def(graph_def, device_name, wall_time) 128 else: 129 self._cached_graph_defs.append(graph_def) 130 self._cached_graph_def_device_names.append(device_name) 131 self._cached_graph_def_wall_times.append(wall_time) 132 else: 133 self._event_listener_servicer.partition_graph_defs.append(graph_def) 134 135 def on_value_event(self, event): 136 """Implementation of the tensor value-carrying Event proto callback. 137 138 Writes the Event proto to the file system for testing. The path written to 139 follows the same pattern as the file:// debug URLs of tfdbg, i.e., the 140 name scope of the op becomes the directory structure under the dump root 141 directory. 142 143 Args: 144 event: The Event proto carrying a tensor value. 145 146 Returns: 147 If the debug node belongs to the set of currently activated breakpoints, 148 a `EventReply` proto will be returned. 149 """ 150 if self._dump_dir: 151 self._write_value_event(event) 152 else: 153 value = event.summary.value[0] 154 tensor_value = debug_data.load_tensor_from_event(event) 155 self._event_listener_servicer.debug_tensor_values[value.node_name].append( 156 tensor_value) 157 158 items = event.summary.value[0].node_name.split(":") 159 node_name = items[0] 160 output_slot = int(items[1]) 161 debug_op = items[2] 162 if ((node_name, output_slot, debug_op) in 163 self._event_listener_servicer.breakpoints): 164 return debug_service_pb2.EventReply() 165 166 def _try_makedirs(self, dir_path): 167 if not os.path.isdir(dir_path): 168 try: 169 os.makedirs(dir_path) 170 except OSError as error: 171 if error.errno != errno.EEXIST: 172 raise 173 174 def _write_core_metadata_event(self, event): 175 core_metadata_path = os.path.join( 176 self._dump_dir, 177 debug_data.METADATA_FILE_PREFIX + debug_data.CORE_METADATA_TAG + 178 "_%d" % event.wall_time) 179 self._try_makedirs(self._dump_dir) 180 with open(core_metadata_path, "wb") as f: 181 f.write(event.SerializeToString()) 182 183 def _write_graph_def(self, graph_def, device_name, wall_time): 184 encoded_graph_def = graph_def.SerializeToString() 185 graph_hash = int(hashlib.md5(encoded_graph_def).hexdigest(), 16) 186 event = event_pb2.Event(graph_def=encoded_graph_def, wall_time=wall_time) 187 graph_file_path = os.path.join( 188 self._dump_dir, 189 debug_data.device_name_to_device_path(device_name), 190 debug_data.METADATA_FILE_PREFIX + debug_data.GRAPH_FILE_TAG + 191 debug_data.HASH_TAG + "%d_%d" % (graph_hash, wall_time)) 192 self._try_makedirs(os.path.dirname(graph_file_path)) 193 with open(graph_file_path, "wb") as f: 194 f.write(event.SerializeToString()) 195 196 def _write_value_event(self, event): 197 value = event.summary.value[0] 198 199 # Obtain the device name from the metadata. 200 summary_metadata = event.summary.value[0].metadata 201 if not summary_metadata.plugin_data: 202 raise ValueError("The value lacks plugin data.") 203 try: 204 content = json.loads(compat.as_text(summary_metadata.plugin_data.content)) 205 except ValueError as err: 206 raise ValueError("Could not parse content into JSON: %r, %r" % (content, 207 err)) 208 device_name = content["device"] 209 210 dump_full_path = _get_dump_file_path( 211 self._dump_dir, device_name, value.node_name) 212 self._try_makedirs(os.path.dirname(dump_full_path)) 213 with open(dump_full_path, "wb") as f: 214 f.write(event.SerializeToString()) 215 216 217 class EventListenerTestServicer(grpc_debug_server.EventListenerBaseServicer): 218 """An implementation of EventListenerBaseServicer for testing.""" 219 220 def __init__(self, server_port, dump_dir, toggle_watch_on_core_metadata=None): 221 """Constructor of EventListenerTestServicer. 222 223 Args: 224 server_port: (int) The server port number. 225 dump_dir: (str) The root directory to which the data files will be 226 dumped. If empty or None, the received debug data will not be dumped 227 to the file system: they will be stored in memory instead. 228 toggle_watch_on_core_metadata: A list of 229 (node_name, output_slot, debug_op) tuples to toggle the 230 watchpoint status during the on_core_metadata calls (optional). 231 """ 232 self.core_metadata_json_strings = [] 233 self.partition_graph_defs = [] 234 self.debug_tensor_values = collections.defaultdict(list) 235 self._initialize_toggle_watch_state(toggle_watch_on_core_metadata) 236 237 grpc_debug_server.EventListenerBaseServicer.__init__( 238 self, server_port, 239 functools.partial(EventListenerTestStreamHandler, dump_dir, self)) 240 241 # Members for storing the graph ops traceback and source files. 242 self._call_types = [] 243 self._call_keys = [] 244 self._origin_stacks = [] 245 self._origin_id_to_strings = [] 246 self._graph_tracebacks = [] 247 self._graph_versions = [] 248 self._source_files = [] 249 250 def _initialize_toggle_watch_state(self, toggle_watches): 251 self._toggle_watches = toggle_watches 252 self._toggle_watch_state = dict() 253 if self._toggle_watches: 254 for watch_key in self._toggle_watches: 255 self._toggle_watch_state[watch_key] = False 256 257 def toggle_watch(self): 258 for watch_key in self._toggle_watch_state: 259 node_name, output_slot, debug_op = watch_key 260 if self._toggle_watch_state[watch_key]: 261 self.request_unwatch(node_name, output_slot, debug_op) 262 else: 263 self.request_watch(node_name, output_slot, debug_op) 264 self._toggle_watch_state[watch_key] = ( 265 not self._toggle_watch_state[watch_key]) 266 267 def clear_data(self): 268 self.core_metadata_json_strings = [] 269 self.partition_graph_defs = [] 270 self.debug_tensor_values = collections.defaultdict(list) 271 self._call_types = [] 272 self._call_keys = [] 273 self._origin_stacks = [] 274 self._origin_id_to_strings = [] 275 self._graph_tracebacks = [] 276 self._graph_versions = [] 277 self._source_files = [] 278 279 def SendTracebacks(self, request, context): 280 self._call_types.append(request.call_type) 281 self._call_keys.append(request.call_key) 282 self._origin_stacks.append(request.origin_stack) 283 self._origin_id_to_strings.append(request.origin_id_to_string) 284 self._graph_tracebacks.append(request.graph_traceback) 285 self._graph_versions.append(request.graph_version) 286 return debug_service_pb2.EventReply() 287 288 def SendSourceFiles(self, request, context): 289 self._source_files.append(request) 290 return debug_service_pb2.EventReply() 291 292 def query_op_traceback(self, op_name): 293 """Query the traceback of an op. 294 295 Args: 296 op_name: Name of the op to query. 297 298 Returns: 299 The traceback of the op, as a list of 3-tuples: 300 (filename, lineno, function_name) 301 302 Raises: 303 ValueError: If the op cannot be found in the tracebacks received by the 304 server so far. 305 """ 306 for op_log_proto in self._graph_tracebacks: 307 for log_entry in op_log_proto.log_entries: 308 if log_entry.name == op_name: 309 return self._code_def_to_traceback(log_entry.code_def, 310 op_log_proto.id_to_string) 311 raise ValueError( 312 "Op '%s' does not exist in the tracebacks received by the debug " 313 "server." % op_name) 314 315 def query_origin_stack(self): 316 """Query the stack of the origin of the execution call. 317 318 Returns: 319 A `list` of all tracebacks. Each item corresponds to an execution call, 320 i.e., a `SendTracebacks` request. Each item is a `list` of 3-tuples: 321 (filename, lineno, function_name). 322 """ 323 ret = [] 324 for stack, id_to_string in zip( 325 self._origin_stacks, self._origin_id_to_strings): 326 ret.append(self._code_def_to_traceback(stack, id_to_string)) 327 return ret 328 329 def query_call_types(self): 330 return self._call_types 331 332 def query_call_keys(self): 333 return self._call_keys 334 335 def query_graph_versions(self): 336 return self._graph_versions 337 338 def query_source_file_line(self, file_path, lineno): 339 """Query the content of a given line in a source file. 340 341 Args: 342 file_path: Path to the source file. 343 lineno: Line number as an `int`. 344 345 Returns: 346 Content of the line as a string. 347 348 Raises: 349 ValueError: If no source file is found at the given file_path. 350 """ 351 if not self._source_files: 352 raise ValueError( 353 "This debug server has not received any source file contents yet.") 354 for source_files in self._source_files: 355 for source_file_proto in source_files.source_files: 356 if source_file_proto.file_path == file_path: 357 return source_file_proto.lines[lineno - 1] 358 raise ValueError( 359 "Source file at path %s has not been received by the debug server", 360 file_path) 361 362 def _code_def_to_traceback(self, code_def, id_to_string): 363 return [(id_to_string[trace.file_id], 364 trace.lineno, 365 id_to_string[trace.function_id]) for trace in code_def.traces] 366 367 368 def start_server_on_separate_thread(dump_to_filesystem=True, 369 server_start_delay_sec=0.0, 370 poll_server=False, 371 blocking=True, 372 toggle_watch_on_core_metadata=None): 373 """Create a test gRPC debug server and run on a separate thread. 374 375 Args: 376 dump_to_filesystem: (bool) whether the debug server will dump debug data 377 to the filesystem. 378 server_start_delay_sec: (float) amount of time (in sec) to delay the server 379 start up for. 380 poll_server: (bool) whether the server will be polled till success on 381 startup. 382 blocking: (bool) whether the server should be started in a blocking mode. 383 toggle_watch_on_core_metadata: A list of 384 (node_name, output_slot, debug_op) tuples to toggle the 385 watchpoint status during the on_core_metadata calls (optional). 386 387 Returns: 388 server_port: (int) Port on which the server runs. 389 debug_server_url: (str) grpc:// URL to the server. 390 server_dump_dir: (str) The debug server's dump directory. 391 server_thread: The server Thread object. 392 server: The `EventListenerTestServicer` object. 393 394 Raises: 395 ValueError: If polling the server process for ready state is not successful 396 within maximum polling count. 397 """ 398 server_port = portpicker.pick_unused_port() 399 debug_server_url = "grpc://localhost:%d" % server_port 400 401 server_dump_dir = tempfile.mkdtemp() if dump_to_filesystem else None 402 server = EventListenerTestServicer( 403 server_port=server_port, 404 dump_dir=server_dump_dir, 405 toggle_watch_on_core_metadata=toggle_watch_on_core_metadata) 406 407 def delay_then_run_server(): 408 time.sleep(server_start_delay_sec) 409 server.run_server(blocking=blocking) 410 411 server_thread = threading.Thread(target=delay_then_run_server) 412 server_thread.start() 413 414 if poll_server: 415 if not _poll_server_till_success( 416 50, 417 0.2, 418 debug_server_url, 419 server_dump_dir, 420 server, 421 gpu_memory_fraction=0.1): 422 raise ValueError( 423 "Failed to start test gRPC debug server at port %d" % server_port) 424 server.clear_data() 425 return server_port, debug_server_url, server_dump_dir, server_thread, server 426 427 428 def _poll_server_till_success(max_attempts, 429 sleep_per_poll_sec, 430 debug_server_url, 431 dump_dir, 432 server, 433 gpu_memory_fraction=1.0): 434 """Poll server until success or exceeding max polling count. 435 436 Args: 437 max_attempts: (int) How many times to poll at maximum 438 sleep_per_poll_sec: (float) How many seconds to sleep for after each 439 unsuccessful poll. 440 debug_server_url: (str) gRPC URL to the debug server. 441 dump_dir: (str) Dump directory to look for files in. If None, will directly 442 check data from the server object. 443 server: The server object. 444 gpu_memory_fraction: (float) Fraction of GPU memory to be 445 allocated for the Session used in server polling. 446 447 Returns: 448 (bool) Whether the polling succeeded within max_polls attempts. 449 """ 450 poll_count = 0 451 452 config = config_pb2.ConfigProto(gpu_options=config_pb2.GPUOptions( 453 per_process_gpu_memory_fraction=gpu_memory_fraction)) 454 with session.Session(config=config) as sess: 455 for poll_count in range(max_attempts): 456 server.clear_data() 457 print("Polling: poll_count = %d" % poll_count) 458 459 x_init_name = "x_init_%d" % poll_count 460 x_init = constant_op.constant([42.0], shape=[1], name=x_init_name) 461 x = variables.Variable(x_init, name=x_init_name) 462 463 run_options = config_pb2.RunOptions() 464 debug_utils.add_debug_tensor_watch( 465 run_options, x_init_name, 0, debug_urls=[debug_server_url]) 466 try: 467 sess.run(x.initializer, options=run_options) 468 except errors.FailedPreconditionError: 469 pass 470 471 if dump_dir: 472 if os.path.isdir( 473 dump_dir) and debug_data.DebugDumpDir(dump_dir).size > 0: 474 shutil.rmtree(dump_dir) 475 print("Poll succeeded.") 476 return True 477 else: 478 print("Poll failed. Sleeping for %f s" % sleep_per_poll_sec) 479 time.sleep(sleep_per_poll_sec) 480 else: 481 if server.debug_tensor_values: 482 print("Poll succeeded.") 483 return True 484 else: 485 print("Poll failed. Sleeping for %f s" % sleep_per_poll_sec) 486 time.sleep(sleep_per_poll_sec) 487 488 return False 489