Home | History | Annotate | Download | only in server
      1 #!/usr/bin/python
      2 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      3 #
      4 # Licensed under the Apache License, Version 2.0 (the "License");
      5 # you may not use this file except in compliance with the License.
      6 # You may obtain a copy of the License at
      7 #
      8 #     http://www.apache.org/licenses/LICENSE-2.0
      9 #
     10 # Unless required by applicable law or agreed to in writing, software
     11 # distributed under the License is distributed on an "AS IS" BASIS,
     12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13 # See the License for the specific language governing permissions and
     14 # limitations under the License.
     15 # ==============================================================================
     16 """Python-based TensorFlow GRPC server.
     17 
     18 Takes input arguments cluster_spec, job_name and task_id, and start a blocking
     19 TensorFlow GRPC server.
     20 
     21 Usage:
     22     grpc_tensorflow_server.py --cluster_spec=SPEC --job_name=NAME --task_id=ID
     23 
     24 Where:
     25     SPEC is <JOB>(,<JOB>)*
     26     JOB  is <NAME>|<HOST:PORT>(;<HOST:PORT>)*
     27     NAME is a valid job name ([a-z][0-9a-z]*)
     28     HOST is a hostname or IP address
     29     PORT is a port number
     30 """
     31 
     32 from __future__ import absolute_import
     33 from __future__ import division
     34 from __future__ import print_function
     35 
     36 import argparse
     37 import sys
     38 
     39 from tensorflow.core.protobuf import config_pb2
     40 from tensorflow.core.protobuf import tensorflow_server_pb2
     41 from tensorflow.python.platform import app
     42 from tensorflow.python.training import server_lib
     43 
     44 
     45 def parse_cluster_spec(cluster_spec, cluster, verbose=False):
     46   """Parse content of cluster_spec string and inject info into cluster protobuf.
     47 
     48   Args:
     49     cluster_spec: cluster specification string, e.g.,
     50           "local|localhost:2222;localhost:2223"
     51     cluster: cluster protobuf.
     52     verbose: If verbose logging is requested.
     53 
     54   Raises:
     55     ValueError: if the cluster_spec string is invalid.
     56   """
     57 
     58   job_strings = cluster_spec.split(",")
     59 
     60   if not cluster_spec:
     61     raise ValueError("Empty cluster_spec string")
     62 
     63   for job_string in job_strings:
     64     job_def = cluster.job.add()
     65 
     66     if job_string.count("|") != 1:
     67       raise ValueError("Not exactly one instance of '|' in cluster_spec")
     68 
     69     job_name = job_string.split("|")[0]
     70 
     71     if not job_name:
     72       raise ValueError("Empty job_name in cluster_spec")
     73 
     74     job_def.name = job_name
     75 
     76     if verbose:
     77       print("Added job named \"%s\"" % job_name)
     78 
     79     job_tasks = job_string.split("|")[1].split(";")
     80     for i in range(len(job_tasks)):
     81       if not job_tasks[i]:
     82         raise ValueError("Empty task string at position %d" % i)
     83 
     84       job_def.tasks[i] = job_tasks[i]
     85 
     86       if verbose:
     87         print("  Added task \"%s\" to job \"%s\"" % (job_tasks[i], job_name))
     88 
     89 
     90 def main(unused_args):
     91   # Create Protobuf ServerDef
     92   server_def = tensorflow_server_pb2.ServerDef(protocol="grpc")
     93 
     94   # Cluster info
     95   parse_cluster_spec(FLAGS.cluster_spec, server_def.cluster, FLAGS.verbose)
     96 
     97   # Job name
     98   if not FLAGS.job_name:
     99     raise ValueError("Empty job_name")
    100   server_def.job_name = FLAGS.job_name
    101 
    102   # Task index
    103   if FLAGS.task_id < 0:
    104     raise ValueError("Invalid task_id: %d" % FLAGS.task_id)
    105   server_def.task_index = FLAGS.task_id
    106 
    107   config = config_pb2.ConfigProto(gpu_options=config_pb2.GPUOptions(
    108       per_process_gpu_memory_fraction=FLAGS.gpu_memory_fraction))
    109 
    110   # Create GRPC Server instance
    111   server = server_lib.Server(server_def, config=config)
    112 
    113   # join() is blocking, unlike start()
    114   server.join()
    115 
    116 
    117 if __name__ == "__main__":
    118   parser = argparse.ArgumentParser()
    119   parser.register("type", "bool", lambda v: v.lower() == "true")
    120   parser.add_argument(
    121       "--cluster_spec",
    122       type=str,
    123       default="",
    124       help="""\
    125       Cluster spec: SPEC.     SPEC is <JOB>(,<JOB>)*,"     JOB  is
    126       <NAME>|<HOST:PORT>(;<HOST:PORT>)*,"     NAME is a valid job name
    127       ([a-z][0-9a-z]*),"     HOST is a hostname or IP address,"     PORT is a
    128       port number." E.g., local|localhost:2222;localhost:2223,
    129       ps|ps0:2222;ps1:2222\
    130       """
    131   )
    132   parser.add_argument(
    133       "--job_name",
    134       type=str,
    135       default="",
    136       help="Job name: e.g., local"
    137   )
    138   parser.add_argument(
    139       "--task_id",
    140       type=int,
    141       default=0,
    142       help="Task index, e.g., 0"
    143   )
    144   parser.add_argument(
    145       "--gpu_memory_fraction",
    146       type=float,
    147       default=1.0,
    148       help="Fraction of GPU memory allocated",)
    149   parser.add_argument(
    150       "--verbose",
    151       type="bool",
    152       nargs="?",
    153       const=True,
    154       default=False,
    155       help="Verbose mode"
    156   )
    157 
    158   FLAGS, unparsed = parser.parse_known_args()
    159   app.run(main=main, argv=[sys.argv[0]] + unparsed)
    160