Home | History | Annotate | Download | only in afe
      1 #!/usr/bin/env python2
      2 # pylint: disable=missing-docstring
      3 
      4 import datetime
      5 import mox
      6 import unittest
      7 
      8 import common
      9 from autotest_lib.client.common_lib import control_data
     10 from autotest_lib.client.common_lib import error
     11 from autotest_lib.client.common_lib import global_config
     12 from autotest_lib.client.common_lib import priorities
     13 from autotest_lib.client.common_lib.cros import dev_server
     14 from autotest_lib.client.common_lib.test_utils import mock
     15 from autotest_lib.frontend import setup_django_environment
     16 from autotest_lib.frontend.afe import frontend_test_utils
     17 from autotest_lib.frontend.afe import model_logic
     18 from autotest_lib.frontend.afe import models
     19 from autotest_lib.frontend.afe import rpc_interface
     20 from autotest_lib.frontend.afe import rpc_utils
     21 from autotest_lib.server import frontend
     22 from autotest_lib.server import utils as server_utils
     23 from autotest_lib.server.cros import provision
     24 from autotest_lib.server.cros.dynamic_suite import constants
     25 from autotest_lib.server.cros.dynamic_suite import control_file_getter
     26 from autotest_lib.server.cros.dynamic_suite import frontend_wrappers
     27 from autotest_lib.server.cros.dynamic_suite import suite_common
     28 
     29 
     30 CLIENT = control_data.CONTROL_TYPE_NAMES.CLIENT
     31 SERVER = control_data.CONTROL_TYPE_NAMES.SERVER
     32 
     33 _hqe_status = models.HostQueueEntry.Status
     34 
     35 
     36 class ShardHeartbeatTest(mox.MoxTestBase, unittest.TestCase):
     37 
     38     _PRIORITY = priorities.Priority.DEFAULT
     39 
     40 
     41     def _do_heartbeat_and_assert_response(self, shard_hostname='shard1',
     42                                           upload_jobs=(), upload_hqes=(),
     43                                           known_jobs=(), known_hosts=(),
     44                                           **kwargs):
     45         known_job_ids = [job.id for job in known_jobs]
     46         known_host_ids = [host.id for host in known_hosts]
     47         known_host_statuses = [host.status for host in known_hosts]
     48 
     49         retval = rpc_interface.shard_heartbeat(
     50             shard_hostname=shard_hostname,
     51             jobs=upload_jobs, hqes=upload_hqes,
     52             known_job_ids=known_job_ids, known_host_ids=known_host_ids,
     53             known_host_statuses=known_host_statuses)
     54 
     55         self._assert_shard_heartbeat_response(shard_hostname, retval,
     56                                               **kwargs)
     57 
     58         return shard_hostname
     59 
     60 
     61     def _assert_shard_heartbeat_response(self, shard_hostname, retval, jobs=[],
     62                                          hosts=[], hqes=[],
     63                                          incorrect_host_ids=[]):
     64 
     65         retval_hosts, retval_jobs = retval['hosts'], retval['jobs']
     66         retval_incorrect_hosts = retval['incorrect_host_ids']
     67 
     68         expected_jobs = [
     69             (job.id, job.name, shard_hostname) for job in jobs]
     70         returned_jobs = [(job['id'], job['name'], job['shard']['hostname'])
     71                          for job in retval_jobs]
     72         self.assertEqual(returned_jobs, expected_jobs)
     73 
     74         expected_hosts = [(host.id, host.hostname) for host in hosts]
     75         returned_hosts = [(host['id'], host['hostname'])
     76                           for host in retval_hosts]
     77         self.assertEqual(returned_hosts, expected_hosts)
     78 
     79         retval_hqes = []
     80         for job in retval_jobs:
     81             retval_hqes += job['hostqueueentry_set']
     82 
     83         expected_hqes = [(hqe.id) for hqe in hqes]
     84         returned_hqes = [(hqe['id']) for hqe in retval_hqes]
     85         self.assertEqual(returned_hqes, expected_hqes)
     86 
     87         self.assertEqual(retval_incorrect_hosts, incorrect_host_ids)
     88 
     89 
     90     def _createJobForLabel(self, label):
     91         job_id = rpc_interface.create_job(name='dummy', priority=self._PRIORITY,
     92                                           control_file='foo',
     93                                           control_type=CLIENT,
     94                                           meta_hosts=[label.name],
     95                                           dependencies=(label.name,))
     96         return models.Job.objects.get(id=job_id)
     97 
     98 
     99     def _testShardHeartbeatFetchHostlessJobHelper(self, host1):
    100         """Create a hostless job and ensure it's not assigned to a shard."""
    101         label2 = models.Label.objects.create(name='bluetooth', platform=False)
    102 
    103         job1 = self._create_job(hostless=True)
    104 
    105         # Hostless jobs should be executed by the global scheduler.
    106         self._do_heartbeat_and_assert_response(hosts=[host1])
    107 
    108 
    109     def _testShardHeartbeatIncorrectHostsHelper(self, host1):
    110         """Ensure that hosts that don't belong to shard are determined."""
    111         host2 = models.Host.objects.create(hostname='test_host2', leased=False)
    112 
    113         # host2 should not belong to shard1. Ensure that if shard1 thinks host2
    114         # is a known host, then it is returned as invalid.
    115         self._do_heartbeat_and_assert_response(known_hosts=[host1, host2],
    116                                                incorrect_host_ids=[host2.id])
    117 
    118 
    119     def _testShardHeartbeatLabelRemovalRaceHelper(self, shard1, host1, label1):
    120         """Ensure correctness if label removed during heartbeat."""
    121         host2 = models.Host.objects.create(hostname='test_host2', leased=False)
    122         host2.labels.add(label1)
    123         self.assertEqual(host2.shard, None)
    124 
    125         # In the middle of the assign_to_shard call, remove label1 from shard1.
    126         self.mox.StubOutWithMock(models.Host, '_assign_to_shard_nothing_helper')
    127         def remove_label():
    128             rpc_interface.remove_board_from_shard(shard1.hostname, label1.name)
    129 
    130         models.Host._assign_to_shard_nothing_helper().WithSideEffects(
    131             remove_label)
    132         self.mox.ReplayAll()
    133 
    134         self._do_heartbeat_and_assert_response(
    135             known_hosts=[host1], hosts=[], incorrect_host_ids=[host1.id])
    136         host2 = models.Host.smart_get(host2.id)
    137         self.assertEqual(host2.shard, None)
    138 
    139 
    140     def _testShardRetrieveJobsHelper(self, shard1, host1, label1, shard2,
    141                                      host2, label2):
    142         """Create jobs and retrieve them."""
    143         # should never be returned by heartbeat
    144         leased_host = models.Host.objects.create(hostname='leased_host',
    145                                                  leased=True)
    146 
    147         leased_host.labels.add(label1)
    148 
    149         job1 = self._createJobForLabel(label1)
    150 
    151         job2 = self._createJobForLabel(label2)
    152 
    153         job_completed = self._createJobForLabel(label1)
    154         # Job is already being run, so don't sync it
    155         job_completed.hostqueueentry_set.update(complete=True)
    156         job_completed.hostqueueentry_set.create(complete=False)
    157 
    158         job_active = self._createJobForLabel(label1)
    159         # Job is already started, so don't sync it
    160         job_active.hostqueueentry_set.update(active=True)
    161         job_active.hostqueueentry_set.create(complete=False, active=False)
    162 
    163         self._do_heartbeat_and_assert_response(
    164             jobs=[job1], hosts=[host1], hqes=job1.hostqueueentry_set.all())
    165 
    166         self._do_heartbeat_and_assert_response(
    167             shard_hostname=shard2.hostname,
    168             jobs=[job2], hosts=[host2], hqes=job2.hostqueueentry_set.all())
    169 
    170         host3 = models.Host.objects.create(hostname='test_host3', leased=False)
    171         host3.labels.add(label1)
    172 
    173         self._do_heartbeat_and_assert_response(
    174             known_jobs=[job1], known_hosts=[host1], hosts=[host3])
    175 
    176 
    177     def _testResendJobsAfterFailedHeartbeatHelper(self, shard1, host1, label1):
    178         """Create jobs, retrieve them, fail on client, fetch them again."""
    179         job1 = self._createJobForLabel(label1)
    180 
    181         self._do_heartbeat_and_assert_response(
    182             jobs=[job1],
    183             hqes=job1.hostqueueentry_set.all(), hosts=[host1])
    184 
    185         # Make sure it's resubmitted by sending last_job=None again
    186         self._do_heartbeat_and_assert_response(
    187             known_hosts=[host1],
    188             jobs=[job1], hqes=job1.hostqueueentry_set.all(), hosts=[])
    189 
    190         # Now it worked, make sure it's not sent again
    191         self._do_heartbeat_and_assert_response(
    192             known_jobs=[job1], known_hosts=[host1])
    193 
    194         job1 = models.Job.objects.get(pk=job1.id)
    195         job1.hostqueueentry_set.all().update(complete=True)
    196 
    197         # Job is completed, make sure it's not sent again
    198         self._do_heartbeat_and_assert_response(
    199             known_hosts=[host1])
    200 
    201         job2 = self._createJobForLabel(label1)
    202 
    203         # job2's creation was later, it should be returned now.
    204         self._do_heartbeat_and_assert_response(
    205             known_hosts=[host1],
    206             jobs=[job2], hqes=job2.hostqueueentry_set.all())
    207 
    208         self._do_heartbeat_and_assert_response(
    209             known_jobs=[job2], known_hosts=[host1])
    210 
    211         job2 = models.Job.objects.get(pk=job2.pk)
    212         job2.hostqueueentry_set.update(aborted=True)
    213         # Setting a job to a complete status will set the shard_id to None in
    214         # scheduler_models. We have to emulate that here, because we use Django
    215         # models in tests.
    216         job2.shard = None
    217         job2.save()
    218 
    219         self._do_heartbeat_and_assert_response(
    220             known_jobs=[job2], known_hosts=[host1],
    221             jobs=[job2],
    222             hqes=job2.hostqueueentry_set.all())
    223 
    224         models.Test.objects.create(name='platform_BootPerfServer:shard',
    225                                    test_type=1)
    226         self.mox.StubOutWithMock(server_utils, 'read_file')
    227         self.mox.ReplayAll()
    228         rpc_interface.delete_shard(hostname=shard1.hostname)
    229 
    230         self.assertRaises(
    231             models.Shard.DoesNotExist, models.Shard.objects.get, pk=shard1.id)
    232 
    233         job1 = models.Job.objects.get(pk=job1.id)
    234         label1 = models.Label.objects.get(pk=label1.id)
    235 
    236         self.assertIsNone(job1.shard)
    237         self.assertEqual(len(label1.shard_set.all()), 0)
    238 
    239 
    240     def _testResendHostsAfterFailedHeartbeatHelper(self, host1):
    241         """Check that master accepts resending updated records after failure."""
    242         # Send the host
    243         self._do_heartbeat_and_assert_response(hosts=[host1])
    244 
    245         # Send it again because previous one didn't persist correctly
    246         self._do_heartbeat_and_assert_response(hosts=[host1])
    247 
    248         # Now it worked, make sure it isn't sent again
    249         self._do_heartbeat_and_assert_response(known_hosts=[host1])
    250 
    251 
    252 class RpcInterfaceTestWithStaticAttribute(
    253         mox.MoxTestBase, unittest.TestCase,
    254         frontend_test_utils.FrontendTestMixin):
    255 
    256     def setUp(self):
    257         super(RpcInterfaceTestWithStaticAttribute, self).setUp()
    258         self._frontend_common_setup()
    259         self.god = mock.mock_god()
    260         self.old_respect_static_config = rpc_interface.RESPECT_STATIC_ATTRIBUTES
    261         rpc_interface.RESPECT_STATIC_ATTRIBUTES = True
    262         models.RESPECT_STATIC_ATTRIBUTES = True
    263 
    264 
    265     def tearDown(self):
    266         self.god.unstub_all()
    267         self._frontend_common_teardown()
    268         global_config.global_config.reset_config_values()
    269         rpc_interface.RESPECT_STATIC_ATTRIBUTES = self.old_respect_static_config
    270         models.RESPECT_STATIC_ATTRIBUTES = self.old_respect_static_config
    271 
    272 
    273     def _fake_host_with_static_attributes(self):
    274         host1 = models.Host.objects.create(hostname='test_host')
    275         host1.set_attribute('test_attribute1', 'test_value1')
    276         host1.set_attribute('test_attribute2', 'test_value2')
    277         self._set_static_attribute(host1, 'test_attribute1', 'static_value1')
    278         self._set_static_attribute(host1, 'static_attribute1', 'static_value2')
    279         host1.save()
    280         return host1
    281 
    282 
    283     def test_get_hosts(self):
    284         host1 = self._fake_host_with_static_attributes()
    285         hosts = rpc_interface.get_hosts(hostname=host1.hostname)
    286         host = hosts[0]
    287 
    288         self.assertEquals(host['hostname'], 'test_host')
    289         self.assertEquals(host['acls'], ['Everyone'])
    290         # Respect the value of static attributes.
    291         self.assertEquals(host['attributes'],
    292                           {'test_attribute1': 'static_value1',
    293                            'test_attribute2': 'test_value2',
    294                            'static_attribute1': 'static_value2'})
    295 
    296     def test_get_host_attribute_with_static(self):
    297         host1 = models.Host.objects.create(hostname='test_host1')
    298         host1.set_attribute('test_attribute1', 'test_value1')
    299         self._set_static_attribute(host1, 'test_attribute1', 'static_value1')
    300         host2 = models.Host.objects.create(hostname='test_host2')
    301         host2.set_attribute('test_attribute1', 'test_value1')
    302         host2.set_attribute('test_attribute2', 'test_value2')
    303 
    304         attributes = rpc_interface.get_host_attribute(
    305                 'test_attribute1',
    306                 hostname__in=['test_host1', 'test_host2'])
    307         hosts = [attr['host'] for attr in attributes]
    308         values = [attr['value'] for attr in attributes]
    309         self.assertEquals(set(hosts),
    310                           set(['test_host1', 'test_host2']))
    311         self.assertEquals(set(values),
    312                           set(['test_value1', 'static_value1']))
    313 
    314 
    315     def test_get_hosts_by_attribute_without_static(self):
    316         host1 = models.Host.objects.create(hostname='test_host1')
    317         host1.set_attribute('test_attribute1', 'test_value1')
    318         host2 = models.Host.objects.create(hostname='test_host2')
    319         host2.set_attribute('test_attribute1', 'test_value1')
    320 
    321         hosts = rpc_interface.get_hosts_by_attribute(
    322                 'test_attribute1', 'test_value1')
    323         self.assertEquals(set(hosts),
    324                           set(['test_host1', 'test_host2']))
    325 
    326 
    327     def test_get_hosts_by_attribute_with_static(self):
    328         host1 = models.Host.objects.create(hostname='test_host1')
    329         host1.set_attribute('test_attribute1', 'test_value1')
    330         self._set_static_attribute(host1, 'test_attribute1', 'test_value1')
    331         host2 = models.Host.objects.create(hostname='test_host2')
    332         host2.set_attribute('test_attribute1', 'test_value1')
    333         self._set_static_attribute(host2, 'test_attribute1', 'static_value1')
    334         host3 = models.Host.objects.create(hostname='test_host3')
    335         self._set_static_attribute(host3, 'test_attribute1', 'test_value1')
    336         host4 = models.Host.objects.create(hostname='test_host4')
    337         host4.set_attribute('test_attribute1', 'test_value1')
    338         host5 = models.Host.objects.create(hostname='test_host5')
    339         host5.set_attribute('test_attribute1', 'temp_value1')
    340         self._set_static_attribute(host5, 'test_attribute1', 'test_value1')
    341 
    342         hosts = rpc_interface.get_hosts_by_attribute(
    343                 'test_attribute1', 'test_value1')
    344         # host1: matched, it has the same value for test_attribute1.
    345         # host2: not matched, it has a new value in
    346         #        afe_static_host_attributes for test_attribute1.
    347         # host3: matched, it has a corresponding entry in
    348         #        afe_host_attributes for test_attribute1.
    349         # host4: matched, test_attribute1 is not replaced by static
    350         #        attribute.
    351         # host5: matched, it has an updated & matched value for
    352         #        test_attribute1 in afe_static_host_attributes.
    353         self.assertEquals(set(hosts),
    354                           set(['test_host1', 'test_host3',
    355                                'test_host4', 'test_host5']))
    356 
    357 
    358 class RpcInterfaceTestWithStaticLabel(ShardHeartbeatTest,
    359                                       frontend_test_utils.FrontendTestMixin):
    360 
    361     _STATIC_LABELS = ['board:lumpy']
    362 
    363     def setUp(self):
    364         super(RpcInterfaceTestWithStaticLabel, self).setUp()
    365         self._frontend_common_setup()
    366         self.god = mock.mock_god()
    367         self.old_respect_static_config = rpc_interface.RESPECT_STATIC_LABELS
    368         rpc_interface.RESPECT_STATIC_LABELS = True
    369         models.RESPECT_STATIC_LABELS = True
    370 
    371 
    372     def tearDown(self):
    373         self.god.unstub_all()
    374         self._frontend_common_teardown()
    375         global_config.global_config.reset_config_values()
    376         rpc_interface.RESPECT_STATIC_LABELS = self.old_respect_static_config
    377         models.RESPECT_STATIC_LABELS = self.old_respect_static_config
    378 
    379 
    380     def _fake_host_with_static_labels(self):
    381         host1 = models.Host.objects.create(hostname='test_host')
    382         label1 = models.Label.objects.create(
    383                 name='non_static_label1', platform=False)
    384         non_static_platform = models.Label.objects.create(
    385                 name='static_platform', platform=False)
    386         static_platform = models.StaticLabel.objects.create(
    387                 name='static_platform', platform=True)
    388         models.ReplacedLabel.objects.create(label_id=non_static_platform.id)
    389         host1.static_labels.add(static_platform)
    390         host1.labels.add(non_static_platform)
    391         host1.labels.add(label1)
    392         host1.save()
    393         return host1
    394 
    395 
    396     def test_get_hosts(self):
    397         host1 = self._fake_host_with_static_labels()
    398         hosts = rpc_interface.get_hosts(hostname=host1.hostname)
    399         host = hosts[0]
    400 
    401         self.assertEquals(host['hostname'], 'test_host')
    402         self.assertEquals(host['acls'], ['Everyone'])
    403         # Respect all labels in afe_hosts_labels.
    404         self.assertEquals(host['labels'],
    405                           ['non_static_label1', 'static_platform'])
    406         # Respect static labels.
    407         self.assertEquals(host['platform'], 'static_platform')
    408 
    409 
    410     def test_get_hosts_multiple_labels(self):
    411         self._fake_host_with_static_labels()
    412         hosts = rpc_interface.get_hosts(
    413                 multiple_labels=['non_static_label1', 'static_platform'])
    414         host = hosts[0]
    415         self.assertEquals(host['hostname'], 'test_host')
    416 
    417 
    418     def test_delete_static_label(self):
    419         label1 = models.Label.smart_get('static')
    420 
    421         host2 = models.Host.objects.all()[1]
    422         shard1 = models.Shard.objects.create(hostname='shard1')
    423         host2.shard = shard1
    424         host2.labels.add(label1)
    425         host2.save()
    426 
    427         mock_afe = self.god.create_mock_class_obj(frontend_wrappers.RetryingAFE,
    428                                                   'MockAFE')
    429         self.god.stub_with(frontend_wrappers, 'RetryingAFE', mock_afe)
    430 
    431         self.assertRaises(error.UnmodifiableLabelException,
    432                           rpc_interface.delete_label,
    433                           label1.id)
    434 
    435         self.god.check_playback()
    436 
    437 
    438     def test_modify_static_label(self):
    439         label1 = models.Label.smart_get('static')
    440         self.assertEqual(label1.invalid, 0)
    441 
    442         host2 = models.Host.objects.all()[1]
    443         shard1 = models.Shard.objects.create(hostname='shard1')
    444         host2.shard = shard1
    445         host2.labels.add(label1)
    446         host2.save()
    447 
    448         mock_afe = self.god.create_mock_class_obj(frontend_wrappers.RetryingAFE,
    449                                                   'MockAFE')
    450         self.god.stub_with(frontend_wrappers, 'RetryingAFE', mock_afe)
    451 
    452         self.assertRaises(error.UnmodifiableLabelException,
    453                           rpc_interface.modify_label,
    454                           label1.id,
    455                           invalid=1)
    456 
    457         self.assertEqual(models.Label.smart_get('static').invalid, 0)
    458         self.god.check_playback()
    459 
    460 
    461     def test_multiple_platforms_add_non_static_to_static(self):
    462         """Test non-static platform to a host with static platform."""
    463         static_platform = models.StaticLabel.objects.create(
    464                 name='static_platform', platform=True)
    465         non_static_platform = models.Label.objects.create(
    466                 name='static_platform', platform=True)
    467         models.ReplacedLabel.objects.create(label_id=non_static_platform.id)
    468         platform2 = models.Label.objects.create(name='platform2', platform=True)
    469         host1 = models.Host.objects.create(hostname='test_host')
    470         host1.static_labels.add(static_platform)
    471         host1.labels.add(non_static_platform)
    472         host1.save()
    473 
    474         self.assertRaises(model_logic.ValidationError,
    475                           rpc_interface.label_add_hosts, id='platform2',
    476                           hosts=['test_host'])
    477         self.assertRaises(model_logic.ValidationError,
    478                           rpc_interface.host_add_labels,
    479                           id='test_host', labels=['platform2'])
    480         # make sure the platform didn't get added
    481         platforms = rpc_interface.get_labels(
    482             host__hostname__in=['test_host'], platform=True)
    483         self.assertEquals(len(platforms), 1)
    484 
    485 
    486     def test_multiple_platforms_add_static_to_non_static(self):
    487         """Test static platform to a host with non-static platform."""
    488         platform1 = models.Label.objects.create(
    489                 name='static_platform', platform=True)
    490         models.ReplacedLabel.objects.create(label_id=platform1.id)
    491         static_platform = models.StaticLabel.objects.create(
    492                 name='static_platform', platform=True)
    493         platform2 = models.Label.objects.create(
    494                 name='platform2', platform=True)
    495 
    496         host1 = models.Host.objects.create(hostname='test_host')
    497         host1.labels.add(platform2)
    498         host1.save()
    499 
    500         self.assertRaises(model_logic.ValidationError,
    501                           rpc_interface.label_add_hosts,
    502                           id='static_platform',
    503                           hosts=['test_host'])
    504         self.assertRaises(model_logic.ValidationError,
    505                           rpc_interface.host_add_labels,
    506                           id='test_host', labels=['static_platform'])
    507         # make sure the platform didn't get added
    508         platforms = rpc_interface.get_labels(
    509             host__hostname__in=['test_host'], platform=True)
    510         self.assertEquals(len(platforms), 1)
    511 
    512 
    513     def test_label_remove_hosts(self):
    514         """Test remove a label of hosts."""
    515         label = models.Label.smart_get('static')
    516         static_label = models.StaticLabel.objects.create(name='static')
    517 
    518         host1 = models.Host.objects.create(hostname='test_host')
    519         host1.labels.add(label)
    520         host1.static_labels.add(static_label)
    521         host1.save()
    522 
    523         self.assertRaises(error.UnmodifiableLabelException,
    524                           rpc_interface.label_remove_hosts,
    525                           id='static', hosts=['test_host'])
    526 
    527 
    528     def test_host_remove_labels(self):
    529         """Test remove labels of a given host."""
    530         label = models.Label.smart_get('static')
    531         label1 = models.Label.smart_get('label1')
    532         label2 = models.Label.smart_get('label2')
    533         static_label = models.StaticLabel.objects.create(name='static')
    534 
    535         host1 = models.Host.objects.create(hostname='test_host')
    536         host1.labels.add(label)
    537         host1.labels.add(label1)
    538         host1.labels.add(label2)
    539         host1.static_labels.add(static_label)
    540         host1.save()
    541 
    542         rpc_interface.host_remove_labels(
    543                 'test_host', ['static', 'label1'])
    544         labels = rpc_interface.get_labels(host__hostname__in=['test_host'])
    545         # Only non_static label 'label1' is removed.
    546         self.assertEquals(len(labels), 2)
    547         self.assertEquals(labels[0].get('name'), 'label2')
    548 
    549 
    550     def test_remove_board_from_shard(self):
    551         """test remove a board (static label) from shard."""
    552         label = models.Label.smart_get('static')
    553         static_label = models.StaticLabel.objects.create(name='static')
    554 
    555         shard = models.Shard.objects.create(hostname='test_shard')
    556         shard.labels.add(label)
    557 
    558         host = models.Host.objects.create(hostname='test_host',
    559                                           leased=False,
    560                                           shard=shard)
    561         host.static_labels.add(static_label)
    562         host.save()
    563 
    564         rpc_interface.remove_board_from_shard(shard.hostname, label.name)
    565         host1 = models.Host.smart_get(host.id)
    566         shard1 = models.Shard.smart_get(shard.id)
    567         self.assertEqual(host1.shard, None)
    568         self.assertItemsEqual(shard1.labels.all(), [])
    569 
    570 
    571     def test_check_job_dependencies_success(self):
    572         """Test check_job_dependencies successfully."""
    573         static_label = models.StaticLabel.objects.create(name='static')
    574 
    575         host = models.Host.objects.create(hostname='test_host')
    576         host.static_labels.add(static_label)
    577         host.save()
    578 
    579         host1 = models.Host.smart_get(host.id)
    580         rpc_utils.check_job_dependencies([host1], ['static'])
    581 
    582 
    583     def test_check_job_dependencies_fail(self):
    584         """Test check_job_dependencies with raising ValidationError."""
    585         label = models.Label.smart_get('static')
    586         static_label = models.StaticLabel.objects.create(name='static')
    587 
    588         host = models.Host.objects.create(hostname='test_host')
    589         host.labels.add(label)
    590         host.save()
    591 
    592         host1 = models.Host.smart_get(host.id)
    593         self.assertRaises(model_logic.ValidationError,
    594                           rpc_utils.check_job_dependencies,
    595                           [host1],
    596                           ['static'])
    597 
    598     def test_check_job_metahost_dependencies_success(self):
    599         """Test check_job_metahost_dependencies successfully."""
    600         label1 = models.Label.smart_get('label1')
    601         label2 = models.Label.smart_get('label2')
    602         label = models.Label.smart_get('static')
    603         static_label = models.StaticLabel.objects.create(name='static')
    604 
    605         host = models.Host.objects.create(hostname='test_host')
    606         host.static_labels.add(static_label)
    607         host.labels.add(label1)
    608         host.labels.add(label2)
    609         host.save()
    610 
    611         rpc_utils.check_job_metahost_dependencies(
    612                 [label1, label], [label2.name])
    613         rpc_utils.check_job_metahost_dependencies(
    614                 [label1], [label2.name, static_label.name])
    615 
    616 
    617     def test_check_job_metahost_dependencies_fail(self):
    618         """Test check_job_metahost_dependencies with raising errors."""
    619         label1 = models.Label.smart_get('label1')
    620         label2 = models.Label.smart_get('label2')
    621         label = models.Label.smart_get('static')
    622         static_label = models.StaticLabel.objects.create(name='static')
    623 
    624         host = models.Host.objects.create(hostname='test_host')
    625         host.labels.add(label1)
    626         host.labels.add(label2)
    627         host.save()
    628 
    629         self.assertRaises(error.NoEligibleHostException,
    630                           rpc_utils.check_job_metahost_dependencies,
    631                           [label1, label], [label2.name])
    632         self.assertRaises(error.NoEligibleHostException,
    633                           rpc_utils.check_job_metahost_dependencies,
    634                           [label1], [label2.name, static_label.name])
    635 
    636 
    637     def _createShardAndHostWithStaticLabel(self,
    638                                            shard_hostname='shard1',
    639                                            host_hostname='test_host1',
    640                                            label_name='board:lumpy'):
    641         label = models.Label.objects.create(name=label_name)
    642 
    643         shard = models.Shard.objects.create(hostname=shard_hostname)
    644         shard.labels.add(label)
    645 
    646         host = models.Host.objects.create(hostname=host_hostname, leased=False,
    647                                           shard=shard)
    648         host.labels.add(label)
    649         if label_name in self._STATIC_LABELS:
    650             models.ReplacedLabel.objects.create(label_id=label.id)
    651             static_label = models.StaticLabel.objects.create(name=label_name)
    652             host.static_labels.add(static_label)
    653 
    654         return shard, host, label
    655 
    656 
    657     def testShardHeartbeatFetchHostlessJob(self):
    658         shard1, host1, label1 = self._createShardAndHostWithStaticLabel(
    659                 host_hostname='test_host1')
    660         self._testShardHeartbeatFetchHostlessJobHelper(host1)
    661 
    662 
    663     def testShardHeartbeatIncorrectHosts(self):
    664         shard1, host1, label1 = self._createShardAndHostWithStaticLabel(
    665                 host_hostname='test_host1')
    666         self._testShardHeartbeatIncorrectHostsHelper(host1)
    667 
    668 
    669     def testShardHeartbeatLabelRemovalRace(self):
    670         shard1, host1, label1 = self._createShardAndHostWithStaticLabel(
    671                 host_hostname='test_host1')
    672         self._testShardHeartbeatLabelRemovalRaceHelper(shard1, host1, label1)
    673 
    674 
    675     def testShardRetrieveJobs(self):
    676         shard1, host1, label1 = self._createShardAndHostWithStaticLabel()
    677         shard2, host2, label2 = self._createShardAndHostWithStaticLabel(
    678             'shard2', 'test_host2', 'board:grumpy')
    679         self._testShardRetrieveJobsHelper(shard1, host1, label1,
    680                                           shard2, host2, label2)
    681 
    682 
    683     def testResendJobsAfterFailedHeartbeat(self):
    684         shard1, host1, label1 = self._createShardAndHostWithStaticLabel()
    685         self._testResendJobsAfterFailedHeartbeatHelper(shard1, host1, label1)
    686 
    687 
    688     def testResendHostsAfterFailedHeartbeat(self):
    689         shard1, host1, label1 = self._createShardAndHostWithStaticLabel(
    690                 host_hostname='test_host1')
    691         self._testResendHostsAfterFailedHeartbeatHelper(host1)
    692 
    693 
    694 class RpcInterfaceTest(unittest.TestCase,
    695                        frontend_test_utils.FrontendTestMixin):
    696     def setUp(self):
    697         self._frontend_common_setup()
    698         self.god = mock.mock_god()
    699 
    700 
    701     def tearDown(self):
    702         self.god.unstub_all()
    703         self._frontend_common_teardown()
    704         global_config.global_config.reset_config_values()
    705 
    706 
    707     def test_validation(self):
    708         # omit a required field
    709         self.assertRaises(model_logic.ValidationError, rpc_interface.add_label,
    710                           name=None)
    711         # violate uniqueness constraint
    712         self.assertRaises(model_logic.ValidationError, rpc_interface.add_host,
    713                           hostname='host1')
    714 
    715 
    716     def test_multiple_platforms(self):
    717         platform2 = models.Label.objects.create(name='platform2', platform=True)
    718         self.assertRaises(model_logic.ValidationError,
    719                           rpc_interface. label_add_hosts, id='platform2',
    720                           hosts=['host1', 'host2'])
    721         self.assertRaises(model_logic.ValidationError,
    722                           rpc_interface.host_add_labels,
    723                           id='host1', labels=['platform2'])
    724         # make sure the platform didn't get added
    725         platforms = rpc_interface.get_labels(
    726             host__hostname__in=['host1', 'host2'], platform=True)
    727         self.assertEquals(len(platforms), 1)
    728         self.assertEquals(platforms[0]['name'], 'myplatform')
    729 
    730 
    731     def _check_hostnames(self, hosts, expected_hostnames):
    732         self.assertEquals(set(host['hostname'] for host in hosts),
    733                           set(expected_hostnames))
    734 
    735 
    736     def test_ping_db(self):
    737         self.assertEquals(rpc_interface.ping_db(), [True])
    738 
    739 
    740     def test_get_hosts_by_attribute(self):
    741         host1 = models.Host.objects.create(hostname='test_host1')
    742         host1.set_attribute('test_attribute1', 'test_value1')
    743         host2 = models.Host.objects.create(hostname='test_host2')
    744         host2.set_attribute('test_attribute1', 'test_value1')
    745 
    746         hosts = rpc_interface.get_hosts_by_attribute(
    747                 'test_attribute1', 'test_value1')
    748         self.assertEquals(set(hosts),
    749                           set(['test_host1', 'test_host2']))
    750 
    751 
    752     def test_get_host_attribute(self):
    753         host1 = models.Host.objects.create(hostname='test_host1')
    754         host1.set_attribute('test_attribute1', 'test_value1')
    755         host2 = models.Host.objects.create(hostname='test_host2')
    756         host2.set_attribute('test_attribute1', 'test_value1')
    757 
    758         attributes = rpc_interface.get_host_attribute(
    759                 'test_attribute1',
    760                 hostname__in=['test_host1', 'test_host2'])
    761         hosts = [attr['host'] for attr in attributes]
    762         values = [attr['value'] for attr in attributes]
    763         self.assertEquals(set(hosts),
    764                           set(['test_host1', 'test_host2']))
    765         self.assertEquals(set(values), set(['test_value1']))
    766 
    767 
    768     def test_get_hosts(self):
    769         hosts = rpc_interface.get_hosts()
    770         self._check_hostnames(hosts, [host.hostname for host in self.hosts])
    771 
    772         hosts = rpc_interface.get_hosts(hostname='host1')
    773         self._check_hostnames(hosts, ['host1'])
    774         host = hosts[0]
    775         self.assertEquals(sorted(host['labels']), ['label1', 'myplatform'])
    776         self.assertEquals(host['platform'], 'myplatform')
    777         self.assertEquals(host['acls'], ['my_acl'])
    778         self.assertEquals(host['attributes'], {})
    779 
    780 
    781     def test_get_hosts_multiple_labels(self):
    782         hosts = rpc_interface.get_hosts(
    783                 multiple_labels=['myplatform', 'label1'])
    784         self._check_hostnames(hosts, ['host1'])
    785 
    786 
    787     def test_job_keyvals(self):
    788         keyval_dict = {'mykey': 'myvalue'}
    789         job_id = rpc_interface.create_job(name='test',
    790                                           priority=priorities.Priority.DEFAULT,
    791                                           control_file='foo',
    792                                           control_type=CLIENT,
    793                                           hosts=['host1'],
    794                                           keyvals=keyval_dict)
    795         jobs = rpc_interface.get_jobs(id=job_id)
    796         self.assertEquals(len(jobs), 1)
    797         self.assertEquals(jobs[0]['keyvals'], keyval_dict)
    798 
    799 
    800     def test_get_jobs_summary(self):
    801         job = self._create_job(hosts=xrange(1, 4))
    802         entries = list(job.hostqueueentry_set.all())
    803         entries[1].status = _hqe_status.FAILED
    804         entries[1].save()
    805         entries[2].status = _hqe_status.FAILED
    806         entries[2].aborted = True
    807         entries[2].save()
    808 
    809         # Mock up tko_rpc_interface.get_status_counts.
    810         self.god.stub_function_to_return(rpc_interface.tko_rpc_interface,
    811                                          'get_status_counts',
    812                                          None)
    813 
    814         job_summaries = rpc_interface.get_jobs_summary(id=job.id)
    815         self.assertEquals(len(job_summaries), 1)
    816         summary = job_summaries[0]
    817         self.assertEquals(summary['status_counts'], {'Queued': 1,
    818                                                      'Failed': 2})
    819 
    820 
    821     def _check_job_ids(self, actual_job_dicts, expected_jobs):
    822         self.assertEquals(
    823                 set(job_dict['id'] for job_dict in actual_job_dicts),
    824                 set(job.id for job in expected_jobs))
    825 
    826 
    827     def test_get_jobs_status_filters(self):
    828         HqeStatus = models.HostQueueEntry.Status
    829         def create_two_host_job():
    830             return self._create_job(hosts=[1, 2])
    831         def set_hqe_statuses(job, first_status, second_status):
    832             entries = job.hostqueueentry_set.all()
    833             entries[0].update_object(status=first_status)
    834             entries[1].update_object(status=second_status)
    835 
    836         queued = create_two_host_job()
    837 
    838         queued_and_running = create_two_host_job()
    839         set_hqe_statuses(queued_and_running, HqeStatus.QUEUED,
    840                            HqeStatus.RUNNING)
    841 
    842         running_and_complete = create_two_host_job()
    843         set_hqe_statuses(running_and_complete, HqeStatus.RUNNING,
    844                            HqeStatus.COMPLETED)
    845 
    846         complete = create_two_host_job()
    847         set_hqe_statuses(complete, HqeStatus.COMPLETED, HqeStatus.COMPLETED)
    848 
    849         started_but_inactive = create_two_host_job()
    850         set_hqe_statuses(started_but_inactive, HqeStatus.QUEUED,
    851                            HqeStatus.COMPLETED)
    852 
    853         parsing = create_two_host_job()
    854         set_hqe_statuses(parsing, HqeStatus.PARSING, HqeStatus.PARSING)
    855 
    856         self._check_job_ids(rpc_interface.get_jobs(not_yet_run=True), [queued])
    857         self._check_job_ids(rpc_interface.get_jobs(running=True),
    858                       [queued_and_running, running_and_complete,
    859                        started_but_inactive, parsing])
    860         self._check_job_ids(rpc_interface.get_jobs(finished=True), [complete])
    861 
    862 
    863     def test_get_jobs_type_filters(self):
    864         self.assertRaises(AssertionError, rpc_interface.get_jobs,
    865                           suite=True, sub=True)
    866         self.assertRaises(AssertionError, rpc_interface.get_jobs,
    867                           suite=True, standalone=True)
    868         self.assertRaises(AssertionError, rpc_interface.get_jobs,
    869                           standalone=True, sub=True)
    870 
    871         parent_job = self._create_job(hosts=[1])
    872         child_jobs = self._create_job(hosts=[1, 2],
    873                                       parent_job_id=parent_job.id)
    874         standalone_job = self._create_job(hosts=[1])
    875 
    876         self._check_job_ids(rpc_interface.get_jobs(suite=True), [parent_job])
    877         self._check_job_ids(rpc_interface.get_jobs(sub=True), [child_jobs])
    878         self._check_job_ids(rpc_interface.get_jobs(standalone=True),
    879                             [standalone_job])
    880 
    881 
    882     def _create_job_helper(self, **kwargs):
    883         return rpc_interface.create_job(name='test',
    884                                         priority=priorities.Priority.DEFAULT,
    885                                         control_file='control file',
    886                                         control_type=SERVER, **kwargs)
    887 
    888 
    889     def test_one_time_hosts(self):
    890         job = self._create_job_helper(one_time_hosts=['testhost'])
    891         host = models.Host.objects.get(hostname='testhost')
    892         self.assertEquals(host.invalid, True)
    893         self.assertEquals(host.labels.count(), 0)
    894         self.assertEquals(host.aclgroup_set.count(), 0)
    895 
    896 
    897     def test_create_job_duplicate_hosts(self):
    898         self.assertRaises(model_logic.ValidationError, self._create_job_helper,
    899                           hosts=[1, 1])
    900 
    901 
    902     def test_create_unrunnable_metahost_job(self):
    903         self.assertRaises(error.NoEligibleHostException,
    904                           self._create_job_helper, meta_hosts=['unused'])
    905 
    906 
    907     def test_create_hostless_job(self):
    908         job_id = self._create_job_helper(hostless=True)
    909         job = models.Job.objects.get(pk=job_id)
    910         queue_entries = job.hostqueueentry_set.all()
    911         self.assertEquals(len(queue_entries), 1)
    912         self.assertEquals(queue_entries[0].host, None)
    913         self.assertEquals(queue_entries[0].meta_host, None)
    914 
    915 
    916     def _setup_special_tasks(self):
    917         host = self.hosts[0]
    918 
    919         job1 = self._create_job(hosts=[1])
    920         job2 = self._create_job(hosts=[1])
    921 
    922         entry1 = job1.hostqueueentry_set.all()[0]
    923         entry1.update_object(started_on=datetime.datetime(2009, 1, 2),
    924                              execution_subdir='host1')
    925         entry2 = job2.hostqueueentry_set.all()[0]
    926         entry2.update_object(started_on=datetime.datetime(2009, 1, 3),
    927                              execution_subdir='host1')
    928 
    929         self.task1 = models.SpecialTask.objects.create(
    930                 host=host, task=models.SpecialTask.Task.VERIFY,
    931                 time_started=datetime.datetime(2009, 1, 1), # ran before job 1
    932                 is_complete=True, requested_by=models.User.current_user())
    933         self.task2 = models.SpecialTask.objects.create(
    934                 host=host, task=models.SpecialTask.Task.VERIFY,
    935                 queue_entry=entry2, # ran with job 2
    936                 is_active=True, requested_by=models.User.current_user())
    937         self.task3 = models.SpecialTask.objects.create(
    938                 host=host, task=models.SpecialTask.Task.VERIFY,
    939                 requested_by=models.User.current_user()) # not yet run
    940 
    941 
    942     def test_get_special_tasks(self):
    943         self._setup_special_tasks()
    944         tasks = rpc_interface.get_special_tasks(host__hostname='host1',
    945                                                 queue_entry__isnull=True)
    946         self.assertEquals(len(tasks), 2)
    947         self.assertEquals(tasks[0]['task'], models.SpecialTask.Task.VERIFY)
    948         self.assertEquals(tasks[0]['is_active'], False)
    949         self.assertEquals(tasks[0]['is_complete'], True)
    950 
    951 
    952     def test_get_latest_special_task(self):
    953         # a particular usage of get_special_tasks()
    954         self._setup_special_tasks()
    955         self.task2.time_started = datetime.datetime(2009, 1, 2)
    956         self.task2.save()
    957 
    958         tasks = rpc_interface.get_special_tasks(
    959                 host__hostname='host1', task=models.SpecialTask.Task.VERIFY,
    960                 time_started__isnull=False, sort_by=['-time_started'],
    961                 query_limit=1)
    962         self.assertEquals(len(tasks), 1)
    963         self.assertEquals(tasks[0]['id'], 2)
    964 
    965 
    966     def _common_entry_check(self, entry_dict):
    967         self.assertEquals(entry_dict['host']['hostname'], 'host1')
    968         self.assertEquals(entry_dict['job']['id'], 2)
    969 
    970 
    971     def test_get_host_queue_entries_and_special_tasks(self):
    972         self._setup_special_tasks()
    973 
    974         host = self.hosts[0].id
    975         entries_and_tasks = (
    976                 rpc_interface.get_host_queue_entries_and_special_tasks(host))
    977 
    978         paths = [entry['execution_path'] for entry in entries_and_tasks]
    979         self.assertEquals(paths, ['hosts/host1/3-verify',
    980                                   '2-autotest_system/host1',
    981                                   'hosts/host1/2-verify',
    982                                   '1-autotest_system/host1',
    983                                   'hosts/host1/1-verify'])
    984 
    985         verify2 = entries_and_tasks[2]
    986         self._common_entry_check(verify2)
    987         self.assertEquals(verify2['type'], 'Verify')
    988         self.assertEquals(verify2['status'], 'Running')
    989         self.assertEquals(verify2['execution_path'], 'hosts/host1/2-verify')
    990 
    991         entry2 = entries_and_tasks[1]
    992         self._common_entry_check(entry2)
    993         self.assertEquals(entry2['type'], 'Job')
    994         self.assertEquals(entry2['status'], 'Queued')
    995         self.assertEquals(entry2['started_on'], '2009-01-03 00:00:00')
    996 
    997 
    998     def _create_hqes_and_start_time_index_entries(self):
    999         shard = models.Shard.objects.create(hostname='shard')
   1000         job = self._create_job(shard=shard, control_file='foo')
   1001         HqeStatus = models.HostQueueEntry.Status
   1002 
   1003         models.HostQueueEntry(
   1004             id=1, job=job, started_on='2017-01-01',
   1005             status=HqeStatus.QUEUED).save()
   1006         models.HostQueueEntry(
   1007             id=2, job=job, started_on='2017-01-02',
   1008             status=HqeStatus.QUEUED).save()
   1009         models.HostQueueEntry(
   1010             id=3, job=job, started_on='2017-01-03',
   1011             status=HqeStatus.QUEUED).save()
   1012 
   1013         models.HostQueueEntryStartTimes(
   1014             insert_time='2017-01-03', highest_hqe_id=3).save()
   1015         models.HostQueueEntryStartTimes(
   1016             insert_time='2017-01-02', highest_hqe_id=2).save()
   1017         models.HostQueueEntryStartTimes(
   1018             insert_time='2017-01-01', highest_hqe_id=1).save()
   1019 
   1020     def test_get_host_queue_entries_by_insert_time(self):
   1021         """Check the insert_time_after and insert_time_before constraints."""
   1022         self._create_hqes_and_start_time_index_entries()
   1023         hqes = rpc_interface.get_host_queue_entries_by_insert_time(
   1024             insert_time_after='2017-01-01')
   1025         self.assertEquals(len(hqes), 3)
   1026 
   1027         hqes = rpc_interface.get_host_queue_entries_by_insert_time(
   1028             insert_time_after='2017-01-02')
   1029         self.assertEquals(len(hqes), 2)
   1030 
   1031         hqes = rpc_interface.get_host_queue_entries_by_insert_time(
   1032             insert_time_after='2017-01-03')
   1033         self.assertEquals(len(hqes), 1)
   1034 
   1035         hqes = rpc_interface.get_host_queue_entries_by_insert_time(
   1036             insert_time_before='2017-01-01')
   1037         self.assertEquals(len(hqes), 1)
   1038 
   1039         hqes = rpc_interface.get_host_queue_entries_by_insert_time(
   1040             insert_time_before='2017-01-02')
   1041         self.assertEquals(len(hqes), 2)
   1042 
   1043         hqes = rpc_interface.get_host_queue_entries_by_insert_time(
   1044             insert_time_before='2017-01-03')
   1045         self.assertEquals(len(hqes), 3)
   1046 
   1047 
   1048     def test_get_host_queue_entries_by_insert_time_with_missing_index_row(self):
   1049         """Shows that the constraints are approximate.
   1050 
   1051         The query may return rows which are actually outside of the bounds
   1052         given, if the index table does not have an entry for the specific time.
   1053         """
   1054         self._create_hqes_and_start_time_index_entries()
   1055         hqes = rpc_interface.get_host_queue_entries_by_insert_time(
   1056             insert_time_before='2016-12-01')
   1057         self.assertEquals(len(hqes), 1)
   1058 
   1059     def test_get_hqe_by_insert_time_with_before_and_after(self):
   1060         self._create_hqes_and_start_time_index_entries()
   1061         hqes = rpc_interface.get_host_queue_entries_by_insert_time(
   1062             insert_time_before='2017-01-02',
   1063             insert_time_after='2017-01-02')
   1064         self.assertEquals(len(hqes), 1)
   1065 
   1066     def test_get_hqe_by_insert_time_and_id_constraint(self):
   1067         self._create_hqes_and_start_time_index_entries()
   1068         # The time constraint is looser than the id constraint, so the time
   1069         # constraint should take precedence.
   1070         hqes = rpc_interface.get_host_queue_entries_by_insert_time(
   1071             insert_time_before='2017-01-02',
   1072             id__lte=1)
   1073         self.assertEquals(len(hqes), 1)
   1074 
   1075         # Now make the time constraint tighter than the id constraint.
   1076         hqes = rpc_interface.get_host_queue_entries_by_insert_time(
   1077             insert_time_before='2017-01-01',
   1078             id__lte=42)
   1079         self.assertEquals(len(hqes), 1)
   1080 
   1081     def test_view_invalid_host(self):
   1082         # RPCs used by View Host page should work for invalid hosts
   1083         self._create_job_helper(hosts=[1])
   1084         host = self.hosts[0]
   1085         host.delete()
   1086 
   1087         self.assertEquals(1, rpc_interface.get_num_hosts(hostname='host1',
   1088                                                          valid_only=False))
   1089         data = rpc_interface.get_hosts(hostname='host1', valid_only=False)
   1090         self.assertEquals(1, len(data))
   1091 
   1092         self.assertEquals(1, rpc_interface.get_num_host_queue_entries(
   1093                 host__hostname='host1'))
   1094         data = rpc_interface.get_host_queue_entries(host__hostname='host1')
   1095         self.assertEquals(1, len(data))
   1096 
   1097         count = rpc_interface.get_num_host_queue_entries_and_special_tasks(
   1098                 host=host.id)
   1099         self.assertEquals(1, count)
   1100         data = rpc_interface.get_host_queue_entries_and_special_tasks(
   1101                 host=host.id)
   1102         self.assertEquals(1, len(data))
   1103 
   1104 
   1105     def test_reverify_hosts(self):
   1106         hostname_list = rpc_interface.reverify_hosts(id__in=[1, 2])
   1107         self.assertEquals(hostname_list, ['host1', 'host2'])
   1108         tasks = rpc_interface.get_special_tasks()
   1109         self.assertEquals(len(tasks), 2)
   1110         self.assertEquals(set(task['host']['id'] for task in tasks),
   1111                           set([1, 2]))
   1112 
   1113         task = tasks[0]
   1114         self.assertEquals(task['task'], models.SpecialTask.Task.VERIFY)
   1115         self.assertEquals(task['requested_by'], 'autotest_system')
   1116 
   1117 
   1118     def test_repair_hosts(self):
   1119         hostname_list = rpc_interface.repair_hosts(id__in=[1, 2])
   1120         self.assertEquals(hostname_list, ['host1', 'host2'])
   1121         tasks = rpc_interface.get_special_tasks()
   1122         self.assertEquals(len(tasks), 2)
   1123         self.assertEquals(set(task['host']['id'] for task in tasks),
   1124                           set([1, 2]))
   1125 
   1126         task = tasks[0]
   1127         self.assertEquals(task['task'], models.SpecialTask.Task.REPAIR)
   1128         self.assertEquals(task['requested_by'], 'autotest_system')
   1129 
   1130 
   1131     def _modify_host_helper(self, on_shard=False, host_on_shard=False):
   1132         shard_hostname = 'shard1'
   1133         if on_shard:
   1134             global_config.global_config.override_config_value(
   1135                 'SHARD', 'shard_hostname', shard_hostname)
   1136 
   1137         host = models.Host.objects.all()[0]
   1138         if host_on_shard:
   1139             shard = models.Shard.objects.create(hostname=shard_hostname)
   1140             host.shard = shard
   1141             host.save()
   1142 
   1143         self.assertFalse(host.locked)
   1144 
   1145         self.god.stub_class_method(frontend.AFE, 'run')
   1146 
   1147         if host_on_shard and not on_shard:
   1148             mock_afe = self.god.create_mock_class_obj(
   1149                     frontend_wrappers.RetryingAFE, 'MockAFE')
   1150             self.god.stub_with(frontend_wrappers, 'RetryingAFE', mock_afe)
   1151 
   1152             mock_afe2 = frontend_wrappers.RetryingAFE.expect_new(
   1153                     server=shard_hostname, user=None)
   1154             mock_afe2.run.expect_call('modify_host_local', id=host.id,
   1155                     locked=True, lock_reason='_modify_host_helper lock',
   1156                     lock_time=datetime.datetime(2015, 12, 15))
   1157         elif on_shard:
   1158             mock_afe = self.god.create_mock_class_obj(
   1159                     frontend_wrappers.RetryingAFE, 'MockAFE')
   1160             self.god.stub_with(frontend_wrappers, 'RetryingAFE', mock_afe)
   1161 
   1162             mock_afe2 = frontend_wrappers.RetryingAFE.expect_new(
   1163                     server=server_utils.get_global_afe_hostname(), user=None)
   1164             mock_afe2.run.expect_call('modify_host', id=host.id,
   1165                     locked=True, lock_reason='_modify_host_helper lock',
   1166                     lock_time=datetime.datetime(2015, 12, 15))
   1167 
   1168         rpc_interface.modify_host(id=host.id, locked=True,
   1169                                   lock_reason='_modify_host_helper lock',
   1170                                   lock_time=datetime.datetime(2015, 12, 15))
   1171 
   1172         host = models.Host.objects.get(pk=host.id)
   1173         if on_shard:
   1174             # modify_host on shard does nothing but routing the RPC to master.
   1175             self.assertFalse(host.locked)
   1176         else:
   1177             self.assertTrue(host.locked)
   1178         self.god.check_playback()
   1179 
   1180 
   1181     def test_modify_host_on_master_host_on_master(self):
   1182         """Call modify_host to master for host in master."""
   1183         self._modify_host_helper()
   1184 
   1185 
   1186     def test_modify_host_on_master_host_on_shard(self):
   1187         """Call modify_host to master for host in shard."""
   1188         self._modify_host_helper(host_on_shard=True)
   1189 
   1190 
   1191     def test_modify_host_on_shard(self):
   1192         """Call modify_host to shard for host in shard."""
   1193         self._modify_host_helper(on_shard=True, host_on_shard=True)
   1194 
   1195 
   1196     def test_modify_hosts_on_master_host_on_shard(self):
   1197         """Ensure calls to modify_hosts are correctly forwarded to shards."""
   1198         host1 = models.Host.objects.all()[0]
   1199         host2 = models.Host.objects.all()[1]
   1200 
   1201         shard1 = models.Shard.objects.create(hostname='shard1')
   1202         host1.shard = shard1
   1203         host1.save()
   1204 
   1205         shard2 = models.Shard.objects.create(hostname='shard2')
   1206         host2.shard = shard2
   1207         host2.save()
   1208 
   1209         self.assertFalse(host1.locked)
   1210         self.assertFalse(host2.locked)
   1211 
   1212         mock_afe = self.god.create_mock_class_obj(frontend_wrappers.RetryingAFE,
   1213                                                   'MockAFE')
   1214         self.god.stub_with(frontend_wrappers, 'RetryingAFE', mock_afe)
   1215 
   1216         # The statuses of one host might differ on master and shard.
   1217         # Filters are always applied on the master. So the host on the shard
   1218         # will be affected no matter what his status is.
   1219         filters_to_use = {'status': 'Ready'}
   1220 
   1221         mock_afe2 = frontend_wrappers.RetryingAFE.expect_new(
   1222                 server='shard2', user=None)
   1223         mock_afe2.run.expect_call(
   1224             'modify_hosts_local',
   1225             host_filter_data={'id__in': [shard1.id, shard2.id]},
   1226             update_data={'locked': True,
   1227                          'lock_reason': 'Testing forward to shard',
   1228                          'lock_time' : datetime.datetime(2015, 12, 15) })
   1229 
   1230         mock_afe1 = frontend_wrappers.RetryingAFE.expect_new(
   1231                 server='shard1', user=None)
   1232         mock_afe1.run.expect_call(
   1233             'modify_hosts_local',
   1234             host_filter_data={'id__in': [shard1.id, shard2.id]},
   1235             update_data={'locked': True,
   1236                          'lock_reason': 'Testing forward to shard',
   1237                          'lock_time' : datetime.datetime(2015, 12, 15)})
   1238 
   1239         rpc_interface.modify_hosts(
   1240                 host_filter_data={'status': 'Ready'},
   1241                 update_data={'locked': True,
   1242                              'lock_reason': 'Testing forward to shard',
   1243                              'lock_time' : datetime.datetime(2015, 12, 15) })
   1244 
   1245         host1 = models.Host.objects.get(pk=host1.id)
   1246         self.assertTrue(host1.locked)
   1247         host2 = models.Host.objects.get(pk=host2.id)
   1248         self.assertTrue(host2.locked)
   1249         self.god.check_playback()
   1250 
   1251 
   1252     def test_delete_host(self):
   1253         """Ensure an RPC is made on delete a host, if it is on a shard."""
   1254         host1 = models.Host.objects.all()[0]
   1255         shard1 = models.Shard.objects.create(hostname='shard1')
   1256         host1.shard = shard1
   1257         host1.save()
   1258         host1_id = host1.id
   1259 
   1260         mock_afe = self.god.create_mock_class_obj(frontend_wrappers.RetryingAFE,
   1261                                                  'MockAFE')
   1262         self.god.stub_with(frontend_wrappers, 'RetryingAFE', mock_afe)
   1263 
   1264         mock_afe1 = frontend_wrappers.RetryingAFE.expect_new(
   1265                 server='shard1', user=None)
   1266         mock_afe1.run.expect_call('delete_host', id=host1.id)
   1267 
   1268         rpc_interface.delete_host(id=host1.id)
   1269 
   1270         self.assertRaises(models.Host.DoesNotExist,
   1271                           models.Host.smart_get, host1_id)
   1272 
   1273         self.god.check_playback()
   1274 
   1275 
   1276     def test_delete_shard(self):
   1277         """Ensure the RPC can delete a shard."""
   1278         host1 = models.Host.objects.all()[0]
   1279         shard1 = models.Shard.objects.create(hostname='shard1')
   1280         host1.shard = shard1
   1281         host1.save()
   1282 
   1283         rpc_interface.delete_shard(hostname=shard1.hostname)
   1284 
   1285         host1 = models.Host.smart_get(host1.id)
   1286         self.assertIsNone(host1.shard)
   1287         self.assertRaises(models.Shard.DoesNotExist,
   1288                           models.Shard.smart_get, shard1.hostname)
   1289 
   1290 
   1291     def test_modify_label(self):
   1292         label1 = models.Label.objects.all()[0]
   1293         self.assertEqual(label1.invalid, 0)
   1294 
   1295         host2 = models.Host.objects.all()[1]
   1296         shard1 = models.Shard.objects.create(hostname='shard1')
   1297         host2.shard = shard1
   1298         host2.labels.add(label1)
   1299         host2.save()
   1300 
   1301         mock_afe = self.god.create_mock_class_obj(frontend_wrappers.RetryingAFE,
   1302                                                   'MockAFE')
   1303         self.god.stub_with(frontend_wrappers, 'RetryingAFE', mock_afe)
   1304 
   1305         mock_afe1 = frontend_wrappers.RetryingAFE.expect_new(
   1306                 server='shard1', user=None)
   1307         mock_afe1.run.expect_call('modify_label', id=label1.id, invalid=1)
   1308 
   1309         rpc_interface.modify_label(label1.id, invalid=1)
   1310 
   1311         self.assertEqual(models.Label.objects.all()[0].invalid, 1)
   1312         self.god.check_playback()
   1313 
   1314 
   1315     def test_delete_label(self):
   1316         label1 = models.Label.objects.all()[0]
   1317 
   1318         host2 = models.Host.objects.all()[1]
   1319         shard1 = models.Shard.objects.create(hostname='shard1')
   1320         host2.shard = shard1
   1321         host2.labels.add(label1)
   1322         host2.save()
   1323 
   1324         mock_afe = self.god.create_mock_class_obj(frontend_wrappers.RetryingAFE,
   1325                                                   'MockAFE')
   1326         self.god.stub_with(frontend_wrappers, 'RetryingAFE', mock_afe)
   1327 
   1328         mock_afe1 = frontend_wrappers.RetryingAFE.expect_new(
   1329                 server='shard1', user=None)
   1330         mock_afe1.run.expect_call('delete_label', id=label1.id)
   1331 
   1332         rpc_interface.delete_label(id=label1.id)
   1333 
   1334         self.assertRaises(models.Label.DoesNotExist,
   1335                           models.Label.smart_get, label1.id)
   1336         self.god.check_playback()
   1337 
   1338 
   1339     def test_get_image_for_job_with_keyval_build(self):
   1340         keyval_dict = {'build': 'cool-image'}
   1341         job_id = rpc_interface.create_job(name='test',
   1342                                           priority=priorities.Priority.DEFAULT,
   1343                                           control_file='foo',
   1344                                           control_type=CLIENT,
   1345                                           hosts=['host1'],
   1346                                           keyvals=keyval_dict)
   1347         job = models.Job.objects.get(id=job_id)
   1348         self.assertIsNotNone(job)
   1349         image = rpc_interface._get_image_for_job(job, True)
   1350         self.assertEquals('cool-image', image)
   1351 
   1352 
   1353     def test_get_image_for_job_with_keyval_builds(self):
   1354         keyval_dict = {'builds': {'cros-version': 'cool-image'}}
   1355         job_id = rpc_interface.create_job(name='test',
   1356                                           priority=priorities.Priority.DEFAULT,
   1357                                           control_file='foo',
   1358                                           control_type=CLIENT,
   1359                                           hosts=['host1'],
   1360                                           keyvals=keyval_dict)
   1361         job = models.Job.objects.get(id=job_id)
   1362         self.assertIsNotNone(job)
   1363         image = rpc_interface._get_image_for_job(job, True)
   1364         self.assertEquals('cool-image', image)
   1365 
   1366 
   1367     def test_get_image_for_job_with_control_build(self):
   1368         CONTROL_FILE = """build='cool-image'
   1369         """
   1370         job_id = rpc_interface.create_job(name='test',
   1371                                           priority=priorities.Priority.DEFAULT,
   1372                                           control_file='foo',
   1373                                           control_type=CLIENT,
   1374                                           hosts=['host1'])
   1375         job = models.Job.objects.get(id=job_id)
   1376         self.assertIsNotNone(job)
   1377         job.control_file = CONTROL_FILE
   1378         image = rpc_interface._get_image_for_job(job, True)
   1379         self.assertEquals('cool-image', image)
   1380 
   1381 
   1382     def test_get_image_for_job_with_control_builds(self):
   1383         CONTROL_FILE = """builds={'cros-version': 'cool-image'}
   1384         """
   1385         job_id = rpc_interface.create_job(name='test',
   1386                                           priority=priorities.Priority.DEFAULT,
   1387                                           control_file='foo',
   1388                                           control_type=CLIENT,
   1389                                           hosts=['host1'])
   1390         job = models.Job.objects.get(id=job_id)
   1391         self.assertIsNotNone(job)
   1392         job.control_file = CONTROL_FILE
   1393         image = rpc_interface._get_image_for_job(job, True)
   1394         self.assertEquals('cool-image', image)
   1395 
   1396 
   1397 class ExtraRpcInterfaceTest(frontend_test_utils.FrontendTestMixin,
   1398                             ShardHeartbeatTest):
   1399     """Unit tests for functions originally in site_rpc_interface.py.
   1400 
   1401     @var _NAME: fake suite name.
   1402     @var _BOARD: fake board to reimage.
   1403     @var _BUILD: fake build with which to reimage.
   1404     @var _PRIORITY: fake priority with which to reimage.
   1405     """
   1406     _NAME = 'name'
   1407     _BOARD = 'link'
   1408     _BUILD = 'link-release/R36-5812.0.0'
   1409     _BUILDS = {provision.CROS_VERSION_PREFIX: _BUILD}
   1410     _PRIORITY = priorities.Priority.DEFAULT
   1411     _TIMEOUT = 24
   1412 
   1413 
   1414     def setUp(self):
   1415         super(ExtraRpcInterfaceTest, self).setUp()
   1416         self._SUITE_NAME = suite_common.canonicalize_suite_name(
   1417             self._NAME)
   1418         self.dev_server = self.mox.CreateMock(dev_server.ImageServer)
   1419         self._frontend_common_setup(fill_data=False)
   1420 
   1421 
   1422     def tearDown(self):
   1423         self._frontend_common_teardown()
   1424 
   1425 
   1426     def _setupDevserver(self):
   1427         self.mox.StubOutClassWithMocks(dev_server, 'ImageServer')
   1428         dev_server.resolve(self._BUILD).AndReturn(self.dev_server)
   1429 
   1430 
   1431     def _mockDevServerGetter(self, get_control_file=True):
   1432         self._setupDevserver()
   1433         if get_control_file:
   1434           self.getter = self.mox.CreateMock(
   1435               control_file_getter.DevServerGetter)
   1436           self.mox.StubOutWithMock(control_file_getter.DevServerGetter,
   1437                                    'create')
   1438           control_file_getter.DevServerGetter.create(
   1439               mox.IgnoreArg(), mox.IgnoreArg()).AndReturn(self.getter)
   1440 
   1441 
   1442     def _mockRpcUtils(self, to_return, control_file_substring=''):
   1443         """Fake out the autotest rpc_utils module with a mockable class.
   1444 
   1445         @param to_return: the value that rpc_utils.create_job_common() should
   1446                           be mocked out to return.
   1447         @param control_file_substring: A substring that is expected to appear
   1448                                        in the control file output string that
   1449                                        is passed to create_job_common.
   1450                                        Default: ''
   1451         """
   1452         download_started_time = constants.DOWNLOAD_STARTED_TIME
   1453         payload_finished_time = constants.PAYLOAD_FINISHED_TIME
   1454         self.mox.StubOutWithMock(rpc_utils, 'create_job_common')
   1455         rpc_utils.create_job_common(mox.And(mox.StrContains(self._NAME),
   1456                                     mox.StrContains(self._BUILD)),
   1457                             priority=self._PRIORITY,
   1458                             timeout_mins=self._TIMEOUT*60,
   1459                             max_runtime_mins=self._TIMEOUT*60,
   1460                             control_type='Server',
   1461                             control_file=mox.And(mox.StrContains(self._BOARD),
   1462                                                  mox.StrContains(self._BUILD),
   1463                                                  mox.StrContains(
   1464                                                      control_file_substring)),
   1465                             hostless=True,
   1466                             keyvals=mox.And(mox.In(download_started_time),
   1467                                             mox.In(payload_finished_time))
   1468                             ).AndReturn(to_return)
   1469 
   1470 
   1471     def testStageBuildFail(self):
   1472         """Ensure that a failure to stage the desired build fails the RPC."""
   1473         self._setupDevserver()
   1474 
   1475         self.dev_server.hostname = 'mox_url'
   1476         self.dev_server.stage_artifacts(
   1477                 image=self._BUILD, artifacts=['test_suites']).AndRaise(
   1478                 dev_server.DevServerException())
   1479         self.mox.ReplayAll()
   1480         self.assertRaises(error.StageControlFileFailure,
   1481                           rpc_interface.create_suite_job,
   1482                           name=self._NAME,
   1483                           board=self._BOARD,
   1484                           builds=self._BUILDS,
   1485                           pool=None)
   1486 
   1487 
   1488     def testGetControlFileFail(self):
   1489         """Ensure that a failure to get needed control file fails the RPC."""
   1490         self._mockDevServerGetter()
   1491 
   1492         self.dev_server.hostname = 'mox_url'
   1493         self.dev_server.stage_artifacts(
   1494                 image=self._BUILD, artifacts=['test_suites']).AndReturn(True)
   1495 
   1496         self.getter.get_control_file_contents_by_name(
   1497             self._SUITE_NAME).AndReturn(None)
   1498         self.mox.ReplayAll()
   1499         self.assertRaises(error.ControlFileEmpty,
   1500                           rpc_interface.create_suite_job,
   1501                           name=self._NAME,
   1502                           board=self._BOARD,
   1503                           builds=self._BUILDS,
   1504                           pool=None)
   1505 
   1506 
   1507     def testGetControlFileListFail(self):
   1508         """Ensure that a failure to get needed control file fails the RPC."""
   1509         self._mockDevServerGetter()
   1510 
   1511         self.dev_server.hostname = 'mox_url'
   1512         self.dev_server.stage_artifacts(
   1513                 image=self._BUILD, artifacts=['test_suites']).AndReturn(True)
   1514 
   1515         self.getter.get_control_file_contents_by_name(
   1516             self._SUITE_NAME).AndRaise(error.NoControlFileList())
   1517         self.mox.ReplayAll()
   1518         self.assertRaises(error.NoControlFileList,
   1519                           rpc_interface.create_suite_job,
   1520                           name=self._NAME,
   1521                           board=self._BOARD,
   1522                           builds=self._BUILDS,
   1523                           pool=None)
   1524 
   1525 
   1526     def testCreateSuiteJobFail(self):
   1527         """Ensure that failure to schedule the suite job fails the RPC."""
   1528         self._mockDevServerGetter()
   1529 
   1530         self.dev_server.hostname = 'mox_url'
   1531         self.dev_server.stage_artifacts(
   1532                 image=self._BUILD, artifacts=['test_suites']).AndReturn(True)
   1533 
   1534         self.getter.get_control_file_contents_by_name(
   1535             self._SUITE_NAME).AndReturn('f')
   1536 
   1537         self.dev_server.url().AndReturn('mox_url')
   1538         self._mockRpcUtils(-1)
   1539         self.mox.ReplayAll()
   1540         self.assertEquals(
   1541             rpc_interface.create_suite_job(name=self._NAME,
   1542                                            board=self._BOARD,
   1543                                            builds=self._BUILDS, pool=None),
   1544             -1)
   1545 
   1546 
   1547     def testCreateSuiteJobSuccess(self):
   1548         """Ensures that success results in a successful RPC."""
   1549         self._mockDevServerGetter()
   1550 
   1551         self.dev_server.hostname = 'mox_url'
   1552         self.dev_server.stage_artifacts(
   1553                 image=self._BUILD, artifacts=['test_suites']).AndReturn(True)
   1554 
   1555         self.getter.get_control_file_contents_by_name(
   1556             self._SUITE_NAME).AndReturn('f')
   1557 
   1558         self.dev_server.url().AndReturn('mox_url')
   1559         job_id = 5
   1560         self._mockRpcUtils(job_id)
   1561         self.mox.ReplayAll()
   1562         self.assertEquals(
   1563             rpc_interface.create_suite_job(name=self._NAME,
   1564                                            board=self._BOARD,
   1565                                            builds=self._BUILDS,
   1566                                            pool=None),
   1567             job_id)
   1568 
   1569 
   1570     def testCreateSuiteJobNoHostCheckSuccess(self):
   1571         """Ensures that success results in a successful RPC."""
   1572         self._mockDevServerGetter()
   1573 
   1574         self.dev_server.hostname = 'mox_url'
   1575         self.dev_server.stage_artifacts(
   1576                 image=self._BUILD, artifacts=['test_suites']).AndReturn(True)
   1577 
   1578         self.getter.get_control_file_contents_by_name(
   1579             self._SUITE_NAME).AndReturn('f')
   1580 
   1581         self.dev_server.url().AndReturn('mox_url')
   1582         job_id = 5
   1583         self._mockRpcUtils(job_id)
   1584         self.mox.ReplayAll()
   1585         self.assertEquals(
   1586           rpc_interface.create_suite_job(name=self._NAME,
   1587                                          board=self._BOARD,
   1588                                          builds=self._BUILDS,
   1589                                          pool=None, check_hosts=False),
   1590           job_id)
   1591 
   1592 
   1593     def testCreateSuiteJobControlFileSupplied(self):
   1594         """Ensure we can supply the control file to create_suite_job."""
   1595         self._mockDevServerGetter(get_control_file=False)
   1596 
   1597         self.dev_server.hostname = 'mox_url'
   1598         self.dev_server.stage_artifacts(
   1599                 image=self._BUILD, artifacts=['test_suites']).AndReturn(True)
   1600         self.dev_server.url().AndReturn('mox_url')
   1601         job_id = 5
   1602         self._mockRpcUtils(job_id)
   1603         self.mox.ReplayAll()
   1604         self.assertEquals(
   1605             rpc_interface.create_suite_job(name='%s/%s' % (self._NAME,
   1606                                                            self._BUILD),
   1607                                            board=None,
   1608                                            builds=self._BUILDS,
   1609                                            pool=None,
   1610                                            control_file='CONTROL FILE'),
   1611             job_id)
   1612 
   1613 
   1614     def _get_records_for_sending_to_master(self):
   1615         return [{'control_file': 'foo',
   1616                  'control_type': 1,
   1617                  'created_on': datetime.datetime(2014, 8, 21),
   1618                  'drone_set': None,
   1619                  'email_list': '',
   1620                  'max_runtime_hrs': 72,
   1621                  'max_runtime_mins': 1440,
   1622                  'name': 'dummy',
   1623                  'owner': 'autotest_system',
   1624                  'parse_failed_repair': True,
   1625                  'priority': 40,
   1626                  'reboot_after': 0,
   1627                  'reboot_before': 1,
   1628                  'run_reset': True,
   1629                  'run_verify': False,
   1630                  'synch_count': 0,
   1631                  'test_retry': 0,
   1632                  'timeout': 24,
   1633                  'timeout_mins': 1440,
   1634                  'id': 1
   1635                  }], [{
   1636                     'aborted': False,
   1637                     'active': False,
   1638                     'complete': False,
   1639                     'deleted': False,
   1640                     'execution_subdir': '',
   1641                     'finished_on': None,
   1642                     'started_on': None,
   1643                     'status': 'Queued',
   1644                     'id': 1
   1645                 }]
   1646 
   1647 
   1648     def _send_records_to_master_helper(
   1649         self, jobs, hqes, shard_hostname='host1',
   1650         exception_to_throw=error.UnallowedRecordsSentToMaster, aborted=False):
   1651         job_id = rpc_interface.create_job(
   1652                 name='dummy',
   1653                 priority=self._PRIORITY,
   1654                 control_file='foo',
   1655                 control_type=SERVER,
   1656                 hostless=True)
   1657         job = models.Job.objects.get(pk=job_id)
   1658         shard = models.Shard.objects.create(hostname='host1')
   1659         job.shard = shard
   1660         job.save()
   1661 
   1662         if aborted:
   1663             job.hostqueueentry_set.update(aborted=True)
   1664             job.shard = None
   1665             job.save()
   1666 
   1667         hqe = job.hostqueueentry_set.all()[0]
   1668         if not exception_to_throw:
   1669             self._do_heartbeat_and_assert_response(
   1670                 shard_hostname=shard_hostname,
   1671                 upload_jobs=jobs, upload_hqes=hqes)
   1672         else:
   1673             self.assertRaises(
   1674                 exception_to_throw,
   1675                 self._do_heartbeat_and_assert_response,
   1676                 shard_hostname=shard_hostname,
   1677                 upload_jobs=jobs, upload_hqes=hqes)
   1678 
   1679 
   1680     def testSendingRecordsToMaster(self):
   1681         """Send records to the master and ensure they are persisted."""
   1682         jobs, hqes = self._get_records_for_sending_to_master()
   1683         hqes[0]['status'] = 'Completed'
   1684         self._send_records_to_master_helper(
   1685             jobs=jobs, hqes=hqes, exception_to_throw=None)
   1686 
   1687         # Check the entry was actually written to db
   1688         self.assertEqual(models.HostQueueEntry.objects.all()[0].status,
   1689                          'Completed')
   1690 
   1691 
   1692     def testSendingRecordsToMasterAbortedOnMaster(self):
   1693         """Send records to the master and ensure they are persisted."""
   1694         jobs, hqes = self._get_records_for_sending_to_master()
   1695         hqes[0]['status'] = 'Completed'
   1696         self._send_records_to_master_helper(
   1697             jobs=jobs, hqes=hqes, exception_to_throw=None, aborted=True)
   1698 
   1699         # Check the entry was actually written to db
   1700         self.assertEqual(models.HostQueueEntry.objects.all()[0].status,
   1701                          'Completed')
   1702 
   1703 
   1704     def testSendingRecordsToMasterJobAssignedToDifferentShard(self):
   1705         """Ensure records belonging to different shard are silently rejected."""
   1706         shard1 = models.Shard.objects.create(hostname='shard1')
   1707         shard2 = models.Shard.objects.create(hostname='shard2')
   1708         job1 = self._create_job(shard=shard1, control_file='foo1')
   1709         job2 = self._create_job(shard=shard2, control_file='foo2')
   1710         job1_id = job1.id
   1711         job2_id = job2.id
   1712         hqe1 = models.HostQueueEntry.objects.create(job=job1)
   1713         hqe2 = models.HostQueueEntry.objects.create(job=job2)
   1714         hqe1_id = hqe1.id
   1715         hqe2_id = hqe2.id
   1716         job1_record = job1.serialize(include_dependencies=False)
   1717         job2_record = job2.serialize(include_dependencies=False)
   1718         hqe1_record = hqe1.serialize(include_dependencies=False)
   1719         hqe2_record = hqe2.serialize(include_dependencies=False)
   1720 
   1721         # Prepare a bogus job record update from the wrong shard. The update
   1722         # should not throw an exception. Non-bogus jobs in the same update
   1723         # should happily update.
   1724         job1_record.update({'control_file': 'bar1'})
   1725         job2_record.update({'control_file': 'bar2'})
   1726         hqe1_record.update({'status': 'Aborted'})
   1727         hqe2_record.update({'status': 'Aborted'})
   1728         self._do_heartbeat_and_assert_response(
   1729             shard_hostname='shard2', upload_jobs=[job1_record, job2_record],
   1730             upload_hqes=[hqe1_record, hqe2_record])
   1731 
   1732         # Job and HQE record for wrong job should not be modified, because the
   1733         # rpc came from the wrong shard. Job and HQE record for valid job are
   1734         # modified.
   1735         self.assertEqual(models.Job.objects.get(id=job1_id).control_file,
   1736                          'foo1')
   1737         self.assertEqual(models.Job.objects.get(id=job2_id).control_file,
   1738                          'bar2')
   1739         self.assertEqual(models.HostQueueEntry.objects.get(id=hqe1_id).status,
   1740                          '')
   1741         self.assertEqual(models.HostQueueEntry.objects.get(id=hqe2_id).status,
   1742                          'Aborted')
   1743 
   1744 
   1745     def testSendingRecordsToMasterNotExistingJob(self):
   1746         """Ensure update for non existing job gets rejected."""
   1747         jobs, hqes = self._get_records_for_sending_to_master()
   1748         jobs[0]['id'] = 3
   1749 
   1750         self._send_records_to_master_helper(
   1751             jobs=jobs, hqes=hqes)
   1752 
   1753 
   1754     def _createShardAndHostWithLabel(self, shard_hostname='shard1',
   1755                                      host_hostname='host1',
   1756                                      label_name='board:lumpy'):
   1757         """Create a label, host, shard, and assign host to shard."""
   1758         try:
   1759             label = models.Label.objects.create(name=label_name)
   1760         except:
   1761             label = models.Label.smart_get(label_name)
   1762 
   1763         shard = models.Shard.objects.create(hostname=shard_hostname)
   1764         shard.labels.add(label)
   1765 
   1766         host = models.Host.objects.create(hostname=host_hostname, leased=False,
   1767                                           shard=shard)
   1768         host.labels.add(label)
   1769 
   1770         return shard, host, label
   1771 
   1772 
   1773     def testShardLabelRemovalInvalid(self):
   1774         """Ensure you cannot remove the wrong label from shard."""
   1775         shard1, host1, lumpy_label = self._createShardAndHostWithLabel()
   1776         stumpy_label = models.Label.objects.create(
   1777                 name='board:stumpy', platform=True)
   1778         with self.assertRaises(error.RPCException):
   1779             rpc_interface.remove_board_from_shard(
   1780                     shard1.hostname, stumpy_label.name)
   1781 
   1782 
   1783     def testShardHeartbeatLabelRemoval(self):
   1784         """Ensure label removal from shard works."""
   1785         shard1, host1, lumpy_label = self._createShardAndHostWithLabel()
   1786 
   1787         self.assertEqual(host1.shard, shard1)
   1788         self.assertItemsEqual(shard1.labels.all(), [lumpy_label])
   1789         rpc_interface.remove_board_from_shard(
   1790                 shard1.hostname, lumpy_label.name)
   1791         host1 = models.Host.smart_get(host1.id)
   1792         shard1 = models.Shard.smart_get(shard1.id)
   1793         self.assertEqual(host1.shard, None)
   1794         self.assertItemsEqual(shard1.labels.all(), [])
   1795 
   1796 
   1797     def testCreateListShard(self):
   1798         """Retrieve a list of all shards."""
   1799         lumpy_label = models.Label.objects.create(name='board:lumpy',
   1800                                                   platform=True)
   1801         stumpy_label = models.Label.objects.create(name='board:stumpy',
   1802                                                   platform=True)
   1803         peppy_label = models.Label.objects.create(name='board:peppy',
   1804                                                   platform=True)
   1805 
   1806         shard_id = rpc_interface.add_shard(
   1807             hostname='host1', labels='board:lumpy,board:stumpy')
   1808         self.assertRaises(error.RPCException,
   1809                           rpc_interface.add_shard,
   1810                           hostname='host1', labels='board:lumpy,board:stumpy')
   1811         self.assertRaises(model_logic.ValidationError,
   1812                           rpc_interface.add_shard,
   1813                           hostname='host1', labels='board:peppy')
   1814         shard = models.Shard.objects.get(pk=shard_id)
   1815         self.assertEqual(shard.hostname, 'host1')
   1816         self.assertEqual(shard.labels.values_list('pk')[0], (lumpy_label.id,))
   1817         self.assertEqual(shard.labels.values_list('pk')[1], (stumpy_label.id,))
   1818 
   1819         self.assertEqual(rpc_interface.get_shards(),
   1820                          [{'labels': ['board:lumpy','board:stumpy'],
   1821                            'hostname': 'host1',
   1822                            'id': 1}])
   1823 
   1824 
   1825     def testAddBoardsToShard(self):
   1826         """Add boards to a given shard."""
   1827         shard1, host1, lumpy_label = self._createShardAndHostWithLabel()
   1828         stumpy_label = models.Label.objects.create(name='board:stumpy',
   1829                                                    platform=True)
   1830         shard_id = rpc_interface.add_board_to_shard(
   1831             hostname='shard1', labels='board:stumpy')
   1832         # Test whether raise exception when board label does not exist.
   1833         self.assertRaises(models.Label.DoesNotExist,
   1834                           rpc_interface.add_board_to_shard,
   1835                           hostname='shard1', labels='board:test')
   1836         # Test whether raise exception when board already sharded.
   1837         self.assertRaises(error.RPCException,
   1838                           rpc_interface.add_board_to_shard,
   1839                           hostname='shard1', labels='board:lumpy')
   1840         shard = models.Shard.objects.get(pk=shard_id)
   1841         self.assertEqual(shard.hostname, 'shard1')
   1842         self.assertEqual(shard.labels.values_list('pk')[0], (lumpy_label.id,))
   1843         self.assertEqual(shard.labels.values_list('pk')[1], (stumpy_label.id,))
   1844 
   1845         self.assertEqual(rpc_interface.get_shards(),
   1846                          [{'labels': ['board:lumpy','board:stumpy'],
   1847                            'hostname': 'shard1',
   1848                            'id': 1}])
   1849 
   1850 
   1851     def testShardHeartbeatFetchHostlessJob(self):
   1852         shard1, host1, label1 = self._createShardAndHostWithLabel()
   1853         self._testShardHeartbeatFetchHostlessJobHelper(host1)
   1854 
   1855 
   1856     def testShardHeartbeatIncorrectHosts(self):
   1857         shard1, host1, label1 = self._createShardAndHostWithLabel()
   1858         self._testShardHeartbeatIncorrectHostsHelper(host1)
   1859 
   1860 
   1861     def testShardHeartbeatLabelRemovalRace(self):
   1862         shard1, host1, label1 = self._createShardAndHostWithLabel()
   1863         self._testShardHeartbeatLabelRemovalRaceHelper(shard1, host1, label1)
   1864 
   1865 
   1866     def testShardRetrieveJobs(self):
   1867         shard1, host1, label1 = self._createShardAndHostWithLabel()
   1868         shard2, host2, label2 = self._createShardAndHostWithLabel(
   1869                 'shard2', 'host2', 'board:grumpy')
   1870         self._testShardRetrieveJobsHelper(shard1, host1, label1,
   1871                                           shard2, host2, label2)
   1872 
   1873 
   1874     def testResendJobsAfterFailedHeartbeat(self):
   1875         shard1, host1, label1 = self._createShardAndHostWithLabel()
   1876         self._testResendJobsAfterFailedHeartbeatHelper(shard1, host1, label1)
   1877 
   1878 
   1879     def testResendHostsAfterFailedHeartbeat(self):
   1880         shard1, host1, label1 = self._createShardAndHostWithLabel()
   1881         self._testResendHostsAfterFailedHeartbeatHelper(host1)
   1882 
   1883 
   1884 if __name__ == '__main__':
   1885     unittest.main()
   1886