Home | History | Annotate | Download | only in update_payload
      1 # Copyright (c) 2013 The Chromium OS Authors. All rights reserved.
      2 # Use of this source code is governed by a BSD-style license that can be
      3 # found in the LICENSE file.
      4 
      5 """Applying a Chrome OS update payload.
      6 
      7 This module is used internally by the main Payload class for applying an update
      8 payload. The interface for invoking the applier is as follows:
      9 
     10   applier = PayloadApplier(payload)
     11   applier.Run(...)
     12 
     13 """
     14 
     15 from __future__ import print_function
     16 
     17 import array
     18 import bz2
     19 import hashlib
     20 import itertools
     21 import os
     22 import shutil
     23 import subprocess
     24 import sys
     25 import tempfile
     26 
     27 import common
     28 from error import PayloadError
     29 
     30 
     31 #
     32 # Helper functions.
     33 #
     34 def _VerifySha256(file_obj, expected_hash, name, length=-1):
     35   """Verifies the SHA256 hash of a file.
     36 
     37   Args:
     38     file_obj: file object to read
     39     expected_hash: the hash digest we expect to be getting
     40     name: name string of this hash, for error reporting
     41     length: precise length of data to verify (optional)
     42 
     43   Raises:
     44     PayloadError if computed hash doesn't match expected one, or if fails to
     45     read the specified length of data.
     46   """
     47   # pylint: disable=E1101
     48   hasher = hashlib.sha256()
     49   block_length = 1024 * 1024
     50   max_length = length if length >= 0 else sys.maxint
     51 
     52   while max_length > 0:
     53     read_length = min(max_length, block_length)
     54     data = file_obj.read(read_length)
     55     if not data:
     56       break
     57     max_length -= len(data)
     58     hasher.update(data)
     59 
     60   if length >= 0 and max_length > 0:
     61     raise PayloadError(
     62         'insufficient data (%d instead of %d) when verifying %s' %
     63         (length - max_length, length, name))
     64 
     65   actual_hash = hasher.digest()
     66   if actual_hash != expected_hash:
     67     raise PayloadError('%s hash (%s) not as expected (%s)' %
     68                        (name, common.FormatSha256(actual_hash),
     69                         common.FormatSha256(expected_hash)))
     70 
     71 
     72 def _ReadExtents(file_obj, extents, block_size, max_length=-1):
     73   """Reads data from file as defined by extent sequence.
     74 
     75   This tries to be efficient by not copying data as it is read in chunks.
     76 
     77   Args:
     78     file_obj: file object
     79     extents: sequence of block extents (offset and length)
     80     block_size: size of each block
     81     max_length: maximum length to read (optional)
     82 
     83   Returns:
     84     A character array containing the concatenated read data.
     85   """
     86   data = array.array('c')
     87   if max_length < 0:
     88     max_length = sys.maxint
     89   for ex in extents:
     90     if max_length == 0:
     91       break
     92     read_length = min(max_length, ex.num_blocks * block_size)
     93 
     94     # Fill with zeros or read from file, depending on the type of extent.
     95     if ex.start_block == common.PSEUDO_EXTENT_MARKER:
     96       data.extend(itertools.repeat('\0', read_length))
     97     else:
     98       file_obj.seek(ex.start_block * block_size)
     99       data.fromfile(file_obj, read_length)
    100 
    101     max_length -= read_length
    102 
    103   return data
    104 
    105 
    106 def _WriteExtents(file_obj, data, extents, block_size, base_name):
    107   """Writes data to file as defined by extent sequence.
    108 
    109   This tries to be efficient by not copy data as it is written in chunks.
    110 
    111   Args:
    112     file_obj: file object
    113     data: data to write
    114     extents: sequence of block extents (offset and length)
    115     block_size: size of each block
    116     base_name: name string of extent sequence for error reporting
    117 
    118   Raises:
    119     PayloadError when things don't add up.
    120   """
    121   data_offset = 0
    122   data_length = len(data)
    123   for ex, ex_name in common.ExtentIter(extents, base_name):
    124     if not data_length:
    125       raise PayloadError('%s: more write extents than data' % ex_name)
    126     write_length = min(data_length, ex.num_blocks * block_size)
    127 
    128     # Only do actual writing if this is not a pseudo-extent.
    129     if ex.start_block != common.PSEUDO_EXTENT_MARKER:
    130       file_obj.seek(ex.start_block * block_size)
    131       data_view = buffer(data, data_offset, write_length)
    132       file_obj.write(data_view)
    133 
    134     data_offset += write_length
    135     data_length -= write_length
    136 
    137   if data_length:
    138     raise PayloadError('%s: more data than write extents' % base_name)
    139 
    140 
    141 def _ExtentsToBspatchArg(extents, block_size, base_name, data_length=-1):
    142   """Translates an extent sequence into a bspatch-compatible string argument.
    143 
    144   Args:
    145     extents: sequence of block extents (offset and length)
    146     block_size: size of each block
    147     base_name: name string of extent sequence for error reporting
    148     data_length: the actual total length of the data in bytes (optional)
    149 
    150   Returns:
    151     A tuple consisting of (i) a string of the form
    152     "off_1:len_1,...,off_n:len_n", (ii) an offset where zero padding is needed
    153     for filling the last extent, (iii) the length of the padding (zero means no
    154     padding is needed and the extents cover the full length of data).
    155 
    156   Raises:
    157     PayloadError if data_length is too short or too long.
    158   """
    159   arg = ''
    160   pad_off = pad_len = 0
    161   if data_length < 0:
    162     data_length = sys.maxint
    163   for ex, ex_name in common.ExtentIter(extents, base_name):
    164     if not data_length:
    165       raise PayloadError('%s: more extents than total data length' % ex_name)
    166 
    167     is_pseudo = ex.start_block == common.PSEUDO_EXTENT_MARKER
    168     start_byte = -1 if is_pseudo else ex.start_block * block_size
    169     num_bytes = ex.num_blocks * block_size
    170     if data_length < num_bytes:
    171       # We're only padding a real extent.
    172       if not is_pseudo:
    173         pad_off = start_byte + data_length
    174         pad_len = num_bytes - data_length
    175 
    176       num_bytes = data_length
    177 
    178     arg += '%s%d:%d' % (arg and ',', start_byte, num_bytes)
    179     data_length -= num_bytes
    180 
    181   if data_length:
    182     raise PayloadError('%s: extents not covering full data length' % base_name)
    183 
    184   return arg, pad_off, pad_len
    185 
    186 
    187 #
    188 # Payload application.
    189 #
    190 class PayloadApplier(object):
    191   """Applying an update payload.
    192 
    193   This is a short-lived object whose purpose is to isolate the logic used for
    194   applying an update payload.
    195   """
    196 
    197   def __init__(self, payload, bsdiff_in_place=True, bspatch_path=None,
    198                imgpatch_path=None, truncate_to_expected_size=True):
    199     """Initialize the applier.
    200 
    201     Args:
    202       payload: the payload object to check
    203       bsdiff_in_place: whether to perform BSDIFF operation in-place (optional)
    204       bspatch_path: path to the bspatch binary (optional)
    205       imgpatch_path: path to the imgpatch binary (optional)
    206       truncate_to_expected_size: whether to truncate the resulting partitions
    207                                  to their expected sizes, as specified in the
    208                                  payload (optional)
    209     """
    210     assert payload.is_init, 'uninitialized update payload'
    211     self.payload = payload
    212     self.block_size = payload.manifest.block_size
    213     self.minor_version = payload.manifest.minor_version
    214     self.bsdiff_in_place = bsdiff_in_place
    215     self.bspatch_path = bspatch_path or 'bspatch'
    216     self.imgpatch_path = imgpatch_path or 'imgpatch'
    217     self.truncate_to_expected_size = truncate_to_expected_size
    218 
    219   def _ApplyReplaceOperation(self, op, op_name, out_data, part_file, part_size):
    220     """Applies a REPLACE{,_BZ} operation.
    221 
    222     Args:
    223       op: the operation object
    224       op_name: name string for error reporting
    225       out_data: the data to be written
    226       part_file: the partition file object
    227       part_size: the size of the partition
    228 
    229     Raises:
    230       PayloadError if something goes wrong.
    231     """
    232     block_size = self.block_size
    233     data_length = len(out_data)
    234 
    235     # Decompress data if needed.
    236     if op.type == common.OpType.REPLACE_BZ:
    237       out_data = bz2.decompress(out_data)
    238       data_length = len(out_data)
    239 
    240     # Write data to blocks specified in dst extents.
    241     data_start = 0
    242     for ex, ex_name in common.ExtentIter(op.dst_extents,
    243                                          '%s.dst_extents' % op_name):
    244       start_block = ex.start_block
    245       num_blocks = ex.num_blocks
    246       count = num_blocks * block_size
    247 
    248       # Make sure it's not a fake (signature) operation.
    249       if start_block != common.PSEUDO_EXTENT_MARKER:
    250         data_end = data_start + count
    251 
    252         # Make sure we're not running past partition boundary.
    253         if (start_block + num_blocks) * block_size > part_size:
    254           raise PayloadError(
    255               '%s: extent (%s) exceeds partition size (%d)' %
    256               (ex_name, common.FormatExtent(ex, block_size),
    257                part_size))
    258 
    259         # Make sure that we have enough data to write.
    260         if data_end >= data_length + block_size:
    261           raise PayloadError(
    262               '%s: more dst blocks than data (even with padding)')
    263 
    264         # Pad with zeros if necessary.
    265         if data_end > data_length:
    266           padding = data_end - data_length
    267           out_data += '\0' * padding
    268 
    269         self.payload.payload_file.seek(start_block * block_size)
    270         part_file.seek(start_block * block_size)
    271         part_file.write(out_data[data_start:data_end])
    272 
    273       data_start += count
    274 
    275     # Make sure we wrote all data.
    276     if data_start < data_length:
    277       raise PayloadError('%s: wrote fewer bytes (%d) than expected (%d)' %
    278                          (op_name, data_start, data_length))
    279 
    280   def _ApplyMoveOperation(self, op, op_name, part_file):
    281     """Applies a MOVE operation.
    282 
    283     Note that this operation must read the whole block data from the input and
    284     only then dump it, due to our in-place update semantics; otherwise, it
    285     might clobber data midway through.
    286 
    287     Args:
    288       op: the operation object
    289       op_name: name string for error reporting
    290       part_file: the partition file object
    291 
    292     Raises:
    293       PayloadError if something goes wrong.
    294     """
    295     block_size = self.block_size
    296 
    297     # Gather input raw data from src extents.
    298     in_data = _ReadExtents(part_file, op.src_extents, block_size)
    299 
    300     # Dump extracted data to dst extents.
    301     _WriteExtents(part_file, in_data, op.dst_extents, block_size,
    302                   '%s.dst_extents' % op_name)
    303 
    304   def _ApplyBsdiffOperation(self, op, op_name, patch_data, new_part_file):
    305     """Applies a BSDIFF operation.
    306 
    307     Args:
    308       op: the operation object
    309       op_name: name string for error reporting
    310       patch_data: the binary patch content
    311       new_part_file: the target partition file object
    312 
    313     Raises:
    314       PayloadError if something goes wrong.
    315     """
    316     # Implemented using a SOURCE_BSDIFF operation with the source and target
    317     # partition set to the new partition.
    318     self._ApplyDiffOperation(op, op_name, patch_data, new_part_file,
    319                              new_part_file)
    320 
    321   def _ApplySourceCopyOperation(self, op, op_name, old_part_file,
    322                                 new_part_file):
    323     """Applies a SOURCE_COPY operation.
    324 
    325     Args:
    326       op: the operation object
    327       op_name: name string for error reporting
    328       old_part_file: the old partition file object
    329       new_part_file: the new partition file object
    330 
    331     Raises:
    332       PayloadError if something goes wrong.
    333     """
    334     if not old_part_file:
    335       raise PayloadError(
    336           '%s: no source partition file provided for operation type (%d)' %
    337           (op_name, op.type))
    338 
    339     block_size = self.block_size
    340 
    341     # Gather input raw data from src extents.
    342     in_data = _ReadExtents(old_part_file, op.src_extents, block_size)
    343 
    344     # Dump extracted data to dst extents.
    345     _WriteExtents(new_part_file, in_data, op.dst_extents, block_size,
    346                   '%s.dst_extents' % op_name)
    347 
    348   def _ApplyDiffOperation(self, op, op_name, patch_data, old_part_file,
    349                           new_part_file):
    350     """Applies a SOURCE_BSDIFF or IMGDIFF operation.
    351 
    352     Args:
    353       op: the operation object
    354       op_name: name string for error reporting
    355       patch_data: the binary patch content
    356       old_part_file: the source partition file object
    357       new_part_file: the target partition file object
    358 
    359     Raises:
    360       PayloadError if something goes wrong.
    361     """
    362     if not old_part_file:
    363       raise PayloadError(
    364           '%s: no source partition file provided for operation type (%d)' %
    365           (op_name, op.type))
    366 
    367     block_size = self.block_size
    368 
    369     # Dump patch data to file.
    370     with tempfile.NamedTemporaryFile(delete=False) as patch_file:
    371       patch_file_name = patch_file.name
    372       patch_file.write(patch_data)
    373 
    374     if (hasattr(new_part_file, 'fileno') and
    375         ((not old_part_file) or hasattr(old_part_file, 'fileno')) and
    376         op.type != common.OpType.IMGDIFF):
    377       # Construct input and output extents argument for bspatch.
    378       in_extents_arg, _, _ = _ExtentsToBspatchArg(
    379           op.src_extents, block_size, '%s.src_extents' % op_name,
    380           data_length=op.src_length)
    381       out_extents_arg, pad_off, pad_len = _ExtentsToBspatchArg(
    382           op.dst_extents, block_size, '%s.dst_extents' % op_name,
    383           data_length=op.dst_length)
    384 
    385       new_file_name = '/dev/fd/%d' % new_part_file.fileno()
    386       # Diff from source partition.
    387       old_file_name = '/dev/fd/%d' % old_part_file.fileno()
    388 
    389       # Invoke bspatch on partition file with extents args.
    390       bspatch_cmd = [self.bspatch_path, old_file_name, new_file_name,
    391                      patch_file_name, in_extents_arg, out_extents_arg]
    392       subprocess.check_call(bspatch_cmd)
    393 
    394       # Pad with zeros past the total output length.
    395       if pad_len:
    396         new_part_file.seek(pad_off)
    397         new_part_file.write('\0' * pad_len)
    398     else:
    399       # Gather input raw data and write to a temp file.
    400       input_part_file = old_part_file if old_part_file else new_part_file
    401       in_data = _ReadExtents(input_part_file, op.src_extents, block_size,
    402                              max_length=op.src_length)
    403       with tempfile.NamedTemporaryFile(delete=False) as in_file:
    404         in_file_name = in_file.name
    405         in_file.write(in_data)
    406 
    407       # Allocate temporary output file.
    408       with tempfile.NamedTemporaryFile(delete=False) as out_file:
    409         out_file_name = out_file.name
    410 
    411       # Invoke bspatch.
    412       patch_cmd = [self.bspatch_path, in_file_name, out_file_name,
    413                    patch_file_name]
    414       if op.type == common.OpType.IMGDIFF:
    415         patch_cmd[0] = self.imgpatch_path
    416       subprocess.check_call(patch_cmd)
    417 
    418       # Read output.
    419       with open(out_file_name, 'rb') as out_file:
    420         out_data = out_file.read()
    421         if len(out_data) != op.dst_length:
    422           raise PayloadError(
    423               '%s: actual patched data length (%d) not as expected (%d)' %
    424               (op_name, len(out_data), op.dst_length))
    425 
    426       # Write output back to partition, with padding.
    427       unaligned_out_len = len(out_data) % block_size
    428       if unaligned_out_len:
    429         out_data += '\0' * (block_size - unaligned_out_len)
    430       _WriteExtents(new_part_file, out_data, op.dst_extents, block_size,
    431                     '%s.dst_extents' % op_name)
    432 
    433       # Delete input/output files.
    434       os.remove(in_file_name)
    435       os.remove(out_file_name)
    436 
    437     # Delete patch file.
    438     os.remove(patch_file_name)
    439 
    440   def _ApplyOperations(self, operations, base_name, old_part_file,
    441                        new_part_file, part_size):
    442     """Applies a sequence of update operations to a partition.
    443 
    444     This assumes an in-place update semantics for MOVE and BSDIFF, namely all
    445     reads are performed first, then the data is processed and written back to
    446     the same file.
    447 
    448     Args:
    449       operations: the sequence of operations
    450       base_name: the name of the operation sequence
    451       old_part_file: the old partition file object, open for reading/writing
    452       new_part_file: the new partition file object, open for reading/writing
    453       part_size: the partition size
    454 
    455     Raises:
    456       PayloadError if anything goes wrong while processing the payload.
    457     """
    458     for op, op_name in common.OperationIter(operations, base_name):
    459       # Read data blob.
    460       data = self.payload.ReadDataBlob(op.data_offset, op.data_length)
    461 
    462       if op.type in (common.OpType.REPLACE, common.OpType.REPLACE_BZ):
    463         self._ApplyReplaceOperation(op, op_name, data, new_part_file, part_size)
    464       elif op.type == common.OpType.MOVE:
    465         self._ApplyMoveOperation(op, op_name, new_part_file)
    466       elif op.type == common.OpType.BSDIFF:
    467         self._ApplyBsdiffOperation(op, op_name, data, new_part_file)
    468       elif op.type == common.OpType.SOURCE_COPY:
    469         self._ApplySourceCopyOperation(op, op_name, old_part_file,
    470                                        new_part_file)
    471       elif op.type in (common.OpType.SOURCE_BSDIFF, common.OpType.IMGDIFF):
    472         self._ApplyDiffOperation(op, op_name, data, old_part_file,
    473                                  new_part_file)
    474       else:
    475         raise PayloadError('%s: unknown operation type (%d)' %
    476                            (op_name, op.type))
    477 
    478   def _ApplyToPartition(self, operations, part_name, base_name,
    479                         new_part_file_name, new_part_info,
    480                         old_part_file_name=None, old_part_info=None):
    481     """Applies an update to a partition.
    482 
    483     Args:
    484       operations: the sequence of update operations to apply
    485       part_name: the name of the partition, for error reporting
    486       base_name: the name of the operation sequence
    487       new_part_file_name: file name to write partition data to
    488       new_part_info: size and expected hash of dest partition
    489       old_part_file_name: file name of source partition (optional)
    490       old_part_info: size and expected hash of source partition (optional)
    491 
    492     Raises:
    493       PayloadError if anything goes wrong with the update.
    494     """
    495     # Do we have a source partition?
    496     if old_part_file_name:
    497       # Verify the source partition.
    498       with open(old_part_file_name, 'rb') as old_part_file:
    499         _VerifySha256(old_part_file, old_part_info.hash,
    500                       'old ' + part_name, length=old_part_info.size)
    501       new_part_file_mode = 'r+b'
    502       if self.minor_version == common.INPLACE_MINOR_PAYLOAD_VERSION:
    503         # Copy the src partition to the dst one; make sure we don't truncate it.
    504         shutil.copyfile(old_part_file_name, new_part_file_name)
    505       elif (self.minor_version == common.SOURCE_MINOR_PAYLOAD_VERSION or
    506             self.minor_version == common.OPSRCHASH_MINOR_PAYLOAD_VERSION or
    507             self.minor_version == common.IMGDIFF_MINOR_PAYLOAD_VERSION):
    508         # In minor version >= 2, we don't want to copy the partitions, so
    509         # instead just make the new partition file.
    510         open(new_part_file_name, 'w').close()
    511       else:
    512         raise PayloadError("Unknown minor version: %d" % self.minor_version)
    513     else:
    514       # We need to create/truncate the dst partition file.
    515       new_part_file_mode = 'w+b'
    516 
    517     # Apply operations.
    518     with open(new_part_file_name, new_part_file_mode) as new_part_file:
    519       old_part_file = (open(old_part_file_name, 'r+b')
    520                        if old_part_file_name else None)
    521       try:
    522         self._ApplyOperations(operations, base_name, old_part_file,
    523                               new_part_file, new_part_info.size)
    524       finally:
    525         if old_part_file:
    526           old_part_file.close()
    527 
    528       # Truncate the result, if so instructed.
    529       if self.truncate_to_expected_size:
    530         new_part_file.seek(0, 2)
    531         if new_part_file.tell() > new_part_info.size:
    532           new_part_file.seek(new_part_info.size)
    533           new_part_file.truncate()
    534 
    535     # Verify the resulting partition.
    536     with open(new_part_file_name, 'rb') as new_part_file:
    537       _VerifySha256(new_part_file, new_part_info.hash,
    538                     'new ' + part_name, length=new_part_info.size)
    539 
    540   def Run(self, new_kernel_part, new_rootfs_part, old_kernel_part=None,
    541           old_rootfs_part=None):
    542     """Applier entry point, invoking all update operations.
    543 
    544     Args:
    545       new_kernel_part: name of dest kernel partition file
    546       new_rootfs_part: name of dest rootfs partition file
    547       old_kernel_part: name of source kernel partition file (optional)
    548       old_rootfs_part: name of source rootfs partition file (optional)
    549 
    550     Raises:
    551       PayloadError if payload application failed.
    552     """
    553     self.payload.ResetFile()
    554 
    555     # Make sure the arguments are sane and match the payload.
    556     if not (new_kernel_part and new_rootfs_part):
    557       raise PayloadError('missing dst {kernel,rootfs} partitions')
    558 
    559     if not (old_kernel_part or old_rootfs_part):
    560       if not self.payload.IsFull():
    561         raise PayloadError('trying to apply a non-full update without src '
    562                            '{kernel,rootfs} partitions')
    563     elif old_kernel_part and old_rootfs_part:
    564       if not self.payload.IsDelta():
    565         raise PayloadError('trying to apply a non-delta update onto src '
    566                            '{kernel,rootfs} partitions')
    567     else:
    568       raise PayloadError('not all src partitions provided')
    569 
    570     # Apply update to rootfs.
    571     self._ApplyToPartition(
    572         self.payload.manifest.install_operations, 'rootfs',
    573         'install_operations', new_rootfs_part,
    574         self.payload.manifest.new_rootfs_info, old_rootfs_part,
    575         self.payload.manifest.old_rootfs_info)
    576 
    577     # Apply update to kernel update.
    578     self._ApplyToPartition(
    579         self.payload.manifest.kernel_install_operations, 'kernel',
    580         'kernel_install_operations', new_kernel_part,
    581         self.payload.manifest.new_kernel_info, old_kernel_part,
    582         self.payload.manifest.old_kernel_info)
    583