Home | History | Annotate | Download | only in scripts
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for tensorflow.tools.dist_test.scripts.k8s_tensorflow_lib."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.python.platform import googletest
     22 from tensorflow.tools.dist_test.scripts import k8s_tensorflow_lib
     23 
     24 
     25 class K8sTensorflowTest(googletest.TestCase):
     26 
     27   def testGenerateConfig_LoadBalancer(self):
     28     # Use loadbalancer
     29     config = k8s_tensorflow_lib.GenerateConfig(
     30         num_workers=1,
     31         num_param_servers=1,
     32         port=5000,
     33         request_load_balancer=True,
     34         docker_image='test_image',
     35         name_prefix='abc',
     36         use_shared_volume=False)
     37     self.assertTrue('LoadBalancer' in config)
     38 
     39     # Don't use loadbalancer
     40     config = k8s_tensorflow_lib.GenerateConfig(
     41         num_workers=1,
     42         num_param_servers=1,
     43         port=5000,
     44         request_load_balancer=False,
     45         docker_image='test_image',
     46         name_prefix='abc',
     47         use_shared_volume=False)
     48     self.assertFalse('LoadBalancer' in config)
     49 
     50   def testGenerateConfig_SharedVolume(self):
     51     # Use shared directory
     52     config = k8s_tensorflow_lib.GenerateConfig(
     53         num_workers=1,
     54         num_param_servers=1,
     55         port=5000,
     56         request_load_balancer=False,
     57         docker_image='test_image',
     58         name_prefix='abc',
     59         use_shared_volume=True)
     60     self.assertTrue('/shared' in config)
     61 
     62     # Don't use shared directory
     63     config = k8s_tensorflow_lib.GenerateConfig(
     64         num_workers=1,
     65         num_param_servers=1,
     66         port=5000,
     67         request_load_balancer=False,
     68         docker_image='test_image',
     69         name_prefix='abc',
     70         use_shared_volume=False)
     71     self.assertFalse('/shared' in config)
     72 
     73   def testEnvVar(self):
     74     # Use loadbalancer
     75     config = k8s_tensorflow_lib.GenerateConfig(
     76         num_workers=1,
     77         num_param_servers=1,
     78         port=5000,
     79         request_load_balancer=True,
     80         docker_image='test_image',
     81         name_prefix='abc',
     82         use_shared_volume=False,
     83         env_vars={'test1': 'test1_value', 'test2': 'test2_value'})
     84     self.assertTrue('{name: "test1", value: "test1_value"}' in config)
     85     self.assertTrue('{name: "test2", value: "test2_value"}' in config)
     86 
     87   def testClusterSpec(self):
     88     # Use cluster_spec
     89     config = k8s_tensorflow_lib.GenerateConfig(
     90         num_workers=1,
     91         num_param_servers=1,
     92         port=5000,
     93         request_load_balancer=True,
     94         docker_image='test_image',
     95         name_prefix='abc',
     96         use_shared_volume=False,
     97         use_cluster_spec=True)
     98     self.assertFalse('worker_hosts' in config)
     99     self.assertFalse('ps_hosts' in config)
    100     self.assertTrue(
    101         '"--cluster_spec=worker|abc-worker0:5000,ps|abc-ps0:5000"' in config)
    102 
    103     # Don't use cluster_spec
    104     config = k8s_tensorflow_lib.GenerateConfig(
    105         num_workers=1,
    106         num_param_servers=1,
    107         port=5000,
    108         request_load_balancer=True,
    109         docker_image='test_image',
    110         name_prefix='abc',
    111         use_shared_volume=False,
    112         use_cluster_spec=False)
    113     self.assertFalse('cluster_spec' in config)
    114     self.assertTrue('"--worker_hosts=abc-worker0:5000"' in config)
    115     self.assertTrue('"--ps_hosts=abc-ps0:5000"' in config)
    116 
    117   def testWorkerHosts(self):
    118     self.assertEquals(
    119         'test_prefix-worker0:1234',
    120         k8s_tensorflow_lib.WorkerHosts(1, 1234, 'test_prefix'))
    121     self.assertEquals(
    122         'test_prefix-worker0:1234,test_prefix-worker1:1234',
    123         k8s_tensorflow_lib.WorkerHosts(2, 1234, 'test_prefix'))
    124 
    125   def testPsHosts(self):
    126     self.assertEquals(
    127         'test_prefix-ps0:1234,test_prefix-ps1:1234',
    128         k8s_tensorflow_lib.PsHosts(2, 1234, 'test_prefix'))
    129 
    130 
    131 if __name__ == '__main__':
    132   googletest.main()
    133