Home | History | Annotate | Download | only in releasetools
      1 # Copyright (C) 2014 The Android Open Source Project
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #      http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 
     15 from __future__ import print_function
     16 
     17 from collections import deque, OrderedDict
     18 from hashlib import sha1
     19 import heapq
     20 import itertools
     21 import multiprocessing
     22 import os
     23 import pprint
     24 import re
     25 import subprocess
     26 import sys
     27 import threading
     28 import tempfile
     29 
     30 from rangelib import *
     31 
     32 __all__ = ["EmptyImage", "DataImage", "BlockImageDiff"]
     33 
     34 def compute_patch(src, tgt, imgdiff=False):
     35   srcfd, srcfile = tempfile.mkstemp(prefix="src-")
     36   tgtfd, tgtfile = tempfile.mkstemp(prefix="tgt-")
     37   patchfd, patchfile = tempfile.mkstemp(prefix="patch-")
     38   os.close(patchfd)
     39 
     40   try:
     41     with os.fdopen(srcfd, "wb") as f_src:
     42       for p in src:
     43         f_src.write(p)
     44 
     45     with os.fdopen(tgtfd, "wb") as f_tgt:
     46       for p in tgt:
     47         f_tgt.write(p)
     48     try:
     49       os.unlink(patchfile)
     50     except OSError:
     51       pass
     52     if imgdiff:
     53       p = subprocess.call(["imgdiff", "-z", srcfile, tgtfile, patchfile],
     54                           stdout=open("/dev/null", "a"),
     55                           stderr=subprocess.STDOUT)
     56     else:
     57       p = subprocess.call(["bsdiff", srcfile, tgtfile, patchfile])
     58 
     59     if p:
     60       raise ValueError("diff failed: " + str(p))
     61 
     62     with open(patchfile, "rb") as f:
     63       return f.read()
     64   finally:
     65     try:
     66       os.unlink(srcfile)
     67       os.unlink(tgtfile)
     68       os.unlink(patchfile)
     69     except OSError:
     70       pass
     71 
     72 class EmptyImage(object):
     73   """A zero-length image."""
     74   blocksize = 4096
     75   care_map = RangeSet()
     76   total_blocks = 0
     77   file_map = {}
     78   def ReadRangeSet(self, ranges):
     79     return ()
     80   def TotalSha1(self):
     81     return sha1().hexdigest()
     82 
     83 
     84 class DataImage(object):
     85   """An image wrapped around a single string of data."""
     86 
     87   def __init__(self, data, trim=False, pad=False):
     88     self.data = data
     89     self.blocksize = 4096
     90 
     91     assert not (trim and pad)
     92 
     93     partial = len(self.data) % self.blocksize
     94     if partial > 0:
     95       if trim:
     96         self.data = self.data[:-partial]
     97       elif pad:
     98         self.data += '\0' * (self.blocksize - partial)
     99       else:
    100         raise ValueError(("data for DataImage must be multiple of %d bytes "
    101                           "unless trim or pad is specified") %
    102                          (self.blocksize,))
    103 
    104     assert len(self.data) % self.blocksize == 0
    105 
    106     self.total_blocks = len(self.data) / self.blocksize
    107     self.care_map = RangeSet(data=(0, self.total_blocks))
    108 
    109     zero_blocks = []
    110     nonzero_blocks = []
    111     reference = '\0' * self.blocksize
    112 
    113     for i in range(self.total_blocks):
    114       d = self.data[i*self.blocksize : (i+1)*self.blocksize]
    115       if d == reference:
    116         zero_blocks.append(i)
    117         zero_blocks.append(i+1)
    118       else:
    119         nonzero_blocks.append(i)
    120         nonzero_blocks.append(i+1)
    121 
    122     self.file_map = {"__ZERO": RangeSet(zero_blocks),
    123                      "__NONZERO": RangeSet(nonzero_blocks)}
    124 
    125   def ReadRangeSet(self, ranges):
    126     return [self.data[s*self.blocksize:e*self.blocksize] for (s, e) in ranges]
    127 
    128   def TotalSha1(self):
    129     if not hasattr(self, "sha1"):
    130       self.sha1 = sha1(self.data).hexdigest()
    131     return self.sha1
    132 
    133 
    134 class Transfer(object):
    135   def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, style, by_id):
    136     self.tgt_name = tgt_name
    137     self.src_name = src_name
    138     self.tgt_ranges = tgt_ranges
    139     self.src_ranges = src_ranges
    140     self.style = style
    141     self.intact = (getattr(tgt_ranges, "monotonic", False) and
    142                    getattr(src_ranges, "monotonic", False))
    143     self.goes_before = {}
    144     self.goes_after = {}
    145 
    146     self.stash_before = []
    147     self.use_stash = []
    148 
    149     self.id = len(by_id)
    150     by_id.append(self)
    151 
    152   def NetStashChange(self):
    153     return (sum(sr.size() for (_, sr) in self.stash_before) -
    154             sum(sr.size() for (_, sr) in self.use_stash))
    155 
    156   def __str__(self):
    157     return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style +
    158             " to " + str(self.tgt_ranges) + ">")
    159 
    160 
    161 # BlockImageDiff works on two image objects.  An image object is
    162 # anything that provides the following attributes:
    163 #
    164 #    blocksize: the size in bytes of a block, currently must be 4096.
    165 #
    166 #    total_blocks: the total size of the partition/image, in blocks.
    167 #
    168 #    care_map: a RangeSet containing which blocks (in the range [0,
    169 #      total_blocks) we actually care about; i.e. which blocks contain
    170 #      data.
    171 #
    172 #    file_map: a dict that partitions the blocks contained in care_map
    173 #      into smaller domains that are useful for doing diffs on.
    174 #      (Typically a domain is a file, and the key in file_map is the
    175 #      pathname.)
    176 #
    177 #    ReadRangeSet(): a function that takes a RangeSet and returns the
    178 #      data contained in the image blocks of that RangeSet.  The data
    179 #      is returned as a list or tuple of strings; concatenating the
    180 #      elements together should produce the requested data.
    181 #      Implementations are free to break up the data into list/tuple
    182 #      elements in any way that is convenient.
    183 #
    184 #    TotalSha1(): a function that returns (as a hex string) the SHA-1
    185 #      hash of all the data in the image (ie, all the blocks in the
    186 #      care_map)
    187 #
    188 # When creating a BlockImageDiff, the src image may be None, in which
    189 # case the list of transfers produced will never read from the
    190 # original image.
    191 
    192 class BlockImageDiff(object):
    193   def __init__(self, tgt, src=None, threads=None, version=2):
    194     if threads is None:
    195       threads = multiprocessing.cpu_count() // 2
    196       if threads == 0: threads = 1
    197     self.threads = threads
    198     self.version = version
    199 
    200     assert version in (1, 2)
    201 
    202     self.tgt = tgt
    203     if src is None:
    204       src = EmptyImage()
    205     self.src = src
    206 
    207     # The updater code that installs the patch always uses 4k blocks.
    208     assert tgt.blocksize == 4096
    209     assert src.blocksize == 4096
    210 
    211     # The range sets in each filemap should comprise a partition of
    212     # the care map.
    213     self.AssertPartition(src.care_map, src.file_map.values())
    214     self.AssertPartition(tgt.care_map, tgt.file_map.values())
    215 
    216   def Compute(self, prefix):
    217     # When looking for a source file to use as the diff input for a
    218     # target file, we try:
    219     #   1) an exact path match if available, otherwise
    220     #   2) a exact basename match if available, otherwise
    221     #   3) a basename match after all runs of digits are replaced by
    222     #      "#" if available, otherwise
    223     #   4) we have no source for this target.
    224     self.AbbreviateSourceNames()
    225     self.FindTransfers()
    226 
    227     # Find the ordering dependencies among transfers (this is O(n^2)
    228     # in the number of transfers).
    229     self.GenerateDigraph()
    230     # Find a sequence of transfers that satisfies as many ordering
    231     # dependencies as possible (heuristically).
    232     self.FindVertexSequence()
    233     # Fix up the ordering dependencies that the sequence didn't
    234     # satisfy.
    235     if self.version == 1:
    236       self.RemoveBackwardEdges()
    237     else:
    238       self.ReverseBackwardEdges()
    239       self.ImproveVertexSequence()
    240 
    241     # Double-check our work.
    242     self.AssertSequenceGood()
    243 
    244     self.ComputePatches(prefix)
    245     self.WriteTransfers(prefix)
    246 
    247   def WriteTransfers(self, prefix):
    248     out = []
    249 
    250     total = 0
    251     performs_read = False
    252 
    253     stashes = {}
    254     stashed_blocks = 0
    255     max_stashed_blocks = 0
    256 
    257     free_stash_ids = []
    258     next_stash_id = 0
    259 
    260     for xf in self.transfers:
    261 
    262       if self.version < 2:
    263         assert not xf.stash_before
    264         assert not xf.use_stash
    265 
    266       for s, sr in xf.stash_before:
    267         assert s not in stashes
    268         if free_stash_ids:
    269           sid = heapq.heappop(free_stash_ids)
    270         else:
    271           sid = next_stash_id
    272           next_stash_id += 1
    273         stashes[s] = sid
    274         stashed_blocks += sr.size()
    275         out.append("stash %d %s\n" % (sid, sr.to_string_raw()))
    276 
    277       if stashed_blocks > max_stashed_blocks:
    278         max_stashed_blocks = stashed_blocks
    279 
    280       if self.version == 1:
    281         src_string = xf.src_ranges.to_string_raw()
    282       elif self.version == 2:
    283 
    284         #   <# blocks> <src ranges>
    285         #     OR
    286         #   <# blocks> <src ranges> <src locs> <stash refs...>
    287         #     OR
    288         #   <# blocks> - <stash refs...>
    289 
    290         size = xf.src_ranges.size()
    291         src_string = [str(size)]
    292 
    293         unstashed_src_ranges = xf.src_ranges
    294         mapped_stashes = []
    295         for s, sr in xf.use_stash:
    296           sid = stashes.pop(s)
    297           stashed_blocks -= sr.size()
    298           unstashed_src_ranges = unstashed_src_ranges.subtract(sr)
    299           sr = xf.src_ranges.map_within(sr)
    300           mapped_stashes.append(sr)
    301           src_string.append("%d:%s" % (sid, sr.to_string_raw()))
    302           heapq.heappush(free_stash_ids, sid)
    303 
    304         if unstashed_src_ranges:
    305           src_string.insert(1, unstashed_src_ranges.to_string_raw())
    306           if xf.use_stash:
    307             mapped_unstashed = xf.src_ranges.map_within(unstashed_src_ranges)
    308             src_string.insert(2, mapped_unstashed.to_string_raw())
    309             mapped_stashes.append(mapped_unstashed)
    310             self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
    311         else:
    312           src_string.insert(1, "-")
    313           self.AssertPartition(RangeSet(data=(0, size)), mapped_stashes)
    314 
    315         src_string = " ".join(src_string)
    316 
    317       # both versions:
    318       #   zero <rangeset>
    319       #   new <rangeset>
    320       #   erase <rangeset>
    321       #
    322       # version 1:
    323       #   bsdiff patchstart patchlen <src rangeset> <tgt rangeset>
    324       #   imgdiff patchstart patchlen <src rangeset> <tgt rangeset>
    325       #   move <src rangeset> <tgt rangeset>
    326       #
    327       # version 2:
    328       #   bsdiff patchstart patchlen <tgt rangeset> <src_string>
    329       #   imgdiff patchstart patchlen <tgt rangeset> <src_string>
    330       #   move <tgt rangeset> <src_string>
    331 
    332       tgt_size = xf.tgt_ranges.size()
    333 
    334       if xf.style == "new":
    335         assert xf.tgt_ranges
    336         out.append("%s %s\n" % (xf.style, xf.tgt_ranges.to_string_raw()))
    337         total += tgt_size
    338       elif xf.style == "move":
    339         performs_read = True
    340         assert xf.tgt_ranges
    341         assert xf.src_ranges.size() == tgt_size
    342         if xf.src_ranges != xf.tgt_ranges:
    343           if self.version == 1:
    344             out.append("%s %s %s\n" % (
    345                 xf.style,
    346                 xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
    347           elif self.version == 2:
    348             out.append("%s %s %s\n" % (
    349                 xf.style,
    350                 xf.tgt_ranges.to_string_raw(), src_string))
    351           total += tgt_size
    352       elif xf.style in ("bsdiff", "imgdiff"):
    353         performs_read = True
    354         assert xf.tgt_ranges
    355         assert xf.src_ranges
    356         if self.version == 1:
    357           out.append("%s %d %d %s %s\n" % (
    358               xf.style, xf.patch_start, xf.patch_len,
    359               xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
    360         elif self.version == 2:
    361           out.append("%s %d %d %s %s\n" % (
    362               xf.style, xf.patch_start, xf.patch_len,
    363               xf.tgt_ranges.to_string_raw(), src_string))
    364         total += tgt_size
    365       elif xf.style == "zero":
    366         assert xf.tgt_ranges
    367         to_zero = xf.tgt_ranges.subtract(xf.src_ranges)
    368         if to_zero:
    369           out.append("%s %s\n" % (xf.style, to_zero.to_string_raw()))
    370           total += to_zero.size()
    371       else:
    372         raise ValueError, "unknown transfer style '%s'\n" % (xf.style,)
    373 
    374 
    375       # sanity check: abort if we're going to need more than 512 MB if
    376       # stash space
    377       assert max_stashed_blocks * self.tgt.blocksize < (512 << 20)
    378 
    379     all_tgt = RangeSet(data=(0, self.tgt.total_blocks))
    380     if performs_read:
    381       # if some of the original data is used, then at the end we'll
    382       # erase all the blocks on the partition that don't contain data
    383       # in the new image.
    384       new_dontcare = all_tgt.subtract(self.tgt.care_map)
    385       if new_dontcare:
    386         out.append("erase %s\n" % (new_dontcare.to_string_raw(),))
    387     else:
    388       # if nothing is read (ie, this is a full OTA), then we can start
    389       # by erasing the entire partition.
    390       out.insert(0, "erase %s\n" % (all_tgt.to_string_raw(),))
    391 
    392     out.insert(0, "%d\n" % (self.version,))   # format version number
    393     out.insert(1, str(total) + "\n")
    394     if self.version >= 2:
    395       # version 2 only: after the total block count, we give the number
    396       # of stash slots needed, and the maximum size needed (in blocks)
    397       out.insert(2, str(next_stash_id) + "\n")
    398       out.insert(3, str(max_stashed_blocks) + "\n")
    399 
    400     with open(prefix + ".transfer.list", "wb") as f:
    401       for i in out:
    402         f.write(i)
    403 
    404     if self.version >= 2:
    405       print("max stashed blocks: %d  (%d bytes)\n" % (
    406           max_stashed_blocks, max_stashed_blocks * self.tgt.blocksize))
    407 
    408   def ComputePatches(self, prefix):
    409     print("Reticulating splines...")
    410     diff_q = []
    411     patch_num = 0
    412     with open(prefix + ".new.dat", "wb") as new_f:
    413       for xf in self.transfers:
    414         if xf.style == "zero":
    415           pass
    416         elif xf.style == "new":
    417           for piece in self.tgt.ReadRangeSet(xf.tgt_ranges):
    418             new_f.write(piece)
    419         elif xf.style == "diff":
    420           src = self.src.ReadRangeSet(xf.src_ranges)
    421           tgt = self.tgt.ReadRangeSet(xf.tgt_ranges)
    422 
    423           # We can't compare src and tgt directly because they may have
    424           # the same content but be broken up into blocks differently, eg:
    425           #
    426           #    ["he", "llo"]  vs  ["h", "ello"]
    427           #
    428           # We want those to compare equal, ideally without having to
    429           # actually concatenate the strings (these may be tens of
    430           # megabytes).
    431 
    432           src_sha1 = sha1()
    433           for p in src:
    434             src_sha1.update(p)
    435           tgt_sha1 = sha1()
    436           tgt_size = 0
    437           for p in tgt:
    438             tgt_sha1.update(p)
    439             tgt_size += len(p)
    440 
    441           if src_sha1.digest() == tgt_sha1.digest():
    442             # These are identical; we don't need to generate a patch,
    443             # just issue copy commands on the device.
    444             xf.style = "move"
    445           else:
    446             # For files in zip format (eg, APKs, JARs, etc.) we would
    447             # like to use imgdiff -z if possible (because it usually
    448             # produces significantly smaller patches than bsdiff).
    449             # This is permissible if:
    450             #
    451             #  - the source and target files are monotonic (ie, the
    452             #    data is stored with blocks in increasing order), and
    453             #  - we haven't removed any blocks from the source set.
    454             #
    455             # If these conditions are satisfied then appending all the
    456             # blocks in the set together in order will produce a valid
    457             # zip file (plus possibly extra zeros in the last block),
    458             # which is what imgdiff needs to operate.  (imgdiff is
    459             # fine with extra zeros at the end of the file.)
    460             imgdiff = (xf.intact and
    461                        xf.tgt_name.split(".")[-1].lower()
    462                        in ("apk", "jar", "zip"))
    463             xf.style = "imgdiff" if imgdiff else "bsdiff"
    464             diff_q.append((tgt_size, src, tgt, xf, patch_num))
    465             patch_num += 1
    466 
    467         else:
    468           assert False, "unknown style " + xf.style
    469 
    470     if diff_q:
    471       if self.threads > 1:
    472         print("Computing patches (using %d threads)..." % (self.threads,))
    473       else:
    474         print("Computing patches...")
    475       diff_q.sort()
    476 
    477       patches = [None] * patch_num
    478 
    479       lock = threading.Lock()
    480       def diff_worker():
    481         while True:
    482           with lock:
    483             if not diff_q: return
    484             tgt_size, src, tgt, xf, patchnum = diff_q.pop()
    485           patch = compute_patch(src, tgt, imgdiff=(xf.style == "imgdiff"))
    486           size = len(patch)
    487           with lock:
    488             patches[patchnum] = (patch, xf)
    489             print("%10d %10d (%6.2f%%) %7s %s" % (
    490                 size, tgt_size, size * 100.0 / tgt_size, xf.style,
    491                 xf.tgt_name if xf.tgt_name == xf.src_name else (
    492                     xf.tgt_name + " (from " + xf.src_name + ")")))
    493 
    494       threads = [threading.Thread(target=diff_worker)
    495                  for i in range(self.threads)]
    496       for th in threads:
    497         th.start()
    498       while threads:
    499         threads.pop().join()
    500     else:
    501       patches = []
    502 
    503     p = 0
    504     with open(prefix + ".patch.dat", "wb") as patch_f:
    505       for patch, xf in patches:
    506         xf.patch_start = p
    507         xf.patch_len = len(patch)
    508         patch_f.write(patch)
    509         p += len(patch)
    510 
    511   def AssertSequenceGood(self):
    512     # Simulate the sequences of transfers we will output, and check that:
    513     # - we never read a block after writing it, and
    514     # - we write every block we care about exactly once.
    515 
    516     # Start with no blocks having been touched yet.
    517     touched = RangeSet()
    518 
    519     # Imagine processing the transfers in order.
    520     for xf in self.transfers:
    521       # Check that the input blocks for this transfer haven't yet been touched.
    522 
    523       x = xf.src_ranges
    524       if self.version >= 2:
    525         for _, sr in xf.use_stash:
    526           x = x.subtract(sr)
    527 
    528       assert not touched.overlaps(x)
    529       # Check that the output blocks for this transfer haven't yet been touched.
    530       assert not touched.overlaps(xf.tgt_ranges)
    531       # Touch all the blocks written by this transfer.
    532       touched = touched.union(xf.tgt_ranges)
    533 
    534     # Check that we've written every target block.
    535     assert touched == self.tgt.care_map
    536 
    537   def ImproveVertexSequence(self):
    538     print("Improving vertex order...")
    539 
    540     # At this point our digraph is acyclic; we reversed any edges that
    541     # were backwards in the heuristically-generated sequence.  The
    542     # previously-generated order is still acceptable, but we hope to
    543     # find a better order that needs less memory for stashed data.
    544     # Now we do a topological sort to generate a new vertex order,
    545     # using a greedy algorithm to choose which vertex goes next
    546     # whenever we have a choice.
    547 
    548     # Make a copy of the edge set; this copy will get destroyed by the
    549     # algorithm.
    550     for xf in self.transfers:
    551       xf.incoming = xf.goes_after.copy()
    552       xf.outgoing = xf.goes_before.copy()
    553 
    554     L = []   # the new vertex order
    555 
    556     # S is the set of sources in the remaining graph; we always choose
    557     # the one that leaves the least amount of stashed data after it's
    558     # executed.
    559     S = [(u.NetStashChange(), u.order, u) for u in self.transfers
    560          if not u.incoming]
    561     heapq.heapify(S)
    562 
    563     while S:
    564       _, _, xf = heapq.heappop(S)
    565       L.append(xf)
    566       for u in xf.outgoing:
    567         del u.incoming[xf]
    568         if not u.incoming:
    569           heapq.heappush(S, (u.NetStashChange(), u.order, u))
    570 
    571     # if this fails then our graph had a cycle.
    572     assert len(L) == len(self.transfers)
    573 
    574     self.transfers = L
    575     for i, xf in enumerate(L):
    576       xf.order = i
    577 
    578   def RemoveBackwardEdges(self):
    579     print("Removing backward edges...")
    580     in_order = 0
    581     out_of_order = 0
    582     lost_source = 0
    583 
    584     for xf in self.transfers:
    585       lost = 0
    586       size = xf.src_ranges.size()
    587       for u in xf.goes_before:
    588         # xf should go before u
    589         if xf.order < u.order:
    590           # it does, hurray!
    591           in_order += 1
    592         else:
    593           # it doesn't, boo.  trim the blocks that u writes from xf's
    594           # source, so that xf can go after u.
    595           out_of_order += 1
    596           assert xf.src_ranges.overlaps(u.tgt_ranges)
    597           xf.src_ranges = xf.src_ranges.subtract(u.tgt_ranges)
    598           xf.intact = False
    599 
    600       if xf.style == "diff" and not xf.src_ranges:
    601         # nothing left to diff from; treat as new data
    602         xf.style = "new"
    603 
    604       lost = size - xf.src_ranges.size()
    605       lost_source += lost
    606 
    607     print(("  %d/%d dependencies (%.2f%%) were violated; "
    608            "%d source blocks removed.") %
    609           (out_of_order, in_order + out_of_order,
    610            (out_of_order * 100.0 / (in_order + out_of_order))
    611            if (in_order + out_of_order) else 0.0,
    612            lost_source))
    613 
    614   def ReverseBackwardEdges(self):
    615     print("Reversing backward edges...")
    616     in_order = 0
    617     out_of_order = 0
    618     stashes = 0
    619     stash_size = 0
    620 
    621     for xf in self.transfers:
    622       lost = 0
    623       size = xf.src_ranges.size()
    624       for u in xf.goes_before.copy():
    625         # xf should go before u
    626         if xf.order < u.order:
    627           # it does, hurray!
    628           in_order += 1
    629         else:
    630           # it doesn't, boo.  modify u to stash the blocks that it
    631           # writes that xf wants to read, and then require u to go
    632           # before xf.
    633           out_of_order += 1
    634 
    635           overlap = xf.src_ranges.intersect(u.tgt_ranges)
    636           assert overlap
    637 
    638           u.stash_before.append((stashes, overlap))
    639           xf.use_stash.append((stashes, overlap))
    640           stashes += 1
    641           stash_size += overlap.size()
    642 
    643           # reverse the edge direction; now xf must go after u
    644           del xf.goes_before[u]
    645           del u.goes_after[xf]
    646           xf.goes_after[u] = None    # value doesn't matter
    647           u.goes_before[xf] = None
    648 
    649     print(("  %d/%d dependencies (%.2f%%) were violated; "
    650            "%d source blocks stashed.") %
    651           (out_of_order, in_order + out_of_order,
    652            (out_of_order * 100.0 / (in_order + out_of_order))
    653            if (in_order + out_of_order) else 0.0,
    654            stash_size))
    655 
    656   def FindVertexSequence(self):
    657     print("Finding vertex sequence...")
    658 
    659     # This is based on "A Fast & Effective Heuristic for the Feedback
    660     # Arc Set Problem" by P. Eades, X. Lin, and W.F. Smyth.  Think of
    661     # it as starting with the digraph G and moving all the vertices to
    662     # be on a horizontal line in some order, trying to minimize the
    663     # number of edges that end up pointing to the left.  Left-pointing
    664     # edges will get removed to turn the digraph into a DAG.  In this
    665     # case each edge has a weight which is the number of source blocks
    666     # we'll lose if that edge is removed; we try to minimize the total
    667     # weight rather than just the number of edges.
    668 
    669     # Make a copy of the edge set; this copy will get destroyed by the
    670     # algorithm.
    671     for xf in self.transfers:
    672       xf.incoming = xf.goes_after.copy()
    673       xf.outgoing = xf.goes_before.copy()
    674 
    675     # We use an OrderedDict instead of just a set so that the output
    676     # is repeatable; otherwise it would depend on the hash values of
    677     # the transfer objects.
    678     G = OrderedDict()
    679     for xf in self.transfers:
    680       G[xf] = None
    681     s1 = deque()  # the left side of the sequence, built from left to right
    682     s2 = deque()  # the right side of the sequence, built from right to left
    683 
    684     while G:
    685 
    686       # Put all sinks at the end of the sequence.
    687       while True:
    688         sinks = [u for u in G if not u.outgoing]
    689         if not sinks: break
    690         for u in sinks:
    691           s2.appendleft(u)
    692           del G[u]
    693           for iu in u.incoming:
    694             del iu.outgoing[u]
    695 
    696       # Put all the sources at the beginning of the sequence.
    697       while True:
    698         sources = [u for u in G if not u.incoming]
    699         if not sources: break
    700         for u in sources:
    701           s1.append(u)
    702           del G[u]
    703           for iu in u.outgoing:
    704             del iu.incoming[u]
    705 
    706       if not G: break
    707 
    708       # Find the "best" vertex to put next.  "Best" is the one that
    709       # maximizes the net difference in source blocks saved we get by
    710       # pretending it's a source rather than a sink.
    711 
    712       max_d = None
    713       best_u = None
    714       for u in G:
    715         d = sum(u.outgoing.values()) - sum(u.incoming.values())
    716         if best_u is None or d > max_d:
    717           max_d = d
    718           best_u = u
    719 
    720       u = best_u
    721       s1.append(u)
    722       del G[u]
    723       for iu in u.outgoing:
    724         del iu.incoming[u]
    725       for iu in u.incoming:
    726         del iu.outgoing[u]
    727 
    728     # Now record the sequence in the 'order' field of each transfer,
    729     # and by rearranging self.transfers to be in the chosen sequence.
    730 
    731     new_transfers = []
    732     for x in itertools.chain(s1, s2):
    733       x.order = len(new_transfers)
    734       new_transfers.append(x)
    735       del x.incoming
    736       del x.outgoing
    737 
    738     self.transfers = new_transfers
    739 
    740   def GenerateDigraph(self):
    741     print("Generating digraph...")
    742     for a in self.transfers:
    743       for b in self.transfers:
    744         if a is b: continue
    745 
    746         # If the blocks written by A are read by B, then B needs to go before A.
    747         i = a.tgt_ranges.intersect(b.src_ranges)
    748         if i:
    749           if b.src_name == "__ZERO":
    750             # the cost of removing source blocks for the __ZERO domain
    751             # is (nearly) zero.
    752             size = 0
    753           else:
    754             size = i.size()
    755           b.goes_before[a] = size
    756           a.goes_after[b] = size
    757 
    758   def FindTransfers(self):
    759     self.transfers = []
    760     empty = RangeSet()
    761     for tgt_fn, tgt_ranges in self.tgt.file_map.items():
    762       if tgt_fn == "__ZERO":
    763         # the special "__ZERO" domain is all the blocks not contained
    764         # in any file and that are filled with zeros.  We have a
    765         # special transfer style for zero blocks.
    766         src_ranges = self.src.file_map.get("__ZERO", empty)
    767         Transfer(tgt_fn, "__ZERO", tgt_ranges, src_ranges,
    768                  "zero", self.transfers)
    769         continue
    770 
    771       elif tgt_fn in self.src.file_map:
    772         # Look for an exact pathname match in the source.
    773         Transfer(tgt_fn, tgt_fn, tgt_ranges, self.src.file_map[tgt_fn],
    774                  "diff", self.transfers)
    775         continue
    776 
    777       b = os.path.basename(tgt_fn)
    778       if b in self.src_basenames:
    779         # Look for an exact basename match in the source.
    780         src_fn = self.src_basenames[b]
    781         Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
    782                  "diff", self.transfers)
    783         continue
    784 
    785       b = re.sub("[0-9]+", "#", b)
    786       if b in self.src_numpatterns:
    787         # Look for a 'number pattern' match (a basename match after
    788         # all runs of digits are replaced by "#").  (This is useful
    789         # for .so files that contain version numbers in the filename
    790         # that get bumped.)
    791         src_fn = self.src_numpatterns[b]
    792         Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
    793                  "diff", self.transfers)
    794         continue
    795 
    796       Transfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
    797 
    798   def AbbreviateSourceNames(self):
    799     self.src_basenames = {}
    800     self.src_numpatterns = {}
    801 
    802     for k in self.src.file_map.keys():
    803       b = os.path.basename(k)
    804       self.src_basenames[b] = k
    805       b = re.sub("[0-9]+", "#", b)
    806       self.src_numpatterns[b] = k
    807 
    808   @staticmethod
    809   def AssertPartition(total, seq):
    810     """Assert that all the RangeSets in 'seq' form a partition of the
    811     'total' RangeSet (ie, they are nonintersecting and their union
    812     equals 'total')."""
    813     so_far = RangeSet()
    814     for i in seq:
    815       assert not so_far.overlaps(i)
    816       so_far = so_far.union(i)
    817     assert so_far == total
    818