1 # Copyright 2008 Google Inc, Martin J. Bligh <mbligh (at] google.com>, 2 # Benjamin Poirier, Ryan Stutsman 3 # Released under the GPL v2 4 """ 5 Miscellaneous small functions. 6 7 DO NOT import this file directly - it is mixed in by server/utils.py, 8 import that instead 9 """ 10 11 import atexit, os, re, shutil, textwrap, sys, tempfile, types 12 13 from autotest_lib.client.common_lib import barrier, utils 14 from autotest_lib.server import subcommand 15 16 17 # A dictionary of pid and a list of tmpdirs for that pid 18 __tmp_dirs = {} 19 20 21 def scp_remote_escape(filename): 22 """ 23 Escape special characters from a filename so that it can be passed 24 to scp (within double quotes) as a remote file. 25 26 Bis-quoting has to be used with scp for remote files, "bis-quoting" 27 as in quoting x 2 28 scp does not support a newline in the filename 29 30 Args: 31 filename: the filename string to escape. 32 33 Returns: 34 The escaped filename string. The required englobing double 35 quotes are NOT added and so should be added at some point by 36 the caller. 37 """ 38 escape_chars= r' !"$&' "'" r'()*,:;<=>?[\]^`{|}' 39 40 new_name= [] 41 for char in filename: 42 if char in escape_chars: 43 new_name.append("\\%s" % (char,)) 44 else: 45 new_name.append(char) 46 47 return utils.sh_escape("".join(new_name)) 48 49 50 def get(location, local_copy = False): 51 """Get a file or directory to a local temporary directory. 52 53 Args: 54 location: the source of the material to get. This source may 55 be one of: 56 * a local file or directory 57 * a URL (http or ftp) 58 * a python file-like object 59 60 Returns: 61 The location of the file or directory where the requested 62 content was saved. This will be contained in a temporary 63 directory on the local host. If the material to get was a 64 directory, the location will contain a trailing '/' 65 """ 66 tmpdir = get_tmp_dir() 67 68 # location is a file-like object 69 if hasattr(location, "read"): 70 tmpfile = os.path.join(tmpdir, "file") 71 tmpfileobj = file(tmpfile, 'w') 72 shutil.copyfileobj(location, tmpfileobj) 73 tmpfileobj.close() 74 return tmpfile 75 76 if isinstance(location, types.StringTypes): 77 # location is a URL 78 if location.startswith('http') or location.startswith('ftp'): 79 tmpfile = os.path.join(tmpdir, os.path.basename(location)) 80 utils.urlretrieve(location, tmpfile) 81 return tmpfile 82 # location is a local path 83 elif os.path.exists(os.path.abspath(location)): 84 if not local_copy: 85 if os.path.isdir(location): 86 return location.rstrip('/') + '/' 87 else: 88 return location 89 tmpfile = os.path.join(tmpdir, os.path.basename(location)) 90 if os.path.isdir(location): 91 tmpfile += '/' 92 shutil.copytree(location, tmpfile, symlinks=True) 93 return tmpfile 94 shutil.copyfile(location, tmpfile) 95 return tmpfile 96 # location is just a string, dump it to a file 97 else: 98 tmpfd, tmpfile = tempfile.mkstemp(dir=tmpdir) 99 tmpfileobj = os.fdopen(tmpfd, 'w') 100 tmpfileobj.write(location) 101 tmpfileobj.close() 102 return tmpfile 103 104 105 def get_tmp_dir(): 106 """Return the pathname of a directory on the host suitable 107 for temporary file storage. 108 109 The directory and its content will be deleted automatically 110 at the end of the program execution if they are still present. 111 """ 112 dir_name = tempfile.mkdtemp(prefix="autoserv-") 113 pid = os.getpid() 114 if not pid in __tmp_dirs: 115 __tmp_dirs[pid] = [] 116 __tmp_dirs[pid].append(dir_name) 117 return dir_name 118 119 120 def __clean_tmp_dirs(): 121 """Erase temporary directories that were created by the get_tmp_dir() 122 function and that are still present. 123 """ 124 pid = os.getpid() 125 if pid not in __tmp_dirs: 126 return 127 for dir in __tmp_dirs[pid]: 128 try: 129 shutil.rmtree(dir) 130 except OSError, e: 131 if e.errno == 2: 132 pass 133 __tmp_dirs[pid] = [] 134 atexit.register(__clean_tmp_dirs) 135 subcommand.subcommand.register_join_hook(lambda _: __clean_tmp_dirs()) 136 137 138 def unarchive(host, source_material): 139 """Uncompress and untar an archive on a host. 140 141 If the "source_material" is compresses (according to the file 142 extension) it will be uncompressed. Supported compression formats 143 are gzip and bzip2. Afterwards, if the source_material is a tar 144 archive, it will be untarred. 145 146 Args: 147 host: the host object on which the archive is located 148 source_material: the path of the archive on the host 149 150 Returns: 151 The file or directory name of the unarchived source material. 152 If the material is a tar archive, it will be extracted in the 153 directory where it is and the path returned will be the first 154 entry in the archive, assuming it is the topmost directory. 155 If the material is not an archive, nothing will be done so this 156 function is "harmless" when it is "useless". 157 """ 158 # uncompress 159 if (source_material.endswith(".gz") or 160 source_material.endswith(".gzip")): 161 host.run('gunzip "%s"' % (utils.sh_escape(source_material))) 162 source_material= ".".join(source_material.split(".")[:-1]) 163 elif source_material.endswith("bz2"): 164 host.run('bunzip2 "%s"' % (utils.sh_escape(source_material))) 165 source_material= ".".join(source_material.split(".")[:-1]) 166 167 # untar 168 if source_material.endswith(".tar"): 169 retval= host.run('tar -C "%s" -xvf "%s"' % ( 170 utils.sh_escape(os.path.dirname(source_material)), 171 utils.sh_escape(source_material),)) 172 source_material= os.path.join(os.path.dirname(source_material), 173 retval.stdout.split()[0]) 174 175 return source_material 176 177 178 def get_server_dir(): 179 path = os.path.dirname(sys.modules['autotest_lib.server.utils'].__file__) 180 return os.path.abspath(path) 181 182 183 def find_pid(command): 184 for line in utils.system_output('ps -eo pid,cmd').rstrip().split('\n'): 185 (pid, cmd) = line.split(None, 1) 186 if re.search(command, cmd): 187 return int(pid) 188 return None 189 190 191 def nohup(command, stdout='/dev/null', stderr='/dev/null', background=True, 192 env = {}): 193 cmd = ' '.join(key+'='+val for key, val in env.iteritems()) 194 cmd += ' nohup ' + command 195 cmd += ' > %s' % stdout 196 if stdout == stderr: 197 cmd += ' 2>&1' 198 else: 199 cmd += ' 2> %s' % stderr 200 if background: 201 cmd += ' &' 202 utils.system(cmd) 203 204 205 def default_mappings(machines): 206 """ 207 Returns a simple mapping in which all machines are assigned to the 208 same key. Provides the default behavior for 209 form_ntuples_from_machines. """ 210 mappings = {} 211 failures = [] 212 213 mach = machines[0] 214 mappings['ident'] = [mach] 215 if len(machines) > 1: 216 machines = machines[1:] 217 for machine in machines: 218 mappings['ident'].append(machine) 219 220 return (mappings, failures) 221 222 223 def form_ntuples_from_machines(machines, n=2, mapping_func=default_mappings): 224 """Returns a set of ntuples from machines where the machines in an 225 ntuple are in the same mapping, and a set of failures which are 226 (machine name, reason) tuples.""" 227 ntuples = [] 228 (mappings, failures) = mapping_func(machines) 229 230 # now run through the mappings and create n-tuples. 231 # throw out the odd guys out 232 for key in mappings: 233 key_machines = mappings[key] 234 total_machines = len(key_machines) 235 236 # form n-tuples 237 while len(key_machines) >= n: 238 ntuples.append(key_machines[0:n]) 239 key_machines = key_machines[n:] 240 241 for mach in key_machines: 242 failures.append((mach, "machine can not be tupled")) 243 244 return (ntuples, failures) 245 246 247 def parse_machine(machine, user='root', password='', port=22): 248 """ 249 Parse the machine string user:pass@host:port and return it separately, 250 if the machine string is not complete, use the default parameters 251 when appropriate. 252 """ 253 254 if '@' in machine: 255 user, machine = machine.split('@', 1) 256 257 if ':' in user: 258 user, password = user.split(':', 1) 259 260 # Brackets are required to protect an IPv6 address whenever a 261 # [xx::xx]:port number (or a file [xx::xx]:/path/) is appended to 262 # it. Do not attempt to extract a (non-existent) port number from 263 # an unprotected/bare IPv6 address "xx::xx". 264 # In the Python >= 3.3 future, 'import ipaddress' will parse 265 # addresses; and maybe more. 266 bare_ipv6 = '[' != machine[0] and re.search(r':.*:', machine) 267 268 # Extract trailing :port number if any. 269 if not bare_ipv6 and re.search(r':\d*$', machine): 270 machine, port = machine.rsplit(':', 1) 271 port = int(port) 272 273 # Strip any IPv6 brackets (ssh does not support them). 274 # We'll add them back later for rsync, scp, etc. 275 if machine[0] == '[' and machine[-1] == ']': 276 machine = machine[1:-1] 277 278 if not machine or not user: 279 raise ValueError 280 281 return machine, user, password, port 282 283 284 def get_public_key(): 285 """ 286 Return a valid string ssh public key for the user executing autoserv or 287 autotest. If there's no DSA or RSA public key, create a DSA keypair with 288 ssh-keygen and return it. 289 """ 290 291 ssh_conf_path = os.path.expanduser('~/.ssh') 292 293 dsa_public_key_path = os.path.join(ssh_conf_path, 'id_dsa.pub') 294 dsa_private_key_path = os.path.join(ssh_conf_path, 'id_dsa') 295 296 rsa_public_key_path = os.path.join(ssh_conf_path, 'id_rsa.pub') 297 rsa_private_key_path = os.path.join(ssh_conf_path, 'id_rsa') 298 299 has_dsa_keypair = os.path.isfile(dsa_public_key_path) and \ 300 os.path.isfile(dsa_private_key_path) 301 has_rsa_keypair = os.path.isfile(rsa_public_key_path) and \ 302 os.path.isfile(rsa_private_key_path) 303 304 if has_dsa_keypair: 305 print 'DSA keypair found, using it' 306 public_key_path = dsa_public_key_path 307 308 elif has_rsa_keypair: 309 print 'RSA keypair found, using it' 310 public_key_path = rsa_public_key_path 311 312 else: 313 print 'Neither RSA nor DSA keypair found, creating DSA ssh key pair' 314 utils.system('ssh-keygen -t dsa -q -N "" -f %s' % dsa_private_key_path) 315 public_key_path = dsa_public_key_path 316 317 public_key = open(public_key_path, 'r') 318 public_key_str = public_key.read() 319 public_key.close() 320 321 return public_key_str 322 323 324 def get_sync_control_file(control, host_name, host_num, 325 instance, num_jobs, port_base=63100): 326 """ 327 This function is used when there is a need to run more than one 328 job simultaneously starting exactly at the same time. It basically returns 329 a modified control file (containing the synchronization code prepended) 330 whenever it is ready to run the control file. The synchronization 331 is done using barriers to make sure that the jobs start at the same time. 332 333 Here is how the synchronization is done to make sure that the tests 334 start at exactly the same time on the client. 335 sc_bar is a server barrier and s_bar, c_bar are the normal barriers 336 337 Job1 Job2 ...... JobN 338 Server: | sc_bar 339 Server: | s_bar ...... s_bar 340 Server: | at.run() at.run() ...... at.run() 341 ----------|------------------------------------------------------ 342 Client | sc_bar 343 Client | c_bar c_bar ...... c_bar 344 Client | <run test> <run test> ...... <run test> 345 346 @param control: The control file which to which the above synchronization 347 code will be prepended. 348 @param host_name: The host name on which the job is going to run. 349 @param host_num: (non negative) A number to identify the machine so that 350 we have different sets of s_bar_ports for each of the machines. 351 @param instance: The number of the job 352 @param num_jobs: Total number of jobs that are going to run in parallel 353 with this job starting at the same time. 354 @param port_base: Port number that is used to derive the actual barrier 355 ports. 356 357 @returns The modified control file. 358 """ 359 sc_bar_port = port_base 360 c_bar_port = port_base 361 if host_num < 0: 362 print "Please provide a non negative number for the host" 363 return None 364 s_bar_port = port_base + 1 + host_num # The set of s_bar_ports are 365 # the same for a given machine 366 367 sc_bar_timeout = 180 368 s_bar_timeout = c_bar_timeout = 120 369 370 # The barrier code snippet is prepended into the conrol file 371 # dynamically before at.run() is called finally. 372 control_new = [] 373 374 # jobid is the unique name used to identify the processes 375 # trying to reach the barriers 376 jobid = "%s#%d" % (host_name, instance) 377 378 rendv = [] 379 # rendvstr is a temp holder for the rendezvous list of the processes 380 for n in range(num_jobs): 381 rendv.append("'%s#%d'" % (host_name, n)) 382 rendvstr = ",".join(rendv) 383 384 if instance == 0: 385 # Do the setup and wait at the server barrier 386 # Clean up the tmp and the control dirs for the first instance 387 control_new.append('if os.path.exists(job.tmpdir):') 388 control_new.append("\t system('umount -f %s > /dev/null" 389 "2> /dev/null' % job.tmpdir," 390 "ignore_status=True)") 391 control_new.append("\t system('rm -rf ' + job.tmpdir)") 392 control_new.append( 393 'b0 = job.barrier("%s", "sc_bar", %d, port=%d)' 394 % (jobid, sc_bar_timeout, sc_bar_port)) 395 control_new.append( 396 'b0.rendezvous_servers("PARALLEL_MASTER", "%s")' 397 % jobid) 398 399 elif instance == 1: 400 # Wait at the server barrier to wait for instance=0 401 # process to complete setup 402 b0 = barrier.barrier("PARALLEL_MASTER", "sc_bar", sc_bar_timeout, 403 port=sc_bar_port) 404 b0.rendezvous_servers("PARALLEL_MASTER", jobid) 405 406 if(num_jobs > 2): 407 b1 = barrier.barrier(jobid, "s_bar", s_bar_timeout, 408 port=s_bar_port) 409 b1.rendezvous(rendvstr) 410 411 else: 412 # For the rest of the clients 413 b2 = barrier.barrier(jobid, "s_bar", s_bar_timeout, port=s_bar_port) 414 b2.rendezvous(rendvstr) 415 416 # Client side barrier for all the tests to start at the same time 417 control_new.append('b1 = job.barrier("%s", "c_bar", %d, port=%d)' 418 % (jobid, c_bar_timeout, c_bar_port)) 419 control_new.append("b1.rendezvous(%s)" % rendvstr) 420 421 # Stick in the rest of the control file 422 control_new.append(control) 423 424 return "\n".join(control_new) 425