1 #!/usr/bin/python 2 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 # 4 # Licensed under the Apache License, Version 2.0 (the "License"); 5 # you may not use this file except in compliance with the License. 6 # You may obtain a copy of the License at 7 # 8 # http://www.apache.org/licenses/LICENSE-2.0 9 # 10 # Unless required by applicable law or agreed to in writing, software 11 # distributed under the License is distributed on an "AS IS" BASIS, 12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 # See the License for the specific language governing permissions and 14 # limitations under the License. 15 # ============================================================================== 16 17 """Generates YAML configuration files for distributed TensorFlow workers. 18 19 The workers will be run in a Kubernetes (k8s) container cluster. 20 """ 21 from __future__ import absolute_import 22 from __future__ import division 23 from __future__ import print_function 24 25 import argparse 26 import sys 27 28 import k8s_tensorflow_lib 29 30 # Note: It is intentional that we do not import tensorflow in this script. The 31 # machine that launches a TensorFlow k8s cluster does not have to have the 32 # Python package of TensorFlow installed on it. 33 34 35 DEFAULT_DOCKER_IMAGE = 'tensorflow/tf_grpc_test_server' 36 DEFAULT_PORT = 2222 37 38 39 def main(): 40 """Do arg parsing.""" 41 parser = argparse.ArgumentParser() 42 parser.register( 43 'type', 'bool', lambda v: v.lower() in ('true', 't', 'y', 'yes')) 44 parser.add_argument('--num_workers', 45 type=int, 46 default=2, 47 help='How many worker pods to run') 48 parser.add_argument('--num_parameter_servers', 49 type=int, 50 default=1, 51 help='How many paramater server pods to run') 52 parser.add_argument('--grpc_port', 53 type=int, 54 default=DEFAULT_PORT, 55 help='GRPC server port (Default: %d)' % DEFAULT_PORT) 56 parser.add_argument('--request_load_balancer', 57 type='bool', 58 default=False, 59 help='To request worker0 to be exposed on a public IP ' 60 'address via an external load balancer, enabling you to ' 61 'run client processes from outside the cluster') 62 parser.add_argument('--docker_image', 63 type=str, 64 default=DEFAULT_DOCKER_IMAGE, 65 help='Override default docker image for the TensorFlow ' 66 'GRPC server') 67 parser.add_argument('--name_prefix', 68 type=str, 69 default='tf', 70 help='Prefix for job names. Jobs will be named as ' 71 '<name_prefix>_worker|ps<task_id>') 72 parser.add_argument('--use_shared_volume', 73 type='bool', 74 default=True, 75 help='Whether to mount /shared directory from host to ' 76 'the pod') 77 args = parser.parse_args() 78 79 if args.num_workers <= 0: 80 sys.stderr.write('--num_workers must be greater than 0; received %d\n' 81 % args.num_workers) 82 sys.exit(1) 83 if args.num_parameter_servers <= 0: 84 sys.stderr.write( 85 '--num_parameter_servers must be greater than 0; received %d\n' 86 % args.num_parameter_servers) 87 sys.exit(1) 88 89 # Generate contents of yaml config 90 yaml_config = k8s_tensorflow_lib.GenerateConfig( 91 args.num_workers, 92 args.num_parameter_servers, 93 args.grpc_port, 94 args.request_load_balancer, 95 args.docker_image, 96 args.name_prefix, 97 env_vars=None, 98 use_shared_volume=args.use_shared_volume) 99 print(yaml_config) # pylint: disable=superfluous-parens 100 101 102 if __name__ == '__main__': 103 main() 104