Home | History | Annotate | Download | only in core
      1 # Copyright 2013 The Chromium Authors. All rights reserved.
      2 # Use of this source code is governed by a BSD-style license that can be
      3 # found in the LICENSE file.
      4 """A wrapper around ssh for common operations on a CrOS-based device"""
      5 import logging
      6 import os
      7 import re
      8 import shutil
      9 import stat
     10 import subprocess
     11 import tempfile
     12 
     13 # Some developers' workflow includes running the Chrome process from
     14 # /usr/local/... instead of the default location. We have to check for both
     15 # paths in order to support this workflow.
     16 _CHROME_PROCESS_REGEX = [re.compile(r'^/opt/google/chrome/chrome '),
     17                          re.compile(r'^/usr/local/?.*/chrome/chrome ')]
     18 
     19 
     20 def RunCmd(args, cwd=None, quiet=False):
     21   """Opens a subprocess to execute a program and returns its return value.
     22 
     23   Args:
     24     args: A string or a sequence of program arguments. The program to execute is
     25       the string or the first item in the args sequence.
     26     cwd: If not None, the subprocess's current directory will be changed to
     27       |cwd| before it's executed.
     28 
     29   Returns:
     30     Return code from the command execution.
     31   """
     32   if not quiet:
     33     logging.debug(' '.join(args) + ' ' + (cwd or ''))
     34   with open(os.devnull, 'w') as devnull:
     35     p = subprocess.Popen(args=args,
     36                          cwd=cwd,
     37                          stdout=devnull,
     38                          stderr=devnull,
     39                          stdin=devnull,
     40                          shell=False)
     41     return p.wait()
     42 
     43 
     44 def GetAllCmdOutput(args, cwd=None, quiet=False):
     45   """Open a subprocess to execute a program and returns its output.
     46 
     47   Args:
     48     args: A string or a sequence of program arguments. The program to execute is
     49       the string or the first item in the args sequence.
     50     cwd: If not None, the subprocess's current directory will be changed to
     51       |cwd| before it's executed.
     52 
     53   Returns:
     54     Captures and returns the command's stdout.
     55     Prints the command's stderr to logger (which defaults to stdout).
     56   """
     57   if not quiet:
     58     logging.debug(' '.join(args) + ' ' + (cwd or ''))
     59   with open(os.devnull, 'w') as devnull:
     60     p = subprocess.Popen(args=args,
     61                          cwd=cwd,
     62                          stdout=subprocess.PIPE,
     63                          stderr=subprocess.PIPE,
     64                          stdin=devnull)
     65     stdout, stderr = p.communicate()
     66     if not quiet:
     67       logging.debug(' > stdout=[%s], stderr=[%s]', stdout, stderr)
     68     return stdout, stderr
     69 
     70 
     71 def HasSSH():
     72   try:
     73     RunCmd(['ssh'], quiet=True)
     74     RunCmd(['scp'], quiet=True)
     75     logging.debug("HasSSH()->True")
     76     return True
     77   except OSError:
     78     logging.debug("HasSSH()->False")
     79     return False
     80 
     81 
     82 class LoginException(Exception):
     83   pass
     84 
     85 
     86 class KeylessLoginRequiredException(LoginException):
     87   pass
     88 
     89 
     90 class DNSFailureException(LoginException):
     91   pass
     92 
     93 
     94 class CrOSInterface(object):
     95 
     96   def __init__(self, hostname=None, ssh_port=None, ssh_identity=None):
     97     self._hostname = hostname
     98     self._ssh_port = ssh_port
     99 
    100     # List of ports generated from GetRemotePort() that may not be in use yet.
    101     self._reserved_ports = []
    102 
    103     if self.local:
    104       return
    105 
    106     self._ssh_identity = None
    107     self._ssh_args = ['-o ConnectTimeout=5', '-o StrictHostKeyChecking=no',
    108                       '-o KbdInteractiveAuthentication=no',
    109                       '-o PreferredAuthentications=publickey',
    110                       '-o UserKnownHostsFile=/dev/null', '-o ControlMaster=no']
    111 
    112     if ssh_identity:
    113       self._ssh_identity = os.path.abspath(os.path.expanduser(ssh_identity))
    114       os.chmod(self._ssh_identity, stat.S_IREAD)
    115 
    116     # Establish master SSH connection using ControlPersist.
    117     # Since only one test will be run on a remote host at a time,
    118     # the control socket filename can be telemetry@hostname.
    119     self._ssh_control_file = '/tmp/' + 'telemetry' + '@' + hostname
    120     with open(os.devnull, 'w') as devnull:
    121       subprocess.call(
    122           self.FormSSHCommandLine(['-M', '-o ControlPersist=yes']),
    123           stdin=devnull,
    124           stdout=devnull,
    125           stderr=devnull)
    126 
    127   def __enter__(self):
    128     return self
    129 
    130   def __exit__(self, *args):
    131     self.CloseConnection()
    132 
    133   @property
    134   def local(self):
    135     return not self._hostname
    136 
    137   @property
    138   def hostname(self):
    139     return self._hostname
    140 
    141   @property
    142   def ssh_port(self):
    143     return self._ssh_port
    144 
    145   def FormSSHCommandLine(self, args, extra_ssh_args=None):
    146     """Constructs a subprocess-suitable command line for `ssh'.
    147     """
    148     if self.local:
    149       # We run the command through the shell locally for consistency with
    150       # how commands are run through SSH (crbug.com/239161). This work
    151       # around will be unnecessary once we implement a persistent SSH
    152       # connection to run remote commands (crbug.com/239607).
    153       return ['sh', '-c', " ".join(args)]
    154 
    155     full_args = ['ssh', '-o ForwardX11=no', '-o ForwardX11Trusted=no', '-n',
    156                  '-S', self._ssh_control_file] + self._ssh_args
    157     if self._ssh_identity is not None:
    158       full_args.extend(['-i', self._ssh_identity])
    159     if extra_ssh_args:
    160       full_args.extend(extra_ssh_args)
    161     full_args.append('root@%s' % self._hostname)
    162     full_args.append('-p%d' % self._ssh_port)
    163     full_args.extend(args)
    164     return full_args
    165 
    166   def _FormSCPCommandLine(self, src, dst, extra_scp_args=None):
    167     """Constructs a subprocess-suitable command line for `scp'.
    168 
    169     Note: this function is not designed to work with IPv6 addresses, which need
    170     to have their addresses enclosed in brackets and a '-6' flag supplied
    171     in order to be properly parsed by `scp'.
    172     """
    173     assert not self.local, "Cannot use SCP on local target."
    174 
    175     args = ['scp', '-P', str(self._ssh_port)] + self._ssh_args
    176     if self._ssh_identity:
    177       args.extend(['-i', self._ssh_identity])
    178     if extra_scp_args:
    179       args.extend(extra_scp_args)
    180     args += [src, dst]
    181     return args
    182 
    183   def _FormSCPToRemote(self,
    184                        source,
    185                        remote_dest,
    186                        extra_scp_args=None,
    187                        user='root'):
    188     return self._FormSCPCommandLine(source,
    189                                     '%s@%s:%s' % (user, self._hostname,
    190                                                   remote_dest),
    191                                     extra_scp_args=extra_scp_args)
    192 
    193   def _FormSCPFromRemote(self,
    194                          remote_source,
    195                          dest,
    196                          extra_scp_args=None,
    197                          user='root'):
    198     return self._FormSCPCommandLine('%s@%s:%s' % (user, self._hostname,
    199                                                   remote_source),
    200                                     dest,
    201                                     extra_scp_args=extra_scp_args)
    202 
    203   def _RemoveSSHWarnings(self, toClean):
    204     """Removes specific ssh warning lines from a string.
    205 
    206     Args:
    207       toClean: A string that may be containing multiple lines.
    208 
    209     Returns:
    210       A copy of toClean with all the Warning lines removed.
    211     """
    212     # Remove the Warning about connecting to a new host for the first time.
    213     return re.sub(
    214         r'Warning: Permanently added [^\n]* to the list of known hosts.\s\n',
    215         '', toClean)
    216 
    217   def RunCmdOnDevice(self, args, cwd=None, quiet=False):
    218     stdout, stderr = GetAllCmdOutput(
    219         self.FormSSHCommandLine(args),
    220         cwd,
    221         quiet=quiet)
    222     # The initial login will add the host to the hosts file but will also print
    223     # a warning to stderr that we need to remove.
    224     stderr = self._RemoveSSHWarnings(stderr)
    225     return stdout, stderr
    226 
    227   def TryLogin(self):
    228     logging.debug('TryLogin()')
    229     assert not self.local
    230     stdout, stderr = self.RunCmdOnDevice(['echo', '$USER'], quiet=True)
    231     if stderr != '':
    232       if 'Host key verification failed' in stderr:
    233         raise LoginException(('%s host key verification failed. ' +
    234                               'SSH to it manually to fix connectivity.') %
    235                              self._hostname)
    236       if 'Operation timed out' in stderr:
    237         raise LoginException('Timed out while logging into %s' % self._hostname)
    238       if 'UNPROTECTED PRIVATE KEY FILE!' in stderr:
    239         raise LoginException('Permissions for %s are too open. To fix this,\n'
    240                              'chmod 600 %s' % (self._ssh_identity,
    241                                                self._ssh_identity))
    242       if 'Permission denied (publickey,keyboard-interactive)' in stderr:
    243         raise KeylessLoginRequiredException('Need to set up ssh auth for %s' %
    244                                             self._hostname)
    245       if 'Could not resolve hostname' in stderr:
    246         raise DNSFailureException('Unable to resolve the hostname for: %s' %
    247                                   self._hostname)
    248       raise LoginException('While logging into %s, got %s' % (self._hostname,
    249                                                               stderr))
    250     if stdout != 'root\n':
    251       raise LoginException('Logged into %s, expected $USER=root, but got %s.' %
    252                            (self._hostname, stdout))
    253 
    254   def FileExistsOnDevice(self, file_name):
    255     if self.local:
    256       return os.path.exists(file_name)
    257 
    258     stdout, stderr = self.RunCmdOnDevice(
    259         [
    260             'if', 'test', '-e', file_name, ';', 'then', 'echo', '1', ';', 'fi'
    261         ],
    262         quiet=True)
    263     if stderr != '':
    264       if "Connection timed out" in stderr:
    265         raise OSError('Machine wasn\'t responding to ssh: %s' % stderr)
    266       raise OSError('Unexpected error: %s' % stderr)
    267     exists = stdout == '1\n'
    268     logging.debug("FileExistsOnDevice(<text>, %s)->%s" % (file_name, exists))
    269     return exists
    270 
    271   def PushFile(self, filename, remote_filename):
    272     if self.local:
    273       args = ['cp', '-r', filename, remote_filename]
    274       stdout, stderr = GetAllCmdOutput(args, quiet=True)
    275       if stderr != '':
    276         raise OSError('No such file or directory %s' % stderr)
    277       return
    278 
    279     args = self._FormSCPToRemote(
    280         os.path.abspath(filename),
    281         remote_filename,
    282         extra_scp_args=['-r'])
    283 
    284     stdout, stderr = GetAllCmdOutput(args, quiet=True)
    285     stderr = self._RemoveSSHWarnings(stderr)
    286     if stderr != '':
    287       raise OSError('No such file or directory %s' % stderr)
    288 
    289   def PushContents(self, text, remote_filename):
    290     logging.debug("PushContents(<text>, %s)" % remote_filename)
    291     with tempfile.NamedTemporaryFile() as f:
    292       f.write(text)
    293       f.flush()
    294       self.PushFile(f.name, remote_filename)
    295 
    296   def GetFile(self, filename, destfile=None):
    297     """Copies a local file |filename| to |destfile| on the device.
    298 
    299     Args:
    300       filename: The name of the local source file.
    301       destfile: The name of the file to copy to, and if it is not specified
    302         then it is the basename of the source file.
    303 
    304     """
    305     logging.debug("GetFile(%s, %s)" % (filename, destfile))
    306     if self.local:
    307       if destfile is not None and destfile != filename:
    308         shutil.copyfile(filename, destfile)
    309       return
    310 
    311     if destfile is None:
    312       destfile = os.path.basename(filename)
    313     args = self._FormSCPFromRemote(filename, os.path.abspath(destfile))
    314 
    315     stdout, stderr = GetAllCmdOutput(args, quiet=True)
    316     stderr = self._RemoveSSHWarnings(stderr)
    317     if stderr != '':
    318       raise OSError('No such file or directory %s' % stderr)
    319 
    320   def GetFileContents(self, filename):
    321     """Get the contents of a file on the device.
    322 
    323     Args:
    324       filename: The name of the file on the device.
    325 
    326     Returns:
    327       A string containing the contents of the file.
    328     """
    329     # TODO: handle the self.local case
    330     assert not self.local
    331     t = tempfile.NamedTemporaryFile()
    332     self.GetFile(filename, t.name)
    333     with open(t.name, 'r') as f2:
    334       res = f2.read()
    335       logging.debug("GetFileContents(%s)->%s" % (filename, res))
    336       f2.close()
    337       return res
    338 
    339   def ListProcesses(self):
    340     """Returns (pid, cmd, ppid, state) of all processes on the device."""
    341     stdout, stderr = self.RunCmdOnDevice(
    342         [
    343             '/bin/ps', '--no-headers', '-A', '-o', 'pid,ppid,args:4096,state'
    344         ],
    345         quiet=True)
    346     assert stderr == '', stderr
    347     procs = []
    348     for l in stdout.split('\n'):
    349       if l == '':
    350         continue
    351       m = re.match(r'^\s*(\d+)\s+(\d+)\s+(.+)\s+(.+)', l, re.DOTALL)
    352       assert m
    353       procs.append((int(m.group(1)), m.group(3).rstrip(), int(m.group(2)),
    354                     m.group(4)))
    355     logging.debug("ListProcesses(<predicate>)->[%i processes]" % len(procs))
    356     return procs
    357 
    358   def _GetSessionManagerPid(self, procs):
    359     """Returns the pid of the session_manager process, given the list of
    360     processes."""
    361     for pid, process, _, _ in procs:
    362       argv = process.split()
    363       if argv and os.path.basename(argv[0]) == 'session_manager':
    364         return pid
    365     return None
    366 
    367   def GetChromeProcess(self):
    368     """Locates the the main chrome browser process.
    369 
    370     Chrome on cros is usually in /opt/google/chrome, but could be in
    371     /usr/local/ for developer workflows - debug chrome is too large to fit on
    372     rootfs.
    373 
    374     Chrome spawns multiple processes for renderers. pids wrap around after they
    375     are exhausted so looking for the smallest pid is not always correct. We
    376     locate the session_manager's pid, and look for the chrome process that's an
    377     immediate child. This is the main browser process.
    378     """
    379     procs = self.ListProcesses()
    380     session_manager_pid = self._GetSessionManagerPid(procs)
    381     if not session_manager_pid:
    382       return None
    383 
    384     # Find the chrome process that is the child of the session_manager.
    385     for pid, process, ppid, _ in procs:
    386       if ppid != session_manager_pid:
    387         continue
    388       for regex in _CHROME_PROCESS_REGEX:
    389         path_match = re.match(regex, process)
    390         if path_match is not None:
    391           return {'pid': pid, 'path': path_match.group(), 'args': process}
    392     return None
    393 
    394   def GetChromePid(self):
    395     """Returns pid of main chrome browser process."""
    396     result = self.GetChromeProcess()
    397     if result and 'pid' in result:
    398       return result['pid']
    399     return None
    400 
    401   def RmRF(self, filename):
    402     logging.debug("rm -rf %s" % filename)
    403     self.RunCmdOnDevice(['rm', '-rf', filename], quiet=True)
    404 
    405   def Chown(self, filename):
    406     self.RunCmdOnDevice(['chown', '-R', 'chronos:chronos', filename])
    407 
    408   def KillAllMatching(self, predicate):
    409     kills = ['kill', '-KILL']
    410     for pid, cmd, _, _ in self.ListProcesses():
    411       if predicate(cmd):
    412         logging.info('Killing %s, pid %d' % cmd, pid)
    413         kills.append(pid)
    414     logging.debug("KillAllMatching(<predicate>)->%i" % (len(kills) - 2))
    415     if len(kills) > 2:
    416       self.RunCmdOnDevice(kills, quiet=True)
    417     return len(kills) - 2
    418 
    419   def IsServiceRunning(self, service_name):
    420     stdout, stderr = self.RunCmdOnDevice(['status', service_name], quiet=True)
    421     assert stderr == '', stderr
    422     running = 'running, process' in stdout
    423     logging.debug("IsServiceRunning(%s)->%s" % (service_name, running))
    424     return running
    425 
    426   def GetRemotePort(self):
    427     netstat = self.RunCmdOnDevice(['netstat', '-ant'])
    428     netstat = netstat[0].split('\n')
    429     ports_in_use = []
    430 
    431     for line in netstat[2:]:
    432       if not line:
    433         continue
    434       address_in_use = line.split()[3]
    435       port_in_use = address_in_use.split(':')[-1]
    436       ports_in_use.append(int(port_in_use))
    437 
    438     ports_in_use.extend(self._reserved_ports)
    439 
    440     new_port = sorted(ports_in_use)[-1] + 1
    441     self._reserved_ports.append(new_port)
    442 
    443     return new_port
    444 
    445   def IsHTTPServerRunningOnPort(self, port):
    446     wget_output = self.RunCmdOnDevice(['wget', 'localhost:%i' % (port), '-T1',
    447                                        '-t1'])
    448 
    449     if 'Connection refused' in wget_output[1]:
    450       return False
    451 
    452     return True
    453 
    454   def FilesystemMountedAt(self, path):
    455     """Returns the filesystem mounted at |path|"""
    456     df_out, _ = self.RunCmdOnDevice(['/bin/df', path])
    457     df_ary = df_out.split('\n')
    458     # 3 lines for title, mount info, and empty line.
    459     if len(df_ary) == 3:
    460       line_ary = df_ary[1].split()
    461       if line_ary:
    462         return line_ary[0]
    463     return None
    464 
    465   def CryptohomePath(self, user):
    466     """Returns the cryptohome mount point for |user|."""
    467     stdout, stderr = self.RunCmdOnDevice(['cryptohome-path', 'user', "'%s'" %
    468                                           user])
    469     if stderr != '':
    470       raise OSError('cryptohome-path failed: %s' % stderr)
    471     return stdout.rstrip()
    472 
    473   def IsCryptohomeMounted(self, username, is_guest):
    474     """Returns True iff |user|'s cryptohome is mounted."""
    475     profile_path = self.CryptohomePath(username)
    476     mount = self.FilesystemMountedAt(profile_path)
    477     mount_prefix = 'guestfs' if is_guest else '/home/.shadow/'
    478     return mount and mount.startswith(mount_prefix)
    479 
    480   def TakeScreenShot(self, screenshot_prefix):
    481     """Takes a screenshot, useful for debugging failures."""
    482     # TODO(achuith): Find a better location for screenshots. Cros autotests
    483     # upload everything in /var/log so use /var/log/screenshots for now.
    484     SCREENSHOT_DIR = '/var/log/screenshots/'
    485     SCREENSHOT_EXT = '.png'
    486 
    487     self.RunCmdOnDevice(['mkdir', '-p', SCREENSHOT_DIR])
    488     # Large number of screenshots can increase hardware lab bandwidth
    489     # dramatically, so keep this number low. crbug.com/524814.
    490     for i in xrange(2):
    491       screenshot_file = ('%s%s-%d%s' %
    492                          (SCREENSHOT_DIR, screenshot_prefix, i, SCREENSHOT_EXT))
    493       if not self.FileExistsOnDevice(screenshot_file):
    494         self.RunCmdOnDevice([
    495             '/usr/local/autotest/bin/screenshot.py', screenshot_file
    496         ])
    497         return
    498     logging.warning('screenshot directory full.')
    499 
    500   def RestartUI(self, clear_enterprise_policy):
    501     logging.info('(Re)starting the ui (logs the user out)')
    502     if clear_enterprise_policy:
    503       self.RunCmdOnDevice(['stop', 'ui'])
    504       self.RmRF('/var/lib/whitelist/*')
    505       self.RmRF(r'/home/chronos/Local\ State')
    506 
    507     if self.IsServiceRunning('ui'):
    508       self.RunCmdOnDevice(['restart', 'ui'])
    509     else:
    510       self.RunCmdOnDevice(['start', 'ui'])
    511 
    512   def CloseConnection(self):
    513     if not self.local:
    514       with open(os.devnull, 'w') as devnull:
    515         subprocess.call(
    516             self.FormSSHCommandLine(['-O', 'exit', self._hostname]),
    517             stdout=devnull,
    518             stderr=devnull)
    519