Home | History | Annotate | Download | only in ssh
      1 # Copyright 2016 - The Android Open Source Project
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 
     15 import collections
     16 import logging
     17 import os
     18 import re
     19 import shutil
     20 import tempfile
     21 import threading
     22 import time
     23 import uuid
     24 
     25 from acts.controllers.utils_lib import host_utils
     26 from acts.controllers.utils_lib.ssh import formatter
     27 from acts.libs.proc import job
     28 
     29 
     30 class Error(Exception):
     31     """An error occured during an ssh operation."""
     32 
     33 
     34 class CommandError(Exception):
     35     """An error occured with the command.
     36 
     37     Attributes:
     38         result: The results of the ssh command that had the error.
     39     """
     40 
     41     def __init__(self, result):
     42         """
     43         Args:
     44             result: The result of the ssh command that created the problem.
     45         """
     46         self.result = result
     47 
     48     def __str__(self):
     49         return 'cmd: %s\nstdout: %s\nstderr: %s' % (
     50             self.result.command, self.result.stdout, self.result.stderr)
     51 
     52 
     53 _Tunnel = collections.namedtuple('_Tunnel',
     54                                  ['local_port', 'remote_port', 'proc'])
     55 
     56 
     57 class SshConnection(object):
     58     """Provides a connection to a remote machine through ssh.
     59 
     60     Provides the ability to connect to a remote machine and execute a command
     61     on it. The connection will try to establish a persistent connection When
     62     a command is run. If the persistent connection fails it will attempt
     63     to connect normally.
     64     """
     65 
     66     @property
     67     def socket_path(self):
     68         """Returns: The os path to the master socket file."""
     69         return os.path.join(self._master_ssh_tempdir, 'socket')
     70 
     71     def __init__(self, settings):
     72         """
     73         Args:
     74             settings: The ssh settings to use for this conneciton.
     75             formatter: The object that will handle formatting ssh command
     76                        for use with the background job.
     77         """
     78         self._settings = settings
     79         self._formatter = formatter.SshFormatter()
     80         self._lock = threading.Lock()
     81         self._master_ssh_proc = None
     82         self._master_ssh_tempdir = None
     83         self._tunnels = list()
     84 
     85     def __del__(self):
     86         self.close()
     87 
     88     def setup_master_ssh(self, timeout_seconds=5):
     89         """Sets up the master ssh connection.
     90 
     91         Sets up the inital master ssh connection if it has not already been
     92         started.
     93 
     94         Args:
     95             timeout_seconds: The time to wait for the master ssh connection to be made.
     96 
     97         Raises:
     98             Error: When setting up the master ssh connection fails.
     99         """
    100         with self._lock:
    101             if self._master_ssh_proc is not None:
    102                 socket_path = self.socket_path
    103                 if (not os.path.exists(socket_path) or
    104                         self._master_ssh_proc.poll() is not None):
    105                     logging.debug('Master ssh connection to %s is down.',
    106                                   self._settings.hostname)
    107                     self._cleanup_master_ssh()
    108 
    109             if self._master_ssh_proc is None:
    110                 # Create a shared socket in a temp location.
    111                 self._master_ssh_tempdir = tempfile.mkdtemp(
    112                     prefix='ssh-master')
    113 
    114                 # Setup flags and options for running the master ssh
    115                 # -N: Do not execute a remote command.
    116                 # ControlMaster: Spawn a master connection.
    117                 # ControlPath: The master connection socket path.
    118                 extra_flags = {'-N': None}
    119                 extra_options = {
    120                     'ControlMaster': True,
    121                     'ControlPath': self.socket_path,
    122                     'BatchMode': True
    123                 }
    124 
    125                 # Construct the command and start it.
    126                 master_cmd = self._formatter.format_ssh_local_command(
    127                     self._settings,
    128                     extra_flags=extra_flags,
    129                     extra_options=extra_options)
    130                 logging.info('Starting master ssh connection to %s',
    131                              self._settings.hostname)
    132                 self._master_ssh_proc = job.run_async(master_cmd)
    133 
    134                 end_time = time.time() + timeout_seconds
    135 
    136                 while time.time() < end_time:
    137                     if os.path.exists(self.socket_path):
    138                         break
    139                     time.sleep(.2)
    140                 else:
    141                     self._cleanup_master_ssh()
    142                     raise Error('Master ssh connection timed out.')
    143 
    144     def run(self,
    145             command,
    146             timeout=3600,
    147             ignore_status=False,
    148             env=None,
    149             io_encoding='utf-8'):
    150         """Runs a remote command over ssh.
    151 
    152         Will ssh to a remote host and run a command. This method will
    153         block until the remote command is finished.
    154 
    155         Args:
    156             command: The command to execute over ssh. Can be either a string
    157                      or a list.
    158             timeout: number seconds to wait for command to finish.
    159             ignore_status: bool True to ignore the exit code of the remote
    160                            subprocess.  Note that if you do ignore status codes,
    161                            you should handle non-zero exit codes explicitly.
    162             env: dict enviroment variables to setup on the remote host.
    163             io_encoding: str unicode encoding of command output.
    164 
    165         Returns:
    166             A job.Result containing the results of the ssh command.
    167 
    168         Raises:
    169             job.TimeoutError: When the remote command took to long to execute.
    170             Error: When the ssh connection failed to be created.
    171             CommandError: Ssh worked, but the command had an error executing.
    172         """
    173         if env is None:
    174             env = {}
    175 
    176         try:
    177             self.setup_master_ssh(self._settings.connect_timeout)
    178         except Error:
    179             logging.warning('Failed to create master ssh connection, using '
    180                             'normal ssh connection.')
    181 
    182         extra_options = {'BatchMode': True}
    183         if self._master_ssh_proc:
    184             extra_options['ControlPath'] = self.socket_path
    185 
    186         identifier = str(uuid.uuid4())
    187         full_command = 'echo "CONNECTED: %s"; %s' % (identifier, command)
    188 
    189         terminal_command = self._formatter.format_command(
    190             full_command, env, self._settings, extra_options=extra_options)
    191 
    192         dns_retry_count = 2
    193         while True:
    194             result = job.run(terminal_command,
    195                              ignore_status=True,
    196                              timeout=timeout)
    197             output = result.stdout
    198 
    199             # Check for a connected message to prevent false negatives.
    200             valid_connection = re.search(
    201                 '^CONNECTED: %s' % identifier, output, flags=re.MULTILINE)
    202             if valid_connection:
    203                 # Remove the first line that contains the connect message.
    204                 line_index = output.find('\n')
    205                 real_output = output[line_index + 1:].encode(result._encoding)
    206 
    207                 result = job.Result(
    208                     command=result.command,
    209                     stdout=real_output,
    210                     stderr=result._raw_stderr,
    211                     exit_status=result.exit_status,
    212                     duration=result.duration,
    213                     did_timeout=result.did_timeout,
    214                     encoding=result._encoding)
    215                 if result.exit_status:
    216                     raise job.Error(result)
    217                 return result
    218 
    219             error_string = result.stderr
    220 
    221             had_dns_failure = (result.exit_status == 255 and re.search(
    222                 r'^ssh: .*: Name or service not known',
    223                 error_string,
    224                 flags=re.MULTILINE))
    225             if had_dns_failure:
    226                 dns_retry_count -= 1
    227                 if not dns_retry_count:
    228                     raise Error('DNS failed to find host.', result)
    229                 logging.debug('Failed to connecto to host, retrying...')
    230             else:
    231                 break
    232 
    233         had_timeout = re.search(
    234             r'^ssh: connect to host .* port .*: '
    235             r'Connection timed out\r$',
    236             error_string,
    237             flags=re.MULTILINE)
    238         if had_timeout:
    239             raise Error('Ssh timed out.', result)
    240 
    241         permission_denied = 'Permission denied' in error_string
    242         if permission_denied:
    243             raise Error('Permission denied.', result)
    244 
    245         unknown_host = re.search(
    246             r'ssh: Could not resolve hostname .*: '
    247             r'Name or service not known',
    248             error_string,
    249             flags=re.MULTILINE)
    250         if unknown_host:
    251             raise Error('Unknown host.', result)
    252 
    253         raise Error('The job failed for unkown reasons.', result)
    254 
    255     def run_async(self, command, env=None):
    256         """Starts up a background command over ssh.
    257 
    258         Will ssh to a remote host and startup a command. This method will
    259         block until there is confirmation that the remote command has started.
    260 
    261         Args:
    262             command: The command to execute over ssh. Can be either a string
    263                      or a list.
    264             env: A dictonary of enviroment variables to setup on the remote
    265                  host.
    266 
    267         Returns:
    268             The result of the command to launch the background job.
    269 
    270         Raises:
    271             CmdTimeoutError: When the remote command took to long to execute.
    272             SshTimeoutError: When the connection took to long to established.
    273             SshPermissionDeniedError: When permission is not allowed on the
    274                                       remote host.
    275         """
    276         command = '(%s) < /dev/null > /dev/null 2>&1 & echo -n $!' % command
    277         result = self.run(command, env=env)
    278         return result
    279 
    280     def close(self):
    281         """Clean up open connections to remote host."""
    282         self._cleanup_master_ssh()
    283         while self._tunnels:
    284             self.close_ssh_tunnel(self._tunnels[0].local_port)
    285 
    286     def _cleanup_master_ssh(self):
    287         """
    288         Release all resources (process, temporary directory) used by an active
    289         master SSH connection.
    290         """
    291         # If a master SSH connection is running, kill it.
    292         if self._master_ssh_proc is not None:
    293             logging.debug('Nuking master_ssh_job.')
    294             self._master_ssh_proc.kill()
    295             self._master_ssh_proc.wait()
    296             self._master_ssh_proc = None
    297 
    298         # Remove the temporary directory for the master SSH socket.
    299         if self._master_ssh_tempdir is not None:
    300             logging.debug('Cleaning master_ssh_tempdir.')
    301             shutil.rmtree(self._master_ssh_tempdir)
    302             self._master_ssh_tempdir = None
    303 
    304     def create_ssh_tunnel(self, port, local_port=None):
    305         """Create an ssh tunnel from local_port to port.
    306 
    307         This securely forwards traffic from local_port on this machine to the
    308         remote SSH host at port.
    309 
    310         Args:
    311             port: remote port on the host.
    312             local_port: local forwarding port, or None to pick an available
    313                         port.
    314 
    315         Returns:
    316             the created tunnel process.
    317         """
    318         if local_port is None:
    319             local_port = host_utils.get_available_host_port()
    320         else:
    321             for tunnel in self._tunnels:
    322                 if tunnel.remote_port == port:
    323                     return tunnel.local_port
    324 
    325         extra_flags = {
    326             '-n': None,  # Read from /dev/null for stdin
    327             '-N': None,  # Do not execute a remote command
    328             '-q': None,  # Suppress warnings and diagnostic commands
    329             '-L': '%d:localhost:%d' % (local_port, port),
    330         }
    331         extra_options = dict()
    332         if self._master_ssh_proc:
    333             extra_options['ControlPath'] = self.socket_path
    334         tunnel_cmd = self._formatter.format_ssh_local_command(
    335             self._settings,
    336             extra_flags=extra_flags,
    337             extra_options=extra_options)
    338         logging.debug('Full tunnel command: %s', tunnel_cmd)
    339         # Exec the ssh process directly so that when we deliver signals, we
    340         # deliver them straight to the child process.
    341         tunnel_proc = job.run_async(tunnel_cmd)
    342         logging.debug('Started ssh tunnel, local = %d'
    343                       ' remote = %d, pid = %d', local_port, port,
    344                       tunnel_proc.pid)
    345         self._tunnels.append(_Tunnel(local_port, port, tunnel_proc))
    346         return local_port
    347 
    348     def close_ssh_tunnel(self, local_port):
    349         """Close a previously created ssh tunnel of a TCP port.
    350 
    351         Args:
    352             local_port: int port on localhost previously forwarded to the remote
    353                         host.
    354 
    355         Returns:
    356             integer port number this port was forwarded to on the remote host or
    357             None if no tunnel was found.
    358         """
    359         idx = None
    360         for i, tunnel in enumerate(self._tunnels):
    361             if tunnel.local_port == local_port:
    362                 idx = i
    363                 break
    364         if idx is not None:
    365             tunnel = self._tunnels.pop(idx)
    366             tunnel.proc.kill()
    367             tunnel.proc.wait()
    368             return tunnel.remote_port
    369         return None
    370 
    371     def send_file(self, local_path, remote_path):
    372         """Send a file from the local host to the remote host.
    373 
    374         Args:
    375             local_path: string path of file to send on local host.
    376             remote_path: string path to copy file to on remote host.
    377         """
    378         # TODO: This may belong somewhere else: b/32572515
    379         user_host = self._formatter.format_host_name(self._settings)
    380         job.run('scp %s %s:%s' % (local_path, user_host, remote_path))
    381 
    382     def find_free_port(self, interface_name='localhost'):
    383         """Find a unused port on the remote host.
    384 
    385         Note that this method is inherently racy, since it is impossible
    386         to promise that the remote port will remain free.
    387 
    388         Args:
    389             interface_name: string name of interface to check whether a
    390                             port is used against.
    391 
    392         Returns:
    393             integer port number on remote interface that was free.
    394         """
    395         # TODO: This may belong somewhere else: b/3257251
    396         free_port_cmd = (
    397             'python -c "import socket; s=socket.socket(); '
    398             's.bind((\'%s\', 0)); print(s.getsockname()[1]); s.close()"'
    399         ) % interface_name
    400         port = int(self.run(free_port_cmd).stdout)
    401         # Yield to the os to ensure the port gets cleaned up.
    402         time.sleep(0.001)
    403         return port
    404