Home | History | Annotate | Download | only in test
      1 #!/usr/bin/env python
      2 #
      3 # Copyright 2012, Google Inc.
      4 # All rights reserved.
      5 #
      6 # Redistribution and use in source and binary forms, with or without
      7 # modification, are permitted provided that the following conditions are
      8 # met:
      9 #
     10 #     * Redistributions of source code must retain the above copyright
     11 # notice, this list of conditions and the following disclaimer.
     12 #     * Redistributions in binary form must reproduce the above
     13 # copyright notice, this list of conditions and the following disclaimer
     14 # in the documentation and/or other materials provided with the
     15 # distribution.
     16 #     * Neither the name of Google Inc. nor the names of its
     17 # contributors may be used to endorse or promote products derived from
     18 # this software without specific prior written permission.
     19 #
     20 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
     21 # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
     22 # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
     23 # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
     24 # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
     25 # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
     26 # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
     27 # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
     28 # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
     29 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
     30 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
     31 
     32 
     33 """Tests for msgutil module."""
     34 
     35 
     36 import array
     37 import Queue
     38 import struct
     39 import unittest
     40 import zlib
     41 
     42 import set_sys_path  # Update sys.path to locate mod_pywebsocket module.
     43 
     44 from mod_pywebsocket import common
     45 from mod_pywebsocket.extensions import DeflateFrameExtensionProcessor
     46 from mod_pywebsocket.extensions import PerFrameCompressionExtensionProcessor
     47 from mod_pywebsocket.extensions import PerMessageCompressionExtensionProcessor
     48 from mod_pywebsocket import msgutil
     49 from mod_pywebsocket.stream import InvalidUTF8Exception
     50 from mod_pywebsocket.stream import Stream
     51 from mod_pywebsocket.stream import StreamHixie75
     52 from mod_pywebsocket.stream import StreamOptions
     53 from mod_pywebsocket import util
     54 from test import mock
     55 
     56 
     57 # We use one fixed nonce for testing instead of cryptographically secure PRNG.
     58 _MASKING_NONCE = 'ABCD'
     59 
     60 
     61 def _mask_hybi(frame):
     62     frame_key = map(ord, _MASKING_NONCE)
     63     frame_key_len = len(frame_key)
     64     result = array.array('B')
     65     result.fromstring(frame)
     66     count = 0
     67     for i in xrange(len(result)):
     68         result[i] ^= frame_key[count]
     69         count = (count + 1) % frame_key_len
     70     return _MASKING_NONCE + result.tostring()
     71 
     72 
     73 def _install_extension_processor(processor, request, stream_options):
     74     response = processor.get_extension_response()
     75     if response is not None:
     76         processor.setup_stream_options(stream_options)
     77         request.ws_extension_processors.append(processor)
     78 
     79 
     80 def _create_request_from_rawdata(
     81     read_data, deflate_stream=False, deflate_frame_request=None,
     82     perframe_compression_request=None, permessage_compression_request=None):
     83     req = mock.MockRequest(connection=mock.MockConn(''.join(read_data)))
     84     req.ws_version = common.VERSION_HYBI_LATEST
     85     stream_options = StreamOptions()
     86     stream_options.deflate_stream = deflate_stream
     87     req.ws_extension_processors = []
     88     if deflate_frame_request is not None:
     89         processor = DeflateFrameExtensionProcessor(deflate_frame_request)
     90         _install_extension_processor(processor, req, stream_options)
     91     elif perframe_compression_request is not None:
     92         processor = PerFrameCompressionExtensionProcessor(
     93             perframe_compression_request)
     94         _install_extension_processor(processor, req, stream_options)
     95     elif permessage_compression_request is not None:
     96         processor = PerMessageCompressionExtensionProcessor(
     97             permessage_compression_request)
     98         _install_extension_processor(processor, req, stream_options)
     99 
    100     req.ws_stream = Stream(req, stream_options)
    101     return req
    102 
    103 
    104 def _create_request(*frames):
    105     """Creates MockRequest using data given as frames.
    106 
    107     frames will be returned on calling request.connection.read() where request
    108     is MockRequest returned by this function.
    109     """
    110 
    111     read_data = []
    112     for (header, body) in frames:
    113         read_data.append(header + _mask_hybi(body))
    114 
    115     return _create_request_from_rawdata(read_data)
    116 
    117 
    118 def _create_blocking_request():
    119     """Creates MockRequest.
    120 
    121     Data written to a MockRequest can be read out by calling
    122     request.connection.written_data().
    123     """
    124 
    125     req = mock.MockRequest(connection=mock.MockBlockingConn())
    126     req.ws_version = common.VERSION_HYBI_LATEST
    127     stream_options = StreamOptions()
    128     req.ws_stream = Stream(req, stream_options)
    129     return req
    130 
    131 
    132 def _create_request_hixie75(read_data=''):
    133     req = mock.MockRequest(connection=mock.MockConn(read_data))
    134     req.ws_stream = StreamHixie75(req)
    135     return req
    136 
    137 
    138 def _create_blocking_request_hixie75():
    139     req = mock.MockRequest(connection=mock.MockBlockingConn())
    140     req.ws_stream = StreamHixie75(req)
    141     return req
    142 
    143 
    144 class MessageTest(unittest.TestCase):
    145     # Tests for Stream
    146 
    147     def test_send_message(self):
    148         request = _create_request()
    149         msgutil.send_message(request, 'Hello')
    150         self.assertEqual('\x81\x05Hello', request.connection.written_data())
    151 
    152         payload = 'a' * 125
    153         request = _create_request()
    154         msgutil.send_message(request, payload)
    155         self.assertEqual('\x81\x7d' + payload,
    156                          request.connection.written_data())
    157 
    158     def test_send_medium_message(self):
    159         payload = 'a' * 126
    160         request = _create_request()
    161         msgutil.send_message(request, payload)
    162         self.assertEqual('\x81\x7e\x00\x7e' + payload,
    163                          request.connection.written_data())
    164 
    165         payload = 'a' * ((1 << 16) - 1)
    166         request = _create_request()
    167         msgutil.send_message(request, payload)
    168         self.assertEqual('\x81\x7e\xff\xff' + payload,
    169                          request.connection.written_data())
    170 
    171     def test_send_large_message(self):
    172         payload = 'a' * (1 << 16)
    173         request = _create_request()
    174         msgutil.send_message(request, payload)
    175         self.assertEqual('\x81\x7f\x00\x00\x00\x00\x00\x01\x00\x00' + payload,
    176                          request.connection.written_data())
    177 
    178     def test_send_message_unicode(self):
    179         request = _create_request()
    180         msgutil.send_message(request, u'\u65e5')
    181         # U+65e5 is encoded as e6,97,a5 in UTF-8
    182         self.assertEqual('\x81\x03\xe6\x97\xa5',
    183                          request.connection.written_data())
    184 
    185     def test_send_message_fragments(self):
    186         request = _create_request()
    187         msgutil.send_message(request, 'Hello', False)
    188         msgutil.send_message(request, ' ', False)
    189         msgutil.send_message(request, 'World', False)
    190         msgutil.send_message(request, '!', True)
    191         self.assertEqual('\x01\x05Hello\x00\x01 \x00\x05World\x80\x01!',
    192                          request.connection.written_data())
    193 
    194     def test_send_fragments_immediate_zero_termination(self):
    195         request = _create_request()
    196         msgutil.send_message(request, 'Hello World!', False)
    197         msgutil.send_message(request, '', True)
    198         self.assertEqual('\x01\x0cHello World!\x80\x00',
    199                          request.connection.written_data())
    200 
    201     def test_send_message_deflate_stream(self):
    202         compress = zlib.compressobj(
    203             zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)
    204 
    205         request = _create_request_from_rawdata('', deflate_stream=True)
    206         msgutil.send_message(request, 'Hello')
    207         expected = compress.compress('\x81\x05Hello')
    208         expected += compress.flush(zlib.Z_SYNC_FLUSH)
    209         self.assertEqual(expected, request.connection.written_data())
    210 
    211     def test_send_message_deflate_frame(self):
    212         compress = zlib.compressobj(
    213             zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)
    214 
    215         extension = common.ExtensionParameter(common.DEFLATE_FRAME_EXTENSION)
    216         request = _create_request_from_rawdata(
    217             '', deflate_frame_request=extension)
    218         msgutil.send_message(request, 'Hello')
    219         msgutil.send_message(request, 'World')
    220 
    221         expected = ''
    222 
    223         compressed_hello = compress.compress('Hello')
    224         compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH)
    225         compressed_hello = compressed_hello[:-4]
    226         expected += '\xc1%c' % len(compressed_hello)
    227         expected += compressed_hello
    228 
    229         compressed_world = compress.compress('World')
    230         compressed_world += compress.flush(zlib.Z_SYNC_FLUSH)
    231         compressed_world = compressed_world[:-4]
    232         expected += '\xc1%c' % len(compressed_world)
    233         expected += compressed_world
    234 
    235         self.assertEqual(expected, request.connection.written_data())
    236 
    237     def test_send_message_deflate_frame_comp_bit(self):
    238         compress = zlib.compressobj(
    239             zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)
    240 
    241         extension = common.ExtensionParameter(common.DEFLATE_FRAME_EXTENSION)
    242         request = _create_request_from_rawdata(
    243             '', deflate_frame_request=extension)
    244         self.assertEquals(1, len(request.ws_extension_processors))
    245         deflate_frame_processor = request.ws_extension_processors[0]
    246         msgutil.send_message(request, 'Hello')
    247         deflate_frame_processor.disable_outgoing_compression()
    248         msgutil.send_message(request, 'Hello')
    249         deflate_frame_processor.enable_outgoing_compression()
    250         msgutil.send_message(request, 'Hello')
    251 
    252         expected = ''
    253 
    254         compressed_hello = compress.compress('Hello')
    255         compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH)
    256         compressed_hello = compressed_hello[:-4]
    257         expected += '\xc1%c' % len(compressed_hello)
    258         expected += compressed_hello
    259 
    260         expected += '\x81\x05Hello'
    261 
    262         compressed_2nd_hello = compress.compress('Hello')
    263         compressed_2nd_hello += compress.flush(zlib.Z_SYNC_FLUSH)
    264         compressed_2nd_hello = compressed_2nd_hello[:-4]
    265         expected += '\xc1%c' % len(compressed_2nd_hello)
    266         expected += compressed_2nd_hello
    267 
    268         self.assertEqual(expected, request.connection.written_data())
    269 
    270     def test_send_message_deflate_frame_no_context_takeover_parameter(self):
    271         compress = zlib.compressobj(
    272             zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)
    273 
    274         extension = common.ExtensionParameter(common.DEFLATE_FRAME_EXTENSION)
    275         extension.add_parameter('no_context_takeover', None)
    276         request = _create_request_from_rawdata(
    277             '', deflate_frame_request=extension)
    278         for i in xrange(3):
    279             msgutil.send_message(request, 'Hello')
    280 
    281         compressed_message = compress.compress('Hello')
    282         compressed_message += compress.flush(zlib.Z_SYNC_FLUSH)
    283         compressed_message = compressed_message[:-4]
    284         expected = '\xc1%c' % len(compressed_message)
    285         expected += compressed_message
    286 
    287         self.assertEqual(
    288             expected + expected + expected, request.connection.written_data())
    289 
    290     def test_deflate_frame_bad_request_parameters(self):
    291         """Tests that if there's anything wrong with deflate-frame extension
    292         request, deflate-frame is rejected.
    293         """
    294 
    295         extension = common.ExtensionParameter(common.DEFLATE_FRAME_EXTENSION)
    296         # max_window_bits less than 8 is illegal.
    297         extension.add_parameter('max_window_bits', '7')
    298         processor = DeflateFrameExtensionProcessor(extension)
    299         self.assertEqual(None, processor.get_extension_response())
    300 
    301         extension = common.ExtensionParameter(common.DEFLATE_FRAME_EXTENSION)
    302         # max_window_bits greater than 15 is illegal.
    303         extension.add_parameter('max_window_bits', '16')
    304         processor = DeflateFrameExtensionProcessor(extension)
    305         self.assertEqual(None, processor.get_extension_response())
    306 
    307         extension = common.ExtensionParameter(common.DEFLATE_FRAME_EXTENSION)
    308         # Non integer max_window_bits is illegal.
    309         extension.add_parameter('max_window_bits', 'foobar')
    310         processor = DeflateFrameExtensionProcessor(extension)
    311         self.assertEqual(None, processor.get_extension_response())
    312 
    313         extension = common.ExtensionParameter(common.DEFLATE_FRAME_EXTENSION)
    314         # no_context_takeover must not have any value.
    315         extension.add_parameter('no_context_takeover', 'foobar')
    316         processor = DeflateFrameExtensionProcessor(extension)
    317         self.assertEqual(None, processor.get_extension_response())
    318 
    319     def test_deflate_frame_response_parameters(self):
    320         extension = common.ExtensionParameter(common.DEFLATE_FRAME_EXTENSION)
    321         processor = DeflateFrameExtensionProcessor(extension)
    322         processor.set_response_window_bits(8)
    323         response = processor.get_extension_response()
    324         self.assertTrue(response.has_parameter('max_window_bits'))
    325         self.assertEqual('8', response.get_parameter_value('max_window_bits'))
    326 
    327         extension = common.ExtensionParameter(common.DEFLATE_FRAME_EXTENSION)
    328         processor = DeflateFrameExtensionProcessor(extension)
    329         processor.set_response_no_context_takeover(True)
    330         response = processor.get_extension_response()
    331         self.assertTrue(response.has_parameter('no_context_takeover'))
    332         self.assertTrue(
    333             response.get_parameter_value('no_context_takeover') is None)
    334 
    335     def test_send_message_perframe_compress_deflate(self):
    336         compress = zlib.compressobj(
    337             zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)
    338         extension = common.ExtensionParameter(
    339             common.PERFRAME_COMPRESSION_EXTENSION)
    340         extension.add_parameter('method', 'deflate')
    341         request = _create_request_from_rawdata(
    342                       '', perframe_compression_request=extension)
    343         msgutil.send_message(request, 'Hello')
    344         msgutil.send_message(request, 'World')
    345 
    346         expected = ''
    347 
    348         compressed_hello = compress.compress('Hello')
    349         compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH)
    350         compressed_hello = compressed_hello[:-4]
    351         expected += '\xc1%c' % len(compressed_hello)
    352         expected += compressed_hello
    353 
    354         compressed_world = compress.compress('World')
    355         compressed_world += compress.flush(zlib.Z_SYNC_FLUSH)
    356         compressed_world = compressed_world[:-4]
    357         expected += '\xc1%c' % len(compressed_world)
    358         expected += compressed_world
    359 
    360         self.assertEqual(expected, request.connection.written_data())
    361 
    362     def test_send_message_permessage_compress_deflate(self):
    363         compress = zlib.compressobj(
    364             zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)
    365         extension = common.ExtensionParameter(
    366             common.PERMESSAGE_COMPRESSION_EXTENSION)
    367         extension.add_parameter('method', 'deflate')
    368         request = _create_request_from_rawdata(
    369                       '', permessage_compression_request=extension)
    370         msgutil.send_message(request, 'Hello')
    371         msgutil.send_message(request, 'World')
    372 
    373         expected = ''
    374 
    375         compressed_hello = compress.compress('Hello')
    376         compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH)
    377         compressed_hello = compressed_hello[:-4]
    378         expected += '\xc1%c' % len(compressed_hello)
    379         expected += compressed_hello
    380 
    381         compressed_world = compress.compress('World')
    382         compressed_world += compress.flush(zlib.Z_SYNC_FLUSH)
    383         compressed_world = compressed_world[:-4]
    384         expected += '\xc1%c' % len(compressed_world)
    385         expected += compressed_world
    386 
    387         self.assertEqual(expected, request.connection.written_data())
    388 
    389     def test_receive_message(self):
    390         request = _create_request(
    391             ('\x81\x85', 'Hello'), ('\x81\x86', 'World!'))
    392         self.assertEqual('Hello', msgutil.receive_message(request))
    393         self.assertEqual('World!', msgutil.receive_message(request))
    394 
    395         payload = 'a' * 125
    396         request = _create_request(('\x81\xfd', payload))
    397         self.assertEqual(payload, msgutil.receive_message(request))
    398 
    399     def test_receive_medium_message(self):
    400         payload = 'a' * 126
    401         request = _create_request(('\x81\xfe\x00\x7e', payload))
    402         self.assertEqual(payload, msgutil.receive_message(request))
    403 
    404         payload = 'a' * ((1 << 16) - 1)
    405         request = _create_request(('\x81\xfe\xff\xff', payload))
    406         self.assertEqual(payload, msgutil.receive_message(request))
    407 
    408     def test_receive_large_message(self):
    409         payload = 'a' * (1 << 16)
    410         request = _create_request(
    411             ('\x81\xff\x00\x00\x00\x00\x00\x01\x00\x00', payload))
    412         self.assertEqual(payload, msgutil.receive_message(request))
    413 
    414     def test_receive_length_not_encoded_using_minimal_number_of_bytes(self):
    415         # Log warning on receiving bad payload length field that doesn't use
    416         # minimal number of bytes but continue processing.
    417 
    418         payload = 'a'
    419         # 1 byte can be represented without extended payload length field.
    420         request = _create_request(
    421             ('\x81\xff\x00\x00\x00\x00\x00\x00\x00\x01', payload))
    422         self.assertEqual(payload, msgutil.receive_message(request))
    423 
    424     def test_receive_message_unicode(self):
    425         request = _create_request(('\x81\x83', '\xe6\x9c\xac'))
    426         # U+672c is encoded as e6,9c,ac in UTF-8
    427         self.assertEqual(u'\u672c', msgutil.receive_message(request))
    428 
    429     def test_receive_message_erroneous_unicode(self):
    430         # \x80 and \x81 are invalid as UTF-8.
    431         request = _create_request(('\x81\x82', '\x80\x81'))
    432         # Invalid characters should raise InvalidUTF8Exception
    433         self.assertRaises(InvalidUTF8Exception,
    434                           msgutil.receive_message,
    435                           request)
    436 
    437     def test_receive_fragments(self):
    438         request = _create_request(
    439             ('\x01\x85', 'Hello'),
    440             ('\x00\x81', ' '),
    441             ('\x00\x85', 'World'),
    442             ('\x80\x81', '!'))
    443         self.assertEqual('Hello World!', msgutil.receive_message(request))
    444 
    445     def test_receive_fragments_unicode(self):
    446         # UTF-8 encodes U+6f22 into e6bca2 and U+5b57 into e5ad97.
    447         request = _create_request(
    448             ('\x01\x82', '\xe6\xbc'),
    449             ('\x00\x82', '\xa2\xe5'),
    450             ('\x80\x82', '\xad\x97'))
    451         self.assertEqual(u'\u6f22\u5b57', msgutil.receive_message(request))
    452 
    453     def test_receive_fragments_immediate_zero_termination(self):
    454         request = _create_request(
    455             ('\x01\x8c', 'Hello World!'), ('\x80\x80', ''))
    456         self.assertEqual('Hello World!', msgutil.receive_message(request))
    457 
    458     def test_receive_fragments_duplicate_start(self):
    459         request = _create_request(
    460             ('\x01\x85', 'Hello'), ('\x01\x85', 'World'))
    461         self.assertRaises(msgutil.InvalidFrameException,
    462                           msgutil.receive_message,
    463                           request)
    464 
    465     def test_receive_fragments_intermediate_but_not_started(self):
    466         request = _create_request(('\x00\x85', 'Hello'))
    467         self.assertRaises(msgutil.InvalidFrameException,
    468                           msgutil.receive_message,
    469                           request)
    470 
    471     def test_receive_fragments_end_but_not_started(self):
    472         request = _create_request(('\x80\x85', 'Hello'))
    473         self.assertRaises(msgutil.InvalidFrameException,
    474                           msgutil.receive_message,
    475                           request)
    476 
    477     def test_receive_message_discard(self):
    478         request = _create_request(
    479             ('\x8f\x86', 'IGNORE'), ('\x81\x85', 'Hello'),
    480             ('\x8f\x89', 'DISREGARD'), ('\x81\x86', 'World!'))
    481         self.assertRaises(msgutil.UnsupportedFrameException,
    482                           msgutil.receive_message, request)
    483         self.assertEqual('Hello', msgutil.receive_message(request))
    484         self.assertRaises(msgutil.UnsupportedFrameException,
    485                           msgutil.receive_message, request)
    486         self.assertEqual('World!', msgutil.receive_message(request))
    487 
    488     def test_receive_close(self):
    489         request = _create_request(
    490             ('\x88\x8a', struct.pack('!H', 1000) + 'Good bye'))
    491         self.assertEqual(None, msgutil.receive_message(request))
    492         self.assertEqual(1000, request.ws_close_code)
    493         self.assertEqual('Good bye', request.ws_close_reason)
    494 
    495     def test_receive_message_deflate_stream(self):
    496         compress = zlib.compressobj(
    497             zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)
    498 
    499         data = compress.compress('\x81\x85' + _mask_hybi('Hello'))
    500         data += compress.flush(zlib.Z_SYNC_FLUSH)
    501         data += compress.compress('\x81\x89' + _mask_hybi('WebSocket'))
    502         data += compress.flush(zlib.Z_FINISH)
    503 
    504         compress = zlib.compressobj(
    505             zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)
    506 
    507         data += compress.compress('\x81\x85' + _mask_hybi('World'))
    508         data += compress.flush(zlib.Z_SYNC_FLUSH)
    509         # Close frame
    510         data += compress.compress(
    511             '\x88\x8a' + _mask_hybi(struct.pack('!H', 1000) + 'Good bye'))
    512         data += compress.flush(zlib.Z_SYNC_FLUSH)
    513 
    514         request = _create_request_from_rawdata(data, deflate_stream=True)
    515         self.assertEqual('Hello', msgutil.receive_message(request))
    516         self.assertEqual('WebSocket', msgutil.receive_message(request))
    517         self.assertEqual('World', msgutil.receive_message(request))
    518 
    519         self.assertFalse(request.drain_received_data_called)
    520 
    521         self.assertEqual(None, msgutil.receive_message(request))
    522 
    523         self.assertTrue(request.drain_received_data_called)
    524 
    525     def test_receive_message_deflate_frame(self):
    526         compress = zlib.compressobj(
    527             zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)
    528 
    529         data = ''
    530 
    531         compressed_hello = compress.compress('Hello')
    532         compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH)
    533         compressed_hello = compressed_hello[:-4]
    534         data += '\xc1%c' % (len(compressed_hello) | 0x80)
    535         data += _mask_hybi(compressed_hello)
    536 
    537         compressed_websocket = compress.compress('WebSocket')
    538         compressed_websocket += compress.flush(zlib.Z_FINISH)
    539         compressed_websocket += '\x00'
    540         data += '\xc1%c' % (len(compressed_websocket) | 0x80)
    541         data += _mask_hybi(compressed_websocket)
    542 
    543         compress = zlib.compressobj(
    544             zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)
    545 
    546         compressed_world = compress.compress('World')
    547         compressed_world += compress.flush(zlib.Z_SYNC_FLUSH)
    548         compressed_world = compressed_world[:-4]
    549         data += '\xc1%c' % (len(compressed_world) | 0x80)
    550         data += _mask_hybi(compressed_world)
    551 
    552         # Close frame
    553         data += '\x88\x8a' + _mask_hybi(struct.pack('!H', 1000) + 'Good bye')
    554 
    555         extension = common.ExtensionParameter(common.DEFLATE_FRAME_EXTENSION)
    556         request = _create_request_from_rawdata(
    557             data, deflate_frame_request=extension)
    558         self.assertEqual('Hello', msgutil.receive_message(request))
    559         self.assertEqual('WebSocket', msgutil.receive_message(request))
    560         self.assertEqual('World', msgutil.receive_message(request))
    561 
    562         self.assertEqual(None, msgutil.receive_message(request))
    563 
    564     def test_receive_message_deflate_frame_client_using_smaller_window(self):
    565         """Test that frames coming from a client which is using smaller window
    566         size that the server are correctly received.
    567         """
    568 
    569         # Using the smallest window bits of 8 for generating input frames.
    570         compress = zlib.compressobj(
    571             zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -8)
    572 
    573         data = ''
    574 
    575         # Use a frame whose content is bigger than the clients' DEFLATE window
    576         # size before compression. The content mainly consists of 'a' but
    577         # repetition of 'b' is put at the head and tail so that if the window
    578         # size is big, the head is back-referenced but if small, not.
    579         payload = 'b' * 64 + 'a' * 1024 + 'b' * 64
    580         compressed_hello = compress.compress(payload)
    581         compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH)
    582         compressed_hello = compressed_hello[:-4]
    583         data += '\xc1%c' % (len(compressed_hello) | 0x80)
    584         data += _mask_hybi(compressed_hello)
    585 
    586         # Close frame
    587         data += '\x88\x8a' + _mask_hybi(struct.pack('!H', 1000) + 'Good bye')
    588 
    589         extension = common.ExtensionParameter(common.DEFLATE_FRAME_EXTENSION)
    590         request = _create_request_from_rawdata(
    591             data, deflate_frame_request=extension)
    592         self.assertEqual(payload, msgutil.receive_message(request))
    593 
    594         self.assertEqual(None, msgutil.receive_message(request))
    595 
    596     def test_receive_message_deflate_frame_comp_bit(self):
    597         compress = zlib.compressobj(
    598             zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)
    599 
    600         data = ''
    601 
    602         compressed_hello = compress.compress('Hello')
    603         compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH)
    604         compressed_hello = compressed_hello[:-4]
    605         data += '\xc1%c' % (len(compressed_hello) | 0x80)
    606         data += _mask_hybi(compressed_hello)
    607 
    608         data += '\x81\x85' + _mask_hybi('Hello')
    609 
    610         compress = zlib.compressobj(
    611             zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)
    612 
    613         compressed_2nd_hello = compress.compress('Hello')
    614         compressed_2nd_hello += compress.flush(zlib.Z_SYNC_FLUSH)
    615         compressed_2nd_hello = compressed_2nd_hello[:-4]
    616         data += '\xc1%c' % (len(compressed_2nd_hello) | 0x80)
    617         data += _mask_hybi(compressed_2nd_hello)
    618 
    619         extension = common.ExtensionParameter(common.DEFLATE_FRAME_EXTENSION)
    620         request = _create_request_from_rawdata(
    621             data, deflate_frame_request=extension)
    622         for i in xrange(3):
    623             self.assertEqual('Hello', msgutil.receive_message(request))
    624 
    625     def test_receive_message_perframe_compression_frame(self):
    626         compress = zlib.compressobj(
    627             zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)
    628 
    629         data = ''
    630 
    631         compressed_hello = compress.compress('Hello')
    632         compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH)
    633         compressed_hello = compressed_hello[:-4]
    634         data += '\xc1%c' % (len(compressed_hello) | 0x80)
    635         data += _mask_hybi(compressed_hello)
    636 
    637         compressed_websocket = compress.compress('WebSocket')
    638         compressed_websocket += compress.flush(zlib.Z_FINISH)
    639         compressed_websocket += '\x00'
    640         data += '\xc1%c' % (len(compressed_websocket) | 0x80)
    641         data += _mask_hybi(compressed_websocket)
    642 
    643         compress = zlib.compressobj(
    644             zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)
    645 
    646         compressed_world = compress.compress('World')
    647         compressed_world += compress.flush(zlib.Z_SYNC_FLUSH)
    648         compressed_world = compressed_world[:-4]
    649         data += '\xc1%c' % (len(compressed_world) | 0x80)
    650         data += _mask_hybi(compressed_world)
    651 
    652         # Close frame
    653         data += '\x88\x8a' + _mask_hybi(struct.pack('!H', 1000) + 'Good bye')
    654 
    655         extension = common.ExtensionParameter(
    656             common.PERFRAME_COMPRESSION_EXTENSION)
    657         extension.add_parameter('method', 'deflate')
    658         request = _create_request_from_rawdata(
    659             data, perframe_compression_request=extension)
    660         self.assertEqual('Hello', msgutil.receive_message(request))
    661         self.assertEqual('WebSocket', msgutil.receive_message(request))
    662         self.assertEqual('World', msgutil.receive_message(request))
    663 
    664         self.assertEqual(None, msgutil.receive_message(request))
    665 
    666     def test_receive_message_permessage_deflate_compression(self):
    667         compress = zlib.compressobj(
    668             zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)
    669 
    670         data = ''
    671 
    672         compressed_hello = compress.compress('HelloWebSocket')
    673         compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH)
    674         compressed_hello = compressed_hello[:-4]
    675         split_position = len(compressed_hello) / 2
    676         data += '\x41%c' % (split_position | 0x80)
    677         data += _mask_hybi(compressed_hello[:split_position])
    678 
    679         data += '\x80%c' % ((len(compressed_hello) - split_position) | 0x80)
    680         data += _mask_hybi(compressed_hello[split_position:])
    681 
    682         compress = zlib.compressobj(
    683             zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)
    684 
    685         compressed_world = compress.compress('World')
    686         compressed_world += compress.flush(zlib.Z_SYNC_FLUSH)
    687         compressed_world = compressed_world[:-4]
    688         data += '\xc1%c' % (len(compressed_world) | 0x80)
    689         data += _mask_hybi(compressed_world)
    690 
    691         # Close frame
    692         data += '\x88\x8a' + _mask_hybi(struct.pack('!H', 1000) + 'Good bye')
    693 
    694         extension = common.ExtensionParameter(
    695             common.PERMESSAGE_COMPRESSION_EXTENSION)
    696         extension.add_parameter('method', 'deflate')
    697         request = _create_request_from_rawdata(
    698             data, permessage_compression_request=extension)
    699         self.assertEqual('HelloWebSocket', msgutil.receive_message(request))
    700         self.assertEqual('World', msgutil.receive_message(request))
    701 
    702         self.assertEqual(None, msgutil.receive_message(request))
    703 
    704     def test_send_longest_close(self):
    705         reason = 'a' * 123
    706         request = _create_request(
    707             ('\x88\xfd',
    708              struct.pack('!H', common.STATUS_NORMAL_CLOSURE) + reason))
    709         request.ws_stream.close_connection(common.STATUS_NORMAL_CLOSURE,
    710                                            reason)
    711         self.assertEqual(request.ws_close_code, common.STATUS_NORMAL_CLOSURE)
    712         self.assertEqual(request.ws_close_reason, reason)
    713 
    714     def test_send_close_too_long(self):
    715         request = _create_request()
    716         self.assertRaises(msgutil.BadOperationException,
    717                           Stream.close_connection,
    718                           request.ws_stream,
    719                           common.STATUS_NORMAL_CLOSURE,
    720                           'a' * 124)
    721 
    722     def test_send_close_inconsistent_code_and_reason(self):
    723         request = _create_request()
    724         # reason parameter must not be specified when code is None.
    725         self.assertRaises(msgutil.BadOperationException,
    726                           Stream.close_connection,
    727                           request.ws_stream,
    728                           None,
    729                           'a')
    730 
    731     def test_send_ping(self):
    732         request = _create_request()
    733         msgutil.send_ping(request, 'Hello World!')
    734         self.assertEqual('\x89\x0cHello World!',
    735                          request.connection.written_data())
    736 
    737     def test_send_longest_ping(self):
    738         request = _create_request()
    739         msgutil.send_ping(request, 'a' * 125)
    740         self.assertEqual('\x89\x7d' + 'a' * 125,
    741                          request.connection.written_data())
    742 
    743     def test_send_ping_too_long(self):
    744         request = _create_request()
    745         self.assertRaises(msgutil.BadOperationException,
    746                           msgutil.send_ping,
    747                           request,
    748                           'a' * 126)
    749 
    750     def test_receive_ping(self):
    751         """Tests receiving a ping control frame."""
    752 
    753         def handler(request, message):
    754             request.called = True
    755 
    756         # Stream automatically respond to ping with pong without any action
    757         # by application layer.
    758         request = _create_request(
    759             ('\x89\x85', 'Hello'), ('\x81\x85', 'World'))
    760         self.assertEqual('World', msgutil.receive_message(request))
    761         self.assertEqual('\x8a\x05Hello',
    762                          request.connection.written_data())
    763 
    764         request = _create_request(
    765             ('\x89\x85', 'Hello'), ('\x81\x85', 'World'))
    766         request.on_ping_handler = handler
    767         self.assertEqual('World', msgutil.receive_message(request))
    768         self.assertTrue(request.called)
    769 
    770     def test_receive_longest_ping(self):
    771         request = _create_request(
    772             ('\x89\xfd', 'a' * 125), ('\x81\x85', 'World'))
    773         self.assertEqual('World', msgutil.receive_message(request))
    774         self.assertEqual('\x8a\x7d' + 'a' * 125,
    775                          request.connection.written_data())
    776 
    777     def test_receive_ping_too_long(self):
    778         request = _create_request(('\x89\xfe\x00\x7e', 'a' * 126))
    779         self.assertRaises(msgutil.InvalidFrameException,
    780                           msgutil.receive_message,
    781                           request)
    782 
    783     def test_receive_pong(self):
    784         """Tests receiving a pong control frame."""
    785 
    786         def handler(request, message):
    787             request.called = True
    788 
    789         request = _create_request(
    790             ('\x8a\x85', 'Hello'), ('\x81\x85', 'World'))
    791         request.on_pong_handler = handler
    792         msgutil.send_ping(request, 'Hello')
    793         self.assertEqual('\x89\x05Hello',
    794                          request.connection.written_data())
    795         # Valid pong is received, but receive_message won't return for it.
    796         self.assertEqual('World', msgutil.receive_message(request))
    797         # Check that nothing was written after receive_message call.
    798         self.assertEqual('\x89\x05Hello',
    799                          request.connection.written_data())
    800 
    801         self.assertTrue(request.called)
    802 
    803     def test_receive_unsolicited_pong(self):
    804         # Unsolicited pong is allowed from HyBi 07.
    805         request = _create_request(
    806             ('\x8a\x85', 'Hello'), ('\x81\x85', 'World'))
    807         msgutil.receive_message(request)
    808 
    809         request = _create_request(
    810             ('\x8a\x85', 'Hello'), ('\x81\x85', 'World'))
    811         msgutil.send_ping(request, 'Jumbo')
    812         # Body mismatch.
    813         msgutil.receive_message(request)
    814 
    815     def test_ping_cannot_be_fragmented(self):
    816         request = _create_request(('\x09\x85', 'Hello'))
    817         self.assertRaises(msgutil.InvalidFrameException,
    818                           msgutil.receive_message,
    819                           request)
    820 
    821     def test_ping_with_too_long_payload(self):
    822         request = _create_request(('\x89\xfe\x01\x00', 'a' * 256))
    823         self.assertRaises(msgutil.InvalidFrameException,
    824                           msgutil.receive_message,
    825                           request)
    826 
    827 
    828 class MessageTestHixie75(unittest.TestCase):
    829     """Tests for draft-hixie-thewebsocketprotocol-76 stream class."""
    830 
    831     def test_send_message(self):
    832         request = _create_request_hixie75()
    833         msgutil.send_message(request, 'Hello')
    834         self.assertEqual('\x00Hello\xff', request.connection.written_data())
    835 
    836     def test_send_message_unicode(self):
    837         request = _create_request_hixie75()
    838         msgutil.send_message(request, u'\u65e5')
    839         # U+65e5 is encoded as e6,97,a5 in UTF-8
    840         self.assertEqual('\x00\xe6\x97\xa5\xff',
    841                          request.connection.written_data())
    842 
    843     def test_receive_message(self):
    844         request = _create_request_hixie75('\x00Hello\xff\x00World!\xff')
    845         self.assertEqual('Hello', msgutil.receive_message(request))
    846         self.assertEqual('World!', msgutil.receive_message(request))
    847 
    848     def test_receive_message_unicode(self):
    849         request = _create_request_hixie75('\x00\xe6\x9c\xac\xff')
    850         # U+672c is encoded as e6,9c,ac in UTF-8
    851         self.assertEqual(u'\u672c', msgutil.receive_message(request))
    852 
    853     def test_receive_message_erroneous_unicode(self):
    854         # \x80 and \x81 are invalid as UTF-8.
    855         request = _create_request_hixie75('\x00\x80\x81\xff')
    856         # Invalid characters should be replaced with
    857         # U+fffd REPLACEMENT CHARACTER
    858         self.assertEqual(u'\ufffd\ufffd', msgutil.receive_message(request))
    859 
    860     def test_receive_message_discard(self):
    861         request = _create_request_hixie75('\x80\x06IGNORE\x00Hello\xff'
    862                                           '\x01DISREGARD\xff\x00World!\xff')
    863         self.assertEqual('Hello', msgutil.receive_message(request))
    864         self.assertEqual('World!', msgutil.receive_message(request))
    865 
    866 
    867 class MessageReceiverTest(unittest.TestCase):
    868     """Tests the Stream class using MessageReceiver."""
    869 
    870     def test_queue(self):
    871         request = _create_blocking_request()
    872         receiver = msgutil.MessageReceiver(request)
    873 
    874         self.assertEqual(None, receiver.receive_nowait())
    875 
    876         request.connection.put_bytes('\x81\x86' + _mask_hybi('Hello!'))
    877         self.assertEqual('Hello!', receiver.receive())
    878 
    879     def test_onmessage(self):
    880         onmessage_queue = Queue.Queue()
    881 
    882         def onmessage_handler(message):
    883             onmessage_queue.put(message)
    884 
    885         request = _create_blocking_request()
    886         receiver = msgutil.MessageReceiver(request, onmessage_handler)
    887 
    888         request.connection.put_bytes('\x81\x86' + _mask_hybi('Hello!'))
    889         self.assertEqual('Hello!', onmessage_queue.get())
    890 
    891 
    892 class MessageReceiverHixie75Test(unittest.TestCase):
    893     """Tests the StreamHixie75 class using MessageReceiver."""
    894 
    895     def test_queue(self):
    896         request = _create_blocking_request_hixie75()
    897         receiver = msgutil.MessageReceiver(request)
    898 
    899         self.assertEqual(None, receiver.receive_nowait())
    900 
    901         request.connection.put_bytes('\x00Hello!\xff')
    902         self.assertEqual('Hello!', receiver.receive())
    903 
    904     def test_onmessage(self):
    905         onmessage_queue = Queue.Queue()
    906 
    907         def onmessage_handler(message):
    908             onmessage_queue.put(message)
    909 
    910         request = _create_blocking_request_hixie75()
    911         receiver = msgutil.MessageReceiver(request, onmessage_handler)
    912 
    913         request.connection.put_bytes('\x00Hello!\xff')
    914         self.assertEqual('Hello!', onmessage_queue.get())
    915 
    916 
    917 class MessageSenderTest(unittest.TestCase):
    918     """Tests the Stream class using MessageSender."""
    919 
    920     def test_send(self):
    921         request = _create_blocking_request()
    922         sender = msgutil.MessageSender(request)
    923 
    924         sender.send('World')
    925         self.assertEqual('\x81\x05World', request.connection.written_data())
    926 
    927     def test_send_nowait(self):
    928         # Use a queue to check the bytes written by MessageSender.
    929         # request.connection.written_data() cannot be used here because
    930         # MessageSender runs in a separate thread.
    931         send_queue = Queue.Queue()
    932 
    933         def write(bytes):
    934             send_queue.put(bytes)
    935 
    936         request = _create_blocking_request()
    937         request.connection.write = write
    938 
    939         sender = msgutil.MessageSender(request)
    940 
    941         sender.send_nowait('Hello')
    942         sender.send_nowait('World')
    943         self.assertEqual('\x81\x05Hello', send_queue.get())
    944         self.assertEqual('\x81\x05World', send_queue.get())
    945 
    946 
    947 class MessageSenderHixie75Test(unittest.TestCase):
    948     """Tests the StreamHixie75 class using MessageSender."""
    949 
    950     def test_send(self):
    951         request = _create_blocking_request_hixie75()
    952         sender = msgutil.MessageSender(request)
    953 
    954         sender.send('World')
    955         self.assertEqual('\x00World\xff', request.connection.written_data())
    956 
    957     def test_send_nowait(self):
    958         # Use a queue to check the bytes written by MessageSender.
    959         # request.connection.written_data() cannot be used here because
    960         # MessageSender runs in a separate thread.
    961         send_queue = Queue.Queue()
    962 
    963         def write(bytes):
    964             send_queue.put(bytes)
    965 
    966         request = _create_blocking_request_hixie75()
    967         request.connection.write = write
    968 
    969         sender = msgutil.MessageSender(request)
    970 
    971         sender.send_nowait('Hello')
    972         sender.send_nowait('World')
    973         self.assertEqual('\x00Hello\xff', send_queue.get())
    974         self.assertEqual('\x00World\xff', send_queue.get())
    975 
    976 
    977 if __name__ == '__main__':
    978     unittest.main()
    979 
    980 
    981 # vi:sts=4 sw=4 et
    982