Home | History | Annotate | Download | only in afe
      1 #!/usr/bin/env python2
      2 # pylint: disable=missing-docstring
      3 
      4 import datetime
      5 import mox
      6 import unittest
      7 
      8 import common
      9 from autotest_lib.client.common_lib import control_data
     10 from autotest_lib.client.common_lib import error
     11 from autotest_lib.client.common_lib import global_config
     12 from autotest_lib.client.common_lib import priorities
     13 from autotest_lib.client.common_lib.cros import dev_server
     14 from autotest_lib.client.common_lib.test_utils import mock
     15 from autotest_lib.frontend import setup_django_environment
     16 from autotest_lib.frontend.afe import frontend_test_utils
     17 from autotest_lib.frontend.afe import model_logic
     18 from autotest_lib.frontend.afe import models
     19 from autotest_lib.frontend.afe import rpc_interface
     20 from autotest_lib.frontend.afe import rpc_utils
     21 from autotest_lib.server import frontend
     22 from autotest_lib.server import utils as server_utils
     23 from autotest_lib.server.cros import provision
     24 from autotest_lib.server.cros.dynamic_suite import constants
     25 from autotest_lib.server.cros.dynamic_suite import control_file_getter
     26 from autotest_lib.server.cros.dynamic_suite import frontend_wrappers
     27 
     28 CLIENT = control_data.CONTROL_TYPE_NAMES.CLIENT
     29 SERVER = control_data.CONTROL_TYPE_NAMES.SERVER
     30 
     31 _hqe_status = models.HostQueueEntry.Status
     32 
     33 
     34 class ShardHeartbeatTest(mox.MoxTestBase, unittest.TestCase):
     35 
     36     _PRIORITY = priorities.Priority.DEFAULT
     37 
     38     def _do_heartbeat_and_assert_response(self, shard_hostname='shard1',
     39                                           upload_jobs=(), upload_hqes=(),
     40                                           known_jobs=(), known_hosts=(),
     41                                           **kwargs):
     42         known_job_ids = [job.id for job in known_jobs]
     43         known_host_ids = [host.id for host in known_hosts]
     44         known_host_statuses = [host.status for host in known_hosts]
     45 
     46         retval = rpc_interface.shard_heartbeat(
     47             shard_hostname=shard_hostname,
     48             jobs=upload_jobs, hqes=upload_hqes,
     49             known_job_ids=known_job_ids, known_host_ids=known_host_ids,
     50             known_host_statuses=known_host_statuses)
     51 
     52         self._assert_shard_heartbeat_response(shard_hostname, retval,
     53                                               **kwargs)
     54 
     55         return shard_hostname
     56 
     57 
     58     def _assert_shard_heartbeat_response(self, shard_hostname, retval, jobs=[],
     59                                          hosts=[], hqes=[],
     60                                          incorrect_host_ids=[]):
     61 
     62         retval_hosts, retval_jobs = retval['hosts'], retval['jobs']
     63         retval_incorrect_hosts = retval['incorrect_host_ids']
     64 
     65         expected_jobs = [
     66             (job.id, job.name, shard_hostname) for job in jobs]
     67         returned_jobs = [(job['id'], job['name'], job['shard']['hostname'])
     68                          for job in retval_jobs]
     69         self.assertEqual(returned_jobs, expected_jobs)
     70 
     71         expected_hosts = [(host.id, host.hostname) for host in hosts]
     72         returned_hosts = [(host['id'], host['hostname'])
     73                           for host in retval_hosts]
     74         self.assertEqual(returned_hosts, expected_hosts)
     75 
     76         retval_hqes = []
     77         for job in retval_jobs:
     78             retval_hqes += job['hostqueueentry_set']
     79 
     80         expected_hqes = [(hqe.id) for hqe in hqes]
     81         returned_hqes = [(hqe['id']) for hqe in retval_hqes]
     82         self.assertEqual(returned_hqes, expected_hqes)
     83 
     84         self.assertEqual(retval_incorrect_hosts, incorrect_host_ids)
     85 
     86 
     87     def _createJobForLabel(self, label):
     88         job_id = rpc_interface.create_job(name='dummy', priority=self._PRIORITY,
     89                                           control_file='foo',
     90                                           control_type=CLIENT,
     91                                           meta_hosts=[label.name],
     92                                           dependencies=(label.name,))
     93         return models.Job.objects.get(id=job_id)
     94 
     95 
     96     def _testShardHeartbeatFetchHostlessJobHelper(self, host1):
     97         """Create a hostless job and ensure it's not assigned to a shard."""
     98         label2 = models.Label.objects.create(name='bluetooth', platform=False)
     99 
    100         job1 = self._create_job(hostless=True)
    101 
    102         # Hostless jobs should be executed by the global scheduler.
    103         self._do_heartbeat_and_assert_response(hosts=[host1])
    104 
    105 
    106     def _testShardHeartbeatIncorrectHostsHelper(self, host1):
    107         """Ensure that hosts that don't belong to shard are determined."""
    108         host2 = models.Host.objects.create(hostname='test_host2', leased=False)
    109 
    110         # host2 should not belong to shard1. Ensure that if shard1 thinks host2
    111         # is a known host, then it is returned as invalid.
    112         self._do_heartbeat_and_assert_response(known_hosts=[host1, host2],
    113                                                incorrect_host_ids=[host2.id])
    114 
    115 
    116     def _testShardHeartbeatLabelRemovalRaceHelper(self, shard1, host1, label1):
    117         """Ensure correctness if label removed during heartbeat."""
    118         host2 = models.Host.objects.create(hostname='test_host2', leased=False)
    119         host2.labels.add(label1)
    120         self.assertEqual(host2.shard, None)
    121 
    122         # In the middle of the assign_to_shard call, remove label1 from shard1.
    123         self.mox.StubOutWithMock(models.Host, '_assign_to_shard_nothing_helper')
    124         def remove_label():
    125             rpc_interface.remove_board_from_shard(shard1.hostname, label1.name)
    126 
    127         models.Host._assign_to_shard_nothing_helper().WithSideEffects(
    128             remove_label)
    129         self.mox.ReplayAll()
    130 
    131         self._do_heartbeat_and_assert_response(
    132             known_hosts=[host1], hosts=[], incorrect_host_ids=[host1.id])
    133         host2 = models.Host.smart_get(host2.id)
    134         self.assertEqual(host2.shard, None)
    135 
    136 
    137     def _testShardRetrieveJobsHelper(self, shard1, host1, label1, shard2,
    138                                      host2, label2):
    139         """Create jobs and retrieve them."""
    140         # should never be returned by heartbeat
    141         leased_host = models.Host.objects.create(hostname='leased_host',
    142                                                  leased=True)
    143 
    144         leased_host.labels.add(label1)
    145 
    146         job1 = self._createJobForLabel(label1)
    147 
    148         job2 = self._createJobForLabel(label2)
    149 
    150         job_completed = self._createJobForLabel(label1)
    151         # Job is already being run, so don't sync it
    152         job_completed.hostqueueentry_set.update(complete=True)
    153         job_completed.hostqueueentry_set.create(complete=False)
    154 
    155         job_active = self._createJobForLabel(label1)
    156         # Job is already started, so don't sync it
    157         job_active.hostqueueentry_set.update(active=True)
    158         job_active.hostqueueentry_set.create(complete=False, active=False)
    159 
    160         self._do_heartbeat_and_assert_response(
    161             jobs=[job1], hosts=[host1], hqes=job1.hostqueueentry_set.all())
    162 
    163         self._do_heartbeat_and_assert_response(
    164             shard_hostname=shard2.hostname,
    165             jobs=[job2], hosts=[host2], hqes=job2.hostqueueentry_set.all())
    166 
    167         host3 = models.Host.objects.create(hostname='test_host3', leased=False)
    168         host3.labels.add(label1)
    169 
    170         self._do_heartbeat_and_assert_response(
    171             known_jobs=[job1], known_hosts=[host1], hosts=[host3])
    172 
    173 
    174     def _testResendJobsAfterFailedHeartbeatHelper(self, shard1, host1, label1):
    175         """Create jobs, retrieve them, fail on client, fetch them again."""
    176         job1 = self._createJobForLabel(label1)
    177 
    178         self._do_heartbeat_and_assert_response(
    179             jobs=[job1],
    180             hqes=job1.hostqueueentry_set.all(), hosts=[host1])
    181 
    182         # Make sure it's resubmitted by sending last_job=None again
    183         self._do_heartbeat_and_assert_response(
    184             known_hosts=[host1],
    185             jobs=[job1], hqes=job1.hostqueueentry_set.all(), hosts=[])
    186 
    187         # Now it worked, make sure it's not sent again
    188         self._do_heartbeat_and_assert_response(
    189             known_jobs=[job1], known_hosts=[host1])
    190 
    191         job1 = models.Job.objects.get(pk=job1.id)
    192         job1.hostqueueentry_set.all().update(complete=True)
    193 
    194         # Job is completed, make sure it's not sent again
    195         self._do_heartbeat_and_assert_response(
    196             known_hosts=[host1])
    197 
    198         job2 = self._createJobForLabel(label1)
    199 
    200         # job2's creation was later, it should be returned now.
    201         self._do_heartbeat_and_assert_response(
    202             known_hosts=[host1],
    203             jobs=[job2], hqes=job2.hostqueueentry_set.all())
    204 
    205         self._do_heartbeat_and_assert_response(
    206             known_jobs=[job2], known_hosts=[host1])
    207 
    208         job2 = models.Job.objects.get(pk=job2.pk)
    209         job2.hostqueueentry_set.update(aborted=True)
    210         # Setting a job to a complete status will set the shard_id to None in
    211         # scheduler_models. We have to emulate that here, because we use Django
    212         # models in tests.
    213         job2.shard = None
    214         job2.save()
    215 
    216         self._do_heartbeat_and_assert_response(
    217             known_jobs=[job2], known_hosts=[host1],
    218             jobs=[job2],
    219             hqes=job2.hostqueueentry_set.all())
    220 
    221         models.Test.objects.create(name='platform_BootPerfServer:shard',
    222                                    test_type=1)
    223         self.mox.StubOutWithMock(server_utils, 'read_file')
    224         self.mox.ReplayAll()
    225         rpc_interface.delete_shard(hostname=shard1.hostname)
    226 
    227         self.assertRaises(
    228             models.Shard.DoesNotExist, models.Shard.objects.get, pk=shard1.id)
    229 
    230         job1 = models.Job.objects.get(pk=job1.id)
    231         label1 = models.Label.objects.get(pk=label1.id)
    232 
    233         self.assertIsNone(job1.shard)
    234         self.assertEqual(len(label1.shard_set.all()), 0)
    235 
    236 
    237     def _testResendHostsAfterFailedHeartbeatHelper(self, host1):
    238         """Check that master accepts resending updated records after failure."""
    239         # Send the host
    240         self._do_heartbeat_and_assert_response(hosts=[host1])
    241 
    242         # Send it again because previous one didn't persist correctly
    243         self._do_heartbeat_and_assert_response(hosts=[host1])
    244 
    245         # Now it worked, make sure it isn't sent again
    246         self._do_heartbeat_and_assert_response(known_hosts=[host1])
    247 
    248 
    249 class RpcInterfaceTestWithStaticAttribute(
    250         mox.MoxTestBase, unittest.TestCase,
    251         frontend_test_utils.FrontendTestMixin):
    252 
    253     def setUp(self):
    254         super(RpcInterfaceTestWithStaticAttribute, self).setUp()
    255         self._frontend_common_setup()
    256         self.god = mock.mock_god()
    257         self.old_respect_static_config = rpc_interface.RESPECT_STATIC_ATTRIBUTES
    258         rpc_interface.RESPECT_STATIC_ATTRIBUTES = True
    259         models.RESPECT_STATIC_ATTRIBUTES = True
    260 
    261 
    262     def tearDown(self):
    263         self.god.unstub_all()
    264         self._frontend_common_teardown()
    265         global_config.global_config.reset_config_values()
    266         rpc_interface.RESPECT_STATIC_ATTRIBUTES = self.old_respect_static_config
    267         models.RESPECT_STATIC_ATTRIBUTES = self.old_respect_static_config
    268 
    269 
    270     def _fake_host_with_static_attributes(self):
    271         host1 = models.Host.objects.create(hostname='test_host')
    272         host1.set_attribute('test_attribute1', 'test_value1')
    273         host1.set_attribute('test_attribute2', 'test_value2')
    274         self._set_static_attribute(host1, 'test_attribute1', 'static_value1')
    275         self._set_static_attribute(host1, 'static_attribute1', 'static_value2')
    276         host1.save()
    277         return host1
    278 
    279 
    280     def test_get_hosts(self):
    281         host1 = self._fake_host_with_static_attributes()
    282         hosts = rpc_interface.get_hosts(hostname=host1.hostname)
    283         host = hosts[0]
    284 
    285         self.assertEquals(host['hostname'], 'test_host')
    286         self.assertEquals(host['acls'], ['Everyone'])
    287         # Respect the value of static attributes.
    288         self.assertEquals(host['attributes'],
    289                           {'test_attribute1': 'static_value1',
    290                            'test_attribute2': 'test_value2',
    291                            'static_attribute1': 'static_value2'})
    292 
    293     def test_get_host_attribute_with_static(self):
    294         host1 = models.Host.objects.create(hostname='test_host1')
    295         host1.set_attribute('test_attribute1', 'test_value1')
    296         self._set_static_attribute(host1, 'test_attribute1', 'static_value1')
    297         host2 = models.Host.objects.create(hostname='test_host2')
    298         host2.set_attribute('test_attribute1', 'test_value1')
    299         host2.set_attribute('test_attribute2', 'test_value2')
    300 
    301         attributes = rpc_interface.get_host_attribute(
    302                 'test_attribute1',
    303                 hostname__in=['test_host1', 'test_host2'])
    304         hosts = [attr['host'] for attr in attributes]
    305         values = [attr['value'] for attr in attributes]
    306         self.assertEquals(set(hosts),
    307                           set(['test_host1', 'test_host2']))
    308         self.assertEquals(set(values),
    309                           set(['test_value1', 'static_value1']))
    310 
    311 
    312     def test_get_hosts_by_attribute_without_static(self):
    313         host1 = models.Host.objects.create(hostname='test_host1')
    314         host1.set_attribute('test_attribute1', 'test_value1')
    315         host2 = models.Host.objects.create(hostname='test_host2')
    316         host2.set_attribute('test_attribute1', 'test_value1')
    317 
    318         hosts = rpc_interface.get_hosts_by_attribute(
    319                 'test_attribute1', 'test_value1')
    320         self.assertEquals(set(hosts),
    321                           set(['test_host1', 'test_host2']))
    322 
    323 
    324     def test_get_hosts_by_attribute_with_static(self):
    325         host1 = models.Host.objects.create(hostname='test_host1')
    326         host1.set_attribute('test_attribute1', 'test_value1')
    327         self._set_static_attribute(host1, 'test_attribute1', 'test_value1')
    328         host2 = models.Host.objects.create(hostname='test_host2')
    329         host2.set_attribute('test_attribute1', 'test_value1')
    330         self._set_static_attribute(host2, 'test_attribute1', 'static_value1')
    331         host3 = models.Host.objects.create(hostname='test_host3')
    332         self._set_static_attribute(host3, 'test_attribute1', 'test_value1')
    333         host4 = models.Host.objects.create(hostname='test_host4')
    334         host4.set_attribute('test_attribute1', 'test_value1')
    335         host5 = models.Host.objects.create(hostname='test_host5')
    336         host5.set_attribute('test_attribute1', 'temp_value1')
    337         self._set_static_attribute(host5, 'test_attribute1', 'test_value1')
    338 
    339         hosts = rpc_interface.get_hosts_by_attribute(
    340                 'test_attribute1', 'test_value1')
    341         # host1: matched, it has the same value for test_attribute1.
    342         # host2: not matched, it has a new value in
    343         #        afe_static_host_attributes for test_attribute1.
    344         # host3: matched, it has a corresponding entry in
    345         #        afe_host_attributes for test_attribute1.
    346         # host4: matched, test_attribute1 is not replaced by static
    347         #        attribute.
    348         # host5: matched, it has an updated & matched value for
    349         #        test_attribute1 in afe_static_host_attributes.
    350         self.assertEquals(set(hosts),
    351                           set(['test_host1', 'test_host3',
    352                                'test_host4', 'test_host5']))
    353 
    354 
    355 class RpcInterfaceTestWithStaticLabel(ShardHeartbeatTest,
    356                                       frontend_test_utils.FrontendTestMixin):
    357 
    358     _STATIC_LABELS = ['board:lumpy']
    359 
    360     def setUp(self):
    361         super(RpcInterfaceTestWithStaticLabel, self).setUp()
    362         self._frontend_common_setup()
    363         self.god = mock.mock_god()
    364         self.old_respect_static_config = rpc_interface.RESPECT_STATIC_LABELS
    365         rpc_interface.RESPECT_STATIC_LABELS = True
    366         models.RESPECT_STATIC_LABELS = True
    367 
    368 
    369     def tearDown(self):
    370         self.god.unstub_all()
    371         self._frontend_common_teardown()
    372         global_config.global_config.reset_config_values()
    373         rpc_interface.RESPECT_STATIC_LABELS = self.old_respect_static_config
    374         models.RESPECT_STATIC_LABELS = self.old_respect_static_config
    375 
    376 
    377     def _fake_host_with_static_labels(self):
    378         host1 = models.Host.objects.create(hostname='test_host')
    379         label1 = models.Label.objects.create(
    380                 name='non_static_label1', platform=False)
    381         non_static_platform = models.Label.objects.create(
    382                 name='static_platform', platform=False)
    383         static_platform = models.StaticLabel.objects.create(
    384                 name='static_platform', platform=True)
    385         models.ReplacedLabel.objects.create(label_id=non_static_platform.id)
    386         host1.static_labels.add(static_platform)
    387         host1.labels.add(non_static_platform)
    388         host1.labels.add(label1)
    389         host1.save()
    390         return host1
    391 
    392 
    393     def test_get_hosts(self):
    394         host1 = self._fake_host_with_static_labels()
    395         hosts = rpc_interface.get_hosts(hostname=host1.hostname)
    396         host = hosts[0]
    397 
    398         self.assertEquals(host['hostname'], 'test_host')
    399         self.assertEquals(host['acls'], ['Everyone'])
    400         # Respect all labels in afe_hosts_labels.
    401         self.assertEquals(host['labels'],
    402                           ['non_static_label1', 'static_platform'])
    403         # Respect static labels.
    404         self.assertEquals(host['platform'], 'static_platform')
    405 
    406 
    407     def test_get_hosts_multiple_labels(self):
    408         self._fake_host_with_static_labels()
    409         hosts = rpc_interface.get_hosts(
    410                 multiple_labels=['non_static_label1', 'static_platform'])
    411         host = hosts[0]
    412         self.assertEquals(host['hostname'], 'test_host')
    413 
    414 
    415     def test_delete_static_label(self):
    416         label1 = models.Label.smart_get('static')
    417 
    418         host2 = models.Host.objects.all()[1]
    419         shard1 = models.Shard.objects.create(hostname='shard1')
    420         host2.shard = shard1
    421         host2.labels.add(label1)
    422         host2.save()
    423 
    424         mock_afe = self.god.create_mock_class_obj(frontend_wrappers.RetryingAFE,
    425                                                   'MockAFE')
    426         self.god.stub_with(frontend_wrappers, 'RetryingAFE', mock_afe)
    427 
    428         self.assertRaises(error.UnmodifiableLabelException,
    429                           rpc_interface.delete_label,
    430                           label1.id)
    431 
    432         self.god.check_playback()
    433 
    434 
    435     def test_modify_static_label(self):
    436         label1 = models.Label.smart_get('static')
    437         self.assertEqual(label1.invalid, 0)
    438 
    439         host2 = models.Host.objects.all()[1]
    440         shard1 = models.Shard.objects.create(hostname='shard1')
    441         host2.shard = shard1
    442         host2.labels.add(label1)
    443         host2.save()
    444 
    445         mock_afe = self.god.create_mock_class_obj(frontend_wrappers.RetryingAFE,
    446                                                   'MockAFE')
    447         self.god.stub_with(frontend_wrappers, 'RetryingAFE', mock_afe)
    448 
    449         self.assertRaises(error.UnmodifiableLabelException,
    450                           rpc_interface.modify_label,
    451                           label1.id,
    452                           invalid=1)
    453 
    454         self.assertEqual(models.Label.smart_get('static').invalid, 0)
    455         self.god.check_playback()
    456 
    457 
    458     def test_multiple_platforms_add_non_static_to_static(self):
    459         """Test non-static platform to a host with static platform."""
    460         static_platform = models.StaticLabel.objects.create(
    461                 name='static_platform', platform=True)
    462         non_static_platform = models.Label.objects.create(
    463                 name='static_platform', platform=True)
    464         models.ReplacedLabel.objects.create(label_id=non_static_platform.id)
    465         platform2 = models.Label.objects.create(name='platform2', platform=True)
    466         host1 = models.Host.objects.create(hostname='test_host')
    467         host1.static_labels.add(static_platform)
    468         host1.labels.add(non_static_platform)
    469         host1.save()
    470 
    471         self.assertRaises(model_logic.ValidationError,
    472                           rpc_interface.label_add_hosts, id='platform2',
    473                           hosts=['test_host'])
    474         self.assertRaises(model_logic.ValidationError,
    475                           rpc_interface.host_add_labels,
    476                           id='test_host', labels=['platform2'])
    477         # make sure the platform didn't get added
    478         platforms = rpc_interface.get_labels(
    479             host__hostname__in=['test_host'], platform=True)
    480         self.assertEquals(len(platforms), 1)
    481 
    482 
    483     def test_multiple_platforms_add_static_to_non_static(self):
    484         """Test static platform to a host with non-static platform."""
    485         platform1 = models.Label.objects.create(
    486                 name='static_platform', platform=True)
    487         models.ReplacedLabel.objects.create(label_id=platform1.id)
    488         static_platform = models.StaticLabel.objects.create(
    489                 name='static_platform', platform=True)
    490         platform2 = models.Label.objects.create(
    491                 name='platform2', platform=True)
    492 
    493         host1 = models.Host.objects.create(hostname='test_host')
    494         host1.labels.add(platform2)
    495         host1.save()
    496 
    497         self.assertRaises(model_logic.ValidationError,
    498                           rpc_interface.label_add_hosts,
    499                           id='static_platform',
    500                           hosts=['test_host'])
    501         self.assertRaises(model_logic.ValidationError,
    502                           rpc_interface.host_add_labels,
    503                           id='test_host', labels=['static_platform'])
    504         # make sure the platform didn't get added
    505         platforms = rpc_interface.get_labels(
    506             host__hostname__in=['test_host'], platform=True)
    507         self.assertEquals(len(platforms), 1)
    508 
    509 
    510     def test_label_remove_hosts(self):
    511         """Test remove a label of hosts."""
    512         label = models.Label.smart_get('static')
    513         static_label = models.StaticLabel.objects.create(name='static')
    514 
    515         host1 = models.Host.objects.create(hostname='test_host')
    516         host1.labels.add(label)
    517         host1.static_labels.add(static_label)
    518         host1.save()
    519 
    520         self.assertRaises(error.UnmodifiableLabelException,
    521                           rpc_interface.label_remove_hosts,
    522                           id='static', hosts=['test_host'])
    523 
    524 
    525     def test_host_remove_labels(self):
    526         """Test remove labels of a given host."""
    527         label = models.Label.smart_get('static')
    528         label1 = models.Label.smart_get('label1')
    529         label2 = models.Label.smart_get('label2')
    530         static_label = models.StaticLabel.objects.create(name='static')
    531 
    532         host1 = models.Host.objects.create(hostname='test_host')
    533         host1.labels.add(label)
    534         host1.labels.add(label1)
    535         host1.labels.add(label2)
    536         host1.static_labels.add(static_label)
    537         host1.save()
    538 
    539         rpc_interface.host_remove_labels(
    540                 'test_host', ['static', 'label1'])
    541         labels = rpc_interface.get_labels(host__hostname__in=['test_host'])
    542         # Only non_static label 'label1' is removed.
    543         self.assertEquals(len(labels), 2)
    544         self.assertEquals(labels[0].get('name'), 'label2')
    545 
    546 
    547     def test_remove_board_from_shard(self):
    548         """test remove a board (static label) from shard."""
    549         label = models.Label.smart_get('static')
    550         static_label = models.StaticLabel.objects.create(name='static')
    551 
    552         shard = models.Shard.objects.create(hostname='test_shard')
    553         shard.labels.add(label)
    554 
    555         host = models.Host.objects.create(hostname='test_host',
    556                                           leased=False,
    557                                           shard=shard)
    558         host.static_labels.add(static_label)
    559         host.save()
    560 
    561         rpc_interface.remove_board_from_shard(shard.hostname, label.name)
    562         host1 = models.Host.smart_get(host.id)
    563         shard1 = models.Shard.smart_get(shard.id)
    564         self.assertEqual(host1.shard, None)
    565         self.assertItemsEqual(shard1.labels.all(), [])
    566 
    567 
    568     def test_check_job_dependencies_success(self):
    569         """Test check_job_dependencies successfully."""
    570         static_label = models.StaticLabel.objects.create(name='static')
    571 
    572         host = models.Host.objects.create(hostname='test_host')
    573         host.static_labels.add(static_label)
    574         host.save()
    575 
    576         host1 = models.Host.smart_get(host.id)
    577         rpc_utils.check_job_dependencies([host1], ['static'])
    578 
    579 
    580     def test_check_job_dependencies_fail(self):
    581         """Test check_job_dependencies with raising ValidationError."""
    582         label = models.Label.smart_get('static')
    583         static_label = models.StaticLabel.objects.create(name='static')
    584 
    585         host = models.Host.objects.create(hostname='test_host')
    586         host.labels.add(label)
    587         host.save()
    588 
    589         host1 = models.Host.smart_get(host.id)
    590         self.assertRaises(model_logic.ValidationError,
    591                           rpc_utils.check_job_dependencies,
    592                           [host1],
    593                           ['static'])
    594 
    595     def test_check_job_metahost_dependencies_success(self):
    596         """Test check_job_metahost_dependencies successfully."""
    597         label1 = models.Label.smart_get('label1')
    598         label2 = models.Label.smart_get('label2')
    599         label = models.Label.smart_get('static')
    600         static_label = models.StaticLabel.objects.create(name='static')
    601 
    602         host = models.Host.objects.create(hostname='test_host')
    603         host.static_labels.add(static_label)
    604         host.labels.add(label1)
    605         host.labels.add(label2)
    606         host.save()
    607 
    608         rpc_utils.check_job_metahost_dependencies(
    609                 [label1, label], [label2.name])
    610         rpc_utils.check_job_metahost_dependencies(
    611                 [label1], [label2.name, static_label.name])
    612 
    613 
    614     def test_check_job_metahost_dependencies_fail(self):
    615         """Test check_job_metahost_dependencies with raising errors."""
    616         label1 = models.Label.smart_get('label1')
    617         label2 = models.Label.smart_get('label2')
    618         label = models.Label.smart_get('static')
    619         static_label = models.StaticLabel.objects.create(name='static')
    620 
    621         host = models.Host.objects.create(hostname='test_host')
    622         host.labels.add(label1)
    623         host.labels.add(label2)
    624         host.save()
    625 
    626         self.assertRaises(error.NoEligibleHostException,
    627                           rpc_utils.check_job_metahost_dependencies,
    628                           [label1, label], [label2.name])
    629         self.assertRaises(error.NoEligibleHostException,
    630                           rpc_utils.check_job_metahost_dependencies,
    631                           [label1], [label2.name, static_label.name])
    632 
    633 
    634     def _createShardAndHostWithStaticLabel(self,
    635                                            shard_hostname='shard1',
    636                                            host_hostname='test_host1',
    637                                            label_name='board:lumpy'):
    638         label = models.Label.objects.create(name=label_name)
    639 
    640         shard = models.Shard.objects.create(hostname=shard_hostname)
    641         shard.labels.add(label)
    642 
    643         host = models.Host.objects.create(hostname=host_hostname, leased=False,
    644                                           shard=shard)
    645         host.labels.add(label)
    646         if label_name in self._STATIC_LABELS:
    647             models.ReplacedLabel.objects.create(label_id=label.id)
    648             static_label = models.StaticLabel.objects.create(name=label_name)
    649             host.static_labels.add(static_label)
    650 
    651         return shard, host, label
    652 
    653 
    654     def testShardHeartbeatFetchHostlessJob(self):
    655         shard1, host1, label1 = self._createShardAndHostWithStaticLabel(
    656                 host_hostname='test_host1')
    657         self._testShardHeartbeatFetchHostlessJobHelper(host1)
    658 
    659 
    660     def testShardHeartbeatIncorrectHosts(self):
    661         shard1, host1, label1 = self._createShardAndHostWithStaticLabel(
    662                 host_hostname='test_host1')
    663         self._testShardHeartbeatIncorrectHostsHelper(host1)
    664 
    665 
    666     def testShardHeartbeatLabelRemovalRace(self):
    667         shard1, host1, label1 = self._createShardAndHostWithStaticLabel(
    668                 host_hostname='test_host1')
    669         self._testShardHeartbeatLabelRemovalRaceHelper(shard1, host1, label1)
    670 
    671 
    672     def testShardRetrieveJobs(self):
    673         shard1, host1, label1 = self._createShardAndHostWithStaticLabel()
    674         shard2, host2, label2 = self._createShardAndHostWithStaticLabel(
    675             'shard2', 'test_host2', 'board:grumpy')
    676         self._testShardRetrieveJobsHelper(shard1, host1, label1,
    677                                           shard2, host2, label2)
    678 
    679 
    680     def testResendJobsAfterFailedHeartbeat(self):
    681         shard1, host1, label1 = self._createShardAndHostWithStaticLabel()
    682         self._testResendJobsAfterFailedHeartbeatHelper(shard1, host1, label1)
    683 
    684 
    685     def testResendHostsAfterFailedHeartbeat(self):
    686         shard1, host1, label1 = self._createShardAndHostWithStaticLabel(
    687                 host_hostname='test_host1')
    688         self._testResendHostsAfterFailedHeartbeatHelper(host1)
    689 
    690 
    691 class RpcInterfaceTest(unittest.TestCase,
    692                        frontend_test_utils.FrontendTestMixin):
    693     def setUp(self):
    694         self._frontend_common_setup()
    695         self.god = mock.mock_god()
    696 
    697 
    698     def tearDown(self):
    699         self.god.unstub_all()
    700         self._frontend_common_teardown()
    701         global_config.global_config.reset_config_values()
    702 
    703 
    704     def test_validation(self):
    705         # omit a required field
    706         self.assertRaises(model_logic.ValidationError, rpc_interface.add_label,
    707                           name=None)
    708         # violate uniqueness constraint
    709         self.assertRaises(model_logic.ValidationError, rpc_interface.add_host,
    710                           hostname='host1')
    711 
    712 
    713     def test_multiple_platforms(self):
    714         platform2 = models.Label.objects.create(name='platform2', platform=True)
    715         self.assertRaises(model_logic.ValidationError,
    716                           rpc_interface. label_add_hosts, id='platform2',
    717                           hosts=['host1', 'host2'])
    718         self.assertRaises(model_logic.ValidationError,
    719                           rpc_interface.host_add_labels,
    720                           id='host1', labels=['platform2'])
    721         # make sure the platform didn't get added
    722         platforms = rpc_interface.get_labels(
    723             host__hostname__in=['host1', 'host2'], platform=True)
    724         self.assertEquals(len(platforms), 1)
    725         self.assertEquals(platforms[0]['name'], 'myplatform')
    726 
    727 
    728     def _check_hostnames(self, hosts, expected_hostnames):
    729         self.assertEquals(set(host['hostname'] for host in hosts),
    730                           set(expected_hostnames))
    731 
    732 
    733     def test_ping_db(self):
    734         self.assertEquals(rpc_interface.ping_db(), [True])
    735 
    736 
    737     def test_get_hosts_by_attribute(self):
    738         host1 = models.Host.objects.create(hostname='test_host1')
    739         host1.set_attribute('test_attribute1', 'test_value1')
    740         host2 = models.Host.objects.create(hostname='test_host2')
    741         host2.set_attribute('test_attribute1', 'test_value1')
    742 
    743         hosts = rpc_interface.get_hosts_by_attribute(
    744                 'test_attribute1', 'test_value1')
    745         self.assertEquals(set(hosts),
    746                           set(['test_host1', 'test_host2']))
    747 
    748 
    749     def test_get_host_attribute(self):
    750         host1 = models.Host.objects.create(hostname='test_host1')
    751         host1.set_attribute('test_attribute1', 'test_value1')
    752         host2 = models.Host.objects.create(hostname='test_host2')
    753         host2.set_attribute('test_attribute1', 'test_value1')
    754 
    755         attributes = rpc_interface.get_host_attribute(
    756                 'test_attribute1',
    757                 hostname__in=['test_host1', 'test_host2'])
    758         hosts = [attr['host'] for attr in attributes]
    759         values = [attr['value'] for attr in attributes]
    760         self.assertEquals(set(hosts),
    761                           set(['test_host1', 'test_host2']))
    762         self.assertEquals(set(values), set(['test_value1']))
    763 
    764 
    765     def test_get_hosts(self):
    766         hosts = rpc_interface.get_hosts()
    767         self._check_hostnames(hosts, [host.hostname for host in self.hosts])
    768 
    769         hosts = rpc_interface.get_hosts(hostname='host1')
    770         self._check_hostnames(hosts, ['host1'])
    771         host = hosts[0]
    772         self.assertEquals(sorted(host['labels']), ['label1', 'myplatform'])
    773         self.assertEquals(host['platform'], 'myplatform')
    774         self.assertEquals(host['acls'], ['my_acl'])
    775         self.assertEquals(host['attributes'], {})
    776 
    777 
    778     def test_get_hosts_multiple_labels(self):
    779         hosts = rpc_interface.get_hosts(
    780                 multiple_labels=['myplatform', 'label1'])
    781         self._check_hostnames(hosts, ['host1'])
    782 
    783 
    784     def test_job_keyvals(self):
    785         keyval_dict = {'mykey': 'myvalue'}
    786         job_id = rpc_interface.create_job(name='test',
    787                                           priority=priorities.Priority.DEFAULT,
    788                                           control_file='foo',
    789                                           control_type=CLIENT,
    790                                           hosts=['host1'],
    791                                           keyvals=keyval_dict)
    792         jobs = rpc_interface.get_jobs(id=job_id)
    793         self.assertEquals(len(jobs), 1)
    794         self.assertEquals(jobs[0]['keyvals'], keyval_dict)
    795 
    796 
    797     def test_test_retry(self):
    798         job_id = rpc_interface.create_job(name='flake',
    799                                           priority=priorities.Priority.DEFAULT,
    800                                           control_file='foo',
    801                                           control_type=CLIENT,
    802                                           hosts=['host1'],
    803                                           test_retry=10)
    804         jobs = rpc_interface.get_jobs(id=job_id)
    805         self.assertEquals(len(jobs), 1)
    806         self.assertEquals(jobs[0]['test_retry'], 10)
    807 
    808 
    809     def test_get_jobs_summary(self):
    810         job = self._create_job(hosts=xrange(1, 4))
    811         entries = list(job.hostqueueentry_set.all())
    812         entries[1].status = _hqe_status.FAILED
    813         entries[1].save()
    814         entries[2].status = _hqe_status.FAILED
    815         entries[2].aborted = True
    816         entries[2].save()
    817 
    818         # Mock up tko_rpc_interface.get_status_counts.
    819         self.god.stub_function_to_return(rpc_interface.tko_rpc_interface,
    820                                          'get_status_counts',
    821                                          None)
    822 
    823         job_summaries = rpc_interface.get_jobs_summary(id=job.id)
    824         self.assertEquals(len(job_summaries), 1)
    825         summary = job_summaries[0]
    826         self.assertEquals(summary['status_counts'], {'Queued': 1,
    827                                                      'Failed': 2})
    828 
    829 
    830     def _check_job_ids(self, actual_job_dicts, expected_jobs):
    831         self.assertEquals(
    832                 set(job_dict['id'] for job_dict in actual_job_dicts),
    833                 set(job.id for job in expected_jobs))
    834 
    835 
    836     def test_get_jobs_status_filters(self):
    837         HqeStatus = models.HostQueueEntry.Status
    838         def create_two_host_job():
    839             return self._create_job(hosts=[1, 2])
    840         def set_hqe_statuses(job, first_status, second_status):
    841             entries = job.hostqueueentry_set.all()
    842             entries[0].update_object(status=first_status)
    843             entries[1].update_object(status=second_status)
    844 
    845         queued = create_two_host_job()
    846 
    847         queued_and_running = create_two_host_job()
    848         set_hqe_statuses(queued_and_running, HqeStatus.QUEUED,
    849                            HqeStatus.RUNNING)
    850 
    851         running_and_complete = create_two_host_job()
    852         set_hqe_statuses(running_and_complete, HqeStatus.RUNNING,
    853                            HqeStatus.COMPLETED)
    854 
    855         complete = create_two_host_job()
    856         set_hqe_statuses(complete, HqeStatus.COMPLETED, HqeStatus.COMPLETED)
    857 
    858         started_but_inactive = create_two_host_job()
    859         set_hqe_statuses(started_but_inactive, HqeStatus.QUEUED,
    860                            HqeStatus.COMPLETED)
    861 
    862         parsing = create_two_host_job()
    863         set_hqe_statuses(parsing, HqeStatus.PARSING, HqeStatus.PARSING)
    864 
    865         self._check_job_ids(rpc_interface.get_jobs(not_yet_run=True), [queued])
    866         self._check_job_ids(rpc_interface.get_jobs(running=True),
    867                       [queued_and_running, running_and_complete,
    868                        started_but_inactive, parsing])
    869         self._check_job_ids(rpc_interface.get_jobs(finished=True), [complete])
    870 
    871 
    872     def test_get_jobs_type_filters(self):
    873         self.assertRaises(AssertionError, rpc_interface.get_jobs,
    874                           suite=True, sub=True)
    875         self.assertRaises(AssertionError, rpc_interface.get_jobs,
    876                           suite=True, standalone=True)
    877         self.assertRaises(AssertionError, rpc_interface.get_jobs,
    878                           standalone=True, sub=True)
    879 
    880         parent_job = self._create_job(hosts=[1])
    881         child_jobs = self._create_job(hosts=[1, 2],
    882                                       parent_job_id=parent_job.id)
    883         standalone_job = self._create_job(hosts=[1])
    884 
    885         self._check_job_ids(rpc_interface.get_jobs(suite=True), [parent_job])
    886         self._check_job_ids(rpc_interface.get_jobs(sub=True), [child_jobs])
    887         self._check_job_ids(rpc_interface.get_jobs(standalone=True),
    888                             [standalone_job])
    889 
    890 
    891     def _create_job_helper(self, **kwargs):
    892         return rpc_interface.create_job(name='test',
    893                                         priority=priorities.Priority.DEFAULT,
    894                                         control_file='control file',
    895                                         control_type=SERVER, **kwargs)
    896 
    897 
    898     def test_one_time_hosts(self):
    899         job = self._create_job_helper(one_time_hosts=['testhost'])
    900         host = models.Host.objects.get(hostname='testhost')
    901         self.assertEquals(host.invalid, True)
    902         self.assertEquals(host.labels.count(), 0)
    903         self.assertEquals(host.aclgroup_set.count(), 0)
    904 
    905 
    906     def test_create_job_duplicate_hosts(self):
    907         self.assertRaises(model_logic.ValidationError, self._create_job_helper,
    908                           hosts=[1, 1])
    909 
    910 
    911     def test_create_unrunnable_metahost_job(self):
    912         self.assertRaises(error.NoEligibleHostException,
    913                           self._create_job_helper, meta_hosts=['unused'])
    914 
    915 
    916     def test_create_hostless_job(self):
    917         job_id = self._create_job_helper(hostless=True)
    918         job = models.Job.objects.get(pk=job_id)
    919         queue_entries = job.hostqueueentry_set.all()
    920         self.assertEquals(len(queue_entries), 1)
    921         self.assertEquals(queue_entries[0].host, None)
    922         self.assertEquals(queue_entries[0].meta_host, None)
    923 
    924 
    925     def _setup_special_tasks(self):
    926         host = self.hosts[0]
    927 
    928         job1 = self._create_job(hosts=[1])
    929         job2 = self._create_job(hosts=[1])
    930 
    931         entry1 = job1.hostqueueentry_set.all()[0]
    932         entry1.update_object(started_on=datetime.datetime(2009, 1, 2),
    933                              execution_subdir='host1')
    934         entry2 = job2.hostqueueentry_set.all()[0]
    935         entry2.update_object(started_on=datetime.datetime(2009, 1, 3),
    936                              execution_subdir='host1')
    937 
    938         self.task1 = models.SpecialTask.objects.create(
    939                 host=host, task=models.SpecialTask.Task.VERIFY,
    940                 time_started=datetime.datetime(2009, 1, 1), # ran before job 1
    941                 is_complete=True, requested_by=models.User.current_user())
    942         self.task2 = models.SpecialTask.objects.create(
    943                 host=host, task=models.SpecialTask.Task.VERIFY,
    944                 queue_entry=entry2, # ran with job 2
    945                 is_active=True, requested_by=models.User.current_user())
    946         self.task3 = models.SpecialTask.objects.create(
    947                 host=host, task=models.SpecialTask.Task.VERIFY,
    948                 requested_by=models.User.current_user()) # not yet run
    949 
    950 
    951     def test_get_special_tasks(self):
    952         self._setup_special_tasks()
    953         tasks = rpc_interface.get_special_tasks(host__hostname='host1',
    954                                                 queue_entry__isnull=True)
    955         self.assertEquals(len(tasks), 2)
    956         self.assertEquals(tasks[0]['task'], models.SpecialTask.Task.VERIFY)
    957         self.assertEquals(tasks[0]['is_active'], False)
    958         self.assertEquals(tasks[0]['is_complete'], True)
    959 
    960 
    961     def test_get_latest_special_task(self):
    962         # a particular usage of get_special_tasks()
    963         self._setup_special_tasks()
    964         self.task2.time_started = datetime.datetime(2009, 1, 2)
    965         self.task2.save()
    966 
    967         tasks = rpc_interface.get_special_tasks(
    968                 host__hostname='host1', task=models.SpecialTask.Task.VERIFY,
    969                 time_started__isnull=False, sort_by=['-time_started'],
    970                 query_limit=1)
    971         self.assertEquals(len(tasks), 1)
    972         self.assertEquals(tasks[0]['id'], 2)
    973 
    974 
    975     def _common_entry_check(self, entry_dict):
    976         self.assertEquals(entry_dict['host']['hostname'], 'host1')
    977         self.assertEquals(entry_dict['job']['id'], 2)
    978 
    979 
    980     def test_get_host_queue_entries_and_special_tasks(self):
    981         self._setup_special_tasks()
    982 
    983         host = self.hosts[0].id
    984         entries_and_tasks = (
    985                 rpc_interface.get_host_queue_entries_and_special_tasks(host))
    986 
    987         paths = [entry['execution_path'] for entry in entries_and_tasks]
    988         self.assertEquals(paths, ['hosts/host1/3-verify',
    989                                   '2-autotest_system/host1',
    990                                   'hosts/host1/2-verify',
    991                                   '1-autotest_system/host1',
    992                                   'hosts/host1/1-verify'])
    993 
    994         verify2 = entries_and_tasks[2]
    995         self._common_entry_check(verify2)
    996         self.assertEquals(verify2['type'], 'Verify')
    997         self.assertEquals(verify2['status'], 'Running')
    998         self.assertEquals(verify2['execution_path'], 'hosts/host1/2-verify')
    999 
   1000         entry2 = entries_and_tasks[1]
   1001         self._common_entry_check(entry2)
   1002         self.assertEquals(entry2['type'], 'Job')
   1003         self.assertEquals(entry2['status'], 'Queued')
   1004         self.assertEquals(entry2['started_on'], '2009-01-03 00:00:00')
   1005 
   1006 
   1007     def _create_hqes_and_start_time_index_entries(self):
   1008         shard = models.Shard.objects.create(hostname='shard')
   1009         job = self._create_job(shard=shard, control_file='foo')
   1010         HqeStatus = models.HostQueueEntry.Status
   1011 
   1012         models.HostQueueEntry(
   1013             id=1, job=job, started_on='2017-01-01',
   1014             status=HqeStatus.QUEUED).save()
   1015         models.HostQueueEntry(
   1016             id=2, job=job, started_on='2017-01-02',
   1017             status=HqeStatus.QUEUED).save()
   1018         models.HostQueueEntry(
   1019             id=3, job=job, started_on='2017-01-03',
   1020             status=HqeStatus.QUEUED).save()
   1021 
   1022         models.HostQueueEntryStartTimes(
   1023             insert_time='2017-01-03', highest_hqe_id=3).save()
   1024         models.HostQueueEntryStartTimes(
   1025             insert_time='2017-01-02', highest_hqe_id=2).save()
   1026         models.HostQueueEntryStartTimes(
   1027             insert_time='2017-01-01', highest_hqe_id=1).save()
   1028 
   1029     def test_get_host_queue_entries_by_insert_time(self):
   1030         """Check the insert_time_after and insert_time_before constraints."""
   1031         self._create_hqes_and_start_time_index_entries()
   1032         hqes = rpc_interface.get_host_queue_entries_by_insert_time(
   1033             insert_time_after='2017-01-01')
   1034         self.assertEquals(len(hqes), 3)
   1035 
   1036         hqes = rpc_interface.get_host_queue_entries_by_insert_time(
   1037             insert_time_after='2017-01-02')
   1038         self.assertEquals(len(hqes), 2)
   1039 
   1040         hqes = rpc_interface.get_host_queue_entries_by_insert_time(
   1041             insert_time_after='2017-01-03')
   1042         self.assertEquals(len(hqes), 1)
   1043 
   1044         hqes = rpc_interface.get_host_queue_entries_by_insert_time(
   1045             insert_time_before='2017-01-01')
   1046         self.assertEquals(len(hqes), 1)
   1047 
   1048         hqes = rpc_interface.get_host_queue_entries_by_insert_time(
   1049             insert_time_before='2017-01-02')
   1050         self.assertEquals(len(hqes), 2)
   1051 
   1052         hqes = rpc_interface.get_host_queue_entries_by_insert_time(
   1053             insert_time_before='2017-01-03')
   1054         self.assertEquals(len(hqes), 3)
   1055 
   1056 
   1057     def test_get_host_queue_entries_by_insert_time_with_missing_index_row(self):
   1058         """Shows that the constraints are approximate.
   1059 
   1060         The query may return rows which are actually outside of the bounds
   1061         given, if the index table does not have an entry for the specific time.
   1062         """
   1063         self._create_hqes_and_start_time_index_entries()
   1064         hqes = rpc_interface.get_host_queue_entries_by_insert_time(
   1065             insert_time_before='2016-12-01')
   1066         self.assertEquals(len(hqes), 1)
   1067 
   1068     def test_get_hqe_by_insert_time_with_before_and_after(self):
   1069         self._create_hqes_and_start_time_index_entries()
   1070         hqes = rpc_interface.get_host_queue_entries_by_insert_time(
   1071             insert_time_before='2017-01-02',
   1072             insert_time_after='2017-01-02')
   1073         self.assertEquals(len(hqes), 1)
   1074 
   1075     def test_get_hqe_by_insert_time_and_id_constraint(self):
   1076         self._create_hqes_and_start_time_index_entries()
   1077         # The time constraint is looser than the id constraint, so the time
   1078         # constraint should take precedence.
   1079         hqes = rpc_interface.get_host_queue_entries_by_insert_time(
   1080             insert_time_before='2017-01-02',
   1081             id__lte=1)
   1082         self.assertEquals(len(hqes), 1)
   1083 
   1084         # Now make the time constraint tighter than the id constraint.
   1085         hqes = rpc_interface.get_host_queue_entries_by_insert_time(
   1086             insert_time_before='2017-01-01',
   1087             id__lte=42)
   1088         self.assertEquals(len(hqes), 1)
   1089 
   1090     def test_view_invalid_host(self):
   1091         # RPCs used by View Host page should work for invalid hosts
   1092         self._create_job_helper(hosts=[1])
   1093         host = self.hosts[0]
   1094         host.delete()
   1095 
   1096         self.assertEquals(1, rpc_interface.get_num_hosts(hostname='host1',
   1097                                                          valid_only=False))
   1098         data = rpc_interface.get_hosts(hostname='host1', valid_only=False)
   1099         self.assertEquals(1, len(data))
   1100 
   1101         self.assertEquals(1, rpc_interface.get_num_host_queue_entries(
   1102                 host__hostname='host1'))
   1103         data = rpc_interface.get_host_queue_entries(host__hostname='host1')
   1104         self.assertEquals(1, len(data))
   1105 
   1106         count = rpc_interface.get_num_host_queue_entries_and_special_tasks(
   1107                 host=host.id)
   1108         self.assertEquals(1, count)
   1109         data = rpc_interface.get_host_queue_entries_and_special_tasks(
   1110                 host=host.id)
   1111         self.assertEquals(1, len(data))
   1112 
   1113 
   1114     def test_reverify_hosts(self):
   1115         hostname_list = rpc_interface.reverify_hosts(id__in=[1, 2])
   1116         self.assertEquals(hostname_list, ['host1', 'host2'])
   1117         tasks = rpc_interface.get_special_tasks()
   1118         self.assertEquals(len(tasks), 2)
   1119         self.assertEquals(set(task['host']['id'] for task in tasks),
   1120                           set([1, 2]))
   1121 
   1122         task = tasks[0]
   1123         self.assertEquals(task['task'], models.SpecialTask.Task.VERIFY)
   1124         self.assertEquals(task['requested_by'], 'autotest_system')
   1125 
   1126 
   1127     def test_repair_hosts(self):
   1128         hostname_list = rpc_interface.repair_hosts(id__in=[1, 2])
   1129         self.assertEquals(hostname_list, ['host1', 'host2'])
   1130         tasks = rpc_interface.get_special_tasks()
   1131         self.assertEquals(len(tasks), 2)
   1132         self.assertEquals(set(task['host']['id'] for task in tasks),
   1133                           set([1, 2]))
   1134 
   1135         task = tasks[0]
   1136         self.assertEquals(task['task'], models.SpecialTask.Task.REPAIR)
   1137         self.assertEquals(task['requested_by'], 'autotest_system')
   1138 
   1139 
   1140     def _modify_host_helper(self, on_shard=False, host_on_shard=False):
   1141         shard_hostname = 'shard1'
   1142         if on_shard:
   1143             global_config.global_config.override_config_value(
   1144                 'SHARD', 'shard_hostname', shard_hostname)
   1145 
   1146         host = models.Host.objects.all()[0]
   1147         if host_on_shard:
   1148             shard = models.Shard.objects.create(hostname=shard_hostname)
   1149             host.shard = shard
   1150             host.save()
   1151 
   1152         self.assertFalse(host.locked)
   1153 
   1154         self.god.stub_class_method(frontend.AFE, 'run')
   1155 
   1156         if host_on_shard and not on_shard:
   1157             mock_afe = self.god.create_mock_class_obj(
   1158                     frontend_wrappers.RetryingAFE, 'MockAFE')
   1159             self.god.stub_with(frontend_wrappers, 'RetryingAFE', mock_afe)
   1160 
   1161             mock_afe2 = frontend_wrappers.RetryingAFE.expect_new(
   1162                     server=shard_hostname, user=None)
   1163             mock_afe2.run.expect_call('modify_host_local', id=host.id,
   1164                     locked=True, lock_reason='_modify_host_helper lock',
   1165                     lock_time=datetime.datetime(2015, 12, 15))
   1166         elif on_shard:
   1167             mock_afe = self.god.create_mock_class_obj(
   1168                     frontend_wrappers.RetryingAFE, 'MockAFE')
   1169             self.god.stub_with(frontend_wrappers, 'RetryingAFE', mock_afe)
   1170 
   1171             mock_afe2 = frontend_wrappers.RetryingAFE.expect_new(
   1172                     server=server_utils.get_global_afe_hostname(), user=None)
   1173             mock_afe2.run.expect_call('modify_host', id=host.id,
   1174                     locked=True, lock_reason='_modify_host_helper lock',
   1175                     lock_time=datetime.datetime(2015, 12, 15))
   1176 
   1177         rpc_interface.modify_host(id=host.id, locked=True,
   1178                                   lock_reason='_modify_host_helper lock',
   1179                                   lock_time=datetime.datetime(2015, 12, 15))
   1180 
   1181         host = models.Host.objects.get(pk=host.id)
   1182         if on_shard:
   1183             # modify_host on shard does nothing but routing the RPC to master.
   1184             self.assertFalse(host.locked)
   1185         else:
   1186             self.assertTrue(host.locked)
   1187         self.god.check_playback()
   1188 
   1189 
   1190     def test_modify_host_on_master_host_on_master(self):
   1191         """Call modify_host to master for host in master."""
   1192         self._modify_host_helper()
   1193 
   1194 
   1195     def test_modify_host_on_master_host_on_shard(self):
   1196         """Call modify_host to master for host in shard."""
   1197         self._modify_host_helper(host_on_shard=True)
   1198 
   1199 
   1200     def test_modify_host_on_shard(self):
   1201         """Call modify_host to shard for host in shard."""
   1202         self._modify_host_helper(on_shard=True, host_on_shard=True)
   1203 
   1204 
   1205     def test_modify_hosts_on_master_host_on_shard(self):
   1206         """Ensure calls to modify_hosts are correctly forwarded to shards."""
   1207         host1 = models.Host.objects.all()[0]
   1208         host2 = models.Host.objects.all()[1]
   1209 
   1210         shard1 = models.Shard.objects.create(hostname='shard1')
   1211         host1.shard = shard1
   1212         host1.save()
   1213 
   1214         shard2 = models.Shard.objects.create(hostname='shard2')
   1215         host2.shard = shard2
   1216         host2.save()
   1217 
   1218         self.assertFalse(host1.locked)
   1219         self.assertFalse(host2.locked)
   1220 
   1221         mock_afe = self.god.create_mock_class_obj(frontend_wrappers.RetryingAFE,
   1222                                                   'MockAFE')
   1223         self.god.stub_with(frontend_wrappers, 'RetryingAFE', mock_afe)
   1224 
   1225         # The statuses of one host might differ on master and shard.
   1226         # Filters are always applied on the master. So the host on the shard
   1227         # will be affected no matter what his status is.
   1228         filters_to_use = {'status': 'Ready'}
   1229 
   1230         mock_afe2 = frontend_wrappers.RetryingAFE.expect_new(
   1231                 server='shard2', user=None)
   1232         mock_afe2.run.expect_call(
   1233             'modify_hosts_local',
   1234             host_filter_data={'id__in': [shard1.id, shard2.id]},
   1235             update_data={'locked': True,
   1236                          'lock_reason': 'Testing forward to shard',
   1237                          'lock_time' : datetime.datetime(2015, 12, 15) })
   1238 
   1239         mock_afe1 = frontend_wrappers.RetryingAFE.expect_new(
   1240                 server='shard1', user=None)
   1241         mock_afe1.run.expect_call(
   1242             'modify_hosts_local',
   1243             host_filter_data={'id__in': [shard1.id, shard2.id]},
   1244             update_data={'locked': True,
   1245                          'lock_reason': 'Testing forward to shard',
   1246                          'lock_time' : datetime.datetime(2015, 12, 15)})
   1247 
   1248         rpc_interface.modify_hosts(
   1249                 host_filter_data={'status': 'Ready'},
   1250                 update_data={'locked': True,
   1251                              'lock_reason': 'Testing forward to shard',
   1252                              'lock_time' : datetime.datetime(2015, 12, 15) })
   1253 
   1254         host1 = models.Host.objects.get(pk=host1.id)
   1255         self.assertTrue(host1.locked)
   1256         host2 = models.Host.objects.get(pk=host2.id)
   1257         self.assertTrue(host2.locked)
   1258         self.god.check_playback()
   1259 
   1260 
   1261     def test_delete_host(self):
   1262         """Ensure an RPC is made on delete a host, if it is on a shard."""
   1263         host1 = models.Host.objects.all()[0]
   1264         shard1 = models.Shard.objects.create(hostname='shard1')
   1265         host1.shard = shard1
   1266         host1.save()
   1267         host1_id = host1.id
   1268 
   1269         mock_afe = self.god.create_mock_class_obj(frontend_wrappers.RetryingAFE,
   1270                                                  'MockAFE')
   1271         self.god.stub_with(frontend_wrappers, 'RetryingAFE', mock_afe)
   1272 
   1273         mock_afe1 = frontend_wrappers.RetryingAFE.expect_new(
   1274                 server='shard1', user=None)
   1275         mock_afe1.run.expect_call('delete_host', id=host1.id)
   1276 
   1277         rpc_interface.delete_host(id=host1.id)
   1278 
   1279         self.assertRaises(models.Host.DoesNotExist,
   1280                           models.Host.smart_get, host1_id)
   1281 
   1282         self.god.check_playback()
   1283 
   1284 
   1285     def test_modify_label(self):
   1286         label1 = models.Label.objects.all()[0]
   1287         self.assertEqual(label1.invalid, 0)
   1288 
   1289         host2 = models.Host.objects.all()[1]
   1290         shard1 = models.Shard.objects.create(hostname='shard1')
   1291         host2.shard = shard1
   1292         host2.labels.add(label1)
   1293         host2.save()
   1294 
   1295         mock_afe = self.god.create_mock_class_obj(frontend_wrappers.RetryingAFE,
   1296                                                   'MockAFE')
   1297         self.god.stub_with(frontend_wrappers, 'RetryingAFE', mock_afe)
   1298 
   1299         mock_afe1 = frontend_wrappers.RetryingAFE.expect_new(
   1300                 server='shard1', user=None)
   1301         mock_afe1.run.expect_call('modify_label', id=label1.id, invalid=1)
   1302 
   1303         rpc_interface.modify_label(label1.id, invalid=1)
   1304 
   1305         self.assertEqual(models.Label.objects.all()[0].invalid, 1)
   1306         self.god.check_playback()
   1307 
   1308 
   1309     def test_delete_label(self):
   1310         label1 = models.Label.objects.all()[0]
   1311 
   1312         host2 = models.Host.objects.all()[1]
   1313         shard1 = models.Shard.objects.create(hostname='shard1')
   1314         host2.shard = shard1
   1315         host2.labels.add(label1)
   1316         host2.save()
   1317 
   1318         mock_afe = self.god.create_mock_class_obj(frontend_wrappers.RetryingAFE,
   1319                                                   'MockAFE')
   1320         self.god.stub_with(frontend_wrappers, 'RetryingAFE', mock_afe)
   1321 
   1322         mock_afe1 = frontend_wrappers.RetryingAFE.expect_new(
   1323                 server='shard1', user=None)
   1324         mock_afe1.run.expect_call('delete_label', id=label1.id)
   1325 
   1326         rpc_interface.delete_label(id=label1.id)
   1327 
   1328         self.assertRaises(models.Label.DoesNotExist,
   1329                           models.Label.smart_get, label1.id)
   1330         self.god.check_playback()
   1331 
   1332 
   1333     def test_get_image_for_job_with_keyval_build(self):
   1334         keyval_dict = {'build': 'cool-image'}
   1335         job_id = rpc_interface.create_job(name='test',
   1336                                           priority=priorities.Priority.DEFAULT,
   1337                                           control_file='foo',
   1338                                           control_type=CLIENT,
   1339                                           hosts=['host1'],
   1340                                           keyvals=keyval_dict)
   1341         job = models.Job.objects.get(id=job_id)
   1342         self.assertIsNotNone(job)
   1343         image = rpc_interface._get_image_for_job(job, True)
   1344         self.assertEquals('cool-image', image)
   1345 
   1346 
   1347     def test_get_image_for_job_with_keyval_builds(self):
   1348         keyval_dict = {'builds': {'cros-version': 'cool-image'}}
   1349         job_id = rpc_interface.create_job(name='test',
   1350                                           priority=priorities.Priority.DEFAULT,
   1351                                           control_file='foo',
   1352                                           control_type=CLIENT,
   1353                                           hosts=['host1'],
   1354                                           keyvals=keyval_dict)
   1355         job = models.Job.objects.get(id=job_id)
   1356         self.assertIsNotNone(job)
   1357         image = rpc_interface._get_image_for_job(job, True)
   1358         self.assertEquals('cool-image', image)
   1359 
   1360 
   1361     def test_get_image_for_job_with_control_build(self):
   1362         CONTROL_FILE = """build='cool-image'
   1363         """
   1364         job_id = rpc_interface.create_job(name='test',
   1365                                           priority=priorities.Priority.DEFAULT,
   1366                                           control_file='foo',
   1367                                           control_type=CLIENT,
   1368                                           hosts=['host1'])
   1369         job = models.Job.objects.get(id=job_id)
   1370         self.assertIsNotNone(job)
   1371         job.control_file = CONTROL_FILE
   1372         image = rpc_interface._get_image_for_job(job, True)
   1373         self.assertEquals('cool-image', image)
   1374 
   1375 
   1376     def test_get_image_for_job_with_control_builds(self):
   1377         CONTROL_FILE = """builds={'cros-version': 'cool-image'}
   1378         """
   1379         job_id = rpc_interface.create_job(name='test',
   1380                                           priority=priorities.Priority.DEFAULT,
   1381                                           control_file='foo',
   1382                                           control_type=CLIENT,
   1383                                           hosts=['host1'])
   1384         job = models.Job.objects.get(id=job_id)
   1385         self.assertIsNotNone(job)
   1386         job.control_file = CONTROL_FILE
   1387         image = rpc_interface._get_image_for_job(job, True)
   1388         self.assertEquals('cool-image', image)
   1389 
   1390 
   1391 class ExtraRpcInterfaceTest(frontend_test_utils.FrontendTestMixin,
   1392                             ShardHeartbeatTest):
   1393     """Unit tests for functions originally in site_rpc_interface.py.
   1394 
   1395     @var _NAME: fake suite name.
   1396     @var _BOARD: fake board to reimage.
   1397     @var _BUILD: fake build with which to reimage.
   1398     @var _PRIORITY: fake priority with which to reimage.
   1399     """
   1400     _NAME = 'name'
   1401     _BOARD = 'link'
   1402     _BUILD = 'link-release/R36-5812.0.0'
   1403     _BUILDS = {provision.CROS_VERSION_PREFIX: _BUILD}
   1404     _PRIORITY = priorities.Priority.DEFAULT
   1405     _TIMEOUT = 24
   1406 
   1407 
   1408     def setUp(self):
   1409         super(ExtraRpcInterfaceTest, self).setUp()
   1410         self._SUITE_NAME = rpc_interface.canonicalize_suite_name(
   1411             self._NAME)
   1412         self.dev_server = self.mox.CreateMock(dev_server.ImageServer)
   1413         self._frontend_common_setup(fill_data=False)
   1414 
   1415 
   1416     def tearDown(self):
   1417         self._frontend_common_teardown()
   1418 
   1419 
   1420     def _setupDevserver(self):
   1421         self.mox.StubOutClassWithMocks(dev_server, 'ImageServer')
   1422         dev_server.resolve(self._BUILD).AndReturn(self.dev_server)
   1423 
   1424 
   1425     def _mockDevServerGetter(self, get_control_file=True):
   1426         self._setupDevserver()
   1427         if get_control_file:
   1428           self.getter = self.mox.CreateMock(
   1429               control_file_getter.DevServerGetter)
   1430           self.mox.StubOutWithMock(control_file_getter.DevServerGetter,
   1431                                    'create')
   1432           control_file_getter.DevServerGetter.create(
   1433               mox.IgnoreArg(), mox.IgnoreArg()).AndReturn(self.getter)
   1434 
   1435 
   1436     def _mockRpcUtils(self, to_return, control_file_substring=''):
   1437         """Fake out the autotest rpc_utils module with a mockable class.
   1438 
   1439         @param to_return: the value that rpc_utils.create_job_common() should
   1440                           be mocked out to return.
   1441         @param control_file_substring: A substring that is expected to appear
   1442                                        in the control file output string that
   1443                                        is passed to create_job_common.
   1444                                        Default: ''
   1445         """
   1446         download_started_time = constants.DOWNLOAD_STARTED_TIME
   1447         payload_finished_time = constants.PAYLOAD_FINISHED_TIME
   1448         self.mox.StubOutWithMock(rpc_utils, 'create_job_common')
   1449         rpc_utils.create_job_common(mox.And(mox.StrContains(self._NAME),
   1450                                     mox.StrContains(self._BUILD)),
   1451                             priority=self._PRIORITY,
   1452                             timeout_mins=self._TIMEOUT*60,
   1453                             max_runtime_mins=self._TIMEOUT*60,
   1454                             control_type='Server',
   1455                             control_file=mox.And(mox.StrContains(self._BOARD),
   1456                                                  mox.StrContains(self._BUILD),
   1457                                                  mox.StrContains(
   1458                                                      control_file_substring)),
   1459                             hostless=True,
   1460                             keyvals=mox.And(mox.In(download_started_time),
   1461                                             mox.In(payload_finished_time))
   1462                             ).AndReturn(to_return)
   1463 
   1464 
   1465     def testStageBuildFail(self):
   1466         """Ensure that a failure to stage the desired build fails the RPC."""
   1467         self._setupDevserver()
   1468 
   1469         self.dev_server.hostname = 'mox_url'
   1470         self.dev_server.stage_artifacts(
   1471                 image=self._BUILD, artifacts=['test_suites']).AndRaise(
   1472                 dev_server.DevServerException())
   1473         self.mox.ReplayAll()
   1474         self.assertRaises(error.StageControlFileFailure,
   1475                           rpc_interface.create_suite_job,
   1476                           name=self._NAME,
   1477                           board=self._BOARD,
   1478                           builds=self._BUILDS,
   1479                           pool=None)
   1480 
   1481 
   1482     def testGetControlFileFail(self):
   1483         """Ensure that a failure to get needed control file fails the RPC."""
   1484         self._mockDevServerGetter()
   1485 
   1486         self.dev_server.hostname = 'mox_url'
   1487         self.dev_server.stage_artifacts(
   1488                 image=self._BUILD, artifacts=['test_suites']).AndReturn(True)
   1489 
   1490         self.getter.get_control_file_contents_by_name(
   1491             self._SUITE_NAME).AndReturn(None)
   1492         self.mox.ReplayAll()
   1493         self.assertRaises(error.ControlFileEmpty,
   1494                           rpc_interface.create_suite_job,
   1495                           name=self._NAME,
   1496                           board=self._BOARD,
   1497                           builds=self._BUILDS,
   1498                           pool=None)
   1499 
   1500 
   1501     def testGetControlFileListFail(self):
   1502         """Ensure that a failure to get needed control file fails the RPC."""
   1503         self._mockDevServerGetter()
   1504 
   1505         self.dev_server.hostname = 'mox_url'
   1506         self.dev_server.stage_artifacts(
   1507                 image=self._BUILD, artifacts=['test_suites']).AndReturn(True)
   1508 
   1509         self.getter.get_control_file_contents_by_name(
   1510             self._SUITE_NAME).AndRaise(error.NoControlFileList())
   1511         self.mox.ReplayAll()
   1512         self.assertRaises(error.NoControlFileList,
   1513                           rpc_interface.create_suite_job,
   1514                           name=self._NAME,
   1515                           board=self._BOARD,
   1516                           builds=self._BUILDS,
   1517                           pool=None)
   1518 
   1519 
   1520     def testCreateSuiteJobFail(self):
   1521         """Ensure that failure to schedule the suite job fails the RPC."""
   1522         self._mockDevServerGetter()
   1523 
   1524         self.dev_server.hostname = 'mox_url'
   1525         self.dev_server.stage_artifacts(
   1526                 image=self._BUILD, artifacts=['test_suites']).AndReturn(True)
   1527 
   1528         self.getter.get_control_file_contents_by_name(
   1529             self._SUITE_NAME).AndReturn('f')
   1530 
   1531         self.dev_server.url().AndReturn('mox_url')
   1532         self._mockRpcUtils(-1)
   1533         self.mox.ReplayAll()
   1534         self.assertEquals(
   1535             rpc_interface.create_suite_job(name=self._NAME,
   1536                                            board=self._BOARD,
   1537                                            builds=self._BUILDS, pool=None),
   1538             -1)
   1539 
   1540 
   1541     def testCreateSuiteJobSuccess(self):
   1542         """Ensures that success results in a successful RPC."""
   1543         self._mockDevServerGetter()
   1544 
   1545         self.dev_server.hostname = 'mox_url'
   1546         self.dev_server.stage_artifacts(
   1547                 image=self._BUILD, artifacts=['test_suites']).AndReturn(True)
   1548 
   1549         self.getter.get_control_file_contents_by_name(
   1550             self._SUITE_NAME).AndReturn('f')
   1551 
   1552         self.dev_server.url().AndReturn('mox_url')
   1553         job_id = 5
   1554         self._mockRpcUtils(job_id)
   1555         self.mox.ReplayAll()
   1556         self.assertEquals(
   1557             rpc_interface.create_suite_job(name=self._NAME,
   1558                                            board=self._BOARD,
   1559                                            builds=self._BUILDS,
   1560                                            pool=None),
   1561             job_id)
   1562 
   1563 
   1564     def testCreateSuiteJobNoHostCheckSuccess(self):
   1565         """Ensures that success results in a successful RPC."""
   1566         self._mockDevServerGetter()
   1567 
   1568         self.dev_server.hostname = 'mox_url'
   1569         self.dev_server.stage_artifacts(
   1570                 image=self._BUILD, artifacts=['test_suites']).AndReturn(True)
   1571 
   1572         self.getter.get_control_file_contents_by_name(
   1573             self._SUITE_NAME).AndReturn('f')
   1574 
   1575         self.dev_server.url().AndReturn('mox_url')
   1576         job_id = 5
   1577         self._mockRpcUtils(job_id)
   1578         self.mox.ReplayAll()
   1579         self.assertEquals(
   1580           rpc_interface.create_suite_job(name=self._NAME,
   1581                                          board=self._BOARD,
   1582                                          builds=self._BUILDS,
   1583                                          pool=None, check_hosts=False),
   1584           job_id)
   1585 
   1586 
   1587     def testCreateSuiteJobControlFileSupplied(self):
   1588         """Ensure we can supply the control file to create_suite_job."""
   1589         self._mockDevServerGetter(get_control_file=False)
   1590 
   1591         self.dev_server.hostname = 'mox_url'
   1592         self.dev_server.stage_artifacts(
   1593                 image=self._BUILD, artifacts=['test_suites']).AndReturn(True)
   1594         self.dev_server.url().AndReturn('mox_url')
   1595         job_id = 5
   1596         self._mockRpcUtils(job_id)
   1597         self.mox.ReplayAll()
   1598         self.assertEquals(
   1599             rpc_interface.create_suite_job(name='%s/%s' % (self._NAME,
   1600                                                            self._BUILD),
   1601                                            board=None,
   1602                                            builds=self._BUILDS,
   1603                                            pool=None,
   1604                                            control_file='CONTROL FILE'),
   1605             job_id)
   1606 
   1607 
   1608     def _get_records_for_sending_to_master(self):
   1609         return [{'control_file': 'foo',
   1610                  'control_type': 1,
   1611                  'created_on': datetime.datetime(2014, 8, 21),
   1612                  'drone_set': None,
   1613                  'email_list': '',
   1614                  'max_runtime_hrs': 72,
   1615                  'max_runtime_mins': 1440,
   1616                  'name': 'dummy',
   1617                  'owner': 'autotest_system',
   1618                  'parse_failed_repair': True,
   1619                  'priority': 40,
   1620                  'reboot_after': 0,
   1621                  'reboot_before': 1,
   1622                  'run_reset': True,
   1623                  'run_verify': False,
   1624                  'synch_count': 0,
   1625                  'test_retry': 10,
   1626                  'timeout': 24,
   1627                  'timeout_mins': 1440,
   1628                  'id': 1
   1629                  }], [{
   1630                     'aborted': False,
   1631                     'active': False,
   1632                     'complete': False,
   1633                     'deleted': False,
   1634                     'execution_subdir': '',
   1635                     'finished_on': None,
   1636                     'started_on': None,
   1637                     'status': 'Queued',
   1638                     'id': 1
   1639                 }]
   1640 
   1641 
   1642     def _send_records_to_master_helper(
   1643         self, jobs, hqes, shard_hostname='host1',
   1644         exception_to_throw=error.UnallowedRecordsSentToMaster, aborted=False):
   1645         job_id = rpc_interface.create_job(
   1646                 name='dummy',
   1647                 priority=self._PRIORITY,
   1648                 control_file='foo',
   1649                 control_type=SERVER,
   1650                 test_retry=10, hostless=True)
   1651         job = models.Job.objects.get(pk=job_id)
   1652         shard = models.Shard.objects.create(hostname='host1')
   1653         job.shard = shard
   1654         job.save()
   1655 
   1656         if aborted:
   1657             job.hostqueueentry_set.update(aborted=True)
   1658             job.shard = None
   1659             job.save()
   1660 
   1661         hqe = job.hostqueueentry_set.all()[0]
   1662         if not exception_to_throw:
   1663             self._do_heartbeat_and_assert_response(
   1664                 shard_hostname=shard_hostname,
   1665                 upload_jobs=jobs, upload_hqes=hqes)
   1666         else:
   1667             self.assertRaises(
   1668                 exception_to_throw,
   1669                 self._do_heartbeat_and_assert_response,
   1670                 shard_hostname=shard_hostname,
   1671                 upload_jobs=jobs, upload_hqes=hqes)
   1672 
   1673 
   1674     def testSendingRecordsToMaster(self):
   1675         """Send records to the master and ensure they are persisted."""
   1676         jobs, hqes = self._get_records_for_sending_to_master()
   1677         hqes[0]['status'] = 'Completed'
   1678         self._send_records_to_master_helper(
   1679             jobs=jobs, hqes=hqes, exception_to_throw=None)
   1680 
   1681         # Check the entry was actually written to db
   1682         self.assertEqual(models.HostQueueEntry.objects.all()[0].status,
   1683                          'Completed')
   1684 
   1685 
   1686     def testSendingRecordsToMasterAbortedOnMaster(self):
   1687         """Send records to the master and ensure they are persisted."""
   1688         jobs, hqes = self._get_records_for_sending_to_master()
   1689         hqes[0]['status'] = 'Completed'
   1690         self._send_records_to_master_helper(
   1691             jobs=jobs, hqes=hqes, exception_to_throw=None, aborted=True)
   1692 
   1693         # Check the entry was actually written to db
   1694         self.assertEqual(models.HostQueueEntry.objects.all()[0].status,
   1695                          'Completed')
   1696 
   1697 
   1698     def testSendingRecordsToMasterJobAssignedToDifferentShard(self):
   1699         """Ensure records belonging to different shard are silently rejected."""
   1700         shard1 = models.Shard.objects.create(hostname='shard1')
   1701         shard2 = models.Shard.objects.create(hostname='shard2')
   1702         job1 = self._create_job(shard=shard1, control_file='foo1')
   1703         job2 = self._create_job(shard=shard2, control_file='foo2')
   1704         job1_id = job1.id
   1705         job2_id = job2.id
   1706         hqe1 = models.HostQueueEntry.objects.create(job=job1)
   1707         hqe2 = models.HostQueueEntry.objects.create(job=job2)
   1708         hqe1_id = hqe1.id
   1709         hqe2_id = hqe2.id
   1710         job1_record = job1.serialize(include_dependencies=False)
   1711         job2_record = job2.serialize(include_dependencies=False)
   1712         hqe1_record = hqe1.serialize(include_dependencies=False)
   1713         hqe2_record = hqe2.serialize(include_dependencies=False)
   1714 
   1715         # Prepare a bogus job record update from the wrong shard. The update
   1716         # should not throw an exception. Non-bogus jobs in the same update
   1717         # should happily update.
   1718         job1_record.update({'control_file': 'bar1'})
   1719         job2_record.update({'control_file': 'bar2'})
   1720         hqe1_record.update({'status': 'Aborted'})
   1721         hqe2_record.update({'status': 'Aborted'})
   1722         self._do_heartbeat_and_assert_response(
   1723             shard_hostname='shard2', upload_jobs=[job1_record, job2_record],
   1724             upload_hqes=[hqe1_record, hqe2_record])
   1725 
   1726         # Job and HQE record for wrong job should not be modified, because the
   1727         # rpc came from the wrong shard. Job and HQE record for valid job are
   1728         # modified.
   1729         self.assertEqual(models.Job.objects.get(id=job1_id).control_file,
   1730                          'foo1')
   1731         self.assertEqual(models.Job.objects.get(id=job2_id).control_file,
   1732                          'bar2')
   1733         self.assertEqual(models.HostQueueEntry.objects.get(id=hqe1_id).status,
   1734                          '')
   1735         self.assertEqual(models.HostQueueEntry.objects.get(id=hqe2_id).status,
   1736                          'Aborted')
   1737 
   1738 
   1739     def testSendingRecordsToMasterNotExistingJob(self):
   1740         """Ensure update for non existing job gets rejected."""
   1741         jobs, hqes = self._get_records_for_sending_to_master()
   1742         jobs[0]['id'] = 3
   1743 
   1744         self._send_records_to_master_helper(
   1745             jobs=jobs, hqes=hqes)
   1746 
   1747 
   1748     def _createShardAndHostWithLabel(self, shard_hostname='shard1',
   1749                                      host_hostname='host1',
   1750                                      label_name='board:lumpy'):
   1751         """Create a label, host, shard, and assign host to shard."""
   1752         try:
   1753             label = models.Label.objects.create(name=label_name)
   1754         except:
   1755             label = models.Label.smart_get(label_name)
   1756 
   1757         shard = models.Shard.objects.create(hostname=shard_hostname)
   1758         shard.labels.add(label)
   1759 
   1760         host = models.Host.objects.create(hostname=host_hostname, leased=False,
   1761                                           shard=shard)
   1762         host.labels.add(label)
   1763 
   1764         return shard, host, label
   1765 
   1766 
   1767     def testShardLabelRemovalInvalid(self):
   1768         """Ensure you cannot remove the wrong label from shard."""
   1769         shard1, host1, lumpy_label = self._createShardAndHostWithLabel()
   1770         stumpy_label = models.Label.objects.create(
   1771                 name='board:stumpy', platform=True)
   1772         with self.assertRaises(error.RPCException):
   1773             rpc_interface.remove_board_from_shard(
   1774                     shard1.hostname, stumpy_label.name)
   1775 
   1776 
   1777     def testShardHeartbeatLabelRemoval(self):
   1778         """Ensure label removal from shard works."""
   1779         shard1, host1, lumpy_label = self._createShardAndHostWithLabel()
   1780 
   1781         self.assertEqual(host1.shard, shard1)
   1782         self.assertItemsEqual(shard1.labels.all(), [lumpy_label])
   1783         rpc_interface.remove_board_from_shard(
   1784                 shard1.hostname, lumpy_label.name)
   1785         host1 = models.Host.smart_get(host1.id)
   1786         shard1 = models.Shard.smart_get(shard1.id)
   1787         self.assertEqual(host1.shard, None)
   1788         self.assertItemsEqual(shard1.labels.all(), [])
   1789 
   1790 
   1791     def testCreateListShard(self):
   1792         """Retrieve a list of all shards."""
   1793         lumpy_label = models.Label.objects.create(name='board:lumpy',
   1794                                                   platform=True)
   1795         stumpy_label = models.Label.objects.create(name='board:stumpy',
   1796                                                   platform=True)
   1797         peppy_label = models.Label.objects.create(name='board:peppy',
   1798                                                   platform=True)
   1799 
   1800         shard_id = rpc_interface.add_shard(
   1801             hostname='host1', labels='board:lumpy,board:stumpy')
   1802         self.assertRaises(error.RPCException,
   1803                           rpc_interface.add_shard,
   1804                           hostname='host1', labels='board:lumpy,board:stumpy')
   1805         self.assertRaises(model_logic.ValidationError,
   1806                           rpc_interface.add_shard,
   1807                           hostname='host1', labels='board:peppy')
   1808         shard = models.Shard.objects.get(pk=shard_id)
   1809         self.assertEqual(shard.hostname, 'host1')
   1810         self.assertEqual(shard.labels.values_list('pk')[0], (lumpy_label.id,))
   1811         self.assertEqual(shard.labels.values_list('pk')[1], (stumpy_label.id,))
   1812 
   1813         self.assertEqual(rpc_interface.get_shards(),
   1814                          [{'labels': ['board:lumpy','board:stumpy'],
   1815                            'hostname': 'host1',
   1816                            'id': 1}])
   1817 
   1818 
   1819     def testAddBoardsToShard(self):
   1820         """Add boards to a given shard."""
   1821         shard1, host1, lumpy_label = self._createShardAndHostWithLabel()
   1822         stumpy_label = models.Label.objects.create(name='board:stumpy',
   1823                                                    platform=True)
   1824         shard_id = rpc_interface.add_board_to_shard(
   1825             hostname='shard1', labels='board:stumpy')
   1826         # Test whether raise exception when board label does not exist.
   1827         self.assertRaises(models.Label.DoesNotExist,
   1828                           rpc_interface.add_board_to_shard,
   1829                           hostname='shard1', labels='board:test')
   1830         # Test whether raise exception when board already sharded.
   1831         self.assertRaises(error.RPCException,
   1832                           rpc_interface.add_board_to_shard,
   1833                           hostname='shard1', labels='board:lumpy')
   1834         shard = models.Shard.objects.get(pk=shard_id)
   1835         self.assertEqual(shard.hostname, 'shard1')
   1836         self.assertEqual(shard.labels.values_list('pk')[0], (lumpy_label.id,))
   1837         self.assertEqual(shard.labels.values_list('pk')[1], (stumpy_label.id,))
   1838 
   1839         self.assertEqual(rpc_interface.get_shards(),
   1840                          [{'labels': ['board:lumpy','board:stumpy'],
   1841                            'hostname': 'shard1',
   1842                            'id': 1}])
   1843 
   1844 
   1845     def testShardHeartbeatFetchHostlessJob(self):
   1846         shard1, host1, label1 = self._createShardAndHostWithLabel()
   1847         self._testShardHeartbeatFetchHostlessJobHelper(host1)
   1848 
   1849 
   1850     def testShardHeartbeatIncorrectHosts(self):
   1851         shard1, host1, label1 = self._createShardAndHostWithLabel()
   1852         self._testShardHeartbeatIncorrectHostsHelper(host1)
   1853 
   1854 
   1855     def testShardHeartbeatLabelRemovalRace(self):
   1856         shard1, host1, label1 = self._createShardAndHostWithLabel()
   1857         self._testShardHeartbeatLabelRemovalRaceHelper(shard1, host1, label1)
   1858 
   1859 
   1860     def testShardRetrieveJobs(self):
   1861         shard1, host1, label1 = self._createShardAndHostWithLabel()
   1862         shard2, host2, label2 = self._createShardAndHostWithLabel(
   1863                 'shard2', 'host2', 'board:grumpy')
   1864         self._testShardRetrieveJobsHelper(shard1, host1, label1,
   1865                                           shard2, host2, label2)
   1866 
   1867 
   1868     def testResendJobsAfterFailedHeartbeat(self):
   1869         shard1, host1, label1 = self._createShardAndHostWithLabel()
   1870         self._testResendJobsAfterFailedHeartbeatHelper(shard1, host1, label1)
   1871 
   1872 
   1873     def testResendHostsAfterFailedHeartbeat(self):
   1874         shard1, host1, label1 = self._createShardAndHostWithLabel()
   1875         self._testResendHostsAfterFailedHeartbeatHelper(host1)
   1876 
   1877 
   1878 if __name__ == '__main__':
   1879     unittest.main()
   1880