Home | History | Annotate | Download | only in server
      1 # Copyright 2008 Google Inc, Martin J. Bligh <mbligh (at] google.com>,
      2 #                Benjamin Poirier, Ryan Stutsman
      3 # Released under the GPL v2
      4 """
      5 Miscellaneous small functions.
      6 
      7 DO NOT import this file directly - it is mixed in by server/utils.py,
      8 import that instead
      9 """
     10 
     11 import atexit, os, re, shutil, textwrap, sys, tempfile, types
     12 
     13 from autotest_lib.client.common_lib import barrier, utils
     14 from autotest_lib.server import subcommand
     15 
     16 
     17 # A dictionary of pid and a list of tmpdirs for that pid
     18 __tmp_dirs = {}
     19 
     20 
     21 def scp_remote_escape(filename):
     22     """
     23     Escape special characters from a filename so that it can be passed
     24     to scp (within double quotes) as a remote file.
     25 
     26     Bis-quoting has to be used with scp for remote files, "bis-quoting"
     27     as in quoting x 2
     28     scp does not support a newline in the filename
     29 
     30     Args:
     31             filename: the filename string to escape.
     32 
     33     Returns:
     34             The escaped filename string. The required englobing double
     35             quotes are NOT added and so should be added at some point by
     36             the caller.
     37     """
     38     escape_chars= r' !"$&' "'" r'()*,:;<=>?[\]^`{|}'
     39 
     40     new_name= []
     41     for char in filename:
     42         if char in escape_chars:
     43             new_name.append("\\%s" % (char,))
     44         else:
     45             new_name.append(char)
     46 
     47     return utils.sh_escape("".join(new_name))
     48 
     49 
     50 def get(location, local_copy = False):
     51     """Get a file or directory to a local temporary directory.
     52 
     53     Args:
     54             location: the source of the material to get. This source may
     55                     be one of:
     56                     * a local file or directory
     57                     * a URL (http or ftp)
     58                     * a python file-like object
     59 
     60     Returns:
     61             The location of the file or directory where the requested
     62             content was saved. This will be contained in a temporary
     63             directory on the local host. If the material to get was a
     64             directory, the location will contain a trailing '/'
     65     """
     66     tmpdir = get_tmp_dir()
     67 
     68     # location is a file-like object
     69     if hasattr(location, "read"):
     70         tmpfile = os.path.join(tmpdir, "file")
     71         tmpfileobj = file(tmpfile, 'w')
     72         shutil.copyfileobj(location, tmpfileobj)
     73         tmpfileobj.close()
     74         return tmpfile
     75 
     76     if isinstance(location, types.StringTypes):
     77         # location is a URL
     78         if location.startswith('http') or location.startswith('ftp'):
     79             tmpfile = os.path.join(tmpdir, os.path.basename(location))
     80             utils.urlretrieve(location, tmpfile)
     81             return tmpfile
     82         # location is a local path
     83         elif os.path.exists(os.path.abspath(location)):
     84             if not local_copy:
     85                 if os.path.isdir(location):
     86                     return location.rstrip('/') + '/'
     87                 else:
     88                     return location
     89             tmpfile = os.path.join(tmpdir, os.path.basename(location))
     90             if os.path.isdir(location):
     91                 tmpfile += '/'
     92                 shutil.copytree(location, tmpfile, symlinks=True)
     93                 return tmpfile
     94             shutil.copyfile(location, tmpfile)
     95             return tmpfile
     96         # location is just a string, dump it to a file
     97         else:
     98             tmpfd, tmpfile = tempfile.mkstemp(dir=tmpdir)
     99             tmpfileobj = os.fdopen(tmpfd, 'w')
    100             tmpfileobj.write(location)
    101             tmpfileobj.close()
    102             return tmpfile
    103 
    104 
    105 def get_tmp_dir():
    106     """Return the pathname of a directory on the host suitable
    107     for temporary file storage.
    108 
    109     The directory and its content will be deleted automatically
    110     at the end of the program execution if they are still present.
    111     """
    112     dir_name = tempfile.mkdtemp(prefix="autoserv-")
    113     pid = os.getpid()
    114     if not pid in __tmp_dirs:
    115         __tmp_dirs[pid] = []
    116     __tmp_dirs[pid].append(dir_name)
    117     return dir_name
    118 
    119 
    120 def __clean_tmp_dirs():
    121     """Erase temporary directories that were created by the get_tmp_dir()
    122     function and that are still present.
    123     """
    124     pid = os.getpid()
    125     if pid not in __tmp_dirs:
    126         return
    127     for dir in __tmp_dirs[pid]:
    128         try:
    129             shutil.rmtree(dir)
    130         except OSError, e:
    131             if e.errno == 2:
    132                 pass
    133     __tmp_dirs[pid] = []
    134 atexit.register(__clean_tmp_dirs)
    135 subcommand.subcommand.register_join_hook(lambda _: __clean_tmp_dirs())
    136 
    137 
    138 def unarchive(host, source_material):
    139     """Uncompress and untar an archive on a host.
    140 
    141     If the "source_material" is compresses (according to the file
    142     extension) it will be uncompressed. Supported compression formats
    143     are gzip and bzip2. Afterwards, if the source_material is a tar
    144     archive, it will be untarred.
    145 
    146     Args:
    147             host: the host object on which the archive is located
    148             source_material: the path of the archive on the host
    149 
    150     Returns:
    151             The file or directory name of the unarchived source material.
    152             If the material is a tar archive, it will be extracted in the
    153             directory where it is and the path returned will be the first
    154             entry in the archive, assuming it is the topmost directory.
    155             If the material is not an archive, nothing will be done so this
    156             function is "harmless" when it is "useless".
    157     """
    158     # uncompress
    159     if (source_material.endswith(".gz") or
    160             source_material.endswith(".gzip")):
    161         host.run('gunzip "%s"' % (utils.sh_escape(source_material)))
    162         source_material= ".".join(source_material.split(".")[:-1])
    163     elif source_material.endswith("bz2"):
    164         host.run('bunzip2 "%s"' % (utils.sh_escape(source_material)))
    165         source_material= ".".join(source_material.split(".")[:-1])
    166 
    167     # untar
    168     if source_material.endswith(".tar"):
    169         retval= host.run('tar -C "%s" -xvf "%s"' % (
    170                 utils.sh_escape(os.path.dirname(source_material)),
    171                 utils.sh_escape(source_material),))
    172         source_material= os.path.join(os.path.dirname(source_material),
    173                 retval.stdout.split()[0])
    174 
    175     return source_material
    176 
    177 
    178 def get_server_dir():
    179     path = os.path.dirname(sys.modules['autotest_lib.server.utils'].__file__)
    180     return os.path.abspath(path)
    181 
    182 
    183 def find_pid(command):
    184     for line in utils.system_output('ps -eo pid,cmd').rstrip().split('\n'):
    185         (pid, cmd) = line.split(None, 1)
    186         if re.search(command, cmd):
    187             return int(pid)
    188     return None
    189 
    190 
    191 def nohup(command, stdout='/dev/null', stderr='/dev/null', background=True,
    192                                                                 env = {}):
    193     cmd = ' '.join(key+'='+val for key, val in env.iteritems())
    194     cmd += ' nohup ' + command
    195     cmd += ' > %s' % stdout
    196     if stdout == stderr:
    197         cmd += ' 2>&1'
    198     else:
    199         cmd += ' 2> %s' % stderr
    200     if background:
    201         cmd += ' &'
    202     utils.system(cmd)
    203 
    204 
    205 def default_mappings(machines):
    206     """
    207     Returns a simple mapping in which all machines are assigned to the
    208     same key.  Provides the default behavior for
    209     form_ntuples_from_machines. """
    210     mappings = {}
    211     failures = []
    212 
    213     mach = machines[0]
    214     mappings['ident'] = [mach]
    215     if len(machines) > 1:
    216         machines = machines[1:]
    217         for machine in machines:
    218             mappings['ident'].append(machine)
    219 
    220     return (mappings, failures)
    221 
    222 
    223 def form_ntuples_from_machines(machines, n=2, mapping_func=default_mappings):
    224     """Returns a set of ntuples from machines where the machines in an
    225        ntuple are in the same mapping, and a set of failures which are
    226        (machine name, reason) tuples."""
    227     ntuples = []
    228     (mappings, failures) = mapping_func(machines)
    229 
    230     # now run through the mappings and create n-tuples.
    231     # throw out the odd guys out
    232     for key in mappings:
    233         key_machines = mappings[key]
    234         total_machines = len(key_machines)
    235 
    236         # form n-tuples
    237         while len(key_machines) >= n:
    238             ntuples.append(key_machines[0:n])
    239             key_machines = key_machines[n:]
    240 
    241         for mach in key_machines:
    242             failures.append((mach, "machine can not be tupled"))
    243 
    244     return (ntuples, failures)
    245 
    246 
    247 def parse_machine(machine, user='root', password='', port=22):
    248     """
    249     Parse the machine string user:pass@host:port and return it separately,
    250     if the machine string is not complete, use the default parameters
    251     when appropriate.
    252     """
    253 
    254     if '@' in machine:
    255         user, machine = machine.split('@', 1)
    256 
    257     if ':' in user:
    258         user, password = user.split(':', 1)
    259 
    260     # Brackets are required to protect an IPv6 address whenever a
    261     # [xx::xx]:port number (or a file [xx::xx]:/path/) is appended to
    262     # it. Do not attempt to extract a (non-existent) port number from
    263     # an unprotected/bare IPv6 address "xx::xx".
    264     # In the Python >= 3.3 future, 'import ipaddress' will parse
    265     # addresses; and maybe more.
    266     bare_ipv6 = '[' != machine[0] and re.search(r':.*:', machine)
    267 
    268     # Extract trailing :port number if any.
    269     if not bare_ipv6 and re.search(r':\d*$', machine):
    270         machine, port = machine.rsplit(':', 1)
    271         port = int(port)
    272 
    273     # Strip any IPv6 brackets (ssh does not support them).
    274     # We'll add them back later for rsync, scp, etc.
    275     if machine[0] == '[' and machine[-1] == ']':
    276         machine = machine[1:-1]
    277 
    278     if not machine or not user:
    279         raise ValueError
    280 
    281     return machine, user, password, port
    282 
    283 
    284 def get_public_key():
    285     """
    286     Return a valid string ssh public key for the user executing autoserv or
    287     autotest. If there's no DSA or RSA public key, create a DSA keypair with
    288     ssh-keygen and return it.
    289     """
    290 
    291     ssh_conf_path = os.path.expanduser('~/.ssh')
    292 
    293     dsa_public_key_path = os.path.join(ssh_conf_path, 'id_dsa.pub')
    294     dsa_private_key_path = os.path.join(ssh_conf_path, 'id_dsa')
    295 
    296     rsa_public_key_path = os.path.join(ssh_conf_path, 'id_rsa.pub')
    297     rsa_private_key_path = os.path.join(ssh_conf_path, 'id_rsa')
    298 
    299     has_dsa_keypair = os.path.isfile(dsa_public_key_path) and \
    300         os.path.isfile(dsa_private_key_path)
    301     has_rsa_keypair = os.path.isfile(rsa_public_key_path) and \
    302         os.path.isfile(rsa_private_key_path)
    303 
    304     if has_dsa_keypair:
    305         print 'DSA keypair found, using it'
    306         public_key_path = dsa_public_key_path
    307 
    308     elif has_rsa_keypair:
    309         print 'RSA keypair found, using it'
    310         public_key_path = rsa_public_key_path
    311 
    312     else:
    313         print 'Neither RSA nor DSA keypair found, creating DSA ssh key pair'
    314         utils.system('ssh-keygen -t dsa -q -N "" -f %s' % dsa_private_key_path)
    315         public_key_path = dsa_public_key_path
    316 
    317     public_key = open(public_key_path, 'r')
    318     public_key_str = public_key.read()
    319     public_key.close()
    320 
    321     return public_key_str
    322 
    323 
    324 def get_sync_control_file(control, host_name, host_num,
    325                           instance, num_jobs, port_base=63100):
    326     """
    327     This function is used when there is a need to run more than one
    328     job simultaneously starting exactly at the same time. It basically returns
    329     a modified control file (containing the synchronization code prepended)
    330     whenever it is ready to run the control file. The synchronization
    331     is done using barriers to make sure that the jobs start at the same time.
    332 
    333     Here is how the synchronization is done to make sure that the tests
    334     start at exactly the same time on the client.
    335     sc_bar is a server barrier and s_bar, c_bar are the normal barriers
    336 
    337                       Job1              Job2         ......      JobN
    338      Server:   |                        sc_bar
    339      Server:   |                        s_bar        ......      s_bar
    340      Server:   |      at.run()         at.run()      ......      at.run()
    341      ----------|------------------------------------------------------
    342      Client    |      sc_bar
    343      Client    |      c_bar             c_bar        ......      c_bar
    344      Client    |    <run test>         <run test>    ......     <run test>
    345 
    346     @param control: The control file which to which the above synchronization
    347             code will be prepended.
    348     @param host_name: The host name on which the job is going to run.
    349     @param host_num: (non negative) A number to identify the machine so that
    350             we have different sets of s_bar_ports for each of the machines.
    351     @param instance: The number of the job
    352     @param num_jobs: Total number of jobs that are going to run in parallel
    353             with this job starting at the same time.
    354     @param port_base: Port number that is used to derive the actual barrier
    355             ports.
    356 
    357     @returns The modified control file.
    358     """
    359     sc_bar_port = port_base
    360     c_bar_port = port_base
    361     if host_num < 0:
    362         print "Please provide a non negative number for the host"
    363         return None
    364     s_bar_port = port_base + 1 + host_num # The set of s_bar_ports are
    365                                           # the same for a given machine
    366 
    367     sc_bar_timeout = 180
    368     s_bar_timeout = c_bar_timeout = 120
    369 
    370     # The barrier code snippet is prepended into the conrol file
    371     # dynamically before at.run() is called finally.
    372     control_new = []
    373 
    374     # jobid is the unique name used to identify the processes
    375     # trying to reach the barriers
    376     jobid = "%s#%d" % (host_name, instance)
    377 
    378     rendv = []
    379     # rendvstr is a temp holder for the rendezvous list of the processes
    380     for n in range(num_jobs):
    381         rendv.append("'%s#%d'" % (host_name, n))
    382     rendvstr = ",".join(rendv)
    383 
    384     if instance == 0:
    385         # Do the setup and wait at the server barrier
    386         # Clean up the tmp and the control dirs for the first instance
    387         control_new.append('if os.path.exists(job.tmpdir):')
    388         control_new.append("\t system('umount -f %s > /dev/null"
    389                            "2> /dev/null' % job.tmpdir,"
    390                            "ignore_status=True)")
    391         control_new.append("\t system('rm -rf ' + job.tmpdir)")
    392         control_new.append(
    393             'b0 = job.barrier("%s", "sc_bar", %d, port=%d)'
    394             % (jobid, sc_bar_timeout, sc_bar_port))
    395         control_new.append(
    396         'b0.rendezvous_servers("PARALLEL_MASTER", "%s")'
    397          % jobid)
    398 
    399     elif instance == 1:
    400         # Wait at the server barrier to wait for instance=0
    401         # process to complete setup
    402         b0 = barrier.barrier("PARALLEL_MASTER", "sc_bar", sc_bar_timeout,
    403                      port=sc_bar_port)
    404         b0.rendezvous_servers("PARALLEL_MASTER", jobid)
    405 
    406         if(num_jobs > 2):
    407             b1 = barrier.barrier(jobid, "s_bar", s_bar_timeout,
    408                          port=s_bar_port)
    409             b1.rendezvous(rendvstr)
    410 
    411     else:
    412         # For the rest of the clients
    413         b2 = barrier.barrier(jobid, "s_bar", s_bar_timeout, port=s_bar_port)
    414         b2.rendezvous(rendvstr)
    415 
    416     # Client side barrier for all the tests to start at the same time
    417     control_new.append('b1 = job.barrier("%s", "c_bar", %d, port=%d)'
    418                     % (jobid, c_bar_timeout, c_bar_port))
    419     control_new.append("b1.rendezvous(%s)" % rendvstr)
    420 
    421     # Stick in the rest of the control file
    422     control_new.append(control)
    423 
    424     return "\n".join(control_new)
    425