Home | History | Annotate | Download | only in hosts
      1 # pylint: disable-msg=C0111
      2 import os, time, signal, socket, re, fnmatch, logging, threading
      3 import paramiko
      4 
      5 from autotest_lib.client.common_lib import utils, error, global_config
      6 from autotest_lib.server import subcommand
      7 from autotest_lib.server.hosts import abstract_ssh
      8 
      9 
     10 class ParamikoHost(abstract_ssh.AbstractSSHHost):
     11     KEEPALIVE_TIMEOUT_SECONDS = 30
     12     CONNECT_TIMEOUT_SECONDS = 30
     13     CONNECT_TIMEOUT_RETRIES = 3
     14     BUFFSIZE = 2**16
     15 
     16     def _initialize(self, hostname, *args, **dargs):
     17         super(ParamikoHost, self)._initialize(hostname=hostname, *args, **dargs)
     18 
     19         # paramiko is very noisy, tone down the logging
     20         paramiko.util.log_to_file("/dev/null", paramiko.util.ERROR)
     21 
     22         self.keys = self.get_user_keys(hostname)
     23         self.pid = None
     24 
     25 
     26     @staticmethod
     27     def _load_key(path):
     28         """Given a path to a private key file, load the appropriate keyfile.
     29 
     30         Tries to load the file as both an RSAKey and a DSAKey. If the file
     31         cannot be loaded as either type, returns None."""
     32         try:
     33             return paramiko.DSSKey.from_private_key_file(path)
     34         except paramiko.SSHException:
     35             try:
     36                 return paramiko.RSAKey.from_private_key_file(path)
     37             except paramiko.SSHException:
     38                 return None
     39 
     40 
     41     @staticmethod
     42     def _parse_config_line(line):
     43         """Given an ssh config line, return a (key, value) tuple for the
     44         config value listed in the line, or (None, None)"""
     45         match = re.match(r"\s*(\w+)\s*=?(.*)\n", line)
     46         if match:
     47             return match.groups()
     48         else:
     49             return None, None
     50 
     51 
     52     @staticmethod
     53     def get_user_keys(hostname):
     54         """Returns a mapping of path -> paramiko.PKey entries available for
     55         this user. Keys are found in the default locations (~/.ssh/id_[d|r]sa)
     56         as well as any IdentityFile entries in the standard ssh config files.
     57         """
     58         raw_identity_files = ["~/.ssh/id_dsa", "~/.ssh/id_rsa"]
     59         for config_path in ("/etc/ssh/ssh_config", "~/.ssh/config"):
     60             config_path = os.path.expanduser(config_path)
     61             if not os.path.exists(config_path):
     62                 continue
     63             host_pattern = "*"
     64             config_lines = open(config_path).readlines()
     65             for line in config_lines:
     66                 key, value = ParamikoHost._parse_config_line(line)
     67                 if key == "Host":
     68                     host_pattern = value
     69                 elif (key == "IdentityFile"
     70                       and fnmatch.fnmatch(hostname, host_pattern)):
     71                     raw_identity_files.append(value)
     72 
     73         # drop any files that use percent-escapes; we don't support them
     74         identity_files = []
     75         UNSUPPORTED_ESCAPES = ["%d", "%u", "%l", "%h", "%r"]
     76         for path in raw_identity_files:
     77             # skip this path if it uses % escapes
     78             if sum((escape in path) for escape in UNSUPPORTED_ESCAPES):
     79                 continue
     80             path = os.path.expanduser(path)
     81             if os.path.exists(path):
     82                 identity_files.append(path)
     83 
     84         # load up all the keys that we can and return them
     85         user_keys = {}
     86         for path in identity_files:
     87             key = ParamikoHost._load_key(path)
     88             if key:
     89                 user_keys[path] = key
     90 
     91         # load up all the ssh agent keys
     92         use_sshagent = global_config.global_config.get_config_value(
     93             'AUTOSERV', 'use_sshagent_with_paramiko', type=bool)
     94         if use_sshagent:
     95             ssh_agent = paramiko.Agent()
     96             for i, key in enumerate(ssh_agent.get_keys()):
     97                 user_keys['agent-key-%d' % i] = key
     98 
     99         return user_keys
    100 
    101 
    102     def _check_transport_error(self, transport):
    103         error = transport.get_exception()
    104         if error:
    105             transport.close()
    106             raise error
    107 
    108 
    109     def _connect_socket(self):
    110         """Return a socket for use in instantiating a paramiko transport. Does
    111         not have to be a literal socket, it can be anything that the
    112         paramiko.Transport constructor accepts."""
    113         return self.hostname, self.port
    114 
    115 
    116     def _connect_transport(self, pkey):
    117         for _ in xrange(self.CONNECT_TIMEOUT_RETRIES):
    118             transport = paramiko.Transport(self._connect_socket())
    119             completed = threading.Event()
    120             transport.start_client(completed)
    121             completed.wait(self.CONNECT_TIMEOUT_SECONDS)
    122             if completed.isSet():
    123                 self._check_transport_error(transport)
    124                 completed.clear()
    125                 transport.auth_publickey(self.user, pkey, completed)
    126                 completed.wait(self.CONNECT_TIMEOUT_SECONDS)
    127                 if completed.isSet():
    128                     self._check_transport_error(transport)
    129                     if not transport.is_authenticated():
    130                         transport.close()
    131                         raise paramiko.AuthenticationException()
    132                     return transport
    133             logging.warning("SSH negotiation (%s:%d) timed out, retrying",
    134                          self.hostname, self.port)
    135             # HACK: we can't count on transport.join not hanging now, either
    136             transport.join = lambda: None
    137             transport.close()
    138         logging.error("SSH negotation (%s:%d) has timed out %s times, "
    139                       "giving up", self.hostname, self.port,
    140                       self.CONNECT_TIMEOUT_RETRIES)
    141         raise error.AutoservSSHTimeout("SSH negotiation timed out")
    142 
    143 
    144     def _init_transport(self):
    145         for path, key in self.keys.iteritems():
    146             try:
    147                 logging.debug("Connecting with %s", path)
    148                 transport = self._connect_transport(key)
    149                 transport.set_keepalive(self.KEEPALIVE_TIMEOUT_SECONDS)
    150                 self.transport = transport
    151                 self.pid = os.getpid()
    152                 return
    153             except paramiko.AuthenticationException:
    154                 logging.debug("Authentication failure")
    155         else:
    156             raise error.AutoservSshPermissionDeniedError(
    157                 "Permission denied using all keys available to ParamikoHost",
    158                 utils.CmdResult())
    159 
    160 
    161     def _open_channel(self, timeout):
    162         start_time = time.time()
    163         if os.getpid() != self.pid:
    164             if self.pid is not None:
    165                 # HACK: paramiko tries to join() on its worker thread
    166                 # and this just hangs on linux after a fork()
    167                 self.transport.join = lambda: None
    168                 self.transport.atfork()
    169                 join_hook = lambda cmd: self._close_transport()
    170                 subcommand.subcommand.register_join_hook(join_hook)
    171                 logging.debug("Reopening SSH connection after a process fork")
    172             self._init_transport()
    173 
    174         channel = None
    175         try:
    176             channel = self.transport.open_session()
    177         except (socket.error, paramiko.SSHException, EOFError), e:
    178             logging.warning("Exception occured while opening session: %s", e)
    179             if time.time() - start_time >= timeout:
    180                 raise error.AutoservSSHTimeout("ssh failed: %s" % e)
    181 
    182         if not channel:
    183             # we couldn't get a channel; re-initing transport should fix that
    184             try:
    185                 self.transport.close()
    186             except Exception, e:
    187                 logging.debug("paramiko.Transport.close failed with %s", e)
    188             self._init_transport()
    189             return self.transport.open_session()
    190         else:
    191             return channel
    192 
    193 
    194     def _close_transport(self):
    195         if os.getpid() == self.pid:
    196             self.transport.close()
    197 
    198 
    199     def close(self):
    200         super(ParamikoHost, self).close()
    201         self._close_transport()
    202 
    203 
    204     @classmethod
    205     def _exhaust_stream(cls, tee, output_list, recvfunc):
    206         while True:
    207             try:
    208                 output_list.append(recvfunc(cls.BUFFSIZE))
    209             except socket.timeout:
    210                 return
    211             tee.write(output_list[-1])
    212             if not output_list[-1]:
    213                 return
    214 
    215 
    216     @classmethod
    217     def __send_stdin(cls, channel, stdin):
    218         if not stdin or not channel.send_ready():
    219             # nothing more to send or just no space to send now
    220             return
    221 
    222         sent = channel.send(stdin[:cls.BUFFSIZE])
    223         if not sent:
    224             logging.warning('Could not send a single stdin byte.')
    225         else:
    226             stdin = stdin[sent:]
    227             if not stdin:
    228                 # no more stdin input, close output direction
    229                 channel.shutdown_write()
    230         return stdin
    231 
    232 
    233     def run(self, command, timeout=3600, ignore_status=False,
    234             stdout_tee=utils.TEE_TO_LOGS, stderr_tee=utils.TEE_TO_LOGS,
    235             connect_timeout=30, stdin=None, verbose=True, args=(),
    236             ignore_timeout=False):
    237         """
    238         Run a command on the remote host.
    239         @see common_lib.hosts.host.run()
    240 
    241         @param connect_timeout: connection timeout (in seconds)
    242         @param options: string with additional ssh command options
    243         @param verbose: log the commands
    244         @param ignore_timeout: bool True command timeouts should be
    245                                ignored.  Will return None on command timeout.
    246 
    247         @raises AutoservRunError: if the command failed
    248         @raises AutoservSSHTimeout: ssh connection has timed out
    249         """
    250 
    251         stdout = utils.get_stream_tee_file(
    252                 stdout_tee, utils.DEFAULT_STDOUT_LEVEL,
    253                 prefix=utils.STDOUT_PREFIX)
    254         stderr = utils.get_stream_tee_file(
    255                 stderr_tee, utils.get_stderr_level(ignore_status),
    256                 prefix=utils.STDERR_PREFIX)
    257 
    258         for arg in args:
    259             command += ' "%s"' % utils.sh_escape(arg)
    260 
    261         if verbose:
    262             logging.debug("Running (ssh-paramiko) '%s'", command)
    263 
    264         # start up the command
    265         start_time = time.time()
    266         try:
    267             channel = self._open_channel(timeout)
    268             channel.exec_command(command)
    269         except (socket.error, paramiko.SSHException, EOFError), e:
    270             # This has to match the string from paramiko *exactly*.
    271             if str(e) != 'Channel closed.':
    272                 raise error.AutoservSSHTimeout("ssh failed: %s" % e)
    273 
    274         # pull in all the stdout, stderr until the command terminates
    275         raw_stdout, raw_stderr = [], []
    276         timed_out = False
    277         while not channel.exit_status_ready():
    278             if channel.recv_ready():
    279                 raw_stdout.append(channel.recv(self.BUFFSIZE))
    280                 stdout.write(raw_stdout[-1])
    281             if channel.recv_stderr_ready():
    282                 raw_stderr.append(channel.recv_stderr(self.BUFFSIZE))
    283                 stderr.write(raw_stderr[-1])
    284             if timeout and time.time() - start_time > timeout:
    285                 timed_out = True
    286                 break
    287             stdin = self.__send_stdin(channel, stdin)
    288             time.sleep(1)
    289 
    290         if timed_out:
    291             exit_status = -signal.SIGTERM
    292         else:
    293             exit_status = channel.recv_exit_status()
    294         channel.settimeout(10)
    295         self._exhaust_stream(stdout, raw_stdout, channel.recv)
    296         self._exhaust_stream(stderr, raw_stderr, channel.recv_stderr)
    297         channel.close()
    298         duration = time.time() - start_time
    299 
    300         # create the appropriate results
    301         stdout = "".join(raw_stdout)
    302         stderr = "".join(raw_stderr)
    303         result = utils.CmdResult(command, stdout, stderr, exit_status,
    304                                  duration)
    305         if exit_status == -signal.SIGHUP:
    306             msg = "ssh connection unexpectedly terminated"
    307             raise error.AutoservRunError(msg, result)
    308         if timed_out:
    309             logging.warning('Paramiko command timed out after %s sec: %s', timeout,
    310                          command)
    311             if not ignore_timeout:
    312                 raise error.AutoservRunError("command timed out", result)
    313         if not ignore_status and exit_status:
    314             raise error.AutoservRunError(command, result)
    315         return result
    316