1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Implementation of Cluster Resolvers for Cloud TPUs.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 22 from six.moves.urllib.request import Request 23 from six.moves.urllib.request import urlopen 24 25 from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver 26 from tensorflow.python.training.server_lib import ClusterSpec 27 28 _GOOGLE_API_CLIENT_INSTALLED = True 29 try: 30 from googleapiclient import discovery # pylint: disable=g-import-not-at-top 31 from oauth2client.client import GoogleCredentials # pylint: disable=g-import-not-at-top 32 except ImportError: 33 _GOOGLE_API_CLIENT_INSTALLED = False 34 35 36 class TPUClusterResolver(ClusterResolver): 37 """Cluster Resolver for Google Cloud TPUs. 38 39 This is an implementation of cluster resolvers for the Google Cloud TPU 40 service. As Cloud TPUs are in alpha, you will need to specify a API definition 41 file for this to consume, in addition to a list of Cloud TPUs in your Google 42 Cloud Platform project. 43 """ 44 45 def _requestComputeMetadata(self, path): 46 req = Request('http://metadata/computeMetadata/v1/%s' % path, 47 headers={'Metadata-Flavor': 'Google'}) 48 resp = urlopen(req) 49 return resp.read() 50 51 def __init__(self, 52 tpu_names, 53 zone=None, 54 project=None, 55 job_name='tpu_worker', 56 credentials='default', 57 service=None): 58 """Creates a new TPUClusterResolver object. 59 60 The ClusterResolver will then use the parameters to query the Cloud TPU APIs 61 for the IP addresses and ports of each Cloud TPU listed. 62 63 Args: 64 tpu_names: A list of names of the target Cloud TPUs. 65 zone: Zone where the TPUs are located. If omitted or empty, we will assume 66 that the zone of the TPU is the same as the zone of the GCE VM, which we 67 will try to discover from the GCE metadata service. 68 project: Name of the GCP project containing Cloud TPUs. If omitted or 69 empty, we will try to discover the project name of the GCE VM from the 70 GCE metadata service. 71 job_name: Name of the TensorFlow job the TPUs belong to. 72 credentials: GCE Credentials. If None, then we use default credentials 73 from the oauth2client 74 service: The GCE API object returned by the googleapiclient.discovery 75 function. If you specify a custom service object, then the credentials 76 parameter will be ignored. 77 78 Raises: 79 ImportError: If the googleapiclient is not installed. 80 """ 81 82 if not project: 83 project = self._requestComputeMetadata('/project/project-id') 84 85 if not zone: 86 zone_path = self._requestComputeMetadata('/instance/zone') 87 zone = zone_path.split('/')[-1] 88 89 self._project = project 90 self._zone = zone 91 self._tpu_names = tpu_names 92 self._job_name = job_name 93 self._credentials = credentials 94 95 if credentials == 'default': 96 if _GOOGLE_API_CLIENT_INSTALLED: 97 self._credentials = GoogleCredentials.get_application_default() 98 99 if service is None: 100 if not _GOOGLE_API_CLIENT_INSTALLED: 101 raise ImportError('googleapiclient must be installed before using the ' 102 'TPU cluster resolver') 103 104 self._service = discovery.build( 105 'tpu', 'v1alpha1', 106 credentials=self._credentials) 107 else: 108 self._service = service 109 110 def get_master(self): 111 """Get the ClusterSpec grpc master path. 112 113 This returns the grpc path (grpc://1.2.3.4:8470) of first instance in the 114 ClusterSpec returned by the cluster_spec function. This is suitable for use 115 for the `master` argument in tf.Session() when you are using one TPU. 116 117 Returns: 118 string, the grpc path of the first instance in the ClusterSpec. 119 120 Raises: 121 ValueError: If none of the TPUs specified exists. 122 """ 123 job_tasks = self.cluster_spec().job_tasks(self._job_name) 124 if not job_tasks: 125 raise ValueError('No TPUs exists with the specified names exist.') 126 127 return 'grpc://' + job_tasks[0] 128 129 def cluster_spec(self): 130 """Returns a ClusterSpec object based on the latest TPU information. 131 132 We retrieve the information from the GCE APIs every time this method is 133 called. 134 135 Returns: 136 A ClusterSpec containing host information returned from Cloud TPUs. 137 """ 138 worker_list = [] 139 140 for tpu_name in self._tpu_names: 141 full_name = 'projects/%s/locations/%s/nodes/%s' % ( 142 self._project, self._zone, tpu_name) 143 request = self._service.projects().locations().nodes().get(name=full_name) 144 response = request.execute() 145 146 if 'health' in response and response['health'] == 'HEALTHY': 147 instance_url = '%s:%s' % (response['ipAddress'], response['port']) 148 worker_list.append(instance_url) 149 150 return ClusterSpec({self._job_name: worker_list}) 151