Home | History | Annotate | Download | only in update_payload
      1 #!/usr/bin/python2
      2 #
      3 # Copyright (c) 2013 The Chromium OS Authors. All rights reserved.
      4 # Use of this source code is governed by a BSD-style license that can be
      5 # found in the LICENSE file.
      6 
      7 """Unit testing checker.py."""
      8 
      9 from __future__ import print_function
     10 
     11 import array
     12 import collections
     13 import cStringIO
     14 import hashlib
     15 import itertools
     16 import os
     17 import unittest
     18 
     19 # pylint cannot find mox.
     20 # pylint: disable=F0401
     21 import mox
     22 
     23 from update_payload import checker
     24 from update_payload import common
     25 from update_payload import test_utils
     26 from update_payload import update_metadata_pb2
     27 from update_payload.error import PayloadError
     28 from update_payload.payload import Payload # Avoid name conflicts later.
     29 
     30 
     31 def _OpTypeByName(op_name):
     32   """Returns the type of an operation from itsname."""
     33   op_name_to_type = {
     34       'REPLACE': common.OpType.REPLACE,
     35       'REPLACE_BZ': common.OpType.REPLACE_BZ,
     36       'MOVE': common.OpType.MOVE,
     37       'BSDIFF': common.OpType.BSDIFF,
     38       'SOURCE_COPY': common.OpType.SOURCE_COPY,
     39       'SOURCE_BSDIFF': common.OpType.SOURCE_BSDIFF,
     40       'ZERO': common.OpType.ZERO,
     41       'DISCARD': common.OpType.DISCARD,
     42       'REPLACE_XZ': common.OpType.REPLACE_XZ,
     43       'PUFFDIFF': common.OpType.PUFFDIFF,
     44       'BROTLI_BSDIFF': common.OpType.BROTLI_BSDIFF,
     45   }
     46   return op_name_to_type[op_name]
     47 
     48 
     49 def _GetPayloadChecker(payload_gen_write_to_file_func, payload_gen_dargs=None,
     50                        checker_init_dargs=None):
     51   """Returns a payload checker from a given payload generator."""
     52   if payload_gen_dargs is None:
     53     payload_gen_dargs = {}
     54   if checker_init_dargs is None:
     55     checker_init_dargs = {}
     56 
     57   payload_file = cStringIO.StringIO()
     58   payload_gen_write_to_file_func(payload_file, **payload_gen_dargs)
     59   payload_file.seek(0)
     60   payload = Payload(payload_file)
     61   payload.Init()
     62   return checker.PayloadChecker(payload, **checker_init_dargs)
     63 
     64 
     65 def _GetPayloadCheckerWithData(payload_gen):
     66   """Returns a payload checker from a given payload generator."""
     67   payload_file = cStringIO.StringIO()
     68   payload_gen.WriteToFile(payload_file)
     69   payload_file.seek(0)
     70   payload = Payload(payload_file)
     71   payload.Init()
     72   return checker.PayloadChecker(payload)
     73 
     74 
     75 # This class doesn't need an __init__().
     76 # pylint: disable=W0232
     77 # Unit testing is all about running protected methods.
     78 # pylint: disable=W0212
     79 # Don't bark about missing members of classes you cannot import.
     80 # pylint: disable=E1101
     81 class PayloadCheckerTest(mox.MoxTestBase):
     82   """Tests the PayloadChecker class.
     83 
     84   In addition to ordinary testFoo() methods, which are automatically invoked by
     85   the unittest framework, in this class we make use of DoBarTest() calls that
     86   implement parametric tests of certain features. In order to invoke each test,
     87   which embodies a unique combination of parameter values, as a complete unit
     88   test, we perform explicit enumeration of the parameter space and create
     89   individual invocation contexts for each, which are then bound as
     90   testBar__param1=val1__param2=val2(). The enumeration of parameter spaces for
     91   all such tests is done in AddAllParametricTests().
     92   """
     93 
     94   def MockPayload(self):
     95     """Create a mock payload object, complete with a mock manifest."""
     96     payload = self.mox.CreateMock(Payload)
     97     payload.is_init = True
     98     payload.manifest = self.mox.CreateMock(
     99         update_metadata_pb2.DeltaArchiveManifest)
    100     return payload
    101 
    102   @staticmethod
    103   def NewExtent(start_block, num_blocks):
    104     """Returns an Extent message.
    105 
    106     Each of the provided fields is set iff it is >= 0; otherwise, it's left at
    107     its default state.
    108 
    109     Args:
    110       start_block: The starting block of the extent.
    111       num_blocks: The number of blocks in the extent.
    112 
    113     Returns:
    114       An Extent message.
    115     """
    116     ex = update_metadata_pb2.Extent()
    117     if start_block >= 0:
    118       ex.start_block = start_block
    119     if num_blocks >= 0:
    120       ex.num_blocks = num_blocks
    121     return ex
    122 
    123   @staticmethod
    124   def NewExtentList(*args):
    125     """Returns an list of extents.
    126 
    127     Args:
    128       *args: (start_block, num_blocks) pairs defining the extents.
    129 
    130     Returns:
    131       A list of Extent objects.
    132     """
    133     ex_list = []
    134     for start_block, num_blocks in args:
    135       ex_list.append(PayloadCheckerTest.NewExtent(start_block, num_blocks))
    136     return ex_list
    137 
    138   @staticmethod
    139   def AddToMessage(repeated_field, field_vals):
    140     for field_val in field_vals:
    141       new_field = repeated_field.add()
    142       new_field.CopyFrom(field_val)
    143 
    144   def SetupAddElemTest(self, is_present, is_submsg, convert=str,
    145                        linebreak=False, indent=0):
    146     """Setup for testing of _CheckElem() and its derivatives.
    147 
    148     Args:
    149       is_present: Whether or not the element is found in the message.
    150       is_submsg: Whether the element is a sub-message itself.
    151       convert: A representation conversion function.
    152       linebreak: Whether or not a linebreak is to be used in the report.
    153       indent: Indentation used for the report.
    154 
    155     Returns:
    156       msg: A mock message object.
    157       report: A mock report object.
    158       subreport: A mock sub-report object.
    159       name: An element name to check.
    160       val: Expected element value.
    161     """
    162     name = 'foo'
    163     val = 'fake submsg' if is_submsg else 'fake field'
    164     subreport = 'fake subreport'
    165 
    166     # Create a mock message.
    167     msg = self.mox.CreateMock(update_metadata_pb2._message.Message)
    168     msg.HasField(name).AndReturn(is_present)
    169     setattr(msg, name, val)
    170 
    171     # Create a mock report.
    172     report = self.mox.CreateMock(checker._PayloadReport)
    173     if is_present:
    174       if is_submsg:
    175         report.AddSubReport(name).AndReturn(subreport)
    176       else:
    177         report.AddField(name, convert(val), linebreak=linebreak, indent=indent)
    178 
    179     self.mox.ReplayAll()
    180     return (msg, report, subreport, name, val)
    181 
    182   def DoAddElemTest(self, is_present, is_mandatory, is_submsg, convert,
    183                     linebreak, indent):
    184     """Parametric testing of _CheckElem().
    185 
    186     Args:
    187       is_present: Whether or not the element is found in the message.
    188       is_mandatory: Whether or not it's a mandatory element.
    189       is_submsg: Whether the element is a sub-message itself.
    190       convert: A representation conversion function.
    191       linebreak: Whether or not a linebreak is to be used in the report.
    192       indent: Indentation used for the report.
    193     """
    194     msg, report, subreport, name, val = self.SetupAddElemTest(
    195         is_present, is_submsg, convert, linebreak, indent)
    196 
    197     args = (msg, name, report, is_mandatory, is_submsg)
    198     kwargs = {'convert': convert, 'linebreak': linebreak, 'indent': indent}
    199     if is_mandatory and not is_present:
    200       self.assertRaises(PayloadError,
    201                         checker.PayloadChecker._CheckElem, *args, **kwargs)
    202     else:
    203       ret_val, ret_subreport = checker.PayloadChecker._CheckElem(*args,
    204                                                                  **kwargs)
    205       self.assertEquals(val if is_present else None, ret_val)
    206       self.assertEquals(subreport if is_present and is_submsg else None,
    207                         ret_subreport)
    208 
    209   def DoAddFieldTest(self, is_mandatory, is_present, convert, linebreak,
    210                      indent):
    211     """Parametric testing of _Check{Mandatory,Optional}Field().
    212 
    213     Args:
    214       is_mandatory: Whether we're testing a mandatory call.
    215       is_present: Whether or not the element is found in the message.
    216       convert: A representation conversion function.
    217       linebreak: Whether or not a linebreak is to be used in the report.
    218       indent: Indentation used for the report.
    219     """
    220     msg, report, _, name, val = self.SetupAddElemTest(
    221         is_present, False, convert, linebreak, indent)
    222 
    223     # Prepare for invocation of the tested method.
    224     args = [msg, name, report]
    225     kwargs = {'convert': convert, 'linebreak': linebreak, 'indent': indent}
    226     if is_mandatory:
    227       args.append('bar')
    228       tested_func = checker.PayloadChecker._CheckMandatoryField
    229     else:
    230       tested_func = checker.PayloadChecker._CheckOptionalField
    231 
    232     # Test the method call.
    233     if is_mandatory and not is_present:
    234       self.assertRaises(PayloadError, tested_func, *args, **kwargs)
    235     else:
    236       ret_val = tested_func(*args, **kwargs)
    237       self.assertEquals(val if is_present else None, ret_val)
    238 
    239   def DoAddSubMsgTest(self, is_mandatory, is_present):
    240     """Parametrized testing of _Check{Mandatory,Optional}SubMsg().
    241 
    242     Args:
    243       is_mandatory: Whether we're testing a mandatory call.
    244       is_present: Whether or not the element is found in the message.
    245     """
    246     msg, report, subreport, name, val = self.SetupAddElemTest(is_present, True)
    247 
    248     # Prepare for invocation of the tested method.
    249     args = [msg, name, report]
    250     if is_mandatory:
    251       args.append('bar')
    252       tested_func = checker.PayloadChecker._CheckMandatorySubMsg
    253     else:
    254       tested_func = checker.PayloadChecker._CheckOptionalSubMsg
    255 
    256     # Test the method call.
    257     if is_mandatory and not is_present:
    258       self.assertRaises(PayloadError, tested_func, *args)
    259     else:
    260       ret_val, ret_subreport = tested_func(*args)
    261       self.assertEquals(val if is_present else None, ret_val)
    262       self.assertEquals(subreport if is_present else None, ret_subreport)
    263 
    264   def testCheckPresentIff(self):
    265     """Tests _CheckPresentIff()."""
    266     self.assertIsNone(checker.PayloadChecker._CheckPresentIff(
    267         None, None, 'foo', 'bar', 'baz'))
    268     self.assertIsNone(checker.PayloadChecker._CheckPresentIff(
    269         'a', 'b', 'foo', 'bar', 'baz'))
    270     self.assertRaises(PayloadError, checker.PayloadChecker._CheckPresentIff,
    271                       'a', None, 'foo', 'bar', 'baz')
    272     self.assertRaises(PayloadError, checker.PayloadChecker._CheckPresentIff,
    273                       None, 'b', 'foo', 'bar', 'baz')
    274 
    275   def DoCheckSha256SignatureTest(self, expect_pass, expect_subprocess_call,
    276                                  sig_data, sig_asn1_header,
    277                                  returned_signed_hash, expected_signed_hash):
    278     """Parametric testing of _CheckSha256SignatureTest().
    279 
    280     Args:
    281       expect_pass: Whether or not it should pass.
    282       expect_subprocess_call: Whether to expect the openssl call to happen.
    283       sig_data: The signature raw data.
    284       sig_asn1_header: The ASN1 header.
    285       returned_signed_hash: The signed hash data retuned by openssl.
    286       expected_signed_hash: The signed hash data to compare against.
    287     """
    288     try:
    289       # Stub out the subprocess invocation.
    290       self.mox.StubOutWithMock(checker.PayloadChecker, '_Run')
    291       if expect_subprocess_call:
    292         checker.PayloadChecker._Run(
    293             mox.IsA(list), send_data=sig_data).AndReturn(
    294                 (sig_asn1_header + returned_signed_hash, None))
    295 
    296       self.mox.ReplayAll()
    297       if expect_pass:
    298         self.assertIsNone(checker.PayloadChecker._CheckSha256Signature(
    299             sig_data, 'foo', expected_signed_hash, 'bar'))
    300       else:
    301         self.assertRaises(PayloadError,
    302                           checker.PayloadChecker._CheckSha256Signature,
    303                           sig_data, 'foo', expected_signed_hash, 'bar')
    304     finally:
    305       self.mox.UnsetStubs()
    306 
    307   def testCheckSha256Signature_Pass(self):
    308     """Tests _CheckSha256Signature(); pass case."""
    309     sig_data = 'fake-signature'.ljust(256)
    310     signed_hash = hashlib.sha256('fake-data').digest()
    311     self.DoCheckSha256SignatureTest(True, True, sig_data,
    312                                     common.SIG_ASN1_HEADER, signed_hash,
    313                                     signed_hash)
    314 
    315   def testCheckSha256Signature_FailBadSignature(self):
    316     """Tests _CheckSha256Signature(); fails due to malformed signature."""
    317     sig_data = 'fake-signature'  # Malformed (not 256 bytes in length).
    318     signed_hash = hashlib.sha256('fake-data').digest()
    319     self.DoCheckSha256SignatureTest(False, False, sig_data,
    320                                     common.SIG_ASN1_HEADER, signed_hash,
    321                                     signed_hash)
    322 
    323   def testCheckSha256Signature_FailBadOutputLength(self):
    324     """Tests _CheckSha256Signature(); fails due to unexpected output length."""
    325     sig_data = 'fake-signature'.ljust(256)
    326     signed_hash = 'fake-hash'  # Malformed (not 32 bytes in length).
    327     self.DoCheckSha256SignatureTest(False, True, sig_data,
    328                                     common.SIG_ASN1_HEADER, signed_hash,
    329                                     signed_hash)
    330 
    331   def testCheckSha256Signature_FailBadAsnHeader(self):
    332     """Tests _CheckSha256Signature(); fails due to bad ASN1 header."""
    333     sig_data = 'fake-signature'.ljust(256)
    334     signed_hash = hashlib.sha256('fake-data').digest()
    335     bad_asn1_header = 'bad-asn-header'.ljust(len(common.SIG_ASN1_HEADER))
    336     self.DoCheckSha256SignatureTest(False, True, sig_data, bad_asn1_header,
    337                                     signed_hash, signed_hash)
    338 
    339   def testCheckSha256Signature_FailBadHash(self):
    340     """Tests _CheckSha256Signature(); fails due to bad hash returned."""
    341     sig_data = 'fake-signature'.ljust(256)
    342     expected_signed_hash = hashlib.sha256('fake-data').digest()
    343     returned_signed_hash = hashlib.sha256('bad-fake-data').digest()
    344     self.DoCheckSha256SignatureTest(False, True, sig_data,
    345                                     common.SIG_ASN1_HEADER,
    346                                     expected_signed_hash, returned_signed_hash)
    347 
    348   def testCheckBlocksFitLength_Pass(self):
    349     """Tests _CheckBlocksFitLength(); pass case."""
    350     self.assertIsNone(checker.PayloadChecker._CheckBlocksFitLength(
    351         64, 4, 16, 'foo'))
    352     self.assertIsNone(checker.PayloadChecker._CheckBlocksFitLength(
    353         60, 4, 16, 'foo'))
    354     self.assertIsNone(checker.PayloadChecker._CheckBlocksFitLength(
    355         49, 4, 16, 'foo'))
    356     self.assertIsNone(checker.PayloadChecker._CheckBlocksFitLength(
    357         48, 3, 16, 'foo'))
    358 
    359   def testCheckBlocksFitLength_TooManyBlocks(self):
    360     """Tests _CheckBlocksFitLength(); fails due to excess blocks."""
    361     self.assertRaises(PayloadError,
    362                       checker.PayloadChecker._CheckBlocksFitLength,
    363                       64, 5, 16, 'foo')
    364     self.assertRaises(PayloadError,
    365                       checker.PayloadChecker._CheckBlocksFitLength,
    366                       60, 5, 16, 'foo')
    367     self.assertRaises(PayloadError,
    368                       checker.PayloadChecker._CheckBlocksFitLength,
    369                       49, 5, 16, 'foo')
    370     self.assertRaises(PayloadError,
    371                       checker.PayloadChecker._CheckBlocksFitLength,
    372                       48, 4, 16, 'foo')
    373 
    374   def testCheckBlocksFitLength_TooFewBlocks(self):
    375     """Tests _CheckBlocksFitLength(); fails due to insufficient blocks."""
    376     self.assertRaises(PayloadError,
    377                       checker.PayloadChecker._CheckBlocksFitLength,
    378                       64, 3, 16, 'foo')
    379     self.assertRaises(PayloadError,
    380                       checker.PayloadChecker._CheckBlocksFitLength,
    381                       60, 3, 16, 'foo')
    382     self.assertRaises(PayloadError,
    383                       checker.PayloadChecker._CheckBlocksFitLength,
    384                       49, 3, 16, 'foo')
    385     self.assertRaises(PayloadError,
    386                       checker.PayloadChecker._CheckBlocksFitLength,
    387                       48, 2, 16, 'foo')
    388 
    389   def DoCheckManifestTest(self, fail_mismatched_block_size, fail_bad_sigs,
    390                           fail_mismatched_oki_ori, fail_bad_oki, fail_bad_ori,
    391                           fail_bad_nki, fail_bad_nri, fail_old_kernel_fs_size,
    392                           fail_old_rootfs_fs_size, fail_new_kernel_fs_size,
    393                           fail_new_rootfs_fs_size):
    394     """Parametric testing of _CheckManifest().
    395 
    396     Args:
    397       fail_mismatched_block_size: Simulate a missing block_size field.
    398       fail_bad_sigs: Make signatures descriptor inconsistent.
    399       fail_mismatched_oki_ori: Make old rootfs/kernel info partially present.
    400       fail_bad_oki: Tamper with old kernel info.
    401       fail_bad_ori: Tamper with old rootfs info.
    402       fail_bad_nki: Tamper with new kernel info.
    403       fail_bad_nri: Tamper with new rootfs info.
    404       fail_old_kernel_fs_size: Make old kernel fs size too big.
    405       fail_old_rootfs_fs_size: Make old rootfs fs size too big.
    406       fail_new_kernel_fs_size: Make new kernel fs size too big.
    407       fail_new_rootfs_fs_size: Make new rootfs fs size too big.
    408     """
    409     # Generate a test payload. For this test, we only care about the manifest
    410     # and don't need any data blobs, hence we can use a plain paylaod generator
    411     # (which also gives us more control on things that can be screwed up).
    412     payload_gen = test_utils.PayloadGenerator()
    413 
    414     # Tamper with block size, if required.
    415     if fail_mismatched_block_size:
    416       payload_gen.SetBlockSize(test_utils.KiB(1))
    417     else:
    418       payload_gen.SetBlockSize(test_utils.KiB(4))
    419 
    420     # Add some operations.
    421     payload_gen.AddOperation(False, common.OpType.MOVE,
    422                              src_extents=[(0, 16), (16, 497)],
    423                              dst_extents=[(16, 496), (0, 16)])
    424     payload_gen.AddOperation(True, common.OpType.MOVE,
    425                              src_extents=[(0, 8), (8, 8)],
    426                              dst_extents=[(8, 8), (0, 8)])
    427 
    428     # Set an invalid signatures block (offset but no size), if required.
    429     if fail_bad_sigs:
    430       payload_gen.SetSignatures(32, None)
    431 
    432     # Set partition / filesystem sizes.
    433     rootfs_part_size = test_utils.MiB(8)
    434     kernel_part_size = test_utils.KiB(512)
    435     old_rootfs_fs_size = new_rootfs_fs_size = rootfs_part_size
    436     old_kernel_fs_size = new_kernel_fs_size = kernel_part_size
    437     if fail_old_kernel_fs_size:
    438       old_kernel_fs_size += 100
    439     if fail_old_rootfs_fs_size:
    440       old_rootfs_fs_size += 100
    441     if fail_new_kernel_fs_size:
    442       new_kernel_fs_size += 100
    443     if fail_new_rootfs_fs_size:
    444       new_rootfs_fs_size += 100
    445 
    446     # Add old kernel/rootfs partition info, as required.
    447     if fail_mismatched_oki_ori or fail_old_kernel_fs_size or fail_bad_oki:
    448       oki_hash = (None if fail_bad_oki
    449                   else hashlib.sha256('fake-oki-content').digest())
    450       payload_gen.SetPartInfo(True, False, old_kernel_fs_size, oki_hash)
    451     if not fail_mismatched_oki_ori and (fail_old_rootfs_fs_size or
    452                                         fail_bad_ori):
    453       ori_hash = (None if fail_bad_ori
    454                   else hashlib.sha256('fake-ori-content').digest())
    455       payload_gen.SetPartInfo(False, False, old_rootfs_fs_size, ori_hash)
    456 
    457     # Add new kernel/rootfs partition info.
    458     payload_gen.SetPartInfo(
    459         True, True, new_kernel_fs_size,
    460         None if fail_bad_nki else hashlib.sha256('fake-nki-content').digest())
    461     payload_gen.SetPartInfo(
    462         False, True, new_rootfs_fs_size,
    463         None if fail_bad_nri else hashlib.sha256('fake-nri-content').digest())
    464 
    465     # Set the minor version.
    466     payload_gen.SetMinorVersion(0)
    467 
    468     # Create the test object.
    469     payload_checker = _GetPayloadChecker(payload_gen.WriteToFile)
    470     report = checker._PayloadReport()
    471 
    472     should_fail = (fail_mismatched_block_size or fail_bad_sigs or
    473                    fail_mismatched_oki_ori or fail_bad_oki or fail_bad_ori or
    474                    fail_bad_nki or fail_bad_nri or fail_old_kernel_fs_size or
    475                    fail_old_rootfs_fs_size or fail_new_kernel_fs_size or
    476                    fail_new_rootfs_fs_size)
    477     if should_fail:
    478       self.assertRaises(PayloadError, payload_checker._CheckManifest, report,
    479                         rootfs_part_size, kernel_part_size)
    480     else:
    481       self.assertIsNone(payload_checker._CheckManifest(report,
    482                                                        rootfs_part_size,
    483                                                        kernel_part_size))
    484 
    485   def testCheckLength(self):
    486     """Tests _CheckLength()."""
    487     payload_checker = checker.PayloadChecker(self.MockPayload())
    488     block_size = payload_checker.block_size
    489 
    490     # Passes.
    491     self.assertIsNone(payload_checker._CheckLength(
    492         int(3.5 * block_size), 4, 'foo', 'bar'))
    493     # Fails, too few blocks.
    494     self.assertRaises(PayloadError, payload_checker._CheckLength,
    495                       int(3.5 * block_size), 3, 'foo', 'bar')
    496     # Fails, too many blocks.
    497     self.assertRaises(PayloadError, payload_checker._CheckLength,
    498                       int(3.5 * block_size), 5, 'foo', 'bar')
    499 
    500   def testCheckExtents(self):
    501     """Tests _CheckExtents()."""
    502     payload_checker = checker.PayloadChecker(self.MockPayload())
    503     block_size = payload_checker.block_size
    504 
    505     # Passes w/ all real extents.
    506     extents = self.NewExtentList((0, 4), (8, 3), (1024, 16))
    507     self.assertEquals(
    508         23,
    509         payload_checker._CheckExtents(extents, (1024 + 16) * block_size,
    510                                       collections.defaultdict(int), 'foo'))
    511 
    512     # Passes w/ pseudo-extents (aka sparse holes).
    513     extents = self.NewExtentList((0, 4), (common.PSEUDO_EXTENT_MARKER, 5),
    514                                  (8, 3))
    515     self.assertEquals(
    516         12,
    517         payload_checker._CheckExtents(extents, (1024 + 16) * block_size,
    518                                       collections.defaultdict(int), 'foo',
    519                                       allow_pseudo=True))
    520 
    521     # Passes w/ pseudo-extent due to a signature.
    522     extents = self.NewExtentList((common.PSEUDO_EXTENT_MARKER, 2))
    523     self.assertEquals(
    524         2,
    525         payload_checker._CheckExtents(extents, (1024 + 16) * block_size,
    526                                       collections.defaultdict(int), 'foo',
    527                                       allow_signature=True))
    528 
    529     # Fails, extent missing a start block.
    530     extents = self.NewExtentList((-1, 4), (8, 3), (1024, 16))
    531     self.assertRaises(
    532         PayloadError, payload_checker._CheckExtents, extents,
    533         (1024 + 16) * block_size, collections.defaultdict(int), 'foo')
    534 
    535     # Fails, extent missing block count.
    536     extents = self.NewExtentList((0, -1), (8, 3), (1024, 16))
    537     self.assertRaises(
    538         PayloadError, payload_checker._CheckExtents, extents,
    539         (1024 + 16) * block_size, collections.defaultdict(int), 'foo')
    540 
    541     # Fails, extent has zero blocks.
    542     extents = self.NewExtentList((0, 4), (8, 3), (1024, 0))
    543     self.assertRaises(
    544         PayloadError, payload_checker._CheckExtents, extents,
    545         (1024 + 16) * block_size, collections.defaultdict(int), 'foo')
    546 
    547     # Fails, extent exceeds partition boundaries.
    548     extents = self.NewExtentList((0, 4), (8, 3), (1024, 16))
    549     self.assertRaises(
    550         PayloadError, payload_checker._CheckExtents, extents,
    551         (1024 + 15) * block_size, collections.defaultdict(int), 'foo')
    552 
    553   def testCheckReplaceOperation(self):
    554     """Tests _CheckReplaceOperation() where op.type == REPLACE."""
    555     payload_checker = checker.PayloadChecker(self.MockPayload())
    556     block_size = payload_checker.block_size
    557     data_length = 10000
    558 
    559     op = self.mox.CreateMock(
    560         update_metadata_pb2.InstallOperation)
    561     op.type = common.OpType.REPLACE
    562 
    563     # Pass.
    564     op.src_extents = []
    565     self.assertIsNone(
    566         payload_checker._CheckReplaceOperation(
    567             op, data_length, (data_length + block_size - 1) / block_size,
    568             'foo'))
    569 
    570     # Fail, src extents founds.
    571     op.src_extents = ['bar']
    572     self.assertRaises(
    573         PayloadError, payload_checker._CheckReplaceOperation,
    574         op, data_length, (data_length + block_size - 1) / block_size, 'foo')
    575 
    576     # Fail, missing data.
    577     op.src_extents = []
    578     self.assertRaises(
    579         PayloadError, payload_checker._CheckReplaceOperation,
    580         op, None, (data_length + block_size - 1) / block_size, 'foo')
    581 
    582     # Fail, length / block number mismatch.
    583     op.src_extents = ['bar']
    584     self.assertRaises(
    585         PayloadError, payload_checker._CheckReplaceOperation,
    586         op, data_length, (data_length + block_size - 1) / block_size + 1, 'foo')
    587 
    588   def testCheckReplaceBzOperation(self):
    589     """Tests _CheckReplaceOperation() where op.type == REPLACE_BZ."""
    590     payload_checker = checker.PayloadChecker(self.MockPayload())
    591     block_size = payload_checker.block_size
    592     data_length = block_size * 3
    593 
    594     op = self.mox.CreateMock(
    595         update_metadata_pb2.InstallOperation)
    596     op.type = common.OpType.REPLACE_BZ
    597 
    598     # Pass.
    599     op.src_extents = []
    600     self.assertIsNone(
    601         payload_checker._CheckReplaceOperation(
    602             op, data_length, (data_length + block_size - 1) / block_size + 5,
    603             'foo'))
    604 
    605     # Fail, src extents founds.
    606     op.src_extents = ['bar']
    607     self.assertRaises(
    608         PayloadError, payload_checker._CheckReplaceOperation,
    609         op, data_length, (data_length + block_size - 1) / block_size + 5, 'foo')
    610 
    611     # Fail, missing data.
    612     op.src_extents = []
    613     self.assertRaises(
    614         PayloadError, payload_checker._CheckReplaceOperation,
    615         op, None, (data_length + block_size - 1) / block_size, 'foo')
    616 
    617     # Fail, too few blocks to justify BZ.
    618     op.src_extents = []
    619     self.assertRaises(
    620         PayloadError, payload_checker._CheckReplaceOperation,
    621         op, data_length, (data_length + block_size - 1) / block_size, 'foo')
    622 
    623   def testCheckMoveOperation_Pass(self):
    624     """Tests _CheckMoveOperation(); pass case."""
    625     payload_checker = checker.PayloadChecker(self.MockPayload())
    626     op = update_metadata_pb2.InstallOperation()
    627     op.type = common.OpType.MOVE
    628 
    629     self.AddToMessage(op.src_extents,
    630                       self.NewExtentList((1, 4), (12, 2), (1024, 128)))
    631     self.AddToMessage(op.dst_extents,
    632                       self.NewExtentList((16, 128), (512, 6)))
    633     self.assertIsNone(
    634         payload_checker._CheckMoveOperation(op, None, 134, 134, 'foo'))
    635 
    636   def testCheckMoveOperation_FailContainsData(self):
    637     """Tests _CheckMoveOperation(); fails, message contains data."""
    638     payload_checker = checker.PayloadChecker(self.MockPayload())
    639     op = update_metadata_pb2.InstallOperation()
    640     op.type = common.OpType.MOVE
    641 
    642     self.AddToMessage(op.src_extents,
    643                       self.NewExtentList((1, 4), (12, 2), (1024, 128)))
    644     self.AddToMessage(op.dst_extents,
    645                       self.NewExtentList((16, 128), (512, 6)))
    646     self.assertRaises(
    647         PayloadError, payload_checker._CheckMoveOperation,
    648         op, 1024, 134, 134, 'foo')
    649 
    650   def testCheckMoveOperation_FailInsufficientSrcBlocks(self):
    651     """Tests _CheckMoveOperation(); fails, not enough actual src blocks."""
    652     payload_checker = checker.PayloadChecker(self.MockPayload())
    653     op = update_metadata_pb2.InstallOperation()
    654     op.type = common.OpType.MOVE
    655 
    656     self.AddToMessage(op.src_extents,
    657                       self.NewExtentList((1, 4), (12, 2), (1024, 127)))
    658     self.AddToMessage(op.dst_extents,
    659                       self.NewExtentList((16, 128), (512, 6)))
    660     self.assertRaises(
    661         PayloadError, payload_checker._CheckMoveOperation,
    662         op, None, 134, 134, 'foo')
    663 
    664   def testCheckMoveOperation_FailInsufficientDstBlocks(self):
    665     """Tests _CheckMoveOperation(); fails, not enough actual dst blocks."""
    666     payload_checker = checker.PayloadChecker(self.MockPayload())
    667     op = update_metadata_pb2.InstallOperation()
    668     op.type = common.OpType.MOVE
    669 
    670     self.AddToMessage(op.src_extents,
    671                       self.NewExtentList((1, 4), (12, 2), (1024, 128)))
    672     self.AddToMessage(op.dst_extents,
    673                       self.NewExtentList((16, 128), (512, 5)))
    674     self.assertRaises(
    675         PayloadError, payload_checker._CheckMoveOperation,
    676         op, None, 134, 134, 'foo')
    677 
    678   def testCheckMoveOperation_FailExcessSrcBlocks(self):
    679     """Tests _CheckMoveOperation(); fails, too many actual src blocks."""
    680     payload_checker = checker.PayloadChecker(self.MockPayload())
    681     op = update_metadata_pb2.InstallOperation()
    682     op.type = common.OpType.MOVE
    683 
    684     self.AddToMessage(op.src_extents,
    685                       self.NewExtentList((1, 4), (12, 2), (1024, 128)))
    686     self.AddToMessage(op.dst_extents,
    687                       self.NewExtentList((16, 128), (512, 5)))
    688     self.assertRaises(
    689         PayloadError, payload_checker._CheckMoveOperation,
    690         op, None, 134, 134, 'foo')
    691     self.AddToMessage(op.src_extents,
    692                       self.NewExtentList((1, 4), (12, 2), (1024, 129)))
    693     self.AddToMessage(op.dst_extents,
    694                       self.NewExtentList((16, 128), (512, 6)))
    695     self.assertRaises(
    696         PayloadError, payload_checker._CheckMoveOperation,
    697         op, None, 134, 134, 'foo')
    698 
    699   def testCheckMoveOperation_FailExcessDstBlocks(self):
    700     """Tests _CheckMoveOperation(); fails, too many actual dst blocks."""
    701     payload_checker = checker.PayloadChecker(self.MockPayload())
    702     op = update_metadata_pb2.InstallOperation()
    703     op.type = common.OpType.MOVE
    704 
    705     self.AddToMessage(op.src_extents,
    706                       self.NewExtentList((1, 4), (12, 2), (1024, 128)))
    707     self.AddToMessage(op.dst_extents,
    708                       self.NewExtentList((16, 128), (512, 7)))
    709     self.assertRaises(
    710         PayloadError, payload_checker._CheckMoveOperation,
    711         op, None, 134, 134, 'foo')
    712 
    713   def testCheckMoveOperation_FailStagnantBlocks(self):
    714     """Tests _CheckMoveOperation(); fails, there are blocks that do not move."""
    715     payload_checker = checker.PayloadChecker(self.MockPayload())
    716     op = update_metadata_pb2.InstallOperation()
    717     op.type = common.OpType.MOVE
    718 
    719     self.AddToMessage(op.src_extents,
    720                       self.NewExtentList((1, 4), (12, 2), (1024, 128)))
    721     self.AddToMessage(op.dst_extents,
    722                       self.NewExtentList((8, 128), (512, 6)))
    723     self.assertRaises(
    724         PayloadError, payload_checker._CheckMoveOperation,
    725         op, None, 134, 134, 'foo')
    726 
    727   def testCheckMoveOperation_FailZeroStartBlock(self):
    728     """Tests _CheckMoveOperation(); fails, has extent with start block 0."""
    729     payload_checker = checker.PayloadChecker(self.MockPayload())
    730     op = update_metadata_pb2.InstallOperation()
    731     op.type = common.OpType.MOVE
    732 
    733     self.AddToMessage(op.src_extents,
    734                       self.NewExtentList((0, 4), (12, 2), (1024, 128)))
    735     self.AddToMessage(op.dst_extents,
    736                       self.NewExtentList((8, 128), (512, 6)))
    737     self.assertRaises(
    738         PayloadError, payload_checker._CheckMoveOperation,
    739         op, None, 134, 134, 'foo')
    740 
    741     self.AddToMessage(op.src_extents,
    742                       self.NewExtentList((1, 4), (12, 2), (1024, 128)))
    743     self.AddToMessage(op.dst_extents,
    744                       self.NewExtentList((0, 128), (512, 6)))
    745     self.assertRaises(
    746         PayloadError, payload_checker._CheckMoveOperation,
    747         op, None, 134, 134, 'foo')
    748 
    749   def testCheckAnyDiff(self):
    750     """Tests _CheckAnyDiffOperation()."""
    751     payload_checker = checker.PayloadChecker(self.MockPayload())
    752     op = update_metadata_pb2.InstallOperation()
    753 
    754     # Pass.
    755     self.assertIsNone(
    756         payload_checker._CheckAnyDiffOperation(op, 10000, 3, 'foo'))
    757 
    758     # Fail, missing data blob.
    759     self.assertRaises(
    760         PayloadError, payload_checker._CheckAnyDiffOperation,
    761         op, None, 3, 'foo')
    762 
    763     # Fail, too big of a diff blob (unjustified).
    764     self.assertRaises(
    765         PayloadError, payload_checker._CheckAnyDiffOperation,
    766         op, 10000, 2, 'foo')
    767 
    768   def testCheckSourceCopyOperation_Pass(self):
    769     """Tests _CheckSourceCopyOperation(); pass case."""
    770     payload_checker = checker.PayloadChecker(self.MockPayload())
    771     self.assertIsNone(
    772         payload_checker._CheckSourceCopyOperation(None, 134, 134, 'foo'))
    773 
    774   def testCheckSourceCopyOperation_FailContainsData(self):
    775     """Tests _CheckSourceCopyOperation(); message contains data."""
    776     payload_checker = checker.PayloadChecker(self.MockPayload())
    777     self.assertRaises(PayloadError, payload_checker._CheckSourceCopyOperation,
    778                       134, 0, 0, 'foo')
    779 
    780   def testCheckSourceCopyOperation_FailBlockCountsMismatch(self):
    781     """Tests _CheckSourceCopyOperation(); src and dst block totals not equal."""
    782     payload_checker = checker.PayloadChecker(self.MockPayload())
    783     self.assertRaises(PayloadError, payload_checker._CheckSourceCopyOperation,
    784                       None, 0, 1, 'foo')
    785 
    786   def DoCheckOperationTest(self, op_type_name, is_last, allow_signature,
    787                            allow_unhashed, fail_src_extents, fail_dst_extents,
    788                            fail_mismatched_data_offset_length,
    789                            fail_missing_dst_extents, fail_src_length,
    790                            fail_dst_length, fail_data_hash,
    791                            fail_prev_data_offset, fail_bad_minor_version):
    792     """Parametric testing of _CheckOperation().
    793 
    794     Args:
    795       op_type_name: 'REPLACE', 'REPLACE_BZ', 'MOVE', 'BSDIFF', 'SOURCE_COPY',
    796         'SOURCE_BSDIFF', BROTLI_BSDIFF or 'PUFFDIFF'.
    797       is_last: Whether we're testing the last operation in a sequence.
    798       allow_signature: Whether we're testing a signature-capable operation.
    799       allow_unhashed: Whether we're allowing to not hash the data.
    800       fail_src_extents: Tamper with src extents.
    801       fail_dst_extents: Tamper with dst extents.
    802       fail_mismatched_data_offset_length: Make data_{offset,length}
    803         inconsistent.
    804       fail_missing_dst_extents: Do not include dst extents.
    805       fail_src_length: Make src length inconsistent.
    806       fail_dst_length: Make dst length inconsistent.
    807       fail_data_hash: Tamper with the data blob hash.
    808       fail_prev_data_offset: Make data space uses incontiguous.
    809       fail_bad_minor_version: Make minor version incompatible with op.
    810     """
    811     op_type = _OpTypeByName(op_type_name)
    812 
    813     # Create the test object.
    814     payload = self.MockPayload()
    815     payload_checker = checker.PayloadChecker(payload,
    816                                              allow_unhashed=allow_unhashed)
    817     block_size = payload_checker.block_size
    818 
    819     # Create auxiliary arguments.
    820     old_part_size = test_utils.MiB(4)
    821     new_part_size = test_utils.MiB(8)
    822     old_block_counters = array.array(
    823         'B', [0] * ((old_part_size + block_size - 1) / block_size))
    824     new_block_counters = array.array(
    825         'B', [0] * ((new_part_size + block_size - 1) / block_size))
    826     prev_data_offset = 1876
    827     blob_hash_counts = collections.defaultdict(int)
    828 
    829     # Create the operation object for the test.
    830     op = update_metadata_pb2.InstallOperation()
    831     op.type = op_type
    832 
    833     total_src_blocks = 0
    834     if op_type in (common.OpType.MOVE, common.OpType.BSDIFF,
    835                    common.OpType.SOURCE_COPY, common.OpType.SOURCE_BSDIFF,
    836                    common.OpType.PUFFDIFF, common.OpType.BROTLI_BSDIFF):
    837       if fail_src_extents:
    838         self.AddToMessage(op.src_extents,
    839                           self.NewExtentList((1, 0)))
    840       else:
    841         self.AddToMessage(op.src_extents,
    842                           self.NewExtentList((1, 16)))
    843         total_src_blocks = 16
    844 
    845     if op_type in (common.OpType.REPLACE, common.OpType.REPLACE_BZ):
    846       payload_checker.minor_version = 0
    847     elif op_type in (common.OpType.MOVE, common.OpType.BSDIFF):
    848       payload_checker.minor_version = 2 if fail_bad_minor_version else 1
    849     elif op_type in (common.OpType.SOURCE_COPY, common.OpType.SOURCE_BSDIFF):
    850       payload_checker.minor_version = 1 if fail_bad_minor_version else 2
    851     elif op_type in (common.OpType.ZERO, common.OpType.DISCARD,
    852                      common.OpType.BROTLI_BSDIFF):
    853       payload_checker.minor_version = 3 if fail_bad_minor_version else 4
    854     elif op_type == common.OpType.PUFFDIFF:
    855       payload_checker.minor_version = 4 if fail_bad_minor_version else 5
    856 
    857     if op_type not in (common.OpType.MOVE, common.OpType.SOURCE_COPY):
    858       if not fail_mismatched_data_offset_length:
    859         op.data_length = 16 * block_size - 8
    860       if fail_prev_data_offset:
    861         op.data_offset = prev_data_offset + 16
    862       else:
    863         op.data_offset = prev_data_offset
    864 
    865       fake_data = 'fake-data'.ljust(op.data_length)
    866       if not (allow_unhashed or (is_last and allow_signature and
    867                                  op_type == common.OpType.REPLACE)):
    868         if not fail_data_hash:
    869           # Create a valid data blob hash.
    870           op.data_sha256_hash = hashlib.sha256(fake_data).digest()
    871           payload.ReadDataBlob(op.data_offset, op.data_length).AndReturn(
    872               fake_data)
    873 
    874       elif fail_data_hash:
    875         # Create an invalid data blob hash.
    876         op.data_sha256_hash = hashlib.sha256(
    877             fake_data.replace(' ', '-')).digest()
    878         payload.ReadDataBlob(op.data_offset, op.data_length).AndReturn(
    879             fake_data)
    880 
    881     total_dst_blocks = 0
    882     if not fail_missing_dst_extents:
    883       total_dst_blocks = 16
    884       if fail_dst_extents:
    885         self.AddToMessage(op.dst_extents,
    886                           self.NewExtentList((4, 16), (32, 0)))
    887       else:
    888         self.AddToMessage(op.dst_extents,
    889                           self.NewExtentList((4, 8), (64, 8)))
    890 
    891     if total_src_blocks:
    892       if fail_src_length:
    893         op.src_length = total_src_blocks * block_size + 8
    894       elif (op_type in (common.OpType.MOVE, common.OpType.BSDIFF,
    895                         common.OpType.SOURCE_BSDIFF) and
    896             payload_checker.minor_version <= 3):
    897         op.src_length = total_src_blocks * block_size
    898     elif fail_src_length:
    899       # Add an orphaned src_length.
    900       op.src_length = 16
    901 
    902     if total_dst_blocks:
    903       if fail_dst_length:
    904         op.dst_length = total_dst_blocks * block_size + 8
    905       elif (op_type in (common.OpType.MOVE, common.OpType.BSDIFF,
    906                         common.OpType.SOURCE_BSDIFF) and
    907             payload_checker.minor_version <= 3):
    908         op.dst_length = total_dst_blocks * block_size
    909 
    910     self.mox.ReplayAll()
    911     should_fail = (fail_src_extents or fail_dst_extents or
    912                    fail_mismatched_data_offset_length or
    913                    fail_missing_dst_extents or fail_src_length or
    914                    fail_dst_length or fail_data_hash or fail_prev_data_offset or
    915                    fail_bad_minor_version)
    916     args = (op, 'foo', is_last, old_block_counters, new_block_counters,
    917             old_part_size, new_part_size, prev_data_offset, allow_signature,
    918             blob_hash_counts)
    919     if should_fail:
    920       self.assertRaises(PayloadError, payload_checker._CheckOperation, *args)
    921     else:
    922       self.assertEqual(op.data_length if op.HasField('data_length') else 0,
    923                        payload_checker._CheckOperation(*args))
    924 
    925   def testAllocBlockCounters(self):
    926     """Tests _CheckMoveOperation()."""
    927     payload_checker = checker.PayloadChecker(self.MockPayload())
    928     block_size = payload_checker.block_size
    929 
    930     # Check allocation for block-aligned partition size, ensure it's integers.
    931     result = payload_checker._AllocBlockCounters(16 * block_size)
    932     self.assertEqual(16, len(result))
    933     self.assertEqual(int, type(result[0]))
    934 
    935     # Check allocation of unaligned partition sizes.
    936     result = payload_checker._AllocBlockCounters(16 * block_size - 1)
    937     self.assertEqual(16, len(result))
    938     result = payload_checker._AllocBlockCounters(16 * block_size + 1)
    939     self.assertEqual(17, len(result))
    940 
    941   def DoCheckOperationsTest(self, fail_nonexhaustive_full_update):
    942     """Tests _CheckOperations()."""
    943     # Generate a test payload. For this test, we only care about one
    944     # (arbitrary) set of operations, so we'll only be generating kernel and
    945     # test with them.
    946     payload_gen = test_utils.PayloadGenerator()
    947 
    948     block_size = test_utils.KiB(4)
    949     payload_gen.SetBlockSize(block_size)
    950 
    951     rootfs_part_size = test_utils.MiB(8)
    952 
    953     # Fake rootfs operations in a full update, tampered with as required.
    954     rootfs_op_type = common.OpType.REPLACE
    955     rootfs_data_length = rootfs_part_size
    956     if fail_nonexhaustive_full_update:
    957       rootfs_data_length -= block_size
    958 
    959     payload_gen.AddOperation(False, rootfs_op_type,
    960                              dst_extents=[(0, rootfs_data_length / block_size)],
    961                              data_offset=0,
    962                              data_length=rootfs_data_length)
    963 
    964     # Create the test object.
    965     payload_checker = _GetPayloadChecker(payload_gen.WriteToFile,
    966                                          checker_init_dargs={
    967                                              'allow_unhashed': True})
    968     payload_checker.payload_type = checker._TYPE_FULL
    969     report = checker._PayloadReport()
    970 
    971     args = (payload_checker.payload.manifest.install_operations, report, 'foo',
    972             0, rootfs_part_size, rootfs_part_size, rootfs_part_size, 0, False)
    973     if fail_nonexhaustive_full_update:
    974       self.assertRaises(PayloadError, payload_checker._CheckOperations, *args)
    975     else:
    976       self.assertEqual(rootfs_data_length,
    977                        payload_checker._CheckOperations(*args))
    978 
    979   def DoCheckSignaturesTest(self, fail_empty_sigs_blob, fail_missing_pseudo_op,
    980                             fail_mismatched_pseudo_op, fail_sig_missing_fields,
    981                             fail_unknown_sig_version, fail_incorrect_sig):
    982     """Tests _CheckSignatures()."""
    983     # Generate a test payload. For this test, we only care about the signature
    984     # block and how it relates to the payload hash. Therefore, we're generating
    985     # a random (otherwise useless) payload for this purpose.
    986     payload_gen = test_utils.EnhancedPayloadGenerator()
    987     block_size = test_utils.KiB(4)
    988     payload_gen.SetBlockSize(block_size)
    989     rootfs_part_size = test_utils.MiB(2)
    990     kernel_part_size = test_utils.KiB(16)
    991     payload_gen.SetPartInfo(False, True, rootfs_part_size,
    992                             hashlib.sha256('fake-new-rootfs-content').digest())
    993     payload_gen.SetPartInfo(True, True, kernel_part_size,
    994                             hashlib.sha256('fake-new-kernel-content').digest())
    995     payload_gen.SetMinorVersion(0)
    996     payload_gen.AddOperationWithData(
    997         False, common.OpType.REPLACE,
    998         dst_extents=[(0, rootfs_part_size / block_size)],
    999         data_blob=os.urandom(rootfs_part_size))
   1000 
   1001     do_forge_pseudo_op = (fail_missing_pseudo_op or fail_mismatched_pseudo_op)
   1002     do_forge_sigs_data = (do_forge_pseudo_op or fail_empty_sigs_blob or
   1003                           fail_sig_missing_fields or fail_unknown_sig_version
   1004                           or fail_incorrect_sig)
   1005 
   1006     sigs_data = None
   1007     if do_forge_sigs_data:
   1008       sigs_gen = test_utils.SignaturesGenerator()
   1009       if not fail_empty_sigs_blob:
   1010         if fail_sig_missing_fields:
   1011           sig_data = None
   1012         else:
   1013           sig_data = test_utils.SignSha256('fake-payload-content',
   1014                                            test_utils._PRIVKEY_FILE_NAME)
   1015         sigs_gen.AddSig(5 if fail_unknown_sig_version else 1, sig_data)
   1016 
   1017       sigs_data = sigs_gen.ToBinary()
   1018       payload_gen.SetSignatures(payload_gen.curr_offset, len(sigs_data))
   1019 
   1020     if do_forge_pseudo_op:
   1021       assert sigs_data is not None, 'should have forged signatures blob by now'
   1022       sigs_len = len(sigs_data)
   1023       payload_gen.AddOperation(
   1024           False, common.OpType.REPLACE,
   1025           data_offset=payload_gen.curr_offset / 2,
   1026           data_length=sigs_len / 2,
   1027           dst_extents=[(0, (sigs_len / 2 + block_size - 1) / block_size)])
   1028 
   1029     # Generate payload (complete w/ signature) and create the test object.
   1030     payload_checker = _GetPayloadChecker(
   1031         payload_gen.WriteToFileWithData,
   1032         payload_gen_dargs={
   1033             'sigs_data': sigs_data,
   1034             'privkey_file_name': test_utils._PRIVKEY_FILE_NAME,
   1035             'do_add_pseudo_operation': not do_forge_pseudo_op})
   1036     payload_checker.payload_type = checker._TYPE_FULL
   1037     report = checker._PayloadReport()
   1038 
   1039     # We have to check the manifest first in order to set signature attributes.
   1040     payload_checker._CheckManifest(report, rootfs_part_size, kernel_part_size)
   1041 
   1042     should_fail = (fail_empty_sigs_blob or fail_missing_pseudo_op or
   1043                    fail_mismatched_pseudo_op or fail_sig_missing_fields or
   1044                    fail_unknown_sig_version or fail_incorrect_sig)
   1045     args = (report, test_utils._PUBKEY_FILE_NAME)
   1046     if should_fail:
   1047       self.assertRaises(PayloadError, payload_checker._CheckSignatures, *args)
   1048     else:
   1049       self.assertIsNone(payload_checker._CheckSignatures(*args))
   1050 
   1051   def DoCheckManifestMinorVersionTest(self, minor_version, payload_type):
   1052     """Parametric testing for CheckManifestMinorVersion().
   1053 
   1054     Args:
   1055       minor_version: The payload minor version to test with.
   1056       payload_type: The type of the payload we're testing, delta or full.
   1057     """
   1058     # Create the test object.
   1059     payload = self.MockPayload()
   1060     payload.manifest.minor_version = minor_version
   1061     payload_checker = checker.PayloadChecker(payload)
   1062     payload_checker.payload_type = payload_type
   1063     report = checker._PayloadReport()
   1064 
   1065     should_succeed = (
   1066         (minor_version == 0 and payload_type == checker._TYPE_FULL) or
   1067         (minor_version == 1 and payload_type == checker._TYPE_DELTA) or
   1068         (minor_version == 2 and payload_type == checker._TYPE_DELTA) or
   1069         (minor_version == 3 and payload_type == checker._TYPE_DELTA) or
   1070         (minor_version == 4 and payload_type == checker._TYPE_DELTA) or
   1071         (minor_version == 5 and payload_type == checker._TYPE_DELTA))
   1072     args = (report,)
   1073 
   1074     if should_succeed:
   1075       self.assertIsNone(payload_checker._CheckManifestMinorVersion(*args))
   1076     else:
   1077       self.assertRaises(PayloadError,
   1078                         payload_checker._CheckManifestMinorVersion, *args)
   1079 
   1080   def DoRunTest(self, rootfs_part_size_provided, kernel_part_size_provided,
   1081                 fail_wrong_payload_type, fail_invalid_block_size,
   1082                 fail_mismatched_block_size, fail_excess_data,
   1083                 fail_rootfs_part_size_exceeded,
   1084                 fail_kernel_part_size_exceeded):
   1085     """Tests Run()."""
   1086     # Generate a test payload. For this test, we generate a full update that
   1087     # has sample kernel and rootfs operations. Since most testing is done with
   1088     # internal PayloadChecker methods that are tested elsewhere, here we only
   1089     # tamper with what's actually being manipulated and/or tested in the Run()
   1090     # method itself. Note that the checker doesn't verify partition hashes, so
   1091     # they're safe to fake.
   1092     payload_gen = test_utils.EnhancedPayloadGenerator()
   1093     block_size = test_utils.KiB(4)
   1094     payload_gen.SetBlockSize(block_size)
   1095     kernel_filesystem_size = test_utils.KiB(16)
   1096     rootfs_filesystem_size = test_utils.MiB(2)
   1097     payload_gen.SetPartInfo(False, True, rootfs_filesystem_size,
   1098                             hashlib.sha256('fake-new-rootfs-content').digest())
   1099     payload_gen.SetPartInfo(True, True, kernel_filesystem_size,
   1100                             hashlib.sha256('fake-new-kernel-content').digest())
   1101     payload_gen.SetMinorVersion(0)
   1102 
   1103     rootfs_part_size = 0
   1104     if rootfs_part_size_provided:
   1105       rootfs_part_size = rootfs_filesystem_size + block_size
   1106     rootfs_op_size = rootfs_part_size or rootfs_filesystem_size
   1107     if fail_rootfs_part_size_exceeded:
   1108       rootfs_op_size += block_size
   1109     payload_gen.AddOperationWithData(
   1110         False, common.OpType.REPLACE,
   1111         dst_extents=[(0, rootfs_op_size / block_size)],
   1112         data_blob=os.urandom(rootfs_op_size))
   1113 
   1114     kernel_part_size = 0
   1115     if kernel_part_size_provided:
   1116       kernel_part_size = kernel_filesystem_size + block_size
   1117     kernel_op_size = kernel_part_size or kernel_filesystem_size
   1118     if fail_kernel_part_size_exceeded:
   1119       kernel_op_size += block_size
   1120     payload_gen.AddOperationWithData(
   1121         True, common.OpType.REPLACE,
   1122         dst_extents=[(0, kernel_op_size / block_size)],
   1123         data_blob=os.urandom(kernel_op_size))
   1124 
   1125     # Generate payload (complete w/ signature) and create the test object.
   1126     if fail_invalid_block_size:
   1127       use_block_size = block_size + 5  # Not a power of two.
   1128     elif fail_mismatched_block_size:
   1129       use_block_size = block_size * 2  # Different that payload stated.
   1130     else:
   1131       use_block_size = block_size
   1132 
   1133     kwargs = {
   1134         'payload_gen_dargs': {
   1135             'privkey_file_name': test_utils._PRIVKEY_FILE_NAME,
   1136             'do_add_pseudo_operation': True,
   1137             'is_pseudo_in_kernel': True,
   1138             'padding': os.urandom(1024) if fail_excess_data else None},
   1139         'checker_init_dargs': {
   1140             'assert_type': 'delta' if fail_wrong_payload_type else 'full',
   1141             'block_size': use_block_size}}
   1142     if fail_invalid_block_size:
   1143       self.assertRaises(PayloadError, _GetPayloadChecker,
   1144                         payload_gen.WriteToFileWithData, **kwargs)
   1145     else:
   1146       payload_checker = _GetPayloadChecker(payload_gen.WriteToFileWithData,
   1147                                            **kwargs)
   1148 
   1149       kwargs = {'pubkey_file_name': test_utils._PUBKEY_FILE_NAME,
   1150                 'rootfs_part_size': rootfs_part_size,
   1151                 'kernel_part_size': kernel_part_size}
   1152       should_fail = (fail_wrong_payload_type or fail_mismatched_block_size or
   1153                      fail_excess_data or
   1154                      fail_rootfs_part_size_exceeded or
   1155                      fail_kernel_part_size_exceeded)
   1156       if should_fail:
   1157         self.assertRaises(PayloadError, payload_checker.Run, **kwargs)
   1158       else:
   1159         self.assertIsNone(payload_checker.Run(**kwargs))
   1160 
   1161 # This implements a generic API, hence the occasional unused args.
   1162 # pylint: disable=W0613
   1163 def ValidateCheckOperationTest(op_type_name, is_last, allow_signature,
   1164                                allow_unhashed, fail_src_extents,
   1165                                fail_dst_extents,
   1166                                fail_mismatched_data_offset_length,
   1167                                fail_missing_dst_extents, fail_src_length,
   1168                                fail_dst_length, fail_data_hash,
   1169                                fail_prev_data_offset, fail_bad_minor_version):
   1170   """Returns True iff the combination of arguments represents a valid test."""
   1171   op_type = _OpTypeByName(op_type_name)
   1172 
   1173   # REPLACE/REPLACE_BZ operations don't read data from src partition. They are
   1174   # compatible with all valid minor versions, so we don't need to check that.
   1175   if (op_type in (common.OpType.REPLACE, common.OpType.REPLACE_BZ) and (
   1176       fail_src_extents or fail_src_length or fail_bad_minor_version)):
   1177     return False
   1178 
   1179   # MOVE and SOURCE_COPY operations don't carry data.
   1180   if (op_type in (common.OpType.MOVE, common.OpType.SOURCE_COPY) and (
   1181       fail_mismatched_data_offset_length or fail_data_hash or
   1182       fail_prev_data_offset)):
   1183     return False
   1184 
   1185   return True
   1186 
   1187 
   1188 def TestMethodBody(run_method_name, run_dargs):
   1189   """Returns a function that invokes a named method with named arguments."""
   1190   return lambda self: getattr(self, run_method_name)(**run_dargs)
   1191 
   1192 
   1193 def AddParametricTests(tested_method_name, arg_space, validate_func=None):
   1194   """Enumerates and adds specific parametric tests to PayloadCheckerTest.
   1195 
   1196   This function enumerates a space of test parameters (defined by arg_space),
   1197   then binds a new, unique method name in PayloadCheckerTest to a test function
   1198   that gets handed the said parameters. This is a preferable approach to doing
   1199   the enumeration and invocation during the tests because this way each test is
   1200   treated as a complete run by the unittest framework, and so benefits from the
   1201   usual setUp/tearDown mechanics.
   1202 
   1203   Args:
   1204     tested_method_name: Name of the tested PayloadChecker method.
   1205     arg_space: A dictionary containing variables (keys) and lists of values
   1206                (values) associated with them.
   1207     validate_func: A function used for validating test argument combinations.
   1208   """
   1209   for value_tuple in itertools.product(*arg_space.itervalues()):
   1210     run_dargs = dict(zip(arg_space.iterkeys(), value_tuple))
   1211     if validate_func and not validate_func(**run_dargs):
   1212       continue
   1213     run_method_name = 'Do%sTest' % tested_method_name
   1214     test_method_name = 'test%s' % tested_method_name
   1215     for arg_key, arg_val in run_dargs.iteritems():
   1216       if arg_val or type(arg_val) is int:
   1217         test_method_name += '__%s=%s' % (arg_key, arg_val)
   1218     setattr(PayloadCheckerTest, test_method_name,
   1219             TestMethodBody(run_method_name, run_dargs))
   1220 
   1221 
   1222 def AddAllParametricTests():
   1223   """Enumerates and adds all parametric tests to PayloadCheckerTest."""
   1224   # Add all _CheckElem() test cases.
   1225   AddParametricTests('AddElem',
   1226                      {'linebreak': (True, False),
   1227                       'indent': (0, 1, 2),
   1228                       'convert': (str, lambda s: s[::-1]),
   1229                       'is_present': (True, False),
   1230                       'is_mandatory': (True, False),
   1231                       'is_submsg': (True, False)})
   1232 
   1233   # Add all _Add{Mandatory,Optional}Field tests.
   1234   AddParametricTests('AddField',
   1235                      {'is_mandatory': (True, False),
   1236                       'linebreak': (True, False),
   1237                       'indent': (0, 1, 2),
   1238                       'convert': (str, lambda s: s[::-1]),
   1239                       'is_present': (True, False)})
   1240 
   1241   # Add all _Add{Mandatory,Optional}SubMsg tests.
   1242   AddParametricTests('AddSubMsg',
   1243                      {'is_mandatory': (True, False),
   1244                       'is_present': (True, False)})
   1245 
   1246   # Add all _CheckManifest() test cases.
   1247   AddParametricTests('CheckManifest',
   1248                      {'fail_mismatched_block_size': (True, False),
   1249                       'fail_bad_sigs': (True, False),
   1250                       'fail_mismatched_oki_ori': (True, False),
   1251                       'fail_bad_oki': (True, False),
   1252                       'fail_bad_ori': (True, False),
   1253                       'fail_bad_nki': (True, False),
   1254                       'fail_bad_nri': (True, False),
   1255                       'fail_old_kernel_fs_size': (True, False),
   1256                       'fail_old_rootfs_fs_size': (True, False),
   1257                       'fail_new_kernel_fs_size': (True, False),
   1258                       'fail_new_rootfs_fs_size': (True, False)})
   1259 
   1260   # Add all _CheckOperation() test cases.
   1261   AddParametricTests('CheckOperation',
   1262                      {'op_type_name': ('REPLACE', 'REPLACE_BZ', 'MOVE',
   1263                                        'BSDIFF', 'SOURCE_COPY',
   1264                                        'SOURCE_BSDIFF', 'PUFFDIFF',
   1265                                        'BROTLI_BSDIFF'),
   1266                       'is_last': (True, False),
   1267                       'allow_signature': (True, False),
   1268                       'allow_unhashed': (True, False),
   1269                       'fail_src_extents': (True, False),
   1270                       'fail_dst_extents': (True, False),
   1271                       'fail_mismatched_data_offset_length': (True, False),
   1272                       'fail_missing_dst_extents': (True, False),
   1273                       'fail_src_length': (True, False),
   1274                       'fail_dst_length': (True, False),
   1275                       'fail_data_hash': (True, False),
   1276                       'fail_prev_data_offset': (True, False),
   1277                       'fail_bad_minor_version': (True, False)},
   1278                      validate_func=ValidateCheckOperationTest)
   1279 
   1280   # Add all _CheckOperations() test cases.
   1281   AddParametricTests('CheckOperations',
   1282                      {'fail_nonexhaustive_full_update': (True, False)})
   1283 
   1284   # Add all _CheckOperations() test cases.
   1285   AddParametricTests('CheckSignatures',
   1286                      {'fail_empty_sigs_blob': (True, False),
   1287                       'fail_missing_pseudo_op': (True, False),
   1288                       'fail_mismatched_pseudo_op': (True, False),
   1289                       'fail_sig_missing_fields': (True, False),
   1290                       'fail_unknown_sig_version': (True, False),
   1291                       'fail_incorrect_sig': (True, False)})
   1292 
   1293   # Add all _CheckManifestMinorVersion() test cases.
   1294   AddParametricTests('CheckManifestMinorVersion',
   1295                      {'minor_version': (None, 0, 1, 2, 3, 4, 5, 555),
   1296                       'payload_type': (checker._TYPE_FULL,
   1297                                        checker._TYPE_DELTA)})
   1298 
   1299   # Add all Run() test cases.
   1300   AddParametricTests('Run',
   1301                      {'rootfs_part_size_provided': (True, False),
   1302                       'kernel_part_size_provided': (True, False),
   1303                       'fail_wrong_payload_type': (True, False),
   1304                       'fail_invalid_block_size': (True, False),
   1305                       'fail_mismatched_block_size': (True, False),
   1306                       'fail_excess_data': (True, False),
   1307                       'fail_rootfs_part_size_exceeded': (True, False),
   1308                       'fail_kernel_part_size_exceeded': (True, False)})
   1309 
   1310 
   1311 if __name__ == '__main__':
   1312   AddAllParametricTests()
   1313   unittest.main()
   1314