1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for tensorflow.tools.dist_test.scripts.k8s_tensorflow_lib.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.python.platform import googletest 22 from tensorflow.tools.dist_test.scripts import k8s_tensorflow_lib 23 24 25 class K8sTensorflowTest(googletest.TestCase): 26 27 def testGenerateConfig_LoadBalancer(self): 28 # Use loadbalancer 29 config = k8s_tensorflow_lib.GenerateConfig( 30 num_workers=1, 31 num_param_servers=1, 32 port=5000, 33 request_load_balancer=True, 34 docker_image='test_image', 35 name_prefix='abc', 36 use_shared_volume=False) 37 self.assertTrue('LoadBalancer' in config) 38 39 # Don't use loadbalancer 40 config = k8s_tensorflow_lib.GenerateConfig( 41 num_workers=1, 42 num_param_servers=1, 43 port=5000, 44 request_load_balancer=False, 45 docker_image='test_image', 46 name_prefix='abc', 47 use_shared_volume=False) 48 self.assertFalse('LoadBalancer' in config) 49 50 def testGenerateConfig_SharedVolume(self): 51 # Use shared directory 52 config = k8s_tensorflow_lib.GenerateConfig( 53 num_workers=1, 54 num_param_servers=1, 55 port=5000, 56 request_load_balancer=False, 57 docker_image='test_image', 58 name_prefix='abc', 59 use_shared_volume=True) 60 self.assertTrue('/shared' in config) 61 62 # Don't use shared directory 63 config = k8s_tensorflow_lib.GenerateConfig( 64 num_workers=1, 65 num_param_servers=1, 66 port=5000, 67 request_load_balancer=False, 68 docker_image='test_image', 69 name_prefix='abc', 70 use_shared_volume=False) 71 self.assertFalse('/shared' in config) 72 73 def testEnvVar(self): 74 # Use loadbalancer 75 config = k8s_tensorflow_lib.GenerateConfig( 76 num_workers=1, 77 num_param_servers=1, 78 port=5000, 79 request_load_balancer=True, 80 docker_image='test_image', 81 name_prefix='abc', 82 use_shared_volume=False, 83 env_vars={'test1': 'test1_value', 'test2': 'test2_value'}) 84 self.assertTrue('{name: "test1", value: "test1_value"}' in config) 85 self.assertTrue('{name: "test2", value: "test2_value"}' in config) 86 87 def testClusterSpec(self): 88 # Use cluster_spec 89 config = k8s_tensorflow_lib.GenerateConfig( 90 num_workers=1, 91 num_param_servers=1, 92 port=5000, 93 request_load_balancer=True, 94 docker_image='test_image', 95 name_prefix='abc', 96 use_shared_volume=False, 97 use_cluster_spec=True) 98 self.assertFalse('worker_hosts' in config) 99 self.assertFalse('ps_hosts' in config) 100 self.assertTrue( 101 '"--cluster_spec=worker|abc-worker0:5000,ps|abc-ps0:5000"' in config) 102 103 # Don't use cluster_spec 104 config = k8s_tensorflow_lib.GenerateConfig( 105 num_workers=1, 106 num_param_servers=1, 107 port=5000, 108 request_load_balancer=True, 109 docker_image='test_image', 110 name_prefix='abc', 111 use_shared_volume=False, 112 use_cluster_spec=False) 113 self.assertFalse('cluster_spec' in config) 114 self.assertTrue('"--worker_hosts=abc-worker0:5000"' in config) 115 self.assertTrue('"--ps_hosts=abc-ps0:5000"' in config) 116 117 def testWorkerHosts(self): 118 self.assertEquals( 119 'test_prefix-worker0:1234', 120 k8s_tensorflow_lib.WorkerHosts(1, 1234, 'test_prefix')) 121 self.assertEquals( 122 'test_prefix-worker0:1234,test_prefix-worker1:1234', 123 k8s_tensorflow_lib.WorkerHosts(2, 1234, 'test_prefix')) 124 125 def testPsHosts(self): 126 self.assertEquals( 127 'test_prefix-ps0:1234,test_prefix-ps1:1234', 128 k8s_tensorflow_lib.PsHosts(2, 1234, 'test_prefix')) 129 130 131 if __name__ == '__main__': 132 googletest.main() 133