Home | History | Annotate | Download | only in test
      1 #!/usr/bin/env python
      2 #
      3 # Copyright 2011, Google Inc.
      4 # All rights reserved.
      5 #
      6 # Redistribution and use in source and binary forms, with or without
      7 # modification, are permitted provided that the following conditions are
      8 # met:
      9 #
     10 #     * Redistributions of source code must retain the above copyright
     11 # notice, this list of conditions and the following disclaimer.
     12 #     * Redistributions in binary form must reproduce the above
     13 # copyright notice, this list of conditions and the following disclaimer
     14 # in the documentation and/or other materials provided with the
     15 # distribution.
     16 #     * Neither the name of Google Inc. nor the names of its
     17 # contributors may be used to endorse or promote products derived from
     18 # this software without specific prior written permission.
     19 #
     20 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
     21 # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
     22 # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
     23 # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
     24 # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
     25 # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
     26 # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
     27 # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
     28 # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
     29 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
     30 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
     31 
     32 
     33 """Tests for handshake module."""
     34 
     35 
     36 import unittest
     37 
     38 import set_sys_path  # Update sys.path to locate mod_pywebsocket module.
     39 from mod_pywebsocket import common
     40 from mod_pywebsocket.handshake._base import AbortedByUserException
     41 from mod_pywebsocket.handshake._base import HandshakeException
     42 from mod_pywebsocket.handshake._base import VersionException
     43 from mod_pywebsocket.handshake.hybi import Handshaker
     44 
     45 import mock
     46 
     47 
     48 class RequestDefinition(object):
     49     """A class for holding data for constructing opening handshake strings for
     50     testing the opening handshake processor.
     51     """
     52 
     53     def __init__(self, method, uri, headers):
     54         self.method = method
     55         self.uri = uri
     56         self.headers = headers
     57 
     58 
     59 def _create_good_request_def():
     60     return RequestDefinition(
     61         'GET', '/demo',
     62         {'Host': 'server.example.com',
     63          'Upgrade': 'websocket',
     64          'Connection': 'Upgrade',
     65          'Sec-WebSocket-Key': 'dGhlIHNhbXBsZSBub25jZQ==',
     66          'Sec-WebSocket-Version': '13',
     67          'Origin': 'http://example.com'})
     68 
     69 
     70 def _create_request(request_def):
     71     conn = mock.MockConn('')
     72     return mock.MockRequest(
     73         method=request_def.method,
     74         uri=request_def.uri,
     75         headers_in=request_def.headers,
     76         connection=conn)
     77 
     78 
     79 def _create_handshaker(request):
     80     handshaker = Handshaker(request, mock.MockDispatcher())
     81     return handshaker
     82 
     83 
     84 class SubprotocolChoosingDispatcher(object):
     85     """A dispatcher for testing. This dispatcher sets the i-th subprotocol
     86     of requested ones to ws_protocol where i is given on construction as index
     87     argument. If index is negative, default_value will be set to ws_protocol.
     88     """
     89 
     90     def __init__(self, index, default_value=None):
     91         self.index = index
     92         self.default_value = default_value
     93 
     94     def do_extra_handshake(self, conn_context):
     95         if self.index >= 0:
     96             conn_context.ws_protocol = conn_context.ws_requested_protocols[
     97                 self.index]
     98         else:
     99             conn_context.ws_protocol = self.default_value
    100 
    101     def transfer_data(self, conn_context):
    102         pass
    103 
    104 
    105 class HandshakeAbortedException(Exception):
    106     pass
    107 
    108 
    109 class AbortingDispatcher(object):
    110     """A dispatcher for testing. This dispatcher raises an exception in
    111     do_extra_handshake to reject the request.
    112     """
    113 
    114     def do_extra_handshake(self, conn_context):
    115         raise HandshakeAbortedException('An exception to reject the request')
    116 
    117     def transfer_data(self, conn_context):
    118         pass
    119 
    120 
    121 class AbortedByUserDispatcher(object):
    122     """A dispatcher for testing. This dispatcher raises an
    123     AbortedByUserException in do_extra_handshake to reject the request.
    124     """
    125 
    126     def do_extra_handshake(self, conn_context):
    127         raise AbortedByUserException('An AbortedByUserException to reject the '
    128                                      'request')
    129 
    130     def transfer_data(self, conn_context):
    131         pass
    132 
    133 
    134 _EXPECTED_RESPONSE = (
    135     'HTTP/1.1 101 Switching Protocols\r\n'
    136     'Upgrade: websocket\r\n'
    137     'Connection: Upgrade\r\n'
    138     'Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\r\n')
    139 
    140 
    141 class HandshakerTest(unittest.TestCase):
    142     """A unittest for draft-ietf-hybi-thewebsocketprotocol-06 and later
    143     handshake processor.
    144     """
    145 
    146     def test_do_handshake(self):
    147         request = _create_request(_create_good_request_def())
    148         dispatcher = mock.MockDispatcher()
    149         handshaker = Handshaker(request, dispatcher)
    150         handshaker.do_handshake()
    151 
    152         self.assertTrue(dispatcher.do_extra_handshake_called)
    153 
    154         self.assertEqual(
    155             _EXPECTED_RESPONSE, request.connection.written_data())
    156         self.assertEqual('/demo', request.ws_resource)
    157         self.assertEqual('http://example.com', request.ws_origin)
    158         self.assertEqual(None, request.ws_protocol)
    159         self.assertEqual(None, request.ws_extensions)
    160         self.assertEqual(common.VERSION_HYBI_LATEST, request.ws_version)
    161 
    162     def test_do_handshake_with_capitalized_value(self):
    163         request_def = _create_good_request_def()
    164         request_def.headers['upgrade'] = 'WEBSOCKET'
    165 
    166         request = _create_request(request_def)
    167         handshaker = _create_handshaker(request)
    168         handshaker.do_handshake()
    169         self.assertEqual(
    170             _EXPECTED_RESPONSE, request.connection.written_data())
    171 
    172         request_def = _create_good_request_def()
    173         request_def.headers['Connection'] = 'UPGRADE'
    174 
    175         request = _create_request(request_def)
    176         handshaker = _create_handshaker(request)
    177         handshaker.do_handshake()
    178         self.assertEqual(
    179             _EXPECTED_RESPONSE, request.connection.written_data())
    180 
    181     def test_do_handshake_with_multiple_connection_values(self):
    182         request_def = _create_good_request_def()
    183         request_def.headers['Connection'] = 'Upgrade, keep-alive, , '
    184 
    185         request = _create_request(request_def)
    186         handshaker = _create_handshaker(request)
    187         handshaker.do_handshake()
    188         self.assertEqual(
    189             _EXPECTED_RESPONSE, request.connection.written_data())
    190 
    191     def test_aborting_handshake(self):
    192         handshaker = Handshaker(
    193             _create_request(_create_good_request_def()),
    194             AbortingDispatcher())
    195         # do_extra_handshake raises an exception. Check that it's not caught by
    196         # do_handshake.
    197         self.assertRaises(HandshakeAbortedException, handshaker.do_handshake)
    198 
    199     def test_do_handshake_with_protocol(self):
    200         request_def = _create_good_request_def()
    201         request_def.headers['Sec-WebSocket-Protocol'] = 'chat, superchat'
    202 
    203         request = _create_request(request_def)
    204         handshaker = Handshaker(request, SubprotocolChoosingDispatcher(0))
    205         handshaker.do_handshake()
    206 
    207         EXPECTED_RESPONSE = (
    208             'HTTP/1.1 101 Switching Protocols\r\n'
    209             'Upgrade: websocket\r\n'
    210             'Connection: Upgrade\r\n'
    211             'Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n'
    212             'Sec-WebSocket-Protocol: chat\r\n\r\n')
    213 
    214         self.assertEqual(EXPECTED_RESPONSE, request.connection.written_data())
    215         self.assertEqual('chat', request.ws_protocol)
    216 
    217     def test_do_handshake_protocol_not_in_request_but_in_response(self):
    218         request_def = _create_good_request_def()
    219         request = _create_request(request_def)
    220         handshaker = Handshaker(
    221             request, SubprotocolChoosingDispatcher(-1, 'foobar'))
    222         # No request has been made but ws_protocol is set. HandshakeException
    223         # must be raised.
    224         self.assertRaises(HandshakeException, handshaker.do_handshake)
    225 
    226     def test_do_handshake_with_protocol_no_protocol_selection(self):
    227         request_def = _create_good_request_def()
    228         request_def.headers['Sec-WebSocket-Protocol'] = 'chat, superchat'
    229 
    230         request = _create_request(request_def)
    231         handshaker = _create_handshaker(request)
    232         # ws_protocol is not set. HandshakeException must be raised.
    233         self.assertRaises(HandshakeException, handshaker.do_handshake)
    234 
    235     def test_do_handshake_with_extensions(self):
    236         request_def = _create_good_request_def()
    237         request_def.headers['Sec-WebSocket-Extensions'] = (
    238             'deflate-stream, unknown')
    239 
    240         EXPECTED_RESPONSE = (
    241             'HTTP/1.1 101 Switching Protocols\r\n'
    242             'Upgrade: websocket\r\n'
    243             'Connection: Upgrade\r\n'
    244             'Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n'
    245             'Sec-WebSocket-Extensions: deflate-stream\r\n\r\n')
    246 
    247         request = _create_request(request_def)
    248         handshaker = _create_handshaker(request)
    249         handshaker.do_handshake()
    250         self.assertEqual(EXPECTED_RESPONSE, request.connection.written_data())
    251         self.assertEqual(1, len(request.ws_extensions))
    252         extension = request.ws_extensions[0]
    253         self.assertEqual('deflate-stream', extension.name())
    254         self.assertEqual(0, len(extension.get_parameter_names()))
    255 
    256     def test_do_handshake_with_quoted_extensions(self):
    257         request_def = _create_good_request_def()
    258         request_def.headers['Sec-WebSocket-Extensions'] = (
    259             'deflate-stream, , '
    260             'unknown; e   =    "mc^2"; ma="\r\n      \\\rf  "; pv=nrt')
    261 
    262         request = _create_request(request_def)
    263         handshaker = _create_handshaker(request)
    264         handshaker.do_handshake()
    265         self.assertEqual(2, len(request.ws_requested_extensions))
    266         first_extension = request.ws_requested_extensions[0]
    267         self.assertEqual('deflate-stream', first_extension.name())
    268         self.assertEqual(0, len(first_extension.get_parameter_names()))
    269         second_extension = request.ws_requested_extensions[1]
    270         self.assertEqual('unknown', second_extension.name())
    271         self.assertEqual(
    272             ['e', 'ma', 'pv'], second_extension.get_parameter_names())
    273         self.assertEqual('mc^2', second_extension.get_parameter_value('e'))
    274         self.assertEqual(' \rf ', second_extension.get_parameter_value('ma'))
    275         self.assertEqual('nrt', second_extension.get_parameter_value('pv'))
    276 
    277     def test_do_handshake_with_optional_headers(self):
    278         request_def = _create_good_request_def()
    279         request_def.headers['EmptyValue'] = ''
    280         request_def.headers['AKey'] = 'AValue'
    281 
    282         request = _create_request(request_def)
    283         handshaker = _create_handshaker(request)
    284         handshaker.do_handshake()
    285         self.assertEqual(
    286             'AValue', request.headers_in['AKey'])
    287         self.assertEqual(
    288             '', request.headers_in['EmptyValue'])
    289 
    290     def test_abort_extra_handshake(self):
    291         handshaker = Handshaker(
    292             _create_request(_create_good_request_def()),
    293             AbortedByUserDispatcher())
    294         # do_extra_handshake raises an AbortedByUserException. Check that it's
    295         # not caught by do_handshake.
    296         self.assertRaises(AbortedByUserException, handshaker.do_handshake)
    297 
    298     def test_do_handshake_with_mux_and_deflateframe(self):
    299         request_def = _create_good_request_def()
    300         request_def.headers['Sec-WebSocket-Extensions'] = ('%s, %s' % (
    301                 common.MUX_EXTENSION,
    302                 common.DEFLATE_FRAME_EXTENSION))
    303         request = _create_request(request_def)
    304         handshaker = _create_handshaker(request)
    305         handshaker.do_handshake()
    306         self.assertEqual(2, len(request.ws_extensions))
    307         self.assertEqual(common.MUX_EXTENSION,
    308                          request.ws_extensions[0].name())
    309         self.assertEqual(common.DEFLATE_FRAME_EXTENSION,
    310                          request.ws_extensions[1].name())
    311         self.assertTrue(request.mux)
    312         self.assertEqual(0, len(request.mux_extensions))
    313 
    314     def test_do_handshake_with_deflateframe_and_mux(self):
    315         request_def = _create_good_request_def()
    316         request_def.headers['Sec-WebSocket-Extensions'] = ('%s, %s' % (
    317                 common.DEFLATE_FRAME_EXTENSION,
    318                 common.MUX_EXTENSION))
    319         request = _create_request(request_def)
    320         handshaker = _create_handshaker(request)
    321         handshaker.do_handshake()
    322         # mux should be rejected.
    323         self.assertEqual(1, len(request.ws_extensions))
    324         first_extension = request.ws_extensions[0]
    325         self.assertEqual(common.DEFLATE_FRAME_EXTENSION,
    326                          first_extension.name())
    327 
    328     def test_bad_requests(self):
    329         bad_cases = [
    330             ('HTTP request',
    331              RequestDefinition(
    332                  'GET', '/demo',
    333                  {'Host': 'www.google.com',
    334                   'User-Agent':
    335                       'Mozilla/5.0 (Macintosh; U; Intel Mac OS X 10.5;'
    336                       ' en-US; rv:1.9.1.3) Gecko/20090824 Firefox/3.5.3'
    337                       ' GTB6 GTBA',
    338                   'Accept':
    339                       'text/html,application/xhtml+xml,application/xml;q=0.9,'
    340                       '*/*;q=0.8',
    341                   'Accept-Language': 'en-us,en;q=0.5',
    342                   'Accept-Encoding': 'gzip,deflate',
    343                   'Accept-Charset': 'ISO-8859-1,utf-8;q=0.7,*;q=0.7',
    344                   'Keep-Alive': '300',
    345                   'Connection': 'keep-alive'}), None, True)]
    346 
    347         request_def = _create_good_request_def()
    348         request_def.method = 'POST'
    349         bad_cases.append(('Wrong method', request_def, None, True))
    350 
    351         request_def = _create_good_request_def()
    352         del request_def.headers['Host']
    353         bad_cases.append(('Missing Host', request_def, None, True))
    354 
    355         request_def = _create_good_request_def()
    356         del request_def.headers['Upgrade']
    357         bad_cases.append(('Missing Upgrade', request_def, None, True))
    358 
    359         request_def = _create_good_request_def()
    360         request_def.headers['Upgrade'] = 'nonwebsocket'
    361         bad_cases.append(('Wrong Upgrade', request_def, None, True))
    362 
    363         request_def = _create_good_request_def()
    364         del request_def.headers['Connection']
    365         bad_cases.append(('Missing Connection', request_def, None, True))
    366 
    367         request_def = _create_good_request_def()
    368         request_def.headers['Connection'] = 'Downgrade'
    369         bad_cases.append(('Wrong Connection', request_def, None, True))
    370 
    371         request_def = _create_good_request_def()
    372         del request_def.headers['Sec-WebSocket-Key']
    373         bad_cases.append(('Missing Sec-WebSocket-Key', request_def, 400, True))
    374 
    375         request_def = _create_good_request_def()
    376         request_def.headers['Sec-WebSocket-Key'] = (
    377             'dGhlIHNhbXBsZSBub25jZQ==garbage')
    378         bad_cases.append(('Wrong Sec-WebSocket-Key (with garbage on the tail)',
    379                           request_def, 400, True))
    380 
    381         request_def = _create_good_request_def()
    382         request_def.headers['Sec-WebSocket-Key'] = 'YQ=='  # BASE64 of 'a'
    383         bad_cases.append(
    384             ('Wrong Sec-WebSocket-Key (decoded value is not 16 octets long)',
    385              request_def, 400, True))
    386 
    387         request_def = _create_good_request_def()
    388         # The last character right before == must be any of A, Q, w and g.
    389         request_def.headers['Sec-WebSocket-Key'] = (
    390             'AQIDBAUGBwgJCgsMDQ4PEC==')
    391         bad_cases.append(
    392             ('Wrong Sec-WebSocket-Key (padding bits are not zero)',
    393              request_def, 400, True))
    394 
    395         request_def = _create_good_request_def()
    396         request_def.headers['Sec-WebSocket-Key'] = (
    397             'dGhlIHNhbXBsZSBub25jZQ==,dGhlIHNhbXBsZSBub25jZQ==')
    398         bad_cases.append(
    399             ('Wrong Sec-WebSocket-Key (multiple values)',
    400              request_def, 400, True))
    401 
    402         request_def = _create_good_request_def()
    403         del request_def.headers['Sec-WebSocket-Version']
    404         bad_cases.append(('Missing Sec-WebSocket-Version', request_def, None,
    405                           True))
    406 
    407         request_def = _create_good_request_def()
    408         request_def.headers['Sec-WebSocket-Version'] = '3'
    409         bad_cases.append(('Wrong Sec-WebSocket-Version', request_def, None,
    410                           False))
    411 
    412         request_def = _create_good_request_def()
    413         request_def.headers['Sec-WebSocket-Version'] = '13, 13'
    414         bad_cases.append(('Wrong Sec-WebSocket-Version (multiple values)',
    415                           request_def, 400, True))
    416 
    417         for (case_name, request_def, expected_status,
    418              expect_handshake_exception) in bad_cases:
    419             request = _create_request(request_def)
    420             handshaker = Handshaker(request, mock.MockDispatcher())
    421             try:
    422                 handshaker.do_handshake()
    423                 self.fail('No exception thrown for \'%s\' case' % case_name)
    424             except HandshakeException, e:
    425                 self.assertTrue(expect_handshake_exception)
    426                 self.assertEqual(expected_status, e.status)
    427             except VersionException, e:
    428                 self.assertFalse(expect_handshake_exception)
    429 
    430 
    431 if __name__ == '__main__':
    432     unittest.main()
    433 
    434 
    435 # vi:sts=4 sw=4 et
    436