Home | History | Annotate | Download | only in shard
      1 # Copyright (c) 2014 The Chromium OS Authors. All rights reserved.
      2 # Use of this source code is governed by a BSD-style license that can be
      3 # found in the LICENSE file.
      4 
      5 import datetime
      6 import mox
      7 import time
      8 import unittest
      9 
     10 import common
     11 
     12 from autotest_lib.frontend import setup_django_environment
     13 from autotest_lib.frontend.afe import frontend_test_utils
     14 from autotest_lib.frontend.afe import models
     15 from autotest_lib.client.common_lib import error
     16 from autotest_lib.client.common_lib import global_config
     17 from autotest_lib.server.cros.dynamic_suite import frontend_wrappers
     18 from autotest_lib.scheduler.shard import shard_client
     19 
     20 
     21 class ShardClientTest(mox.MoxTestBase,
     22                       frontend_test_utils.FrontendTestMixin):
     23     """Unit tests for functions in shard_client.py"""
     24 
     25 
     26     GLOBAL_AFE_HOSTNAME = 'foo_autotest'
     27 
     28 
     29     def setUp(self):
     30         super(ShardClientTest, self).setUp()
     31 
     32         global_config.global_config.override_config_value(
     33                 'SHARD', 'global_afe_hostname', self.GLOBAL_AFE_HOSTNAME)
     34 
     35         self._frontend_common_setup(fill_data=False)
     36 
     37 
     38     def tearDown(self):
     39         self.mox.UnsetStubs()
     40 
     41 
     42     def setup_mocks(self):
     43         self.mox.StubOutClassWithMocks(frontend_wrappers, 'RetryingAFE')
     44         self.afe = frontend_wrappers.RetryingAFE(server=mox.IgnoreArg(),
     45                                                  delay_sec=5,
     46                                                  timeout_min=5)
     47 
     48 
     49     def setup_global_config(self):
     50         global_config.global_config.override_config_value(
     51                 'SHARD', 'is_slave_shard', 'True')
     52         global_config.global_config.override_config_value(
     53                 'SHARD', 'shard_hostname', 'host1')
     54 
     55 
     56     def expect_heartbeat(self, shard_hostname='host1',
     57                          known_job_ids=[], known_host_ids=[],
     58                          known_host_statuses=[], hqes=[], jobs=[],
     59                          side_effect=None, return_hosts=[], return_jobs=[],
     60                          return_suite_keyvals=[], return_incorrect_hosts=[]):
     61         call = self.afe.run(
     62             'shard_heartbeat', shard_hostname=shard_hostname,
     63             hqes=hqes, jobs=jobs,
     64             known_job_ids=known_job_ids, known_host_ids=known_host_ids,
     65             known_host_statuses=known_host_statuses,
     66             )
     67 
     68         if side_effect:
     69             call = call.WithSideEffects(side_effect)
     70 
     71         call.AndReturn({
     72                 'hosts': return_hosts,
     73                 'jobs': return_jobs,
     74                 'suite_keyvals': return_suite_keyvals,
     75                 'incorrect_host_ids': return_incorrect_hosts,
     76             })
     77 
     78 
     79     def tearDown(self):
     80         self._frontend_common_teardown()
     81 
     82         # Without this global_config will keep state over test cases
     83         global_config.global_config.reset_config_values()
     84 
     85 
     86     def _get_sample_serialized_host(self):
     87         return {'aclgroup_set': [],
     88                 'dirty': True,
     89                 'hostattribute_set': [],
     90                 'hostname': u'host1',
     91                 u'id': 2,
     92                 'invalid': False,
     93                 'labels': [],
     94                 'leased': True,
     95                 'lock_time': None,
     96                 'locked': False,
     97                 'protection': 0,
     98                 'shard': None,
     99                 'status': u'Ready'}
    100 
    101 
    102     def _get_sample_serialized_job(self):
    103         return {'control_file': u'foo',
    104                 'control_type': 2,
    105                 'created_on': datetime.datetime(2014, 9, 23, 15, 56, 10, 0),
    106                 'dependency_labels': [{u'id': 1,
    107                                        'invalid': False,
    108                                        'kernel_config': u'',
    109                                        'name': u'board:lumpy',
    110                                        'only_if_needed': False,
    111                                        'platform': False}],
    112                 'email_list': u'',
    113                 'hostqueueentry_set': [{'aborted': False,
    114                                         'active': False,
    115                                         'complete': False,
    116                                         'deleted': False,
    117                                         'execution_subdir': u'',
    118                                         'finished_on': None,
    119                                         u'id': 1,
    120                                         'meta_host': {u'id': 1,
    121                                                       'invalid': False,
    122                                                       'kernel_config': u'',
    123                                                       'name': u'board:lumpy',
    124                                                       'only_if_needed': False,
    125                                                       'platform': False},
    126                                         'started_on': None,
    127                                         'status': u'Queued'}],
    128                 u'id': 1,
    129                 'jobkeyval_set': [],
    130                 'max_runtime_hrs': 72,
    131                 'max_runtime_mins': 1440,
    132                 'name': u'dummy',
    133                 'owner': u'autotest_system',
    134                 'parse_failed_repair': True,
    135                 'priority': 40,
    136                 'parent_job_id': 0,
    137                 'reboot_after': 0,
    138                 'reboot_before': 1,
    139                 'run_reset': True,
    140                 'run_verify': False,
    141                 'shard': {'hostname': u'shard1', u'id': 1},
    142                 'synch_count': 0,
    143                 'test_retry': 0,
    144                 'timeout': 24,
    145                 'timeout_mins': 1440}
    146 
    147 
    148     def _get_sample_serialized_suite_keyvals(self):
    149         return {'id': 1,
    150                 'job_id': 0,
    151                 'key': 'test_key',
    152                 'value': 'test_value'}
    153 
    154 
    155     def testHeartbeat(self):
    156         """Trigger heartbeat, verify RPCs and persisting of the responses."""
    157         self.setup_mocks()
    158 
    159         global_config.global_config.override_config_value(
    160                 'SHARD', 'shard_hostname', 'host1')
    161 
    162         self.expect_heartbeat(
    163                 return_hosts=[self._get_sample_serialized_host()],
    164                 return_jobs=[self._get_sample_serialized_job()],
    165                 return_suite_keyvals=[
    166                         self._get_sample_serialized_suite_keyvals()])
    167 
    168         modified_sample_host = self._get_sample_serialized_host()
    169         modified_sample_host['hostname'] = 'host2'
    170 
    171         self.expect_heartbeat(
    172                 return_hosts=[modified_sample_host],
    173                 known_host_ids=[modified_sample_host['id']],
    174                 known_host_statuses=[modified_sample_host['status']],
    175                 known_job_ids=[1])
    176 
    177 
    178         def verify_upload_jobs_and_hqes(name, shard_hostname, jobs, hqes,
    179                                         known_host_ids, known_host_statuses,
    180                                         known_job_ids):
    181             self.assertEqual(len(jobs), 1)
    182             self.assertEqual(len(hqes), 1)
    183             job, hqe = jobs[0], hqes[0]
    184             self.assertEqual(hqe['status'], 'Completed')
    185 
    186 
    187         self.expect_heartbeat(
    188                 jobs=mox.IgnoreArg(), hqes=mox.IgnoreArg(),
    189                 known_host_ids=[modified_sample_host['id']],
    190                 known_host_statuses=[modified_sample_host['status']],
    191                 known_job_ids=[], side_effect=verify_upload_jobs_and_hqes)
    192 
    193         self.mox.ReplayAll()
    194         sut = shard_client.get_shard_client()
    195 
    196         sut.do_heartbeat()
    197 
    198         # Check if dummy object was saved to DB
    199         host = models.Host.objects.get(id=2)
    200         self.assertEqual(host.hostname, 'host1')
    201 
    202         # Check if suite keyval  was saved to DB
    203         suite_keyval = models.JobKeyval.objects.filter(job_id=0)[0]
    204         self.assertEqual(suite_keyval.key, 'test_key')
    205 
    206         sut.do_heartbeat()
    207 
    208         # Ensure it wasn't overwritten
    209         host = models.Host.objects.get(id=2)
    210         self.assertEqual(host.hostname, 'host1')
    211 
    212         job = models.Job.objects.all()[0]
    213         job.shard = None
    214         job.save()
    215         hqe = job.hostqueueentry_set.all()[0]
    216         hqe.status = 'Completed'
    217         hqe.save()
    218 
    219         sut.do_heartbeat()
    220 
    221 
    222         self.mox.VerifyAll()
    223 
    224 
    225     def testRemoveInvalidHosts(self):
    226         self.setup_mocks()
    227         self.setup_global_config()
    228 
    229         host_serialized = self._get_sample_serialized_host()
    230         host_id = host_serialized[u'id']
    231 
    232         # 1st heartbeat: return a host.
    233         # 2nd heartbeat: "delete" that host. Also send a spurious extra ID
    234         # that isn't present to ensure shard client doesn't crash. (Note: delete
    235         # operation doesn't actually delete db entry. Djanjo model ;logic
    236         # instead simply marks it as invalid.
    237         # 3rd heartbeat: host is no longer present in shard's request.
    238 
    239         self.expect_heartbeat(return_hosts=[host_serialized])
    240         self.expect_heartbeat(known_host_ids=[host_id],
    241                               known_host_statuses=[u'Ready'],
    242                               return_incorrect_hosts=[host_id, 42])
    243         self.expect_heartbeat()
    244 
    245         self.mox.ReplayAll()
    246         sut = shard_client.get_shard_client()
    247 
    248         sut.do_heartbeat()
    249         host = models.Host.smart_get(host_id)
    250         self.assertFalse(host.invalid)
    251 
    252         # Host should no longer "exist" after the invalidation.
    253         # Why don't we simply count the number of hosts in db? Because the host
    254         # actually remains int he db, but simply has it's invalid bit set to
    255         # True.
    256         sut.do_heartbeat()
    257         with self.assertRaises(models.Host.DoesNotExist):
    258             host = models.Host.smart_get(host_id)
    259 
    260 
    261         # Subsequent heartbeat no longer passes the host id as a known host.
    262         sut.do_heartbeat()
    263 
    264 
    265     def testFailAndRedownloadJobs(self):
    266         self.setup_mocks()
    267         self.setup_global_config()
    268 
    269         job1_serialized = self._get_sample_serialized_job()
    270         job2_serialized = self._get_sample_serialized_job()
    271         job2_serialized['id'] = 2
    272         job2_serialized['hostqueueentry_set'][0]['id'] = 2
    273 
    274         self.expect_heartbeat(return_jobs=[job1_serialized])
    275         self.expect_heartbeat(return_jobs=[job1_serialized, job2_serialized])
    276         self.expect_heartbeat(known_job_ids=[job1_serialized['id'],
    277                                              job2_serialized['id']])
    278         self.expect_heartbeat(known_job_ids=[job2_serialized['id']])
    279 
    280         self.mox.ReplayAll()
    281         sut = shard_client.get_shard_client()
    282 
    283         original_process_heartbeat_response = sut.process_heartbeat_response
    284         def failing_process_heartbeat_response(*args, **kwargs):
    285             raise RuntimeError
    286 
    287         sut.process_heartbeat_response = failing_process_heartbeat_response
    288         self.assertRaises(RuntimeError, sut.do_heartbeat)
    289 
    290         sut.process_heartbeat_response = original_process_heartbeat_response
    291         sut.do_heartbeat()
    292         sut.do_heartbeat()
    293 
    294         job2 = models.Job.objects.get(pk=job1_serialized['id'])
    295         job2.hostqueueentry_set.all().update(complete=True)
    296 
    297         sut.do_heartbeat()
    298 
    299         self.mox.VerifyAll()
    300 
    301 
    302     def testFailAndRedownloadHosts(self):
    303         self.setup_mocks()
    304         self.setup_global_config()
    305 
    306         host1_serialized = self._get_sample_serialized_host()
    307         host2_serialized = self._get_sample_serialized_host()
    308         host2_serialized['id'] = 3
    309         host2_serialized['hostname'] = 'host2'
    310 
    311         self.expect_heartbeat(return_hosts=[host1_serialized])
    312         self.expect_heartbeat(return_hosts=[host1_serialized, host2_serialized])
    313         self.expect_heartbeat(known_host_ids=[host1_serialized['id'],
    314                                               host2_serialized['id']],
    315                               known_host_statuses=[host1_serialized['status'],
    316                                                    host2_serialized['status']])
    317 
    318         self.mox.ReplayAll()
    319         sut = shard_client.get_shard_client()
    320 
    321         original_process_heartbeat_response = sut.process_heartbeat_response
    322         def failing_process_heartbeat_response(*args, **kwargs):
    323             raise RuntimeError
    324 
    325         sut.process_heartbeat_response = failing_process_heartbeat_response
    326         self.assertRaises(RuntimeError, sut.do_heartbeat)
    327 
    328         self.assertEqual(models.Host.objects.count(), 0)
    329 
    330         sut.process_heartbeat_response = original_process_heartbeat_response
    331         sut.do_heartbeat()
    332         sut.do_heartbeat()
    333 
    334         self.mox.VerifyAll()
    335 
    336 
    337     def testHeartbeatNoShardMode(self):
    338         """Ensure an exception is thrown when run on a non-shard machine."""
    339         self.mox.ReplayAll()
    340 
    341         self.assertRaises(error.HeartbeatOnlyAllowedInShardModeException,
    342                           shard_client.get_shard_client)
    343 
    344         self.mox.VerifyAll()
    345 
    346 
    347     def testLoop(self):
    348         """Test looping over heartbeats and aborting that loop works."""
    349         self.setup_mocks()
    350         self.setup_global_config()
    351 
    352         global_config.global_config.override_config_value(
    353                 'SHARD', 'heartbeat_pause_sec', '0.01')
    354 
    355         self.expect_heartbeat()
    356 
    357         sut = None
    358 
    359         def shutdown_sut(*args, **kwargs):
    360             sut.shutdown()
    361 
    362         self.expect_heartbeat(side_effect=shutdown_sut)
    363 
    364         self.mox.ReplayAll()
    365         sut = shard_client.get_shard_client()
    366         sut.loop(None)
    367 
    368         self.mox.VerifyAll()
    369 
    370 
    371     def testLoopWithDeadline(self):
    372         """Test looping over heartbeats with a timeout."""
    373         self.setup_mocks()
    374         self.setup_global_config()
    375         self.mox.StubOutWithMock(time, 'time')
    376 
    377         global_config.global_config.override_config_value(
    378                 'SHARD', 'heartbeat_pause_sec', '0.01')
    379         time.time().AndReturn(1516894000)
    380         time.time().AndReturn(1516894000)
    381         self.expect_heartbeat()
    382         # Set expectation that heartbeat took 1 minute.
    383         time.time().MultipleTimes().AndReturn(1516894000 + 60)
    384 
    385         self.mox.ReplayAll()
    386         sut = shard_client.get_shard_client()
    387         # 36 seconds
    388         sut.loop(lifetime_hours=0.01)
    389         self.mox.VerifyAll()
    390 
    391 
    392 if __name__ == '__main__':
    393     unittest.main()
    394