Home | History | Annotate | Download | only in scheduler
      1 #!/usr/bin/python
      2 #pylint: disable-msg=C0111
      3 
      4 import datetime
      5 import unittest
      6 
      7 import common
      8 from autotest_lib.frontend import setup_django_environment
      9 from autotest_lib.frontend.afe import frontend_test_utils
     10 from autotest_lib.client.common_lib import host_queue_entry_states
     11 from autotest_lib.database import database_connection
     12 from autotest_lib.frontend.afe import models, model_attributes
     13 from autotest_lib.scheduler import monitor_db
     14 from autotest_lib.scheduler import scheduler_lib
     15 from autotest_lib.scheduler import scheduler_models
     16 
     17 _DEBUG = False
     18 
     19 
     20 class BaseSchedulerModelsTest(unittest.TestCase,
     21                               frontend_test_utils.FrontendTestMixin):
     22     _config_section = 'AUTOTEST_WEB'
     23 
     24     def _do_query(self, sql):
     25         self._database.execute(sql)
     26 
     27 
     28     def _set_monitor_stubs(self):
     29         # Clear the instance cache as this is a brand new database.
     30         scheduler_models.DBObject._clear_instance_cache()
     31 
     32         self._database = (
     33             database_connection.TranslatingDatabase.get_test_database(
     34                 translators=scheduler_lib._DB_TRANSLATORS))
     35         self._database.connect(db_type='django')
     36         self._database.debug = _DEBUG
     37 
     38         self.god.stub_with(scheduler_models, '_db', self._database)
     39 
     40 
     41     def setUp(self):
     42         self._frontend_common_setup()
     43         self._set_monitor_stubs()
     44 
     45 
     46     def tearDown(self):
     47         self._database.disconnect()
     48         self._frontend_common_teardown()
     49 
     50 
     51     def _update_hqe(self, set, where=''):
     52         query = 'UPDATE afe_host_queue_entries SET ' + set
     53         if where:
     54             query += ' WHERE ' + where
     55         self._do_query(query)
     56 
     57 
     58 class DBObjectTest(BaseSchedulerModelsTest):
     59 
     60     def test_compare_fields_in_row(self):
     61         host = scheduler_models.Host(id=1)
     62         fields = list(host._fields)
     63         row_data = [getattr(host, fieldname) for fieldname in fields]
     64         self.assertEqual({}, host._compare_fields_in_row(row_data))
     65         row_data[fields.index('hostname')] = 'spam'
     66         self.assertEqual({'hostname': ('host1', 'spam')},
     67                          host._compare_fields_in_row(row_data))
     68         row_data[fields.index('id')] = 23
     69         self.assertEqual({'hostname': ('host1', 'spam'), 'id': (1, 23)},
     70                          host._compare_fields_in_row(row_data))
     71 
     72 
     73     def test_compare_fields_in_row_datetime_ignores_microseconds(self):
     74         datetime_with_us = datetime.datetime(2009, 10, 07, 12, 34, 56, 7890)
     75         datetime_without_us = datetime.datetime(2009, 10, 07, 12, 34, 56, 0)
     76         class TestTable(scheduler_models.DBObject):
     77             _table_name = 'test_table'
     78             _fields = ('id', 'test_datetime')
     79         tt = TestTable(row=[1, datetime_without_us])
     80         self.assertEqual({}, tt._compare_fields_in_row([1, datetime_with_us]))
     81 
     82 
     83     def test_always_query(self):
     84         host_a = scheduler_models.Host(id=2)
     85         self.assertEqual(host_a.hostname, 'host2')
     86         self._do_query('UPDATE afe_hosts SET hostname="host2-updated" '
     87                        'WHERE id=2')
     88         host_b = scheduler_models.Host(id=2, always_query=True)
     89         self.assert_(host_a is host_b, 'Cached instance not returned.')
     90         self.assertEqual(host_a.hostname, 'host2-updated',
     91                          'Database was not re-queried')
     92 
     93         # If either of these are called, a query was made when it shouldn't be.
     94         host_a._compare_fields_in_row = lambda _: self.fail('eek! a query!')
     95         host_a._update_fields_from_row = host_a._compare_fields_in_row
     96         host_c = scheduler_models.Host(id=2, always_query=False)
     97         self.assert_(host_a is host_c, 'Cached instance not returned')
     98 
     99 
    100     def test_delete(self):
    101         host = scheduler_models.Host(id=3)
    102         host.delete()
    103         host = self.assertRaises(scheduler_models.DBError, scheduler_models.Host, id=3,
    104                                  always_query=False)
    105         host = self.assertRaises(scheduler_models.DBError, scheduler_models.Host, id=3,
    106                                  always_query=True)
    107 
    108     def test_save(self):
    109         # Dummy Job to avoid creating a one in the HostQueueEntry __init__.
    110         class MockJob(object):
    111             def __init__(self, id, row):
    112                 pass
    113             def tag(self):
    114                 return 'MockJob'
    115         self.god.stub_with(scheduler_models, 'Job', MockJob)
    116         hqe = scheduler_models.HostQueueEntry(
    117                 new_record=True,
    118                 row=[0, 1, 2, 'Queued', None, 0, 0, 0, '.', None, False, None,
    119                      None])
    120         hqe.save()
    121         new_id = hqe.id
    122         # Force a re-query and verify that the correct data was stored.
    123         scheduler_models.DBObject._clear_instance_cache()
    124         hqe = scheduler_models.HostQueueEntry(id=new_id)
    125         self.assertEqual(hqe.id, new_id)
    126         self.assertEqual(hqe.job_id, 1)
    127         self.assertEqual(hqe.host_id, 2)
    128         self.assertEqual(hqe.status, 'Queued')
    129         self.assertEqual(hqe.meta_host, None)
    130         self.assertEqual(hqe.active, False)
    131         self.assertEqual(hqe.complete, False)
    132         self.assertEqual(hqe.deleted, False)
    133         self.assertEqual(hqe.execution_subdir, '.')
    134         self.assertEqual(hqe.started_on, None)
    135         self.assertEqual(hqe.finished_on, None)
    136 
    137 
    138 class HostTest(BaseSchedulerModelsTest):
    139 
    140     def setUp(self):
    141         super(HostTest, self).setUp()
    142         self.old_config = scheduler_models.RESPECT_STATIC_LABELS
    143 
    144 
    145     def tearDown(self):
    146         super(HostTest, self).tearDown()
    147         scheduler_models.RESPECT_STATIC_LABELS = self.old_config
    148 
    149 
    150     def _setup_static_labels(self):
    151         label1 = models.Label.objects.create(name='non_static_label')
    152         non_static_platform = models.Label.objects.create(
    153                 name='static_platform', platform=False)
    154         models.ReplacedLabel.objects.create(label_id=non_static_platform.id)
    155 
    156         static_label1 = models.StaticLabel.objects.create(
    157                 name='no_reference_label', platform=False)
    158         static_platform = models.StaticLabel.objects.create(
    159                 name=non_static_platform.name, platform=True)
    160 
    161         host1 = models.Host.objects.create(hostname='test_host')
    162         host1.labels.add(label1)
    163         host1.labels.add(non_static_platform)
    164         host1.static_labels.add(static_label1)
    165         host1.static_labels.add(static_platform)
    166         host1.save()
    167         return host1
    168 
    169 
    170     def test_platform_and_labels_with_respect(self):
    171         scheduler_models.RESPECT_STATIC_LABELS = True
    172         test_host = self._setup_static_labels()
    173         host = scheduler_models.Host(id=test_host.id)
    174         platform, all_labels = host.platform_and_labels()
    175         self.assertEqual(platform, 'static_platform')
    176         self.assertNotIn('no_reference_label', all_labels)
    177         self.assertEqual(all_labels, ['non_static_label', 'static_platform'])
    178 
    179 
    180     def test_platform_and_labels_without_respect(self):
    181         scheduler_models.RESPECT_STATIC_LABELS = False
    182         test_host = self._setup_static_labels()
    183         host = scheduler_models.Host(id=test_host.id)
    184         platform, all_labels = host.platform_and_labels()
    185         self.assertIsNone(platform)
    186         self.assertEqual(all_labels, ['non_static_label', 'static_platform'])
    187 
    188 
    189     def test_cmp_for_sort(self):
    190         expected_order = [
    191                 'alice', 'Host1', 'host2', 'host3', 'host09', 'HOST010',
    192                 'host10', 'host11', 'yolkfolk']
    193         hostname_idx = list(scheduler_models.Host._fields).index('hostname')
    194         row = [None] * len(scheduler_models.Host._fields)
    195         hosts = []
    196         for hostname in expected_order:
    197             row[hostname_idx] = hostname
    198             hosts.append(scheduler_models.Host(row=row, new_record=True))
    199 
    200         host1 = hosts[expected_order.index('Host1')]
    201         host010 = hosts[expected_order.index('HOST010')]
    202         host10 = hosts[expected_order.index('host10')]
    203         host3 = hosts[expected_order.index('host3')]
    204         alice = hosts[expected_order.index('alice')]
    205         self.assertEqual(0, scheduler_models.Host.cmp_for_sort(host10, host10))
    206         self.assertEqual(1, scheduler_models.Host.cmp_for_sort(host10, host010))
    207         self.assertEqual(-1, scheduler_models.Host.cmp_for_sort(host010, host10))
    208         self.assertEqual(-1, scheduler_models.Host.cmp_for_sort(host1, host10))
    209         self.assertEqual(-1, scheduler_models.Host.cmp_for_sort(host1, host010))
    210         self.assertEqual(-1, scheduler_models.Host.cmp_for_sort(host3, host10))
    211         self.assertEqual(-1, scheduler_models.Host.cmp_for_sort(host3, host010))
    212         self.assertEqual(1, scheduler_models.Host.cmp_for_sort(host3, host1))
    213         self.assertEqual(-1, scheduler_models.Host.cmp_for_sort(host1, host3))
    214         self.assertEqual(-1, scheduler_models.Host.cmp_for_sort(alice, host3))
    215         self.assertEqual(1, scheduler_models.Host.cmp_for_sort(host3, alice))
    216         self.assertEqual(0, scheduler_models.Host.cmp_for_sort(alice, alice))
    217 
    218         hosts.sort(cmp=scheduler_models.Host.cmp_for_sort)
    219         self.assertEqual(expected_order, [h.hostname for h in hosts])
    220 
    221         hosts.reverse()
    222         hosts.sort(cmp=scheduler_models.Host.cmp_for_sort)
    223         self.assertEqual(expected_order, [h.hostname for h in hosts])
    224 
    225 
    226 class HostQueueEntryTest(BaseSchedulerModelsTest):
    227     def _create_hqe(self, dependency_labels=(), **create_job_kwargs):
    228         job = self._create_job(**create_job_kwargs)
    229         for label in dependency_labels:
    230             job.dependency_labels.add(label)
    231         hqes = list(scheduler_models.HostQueueEntry.fetch(where='job_id=%d' % job.id))
    232         self.assertEqual(1, len(hqes))
    233         return hqes[0]
    234 
    235 
    236     def _check_hqe_labels(self, hqe, expected_labels):
    237         expected_labels = set(expected_labels)
    238         label_names = set(label.name for label in hqe.get_labels())
    239         self.assertEqual(expected_labels, label_names)
    240 
    241 
    242     def test_get_labels_empty(self):
    243         hqe = self._create_hqe(hosts=[1])
    244         labels = list(hqe.get_labels())
    245         self.assertEqual([], labels)
    246 
    247 
    248     def test_get_labels_metahost(self):
    249         hqe = self._create_hqe(metahosts=[2])
    250         self._check_hqe_labels(hqe, ['label2'])
    251 
    252 
    253     def test_get_labels_dependencies(self):
    254         hqe = self._create_hqe(dependency_labels=(self.label3,),
    255                                metahosts=[1])
    256         self._check_hqe_labels(hqe, ['label1', 'label3'])
    257 
    258 
    259     def setup_abort_test(self, agent_finished=True):
    260         """Setup the variables for testing abort method.
    261 
    262         @param agent_finished: True to mock agent is finished before aborting
    263                                the hqe.
    264         @return hqe, dispatcher: Mock object of hqe and dispatcher to be used
    265                                to test abort method.
    266         """
    267         hqe = self._create_hqe(hosts=[1])
    268         hqe.aborted = True
    269         hqe.complete = False
    270         hqe.status = models.HostQueueEntry.Status.STARTING
    271         hqe.started_on = datetime.datetime.now()
    272 
    273         dispatcher = self.god.create_mock_class(monitor_db.Dispatcher,
    274                                                 'Dispatcher')
    275         agent = self.god.create_mock_class(monitor_db.Agent, 'Agent')
    276         dispatcher.get_agents_for_entry.expect_call(hqe).and_return([agent])
    277         agent.is_done.expect_call().and_return(agent_finished)
    278         return hqe, dispatcher
    279 
    280 
    281     def test_abort_fail_with_unfinished_agent(self):
    282         """abort should fail if the hqe still has agent not finished.
    283         """
    284         hqe, dispatcher = self.setup_abort_test(agent_finished=False)
    285         self.assertIsNone(hqe.finished_on)
    286         with self.assertRaises(AssertionError):
    287             hqe.abort(dispatcher)
    288         self.god.check_playback()
    289         # abort failed, finished_on should not be set
    290         self.assertIsNone(hqe.finished_on)
    291 
    292 
    293     def test_abort_success(self):
    294         """abort should succeed if all agents for the hqe are finished.
    295         """
    296         hqe, dispatcher = self.setup_abort_test(agent_finished=True)
    297         self.assertIsNone(hqe.finished_on)
    298         hqe.abort(dispatcher)
    299         self.god.check_playback()
    300         self.assertIsNotNone(hqe.finished_on)
    301 
    302 
    303     def test_set_finished_on(self):
    304         """Test that finished_on is set when hqe completes."""
    305         for status in host_queue_entry_states.Status.values:
    306             hqe = self._create_hqe(hosts=[1])
    307             hqe.started_on = datetime.datetime.now()
    308             hqe.job.update_field('shard_id', 3)
    309             self.assertIsNone(hqe.finished_on)
    310             hqe.set_status(status)
    311             if status in host_queue_entry_states.COMPLETE_STATUSES:
    312                 self.assertIsNotNone(hqe.finished_on)
    313                 self.assertIsNone(hqe.job.shard_id)
    314             else:
    315                 self.assertIsNone(hqe.finished_on)
    316                 self.assertEquals(hqe.job.shard_id, 3)
    317 
    318 
    319 class JobTest(BaseSchedulerModelsTest):
    320     def setUp(self):
    321         super(JobTest, self).setUp()
    322 
    323         def _mock_create(**kwargs):
    324             task = models.SpecialTask(**kwargs)
    325             task.save()
    326             self._tasks.append(task)
    327         self.god.stub_with(models.SpecialTask.objects, 'create', _mock_create)
    328 
    329 
    330     def _test_pre_job_tasks_helper(self,
    331                             reboot_before=model_attributes.RebootBefore.ALWAYS):
    332         """
    333         Calls HQE._do_schedule_pre_job_tasks() and returns the created special
    334         task
    335         """
    336         self._tasks = []
    337         queue_entry = scheduler_models.HostQueueEntry.fetch('id = 1')[0]
    338         queue_entry.job.reboot_before = reboot_before
    339         queue_entry._do_schedule_pre_job_tasks()
    340         return self._tasks
    341 
    342 
    343     def test_job_request_abort(self):
    344         django_job = self._create_job(hosts=[5, 6])
    345         job = scheduler_models.Job(django_job.id)
    346         job.request_abort()
    347         django_hqes = list(models.HostQueueEntry.objects.filter(job=job.id))
    348         for hqe in django_hqes:
    349             self.assertTrue(hqe.aborted)
    350 
    351 
    352     def _check_special_tasks(self, tasks, task_types):
    353         self.assertEquals(len(tasks), len(task_types))
    354         for task, (task_type, queue_entry_id) in zip(tasks, task_types):
    355             self.assertEquals(task.task, task_type)
    356             self.assertEquals(task.host.id, 1)
    357             if queue_entry_id:
    358                 self.assertEquals(task.queue_entry.id, queue_entry_id)
    359 
    360 
    361     def test_run_asynchronous(self):
    362         self._create_job(hosts=[1, 2])
    363 
    364         tasks = self._test_pre_job_tasks_helper()
    365 
    366         self._check_special_tasks(tasks, [(models.SpecialTask.Task.RESET, 1)])
    367 
    368 
    369     def test_run_asynchronous_skip_verify(self):
    370         job = self._create_job(hosts=[1, 2])
    371         job.run_verify = False
    372         job.save()
    373 
    374         tasks = self._test_pre_job_tasks_helper()
    375 
    376         self._check_special_tasks(tasks, [(models.SpecialTask.Task.RESET, 1)])
    377 
    378 
    379     def test_run_synchronous_verify(self):
    380         self._create_job(hosts=[1, 2], synchronous=True)
    381 
    382         tasks = self._test_pre_job_tasks_helper()
    383 
    384         self._check_special_tasks(tasks, [(models.SpecialTask.Task.RESET, 1)])
    385 
    386 
    387     def test_run_synchronous_skip_verify(self):
    388         job = self._create_job(hosts=[1, 2], synchronous=True)
    389         job.run_verify = False
    390         job.save()
    391 
    392         tasks = self._test_pre_job_tasks_helper()
    393 
    394         self._check_special_tasks(tasks, [(models.SpecialTask.Task.RESET, 1)])
    395 
    396 
    397     def test_run_asynchronous_do_not_reset(self):
    398         job = self._create_job(hosts=[1, 2])
    399         job.run_reset = False
    400         job.run_verify = False
    401         job.save()
    402 
    403         tasks = self._test_pre_job_tasks_helper()
    404 
    405         self.assertEquals(tasks, [])
    406 
    407 
    408     def test_run_synchronous_do_not_reset_no_RebootBefore(self):
    409         job = self._create_job(hosts=[1, 2], synchronous=True)
    410         job.reboot_before = model_attributes.RebootBefore.NEVER
    411         job.save()
    412 
    413         tasks = self._test_pre_job_tasks_helper(
    414                             reboot_before=model_attributes.RebootBefore.NEVER)
    415 
    416         self._check_special_tasks(tasks, [(models.SpecialTask.Task.VERIFY, 1)])
    417 
    418 
    419     def test_run_asynchronous_do_not_reset(self):
    420         job = self._create_job(hosts=[1, 2], synchronous=False)
    421         job.reboot_before = model_attributes.RebootBefore.NEVER
    422         job.save()
    423 
    424         tasks = self._test_pre_job_tasks_helper(
    425                             reboot_before=model_attributes.RebootBefore.NEVER)
    426 
    427         self._check_special_tasks(tasks, [(models.SpecialTask.Task.VERIFY, 1)])
    428 
    429 
    430     def test_reboot_before_always(self):
    431         job = self._create_job(hosts=[1])
    432         job.reboot_before = model_attributes.RebootBefore.ALWAYS
    433         job.save()
    434 
    435         tasks = self._test_pre_job_tasks_helper()
    436 
    437         self._check_special_tasks(tasks, [
    438                 (models.SpecialTask.Task.RESET, None)
    439             ])
    440 
    441 
    442     def _test_reboot_before_if_dirty_helper(self):
    443         job = self._create_job(hosts=[1])
    444         job.reboot_before = model_attributes.RebootBefore.IF_DIRTY
    445         job.save()
    446 
    447         tasks = self._test_pre_job_tasks_helper()
    448         task_types = [(models.SpecialTask.Task.RESET, None)]
    449 
    450         self._check_special_tasks(tasks, task_types)
    451 
    452 
    453     def test_reboot_before_if_dirty(self):
    454         models.Host.smart_get(1).update_object(dirty=True)
    455         self._test_reboot_before_if_dirty_helper()
    456 
    457 
    458     def test_reboot_before_not_dirty(self):
    459         models.Host.smart_get(1).update_object(dirty=False)
    460         self._test_reboot_before_if_dirty_helper()
    461 
    462 
    463 if __name__ == '__main__':
    464     unittest.main()
    465