1 #!/usr/bin/python 2 # Copyright 2016 The Chromium OS Authors. All rights reserved. 3 # Use of this source code is governed by a BSD-style license that can be 4 # found in the LICENSE file. 5 6 import mock 7 import unittest 8 9 import common 10 from autotest_lib.client.common_lib import error 11 from autotest_lib.server.hosts import base_label_unittest, factory 12 from autotest_lib.server.hosts import host_info 13 14 15 class MockHost(object): 16 """Mock host object with no side effects.""" 17 def __init__(self, hostname, **args): 18 self._init_args = args 19 self._init_args['hostname'] = hostname 20 21 22 def job_start(self): 23 """Only method called by factory.""" 24 pass 25 26 27 class MockConnectivity(object): 28 """Mock connectivity object with no side effects.""" 29 def __init__(self, hostname, **args): 30 pass 31 32 33 def close(self): 34 """Only method called by factory.""" 35 pass 36 37 38 def _gen_mock_host(name, check_host=False): 39 """Create an identifiable mock host closs. 40 """ 41 return type('mock_host_%s' % name, (MockHost,), { 42 '_host_cls_name': name, 43 'check_host': staticmethod(lambda host, timeout=None: check_host) 44 }) 45 46 47 def _gen_mock_conn(name): 48 """Create an identifiable mock connectivity class. 49 """ 50 return type('mock_conn_%s' % name, (MockConnectivity,), 51 {'_conn_cls_name': name}) 52 53 54 def _gen_machine_dict(hostname='localhost', labels=[], attributes={}): 55 """Generate a machine dictionary with the specified parameters. 56 57 @param hostname: hostname of machine 58 @param labels: list of host labels 59 @param attributes: dict of host attributes 60 61 @return: machine dict with mocked AFE Host object and fake AfeStore. 62 """ 63 afe_host = base_label_unittest.MockAFEHost(labels, attributes) 64 store = host_info.InMemoryHostInfoStore() 65 store.commit(host_info.HostInfo(labels, attributes)) 66 return {'hostname': hostname, 67 'afe_host': afe_host, 68 'host_info_store': store} 69 70 71 class CreateHostUnittests(unittest.TestCase): 72 """Tests for create_host function.""" 73 74 def setUp(self): 75 """Prevent use of real Host and connectivity objects due to potential 76 side effects. 77 """ 78 self._orig_ssh_engine = factory.SSH_ENGINE 79 self._orig_types = factory.host_types 80 self._orig_dict = factory.OS_HOST_DICT 81 self._orig_cros_host = factory.cros_host.CrosHost 82 self._orig_local_host = factory.local_host.LocalHost 83 self._orig_ssh_host = factory.ssh_host.SSHHost 84 85 self.host_types = factory.host_types = [] 86 self.os_host_dict = factory.OS_HOST_DICT = {} 87 factory.cros_host.CrosHost = _gen_mock_host('cros_host') 88 factory.local_host.LocalHost = _gen_mock_conn('local') 89 factory.ssh_host.SSHHost = _gen_mock_conn('ssh') 90 91 92 def tearDown(self): 93 """Clean up mocks.""" 94 factory.SSH_ENGINE = self._orig_ssh_engine 95 factory.host_types = self._orig_types 96 factory.OS_HOST_DICT = self._orig_dict 97 factory.cros_host.CrosHost = self._orig_cros_host 98 factory.local_host.LocalHost = self._orig_local_host 99 factory.ssh_host.SSHHost = self._orig_ssh_host 100 101 102 def test_use_specified(self): 103 """Confirm that the specified host and connectivity classes are used.""" 104 machine = _gen_machine_dict() 105 host_obj = factory.create_host( 106 machine, 107 _gen_mock_host('specified'), 108 _gen_mock_conn('specified') 109 ) 110 self.assertEqual(host_obj._host_cls_name, 'specified') 111 self.assertEqual(host_obj._conn_cls_name, 'specified') 112 113 114 def test_detect_host_by_os_label(self): 115 """Confirm that the host object is selected by the os label. 116 """ 117 machine = _gen_machine_dict(labels=['os:foo']) 118 self.os_host_dict['foo'] = _gen_mock_host('foo') 119 host_obj = factory.create_host(machine) 120 self.assertEqual(host_obj._host_cls_name, 'foo') 121 122 123 def test_detect_host_by_os_type_attribute(self): 124 """Confirm that the host object is selected by the os_type attribute 125 and that the os_type attribute is preferred over the os label. 126 """ 127 machine = _gen_machine_dict(labels=['os:foo'], 128 attributes={'os_type': 'bar'}) 129 self.os_host_dict['foo'] = _gen_mock_host('foo') 130 self.os_host_dict['bar'] = _gen_mock_host('bar') 131 host_obj = factory.create_host(machine) 132 self.assertEqual(host_obj._host_cls_name, 'bar') 133 134 135 def test_detect_host_by_check_host(self): 136 """Confirm check_host logic chooses a host object when label/attribute 137 detection fails. 138 """ 139 machine = _gen_machine_dict() 140 self.host_types.append(_gen_mock_host('first', check_host=False)) 141 self.host_types.append(_gen_mock_host('second', check_host=True)) 142 self.host_types.append(_gen_mock_host('third', check_host=False)) 143 host_obj = factory.create_host(machine) 144 self.assertEqual(host_obj._host_cls_name, 'second') 145 146 147 def test_detect_host_fallback_to_cros_host(self): 148 """Confirm fallback to CrosHost when all other detection fails. 149 """ 150 machine = _gen_machine_dict() 151 host_obj = factory.create_host(machine) 152 self.assertEqual(host_obj._host_cls_name, 'cros_host') 153 154 155 def test_choose_connectivity_local(self): 156 """Confirm local connectivity class used when hostname is localhost. 157 """ 158 machine = _gen_machine_dict(hostname='localhost') 159 host_obj = factory.create_host(machine) 160 self.assertEqual(host_obj._conn_cls_name, 'local') 161 162 163 def test_choose_connectivity_ssh(self): 164 """Confirm ssh connectivity class used when configured and hostname 165 is not localhost. 166 """ 167 factory.SSH_ENGINE = 'raw_ssh' 168 machine = _gen_machine_dict(hostname='somehost') 169 host_obj = factory.create_host(machine) 170 self.assertEqual(host_obj._conn_cls_name, 'ssh') 171 172 173 def test_choose_connectivity_unsupported(self): 174 """Confirm exception when configured for unsupported ssh engine. 175 """ 176 factory.SSH_ENGINE = 'unsupported' 177 machine = _gen_machine_dict(hostname='somehost') 178 with self.assertRaises(error.AutoservError): 179 factory.create_host(machine) 180 181 182 def test_argument_passthrough(self): 183 """Confirm that detected and specified arguments are passed through to 184 the host object. 185 """ 186 machine = _gen_machine_dict(hostname='localhost') 187 host_obj = factory.create_host(machine, foo='bar') 188 self.assertEqual(host_obj._init_args['hostname'], 'localhost') 189 self.assertTrue('afe_host' in host_obj._init_args) 190 self.assertTrue('host_info_store' in host_obj._init_args) 191 self.assertEqual(host_obj._init_args['foo'], 'bar') 192 193 194 def test_global_ssh_params(self): 195 """Confirm passing of ssh parameters set as globals. 196 """ 197 factory.ssh_user = 'foo' 198 factory.ssh_pass = 'bar' 199 factory.ssh_port = 1 200 factory.ssh_verbosity_flag = 'baz' 201 factory.ssh_options = 'zip' 202 machine = _gen_machine_dict() 203 try: 204 host_obj = factory.create_host(machine) 205 self.assertEqual(host_obj._init_args['user'], 'foo') 206 self.assertEqual(host_obj._init_args['password'], 'bar') 207 self.assertEqual(host_obj._init_args['port'], 1) 208 self.assertEqual(host_obj._init_args['ssh_verbosity_flag'], 'baz') 209 self.assertEqual(host_obj._init_args['ssh_options'], 'zip') 210 finally: 211 del factory.ssh_user 212 del factory.ssh_pass 213 del factory.ssh_port 214 del factory.ssh_verbosity_flag 215 del factory.ssh_options 216 217 218 def test_host_attribute_ssh_params(self): 219 """Confirm passing of ssh parameters from host attributes. 220 """ 221 machine = _gen_machine_dict(attributes={'ssh_user': 'somebody', 222 'ssh_port': 100, 223 'ssh_verbosity_flag': 'verb', 224 'ssh_options': 'options'}) 225 host_obj = factory.create_host(machine) 226 self.assertEqual(host_obj._init_args['user'], 'somebody') 227 self.assertEqual(host_obj._init_args['port'], 100) 228 self.assertEqual(host_obj._init_args['ssh_verbosity_flag'], 'verb') 229 self.assertEqual(host_obj._init_args['ssh_options'], 'options') 230 231 232 class CreateTestbedUnittests(unittest.TestCase): 233 """Tests for create_testbed function.""" 234 235 def setUp(self): 236 """Mock out TestBed class to eliminate side effects. 237 """ 238 self._orig_testbed = factory.testbed.TestBed 239 factory.testbed.TestBed = _gen_mock_host('testbed') 240 241 242 def tearDown(self): 243 """Clean up mock. 244 """ 245 factory.testbed.TestBed = self._orig_testbed 246 247 248 def test_argument_passthrough(self): 249 """Confirm that detected and specified arguments are passed through to 250 the testbed object. 251 """ 252 machine = _gen_machine_dict(hostname='localhost') 253 testbed_obj = factory.create_testbed(machine, foo='bar') 254 self.assertEqual(testbed_obj._init_args['hostname'], 'localhost') 255 self.assertTrue('afe_host' in testbed_obj._init_args) 256 self.assertTrue('host_info_store' in testbed_obj._init_args) 257 self.assertEqual(testbed_obj._init_args['foo'], 'bar') 258 259 260 def test_global_ssh_params(self): 261 """Confirm passing of ssh parameters set as globals. 262 """ 263 factory.ssh_user = 'foo' 264 factory.ssh_pass = 'bar' 265 factory.ssh_port = 1 266 factory.ssh_verbosity_flag = 'baz' 267 factory.ssh_options = 'zip' 268 machine = _gen_machine_dict() 269 try: 270 testbed_obj = factory.create_testbed(machine) 271 self.assertEqual(testbed_obj._init_args['user'], 'foo') 272 self.assertEqual(testbed_obj._init_args['password'], 'bar') 273 self.assertEqual(testbed_obj._init_args['port'], 1) 274 self.assertEqual(testbed_obj._init_args['ssh_verbosity_flag'], 275 'baz') 276 self.assertEqual(testbed_obj._init_args['ssh_options'], 'zip') 277 finally: 278 del factory.ssh_user 279 del factory.ssh_pass 280 del factory.ssh_port 281 del factory.ssh_verbosity_flag 282 del factory.ssh_options 283 284 285 def test_host_attribute_ssh_params(self): 286 """Confirm passing of ssh parameters from host attributes. 287 """ 288 machine = _gen_machine_dict(attributes={'ssh_user': 'somebody', 289 'ssh_port': 100, 290 'ssh_verbosity_flag': 'verb', 291 'ssh_options': 'options'}) 292 testbed_obj = factory.create_testbed(machine) 293 self.assertEqual(testbed_obj._init_args['user'], 'somebody') 294 self.assertEqual(testbed_obj._init_args['port'], 100) 295 self.assertEqual(testbed_obj._init_args['ssh_verbosity_flag'], 'verb') 296 self.assertEqual(testbed_obj._init_args['ssh_options'], 'options') 297 298 299 if __name__ == '__main__': 300 unittest.main() 301 302