Home | History | Annotate | Download | only in hosts
      1 #!/usr/bin/python
      2 # Copyright 2016 The Chromium OS Authors. All rights reserved.
      3 # Use of this source code is governed by a BSD-style license that can be
      4 # found in the LICENSE file.
      5 
      6 import mock
      7 import unittest
      8 
      9 import common
     10 from autotest_lib.client.common_lib import error
     11 from autotest_lib.server.hosts import base_label_unittest, factory
     12 from autotest_lib.server.hosts import host_info
     13 
     14 
     15 class MockHost(object):
     16     """Mock host object with no side effects."""
     17     def __init__(self, hostname, **args):
     18         self._init_args = args
     19         self._init_args['hostname'] = hostname
     20 
     21 
     22     def job_start(self):
     23         """Only method called by factory."""
     24         pass
     25 
     26 
     27 class MockConnectivity(object):
     28     """Mock connectivity object with no side effects."""
     29     def __init__(self, hostname, **args):
     30         pass
     31 
     32 
     33     def close(self):
     34         """Only method called by factory."""
     35         pass
     36 
     37 
     38 def _gen_mock_host(name, check_host=False):
     39     """Create an identifiable mock host closs.
     40     """
     41     return type('mock_host_%s' % name, (MockHost,), {
     42         '_host_cls_name': name,
     43         'check_host': staticmethod(lambda host, timeout=None: check_host)
     44     })
     45 
     46 
     47 def _gen_mock_conn(name):
     48     """Create an identifiable mock connectivity class.
     49     """
     50     return type('mock_conn_%s' % name, (MockConnectivity,),
     51                 {'_conn_cls_name': name})
     52 
     53 
     54 def _gen_machine_dict(hostname='localhost', labels=[], attributes={}):
     55     """Generate a machine dictionary with the specified parameters.
     56 
     57     @param hostname: hostname of machine
     58     @param labels: list of host labels
     59     @param attributes: dict of host attributes
     60 
     61     @return: machine dict with mocked AFE Host object and fake AfeStore.
     62     """
     63     afe_host = base_label_unittest.MockAFEHost(labels, attributes)
     64     store = host_info.InMemoryHostInfoStore()
     65     store.commit(host_info.HostInfo(labels, attributes))
     66     return {'hostname': hostname,
     67             'afe_host': afe_host,
     68             'host_info_store': store}
     69 
     70 
     71 class CreateHostUnittests(unittest.TestCase):
     72     """Tests for create_host function."""
     73 
     74     def setUp(self):
     75         """Prevent use of real Host and connectivity objects due to potential
     76         side effects.
     77         """
     78         self._orig_ssh_engine = factory.SSH_ENGINE
     79         self._orig_types = factory.host_types
     80         self._orig_dict = factory.OS_HOST_DICT
     81         self._orig_cros_host = factory.cros_host.CrosHost
     82         self._orig_local_host = factory.local_host.LocalHost
     83         self._orig_ssh_host = factory.ssh_host.SSHHost
     84 
     85         self.host_types = factory.host_types = []
     86         self.os_host_dict = factory.OS_HOST_DICT = {}
     87         factory.cros_host.CrosHost = _gen_mock_host('cros_host')
     88         factory.local_host.LocalHost = _gen_mock_conn('local')
     89         factory.ssh_host.SSHHost = _gen_mock_conn('ssh')
     90 
     91 
     92     def tearDown(self):
     93         """Clean up mocks."""
     94         factory.SSH_ENGINE = self._orig_ssh_engine
     95         factory.host_types = self._orig_types
     96         factory.OS_HOST_DICT = self._orig_dict
     97         factory.cros_host.CrosHost = self._orig_cros_host
     98         factory.local_host.LocalHost = self._orig_local_host
     99         factory.ssh_host.SSHHost = self._orig_ssh_host
    100 
    101 
    102     def test_use_specified(self):
    103         """Confirm that the specified host and connectivity classes are used."""
    104         machine = _gen_machine_dict()
    105         host_obj = factory.create_host(
    106                 machine,
    107                 _gen_mock_host('specified'),
    108                 _gen_mock_conn('specified')
    109         )
    110         self.assertEqual(host_obj._host_cls_name, 'specified')
    111         self.assertEqual(host_obj._conn_cls_name, 'specified')
    112 
    113 
    114     def test_detect_host_by_os_label(self):
    115         """Confirm that the host object is selected by the os label.
    116         """
    117         machine = _gen_machine_dict(labels=['os:foo'])
    118         self.os_host_dict['foo'] = _gen_mock_host('foo')
    119         host_obj = factory.create_host(machine)
    120         self.assertEqual(host_obj._host_cls_name, 'foo')
    121 
    122 
    123     def test_detect_host_by_os_type_attribute(self):
    124         """Confirm that the host object is selected by the os_type attribute
    125         and that the os_type attribute is preferred over the os label.
    126         """
    127         machine = _gen_machine_dict(labels=['os:foo'],
    128                                          attributes={'os_type': 'bar'})
    129         self.os_host_dict['foo'] = _gen_mock_host('foo')
    130         self.os_host_dict['bar'] = _gen_mock_host('bar')
    131         host_obj = factory.create_host(machine)
    132         self.assertEqual(host_obj._host_cls_name, 'bar')
    133 
    134 
    135     def test_detect_host_by_check_host(self):
    136         """Confirm check_host logic chooses a host object when label/attribute
    137         detection fails.
    138         """
    139         machine = _gen_machine_dict()
    140         self.host_types.append(_gen_mock_host('first', check_host=False))
    141         self.host_types.append(_gen_mock_host('second', check_host=True))
    142         self.host_types.append(_gen_mock_host('third', check_host=False))
    143         host_obj = factory.create_host(machine)
    144         self.assertEqual(host_obj._host_cls_name, 'second')
    145 
    146 
    147     def test_detect_host_fallback_to_cros_host(self):
    148         """Confirm fallback to CrosHost when all other detection fails.
    149         """
    150         machine = _gen_machine_dict()
    151         host_obj = factory.create_host(machine)
    152         self.assertEqual(host_obj._host_cls_name, 'cros_host')
    153 
    154 
    155     def test_choose_connectivity_local(self):
    156         """Confirm local connectivity class used when hostname is localhost.
    157         """
    158         machine = _gen_machine_dict(hostname='localhost')
    159         host_obj = factory.create_host(machine)
    160         self.assertEqual(host_obj._conn_cls_name, 'local')
    161 
    162 
    163     def test_choose_connectivity_ssh(self):
    164         """Confirm ssh connectivity class used when configured and hostname
    165         is not localhost.
    166         """
    167         factory.SSH_ENGINE = 'raw_ssh'
    168         machine = _gen_machine_dict(hostname='somehost')
    169         host_obj = factory.create_host(machine)
    170         self.assertEqual(host_obj._conn_cls_name, 'ssh')
    171 
    172 
    173     def test_choose_connectivity_unsupported(self):
    174         """Confirm exception when configured for unsupported ssh engine.
    175         """
    176         factory.SSH_ENGINE = 'unsupported'
    177         machine = _gen_machine_dict(hostname='somehost')
    178         with self.assertRaises(error.AutoservError):
    179             factory.create_host(machine)
    180 
    181 
    182     def test_argument_passthrough(self):
    183         """Confirm that detected and specified arguments are passed through to
    184         the host object.
    185         """
    186         machine = _gen_machine_dict(hostname='localhost')
    187         host_obj = factory.create_host(machine, foo='bar')
    188         self.assertEqual(host_obj._init_args['hostname'], 'localhost')
    189         self.assertTrue('afe_host' in host_obj._init_args)
    190         self.assertTrue('host_info_store' in host_obj._init_args)
    191         self.assertEqual(host_obj._init_args['foo'], 'bar')
    192 
    193 
    194     def test_global_ssh_params(self):
    195         """Confirm passing of ssh parameters set as globals.
    196         """
    197         factory.ssh_user = 'foo'
    198         factory.ssh_pass = 'bar'
    199         factory.ssh_port = 1
    200         factory.ssh_verbosity_flag = 'baz'
    201         factory.ssh_options = 'zip'
    202         machine = _gen_machine_dict()
    203         try:
    204             host_obj = factory.create_host(machine)
    205             self.assertEqual(host_obj._init_args['user'], 'foo')
    206             self.assertEqual(host_obj._init_args['password'], 'bar')
    207             self.assertEqual(host_obj._init_args['port'], 1)
    208             self.assertEqual(host_obj._init_args['ssh_verbosity_flag'], 'baz')
    209             self.assertEqual(host_obj._init_args['ssh_options'], 'zip')
    210         finally:
    211             del factory.ssh_user
    212             del factory.ssh_pass
    213             del factory.ssh_port
    214             del factory.ssh_verbosity_flag
    215             del factory.ssh_options
    216 
    217 
    218     def test_host_attribute_ssh_params(self):
    219         """Confirm passing of ssh parameters from host attributes.
    220         """
    221         machine = _gen_machine_dict(attributes={'ssh_user': 'somebody',
    222                                                 'ssh_port': 100,
    223                                                 'ssh_verbosity_flag': 'verb',
    224                                                 'ssh_options': 'options'})
    225         host_obj = factory.create_host(machine)
    226         self.assertEqual(host_obj._init_args['user'], 'somebody')
    227         self.assertEqual(host_obj._init_args['port'], 100)
    228         self.assertEqual(host_obj._init_args['ssh_verbosity_flag'], 'verb')
    229         self.assertEqual(host_obj._init_args['ssh_options'], 'options')
    230 
    231 
    232 class CreateTestbedUnittests(unittest.TestCase):
    233     """Tests for create_testbed function."""
    234 
    235     def setUp(self):
    236         """Mock out TestBed class to eliminate side effects.
    237         """
    238         self._orig_testbed = factory.testbed.TestBed
    239         factory.testbed.TestBed = _gen_mock_host('testbed')
    240 
    241 
    242     def tearDown(self):
    243         """Clean up mock.
    244         """
    245         factory.testbed.TestBed = self._orig_testbed
    246 
    247 
    248     def test_argument_passthrough(self):
    249         """Confirm that detected and specified arguments are passed through to
    250         the testbed object.
    251         """
    252         machine = _gen_machine_dict(hostname='localhost')
    253         testbed_obj = factory.create_testbed(machine, foo='bar')
    254         self.assertEqual(testbed_obj._init_args['hostname'], 'localhost')
    255         self.assertTrue('afe_host' in testbed_obj._init_args)
    256         self.assertTrue('host_info_store' in testbed_obj._init_args)
    257         self.assertEqual(testbed_obj._init_args['foo'], 'bar')
    258 
    259 
    260     def test_global_ssh_params(self):
    261         """Confirm passing of ssh parameters set as globals.
    262         """
    263         factory.ssh_user = 'foo'
    264         factory.ssh_pass = 'bar'
    265         factory.ssh_port = 1
    266         factory.ssh_verbosity_flag = 'baz'
    267         factory.ssh_options = 'zip'
    268         machine = _gen_machine_dict()
    269         try:
    270             testbed_obj = factory.create_testbed(machine)
    271             self.assertEqual(testbed_obj._init_args['user'], 'foo')
    272             self.assertEqual(testbed_obj._init_args['password'], 'bar')
    273             self.assertEqual(testbed_obj._init_args['port'], 1)
    274             self.assertEqual(testbed_obj._init_args['ssh_verbosity_flag'],
    275                              'baz')
    276             self.assertEqual(testbed_obj._init_args['ssh_options'], 'zip')
    277         finally:
    278             del factory.ssh_user
    279             del factory.ssh_pass
    280             del factory.ssh_port
    281             del factory.ssh_verbosity_flag
    282             del factory.ssh_options
    283 
    284 
    285     def test_host_attribute_ssh_params(self):
    286         """Confirm passing of ssh parameters from host attributes.
    287         """
    288         machine = _gen_machine_dict(attributes={'ssh_user': 'somebody',
    289                                                 'ssh_port': 100,
    290                                                 'ssh_verbosity_flag': 'verb',
    291                                                 'ssh_options': 'options'})
    292         testbed_obj = factory.create_testbed(machine)
    293         self.assertEqual(testbed_obj._init_args['user'], 'somebody')
    294         self.assertEqual(testbed_obj._init_args['port'], 100)
    295         self.assertEqual(testbed_obj._init_args['ssh_verbosity_flag'], 'verb')
    296         self.assertEqual(testbed_obj._init_args['ssh_options'], 'options')
    297 
    298 
    299 if __name__ == '__main__':
    300     unittest.main()
    301 
    302