1 # 2 # Copyright 2008 Google Inc. All Rights Reserved. 3 # 4 """ 5 This module contains the generic CLI object 6 7 High Level Design: 8 9 The atest class contains attributes & method generic to all the CLI 10 operations. 11 12 The class inheritance is shown here using the command 13 'atest host create ...' as an example: 14 15 atest <-- host <-- host_create <-- site_host_create 16 17 Note: The site_<topic>.py and its classes are only needed if you need 18 to override the common <topic>.py methods with your site specific ones. 19 20 21 High Level Algorithm: 22 23 1. atest figures out the topic and action from the 2 first arguments 24 on the command line and imports the <topic> (or site_<topic>) 25 module. 26 27 1. Init 28 The main atest module creates a <topic>_<action> object. The 29 __init__() function is used to setup the parser options, if this 30 <action> has some specific options to add to its <topic>. 31 32 If it exists, the child __init__() method must call its parent 33 class __init__() before adding its own parser arguments. 34 35 2. Parsing 36 If the child wants to validate the parsing (e.g. make sure that 37 there are hosts in the arguments), or if it wants to check the 38 options it added in its __init__(), it should implement a parse() 39 method. 40 41 The child parser must call its parent parser and gets back the 42 options dictionary and the rest of the command line arguments 43 (leftover). Each level gets to see all the options, but the 44 leftovers can be deleted as they can be consumed by only one 45 object. 46 47 3. Execution 48 This execute() method is specific to the child and should use the 49 self.execute_rpc() to send commands to the Autotest Front-End. It 50 should return results. 51 52 4. Output 53 The child output() method is called with the execute() resutls as a 54 parameter. This is child-specific, but should leverage the 55 atest.print_*() methods. 56 """ 57 58 import optparse 59 import os 60 import re 61 import sys 62 import textwrap 63 import traceback 64 import urllib2 65 66 from autotest_lib.cli import rpc 67 from autotest_lib.client.common_lib.test_utils import mock 68 69 70 # Maps the AFE keys to printable names. 71 KEYS_TO_NAMES_EN = {'hostname': 'Host', 72 'platform': 'Platform', 73 'status': 'Status', 74 'locked': 'Locked', 75 'locked_by': 'Locked by', 76 'lock_time': 'Locked time', 77 'lock_reason': 'Lock Reason', 78 'labels': 'Labels', 79 'description': 'Description', 80 'hosts': 'Hosts', 81 'users': 'Users', 82 'id': 'Id', 83 'name': 'Name', 84 'invalid': 'Valid', 85 'login': 'Login', 86 'access_level': 'Access Level', 87 'job_id': 'Job Id', 88 'job_owner': 'Job Owner', 89 'job_name': 'Job Name', 90 'test_type': 'Test Type', 91 'test_class': 'Test Class', 92 'path': 'Path', 93 'owner': 'Owner', 94 'status_counts': 'Status Counts', 95 'hosts_status': 'Host Status', 96 'hosts_selected_status': 'Hosts filtered by Status', 97 'priority': 'Priority', 98 'control_type': 'Control Type', 99 'created_on': 'Created On', 100 'synch_type': 'Synch Type', 101 'control_file': 'Control File', 102 'only_if_needed': 'Use only if needed', 103 'protection': 'Protection', 104 'run_verify': 'Run verify', 105 'reboot_before': 'Pre-job reboot', 106 'reboot_after': 'Post-job reboot', 107 'experimental': 'Experimental', 108 'synch_count': 'Sync Count', 109 'max_number_of_machines': 'Max. hosts to use', 110 'parse_failed_repair': 'Include failed repair results', 111 'atomic_group.name': 'Atomic Group Name', 112 'shard': 'Shard', 113 } 114 115 # In the failure, tag that will replace the item. 116 FAIL_TAG = '<XYZ>' 117 118 # Global socket timeout: uploading kernels can take much, 119 # much longer than the default 120 UPLOAD_SOCKET_TIMEOUT = 60*30 121 122 123 # Convertion functions to be called for printing, 124 # e.g. to print True/False for booleans. 125 def __convert_platform(field): 126 if field is None: 127 return "" 128 elif isinstance(field, int): 129 # Can be 0/1 for False/True 130 return str(bool(field)) 131 else: 132 # Can be a platform name 133 return field 134 135 136 def _int_2_bool_string(value): 137 return str(bool(value)) 138 139 KEYS_CONVERT = {'locked': _int_2_bool_string, 140 'invalid': lambda flag: str(bool(not flag)), 141 'only_if_needed': _int_2_bool_string, 142 'platform': __convert_platform, 143 'labels': lambda labels: ', '.join(labels), 144 'shards': lambda shard: shard.hostname if shard else ''} 145 146 147 def _get_item_key(item, key): 148 """Allow for lookups in nested dictionaries using '.'s within a key.""" 149 if key in item: 150 return item[key] 151 nested_item = item 152 for subkey in key.split('.'): 153 if not subkey: 154 raise ValueError('empty subkey in %r' % key) 155 try: 156 nested_item = nested_item[subkey] 157 except KeyError, e: 158 raise KeyError('%r - looking up key %r in %r' % 159 (e, key, nested_item)) 160 else: 161 return nested_item 162 163 164 class CliError(Exception): 165 """Error raised by cli calls. 166 """ 167 pass 168 169 170 class item_parse_info(object): 171 """Object keeping track of the parsing options. 172 """ 173 174 def __init__(self, attribute_name, inline_option='', 175 filename_option='', use_leftover=False): 176 """Object keeping track of the parsing options that will 177 make up the content of the atest attribute: 178 attribute_name: the atest attribute name to populate (label) 179 inline_option: the option containing the items (--label) 180 filename_option: the option containing the filename (--blist) 181 use_leftover: whether to add the leftover arguments or not.""" 182 self.attribute_name = attribute_name 183 self.filename_option = filename_option 184 self.inline_option = inline_option 185 self.use_leftover = use_leftover 186 187 188 def get_values(self, options, leftover=[]): 189 """Returns the value for that attribute by accumualting all 190 the values found through the inline option, the parsing of the 191 file and the leftover""" 192 193 def __get_items(input, split_spaces=True): 194 """Splits a string of comma separated items. Escaped commas will not 195 be split. I.e. Splitting 'a, b\,c, d' will yield ['a', 'b,c', 'd']. 196 If split_spaces is set to False spaces will not be split. I.e. 197 Splitting 'a b, c\,d, e' will yield ['a b', 'c,d', 'e']""" 198 199 # Replace escaped slashes with null characters so we don't misparse 200 # proceeding commas. 201 input = input.replace(r'\\', '\0') 202 203 # Split on commas which are not preceded by a slash. 204 if not split_spaces: 205 split = re.split(r'(?<!\\),', input) 206 else: 207 split = re.split(r'(?<!\\),|\s', input) 208 209 # Convert null characters to single slashes and escaped commas to 210 # just plain commas. 211 return (item.strip().replace('\0', '\\').replace(r'\,', ',') for 212 item in split if item.strip()) 213 214 if self.use_leftover: 215 add_on = leftover 216 leftover = [] 217 else: 218 add_on = [] 219 220 # Start with the add_on 221 result = set() 222 for items in add_on: 223 # Don't split on space here because the add-on 224 # may have some spaces (like the job name) 225 result.update(__get_items(items, split_spaces=False)) 226 227 # Process the inline_option, if any 228 try: 229 items = getattr(options, self.inline_option) 230 result.update(__get_items(items)) 231 except (AttributeError, TypeError): 232 pass 233 234 # Process the file list, if any and not empty 235 # The file can contain space and/or comma separated items 236 try: 237 flist = getattr(options, self.filename_option) 238 file_content = [] 239 for line in open(flist).readlines(): 240 file_content += __get_items(line) 241 if len(file_content) == 0: 242 raise CliError("Empty file %s" % flist) 243 result.update(file_content) 244 except (AttributeError, TypeError): 245 pass 246 except IOError: 247 raise CliError("Could not open file %s" % flist) 248 249 return list(result), leftover 250 251 252 class atest(object): 253 """Common class for generic processing 254 Should only be instantiated by itself for usage 255 references, otherwise, the <topic> objects should 256 be used.""" 257 msg_topic = ('[acl|host|job|label|shard|atomicgroup|test|user|server|' 258 'stable_version]') 259 usage_action = '[action]' 260 msg_items = '' 261 262 def invalid_arg(self, header, follow_up=''): 263 """Fail the command with error that command line has invalid argument. 264 265 @param header: Header of the error message. 266 @param follow_up: Extra error message, default to empty string. 267 """ 268 twrap = textwrap.TextWrapper(initial_indent=' ', 269 subsequent_indent=' ') 270 rest = twrap.fill(follow_up) 271 272 if self.kill_on_failure: 273 self.invalid_syntax(header + rest) 274 else: 275 print >> sys.stderr, header + rest 276 277 278 def invalid_syntax(self, msg): 279 """Fail the command with error that the command line syntax is wrong. 280 281 @param msg: Error message. 282 """ 283 print 284 print >> sys.stderr, msg 285 print 286 print "usage:", 287 print self._get_usage() 288 print 289 sys.exit(1) 290 291 292 def generic_error(self, msg): 293 """Fail the command with a generic error. 294 295 @param msg: Error message. 296 """ 297 if self.debug: 298 traceback.print_exc() 299 print >> sys.stderr, msg 300 sys.exit(1) 301 302 303 def parse_json_exception(self, full_error): 304 """Parses the JSON exception to extract the bad 305 items and returns them 306 This is very kludgy for the moment, but we would need 307 to refactor the exceptions sent from the front end 308 to make this better. 309 310 @param full_error: The complete error message. 311 """ 312 errmsg = str(full_error).split('Traceback')[0].rstrip('\n') 313 parts = errmsg.split(':') 314 # Kludge: If there are 2 colons the last parts contains 315 # the items that failed. 316 if len(parts) != 3: 317 return [] 318 return [item.strip() for item in parts[2].split(',') if item.strip()] 319 320 321 def failure(self, full_error, item=None, what_failed='', fatal=False): 322 """If kill_on_failure, print this error and die, 323 otherwise, queue the error and accumulate all the items 324 that triggered the same error. 325 326 @param full_error: The complete error message. 327 @param item: Name of the actionable item, e.g., hostname. 328 @param what_failed: Name of the failed item. 329 @param fatal: True to exit the program with failure. 330 """ 331 332 if self.debug: 333 errmsg = str(full_error) 334 else: 335 errmsg = str(full_error).split('Traceback')[0].rstrip('\n') 336 337 if self.kill_on_failure or fatal: 338 print >> sys.stderr, "%s\n %s" % (what_failed, errmsg) 339 sys.exit(1) 340 341 # Build a dictionary with the 'what_failed' as keys. The 342 # values are dictionaries with the errmsg as keys and a set 343 # of items as values. 344 # self.failed = 345 # {'Operation delete_host_failed': {'AclAccessViolation: 346 # set('host0', 'host1')}} 347 # Try to gather all the same error messages together, 348 # even if they contain the 'item' 349 if item and item in errmsg: 350 errmsg = errmsg.replace(item, FAIL_TAG) 351 if self.failed.has_key(what_failed): 352 self.failed[what_failed].setdefault(errmsg, set()).add(item) 353 else: 354 self.failed[what_failed] = {errmsg: set([item])} 355 356 357 def show_all_failures(self): 358 """Print all failure information. 359 """ 360 if not self.failed: 361 return 0 362 for what_failed in self.failed.keys(): 363 print >> sys.stderr, what_failed + ':' 364 for (errmsg, items) in self.failed[what_failed].iteritems(): 365 if len(items) == 0: 366 print >> sys.stderr, errmsg 367 elif items == set(['']): 368 print >> sys.stderr, ' ' + errmsg 369 elif len(items) == 1: 370 # Restore the only item 371 if FAIL_TAG in errmsg: 372 errmsg = errmsg.replace(FAIL_TAG, items.pop()) 373 else: 374 errmsg = '%s (%s)' % (errmsg, items.pop()) 375 print >> sys.stderr, ' ' + errmsg 376 else: 377 print >> sys.stderr, ' ' + errmsg + ' with <XYZ> in:' 378 twrap = textwrap.TextWrapper(initial_indent=' ', 379 subsequent_indent=' ') 380 items = list(items) 381 items.sort() 382 print >> sys.stderr, twrap.fill(', '.join(items)) 383 return 1 384 385 386 def __init__(self): 387 """Setup the parser common options""" 388 # Initialized for unit tests. 389 self.afe = None 390 self.failed = {} 391 self.data = {} 392 self.debug = False 393 self.parse_delim = '|' 394 self.kill_on_failure = False 395 self.web_server = '' 396 self.verbose = False 397 self.no_confirmation = False 398 self.topic_parse_info = item_parse_info(attribute_name='not_used') 399 400 self.parser = optparse.OptionParser(self._get_usage()) 401 self.parser.add_option('-g', '--debug', 402 help='Print debugging information', 403 action='store_true', default=False) 404 self.parser.add_option('--kill-on-failure', 405 help='Stop at the first failure', 406 action='store_true', default=False) 407 self.parser.add_option('--parse', 408 help='Print the output using | ' 409 'separated key=value fields', 410 action='store_true', default=False) 411 self.parser.add_option('--parse-delim', 412 help='Delimiter to use to separate the ' 413 'key=value fields', default='|') 414 self.parser.add_option('--no-confirmation', 415 help=('Skip all confirmation in when function ' 416 'require_confirmation is called.'), 417 action='store_true', default=False) 418 self.parser.add_option('-v', '--verbose', 419 action='store_true', default=False) 420 self.parser.add_option('-w', '--web', 421 help='Specify the autotest server ' 422 'to talk to', 423 action='store', type='string', 424 dest='web_server', default=None) 425 426 427 def _get_usage(self): 428 return "atest %s %s [options] %s" % (self.msg_topic.lower(), 429 self.usage_action, 430 self.msg_items) 431 432 433 def backward_compatibility(self, action, argv): 434 """To be overidden by subclass if their syntax changed. 435 436 @param action: Name of the action. 437 @param argv: A list of arguments. 438 """ 439 return action 440 441 442 def parse(self, parse_info=[], req_items=None): 443 """Parse command arguments. 444 445 parse_info is a list of item_parse_info objects. 446 There should only be one use_leftover set to True in the list. 447 448 Also check that the req_items is not empty after parsing. 449 450 @param parse_info: A list of item_parse_info objects. 451 @param req_items: A list of required items. 452 """ 453 (options, leftover) = self.parse_global() 454 455 all_parse_info = parse_info[:] 456 all_parse_info.append(self.topic_parse_info) 457 458 try: 459 for item_parse_info in all_parse_info: 460 values, leftover = item_parse_info.get_values(options, 461 leftover) 462 setattr(self, item_parse_info.attribute_name, values) 463 except CliError, s: 464 self.invalid_syntax(s) 465 466 if (req_items and not getattr(self, req_items, None)): 467 self.invalid_syntax('%s %s requires at least one %s' % 468 (self.msg_topic, 469 self.usage_action, 470 self.msg_topic)) 471 472 return (options, leftover) 473 474 475 def parse_global(self): 476 """Parse the global arguments. 477 478 It consumes what the common object needs to know, and 479 let the children look at all the options. We could 480 remove the options that we have used, but there is no 481 harm in leaving them, and the children may need them 482 in the future. 483 484 Must be called from its children parse()""" 485 (options, leftover) = self.parser.parse_args() 486 # Handle our own options setup in __init__() 487 self.debug = options.debug 488 self.kill_on_failure = options.kill_on_failure 489 490 if options.parse: 491 suffix = '_parse' 492 else: 493 suffix = '_std' 494 for func in ['print_fields', 'print_table', 495 'print_by_ids', 'print_list']: 496 setattr(self, func, getattr(self, func + suffix)) 497 498 self.parse_delim = options.parse_delim 499 500 self.verbose = options.verbose 501 self.no_confirmation = options.no_confirmation 502 self.web_server = options.web_server 503 try: 504 self.afe = rpc.afe_comm(self.web_server) 505 except rpc.AuthError, s: 506 self.failure(str(s), fatal=True) 507 508 return (options, leftover) 509 510 511 def check_and_create_items(self, op_get, op_create, 512 items, **data_create): 513 """Create the items if they don't exist already. 514 515 @param op_get: Name of `get` RPC. 516 @param op_create: Name of `create` RPC. 517 @param items: Actionable items specified in CLI command, e.g., hostname, 518 to be passed to each RPC. 519 @param data_create: Data to be passed to `create` RPC. 520 """ 521 for item in items: 522 ret = self.execute_rpc(op_get, name=item) 523 524 if len(ret) == 0: 525 try: 526 data_create['name'] = item 527 self.execute_rpc(op_create, **data_create) 528 except CliError: 529 continue 530 531 532 def execute_rpc(self, op, item='', **data): 533 """Execute RPC. 534 535 @param op: Name of the RPC. 536 @param item: Actionable item specified in CLI command. 537 @param data: Data to be passed to RPC. 538 """ 539 retry = 2 540 while retry: 541 try: 542 return self.afe.run(op, **data) 543 except urllib2.URLError, err: 544 if hasattr(err, 'reason'): 545 if 'timed out' not in err.reason: 546 self.invalid_syntax('Invalid server name %s: %s' % 547 (self.afe.web_server, err)) 548 if hasattr(err, 'code'): 549 error_parts = [str(err)] 550 if self.debug: 551 error_parts.append(err.read()) # read the response body 552 self.failure('\n\n'.join(error_parts), item=item, 553 what_failed=("Error received from web server")) 554 raise CliError("Error from web server") 555 if self.debug: 556 print 'retrying: %r %d' % (data, retry) 557 retry -= 1 558 if retry == 0: 559 if item: 560 myerr = '%s timed out for %s' % (op, item) 561 else: 562 myerr = '%s timed out' % op 563 self.failure(myerr, item=item, 564 what_failed=("Timed-out contacting " 565 "the Autotest server")) 566 raise CliError("Timed-out contacting the Autotest server") 567 except mock.CheckPlaybackError: 568 raise 569 except Exception, full_error: 570 # There are various exceptions throwns by JSON, 571 # urllib & httplib, so catch them all. 572 self.failure(full_error, item=item, 573 what_failed='Operation %s failed' % op) 574 raise CliError(str(full_error)) 575 576 577 # There is no output() method in the atest object (yet?) 578 # but here are some helper functions to be used by its 579 # children 580 def print_wrapped(self, msg, values): 581 """Print given message and values in wrapped lines unless 582 AUTOTEST_CLI_NO_WRAP is specified in environment variables. 583 584 @param msg: Message to print. 585 @param values: A list of values to print. 586 """ 587 if len(values) == 0: 588 return 589 elif len(values) == 1: 590 print msg + ': ' 591 elif len(values) > 1: 592 if msg.endswith('s'): 593 print msg + ': ' 594 else: 595 print msg + 's: ' 596 597 values.sort() 598 599 if 'AUTOTEST_CLI_NO_WRAP' in os.environ: 600 print '\n'.join(values) 601 return 602 603 twrap = textwrap.TextWrapper(initial_indent='\t', 604 subsequent_indent='\t') 605 print twrap.fill(', '.join(values)) 606 607 608 def __conv_value(self, type, value): 609 return KEYS_CONVERT.get(type, str)(value) 610 611 612 def print_fields_std(self, items, keys, title=None): 613 """Print the keys in each item, one on each line. 614 615 @param items: Items to print. 616 @param keys: Name of the keys to look up each item in items. 617 @param title: Title of the output, default to None. 618 """ 619 if not items: 620 return 621 if title: 622 print title 623 for item in items: 624 for key in keys: 625 print '%s: %s' % (KEYS_TO_NAMES_EN[key], 626 self.__conv_value(key, 627 _get_item_key(item, key))) 628 629 630 def print_fields_parse(self, items, keys, title=None): 631 """Print the keys in each item as comma separated name=value 632 633 @param items: Items to print. 634 @param keys: Name of the keys to look up each item in items. 635 @param title: Title of the output, default to None. 636 """ 637 for item in items: 638 values = ['%s=%s' % (KEYS_TO_NAMES_EN[key], 639 self.__conv_value(key, 640 _get_item_key(item, key))) 641 for key in keys 642 if self.__conv_value(key, 643 _get_item_key(item, key)) != ''] 644 print self.parse_delim.join(values) 645 646 647 def __find_justified_fmt(self, items, keys): 648 """Find the max length for each field. 649 650 @param items: Items to lookup for. 651 @param keys: Name of the keys to look up each item in items. 652 """ 653 lens = {} 654 # Don't justify the last field, otherwise we have blank 655 # lines when the max is overlaps but the current values 656 # are smaller 657 if not items: 658 print "No results" 659 return 660 for key in keys[:-1]: 661 lens[key] = max(len(self.__conv_value(key, 662 _get_item_key(item, key))) 663 for item in items) 664 lens[key] = max(lens[key], len(KEYS_TO_NAMES_EN[key])) 665 lens[keys[-1]] = 0 666 667 return ' '.join(["%%-%ds" % lens[key] for key in keys]) 668 669 670 def print_dict(self, items, title=None, line_before=False): 671 """Print a dictionary. 672 673 @param items: Dictionary to print. 674 @param title: Title of the output, default to None. 675 @param line_before: True to print an empty line before the output, 676 default to False. 677 """ 678 if not items: 679 return 680 if line_before: 681 print 682 print title 683 for key, value in items.items(): 684 print '%s : %s' % (key, value) 685 686 687 def print_table_std(self, items, keys_header, sublist_keys=()): 688 """Print a mix of header and lists in a user readable format. 689 690 The headers are justified, the sublist_keys are wrapped. 691 692 @param items: Items to print. 693 @param keys_header: Header of the keys, use to look up in items. 694 @param sublist_keys: Keys for sublist in each item. 695 """ 696 if not items: 697 return 698 fmt = self.__find_justified_fmt(items, keys_header) 699 header = tuple(KEYS_TO_NAMES_EN[key] for key in keys_header) 700 print fmt % header 701 for item in items: 702 values = tuple(self.__conv_value(key, 703 _get_item_key(item, key)) 704 for key in keys_header) 705 print fmt % values 706 if sublist_keys: 707 for key in sublist_keys: 708 self.print_wrapped(KEYS_TO_NAMES_EN[key], 709 _get_item_key(item, key)) 710 print '\n' 711 712 713 def print_table_parse(self, items, keys_header, sublist_keys=()): 714 """Print a mix of header and lists in a user readable format. 715 716 @param items: Items to print. 717 @param keys_header: Header of the keys, use to look up in items. 718 @param sublist_keys: Keys for sublist in each item. 719 """ 720 for item in items: 721 values = ['%s=%s' % (KEYS_TO_NAMES_EN[key], 722 self.__conv_value(key, _get_item_key(item, key))) 723 for key in keys_header 724 if self.__conv_value(key, 725 _get_item_key(item, key)) != ''] 726 727 if sublist_keys: 728 [values.append('%s=%s'% (KEYS_TO_NAMES_EN[key], 729 ','.join(_get_item_key(item, key)))) 730 for key in sublist_keys 731 if len(_get_item_key(item, key))] 732 733 print self.parse_delim.join(values) 734 735 736 def print_by_ids_std(self, items, title=None, line_before=False): 737 """Prints ID & names of items in a user readable form. 738 739 @param items: Items to print. 740 @param title: Title of the output, default to None. 741 @param line_before: True to print an empty line before the output, 742 default to False. 743 """ 744 if not items: 745 return 746 if line_before: 747 print 748 if title: 749 print title + ':' 750 self.print_table_std(items, keys_header=['id', 'name']) 751 752 753 def print_by_ids_parse(self, items, title=None, line_before=False): 754 """Prints ID & names of items in a parseable format. 755 756 @param items: Items to print. 757 @param title: Title of the output, default to None. 758 @param line_before: True to print an empty line before the output, 759 default to False. 760 """ 761 if not items: 762 return 763 if line_before: 764 print 765 if title: 766 print title + '=', 767 values = [] 768 for item in items: 769 values += ['%s=%s' % (KEYS_TO_NAMES_EN[key], 770 self.__conv_value(key, 771 _get_item_key(item, key))) 772 for key in ['id', 'name'] 773 if self.__conv_value(key, 774 _get_item_key(item, key)) != ''] 775 print self.parse_delim.join(values) 776 777 778 def print_list_std(self, items, key): 779 """Print a wrapped list of results 780 781 @param items: Items to to lookup for given key, could be a nested 782 dictionary. 783 @param key: Name of the key to look up for value. 784 """ 785 if not items: 786 return 787 print ' '.join(_get_item_key(item, key) for item in items) 788 789 790 def print_list_parse(self, items, key): 791 """Print a wrapped list of results. 792 793 @param items: Items to to lookup for given key, could be a nested 794 dictionary. 795 @param key: Name of the key to look up for value. 796 """ 797 if not items: 798 return 799 print '%s=%s' % (KEYS_TO_NAMES_EN[key], 800 ','.join(_get_item_key(item, key) for item in items)) 801 802 803 @staticmethod 804 def prompt_confirmation(message=None): 805 """Prompt a question for user to confirm the action before proceeding. 806 807 @param message: A detailed message to explain possible impact of the 808 action. 809 810 @return: True to proceed or False to abort. 811 """ 812 if message: 813 print message 814 sys.stdout.write('Continue? [y/N] ') 815 read = raw_input().lower() 816 if read == 'y': 817 return True 818 else: 819 print 'User did not confirm. Aborting...' 820 return False 821 822 823 @staticmethod 824 def require_confirmation(message=None): 825 """Decorator to prompt a question for user to confirm action before 826 proceeding. 827 828 If user chooses not to proceed, do not call the function. 829 830 @param message: A detailed message to explain possible impact of the 831 action. 832 833 @return: A decorator wrapper for calling the actual function. 834 """ 835 def deco_require_confirmation(func): 836 """Wrapper for the decorator. 837 838 @param func: Function to be called. 839 840 @return: the actual decorator to call the function. 841 """ 842 def func_require_confirmation(*args, **kwargs): 843 """Decorator to prompt a question for user to confirm. 844 845 @param message: A detailed message to explain possible impact of 846 the action. 847 """ 848 if (args[0].no_confirmation or 849 atest.prompt_confirmation(message)): 850 func(*args, **kwargs) 851 852 return func_require_confirmation 853 return deco_require_confirmation 854