Home | History | Annotate | Download | only in update_payload
      1 #!/usr/bin/python2
      2 #
      3 # Copyright (c) 2013 The Chromium OS Authors. All rights reserved.
      4 # Use of this source code is governed by a BSD-style license that can be
      5 # found in the LICENSE file.
      6 
      7 """Unit testing checker.py."""
      8 
      9 from __future__ import print_function
     10 
     11 import array
     12 import collections
     13 import cStringIO
     14 import hashlib
     15 import itertools
     16 import os
     17 import unittest
     18 
     19 # pylint cannot find mox.
     20 # pylint: disable=F0401
     21 import mox
     22 
     23 import checker
     24 import common
     25 import payload as update_payload  # Avoid name conflicts later.
     26 import test_utils
     27 import update_metadata_pb2
     28 
     29 
     30 def _OpTypeByName(op_name):
     31   op_name_to_type = {
     32       'REPLACE': common.OpType.REPLACE,
     33       'REPLACE_BZ': common.OpType.REPLACE_BZ,
     34       'MOVE': common.OpType.MOVE,
     35       'BSDIFF': common.OpType.BSDIFF,
     36       'SOURCE_COPY': common.OpType.SOURCE_COPY,
     37       'SOURCE_BSDIFF': common.OpType.SOURCE_BSDIFF,
     38       'ZERO': common.OpType.ZERO,
     39       'DISCARD': common.OpType.DISCARD,
     40       'REPLACE_XZ': common.OpType.REPLACE_XZ,
     41       'IMGDIFF': common.OpType.IMGDIFF,
     42   }
     43   return op_name_to_type[op_name]
     44 
     45 
     46 def _GetPayloadChecker(payload_gen_write_to_file_func, payload_gen_dargs=None,
     47                        checker_init_dargs=None):
     48   """Returns a payload checker from a given payload generator."""
     49   if payload_gen_dargs is None:
     50     payload_gen_dargs = {}
     51   if checker_init_dargs is None:
     52     checker_init_dargs = {}
     53 
     54   payload_file = cStringIO.StringIO()
     55   payload_gen_write_to_file_func(payload_file, **payload_gen_dargs)
     56   payload_file.seek(0)
     57   payload = update_payload.Payload(payload_file)
     58   payload.Init()
     59   return checker.PayloadChecker(payload, **checker_init_dargs)
     60 
     61 
     62 def _GetPayloadCheckerWithData(payload_gen):
     63   """Returns a payload checker from a given payload generator."""
     64   payload_file = cStringIO.StringIO()
     65   payload_gen.WriteToFile(payload_file)
     66   payload_file.seek(0)
     67   payload = update_payload.Payload(payload_file)
     68   payload.Init()
     69   return checker.PayloadChecker(payload)
     70 
     71 
     72 # This class doesn't need an __init__().
     73 # pylint: disable=W0232
     74 # Unit testing is all about running protected methods.
     75 # pylint: disable=W0212
     76 # Don't bark about missing members of classes you cannot import.
     77 # pylint: disable=E1101
     78 class PayloadCheckerTest(mox.MoxTestBase):
     79   """Tests the PayloadChecker class.
     80 
     81   In addition to ordinary testFoo() methods, which are automatically invoked by
     82   the unittest framework, in this class we make use of DoBarTest() calls that
     83   implement parametric tests of certain features. In order to invoke each test,
     84   which embodies a unique combination of parameter values, as a complete unit
     85   test, we perform explicit enumeration of the parameter space and create
     86   individual invocation contexts for each, which are then bound as
     87   testBar__param1=val1__param2=val2(). The enumeration of parameter spaces for
     88   all such tests is done in AddAllParametricTests().
     89   """
     90 
     91   def MockPayload(self):
     92     """Create a mock payload object, complete with a mock manifest."""
     93     payload = self.mox.CreateMock(update_payload.Payload)
     94     payload.is_init = True
     95     payload.manifest = self.mox.CreateMock(
     96         update_metadata_pb2.DeltaArchiveManifest)
     97     return payload
     98 
     99   @staticmethod
    100   def NewExtent(start_block, num_blocks):
    101     """Returns an Extent message.
    102 
    103     Each of the provided fields is set iff it is >= 0; otherwise, it's left at
    104     its default state.
    105 
    106     Args:
    107       start_block: The starting block of the extent.
    108       num_blocks: The number of blocks in the extent.
    109 
    110     Returns:
    111       An Extent message.
    112     """
    113     ex = update_metadata_pb2.Extent()
    114     if start_block >= 0:
    115       ex.start_block = start_block
    116     if num_blocks >= 0:
    117       ex.num_blocks = num_blocks
    118     return ex
    119 
    120   @staticmethod
    121   def NewExtentList(*args):
    122     """Returns an list of extents.
    123 
    124     Args:
    125       *args: (start_block, num_blocks) pairs defining the extents.
    126 
    127     Returns:
    128       A list of Extent objects.
    129     """
    130     ex_list = []
    131     for start_block, num_blocks in args:
    132       ex_list.append(PayloadCheckerTest.NewExtent(start_block, num_blocks))
    133     return ex_list
    134 
    135   @staticmethod
    136   def AddToMessage(repeated_field, field_vals):
    137     for field_val in field_vals:
    138       new_field = repeated_field.add()
    139       new_field.CopyFrom(field_val)
    140 
    141   def SetupAddElemTest(self, is_present, is_submsg, convert=str,
    142                        linebreak=False, indent=0):
    143     """Setup for testing of _CheckElem() and its derivatives.
    144 
    145     Args:
    146       is_present: Whether or not the element is found in the message.
    147       is_submsg: Whether the element is a sub-message itself.
    148       convert: A representation conversion function.
    149       linebreak: Whether or not a linebreak is to be used in the report.
    150       indent: Indentation used for the report.
    151 
    152     Returns:
    153       msg: A mock message object.
    154       report: A mock report object.
    155       subreport: A mock sub-report object.
    156       name: An element name to check.
    157       val: Expected element value.
    158     """
    159     name = 'foo'
    160     val = 'fake submsg' if is_submsg else 'fake field'
    161     subreport = 'fake subreport'
    162 
    163     # Create a mock message.
    164     msg = self.mox.CreateMock(update_metadata_pb2._message.Message)
    165     msg.HasField(name).AndReturn(is_present)
    166     setattr(msg, name, val)
    167 
    168     # Create a mock report.
    169     report = self.mox.CreateMock(checker._PayloadReport)
    170     if is_present:
    171       if is_submsg:
    172         report.AddSubReport(name).AndReturn(subreport)
    173       else:
    174         report.AddField(name, convert(val), linebreak=linebreak, indent=indent)
    175 
    176     self.mox.ReplayAll()
    177     return (msg, report, subreport, name, val)
    178 
    179   def DoAddElemTest(self, is_present, is_mandatory, is_submsg, convert,
    180                     linebreak, indent):
    181     """Parametric testing of _CheckElem().
    182 
    183     Args:
    184       is_present: Whether or not the element is found in the message.
    185       is_mandatory: Whether or not it's a mandatory element.
    186       is_submsg: Whether the element is a sub-message itself.
    187       convert: A representation conversion function.
    188       linebreak: Whether or not a linebreak is to be used in the report.
    189       indent: Indentation used for the report.
    190     """
    191     msg, report, subreport, name, val = self.SetupAddElemTest(
    192         is_present, is_submsg, convert, linebreak, indent)
    193 
    194     args = (msg, name, report, is_mandatory, is_submsg)
    195     kwargs = {'convert': convert, 'linebreak': linebreak, 'indent': indent}
    196     if is_mandatory and not is_present:
    197       self.assertRaises(update_payload.PayloadError,
    198                         checker.PayloadChecker._CheckElem, *args, **kwargs)
    199     else:
    200       ret_val, ret_subreport = checker.PayloadChecker._CheckElem(*args,
    201                                                                  **kwargs)
    202       self.assertEquals(val if is_present else None, ret_val)
    203       self.assertEquals(subreport if is_present and is_submsg else None,
    204                         ret_subreport)
    205 
    206   def DoAddFieldTest(self, is_mandatory, is_present, convert, linebreak,
    207                      indent):
    208     """Parametric testing of _Check{Mandatory,Optional}Field().
    209 
    210     Args:
    211       is_mandatory: Whether we're testing a mandatory call.
    212       is_present: Whether or not the element is found in the message.
    213       convert: A representation conversion function.
    214       linebreak: Whether or not a linebreak is to be used in the report.
    215       indent: Indentation used for the report.
    216     """
    217     msg, report, _, name, val = self.SetupAddElemTest(
    218         is_present, False, convert, linebreak, indent)
    219 
    220     # Prepare for invocation of the tested method.
    221     args = [msg, name, report]
    222     kwargs = {'convert': convert, 'linebreak': linebreak, 'indent': indent}
    223     if is_mandatory:
    224       args.append('bar')
    225       tested_func = checker.PayloadChecker._CheckMandatoryField
    226     else:
    227       tested_func = checker.PayloadChecker._CheckOptionalField
    228 
    229     # Test the method call.
    230     if is_mandatory and not is_present:
    231       self.assertRaises(update_payload.PayloadError, tested_func, *args,
    232                         **kwargs)
    233     else:
    234       ret_val = tested_func(*args, **kwargs)
    235       self.assertEquals(val if is_present else None, ret_val)
    236 
    237   def DoAddSubMsgTest(self, is_mandatory, is_present):
    238     """Parametrized testing of _Check{Mandatory,Optional}SubMsg().
    239 
    240     Args:
    241       is_mandatory: Whether we're testing a mandatory call.
    242       is_present: Whether or not the element is found in the message.
    243     """
    244     msg, report, subreport, name, val = self.SetupAddElemTest(is_present, True)
    245 
    246     # Prepare for invocation of the tested method.
    247     args = [msg, name, report]
    248     if is_mandatory:
    249       args.append('bar')
    250       tested_func = checker.PayloadChecker._CheckMandatorySubMsg
    251     else:
    252       tested_func = checker.PayloadChecker._CheckOptionalSubMsg
    253 
    254     # Test the method call.
    255     if is_mandatory and not is_present:
    256       self.assertRaises(update_payload.PayloadError, tested_func, *args)
    257     else:
    258       ret_val, ret_subreport = tested_func(*args)
    259       self.assertEquals(val if is_present else None, ret_val)
    260       self.assertEquals(subreport if is_present else None, ret_subreport)
    261 
    262   def testCheckPresentIff(self):
    263     """Tests _CheckPresentIff()."""
    264     self.assertIsNone(checker.PayloadChecker._CheckPresentIff(
    265         None, None, 'foo', 'bar', 'baz'))
    266     self.assertIsNone(checker.PayloadChecker._CheckPresentIff(
    267         'a', 'b', 'foo', 'bar', 'baz'))
    268     self.assertRaises(update_payload.PayloadError,
    269                       checker.PayloadChecker._CheckPresentIff,
    270                       'a', None, 'foo', 'bar', 'baz')
    271     self.assertRaises(update_payload.PayloadError,
    272                       checker.PayloadChecker._CheckPresentIff,
    273                       None, 'b', 'foo', 'bar', 'baz')
    274 
    275   def DoCheckSha256SignatureTest(self, expect_pass, expect_subprocess_call,
    276                                  sig_data, sig_asn1_header,
    277                                  returned_signed_hash, expected_signed_hash):
    278     """Parametric testing of _CheckSha256SignatureTest().
    279 
    280     Args:
    281       expect_pass: Whether or not it should pass.
    282       expect_subprocess_call: Whether to expect the openssl call to happen.
    283       sig_data: The signature raw data.
    284       sig_asn1_header: The ASN1 header.
    285       returned_signed_hash: The signed hash data retuned by openssl.
    286       expected_signed_hash: The signed hash data to compare against.
    287     """
    288     try:
    289       # Stub out the subprocess invocation.
    290       self.mox.StubOutWithMock(checker.PayloadChecker, '_Run')
    291       if expect_subprocess_call:
    292         checker.PayloadChecker._Run(
    293             mox.IsA(list), send_data=sig_data).AndReturn(
    294                 (sig_asn1_header + returned_signed_hash, None))
    295 
    296       self.mox.ReplayAll()
    297       if expect_pass:
    298         self.assertIsNone(checker.PayloadChecker._CheckSha256Signature(
    299             sig_data, 'foo', expected_signed_hash, 'bar'))
    300       else:
    301         self.assertRaises(update_payload.PayloadError,
    302                           checker.PayloadChecker._CheckSha256Signature,
    303                           sig_data, 'foo', expected_signed_hash, 'bar')
    304     finally:
    305       self.mox.UnsetStubs()
    306 
    307   def testCheckSha256Signature_Pass(self):
    308     """Tests _CheckSha256Signature(); pass case."""
    309     sig_data = 'fake-signature'.ljust(256)
    310     signed_hash = hashlib.sha256('fake-data').digest()
    311     self.DoCheckSha256SignatureTest(True, True, sig_data,
    312                                     common.SIG_ASN1_HEADER, signed_hash,
    313                                     signed_hash)
    314 
    315   def testCheckSha256Signature_FailBadSignature(self):
    316     """Tests _CheckSha256Signature(); fails due to malformed signature."""
    317     sig_data = 'fake-signature'  # Malformed (not 256 bytes in length).
    318     signed_hash = hashlib.sha256('fake-data').digest()
    319     self.DoCheckSha256SignatureTest(False, False, sig_data,
    320                                     common.SIG_ASN1_HEADER, signed_hash,
    321                                     signed_hash)
    322 
    323   def testCheckSha256Signature_FailBadOutputLength(self):
    324     """Tests _CheckSha256Signature(); fails due to unexpected output length."""
    325     sig_data = 'fake-signature'.ljust(256)
    326     signed_hash = 'fake-hash'  # Malformed (not 32 bytes in length).
    327     self.DoCheckSha256SignatureTest(False, True, sig_data,
    328                                     common.SIG_ASN1_HEADER, signed_hash,
    329                                     signed_hash)
    330 
    331   def testCheckSha256Signature_FailBadAsnHeader(self):
    332     """Tests _CheckSha256Signature(); fails due to bad ASN1 header."""
    333     sig_data = 'fake-signature'.ljust(256)
    334     signed_hash = hashlib.sha256('fake-data').digest()
    335     bad_asn1_header = 'bad-asn-header'.ljust(len(common.SIG_ASN1_HEADER))
    336     self.DoCheckSha256SignatureTest(False, True, sig_data, bad_asn1_header,
    337                                     signed_hash, signed_hash)
    338 
    339   def testCheckSha256Signature_FailBadHash(self):
    340     """Tests _CheckSha256Signature(); fails due to bad hash returned."""
    341     sig_data = 'fake-signature'.ljust(256)
    342     expected_signed_hash = hashlib.sha256('fake-data').digest()
    343     returned_signed_hash = hashlib.sha256('bad-fake-data').digest()
    344     self.DoCheckSha256SignatureTest(False, True, sig_data,
    345                                     common.SIG_ASN1_HEADER,
    346                                     expected_signed_hash, returned_signed_hash)
    347 
    348   def testCheckBlocksFitLength_Pass(self):
    349     """Tests _CheckBlocksFitLength(); pass case."""
    350     self.assertIsNone(checker.PayloadChecker._CheckBlocksFitLength(
    351         64, 4, 16, 'foo'))
    352     self.assertIsNone(checker.PayloadChecker._CheckBlocksFitLength(
    353         60, 4, 16, 'foo'))
    354     self.assertIsNone(checker.PayloadChecker._CheckBlocksFitLength(
    355         49, 4, 16, 'foo'))
    356     self.assertIsNone(checker.PayloadChecker._CheckBlocksFitLength(
    357         48, 3, 16, 'foo'))
    358 
    359   def testCheckBlocksFitLength_TooManyBlocks(self):
    360     """Tests _CheckBlocksFitLength(); fails due to excess blocks."""
    361     self.assertRaises(update_payload.PayloadError,
    362                       checker.PayloadChecker._CheckBlocksFitLength,
    363                       64, 5, 16, 'foo')
    364     self.assertRaises(update_payload.PayloadError,
    365                       checker.PayloadChecker._CheckBlocksFitLength,
    366                       60, 5, 16, 'foo')
    367     self.assertRaises(update_payload.PayloadError,
    368                       checker.PayloadChecker._CheckBlocksFitLength,
    369                       49, 5, 16, 'foo')
    370     self.assertRaises(update_payload.PayloadError,
    371                       checker.PayloadChecker._CheckBlocksFitLength,
    372                       48, 4, 16, 'foo')
    373 
    374   def testCheckBlocksFitLength_TooFewBlocks(self):
    375     """Tests _CheckBlocksFitLength(); fails due to insufficient blocks."""
    376     self.assertRaises(update_payload.PayloadError,
    377                       checker.PayloadChecker._CheckBlocksFitLength,
    378                       64, 3, 16, 'foo')
    379     self.assertRaises(update_payload.PayloadError,
    380                       checker.PayloadChecker._CheckBlocksFitLength,
    381                       60, 3, 16, 'foo')
    382     self.assertRaises(update_payload.PayloadError,
    383                       checker.PayloadChecker._CheckBlocksFitLength,
    384                       49, 3, 16, 'foo')
    385     self.assertRaises(update_payload.PayloadError,
    386                       checker.PayloadChecker._CheckBlocksFitLength,
    387                       48, 2, 16, 'foo')
    388 
    389   def DoCheckManifestTest(self, fail_mismatched_block_size, fail_bad_sigs,
    390                           fail_mismatched_oki_ori, fail_bad_oki, fail_bad_ori,
    391                           fail_bad_nki, fail_bad_nri, fail_old_kernel_fs_size,
    392                           fail_old_rootfs_fs_size, fail_new_kernel_fs_size,
    393                           fail_new_rootfs_fs_size):
    394     """Parametric testing of _CheckManifest().
    395 
    396     Args:
    397       fail_mismatched_block_size: Simulate a missing block_size field.
    398       fail_bad_sigs: Make signatures descriptor inconsistent.
    399       fail_mismatched_oki_ori: Make old rootfs/kernel info partially present.
    400       fail_bad_oki: Tamper with old kernel info.
    401       fail_bad_ori: Tamper with old rootfs info.
    402       fail_bad_nki: Tamper with new kernel info.
    403       fail_bad_nri: Tamper with new rootfs info.
    404       fail_old_kernel_fs_size: Make old kernel fs size too big.
    405       fail_old_rootfs_fs_size: Make old rootfs fs size too big.
    406       fail_new_kernel_fs_size: Make new kernel fs size too big.
    407       fail_new_rootfs_fs_size: Make new rootfs fs size too big.
    408     """
    409     # Generate a test payload. For this test, we only care about the manifest
    410     # and don't need any data blobs, hence we can use a plain paylaod generator
    411     # (which also gives us more control on things that can be screwed up).
    412     payload_gen = test_utils.PayloadGenerator()
    413 
    414     # Tamper with block size, if required.
    415     if fail_mismatched_block_size:
    416       payload_gen.SetBlockSize(test_utils.KiB(1))
    417     else:
    418       payload_gen.SetBlockSize(test_utils.KiB(4))
    419 
    420     # Add some operations.
    421     payload_gen.AddOperation(False, common.OpType.MOVE,
    422                              src_extents=[(0, 16), (16, 497)],
    423                              dst_extents=[(16, 496), (0, 16)])
    424     payload_gen.AddOperation(True, common.OpType.MOVE,
    425                              src_extents=[(0, 8), (8, 8)],
    426                              dst_extents=[(8, 8), (0, 8)])
    427 
    428     # Set an invalid signatures block (offset but no size), if required.
    429     if fail_bad_sigs:
    430       payload_gen.SetSignatures(32, None)
    431 
    432     # Set partition / filesystem sizes.
    433     rootfs_part_size = test_utils.MiB(8)
    434     kernel_part_size = test_utils.KiB(512)
    435     old_rootfs_fs_size = new_rootfs_fs_size = rootfs_part_size
    436     old_kernel_fs_size = new_kernel_fs_size = kernel_part_size
    437     if fail_old_kernel_fs_size:
    438       old_kernel_fs_size += 100
    439     if fail_old_rootfs_fs_size:
    440       old_rootfs_fs_size += 100
    441     if fail_new_kernel_fs_size:
    442       new_kernel_fs_size += 100
    443     if fail_new_rootfs_fs_size:
    444       new_rootfs_fs_size += 100
    445 
    446     # Add old kernel/rootfs partition info, as required.
    447     if fail_mismatched_oki_ori or fail_old_kernel_fs_size or fail_bad_oki:
    448       oki_hash = (None if fail_bad_oki
    449                   else hashlib.sha256('fake-oki-content').digest())
    450       payload_gen.SetPartInfo(True, False, old_kernel_fs_size, oki_hash)
    451     if not fail_mismatched_oki_ori and (fail_old_rootfs_fs_size or
    452                                         fail_bad_ori):
    453       ori_hash = (None if fail_bad_ori
    454                   else hashlib.sha256('fake-ori-content').digest())
    455       payload_gen.SetPartInfo(False, False, old_rootfs_fs_size, ori_hash)
    456 
    457     # Add new kernel/rootfs partition info.
    458     payload_gen.SetPartInfo(
    459         True, True, new_kernel_fs_size,
    460         None if fail_bad_nki else hashlib.sha256('fake-nki-content').digest())
    461     payload_gen.SetPartInfo(
    462         False, True, new_rootfs_fs_size,
    463         None if fail_bad_nri else hashlib.sha256('fake-nri-content').digest())
    464 
    465     # Set the minor version.
    466     payload_gen.SetMinorVersion(0)
    467 
    468     # Create the test object.
    469     payload_checker = _GetPayloadChecker(payload_gen.WriteToFile)
    470     report = checker._PayloadReport()
    471 
    472     should_fail = (fail_mismatched_block_size or fail_bad_sigs or
    473                    fail_mismatched_oki_ori or fail_bad_oki or fail_bad_ori or
    474                    fail_bad_nki or fail_bad_nri or fail_old_kernel_fs_size or
    475                    fail_old_rootfs_fs_size or fail_new_kernel_fs_size or
    476                    fail_new_rootfs_fs_size)
    477     if should_fail:
    478       self.assertRaises(update_payload.PayloadError,
    479                         payload_checker._CheckManifest, report,
    480                         rootfs_part_size, kernel_part_size)
    481     else:
    482       self.assertIsNone(payload_checker._CheckManifest(report,
    483                                                        rootfs_part_size,
    484                                                        kernel_part_size))
    485 
    486   def testCheckLength(self):
    487     """Tests _CheckLength()."""
    488     payload_checker = checker.PayloadChecker(self.MockPayload())
    489     block_size = payload_checker.block_size
    490 
    491     # Passes.
    492     self.assertIsNone(payload_checker._CheckLength(
    493         int(3.5 * block_size), 4, 'foo', 'bar'))
    494     # Fails, too few blocks.
    495     self.assertRaises(update_payload.PayloadError,
    496                       payload_checker._CheckLength,
    497                       int(3.5 * block_size), 3, 'foo', 'bar')
    498     # Fails, too many blocks.
    499     self.assertRaises(update_payload.PayloadError,
    500                       payload_checker._CheckLength,
    501                       int(3.5 * block_size), 5, 'foo', 'bar')
    502 
    503   def testCheckExtents(self):
    504     """Tests _CheckExtents()."""
    505     payload_checker = checker.PayloadChecker(self.MockPayload())
    506     block_size = payload_checker.block_size
    507 
    508     # Passes w/ all real extents.
    509     extents = self.NewExtentList((0, 4), (8, 3), (1024, 16))
    510     self.assertEquals(
    511         23,
    512         payload_checker._CheckExtents(extents, (1024 + 16) * block_size,
    513                                       collections.defaultdict(int), 'foo'))
    514 
    515     # Passes w/ pseudo-extents (aka sparse holes).
    516     extents = self.NewExtentList((0, 4), (common.PSEUDO_EXTENT_MARKER, 5),
    517                                  (8, 3))
    518     self.assertEquals(
    519         12,
    520         payload_checker._CheckExtents(extents, (1024 + 16) * block_size,
    521                                       collections.defaultdict(int), 'foo',
    522                                       allow_pseudo=True))
    523 
    524     # Passes w/ pseudo-extent due to a signature.
    525     extents = self.NewExtentList((common.PSEUDO_EXTENT_MARKER, 2))
    526     self.assertEquals(
    527         2,
    528         payload_checker._CheckExtents(extents, (1024 + 16) * block_size,
    529                                       collections.defaultdict(int), 'foo',
    530                                       allow_signature=True))
    531 
    532     # Fails, extent missing a start block.
    533     extents = self.NewExtentList((-1, 4), (8, 3), (1024, 16))
    534     self.assertRaises(
    535         update_payload.PayloadError, payload_checker._CheckExtents,
    536         extents, (1024 + 16) * block_size, collections.defaultdict(int),
    537         'foo')
    538 
    539     # Fails, extent missing block count.
    540     extents = self.NewExtentList((0, -1), (8, 3), (1024, 16))
    541     self.assertRaises(
    542         update_payload.PayloadError, payload_checker._CheckExtents,
    543         extents, (1024 + 16) * block_size, collections.defaultdict(int),
    544         'foo')
    545 
    546     # Fails, extent has zero blocks.
    547     extents = self.NewExtentList((0, 4), (8, 3), (1024, 0))
    548     self.assertRaises(
    549         update_payload.PayloadError, payload_checker._CheckExtents,
    550         extents, (1024 + 16) * block_size, collections.defaultdict(int),
    551         'foo')
    552 
    553     # Fails, extent exceeds partition boundaries.
    554     extents = self.NewExtentList((0, 4), (8, 3), (1024, 16))
    555     self.assertRaises(
    556         update_payload.PayloadError, payload_checker._CheckExtents,
    557         extents, (1024 + 15) * block_size, collections.defaultdict(int),
    558         'foo')
    559 
    560   def testCheckReplaceOperation(self):
    561     """Tests _CheckReplaceOperation() where op.type == REPLACE."""
    562     payload_checker = checker.PayloadChecker(self.MockPayload())
    563     block_size = payload_checker.block_size
    564     data_length = 10000
    565 
    566     op = self.mox.CreateMock(
    567         update_metadata_pb2.InstallOperation)
    568     op.type = common.OpType.REPLACE
    569 
    570     # Pass.
    571     op.src_extents = []
    572     self.assertIsNone(
    573         payload_checker._CheckReplaceOperation(
    574             op, data_length, (data_length + block_size - 1) / block_size,
    575             'foo'))
    576 
    577     # Fail, src extents founds.
    578     op.src_extents = ['bar']
    579     self.assertRaises(
    580         update_payload.PayloadError,
    581         payload_checker._CheckReplaceOperation,
    582         op, data_length, (data_length + block_size - 1) / block_size, 'foo')
    583 
    584     # Fail, missing data.
    585     op.src_extents = []
    586     self.assertRaises(
    587         update_payload.PayloadError,
    588         payload_checker._CheckReplaceOperation,
    589         op, None, (data_length + block_size - 1) / block_size, 'foo')
    590 
    591     # Fail, length / block number mismatch.
    592     op.src_extents = ['bar']
    593     self.assertRaises(
    594         update_payload.PayloadError,
    595         payload_checker._CheckReplaceOperation,
    596         op, data_length, (data_length + block_size - 1) / block_size + 1, 'foo')
    597 
    598   def testCheckReplaceBzOperation(self):
    599     """Tests _CheckReplaceOperation() where op.type == REPLACE_BZ."""
    600     payload_checker = checker.PayloadChecker(self.MockPayload())
    601     block_size = payload_checker.block_size
    602     data_length = block_size * 3
    603 
    604     op = self.mox.CreateMock(
    605         update_metadata_pb2.InstallOperation)
    606     op.type = common.OpType.REPLACE_BZ
    607 
    608     # Pass.
    609     op.src_extents = []
    610     self.assertIsNone(
    611         payload_checker._CheckReplaceOperation(
    612             op, data_length, (data_length + block_size - 1) / block_size + 5,
    613             'foo'))
    614 
    615     # Fail, src extents founds.
    616     op.src_extents = ['bar']
    617     self.assertRaises(
    618         update_payload.PayloadError,
    619         payload_checker._CheckReplaceOperation,
    620         op, data_length, (data_length + block_size - 1) / block_size + 5, 'foo')
    621 
    622     # Fail, missing data.
    623     op.src_extents = []
    624     self.assertRaises(
    625         update_payload.PayloadError,
    626         payload_checker._CheckReplaceOperation,
    627         op, None, (data_length + block_size - 1) / block_size, 'foo')
    628 
    629     # Fail, too few blocks to justify BZ.
    630     op.src_extents = []
    631     self.assertRaises(
    632         update_payload.PayloadError,
    633         payload_checker._CheckReplaceOperation,
    634         op, data_length, (data_length + block_size - 1) / block_size, 'foo')
    635 
    636   def testCheckMoveOperation_Pass(self):
    637     """Tests _CheckMoveOperation(); pass case."""
    638     payload_checker = checker.PayloadChecker(self.MockPayload())
    639     op = update_metadata_pb2.InstallOperation()
    640     op.type = common.OpType.MOVE
    641 
    642     self.AddToMessage(op.src_extents,
    643                       self.NewExtentList((1, 4), (12, 2), (1024, 128)))
    644     self.AddToMessage(op.dst_extents,
    645                       self.NewExtentList((16, 128), (512, 6)))
    646     self.assertIsNone(
    647         payload_checker._CheckMoveOperation(op, None, 134, 134, 'foo'))
    648 
    649   def testCheckMoveOperation_FailContainsData(self):
    650     """Tests _CheckMoveOperation(); fails, message contains data."""
    651     payload_checker = checker.PayloadChecker(self.MockPayload())
    652     op = update_metadata_pb2.InstallOperation()
    653     op.type = common.OpType.MOVE
    654 
    655     self.AddToMessage(op.src_extents,
    656                       self.NewExtentList((1, 4), (12, 2), (1024, 128)))
    657     self.AddToMessage(op.dst_extents,
    658                       self.NewExtentList((16, 128), (512, 6)))
    659     self.assertRaises(
    660         update_payload.PayloadError,
    661         payload_checker._CheckMoveOperation,
    662         op, 1024, 134, 134, 'foo')
    663 
    664   def testCheckMoveOperation_FailInsufficientSrcBlocks(self):
    665     """Tests _CheckMoveOperation(); fails, not enough actual src blocks."""
    666     payload_checker = checker.PayloadChecker(self.MockPayload())
    667     op = update_metadata_pb2.InstallOperation()
    668     op.type = common.OpType.MOVE
    669 
    670     self.AddToMessage(op.src_extents,
    671                       self.NewExtentList((1, 4), (12, 2), (1024, 127)))
    672     self.AddToMessage(op.dst_extents,
    673                       self.NewExtentList((16, 128), (512, 6)))
    674     self.assertRaises(
    675         update_payload.PayloadError,
    676         payload_checker._CheckMoveOperation,
    677         op, None, 134, 134, 'foo')
    678 
    679   def testCheckMoveOperation_FailInsufficientDstBlocks(self):
    680     """Tests _CheckMoveOperation(); fails, not enough actual dst blocks."""
    681     payload_checker = checker.PayloadChecker(self.MockPayload())
    682     op = update_metadata_pb2.InstallOperation()
    683     op.type = common.OpType.MOVE
    684 
    685     self.AddToMessage(op.src_extents,
    686                       self.NewExtentList((1, 4), (12, 2), (1024, 128)))
    687     self.AddToMessage(op.dst_extents,
    688                       self.NewExtentList((16, 128), (512, 5)))
    689     self.assertRaises(
    690         update_payload.PayloadError,
    691         payload_checker._CheckMoveOperation,
    692         op, None, 134, 134, 'foo')
    693 
    694   def testCheckMoveOperation_FailExcessSrcBlocks(self):
    695     """Tests _CheckMoveOperation(); fails, too many actual src blocks."""
    696     payload_checker = checker.PayloadChecker(self.MockPayload())
    697     op = update_metadata_pb2.InstallOperation()
    698     op.type = common.OpType.MOVE
    699 
    700     self.AddToMessage(op.src_extents,
    701                       self.NewExtentList((1, 4), (12, 2), (1024, 128)))
    702     self.AddToMessage(op.dst_extents,
    703                       self.NewExtentList((16, 128), (512, 5)))
    704     self.assertRaises(
    705         update_payload.PayloadError,
    706         payload_checker._CheckMoveOperation,
    707         op, None, 134, 134, 'foo')
    708     self.AddToMessage(op.src_extents,
    709                       self.NewExtentList((1, 4), (12, 2), (1024, 129)))
    710     self.AddToMessage(op.dst_extents,
    711                       self.NewExtentList((16, 128), (512, 6)))
    712     self.assertRaises(
    713         update_payload.PayloadError,
    714         payload_checker._CheckMoveOperation,
    715         op, None, 134, 134, 'foo')
    716 
    717   def testCheckMoveOperation_FailExcessDstBlocks(self):
    718     """Tests _CheckMoveOperation(); fails, too many actual dst blocks."""
    719     payload_checker = checker.PayloadChecker(self.MockPayload())
    720     op = update_metadata_pb2.InstallOperation()
    721     op.type = common.OpType.MOVE
    722 
    723     self.AddToMessage(op.src_extents,
    724                       self.NewExtentList((1, 4), (12, 2), (1024, 128)))
    725     self.AddToMessage(op.dst_extents,
    726                       self.NewExtentList((16, 128), (512, 7)))
    727     self.assertRaises(
    728         update_payload.PayloadError,
    729         payload_checker._CheckMoveOperation,
    730         op, None, 134, 134, 'foo')
    731 
    732   def testCheckMoveOperation_FailStagnantBlocks(self):
    733     """Tests _CheckMoveOperation(); fails, there are blocks that do not move."""
    734     payload_checker = checker.PayloadChecker(self.MockPayload())
    735     op = update_metadata_pb2.InstallOperation()
    736     op.type = common.OpType.MOVE
    737 
    738     self.AddToMessage(op.src_extents,
    739                       self.NewExtentList((1, 4), (12, 2), (1024, 128)))
    740     self.AddToMessage(op.dst_extents,
    741                       self.NewExtentList((8, 128), (512, 6)))
    742     self.assertRaises(
    743         update_payload.PayloadError,
    744         payload_checker._CheckMoveOperation,
    745         op, None, 134, 134, 'foo')
    746 
    747   def testCheckMoveOperation_FailZeroStartBlock(self):
    748     """Tests _CheckMoveOperation(); fails, has extent with start block 0."""
    749     payload_checker = checker.PayloadChecker(self.MockPayload())
    750     op = update_metadata_pb2.InstallOperation()
    751     op.type = common.OpType.MOVE
    752 
    753     self.AddToMessage(op.src_extents,
    754                       self.NewExtentList((0, 4), (12, 2), (1024, 128)))
    755     self.AddToMessage(op.dst_extents,
    756                       self.NewExtentList((8, 128), (512, 6)))
    757     self.assertRaises(
    758         update_payload.PayloadError,
    759         payload_checker._CheckMoveOperation,
    760         op, None, 134, 134, 'foo')
    761 
    762     self.AddToMessage(op.src_extents,
    763                       self.NewExtentList((1, 4), (12, 2), (1024, 128)))
    764     self.AddToMessage(op.dst_extents,
    765                       self.NewExtentList((0, 128), (512, 6)))
    766     self.assertRaises(
    767         update_payload.PayloadError,
    768         payload_checker._CheckMoveOperation,
    769         op, None, 134, 134, 'foo')
    770 
    771   def testCheckAnyDiff(self):
    772     """Tests _CheckAnyDiffOperation()."""
    773     payload_checker = checker.PayloadChecker(self.MockPayload())
    774 
    775     # Pass.
    776     self.assertIsNone(
    777         payload_checker._CheckAnyDiffOperation(10000, 3, 'foo'))
    778 
    779     # Fail, missing data blob.
    780     self.assertRaises(
    781         update_payload.PayloadError,
    782         payload_checker._CheckAnyDiffOperation,
    783         None, 3, 'foo')
    784 
    785     # Fail, too big of a diff blob (unjustified).
    786     self.assertRaises(
    787         update_payload.PayloadError,
    788         payload_checker._CheckAnyDiffOperation,
    789         10000, 2, 'foo')
    790 
    791   def testCheckSourceCopyOperation_Pass(self):
    792     """Tests _CheckSourceCopyOperation(); pass case."""
    793     payload_checker = checker.PayloadChecker(self.MockPayload())
    794     self.assertIsNone(
    795         payload_checker._CheckSourceCopyOperation(None, 134, 134, 'foo'))
    796 
    797   def testCheckSourceCopyOperation_FailContainsData(self):
    798     """Tests _CheckSourceCopyOperation(); message contains data."""
    799     payload_checker = checker.PayloadChecker(self.MockPayload())
    800     self.assertRaises(update_payload.PayloadError,
    801                       payload_checker._CheckSourceCopyOperation,
    802                       134, 0, 0, 'foo')
    803 
    804   def testCheckSourceCopyOperation_FailBlockCountsMismatch(self):
    805     """Tests _CheckSourceCopyOperation(); src and dst block totals not equal."""
    806     payload_checker = checker.PayloadChecker(self.MockPayload())
    807     self.assertRaises(update_payload.PayloadError,
    808                       payload_checker._CheckSourceCopyOperation,
    809                       None, 0, 1, 'foo')
    810 
    811   def DoCheckOperationTest(self, op_type_name, is_last, allow_signature,
    812                            allow_unhashed, fail_src_extents, fail_dst_extents,
    813                            fail_mismatched_data_offset_length,
    814                            fail_missing_dst_extents, fail_src_length,
    815                            fail_dst_length, fail_data_hash,
    816                            fail_prev_data_offset, fail_bad_minor_version):
    817     """Parametric testing of _CheckOperation().
    818 
    819     Args:
    820       op_type_name: 'REPLACE', 'REPLACE_BZ', 'MOVE', 'BSDIFF', 'SOURCE_COPY',
    821         or 'SOURCE_BSDIFF'.
    822       is_last: Whether we're testing the last operation in a sequence.
    823       allow_signature: Whether we're testing a signature-capable operation.
    824       allow_unhashed: Whether we're allowing to not hash the data.
    825       fail_src_extents: Tamper with src extents.
    826       fail_dst_extents: Tamper with dst extents.
    827       fail_mismatched_data_offset_length: Make data_{offset,length}
    828         inconsistent.
    829       fail_missing_dst_extents: Do not include dst extents.
    830       fail_src_length: Make src length inconsistent.
    831       fail_dst_length: Make dst length inconsistent.
    832       fail_data_hash: Tamper with the data blob hash.
    833       fail_prev_data_offset: Make data space uses incontiguous.
    834       fail_bad_minor_version: Make minor version incompatible with op.
    835     """
    836     op_type = _OpTypeByName(op_type_name)
    837 
    838     # Create the test object.
    839     payload = self.MockPayload()
    840     payload_checker = checker.PayloadChecker(payload,
    841                                              allow_unhashed=allow_unhashed)
    842     block_size = payload_checker.block_size
    843 
    844     # Create auxiliary arguments.
    845     old_part_size = test_utils.MiB(4)
    846     new_part_size = test_utils.MiB(8)
    847     old_block_counters = array.array(
    848         'B', [0] * ((old_part_size + block_size - 1) / block_size))
    849     new_block_counters = array.array(
    850         'B', [0] * ((new_part_size + block_size - 1) / block_size))
    851     prev_data_offset = 1876
    852     blob_hash_counts = collections.defaultdict(int)
    853 
    854     # Create the operation object for the test.
    855     op = update_metadata_pb2.InstallOperation()
    856     op.type = op_type
    857 
    858     total_src_blocks = 0
    859     if op_type in (common.OpType.MOVE, common.OpType.BSDIFF,
    860                    common.OpType.SOURCE_COPY, common.OpType.SOURCE_BSDIFF):
    861       if fail_src_extents:
    862         self.AddToMessage(op.src_extents,
    863                           self.NewExtentList((1, 0)))
    864       else:
    865         self.AddToMessage(op.src_extents,
    866                           self.NewExtentList((1, 16)))
    867         total_src_blocks = 16
    868 
    869     if op_type in (common.OpType.REPLACE, common.OpType.REPLACE_BZ):
    870       payload_checker.minor_version = 0
    871     elif op_type in (common.OpType.MOVE, common.OpType.BSDIFF):
    872       payload_checker.minor_version = 2 if fail_bad_minor_version else 1
    873     elif op_type in (common.OpType.SOURCE_COPY, common.OpType.SOURCE_BSDIFF):
    874       payload_checker.minor_version = 1 if fail_bad_minor_version else 2
    875 
    876     if op_type not in (common.OpType.MOVE, common.OpType.SOURCE_COPY):
    877       if not fail_mismatched_data_offset_length:
    878         op.data_length = 16 * block_size - 8
    879       if fail_prev_data_offset:
    880         op.data_offset = prev_data_offset + 16
    881       else:
    882         op.data_offset = prev_data_offset
    883 
    884       fake_data = 'fake-data'.ljust(op.data_length)
    885       if not (allow_unhashed or (is_last and allow_signature and
    886                                  op_type == common.OpType.REPLACE)):
    887         if not fail_data_hash:
    888           # Create a valid data blob hash.
    889           op.data_sha256_hash = hashlib.sha256(fake_data).digest()
    890           payload.ReadDataBlob(op.data_offset, op.data_length).AndReturn(
    891               fake_data)
    892       elif fail_data_hash:
    893         # Create an invalid data blob hash.
    894         op.data_sha256_hash = hashlib.sha256(
    895             fake_data.replace(' ', '-')).digest()
    896         payload.ReadDataBlob(op.data_offset, op.data_length).AndReturn(
    897             fake_data)
    898 
    899     total_dst_blocks = 0
    900     if not fail_missing_dst_extents:
    901       total_dst_blocks = 16
    902       if fail_dst_extents:
    903         self.AddToMessage(op.dst_extents,
    904                           self.NewExtentList((4, 16), (32, 0)))
    905       else:
    906         self.AddToMessage(op.dst_extents,
    907                           self.NewExtentList((4, 8), (64, 8)))
    908 
    909     if total_src_blocks:
    910       if fail_src_length:
    911         op.src_length = total_src_blocks * block_size + 8
    912       else:
    913         op.src_length = total_src_blocks * block_size
    914     elif fail_src_length:
    915       # Add an orphaned src_length.
    916       op.src_length = 16
    917 
    918     if total_dst_blocks:
    919       if fail_dst_length:
    920         op.dst_length = total_dst_blocks * block_size + 8
    921       else:
    922         op.dst_length = total_dst_blocks * block_size
    923 
    924     self.mox.ReplayAll()
    925     should_fail = (fail_src_extents or fail_dst_extents or
    926                    fail_mismatched_data_offset_length or
    927                    fail_missing_dst_extents or fail_src_length or
    928                    fail_dst_length or fail_data_hash or fail_prev_data_offset or
    929                    fail_bad_minor_version)
    930     args = (op, 'foo', is_last, old_block_counters, new_block_counters,
    931             old_part_size, new_part_size, prev_data_offset, allow_signature,
    932             blob_hash_counts)
    933     if should_fail:
    934       self.assertRaises(update_payload.PayloadError,
    935                         payload_checker._CheckOperation, *args)
    936     else:
    937       self.assertEqual(op.data_length if op.HasField('data_length') else 0,
    938                        payload_checker._CheckOperation(*args))
    939 
    940   def testAllocBlockCounters(self):
    941     """Tests _CheckMoveOperation()."""
    942     payload_checker = checker.PayloadChecker(self.MockPayload())
    943     block_size = payload_checker.block_size
    944 
    945     # Check allocation for block-aligned partition size, ensure it's integers.
    946     result = payload_checker._AllocBlockCounters(16 * block_size)
    947     self.assertEqual(16, len(result))
    948     self.assertEqual(int, type(result[0]))
    949 
    950     # Check allocation of unaligned partition sizes.
    951     result = payload_checker._AllocBlockCounters(16 * block_size - 1)
    952     self.assertEqual(16, len(result))
    953     result = payload_checker._AllocBlockCounters(16 * block_size + 1)
    954     self.assertEqual(17, len(result))
    955 
    956   def DoCheckOperationsTest(self, fail_nonexhaustive_full_update):
    957     # Generate a test payload. For this test, we only care about one
    958     # (arbitrary) set of operations, so we'll only be generating kernel and
    959     # test with them.
    960     payload_gen = test_utils.PayloadGenerator()
    961 
    962     block_size = test_utils.KiB(4)
    963     payload_gen.SetBlockSize(block_size)
    964 
    965     rootfs_part_size = test_utils.MiB(8)
    966 
    967     # Fake rootfs operations in a full update, tampered with as required.
    968     rootfs_op_type = common.OpType.REPLACE
    969     rootfs_data_length = rootfs_part_size
    970     if fail_nonexhaustive_full_update:
    971       rootfs_data_length -= block_size
    972 
    973     payload_gen.AddOperation(False, rootfs_op_type,
    974                              dst_extents=[(0, rootfs_data_length / block_size)],
    975                              data_offset=0,
    976                              data_length=rootfs_data_length)
    977 
    978     # Create the test object.
    979     payload_checker = _GetPayloadChecker(payload_gen.WriteToFile,
    980                                          checker_init_dargs={
    981                                              'allow_unhashed': True})
    982     payload_checker.payload_type = checker._TYPE_FULL
    983     report = checker._PayloadReport()
    984 
    985     args = (payload_checker.payload.manifest.install_operations, report,
    986             'foo', 0, rootfs_part_size, rootfs_part_size, 0, False)
    987     if fail_nonexhaustive_full_update:
    988       self.assertRaises(update_payload.PayloadError,
    989                         payload_checker._CheckOperations, *args)
    990     else:
    991       self.assertEqual(rootfs_data_length,
    992                        payload_checker._CheckOperations(*args))
    993 
    994   def DoCheckSignaturesTest(self, fail_empty_sigs_blob, fail_missing_pseudo_op,
    995                             fail_mismatched_pseudo_op, fail_sig_missing_fields,
    996                             fail_unknown_sig_version, fail_incorrect_sig):
    997     # Generate a test payload. For this test, we only care about the signature
    998     # block and how it relates to the payload hash. Therefore, we're generating
    999     # a random (otherwise useless) payload for this purpose.
   1000     payload_gen = test_utils.EnhancedPayloadGenerator()
   1001     block_size = test_utils.KiB(4)
   1002     payload_gen.SetBlockSize(block_size)
   1003     rootfs_part_size = test_utils.MiB(2)
   1004     kernel_part_size = test_utils.KiB(16)
   1005     payload_gen.SetPartInfo(False, True, rootfs_part_size,
   1006                             hashlib.sha256('fake-new-rootfs-content').digest())
   1007     payload_gen.SetPartInfo(True, True, kernel_part_size,
   1008                             hashlib.sha256('fake-new-kernel-content').digest())
   1009     payload_gen.SetMinorVersion(0)
   1010     payload_gen.AddOperationWithData(
   1011         False, common.OpType.REPLACE,
   1012         dst_extents=[(0, rootfs_part_size / block_size)],
   1013         data_blob=os.urandom(rootfs_part_size))
   1014 
   1015     do_forge_pseudo_op = (fail_missing_pseudo_op or fail_mismatched_pseudo_op)
   1016     do_forge_sigs_data = (do_forge_pseudo_op or fail_empty_sigs_blob or
   1017                           fail_sig_missing_fields or fail_unknown_sig_version
   1018                           or fail_incorrect_sig)
   1019 
   1020     sigs_data = None
   1021     if do_forge_sigs_data:
   1022       sigs_gen = test_utils.SignaturesGenerator()
   1023       if not fail_empty_sigs_blob:
   1024         if fail_sig_missing_fields:
   1025           sig_data = None
   1026         else:
   1027           sig_data = test_utils.SignSha256('fake-payload-content',
   1028                                            test_utils._PRIVKEY_FILE_NAME)
   1029         sigs_gen.AddSig(5 if fail_unknown_sig_version else 1, sig_data)
   1030 
   1031       sigs_data = sigs_gen.ToBinary()
   1032       payload_gen.SetSignatures(payload_gen.curr_offset, len(sigs_data))
   1033 
   1034     if do_forge_pseudo_op:
   1035       assert sigs_data is not None, 'should have forged signatures blob by now'
   1036       sigs_len = len(sigs_data)
   1037       payload_gen.AddOperation(
   1038           False, common.OpType.REPLACE,
   1039           data_offset=payload_gen.curr_offset / 2,
   1040           data_length=sigs_len / 2,
   1041           dst_extents=[(0, (sigs_len / 2 + block_size - 1) / block_size)])
   1042 
   1043     # Generate payload (complete w/ signature) and create the test object.
   1044     payload_checker = _GetPayloadChecker(
   1045         payload_gen.WriteToFileWithData,
   1046         payload_gen_dargs={
   1047             'sigs_data': sigs_data,
   1048             'privkey_file_name': test_utils._PRIVKEY_FILE_NAME,
   1049             'do_add_pseudo_operation': not do_forge_pseudo_op})
   1050     payload_checker.payload_type = checker._TYPE_FULL
   1051     report = checker._PayloadReport()
   1052 
   1053     # We have to check the manifest first in order to set signature attributes.
   1054     payload_checker._CheckManifest(report, rootfs_part_size, kernel_part_size)
   1055 
   1056     should_fail = (fail_empty_sigs_blob or fail_missing_pseudo_op or
   1057                    fail_mismatched_pseudo_op or fail_sig_missing_fields or
   1058                    fail_unknown_sig_version or fail_incorrect_sig)
   1059     args = (report, test_utils._PUBKEY_FILE_NAME)
   1060     if should_fail:
   1061       self.assertRaises(update_payload.PayloadError,
   1062                         payload_checker._CheckSignatures, *args)
   1063     else:
   1064       self.assertIsNone(payload_checker._CheckSignatures(*args))
   1065 
   1066   def DoCheckManifestMinorVersionTest(self, minor_version, payload_type):
   1067     """Parametric testing for CheckManifestMinorVersion().
   1068 
   1069     Args:
   1070       minor_version: The payload minor version to test with.
   1071       payload_type: The type of the payload we're testing, delta or full.
   1072     """
   1073     # Create the test object.
   1074     payload = self.MockPayload()
   1075     payload.manifest.minor_version = minor_version
   1076     payload_checker = checker.PayloadChecker(payload)
   1077     payload_checker.payload_type = payload_type
   1078     report = checker._PayloadReport()
   1079 
   1080     should_succeed = (
   1081         (minor_version == 0 and payload_type == checker._TYPE_FULL) or
   1082         (minor_version == 1 and payload_type == checker._TYPE_DELTA) or
   1083         (minor_version == 2 and payload_type == checker._TYPE_DELTA) or
   1084         (minor_version == 3 and payload_type == checker._TYPE_DELTA) or
   1085         (minor_version == 4 and payload_type == checker._TYPE_DELTA))
   1086     args = (report,)
   1087 
   1088     if should_succeed:
   1089       self.assertIsNone(payload_checker._CheckManifestMinorVersion(*args))
   1090     else:
   1091       self.assertRaises(update_payload.PayloadError,
   1092                         payload_checker._CheckManifestMinorVersion, *args)
   1093 
   1094   def DoRunTest(self, rootfs_part_size_provided, kernel_part_size_provided,
   1095                 fail_wrong_payload_type, fail_invalid_block_size,
   1096                 fail_mismatched_block_size, fail_excess_data,
   1097                 fail_rootfs_part_size_exceeded,
   1098                 fail_kernel_part_size_exceeded):
   1099     # Generate a test payload. For this test, we generate a full update that
   1100     # has sample kernel and rootfs operations. Since most testing is done with
   1101     # internal PayloadChecker methods that are tested elsewhere, here we only
   1102     # tamper with what's actually being manipulated and/or tested in the Run()
   1103     # method itself. Note that the checker doesn't verify partition hashes, so
   1104     # they're safe to fake.
   1105     payload_gen = test_utils.EnhancedPayloadGenerator()
   1106     block_size = test_utils.KiB(4)
   1107     payload_gen.SetBlockSize(block_size)
   1108     kernel_filesystem_size = test_utils.KiB(16)
   1109     rootfs_filesystem_size = test_utils.MiB(2)
   1110     payload_gen.SetPartInfo(False, True, rootfs_filesystem_size,
   1111                             hashlib.sha256('fake-new-rootfs-content').digest())
   1112     payload_gen.SetPartInfo(True, True, kernel_filesystem_size,
   1113                             hashlib.sha256('fake-new-kernel-content').digest())
   1114     payload_gen.SetMinorVersion(0)
   1115 
   1116     rootfs_part_size = 0
   1117     if rootfs_part_size_provided:
   1118       rootfs_part_size = rootfs_filesystem_size + block_size
   1119     rootfs_op_size = rootfs_part_size or rootfs_filesystem_size
   1120     if fail_rootfs_part_size_exceeded:
   1121       rootfs_op_size += block_size
   1122     payload_gen.AddOperationWithData(
   1123         False, common.OpType.REPLACE,
   1124         dst_extents=[(0, rootfs_op_size / block_size)],
   1125         data_blob=os.urandom(rootfs_op_size))
   1126 
   1127     kernel_part_size = 0
   1128     if kernel_part_size_provided:
   1129       kernel_part_size = kernel_filesystem_size + block_size
   1130     kernel_op_size = kernel_part_size or kernel_filesystem_size
   1131     if fail_kernel_part_size_exceeded:
   1132       kernel_op_size += block_size
   1133     payload_gen.AddOperationWithData(
   1134         True, common.OpType.REPLACE,
   1135         dst_extents=[(0, kernel_op_size / block_size)],
   1136         data_blob=os.urandom(kernel_op_size))
   1137 
   1138     # Generate payload (complete w/ signature) and create the test object.
   1139     if fail_invalid_block_size:
   1140       use_block_size = block_size + 5  # Not a power of two.
   1141     elif fail_mismatched_block_size:
   1142       use_block_size = block_size * 2  # Different that payload stated.
   1143     else:
   1144       use_block_size = block_size
   1145 
   1146     kwargs = {
   1147         'payload_gen_dargs': {
   1148             'privkey_file_name': test_utils._PRIVKEY_FILE_NAME,
   1149             'do_add_pseudo_operation': True,
   1150             'is_pseudo_in_kernel': True,
   1151             'padding': os.urandom(1024) if fail_excess_data else None},
   1152         'checker_init_dargs': {
   1153             'assert_type': 'delta' if fail_wrong_payload_type else 'full',
   1154             'block_size': use_block_size}}
   1155     if fail_invalid_block_size:
   1156       self.assertRaises(update_payload.PayloadError, _GetPayloadChecker,
   1157                         payload_gen.WriteToFileWithData, **kwargs)
   1158     else:
   1159       payload_checker = _GetPayloadChecker(payload_gen.WriteToFileWithData,
   1160                                            **kwargs)
   1161 
   1162       kwargs = {'pubkey_file_name': test_utils._PUBKEY_FILE_NAME,
   1163                 'rootfs_part_size': rootfs_part_size,
   1164                 'kernel_part_size': kernel_part_size}
   1165       should_fail = (fail_wrong_payload_type or fail_mismatched_block_size or
   1166                      fail_excess_data or
   1167                      fail_rootfs_part_size_exceeded or
   1168                      fail_kernel_part_size_exceeded)
   1169       if should_fail:
   1170         self.assertRaises(update_payload.PayloadError, payload_checker.Run,
   1171                           **kwargs)
   1172       else:
   1173         self.assertIsNone(payload_checker.Run(**kwargs))
   1174 
   1175 # This implements a generic API, hence the occasional unused args.
   1176 # pylint: disable=W0613
   1177 def ValidateCheckOperationTest(op_type_name, is_last, allow_signature,
   1178                                allow_unhashed, fail_src_extents,
   1179                                fail_dst_extents,
   1180                                fail_mismatched_data_offset_length,
   1181                                fail_missing_dst_extents, fail_src_length,
   1182                                fail_dst_length, fail_data_hash,
   1183                                fail_prev_data_offset, fail_bad_minor_version):
   1184   """Returns True iff the combination of arguments represents a valid test."""
   1185   op_type = _OpTypeByName(op_type_name)
   1186 
   1187   # REPLACE/REPLACE_BZ operations don't read data from src partition. They are
   1188   # compatible with all valid minor versions, so we don't need to check that.
   1189   if (op_type in (common.OpType.REPLACE, common.OpType.REPLACE_BZ) and (
   1190       fail_src_extents or fail_src_length or fail_bad_minor_version)):
   1191     return False
   1192 
   1193   # MOVE and SOURCE_COPY operations don't carry data.
   1194   if (op_type in (common.OpType.MOVE, common.OpType.SOURCE_COPY) and (
   1195       fail_mismatched_data_offset_length or fail_data_hash or
   1196       fail_prev_data_offset)):
   1197     return False
   1198 
   1199   return True
   1200 
   1201 
   1202 def TestMethodBody(run_method_name, run_dargs):
   1203   """Returns a function that invokes a named method with named arguments."""
   1204   return lambda self: getattr(self, run_method_name)(**run_dargs)
   1205 
   1206 
   1207 def AddParametricTests(tested_method_name, arg_space, validate_func=None):
   1208   """Enumerates and adds specific parametric tests to PayloadCheckerTest.
   1209 
   1210   This function enumerates a space of test parameters (defined by arg_space),
   1211   then binds a new, unique method name in PayloadCheckerTest to a test function
   1212   that gets handed the said parameters. This is a preferable approach to doing
   1213   the enumeration and invocation during the tests because this way each test is
   1214   treated as a complete run by the unittest framework, and so benefits from the
   1215   usual setUp/tearDown mechanics.
   1216 
   1217   Args:
   1218     tested_method_name: Name of the tested PayloadChecker method.
   1219     arg_space: A dictionary containing variables (keys) and lists of values
   1220                (values) associated with them.
   1221     validate_func: A function used for validating test argument combinations.
   1222   """
   1223   for value_tuple in itertools.product(*arg_space.itervalues()):
   1224     run_dargs = dict(zip(arg_space.iterkeys(), value_tuple))
   1225     if validate_func and not validate_func(**run_dargs):
   1226       continue
   1227     run_method_name = 'Do%sTest' % tested_method_name
   1228     test_method_name = 'test%s' % tested_method_name
   1229     for arg_key, arg_val in run_dargs.iteritems():
   1230       if arg_val or type(arg_val) is int:
   1231         test_method_name += '__%s=%s' % (arg_key, arg_val)
   1232     setattr(PayloadCheckerTest, test_method_name,
   1233             TestMethodBody(run_method_name, run_dargs))
   1234 
   1235 
   1236 def AddAllParametricTests():
   1237   """Enumerates and adds all parametric tests to PayloadCheckerTest."""
   1238   # Add all _CheckElem() test cases.
   1239   AddParametricTests('AddElem',
   1240                      {'linebreak': (True, False),
   1241                       'indent': (0, 1, 2),
   1242                       'convert': (str, lambda s: s[::-1]),
   1243                       'is_present': (True, False),
   1244                       'is_mandatory': (True, False),
   1245                       'is_submsg': (True, False)})
   1246 
   1247   # Add all _Add{Mandatory,Optional}Field tests.
   1248   AddParametricTests('AddField',
   1249                      {'is_mandatory': (True, False),
   1250                       'linebreak': (True, False),
   1251                       'indent': (0, 1, 2),
   1252                       'convert': (str, lambda s: s[::-1]),
   1253                       'is_present': (True, False)})
   1254 
   1255   # Add all _Add{Mandatory,Optional}SubMsg tests.
   1256   AddParametricTests('AddSubMsg',
   1257                      {'is_mandatory': (True, False),
   1258                       'is_present': (True, False)})
   1259 
   1260   # Add all _CheckManifest() test cases.
   1261   AddParametricTests('CheckManifest',
   1262                      {'fail_mismatched_block_size': (True, False),
   1263                       'fail_bad_sigs': (True, False),
   1264                       'fail_mismatched_oki_ori': (True, False),
   1265                       'fail_bad_oki': (True, False),
   1266                       'fail_bad_ori': (True, False),
   1267                       'fail_bad_nki': (True, False),
   1268                       'fail_bad_nri': (True, False),
   1269                       'fail_old_kernel_fs_size': (True, False),
   1270                       'fail_old_rootfs_fs_size': (True, False),
   1271                       'fail_new_kernel_fs_size': (True, False),
   1272                       'fail_new_rootfs_fs_size': (True, False)})
   1273 
   1274   # Add all _CheckOperation() test cases.
   1275   AddParametricTests('CheckOperation',
   1276                      {'op_type_name': ('REPLACE', 'REPLACE_BZ', 'MOVE',
   1277                                        'BSDIFF', 'SOURCE_COPY',
   1278                                        'SOURCE_BSDIFF'),
   1279                       'is_last': (True, False),
   1280                       'allow_signature': (True, False),
   1281                       'allow_unhashed': (True, False),
   1282                       'fail_src_extents': (True, False),
   1283                       'fail_dst_extents': (True, False),
   1284                       'fail_mismatched_data_offset_length': (True, False),
   1285                       'fail_missing_dst_extents': (True, False),
   1286                       'fail_src_length': (True, False),
   1287                       'fail_dst_length': (True, False),
   1288                       'fail_data_hash': (True, False),
   1289                       'fail_prev_data_offset': (True, False),
   1290                       'fail_bad_minor_version': (True, False)},
   1291                      validate_func=ValidateCheckOperationTest)
   1292 
   1293   # Add all _CheckOperations() test cases.
   1294   AddParametricTests('CheckOperations',
   1295                      {'fail_nonexhaustive_full_update': (True, False)})
   1296 
   1297   # Add all _CheckOperations() test cases.
   1298   AddParametricTests('CheckSignatures',
   1299                      {'fail_empty_sigs_blob': (True, False),
   1300                       'fail_missing_pseudo_op': (True, False),
   1301                       'fail_mismatched_pseudo_op': (True, False),
   1302                       'fail_sig_missing_fields': (True, False),
   1303                       'fail_unknown_sig_version': (True, False),
   1304                       'fail_incorrect_sig': (True, False)})
   1305 
   1306   # Add all _CheckManifestMinorVersion() test cases.
   1307   AddParametricTests('CheckManifestMinorVersion',
   1308                      {'minor_version': (None, 0, 1, 2, 3, 4, 555),
   1309                       'payload_type': (checker._TYPE_FULL,
   1310                                        checker._TYPE_DELTA)})
   1311 
   1312   # Add all Run() test cases.
   1313   AddParametricTests('Run',
   1314                      {'rootfs_part_size_provided': (True, False),
   1315                       'kernel_part_size_provided': (True, False),
   1316                       'fail_wrong_payload_type': (True, False),
   1317                       'fail_invalid_block_size': (True, False),
   1318                       'fail_mismatched_block_size': (True, False),
   1319                       'fail_excess_data': (True, False),
   1320                       'fail_rootfs_part_size_exceeded': (True, False),
   1321                       'fail_kernel_part_size_exceeded': (True, False)})
   1322 
   1323 
   1324 if __name__ == '__main__':
   1325   AddAllParametricTests()
   1326   unittest.main()
   1327