Home | History | Annotate | Download | only in afe
      1 # pylint: disable=missing-docstring
      2 """
      3 Utility functions for rpc_interface.py.  We keep them in a separate file so that
      4 only RPC interface functions go into that file.
      5 """
      6 
      7 __author__ = 'showard (at] google.com (Steve Howard)'
      8 
      9 import collections
     10 import datetime
     11 from functools import wraps
     12 import inspect
     13 import os
     14 import sys
     15 import django.db.utils
     16 import django.http
     17 
     18 from autotest_lib.frontend import thread_local
     19 from autotest_lib.frontend.afe import models, model_logic
     20 from autotest_lib.client.common_lib import control_data, error
     21 from autotest_lib.client.common_lib import global_config
     22 from autotest_lib.client.common_lib import time_utils
     23 from autotest_lib.client.common_lib.cros import dev_server
     24 # TODO(akeshet): Replace with monarch once we know how to instrument rpc server
     25 # with ts_mon.
     26 from autotest_lib.client.common_lib.cros.graphite import autotest_stats
     27 from autotest_lib.server import utils as server_utils
     28 from autotest_lib.server.cros import provision
     29 from autotest_lib.server.cros.dynamic_suite import frontend_wrappers
     30 
     31 NULL_DATETIME = datetime.datetime.max
     32 NULL_DATE = datetime.date.max
     33 DUPLICATE_KEY_MSG = 'Duplicate entry'
     34 
     35 def prepare_for_serialization(objects):
     36     """
     37     Prepare Python objects to be returned via RPC.
     38     @param objects: objects to be prepared.
     39     """
     40     if (isinstance(objects, list) and len(objects) and
     41         isinstance(objects[0], dict) and 'id' in objects[0]):
     42         objects = _gather_unique_dicts(objects)
     43     return _prepare_data(objects)
     44 
     45 
     46 def prepare_rows_as_nested_dicts(query, nested_dict_column_names):
     47     """
     48     Prepare a Django query to be returned via RPC as a sequence of nested
     49     dictionaries.
     50 
     51     @param query - A Django model query object with a select_related() method.
     52     @param nested_dict_column_names - A list of column/attribute names for the
     53             rows returned by query to expand into nested dictionaries using
     54             their get_object_dict() method when not None.
     55 
     56     @returns An list suitable to returned in an RPC.
     57     """
     58     all_dicts = []
     59     for row in query.select_related():
     60         row_dict = row.get_object_dict()
     61         for column in nested_dict_column_names:
     62             if row_dict[column] is not None:
     63                 row_dict[column] = getattr(row, column).get_object_dict()
     64         all_dicts.append(row_dict)
     65     return prepare_for_serialization(all_dicts)
     66 
     67 
     68 def _prepare_data(data):
     69     """
     70     Recursively process data structures, performing necessary type
     71     conversions to values in data to allow for RPC serialization:
     72     -convert datetimes to strings
     73     -convert tuples and sets to lists
     74     """
     75     if isinstance(data, dict):
     76         new_data = {}
     77         for key, value in data.iteritems():
     78             new_data[key] = _prepare_data(value)
     79         return new_data
     80     elif (isinstance(data, list) or isinstance(data, tuple) or
     81           isinstance(data, set)):
     82         return [_prepare_data(item) for item in data]
     83     elif isinstance(data, datetime.date):
     84         if data is NULL_DATETIME or data is NULL_DATE:
     85             return None
     86         return str(data)
     87     else:
     88         return data
     89 
     90 
     91 def fetchall_as_list_of_dicts(cursor):
     92     """
     93     Converts each row in the cursor to a dictionary so that values can be read
     94     by using the column name.
     95     @param cursor: The database cursor to read from.
     96     @returns: A list of each row in the cursor as a dictionary.
     97     """
     98     desc = cursor.description
     99     return [ dict(zip([col[0] for col in desc], row))
    100              for row in cursor.fetchall() ]
    101 
    102 
    103 def raw_http_response(response_data, content_type=None):
    104     response = django.http.HttpResponse(response_data, mimetype=content_type)
    105     response['Content-length'] = str(len(response.content))
    106     return response
    107 
    108 
    109 def _gather_unique_dicts(dict_iterable):
    110     """\
    111     Pick out unique objects (by ID) from an iterable of object dicts.
    112     """
    113     objects = collections.OrderedDict()
    114     for obj in dict_iterable:
    115         objects.setdefault(obj['id'], obj)
    116     return objects.values()
    117 
    118 
    119 def extra_job_status_filters(not_yet_run=False, running=False, finished=False):
    120     """\
    121     Generate a SQL WHERE clause for job status filtering, and return it in
    122     a dict of keyword args to pass to query.extra().
    123     * not_yet_run: all HQEs are Queued
    124     * finished: all HQEs are complete
    125     * running: everything else
    126     """
    127     if not (not_yet_run or running or finished):
    128         return {}
    129     not_queued = ('(SELECT job_id FROM afe_host_queue_entries '
    130                   'WHERE status != "%s")'
    131                   % models.HostQueueEntry.Status.QUEUED)
    132     not_finished = ('(SELECT job_id FROM afe_host_queue_entries '
    133                     'WHERE not complete)')
    134 
    135     where = []
    136     if not_yet_run:
    137         where.append('id NOT IN ' + not_queued)
    138     if running:
    139         where.append('(id IN %s) AND (id IN %s)' % (not_queued, not_finished))
    140     if finished:
    141         where.append('id NOT IN ' + not_finished)
    142     return {'where': [' OR '.join(['(%s)' % x for x in where])]}
    143 
    144 
    145 def extra_job_type_filters(extra_args, suite=False,
    146                            sub=False, standalone=False):
    147     """\
    148     Generate a SQL WHERE clause for job status filtering, and return it in
    149     a dict of keyword args to pass to query.extra().
    150 
    151     param extra_args: a dict of existing extra_args.
    152 
    153     No more than one of the parameters should be passed as True:
    154     * suite: job which is parent of other jobs
    155     * sub: job with a parent job
    156     * standalone: job with no child or parent jobs
    157     """
    158     assert not ((suite and sub) or
    159                 (suite and standalone) or
    160                 (sub and standalone)), ('Cannot specify more than one '
    161                                         'filter to this function')
    162 
    163     where = extra_args.get('where', [])
    164     parent_job_id = ('DISTINCT parent_job_id')
    165     child_job_id = ('id')
    166     filter_common = ('(SELECT %s FROM afe_jobs '
    167                      'WHERE parent_job_id IS NOT NULL)')
    168 
    169     if suite:
    170         where.append('id IN ' + filter_common % parent_job_id)
    171     elif sub:
    172         where.append('id IN ' + filter_common % child_job_id)
    173     elif standalone:
    174         where.append('NOT EXISTS (SELECT 1 from afe_jobs AS sub_query '
    175                      'WHERE parent_job_id IS NOT NULL'
    176                      ' AND (sub_query.parent_job_id=afe_jobs.id'
    177                      ' OR sub_query.id=afe_jobs.id))')
    178     else:
    179         return extra_args
    180 
    181     extra_args['where'] = where
    182     return extra_args
    183 
    184 
    185 
    186 def extra_host_filters(multiple_labels=()):
    187     """\
    188     Generate SQL WHERE clauses for matching hosts in an intersection of
    189     labels.
    190     """
    191     extra_args = {}
    192     where_str = ('afe_hosts.id in (select host_id from afe_hosts_labels '
    193                  'where label_id=%s)')
    194     extra_args['where'] = [where_str] * len(multiple_labels)
    195     extra_args['params'] = [models.Label.smart_get(label).id
    196                             for label in multiple_labels]
    197     return extra_args
    198 
    199 
    200 def get_host_query(multiple_labels, exclude_only_if_needed_labels,
    201                    valid_only, filter_data):
    202     if valid_only:
    203         query = models.Host.valid_objects.all()
    204     else:
    205         query = models.Host.objects.all()
    206 
    207     if exclude_only_if_needed_labels:
    208         only_if_needed_labels = models.Label.valid_objects.filter(
    209             only_if_needed=True)
    210         if only_if_needed_labels.count() > 0:
    211             only_if_needed_ids = ','.join(
    212                     str(label['id'])
    213                     for label in only_if_needed_labels.values('id'))
    214             query = models.Host.objects.add_join(
    215                 query, 'afe_hosts_labels', join_key='host_id',
    216                 join_condition=('afe_hosts_labels_exclude_OIN.label_id IN (%s)'
    217                                 % only_if_needed_ids),
    218                 suffix='_exclude_OIN', exclude=True)
    219     try:
    220         assert 'extra_args' not in filter_data
    221         filter_data['extra_args'] = extra_host_filters(multiple_labels)
    222         return models.Host.query_objects(filter_data, initial_query=query)
    223     except models.Label.DoesNotExist:
    224         return models.Host.objects.none()
    225 
    226 
    227 class InconsistencyException(Exception):
    228     'Raised when a list of objects does not have a consistent value'
    229 
    230 
    231 def get_consistent_value(objects, field):
    232     if not objects:
    233         # well a list of nothing is consistent
    234         return None
    235 
    236     value = getattr(objects[0], field)
    237     for obj in objects:
    238         this_value = getattr(obj, field)
    239         if this_value != value:
    240             raise InconsistencyException(objects[0], obj)
    241     return value
    242 
    243 
    244 def afe_test_dict_to_test_object(test_dict):
    245     if not isinstance(test_dict, dict):
    246         return test_dict
    247 
    248     numerized_dict = {}
    249     for key, value in test_dict.iteritems():
    250         try:
    251             numerized_dict[key] = int(value)
    252         except (ValueError, TypeError):
    253             numerized_dict[key] = value
    254 
    255     return type('TestObject', (object,), numerized_dict)
    256 
    257 
    258 def _check_is_server_test(test_type):
    259     """Checks if the test type is a server test.
    260 
    261     @param test_type The test type in enum integer or string.
    262 
    263     @returns A boolean to identify if the test type is server test.
    264     """
    265     if test_type is not None:
    266         if isinstance(test_type, basestring):
    267             try:
    268                 test_type = control_data.CONTROL_TYPE.get_value(test_type)
    269             except AttributeError:
    270                 return False
    271         return (test_type == control_data.CONTROL_TYPE.SERVER)
    272     return False
    273 
    274 
    275 def prepare_generate_control_file(tests, profilers, db_tests=True):
    276     if db_tests:
    277         test_objects = [models.Test.smart_get(test) for test in tests]
    278     else:
    279         test_objects = [afe_test_dict_to_test_object(test) for test in tests]
    280 
    281     profiler_objects = [models.Profiler.smart_get(profiler)
    282                         for profiler in profilers]
    283     # ensure tests are all the same type
    284     try:
    285         test_type = get_consistent_value(test_objects, 'test_type')
    286     except InconsistencyException, exc:
    287         test1, test2 = exc.args
    288         raise model_logic.ValidationError(
    289             {'tests' : 'You cannot run both test_suites and server-side '
    290              'tests together (tests %s and %s differ' % (
    291             test1.name, test2.name)})
    292 
    293     is_server = _check_is_server_test(test_type)
    294     if test_objects:
    295         synch_count = max(test.sync_count for test in test_objects)
    296     else:
    297         synch_count = 1
    298 
    299     if db_tests:
    300         dependencies = set(label.name for label
    301                            in models.Label.objects.filter(test__in=test_objects))
    302     else:
    303         dependencies = reduce(
    304                 set.union, [set(test.dependencies) for test in test_objects])
    305 
    306     cf_info = dict(is_server=is_server, synch_count=synch_count,
    307                    dependencies=list(dependencies))
    308     return cf_info, test_objects, profiler_objects
    309 
    310 
    311 def check_job_dependencies(host_objects, job_dependencies):
    312     """
    313     Check that a set of machines satisfies a job's dependencies.
    314     host_objects: list of models.Host objects
    315     job_dependencies: list of names of labels
    316     """
    317     # check that hosts satisfy dependencies
    318     host_ids = [host.id for host in host_objects]
    319     hosts_in_job = models.Host.objects.filter(id__in=host_ids)
    320     ok_hosts = hosts_in_job
    321     for index, dependency in enumerate(job_dependencies):
    322         if not provision.is_for_special_action(dependency):
    323             ok_hosts = ok_hosts.filter(labels__name=dependency)
    324     failing_hosts = (set(host.hostname for host in host_objects) -
    325                      set(host.hostname for host in ok_hosts))
    326     if failing_hosts:
    327         raise model_logic.ValidationError(
    328             {'hosts' : 'Host(s) failed to meet job dependencies (' +
    329                        (', '.join(job_dependencies)) + '): ' +
    330                        (', '.join(failing_hosts))})
    331 
    332 
    333 def check_job_metahost_dependencies(metahost_objects, job_dependencies):
    334     """
    335     Check that at least one machine within the metahost spec satisfies the job's
    336     dependencies.
    337 
    338     @param metahost_objects A list of label objects representing the metahosts.
    339     @param job_dependencies A list of strings of the required label names.
    340     @raises NoEligibleHostException If a metahost cannot run the job.
    341     """
    342     for metahost in metahost_objects:
    343         hosts = models.Host.objects.filter(labels=metahost)
    344         for label_name in job_dependencies:
    345             if not provision.is_for_special_action(label_name):
    346                 hosts = hosts.filter(labels__name=label_name)
    347         if not any(hosts):
    348             raise error.NoEligibleHostException("No hosts within %s satisfy %s."
    349                     % (metahost.name, ', '.join(job_dependencies)))
    350 
    351 
    352 def _execution_key_for(host_queue_entry):
    353     return (host_queue_entry.job.id, host_queue_entry.execution_subdir)
    354 
    355 
    356 def check_abort_synchronous_jobs(host_queue_entries):
    357     # ensure user isn't aborting part of a synchronous autoserv execution
    358     count_per_execution = {}
    359     for queue_entry in host_queue_entries:
    360         key = _execution_key_for(queue_entry)
    361         count_per_execution.setdefault(key, 0)
    362         count_per_execution[key] += 1
    363 
    364     for queue_entry in host_queue_entries:
    365         if not queue_entry.execution_subdir:
    366             continue
    367         execution_count = count_per_execution[_execution_key_for(queue_entry)]
    368         if execution_count < queue_entry.job.synch_count:
    369             raise model_logic.ValidationError(
    370                 {'' : 'You cannot abort part of a synchronous job execution '
    371                       '(%d/%s), %d included, %d expected'
    372                       % (queue_entry.job.id, queue_entry.execution_subdir,
    373                          execution_count, queue_entry.job.synch_count)})
    374 
    375 
    376 def check_modify_host(update_data):
    377     """
    378     Sanity check modify_host* requests.
    379 
    380     @param update_data: A dictionary with the changes to make to a host
    381             or hosts.
    382     """
    383     # Only the scheduler (monitor_db) is allowed to modify Host status.
    384     # Otherwise race conditions happen as a hosts state is changed out from
    385     # beneath tasks being run on a host.
    386     if 'status' in update_data:
    387         raise model_logic.ValidationError({
    388                 'status': 'Host status can not be modified by the frontend.'})
    389 
    390 
    391 def check_modify_host_locking(host, update_data):
    392     """
    393     Checks when locking/unlocking has been requested if the host is already
    394     locked/unlocked.
    395 
    396     @param host: models.Host object to be modified
    397     @param update_data: A dictionary with the changes to make to the host.
    398     """
    399     locked = update_data.get('locked', None)
    400     lock_reason = update_data.get('lock_reason', None)
    401     if locked is not None:
    402         if locked and host.locked:
    403             raise model_logic.ValidationError({
    404                     'locked': 'Host %s already locked by %s on %s.' %
    405                     (host.hostname, host.locked_by, host.lock_time)})
    406         if not locked and not host.locked:
    407             raise model_logic.ValidationError({
    408                     'locked': 'Host %s already unlocked.' % host.hostname})
    409         if locked and not lock_reason and not host.locked:
    410             raise model_logic.ValidationError({
    411                     'locked': 'Please provide a reason for locking Host %s' %
    412                     host.hostname})
    413 
    414 
    415 def get_motd():
    416     dirname = os.path.dirname(__file__)
    417     filename = os.path.join(dirname, "..", "..", "motd.txt")
    418     text = ''
    419     try:
    420         fp = open(filename, "r")
    421         try:
    422             text = fp.read()
    423         finally:
    424             fp.close()
    425     except:
    426         pass
    427 
    428     return text
    429 
    430 
    431 def _get_metahost_counts(metahost_objects):
    432     metahost_counts = {}
    433     for metahost in metahost_objects:
    434         metahost_counts.setdefault(metahost, 0)
    435         metahost_counts[metahost] += 1
    436     return metahost_counts
    437 
    438 
    439 def get_job_info(job, preserve_metahosts=False, queue_entry_filter_data=None):
    440     hosts = []
    441     one_time_hosts = []
    442     meta_hosts = []
    443     hostless = False
    444 
    445     queue_entries = job.hostqueueentry_set.all()
    446     if queue_entry_filter_data:
    447         queue_entries = models.HostQueueEntry.query_objects(
    448             queue_entry_filter_data, initial_query=queue_entries)
    449 
    450     for queue_entry in queue_entries:
    451         if (queue_entry.host and (preserve_metahosts or
    452                                   not queue_entry.meta_host)):
    453             if queue_entry.deleted:
    454                 continue
    455             if queue_entry.host.invalid:
    456                 one_time_hosts.append(queue_entry.host)
    457             else:
    458                 hosts.append(queue_entry.host)
    459         elif queue_entry.meta_host:
    460             meta_hosts.append(queue_entry.meta_host)
    461         else:
    462             hostless = True
    463 
    464     meta_host_counts = _get_metahost_counts(meta_hosts)
    465 
    466     info = dict(dependencies=[label.name for label
    467                               in job.dependency_labels.all()],
    468                 hosts=hosts,
    469                 meta_hosts=meta_hosts,
    470                 meta_host_counts=meta_host_counts,
    471                 one_time_hosts=one_time_hosts,
    472                 hostless=hostless)
    473     return info
    474 
    475 
    476 def check_for_duplicate_hosts(host_objects):
    477     host_counts = collections.Counter(host_objects)
    478     duplicate_hostnames = {host.hostname
    479                            for host, count in host_counts.iteritems()
    480                            if count > 1}
    481     if duplicate_hostnames:
    482         raise model_logic.ValidationError(
    483                 {'hosts' : 'Duplicate hosts: %s'
    484                  % ', '.join(duplicate_hostnames)})
    485 
    486 
    487 def create_new_job(owner, options, host_objects, metahost_objects):
    488     all_host_objects = host_objects + metahost_objects
    489     dependencies = options.get('dependencies', [])
    490     synch_count = options.get('synch_count')
    491 
    492     if synch_count is not None and synch_count > len(all_host_objects):
    493         raise model_logic.ValidationError(
    494                 {'hosts':
    495                  'only %d hosts provided for job with synch_count = %d' %
    496                  (len(all_host_objects), synch_count)})
    497 
    498     check_for_duplicate_hosts(host_objects)
    499 
    500     for label_name in dependencies:
    501         if provision.is_for_special_action(label_name):
    502             # TODO: We could save a few queries
    503             # if we had a bulk ensure-label-exists function, which used
    504             # a bulk .get() call. The win is probably very small.
    505             _ensure_label_exists(label_name)
    506 
    507     # This only checks targeted hosts, not hosts eligible due to the metahost
    508     check_job_dependencies(host_objects, dependencies)
    509     check_job_metahost_dependencies(metahost_objects, dependencies)
    510 
    511     options['dependencies'] = list(
    512             models.Label.objects.filter(name__in=dependencies))
    513 
    514     job = models.Job.create(owner=owner, options=options,
    515                             hosts=all_host_objects)
    516     job.queue(all_host_objects,
    517               is_template=options.get('is_template', False))
    518     return job.id
    519 
    520 
    521 def _ensure_label_exists(name):
    522     """
    523     Ensure that a label called |name| exists in the Django models.
    524 
    525     This function is to be called from within afe rpcs only, as an
    526     alternative to server.cros.provision.ensure_label_exists(...). It works
    527     by Django model manipulation, rather than by making another create_label
    528     rpc call.
    529 
    530     @param name: the label to check for/create.
    531     @raises ValidationError: There was an error in the response that was
    532                              not because the label already existed.
    533     @returns True is a label was created, False otherwise.
    534     """
    535     # Make sure this function is not called on shards but only on master.
    536     assert not server_utils.is_shard()
    537     try:
    538         models.Label.objects.get(name=name)
    539     except models.Label.DoesNotExist:
    540         try:
    541             new_label = models.Label.objects.create(name=name)
    542             new_label.save()
    543             return True
    544         except django.db.utils.IntegrityError as e:
    545             # It is possible that another suite/test already
    546             # created the label between the check and save.
    547             if DUPLICATE_KEY_MSG in str(e):
    548                 return False
    549             else:
    550                 raise
    551     return False
    552 
    553 
    554 def find_platform(host):
    555     """
    556     Figure out the platform name for the given host
    557     object.  If none, the return value for either will be None.
    558 
    559     @returns platform name for the given host.
    560     """
    561     platforms = [label.name for label in host.label_list if label.platform]
    562     if not platforms:
    563         platform = None
    564     else:
    565         platform = platforms[0]
    566     if len(platforms) > 1:
    567         raise ValueError('Host %s has more than one platform: %s' %
    568                          (host.hostname, ', '.join(platforms)))
    569     return platform
    570 
    571 
    572 # support for get_host_queue_entries_and_special_tasks()
    573 
    574 def _common_entry_to_dict(entry, type, job_dict, exec_path, status, started_on):
    575     return dict(type=type,
    576                 host=entry['host'],
    577                 job=job_dict,
    578                 execution_path=exec_path,
    579                 status=status,
    580                 started_on=started_on,
    581                 id=str(entry['id']) + type,
    582                 oid=entry['id'])
    583 
    584 
    585 def _special_task_to_dict(task, queue_entries):
    586     """Transforms a special task dictionary to another form of dictionary.
    587 
    588     @param task           Special task as a dictionary type
    589     @param queue_entries  Host queue entries as a list of dictionaries.
    590 
    591     @return Transformed dictionary for a special task.
    592     """
    593     job_dict = None
    594     if task['queue_entry']:
    595         # Scan queue_entries to get the job detail info.
    596         for qentry in queue_entries:
    597             if task['queue_entry']['id'] == qentry['id']:
    598                 job_dict = qentry['job']
    599                 break
    600         # If not found, get it from DB.
    601         if job_dict is None:
    602             job = models.Job.objects.get(id=task['queue_entry']['job'])
    603             job_dict = job.get_object_dict()
    604 
    605     exec_path = server_utils.get_special_task_exec_path(
    606             task['host']['hostname'], task['id'], task['task'],
    607             time_utils.time_string_to_datetime(task['time_requested']))
    608     status = server_utils.get_special_task_status(
    609             task['is_complete'], task['success'], task['is_active'])
    610     return _common_entry_to_dict(task, task['task'], job_dict,
    611             exec_path, status, task['time_started'])
    612 
    613 
    614 def _queue_entry_to_dict(queue_entry):
    615     job_dict = queue_entry['job']
    616     tag = server_utils.get_job_tag(job_dict['id'], job_dict['owner'])
    617     exec_path = server_utils.get_hqe_exec_path(tag,
    618                                                queue_entry['execution_subdir'])
    619     return _common_entry_to_dict(queue_entry, 'Job', job_dict, exec_path,
    620             queue_entry['status'], queue_entry['started_on'])
    621 
    622 
    623 def prepare_host_queue_entries_and_special_tasks(interleaved_entries,
    624                                                  queue_entries):
    625     """
    626     Prepare for serialization the interleaved entries of host queue entries
    627     and special tasks.
    628     Each element in the entries is a dictionary type.
    629     The special task dictionary has only a job id for a job and lacks
    630     the detail of the job while the host queue entry dictionary has.
    631     queue_entries is used to look up the job detail info.
    632 
    633     @param interleaved_entries  Host queue entries and special tasks as a list
    634                                 of dictionaries.
    635     @param queue_entries        Host queue entries as a list of dictionaries.
    636 
    637     @return A post-processed list of dictionaries that is to be serialized.
    638     """
    639     dict_list = []
    640     for e in interleaved_entries:
    641         # Distinguish the two mixed entries based on the existence of
    642         # the key "task". If an entry has the key, the entry is for
    643         # special task. Otherwise, host queue entry.
    644         if 'task' in e:
    645             dict_list.append(_special_task_to_dict(e, queue_entries))
    646         else:
    647             dict_list.append(_queue_entry_to_dict(e))
    648     return prepare_for_serialization(dict_list)
    649 
    650 
    651 def _compute_next_job_for_tasks(queue_entries, special_tasks):
    652     """
    653     For each task, try to figure out the next job that ran after that task.
    654     This is done using two pieces of information:
    655     * if the task has a queue entry, we can use that entry's job ID.
    656     * if the task has a time_started, we can try to compare that against the
    657       started_on field of queue_entries. this isn't guaranteed to work perfectly
    658       since queue_entries may also have null started_on values.
    659     * if the task has neither, or if use of time_started fails, just use the
    660       last computed job ID.
    661 
    662     @param queue_entries    Host queue entries as a list of dictionaries.
    663     @param special_tasks    Special tasks as a list of dictionaries.
    664     """
    665     next_job_id = None # most recently computed next job
    666     hqe_index = 0 # index for scanning by started_on times
    667     for task in special_tasks:
    668         if task['queue_entry']:
    669             next_job_id = task['queue_entry']['job']
    670         elif task['time_started'] is not None:
    671             for queue_entry in queue_entries[hqe_index:]:
    672                 if queue_entry['started_on'] is None:
    673                     continue
    674                 t1 = time_utils.time_string_to_datetime(
    675                         queue_entry['started_on'])
    676                 t2 = time_utils.time_string_to_datetime(task['time_started'])
    677                 if t1 < t2:
    678                     break
    679                 next_job_id = queue_entry['job']['id']
    680 
    681         task['next_job_id'] = next_job_id
    682 
    683         # advance hqe_index to just after next_job_id
    684         if next_job_id is not None:
    685             for queue_entry in queue_entries[hqe_index:]:
    686                 if queue_entry['job']['id'] < next_job_id:
    687                     break
    688                 hqe_index += 1
    689 
    690 
    691 def interleave_entries(queue_entries, special_tasks):
    692     """
    693     Both lists should be ordered by descending ID.
    694     """
    695     _compute_next_job_for_tasks(queue_entries, special_tasks)
    696 
    697     # start with all special tasks that've run since the last job
    698     interleaved_entries = []
    699     for task in special_tasks:
    700         if task['next_job_id'] is not None:
    701             break
    702         interleaved_entries.append(task)
    703 
    704     # now interleave queue entries with the remaining special tasks
    705     special_task_index = len(interleaved_entries)
    706     for queue_entry in queue_entries:
    707         interleaved_entries.append(queue_entry)
    708         # add all tasks that ran between this job and the previous one
    709         for task in special_tasks[special_task_index:]:
    710             if task['next_job_id'] < queue_entry['job']['id']:
    711                 break
    712             interleaved_entries.append(task)
    713             special_task_index += 1
    714 
    715     return interleaved_entries
    716 
    717 
    718 def bucket_hosts_by_shard(host_objs, rpc_hostnames=False):
    719     """Figure out which hosts are on which shards.
    720 
    721     @param host_objs: A list of host objects.
    722     @param rpc_hostnames: If True, the rpc_hostnames of a shard are returned
    723         instead of the 'real' shard hostnames. This only matters for testing
    724         environments.
    725 
    726     @return: A map of shard hostname: list of hosts on the shard.
    727     """
    728     shard_host_map = collections.defaultdict(list)
    729     for host in host_objs:
    730         if host.shard:
    731             shard_name = (host.shard.rpc_hostname() if rpc_hostnames
    732                           else host.shard.hostname)
    733             shard_host_map[shard_name].append(host.hostname)
    734     return shard_host_map
    735 
    736 
    737 def create_job_common(
    738         name,
    739         priority,
    740         control_type,
    741         control_file=None,
    742         hosts=(),
    743         meta_hosts=(),
    744         one_time_hosts=(),
    745         synch_count=None,
    746         is_template=False,
    747         timeout=None,
    748         timeout_mins=None,
    749         max_runtime_mins=None,
    750         run_verify=True,
    751         email_list='',
    752         dependencies=(),
    753         reboot_before=None,
    754         reboot_after=None,
    755         parse_failed_repair=None,
    756         hostless=False,
    757         keyvals=None,
    758         drone_set=None,
    759         parent_job_id=None,
    760         test_retry=0,
    761         run_reset=True,
    762         require_ssp=None):
    763     #pylint: disable-msg=C0111
    764     """
    765     Common code between creating "standard" jobs and creating parameterized jobs
    766     """
    767     # input validation
    768     host_args_passed = any((hosts, meta_hosts, one_time_hosts))
    769     if hostless:
    770         if host_args_passed:
    771             raise model_logic.ValidationError({
    772                     'hostless': 'Hostless jobs cannot include any hosts!'})
    773         if control_type != control_data.CONTROL_TYPE_NAMES.SERVER:
    774             raise model_logic.ValidationError({
    775                     'control_type': 'Hostless jobs cannot use client-side '
    776                                     'control files'})
    777     elif not host_args_passed:
    778         raise model_logic.ValidationError({
    779             'arguments' : "For host jobs, you must pass at least one of"
    780                           " 'hosts', 'meta_hosts', 'one_time_hosts'."
    781             })
    782     label_objects = list(models.Label.objects.filter(name__in=meta_hosts))
    783 
    784     # convert hostnames & meta hosts to host/label objects
    785     host_objects = models.Host.smart_get_bulk(hosts)
    786     _validate_host_job_sharding(host_objects)
    787     for host in one_time_hosts:
    788         this_host = models.Host.create_one_time_host(host)
    789         host_objects.append(this_host)
    790 
    791     metahost_objects = []
    792     meta_host_labels_by_name = {label.name: label for label in label_objects}
    793     for label_name in meta_hosts:
    794         if label_name in meta_host_labels_by_name:
    795             metahost_objects.append(meta_host_labels_by_name[label_name])
    796         else:
    797             raise model_logic.ValidationError(
    798                 {'meta_hosts' : 'Label "%s" not found' % label_name})
    799 
    800     options = dict(name=name,
    801                    priority=priority,
    802                    control_file=control_file,
    803                    control_type=control_type,
    804                    is_template=is_template,
    805                    timeout=timeout,
    806                    timeout_mins=timeout_mins,
    807                    max_runtime_mins=max_runtime_mins,
    808                    synch_count=synch_count,
    809                    run_verify=run_verify,
    810                    email_list=email_list,
    811                    dependencies=dependencies,
    812                    reboot_before=reboot_before,
    813                    reboot_after=reboot_after,
    814                    parse_failed_repair=parse_failed_repair,
    815                    keyvals=keyvals,
    816                    drone_set=drone_set,
    817                    parent_job_id=parent_job_id,
    818                    test_retry=test_retry,
    819                    run_reset=run_reset,
    820                    require_ssp=require_ssp)
    821 
    822     return create_new_job(owner=models.User.current_user().login,
    823                           options=options,
    824                           host_objects=host_objects,
    825                           metahost_objects=metahost_objects)
    826 
    827 
    828 def _validate_host_job_sharding(host_objects):
    829     """Check that the hosts obey job sharding rules."""
    830     if not (server_utils.is_shard()
    831             or _allowed_hosts_for_master_job(host_objects)):
    832         shard_host_map = bucket_hosts_by_shard(host_objects)
    833         raise ValueError(
    834                 'The following hosts are on shard(s), please create '
    835                 'seperate jobs for hosts on each shard: %s ' %
    836                 shard_host_map)
    837 
    838 
    839 def _allowed_hosts_for_master_job(host_objects):
    840     """Check that the hosts are allowed for a job on master."""
    841     # We disallow the following jobs on master:
    842     #   num_shards > 1: this is a job spanning across multiple shards.
    843     #   num_shards == 1 but number of hosts on shard is less
    844     #   than total number of hosts: this is a job that spans across
    845     #   one shard and the master.
    846     shard_host_map = bucket_hosts_by_shard(host_objects)
    847     num_shards = len(shard_host_map)
    848     if num_shards > 1:
    849         return False
    850     if num_shards == 1:
    851         hosts_on_shard = shard_host_map.values()[0]
    852         assert len(hosts_on_shard) <= len(host_objects)
    853         return len(hosts_on_shard) == len(host_objects)
    854     else:
    855         return True
    856 
    857 
    858 def encode_ascii(control_file):
    859     """Force a control file to only contain ascii characters.
    860 
    861     @param control_file: Control file to encode.
    862 
    863     @returns the control file in an ascii encoding.
    864 
    865     @raises error.ControlFileMalformed: if encoding fails.
    866     """
    867     try:
    868         return control_file.encode('ascii')
    869     except UnicodeDecodeError as e:
    870         raise error.ControlFileMalformed(str(e))
    871 
    872 
    873 def get_wmatrix_url():
    874     """Get wmatrix url from config file.
    875 
    876     @returns the wmatrix url or an empty string.
    877     """
    878     return global_config.global_config.get_config_value('AUTOTEST_WEB',
    879                                                         'wmatrix_url',
    880                                                         default='')
    881 
    882 
    883 def inject_times_to_filter(start_time_key=None, end_time_key=None,
    884                          start_time_value=None, end_time_value=None,
    885                          **filter_data):
    886     """Inject the key value pairs of start and end time if provided.
    887 
    888     @param start_time_key: A string represents the filter key of start_time.
    889     @param end_time_key: A string represents the filter key of end_time.
    890     @param start_time_value: Start_time value.
    891     @param end_time_value: End_time value.
    892 
    893     @returns the injected filter_data.
    894     """
    895     if start_time_value:
    896         filter_data[start_time_key] = start_time_value
    897     if end_time_value:
    898         filter_data[end_time_key] = end_time_value
    899     return filter_data
    900 
    901 
    902 def inject_times_to_hqe_special_tasks_filters(filter_data_common,
    903                                               start_time, end_time):
    904     """Inject start and end time to hqe and special tasks filters.
    905 
    906     @param filter_data_common: Common filter for hqe and special tasks.
    907     @param start_time_key: A string represents the filter key of start_time.
    908     @param end_time_key: A string represents the filter key of end_time.
    909 
    910     @returns a pair of hqe and special tasks filters.
    911     """
    912     filter_data_special_tasks = filter_data_common.copy()
    913     return (inject_times_to_filter('started_on__gte', 'started_on__lte',
    914                                    start_time, end_time, **filter_data_common),
    915            inject_times_to_filter('time_started__gte', 'time_started__lte',
    916                                   start_time, end_time,
    917                                   **filter_data_special_tasks))
    918 
    919 
    920 def retrieve_shard(shard_hostname):
    921     """
    922     Retrieves the shard with the given hostname from the database.
    923 
    924     @param shard_hostname: Hostname of the shard to retrieve
    925 
    926     @raises models.Shard.DoesNotExist, if no shard with this hostname was found.
    927 
    928     @returns: Shard object
    929     """
    930     timer = autotest_stats.Timer('shard_heartbeat.retrieve_shard')
    931     with timer:
    932         return models.Shard.smart_get(shard_hostname)
    933 
    934 
    935 def find_records_for_shard(shard, known_job_ids, known_host_ids):
    936     """Find records that should be sent to a shard.
    937 
    938     @param shard: Shard to find records for.
    939     @param known_job_ids: List of ids of jobs the shard already has.
    940     @param known_host_ids: List of ids of hosts the shard already has.
    941 
    942     @returns: Tuple of three lists for hosts, jobs, and suite job keyvals:
    943               (hosts, jobs, suite_job_keyvals).
    944     """
    945     timer = autotest_stats.Timer('shard_heartbeat')
    946     with timer.get_client('find_hosts'):
    947         hosts = models.Host.assign_to_shard(shard, known_host_ids)
    948     with timer.get_client('find_jobs'):
    949         jobs = models.Job.assign_to_shard(shard, known_job_ids)
    950     with timer.get_client('find_suite_job_keyvals'):
    951         parent_job_ids = [job.parent_job_id for job in jobs]
    952         suite_job_keyvals = models.JobKeyval.objects.filter(
    953                 job_id__in=parent_job_ids)
    954     return hosts, jobs, suite_job_keyvals
    955 
    956 
    957 def _persist_records_with_type_sent_from_shard(
    958     shard, records, record_type, *args, **kwargs):
    959     """
    960     Handle records of a specified type that were sent to the shard master.
    961 
    962     @param shard: The shard the records were sent from.
    963     @param records: The records sent in their serialized format.
    964     @param record_type: Type of the objects represented by records.
    965     @param args: Additional arguments that will be passed on to the sanity
    966                  checks.
    967     @param kwargs: Additional arguments that will be passed on to the sanity
    968                   checks.
    969 
    970     @raises error.UnallowedRecordsSentToMaster if any of the sanity checks fail.
    971 
    972     @returns: List of primary keys of the processed records.
    973     """
    974     pks = []
    975     for serialized_record in records:
    976         pk = serialized_record['id']
    977         try:
    978             current_record = record_type.objects.get(pk=pk)
    979         except record_type.DoesNotExist:
    980             raise error.UnallowedRecordsSentToMaster(
    981                 'Object with pk %s of type %s does not exist on master.' % (
    982                     pk, record_type))
    983 
    984         current_record.sanity_check_update_from_shard(
    985             shard, serialized_record, *args, **kwargs)
    986 
    987         current_record.update_from_serialized(serialized_record)
    988         pks.append(pk)
    989     return pks
    990 
    991 
    992 def persist_records_sent_from_shard(shard, jobs, hqes):
    993     """
    994     Sanity checking then saving serialized records sent to master from shard.
    995 
    996     During heartbeats shards upload jobs and hostqueuentries. This performs
    997     some sanity checks on these and then updates the existing records for those
    998     entries with the updated ones from the heartbeat.
    999 
   1000     The sanity checks include:
   1001     - Checking if the objects sent already exist on the master.
   1002     - Checking if the objects sent were assigned to this shard.
   1003     - hostqueueentries must be sent together with their jobs.
   1004 
   1005     @param shard: The shard the records were sent from.
   1006     @param jobs: The jobs the shard sent.
   1007     @param hqes: The hostqueuentries the shart sent.
   1008 
   1009     @raises error.UnallowedRecordsSentToMaster if any of the sanity checks fail.
   1010     """
   1011     timer = autotest_stats.Timer('shard_heartbeat')
   1012     with timer.get_client('persist_jobs'):
   1013         job_ids_sent = _persist_records_with_type_sent_from_shard(
   1014                 shard, jobs, models.Job)
   1015 
   1016     with timer.get_client('persist_hqes'):
   1017         _persist_records_with_type_sent_from_shard(
   1018                 shard, hqes, models.HostQueueEntry, job_ids_sent=job_ids_sent)
   1019 
   1020 
   1021 def forward_single_host_rpc_to_shard(func):
   1022     """This decorator forwards rpc calls that modify a host to a shard.
   1023 
   1024     If a host is assigned to a shard, rpcs that change his attributes should be
   1025     forwarded to the shard.
   1026 
   1027     This assumes the first argument of the function represents a host id.
   1028 
   1029     @param func: The function to decorate
   1030 
   1031     @returns: The function to replace func with.
   1032     """
   1033     def replacement(**kwargs):
   1034         # Only keyword arguments can be accepted here, as we need the argument
   1035         # names to send the rpc. serviceHandler always provides arguments with
   1036         # their keywords, so this is not a problem.
   1037 
   1038         # A host record (identified by kwargs['id']) can be deleted in
   1039         # func(). Therefore, we should save the data that can be needed later
   1040         # before func() is called.
   1041         shard_hostname = None
   1042         host = models.Host.smart_get(kwargs['id'])
   1043         if host and host.shard:
   1044             shard_hostname = host.shard.rpc_hostname()
   1045         ret = func(**kwargs)
   1046         if shard_hostname and not server_utils.is_shard():
   1047             run_rpc_on_multiple_hostnames(func.func_name,
   1048                                           [shard_hostname],
   1049                                           **kwargs)
   1050         return ret
   1051 
   1052     return replacement
   1053 
   1054 
   1055 def fanout_rpc(host_objs, rpc_name, include_hostnames=True, **kwargs):
   1056     """Fanout the given rpc to shards of given hosts.
   1057 
   1058     @param host_objs: Host objects for the rpc.
   1059     @param rpc_name: The name of the rpc.
   1060     @param include_hostnames: If True, include the hostnames in the kwargs.
   1061         Hostnames are not always necessary, this functions is designed to
   1062         send rpcs to the shard a host is on, the rpcs themselves could be
   1063         related to labels, acls etc.
   1064     @param kwargs: The kwargs for the rpc.
   1065     """
   1066     # Figure out which hosts are on which shards.
   1067     shard_host_map = bucket_hosts_by_shard(
   1068             host_objs, rpc_hostnames=True)
   1069 
   1070     # Execute the rpc against the appropriate shards.
   1071     for shard, hostnames in shard_host_map.iteritems():
   1072         if include_hostnames:
   1073             kwargs['hosts'] = hostnames
   1074         try:
   1075             run_rpc_on_multiple_hostnames(rpc_name, [shard], **kwargs)
   1076         except:
   1077             ei = sys.exc_info()
   1078             new_exc = error.RPCException('RPC %s failed on shard %s due to '
   1079                     '%s: %s' % (rpc_name, shard, ei[0].__name__, ei[1]))
   1080             raise new_exc.__class__, new_exc, ei[2]
   1081 
   1082 
   1083 def run_rpc_on_multiple_hostnames(rpc_call, shard_hostnames, **kwargs):
   1084     """Runs an rpc to multiple AFEs
   1085 
   1086     This is i.e. used to propagate changes made to hosts after they are assigned
   1087     to a shard.
   1088 
   1089     @param rpc_call: Name of the rpc endpoint to call.
   1090     @param shard_hostnames: List of hostnames to run the rpcs on.
   1091     @param **kwargs: Keyword arguments to pass in the rpcs.
   1092     """
   1093     # Make sure this function is not called on shards but only on master.
   1094     assert not server_utils.is_shard()
   1095     for shard_hostname in shard_hostnames:
   1096         afe = frontend_wrappers.RetryingAFE(server=shard_hostname,
   1097                                             user=thread_local.get_user())
   1098         afe.run(rpc_call, **kwargs)
   1099 
   1100 
   1101 def get_label(name):
   1102     """Gets a label object using a given name.
   1103 
   1104     @param name: Label name.
   1105     @raises model.Label.DoesNotExist: when there is no label matching
   1106                                       the given name.
   1107     @return: a label object matching the given name.
   1108     """
   1109     try:
   1110         label = models.Label.smart_get(name)
   1111     except models.Label.DoesNotExist:
   1112         return None
   1113     return label
   1114 
   1115 
   1116 # TODO: hide the following rpcs under is_moblab
   1117 def moblab_only(func):
   1118     """Ensure moblab specific functions only run on Moblab devices."""
   1119     def verify(*args, **kwargs):
   1120         if not server_utils.is_moblab():
   1121             raise error.RPCException('RPC: %s can only run on Moblab Systems!',
   1122                                      func.__name__)
   1123         return func(*args, **kwargs)
   1124     return verify
   1125 
   1126 
   1127 def route_rpc_to_master(func):
   1128     """Route RPC to master AFE.
   1129 
   1130     When a shard receives an RPC decorated by this, the RPC is just
   1131     forwarded to the master.
   1132     When the master gets the RPC, the RPC function is executed.
   1133 
   1134     @param func: An RPC function to decorate
   1135 
   1136     @returns: A function replacing the RPC func.
   1137     """
   1138     argspec = inspect.getargspec(func)
   1139     if argspec.varargs is not None:
   1140         raise Exception('RPC function must not have *args.')
   1141 
   1142     @wraps(func)
   1143     def replacement(*args, **kwargs):
   1144         """We need special handling when decorating an RPC that can be called
   1145         directly using positional arguments.
   1146 
   1147         One example is rpc_interface.create_job().
   1148         rpc_interface.create_job_page_handler() calls the function using both
   1149         positional and keyword arguments.  Since frontend.RpcClient.run()
   1150         takes only keyword arguments for an RPC, positional arguments of the
   1151         RPC function need to be transformed into keyword arguments.
   1152         """
   1153         kwargs = _convert_to_kwargs_only(func, args, kwargs)
   1154         if server_utils.is_shard():
   1155             afe = frontend_wrappers.RetryingAFE(
   1156                     server=server_utils.get_global_afe_hostname(),
   1157                     user=thread_local.get_user())
   1158             return afe.run(func.func_name, **kwargs)
   1159         return func(**kwargs)
   1160 
   1161     return replacement
   1162 
   1163 
   1164 def _convert_to_kwargs_only(func, args, kwargs):
   1165     """Convert a function call's arguments to a kwargs dict.
   1166 
   1167     This is best illustrated with an example.  Given:
   1168 
   1169     def foo(a, b, **kwargs):
   1170         pass
   1171     _to_kwargs(foo, (1, 2), {'c': 3})  # corresponding to foo(1, 2, c=3)
   1172 
   1173         foo(**kwargs)
   1174 
   1175     @param func: function whose signature to use
   1176     @param args: positional arguments of call
   1177     @param kwargs: keyword arguments of call
   1178 
   1179     @returns: kwargs dict
   1180     """
   1181     argspec = inspect.getargspec(func)
   1182     # callargs looks like {'a': 1, 'b': 2, 'kwargs': {'c': 3}}
   1183     callargs = inspect.getcallargs(func, *args, **kwargs)
   1184     if argspec.keywords is None:
   1185         kwargs = {}
   1186     else:
   1187         kwargs = callargs.pop(argspec.keywords)
   1188     kwargs.update(callargs)
   1189     return kwargs
   1190 
   1191 
   1192 def get_sample_dut(board, pool):
   1193     """Get a dut with the given board and pool.
   1194 
   1195     This method is used to help to locate a dut with the given board and pool.
   1196     The dut then can be used to identify a devserver in the same subnet.
   1197 
   1198     @param board: Name of the board.
   1199     @param pool: Name of the pool.
   1200 
   1201     @return: Name of a dut with the given board and pool.
   1202     """
   1203     if not (dev_server.PREFER_LOCAL_DEVSERVER and pool and board):
   1204         return None
   1205     hosts = list(get_host_query(
   1206         multiple_labels=('pool:%s' % pool, 'board:%s' % board),
   1207         exclude_only_if_needed_labels=False,
   1208         valid_only=True,
   1209         filter_data={},
   1210     ))
   1211     if not hosts:
   1212         return None
   1213     else:
   1214         return hosts[0].hostname
   1215