Home | History | Annotate | Download | only in rh
      1 # -*- coding:utf-8 -*-
      2 # Copyright 2016 The Android Open Source Project
      3 #
      4 # Licensed under the Apache License, Version 2.0 (the "License");
      5 # you may not use this file except in compliance with the License.
      6 # You may obtain a copy of the License at
      7 #
      8 #      http://www.apache.org/licenses/LICENSE-2.0
      9 #
     10 # Unless required by applicable law or agreed to in writing, software
     11 # distributed under the License is distributed on an "AS IS" BASIS,
     12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13 # See the License for the specific language governing permissions and
     14 # limitations under the License.
     15 
     16 """Git helper functions."""
     17 
     18 from __future__ import print_function
     19 
     20 import os
     21 import re
     22 import sys
     23 
     24 _path = os.path.realpath(__file__ + '/../..')
     25 if sys.path[0] != _path:
     26     sys.path.insert(0, _path)
     27 del _path
     28 
     29 # pylint: disable=wrong-import-position
     30 import rh.utils
     31 
     32 
     33 def get_upstream_remote():
     34     """Returns the current upstream remote name."""
     35     # First get the current branch name.
     36     cmd = ['git', 'rev-parse', '--abbrev-ref', 'HEAD']
     37     result = rh.utils.run_command(cmd, capture_output=True)
     38     branch = result.output.strip()
     39 
     40     # Then get the remote associated with this branch.
     41     cmd = ['git', 'config', 'branch.%s.remote' % branch]
     42     result = rh.utils.run_command(cmd, capture_output=True)
     43     return result.output.strip()
     44 
     45 
     46 def get_upstream_branch():
     47     """Returns the upstream tracking branch of the current branch.
     48 
     49     Raises:
     50       Error if there is no tracking branch
     51     """
     52     cmd = ['git', 'symbolic-ref', 'HEAD']
     53     result = rh.utils.run_command(cmd, capture_output=True)
     54     current_branch = result.output.strip().replace('refs/heads/', '')
     55     if not current_branch:
     56         raise ValueError('Need to be on a tracking branch')
     57 
     58     cfg_option = 'branch.' + current_branch + '.%s'
     59     cmd = ['git', 'config', cfg_option % 'merge']
     60     result = rh.utils.run_command(cmd, capture_output=True)
     61     full_upstream = result.output.strip()
     62     # If remote is not fully qualified, add an implicit namespace.
     63     if '/' not in full_upstream:
     64         full_upstream = 'refs/heads/%s' % full_upstream
     65     cmd = ['git', 'config', cfg_option % 'remote']
     66     result = rh.utils.run_command(cmd, capture_output=True)
     67     remote = result.output.strip()
     68     if not remote or not full_upstream:
     69         raise ValueError('Need to be on a tracking branch')
     70 
     71     return full_upstream.replace('heads', 'remotes/' + remote)
     72 
     73 
     74 def get_commit_for_ref(ref):
     75     """Returns the latest commit for this ref."""
     76     cmd = ['git', 'rev-parse', ref]
     77     result = rh.utils.run_command(cmd, capture_output=True)
     78     return result.output.strip()
     79 
     80 
     81 def get_remote_revision(ref, remote):
     82     """Returns the remote revision for this ref."""
     83     prefix = 'refs/remotes/%s/' % remote
     84     if ref.startswith(prefix):
     85         return ref[len(prefix):]
     86     return ref
     87 
     88 
     89 def get_patch(commit):
     90     """Returns the patch for this commit."""
     91     cmd = ['git', 'format-patch', '--stdout', '-1', commit]
     92     return rh.utils.run_command(cmd, capture_output=True).output
     93 
     94 
     95 def _try_utf8_decode(data):
     96     """Attempts to decode a string as UTF-8.
     97 
     98     Returns:
     99       The decoded Unicode object, or the original string if parsing fails.
    100     """
    101     try:
    102         return unicode(data, 'utf-8', 'strict')
    103     except UnicodeDecodeError:
    104         return data
    105 
    106 
    107 def get_file_content(commit, path):
    108     """Returns the content of a file at a specific commit.
    109 
    110     We can't rely on the file as it exists in the filesystem as people might be
    111     uploading a series of changes which modifies the file multiple times.
    112 
    113     Note: The "content" of a symlink is just the target.  So if you're expecting
    114     a full file, you should check that first.  One way to detect is that the
    115     content will not have any newlines.
    116     """
    117     cmd = ['git', 'show', '%s:%s' % (commit, path)]
    118     return rh.utils.run_command(cmd, capture_output=True).output
    119 
    120 
    121 # RawDiffEntry represents a line of raw formatted git diff output.
    122 RawDiffEntry = rh.utils.collection(
    123     'RawDiffEntry',
    124     src_mode=0, dst_mode=0, src_sha=None, dst_sha=None,
    125     status=None, score=None, src_file=None, dst_file=None, file=None)
    126 
    127 
    128 # This regular expression pulls apart a line of raw formatted git diff output.
    129 DIFF_RE = re.compile(
    130     r':(?P<src_mode>[0-7]*) (?P<dst_mode>[0-7]*) '
    131     r'(?P<src_sha>[0-9a-f]*)(\.)* (?P<dst_sha>[0-9a-f]*)(\.)* '
    132     r'(?P<status>[ACDMRTUX])(?P<score>[0-9]+)?\t'
    133     r'(?P<src_file>[^\t]+)\t?(?P<dst_file>[^\t]+)?')
    134 
    135 
    136 def raw_diff(path, target):
    137     """Return the parsed raw format diff of target
    138 
    139     Args:
    140       path: Path to the git repository to diff in.
    141       target: The target to diff.
    142 
    143     Returns:
    144       A list of RawDiffEntry's.
    145     """
    146     entries = []
    147 
    148     cmd = ['git', 'diff', '--no-ext-diff', '-M', '--raw', target]
    149     diff = rh.utils.run_command(cmd, cwd=path, capture_output=True).output
    150     diff_lines = diff.strip().splitlines()
    151     for line in diff_lines:
    152         match = DIFF_RE.match(line)
    153         if not match:
    154             raise ValueError('Failed to parse diff output: %s' % line)
    155         diff = RawDiffEntry(**match.groupdict())
    156         diff.src_mode = int(diff.src_mode)
    157         diff.dst_mode = int(diff.dst_mode)
    158         diff.file = diff.dst_file if diff.dst_file else diff.src_file
    159         entries.append(diff)
    160 
    161     return entries
    162 
    163 
    164 def get_affected_files(commit):
    165     """Returns list of file paths that were modified/added.
    166 
    167     Returns:
    168       A list of modified/added (and perhaps deleted) files
    169     """
    170     return raw_diff(os.getcwd(), '%s^!' % commit)
    171 
    172 
    173 def get_commits(ignore_merged_commits=False):
    174     """Returns a list of commits for this review."""
    175     cmd = ['git', 'log', '%s..' % get_upstream_branch(), '--format=%H']
    176     if ignore_merged_commits:
    177         cmd.append('--first-parent')
    178     return rh.utils.run_command(cmd, capture_output=True).output.split()
    179 
    180 
    181 def get_commit_desc(commit):
    182     """Returns the full commit message of a commit."""
    183     cmd = ['git', 'log', '--format=%B', commit + '^!']
    184     return rh.utils.run_command(cmd, capture_output=True).output
    185 
    186 
    187 def find_repo_root(path=None):
    188     """Locate the top level of this repo checkout starting at |path|."""
    189     if path is None:
    190         path = os.getcwd()
    191     orig_path = path
    192 
    193     path = os.path.abspath(path)
    194     while not os.path.exists(os.path.join(path, '.repo')):
    195         path = os.path.dirname(path)
    196         if path == '/':
    197             raise ValueError('Could not locate .repo in %s' % orig_path)
    198 
    199     return path
    200