1 #!/usr/bin/python2 2 # 3 # Copyright (c) 2013 The Chromium OS Authors. All rights reserved. 4 # Use of this source code is governed by a BSD-style license that can be 5 # found in the LICENSE file. 6 7 """Unit testing checker.py.""" 8 9 from __future__ import print_function 10 11 import array 12 import collections 13 import cStringIO 14 import hashlib 15 import itertools 16 import os 17 import unittest 18 19 # pylint cannot find mox. 20 # pylint: disable=F0401 21 import mox 22 23 import checker 24 import common 25 import payload as update_payload # Avoid name conflicts later. 26 import test_utils 27 import update_metadata_pb2 28 29 30 def _OpTypeByName(op_name): 31 op_name_to_type = { 32 'REPLACE': common.OpType.REPLACE, 33 'REPLACE_BZ': common.OpType.REPLACE_BZ, 34 'MOVE': common.OpType.MOVE, 35 'BSDIFF': common.OpType.BSDIFF, 36 'SOURCE_COPY': common.OpType.SOURCE_COPY, 37 'SOURCE_BSDIFF': common.OpType.SOURCE_BSDIFF, 38 'ZERO': common.OpType.ZERO, 39 'DISCARD': common.OpType.DISCARD, 40 'REPLACE_XZ': common.OpType.REPLACE_XZ, 41 'IMGDIFF': common.OpType.IMGDIFF, 42 } 43 return op_name_to_type[op_name] 44 45 46 def _GetPayloadChecker(payload_gen_write_to_file_func, payload_gen_dargs=None, 47 checker_init_dargs=None): 48 """Returns a payload checker from a given payload generator.""" 49 if payload_gen_dargs is None: 50 payload_gen_dargs = {} 51 if checker_init_dargs is None: 52 checker_init_dargs = {} 53 54 payload_file = cStringIO.StringIO() 55 payload_gen_write_to_file_func(payload_file, **payload_gen_dargs) 56 payload_file.seek(0) 57 payload = update_payload.Payload(payload_file) 58 payload.Init() 59 return checker.PayloadChecker(payload, **checker_init_dargs) 60 61 62 def _GetPayloadCheckerWithData(payload_gen): 63 """Returns a payload checker from a given payload generator.""" 64 payload_file = cStringIO.StringIO() 65 payload_gen.WriteToFile(payload_file) 66 payload_file.seek(0) 67 payload = update_payload.Payload(payload_file) 68 payload.Init() 69 return checker.PayloadChecker(payload) 70 71 72 # This class doesn't need an __init__(). 73 # pylint: disable=W0232 74 # Unit testing is all about running protected methods. 75 # pylint: disable=W0212 76 # Don't bark about missing members of classes you cannot import. 77 # pylint: disable=E1101 78 class PayloadCheckerTest(mox.MoxTestBase): 79 """Tests the PayloadChecker class. 80 81 In addition to ordinary testFoo() methods, which are automatically invoked by 82 the unittest framework, in this class we make use of DoBarTest() calls that 83 implement parametric tests of certain features. In order to invoke each test, 84 which embodies a unique combination of parameter values, as a complete unit 85 test, we perform explicit enumeration of the parameter space and create 86 individual invocation contexts for each, which are then bound as 87 testBar__param1=val1__param2=val2(). The enumeration of parameter spaces for 88 all such tests is done in AddAllParametricTests(). 89 """ 90 91 def MockPayload(self): 92 """Create a mock payload object, complete with a mock manifest.""" 93 payload = self.mox.CreateMock(update_payload.Payload) 94 payload.is_init = True 95 payload.manifest = self.mox.CreateMock( 96 update_metadata_pb2.DeltaArchiveManifest) 97 return payload 98 99 @staticmethod 100 def NewExtent(start_block, num_blocks): 101 """Returns an Extent message. 102 103 Each of the provided fields is set iff it is >= 0; otherwise, it's left at 104 its default state. 105 106 Args: 107 start_block: The starting block of the extent. 108 num_blocks: The number of blocks in the extent. 109 110 Returns: 111 An Extent message. 112 """ 113 ex = update_metadata_pb2.Extent() 114 if start_block >= 0: 115 ex.start_block = start_block 116 if num_blocks >= 0: 117 ex.num_blocks = num_blocks 118 return ex 119 120 @staticmethod 121 def NewExtentList(*args): 122 """Returns an list of extents. 123 124 Args: 125 *args: (start_block, num_blocks) pairs defining the extents. 126 127 Returns: 128 A list of Extent objects. 129 """ 130 ex_list = [] 131 for start_block, num_blocks in args: 132 ex_list.append(PayloadCheckerTest.NewExtent(start_block, num_blocks)) 133 return ex_list 134 135 @staticmethod 136 def AddToMessage(repeated_field, field_vals): 137 for field_val in field_vals: 138 new_field = repeated_field.add() 139 new_field.CopyFrom(field_val) 140 141 def SetupAddElemTest(self, is_present, is_submsg, convert=str, 142 linebreak=False, indent=0): 143 """Setup for testing of _CheckElem() and its derivatives. 144 145 Args: 146 is_present: Whether or not the element is found in the message. 147 is_submsg: Whether the element is a sub-message itself. 148 convert: A representation conversion function. 149 linebreak: Whether or not a linebreak is to be used in the report. 150 indent: Indentation used for the report. 151 152 Returns: 153 msg: A mock message object. 154 report: A mock report object. 155 subreport: A mock sub-report object. 156 name: An element name to check. 157 val: Expected element value. 158 """ 159 name = 'foo' 160 val = 'fake submsg' if is_submsg else 'fake field' 161 subreport = 'fake subreport' 162 163 # Create a mock message. 164 msg = self.mox.CreateMock(update_metadata_pb2._message.Message) 165 msg.HasField(name).AndReturn(is_present) 166 setattr(msg, name, val) 167 168 # Create a mock report. 169 report = self.mox.CreateMock(checker._PayloadReport) 170 if is_present: 171 if is_submsg: 172 report.AddSubReport(name).AndReturn(subreport) 173 else: 174 report.AddField(name, convert(val), linebreak=linebreak, indent=indent) 175 176 self.mox.ReplayAll() 177 return (msg, report, subreport, name, val) 178 179 def DoAddElemTest(self, is_present, is_mandatory, is_submsg, convert, 180 linebreak, indent): 181 """Parametric testing of _CheckElem(). 182 183 Args: 184 is_present: Whether or not the element is found in the message. 185 is_mandatory: Whether or not it's a mandatory element. 186 is_submsg: Whether the element is a sub-message itself. 187 convert: A representation conversion function. 188 linebreak: Whether or not a linebreak is to be used in the report. 189 indent: Indentation used for the report. 190 """ 191 msg, report, subreport, name, val = self.SetupAddElemTest( 192 is_present, is_submsg, convert, linebreak, indent) 193 194 args = (msg, name, report, is_mandatory, is_submsg) 195 kwargs = {'convert': convert, 'linebreak': linebreak, 'indent': indent} 196 if is_mandatory and not is_present: 197 self.assertRaises(update_payload.PayloadError, 198 checker.PayloadChecker._CheckElem, *args, **kwargs) 199 else: 200 ret_val, ret_subreport = checker.PayloadChecker._CheckElem(*args, 201 **kwargs) 202 self.assertEquals(val if is_present else None, ret_val) 203 self.assertEquals(subreport if is_present and is_submsg else None, 204 ret_subreport) 205 206 def DoAddFieldTest(self, is_mandatory, is_present, convert, linebreak, 207 indent): 208 """Parametric testing of _Check{Mandatory,Optional}Field(). 209 210 Args: 211 is_mandatory: Whether we're testing a mandatory call. 212 is_present: Whether or not the element is found in the message. 213 convert: A representation conversion function. 214 linebreak: Whether or not a linebreak is to be used in the report. 215 indent: Indentation used for the report. 216 """ 217 msg, report, _, name, val = self.SetupAddElemTest( 218 is_present, False, convert, linebreak, indent) 219 220 # Prepare for invocation of the tested method. 221 args = [msg, name, report] 222 kwargs = {'convert': convert, 'linebreak': linebreak, 'indent': indent} 223 if is_mandatory: 224 args.append('bar') 225 tested_func = checker.PayloadChecker._CheckMandatoryField 226 else: 227 tested_func = checker.PayloadChecker._CheckOptionalField 228 229 # Test the method call. 230 if is_mandatory and not is_present: 231 self.assertRaises(update_payload.PayloadError, tested_func, *args, 232 **kwargs) 233 else: 234 ret_val = tested_func(*args, **kwargs) 235 self.assertEquals(val if is_present else None, ret_val) 236 237 def DoAddSubMsgTest(self, is_mandatory, is_present): 238 """Parametrized testing of _Check{Mandatory,Optional}SubMsg(). 239 240 Args: 241 is_mandatory: Whether we're testing a mandatory call. 242 is_present: Whether or not the element is found in the message. 243 """ 244 msg, report, subreport, name, val = self.SetupAddElemTest(is_present, True) 245 246 # Prepare for invocation of the tested method. 247 args = [msg, name, report] 248 if is_mandatory: 249 args.append('bar') 250 tested_func = checker.PayloadChecker._CheckMandatorySubMsg 251 else: 252 tested_func = checker.PayloadChecker._CheckOptionalSubMsg 253 254 # Test the method call. 255 if is_mandatory and not is_present: 256 self.assertRaises(update_payload.PayloadError, tested_func, *args) 257 else: 258 ret_val, ret_subreport = tested_func(*args) 259 self.assertEquals(val if is_present else None, ret_val) 260 self.assertEquals(subreport if is_present else None, ret_subreport) 261 262 def testCheckPresentIff(self): 263 """Tests _CheckPresentIff().""" 264 self.assertIsNone(checker.PayloadChecker._CheckPresentIff( 265 None, None, 'foo', 'bar', 'baz')) 266 self.assertIsNone(checker.PayloadChecker._CheckPresentIff( 267 'a', 'b', 'foo', 'bar', 'baz')) 268 self.assertRaises(update_payload.PayloadError, 269 checker.PayloadChecker._CheckPresentIff, 270 'a', None, 'foo', 'bar', 'baz') 271 self.assertRaises(update_payload.PayloadError, 272 checker.PayloadChecker._CheckPresentIff, 273 None, 'b', 'foo', 'bar', 'baz') 274 275 def DoCheckSha256SignatureTest(self, expect_pass, expect_subprocess_call, 276 sig_data, sig_asn1_header, 277 returned_signed_hash, expected_signed_hash): 278 """Parametric testing of _CheckSha256SignatureTest(). 279 280 Args: 281 expect_pass: Whether or not it should pass. 282 expect_subprocess_call: Whether to expect the openssl call to happen. 283 sig_data: The signature raw data. 284 sig_asn1_header: The ASN1 header. 285 returned_signed_hash: The signed hash data retuned by openssl. 286 expected_signed_hash: The signed hash data to compare against. 287 """ 288 try: 289 # Stub out the subprocess invocation. 290 self.mox.StubOutWithMock(checker.PayloadChecker, '_Run') 291 if expect_subprocess_call: 292 checker.PayloadChecker._Run( 293 mox.IsA(list), send_data=sig_data).AndReturn( 294 (sig_asn1_header + returned_signed_hash, None)) 295 296 self.mox.ReplayAll() 297 if expect_pass: 298 self.assertIsNone(checker.PayloadChecker._CheckSha256Signature( 299 sig_data, 'foo', expected_signed_hash, 'bar')) 300 else: 301 self.assertRaises(update_payload.PayloadError, 302 checker.PayloadChecker._CheckSha256Signature, 303 sig_data, 'foo', expected_signed_hash, 'bar') 304 finally: 305 self.mox.UnsetStubs() 306 307 def testCheckSha256Signature_Pass(self): 308 """Tests _CheckSha256Signature(); pass case.""" 309 sig_data = 'fake-signature'.ljust(256) 310 signed_hash = hashlib.sha256('fake-data').digest() 311 self.DoCheckSha256SignatureTest(True, True, sig_data, 312 common.SIG_ASN1_HEADER, signed_hash, 313 signed_hash) 314 315 def testCheckSha256Signature_FailBadSignature(self): 316 """Tests _CheckSha256Signature(); fails due to malformed signature.""" 317 sig_data = 'fake-signature' # Malformed (not 256 bytes in length). 318 signed_hash = hashlib.sha256('fake-data').digest() 319 self.DoCheckSha256SignatureTest(False, False, sig_data, 320 common.SIG_ASN1_HEADER, signed_hash, 321 signed_hash) 322 323 def testCheckSha256Signature_FailBadOutputLength(self): 324 """Tests _CheckSha256Signature(); fails due to unexpected output length.""" 325 sig_data = 'fake-signature'.ljust(256) 326 signed_hash = 'fake-hash' # Malformed (not 32 bytes in length). 327 self.DoCheckSha256SignatureTest(False, True, sig_data, 328 common.SIG_ASN1_HEADER, signed_hash, 329 signed_hash) 330 331 def testCheckSha256Signature_FailBadAsnHeader(self): 332 """Tests _CheckSha256Signature(); fails due to bad ASN1 header.""" 333 sig_data = 'fake-signature'.ljust(256) 334 signed_hash = hashlib.sha256('fake-data').digest() 335 bad_asn1_header = 'bad-asn-header'.ljust(len(common.SIG_ASN1_HEADER)) 336 self.DoCheckSha256SignatureTest(False, True, sig_data, bad_asn1_header, 337 signed_hash, signed_hash) 338 339 def testCheckSha256Signature_FailBadHash(self): 340 """Tests _CheckSha256Signature(); fails due to bad hash returned.""" 341 sig_data = 'fake-signature'.ljust(256) 342 expected_signed_hash = hashlib.sha256('fake-data').digest() 343 returned_signed_hash = hashlib.sha256('bad-fake-data').digest() 344 self.DoCheckSha256SignatureTest(False, True, sig_data, 345 common.SIG_ASN1_HEADER, 346 expected_signed_hash, returned_signed_hash) 347 348 def testCheckBlocksFitLength_Pass(self): 349 """Tests _CheckBlocksFitLength(); pass case.""" 350 self.assertIsNone(checker.PayloadChecker._CheckBlocksFitLength( 351 64, 4, 16, 'foo')) 352 self.assertIsNone(checker.PayloadChecker._CheckBlocksFitLength( 353 60, 4, 16, 'foo')) 354 self.assertIsNone(checker.PayloadChecker._CheckBlocksFitLength( 355 49, 4, 16, 'foo')) 356 self.assertIsNone(checker.PayloadChecker._CheckBlocksFitLength( 357 48, 3, 16, 'foo')) 358 359 def testCheckBlocksFitLength_TooManyBlocks(self): 360 """Tests _CheckBlocksFitLength(); fails due to excess blocks.""" 361 self.assertRaises(update_payload.PayloadError, 362 checker.PayloadChecker._CheckBlocksFitLength, 363 64, 5, 16, 'foo') 364 self.assertRaises(update_payload.PayloadError, 365 checker.PayloadChecker._CheckBlocksFitLength, 366 60, 5, 16, 'foo') 367 self.assertRaises(update_payload.PayloadError, 368 checker.PayloadChecker._CheckBlocksFitLength, 369 49, 5, 16, 'foo') 370 self.assertRaises(update_payload.PayloadError, 371 checker.PayloadChecker._CheckBlocksFitLength, 372 48, 4, 16, 'foo') 373 374 def testCheckBlocksFitLength_TooFewBlocks(self): 375 """Tests _CheckBlocksFitLength(); fails due to insufficient blocks.""" 376 self.assertRaises(update_payload.PayloadError, 377 checker.PayloadChecker._CheckBlocksFitLength, 378 64, 3, 16, 'foo') 379 self.assertRaises(update_payload.PayloadError, 380 checker.PayloadChecker._CheckBlocksFitLength, 381 60, 3, 16, 'foo') 382 self.assertRaises(update_payload.PayloadError, 383 checker.PayloadChecker._CheckBlocksFitLength, 384 49, 3, 16, 'foo') 385 self.assertRaises(update_payload.PayloadError, 386 checker.PayloadChecker._CheckBlocksFitLength, 387 48, 2, 16, 'foo') 388 389 def DoCheckManifestTest(self, fail_mismatched_block_size, fail_bad_sigs, 390 fail_mismatched_oki_ori, fail_bad_oki, fail_bad_ori, 391 fail_bad_nki, fail_bad_nri, fail_old_kernel_fs_size, 392 fail_old_rootfs_fs_size, fail_new_kernel_fs_size, 393 fail_new_rootfs_fs_size): 394 """Parametric testing of _CheckManifest(). 395 396 Args: 397 fail_mismatched_block_size: Simulate a missing block_size field. 398 fail_bad_sigs: Make signatures descriptor inconsistent. 399 fail_mismatched_oki_ori: Make old rootfs/kernel info partially present. 400 fail_bad_oki: Tamper with old kernel info. 401 fail_bad_ori: Tamper with old rootfs info. 402 fail_bad_nki: Tamper with new kernel info. 403 fail_bad_nri: Tamper with new rootfs info. 404 fail_old_kernel_fs_size: Make old kernel fs size too big. 405 fail_old_rootfs_fs_size: Make old rootfs fs size too big. 406 fail_new_kernel_fs_size: Make new kernel fs size too big. 407 fail_new_rootfs_fs_size: Make new rootfs fs size too big. 408 """ 409 # Generate a test payload. For this test, we only care about the manifest 410 # and don't need any data blobs, hence we can use a plain paylaod generator 411 # (which also gives us more control on things that can be screwed up). 412 payload_gen = test_utils.PayloadGenerator() 413 414 # Tamper with block size, if required. 415 if fail_mismatched_block_size: 416 payload_gen.SetBlockSize(test_utils.KiB(1)) 417 else: 418 payload_gen.SetBlockSize(test_utils.KiB(4)) 419 420 # Add some operations. 421 payload_gen.AddOperation(False, common.OpType.MOVE, 422 src_extents=[(0, 16), (16, 497)], 423 dst_extents=[(16, 496), (0, 16)]) 424 payload_gen.AddOperation(True, common.OpType.MOVE, 425 src_extents=[(0, 8), (8, 8)], 426 dst_extents=[(8, 8), (0, 8)]) 427 428 # Set an invalid signatures block (offset but no size), if required. 429 if fail_bad_sigs: 430 payload_gen.SetSignatures(32, None) 431 432 # Set partition / filesystem sizes. 433 rootfs_part_size = test_utils.MiB(8) 434 kernel_part_size = test_utils.KiB(512) 435 old_rootfs_fs_size = new_rootfs_fs_size = rootfs_part_size 436 old_kernel_fs_size = new_kernel_fs_size = kernel_part_size 437 if fail_old_kernel_fs_size: 438 old_kernel_fs_size += 100 439 if fail_old_rootfs_fs_size: 440 old_rootfs_fs_size += 100 441 if fail_new_kernel_fs_size: 442 new_kernel_fs_size += 100 443 if fail_new_rootfs_fs_size: 444 new_rootfs_fs_size += 100 445 446 # Add old kernel/rootfs partition info, as required. 447 if fail_mismatched_oki_ori or fail_old_kernel_fs_size or fail_bad_oki: 448 oki_hash = (None if fail_bad_oki 449 else hashlib.sha256('fake-oki-content').digest()) 450 payload_gen.SetPartInfo(True, False, old_kernel_fs_size, oki_hash) 451 if not fail_mismatched_oki_ori and (fail_old_rootfs_fs_size or 452 fail_bad_ori): 453 ori_hash = (None if fail_bad_ori 454 else hashlib.sha256('fake-ori-content').digest()) 455 payload_gen.SetPartInfo(False, False, old_rootfs_fs_size, ori_hash) 456 457 # Add new kernel/rootfs partition info. 458 payload_gen.SetPartInfo( 459 True, True, new_kernel_fs_size, 460 None if fail_bad_nki else hashlib.sha256('fake-nki-content').digest()) 461 payload_gen.SetPartInfo( 462 False, True, new_rootfs_fs_size, 463 None if fail_bad_nri else hashlib.sha256('fake-nri-content').digest()) 464 465 # Set the minor version. 466 payload_gen.SetMinorVersion(0) 467 468 # Create the test object. 469 payload_checker = _GetPayloadChecker(payload_gen.WriteToFile) 470 report = checker._PayloadReport() 471 472 should_fail = (fail_mismatched_block_size or fail_bad_sigs or 473 fail_mismatched_oki_ori or fail_bad_oki or fail_bad_ori or 474 fail_bad_nki or fail_bad_nri or fail_old_kernel_fs_size or 475 fail_old_rootfs_fs_size or fail_new_kernel_fs_size or 476 fail_new_rootfs_fs_size) 477 if should_fail: 478 self.assertRaises(update_payload.PayloadError, 479 payload_checker._CheckManifest, report, 480 rootfs_part_size, kernel_part_size) 481 else: 482 self.assertIsNone(payload_checker._CheckManifest(report, 483 rootfs_part_size, 484 kernel_part_size)) 485 486 def testCheckLength(self): 487 """Tests _CheckLength().""" 488 payload_checker = checker.PayloadChecker(self.MockPayload()) 489 block_size = payload_checker.block_size 490 491 # Passes. 492 self.assertIsNone(payload_checker._CheckLength( 493 int(3.5 * block_size), 4, 'foo', 'bar')) 494 # Fails, too few blocks. 495 self.assertRaises(update_payload.PayloadError, 496 payload_checker._CheckLength, 497 int(3.5 * block_size), 3, 'foo', 'bar') 498 # Fails, too many blocks. 499 self.assertRaises(update_payload.PayloadError, 500 payload_checker._CheckLength, 501 int(3.5 * block_size), 5, 'foo', 'bar') 502 503 def testCheckExtents(self): 504 """Tests _CheckExtents().""" 505 payload_checker = checker.PayloadChecker(self.MockPayload()) 506 block_size = payload_checker.block_size 507 508 # Passes w/ all real extents. 509 extents = self.NewExtentList((0, 4), (8, 3), (1024, 16)) 510 self.assertEquals( 511 23, 512 payload_checker._CheckExtents(extents, (1024 + 16) * block_size, 513 collections.defaultdict(int), 'foo')) 514 515 # Passes w/ pseudo-extents (aka sparse holes). 516 extents = self.NewExtentList((0, 4), (common.PSEUDO_EXTENT_MARKER, 5), 517 (8, 3)) 518 self.assertEquals( 519 12, 520 payload_checker._CheckExtents(extents, (1024 + 16) * block_size, 521 collections.defaultdict(int), 'foo', 522 allow_pseudo=True)) 523 524 # Passes w/ pseudo-extent due to a signature. 525 extents = self.NewExtentList((common.PSEUDO_EXTENT_MARKER, 2)) 526 self.assertEquals( 527 2, 528 payload_checker._CheckExtents(extents, (1024 + 16) * block_size, 529 collections.defaultdict(int), 'foo', 530 allow_signature=True)) 531 532 # Fails, extent missing a start block. 533 extents = self.NewExtentList((-1, 4), (8, 3), (1024, 16)) 534 self.assertRaises( 535 update_payload.PayloadError, payload_checker._CheckExtents, 536 extents, (1024 + 16) * block_size, collections.defaultdict(int), 537 'foo') 538 539 # Fails, extent missing block count. 540 extents = self.NewExtentList((0, -1), (8, 3), (1024, 16)) 541 self.assertRaises( 542 update_payload.PayloadError, payload_checker._CheckExtents, 543 extents, (1024 + 16) * block_size, collections.defaultdict(int), 544 'foo') 545 546 # Fails, extent has zero blocks. 547 extents = self.NewExtentList((0, 4), (8, 3), (1024, 0)) 548 self.assertRaises( 549 update_payload.PayloadError, payload_checker._CheckExtents, 550 extents, (1024 + 16) * block_size, collections.defaultdict(int), 551 'foo') 552 553 # Fails, extent exceeds partition boundaries. 554 extents = self.NewExtentList((0, 4), (8, 3), (1024, 16)) 555 self.assertRaises( 556 update_payload.PayloadError, payload_checker._CheckExtents, 557 extents, (1024 + 15) * block_size, collections.defaultdict(int), 558 'foo') 559 560 def testCheckReplaceOperation(self): 561 """Tests _CheckReplaceOperation() where op.type == REPLACE.""" 562 payload_checker = checker.PayloadChecker(self.MockPayload()) 563 block_size = payload_checker.block_size 564 data_length = 10000 565 566 op = self.mox.CreateMock( 567 update_metadata_pb2.InstallOperation) 568 op.type = common.OpType.REPLACE 569 570 # Pass. 571 op.src_extents = [] 572 self.assertIsNone( 573 payload_checker._CheckReplaceOperation( 574 op, data_length, (data_length + block_size - 1) / block_size, 575 'foo')) 576 577 # Fail, src extents founds. 578 op.src_extents = ['bar'] 579 self.assertRaises( 580 update_payload.PayloadError, 581 payload_checker._CheckReplaceOperation, 582 op, data_length, (data_length + block_size - 1) / block_size, 'foo') 583 584 # Fail, missing data. 585 op.src_extents = [] 586 self.assertRaises( 587 update_payload.PayloadError, 588 payload_checker._CheckReplaceOperation, 589 op, None, (data_length + block_size - 1) / block_size, 'foo') 590 591 # Fail, length / block number mismatch. 592 op.src_extents = ['bar'] 593 self.assertRaises( 594 update_payload.PayloadError, 595 payload_checker._CheckReplaceOperation, 596 op, data_length, (data_length + block_size - 1) / block_size + 1, 'foo') 597 598 def testCheckReplaceBzOperation(self): 599 """Tests _CheckReplaceOperation() where op.type == REPLACE_BZ.""" 600 payload_checker = checker.PayloadChecker(self.MockPayload()) 601 block_size = payload_checker.block_size 602 data_length = block_size * 3 603 604 op = self.mox.CreateMock( 605 update_metadata_pb2.InstallOperation) 606 op.type = common.OpType.REPLACE_BZ 607 608 # Pass. 609 op.src_extents = [] 610 self.assertIsNone( 611 payload_checker._CheckReplaceOperation( 612 op, data_length, (data_length + block_size - 1) / block_size + 5, 613 'foo')) 614 615 # Fail, src extents founds. 616 op.src_extents = ['bar'] 617 self.assertRaises( 618 update_payload.PayloadError, 619 payload_checker._CheckReplaceOperation, 620 op, data_length, (data_length + block_size - 1) / block_size + 5, 'foo') 621 622 # Fail, missing data. 623 op.src_extents = [] 624 self.assertRaises( 625 update_payload.PayloadError, 626 payload_checker._CheckReplaceOperation, 627 op, None, (data_length + block_size - 1) / block_size, 'foo') 628 629 # Fail, too few blocks to justify BZ. 630 op.src_extents = [] 631 self.assertRaises( 632 update_payload.PayloadError, 633 payload_checker._CheckReplaceOperation, 634 op, data_length, (data_length + block_size - 1) / block_size, 'foo') 635 636 def testCheckMoveOperation_Pass(self): 637 """Tests _CheckMoveOperation(); pass case.""" 638 payload_checker = checker.PayloadChecker(self.MockPayload()) 639 op = update_metadata_pb2.InstallOperation() 640 op.type = common.OpType.MOVE 641 642 self.AddToMessage(op.src_extents, 643 self.NewExtentList((1, 4), (12, 2), (1024, 128))) 644 self.AddToMessage(op.dst_extents, 645 self.NewExtentList((16, 128), (512, 6))) 646 self.assertIsNone( 647 payload_checker._CheckMoveOperation(op, None, 134, 134, 'foo')) 648 649 def testCheckMoveOperation_FailContainsData(self): 650 """Tests _CheckMoveOperation(); fails, message contains data.""" 651 payload_checker = checker.PayloadChecker(self.MockPayload()) 652 op = update_metadata_pb2.InstallOperation() 653 op.type = common.OpType.MOVE 654 655 self.AddToMessage(op.src_extents, 656 self.NewExtentList((1, 4), (12, 2), (1024, 128))) 657 self.AddToMessage(op.dst_extents, 658 self.NewExtentList((16, 128), (512, 6))) 659 self.assertRaises( 660 update_payload.PayloadError, 661 payload_checker._CheckMoveOperation, 662 op, 1024, 134, 134, 'foo') 663 664 def testCheckMoveOperation_FailInsufficientSrcBlocks(self): 665 """Tests _CheckMoveOperation(); fails, not enough actual src blocks.""" 666 payload_checker = checker.PayloadChecker(self.MockPayload()) 667 op = update_metadata_pb2.InstallOperation() 668 op.type = common.OpType.MOVE 669 670 self.AddToMessage(op.src_extents, 671 self.NewExtentList((1, 4), (12, 2), (1024, 127))) 672 self.AddToMessage(op.dst_extents, 673 self.NewExtentList((16, 128), (512, 6))) 674 self.assertRaises( 675 update_payload.PayloadError, 676 payload_checker._CheckMoveOperation, 677 op, None, 134, 134, 'foo') 678 679 def testCheckMoveOperation_FailInsufficientDstBlocks(self): 680 """Tests _CheckMoveOperation(); fails, not enough actual dst blocks.""" 681 payload_checker = checker.PayloadChecker(self.MockPayload()) 682 op = update_metadata_pb2.InstallOperation() 683 op.type = common.OpType.MOVE 684 685 self.AddToMessage(op.src_extents, 686 self.NewExtentList((1, 4), (12, 2), (1024, 128))) 687 self.AddToMessage(op.dst_extents, 688 self.NewExtentList((16, 128), (512, 5))) 689 self.assertRaises( 690 update_payload.PayloadError, 691 payload_checker._CheckMoveOperation, 692 op, None, 134, 134, 'foo') 693 694 def testCheckMoveOperation_FailExcessSrcBlocks(self): 695 """Tests _CheckMoveOperation(); fails, too many actual src blocks.""" 696 payload_checker = checker.PayloadChecker(self.MockPayload()) 697 op = update_metadata_pb2.InstallOperation() 698 op.type = common.OpType.MOVE 699 700 self.AddToMessage(op.src_extents, 701 self.NewExtentList((1, 4), (12, 2), (1024, 128))) 702 self.AddToMessage(op.dst_extents, 703 self.NewExtentList((16, 128), (512, 5))) 704 self.assertRaises( 705 update_payload.PayloadError, 706 payload_checker._CheckMoveOperation, 707 op, None, 134, 134, 'foo') 708 self.AddToMessage(op.src_extents, 709 self.NewExtentList((1, 4), (12, 2), (1024, 129))) 710 self.AddToMessage(op.dst_extents, 711 self.NewExtentList((16, 128), (512, 6))) 712 self.assertRaises( 713 update_payload.PayloadError, 714 payload_checker._CheckMoveOperation, 715 op, None, 134, 134, 'foo') 716 717 def testCheckMoveOperation_FailExcessDstBlocks(self): 718 """Tests _CheckMoveOperation(); fails, too many actual dst blocks.""" 719 payload_checker = checker.PayloadChecker(self.MockPayload()) 720 op = update_metadata_pb2.InstallOperation() 721 op.type = common.OpType.MOVE 722 723 self.AddToMessage(op.src_extents, 724 self.NewExtentList((1, 4), (12, 2), (1024, 128))) 725 self.AddToMessage(op.dst_extents, 726 self.NewExtentList((16, 128), (512, 7))) 727 self.assertRaises( 728 update_payload.PayloadError, 729 payload_checker._CheckMoveOperation, 730 op, None, 134, 134, 'foo') 731 732 def testCheckMoveOperation_FailStagnantBlocks(self): 733 """Tests _CheckMoveOperation(); fails, there are blocks that do not move.""" 734 payload_checker = checker.PayloadChecker(self.MockPayload()) 735 op = update_metadata_pb2.InstallOperation() 736 op.type = common.OpType.MOVE 737 738 self.AddToMessage(op.src_extents, 739 self.NewExtentList((1, 4), (12, 2), (1024, 128))) 740 self.AddToMessage(op.dst_extents, 741 self.NewExtentList((8, 128), (512, 6))) 742 self.assertRaises( 743 update_payload.PayloadError, 744 payload_checker._CheckMoveOperation, 745 op, None, 134, 134, 'foo') 746 747 def testCheckMoveOperation_FailZeroStartBlock(self): 748 """Tests _CheckMoveOperation(); fails, has extent with start block 0.""" 749 payload_checker = checker.PayloadChecker(self.MockPayload()) 750 op = update_metadata_pb2.InstallOperation() 751 op.type = common.OpType.MOVE 752 753 self.AddToMessage(op.src_extents, 754 self.NewExtentList((0, 4), (12, 2), (1024, 128))) 755 self.AddToMessage(op.dst_extents, 756 self.NewExtentList((8, 128), (512, 6))) 757 self.assertRaises( 758 update_payload.PayloadError, 759 payload_checker._CheckMoveOperation, 760 op, None, 134, 134, 'foo') 761 762 self.AddToMessage(op.src_extents, 763 self.NewExtentList((1, 4), (12, 2), (1024, 128))) 764 self.AddToMessage(op.dst_extents, 765 self.NewExtentList((0, 128), (512, 6))) 766 self.assertRaises( 767 update_payload.PayloadError, 768 payload_checker._CheckMoveOperation, 769 op, None, 134, 134, 'foo') 770 771 def testCheckAnyDiff(self): 772 """Tests _CheckAnyDiffOperation().""" 773 payload_checker = checker.PayloadChecker(self.MockPayload()) 774 775 # Pass. 776 self.assertIsNone( 777 payload_checker._CheckAnyDiffOperation(10000, 3, 'foo')) 778 779 # Fail, missing data blob. 780 self.assertRaises( 781 update_payload.PayloadError, 782 payload_checker._CheckAnyDiffOperation, 783 None, 3, 'foo') 784 785 # Fail, too big of a diff blob (unjustified). 786 self.assertRaises( 787 update_payload.PayloadError, 788 payload_checker._CheckAnyDiffOperation, 789 10000, 2, 'foo') 790 791 def testCheckSourceCopyOperation_Pass(self): 792 """Tests _CheckSourceCopyOperation(); pass case.""" 793 payload_checker = checker.PayloadChecker(self.MockPayload()) 794 self.assertIsNone( 795 payload_checker._CheckSourceCopyOperation(None, 134, 134, 'foo')) 796 797 def testCheckSourceCopyOperation_FailContainsData(self): 798 """Tests _CheckSourceCopyOperation(); message contains data.""" 799 payload_checker = checker.PayloadChecker(self.MockPayload()) 800 self.assertRaises(update_payload.PayloadError, 801 payload_checker._CheckSourceCopyOperation, 802 134, 0, 0, 'foo') 803 804 def testCheckSourceCopyOperation_FailBlockCountsMismatch(self): 805 """Tests _CheckSourceCopyOperation(); src and dst block totals not equal.""" 806 payload_checker = checker.PayloadChecker(self.MockPayload()) 807 self.assertRaises(update_payload.PayloadError, 808 payload_checker._CheckSourceCopyOperation, 809 None, 0, 1, 'foo') 810 811 def DoCheckOperationTest(self, op_type_name, is_last, allow_signature, 812 allow_unhashed, fail_src_extents, fail_dst_extents, 813 fail_mismatched_data_offset_length, 814 fail_missing_dst_extents, fail_src_length, 815 fail_dst_length, fail_data_hash, 816 fail_prev_data_offset, fail_bad_minor_version): 817 """Parametric testing of _CheckOperation(). 818 819 Args: 820 op_type_name: 'REPLACE', 'REPLACE_BZ', 'MOVE', 'BSDIFF', 'SOURCE_COPY', 821 or 'SOURCE_BSDIFF'. 822 is_last: Whether we're testing the last operation in a sequence. 823 allow_signature: Whether we're testing a signature-capable operation. 824 allow_unhashed: Whether we're allowing to not hash the data. 825 fail_src_extents: Tamper with src extents. 826 fail_dst_extents: Tamper with dst extents. 827 fail_mismatched_data_offset_length: Make data_{offset,length} 828 inconsistent. 829 fail_missing_dst_extents: Do not include dst extents. 830 fail_src_length: Make src length inconsistent. 831 fail_dst_length: Make dst length inconsistent. 832 fail_data_hash: Tamper with the data blob hash. 833 fail_prev_data_offset: Make data space uses incontiguous. 834 fail_bad_minor_version: Make minor version incompatible with op. 835 """ 836 op_type = _OpTypeByName(op_type_name) 837 838 # Create the test object. 839 payload = self.MockPayload() 840 payload_checker = checker.PayloadChecker(payload, 841 allow_unhashed=allow_unhashed) 842 block_size = payload_checker.block_size 843 844 # Create auxiliary arguments. 845 old_part_size = test_utils.MiB(4) 846 new_part_size = test_utils.MiB(8) 847 old_block_counters = array.array( 848 'B', [0] * ((old_part_size + block_size - 1) / block_size)) 849 new_block_counters = array.array( 850 'B', [0] * ((new_part_size + block_size - 1) / block_size)) 851 prev_data_offset = 1876 852 blob_hash_counts = collections.defaultdict(int) 853 854 # Create the operation object for the test. 855 op = update_metadata_pb2.InstallOperation() 856 op.type = op_type 857 858 total_src_blocks = 0 859 if op_type in (common.OpType.MOVE, common.OpType.BSDIFF, 860 common.OpType.SOURCE_COPY, common.OpType.SOURCE_BSDIFF): 861 if fail_src_extents: 862 self.AddToMessage(op.src_extents, 863 self.NewExtentList((1, 0))) 864 else: 865 self.AddToMessage(op.src_extents, 866 self.NewExtentList((1, 16))) 867 total_src_blocks = 16 868 869 if op_type in (common.OpType.REPLACE, common.OpType.REPLACE_BZ): 870 payload_checker.minor_version = 0 871 elif op_type in (common.OpType.MOVE, common.OpType.BSDIFF): 872 payload_checker.minor_version = 2 if fail_bad_minor_version else 1 873 elif op_type in (common.OpType.SOURCE_COPY, common.OpType.SOURCE_BSDIFF): 874 payload_checker.minor_version = 1 if fail_bad_minor_version else 2 875 876 if op_type not in (common.OpType.MOVE, common.OpType.SOURCE_COPY): 877 if not fail_mismatched_data_offset_length: 878 op.data_length = 16 * block_size - 8 879 if fail_prev_data_offset: 880 op.data_offset = prev_data_offset + 16 881 else: 882 op.data_offset = prev_data_offset 883 884 fake_data = 'fake-data'.ljust(op.data_length) 885 if not (allow_unhashed or (is_last and allow_signature and 886 op_type == common.OpType.REPLACE)): 887 if not fail_data_hash: 888 # Create a valid data blob hash. 889 op.data_sha256_hash = hashlib.sha256(fake_data).digest() 890 payload.ReadDataBlob(op.data_offset, op.data_length).AndReturn( 891 fake_data) 892 elif fail_data_hash: 893 # Create an invalid data blob hash. 894 op.data_sha256_hash = hashlib.sha256( 895 fake_data.replace(' ', '-')).digest() 896 payload.ReadDataBlob(op.data_offset, op.data_length).AndReturn( 897 fake_data) 898 899 total_dst_blocks = 0 900 if not fail_missing_dst_extents: 901 total_dst_blocks = 16 902 if fail_dst_extents: 903 self.AddToMessage(op.dst_extents, 904 self.NewExtentList((4, 16), (32, 0))) 905 else: 906 self.AddToMessage(op.dst_extents, 907 self.NewExtentList((4, 8), (64, 8))) 908 909 if total_src_blocks: 910 if fail_src_length: 911 op.src_length = total_src_blocks * block_size + 8 912 else: 913 op.src_length = total_src_blocks * block_size 914 elif fail_src_length: 915 # Add an orphaned src_length. 916 op.src_length = 16 917 918 if total_dst_blocks: 919 if fail_dst_length: 920 op.dst_length = total_dst_blocks * block_size + 8 921 else: 922 op.dst_length = total_dst_blocks * block_size 923 924 self.mox.ReplayAll() 925 should_fail = (fail_src_extents or fail_dst_extents or 926 fail_mismatched_data_offset_length or 927 fail_missing_dst_extents or fail_src_length or 928 fail_dst_length or fail_data_hash or fail_prev_data_offset or 929 fail_bad_minor_version) 930 args = (op, 'foo', is_last, old_block_counters, new_block_counters, 931 old_part_size, new_part_size, prev_data_offset, allow_signature, 932 blob_hash_counts) 933 if should_fail: 934 self.assertRaises(update_payload.PayloadError, 935 payload_checker._CheckOperation, *args) 936 else: 937 self.assertEqual(op.data_length if op.HasField('data_length') else 0, 938 payload_checker._CheckOperation(*args)) 939 940 def testAllocBlockCounters(self): 941 """Tests _CheckMoveOperation().""" 942 payload_checker = checker.PayloadChecker(self.MockPayload()) 943 block_size = payload_checker.block_size 944 945 # Check allocation for block-aligned partition size, ensure it's integers. 946 result = payload_checker._AllocBlockCounters(16 * block_size) 947 self.assertEqual(16, len(result)) 948 self.assertEqual(int, type(result[0])) 949 950 # Check allocation of unaligned partition sizes. 951 result = payload_checker._AllocBlockCounters(16 * block_size - 1) 952 self.assertEqual(16, len(result)) 953 result = payload_checker._AllocBlockCounters(16 * block_size + 1) 954 self.assertEqual(17, len(result)) 955 956 def DoCheckOperationsTest(self, fail_nonexhaustive_full_update): 957 # Generate a test payload. For this test, we only care about one 958 # (arbitrary) set of operations, so we'll only be generating kernel and 959 # test with them. 960 payload_gen = test_utils.PayloadGenerator() 961 962 block_size = test_utils.KiB(4) 963 payload_gen.SetBlockSize(block_size) 964 965 rootfs_part_size = test_utils.MiB(8) 966 967 # Fake rootfs operations in a full update, tampered with as required. 968 rootfs_op_type = common.OpType.REPLACE 969 rootfs_data_length = rootfs_part_size 970 if fail_nonexhaustive_full_update: 971 rootfs_data_length -= block_size 972 973 payload_gen.AddOperation(False, rootfs_op_type, 974 dst_extents=[(0, rootfs_data_length / block_size)], 975 data_offset=0, 976 data_length=rootfs_data_length) 977 978 # Create the test object. 979 payload_checker = _GetPayloadChecker(payload_gen.WriteToFile, 980 checker_init_dargs={ 981 'allow_unhashed': True}) 982 payload_checker.payload_type = checker._TYPE_FULL 983 report = checker._PayloadReport() 984 985 args = (payload_checker.payload.manifest.install_operations, report, 986 'foo', 0, rootfs_part_size, rootfs_part_size, 0, False) 987 if fail_nonexhaustive_full_update: 988 self.assertRaises(update_payload.PayloadError, 989 payload_checker._CheckOperations, *args) 990 else: 991 self.assertEqual(rootfs_data_length, 992 payload_checker._CheckOperations(*args)) 993 994 def DoCheckSignaturesTest(self, fail_empty_sigs_blob, fail_missing_pseudo_op, 995 fail_mismatched_pseudo_op, fail_sig_missing_fields, 996 fail_unknown_sig_version, fail_incorrect_sig): 997 # Generate a test payload. For this test, we only care about the signature 998 # block and how it relates to the payload hash. Therefore, we're generating 999 # a random (otherwise useless) payload for this purpose. 1000 payload_gen = test_utils.EnhancedPayloadGenerator() 1001 block_size = test_utils.KiB(4) 1002 payload_gen.SetBlockSize(block_size) 1003 rootfs_part_size = test_utils.MiB(2) 1004 kernel_part_size = test_utils.KiB(16) 1005 payload_gen.SetPartInfo(False, True, rootfs_part_size, 1006 hashlib.sha256('fake-new-rootfs-content').digest()) 1007 payload_gen.SetPartInfo(True, True, kernel_part_size, 1008 hashlib.sha256('fake-new-kernel-content').digest()) 1009 payload_gen.SetMinorVersion(0) 1010 payload_gen.AddOperationWithData( 1011 False, common.OpType.REPLACE, 1012 dst_extents=[(0, rootfs_part_size / block_size)], 1013 data_blob=os.urandom(rootfs_part_size)) 1014 1015 do_forge_pseudo_op = (fail_missing_pseudo_op or fail_mismatched_pseudo_op) 1016 do_forge_sigs_data = (do_forge_pseudo_op or fail_empty_sigs_blob or 1017 fail_sig_missing_fields or fail_unknown_sig_version 1018 or fail_incorrect_sig) 1019 1020 sigs_data = None 1021 if do_forge_sigs_data: 1022 sigs_gen = test_utils.SignaturesGenerator() 1023 if not fail_empty_sigs_blob: 1024 if fail_sig_missing_fields: 1025 sig_data = None 1026 else: 1027 sig_data = test_utils.SignSha256('fake-payload-content', 1028 test_utils._PRIVKEY_FILE_NAME) 1029 sigs_gen.AddSig(5 if fail_unknown_sig_version else 1, sig_data) 1030 1031 sigs_data = sigs_gen.ToBinary() 1032 payload_gen.SetSignatures(payload_gen.curr_offset, len(sigs_data)) 1033 1034 if do_forge_pseudo_op: 1035 assert sigs_data is not None, 'should have forged signatures blob by now' 1036 sigs_len = len(sigs_data) 1037 payload_gen.AddOperation( 1038 False, common.OpType.REPLACE, 1039 data_offset=payload_gen.curr_offset / 2, 1040 data_length=sigs_len / 2, 1041 dst_extents=[(0, (sigs_len / 2 + block_size - 1) / block_size)]) 1042 1043 # Generate payload (complete w/ signature) and create the test object. 1044 payload_checker = _GetPayloadChecker( 1045 payload_gen.WriteToFileWithData, 1046 payload_gen_dargs={ 1047 'sigs_data': sigs_data, 1048 'privkey_file_name': test_utils._PRIVKEY_FILE_NAME, 1049 'do_add_pseudo_operation': not do_forge_pseudo_op}) 1050 payload_checker.payload_type = checker._TYPE_FULL 1051 report = checker._PayloadReport() 1052 1053 # We have to check the manifest first in order to set signature attributes. 1054 payload_checker._CheckManifest(report, rootfs_part_size, kernel_part_size) 1055 1056 should_fail = (fail_empty_sigs_blob or fail_missing_pseudo_op or 1057 fail_mismatched_pseudo_op or fail_sig_missing_fields or 1058 fail_unknown_sig_version or fail_incorrect_sig) 1059 args = (report, test_utils._PUBKEY_FILE_NAME) 1060 if should_fail: 1061 self.assertRaises(update_payload.PayloadError, 1062 payload_checker._CheckSignatures, *args) 1063 else: 1064 self.assertIsNone(payload_checker._CheckSignatures(*args)) 1065 1066 def DoCheckManifestMinorVersionTest(self, minor_version, payload_type): 1067 """Parametric testing for CheckManifestMinorVersion(). 1068 1069 Args: 1070 minor_version: The payload minor version to test with. 1071 payload_type: The type of the payload we're testing, delta or full. 1072 """ 1073 # Create the test object. 1074 payload = self.MockPayload() 1075 payload.manifest.minor_version = minor_version 1076 payload_checker = checker.PayloadChecker(payload) 1077 payload_checker.payload_type = payload_type 1078 report = checker._PayloadReport() 1079 1080 should_succeed = ( 1081 (minor_version == 0 and payload_type == checker._TYPE_FULL) or 1082 (minor_version == 1 and payload_type == checker._TYPE_DELTA) or 1083 (minor_version == 2 and payload_type == checker._TYPE_DELTA) or 1084 (minor_version == 3 and payload_type == checker._TYPE_DELTA) or 1085 (minor_version == 4 and payload_type == checker._TYPE_DELTA)) 1086 args = (report,) 1087 1088 if should_succeed: 1089 self.assertIsNone(payload_checker._CheckManifestMinorVersion(*args)) 1090 else: 1091 self.assertRaises(update_payload.PayloadError, 1092 payload_checker._CheckManifestMinorVersion, *args) 1093 1094 def DoRunTest(self, rootfs_part_size_provided, kernel_part_size_provided, 1095 fail_wrong_payload_type, fail_invalid_block_size, 1096 fail_mismatched_block_size, fail_excess_data, 1097 fail_rootfs_part_size_exceeded, 1098 fail_kernel_part_size_exceeded): 1099 # Generate a test payload. For this test, we generate a full update that 1100 # has sample kernel and rootfs operations. Since most testing is done with 1101 # internal PayloadChecker methods that are tested elsewhere, here we only 1102 # tamper with what's actually being manipulated and/or tested in the Run() 1103 # method itself. Note that the checker doesn't verify partition hashes, so 1104 # they're safe to fake. 1105 payload_gen = test_utils.EnhancedPayloadGenerator() 1106 block_size = test_utils.KiB(4) 1107 payload_gen.SetBlockSize(block_size) 1108 kernel_filesystem_size = test_utils.KiB(16) 1109 rootfs_filesystem_size = test_utils.MiB(2) 1110 payload_gen.SetPartInfo(False, True, rootfs_filesystem_size, 1111 hashlib.sha256('fake-new-rootfs-content').digest()) 1112 payload_gen.SetPartInfo(True, True, kernel_filesystem_size, 1113 hashlib.sha256('fake-new-kernel-content').digest()) 1114 payload_gen.SetMinorVersion(0) 1115 1116 rootfs_part_size = 0 1117 if rootfs_part_size_provided: 1118 rootfs_part_size = rootfs_filesystem_size + block_size 1119 rootfs_op_size = rootfs_part_size or rootfs_filesystem_size 1120 if fail_rootfs_part_size_exceeded: 1121 rootfs_op_size += block_size 1122 payload_gen.AddOperationWithData( 1123 False, common.OpType.REPLACE, 1124 dst_extents=[(0, rootfs_op_size / block_size)], 1125 data_blob=os.urandom(rootfs_op_size)) 1126 1127 kernel_part_size = 0 1128 if kernel_part_size_provided: 1129 kernel_part_size = kernel_filesystem_size + block_size 1130 kernel_op_size = kernel_part_size or kernel_filesystem_size 1131 if fail_kernel_part_size_exceeded: 1132 kernel_op_size += block_size 1133 payload_gen.AddOperationWithData( 1134 True, common.OpType.REPLACE, 1135 dst_extents=[(0, kernel_op_size / block_size)], 1136 data_blob=os.urandom(kernel_op_size)) 1137 1138 # Generate payload (complete w/ signature) and create the test object. 1139 if fail_invalid_block_size: 1140 use_block_size = block_size + 5 # Not a power of two. 1141 elif fail_mismatched_block_size: 1142 use_block_size = block_size * 2 # Different that payload stated. 1143 else: 1144 use_block_size = block_size 1145 1146 kwargs = { 1147 'payload_gen_dargs': { 1148 'privkey_file_name': test_utils._PRIVKEY_FILE_NAME, 1149 'do_add_pseudo_operation': True, 1150 'is_pseudo_in_kernel': True, 1151 'padding': os.urandom(1024) if fail_excess_data else None}, 1152 'checker_init_dargs': { 1153 'assert_type': 'delta' if fail_wrong_payload_type else 'full', 1154 'block_size': use_block_size}} 1155 if fail_invalid_block_size: 1156 self.assertRaises(update_payload.PayloadError, _GetPayloadChecker, 1157 payload_gen.WriteToFileWithData, **kwargs) 1158 else: 1159 payload_checker = _GetPayloadChecker(payload_gen.WriteToFileWithData, 1160 **kwargs) 1161 1162 kwargs = {'pubkey_file_name': test_utils._PUBKEY_FILE_NAME, 1163 'rootfs_part_size': rootfs_part_size, 1164 'kernel_part_size': kernel_part_size} 1165 should_fail = (fail_wrong_payload_type or fail_mismatched_block_size or 1166 fail_excess_data or 1167 fail_rootfs_part_size_exceeded or 1168 fail_kernel_part_size_exceeded) 1169 if should_fail: 1170 self.assertRaises(update_payload.PayloadError, payload_checker.Run, 1171 **kwargs) 1172 else: 1173 self.assertIsNone(payload_checker.Run(**kwargs)) 1174 1175 # This implements a generic API, hence the occasional unused args. 1176 # pylint: disable=W0613 1177 def ValidateCheckOperationTest(op_type_name, is_last, allow_signature, 1178 allow_unhashed, fail_src_extents, 1179 fail_dst_extents, 1180 fail_mismatched_data_offset_length, 1181 fail_missing_dst_extents, fail_src_length, 1182 fail_dst_length, fail_data_hash, 1183 fail_prev_data_offset, fail_bad_minor_version): 1184 """Returns True iff the combination of arguments represents a valid test.""" 1185 op_type = _OpTypeByName(op_type_name) 1186 1187 # REPLACE/REPLACE_BZ operations don't read data from src partition. They are 1188 # compatible with all valid minor versions, so we don't need to check that. 1189 if (op_type in (common.OpType.REPLACE, common.OpType.REPLACE_BZ) and ( 1190 fail_src_extents or fail_src_length or fail_bad_minor_version)): 1191 return False 1192 1193 # MOVE and SOURCE_COPY operations don't carry data. 1194 if (op_type in (common.OpType.MOVE, common.OpType.SOURCE_COPY) and ( 1195 fail_mismatched_data_offset_length or fail_data_hash or 1196 fail_prev_data_offset)): 1197 return False 1198 1199 return True 1200 1201 1202 def TestMethodBody(run_method_name, run_dargs): 1203 """Returns a function that invokes a named method with named arguments.""" 1204 return lambda self: getattr(self, run_method_name)(**run_dargs) 1205 1206 1207 def AddParametricTests(tested_method_name, arg_space, validate_func=None): 1208 """Enumerates and adds specific parametric tests to PayloadCheckerTest. 1209 1210 This function enumerates a space of test parameters (defined by arg_space), 1211 then binds a new, unique method name in PayloadCheckerTest to a test function 1212 that gets handed the said parameters. This is a preferable approach to doing 1213 the enumeration and invocation during the tests because this way each test is 1214 treated as a complete run by the unittest framework, and so benefits from the 1215 usual setUp/tearDown mechanics. 1216 1217 Args: 1218 tested_method_name: Name of the tested PayloadChecker method. 1219 arg_space: A dictionary containing variables (keys) and lists of values 1220 (values) associated with them. 1221 validate_func: A function used for validating test argument combinations. 1222 """ 1223 for value_tuple in itertools.product(*arg_space.itervalues()): 1224 run_dargs = dict(zip(arg_space.iterkeys(), value_tuple)) 1225 if validate_func and not validate_func(**run_dargs): 1226 continue 1227 run_method_name = 'Do%sTest' % tested_method_name 1228 test_method_name = 'test%s' % tested_method_name 1229 for arg_key, arg_val in run_dargs.iteritems(): 1230 if arg_val or type(arg_val) is int: 1231 test_method_name += '__%s=%s' % (arg_key, arg_val) 1232 setattr(PayloadCheckerTest, test_method_name, 1233 TestMethodBody(run_method_name, run_dargs)) 1234 1235 1236 def AddAllParametricTests(): 1237 """Enumerates and adds all parametric tests to PayloadCheckerTest.""" 1238 # Add all _CheckElem() test cases. 1239 AddParametricTests('AddElem', 1240 {'linebreak': (True, False), 1241 'indent': (0, 1, 2), 1242 'convert': (str, lambda s: s[::-1]), 1243 'is_present': (True, False), 1244 'is_mandatory': (True, False), 1245 'is_submsg': (True, False)}) 1246 1247 # Add all _Add{Mandatory,Optional}Field tests. 1248 AddParametricTests('AddField', 1249 {'is_mandatory': (True, False), 1250 'linebreak': (True, False), 1251 'indent': (0, 1, 2), 1252 'convert': (str, lambda s: s[::-1]), 1253 'is_present': (True, False)}) 1254 1255 # Add all _Add{Mandatory,Optional}SubMsg tests. 1256 AddParametricTests('AddSubMsg', 1257 {'is_mandatory': (True, False), 1258 'is_present': (True, False)}) 1259 1260 # Add all _CheckManifest() test cases. 1261 AddParametricTests('CheckManifest', 1262 {'fail_mismatched_block_size': (True, False), 1263 'fail_bad_sigs': (True, False), 1264 'fail_mismatched_oki_ori': (True, False), 1265 'fail_bad_oki': (True, False), 1266 'fail_bad_ori': (True, False), 1267 'fail_bad_nki': (True, False), 1268 'fail_bad_nri': (True, False), 1269 'fail_old_kernel_fs_size': (True, False), 1270 'fail_old_rootfs_fs_size': (True, False), 1271 'fail_new_kernel_fs_size': (True, False), 1272 'fail_new_rootfs_fs_size': (True, False)}) 1273 1274 # Add all _CheckOperation() test cases. 1275 AddParametricTests('CheckOperation', 1276 {'op_type_name': ('REPLACE', 'REPLACE_BZ', 'MOVE', 1277 'BSDIFF', 'SOURCE_COPY', 1278 'SOURCE_BSDIFF'), 1279 'is_last': (True, False), 1280 'allow_signature': (True, False), 1281 'allow_unhashed': (True, False), 1282 'fail_src_extents': (True, False), 1283 'fail_dst_extents': (True, False), 1284 'fail_mismatched_data_offset_length': (True, False), 1285 'fail_missing_dst_extents': (True, False), 1286 'fail_src_length': (True, False), 1287 'fail_dst_length': (True, False), 1288 'fail_data_hash': (True, False), 1289 'fail_prev_data_offset': (True, False), 1290 'fail_bad_minor_version': (True, False)}, 1291 validate_func=ValidateCheckOperationTest) 1292 1293 # Add all _CheckOperations() test cases. 1294 AddParametricTests('CheckOperations', 1295 {'fail_nonexhaustive_full_update': (True, False)}) 1296 1297 # Add all _CheckOperations() test cases. 1298 AddParametricTests('CheckSignatures', 1299 {'fail_empty_sigs_blob': (True, False), 1300 'fail_missing_pseudo_op': (True, False), 1301 'fail_mismatched_pseudo_op': (True, False), 1302 'fail_sig_missing_fields': (True, False), 1303 'fail_unknown_sig_version': (True, False), 1304 'fail_incorrect_sig': (True, False)}) 1305 1306 # Add all _CheckManifestMinorVersion() test cases. 1307 AddParametricTests('CheckManifestMinorVersion', 1308 {'minor_version': (None, 0, 1, 2, 3, 4, 555), 1309 'payload_type': (checker._TYPE_FULL, 1310 checker._TYPE_DELTA)}) 1311 1312 # Add all Run() test cases. 1313 AddParametricTests('Run', 1314 {'rootfs_part_size_provided': (True, False), 1315 'kernel_part_size_provided': (True, False), 1316 'fail_wrong_payload_type': (True, False), 1317 'fail_invalid_block_size': (True, False), 1318 'fail_mismatched_block_size': (True, False), 1319 'fail_excess_data': (True, False), 1320 'fail_rootfs_part_size_exceeded': (True, False), 1321 'fail_kernel_part_size_exceeded': (True, False)}) 1322 1323 1324 if __name__ == '__main__': 1325 AddAllParametricTests() 1326 unittest.main() 1327