Home | History | Annotate | Download | only in scripts
      1 #!/usr/bin/python
      2 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      3 #
      4 # Licensed under the Apache License, Version 2.0 (the "License");
      5 # you may not use this file except in compliance with the License.
      6 # You may obtain a copy of the License at
      7 #
      8 #     http://www.apache.org/licenses/LICENSE-2.0
      9 #
     10 # Unless required by applicable law or agreed to in writing, software
     11 # distributed under the License is distributed on an "AS IS" BASIS,
     12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13 # See the License for the specific language governing permissions and
     14 # limitations under the License.
     15 # ==============================================================================
     16 
     17 """Generates YAML configuration files for distributed TensorFlow workers.
     18 
     19 The workers will be run in a Kubernetes (k8s) container cluster.
     20 """
     21 from __future__ import absolute_import
     22 from __future__ import division
     23 from __future__ import print_function
     24 
     25 import argparse
     26 import sys
     27 
     28 import k8s_tensorflow_lib
     29 
     30 # Note: It is intentional that we do not import tensorflow in this script. The
     31 # machine that launches a TensorFlow k8s cluster does not have to have the
     32 # Python package of TensorFlow installed on it.
     33 
     34 
     35 DEFAULT_DOCKER_IMAGE = 'tensorflow/tf_grpc_test_server'
     36 DEFAULT_PORT = 2222
     37 
     38 
     39 def main():
     40   """Do arg parsing."""
     41   parser = argparse.ArgumentParser()
     42   parser.register(
     43       'type', 'bool', lambda v: v.lower() in ('true', 't', 'y', 'yes'))
     44   parser.add_argument('--num_workers',
     45                       type=int,
     46                       default=2,
     47                       help='How many worker pods to run')
     48   parser.add_argument('--num_parameter_servers',
     49                       type=int,
     50                       default=1,
     51                       help='How many paramater server pods to run')
     52   parser.add_argument('--grpc_port',
     53                       type=int,
     54                       default=DEFAULT_PORT,
     55                       help='GRPC server port (Default: %d)' % DEFAULT_PORT)
     56   parser.add_argument('--request_load_balancer',
     57                       type='bool',
     58                       default=False,
     59                       help='To request worker0 to be exposed on a public IP '
     60                       'address via an external load balancer, enabling you to '
     61                       'run client processes from outside the cluster')
     62   parser.add_argument('--docker_image',
     63                       type=str,
     64                       default=DEFAULT_DOCKER_IMAGE,
     65                       help='Override default docker image for the TensorFlow '
     66                       'GRPC server')
     67   parser.add_argument('--name_prefix',
     68                       type=str,
     69                       default='tf',
     70                       help='Prefix for job names. Jobs will be named as '
     71                       '<name_prefix>_worker|ps<task_id>')
     72   parser.add_argument('--use_shared_volume',
     73                       type='bool',
     74                       default=True,
     75                       help='Whether to mount /shared directory from host to '
     76                       'the pod')
     77   args = parser.parse_args()
     78 
     79   if args.num_workers <= 0:
     80     sys.stderr.write('--num_workers must be greater than 0; received %d\n'
     81                      % args.num_workers)
     82     sys.exit(1)
     83   if args.num_parameter_servers <= 0:
     84     sys.stderr.write(
     85         '--num_parameter_servers must be greater than 0; received %d\n'
     86         % args.num_parameter_servers)
     87     sys.exit(1)
     88 
     89   # Generate contents of yaml config
     90   yaml_config = k8s_tensorflow_lib.GenerateConfig(
     91       args.num_workers,
     92       args.num_parameter_servers,
     93       args.grpc_port,
     94       args.request_load_balancer,
     95       args.docker_image,
     96       args.name_prefix,
     97       env_vars=None,
     98       use_shared_volume=args.use_shared_volume)
     99   print(yaml_config)  # pylint: disable=superfluous-parens
    100 
    101 
    102 if __name__ == '__main__':
    103   main()
    104