Home | History | Annotate | Download | only in server
      1 # Copyright 2008 Google Inc, Martin J. Bligh <mbligh (at] google.com>,
      2 #                Benjamin Poirier, Ryan Stutsman
      3 # Released under the GPL v2
      4 """
      5 Miscellaneous small functions.
      6 
      7 DO NOT import this file directly - it is mixed in by server/utils.py,
      8 import that instead
      9 """
     10 
     11 import atexit, os, re, shutil, textwrap, sys, tempfile, types
     12 
     13 from autotest_lib.client.common_lib import barrier, utils
     14 from autotest_lib.server import subcommand
     15 
     16 
     17 # A dictionary of pid and a list of tmpdirs for that pid
     18 __tmp_dirs = {}
     19 
     20 
     21 def scp_remote_escape(filename):
     22     """
     23     Escape special characters from a filename so that it can be passed
     24     to scp (within double quotes) as a remote file.
     25 
     26     Bis-quoting has to be used with scp for remote files, "bis-quoting"
     27     as in quoting x 2
     28     scp does not support a newline in the filename
     29 
     30     Args:
     31             filename: the filename string to escape.
     32 
     33     Returns:
     34             The escaped filename string. The required englobing double
     35             quotes are NOT added and so should be added at some point by
     36             the caller.
     37     """
     38     escape_chars= r' !"$&' "'" r'()*,:;<=>?[\]^`{|}'
     39 
     40     new_name= []
     41     for char in filename:
     42         if char in escape_chars:
     43             new_name.append("\\%s" % (char,))
     44         else:
     45             new_name.append(char)
     46 
     47     return utils.sh_escape("".join(new_name))
     48 
     49 
     50 def get(location, local_copy = False):
     51     """Get a file or directory to a local temporary directory.
     52 
     53     Args:
     54             location: the source of the material to get. This source may
     55                     be one of:
     56                     * a local file or directory
     57                     * a URL (http or ftp)
     58                     * a python file-like object
     59 
     60     Returns:
     61             The location of the file or directory where the requested
     62             content was saved. This will be contained in a temporary
     63             directory on the local host. If the material to get was a
     64             directory, the location will contain a trailing '/'
     65     """
     66     tmpdir = get_tmp_dir()
     67 
     68     # location is a file-like object
     69     if hasattr(location, "read"):
     70         tmpfile = os.path.join(tmpdir, "file")
     71         tmpfileobj = file(tmpfile, 'w')
     72         shutil.copyfileobj(location, tmpfileobj)
     73         tmpfileobj.close()
     74         return tmpfile
     75 
     76     if isinstance(location, types.StringTypes):
     77         # location is a URL
     78         if location.startswith('http') or location.startswith('ftp'):
     79             tmpfile = os.path.join(tmpdir, os.path.basename(location))
     80             utils.urlretrieve(location, tmpfile)
     81             return tmpfile
     82         # location is a local path
     83         elif os.path.exists(os.path.abspath(location)):
     84             if not local_copy:
     85                 if os.path.isdir(location):
     86                     return location.rstrip('/') + '/'
     87                 else:
     88                     return location
     89             tmpfile = os.path.join(tmpdir, os.path.basename(location))
     90             if os.path.isdir(location):
     91                 tmpfile += '/'
     92                 shutil.copytree(location, tmpfile, symlinks=True)
     93                 return tmpfile
     94             shutil.copyfile(location, tmpfile)
     95             return tmpfile
     96         # location is just a string, dump it to a file
     97         else:
     98             tmpfd, tmpfile = tempfile.mkstemp(dir=tmpdir)
     99             tmpfileobj = os.fdopen(tmpfd, 'w')
    100             tmpfileobj.write(location)
    101             tmpfileobj.close()
    102             return tmpfile
    103 
    104 
    105 def get_tmp_dir():
    106     """Return the pathname of a directory on the host suitable
    107     for temporary file storage.
    108 
    109     The directory and its content will be deleted automatically
    110     at the end of the program execution if they are still present.
    111     """
    112     dir_name = tempfile.mkdtemp(prefix="autoserv-")
    113     pid = os.getpid()
    114     if not pid in __tmp_dirs:
    115         __tmp_dirs[pid] = []
    116     __tmp_dirs[pid].append(dir_name)
    117     return dir_name
    118 
    119 
    120 def __clean_tmp_dirs():
    121     """Erase temporary directories that were created by the get_tmp_dir()
    122     function and that are still present.
    123     """
    124     pid = os.getpid()
    125     if pid not in __tmp_dirs:
    126         return
    127     for dir in __tmp_dirs[pid]:
    128         try:
    129             shutil.rmtree(dir)
    130         except OSError, e:
    131             if e.errno == 2:
    132                 pass
    133     __tmp_dirs[pid] = []
    134 atexit.register(__clean_tmp_dirs)
    135 subcommand.subcommand.register_join_hook(lambda _: __clean_tmp_dirs())
    136 
    137 
    138 def unarchive(host, source_material):
    139     """Uncompress and untar an archive on a host.
    140 
    141     If the "source_material" is compresses (according to the file
    142     extension) it will be uncompressed. Supported compression formats
    143     are gzip and bzip2. Afterwards, if the source_material is a tar
    144     archive, it will be untarred.
    145 
    146     Args:
    147             host: the host object on which the archive is located
    148             source_material: the path of the archive on the host
    149 
    150     Returns:
    151             The file or directory name of the unarchived source material.
    152             If the material is a tar archive, it will be extracted in the
    153             directory where it is and the path returned will be the first
    154             entry in the archive, assuming it is the topmost directory.
    155             If the material is not an archive, nothing will be done so this
    156             function is "harmless" when it is "useless".
    157     """
    158     # uncompress
    159     if (source_material.endswith(".gz") or
    160             source_material.endswith(".gzip")):
    161         host.run('gunzip "%s"' % (utils.sh_escape(source_material)))
    162         source_material= ".".join(source_material.split(".")[:-1])
    163     elif source_material.endswith("bz2"):
    164         host.run('bunzip2 "%s"' % (utils.sh_escape(source_material)))
    165         source_material= ".".join(source_material.split(".")[:-1])
    166 
    167     # untar
    168     if source_material.endswith(".tar"):
    169         retval= host.run('tar -C "%s" -xvf "%s"' % (
    170                 utils.sh_escape(os.path.dirname(source_material)),
    171                 utils.sh_escape(source_material),))
    172         source_material= os.path.join(os.path.dirname(source_material),
    173                 retval.stdout.split()[0])
    174 
    175     return source_material
    176 
    177 
    178 def get_server_dir():
    179     path = os.path.dirname(sys.modules['autotest_lib.server.utils'].__file__)
    180     return os.path.abspath(path)
    181 
    182 
    183 def find_pid(command):
    184     for line in utils.system_output('ps -eo pid,cmd').rstrip().split('\n'):
    185         (pid, cmd) = line.split(None, 1)
    186         if re.search(command, cmd):
    187             return int(pid)
    188     return None
    189 
    190 
    191 def default_mappings(machines):
    192     """
    193     Returns a simple mapping in which all machines are assigned to the
    194     same key.  Provides the default behavior for
    195     form_ntuples_from_machines. """
    196     mappings = {}
    197     failures = []
    198 
    199     mach = machines[0]
    200     mappings['ident'] = [mach]
    201     if len(machines) > 1:
    202         machines = machines[1:]
    203         for machine in machines:
    204             mappings['ident'].append(machine)
    205 
    206     return (mappings, failures)
    207 
    208 
    209 def form_ntuples_from_machines(machines, n=2, mapping_func=default_mappings):
    210     """Returns a set of ntuples from machines where the machines in an
    211        ntuple are in the same mapping, and a set of failures which are
    212        (machine name, reason) tuples."""
    213     ntuples = []
    214     (mappings, failures) = mapping_func(machines)
    215 
    216     # now run through the mappings and create n-tuples.
    217     # throw out the odd guys out
    218     for key in mappings:
    219         key_machines = mappings[key]
    220         total_machines = len(key_machines)
    221 
    222         # form n-tuples
    223         while len(key_machines) >= n:
    224             ntuples.append(key_machines[0:n])
    225             key_machines = key_machines[n:]
    226 
    227         for mach in key_machines:
    228             failures.append((mach, "machine can not be tupled"))
    229 
    230     return (ntuples, failures)
    231 
    232 
    233 def parse_machine(machine, user='root', password='', port=22):
    234     """
    235     Parse the machine string user:pass@host:port and return it separately,
    236     if the machine string is not complete, use the default parameters
    237     when appropriate.
    238     """
    239 
    240     if '@' in machine:
    241         user, machine = machine.split('@', 1)
    242 
    243     if ':' in user:
    244         user, password = user.split(':', 1)
    245 
    246     # Brackets are required to protect an IPv6 address whenever a
    247     # [xx::xx]:port number (or a file [xx::xx]:/path/) is appended to
    248     # it. Do not attempt to extract a (non-existent) port number from
    249     # an unprotected/bare IPv6 address "xx::xx".
    250     # In the Python >= 3.3 future, 'import ipaddress' will parse
    251     # addresses; and maybe more.
    252     bare_ipv6 = '[' != machine[0] and re.search(r':.*:', machine)
    253 
    254     # Extract trailing :port number if any.
    255     if not bare_ipv6 and re.search(r':\d*$', machine):
    256         machine, port = machine.rsplit(':', 1)
    257         port = int(port)
    258 
    259     # Strip any IPv6 brackets (ssh does not support them).
    260     # We'll add them back later for rsync, scp, etc.
    261     if machine[0] == '[' and machine[-1] == ']':
    262         machine = machine[1:-1]
    263 
    264     if not machine or not user:
    265         raise ValueError
    266 
    267     return machine, user, password, port
    268 
    269 
    270 def get_public_key():
    271     """
    272     Return a valid string ssh public key for the user executing autoserv or
    273     autotest. If there's no DSA or RSA public key, create a DSA keypair with
    274     ssh-keygen and return it.
    275     """
    276 
    277     ssh_conf_path = os.path.expanduser('~/.ssh')
    278 
    279     dsa_public_key_path = os.path.join(ssh_conf_path, 'id_dsa.pub')
    280     dsa_private_key_path = os.path.join(ssh_conf_path, 'id_dsa')
    281 
    282     rsa_public_key_path = os.path.join(ssh_conf_path, 'id_rsa.pub')
    283     rsa_private_key_path = os.path.join(ssh_conf_path, 'id_rsa')
    284 
    285     has_dsa_keypair = os.path.isfile(dsa_public_key_path) and \
    286         os.path.isfile(dsa_private_key_path)
    287     has_rsa_keypair = os.path.isfile(rsa_public_key_path) and \
    288         os.path.isfile(rsa_private_key_path)
    289 
    290     if has_dsa_keypair:
    291         print 'DSA keypair found, using it'
    292         public_key_path = dsa_public_key_path
    293 
    294     elif has_rsa_keypair:
    295         print 'RSA keypair found, using it'
    296         public_key_path = rsa_public_key_path
    297 
    298     else:
    299         print 'Neither RSA nor DSA keypair found, creating DSA ssh key pair'
    300         utils.system('ssh-keygen -t dsa -q -N "" -f %s' % dsa_private_key_path)
    301         public_key_path = dsa_public_key_path
    302 
    303     public_key = open(public_key_path, 'r')
    304     public_key_str = public_key.read()
    305     public_key.close()
    306 
    307     return public_key_str
    308