Home | History | Annotate | Download | only in cli
      1 # Copyright (c) 2011 The Chromium OS Authors. All rights reserved.
      2 # Use of this source code is governed by a BSD-style license that can be
      3 # found in the LICENSE file.
      4 
      5 import common
      6 import inspect, new, socket, sys
      7 
      8 from autotest_lib.client.bin import utils
      9 from autotest_lib.cli import host, rpc
     10 from autotest_lib.server import hosts
     11 from autotest_lib.server.cros.dynamic_suite import frontend_wrappers
     12 from autotest_lib.client.common_lib import error, host_protections
     13 
     14 
     15 # In order for hosts to work correctly, some of its variables must be setup.
     16 hosts.factory.ssh_user = 'root'
     17 hosts.factory.ssh_pass = ''
     18 hosts.factory.ssh_port = 22
     19 hosts.factory.ssh_verbosity_flag = ''
     20 hosts.factory.ssh_options = ''
     21 
     22 
     23 # pylint: disable=missing-docstring
     24 class site_host(host.host):
     25     pass
     26 
     27 
     28 class site_host_create(site_host, host.host_create):
     29     """
     30     site_host_create subclasses host_create in host.py.
     31     """
     32 
     33     @classmethod
     34     def construct_without_parse(
     35             cls, web_server, hosts, platform=None,
     36             locked=False, lock_reason='', labels=[], acls=[],
     37             protection=host_protections.Protection.NO_PROTECTION):
     38         """Construct an site_host_create object and fill in data from args.
     39 
     40         Do not need to call parse after the construction.
     41 
     42         Return an object of site_host_create ready to execute.
     43 
     44         @param web_server: A string specifies the autotest webserver url.
     45             It is needed to setup comm to make rpc.
     46         @param hosts: A list of hostnames as strings.
     47         @param platform: A string or None.
     48         @param locked: A boolean.
     49         @param lock_reason: A string.
     50         @param labels: A list of labels as strings.
     51         @param acls: A list of acls as strings.
     52         @param protection: An enum defined in host_protections.
     53         """
     54         obj = cls()
     55         obj.web_server = web_server
     56         try:
     57             # Setup stuff needed for afe comm.
     58             obj.afe = rpc.afe_comm(web_server)
     59         except rpc.AuthError, s:
     60             obj.failure(str(s), fatal=True)
     61         obj.hosts = hosts
     62         obj.platform = platform
     63         obj.locked = locked
     64         if locked and lock_reason.strip():
     65             obj.data['lock_reason'] = lock_reason.strip()
     66         obj.labels = labels
     67         obj.acls = acls
     68         if protection:
     69             obj.data['protection'] = protection
     70         # TODO(kevcheng): Update the admin page to take in serials?
     71         obj.serials = None
     72         return obj
     73 
     74 
     75     def _execute_add_one_host(self, host):
     76         # Always add the hosts as locked to avoid the host
     77         # being picked up by the scheduler before it's ACL'ed.
     78         self.data['locked'] = True
     79         if not self.locked:
     80             self.data['lock_reason'] = 'Forced lock on device creation'
     81         self.execute_rpc('add_host', hostname=host,
     82                          status="Ready", **self.data)
     83         # If there are labels avaliable for host, use them.
     84         host_info = self.host_info_map[host]
     85         labels = set(self.labels)
     86         if host_info.labels:
     87             labels.update(host_info.labels)
     88         # Now add the platform label.
     89         # If a platform was not provided and we were able to retrieve it
     90         # from the host, use the retrieved platform.
     91         platform = self.platform if self.platform else host_info.platform
     92         if platform:
     93             labels.add(platform)
     94 
     95         if len(labels):
     96             self.execute_rpc('host_add_labels', id=host, labels=list(labels))
     97 
     98         if self.serials:
     99             afe = frontend_wrappers.RetryingAFE(timeout_min=5, delay_sec=10)
    100             afe.set_host_attribute('serials', ','.join(self.serials),
    101                                    hostname=host)
    102 
    103 
    104     def execute(self):
    105         # Check to see if the platform or any other labels can be grabbed from
    106         # the hosts.
    107         self.host_info_map = {}
    108         for host in self.hosts:
    109             try:
    110                 if utils.ping(host, tries=1, deadline=1) == 0:
    111                     if self.serials and len(self.serials) > 1:
    112                         host_dut = hosts.create_testbed(
    113                                 host, adb_serials=self.serials)
    114                     else:
    115                         adb_serial = None
    116                         if self.serials:
    117                             adb_serial = self.serials[0]
    118                         host_dut = hosts.create_host(host,
    119                                                      adb_serial=adb_serial)
    120                     host_info = host_information(host,
    121                                                  host_dut.get_platform(),
    122                                                  host_dut.get_labels())
    123                 else:
    124                     # Can't ping the host, use default information.
    125                     host_info = host_information(host, None, [])
    126             except (socket.gaierror, error.AutoservRunError,
    127                     error.AutoservSSHTimeout):
    128                 # We may be adding a host that does not exist yet or we can't
    129                 # reach due to hostname/address issues or if the host is down.
    130                 host_info = host_information(host, None, [])
    131             self.host_info_map[host] = host_info
    132         # We need to check if these labels & ACLs exist,
    133         # and create them if not.
    134         if self.platform:
    135             self.check_and_create_items('get_labels', 'add_label',
    136                                         [self.platform],
    137                                         platform=True)
    138         else:
    139             # No platform was provided so check and create the platform label
    140             # for each host.
    141             platforms = []
    142             for host_info in self.host_info_map.values():
    143                 if host_info.platform and host_info.platform not in platforms:
    144                     platforms.append(host_info.platform)
    145             if platforms:
    146                 self.check_and_create_items('get_labels', 'add_label',
    147                                             platforms,
    148                                             platform=True)
    149         labels_to_check_and_create = self.labels[:]
    150         for host_info in self.host_info_map.values():
    151             labels_to_check_and_create = (host_info.labels +
    152                                           labels_to_check_and_create)
    153         if labels_to_check_and_create:
    154             self.check_and_create_items('get_labels', 'add_label',
    155                                         labels_to_check_and_create,
    156                                         platform=False)
    157 
    158         if self.acls:
    159             self.check_and_create_items('get_acl_groups',
    160                                         'add_acl_group',
    161                                         self.acls)
    162 
    163         return self._execute_add_hosts()
    164 
    165 
    166 class host_information(object):
    167     """Store host information so we don't have to keep looking it up."""
    168 
    169 
    170     def __init__(self, hostname, platform, labels):
    171         self.hostname = hostname
    172         self.platform = platform
    173         self.labels = labels
    174 
    175 
    176 # Any classes we don't override in host should be copied automatically
    177 for cls in [getattr(host, n) for n in dir(host) if not n.startswith("_")]:
    178     if not inspect.isclass(cls):
    179         continue
    180     cls_name = cls.__name__
    181     site_cls_name = 'site_' + cls_name
    182     if hasattr(sys.modules[__name__], site_cls_name):
    183         continue
    184     bases = (site_host, cls)
    185     members = {'__doc__': cls.__doc__}
    186     site_cls = new.classobj(site_cls_name, bases, members)
    187     setattr(sys.modules[__name__], site_cls_name, site_cls)
    188