1 #!/usr/bin/python2 2 # 3 # Copyright (c) 2013 The Chromium OS Authors. All rights reserved. 4 # Use of this source code is governed by a BSD-style license that can be 5 # found in the LICENSE file. 6 7 """Unit testing checker.py.""" 8 9 from __future__ import print_function 10 11 import array 12 import collections 13 import cStringIO 14 import hashlib 15 import itertools 16 import os 17 import unittest 18 19 # pylint cannot find mox. 20 # pylint: disable=F0401 21 import mox 22 23 from update_payload import checker 24 from update_payload import common 25 from update_payload import test_utils 26 from update_payload import update_metadata_pb2 27 from update_payload.error import PayloadError 28 from update_payload.payload import Payload # Avoid name conflicts later. 29 30 31 def _OpTypeByName(op_name): 32 """Returns the type of an operation from itsname.""" 33 op_name_to_type = { 34 'REPLACE': common.OpType.REPLACE, 35 'REPLACE_BZ': common.OpType.REPLACE_BZ, 36 'MOVE': common.OpType.MOVE, 37 'BSDIFF': common.OpType.BSDIFF, 38 'SOURCE_COPY': common.OpType.SOURCE_COPY, 39 'SOURCE_BSDIFF': common.OpType.SOURCE_BSDIFF, 40 'ZERO': common.OpType.ZERO, 41 'DISCARD': common.OpType.DISCARD, 42 'REPLACE_XZ': common.OpType.REPLACE_XZ, 43 'PUFFDIFF': common.OpType.PUFFDIFF, 44 'BROTLI_BSDIFF': common.OpType.BROTLI_BSDIFF, 45 } 46 return op_name_to_type[op_name] 47 48 49 def _GetPayloadChecker(payload_gen_write_to_file_func, payload_gen_dargs=None, 50 checker_init_dargs=None): 51 """Returns a payload checker from a given payload generator.""" 52 if payload_gen_dargs is None: 53 payload_gen_dargs = {} 54 if checker_init_dargs is None: 55 checker_init_dargs = {} 56 57 payload_file = cStringIO.StringIO() 58 payload_gen_write_to_file_func(payload_file, **payload_gen_dargs) 59 payload_file.seek(0) 60 payload = Payload(payload_file) 61 payload.Init() 62 return checker.PayloadChecker(payload, **checker_init_dargs) 63 64 65 def _GetPayloadCheckerWithData(payload_gen): 66 """Returns a payload checker from a given payload generator.""" 67 payload_file = cStringIO.StringIO() 68 payload_gen.WriteToFile(payload_file) 69 payload_file.seek(0) 70 payload = Payload(payload_file) 71 payload.Init() 72 return checker.PayloadChecker(payload) 73 74 75 # This class doesn't need an __init__(). 76 # pylint: disable=W0232 77 # Unit testing is all about running protected methods. 78 # pylint: disable=W0212 79 # Don't bark about missing members of classes you cannot import. 80 # pylint: disable=E1101 81 class PayloadCheckerTest(mox.MoxTestBase): 82 """Tests the PayloadChecker class. 83 84 In addition to ordinary testFoo() methods, which are automatically invoked by 85 the unittest framework, in this class we make use of DoBarTest() calls that 86 implement parametric tests of certain features. In order to invoke each test, 87 which embodies a unique combination of parameter values, as a complete unit 88 test, we perform explicit enumeration of the parameter space and create 89 individual invocation contexts for each, which are then bound as 90 testBar__param1=val1__param2=val2(). The enumeration of parameter spaces for 91 all such tests is done in AddAllParametricTests(). 92 """ 93 94 def MockPayload(self): 95 """Create a mock payload object, complete with a mock manifest.""" 96 payload = self.mox.CreateMock(Payload) 97 payload.is_init = True 98 payload.manifest = self.mox.CreateMock( 99 update_metadata_pb2.DeltaArchiveManifest) 100 return payload 101 102 @staticmethod 103 def NewExtent(start_block, num_blocks): 104 """Returns an Extent message. 105 106 Each of the provided fields is set iff it is >= 0; otherwise, it's left at 107 its default state. 108 109 Args: 110 start_block: The starting block of the extent. 111 num_blocks: The number of blocks in the extent. 112 113 Returns: 114 An Extent message. 115 """ 116 ex = update_metadata_pb2.Extent() 117 if start_block >= 0: 118 ex.start_block = start_block 119 if num_blocks >= 0: 120 ex.num_blocks = num_blocks 121 return ex 122 123 @staticmethod 124 def NewExtentList(*args): 125 """Returns an list of extents. 126 127 Args: 128 *args: (start_block, num_blocks) pairs defining the extents. 129 130 Returns: 131 A list of Extent objects. 132 """ 133 ex_list = [] 134 for start_block, num_blocks in args: 135 ex_list.append(PayloadCheckerTest.NewExtent(start_block, num_blocks)) 136 return ex_list 137 138 @staticmethod 139 def AddToMessage(repeated_field, field_vals): 140 for field_val in field_vals: 141 new_field = repeated_field.add() 142 new_field.CopyFrom(field_val) 143 144 def SetupAddElemTest(self, is_present, is_submsg, convert=str, 145 linebreak=False, indent=0): 146 """Setup for testing of _CheckElem() and its derivatives. 147 148 Args: 149 is_present: Whether or not the element is found in the message. 150 is_submsg: Whether the element is a sub-message itself. 151 convert: A representation conversion function. 152 linebreak: Whether or not a linebreak is to be used in the report. 153 indent: Indentation used for the report. 154 155 Returns: 156 msg: A mock message object. 157 report: A mock report object. 158 subreport: A mock sub-report object. 159 name: An element name to check. 160 val: Expected element value. 161 """ 162 name = 'foo' 163 val = 'fake submsg' if is_submsg else 'fake field' 164 subreport = 'fake subreport' 165 166 # Create a mock message. 167 msg = self.mox.CreateMock(update_metadata_pb2._message.Message) 168 msg.HasField(name).AndReturn(is_present) 169 setattr(msg, name, val) 170 171 # Create a mock report. 172 report = self.mox.CreateMock(checker._PayloadReport) 173 if is_present: 174 if is_submsg: 175 report.AddSubReport(name).AndReturn(subreport) 176 else: 177 report.AddField(name, convert(val), linebreak=linebreak, indent=indent) 178 179 self.mox.ReplayAll() 180 return (msg, report, subreport, name, val) 181 182 def DoAddElemTest(self, is_present, is_mandatory, is_submsg, convert, 183 linebreak, indent): 184 """Parametric testing of _CheckElem(). 185 186 Args: 187 is_present: Whether or not the element is found in the message. 188 is_mandatory: Whether or not it's a mandatory element. 189 is_submsg: Whether the element is a sub-message itself. 190 convert: A representation conversion function. 191 linebreak: Whether or not a linebreak is to be used in the report. 192 indent: Indentation used for the report. 193 """ 194 msg, report, subreport, name, val = self.SetupAddElemTest( 195 is_present, is_submsg, convert, linebreak, indent) 196 197 args = (msg, name, report, is_mandatory, is_submsg) 198 kwargs = {'convert': convert, 'linebreak': linebreak, 'indent': indent} 199 if is_mandatory and not is_present: 200 self.assertRaises(PayloadError, 201 checker.PayloadChecker._CheckElem, *args, **kwargs) 202 else: 203 ret_val, ret_subreport = checker.PayloadChecker._CheckElem(*args, 204 **kwargs) 205 self.assertEquals(val if is_present else None, ret_val) 206 self.assertEquals(subreport if is_present and is_submsg else None, 207 ret_subreport) 208 209 def DoAddFieldTest(self, is_mandatory, is_present, convert, linebreak, 210 indent): 211 """Parametric testing of _Check{Mandatory,Optional}Field(). 212 213 Args: 214 is_mandatory: Whether we're testing a mandatory call. 215 is_present: Whether or not the element is found in the message. 216 convert: A representation conversion function. 217 linebreak: Whether or not a linebreak is to be used in the report. 218 indent: Indentation used for the report. 219 """ 220 msg, report, _, name, val = self.SetupAddElemTest( 221 is_present, False, convert, linebreak, indent) 222 223 # Prepare for invocation of the tested method. 224 args = [msg, name, report] 225 kwargs = {'convert': convert, 'linebreak': linebreak, 'indent': indent} 226 if is_mandatory: 227 args.append('bar') 228 tested_func = checker.PayloadChecker._CheckMandatoryField 229 else: 230 tested_func = checker.PayloadChecker._CheckOptionalField 231 232 # Test the method call. 233 if is_mandatory and not is_present: 234 self.assertRaises(PayloadError, tested_func, *args, **kwargs) 235 else: 236 ret_val = tested_func(*args, **kwargs) 237 self.assertEquals(val if is_present else None, ret_val) 238 239 def DoAddSubMsgTest(self, is_mandatory, is_present): 240 """Parametrized testing of _Check{Mandatory,Optional}SubMsg(). 241 242 Args: 243 is_mandatory: Whether we're testing a mandatory call. 244 is_present: Whether or not the element is found in the message. 245 """ 246 msg, report, subreport, name, val = self.SetupAddElemTest(is_present, True) 247 248 # Prepare for invocation of the tested method. 249 args = [msg, name, report] 250 if is_mandatory: 251 args.append('bar') 252 tested_func = checker.PayloadChecker._CheckMandatorySubMsg 253 else: 254 tested_func = checker.PayloadChecker._CheckOptionalSubMsg 255 256 # Test the method call. 257 if is_mandatory and not is_present: 258 self.assertRaises(PayloadError, tested_func, *args) 259 else: 260 ret_val, ret_subreport = tested_func(*args) 261 self.assertEquals(val if is_present else None, ret_val) 262 self.assertEquals(subreport if is_present else None, ret_subreport) 263 264 def testCheckPresentIff(self): 265 """Tests _CheckPresentIff().""" 266 self.assertIsNone(checker.PayloadChecker._CheckPresentIff( 267 None, None, 'foo', 'bar', 'baz')) 268 self.assertIsNone(checker.PayloadChecker._CheckPresentIff( 269 'a', 'b', 'foo', 'bar', 'baz')) 270 self.assertRaises(PayloadError, checker.PayloadChecker._CheckPresentIff, 271 'a', None, 'foo', 'bar', 'baz') 272 self.assertRaises(PayloadError, checker.PayloadChecker._CheckPresentIff, 273 None, 'b', 'foo', 'bar', 'baz') 274 275 def DoCheckSha256SignatureTest(self, expect_pass, expect_subprocess_call, 276 sig_data, sig_asn1_header, 277 returned_signed_hash, expected_signed_hash): 278 """Parametric testing of _CheckSha256SignatureTest(). 279 280 Args: 281 expect_pass: Whether or not it should pass. 282 expect_subprocess_call: Whether to expect the openssl call to happen. 283 sig_data: The signature raw data. 284 sig_asn1_header: The ASN1 header. 285 returned_signed_hash: The signed hash data retuned by openssl. 286 expected_signed_hash: The signed hash data to compare against. 287 """ 288 try: 289 # Stub out the subprocess invocation. 290 self.mox.StubOutWithMock(checker.PayloadChecker, '_Run') 291 if expect_subprocess_call: 292 checker.PayloadChecker._Run( 293 mox.IsA(list), send_data=sig_data).AndReturn( 294 (sig_asn1_header + returned_signed_hash, None)) 295 296 self.mox.ReplayAll() 297 if expect_pass: 298 self.assertIsNone(checker.PayloadChecker._CheckSha256Signature( 299 sig_data, 'foo', expected_signed_hash, 'bar')) 300 else: 301 self.assertRaises(PayloadError, 302 checker.PayloadChecker._CheckSha256Signature, 303 sig_data, 'foo', expected_signed_hash, 'bar') 304 finally: 305 self.mox.UnsetStubs() 306 307 def testCheckSha256Signature_Pass(self): 308 """Tests _CheckSha256Signature(); pass case.""" 309 sig_data = 'fake-signature'.ljust(256) 310 signed_hash = hashlib.sha256('fake-data').digest() 311 self.DoCheckSha256SignatureTest(True, True, sig_data, 312 common.SIG_ASN1_HEADER, signed_hash, 313 signed_hash) 314 315 def testCheckSha256Signature_FailBadSignature(self): 316 """Tests _CheckSha256Signature(); fails due to malformed signature.""" 317 sig_data = 'fake-signature' # Malformed (not 256 bytes in length). 318 signed_hash = hashlib.sha256('fake-data').digest() 319 self.DoCheckSha256SignatureTest(False, False, sig_data, 320 common.SIG_ASN1_HEADER, signed_hash, 321 signed_hash) 322 323 def testCheckSha256Signature_FailBadOutputLength(self): 324 """Tests _CheckSha256Signature(); fails due to unexpected output length.""" 325 sig_data = 'fake-signature'.ljust(256) 326 signed_hash = 'fake-hash' # Malformed (not 32 bytes in length). 327 self.DoCheckSha256SignatureTest(False, True, sig_data, 328 common.SIG_ASN1_HEADER, signed_hash, 329 signed_hash) 330 331 def testCheckSha256Signature_FailBadAsnHeader(self): 332 """Tests _CheckSha256Signature(); fails due to bad ASN1 header.""" 333 sig_data = 'fake-signature'.ljust(256) 334 signed_hash = hashlib.sha256('fake-data').digest() 335 bad_asn1_header = 'bad-asn-header'.ljust(len(common.SIG_ASN1_HEADER)) 336 self.DoCheckSha256SignatureTest(False, True, sig_data, bad_asn1_header, 337 signed_hash, signed_hash) 338 339 def testCheckSha256Signature_FailBadHash(self): 340 """Tests _CheckSha256Signature(); fails due to bad hash returned.""" 341 sig_data = 'fake-signature'.ljust(256) 342 expected_signed_hash = hashlib.sha256('fake-data').digest() 343 returned_signed_hash = hashlib.sha256('bad-fake-data').digest() 344 self.DoCheckSha256SignatureTest(False, True, sig_data, 345 common.SIG_ASN1_HEADER, 346 expected_signed_hash, returned_signed_hash) 347 348 def testCheckBlocksFitLength_Pass(self): 349 """Tests _CheckBlocksFitLength(); pass case.""" 350 self.assertIsNone(checker.PayloadChecker._CheckBlocksFitLength( 351 64, 4, 16, 'foo')) 352 self.assertIsNone(checker.PayloadChecker._CheckBlocksFitLength( 353 60, 4, 16, 'foo')) 354 self.assertIsNone(checker.PayloadChecker._CheckBlocksFitLength( 355 49, 4, 16, 'foo')) 356 self.assertIsNone(checker.PayloadChecker._CheckBlocksFitLength( 357 48, 3, 16, 'foo')) 358 359 def testCheckBlocksFitLength_TooManyBlocks(self): 360 """Tests _CheckBlocksFitLength(); fails due to excess blocks.""" 361 self.assertRaises(PayloadError, 362 checker.PayloadChecker._CheckBlocksFitLength, 363 64, 5, 16, 'foo') 364 self.assertRaises(PayloadError, 365 checker.PayloadChecker._CheckBlocksFitLength, 366 60, 5, 16, 'foo') 367 self.assertRaises(PayloadError, 368 checker.PayloadChecker._CheckBlocksFitLength, 369 49, 5, 16, 'foo') 370 self.assertRaises(PayloadError, 371 checker.PayloadChecker._CheckBlocksFitLength, 372 48, 4, 16, 'foo') 373 374 def testCheckBlocksFitLength_TooFewBlocks(self): 375 """Tests _CheckBlocksFitLength(); fails due to insufficient blocks.""" 376 self.assertRaises(PayloadError, 377 checker.PayloadChecker._CheckBlocksFitLength, 378 64, 3, 16, 'foo') 379 self.assertRaises(PayloadError, 380 checker.PayloadChecker._CheckBlocksFitLength, 381 60, 3, 16, 'foo') 382 self.assertRaises(PayloadError, 383 checker.PayloadChecker._CheckBlocksFitLength, 384 49, 3, 16, 'foo') 385 self.assertRaises(PayloadError, 386 checker.PayloadChecker._CheckBlocksFitLength, 387 48, 2, 16, 'foo') 388 389 def DoCheckManifestTest(self, fail_mismatched_block_size, fail_bad_sigs, 390 fail_mismatched_oki_ori, fail_bad_oki, fail_bad_ori, 391 fail_bad_nki, fail_bad_nri, fail_old_kernel_fs_size, 392 fail_old_rootfs_fs_size, fail_new_kernel_fs_size, 393 fail_new_rootfs_fs_size): 394 """Parametric testing of _CheckManifest(). 395 396 Args: 397 fail_mismatched_block_size: Simulate a missing block_size field. 398 fail_bad_sigs: Make signatures descriptor inconsistent. 399 fail_mismatched_oki_ori: Make old rootfs/kernel info partially present. 400 fail_bad_oki: Tamper with old kernel info. 401 fail_bad_ori: Tamper with old rootfs info. 402 fail_bad_nki: Tamper with new kernel info. 403 fail_bad_nri: Tamper with new rootfs info. 404 fail_old_kernel_fs_size: Make old kernel fs size too big. 405 fail_old_rootfs_fs_size: Make old rootfs fs size too big. 406 fail_new_kernel_fs_size: Make new kernel fs size too big. 407 fail_new_rootfs_fs_size: Make new rootfs fs size too big. 408 """ 409 # Generate a test payload. For this test, we only care about the manifest 410 # and don't need any data blobs, hence we can use a plain paylaod generator 411 # (which also gives us more control on things that can be screwed up). 412 payload_gen = test_utils.PayloadGenerator() 413 414 # Tamper with block size, if required. 415 if fail_mismatched_block_size: 416 payload_gen.SetBlockSize(test_utils.KiB(1)) 417 else: 418 payload_gen.SetBlockSize(test_utils.KiB(4)) 419 420 # Add some operations. 421 payload_gen.AddOperation(False, common.OpType.MOVE, 422 src_extents=[(0, 16), (16, 497)], 423 dst_extents=[(16, 496), (0, 16)]) 424 payload_gen.AddOperation(True, common.OpType.MOVE, 425 src_extents=[(0, 8), (8, 8)], 426 dst_extents=[(8, 8), (0, 8)]) 427 428 # Set an invalid signatures block (offset but no size), if required. 429 if fail_bad_sigs: 430 payload_gen.SetSignatures(32, None) 431 432 # Set partition / filesystem sizes. 433 rootfs_part_size = test_utils.MiB(8) 434 kernel_part_size = test_utils.KiB(512) 435 old_rootfs_fs_size = new_rootfs_fs_size = rootfs_part_size 436 old_kernel_fs_size = new_kernel_fs_size = kernel_part_size 437 if fail_old_kernel_fs_size: 438 old_kernel_fs_size += 100 439 if fail_old_rootfs_fs_size: 440 old_rootfs_fs_size += 100 441 if fail_new_kernel_fs_size: 442 new_kernel_fs_size += 100 443 if fail_new_rootfs_fs_size: 444 new_rootfs_fs_size += 100 445 446 # Add old kernel/rootfs partition info, as required. 447 if fail_mismatched_oki_ori or fail_old_kernel_fs_size or fail_bad_oki: 448 oki_hash = (None if fail_bad_oki 449 else hashlib.sha256('fake-oki-content').digest()) 450 payload_gen.SetPartInfo(True, False, old_kernel_fs_size, oki_hash) 451 if not fail_mismatched_oki_ori and (fail_old_rootfs_fs_size or 452 fail_bad_ori): 453 ori_hash = (None if fail_bad_ori 454 else hashlib.sha256('fake-ori-content').digest()) 455 payload_gen.SetPartInfo(False, False, old_rootfs_fs_size, ori_hash) 456 457 # Add new kernel/rootfs partition info. 458 payload_gen.SetPartInfo( 459 True, True, new_kernel_fs_size, 460 None if fail_bad_nki else hashlib.sha256('fake-nki-content').digest()) 461 payload_gen.SetPartInfo( 462 False, True, new_rootfs_fs_size, 463 None if fail_bad_nri else hashlib.sha256('fake-nri-content').digest()) 464 465 # Set the minor version. 466 payload_gen.SetMinorVersion(0) 467 468 # Create the test object. 469 payload_checker = _GetPayloadChecker(payload_gen.WriteToFile) 470 report = checker._PayloadReport() 471 472 should_fail = (fail_mismatched_block_size or fail_bad_sigs or 473 fail_mismatched_oki_ori or fail_bad_oki or fail_bad_ori or 474 fail_bad_nki or fail_bad_nri or fail_old_kernel_fs_size or 475 fail_old_rootfs_fs_size or fail_new_kernel_fs_size or 476 fail_new_rootfs_fs_size) 477 if should_fail: 478 self.assertRaises(PayloadError, payload_checker._CheckManifest, report, 479 rootfs_part_size, kernel_part_size) 480 else: 481 self.assertIsNone(payload_checker._CheckManifest(report, 482 rootfs_part_size, 483 kernel_part_size)) 484 485 def testCheckLength(self): 486 """Tests _CheckLength().""" 487 payload_checker = checker.PayloadChecker(self.MockPayload()) 488 block_size = payload_checker.block_size 489 490 # Passes. 491 self.assertIsNone(payload_checker._CheckLength( 492 int(3.5 * block_size), 4, 'foo', 'bar')) 493 # Fails, too few blocks. 494 self.assertRaises(PayloadError, payload_checker._CheckLength, 495 int(3.5 * block_size), 3, 'foo', 'bar') 496 # Fails, too many blocks. 497 self.assertRaises(PayloadError, payload_checker._CheckLength, 498 int(3.5 * block_size), 5, 'foo', 'bar') 499 500 def testCheckExtents(self): 501 """Tests _CheckExtents().""" 502 payload_checker = checker.PayloadChecker(self.MockPayload()) 503 block_size = payload_checker.block_size 504 505 # Passes w/ all real extents. 506 extents = self.NewExtentList((0, 4), (8, 3), (1024, 16)) 507 self.assertEquals( 508 23, 509 payload_checker._CheckExtents(extents, (1024 + 16) * block_size, 510 collections.defaultdict(int), 'foo')) 511 512 # Passes w/ pseudo-extents (aka sparse holes). 513 extents = self.NewExtentList((0, 4), (common.PSEUDO_EXTENT_MARKER, 5), 514 (8, 3)) 515 self.assertEquals( 516 12, 517 payload_checker._CheckExtents(extents, (1024 + 16) * block_size, 518 collections.defaultdict(int), 'foo', 519 allow_pseudo=True)) 520 521 # Passes w/ pseudo-extent due to a signature. 522 extents = self.NewExtentList((common.PSEUDO_EXTENT_MARKER, 2)) 523 self.assertEquals( 524 2, 525 payload_checker._CheckExtents(extents, (1024 + 16) * block_size, 526 collections.defaultdict(int), 'foo', 527 allow_signature=True)) 528 529 # Fails, extent missing a start block. 530 extents = self.NewExtentList((-1, 4), (8, 3), (1024, 16)) 531 self.assertRaises( 532 PayloadError, payload_checker._CheckExtents, extents, 533 (1024 + 16) * block_size, collections.defaultdict(int), 'foo') 534 535 # Fails, extent missing block count. 536 extents = self.NewExtentList((0, -1), (8, 3), (1024, 16)) 537 self.assertRaises( 538 PayloadError, payload_checker._CheckExtents, extents, 539 (1024 + 16) * block_size, collections.defaultdict(int), 'foo') 540 541 # Fails, extent has zero blocks. 542 extents = self.NewExtentList((0, 4), (8, 3), (1024, 0)) 543 self.assertRaises( 544 PayloadError, payload_checker._CheckExtents, extents, 545 (1024 + 16) * block_size, collections.defaultdict(int), 'foo') 546 547 # Fails, extent exceeds partition boundaries. 548 extents = self.NewExtentList((0, 4), (8, 3), (1024, 16)) 549 self.assertRaises( 550 PayloadError, payload_checker._CheckExtents, extents, 551 (1024 + 15) * block_size, collections.defaultdict(int), 'foo') 552 553 def testCheckReplaceOperation(self): 554 """Tests _CheckReplaceOperation() where op.type == REPLACE.""" 555 payload_checker = checker.PayloadChecker(self.MockPayload()) 556 block_size = payload_checker.block_size 557 data_length = 10000 558 559 op = self.mox.CreateMock( 560 update_metadata_pb2.InstallOperation) 561 op.type = common.OpType.REPLACE 562 563 # Pass. 564 op.src_extents = [] 565 self.assertIsNone( 566 payload_checker._CheckReplaceOperation( 567 op, data_length, (data_length + block_size - 1) / block_size, 568 'foo')) 569 570 # Fail, src extents founds. 571 op.src_extents = ['bar'] 572 self.assertRaises( 573 PayloadError, payload_checker._CheckReplaceOperation, 574 op, data_length, (data_length + block_size - 1) / block_size, 'foo') 575 576 # Fail, missing data. 577 op.src_extents = [] 578 self.assertRaises( 579 PayloadError, payload_checker._CheckReplaceOperation, 580 op, None, (data_length + block_size - 1) / block_size, 'foo') 581 582 # Fail, length / block number mismatch. 583 op.src_extents = ['bar'] 584 self.assertRaises( 585 PayloadError, payload_checker._CheckReplaceOperation, 586 op, data_length, (data_length + block_size - 1) / block_size + 1, 'foo') 587 588 def testCheckReplaceBzOperation(self): 589 """Tests _CheckReplaceOperation() where op.type == REPLACE_BZ.""" 590 payload_checker = checker.PayloadChecker(self.MockPayload()) 591 block_size = payload_checker.block_size 592 data_length = block_size * 3 593 594 op = self.mox.CreateMock( 595 update_metadata_pb2.InstallOperation) 596 op.type = common.OpType.REPLACE_BZ 597 598 # Pass. 599 op.src_extents = [] 600 self.assertIsNone( 601 payload_checker._CheckReplaceOperation( 602 op, data_length, (data_length + block_size - 1) / block_size + 5, 603 'foo')) 604 605 # Fail, src extents founds. 606 op.src_extents = ['bar'] 607 self.assertRaises( 608 PayloadError, payload_checker._CheckReplaceOperation, 609 op, data_length, (data_length + block_size - 1) / block_size + 5, 'foo') 610 611 # Fail, missing data. 612 op.src_extents = [] 613 self.assertRaises( 614 PayloadError, payload_checker._CheckReplaceOperation, 615 op, None, (data_length + block_size - 1) / block_size, 'foo') 616 617 # Fail, too few blocks to justify BZ. 618 op.src_extents = [] 619 self.assertRaises( 620 PayloadError, payload_checker._CheckReplaceOperation, 621 op, data_length, (data_length + block_size - 1) / block_size, 'foo') 622 623 def testCheckMoveOperation_Pass(self): 624 """Tests _CheckMoveOperation(); pass case.""" 625 payload_checker = checker.PayloadChecker(self.MockPayload()) 626 op = update_metadata_pb2.InstallOperation() 627 op.type = common.OpType.MOVE 628 629 self.AddToMessage(op.src_extents, 630 self.NewExtentList((1, 4), (12, 2), (1024, 128))) 631 self.AddToMessage(op.dst_extents, 632 self.NewExtentList((16, 128), (512, 6))) 633 self.assertIsNone( 634 payload_checker._CheckMoveOperation(op, None, 134, 134, 'foo')) 635 636 def testCheckMoveOperation_FailContainsData(self): 637 """Tests _CheckMoveOperation(); fails, message contains data.""" 638 payload_checker = checker.PayloadChecker(self.MockPayload()) 639 op = update_metadata_pb2.InstallOperation() 640 op.type = common.OpType.MOVE 641 642 self.AddToMessage(op.src_extents, 643 self.NewExtentList((1, 4), (12, 2), (1024, 128))) 644 self.AddToMessage(op.dst_extents, 645 self.NewExtentList((16, 128), (512, 6))) 646 self.assertRaises( 647 PayloadError, payload_checker._CheckMoveOperation, 648 op, 1024, 134, 134, 'foo') 649 650 def testCheckMoveOperation_FailInsufficientSrcBlocks(self): 651 """Tests _CheckMoveOperation(); fails, not enough actual src blocks.""" 652 payload_checker = checker.PayloadChecker(self.MockPayload()) 653 op = update_metadata_pb2.InstallOperation() 654 op.type = common.OpType.MOVE 655 656 self.AddToMessage(op.src_extents, 657 self.NewExtentList((1, 4), (12, 2), (1024, 127))) 658 self.AddToMessage(op.dst_extents, 659 self.NewExtentList((16, 128), (512, 6))) 660 self.assertRaises( 661 PayloadError, payload_checker._CheckMoveOperation, 662 op, None, 134, 134, 'foo') 663 664 def testCheckMoveOperation_FailInsufficientDstBlocks(self): 665 """Tests _CheckMoveOperation(); fails, not enough actual dst blocks.""" 666 payload_checker = checker.PayloadChecker(self.MockPayload()) 667 op = update_metadata_pb2.InstallOperation() 668 op.type = common.OpType.MOVE 669 670 self.AddToMessage(op.src_extents, 671 self.NewExtentList((1, 4), (12, 2), (1024, 128))) 672 self.AddToMessage(op.dst_extents, 673 self.NewExtentList((16, 128), (512, 5))) 674 self.assertRaises( 675 PayloadError, payload_checker._CheckMoveOperation, 676 op, None, 134, 134, 'foo') 677 678 def testCheckMoveOperation_FailExcessSrcBlocks(self): 679 """Tests _CheckMoveOperation(); fails, too many actual src blocks.""" 680 payload_checker = checker.PayloadChecker(self.MockPayload()) 681 op = update_metadata_pb2.InstallOperation() 682 op.type = common.OpType.MOVE 683 684 self.AddToMessage(op.src_extents, 685 self.NewExtentList((1, 4), (12, 2), (1024, 128))) 686 self.AddToMessage(op.dst_extents, 687 self.NewExtentList((16, 128), (512, 5))) 688 self.assertRaises( 689 PayloadError, payload_checker._CheckMoveOperation, 690 op, None, 134, 134, 'foo') 691 self.AddToMessage(op.src_extents, 692 self.NewExtentList((1, 4), (12, 2), (1024, 129))) 693 self.AddToMessage(op.dst_extents, 694 self.NewExtentList((16, 128), (512, 6))) 695 self.assertRaises( 696 PayloadError, payload_checker._CheckMoveOperation, 697 op, None, 134, 134, 'foo') 698 699 def testCheckMoveOperation_FailExcessDstBlocks(self): 700 """Tests _CheckMoveOperation(); fails, too many actual dst blocks.""" 701 payload_checker = checker.PayloadChecker(self.MockPayload()) 702 op = update_metadata_pb2.InstallOperation() 703 op.type = common.OpType.MOVE 704 705 self.AddToMessage(op.src_extents, 706 self.NewExtentList((1, 4), (12, 2), (1024, 128))) 707 self.AddToMessage(op.dst_extents, 708 self.NewExtentList((16, 128), (512, 7))) 709 self.assertRaises( 710 PayloadError, payload_checker._CheckMoveOperation, 711 op, None, 134, 134, 'foo') 712 713 def testCheckMoveOperation_FailStagnantBlocks(self): 714 """Tests _CheckMoveOperation(); fails, there are blocks that do not move.""" 715 payload_checker = checker.PayloadChecker(self.MockPayload()) 716 op = update_metadata_pb2.InstallOperation() 717 op.type = common.OpType.MOVE 718 719 self.AddToMessage(op.src_extents, 720 self.NewExtentList((1, 4), (12, 2), (1024, 128))) 721 self.AddToMessage(op.dst_extents, 722 self.NewExtentList((8, 128), (512, 6))) 723 self.assertRaises( 724 PayloadError, payload_checker._CheckMoveOperation, 725 op, None, 134, 134, 'foo') 726 727 def testCheckMoveOperation_FailZeroStartBlock(self): 728 """Tests _CheckMoveOperation(); fails, has extent with start block 0.""" 729 payload_checker = checker.PayloadChecker(self.MockPayload()) 730 op = update_metadata_pb2.InstallOperation() 731 op.type = common.OpType.MOVE 732 733 self.AddToMessage(op.src_extents, 734 self.NewExtentList((0, 4), (12, 2), (1024, 128))) 735 self.AddToMessage(op.dst_extents, 736 self.NewExtentList((8, 128), (512, 6))) 737 self.assertRaises( 738 PayloadError, payload_checker._CheckMoveOperation, 739 op, None, 134, 134, 'foo') 740 741 self.AddToMessage(op.src_extents, 742 self.NewExtentList((1, 4), (12, 2), (1024, 128))) 743 self.AddToMessage(op.dst_extents, 744 self.NewExtentList((0, 128), (512, 6))) 745 self.assertRaises( 746 PayloadError, payload_checker._CheckMoveOperation, 747 op, None, 134, 134, 'foo') 748 749 def testCheckAnyDiff(self): 750 """Tests _CheckAnyDiffOperation().""" 751 payload_checker = checker.PayloadChecker(self.MockPayload()) 752 op = update_metadata_pb2.InstallOperation() 753 754 # Pass. 755 self.assertIsNone( 756 payload_checker._CheckAnyDiffOperation(op, 10000, 3, 'foo')) 757 758 # Fail, missing data blob. 759 self.assertRaises( 760 PayloadError, payload_checker._CheckAnyDiffOperation, 761 op, None, 3, 'foo') 762 763 # Fail, too big of a diff blob (unjustified). 764 self.assertRaises( 765 PayloadError, payload_checker._CheckAnyDiffOperation, 766 op, 10000, 2, 'foo') 767 768 def testCheckSourceCopyOperation_Pass(self): 769 """Tests _CheckSourceCopyOperation(); pass case.""" 770 payload_checker = checker.PayloadChecker(self.MockPayload()) 771 self.assertIsNone( 772 payload_checker._CheckSourceCopyOperation(None, 134, 134, 'foo')) 773 774 def testCheckSourceCopyOperation_FailContainsData(self): 775 """Tests _CheckSourceCopyOperation(); message contains data.""" 776 payload_checker = checker.PayloadChecker(self.MockPayload()) 777 self.assertRaises(PayloadError, payload_checker._CheckSourceCopyOperation, 778 134, 0, 0, 'foo') 779 780 def testCheckSourceCopyOperation_FailBlockCountsMismatch(self): 781 """Tests _CheckSourceCopyOperation(); src and dst block totals not equal.""" 782 payload_checker = checker.PayloadChecker(self.MockPayload()) 783 self.assertRaises(PayloadError, payload_checker._CheckSourceCopyOperation, 784 None, 0, 1, 'foo') 785 786 def DoCheckOperationTest(self, op_type_name, is_last, allow_signature, 787 allow_unhashed, fail_src_extents, fail_dst_extents, 788 fail_mismatched_data_offset_length, 789 fail_missing_dst_extents, fail_src_length, 790 fail_dst_length, fail_data_hash, 791 fail_prev_data_offset, fail_bad_minor_version): 792 """Parametric testing of _CheckOperation(). 793 794 Args: 795 op_type_name: 'REPLACE', 'REPLACE_BZ', 'MOVE', 'BSDIFF', 'SOURCE_COPY', 796 'SOURCE_BSDIFF', BROTLI_BSDIFF or 'PUFFDIFF'. 797 is_last: Whether we're testing the last operation in a sequence. 798 allow_signature: Whether we're testing a signature-capable operation. 799 allow_unhashed: Whether we're allowing to not hash the data. 800 fail_src_extents: Tamper with src extents. 801 fail_dst_extents: Tamper with dst extents. 802 fail_mismatched_data_offset_length: Make data_{offset,length} 803 inconsistent. 804 fail_missing_dst_extents: Do not include dst extents. 805 fail_src_length: Make src length inconsistent. 806 fail_dst_length: Make dst length inconsistent. 807 fail_data_hash: Tamper with the data blob hash. 808 fail_prev_data_offset: Make data space uses incontiguous. 809 fail_bad_minor_version: Make minor version incompatible with op. 810 """ 811 op_type = _OpTypeByName(op_type_name) 812 813 # Create the test object. 814 payload = self.MockPayload() 815 payload_checker = checker.PayloadChecker(payload, 816 allow_unhashed=allow_unhashed) 817 block_size = payload_checker.block_size 818 819 # Create auxiliary arguments. 820 old_part_size = test_utils.MiB(4) 821 new_part_size = test_utils.MiB(8) 822 old_block_counters = array.array( 823 'B', [0] * ((old_part_size + block_size - 1) / block_size)) 824 new_block_counters = array.array( 825 'B', [0] * ((new_part_size + block_size - 1) / block_size)) 826 prev_data_offset = 1876 827 blob_hash_counts = collections.defaultdict(int) 828 829 # Create the operation object for the test. 830 op = update_metadata_pb2.InstallOperation() 831 op.type = op_type 832 833 total_src_blocks = 0 834 if op_type in (common.OpType.MOVE, common.OpType.BSDIFF, 835 common.OpType.SOURCE_COPY, common.OpType.SOURCE_BSDIFF, 836 common.OpType.PUFFDIFF, common.OpType.BROTLI_BSDIFF): 837 if fail_src_extents: 838 self.AddToMessage(op.src_extents, 839 self.NewExtentList((1, 0))) 840 else: 841 self.AddToMessage(op.src_extents, 842 self.NewExtentList((1, 16))) 843 total_src_blocks = 16 844 845 if op_type in (common.OpType.REPLACE, common.OpType.REPLACE_BZ): 846 payload_checker.minor_version = 0 847 elif op_type in (common.OpType.MOVE, common.OpType.BSDIFF): 848 payload_checker.minor_version = 2 if fail_bad_minor_version else 1 849 elif op_type in (common.OpType.SOURCE_COPY, common.OpType.SOURCE_BSDIFF): 850 payload_checker.minor_version = 1 if fail_bad_minor_version else 2 851 elif op_type in (common.OpType.ZERO, common.OpType.DISCARD, 852 common.OpType.BROTLI_BSDIFF): 853 payload_checker.minor_version = 3 if fail_bad_minor_version else 4 854 elif op_type == common.OpType.PUFFDIFF: 855 payload_checker.minor_version = 4 if fail_bad_minor_version else 5 856 857 if op_type not in (common.OpType.MOVE, common.OpType.SOURCE_COPY): 858 if not fail_mismatched_data_offset_length: 859 op.data_length = 16 * block_size - 8 860 if fail_prev_data_offset: 861 op.data_offset = prev_data_offset + 16 862 else: 863 op.data_offset = prev_data_offset 864 865 fake_data = 'fake-data'.ljust(op.data_length) 866 if not (allow_unhashed or (is_last and allow_signature and 867 op_type == common.OpType.REPLACE)): 868 if not fail_data_hash: 869 # Create a valid data blob hash. 870 op.data_sha256_hash = hashlib.sha256(fake_data).digest() 871 payload.ReadDataBlob(op.data_offset, op.data_length).AndReturn( 872 fake_data) 873 874 elif fail_data_hash: 875 # Create an invalid data blob hash. 876 op.data_sha256_hash = hashlib.sha256( 877 fake_data.replace(' ', '-')).digest() 878 payload.ReadDataBlob(op.data_offset, op.data_length).AndReturn( 879 fake_data) 880 881 total_dst_blocks = 0 882 if not fail_missing_dst_extents: 883 total_dst_blocks = 16 884 if fail_dst_extents: 885 self.AddToMessage(op.dst_extents, 886 self.NewExtentList((4, 16), (32, 0))) 887 else: 888 self.AddToMessage(op.dst_extents, 889 self.NewExtentList((4, 8), (64, 8))) 890 891 if total_src_blocks: 892 if fail_src_length: 893 op.src_length = total_src_blocks * block_size + 8 894 elif (op_type in (common.OpType.MOVE, common.OpType.BSDIFF, 895 common.OpType.SOURCE_BSDIFF) and 896 payload_checker.minor_version <= 3): 897 op.src_length = total_src_blocks * block_size 898 elif fail_src_length: 899 # Add an orphaned src_length. 900 op.src_length = 16 901 902 if total_dst_blocks: 903 if fail_dst_length: 904 op.dst_length = total_dst_blocks * block_size + 8 905 elif (op_type in (common.OpType.MOVE, common.OpType.BSDIFF, 906 common.OpType.SOURCE_BSDIFF) and 907 payload_checker.minor_version <= 3): 908 op.dst_length = total_dst_blocks * block_size 909 910 self.mox.ReplayAll() 911 should_fail = (fail_src_extents or fail_dst_extents or 912 fail_mismatched_data_offset_length or 913 fail_missing_dst_extents or fail_src_length or 914 fail_dst_length or fail_data_hash or fail_prev_data_offset or 915 fail_bad_minor_version) 916 args = (op, 'foo', is_last, old_block_counters, new_block_counters, 917 old_part_size, new_part_size, prev_data_offset, allow_signature, 918 blob_hash_counts) 919 if should_fail: 920 self.assertRaises(PayloadError, payload_checker._CheckOperation, *args) 921 else: 922 self.assertEqual(op.data_length if op.HasField('data_length') else 0, 923 payload_checker._CheckOperation(*args)) 924 925 def testAllocBlockCounters(self): 926 """Tests _CheckMoveOperation().""" 927 payload_checker = checker.PayloadChecker(self.MockPayload()) 928 block_size = payload_checker.block_size 929 930 # Check allocation for block-aligned partition size, ensure it's integers. 931 result = payload_checker._AllocBlockCounters(16 * block_size) 932 self.assertEqual(16, len(result)) 933 self.assertEqual(int, type(result[0])) 934 935 # Check allocation of unaligned partition sizes. 936 result = payload_checker._AllocBlockCounters(16 * block_size - 1) 937 self.assertEqual(16, len(result)) 938 result = payload_checker._AllocBlockCounters(16 * block_size + 1) 939 self.assertEqual(17, len(result)) 940 941 def DoCheckOperationsTest(self, fail_nonexhaustive_full_update): 942 """Tests _CheckOperations().""" 943 # Generate a test payload. For this test, we only care about one 944 # (arbitrary) set of operations, so we'll only be generating kernel and 945 # test with them. 946 payload_gen = test_utils.PayloadGenerator() 947 948 block_size = test_utils.KiB(4) 949 payload_gen.SetBlockSize(block_size) 950 951 rootfs_part_size = test_utils.MiB(8) 952 953 # Fake rootfs operations in a full update, tampered with as required. 954 rootfs_op_type = common.OpType.REPLACE 955 rootfs_data_length = rootfs_part_size 956 if fail_nonexhaustive_full_update: 957 rootfs_data_length -= block_size 958 959 payload_gen.AddOperation(False, rootfs_op_type, 960 dst_extents=[(0, rootfs_data_length / block_size)], 961 data_offset=0, 962 data_length=rootfs_data_length) 963 964 # Create the test object. 965 payload_checker = _GetPayloadChecker(payload_gen.WriteToFile, 966 checker_init_dargs={ 967 'allow_unhashed': True}) 968 payload_checker.payload_type = checker._TYPE_FULL 969 report = checker._PayloadReport() 970 971 args = (payload_checker.payload.manifest.install_operations, report, 'foo', 972 0, rootfs_part_size, rootfs_part_size, rootfs_part_size, 0, False) 973 if fail_nonexhaustive_full_update: 974 self.assertRaises(PayloadError, payload_checker._CheckOperations, *args) 975 else: 976 self.assertEqual(rootfs_data_length, 977 payload_checker._CheckOperations(*args)) 978 979 def DoCheckSignaturesTest(self, fail_empty_sigs_blob, fail_missing_pseudo_op, 980 fail_mismatched_pseudo_op, fail_sig_missing_fields, 981 fail_unknown_sig_version, fail_incorrect_sig): 982 """Tests _CheckSignatures().""" 983 # Generate a test payload. For this test, we only care about the signature 984 # block and how it relates to the payload hash. Therefore, we're generating 985 # a random (otherwise useless) payload for this purpose. 986 payload_gen = test_utils.EnhancedPayloadGenerator() 987 block_size = test_utils.KiB(4) 988 payload_gen.SetBlockSize(block_size) 989 rootfs_part_size = test_utils.MiB(2) 990 kernel_part_size = test_utils.KiB(16) 991 payload_gen.SetPartInfo(False, True, rootfs_part_size, 992 hashlib.sha256('fake-new-rootfs-content').digest()) 993 payload_gen.SetPartInfo(True, True, kernel_part_size, 994 hashlib.sha256('fake-new-kernel-content').digest()) 995 payload_gen.SetMinorVersion(0) 996 payload_gen.AddOperationWithData( 997 False, common.OpType.REPLACE, 998 dst_extents=[(0, rootfs_part_size / block_size)], 999 data_blob=os.urandom(rootfs_part_size)) 1000 1001 do_forge_pseudo_op = (fail_missing_pseudo_op or fail_mismatched_pseudo_op) 1002 do_forge_sigs_data = (do_forge_pseudo_op or fail_empty_sigs_blob or 1003 fail_sig_missing_fields or fail_unknown_sig_version 1004 or fail_incorrect_sig) 1005 1006 sigs_data = None 1007 if do_forge_sigs_data: 1008 sigs_gen = test_utils.SignaturesGenerator() 1009 if not fail_empty_sigs_blob: 1010 if fail_sig_missing_fields: 1011 sig_data = None 1012 else: 1013 sig_data = test_utils.SignSha256('fake-payload-content', 1014 test_utils._PRIVKEY_FILE_NAME) 1015 sigs_gen.AddSig(5 if fail_unknown_sig_version else 1, sig_data) 1016 1017 sigs_data = sigs_gen.ToBinary() 1018 payload_gen.SetSignatures(payload_gen.curr_offset, len(sigs_data)) 1019 1020 if do_forge_pseudo_op: 1021 assert sigs_data is not None, 'should have forged signatures blob by now' 1022 sigs_len = len(sigs_data) 1023 payload_gen.AddOperation( 1024 False, common.OpType.REPLACE, 1025 data_offset=payload_gen.curr_offset / 2, 1026 data_length=sigs_len / 2, 1027 dst_extents=[(0, (sigs_len / 2 + block_size - 1) / block_size)]) 1028 1029 # Generate payload (complete w/ signature) and create the test object. 1030 payload_checker = _GetPayloadChecker( 1031 payload_gen.WriteToFileWithData, 1032 payload_gen_dargs={ 1033 'sigs_data': sigs_data, 1034 'privkey_file_name': test_utils._PRIVKEY_FILE_NAME, 1035 'do_add_pseudo_operation': not do_forge_pseudo_op}) 1036 payload_checker.payload_type = checker._TYPE_FULL 1037 report = checker._PayloadReport() 1038 1039 # We have to check the manifest first in order to set signature attributes. 1040 payload_checker._CheckManifest(report, rootfs_part_size, kernel_part_size) 1041 1042 should_fail = (fail_empty_sigs_blob or fail_missing_pseudo_op or 1043 fail_mismatched_pseudo_op or fail_sig_missing_fields or 1044 fail_unknown_sig_version or fail_incorrect_sig) 1045 args = (report, test_utils._PUBKEY_FILE_NAME) 1046 if should_fail: 1047 self.assertRaises(PayloadError, payload_checker._CheckSignatures, *args) 1048 else: 1049 self.assertIsNone(payload_checker._CheckSignatures(*args)) 1050 1051 def DoCheckManifestMinorVersionTest(self, minor_version, payload_type): 1052 """Parametric testing for CheckManifestMinorVersion(). 1053 1054 Args: 1055 minor_version: The payload minor version to test with. 1056 payload_type: The type of the payload we're testing, delta or full. 1057 """ 1058 # Create the test object. 1059 payload = self.MockPayload() 1060 payload.manifest.minor_version = minor_version 1061 payload_checker = checker.PayloadChecker(payload) 1062 payload_checker.payload_type = payload_type 1063 report = checker._PayloadReport() 1064 1065 should_succeed = ( 1066 (minor_version == 0 and payload_type == checker._TYPE_FULL) or 1067 (minor_version == 1 and payload_type == checker._TYPE_DELTA) or 1068 (minor_version == 2 and payload_type == checker._TYPE_DELTA) or 1069 (minor_version == 3 and payload_type == checker._TYPE_DELTA) or 1070 (minor_version == 4 and payload_type == checker._TYPE_DELTA) or 1071 (minor_version == 5 and payload_type == checker._TYPE_DELTA)) 1072 args = (report,) 1073 1074 if should_succeed: 1075 self.assertIsNone(payload_checker._CheckManifestMinorVersion(*args)) 1076 else: 1077 self.assertRaises(PayloadError, 1078 payload_checker._CheckManifestMinorVersion, *args) 1079 1080 def DoRunTest(self, rootfs_part_size_provided, kernel_part_size_provided, 1081 fail_wrong_payload_type, fail_invalid_block_size, 1082 fail_mismatched_block_size, fail_excess_data, 1083 fail_rootfs_part_size_exceeded, 1084 fail_kernel_part_size_exceeded): 1085 """Tests Run().""" 1086 # Generate a test payload. For this test, we generate a full update that 1087 # has sample kernel and rootfs operations. Since most testing is done with 1088 # internal PayloadChecker methods that are tested elsewhere, here we only 1089 # tamper with what's actually being manipulated and/or tested in the Run() 1090 # method itself. Note that the checker doesn't verify partition hashes, so 1091 # they're safe to fake. 1092 payload_gen = test_utils.EnhancedPayloadGenerator() 1093 block_size = test_utils.KiB(4) 1094 payload_gen.SetBlockSize(block_size) 1095 kernel_filesystem_size = test_utils.KiB(16) 1096 rootfs_filesystem_size = test_utils.MiB(2) 1097 payload_gen.SetPartInfo(False, True, rootfs_filesystem_size, 1098 hashlib.sha256('fake-new-rootfs-content').digest()) 1099 payload_gen.SetPartInfo(True, True, kernel_filesystem_size, 1100 hashlib.sha256('fake-new-kernel-content').digest()) 1101 payload_gen.SetMinorVersion(0) 1102 1103 rootfs_part_size = 0 1104 if rootfs_part_size_provided: 1105 rootfs_part_size = rootfs_filesystem_size + block_size 1106 rootfs_op_size = rootfs_part_size or rootfs_filesystem_size 1107 if fail_rootfs_part_size_exceeded: 1108 rootfs_op_size += block_size 1109 payload_gen.AddOperationWithData( 1110 False, common.OpType.REPLACE, 1111 dst_extents=[(0, rootfs_op_size / block_size)], 1112 data_blob=os.urandom(rootfs_op_size)) 1113 1114 kernel_part_size = 0 1115 if kernel_part_size_provided: 1116 kernel_part_size = kernel_filesystem_size + block_size 1117 kernel_op_size = kernel_part_size or kernel_filesystem_size 1118 if fail_kernel_part_size_exceeded: 1119 kernel_op_size += block_size 1120 payload_gen.AddOperationWithData( 1121 True, common.OpType.REPLACE, 1122 dst_extents=[(0, kernel_op_size / block_size)], 1123 data_blob=os.urandom(kernel_op_size)) 1124 1125 # Generate payload (complete w/ signature) and create the test object. 1126 if fail_invalid_block_size: 1127 use_block_size = block_size + 5 # Not a power of two. 1128 elif fail_mismatched_block_size: 1129 use_block_size = block_size * 2 # Different that payload stated. 1130 else: 1131 use_block_size = block_size 1132 1133 kwargs = { 1134 'payload_gen_dargs': { 1135 'privkey_file_name': test_utils._PRIVKEY_FILE_NAME, 1136 'do_add_pseudo_operation': True, 1137 'is_pseudo_in_kernel': True, 1138 'padding': os.urandom(1024) if fail_excess_data else None}, 1139 'checker_init_dargs': { 1140 'assert_type': 'delta' if fail_wrong_payload_type else 'full', 1141 'block_size': use_block_size}} 1142 if fail_invalid_block_size: 1143 self.assertRaises(PayloadError, _GetPayloadChecker, 1144 payload_gen.WriteToFileWithData, **kwargs) 1145 else: 1146 payload_checker = _GetPayloadChecker(payload_gen.WriteToFileWithData, 1147 **kwargs) 1148 1149 kwargs = {'pubkey_file_name': test_utils._PUBKEY_FILE_NAME, 1150 'rootfs_part_size': rootfs_part_size, 1151 'kernel_part_size': kernel_part_size} 1152 should_fail = (fail_wrong_payload_type or fail_mismatched_block_size or 1153 fail_excess_data or 1154 fail_rootfs_part_size_exceeded or 1155 fail_kernel_part_size_exceeded) 1156 if should_fail: 1157 self.assertRaises(PayloadError, payload_checker.Run, **kwargs) 1158 else: 1159 self.assertIsNone(payload_checker.Run(**kwargs)) 1160 1161 # This implements a generic API, hence the occasional unused args. 1162 # pylint: disable=W0613 1163 def ValidateCheckOperationTest(op_type_name, is_last, allow_signature, 1164 allow_unhashed, fail_src_extents, 1165 fail_dst_extents, 1166 fail_mismatched_data_offset_length, 1167 fail_missing_dst_extents, fail_src_length, 1168 fail_dst_length, fail_data_hash, 1169 fail_prev_data_offset, fail_bad_minor_version): 1170 """Returns True iff the combination of arguments represents a valid test.""" 1171 op_type = _OpTypeByName(op_type_name) 1172 1173 # REPLACE/REPLACE_BZ operations don't read data from src partition. They are 1174 # compatible with all valid minor versions, so we don't need to check that. 1175 if (op_type in (common.OpType.REPLACE, common.OpType.REPLACE_BZ) and ( 1176 fail_src_extents or fail_src_length or fail_bad_minor_version)): 1177 return False 1178 1179 # MOVE and SOURCE_COPY operations don't carry data. 1180 if (op_type in (common.OpType.MOVE, common.OpType.SOURCE_COPY) and ( 1181 fail_mismatched_data_offset_length or fail_data_hash or 1182 fail_prev_data_offset)): 1183 return False 1184 1185 return True 1186 1187 1188 def TestMethodBody(run_method_name, run_dargs): 1189 """Returns a function that invokes a named method with named arguments.""" 1190 return lambda self: getattr(self, run_method_name)(**run_dargs) 1191 1192 1193 def AddParametricTests(tested_method_name, arg_space, validate_func=None): 1194 """Enumerates and adds specific parametric tests to PayloadCheckerTest. 1195 1196 This function enumerates a space of test parameters (defined by arg_space), 1197 then binds a new, unique method name in PayloadCheckerTest to a test function 1198 that gets handed the said parameters. This is a preferable approach to doing 1199 the enumeration and invocation during the tests because this way each test is 1200 treated as a complete run by the unittest framework, and so benefits from the 1201 usual setUp/tearDown mechanics. 1202 1203 Args: 1204 tested_method_name: Name of the tested PayloadChecker method. 1205 arg_space: A dictionary containing variables (keys) and lists of values 1206 (values) associated with them. 1207 validate_func: A function used for validating test argument combinations. 1208 """ 1209 for value_tuple in itertools.product(*arg_space.itervalues()): 1210 run_dargs = dict(zip(arg_space.iterkeys(), value_tuple)) 1211 if validate_func and not validate_func(**run_dargs): 1212 continue 1213 run_method_name = 'Do%sTest' % tested_method_name 1214 test_method_name = 'test%s' % tested_method_name 1215 for arg_key, arg_val in run_dargs.iteritems(): 1216 if arg_val or type(arg_val) is int: 1217 test_method_name += '__%s=%s' % (arg_key, arg_val) 1218 setattr(PayloadCheckerTest, test_method_name, 1219 TestMethodBody(run_method_name, run_dargs)) 1220 1221 1222 def AddAllParametricTests(): 1223 """Enumerates and adds all parametric tests to PayloadCheckerTest.""" 1224 # Add all _CheckElem() test cases. 1225 AddParametricTests('AddElem', 1226 {'linebreak': (True, False), 1227 'indent': (0, 1, 2), 1228 'convert': (str, lambda s: s[::-1]), 1229 'is_present': (True, False), 1230 'is_mandatory': (True, False), 1231 'is_submsg': (True, False)}) 1232 1233 # Add all _Add{Mandatory,Optional}Field tests. 1234 AddParametricTests('AddField', 1235 {'is_mandatory': (True, False), 1236 'linebreak': (True, False), 1237 'indent': (0, 1, 2), 1238 'convert': (str, lambda s: s[::-1]), 1239 'is_present': (True, False)}) 1240 1241 # Add all _Add{Mandatory,Optional}SubMsg tests. 1242 AddParametricTests('AddSubMsg', 1243 {'is_mandatory': (True, False), 1244 'is_present': (True, False)}) 1245 1246 # Add all _CheckManifest() test cases. 1247 AddParametricTests('CheckManifest', 1248 {'fail_mismatched_block_size': (True, False), 1249 'fail_bad_sigs': (True, False), 1250 'fail_mismatched_oki_ori': (True, False), 1251 'fail_bad_oki': (True, False), 1252 'fail_bad_ori': (True, False), 1253 'fail_bad_nki': (True, False), 1254 'fail_bad_nri': (True, False), 1255 'fail_old_kernel_fs_size': (True, False), 1256 'fail_old_rootfs_fs_size': (True, False), 1257 'fail_new_kernel_fs_size': (True, False), 1258 'fail_new_rootfs_fs_size': (True, False)}) 1259 1260 # Add all _CheckOperation() test cases. 1261 AddParametricTests('CheckOperation', 1262 {'op_type_name': ('REPLACE', 'REPLACE_BZ', 'MOVE', 1263 'BSDIFF', 'SOURCE_COPY', 1264 'SOURCE_BSDIFF', 'PUFFDIFF', 1265 'BROTLI_BSDIFF'), 1266 'is_last': (True, False), 1267 'allow_signature': (True, False), 1268 'allow_unhashed': (True, False), 1269 'fail_src_extents': (True, False), 1270 'fail_dst_extents': (True, False), 1271 'fail_mismatched_data_offset_length': (True, False), 1272 'fail_missing_dst_extents': (True, False), 1273 'fail_src_length': (True, False), 1274 'fail_dst_length': (True, False), 1275 'fail_data_hash': (True, False), 1276 'fail_prev_data_offset': (True, False), 1277 'fail_bad_minor_version': (True, False)}, 1278 validate_func=ValidateCheckOperationTest) 1279 1280 # Add all _CheckOperations() test cases. 1281 AddParametricTests('CheckOperations', 1282 {'fail_nonexhaustive_full_update': (True, False)}) 1283 1284 # Add all _CheckOperations() test cases. 1285 AddParametricTests('CheckSignatures', 1286 {'fail_empty_sigs_blob': (True, False), 1287 'fail_missing_pseudo_op': (True, False), 1288 'fail_mismatched_pseudo_op': (True, False), 1289 'fail_sig_missing_fields': (True, False), 1290 'fail_unknown_sig_version': (True, False), 1291 'fail_incorrect_sig': (True, False)}) 1292 1293 # Add all _CheckManifestMinorVersion() test cases. 1294 AddParametricTests('CheckManifestMinorVersion', 1295 {'minor_version': (None, 0, 1, 2, 3, 4, 5, 555), 1296 'payload_type': (checker._TYPE_FULL, 1297 checker._TYPE_DELTA)}) 1298 1299 # Add all Run() test cases. 1300 AddParametricTests('Run', 1301 {'rootfs_part_size_provided': (True, False), 1302 'kernel_part_size_provided': (True, False), 1303 'fail_wrong_payload_type': (True, False), 1304 'fail_invalid_block_size': (True, False), 1305 'fail_mismatched_block_size': (True, False), 1306 'fail_excess_data': (True, False), 1307 'fail_rootfs_part_size_exceeded': (True, False), 1308 'fail_kernel_part_size_exceeded': (True, False)}) 1309 1310 1311 if __name__ == '__main__': 1312 AddAllParametricTests() 1313 unittest.main() 1314