1 #!/usr/bin/python 2 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 # 4 # Licensed under the Apache License, Version 2.0 (the "License"); 5 # you may not use this file except in compliance with the License. 6 # You may obtain a copy of the License at 7 # 8 # http://www.apache.org/licenses/LICENSE-2.0 9 # 10 # Unless required by applicable law or agreed to in writing, software 11 # distributed under the License is distributed on an "AS IS" BASIS, 12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 # See the License for the specific language governing permissions and 14 # limitations under the License. 15 # ============================================================================== 16 """Python-based TensorFlow GRPC server. 17 18 Takes input arguments cluster_spec, job_name and task_id, and start a blocking 19 TensorFlow GRPC server. 20 21 Usage: 22 grpc_tensorflow_server.py --cluster_spec=SPEC --job_name=NAME --task_id=ID 23 24 Where: 25 SPEC is <JOB>(,<JOB>)* 26 JOB is <NAME>|<HOST:PORT>(;<HOST:PORT>)* 27 NAME is a valid job name ([a-z][0-9a-z]*) 28 HOST is a hostname or IP address 29 PORT is a port number 30 """ 31 32 from __future__ import absolute_import 33 from __future__ import division 34 from __future__ import print_function 35 36 import argparse 37 import sys 38 39 from tensorflow.core.protobuf import config_pb2 40 from tensorflow.core.protobuf import tensorflow_server_pb2 41 from tensorflow.python.platform import app 42 from tensorflow.python.training import server_lib 43 44 45 def parse_cluster_spec(cluster_spec, cluster, verbose=False): 46 """Parse content of cluster_spec string and inject info into cluster protobuf. 47 48 Args: 49 cluster_spec: cluster specification string, e.g., 50 "local|localhost:2222;localhost:2223" 51 cluster: cluster protobuf. 52 verbose: If verbose logging is requested. 53 54 Raises: 55 ValueError: if the cluster_spec string is invalid. 56 """ 57 58 job_strings = cluster_spec.split(",") 59 60 if not cluster_spec: 61 raise ValueError("Empty cluster_spec string") 62 63 for job_string in job_strings: 64 job_def = cluster.job.add() 65 66 if job_string.count("|") != 1: 67 raise ValueError("Not exactly one instance of '|' in cluster_spec") 68 69 job_name = job_string.split("|")[0] 70 71 if not job_name: 72 raise ValueError("Empty job_name in cluster_spec") 73 74 job_def.name = job_name 75 76 if verbose: 77 print("Added job named \"%s\"" % job_name) 78 79 job_tasks = job_string.split("|")[1].split(";") 80 for i in range(len(job_tasks)): 81 if not job_tasks[i]: 82 raise ValueError("Empty task string at position %d" % i) 83 84 job_def.tasks[i] = job_tasks[i] 85 86 if verbose: 87 print(" Added task \"%s\" to job \"%s\"" % (job_tasks[i], job_name)) 88 89 90 def main(unused_args): 91 # Create Protobuf ServerDef 92 server_def = tensorflow_server_pb2.ServerDef(protocol="grpc") 93 94 # Cluster info 95 parse_cluster_spec(FLAGS.cluster_spec, server_def.cluster, FLAGS.verbose) 96 97 # Job name 98 if not FLAGS.job_name: 99 raise ValueError("Empty job_name") 100 server_def.job_name = FLAGS.job_name 101 102 # Task index 103 if FLAGS.task_id < 0: 104 raise ValueError("Invalid task_id: %d" % FLAGS.task_id) 105 server_def.task_index = FLAGS.task_id 106 107 config = config_pb2.ConfigProto(gpu_options=config_pb2.GPUOptions( 108 per_process_gpu_memory_fraction=FLAGS.gpu_memory_fraction)) 109 110 # Create GRPC Server instance 111 server = server_lib.Server(server_def, config=config) 112 113 # join() is blocking, unlike start() 114 server.join() 115 116 117 if __name__ == "__main__": 118 parser = argparse.ArgumentParser() 119 parser.register("type", "bool", lambda v: v.lower() == "true") 120 parser.add_argument( 121 "--cluster_spec", 122 type=str, 123 default="", 124 help="""\ 125 Cluster spec: SPEC. SPEC is <JOB>(,<JOB>)*," JOB is 126 <NAME>|<HOST:PORT>(;<HOST:PORT>)*," NAME is a valid job name 127 ([a-z][0-9a-z]*)," HOST is a hostname or IP address," PORT is a 128 port number." E.g., local|localhost:2222;localhost:2223, 129 ps|ps0:2222;ps1:2222\ 130 """ 131 ) 132 parser.add_argument( 133 "--job_name", 134 type=str, 135 default="", 136 help="Job name: e.g., local" 137 ) 138 parser.add_argument( 139 "--task_id", 140 type=int, 141 default=0, 142 help="Task index, e.g., 0" 143 ) 144 parser.add_argument( 145 "--gpu_memory_fraction", 146 type=float, 147 default=1.0, 148 help="Fraction of GPU memory allocated",) 149 parser.add_argument( 150 "--verbose", 151 type="bool", 152 nargs="?", 153 const=True, 154 default=False, 155 help="Verbose mode" 156 ) 157 158 FLAGS, unparsed = parser.parse_known_args() 159 app.run(main=main, argv=[sys.argv[0]] + unparsed) 160