Home | History | Annotate | Download | only in mbim_compliance
      1 # Copyright (c) 2015 The Chromium OS Authors. All rights reserved.
      2 # Use of this source code is governed by a BSD-style license that can be
      3 # found in the LICENSE file.
      4 
      5 import array
      6 import logging
      7 import mox
      8 import multiprocessing
      9 import struct
     10 import unittest
     11 
     12 import common
     13 from autotest_lib.client.cros.cellular.mbim_compliance import mbim_channel
     14 from autotest_lib.client.cros.cellular.mbim_compliance import mbim_errors
     15 
     16 
     17 class MBIMChannelTestCase(unittest.TestCase):
     18     """ Test cases for the MBIMChannel class. """
     19 
     20     def setUp(self):
     21         # Arguments passed to MBIMChannel. Irrelevant for these tests, mostly.
     22         self._device = None
     23         self._interface_number = 0
     24         self._interrupt_endpoint_address = 0x01
     25         self._in_buffer_size = 100
     26 
     27         self._setup_mock_subprocess()
     28         self._mox = mox.Mox()
     29 
     30         # Reach into |MBIMChannel| and mock out the request queue, so we can set
     31         # expectations on it.
     32         # |multiprocessing.Queue| is actually a function that returns some
     33         # hidden |multiprocessing.queues.Queue| class. We'll grab the class from
     34         # a temporary object so we can mock it.
     35         some_queue = multiprocessing.Queue()
     36         queue_class = some_queue.__class__
     37         self._mock_request_queue = self._mox.CreateMock(queue_class)
     38         self._channel._request_queue = self._mock_request_queue
     39 
     40         # On the other hand, just grab the real response queue.
     41         self._response_queue = self._channel._response_queue
     42 
     43         # Decrease timeouts to small values to speed up tests.
     44         self._channel.FRAGMENT_TIMEOUT_S = 0.2
     45         self._channel.TRANSACTION_TIMEOUT_S = 0.5
     46 
     47 
     48     def tearDown(self):
     49         self._channel.close()
     50         self._subprocess_mox.VerifyAll()
     51 
     52 
     53     def _setup_mock_subprocess(self):
     54         """
     55         Setup long-term expectations on the mocked out subprocess.
     56 
     57         These expectations are only met when |self._channel.close| is called in
     58         |tearDown|.
     59 
     60         """
     61         self._subprocess_mox = mox.Mox()
     62         mock_process = self._subprocess_mox.CreateMock(multiprocessing.Process)
     63         mock_process(target=mox.IgnoreArg(),
     64                      args=mox.IgnoreArg()).AndReturn(mock_process)
     65         mock_process.start()
     66 
     67         # Each API call into MBIMChannel results in an aliveness ping to the
     68         # subprocess.
     69         # Finally, when |self._channel| is destructed, it will attempt to
     70         # terminate the |mock_process|, with increasingly drastic actions.
     71         mock_process.is_alive().MultipleTimes().AndReturn(True)
     72         mock_process.join(mox.IgnoreArg())
     73         mock_process.is_alive().AndReturn(True)
     74         mock_process.terminate()
     75 
     76         self._subprocess_mox.ReplayAll()
     77         self._channel = mbim_channel.MBIMChannel(
     78                 self._device,
     79                 self._interface_number,
     80                 self._interrupt_endpoint_address,
     81                 self._in_buffer_size,
     82                 mock_process)
     83 
     84 
     85     def test_creation(self):
     86         """ A trivial test that we mocked out the |Process| class correctly. """
     87         pass
     88 
     89 
     90     def test_unfragmented_packet_successful(self):
     91         """ Test that we can synchronously send an unfragmented packet. """
     92         packet = self._get_unfragmented_packet(1)
     93         response_packet = self._get_unfragmented_packet(1)
     94         self._expect_transaction([packet], [response_packet])
     95         self._verify_transaction_successful([packet], [response_packet])
     96 
     97 
     98     def test_unfragmented_packet_timeout(self):
     99         """ Test the case when an unfragmented packet receives no response. """
    100         packet = self._get_unfragmented_packet(1)
    101         self._expect_transaction([packet])
    102         self._verify_transaction_failed([packet])
    103 
    104 
    105     def test_single_fragment_successful(self):
    106         """ Test that we can synchronously send a fragmented packet. """
    107         fragment = self._get_fragment(1, 1, 0)
    108         response_fragment = self._get_fragment(1, 1, 0)
    109         self._expect_transaction([fragment], [response_fragment])
    110         self._verify_transaction_successful([fragment], [response_fragment])
    111 
    112 
    113     def test_single_fragment_timeout(self):
    114         """ Test the case when a fragmented packet receives no response. """
    115         fragment = self._get_fragment(1, 1, 0)
    116         self._expect_transaction([fragment])
    117         self._verify_transaction_failed([fragment])
    118 
    119 
    120     def test_single_fragment_corrupted_reply(self):
    121         """ Test the case when the response has a corrupted fragment header. """
    122         fragment = self._get_fragment(1, 1, 0)
    123         response_fragment = self._get_fragment(1, 1, 0)
    124         response_fragment = response_fragment[:len(response_fragment)-1]
    125         self._expect_transaction([fragment], [response_fragment])
    126         self._verify_transaction_failed([fragment])
    127 
    128 
    129     def test_multiple_fragments_successful(self):
    130         """ Test that we can send/recieve multi-fragment packets. """
    131         fragment_0 = self._get_fragment(1, 2, 0)
    132         fragment_1 = self._get_fragment(1, 2, 1)
    133         response_fragment_0 = self._get_fragment(1, 2, 0)
    134         response_fragment_1 = self._get_fragment(1, 2, 1)
    135         self._expect_transaction([fragment_0, fragment_1],
    136                                  [response_fragment_0, response_fragment_1])
    137         self._verify_transaction_successful(
    138                 [fragment_0, fragment_1],
    139                 [response_fragment_0, response_fragment_1])
    140 
    141 
    142     def test_multiple_fragments_incorrect_total_fragments(self):
    143         """ Test the case when one of the fragment reports incorrect total. """
    144         fragment = self._get_fragment(1, 1, 0)
    145         response_fragment_0 = self._get_fragment(1, 2, 0)
    146         # total_fragment should have been 2, but is 99.
    147         response_fragment_1 = self._get_fragment(1, 99, 1)
    148         self._expect_transaction([fragment],
    149                                  [response_fragment_0, response_fragment_1])
    150         self._verify_transaction_failed([fragment])
    151 
    152 
    153     def test_multiple_fragments_reordered_reply_1(self):
    154         """ Test the case when the first fragemnt reports incorrect index. """
    155         fragment = self._get_fragment(1, 1, 0)
    156         # Incorrect first fragment number.
    157         response_fragment = self._get_fragment(1, 2, 1)
    158         self._expect_transaction([fragment], [response_fragment])
    159         self._verify_transaction_failed([fragment])
    160 
    161 
    162     def test_multiple_fragments_reordered_reply_2(self):
    163         """ Test the case when a follow up fragment reports incorrect index. """
    164         fragment = self._get_fragment(1, 1, 0)
    165         response_fragment_0 = self._get_fragment(1, 2, 0)
    166         # Incorrect second fragment number.
    167         response_fragment_1 = self._get_fragment(1, 2, 99)
    168         self._expect_transaction([fragment],
    169                                  [response_fragment_0, response_fragment_1])
    170         self._verify_transaction_failed([fragment])
    171 
    172 
    173     def test_multiple_fragments_insufficient_reply_timeout(self):
    174         """ Test the case when we recieve only part of the response. """
    175         fragment = self._get_fragment(1, 1, 0)
    176         # The second fragment will never arrive.
    177         response_fragment_0 = self._get_fragment(1, 2, 0)
    178         self._expect_transaction([fragment], [response_fragment_0])
    179         self._verify_transaction_successful([fragment], [response_fragment_0])
    180 
    181 
    182     def test_unfragmented_packet_notification(self):
    183         """ Test the case when a notification comes before the response. """
    184         packet = self._get_unfragmented_packet(1)
    185         response = self._get_unfragmented_packet(1)
    186         notification = self._get_unfragmented_packet(0)
    187         self._expect_transaction([packet], [notification, response])
    188         self._verify_transaction_successful([packet], [response])
    189         self.assertEqual([[notification]],
    190                          self._channel.get_outstanding_packets())
    191 
    192 
    193     def test_fragmented_notification(self):
    194         """ Test the case when a fragmented notification preceeds response. """
    195         packet_fragment_0 = self._get_fragment(1, 2, 0)
    196         packet_fragment_1 = self._get_fragment(1, 2, 1)
    197         response_fragment_0 = self._get_fragment(1, 2, 0)
    198         response_fragment_1 = self._get_fragment(1, 2, 1)
    199         notification_0_fragment_0 = self._get_fragment(0, 2, 0)
    200         notification_0_fragment_1 = self._get_fragment(0, 2, 1)
    201         notification_1_fragment_0 = self._get_fragment(99, 2, 0)
    202         notification_1_fragment_1 = self._get_fragment(99, 2, 1)
    203 
    204         self._expect_transaction(
    205                 [packet_fragment_0, packet_fragment_1],
    206                 [notification_0_fragment_0, notification_0_fragment_1,
    207                  notification_1_fragment_0, notification_1_fragment_1,
    208                  response_fragment_0, response_fragment_1])
    209         self._verify_transaction_successful(
    210                 [packet_fragment_0, packet_fragment_1],
    211                 [response_fragment_0, response_fragment_1])
    212         self.assertEqual(
    213                 [[notification_0_fragment_0, notification_0_fragment_1],
    214                  [notification_1_fragment_0, notification_1_fragment_1]],
    215                 self._channel.get_outstanding_packets())
    216 
    217 
    218     def test_multiple_packets_rollover_notification(self):
    219         """
    220         Test the case when we receive incomplete response, followed by
    221         fragmented notifications.
    222 
    223         We have to be smart enough to realize that the incorrect fragment
    224         recieved at the end of the response belongs to the next notification
    225         instead.
    226 
    227         """
    228         packet = self._get_fragment(1, 1, 0)
    229         # The second fragment never comes, instead we get a notification
    230         # fragment.
    231         response_fragment_0 = self._get_fragment(1, 2, 0)
    232         notification_0_fragment_0 = self._get_fragment(0, 2, 0)
    233         notification_0_fragment_1 = self._get_fragment(0, 2, 1)
    234         notification_1_fragment_0 = self._get_fragment(99, 2, 0)
    235         notification_1_fragment_1 = self._get_fragment(99, 2, 1)
    236 
    237         self._expect_transaction(
    238                 [packet],
    239                 [response_fragment_0,
    240                  notification_0_fragment_0, notification_0_fragment_1,
    241                  notification_1_fragment_0, notification_1_fragment_1])
    242         self._verify_transaction_successful(
    243                 [packet],
    244                 [response_fragment_0])
    245         self.assertEqual(
    246                 [[notification_0_fragment_0, notification_0_fragment_1],
    247                  [notification_1_fragment_0, notification_1_fragment_1]],
    248                 self._channel.get_outstanding_packets())
    249 
    250 
    251     def test_data(self):
    252         """ Test that data is transferred transaperntly. """
    253         packet = self._get_unfragmented_packet(1)
    254         packet.fromlist([0xFF, 0xFF, 0xFF, 0xFF, 0xDD, 0xDD, 0xDD, 0xDD])
    255         response_packet = self._get_unfragmented_packet(1)
    256         response_packet.fromlist([0xAA, 0xAA, 0xBB, 0xBB])
    257         self._expect_transaction([packet], [response_packet])
    258         self._verify_transaction_successful([packet], [response_packet])
    259 
    260 
    261     def test_flush_successful(self):
    262         """ Test that flush clears all queues. """
    263         packet = self._get_unfragmented_packet(1)
    264         response = self._get_unfragmented_packet(1)
    265         notification_1 = self._get_fragment(0, 1, 0)
    266         self._response_queue.put_nowait(notification_1)
    267         self._mock_request_queue.qsize().AndReturn(1)
    268         self._mock_request_queue.empty().AndReturn(False)
    269         self._mock_request_queue.empty().WithSideEffects(
    270                 self._response_queue.put_nowait(response)).AndReturn(True)
    271         self._mox.ReplayAll()
    272         self._channel.flush()
    273         self._mox.VerifyAll()
    274         self.assertEqual(0, self._response_queue.qsize())
    275 
    276 
    277     def test_flush_failed(self):
    278         """ Test the case when the request queue fails to empty out. """
    279         packet = self._get_unfragmented_packet(1)
    280         self._mock_request_queue.qsize().AndReturn(1)
    281         self._mock_request_queue.empty().MultipleTimes().AndReturn(False)
    282         self._mox.ReplayAll()
    283         self.assertRaises(
    284                 mbim_errors.MBIMComplianceChannelError,
    285                 self._channel.flush)
    286         self._mox.VerifyAll()
    287 
    288 
    289     def _queue_responses(self, responses):
    290         """ Helper method for |_expect_transaction|. Do not use directly. """
    291         for response in responses:
    292             self._response_queue.put_nowait(response)
    293 
    294 
    295     def _expect_transaction(self, requests, responses=None):
    296         """
    297         Helper method to setup expectations on the queues.
    298 
    299         @param requests: A list of packets to expect on the |_request_queue|.
    300         @param respones: An optional list of packets to respond with after the
    301                 last request.
    302 
    303         """
    304 
    305         last_request = requests[len(requests) - 1]
    306         earlier_requests = requests[:len(requests) - 1]
    307         for request in earlier_requests:
    308             self._mock_request_queue.put_nowait(request)
    309         if responses:
    310             self._mock_request_queue.put_nowait(last_request).WithSideEffects(
    311                     lambda _: self._queue_responses(responses))
    312         else:
    313             self._mock_request_queue.put_nowait(last_request)
    314 
    315 
    316     def _verify_transaction_successful(self, requests, responses):
    317         """
    318         Helper method to assert that the transaction was successful.
    319 
    320         @param requests: List of packets sent.
    321         @param responses: List of packets expected back.
    322         """
    323         self._mox.ReplayAll()
    324         self.assertEqual(responses,
    325                          self._channel.bidirectional_transaction(*requests))
    326         self._mox.VerifyAll()
    327 
    328 
    329     def _verify_transaction_failed(self, requests):
    330         """
    331         Helper method to assert that the transaction failed.
    332 
    333         @param requests: List of packets sent.
    334 
    335         """
    336         self._mox.ReplayAll()
    337         self.assertRaises(mbim_errors.MBIMComplianceChannelError,
    338                           self._channel.bidirectional_transaction,
    339                           *requests)
    340         self._mox.VerifyAll()
    341 
    342 
    343     def _get_unfragmented_packet(self, transaction_id):
    344         """ Creates a packet that has no fragment header. """
    345         packet_format = '<LLL' # This does not contain a fragment header.
    346         packet = self._create_buffer(struct.calcsize(packet_format))
    347         struct.pack_into(packet_format,
    348                          packet,
    349                          0,
    350                          0x00000000,  # 0x0 does not need fragments.
    351                          struct.calcsize(packet_format),
    352                          transaction_id)
    353         return packet
    354 
    355 
    356     def _get_fragment(self, transaction_id, total_fragments, current_fragment):
    357         """ Creates a fragment with the given fields. """
    358         fragment_header_format = '<LLLLL'
    359         message_type = 0x00000003  # MBIM_COMMAND_MSG has fragments.
    360         fragment = self._create_buffer(struct.calcsize(fragment_header_format))
    361         struct.pack_into(fragment_header_format,
    362                          fragment,
    363                          0,
    364                          message_type,
    365                          struct.calcsize(fragment_header_format),
    366                          transaction_id,
    367                          total_fragments,
    368                          current_fragment)
    369         return fragment
    370 
    371 
    372     def _create_buffer(self, size):
    373         """ Create an array of the give size initialized to 0x00. """
    374         return array.array('B', '\x00' * size)
    375 
    376 
    377 if __name__ == '__main__':
    378     logging.basicConfig(level=logging.DEBUG)
    379     unittest.main()
    380