Home | History | Annotate | Download | only in shard
      1 # Copyright (c) 2014 The Chromium OS Authors. All rights reserved.
      2 # Use of this source code is governed by a BSD-style license that can be
      3 # found in the LICENSE file.
      4 
      5 import datetime
      6 import mox
      7 import time
      8 import unittest
      9 
     10 import common
     11 
     12 from autotest_lib.frontend import setup_django_environment
     13 from autotest_lib.frontend.afe import frontend_test_utils
     14 from autotest_lib.frontend.afe import models
     15 from autotest_lib.frontend.afe import model_logic
     16 from autotest_lib.client.common_lib import error
     17 from autotest_lib.client.common_lib import global_config
     18 from autotest_lib.server.cros.dynamic_suite import frontend_wrappers
     19 from autotest_lib.scheduler.shard import shard_client
     20 from django.core.exceptions import MultipleObjectsReturned
     21 
     22 
     23 class ShardClientTest(mox.MoxTestBase,
     24                       frontend_test_utils.FrontendTestMixin):
     25     """Unit tests for functions in shard_client.py"""
     26 
     27 
     28     GLOBAL_AFE_HOSTNAME = 'foo_autotest'
     29 
     30 
     31     def setUp(self):
     32         super(ShardClientTest, self).setUp()
     33 
     34         global_config.global_config.override_config_value(
     35                 'SHARD', 'global_afe_hostname', self.GLOBAL_AFE_HOSTNAME)
     36 
     37         self._frontend_common_setup(fill_data=False)
     38 
     39 
     40     def tearDown(self):
     41         self.mox.UnsetStubs()
     42 
     43 
     44     def setup_mocks(self):
     45         self.mox.StubOutClassWithMocks(frontend_wrappers, 'RetryingAFE')
     46         self.afe = frontend_wrappers.RetryingAFE(server=mox.IgnoreArg(),
     47                                                  delay_sec=mox.IgnoreArg(),
     48                                                  timeout_min=mox.IgnoreArg())
     49 
     50     def setup_global_config(self):
     51         global_config.global_config.override_config_value(
     52                 'SHARD', 'is_slave_shard', 'True')
     53         global_config.global_config.override_config_value(
     54                 'SHARD', 'shard_hostname', 'host1')
     55 
     56 
     57     def expect_heartbeat(self, shard_hostname='host1',
     58                          known_job_ids=[], known_host_ids=[],
     59                          known_host_statuses=[], hqes=[], jobs=[],
     60                          side_effect=None, return_hosts=[], return_jobs=[],
     61                          return_suite_keyvals=[], return_incorrect_hosts=[]):
     62         call = self.afe.run(
     63             'shard_heartbeat', shard_hostname=shard_hostname,
     64             hqes=hqes, jobs=jobs,
     65             known_job_ids=known_job_ids, known_host_ids=known_host_ids,
     66             known_host_statuses=known_host_statuses,
     67             )
     68 
     69         if side_effect:
     70             call = call.WithSideEffects(side_effect)
     71 
     72         call.AndReturn({
     73                 'hosts': return_hosts,
     74                 'jobs': return_jobs,
     75                 'suite_keyvals': return_suite_keyvals,
     76                 'incorrect_host_ids': return_incorrect_hosts,
     77             })
     78 
     79 
     80     def tearDown(self):
     81         self._frontend_common_teardown()
     82 
     83         # Without this global_config will keep state over test cases
     84         global_config.global_config.reset_config_values()
     85 
     86 
     87     def _get_sample_serialized_host(self):
     88         return {'aclgroup_set': [],
     89                 'dirty': True,
     90                 'hostattribute_set': [],
     91                 'hostname': u'host1',
     92                 u'id': 2,
     93                 'invalid': False,
     94                 'labels': [],
     95                 'leased': True,
     96                 'lock_time': None,
     97                 'locked': False,
     98                 'protection': 0,
     99                 'shard': None,
    100                 'status': u'Ready'}
    101 
    102 
    103     def _get_sample_serialized_job(self):
    104         return {'control_file': u'foo',
    105                 'control_type': 2,
    106                 'created_on': datetime.datetime(2014, 9, 23, 15, 56, 10, 0),
    107                 'dependency_labels': [{u'id': 1,
    108                                        'invalid': False,
    109                                        'kernel_config': u'',
    110                                        'name': u'board:lumpy',
    111                                        'only_if_needed': False,
    112                                        'platform': False}],
    113                 'email_list': u'',
    114                 'hostqueueentry_set': [{'aborted': False,
    115                                         'active': False,
    116                                         'complete': False,
    117                                         'deleted': False,
    118                                         'execution_subdir': u'',
    119                                         'finished_on': None,
    120                                         u'id': 1,
    121                                         'meta_host': {u'id': 1,
    122                                                       'invalid': False,
    123                                                       'kernel_config': u'',
    124                                                       'name': u'board:lumpy',
    125                                                       'only_if_needed': False,
    126                                                       'platform': False},
    127                                         'started_on': None,
    128                                         'status': u'Queued'}],
    129                 u'id': 1,
    130                 'jobkeyval_set': [],
    131                 'max_runtime_hrs': 72,
    132                 'max_runtime_mins': 1440,
    133                 'name': u'dummy',
    134                 'owner': u'autotest_system',
    135                 'parse_failed_repair': True,
    136                 'priority': 40,
    137                 'parent_job_id': 0,
    138                 'reboot_after': 0,
    139                 'reboot_before': 1,
    140                 'run_reset': True,
    141                 'run_verify': False,
    142                 'shard': {'hostname': u'shard1', u'id': 1},
    143                 'synch_count': 0,
    144                 'test_retry': 0,
    145                 'timeout': 24,
    146                 'timeout_mins': 1440}
    147 
    148 
    149     def _get_sample_serialized_suite_keyvals(self):
    150         return {'id': 1,
    151                 'job_id': 0,
    152                 'key': 'test_key',
    153                 'value': 'test_value'}
    154 
    155 
    156     def testHeartbeat(self):
    157         """Trigger heartbeat, verify RPCs and persisting of the responses."""
    158         self.setup_mocks()
    159 
    160         global_config.global_config.override_config_value(
    161                 'SHARD', 'shard_hostname', 'host1')
    162 
    163         self.expect_heartbeat(
    164                 return_hosts=[self._get_sample_serialized_host()],
    165                 return_jobs=[self._get_sample_serialized_job()],
    166                 return_suite_keyvals=[
    167                         self._get_sample_serialized_suite_keyvals()])
    168 
    169         modified_sample_host = self._get_sample_serialized_host()
    170         modified_sample_host['hostname'] = 'host2'
    171 
    172         self.expect_heartbeat(
    173                 return_hosts=[modified_sample_host],
    174                 known_host_ids=[modified_sample_host['id']],
    175                 known_host_statuses=[modified_sample_host['status']],
    176                 known_job_ids=[1])
    177 
    178 
    179         def verify_upload_jobs_and_hqes(name, shard_hostname, jobs, hqes,
    180                                         known_host_ids, known_host_statuses,
    181                                         known_job_ids):
    182             self.assertEqual(len(jobs), 1)
    183             self.assertEqual(len(hqes), 1)
    184             job, hqe = jobs[0], hqes[0]
    185             self.assertEqual(hqe['status'], 'Completed')
    186 
    187 
    188         self.expect_heartbeat(
    189                 jobs=mox.IgnoreArg(), hqes=mox.IgnoreArg(),
    190                 known_host_ids=[modified_sample_host['id']],
    191                 known_host_statuses=[modified_sample_host['status']],
    192                 known_job_ids=[], side_effect=verify_upload_jobs_and_hqes)
    193 
    194         self.mox.ReplayAll()
    195         sut = shard_client.get_shard_client()
    196 
    197         sut.do_heartbeat()
    198 
    199         # Check if dummy object was saved to DB
    200         host = models.Host.objects.get(id=2)
    201         self.assertEqual(host.hostname, 'host1')
    202 
    203         # Check if suite keyval  was saved to DB
    204         suite_keyval = models.JobKeyval.objects.filter(job_id=0)[0]
    205         self.assertEqual(suite_keyval.key, 'test_key')
    206 
    207         sut.do_heartbeat()
    208 
    209         # Ensure it wasn't overwritten
    210         host = models.Host.objects.get(id=2)
    211         self.assertEqual(host.hostname, 'host1')
    212 
    213         job = models.Job.objects.all()[0]
    214         job.shard = None
    215         job.save()
    216         hqe = job.hostqueueentry_set.all()[0]
    217         hqe.status = 'Completed'
    218         hqe.save()
    219 
    220         sut.do_heartbeat()
    221 
    222 
    223         self.mox.VerifyAll()
    224 
    225 
    226     def testRemoveInvalidHosts(self):
    227         self.setup_mocks()
    228         self.setup_global_config()
    229 
    230         host_serialized = self._get_sample_serialized_host()
    231         host_id = host_serialized[u'id']
    232 
    233         # 1st heartbeat: return a host.
    234         # 2nd heartbeat: "delete" that host. Also send a spurious extra ID
    235         # that isn't present to ensure shard client doesn't crash. (Note: delete
    236         # operation doesn't actually delete db entry. Djanjo model ;logic
    237         # instead simply marks it as invalid.
    238         # 3rd heartbeat: host is no longer present in shard's request.
    239 
    240         self.expect_heartbeat(return_hosts=[host_serialized])
    241         self.expect_heartbeat(known_host_ids=[host_id],
    242                               known_host_statuses=[u'Ready'],
    243                               return_incorrect_hosts=[host_id, 42])
    244         self.expect_heartbeat()
    245 
    246         self.mox.ReplayAll()
    247         sut = shard_client.get_shard_client()
    248 
    249         sut.do_heartbeat()
    250         host = models.Host.smart_get(host_id)
    251         self.assertFalse(host.invalid)
    252 
    253         # Host should no longer "exist" after the invalidation.
    254         # Why don't we simply count the number of hosts in db? Because the host
    255         # actually remains int he db, but simply has it's invalid bit set to
    256         # True.
    257         sut.do_heartbeat()
    258         with self.assertRaises(models.Host.DoesNotExist):
    259             host = models.Host.smart_get(host_id)
    260 
    261 
    262         # Subsequent heartbeat no longer passes the host id as a known host.
    263         sut.do_heartbeat()
    264 
    265 
    266     def testFailAndRedownloadJobs(self):
    267         self.setup_mocks()
    268         self.setup_global_config()
    269 
    270         job1_serialized = self._get_sample_serialized_job()
    271         job2_serialized = self._get_sample_serialized_job()
    272         job2_serialized['id'] = 2
    273         job2_serialized['hostqueueentry_set'][0]['id'] = 2
    274 
    275         self.expect_heartbeat(return_jobs=[job1_serialized])
    276         self.expect_heartbeat(return_jobs=[job1_serialized, job2_serialized])
    277         self.expect_heartbeat(known_job_ids=[job1_serialized['id'],
    278                                              job2_serialized['id']])
    279         self.expect_heartbeat(known_job_ids=[job2_serialized['id']])
    280 
    281         self.mox.ReplayAll()
    282         sut = shard_client.get_shard_client()
    283 
    284         original_process_heartbeat_response = sut.process_heartbeat_response
    285         def failing_process_heartbeat_response(*args, **kwargs):
    286             raise RuntimeError
    287 
    288         sut.process_heartbeat_response = failing_process_heartbeat_response
    289         self.assertRaises(RuntimeError, sut.do_heartbeat)
    290 
    291         sut.process_heartbeat_response = original_process_heartbeat_response
    292         sut.do_heartbeat()
    293         sut.do_heartbeat()
    294 
    295         job2 = models.Job.objects.get(pk=job1_serialized['id'])
    296         job2.hostqueueentry_set.all().update(complete=True)
    297 
    298         sut.do_heartbeat()
    299 
    300         self.mox.VerifyAll()
    301 
    302 
    303     def testFailAndRedownloadHosts(self):
    304         self.setup_mocks()
    305         self.setup_global_config()
    306 
    307         host1_serialized = self._get_sample_serialized_host()
    308         host2_serialized = self._get_sample_serialized_host()
    309         host2_serialized['id'] = 3
    310         host2_serialized['hostname'] = 'host2'
    311 
    312         self.expect_heartbeat(return_hosts=[host1_serialized])
    313         self.expect_heartbeat(return_hosts=[host1_serialized, host2_serialized])
    314         self.expect_heartbeat(known_host_ids=[host1_serialized['id'],
    315                                               host2_serialized['id']],
    316                               known_host_statuses=[host1_serialized['status'],
    317                                                    host2_serialized['status']])
    318 
    319         self.mox.ReplayAll()
    320         sut = shard_client.get_shard_client()
    321 
    322         original_process_heartbeat_response = sut.process_heartbeat_response
    323         def failing_process_heartbeat_response(*args, **kwargs):
    324             raise RuntimeError
    325 
    326         sut.process_heartbeat_response = failing_process_heartbeat_response
    327         self.assertRaises(RuntimeError, sut.do_heartbeat)
    328 
    329         self.assertEqual(models.Host.objects.count(), 0)
    330 
    331         sut.process_heartbeat_response = original_process_heartbeat_response
    332         sut.do_heartbeat()
    333         sut.do_heartbeat()
    334 
    335         self.mox.VerifyAll()
    336 
    337 
    338     def testHeartbeatNoShardMode(self):
    339         """Ensure an exception is thrown when run on a non-shard machine."""
    340         self.mox.ReplayAll()
    341 
    342         self.assertRaises(error.HeartbeatOnlyAllowedInShardModeException,
    343                           shard_client.get_shard_client)
    344 
    345         self.mox.VerifyAll()
    346 
    347 
    348     def testLoop(self):
    349         """Test looping over heartbeats and aborting that loop works."""
    350         self.setup_mocks()
    351         self.setup_global_config()
    352 
    353         global_config.global_config.override_config_value(
    354                 'SHARD', 'heartbeat_pause_sec', '0.01')
    355 
    356         self.expect_heartbeat()
    357 
    358         sut = None
    359 
    360         def shutdown_sut(*args, **kwargs):
    361             sut.shutdown()
    362 
    363         self.expect_heartbeat(side_effect=shutdown_sut)
    364 
    365         self.mox.ReplayAll()
    366         sut = shard_client.get_shard_client()
    367         sut.loop(None)
    368 
    369         self.mox.VerifyAll()
    370 
    371 
    372     def testLoopWithDeadline(self):
    373         """Test looping over heartbeats with a timeout."""
    374         self.setup_mocks()
    375         self.setup_global_config()
    376         self.mox.StubOutWithMock(time, 'time')
    377 
    378         global_config.global_config.override_config_value(
    379                 'SHARD', 'heartbeat_pause_sec', '0.01')
    380         time.time().AndReturn(1516894000)
    381         time.time().AndReturn(1516894000)
    382         self.expect_heartbeat()
    383         # Set expectation that heartbeat took 1 minute.
    384         time.time().MultipleTimes().AndReturn(1516894000 + 60)
    385 
    386         self.mox.ReplayAll()
    387         sut = shard_client.get_shard_client()
    388         # 36 seconds
    389         sut.loop(lifetime_hours=0.01)
    390         self.mox.VerifyAll()
    391 
    392     def test_remove_incorrect_hosts(self):
    393         """Test _remove_incorrect_hosts with MultipleObjectsReturned."""
    394         self.setup_mocks()
    395         self.setup_global_config()
    396         self.mox.StubOutWithMock(model_logic.ModelWithInvalidQuerySet, 'delete')
    397         call = models.Host.objects.filter(id__in=[1]).delete()
    398         call.AndRaise(MultipleObjectsReturned('e'))
    399 
    400         self.mox.ReplayAll()
    401         sut = shard_client.get_shard_client()
    402         sut._remove_incorrect_hosts(incorrect_host_ids=[1])
    403 
    404         self.mox.VerifyAll()
    405 
    406 
    407 if __name__ == '__main__':
    408     unittest.main()
    409