Home | History | Annotate | Download | only in framework
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 """Class to represent a device."""
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import copy
     22 from tensorflow.python.util.tf_export import tf_export
     23 
     24 
     25 @tf_export("DeviceSpec")
     26 class DeviceSpec(object):
     27   """Represents a (possibly partial) specification for a TensorFlow device.
     28 
     29   `DeviceSpec`s are used throughout TensorFlow to describe where state is stored
     30   and computations occur. Using `DeviceSpec` allows you to parse device spec
     31   strings to verify their validity, merge them or compose them programmatically.
     32 
     33   Example:
     34 
     35   ```python
     36   # Place the operations on device "GPU:0" in the "ps" job.
     37   device_spec = DeviceSpec(job="ps", device_type="GPU", device_index=0)
     38   with tf.device(device_spec):
     39     # Both my_var and squared_var will be placed on /job:ps/device:GPU:0.
     40     my_var = tf.Variable(..., name="my_variable")
     41     squared_var = tf.square(my_var)
     42   ```
     43 
     44   If a `DeviceSpec` is partially specified, it will be merged with other
     45   `DeviceSpec`s according to the scope in which it is defined. `DeviceSpec`
     46   components defined in inner scopes take precedence over those defined in
     47   outer scopes.
     48 
     49   ```python
     50   with tf.device(DeviceSpec(job="train", )):
     51     with tf.device(DeviceSpec(job="ps", device_type="GPU", device_index=0):
     52       # Nodes created here will be assigned to /job:ps/device:GPU:0.
     53     with tf.device(DeviceSpec(device_type="GPU", device_index=1):
     54       # Nodes created here will be assigned to /job:train/device:GPU:1.
     55   ```
     56 
     57   A `DeviceSpec` consists of 5 components -- each of
     58   which is optionally specified:
     59 
     60   * Job: The job name.
     61   * Replica: The replica index.
     62   * Task: The task index.
     63   * Device type: The device type string (e.g. "CPU" or "GPU").
     64   * Device index: The device index.
     65   """
     66 
     67   def __init__(self, job=None, replica=None, task=None, device_type=None,
     68                device_index=None):
     69     """Create a new `DeviceSpec` object.
     70 
     71     Args:
     72       job: string.  Optional job name.
     73       replica: int.  Optional replica index.
     74       task: int.  Optional task index.
     75       device_type: Optional device type string (e.g. "CPU" or "GPU")
     76       device_index: int.  Optional device index.  If left
     77         unspecified, device represents 'any' device_index.
     78     """
     79     self.job = job
     80     self.replica = replica
     81     self.task = task
     82     if device_type == "cpu" or device_type == "gpu":
     83       # For backwards compatibility only, we support lowercase variants of
     84       # cpu and gpu but turn them into uppercase here.
     85       self.device_type = device_type.upper()
     86     else:
     87       self.device_type = device_type
     88     self.device_index = device_index
     89 
     90   def _clear(self):
     91     self._job = None
     92     self._replica = None
     93     self._task = None
     94     self.device_type = None
     95     self.device_index = None
     96 
     97   @property
     98   def job(self):
     99     return self._job
    100 
    101   @job.setter
    102   def job(self, job):
    103     if job is not None:
    104       self._job = str(job)
    105     else:
    106       self._job = None
    107 
    108   @property
    109   def replica(self):
    110     return self._replica
    111 
    112   @replica.setter
    113   def replica(self, replica):
    114     if replica is not None:
    115       self._replica = int(replica)
    116     else:
    117       self._replica = None
    118 
    119   @property
    120   def task(self):
    121     return self._task
    122 
    123   @task.setter
    124   def task(self, task):
    125     if task is not None:
    126       self._task = int(task)
    127     else:
    128       self._task = None
    129 
    130   def parse_from_string(self, spec):
    131     """Parse a `DeviceSpec` name into its components.
    132 
    133     Args:
    134       spec: a string of the form
    135        /job:<name>/replica:<id>/task:<id>/device:CPU:<id>
    136       or
    137        /job:<name>/replica:<id>/task:<id>/device:GPU:<id>
    138       as cpu and gpu are mutually exclusive.
    139       All entries are optional.
    140 
    141     Returns:
    142       The `DeviceSpec`.
    143 
    144     Raises:
    145       ValueError: if the spec was not valid.
    146     """
    147     self._clear()
    148     splits = [x.split(":") for x in spec.split("/")]
    149     for y in splits:
    150       ly = len(y)
    151       if y:
    152         # NOTE(touts): we use the property getters here.
    153         if ly == 2 and y[0] == "job":
    154           self.job = y[1]
    155         elif ly == 2 and y[0] == "replica":
    156           self.replica = y[1]
    157         elif ly == 2 and y[0] == "task":
    158           self.task = y[1]
    159         elif ((ly == 1 or ly == 2) and
    160               ((y[0].upper() == "GPU") or (y[0].upper() == "CPU"))):
    161           if self.device_type is not None:
    162             raise ValueError("Cannot specify multiple device types: %s" % spec)
    163           self.device_type = y[0].upper()
    164           if ly == 2 and y[1] != "*":
    165             self.device_index = int(y[1])
    166         elif ly == 3 and y[0] == "device":
    167           if self.device_type is not None:
    168             raise ValueError("Cannot specify multiple device types: %s" % spec)
    169           self.device_type = y[1]
    170           if y[2] != "*":
    171             self.device_index = int(y[2])
    172         elif ly and y[0] != "":  # pylint: disable=g-explicit-bool-comparison
    173           raise ValueError("Unknown attribute: '%s' in '%s'" % (y[0], spec))
    174 
    175     return self
    176 
    177   def merge_from(self, dev):
    178     """Merge the properties of "dev" into this `DeviceSpec`.
    179 
    180     Args:
    181       dev: a `DeviceSpec`.
    182     """
    183     if dev.job is not None:
    184       self.job = dev.job
    185     if dev.replica is not None:
    186       self.replica = dev.replica
    187     if dev.task is not None:
    188       self.task = dev.task
    189     if dev.device_type is not None:
    190       self.device_type = dev.device_type
    191     if dev.device_index is not None:
    192       self.device_index = dev.device_index
    193 
    194   def to_string(self):
    195     """Return a string representation of this `DeviceSpec`.
    196 
    197     Returns:
    198       a string of the form
    199       /job:<name>/replica:<id>/task:<id>/device:<device_type>:<id>.
    200     """
    201     dev = ""
    202     if self.job is not None:
    203       dev += "/job:" + self.job
    204     if self.replica is not None:
    205       dev += "/replica:" + str(self.replica)
    206     if self.task is not None:
    207       dev += "/task:" + str(self.task)
    208     if self.device_type is not None:
    209       device_index_string = "*"
    210       if self.device_index is not None:
    211         device_index_string = str(self.device_index)
    212       dev += "/device:%s:%s" % (self.device_type, device_index_string)
    213     return dev
    214 
    215   @staticmethod
    216   def from_string(spec):
    217     """Construct a `DeviceSpec` from a string.
    218 
    219     Args:
    220       spec: a string of the form
    221        /job:<name>/replica:<id>/task:<id>/device:CPU:<id>
    222       or
    223        /job:<name>/replica:<id>/task:<id>/device:GPU:<id>
    224       as cpu and gpu are mutually exclusive.
    225       All entries are optional.
    226 
    227     Returns:
    228       A DeviceSpec.
    229     """
    230     return DeviceSpec().parse_from_string(spec)
    231 
    232 
    233 def check_valid(spec):
    234   """Check that a device spec is valid.
    235 
    236   Args:
    237     spec: a string.
    238 
    239   Raises:
    240     An exception if the spec is invalid.
    241   """
    242   # Construct a DeviceSpec.  It will assert a failure if spec is invalid.
    243   DeviceSpec.from_string(spec)
    244 
    245 
    246 def canonical_name(device):
    247   """Returns a canonical name for the given `DeviceSpec` or device name."""
    248   if device is None:
    249     return ""
    250   if isinstance(device, DeviceSpec):
    251     return device.to_string()
    252   else:
    253     device = DeviceSpec.from_string(device)
    254     return device.to_string()
    255 
    256 
    257 def merge_device(spec):
    258   """Returns a device function that merges devices specifications.
    259 
    260   This can be used to merge partial specifications of devices. The
    261   innermost setting for a device field takes precedence. For example:
    262 
    263     with tf.device(merge_device("/device:GPU:0"))
    264       # Nodes created here have device "/device:GPU:0"
    265       with tf.device(merge_device("/job:worker")):
    266         # Nodes created here have device "/job:worker/device:GPU:0"
    267         with tf.device(merge_device("/device:CPU:0")):
    268           # Nodes created here have device "/job:worker/device:CPU:0"
    269           with tf.device(merge_device("/job:ps")):
    270             # Nodes created here have device "/job:ps/device:CPU:0"
    271 
    272   Args:
    273     spec: A `DeviceSpec` or a device spec string (partially) describing the
    274       device that should be used for all nodes created in the scope of
    275       the returned device function's with block.
    276 
    277   Returns:
    278     A device function with the above-described behavior.
    279 
    280   Raises:
    281     ValueError: if the spec was not valid.
    282   """
    283   if not isinstance(spec, DeviceSpec):
    284     spec = DeviceSpec.from_string(spec or "")
    285   def _device_function(node_def):
    286     current_device = DeviceSpec.from_string(node_def.device or "")
    287     copy_spec = copy.copy(spec)
    288     copy_spec.merge_from(current_device)  # current_device takes precedence.
    289     return copy_spec
    290   return _device_function
    291