Home | History | Annotate | Download | only in afe
      1 # pylint: disable=missing-docstring
      2 """
      3 Utility functions for rpc_interface.py.  We keep them in a separate file so that
      4 only RPC interface functions go into that file.
      5 """
      6 
      7 __author__ = 'showard (at] google.com (Steve Howard)'
      8 
      9 import collections
     10 import datetime
     11 from functools import wraps
     12 import inspect
     13 import os
     14 import sys
     15 import django.db.utils
     16 import django.http
     17 
     18 from autotest_lib.frontend import thread_local
     19 from autotest_lib.frontend.afe import models, model_logic
     20 from autotest_lib.client.common_lib import control_data, error
     21 from autotest_lib.client.common_lib import global_config
     22 from autotest_lib.client.common_lib import time_utils
     23 from autotest_lib.client.common_lib.cros import dev_server
     24 from autotest_lib.server import utils as server_utils
     25 from autotest_lib.server.cros import provision
     26 from autotest_lib.server.cros.dynamic_suite import frontend_wrappers
     27 
     28 NULL_DATETIME = datetime.datetime.max
     29 NULL_DATE = datetime.date.max
     30 DUPLICATE_KEY_MSG = 'Duplicate entry'
     31 
     32 def prepare_for_serialization(objects):
     33     """
     34     Prepare Python objects to be returned via RPC.
     35     @param objects: objects to be prepared.
     36     """
     37     if (isinstance(objects, list) and len(objects) and
     38         isinstance(objects[0], dict) and 'id' in objects[0]):
     39         objects = _gather_unique_dicts(objects)
     40     return _prepare_data(objects)
     41 
     42 
     43 def prepare_rows_as_nested_dicts(query, nested_dict_column_names):
     44     """
     45     Prepare a Django query to be returned via RPC as a sequence of nested
     46     dictionaries.
     47 
     48     @param query - A Django model query object with a select_related() method.
     49     @param nested_dict_column_names - A list of column/attribute names for the
     50             rows returned by query to expand into nested dictionaries using
     51             their get_object_dict() method when not None.
     52 
     53     @returns An list suitable to returned in an RPC.
     54     """
     55     all_dicts = []
     56     for row in query.select_related():
     57         row_dict = row.get_object_dict()
     58         for column in nested_dict_column_names:
     59             if row_dict[column] is not None:
     60                 row_dict[column] = getattr(row, column).get_object_dict()
     61         all_dicts.append(row_dict)
     62     return prepare_for_serialization(all_dicts)
     63 
     64 
     65 def _prepare_data(data):
     66     """
     67     Recursively process data structures, performing necessary type
     68     conversions to values in data to allow for RPC serialization:
     69     -convert datetimes to strings
     70     -convert tuples and sets to lists
     71     """
     72     if isinstance(data, dict):
     73         new_data = {}
     74         for key, value in data.iteritems():
     75             new_data[key] = _prepare_data(value)
     76         return new_data
     77     elif (isinstance(data, list) or isinstance(data, tuple) or
     78           isinstance(data, set)):
     79         return [_prepare_data(item) for item in data]
     80     elif isinstance(data, datetime.date):
     81         if data is NULL_DATETIME or data is NULL_DATE:
     82             return None
     83         return str(data)
     84     else:
     85         return data
     86 
     87 
     88 def fetchall_as_list_of_dicts(cursor):
     89     """
     90     Converts each row in the cursor to a dictionary so that values can be read
     91     by using the column name.
     92     @param cursor: The database cursor to read from.
     93     @returns: A list of each row in the cursor as a dictionary.
     94     """
     95     desc = cursor.description
     96     return [ dict(zip([col[0] for col in desc], row))
     97              for row in cursor.fetchall() ]
     98 
     99 
    100 def raw_http_response(response_data, content_type=None):
    101     response = django.http.HttpResponse(response_data, mimetype=content_type)
    102     response['Content-length'] = str(len(response.content))
    103     return response
    104 
    105 
    106 def _gather_unique_dicts(dict_iterable):
    107     """\
    108     Pick out unique objects (by ID) from an iterable of object dicts.
    109     """
    110     objects = collections.OrderedDict()
    111     for obj in dict_iterable:
    112         objects.setdefault(obj['id'], obj)
    113     return objects.values()
    114 
    115 
    116 def extra_job_status_filters(not_yet_run=False, running=False, finished=False):
    117     """\
    118     Generate a SQL WHERE clause for job status filtering, and return it in
    119     a dict of keyword args to pass to query.extra().
    120     * not_yet_run: all HQEs are Queued
    121     * finished: all HQEs are complete
    122     * running: everything else
    123     """
    124     if not (not_yet_run or running or finished):
    125         return {}
    126     not_queued = ('(SELECT job_id FROM afe_host_queue_entries '
    127                   'WHERE status != "%s")'
    128                   % models.HostQueueEntry.Status.QUEUED)
    129     not_finished = ('(SELECT job_id FROM afe_host_queue_entries '
    130                     'WHERE not complete)')
    131 
    132     where = []
    133     if not_yet_run:
    134         where.append('id NOT IN ' + not_queued)
    135     if running:
    136         where.append('(id IN %s) AND (id IN %s)' % (not_queued, not_finished))
    137     if finished:
    138         where.append('id NOT IN ' + not_finished)
    139     return {'where': [' OR '.join(['(%s)' % x for x in where])]}
    140 
    141 
    142 def extra_job_type_filters(extra_args, suite=False,
    143                            sub=False, standalone=False):
    144     """\
    145     Generate a SQL WHERE clause for job status filtering, and return it in
    146     a dict of keyword args to pass to query.extra().
    147 
    148     param extra_args: a dict of existing extra_args.
    149 
    150     No more than one of the parameters should be passed as True:
    151     * suite: job which is parent of other jobs
    152     * sub: job with a parent job
    153     * standalone: job with no child or parent jobs
    154     """
    155     assert not ((suite and sub) or
    156                 (suite and standalone) or
    157                 (sub and standalone)), ('Cannot specify more than one '
    158                                         'filter to this function')
    159 
    160     where = extra_args.get('where', [])
    161     parent_job_id = ('DISTINCT parent_job_id')
    162     child_job_id = ('id')
    163     filter_common = ('(SELECT %s FROM afe_jobs '
    164                      'WHERE parent_job_id IS NOT NULL)')
    165 
    166     if suite:
    167         where.append('id IN ' + filter_common % parent_job_id)
    168     elif sub:
    169         where.append('id IN ' + filter_common % child_job_id)
    170     elif standalone:
    171         where.append('NOT EXISTS (SELECT 1 from afe_jobs AS sub_query '
    172                      'WHERE parent_job_id IS NOT NULL'
    173                      ' AND (sub_query.parent_job_id=afe_jobs.id'
    174                      ' OR sub_query.id=afe_jobs.id))')
    175     else:
    176         return extra_args
    177 
    178     extra_args['where'] = where
    179     return extra_args
    180 
    181 
    182 
    183 def extra_host_filters(multiple_labels=()):
    184     """\
    185     Generate SQL WHERE clauses for matching hosts in an intersection of
    186     labels.
    187     """
    188     extra_args = {}
    189     where_str = ('afe_hosts.id in (select host_id from afe_hosts_labels '
    190                  'where label_id=%s)')
    191     extra_args['where'] = [where_str] * len(multiple_labels)
    192     extra_args['params'] = [models.Label.smart_get(label).id
    193                             for label in multiple_labels]
    194     return extra_args
    195 
    196 
    197 def get_host_query(multiple_labels, exclude_only_if_needed_labels,
    198                    valid_only, filter_data):
    199     if valid_only:
    200         query = models.Host.valid_objects.all()
    201     else:
    202         query = models.Host.objects.all()
    203 
    204     if exclude_only_if_needed_labels:
    205         only_if_needed_labels = models.Label.valid_objects.filter(
    206             only_if_needed=True)
    207         if only_if_needed_labels.count() > 0:
    208             only_if_needed_ids = ','.join(
    209                     str(label['id'])
    210                     for label in only_if_needed_labels.values('id'))
    211             query = models.Host.objects.add_join(
    212                 query, 'afe_hosts_labels', join_key='host_id',
    213                 join_condition=('afe_hosts_labels_exclude_OIN.label_id IN (%s)'
    214                                 % only_if_needed_ids),
    215                 suffix='_exclude_OIN', exclude=True)
    216     try:
    217         assert 'extra_args' not in filter_data
    218         filter_data['extra_args'] = extra_host_filters(multiple_labels)
    219         return models.Host.query_objects(filter_data, initial_query=query)
    220     except models.Label.DoesNotExist:
    221         return models.Host.objects.none()
    222 
    223 
    224 class InconsistencyException(Exception):
    225     'Raised when a list of objects does not have a consistent value'
    226 
    227 
    228 def get_consistent_value(objects, field):
    229     if not objects:
    230         # well a list of nothing is consistent
    231         return None
    232 
    233     value = getattr(objects[0], field)
    234     for obj in objects:
    235         this_value = getattr(obj, field)
    236         if this_value != value:
    237             raise InconsistencyException(objects[0], obj)
    238     return value
    239 
    240 
    241 def afe_test_dict_to_test_object(test_dict):
    242     if not isinstance(test_dict, dict):
    243         return test_dict
    244 
    245     numerized_dict = {}
    246     for key, value in test_dict.iteritems():
    247         try:
    248             numerized_dict[key] = int(value)
    249         except (ValueError, TypeError):
    250             numerized_dict[key] = value
    251 
    252     return type('TestObject', (object,), numerized_dict)
    253 
    254 
    255 def _check_is_server_test(test_type):
    256     """Checks if the test type is a server test.
    257 
    258     @param test_type The test type in enum integer or string.
    259 
    260     @returns A boolean to identify if the test type is server test.
    261     """
    262     if test_type is not None:
    263         if isinstance(test_type, basestring):
    264             try:
    265                 test_type = control_data.CONTROL_TYPE.get_value(test_type)
    266             except AttributeError:
    267                 return False
    268         return (test_type == control_data.CONTROL_TYPE.SERVER)
    269     return False
    270 
    271 
    272 def prepare_generate_control_file(tests, profilers, db_tests=True):
    273     if db_tests:
    274         test_objects = [models.Test.smart_get(test) for test in tests]
    275     else:
    276         test_objects = [afe_test_dict_to_test_object(test) for test in tests]
    277 
    278     profiler_objects = [models.Profiler.smart_get(profiler)
    279                         for profiler in profilers]
    280     # ensure tests are all the same type
    281     try:
    282         test_type = get_consistent_value(test_objects, 'test_type')
    283     except InconsistencyException, exc:
    284         test1, test2 = exc.args
    285         raise model_logic.ValidationError(
    286             {'tests' : 'You cannot run both test_suites and server-side '
    287              'tests together (tests %s and %s differ' % (
    288             test1.name, test2.name)})
    289 
    290     is_server = _check_is_server_test(test_type)
    291     if test_objects:
    292         synch_count = max(test.sync_count for test in test_objects)
    293     else:
    294         synch_count = 1
    295 
    296     if db_tests:
    297         dependencies = set(label.name for label
    298                            in models.Label.objects.filter(test__in=test_objects))
    299     else:
    300         dependencies = reduce(
    301                 set.union, [set(test.dependencies) for test in test_objects])
    302 
    303     cf_info = dict(is_server=is_server, synch_count=synch_count,
    304                    dependencies=list(dependencies))
    305     return cf_info, test_objects, profiler_objects
    306 
    307 
    308 def check_job_dependencies(host_objects, job_dependencies):
    309     """
    310     Check that a set of machines satisfies a job's dependencies.
    311     host_objects: list of models.Host objects
    312     job_dependencies: list of names of labels
    313     """
    314     # check that hosts satisfy dependencies
    315     host_ids = [host.id for host in host_objects]
    316     hosts_in_job = models.Host.objects.filter(id__in=host_ids)
    317     ok_hosts = hosts_in_job
    318     for index, dependency in enumerate(job_dependencies):
    319         if not provision.is_for_special_action(dependency):
    320             ok_hosts = ok_hosts.filter(labels__name=dependency)
    321     failing_hosts = (set(host.hostname for host in host_objects) -
    322                      set(host.hostname for host in ok_hosts))
    323     if failing_hosts:
    324         raise model_logic.ValidationError(
    325             {'hosts' : 'Host(s) failed to meet job dependencies (' +
    326                        (', '.join(job_dependencies)) + '): ' +
    327                        (', '.join(failing_hosts))})
    328 
    329 
    330 def check_job_metahost_dependencies(metahost_objects, job_dependencies):
    331     """
    332     Check that at least one machine within the metahost spec satisfies the job's
    333     dependencies.
    334 
    335     @param metahost_objects A list of label objects representing the metahosts.
    336     @param job_dependencies A list of strings of the required label names.
    337     @raises NoEligibleHostException If a metahost cannot run the job.
    338     """
    339     for metahost in metahost_objects:
    340         hosts = models.Host.objects.filter(labels=metahost)
    341         for label_name in job_dependencies:
    342             if not provision.is_for_special_action(label_name):
    343                 hosts = hosts.filter(labels__name=label_name)
    344         if not any(hosts):
    345             raise error.NoEligibleHostException("No hosts within %s satisfy %s."
    346                     % (metahost.name, ', '.join(job_dependencies)))
    347 
    348 
    349 def _execution_key_for(host_queue_entry):
    350     return (host_queue_entry.job.id, host_queue_entry.execution_subdir)
    351 
    352 
    353 def check_abort_synchronous_jobs(host_queue_entries):
    354     # ensure user isn't aborting part of a synchronous autoserv execution
    355     count_per_execution = {}
    356     for queue_entry in host_queue_entries:
    357         key = _execution_key_for(queue_entry)
    358         count_per_execution.setdefault(key, 0)
    359         count_per_execution[key] += 1
    360 
    361     for queue_entry in host_queue_entries:
    362         if not queue_entry.execution_subdir:
    363             continue
    364         execution_count = count_per_execution[_execution_key_for(queue_entry)]
    365         if execution_count < queue_entry.job.synch_count:
    366             raise model_logic.ValidationError(
    367                 {'' : 'You cannot abort part of a synchronous job execution '
    368                       '(%d/%s), %d included, %d expected'
    369                       % (queue_entry.job.id, queue_entry.execution_subdir,
    370                          execution_count, queue_entry.job.synch_count)})
    371 
    372 
    373 def check_modify_host(update_data):
    374     """
    375     Sanity check modify_host* requests.
    376 
    377     @param update_data: A dictionary with the changes to make to a host
    378             or hosts.
    379     """
    380     # Only the scheduler (monitor_db) is allowed to modify Host status.
    381     # Otherwise race conditions happen as a hosts state is changed out from
    382     # beneath tasks being run on a host.
    383     if 'status' in update_data:
    384         raise model_logic.ValidationError({
    385                 'status': 'Host status can not be modified by the frontend.'})
    386 
    387 
    388 def check_modify_host_locking(host, update_data):
    389     """
    390     Checks when locking/unlocking has been requested if the host is already
    391     locked/unlocked.
    392 
    393     @param host: models.Host object to be modified
    394     @param update_data: A dictionary with the changes to make to the host.
    395     """
    396     locked = update_data.get('locked', None)
    397     lock_reason = update_data.get('lock_reason', None)
    398     if locked is not None:
    399         if locked and host.locked:
    400             raise model_logic.ValidationError({
    401                     'locked': 'Host %s already locked by %s on %s.' %
    402                     (host.hostname, host.locked_by, host.lock_time)})
    403         if not locked and not host.locked:
    404             raise model_logic.ValidationError({
    405                     'locked': 'Host %s already unlocked.' % host.hostname})
    406         if locked and not lock_reason and not host.locked:
    407             raise model_logic.ValidationError({
    408                     'locked': 'Please provide a reason for locking Host %s' %
    409                     host.hostname})
    410 
    411 
    412 def get_motd():
    413     dirname = os.path.dirname(__file__)
    414     filename = os.path.join(dirname, "..", "..", "motd.txt")
    415     text = ''
    416     try:
    417         fp = open(filename, "r")
    418         try:
    419             text = fp.read()
    420         finally:
    421             fp.close()
    422     except:
    423         pass
    424 
    425     return text
    426 
    427 
    428 def _get_metahost_counts(metahost_objects):
    429     metahost_counts = {}
    430     for metahost in metahost_objects:
    431         metahost_counts.setdefault(metahost, 0)
    432         metahost_counts[metahost] += 1
    433     return metahost_counts
    434 
    435 
    436 def get_job_info(job, preserve_metahosts=False, queue_entry_filter_data=None):
    437     hosts = []
    438     one_time_hosts = []
    439     meta_hosts = []
    440     hostless = False
    441 
    442     queue_entries = job.hostqueueentry_set.all()
    443     if queue_entry_filter_data:
    444         queue_entries = models.HostQueueEntry.query_objects(
    445             queue_entry_filter_data, initial_query=queue_entries)
    446 
    447     for queue_entry in queue_entries:
    448         if (queue_entry.host and (preserve_metahosts or
    449                                   not queue_entry.meta_host)):
    450             if queue_entry.deleted:
    451                 continue
    452             if queue_entry.host.invalid:
    453                 one_time_hosts.append(queue_entry.host)
    454             else:
    455                 hosts.append(queue_entry.host)
    456         elif queue_entry.meta_host:
    457             meta_hosts.append(queue_entry.meta_host)
    458         else:
    459             hostless = True
    460 
    461     meta_host_counts = _get_metahost_counts(meta_hosts)
    462 
    463     info = dict(dependencies=[label.name for label
    464                               in job.dependency_labels.all()],
    465                 hosts=hosts,
    466                 meta_hosts=meta_hosts,
    467                 meta_host_counts=meta_host_counts,
    468                 one_time_hosts=one_time_hosts,
    469                 hostless=hostless)
    470     return info
    471 
    472 
    473 def check_for_duplicate_hosts(host_objects):
    474     host_counts = collections.Counter(host_objects)
    475     duplicate_hostnames = {host.hostname
    476                            for host, count in host_counts.iteritems()
    477                            if count > 1}
    478     if duplicate_hostnames:
    479         raise model_logic.ValidationError(
    480                 {'hosts' : 'Duplicate hosts: %s'
    481                  % ', '.join(duplicate_hostnames)})
    482 
    483 
    484 def create_new_job(owner, options, host_objects, metahost_objects):
    485     all_host_objects = host_objects + metahost_objects
    486     dependencies = options.get('dependencies', [])
    487     synch_count = options.get('synch_count')
    488 
    489     if synch_count is not None and synch_count > len(all_host_objects):
    490         raise model_logic.ValidationError(
    491                 {'hosts':
    492                  'only %d hosts provided for job with synch_count = %d' %
    493                  (len(all_host_objects), synch_count)})
    494 
    495     check_for_duplicate_hosts(host_objects)
    496 
    497     for label_name in dependencies:
    498         if provision.is_for_special_action(label_name):
    499             # TODO: We could save a few queries
    500             # if we had a bulk ensure-label-exists function, which used
    501             # a bulk .get() call. The win is probably very small.
    502             _ensure_label_exists(label_name)
    503 
    504     # This only checks targeted hosts, not hosts eligible due to the metahost
    505     check_job_dependencies(host_objects, dependencies)
    506     check_job_metahost_dependencies(metahost_objects, dependencies)
    507 
    508     options['dependencies'] = list(
    509             models.Label.objects.filter(name__in=dependencies))
    510 
    511     job = models.Job.create(owner=owner, options=options,
    512                             hosts=all_host_objects)
    513     job.queue(all_host_objects,
    514               is_template=options.get('is_template', False))
    515     return job.id
    516 
    517 
    518 def _ensure_label_exists(name):
    519     """
    520     Ensure that a label called |name| exists in the Django models.
    521 
    522     This function is to be called from within afe rpcs only, as an
    523     alternative to server.cros.provision.ensure_label_exists(...). It works
    524     by Django model manipulation, rather than by making another create_label
    525     rpc call.
    526 
    527     @param name: the label to check for/create.
    528     @raises ValidationError: There was an error in the response that was
    529                              not because the label already existed.
    530     @returns True is a label was created, False otherwise.
    531     """
    532     # Make sure this function is not called on shards but only on master.
    533     assert not server_utils.is_shard()
    534     try:
    535         models.Label.objects.get(name=name)
    536     except models.Label.DoesNotExist:
    537         try:
    538             new_label = models.Label.objects.create(name=name)
    539             new_label.save()
    540             return True
    541         except django.db.utils.IntegrityError as e:
    542             # It is possible that another suite/test already
    543             # created the label between the check and save.
    544             if DUPLICATE_KEY_MSG in str(e):
    545                 return False
    546             else:
    547                 raise
    548     return False
    549 
    550 
    551 def find_platform(host):
    552     """
    553     Figure out the platform name for the given host
    554     object.  If none, the return value for either will be None.
    555 
    556     @returns platform name for the given host.
    557     """
    558     platforms = [label.name for label in host.label_list if label.platform]
    559     if not platforms:
    560         platform = None
    561     else:
    562         platform = platforms[0]
    563     if len(platforms) > 1:
    564         raise ValueError('Host %s has more than one platform: %s' %
    565                          (host.hostname, ', '.join(platforms)))
    566     return platform
    567 
    568 
    569 # support for get_host_queue_entries_and_special_tasks()
    570 
    571 def _common_entry_to_dict(entry, type, job_dict, exec_path, status, started_on):
    572     return dict(type=type,
    573                 host=entry['host'],
    574                 job=job_dict,
    575                 execution_path=exec_path,
    576                 status=status,
    577                 started_on=started_on,
    578                 id=str(entry['id']) + type,
    579                 oid=entry['id'])
    580 
    581 
    582 def _special_task_to_dict(task, queue_entries):
    583     """Transforms a special task dictionary to another form of dictionary.
    584 
    585     @param task           Special task as a dictionary type
    586     @param queue_entries  Host queue entries as a list of dictionaries.
    587 
    588     @return Transformed dictionary for a special task.
    589     """
    590     job_dict = None
    591     if task['queue_entry']:
    592         # Scan queue_entries to get the job detail info.
    593         for qentry in queue_entries:
    594             if task['queue_entry']['id'] == qentry['id']:
    595                 job_dict = qentry['job']
    596                 break
    597         # If not found, get it from DB.
    598         if job_dict is None:
    599             job = models.Job.objects.get(id=task['queue_entry']['job'])
    600             job_dict = job.get_object_dict()
    601 
    602     exec_path = server_utils.get_special_task_exec_path(
    603             task['host']['hostname'], task['id'], task['task'],
    604             time_utils.time_string_to_datetime(task['time_requested']))
    605     status = server_utils.get_special_task_status(
    606             task['is_complete'], task['success'], task['is_active'])
    607     return _common_entry_to_dict(task, task['task'], job_dict,
    608             exec_path, status, task['time_started'])
    609 
    610 
    611 def _queue_entry_to_dict(queue_entry):
    612     job_dict = queue_entry['job']
    613     tag = server_utils.get_job_tag(job_dict['id'], job_dict['owner'])
    614     exec_path = server_utils.get_hqe_exec_path(tag,
    615                                                queue_entry['execution_subdir'])
    616     return _common_entry_to_dict(queue_entry, 'Job', job_dict, exec_path,
    617             queue_entry['status'], queue_entry['started_on'])
    618 
    619 
    620 def prepare_host_queue_entries_and_special_tasks(interleaved_entries,
    621                                                  queue_entries):
    622     """
    623     Prepare for serialization the interleaved entries of host queue entries
    624     and special tasks.
    625     Each element in the entries is a dictionary type.
    626     The special task dictionary has only a job id for a job and lacks
    627     the detail of the job while the host queue entry dictionary has.
    628     queue_entries is used to look up the job detail info.
    629 
    630     @param interleaved_entries  Host queue entries and special tasks as a list
    631                                 of dictionaries.
    632     @param queue_entries        Host queue entries as a list of dictionaries.
    633 
    634     @return A post-processed list of dictionaries that is to be serialized.
    635     """
    636     dict_list = []
    637     for e in interleaved_entries:
    638         # Distinguish the two mixed entries based on the existence of
    639         # the key "task". If an entry has the key, the entry is for
    640         # special task. Otherwise, host queue entry.
    641         if 'task' in e:
    642             dict_list.append(_special_task_to_dict(e, queue_entries))
    643         else:
    644             dict_list.append(_queue_entry_to_dict(e))
    645     return prepare_for_serialization(dict_list)
    646 
    647 
    648 def _compute_next_job_for_tasks(queue_entries, special_tasks):
    649     """
    650     For each task, try to figure out the next job that ran after that task.
    651     This is done using two pieces of information:
    652     * if the task has a queue entry, we can use that entry's job ID.
    653     * if the task has a time_started, we can try to compare that against the
    654       started_on field of queue_entries. this isn't guaranteed to work perfectly
    655       since queue_entries may also have null started_on values.
    656     * if the task has neither, or if use of time_started fails, just use the
    657       last computed job ID.
    658 
    659     @param queue_entries    Host queue entries as a list of dictionaries.
    660     @param special_tasks    Special tasks as a list of dictionaries.
    661     """
    662     next_job_id = None # most recently computed next job
    663     hqe_index = 0 # index for scanning by started_on times
    664     for task in special_tasks:
    665         if task['queue_entry']:
    666             next_job_id = task['queue_entry']['job']
    667         elif task['time_started'] is not None:
    668             for queue_entry in queue_entries[hqe_index:]:
    669                 if queue_entry['started_on'] is None:
    670                     continue
    671                 t1 = time_utils.time_string_to_datetime(
    672                         queue_entry['started_on'])
    673                 t2 = time_utils.time_string_to_datetime(task['time_started'])
    674                 if t1 < t2:
    675                     break
    676                 next_job_id = queue_entry['job']['id']
    677 
    678         task['next_job_id'] = next_job_id
    679 
    680         # advance hqe_index to just after next_job_id
    681         if next_job_id is not None:
    682             for queue_entry in queue_entries[hqe_index:]:
    683                 if queue_entry['job']['id'] < next_job_id:
    684                     break
    685                 hqe_index += 1
    686 
    687 
    688 def interleave_entries(queue_entries, special_tasks):
    689     """
    690     Both lists should be ordered by descending ID.
    691     """
    692     _compute_next_job_for_tasks(queue_entries, special_tasks)
    693 
    694     # start with all special tasks that've run since the last job
    695     interleaved_entries = []
    696     for task in special_tasks:
    697         if task['next_job_id'] is not None:
    698             break
    699         interleaved_entries.append(task)
    700 
    701     # now interleave queue entries with the remaining special tasks
    702     special_task_index = len(interleaved_entries)
    703     for queue_entry in queue_entries:
    704         interleaved_entries.append(queue_entry)
    705         # add all tasks that ran between this job and the previous one
    706         for task in special_tasks[special_task_index:]:
    707             if task['next_job_id'] < queue_entry['job']['id']:
    708                 break
    709             interleaved_entries.append(task)
    710             special_task_index += 1
    711 
    712     return interleaved_entries
    713 
    714 
    715 def bucket_hosts_by_shard(host_objs, rpc_hostnames=False):
    716     """Figure out which hosts are on which shards.
    717 
    718     @param host_objs: A list of host objects.
    719     @param rpc_hostnames: If True, the rpc_hostnames of a shard are returned
    720         instead of the 'real' shard hostnames. This only matters for testing
    721         environments.
    722 
    723     @return: A map of shard hostname: list of hosts on the shard.
    724     """
    725     shard_host_map = collections.defaultdict(list)
    726     for host in host_objs:
    727         if host.shard:
    728             shard_name = (host.shard.rpc_hostname() if rpc_hostnames
    729                           else host.shard.hostname)
    730             shard_host_map[shard_name].append(host.hostname)
    731     return shard_host_map
    732 
    733 
    734 def create_job_common(
    735         name,
    736         priority,
    737         control_type,
    738         control_file=None,
    739         hosts=(),
    740         meta_hosts=(),
    741         one_time_hosts=(),
    742         synch_count=None,
    743         is_template=False,
    744         timeout=None,
    745         timeout_mins=None,
    746         max_runtime_mins=None,
    747         run_verify=True,
    748         email_list='',
    749         dependencies=(),
    750         reboot_before=None,
    751         reboot_after=None,
    752         parse_failed_repair=None,
    753         hostless=False,
    754         keyvals=None,
    755         drone_set=None,
    756         parent_job_id=None,
    757         test_retry=0,
    758         run_reset=True,
    759         require_ssp=None):
    760     #pylint: disable-msg=C0111
    761     """
    762     Common code between creating "standard" jobs and creating parameterized jobs
    763     """
    764     # input validation
    765     host_args_passed = any((hosts, meta_hosts, one_time_hosts))
    766     if hostless:
    767         if host_args_passed:
    768             raise model_logic.ValidationError({
    769                     'hostless': 'Hostless jobs cannot include any hosts!'})
    770         if control_type != control_data.CONTROL_TYPE_NAMES.SERVER:
    771             raise model_logic.ValidationError({
    772                     'control_type': 'Hostless jobs cannot use client-side '
    773                                     'control files'})
    774     elif not host_args_passed:
    775         raise model_logic.ValidationError({
    776             'arguments' : "For host jobs, you must pass at least one of"
    777                           " 'hosts', 'meta_hosts', 'one_time_hosts'."
    778             })
    779     label_objects = list(models.Label.objects.filter(name__in=meta_hosts))
    780 
    781     # convert hostnames & meta hosts to host/label objects
    782     host_objects = models.Host.smart_get_bulk(hosts)
    783     _validate_host_job_sharding(host_objects)
    784     for host in one_time_hosts:
    785         this_host = models.Host.create_one_time_host(host)
    786         host_objects.append(this_host)
    787 
    788     metahost_objects = []
    789     meta_host_labels_by_name = {label.name: label for label in label_objects}
    790     for label_name in meta_hosts:
    791         if label_name in meta_host_labels_by_name:
    792             metahost_objects.append(meta_host_labels_by_name[label_name])
    793         else:
    794             raise model_logic.ValidationError(
    795                 {'meta_hosts' : 'Label "%s" not found' % label_name})
    796 
    797     options = dict(name=name,
    798                    priority=priority,
    799                    control_file=control_file,
    800                    control_type=control_type,
    801                    is_template=is_template,
    802                    timeout=timeout,
    803                    timeout_mins=timeout_mins,
    804                    max_runtime_mins=max_runtime_mins,
    805                    synch_count=synch_count,
    806                    run_verify=run_verify,
    807                    email_list=email_list,
    808                    dependencies=dependencies,
    809                    reboot_before=reboot_before,
    810                    reboot_after=reboot_after,
    811                    parse_failed_repair=parse_failed_repair,
    812                    keyvals=keyvals,
    813                    drone_set=drone_set,
    814                    parent_job_id=parent_job_id,
    815                    test_retry=test_retry,
    816                    run_reset=run_reset,
    817                    require_ssp=require_ssp)
    818 
    819     return create_new_job(owner=models.User.current_user().login,
    820                           options=options,
    821                           host_objects=host_objects,
    822                           metahost_objects=metahost_objects)
    823 
    824 
    825 def _validate_host_job_sharding(host_objects):
    826     """Check that the hosts obey job sharding rules."""
    827     if not (server_utils.is_shard()
    828             or _allowed_hosts_for_master_job(host_objects)):
    829         shard_host_map = bucket_hosts_by_shard(host_objects)
    830         raise ValueError(
    831                 'The following hosts are on shard(s), please create '
    832                 'seperate jobs for hosts on each shard: %s ' %
    833                 shard_host_map)
    834 
    835 
    836 def _allowed_hosts_for_master_job(host_objects):
    837     """Check that the hosts are allowed for a job on master."""
    838     # We disallow the following jobs on master:
    839     #   num_shards > 1: this is a job spanning across multiple shards.
    840     #   num_shards == 1 but number of hosts on shard is less
    841     #   than total number of hosts: this is a job that spans across
    842     #   one shard and the master.
    843     shard_host_map = bucket_hosts_by_shard(host_objects)
    844     num_shards = len(shard_host_map)
    845     if num_shards > 1:
    846         return False
    847     if num_shards == 1:
    848         hosts_on_shard = shard_host_map.values()[0]
    849         assert len(hosts_on_shard) <= len(host_objects)
    850         return len(hosts_on_shard) == len(host_objects)
    851     else:
    852         return True
    853 
    854 
    855 def encode_ascii(control_file):
    856     """Force a control file to only contain ascii characters.
    857 
    858     @param control_file: Control file to encode.
    859 
    860     @returns the control file in an ascii encoding.
    861 
    862     @raises error.ControlFileMalformed: if encoding fails.
    863     """
    864     try:
    865         return control_file.encode('ascii')
    866     except UnicodeDecodeError as e:
    867         raise error.ControlFileMalformed(str(e))
    868 
    869 
    870 def get_wmatrix_url():
    871     """Get wmatrix url from config file.
    872 
    873     @returns the wmatrix url or an empty string.
    874     """
    875     return global_config.global_config.get_config_value('AUTOTEST_WEB',
    876                                                         'wmatrix_url',
    877                                                         default='')
    878 
    879 
    880 def inject_times_to_filter(start_time_key=None, end_time_key=None,
    881                          start_time_value=None, end_time_value=None,
    882                          **filter_data):
    883     """Inject the key value pairs of start and end time if provided.
    884 
    885     @param start_time_key: A string represents the filter key of start_time.
    886     @param end_time_key: A string represents the filter key of end_time.
    887     @param start_time_value: Start_time value.
    888     @param end_time_value: End_time value.
    889 
    890     @returns the injected filter_data.
    891     """
    892     if start_time_value:
    893         filter_data[start_time_key] = start_time_value
    894     if end_time_value:
    895         filter_data[end_time_key] = end_time_value
    896     return filter_data
    897 
    898 
    899 def inject_times_to_hqe_special_tasks_filters(filter_data_common,
    900                                               start_time, end_time):
    901     """Inject start and end time to hqe and special tasks filters.
    902 
    903     @param filter_data_common: Common filter for hqe and special tasks.
    904     @param start_time_key: A string represents the filter key of start_time.
    905     @param end_time_key: A string represents the filter key of end_time.
    906 
    907     @returns a pair of hqe and special tasks filters.
    908     """
    909     filter_data_special_tasks = filter_data_common.copy()
    910     return (inject_times_to_filter('started_on__gte', 'started_on__lte',
    911                                    start_time, end_time, **filter_data_common),
    912            inject_times_to_filter('time_started__gte', 'time_started__lte',
    913                                   start_time, end_time,
    914                                   **filter_data_special_tasks))
    915 
    916 
    917 def retrieve_shard(shard_hostname):
    918     """
    919     Retrieves the shard with the given hostname from the database.
    920 
    921     @param shard_hostname: Hostname of the shard to retrieve
    922 
    923     @raises models.Shard.DoesNotExist, if no shard with this hostname was found.
    924 
    925     @returns: Shard object
    926     """
    927     return models.Shard.smart_get(shard_hostname)
    928 
    929 
    930 def find_records_for_shard(shard, known_job_ids, known_host_ids):
    931     """Find records that should be sent to a shard.
    932 
    933     @param shard: Shard to find records for.
    934     @param known_job_ids: List of ids of jobs the shard already has.
    935     @param known_host_ids: List of ids of hosts the shard already has.
    936 
    937     @returns: Tuple of lists:
    938               (hosts, jobs, suite_job_keyvals, invalid_host_ids)
    939     """
    940     hosts, invalid_host_ids = models.Host.assign_to_shard(
    941             shard, known_host_ids)
    942     jobs = models.Job.assign_to_shard(shard, known_job_ids)
    943     parent_job_ids = [job.parent_job_id for job in jobs]
    944     suite_job_keyvals = models.JobKeyval.objects.filter(
    945             job_id__in=parent_job_ids)
    946     return hosts, jobs, suite_job_keyvals, invalid_host_ids
    947 
    948 
    949 def _persist_records_with_type_sent_from_shard(
    950     shard, records, record_type, *args, **kwargs):
    951     """
    952     Handle records of a specified type that were sent to the shard master.
    953 
    954     @param shard: The shard the records were sent from.
    955     @param records: The records sent in their serialized format.
    956     @param record_type: Type of the objects represented by records.
    957     @param args: Additional arguments that will be passed on to the sanity
    958                  checks.
    959     @param kwargs: Additional arguments that will be passed on to the sanity
    960                   checks.
    961 
    962     @raises error.UnallowedRecordsSentToMaster if any of the sanity checks fail.
    963 
    964     @returns: List of primary keys of the processed records.
    965     """
    966     pks = []
    967     for serialized_record in records:
    968         pk = serialized_record['id']
    969         try:
    970             current_record = record_type.objects.get(pk=pk)
    971         except record_type.DoesNotExist:
    972             raise error.UnallowedRecordsSentToMaster(
    973                 'Object with pk %s of type %s does not exist on master.' % (
    974                     pk, record_type))
    975 
    976         try:
    977             current_record.sanity_check_update_from_shard(
    978                 shard, serialized_record, *args, **kwargs)
    979         except error.IgnorableUnallowedRecordsSentToMaster:
    980             # An illegal record change was attempted, but it was of a non-fatal
    981             # variety. Silently skip this record.
    982             pass
    983         else:
    984             current_record.update_from_serialized(serialized_record)
    985             pks.append(pk)
    986 
    987     return pks
    988 
    989 
    990 def persist_records_sent_from_shard(shard, jobs, hqes):
    991     """
    992     Sanity checking then saving serialized records sent to master from shard.
    993 
    994     During heartbeats shards upload jobs and hostqueuentries. This performs
    995     some sanity checks on these and then updates the existing records for those
    996     entries with the updated ones from the heartbeat.
    997 
    998     The sanity checks include:
    999     - Checking if the objects sent already exist on the master.
   1000     - Checking if the objects sent were assigned to this shard.
   1001     - hostqueueentries must be sent together with their jobs.
   1002 
   1003     @param shard: The shard the records were sent from.
   1004     @param jobs: The jobs the shard sent.
   1005     @param hqes: The hostqueuentries the shart sent.
   1006 
   1007     @raises error.UnallowedRecordsSentToMaster if any of the sanity checks fail.
   1008     """
   1009     job_ids_persisted = _persist_records_with_type_sent_from_shard(
   1010             shard, jobs, models.Job)
   1011     _persist_records_with_type_sent_from_shard(
   1012             shard, hqes, models.HostQueueEntry,
   1013             job_ids_sent=job_ids_persisted)
   1014 
   1015 
   1016 def forward_single_host_rpc_to_shard(func):
   1017     """This decorator forwards rpc calls that modify a host to a shard.
   1018 
   1019     If a host is assigned to a shard, rpcs that change his attributes should be
   1020     forwarded to the shard.
   1021 
   1022     This assumes the first argument of the function represents a host id.
   1023 
   1024     @param func: The function to decorate
   1025 
   1026     @returns: The function to replace func with.
   1027     """
   1028     def replacement(**kwargs):
   1029         # Only keyword arguments can be accepted here, as we need the argument
   1030         # names to send the rpc. serviceHandler always provides arguments with
   1031         # their keywords, so this is not a problem.
   1032 
   1033         # A host record (identified by kwargs['id']) can be deleted in
   1034         # func(). Therefore, we should save the data that can be needed later
   1035         # before func() is called.
   1036         shard_hostname = None
   1037         host = models.Host.smart_get(kwargs['id'])
   1038         if host and host.shard:
   1039             shard_hostname = host.shard.rpc_hostname()
   1040         ret = func(**kwargs)
   1041         if shard_hostname and not server_utils.is_shard():
   1042             run_rpc_on_multiple_hostnames(func.func_name,
   1043                                           [shard_hostname],
   1044                                           **kwargs)
   1045         return ret
   1046 
   1047     return replacement
   1048 
   1049 
   1050 def fanout_rpc(host_objs, rpc_name, include_hostnames=True, **kwargs):
   1051     """Fanout the given rpc to shards of given hosts.
   1052 
   1053     @param host_objs: Host objects for the rpc.
   1054     @param rpc_name: The name of the rpc.
   1055     @param include_hostnames: If True, include the hostnames in the kwargs.
   1056         Hostnames are not always necessary, this functions is designed to
   1057         send rpcs to the shard a host is on, the rpcs themselves could be
   1058         related to labels, acls etc.
   1059     @param kwargs: The kwargs for the rpc.
   1060     """
   1061     # Figure out which hosts are on which shards.
   1062     shard_host_map = bucket_hosts_by_shard(
   1063             host_objs, rpc_hostnames=True)
   1064 
   1065     # Execute the rpc against the appropriate shards.
   1066     for shard, hostnames in shard_host_map.iteritems():
   1067         if include_hostnames:
   1068             kwargs['hosts'] = hostnames
   1069         try:
   1070             run_rpc_on_multiple_hostnames(rpc_name, [shard], **kwargs)
   1071         except:
   1072             ei = sys.exc_info()
   1073             new_exc = error.RPCException('RPC %s failed on shard %s due to '
   1074                     '%s: %s' % (rpc_name, shard, ei[0].__name__, ei[1]))
   1075             raise new_exc.__class__, new_exc, ei[2]
   1076 
   1077 
   1078 def run_rpc_on_multiple_hostnames(rpc_call, shard_hostnames, **kwargs):
   1079     """Runs an rpc to multiple AFEs
   1080 
   1081     This is i.e. used to propagate changes made to hosts after they are assigned
   1082     to a shard.
   1083 
   1084     @param rpc_call: Name of the rpc endpoint to call.
   1085     @param shard_hostnames: List of hostnames to run the rpcs on.
   1086     @param **kwargs: Keyword arguments to pass in the rpcs.
   1087     """
   1088     # Make sure this function is not called on shards but only on master.
   1089     assert not server_utils.is_shard()
   1090     for shard_hostname in shard_hostnames:
   1091         afe = frontend_wrappers.RetryingAFE(server=shard_hostname,
   1092                                             user=thread_local.get_user())
   1093         afe.run(rpc_call, **kwargs)
   1094 
   1095 
   1096 def get_label(name):
   1097     """Gets a label object using a given name.
   1098 
   1099     @param name: Label name.
   1100     @raises model.Label.DoesNotExist: when there is no label matching
   1101                                       the given name.
   1102     @return: a label object matching the given name.
   1103     """
   1104     try:
   1105         label = models.Label.smart_get(name)
   1106     except models.Label.DoesNotExist:
   1107         return None
   1108     return label
   1109 
   1110 
   1111 # TODO: hide the following rpcs under is_moblab
   1112 def moblab_only(func):
   1113     """Ensure moblab specific functions only run on Moblab devices."""
   1114     def verify(*args, **kwargs):
   1115         if not server_utils.is_moblab():
   1116             raise error.RPCException('RPC: %s can only run on Moblab Systems!',
   1117                                      func.__name__)
   1118         return func(*args, **kwargs)
   1119     return verify
   1120 
   1121 
   1122 def route_rpc_to_master(func):
   1123     """Route RPC to master AFE.
   1124 
   1125     When a shard receives an RPC decorated by this, the RPC is just
   1126     forwarded to the master.
   1127     When the master gets the RPC, the RPC function is executed.
   1128 
   1129     @param func: An RPC function to decorate
   1130 
   1131     @returns: A function replacing the RPC func.
   1132     """
   1133     argspec = inspect.getargspec(func)
   1134     if argspec.varargs is not None:
   1135         raise Exception('RPC function must not have *args.')
   1136 
   1137     @wraps(func)
   1138     def replacement(*args, **kwargs):
   1139         """We need special handling when decorating an RPC that can be called
   1140         directly using positional arguments.
   1141 
   1142         One example is rpc_interface.create_job().
   1143         rpc_interface.create_job_page_handler() calls the function using both
   1144         positional and keyword arguments.  Since frontend.RpcClient.run()
   1145         takes only keyword arguments for an RPC, positional arguments of the
   1146         RPC function need to be transformed into keyword arguments.
   1147         """
   1148         kwargs = _convert_to_kwargs_only(func, args, kwargs)
   1149         if server_utils.is_shard():
   1150             afe = frontend_wrappers.RetryingAFE(
   1151                     server=server_utils.get_global_afe_hostname(),
   1152                     user=thread_local.get_user())
   1153             return afe.run(func.func_name, **kwargs)
   1154         return func(**kwargs)
   1155 
   1156     return replacement
   1157 
   1158 
   1159 def _convert_to_kwargs_only(func, args, kwargs):
   1160     """Convert a function call's arguments to a kwargs dict.
   1161 
   1162     This is best illustrated with an example.  Given:
   1163 
   1164     def foo(a, b, **kwargs):
   1165         pass
   1166     _to_kwargs(foo, (1, 2), {'c': 3})  # corresponding to foo(1, 2, c=3)
   1167 
   1168         foo(**kwargs)
   1169 
   1170     @param func: function whose signature to use
   1171     @param args: positional arguments of call
   1172     @param kwargs: keyword arguments of call
   1173 
   1174     @returns: kwargs dict
   1175     """
   1176     argspec = inspect.getargspec(func)
   1177     # callargs looks like {'a': 1, 'b': 2, 'kwargs': {'c': 3}}
   1178     callargs = inspect.getcallargs(func, *args, **kwargs)
   1179     if argspec.keywords is None:
   1180         kwargs = {}
   1181     else:
   1182         kwargs = callargs.pop(argspec.keywords)
   1183     kwargs.update(callargs)
   1184     return kwargs
   1185 
   1186 
   1187 def get_sample_dut(board, pool):
   1188     """Get a dut with the given board and pool.
   1189 
   1190     This method is used to help to locate a dut with the given board and pool.
   1191     The dut then can be used to identify a devserver in the same subnet.
   1192 
   1193     @param board: Name of the board.
   1194     @param pool: Name of the pool.
   1195 
   1196     @return: Name of a dut with the given board and pool.
   1197     """
   1198     if not (dev_server.PREFER_LOCAL_DEVSERVER and pool and board):
   1199         return None
   1200     hosts = list(get_host_query(
   1201         multiple_labels=('pool:%s' % pool, 'board:%s' % board),
   1202         exclude_only_if_needed_labels=False,
   1203         valid_only=True,
   1204         filter_data={},
   1205     ))
   1206     if not hosts:
   1207         return None
   1208     else:
   1209         return hosts[0].hostname
   1210