Home | History | Annotate | Download | only in releasetools
      1 # Copyright (C) 2014 The Android Open Source Project
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #      http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 
     15 from __future__ import print_function
     16 
     17 from collections import deque, OrderedDict
     18 from hashlib import sha1
     19 import itertools
     20 import multiprocessing
     21 import os
     22 import pprint
     23 import re
     24 import subprocess
     25 import sys
     26 import threading
     27 import tempfile
     28 
     29 from rangelib import *
     30 
     31 __all__ = ["EmptyImage", "DataImage", "BlockImageDiff"]
     32 
     33 def compute_patch(src, tgt, imgdiff=False):
     34   srcfd, srcfile = tempfile.mkstemp(prefix="src-")
     35   tgtfd, tgtfile = tempfile.mkstemp(prefix="tgt-")
     36   patchfd, patchfile = tempfile.mkstemp(prefix="patch-")
     37   os.close(patchfd)
     38 
     39   try:
     40     with os.fdopen(srcfd, "wb") as f_src:
     41       for p in src:
     42         f_src.write(p)
     43 
     44     with os.fdopen(tgtfd, "wb") as f_tgt:
     45       for p in tgt:
     46         f_tgt.write(p)
     47     try:
     48       os.unlink(patchfile)
     49     except OSError:
     50       pass
     51     if imgdiff:
     52       p = subprocess.call(["imgdiff", "-z", srcfile, tgtfile, patchfile],
     53                           stdout=open("/dev/null", "a"),
     54                           stderr=subprocess.STDOUT)
     55     else:
     56       p = subprocess.call(["bsdiff", srcfile, tgtfile, patchfile])
     57 
     58     if p:
     59       raise ValueError("diff failed: " + str(p))
     60 
     61     with open(patchfile, "rb") as f:
     62       return f.read()
     63   finally:
     64     try:
     65       os.unlink(srcfile)
     66       os.unlink(tgtfile)
     67       os.unlink(patchfile)
     68     except OSError:
     69       pass
     70 
     71 class EmptyImage(object):
     72   """A zero-length image."""
     73   blocksize = 4096
     74   care_map = RangeSet()
     75   total_blocks = 0
     76   file_map = {}
     77   def ReadRangeSet(self, ranges):
     78     return ()
     79   def TotalSha1(self):
     80     return sha1().hexdigest()
     81 
     82 
     83 class DataImage(object):
     84   """An image wrapped around a single string of data."""
     85 
     86   def __init__(self, data, trim=False, pad=False):
     87     self.data = data
     88     self.blocksize = 4096
     89 
     90     assert not (trim and pad)
     91 
     92     partial = len(self.data) % self.blocksize
     93     if partial > 0:
     94       if trim:
     95         self.data = self.data[:-partial]
     96       elif pad:
     97         self.data += '\0' * (self.blocksize - partial)
     98       else:
     99         raise ValueError(("data for DataImage must be multiple of %d bytes "
    100                           "unless trim or pad is specified") %
    101                          (self.blocksize,))
    102 
    103     assert len(self.data) % self.blocksize == 0
    104 
    105     self.total_blocks = len(self.data) / self.blocksize
    106     self.care_map = RangeSet(data=(0, self.total_blocks))
    107 
    108     zero_blocks = []
    109     nonzero_blocks = []
    110     reference = '\0' * self.blocksize
    111 
    112     for i in range(self.total_blocks):
    113       d = self.data[i*self.blocksize : (i+1)*self.blocksize]
    114       if d == reference:
    115         zero_blocks.append(i)
    116         zero_blocks.append(i+1)
    117       else:
    118         nonzero_blocks.append(i)
    119         nonzero_blocks.append(i+1)
    120 
    121     self.file_map = {"__ZERO": RangeSet(zero_blocks),
    122                      "__NONZERO": RangeSet(nonzero_blocks)}
    123 
    124   def ReadRangeSet(self, ranges):
    125     return [self.data[s*self.blocksize:e*self.blocksize] for (s, e) in ranges]
    126 
    127   def TotalSha1(self):
    128     if not hasattr(self, "sha1"):
    129       self.sha1 = sha1(self.data).hexdigest()
    130     return self.sha1
    131 
    132 
    133 class Transfer(object):
    134   def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, style, by_id):
    135     self.tgt_name = tgt_name
    136     self.src_name = src_name
    137     self.tgt_ranges = tgt_ranges
    138     self.src_ranges = src_ranges
    139     self.style = style
    140     self.intact = (getattr(tgt_ranges, "monotonic", False) and
    141                    getattr(src_ranges, "monotonic", False))
    142     self.goes_before = {}
    143     self.goes_after = {}
    144 
    145     self.id = len(by_id)
    146     by_id.append(self)
    147 
    148   def __str__(self):
    149     return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style +
    150             " to " + str(self.tgt_ranges) + ">")
    151 
    152 
    153 # BlockImageDiff works on two image objects.  An image object is
    154 # anything that provides the following attributes:
    155 #
    156 #    blocksize: the size in bytes of a block, currently must be 4096.
    157 #
    158 #    total_blocks: the total size of the partition/image, in blocks.
    159 #
    160 #    care_map: a RangeSet containing which blocks (in the range [0,
    161 #      total_blocks) we actually care about; i.e. which blocks contain
    162 #      data.
    163 #
    164 #    file_map: a dict that partitions the blocks contained in care_map
    165 #      into smaller domains that are useful for doing diffs on.
    166 #      (Typically a domain is a file, and the key in file_map is the
    167 #      pathname.)
    168 #
    169 #    ReadRangeSet(): a function that takes a RangeSet and returns the
    170 #      data contained in the image blocks of that RangeSet.  The data
    171 #      is returned as a list or tuple of strings; concatenating the
    172 #      elements together should produce the requested data.
    173 #      Implementations are free to break up the data into list/tuple
    174 #      elements in any way that is convenient.
    175 #
    176 #    TotalSha1(): a function that returns (as a hex string) the SHA-1
    177 #      hash of all the data in the image (ie, all the blocks in the
    178 #      care_map)
    179 #
    180 # When creating a BlockImageDiff, the src image may be None, in which
    181 # case the list of transfers produced will never read from the
    182 # original image.
    183 
    184 class BlockImageDiff(object):
    185   def __init__(self, tgt, src=None, threads=None):
    186     if threads is None:
    187       threads = multiprocessing.cpu_count() // 2
    188       if threads == 0: threads = 1
    189     self.threads = threads
    190 
    191     self.tgt = tgt
    192     if src is None:
    193       src = EmptyImage()
    194     self.src = src
    195 
    196     # The updater code that installs the patch always uses 4k blocks.
    197     assert tgt.blocksize == 4096
    198     assert src.blocksize == 4096
    199 
    200     # The range sets in each filemap should comprise a partition of
    201     # the care map.
    202     self.AssertPartition(src.care_map, src.file_map.values())
    203     self.AssertPartition(tgt.care_map, tgt.file_map.values())
    204 
    205   def Compute(self, prefix):
    206     # When looking for a source file to use as the diff input for a
    207     # target file, we try:
    208     #   1) an exact path match if available, otherwise
    209     #   2) a exact basename match if available, otherwise
    210     #   3) a basename match after all runs of digits are replaced by
    211     #      "#" if available, otherwise
    212     #   4) we have no source for this target.
    213     self.AbbreviateSourceNames()
    214     self.FindTransfers()
    215 
    216     # Find the ordering dependencies among transfers (this is O(n^2)
    217     # in the number of transfers).
    218     self.GenerateDigraph()
    219     # Find a sequence of transfers that satisfies as many ordering
    220     # dependencies as possible (heuristically).
    221     self.FindVertexSequence()
    222     # Fix up the ordering dependencies that the sequence didn't
    223     # satisfy.
    224     self.RemoveBackwardEdges()
    225     # Double-check our work.
    226     self.AssertSequenceGood()
    227 
    228     self.ComputePatches(prefix)
    229     self.WriteTransfers(prefix)
    230 
    231   def WriteTransfers(self, prefix):
    232     out = []
    233 
    234     out.append("1\n")   # format version number
    235     total = 0
    236     performs_read = False
    237 
    238     for xf in self.transfers:
    239 
    240       # zero [rangeset]
    241       # new [rangeset]
    242       # bsdiff patchstart patchlen [src rangeset] [tgt rangeset]
    243       # imgdiff patchstart patchlen [src rangeset] [tgt rangeset]
    244       # move [src rangeset] [tgt rangeset]
    245       # erase [rangeset]
    246 
    247       tgt_size = xf.tgt_ranges.size()
    248 
    249       if xf.style == "new":
    250         assert xf.tgt_ranges
    251         out.append("%s %s\n" % (xf.style, xf.tgt_ranges.to_string_raw()))
    252         total += tgt_size
    253       elif xf.style == "move":
    254         performs_read = True
    255         assert xf.tgt_ranges
    256         assert xf.src_ranges.size() == tgt_size
    257         if xf.src_ranges != xf.tgt_ranges:
    258           out.append("%s %s %s\n" % (
    259               xf.style,
    260               xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
    261           total += tgt_size
    262       elif xf.style in ("bsdiff", "imgdiff"):
    263         performs_read = True
    264         assert xf.tgt_ranges
    265         assert xf.src_ranges
    266         out.append("%s %d %d %s %s\n" % (
    267             xf.style, xf.patch_start, xf.patch_len,
    268             xf.src_ranges.to_string_raw(), xf.tgt_ranges.to_string_raw()))
    269         total += tgt_size
    270       elif xf.style == "zero":
    271         assert xf.tgt_ranges
    272         to_zero = xf.tgt_ranges.subtract(xf.src_ranges)
    273         if to_zero:
    274           out.append("%s %s\n" % (xf.style, to_zero.to_string_raw()))
    275           total += to_zero.size()
    276       else:
    277         raise ValueError, "unknown transfer style '%s'\n" % (xf.style,)
    278 
    279     out.insert(1, str(total) + "\n")
    280 
    281     all_tgt = RangeSet(data=(0, self.tgt.total_blocks))
    282     if performs_read:
    283       # if some of the original data is used, then at the end we'll
    284       # erase all the blocks on the partition that don't contain data
    285       # in the new image.
    286       new_dontcare = all_tgt.subtract(self.tgt.care_map)
    287       if new_dontcare:
    288         out.append("erase %s\n" % (new_dontcare.to_string_raw(),))
    289     else:
    290       # if nothing is read (ie, this is a full OTA), then we can start
    291       # by erasing the entire partition.
    292       out.insert(2, "erase %s\n" % (all_tgt.to_string_raw(),))
    293 
    294     with open(prefix + ".transfer.list", "wb") as f:
    295       for i in out:
    296         f.write(i)
    297 
    298   def ComputePatches(self, prefix):
    299     print("Reticulating splines...")
    300     diff_q = []
    301     patch_num = 0
    302     with open(prefix + ".new.dat", "wb") as new_f:
    303       for xf in self.transfers:
    304         if xf.style == "zero":
    305           pass
    306         elif xf.style == "new":
    307           for piece in self.tgt.ReadRangeSet(xf.tgt_ranges):
    308             new_f.write(piece)
    309         elif xf.style == "diff":
    310           src = self.src.ReadRangeSet(xf.src_ranges)
    311           tgt = self.tgt.ReadRangeSet(xf.tgt_ranges)
    312 
    313           # We can't compare src and tgt directly because they may have
    314           # the same content but be broken up into blocks differently, eg:
    315           #
    316           #    ["he", "llo"]  vs  ["h", "ello"]
    317           #
    318           # We want those to compare equal, ideally without having to
    319           # actually concatenate the strings (these may be tens of
    320           # megabytes).
    321 
    322           src_sha1 = sha1()
    323           for p in src:
    324             src_sha1.update(p)
    325           tgt_sha1 = sha1()
    326           tgt_size = 0
    327           for p in tgt:
    328             tgt_sha1.update(p)
    329             tgt_size += len(p)
    330 
    331           if src_sha1.digest() == tgt_sha1.digest():
    332             # These are identical; we don't need to generate a patch,
    333             # just issue copy commands on the device.
    334             xf.style = "move"
    335           else:
    336             # For files in zip format (eg, APKs, JARs, etc.) we would
    337             # like to use imgdiff -z if possible (because it usually
    338             # produces significantly smaller patches than bsdiff).
    339             # This is permissible if:
    340             #
    341             #  - the source and target files are monotonic (ie, the
    342             #    data is stored with blocks in increasing order), and
    343             #  - we haven't removed any blocks from the source set.
    344             #
    345             # If these conditions are satisfied then appending all the
    346             # blocks in the set together in order will produce a valid
    347             # zip file (plus possibly extra zeros in the last block),
    348             # which is what imgdiff needs to operate.  (imgdiff is
    349             # fine with extra zeros at the end of the file.)
    350             imgdiff = (xf.intact and
    351                        xf.tgt_name.split(".")[-1].lower()
    352                        in ("apk", "jar", "zip"))
    353             xf.style = "imgdiff" if imgdiff else "bsdiff"
    354             diff_q.append((tgt_size, src, tgt, xf, patch_num))
    355             patch_num += 1
    356 
    357         else:
    358           assert False, "unknown style " + xf.style
    359 
    360     if diff_q:
    361       if self.threads > 1:
    362         print("Computing patches (using %d threads)..." % (self.threads,))
    363       else:
    364         print("Computing patches...")
    365       diff_q.sort()
    366 
    367       patches = [None] * patch_num
    368 
    369       lock = threading.Lock()
    370       def diff_worker():
    371         while True:
    372           with lock:
    373             if not diff_q: return
    374             tgt_size, src, tgt, xf, patchnum = diff_q.pop()
    375           patch = compute_patch(src, tgt, imgdiff=(xf.style == "imgdiff"))
    376           size = len(patch)
    377           with lock:
    378             patches[patchnum] = (patch, xf)
    379             print("%10d %10d (%6.2f%%) %7s %s" % (
    380                 size, tgt_size, size * 100.0 / tgt_size, xf.style,
    381                 xf.tgt_name if xf.tgt_name == xf.src_name else (
    382                     xf.tgt_name + " (from " + xf.src_name + ")")))
    383 
    384       threads = [threading.Thread(target=diff_worker)
    385                  for i in range(self.threads)]
    386       for th in threads:
    387         th.start()
    388       while threads:
    389         threads.pop().join()
    390     else:
    391       patches = []
    392 
    393     p = 0
    394     with open(prefix + ".patch.dat", "wb") as patch_f:
    395       for patch, xf in patches:
    396         xf.patch_start = p
    397         xf.patch_len = len(patch)
    398         patch_f.write(patch)
    399         p += len(patch)
    400 
    401   def AssertSequenceGood(self):
    402     # Simulate the sequences of transfers we will output, and check that:
    403     # - we never read a block after writing it, and
    404     # - we write every block we care about exactly once.
    405 
    406     # Start with no blocks having been touched yet.
    407     touched = RangeSet()
    408 
    409     # Imagine processing the transfers in order.
    410     for xf in self.transfers:
    411       # Check that the input blocks for this transfer haven't yet been touched.
    412       assert not touched.overlaps(xf.src_ranges)
    413       # Check that the output blocks for this transfer haven't yet been touched.
    414       assert not touched.overlaps(xf.tgt_ranges)
    415       # Touch all the blocks written by this transfer.
    416       touched = touched.union(xf.tgt_ranges)
    417 
    418     # Check that we've written every target block.
    419     assert touched == self.tgt.care_map
    420 
    421   def RemoveBackwardEdges(self):
    422     print("Removing backward edges...")
    423     in_order = 0
    424     out_of_order = 0
    425     lost_source = 0
    426 
    427     for xf in self.transfers:
    428       io = 0
    429       ooo = 0
    430       lost = 0
    431       size = xf.src_ranges.size()
    432       for u in xf.goes_before:
    433         # xf should go before u
    434         if xf.order < u.order:
    435           # it does, hurray!
    436           io += 1
    437         else:
    438           # it doesn't, boo.  trim the blocks that u writes from xf's
    439           # source, so that xf can go after u.
    440           ooo += 1
    441           assert xf.src_ranges.overlaps(u.tgt_ranges)
    442           xf.src_ranges = xf.src_ranges.subtract(u.tgt_ranges)
    443           xf.intact = False
    444 
    445       if xf.style == "diff" and not xf.src_ranges:
    446         # nothing left to diff from; treat as new data
    447         xf.style = "new"
    448 
    449       lost = size - xf.src_ranges.size()
    450       lost_source += lost
    451       in_order += io
    452       out_of_order += ooo
    453 
    454     print(("  %d/%d dependencies (%.2f%%) were violated; "
    455            "%d source blocks removed.") %
    456           (out_of_order, in_order + out_of_order,
    457            (out_of_order * 100.0 / (in_order + out_of_order))
    458            if (in_order + out_of_order) else 0.0,
    459            lost_source))
    460 
    461   def FindVertexSequence(self):
    462     print("Finding vertex sequence...")
    463 
    464     # This is based on "A Fast & Effective Heuristic for the Feedback
    465     # Arc Set Problem" by P. Eades, X. Lin, and W.F. Smyth.  Think of
    466     # it as starting with the digraph G and moving all the vertices to
    467     # be on a horizontal line in some order, trying to minimize the
    468     # number of edges that end up pointing to the left.  Left-pointing
    469     # edges will get removed to turn the digraph into a DAG.  In this
    470     # case each edge has a weight which is the number of source blocks
    471     # we'll lose if that edge is removed; we try to minimize the total
    472     # weight rather than just the number of edges.
    473 
    474     # Make a copy of the edge set; this copy will get destroyed by the
    475     # algorithm.
    476     for xf in self.transfers:
    477       xf.incoming = xf.goes_after.copy()
    478       xf.outgoing = xf.goes_before.copy()
    479 
    480     # We use an OrderedDict instead of just a set so that the output
    481     # is repeatable; otherwise it would depend on the hash values of
    482     # the transfer objects.
    483     G = OrderedDict()
    484     for xf in self.transfers:
    485       G[xf] = None
    486     s1 = deque()  # the left side of the sequence, built from left to right
    487     s2 = deque()  # the right side of the sequence, built from right to left
    488 
    489     while G:
    490 
    491       # Put all sinks at the end of the sequence.
    492       while True:
    493         sinks = [u for u in G if not u.outgoing]
    494         if not sinks: break
    495         for u in sinks:
    496           s2.appendleft(u)
    497           del G[u]
    498           for iu in u.incoming:
    499             del iu.outgoing[u]
    500 
    501       # Put all the sources at the beginning of the sequence.
    502       while True:
    503         sources = [u for u in G if not u.incoming]
    504         if not sources: break
    505         for u in sources:
    506           s1.append(u)
    507           del G[u]
    508           for iu in u.outgoing:
    509             del iu.incoming[u]
    510 
    511       if not G: break
    512 
    513       # Find the "best" vertex to put next.  "Best" is the one that
    514       # maximizes the net difference in source blocks saved we get by
    515       # pretending it's a source rather than a sink.
    516 
    517       max_d = None
    518       best_u = None
    519       for u in G:
    520         d = sum(u.outgoing.values()) - sum(u.incoming.values())
    521         if best_u is None or d > max_d:
    522           max_d = d
    523           best_u = u
    524 
    525       u = best_u
    526       s1.append(u)
    527       del G[u]
    528       for iu in u.outgoing:
    529         del iu.incoming[u]
    530       for iu in u.incoming:
    531         del iu.outgoing[u]
    532 
    533     # Now record the sequence in the 'order' field of each transfer,
    534     # and by rearranging self.transfers to be in the chosen sequence.
    535 
    536     new_transfers = []
    537     for x in itertools.chain(s1, s2):
    538       x.order = len(new_transfers)
    539       new_transfers.append(x)
    540       del x.incoming
    541       del x.outgoing
    542 
    543     self.transfers = new_transfers
    544 
    545   def GenerateDigraph(self):
    546     print("Generating digraph...")
    547     for a in self.transfers:
    548       for b in self.transfers:
    549         if a is b: continue
    550 
    551         # If the blocks written by A are read by B, then B needs to go before A.
    552         i = a.tgt_ranges.intersect(b.src_ranges)
    553         if i:
    554           if b.src_name == "__ZERO":
    555             # the cost of removing source blocks for the __ZERO domain
    556             # is (nearly) zero.
    557             size = 0
    558           else:
    559             size = i.size()
    560           b.goes_before[a] = size
    561           a.goes_after[b] = size
    562 
    563   def FindTransfers(self):
    564     self.transfers = []
    565     empty = RangeSet()
    566     for tgt_fn, tgt_ranges in self.tgt.file_map.items():
    567       if tgt_fn == "__ZERO":
    568         # the special "__ZERO" domain is all the blocks not contained
    569         # in any file and that are filled with zeros.  We have a
    570         # special transfer style for zero blocks.
    571         src_ranges = self.src.file_map.get("__ZERO", empty)
    572         Transfer(tgt_fn, "__ZERO", tgt_ranges, src_ranges,
    573                  "zero", self.transfers)
    574         continue
    575 
    576       elif tgt_fn in self.src.file_map:
    577         # Look for an exact pathname match in the source.
    578         Transfer(tgt_fn, tgt_fn, tgt_ranges, self.src.file_map[tgt_fn],
    579                  "diff", self.transfers)
    580         continue
    581 
    582       b = os.path.basename(tgt_fn)
    583       if b in self.src_basenames:
    584         # Look for an exact basename match in the source.
    585         src_fn = self.src_basenames[b]
    586         Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
    587                  "diff", self.transfers)
    588         continue
    589 
    590       b = re.sub("[0-9]+", "#", b)
    591       if b in self.src_numpatterns:
    592         # Look for a 'number pattern' match (a basename match after
    593         # all runs of digits are replaced by "#").  (This is useful
    594         # for .so files that contain version numbers in the filename
    595         # that get bumped.)
    596         src_fn = self.src_numpatterns[b]
    597         Transfer(tgt_fn, src_fn, tgt_ranges, self.src.file_map[src_fn],
    598                  "diff", self.transfers)
    599         continue
    600 
    601       Transfer(tgt_fn, None, tgt_ranges, empty, "new", self.transfers)
    602 
    603   def AbbreviateSourceNames(self):
    604     self.src_basenames = {}
    605     self.src_numpatterns = {}
    606 
    607     for k in self.src.file_map.keys():
    608       b = os.path.basename(k)
    609       self.src_basenames[b] = k
    610       b = re.sub("[0-9]+", "#", b)
    611       self.src_numpatterns[b] = k
    612 
    613   @staticmethod
    614   def AssertPartition(total, seq):
    615     """Assert that all the RangeSets in 'seq' form a partition of the
    616     'total' RangeSet (ie, they are nonintersecting and their union
    617     equals 'total')."""
    618     so_far = RangeSet()
    619     for i in seq:
    620       assert not so_far.overlaps(i)
    621       so_far = so_far.union(i)
    622     assert so_far == total
    623