Home | History | Annotate | Download | only in rh
      1 # -*- coding:utf-8 -*-
      2 # Copyright 2016 The Android Open Source Project
      3 #
      4 # Licensed under the Apache License, Version 2.0 (the "License");
      5 # you may not use this file except in compliance with the License.
      6 # You may obtain a copy of the License at
      7 #
      8 #      http://www.apache.org/licenses/LICENSE-2.0
      9 #
     10 # Unless required by applicable law or agreed to in writing, software
     11 # distributed under the License is distributed on an "AS IS" BASIS,
     12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13 # See the License for the specific language governing permissions and
     14 # limitations under the License.
     15 
     16 """Git helper functions."""
     17 
     18 from __future__ import print_function
     19 
     20 import os
     21 import re
     22 import sys
     23 
     24 _path = os.path.realpath(__file__ + '/../..')
     25 if sys.path[0] != _path:
     26     sys.path.insert(0, _path)
     27 del _path
     28 
     29 import rh.utils
     30 
     31 
     32 def get_upstream_remote():
     33     """Returns the current upstream remote name."""
     34     # First get the current branch name.
     35     cmd = ['git', 'rev-parse', '--abbrev-ref', 'HEAD']
     36     result = rh.utils.run_command(cmd, capture_output=True)
     37     branch = result.output.strip()
     38 
     39     # Then get the remote associated with this branch.
     40     cmd = ['git', 'config', 'branch.%s.remote' % branch]
     41     result = rh.utils.run_command(cmd, capture_output=True)
     42     return result.output.strip()
     43 
     44 
     45 def get_upstream_branch():
     46     """Returns the upstream tracking branch of the current branch.
     47 
     48     Raises:
     49       Error if there is no tracking branch
     50     """
     51     cmd = ['git', 'symbolic-ref', 'HEAD']
     52     result = rh.utils.run_command(cmd, capture_output=True)
     53     current_branch = result.output.strip().replace('refs/heads/', '')
     54     if not current_branch:
     55         raise ValueError('Need to be on a tracking branch')
     56 
     57     cfg_option = 'branch.' + current_branch + '.%s'
     58     cmd = ['git', 'config', cfg_option % 'merge']
     59     result = rh.utils.run_command(cmd, capture_output=True)
     60     full_upstream = result.output.strip()
     61     # If remote is not fully qualified, add an implicit namespace.
     62     if '/' not in full_upstream:
     63         full_upstream = 'refs/heads/%s' % full_upstream
     64     cmd = ['git', 'config', cfg_option % 'remote']
     65     result = rh.utils.run_command(cmd, capture_output=True)
     66     remote = result.output.strip()
     67     if not remote or not full_upstream:
     68         raise ValueError('Need to be on a tracking branch')
     69 
     70     return full_upstream.replace('heads', 'remotes/' + remote)
     71 
     72 
     73 def get_commit_for_ref(ref):
     74     """Returns the latest commit for this ref."""
     75     cmd = ['git', 'rev-parse', ref]
     76     result = rh.utils.run_command(cmd, capture_output=True)
     77     return result.output.strip()
     78 
     79 
     80 def get_remote_revision(ref, remote):
     81     """Returns the remote revision for this ref."""
     82     prefix = 'refs/remotes/%s/' % remote
     83     if ref.startswith(prefix):
     84         return ref[len(prefix):]
     85     return ref
     86 
     87 
     88 def get_patch(commit):
     89     """Returns the patch for this commit."""
     90     cmd = ['git', 'format-patch', '--stdout', '-1', commit]
     91     return rh.utils.run_command(cmd, capture_output=True).output
     92 
     93 
     94 def _try_utf8_decode(data):
     95     """Attempts to decode a string as UTF-8.
     96 
     97     Returns:
     98       The decoded Unicode object, or the original string if parsing fails.
     99     """
    100     try:
    101         return unicode(data, 'utf-8', 'strict')
    102     except UnicodeDecodeError:
    103         return data
    104 
    105 
    106 def get_file_content(commit, path):
    107     """Returns the content of a file at a specific commit.
    108 
    109     We can't rely on the file as it exists in the filesystem as people might be
    110     uploading a series of changes which modifies the file multiple times.
    111 
    112     Note: The "content" of a symlink is just the target.  So if you're expecting
    113     a full file, you should check that first.  One way to detect is that the
    114     content will not have any newlines.
    115     """
    116     cmd = ['git', 'show', '%s:%s' % (commit, path)]
    117     return rh.utils.run_command(cmd, capture_output=True).output
    118 
    119 
    120 # RawDiffEntry represents a line of raw formatted git diff output.
    121 RawDiffEntry = rh.utils.collection(
    122     'RawDiffEntry',
    123     src_mode=0, dst_mode=0, src_sha=None, dst_sha=None,
    124     status=None, score=None, src_file=None, dst_file=None, file=None)
    125 
    126 
    127 # This regular expression pulls apart a line of raw formatted git diff output.
    128 DIFF_RE = re.compile(
    129     r':(?P<src_mode>[0-7]*) (?P<dst_mode>[0-7]*) '
    130     r'(?P<src_sha>[0-9a-f]*)(\.)* (?P<dst_sha>[0-9a-f]*)(\.)* '
    131     r'(?P<status>[ACDMRTUX])(?P<score>[0-9]+)?\t'
    132     r'(?P<src_file>[^\t]+)\t?(?P<dst_file>[^\t]+)?')
    133 
    134 
    135 def raw_diff(path, target):
    136     """Return the parsed raw format diff of target
    137 
    138     Args:
    139       path: Path to the git repository to diff in.
    140       target: The target to diff.
    141 
    142     Returns:
    143       A list of RawDiffEntry's.
    144     """
    145     entries = []
    146 
    147     cmd = ['git', 'diff', '-M', '--raw', target]
    148     diff = rh.utils.run_command(cmd, cwd=path, capture_output=True).output
    149     diff_lines = diff.strip().splitlines()
    150     for line in diff_lines:
    151         match = DIFF_RE.match(line)
    152         if not match:
    153             raise ValueError('Failed to parse diff output: %s' % line)
    154         diff = RawDiffEntry(**match.groupdict())
    155         diff.src_mode = int(diff.src_mode)
    156         diff.dst_mode = int(diff.dst_mode)
    157         diff.file = diff.dst_file if diff.dst_file else diff.src_file
    158         entries.append(diff)
    159 
    160     return entries
    161 
    162 
    163 def get_affected_files(commit):
    164     """Returns list of file paths that were modified/added.
    165 
    166     Returns:
    167       A list of modified/added (and perhaps deleted) files
    168     """
    169     return raw_diff(os.getcwd(), '%s^!' % commit)
    170 
    171 
    172 def get_commits(ignore_merged_commits=False):
    173     """Returns a list of commits for this review."""
    174     cmd = ['git', 'log', '%s..' % get_upstream_branch(), '--format=%H']
    175     if ignore_merged_commits:
    176         cmd.append('--first-parent')
    177     return rh.utils.run_command(cmd, capture_output=True).output.split()
    178 
    179 
    180 def get_commit_desc(commit):
    181     """Returns the full commit message of a commit."""
    182     cmd = ['git', 'log', '--format=%B', commit + '^!']
    183     return rh.utils.run_command(cmd, capture_output=True).output
    184 
    185 
    186 def find_repo_root(path=None):
    187     """Locate the top level of this repo checkout starting at |path|."""
    188     if path is None:
    189         path = os.getcwd()
    190     orig_path = path
    191 
    192     path = os.path.abspath(path)
    193     while not os.path.exists(os.path.join(path, '.repo')):
    194         path = os.path.dirname(path)
    195         if path == '/':
    196             raise ValueError('Could not locate .repo in %s' % orig_path)
    197 
    198     return path
    199