1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the 'License'); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an 'AS IS' BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ====================================== 15 """Operations for handling session logging and shutdown notifications.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import threading 22 23 import time 24 from google.protobuf import text_format 25 26 from tensorflow.core.protobuf import config_pb2 27 from tensorflow.core.util import event_pb2 28 from tensorflow.python.client import session as session_lib 29 from tensorflow.python.framework import dtypes 30 from tensorflow.python.framework import errors 31 from tensorflow.python.framework import ops 32 from tensorflow.python.ops import array_ops 33 from tensorflow.python.platform import tf_logging as logging 34 from tensorflow.python.tpu.ops import tpu_ops 35 from tensorflow.python.training import session_run_hook 36 from tensorflow.python.training import training_util 37 38 _WATCHDOG = None 39 40 41 class CoordinatorShutdownException(Exception): 42 """Raised when the coordinator needs to shutdown.""" 43 pass 44 45 46 def _clone_session(session, graph=None): 47 return session_lib.Session( 48 target=session.sess_str, 49 config=session._config, # pylint: disable=protected-access 50 graph=graph if graph else session.graph) 51 52 53 class WorkerHeartbeatManager(object): 54 """Manages the status/heartbeat monitor for a set of workers.""" 55 56 def __init__(self, session, devices, heartbeat_ops, request_placeholder): 57 """Construct a new WorkerHeartbeatManager. 58 59 (Prefer using `WorkerHeartbeatManager.from_devices` when possible.) 60 61 Args: 62 session: `tf.Session`, session to use for heartbeat operations. 63 devices: `list[string]` Set of devices to connect to. 64 heartbeat_ops: `list[tf.Operation]` Heartbeat operations. 65 request_placeholder: `tf.Placeholder[String]` Placeholder used to specify 66 the WorkerHeartbeatRequest protocol buffer. 67 """ 68 self._session = session 69 self._devices = devices 70 self._ops = heartbeat_ops 71 self._request_placeholder = request_placeholder 72 73 @staticmethod 74 def from_devices(session, devices): 75 """Construct a heartbeat manager for the given devices.""" 76 if not devices: 77 logging.error('Trying to create heartbeat manager with no devices?') 78 79 logging.info('Creating heartbeat manager for %s', devices) 80 request_placeholder = array_ops.placeholder( 81 name='worker_heartbeat_request', dtype=dtypes.string) 82 83 heartbeat_ops = [] 84 for device in devices: 85 with ops.device(device): 86 heartbeat_ops.append(tpu_ops.worker_heartbeat(request_placeholder)) 87 88 return WorkerHeartbeatManager(session, devices, heartbeat_ops, 89 request_placeholder) 90 91 def num_workers(self): 92 return len(self._devices) 93 94 def configure(self, message): 95 """Configure heartbeat manager for all devices. 96 97 Args: 98 message: `event_pb2.WorkerHeartbeatRequest` 99 Returns: `None` 100 """ 101 logging.info('Configuring worker heartbeat: %s', 102 text_format.MessageToString(message)) 103 self._session.run(self._ops, 104 {self._request_placeholder: message.SerializeToString()}) 105 106 def ping(self, request=None, timeout_in_ms=5000): 107 """Ping all workers, returning the parsed status results.""" 108 if request is None: 109 request = event_pb2.WorkerHeartbeatRequest() 110 111 options = config_pb2.RunOptions(timeout_in_ms=timeout_in_ms) 112 results = self._session.run( 113 self._ops, 114 feed_dict={self._request_placeholder: request.SerializeToString()}, 115 options=options) 116 parsed_results = [ 117 event_pb2.WorkerHeartbeatResponse.FromString(res_pb) 118 for res_pb in results 119 ] 120 logging.debug('Ping results: %s', parsed_results) 121 return parsed_results 122 123 def lame_workers(self): 124 """Ping all workers, returning manager containing lame workers (or None).""" 125 ping_results = self.ping() 126 lame_workers = [] 127 128 for ping_response, device, op in zip(ping_results, self._devices, 129 self._ops): 130 if ping_response.health_status != event_pb2.OK: 131 lame_workers.append((device, op)) 132 133 if not lame_workers: 134 return None 135 136 bad_devices, bad_ops = zip(*lame_workers) 137 return WorkerHeartbeatManager(self._session, bad_devices, bad_ops, 138 self._request_placeholder) 139 140 def __repr__(self): 141 return 'HeartbeatManager(%s)' % ','.join(self._devices) 142 143 def shutdown(self, timeout_ms=10000): 144 """Shutdown all workers after `shutdown_timeout_secs`.""" 145 logging.info('Shutting down %s.', self) 146 req = event_pb2.WorkerHeartbeatRequest( 147 watchdog_config=event_pb2.WatchdogConfig(timeout_ms=timeout_ms), 148 shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR) 149 self.configure(req) 150 151 # Wait for workers to shutdown. This isn't strictly required 152 # but it avoids triggering multiple checkpoints with the same lame worker. 153 logging.info('Waiting %dms for worker shutdown.', timeout_ms) 154 time.sleep(timeout_ms / 1000) 155 156 157 def all_worker_devices(session): 158 """Return a list of devices for each worker in the system.""" 159 devices = session.list_devices() 160 161 devices_that_support_heartbeats = [] 162 163 for device in devices: 164 name = device.name 165 # Pick devices that have a TPU but target the attached CPU 166 if ':TPU:0' in name and 'coordinator' not in name: 167 devices_that_support_heartbeats.append(name.replace('TPU', 'CPU')) 168 169 return devices_that_support_heartbeats 170 171 172 class WatchdogManager(threading.Thread): 173 """Configures worker watchdog timer and handles periodic pings. 174 175 Usage: 176 # Ping workers every minute, shutting down workers if they haven't received 177 # a ping after 1 hour. 178 watchdog_manager = WatchdogManager( 179 ping_interval=60, shutdown_timeout=3600 180 ) 181 182 # Use as a context manager, resetting watchdog on context exit: 183 with watchdog_manager: 184 session.run(...) 185 186 # Or setup globally; watchdog will remain active until program exit. 187 watchdog_manager.configure_and_run() 188 """ 189 190 def __init__(self, 191 session, 192 devices=None, 193 ping_interval=60, 194 shutdown_timeout=3600): 195 """Initialize a watchdog manager. 196 197 Args: 198 session: Session connected to worker devices. A cloned session and graph 199 will be created for managing worker pings. 200 devices: Set of devices to monitor. If none, all workers will be 201 monitored. 202 ping_interval: Time, in seconds, between watchdog pings. 203 shutdown_timeout: Time, in seconds, before watchdog timeout. 204 """ 205 threading.Thread.__init__(self) 206 self.ping_interval = ping_interval 207 self.shutdown_timeout = shutdown_timeout 208 self.daemon = True 209 self._config = session._config # pylint: disable=protected-access 210 self._target = session.sess_str 211 self._running = False 212 self._devices = devices 213 214 self._graph = None 215 self._session = None 216 self._worker_manager = None 217 218 def _reset_manager(self): 219 """Reset the graph, session and worker manager.""" 220 self._graph = ops.Graph() 221 self._session = session_lib.Session( 222 target=self._target, 223 graph=self._graph, 224 config=self._config, 225 ) 226 227 if self._devices is None: 228 self._devices = all_worker_devices(self._session) 229 230 with self._graph.as_default(): 231 self._worker_manager = WorkerHeartbeatManager.from_devices( 232 self._session, self._devices) 233 234 self._worker_manager.configure( 235 event_pb2.WorkerHeartbeatRequest( 236 watchdog_config=event_pb2.WatchdogConfig( 237 timeout_ms=self.shutdown_timeout * 1000,), 238 shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR)) 239 240 def configure_and_run(self): 241 logging.info( 242 'Enabling watchdog timer with %d second timeout ' 243 'and %d second ping interval.', self.shutdown_timeout, 244 self.ping_interval) 245 self._reset_manager() 246 self._running = True 247 self.start() 248 249 def stop(self): 250 logging.info('Stopping worker watchdog.') 251 self._worker_manager.configure( 252 event_pb2.WorkerHeartbeatRequest( 253 watchdog_config=event_pb2.WatchdogConfig(timeout_ms=-1,), 254 shutdown_mode=event_pb2.NOT_CONFIGURED)) 255 self._running = False 256 self.join() 257 258 def __enter__(self): 259 self.configure_and_run() 260 261 def __exit__(self, exc_type, exc_val, exc_tb): 262 self.stop() 263 264 def run(self): 265 # Don't fetch logs or adjust timing: just ping the watchdog. 266 # 267 # If we hit an exception, reset our session as it is likely broken. 268 while self._running: 269 try: 270 self._worker_manager.ping(request=None) 271 time.sleep(self.ping_interval) 272 except errors.OpError as e: 273 # Catch any TF errors that occur so we don't stop sending heartbeats 274 logging.debug('Caught error while sending heartbeat: %s', e) 275 self._reset_manager() 276 277 278 def start_worker_watchdog(session, 279 devices=None, 280 ping_interval=60, 281 shutdown_timeout=3600): 282 """Start global worker watchdog to shutdown workers on coordinator exit.""" 283 global _WATCHDOG 284 if _WATCHDOG is None: 285 # Ensure we can send a few pings before we timeout! 286 ping_interval = min(shutdown_timeout / 10., ping_interval) 287 _WATCHDOG = WatchdogManager(session, devices, ping_interval, 288 shutdown_timeout) 289 _WATCHDOG.configure_and_run() 290 291 292 class GracefulShutdownHook(session_run_hook.SessionRunHook): 293 """Session hook that watches for shutdown events. 294 295 If a shutdown is indicated, `saver.save(checkpoint_prefix)` is executed, and a 296 SystemShutdown exception is raised to terminate the main session. If `saver` 297 is None the `SAVERS` collection will be read to find a saver. 298 299 `on_shutdown_hooks` is an optional list of functions that should be called 300 after checkpointing. The function is called with (`run_context`, 301 `all_workers`, `lame_workers`). 302 303 If `heartbeat_group` is not specified, it will default to all CPU workers 304 in the system. 305 """ 306 307 def __init__(self, checkpoint_prefix, saver=None, on_shutdown_hooks=None): 308 self._saver = saver 309 self._checkpoint_prefix = checkpoint_prefix 310 self._on_shutdown_hooks = on_shutdown_hooks if on_shutdown_hooks else [] 311 312 # Worker heartbeats are managed independently of the main training graph. 313 self._graph = ops.Graph() 314 self._workers = None 315 self._session = None 316 self._heartbeat_supported = False 317 318 def after_create_session(self, training_session, coord): # pylint: disable=unused-argument 319 # N.B. We have to pull the global step here to avoid it being unavailable 320 # at checkpoint time; the graph has been frozen at that point. 321 if training_util.get_global_step() is None and self.saver() is not None: 322 raise ValueError( 323 'Saver defined but no global step. Run `get_or_create_global_step()`' 324 ' in your model definition to allow checkpointing.') 325 326 with self._graph.as_default(): 327 logging.info('Installing graceful shutdown hook.') 328 self._session = _clone_session(training_session, self._graph) 329 self._workers = WorkerHeartbeatManager.from_devices( 330 self._session, all_worker_devices(self._session)) 331 self._heartbeat_supported = self._workers.num_workers() > 0 332 if self._heartbeat_supported: 333 try: 334 self._workers.configure( 335 event_pb2.WorkerHeartbeatRequest( 336 shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR)) 337 except errors.InvalidArgumentError: 338 logging.warn( 339 'TPU device does not support heartbeats. Failure ' 340 'handling will be disabled.') 341 self._heartbeat_supported = False 342 else: 343 logging.warn( 344 'No workers support hearbeats. Failure handling will be disabled.') 345 346 def saver(self): 347 if self._saver: 348 return self._saver 349 350 savers = ops.get_collection(ops.GraphKeys.SAVERS) 351 if not savers: 352 return None 353 354 if not isinstance(savers, list): 355 return savers 356 357 if len(savers) > 1: 358 logging.error( 359 'Multiple savers in the SAVERS collection. On-demand checkpointing ' 360 'will be disabled. Pass an explicit `saver` to the constructor to ' 361 'override this behavior.') 362 return None 363 364 return savers[0] 365 366 def after_run(self, run_context, run_values): 367 del run_values 368 369 if not self._heartbeat_supported: 370 return 371 372 lame_workers = self._workers.lame_workers() 373 if lame_workers: 374 logging.info('ShutdownHook: lame workers found: %s', lame_workers) 375 376 if self.saver(): 377 logging.info('ShutdownHook: saving checkpoint to %s', 378 self._checkpoint_prefix) 379 self.saver().save( 380 run_context.session, 381 self._checkpoint_prefix, 382 global_step=training_util.get_global_step(), 383 write_state=True, 384 ) 385 else: 386 logging.info('ShutdownHook: no Saver defined.') 387 388 for fn in self._on_shutdown_hooks: 389 fn(run_context, self._workers, lame_workers) 390 391 392 class RestartComputation(object): 393 """Restart the entire computation. 394 395 This hook shuts down all workers and returns control to the top-level by 396 throwing a CoordinatorShutdownException. 397 """ 398 399 def __init__(self, timeout_ms=10000): 400 self.timeout_ms = timeout_ms 401 402 def __call__(self, run_context, all_workers, lame_workers): 403 del run_context, lame_workers 404 all_workers.shutdown(timeout_ms=self.timeout_ms) 405 406 logging.info('Terminating coordinator.') 407 raise CoordinatorShutdownException() 408 409 410 class ShutdownLameWorkers(object): 411 """Shutdown lamed workers. 412 413 Processing will continue normally (typically by waiting for the down 414 workers to be restarted). 415 """ 416 417 def __init__(self, timeout_ms=10000): 418 self.timeout_in_ms = timeout_ms 419 420 def __call__(self, run_context, all_workers, lame_workers): 421 lame_workers.shutdown(timeout_ms=self.timeout_in_ms) 422