Home | History | Annotate | Download | only in training
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """A Python interface for creating TensorFlow servers."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.core.protobuf import cluster_pb2
     22 from tensorflow.core.protobuf import tensorflow_server_pb2
     23 from tensorflow.python import pywrap_tensorflow
     24 from tensorflow.python.framework import errors
     25 from tensorflow.python.util import compat
     26 from tensorflow.python.util.tf_export import tf_export
     27 
     28 
     29 def _make_server_def(server_or_cluster_def, job_name, task_index, protocol,
     30                      config):
     31   """Creates a `tf.train.ServerDef` protocol buffer.
     32 
     33   Args:
     34     server_or_cluster_def: A `tf.train.ServerDef` or
     35       `tf.train.ClusterDef` protocol buffer, or a
     36       `tf.train.ClusterSpec` object, describing the server to be
     37       defined and/or the cluster of which it is a member.
     38     job_name: (Optional.) Specifies the name of the job of which the server
     39       is a member. Defaults to the value in `server_or_cluster_def`, if
     40       specified.
     41     task_index: (Optional.) Specifies the task index of the server in its job.
     42       Defaults to the value in `server_or_cluster_def`, if specified. Otherwise
     43       defaults to 0 if the server's job has only one task.
     44     protocol: (Optional.) Specifies the protocol to be used by the server.
     45       Acceptable values include `"grpc"`. Defaults to the value in
     46       `server_or_cluster_def`, if specified. Otherwise defaults to `"grpc"`.
     47     config: (Options.) A `tf.ConfigProto` that specifies default configuration
     48       options for all sessions that run on this server.
     49 
     50   Returns:
     51     A `tf.train.ServerDef`.
     52 
     53   Raises:
     54     TypeError: If the arguments do not have the appropriate type.
     55     ValueError: If an argument is not specified and cannot be inferred.
     56   """
     57   server_def = tensorflow_server_pb2.ServerDef()
     58   if isinstance(server_or_cluster_def, tensorflow_server_pb2.ServerDef):
     59     server_def.MergeFrom(server_or_cluster_def)
     60     if job_name is not None:
     61       server_def.job_name = job_name
     62     if task_index is not None:
     63       server_def.task_index = task_index
     64     if protocol is not None:
     65       server_def.protocol = protocol
     66     if config is not None:
     67       server_def.default_session_config.MergeFrom(config)
     68   else:
     69     try:
     70       cluster_spec = ClusterSpec(server_or_cluster_def)
     71     except TypeError:
     72       raise TypeError("Could not convert `server_or_cluster_def` to a "
     73                       "`tf.train.ServerDef` or `tf.train.ClusterSpec`.")
     74     if job_name is None:
     75       if len(cluster_spec.jobs) == 1:
     76         job_name = cluster_spec.jobs[0]
     77       else:
     78         raise ValueError("Must specify an explicit `job_name`.")
     79     if task_index is None:
     80       task_indices = cluster_spec.task_indices(job_name)
     81       if len(task_indices) == 1:
     82         task_index = task_indices[0]
     83       else:
     84         raise ValueError("Must specify an explicit `task_index`.")
     85     if protocol is None:
     86       protocol = "grpc"
     87 
     88     server_def = tensorflow_server_pb2.ServerDef(
     89         cluster=cluster_spec.as_cluster_def(),
     90         job_name=job_name, task_index=task_index, protocol=protocol)
     91     if config is not None:
     92       server_def.default_session_config.MergeFrom(config)
     93   return server_def
     94 
     95 
     96 @tf_export("train.Server")
     97 class Server(object):
     98   """An in-process TensorFlow server, for use in distributed training.
     99 
    100   A `tf.train.Server` instance encapsulates a set of devices and a
    101   @{tf.Session} target that
    102   can participate in distributed training. A server belongs to a
    103   cluster (specified by a @{tf.train.ClusterSpec}), and
    104   corresponds to a particular task in a named job. The server can
    105   communicate with any other server in the same cluster.
    106   """
    107 
    108   def __init__(self,
    109                server_or_cluster_def,
    110                job_name=None,
    111                task_index=None,
    112                protocol=None,
    113                config=None,
    114                start=True):
    115     """Creates a new server with the given definition.
    116 
    117     The `job_name`, `task_index`, and `protocol` arguments are optional, and
    118     override any information provided in `server_or_cluster_def`.
    119 
    120     Args:
    121       server_or_cluster_def: A `tf.train.ServerDef` or
    122         `tf.train.ClusterDef` protocol buffer, or a
    123         `tf.train.ClusterSpec` object, describing the server to be
    124         created and/or the cluster of which it is a member.
    125       job_name: (Optional.) Specifies the name of the job of which the server
    126         is a member. Defaults to the value in `server_or_cluster_def`, if
    127         specified.
    128       task_index: (Optional.) Specifies the task index of the server in its
    129         job. Defaults to the value in `server_or_cluster_def`, if specified.
    130         Otherwise defaults to 0 if the server's job has only one task.
    131       protocol: (Optional.) Specifies the protocol to be used by the server.
    132         Acceptable values include `"grpc"`. Defaults to the value in
    133         `server_or_cluster_def`, if specified. Otherwise defaults to `"grpc"`.
    134       config: (Options.) A `tf.ConfigProto` that specifies default
    135         configuration options for all sessions that run on this server.
    136       start: (Optional.) Boolean, indicating whether to start the server
    137         after creating it. Defaults to `True`.
    138 
    139     Raises:
    140       tf.errors.OpError: Or one of its subclasses if an error occurs while
    141         creating the TensorFlow server.
    142     """
    143     self._server_def = _make_server_def(server_or_cluster_def,
    144                                         job_name, task_index, protocol, config)
    145     with errors.raise_exception_on_not_ok_status() as status:
    146       self._server = pywrap_tensorflow.PyServer_New(
    147           self._server_def.SerializeToString(), status)
    148     if start:
    149       self.start()
    150 
    151   def start(self):
    152     """Starts this server.
    153 
    154     Raises:
    155       tf.errors.OpError: Or one of its subclasses if an error occurs while
    156         starting the TensorFlow server.
    157     """
    158     with errors.raise_exception_on_not_ok_status() as status:
    159       pywrap_tensorflow.PyServer_Start(self._server, status)
    160 
    161   def join(self):
    162     """Blocks until the server has shut down.
    163 
    164     This method currently blocks forever.
    165 
    166     Raises:
    167       tf.errors.OpError: Or one of its subclasses if an error occurs while
    168         joining the TensorFlow server.
    169     """
    170     with errors.raise_exception_on_not_ok_status() as status:
    171       pywrap_tensorflow.PyServer_Join(self._server, status)
    172 
    173   @property
    174   def server_def(self):
    175     """Returns the `tf.train.ServerDef` for this server.
    176 
    177     Returns:
    178       A `tf.train.ServerDef` protocol buffer that describes the configuration
    179       of this server.
    180     """
    181     return self._server_def
    182 
    183   @property
    184   def target(self):
    185     """Returns the target for a `tf.Session` to connect to this server.
    186 
    187     To create a
    188     @{tf.Session} that
    189     connects to this server, use the following snippet:
    190 
    191     ```python
    192     server = tf.train.Server(...)
    193     with tf.Session(server.target):
    194       # ...
    195     ```
    196 
    197     Returns:
    198       A string containing a session target for this server.
    199     """
    200     return self._server.target()
    201 
    202   @staticmethod
    203   def create_local_server(config=None, start=True):
    204     """Creates a new single-process cluster running on the local host.
    205 
    206     This method is a convenience wrapper for creating a
    207     `tf.train.Server` with a `tf.train.ServerDef` that specifies a
    208     single-process cluster containing a single task in a job called
    209     `"local"`.
    210 
    211     Args:
    212       config: (Options.) A `tf.ConfigProto` that specifies default
    213         configuration options for all sessions that run on this server.
    214       start: (Optional.) Boolean, indicating whether to start the server after
    215         creating it. Defaults to `True`.
    216 
    217     Returns:
    218       A local `tf.train.Server`.
    219     """
    220     # Specifying port 0 means that the OS will choose a free port for the
    221     # server.
    222     return Server({"local": ["localhost:0"]}, protocol="grpc", config=config,
    223                   start=start)
    224 
    225 
    226 @tf_export("train.ClusterSpec")
    227 class ClusterSpec(object):
    228   """Represents a cluster as a set of "tasks", organized into "jobs".
    229 
    230   A `tf.train.ClusterSpec` represents the set of processes that
    231   participate in a distributed TensorFlow computation. Every
    232   @{tf.train.Server} is constructed in a particular cluster.
    233 
    234   To create a cluster with two jobs and five tasks, you specify the
    235   mapping from job names to lists of network addresses (typically
    236   hostname-port pairs).
    237 
    238   ```python
    239   cluster = tf.train.ClusterSpec({"worker": ["worker0.example.com:2222",
    240                                              "worker1.example.com:2222",
    241                                              "worker2.example.com:2222"],
    242                                   "ps": ["ps0.example.com:2222",
    243                                          "ps1.example.com:2222"]})
    244   ```
    245 
    246   Each job may also be specified as a sparse mapping from task indices
    247   to network addresses. This enables a server to be configured without
    248   needing to know the identity of (for example) all other worker
    249   tasks:
    250 
    251   ```python
    252   cluster = tf.train.ClusterSpec({"worker": {1: "worker1.example.com:2222"},
    253                                   "ps": ["ps0.example.com:2222",
    254                                          "ps1.example.com:2222"]})
    255   ```
    256   """
    257 
    258   def __init__(self, cluster):
    259     """Creates a `ClusterSpec`.
    260 
    261     Args:
    262       cluster: A dictionary mapping one or more job names to (i) a
    263         list of network addresses, or (ii) a dictionary mapping integer
    264         task indices to network addresses; or a `tf.train.ClusterDef`
    265         protocol buffer.
    266 
    267     Raises:
    268       TypeError: If `cluster` is not a dictionary mapping strings to lists
    269         of strings, and not a `tf.train.ClusterDef` protobuf.
    270     """
    271     if isinstance(cluster, dict):
    272       self._cluster_spec = {}
    273       for job_name, tasks in cluster.items():
    274         if isinstance(tasks, (list, tuple)):
    275           job_tasks = {i: task for i, task in enumerate(tasks)}
    276         elif isinstance(tasks, dict):
    277           job_tasks = {i: task for i, task in tasks.items()}
    278         else:
    279           raise TypeError("The tasks for job %r must be a list or a dictionary "
    280                           "from integers to strings." % job_name)
    281         self._cluster_spec[job_name] = job_tasks
    282       self._make_cluster_def()
    283     elif isinstance(cluster, cluster_pb2.ClusterDef):
    284       self._cluster_def = cluster
    285       self._cluster_spec = {}
    286       for job_def in self._cluster_def.job:
    287         self._cluster_spec[job_def.name] = {
    288             i: t for i, t in job_def.tasks.items()}
    289     elif isinstance(cluster, ClusterSpec):
    290       self._cluster_def = cluster_pb2.ClusterDef()
    291       self._cluster_def.MergeFrom(cluster.as_cluster_def())
    292       self._cluster_spec = {}
    293       for job_def in self._cluster_def.job:
    294         self._cluster_spec[job_def.name] = {
    295             i: t for i, t in job_def.tasks.items()}
    296     else:
    297       raise TypeError("`cluster` must be a dictionary mapping one or more "
    298                       "job names to lists of network addresses, or a "
    299                       "`ClusterDef` protocol buffer")
    300 
    301   def __nonzero__(self):
    302     return bool(self._cluster_spec)
    303 
    304   # Python 3.x
    305   __bool__ = __nonzero__
    306 
    307   def __eq__(self, other):
    308     return self._cluster_spec == other
    309 
    310   def __ne__(self, other):
    311     return self._cluster_spec != other
    312 
    313   def __str__(self):
    314     key_values = self.as_dict()
    315     string_items = [
    316         repr(k) + ": " + repr(key_values[k]) for k in sorted(key_values)]
    317     return "ClusterSpec({" + ", ".join(string_items) + "})"
    318 
    319   def as_dict(self):
    320     """Returns a dictionary from job names to their tasks.
    321 
    322     For each job, if the task index space is dense, the corresponding
    323     value will be a list of network addresses; otherwise it will be a
    324     dictionary mapping (sparse) task indices to the corresponding
    325     addresses.
    326 
    327     Returns:
    328       A dictionary mapping job names to lists or dictionaries
    329       describing the tasks in those jobs.
    330     """
    331     ret = {}
    332     for job in self.jobs:
    333       task_indices = self.task_indices(job)
    334       if max(task_indices) + 1 == len(task_indices):
    335         # Return a list because the task indices are dense. This
    336         # matches the behavior of `as_dict()` before support for
    337         # sparse jobs was added.
    338         ret[job] = self.job_tasks(job)
    339       else:
    340         ret[job] = {i: self.task_address(job, i) for i in task_indices}
    341     return ret
    342 
    343   def as_cluster_def(self):
    344     """Returns a `tf.train.ClusterDef` protocol buffer based on this cluster."""
    345     return self._cluster_def
    346 
    347   @property
    348   def jobs(self):
    349     """Returns a list of job names in this cluster.
    350 
    351     Returns:
    352       A list of strings, corresponding to the names of jobs in this cluster.
    353     """
    354     return list(self._cluster_spec.keys())
    355 
    356   def num_tasks(self, job_name):
    357     """Returns the number of tasks defined in the given job.
    358 
    359     Args:
    360       job_name: The string name of a job in this cluster.
    361 
    362     Returns:
    363       The number of tasks defined in the given job.
    364 
    365     Raises:
    366       ValueError: If `job_name` does not name a job in this cluster.
    367     """
    368     try:
    369       job = self._cluster_spec[job_name]
    370     except KeyError:
    371       raise ValueError("No such job in cluster: %r" % job_name)
    372     return len(job)
    373 
    374   def task_indices(self, job_name):
    375     """Returns a list of valid task indices in the given job.
    376 
    377     Args:
    378       job_name: The string name of a job in this cluster.
    379 
    380     Returns:
    381       A list of valid task indices in the given job.
    382 
    383     Raises:
    384       ValueError: If `job_name` does not name a job in this cluster,
    385       or no task with index `task_index` is defined in that job.
    386     """
    387     try:
    388       job = self._cluster_spec[job_name]
    389     except KeyError:
    390       raise ValueError("No such job in cluster: %r" % job_name)
    391     return list(sorted(job.keys()))
    392 
    393   def task_address(self, job_name, task_index):
    394     """Returns the address of the given task in the given job.
    395 
    396     Args:
    397       job_name: The string name of a job in this cluster.
    398       task_index: A non-negative integer.
    399 
    400     Returns:
    401       The address of the given task in the given job.
    402 
    403     Raises:
    404       ValueError: If `job_name` does not name a job in this cluster,
    405       or no task with index `task_index` is defined in that job.
    406     """
    407     try:
    408       job = self._cluster_spec[job_name]
    409     except KeyError:
    410       raise ValueError("No such job in cluster: %r" % job_name)
    411     try:
    412       return job[task_index]
    413     except KeyError:
    414       raise ValueError("No task with index %r in job %r"
    415                        % (task_index, job_name))
    416 
    417   def job_tasks(self, job_name):
    418     """Returns a mapping from task ID to address in the given job.
    419 
    420     NOTE: For backwards compatibility, this method returns a list. If
    421     the given job was defined with a sparse set of task indices, the
    422     length of this list may not reflect the number of tasks defined in
    423     this job. Use the @{tf.train.ClusterSpec.num_tasks} method
    424     to find the number of tasks defined in a particular job.
    425 
    426     Args:
    427       job_name: The string name of a job in this cluster.
    428 
    429     Returns:
    430       A list of task addresses, where the index in the list
    431       corresponds to the task index of each task. The list may contain
    432       `None` if the job was defined with a sparse set of task indices.
    433 
    434     Raises:
    435       ValueError: If `job_name` does not name a job in this cluster.
    436     """
    437     try:
    438       job = self._cluster_spec[job_name]
    439     except KeyError:
    440       raise ValueError("No such job in cluster: %r" % job_name)
    441     ret = [None for _ in range(max(job.keys()) + 1)]
    442     for i, task in job.items():
    443       ret[i] = task
    444     return ret
    445 
    446   def _make_cluster_def(self):
    447     """Creates a `tf.train.ClusterDef` based on the given `cluster_spec`.
    448 
    449     Raises:
    450       TypeError: If `cluster_spec` is not a dictionary mapping strings to lists
    451         of strings.
    452     """
    453     self._cluster_def = cluster_pb2.ClusterDef()
    454 
    455     # NOTE(mrry): Sort by job_name to produce deterministic protobufs.
    456     for job_name, tasks in sorted(self._cluster_spec.items()):
    457       try:
    458         job_name = compat.as_bytes(job_name)
    459       except TypeError:
    460         raise TypeError("Job name %r must be bytes or unicode" % job_name)
    461 
    462       job_def = self._cluster_def.job.add()
    463       job_def.name = job_name
    464 
    465       for i, task_address in sorted(tasks.items()):
    466         try:
    467           task_address = compat.as_bytes(task_address)
    468         except TypeError:
    469           raise TypeError(
    470               "Task address %r must be bytes or unicode" % task_address)
    471         job_def.tasks[i] = task_address
    472