Home | History | Annotate | Download | only in tpu
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the 'License');
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an 'AS IS' BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ======================================
     15 """Operations for handling session logging and shutdown notifications."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import threading
     22 
     23 import time
     24 from google.protobuf import text_format
     25 
     26 from tensorflow.core.protobuf import config_pb2
     27 from tensorflow.core.util import event_pb2
     28 from tensorflow.python.client import session as session_lib
     29 from tensorflow.python.framework import dtypes
     30 from tensorflow.python.framework import errors
     31 from tensorflow.python.framework import ops
     32 from tensorflow.python.ops import array_ops
     33 from tensorflow.python.platform import tf_logging as logging
     34 from tensorflow.python.tpu.ops import tpu_ops
     35 from tensorflow.python.training import session_run_hook
     36 from tensorflow.python.training import training_util
     37 
     38 _WATCHDOG = None
     39 
     40 
     41 class CoordinatorShutdownException(Exception):
     42   """Raised when the coordinator needs to shutdown."""
     43   pass
     44 
     45 
     46 def _clone_session(session, graph=None):
     47   return session_lib.Session(
     48       target=session.sess_str,
     49       config=session._config,  # pylint: disable=protected-access
     50       graph=graph if graph else session.graph)
     51 
     52 
     53 class WorkerHeartbeatManager(object):
     54   """Manages the status/heartbeat monitor for a set of workers."""
     55 
     56   def __init__(self, session, devices, heartbeat_ops, request_placeholder):
     57     """Construct a new WorkerHeartbeatManager.
     58 
     59     (Prefer using `WorkerHeartbeatManager.from_devices` when possible.)
     60 
     61     Args:
     62       session: `tf.Session`, session to use for heartbeat operations.
     63       devices: `list[string]` Set of devices to connect to.
     64       heartbeat_ops: `list[tf.Operation]` Heartbeat operations.
     65       request_placeholder: `tf.Placeholder[String]` Placeholder used to specify
     66         the WorkerHeartbeatRequest protocol buffer.
     67     """
     68     self._session = session
     69     self._devices = devices
     70     self._ops = heartbeat_ops
     71     self._request_placeholder = request_placeholder
     72 
     73   @staticmethod
     74   def from_devices(session, devices):
     75     """Construct a heartbeat manager for the given devices."""
     76     if not devices:
     77       logging.error('Trying to create heartbeat manager with no devices?')
     78 
     79     logging.info('Creating heartbeat manager for %s', devices)
     80     request_placeholder = array_ops.placeholder(
     81         name='worker_heartbeat_request', dtype=dtypes.string)
     82 
     83     heartbeat_ops = []
     84     for device in devices:
     85       with ops.device(device):
     86         heartbeat_ops.append(tpu_ops.worker_heartbeat(request_placeholder))
     87 
     88     return WorkerHeartbeatManager(session, devices, heartbeat_ops,
     89                                   request_placeholder)
     90 
     91   def num_workers(self):
     92     return len(self._devices)
     93 
     94   def configure(self, message):
     95     """Configure heartbeat manager for all devices.
     96 
     97     Args:
     98       message: `event_pb2.WorkerHeartbeatRequest`
     99     Returns: `None`
    100     """
    101     logging.info('Configuring worker heartbeat: %s',
    102                  text_format.MessageToString(message))
    103     self._session.run(self._ops,
    104                       {self._request_placeholder: message.SerializeToString()})
    105 
    106   def ping(self, request=None, timeout_in_ms=5000):
    107     """Ping all workers, returning the parsed status results."""
    108     if request is None:
    109       request = event_pb2.WorkerHeartbeatRequest()
    110 
    111     options = config_pb2.RunOptions(timeout_in_ms=timeout_in_ms)
    112     results = self._session.run(
    113         self._ops,
    114         feed_dict={self._request_placeholder: request.SerializeToString()},
    115         options=options)
    116     parsed_results = [
    117         event_pb2.WorkerHeartbeatResponse.FromString(res_pb)
    118         for res_pb in results
    119     ]
    120     logging.debug('Ping results: %s', parsed_results)
    121     return parsed_results
    122 
    123   def lame_workers(self):
    124     """Ping all workers, returning manager containing lame workers (or None)."""
    125     ping_results = self.ping()
    126     lame_workers = []
    127 
    128     for ping_response, device, op in zip(ping_results, self._devices,
    129                                          self._ops):
    130       if ping_response.health_status != event_pb2.OK:
    131         lame_workers.append((device, op))
    132 
    133     if not lame_workers:
    134       return None
    135 
    136     bad_devices, bad_ops = zip(*lame_workers)
    137     return WorkerHeartbeatManager(self._session, bad_devices, bad_ops,
    138                                   self._request_placeholder)
    139 
    140   def __repr__(self):
    141     return 'HeartbeatManager(%s)' % ','.join(self._devices)
    142 
    143   def shutdown(self, timeout_ms=10000):
    144     """Shutdown all workers after `shutdown_timeout_secs`."""
    145     logging.info('Shutting down %s.', self)
    146     req = event_pb2.WorkerHeartbeatRequest(
    147         watchdog_config=event_pb2.WatchdogConfig(timeout_ms=timeout_ms),
    148         shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR)
    149     self.configure(req)
    150 
    151     # Wait for workers to shutdown.  This isn't strictly required
    152     # but it avoids triggering multiple checkpoints with the same lame worker.
    153     logging.info('Waiting %dms for worker shutdown.', timeout_ms)
    154     time.sleep(timeout_ms / 1000)
    155 
    156 
    157 def all_worker_devices(session):
    158   """Return a list of devices for each worker in the system."""
    159   devices = session.list_devices()
    160 
    161   devices_that_support_heartbeats = []
    162 
    163   for device in devices:
    164     name = device.name
    165     # Pick devices that have a TPU but target the attached CPU
    166     if ':TPU:0' in name and 'coordinator' not in name:
    167       devices_that_support_heartbeats.append(name.replace('TPU', 'CPU'))
    168 
    169   return devices_that_support_heartbeats
    170 
    171 
    172 class WatchdogManager(threading.Thread):
    173   """Configures worker watchdog timer and handles periodic pings.
    174 
    175   Usage:
    176     # Ping workers every minute, shutting down workers if they haven't received
    177     # a ping after 1 hour.
    178     watchdog_manager = WatchdogManager(
    179       ping_interval=60, shutdown_timeout=3600
    180     )
    181 
    182     # Use as a context manager, resetting watchdog on context exit:
    183     with watchdog_manager:
    184       session.run(...)
    185 
    186     # Or setup globally; watchdog will remain active until program exit.
    187     watchdog_manager.configure_and_run()
    188   """
    189 
    190   def __init__(self,
    191                session,
    192                devices=None,
    193                ping_interval=60,
    194                shutdown_timeout=3600):
    195     """Initialize a watchdog manager.
    196 
    197     Args:
    198       session: Session connected to worker devices.  A cloned session and graph
    199         will be created for managing worker pings.
    200       devices: Set of devices to monitor.  If none, all workers will be
    201         monitored.
    202       ping_interval: Time, in seconds, between watchdog pings.
    203       shutdown_timeout: Time, in seconds, before watchdog timeout.
    204     """
    205     threading.Thread.__init__(self)
    206     self.ping_interval = ping_interval
    207     self.shutdown_timeout = shutdown_timeout
    208     self.daemon = True
    209     self._config = session._config  # pylint: disable=protected-access
    210     self._target = session.sess_str
    211     self._running = False
    212     self._devices = devices
    213 
    214     self._graph = None
    215     self._session = None
    216     self._worker_manager = None
    217 
    218   def _reset_manager(self):
    219     """Reset the graph, session and worker manager."""
    220     self._graph = ops.Graph()
    221     self._session = session_lib.Session(
    222         target=self._target,
    223         graph=self._graph,
    224         config=self._config,
    225     )
    226 
    227     if self._devices is None:
    228       self._devices = all_worker_devices(self._session)
    229 
    230     with self._graph.as_default():
    231       self._worker_manager = WorkerHeartbeatManager.from_devices(
    232           self._session, self._devices)
    233 
    234     self._worker_manager.configure(
    235         event_pb2.WorkerHeartbeatRequest(
    236             watchdog_config=event_pb2.WatchdogConfig(
    237                 timeout_ms=self.shutdown_timeout * 1000,),
    238             shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR))
    239 
    240   def configure_and_run(self):
    241     logging.info(
    242         'Enabling watchdog timer with %d second timeout '
    243         'and %d second ping interval.', self.shutdown_timeout,
    244         self.ping_interval)
    245     self._reset_manager()
    246     self._running = True
    247     self.start()
    248 
    249   def stop(self):
    250     logging.info('Stopping worker watchdog.')
    251     self._worker_manager.configure(
    252         event_pb2.WorkerHeartbeatRequest(
    253             watchdog_config=event_pb2.WatchdogConfig(timeout_ms=-1,),
    254             shutdown_mode=event_pb2.NOT_CONFIGURED))
    255     self._running = False
    256     self.join()
    257 
    258   def __enter__(self):
    259     self.configure_and_run()
    260 
    261   def __exit__(self, exc_type, exc_val, exc_tb):
    262     self.stop()
    263 
    264   def run(self):
    265     # Don't fetch logs or adjust timing: just ping the watchdog.
    266     #
    267     # If we hit an exception, reset our session as it is likely broken.
    268     while self._running:
    269       try:
    270         self._worker_manager.ping(request=None)
    271         time.sleep(self.ping_interval)
    272       except errors.OpError as e:
    273         # Catch any TF errors that occur so we don't stop sending heartbeats
    274         logging.debug('Caught error while sending heartbeat: %s', e)
    275         self._reset_manager()
    276 
    277 
    278 def start_worker_watchdog(session,
    279                           devices=None,
    280                           ping_interval=60,
    281                           shutdown_timeout=3600):
    282   """Start global worker watchdog to shutdown workers on coordinator exit."""
    283   global _WATCHDOG
    284   if _WATCHDOG is None:
    285     # Ensure we can send a few pings before we timeout!
    286     ping_interval = min(shutdown_timeout / 10., ping_interval)
    287     _WATCHDOG = WatchdogManager(session, devices, ping_interval,
    288                                 shutdown_timeout)
    289     _WATCHDOG.configure_and_run()
    290 
    291 
    292 class GracefulShutdownHook(session_run_hook.SessionRunHook):
    293   """Session hook that watches for shutdown events.
    294 
    295   If a shutdown is indicated, `saver.save(checkpoint_prefix)` is executed, and a
    296   SystemShutdown exception is raised to terminate the main session.  If `saver`
    297   is None the `SAVERS` collection will be read to find a saver.
    298 
    299   `on_shutdown_hooks` is an optional list of functions that should be called
    300   after checkpointing.  The function is called with (`run_context`,
    301   `all_workers`, `lame_workers`).
    302 
    303   If `heartbeat_group` is not specified, it will default to all CPU workers
    304   in the system.
    305   """
    306 
    307   def __init__(self, checkpoint_prefix, saver=None, on_shutdown_hooks=None):
    308     self._saver = saver
    309     self._checkpoint_prefix = checkpoint_prefix
    310     self._on_shutdown_hooks = on_shutdown_hooks if on_shutdown_hooks else []
    311 
    312     # Worker heartbeats are managed independently of the main training graph.
    313     self._graph = ops.Graph()
    314     self._workers = None
    315     self._session = None
    316     self._heartbeat_supported = False
    317 
    318   def after_create_session(self, training_session, coord):  # pylint: disable=unused-argument
    319     # N.B. We have to pull the global step here to avoid it being unavailable
    320     # at checkpoint time; the graph has been frozen at that point.
    321     if training_util.get_global_step() is None and self.saver() is not None:
    322       raise ValueError(
    323           'Saver defined but no global step.  Run `get_or_create_global_step()`'
    324           ' in your model definition to allow checkpointing.')
    325 
    326     with self._graph.as_default():
    327       logging.info('Installing graceful shutdown hook.')
    328       self._session = _clone_session(training_session, self._graph)
    329       self._workers = WorkerHeartbeatManager.from_devices(
    330           self._session, all_worker_devices(self._session))
    331       self._heartbeat_supported = self._workers.num_workers() > 0
    332       if self._heartbeat_supported:
    333         try:
    334           self._workers.configure(
    335               event_pb2.WorkerHeartbeatRequest(
    336                   shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR))
    337         except errors.InvalidArgumentError:
    338           logging.warn(
    339               'TPU device does not support heartbeats. Failure '
    340               'handling will be disabled.')
    341           self._heartbeat_supported = False
    342       else:
    343         logging.warn(
    344             'No workers support hearbeats. Failure handling will be disabled.')
    345 
    346   def saver(self):
    347     if self._saver:
    348       return self._saver
    349 
    350     savers = ops.get_collection(ops.GraphKeys.SAVERS)
    351     if not savers:
    352       return None
    353 
    354     if not isinstance(savers, list):
    355       return savers
    356 
    357     if len(savers) > 1:
    358       logging.error(
    359           'Multiple savers in the SAVERS collection.  On-demand checkpointing '
    360           'will be disabled. Pass an explicit `saver` to the constructor to '
    361           'override this behavior.')
    362       return None
    363 
    364     return savers[0]
    365 
    366   def after_run(self, run_context, run_values):
    367     del run_values
    368 
    369     if not self._heartbeat_supported:
    370       return
    371 
    372     lame_workers = self._workers.lame_workers()
    373     if lame_workers:
    374       logging.info('ShutdownHook: lame workers found: %s', lame_workers)
    375 
    376       if self.saver():
    377         logging.info('ShutdownHook: saving checkpoint to %s',
    378                      self._checkpoint_prefix)
    379         self.saver().save(
    380             run_context.session,
    381             self._checkpoint_prefix,
    382             global_step=training_util.get_global_step(),
    383             write_state=True,
    384         )
    385       else:
    386         logging.info('ShutdownHook: no Saver defined.')
    387 
    388       for fn in self._on_shutdown_hooks:
    389         fn(run_context, self._workers, lame_workers)
    390 
    391 
    392 class RestartComputation(object):
    393   """Restart the entire computation.
    394 
    395   This hook shuts down all workers and returns control to the top-level by
    396   throwing a CoordinatorShutdownException.
    397   """
    398 
    399   def __init__(self, timeout_ms=10000):
    400     self.timeout_ms = timeout_ms
    401 
    402   def __call__(self, run_context, all_workers, lame_workers):
    403     del run_context, lame_workers
    404     all_workers.shutdown(timeout_ms=self.timeout_ms)
    405 
    406     logging.info('Terminating coordinator.')
    407     raise CoordinatorShutdownException()
    408 
    409 
    410 class ShutdownLameWorkers(object):
    411   """Shutdown lamed workers.
    412 
    413   Processing will continue normally (typically by waiting for the down
    414   workers to be restarted).
    415   """
    416 
    417   def __init__(self, timeout_ms=10000):
    418     self.timeout_in_ms = timeout_ms
    419 
    420   def __call__(self, run_context, all_workers, lame_workers):
    421     lame_workers.shutdown(timeout_ms=self.timeout_in_ms)
    422