Home | History | Annotate | Download | only in training
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 # http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Implementation of Cluster Resolvers for Cloud TPUs."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 
     22 from six.moves.urllib.request import Request
     23 from six.moves.urllib.request import urlopen
     24 
     25 from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver
     26 from tensorflow.python.training.server_lib import ClusterSpec
     27 
     28 _GOOGLE_API_CLIENT_INSTALLED = True
     29 try:
     30   from googleapiclient import discovery  # pylint: disable=g-import-not-at-top
     31   from oauth2client.client import GoogleCredentials  # pylint: disable=g-import-not-at-top
     32 except ImportError:
     33   _GOOGLE_API_CLIENT_INSTALLED = False
     34 
     35 
     36 class TPUClusterResolver(ClusterResolver):
     37   """Cluster Resolver for Google Cloud TPUs.
     38 
     39   This is an implementation of cluster resolvers for the Google Cloud TPU
     40   service. As Cloud TPUs are in alpha, you will need to specify a API definition
     41   file for this to consume, in addition to a list of Cloud TPUs in your Google
     42   Cloud Platform project.
     43   """
     44 
     45   def _requestComputeMetadata(self, path):
     46     req = Request('http://metadata/computeMetadata/v1/%s' % path,
     47                   headers={'Metadata-Flavor': 'Google'})
     48     resp = urlopen(req)
     49     return resp.read()
     50 
     51   def __init__(self,
     52                tpu_names,
     53                zone=None,
     54                project=None,
     55                job_name='tpu_worker',
     56                credentials='default',
     57                service=None):
     58     """Creates a new TPUClusterResolver object.
     59 
     60     The ClusterResolver will then use the parameters to query the Cloud TPU APIs
     61     for the IP addresses and ports of each Cloud TPU listed.
     62 
     63     Args:
     64       tpu_names: A list of names of the target Cloud TPUs.
     65       zone: Zone where the TPUs are located. If omitted or empty, we will assume
     66         that the zone of the TPU is the same as the zone of the GCE VM, which we
     67         will try to discover from the GCE metadata service.
     68       project: Name of the GCP project containing Cloud TPUs. If omitted or
     69         empty, we will try to discover the project name of the GCE VM from the
     70         GCE metadata service.
     71       job_name: Name of the TensorFlow job the TPUs belong to.
     72       credentials: GCE Credentials. If None, then we use default credentials
     73         from the oauth2client
     74       service: The GCE API object returned by the googleapiclient.discovery
     75         function. If you specify a custom service object, then the credentials
     76         parameter will be ignored.
     77 
     78     Raises:
     79       ImportError: If the googleapiclient is not installed.
     80     """
     81 
     82     if not project:
     83       project = self._requestComputeMetadata('/project/project-id')
     84 
     85     if not zone:
     86       zone_path = self._requestComputeMetadata('/instance/zone')
     87       zone = zone_path.split('/')[-1]
     88 
     89     self._project = project
     90     self._zone = zone
     91     self._tpu_names = tpu_names
     92     self._job_name = job_name
     93     self._credentials = credentials
     94 
     95     if credentials == 'default':
     96       if _GOOGLE_API_CLIENT_INSTALLED:
     97         self._credentials = GoogleCredentials.get_application_default()
     98 
     99     if service is None:
    100       if not _GOOGLE_API_CLIENT_INSTALLED:
    101         raise ImportError('googleapiclient must be installed before using the '
    102                           'TPU cluster resolver')
    103 
    104       self._service = discovery.build(
    105           'tpu', 'v1alpha1',
    106           credentials=self._credentials)
    107     else:
    108       self._service = service
    109 
    110   def get_master(self):
    111     """Get the ClusterSpec grpc master path.
    112 
    113     This returns the grpc path (grpc://1.2.3.4:8470) of first instance in the
    114     ClusterSpec returned by the cluster_spec function. This is suitable for use
    115     for the `master` argument in tf.Session() when you are using one TPU.
    116 
    117     Returns:
    118       string, the grpc path of the first instance in the ClusterSpec.
    119 
    120     Raises:
    121       ValueError: If none of the TPUs specified exists.
    122     """
    123     job_tasks = self.cluster_spec().job_tasks(self._job_name)
    124     if not job_tasks:
    125       raise ValueError('No TPUs exists with the specified names exist.')
    126 
    127     return 'grpc://' + job_tasks[0]
    128 
    129   def cluster_spec(self):
    130     """Returns a ClusterSpec object based on the latest TPU information.
    131 
    132     We retrieve the information from the GCE APIs every time this method is
    133     called.
    134 
    135     Returns:
    136       A ClusterSpec containing host information returned from Cloud TPUs.
    137     """
    138     worker_list = []
    139 
    140     for tpu_name in self._tpu_names:
    141       full_name = 'projects/%s/locations/%s/nodes/%s' % (
    142           self._project, self._zone, tpu_name)
    143       request = self._service.projects().locations().nodes().get(name=full_name)
    144       response = request.execute()
    145 
    146       if 'health' in response and response['health'] == 'HEALTHY':
    147         instance_url = '%s:%s' % (response['ipAddress'], response['port'])
    148         worker_list.append(instance_url)
    149 
    150     return ClusterSpec({self._job_name: worker_list})
    151