Home | History | Annotate | Download | only in repo_pull
      1 #!/usr/bin/env python3
      2 
      3 #
      4 # Copyright (C) 2018 The Android Open Source Project
      5 #
      6 # Licensed under the Apache License, Version 2.0 (the "License");
      7 # you may not use this file except in compliance with the License.
      8 # You may obtain a copy of the License at
      9 #
     10 #      http://www.apache.org/licenses/LICENSE-2.0
     11 #
     12 # Unless required by applicable law or agreed to in writing, software
     13 # distributed under the License is distributed on an "AS IS" BASIS,
     14 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     15 # See the License for the specific language governing permissions and
     16 # limitations under the License.
     17 #
     18 
     19 """A command line utility to pull multiple change lists from Gerrit."""
     20 
     21 from __future__ import print_function
     22 
     23 import argparse
     24 import collections
     25 import itertools
     26 import json
     27 import multiprocessing
     28 import os
     29 import re
     30 import sys
     31 import xml.dom.minidom
     32 
     33 from gerrit import create_url_opener_from_args, query_change_lists
     34 
     35 try:
     36     # pylint: disable=redefined-builtin
     37     from __builtin__ import raw_input as input  # PY2
     38 except ImportError:
     39     pass
     40 
     41 try:
     42     from shlex import quote as _sh_quote  # PY3.3
     43 except ImportError:
     44     # Shell language simple string pattern.  If a string matches this pattern,
     45     # it doesn't have to be quoted.
     46     _SHELL_SIMPLE_PATTERN = re.compile('^[a-zA-Z90-9_./-]+$')
     47 
     48     def _sh_quote(txt):
     49         """Quote a string if it contains special characters."""
     50         return txt if _SHELL_SIMPLE_PATTERN.match(txt) else json.dumps(txt)
     51 
     52 try:
     53     from subprocess import PIPE, run  # PY3.5
     54 except ImportError:
     55     from subprocess import CalledProcessError, PIPE, Popen
     56 
     57     class CompletedProcess(object):
     58         """Process execution result returned by subprocess.run()."""
     59         # pylint: disable=too-few-public-methods
     60 
     61         def __init__(self, args, returncode, stdout, stderr):
     62             self.args = args
     63             self.returncode = returncode
     64             self.stdout = stdout
     65             self.stderr = stderr
     66 
     67     def run(*args, **kwargs):
     68         """Run a command with subprocess.Popen() and redirect input/output."""
     69 
     70         check = kwargs.pop('check', False)
     71 
     72         try:
     73             stdin = kwargs.pop('input')
     74             assert 'stdin' not in kwargs
     75             kwargs['stdin'] = PIPE
     76         except KeyError:
     77             stdin = None
     78 
     79         proc = Popen(*args, **kwargs)
     80         try:
     81             stdout, stderr = proc.communicate(stdin)
     82         except:
     83             proc.kill()
     84             proc.wait()
     85             raise
     86         returncode = proc.wait()
     87 
     88         if check and returncode:
     89             raise CalledProcessError(returncode, args, stdout)
     90         return CompletedProcess(args, returncode, stdout, stderr)
     91 
     92 
     93 if bytes is str:
     94     def write_bytes(data, file):  # PY2
     95         """Write bytes to a file."""
     96         # pylint: disable=redefined-builtin
     97         file.write(data)
     98 else:
     99     def write_bytes(data, file):  # PY3
    100         """Write bytes to a file."""
    101         # pylint: disable=redefined-builtin
    102         file.buffer.write(data)
    103 
    104 
    105 def _confirm(question, default, file=sys.stderr):
    106     """Prompt a yes/no question and convert the answer to a boolean value."""
    107     # pylint: disable=redefined-builtin
    108     answers = {'': default, 'y': True, 'yes': True, 'n': False, 'no': False}
    109     suffix = '[Y/n] ' if default else ' [y/N] '
    110     while True:
    111         file.write(question + suffix)
    112         file.flush()
    113         ans = answers.get(input().lower())
    114         if ans is not None:
    115             return ans
    116 
    117 
    118 class ChangeList(object):
    119     """A ChangeList to be checked out."""
    120     # pylint: disable=too-few-public-methods,too-many-instance-attributes
    121 
    122     def __init__(self, project, fetch, commit_sha1, commit, change_list):
    123         """Initialize a ChangeList instance."""
    124         # pylint: disable=too-many-arguments
    125 
    126         self.project = project
    127         self.number = change_list['_number']
    128 
    129         self.fetch = fetch
    130 
    131         fetch_git = None
    132         for protocol in ('http', 'sso', 'rpc'):
    133             fetch_git = fetch.get(protocol)
    134             if fetch_git:
    135                 break
    136 
    137         if not fetch_git:
    138             raise ValueError(
    139                 'unknown fetch protocols: ' + str(list(fetch.keys())))
    140 
    141         self.fetch_url = fetch_git['url']
    142         self.fetch_ref = fetch_git['ref']
    143 
    144         self.commit_sha1 = commit_sha1
    145         self.commit = commit
    146         self.parents = commit['parents']
    147 
    148         self.change_list = change_list
    149 
    150 
    151     def is_merge(self):
    152         """Check whether this change list a merge commit."""
    153         return len(self.parents) > 1
    154 
    155 
    156 def find_manifest_xml(dir_path):
    157     """Find the path to manifest.xml for this Android source tree."""
    158     dir_path_prev = None
    159     while dir_path != dir_path_prev:
    160         path = os.path.join(dir_path, '.repo', 'manifest.xml')
    161         if os.path.exists(path):
    162             return path
    163         dir_path_prev = dir_path
    164         dir_path = os.path.dirname(dir_path)
    165     raise ValueError('.repo dir not found')
    166 
    167 
    168 def build_project_name_dir_dict(manifest_path):
    169     """Build the mapping from Gerrit project name to source tree project
    170     directory path."""
    171     project_dirs = {}
    172     parsed_xml = xml.dom.minidom.parse(manifest_path)
    173     projects = parsed_xml.getElementsByTagName('project')
    174     for project in projects:
    175         name = project.getAttribute('name')
    176         path = project.getAttribute('path')
    177         if path:
    178             project_dirs[name] = path
    179         else:
    180             project_dirs[name] = name
    181     return project_dirs
    182 
    183 
    184 def group_and_sort_change_lists(change_lists):
    185     """Build a dict that maps projects to a list of topologically sorted change
    186     lists."""
    187 
    188     # Build a dict that map projects to dicts that map commits to changes.
    189     projects = collections.defaultdict(dict)
    190     for change_list in change_lists:
    191         commit_sha1 = None
    192         for commit_sha1, value in change_list['revisions'].items():
    193             fetch = value['fetch']
    194             commit = value['commit']
    195 
    196         if not commit_sha1:
    197             raise ValueError('bad revision')
    198 
    199         project = change_list['project']
    200 
    201         project_changes = projects[project]
    202         if commit_sha1 in project_changes:
    203             raise KeyError('repeated commit sha1 "{}" in project "{}"'.format(
    204                 commit_sha1, project))
    205 
    206         project_changes[commit_sha1] = ChangeList(
    207             project, fetch, commit_sha1, commit, change_list)
    208 
    209     # Sort all change lists in a project in post ordering.
    210     def _sort_project_change_lists(changes):
    211         visited_changes = set()
    212         sorted_changes = []
    213 
    214         def _post_order_traverse(change):
    215             visited_changes.add(change)
    216             for parent in change.parents:
    217                 parent_change = changes.get(parent['commit'])
    218                 if parent_change and parent_change not in visited_changes:
    219                     _post_order_traverse(parent_change)
    220             sorted_changes.append(change)
    221 
    222         for change in sorted(changes.values(), key=lambda x: x.number):
    223             if change not in visited_changes:
    224                 _post_order_traverse(change)
    225 
    226         return sorted_changes
    227 
    228     # Sort changes in each projects
    229     sorted_changes = []
    230     for project in sorted(projects.keys()):
    231         sorted_changes.append(_sort_project_change_lists(projects[project]))
    232 
    233     return sorted_changes
    234 
    235 
    236 def _main_json(args):
    237     """Print the change lists in JSON format."""
    238     change_lists = _get_change_lists_from_args(args)
    239     json.dump(change_lists, sys.stdout, indent=4, separators=(', ', ': '))
    240     print()  # Print the end-of-line
    241 
    242 
    243 # Git commands for merge commits
    244 _MERGE_COMMANDS = {
    245     'merge': ['git', 'merge', '--no-edit'],
    246     'merge-ff-only': ['git', 'merge', '--no-edit', '--ff-only'],
    247     'merge-no-ff': ['git', 'merge', '--no-edit', '--no-ff'],
    248     'reset': ['git', 'reset', '--hard'],
    249     'checkout': ['git', 'checkout'],
    250 }
    251 
    252 
    253 # Git commands for non-merge commits
    254 _PICK_COMMANDS = {
    255     'pick': ['git', 'cherry-pick', '--allow-empty'],
    256     'merge': ['git', 'merge', '--no-edit'],
    257     'merge-ff-only': ['git', 'merge', '--no-edit', '--ff-only'],
    258     'merge-no-ff': ['git', 'merge', '--no-edit', '--no-ff'],
    259     'reset': ['git', 'reset', '--hard'],
    260     'checkout': ['git', 'checkout'],
    261 }
    262 
    263 
    264 def build_pull_commands(change, branch_name, merge_opt, pick_opt):
    265     """Build command lines for each change.  The command lines will be passed
    266     to subprocess.run()."""
    267 
    268     cmds = []
    269     if branch_name is not None:
    270         cmds.append(['repo', 'start', branch_name])
    271     cmds.append(['git', 'fetch', change.fetch_url, change.fetch_ref])
    272     if change.is_merge():
    273         cmds.append(_MERGE_COMMANDS[merge_opt] + ['FETCH_HEAD'])
    274     else:
    275         cmds.append(_PICK_COMMANDS[pick_opt] + ['FETCH_HEAD'])
    276     return cmds
    277 
    278 
    279 def _sh_quote_command(cmd):
    280     """Convert a command (an argument to subprocess.run()) to a shell command
    281     string."""
    282     return ' '.join(_sh_quote(x) for x in cmd)
    283 
    284 
    285 def _sh_quote_commands(cmds):
    286     """Convert multiple commands (arguments to subprocess.run()) to shell
    287     command strings."""
    288     return ' && '.join(_sh_quote_command(cmd) for cmd in cmds)
    289 
    290 
    291 def _main_bash(args):
    292     """Print the bash command to pull the change lists."""
    293 
    294     branch_name = _get_local_branch_name_from_args(args)
    295 
    296     manifest_path = _get_manifest_xml_from_args(args)
    297     project_dirs = build_project_name_dir_dict(manifest_path)
    298 
    299     change_lists = _get_change_lists_from_args(args)
    300     change_list_groups = group_and_sort_change_lists(change_lists)
    301 
    302     for changes in change_list_groups:
    303         for change in changes:
    304             project_dir = project_dirs.get(change.project, change.project)
    305             cmds = []
    306             cmds.append(['pushd', project_dir])
    307             cmds.extend(build_pull_commands(
    308                 change, branch_name, args.merge, args.pick))
    309             cmds.append(['popd'])
    310             print(_sh_quote_commands(cmds))
    311 
    312 
    313 def _do_pull_change_lists_for_project(task):
    314     """Pick a list of changes (usually under a project directory)."""
    315     changes, task_opts = task
    316 
    317     branch_name = task_opts['branch_name']
    318     merge_opt = task_opts['merge_opt']
    319     pick_opt = task_opts['pick_opt']
    320     project_dirs = task_opts['project_dirs']
    321 
    322     for i, change in enumerate(changes):
    323         try:
    324             cwd = project_dirs[change.project]
    325         except KeyError:
    326             err_msg = 'error: project "{}" cannot be found in manifest.xml\n'
    327             err_msg = err_msg.format(change.project).encode('utf-8')
    328             return (change, changes[i + 1:], [], err_msg)
    329 
    330         print(change.commit_sha1[0:10], i + 1, cwd)
    331         cmds = build_pull_commands(change, branch_name, merge_opt, pick_opt)
    332         for cmd in cmds:
    333             proc = run(cmd, cwd=cwd, stderr=PIPE)
    334             if proc.returncode != 0:
    335                 return (change, changes[i + 1:], cmd, proc.stderr)
    336     return None
    337 
    338 
    339 def _print_pull_failures(failures, file=sys.stderr):
    340     """Print pull failures and tracebacks."""
    341     # pylint: disable=redefined-builtin
    342 
    343     separator = '=' * 78
    344     separator_sub = '-' * 78
    345 
    346     print(separator, file=file)
    347     for failed_change, skipped_changes, cmd, errors in failures:
    348         print('PROJECT:', failed_change.project, file=file)
    349         print('FAILED COMMIT:', failed_change.commit_sha1, file=file)
    350         for change in skipped_changes:
    351             print('PENDING COMMIT:', change.commit_sha1, file=file)
    352         print(separator_sub, file=sys.stderr)
    353         print('FAILED COMMAND:', _sh_quote_command(cmd), file=file)
    354         write_bytes(errors, file=sys.stderr)
    355         print(separator, file=sys.stderr)
    356 
    357 
    358 def _main_pull(args):
    359     """Pull the change lists."""
    360 
    361     branch_name = _get_local_branch_name_from_args(args)
    362 
    363     manifest_path = _get_manifest_xml_from_args(args)
    364     project_dirs = build_project_name_dir_dict(manifest_path)
    365 
    366     # Collect change lists
    367     change_lists = _get_change_lists_from_args(args)
    368     change_list_groups = group_and_sort_change_lists(change_lists)
    369 
    370     # Build the options list for tasks
    371     task_opts = {
    372         'branch_name': branch_name,
    373         'merge_opt': args.merge,
    374         'pick_opt': args.pick,
    375         'project_dirs': project_dirs,
    376     }
    377 
    378     # Run the commands to pull the change lists
    379     if args.parallel <= 1:
    380         results = [_do_pull_change_lists_for_project((changes, task_opts))
    381                    for changes in change_list_groups]
    382     else:
    383         pool = multiprocessing.Pool(processes=args.parallel)
    384         results = pool.map(_do_pull_change_lists_for_project,
    385                            zip(change_list_groups, itertools.repeat(task_opts)))
    386 
    387     # Print failures and tracebacks
    388     failures = [result for result in results if result]
    389     if failures:
    390         _print_pull_failures(failures)
    391         sys.exit(1)
    392 
    393 
    394 def _parse_args():
    395     """Parse command line options."""
    396     parser = argparse.ArgumentParser()
    397 
    398     parser.add_argument('command', choices=['pull', 'bash', 'json'],
    399                         help='Commands')
    400 
    401     parser.add_argument('query', help='Change list query string')
    402     parser.add_argument('-g', '--gerrit', required=True,
    403                         help='Gerrit review URL')
    404 
    405     parser.add_argument('--gitcookies',
    406                         default=os.path.expanduser('~/.gitcookies'),
    407                         help='Gerrit cookie file')
    408     parser.add_argument('--manifest', help='Manifest')
    409     parser.add_argument('--limits', default=1000,
    410                         help='Max number of change lists')
    411 
    412     parser.add_argument('-m', '--merge',
    413                         choices=sorted(_MERGE_COMMANDS.keys()),
    414                         default='merge-ff-only',
    415                         help='Method to pull merge commits')
    416 
    417     parser.add_argument('-p', '--pick',
    418                         choices=sorted(_PICK_COMMANDS.keys()),
    419                         default='pick',
    420                         help='Method to pull merge commits')
    421 
    422     parser.add_argument('-b', '--branch',
    423                         help='Local branch name for `repo start`')
    424 
    425     parser.add_argument('-j', '--parallel', default=1, type=int,
    426                         help='Number of parallel running commands')
    427 
    428     return parser.parse_args()
    429 
    430 
    431 def _get_manifest_xml_from_args(args):
    432     """Get the path to manifest.xml from args."""
    433     manifest_path = args.manifest
    434     if not args.manifest:
    435         manifest_path = find_manifest_xml(os.getcwd())
    436     return manifest_path
    437 
    438 
    439 def _get_change_lists_from_args(args):
    440     """Query the change lists by args."""
    441     url_opener = create_url_opener_from_args(args)
    442     return query_change_lists(url_opener, args.gerrit, args.query, args.limits)
    443 
    444 
    445 def _get_local_branch_name_from_args(args):
    446     """Get the local branch name from args."""
    447     if not args.branch and not _confirm(
    448             'Do you want to continue without local branch name?', False):
    449         print('error: `-b` or `--branch` must be specified', file=sys.stderr)
    450         sys.exit(1)
    451     return args.branch
    452 
    453 
    454 def main():
    455     """Main function"""
    456     args = _parse_args()
    457     if args.command == 'json':
    458         _main_json(args)
    459     elif args.command == 'bash':
    460         _main_bash(args)
    461     elif args.command == 'pull':
    462         _main_pull(args)
    463     else:
    464         raise KeyError('unknown command')
    465 
    466 if __name__ == '__main__':
    467     main()
    468