1 # pylint: disable-msg=C0111 2 import os, time, signal, socket, re, fnmatch, logging, threading 3 import paramiko 4 5 from autotest_lib.client.common_lib import utils, error, global_config 6 from autotest_lib.server import subcommand 7 from autotest_lib.server.hosts import abstract_ssh 8 9 10 class ParamikoHost(abstract_ssh.AbstractSSHHost): 11 KEEPALIVE_TIMEOUT_SECONDS = 30 12 CONNECT_TIMEOUT_SECONDS = 30 13 CONNECT_TIMEOUT_RETRIES = 3 14 BUFFSIZE = 2**16 15 16 def _initialize(self, hostname, *args, **dargs): 17 super(ParamikoHost, self)._initialize(hostname=hostname, *args, **dargs) 18 19 # paramiko is very noisy, tone down the logging 20 paramiko.util.log_to_file("/dev/null", paramiko.util.ERROR) 21 22 self.keys = self.get_user_keys(hostname) 23 self.pid = None 24 25 26 @staticmethod 27 def _load_key(path): 28 """Given a path to a private key file, load the appropriate keyfile. 29 30 Tries to load the file as both an RSAKey and a DSAKey. If the file 31 cannot be loaded as either type, returns None.""" 32 try: 33 return paramiko.DSSKey.from_private_key_file(path) 34 except paramiko.SSHException: 35 try: 36 return paramiko.RSAKey.from_private_key_file(path) 37 except paramiko.SSHException: 38 return None 39 40 41 @staticmethod 42 def _parse_config_line(line): 43 """Given an ssh config line, return a (key, value) tuple for the 44 config value listed in the line, or (None, None)""" 45 match = re.match(r"\s*(\w+)\s*=?(.*)\n", line) 46 if match: 47 return match.groups() 48 else: 49 return None, None 50 51 52 @staticmethod 53 def get_user_keys(hostname): 54 """Returns a mapping of path -> paramiko.PKey entries available for 55 this user. Keys are found in the default locations (~/.ssh/id_[d|r]sa) 56 as well as any IdentityFile entries in the standard ssh config files. 57 """ 58 raw_identity_files = ["~/.ssh/id_dsa", "~/.ssh/id_rsa"] 59 for config_path in ("/etc/ssh/ssh_config", "~/.ssh/config"): 60 config_path = os.path.expanduser(config_path) 61 if not os.path.exists(config_path): 62 continue 63 host_pattern = "*" 64 config_lines = open(config_path).readlines() 65 for line in config_lines: 66 key, value = ParamikoHost._parse_config_line(line) 67 if key == "Host": 68 host_pattern = value 69 elif (key == "IdentityFile" 70 and fnmatch.fnmatch(hostname, host_pattern)): 71 raw_identity_files.append(value) 72 73 # drop any files that use percent-escapes; we don't support them 74 identity_files = [] 75 UNSUPPORTED_ESCAPES = ["%d", "%u", "%l", "%h", "%r"] 76 for path in raw_identity_files: 77 # skip this path if it uses % escapes 78 if sum((escape in path) for escape in UNSUPPORTED_ESCAPES): 79 continue 80 path = os.path.expanduser(path) 81 if os.path.exists(path): 82 identity_files.append(path) 83 84 # load up all the keys that we can and return them 85 user_keys = {} 86 for path in identity_files: 87 key = ParamikoHost._load_key(path) 88 if key: 89 user_keys[path] = key 90 91 # load up all the ssh agent keys 92 use_sshagent = global_config.global_config.get_config_value( 93 'AUTOSERV', 'use_sshagent_with_paramiko', type=bool) 94 if use_sshagent: 95 ssh_agent = paramiko.Agent() 96 for i, key in enumerate(ssh_agent.get_keys()): 97 user_keys['agent-key-%d' % i] = key 98 99 return user_keys 100 101 102 def _check_transport_error(self, transport): 103 error = transport.get_exception() 104 if error: 105 transport.close() 106 raise error 107 108 109 def _connect_socket(self): 110 """Return a socket for use in instantiating a paramiko transport. Does 111 not have to be a literal socket, it can be anything that the 112 paramiko.Transport constructor accepts.""" 113 return self.hostname, self.port 114 115 116 def _connect_transport(self, pkey): 117 for _ in xrange(self.CONNECT_TIMEOUT_RETRIES): 118 transport = paramiko.Transport(self._connect_socket()) 119 completed = threading.Event() 120 transport.start_client(completed) 121 completed.wait(self.CONNECT_TIMEOUT_SECONDS) 122 if completed.isSet(): 123 self._check_transport_error(transport) 124 completed.clear() 125 transport.auth_publickey(self.user, pkey, completed) 126 completed.wait(self.CONNECT_TIMEOUT_SECONDS) 127 if completed.isSet(): 128 self._check_transport_error(transport) 129 if not transport.is_authenticated(): 130 transport.close() 131 raise paramiko.AuthenticationException() 132 return transport 133 logging.warning("SSH negotiation (%s:%d) timed out, retrying", 134 self.hostname, self.port) 135 # HACK: we can't count on transport.join not hanging now, either 136 transport.join = lambda: None 137 transport.close() 138 logging.error("SSH negotation (%s:%d) has timed out %s times, " 139 "giving up", self.hostname, self.port, 140 self.CONNECT_TIMEOUT_RETRIES) 141 raise error.AutoservSSHTimeout("SSH negotiation timed out") 142 143 144 def _init_transport(self): 145 for path, key in self.keys.iteritems(): 146 try: 147 logging.debug("Connecting with %s", path) 148 transport = self._connect_transport(key) 149 transport.set_keepalive(self.KEEPALIVE_TIMEOUT_SECONDS) 150 self.transport = transport 151 self.pid = os.getpid() 152 return 153 except paramiko.AuthenticationException: 154 logging.debug("Authentication failure") 155 else: 156 raise error.AutoservSshPermissionDeniedError( 157 "Permission denied using all keys available to ParamikoHost", 158 utils.CmdResult()) 159 160 161 def _open_channel(self, timeout): 162 start_time = time.time() 163 if os.getpid() != self.pid: 164 if self.pid is not None: 165 # HACK: paramiko tries to join() on its worker thread 166 # and this just hangs on linux after a fork() 167 self.transport.join = lambda: None 168 self.transport.atfork() 169 join_hook = lambda cmd: self._close_transport() 170 subcommand.subcommand.register_join_hook(join_hook) 171 logging.debug("Reopening SSH connection after a process fork") 172 self._init_transport() 173 174 channel = None 175 try: 176 channel = self.transport.open_session() 177 except (socket.error, paramiko.SSHException, EOFError), e: 178 logging.warning("Exception occured while opening session: %s", e) 179 if time.time() - start_time >= timeout: 180 raise error.AutoservSSHTimeout("ssh failed: %s" % e) 181 182 if not channel: 183 # we couldn't get a channel; re-initing transport should fix that 184 try: 185 self.transport.close() 186 except Exception, e: 187 logging.debug("paramiko.Transport.close failed with %s", e) 188 self._init_transport() 189 return self.transport.open_session() 190 else: 191 return channel 192 193 194 def _close_transport(self): 195 if os.getpid() == self.pid: 196 self.transport.close() 197 198 199 def close(self): 200 super(ParamikoHost, self).close() 201 self._close_transport() 202 203 204 @classmethod 205 def _exhaust_stream(cls, tee, output_list, recvfunc): 206 while True: 207 try: 208 output_list.append(recvfunc(cls.BUFFSIZE)) 209 except socket.timeout: 210 return 211 tee.write(output_list[-1]) 212 if not output_list[-1]: 213 return 214 215 216 @classmethod 217 def __send_stdin(cls, channel, stdin): 218 if not stdin or not channel.send_ready(): 219 # nothing more to send or just no space to send now 220 return 221 222 sent = channel.send(stdin[:cls.BUFFSIZE]) 223 if not sent: 224 logging.warning('Could not send a single stdin byte.') 225 else: 226 stdin = stdin[sent:] 227 if not stdin: 228 # no more stdin input, close output direction 229 channel.shutdown_write() 230 return stdin 231 232 233 def run(self, command, timeout=3600, ignore_status=False, 234 stdout_tee=utils.TEE_TO_LOGS, stderr_tee=utils.TEE_TO_LOGS, 235 connect_timeout=30, stdin=None, verbose=True, args=(), 236 ignore_timeout=False): 237 """ 238 Run a command on the remote host. 239 @see common_lib.hosts.host.run() 240 241 @param connect_timeout: connection timeout (in seconds) 242 @param options: string with additional ssh command options 243 @param verbose: log the commands 244 @param ignore_timeout: bool True command timeouts should be 245 ignored. Will return None on command timeout. 246 247 @raises AutoservRunError: if the command failed 248 @raises AutoservSSHTimeout: ssh connection has timed out 249 """ 250 251 stdout = utils.get_stream_tee_file( 252 stdout_tee, utils.DEFAULT_STDOUT_LEVEL, 253 prefix=utils.STDOUT_PREFIX) 254 stderr = utils.get_stream_tee_file( 255 stderr_tee, utils.get_stderr_level(ignore_status), 256 prefix=utils.STDERR_PREFIX) 257 258 for arg in args: 259 command += ' "%s"' % utils.sh_escape(arg) 260 261 if verbose: 262 logging.debug("Running (ssh-paramiko) '%s'", command) 263 264 # start up the command 265 start_time = time.time() 266 try: 267 channel = self._open_channel(timeout) 268 channel.exec_command(command) 269 except (socket.error, paramiko.SSHException, EOFError), e: 270 # This has to match the string from paramiko *exactly*. 271 if str(e) != 'Channel closed.': 272 raise error.AutoservSSHTimeout("ssh failed: %s" % e) 273 274 # pull in all the stdout, stderr until the command terminates 275 raw_stdout, raw_stderr = [], [] 276 timed_out = False 277 while not channel.exit_status_ready(): 278 if channel.recv_ready(): 279 raw_stdout.append(channel.recv(self.BUFFSIZE)) 280 stdout.write(raw_stdout[-1]) 281 if channel.recv_stderr_ready(): 282 raw_stderr.append(channel.recv_stderr(self.BUFFSIZE)) 283 stderr.write(raw_stderr[-1]) 284 if timeout and time.time() - start_time > timeout: 285 timed_out = True 286 break 287 stdin = self.__send_stdin(channel, stdin) 288 time.sleep(1) 289 290 if timed_out: 291 exit_status = -signal.SIGTERM 292 else: 293 exit_status = channel.recv_exit_status() 294 channel.settimeout(10) 295 self._exhaust_stream(stdout, raw_stdout, channel.recv) 296 self._exhaust_stream(stderr, raw_stderr, channel.recv_stderr) 297 channel.close() 298 duration = time.time() - start_time 299 300 # create the appropriate results 301 stdout = "".join(raw_stdout) 302 stderr = "".join(raw_stderr) 303 result = utils.CmdResult(command, stdout, stderr, exit_status, 304 duration) 305 if exit_status == -signal.SIGHUP: 306 msg = "ssh connection unexpectedly terminated" 307 raise error.AutoservRunError(msg, result) 308 if timed_out: 309 logging.warning('Paramiko command timed out after %s sec: %s', timeout, 310 command) 311 if not ignore_timeout: 312 raise error.AutoservRunError("command timed out", result) 313 if not ignore_status and exit_status: 314 raise error.AutoservRunError(command, result) 315 return result 316