Home | History | Annotate | Download | only in test
      1 #!/usr/bin/env python
      2 #
      3 # Copyright 2012, Google Inc.
      4 # All rights reserved.
      5 #
      6 # Redistribution and use in source and binary forms, with or without
      7 # modification, are permitted provided that the following conditions are
      8 # met:
      9 #
     10 #     * Redistributions of source code must retain the above copyright
     11 # notice, this list of conditions and the following disclaimer.
     12 #     * Redistributions in binary form must reproduce the above
     13 # copyright notice, this list of conditions and the following disclaimer
     14 # in the documentation and/or other materials provided with the
     15 # distribution.
     16 #     * Neither the name of Google Inc. nor the names of its
     17 # contributors may be used to endorse or promote products derived from
     18 # this software without specific prior written permission.
     19 #
     20 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
     21 # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
     22 # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
     23 # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
     24 # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
     25 # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
     26 # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
     27 # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
     28 # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
     29 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
     30 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
     31 
     32 
     33 """WebSocket client utility for testing.
     34 
     35 This module contains helper methods for performing handshake, frame
     36 sending/receiving as a WebSocket client.
     37 
     38 This is code for testing mod_pywebsocket. Keep this code independent from
     39 mod_pywebsocket. Don't import e.g. Stream class for generating frame for
     40 testing. Using util.hexify, etc. that are not related to protocol processing
     41 is allowed.
     42 
     43 Note:
     44 This code is far from robust, e.g., we cut corners in handshake.
     45 """
     46 
     47 
     48 import base64
     49 import errno
     50 import logging
     51 import os
     52 import random
     53 import re
     54 import socket
     55 import struct
     56 
     57 from mod_pywebsocket import util
     58 
     59 
     60 DEFAULT_PORT = 80
     61 DEFAULT_SECURE_PORT = 443
     62 
     63 # Opcodes introduced in IETF HyBi 01 for the new framing format
     64 OPCODE_CONTINUATION = 0x0
     65 OPCODE_CLOSE = 0x8
     66 OPCODE_PING = 0x9
     67 OPCODE_PONG = 0xa
     68 OPCODE_TEXT = 0x1
     69 OPCODE_BINARY = 0x2
     70 
     71 # Strings used for handshake
     72 _UPGRADE_HEADER = 'Upgrade: websocket\r\n'
     73 _UPGRADE_HEADER_HIXIE75 = 'Upgrade: WebSocket\r\n'
     74 _CONNECTION_HEADER = 'Connection: Upgrade\r\n'
     75 
     76 WEBSOCKET_ACCEPT_UUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
     77 
     78 # Status codes
     79 STATUS_NORMAL_CLOSURE = 1000
     80 STATUS_GOING_AWAY = 1001
     81 STATUS_PROTOCOL_ERROR = 1002
     82 STATUS_UNSUPPORTED_DATA = 1003
     83 STATUS_NO_STATUS_RECEIVED = 1005
     84 STATUS_ABNORMAL_CLOSURE = 1006
     85 STATUS_INVALID_FRAME_PAYLOAD_DATA = 1007
     86 STATUS_POLICY_VIOLATION = 1008
     87 STATUS_MESSAGE_TOO_BIG = 1009
     88 STATUS_MANDATORY_EXT = 1010
     89 STATUS_INTERNAL_SERVER_ERROR = 1011
     90 STATUS_TLS_HANDSHAKE = 1015
     91 
     92 # Extension tokens
     93 _DEFLATE_STREAM_EXTENSION = 'deflate-stream'
     94 _DEFLATE_FRAME_EXTENSION = 'deflate-frame'
     95 # TODO(bashi): Update after mux implementation finished.
     96 _MUX_EXTENSION = 'mux_DO_NOT_USE'
     97 
     98 def _method_line(resource):
     99     return 'GET %s HTTP/1.1\r\n' % resource
    100 
    101 
    102 def _sec_origin_header(origin):
    103     return 'Sec-WebSocket-Origin: %s\r\n' % origin.lower()
    104 
    105 
    106 def _origin_header(origin):
    107     # 4.1 13. concatenation of the string "Origin:", a U+0020 SPACE character,
    108     # and the /origin/ value, converted to ASCII lowercase, to /fields/.
    109     return 'Origin: %s\r\n' % origin.lower()
    110 
    111 
    112 def _format_host_header(host, port, secure):
    113     # 4.1 9. Let /hostport/ be an empty string.
    114     # 4.1 10. Append the /host/ value, converted to ASCII lowercase, to
    115     # /hostport/
    116     hostport = host.lower()
    117     # 4.1 11. If /secure/ is false, and /port/ is not 80, or if /secure/
    118     # is true, and /port/ is not 443, then append a U+003A COLON character
    119     # (:) followed by the value of /port/, expressed as a base-ten integer,
    120     # to /hostport/
    121     if ((not secure and port != DEFAULT_PORT) or
    122         (secure and port != DEFAULT_SECURE_PORT)):
    123         hostport += ':' + str(port)
    124     # 4.1 12. concatenation of the string "Host:", a U+0020 SPACE
    125     # character, and /hostport/, to /fields/.
    126     return 'Host: %s\r\n' % hostport
    127 
    128 
    129 # TODO(tyoshino): Define a base class and move these shared methods to that.
    130 
    131 
    132 def receive_bytes(socket, length):
    133     bytes = []
    134     remaining = length
    135     while remaining > 0:
    136         received_bytes = socket.recv(remaining)
    137         if not received_bytes:
    138             raise Exception(
    139                 'Connection closed before receiving requested length '
    140                 '(requested %d bytes but received only %d bytes)' %
    141                 (length, length - remaining))
    142         bytes.append(received_bytes)
    143         remaining -= len(received_bytes)
    144     return ''.join(bytes)
    145 
    146 
    147 # TODO(tyoshino): Now the WebSocketHandshake class diverts these methods. We
    148 # should move to HTTP parser as specified in RFC 6455. For HyBi 00 and
    149 # Hixie 75, pack these methods as some parser class.
    150 
    151 
    152 def _read_fields(socket):
    153     # 4.1 32. let /fields/ be a list of name-value pairs, initially empty.
    154     fields = {}
    155     while True:
    156         # 4.1 33. let /name/ and /value/ be empty byte arrays
    157         name = ''
    158         value = ''
    159         # 4.1 34. read /name/
    160         name = _read_name(socket)
    161         if name is None:
    162             break
    163         # 4.1 35. read spaces
    164         # TODO(tyoshino): Skip only one space as described in the spec.
    165         ch = _skip_spaces(socket)
    166         # 4.1 36. read /value/
    167         value = _read_value(socket, ch)
    168         # 4.1 37. read a byte from the server
    169         ch = receive_bytes(socket, 1)
    170         if ch != '\n':  # 0x0A
    171             raise Exception(
    172                 'Expected LF but found %r while reading value %r for header '
    173                 '%r' % (ch, name, value))
    174         # 4.1 38. append an entry to the /fields/ list that has the name
    175         # given by the string obtained by interpreting the /name/ byte
    176         # array as a UTF-8 stream and the value given by the string
    177         # obtained by interpreting the /value/ byte array as a UTF-8 byte
    178         # stream.
    179         fields.setdefault(name, []).append(value)
    180         # 4.1 39. return to the "Field" step above
    181     return fields
    182 
    183 
    184 def _read_name(socket):
    185     # 4.1 33. let /name/ be empty byte arrays
    186     name = ''
    187     while True:
    188         # 4.1 34. read a byte from the server
    189         ch = receive_bytes(socket, 1)
    190         if ch == '\r':  # 0x0D
    191             return None
    192         elif ch == '\n':  # 0x0A
    193             raise Exception(
    194                 'Unexpected LF when reading header name %r' % name)
    195         elif ch == ':':  # 0x3A
    196             return name
    197         elif ch >= 'A' and ch <= 'Z':  # range 0x31 to 0x5A
    198             ch = chr(ord(ch) + 0x20)
    199             name += ch
    200         else:
    201             name += ch
    202 
    203 
    204 def _skip_spaces(socket):
    205     # 4.1 35. read a byte from the server
    206     while True:
    207         ch = receive_bytes(socket, 1)
    208         if ch == ' ':  # 0x20
    209             continue
    210         return ch
    211 
    212 
    213 def _read_value(socket, ch):
    214     # 4.1 33. let /value/ be empty byte arrays
    215     value = ''
    216     # 4.1 36. read a byte from server.
    217     while True:
    218         if ch == '\r':  # 0x0D
    219             return value
    220         elif ch == '\n':  # 0x0A
    221             raise Exception(
    222                 'Unexpected LF when reading header value %r' % value)
    223         else:
    224             value += ch
    225         ch = receive_bytes(socket, 1)
    226 
    227 
    228 def read_frame_header(socket):
    229     received = receive_bytes(socket, 2)
    230 
    231     first_byte = ord(received[0])
    232     fin = (first_byte >> 7) & 1
    233     rsv1 = (first_byte >> 6) & 1
    234     rsv2 = (first_byte >> 5) & 1
    235     rsv3 = (first_byte >> 4) & 1
    236     opcode = first_byte & 0xf
    237 
    238     second_byte = ord(received[1])
    239     mask = (second_byte >> 7) & 1
    240     payload_length = second_byte & 0x7f
    241 
    242     if mask != 0:
    243         raise Exception(
    244             'Mask bit must be 0 for frames coming from server')
    245 
    246     if payload_length == 127:
    247         extended_payload_length = receive_bytes(socket, 8)
    248         payload_length = struct.unpack(
    249             '!Q', extended_payload_length)[0]
    250         if payload_length > 0x7FFFFFFFFFFFFFFF:
    251             raise Exception('Extended payload length >= 2^63')
    252     elif payload_length == 126:
    253         extended_payload_length = receive_bytes(socket, 2)
    254         payload_length = struct.unpack(
    255             '!H', extended_payload_length)[0]
    256 
    257     return fin, rsv1, rsv2, rsv3, opcode, payload_length
    258 
    259 
    260 class _TLSSocket(object):
    261     """Wrapper for a TLS connection."""
    262 
    263     def __init__(self, raw_socket):
    264         self._ssl = socket.ssl(raw_socket)
    265 
    266     def send(self, bytes):
    267         return self._ssl.write(bytes)
    268 
    269     def recv(self, size=-1):
    270         return self._ssl.read(size)
    271 
    272     def close(self):
    273         # Nothing to do.
    274         pass
    275 
    276 
    277 class HttpStatusException(Exception):
    278     """This exception will be raised when unexpected http status code was
    279     received as a result of handshake.
    280     """
    281 
    282     def __init__(self, name, status):
    283         super(HttpStatusException, self).__init__(name)
    284         self.status = status
    285 
    286 
    287 class WebSocketHandshake(object):
    288     """Opening handshake processor for the WebSocket protocol (RFC 6455)."""
    289 
    290     def __init__(self, options):
    291         self._logger = util.get_class_logger(self)
    292 
    293         self._options = options
    294 
    295     def handshake(self, socket):
    296         """Handshake WebSocket.
    297 
    298         Raises:
    299             Exception: handshake failed.
    300         """
    301 
    302         self._socket = socket
    303 
    304         request_line = _method_line(self._options.resource)
    305         self._logger.debug('Opening handshake Request-Line: %r', request_line)
    306         self._socket.sendall(request_line)
    307 
    308         fields = []
    309         fields.append(_UPGRADE_HEADER)
    310         fields.append(_CONNECTION_HEADER)
    311 
    312         fields.append(_format_host_header(
    313             self._options.server_host,
    314             self._options.server_port,
    315             self._options.use_tls))
    316 
    317         if self._options.version is 8:
    318             fields.append(_sec_origin_header(self._options.origin))
    319         else:
    320             fields.append(_origin_header(self._options.origin))
    321 
    322         original_key = os.urandom(16)
    323         key = base64.b64encode(original_key)
    324         self._logger.debug(
    325             'Sec-WebSocket-Key: %s (%s)', key, util.hexify(original_key))
    326         fields.append('Sec-WebSocket-Key: %s\r\n' % key)
    327 
    328         fields.append('Sec-WebSocket-Version: %d\r\n' % self._options.version)
    329 
    330         # Setting up extensions.
    331         if len(self._options.extensions) > 0:
    332             fields.append('Sec-WebSocket-Extensions: %s\r\n' %
    333                           ', '.join(self._options.extensions))
    334 
    335         self._logger.debug('Opening handshake request headers: %r', fields)
    336 
    337         for field in fields:
    338             self._socket.sendall(field)
    339         self._socket.sendall('\r\n')
    340 
    341         self._logger.info('Sent opening handshake request')
    342 
    343         field = ''
    344         while True:
    345             ch = receive_bytes(self._socket, 1)
    346             field += ch
    347             if ch == '\n':
    348                 break
    349 
    350         self._logger.debug('Opening handshake Response-Line: %r', field)
    351 
    352         if len(field) < 7 or not field.endswith('\r\n'):
    353             raise Exception('Wrong status line: %r' % field)
    354         m = re.match('[^ ]* ([^ ]*) .*', field)
    355         if m is None:
    356             raise Exception(
    357                 'No HTTP status code found in status line: %r' % field)
    358         code = m.group(1)
    359         if not re.match('[0-9][0-9][0-9]', code):
    360             raise Exception(
    361                 'HTTP status code %r is not three digit in status line: %r' %
    362                 (code, field))
    363         if code != '101':
    364             raise HttpStatusException(
    365                 'Expected HTTP status code 101 but found %r in status line: '
    366                 '%r' % (code, field), int(code))
    367         fields = _read_fields(self._socket)
    368         ch = receive_bytes(self._socket, 1)
    369         if ch != '\n':  # 0x0A
    370             raise Exception('Expected LF but found: %r' % ch)
    371 
    372         self._logger.debug('Opening handshake response headers: %r', fields)
    373 
    374         # Check /fields/
    375         if len(fields['upgrade']) != 1:
    376             raise Exception(
    377                 'Multiple Upgrade headers found: %s' % fields['upgrade'])
    378         if len(fields['connection']) != 1:
    379             raise Exception(
    380                 'Multiple Connection headers found: %s' % fields['connection'])
    381         if fields['upgrade'][0] != 'websocket':
    382             raise Exception(
    383                 'Unexpected Upgrade header value: %s' % fields['upgrade'][0])
    384         if fields['connection'][0].lower() != 'upgrade':
    385             raise Exception(
    386                 'Unexpected Connection header value: %s' %
    387                 fields['connection'][0])
    388 
    389         if len(fields['sec-websocket-accept']) != 1:
    390             raise Exception(
    391                 'Multiple Sec-WebSocket-Accept headers found: %s' %
    392                 fields['sec-websocket-accept'])
    393 
    394         accept = fields['sec-websocket-accept'][0]
    395 
    396         # Validate
    397         try:
    398             decoded_accept = base64.b64decode(accept)
    399         except TypeError, e:
    400             raise HandshakeException(
    401                 'Illegal value for header Sec-WebSocket-Accept: ' + accept)
    402 
    403         if len(decoded_accept) != 20:
    404             raise HandshakeException(
    405                 'Decoded value of Sec-WebSocket-Accept is not 20-byte long')
    406 
    407         self._logger.debug('Actual Sec-WebSocket-Accept: %r (%s)',
    408                            accept, util.hexify(decoded_accept))
    409 
    410         original_expected_accept = util.sha1_hash(
    411             key + WEBSOCKET_ACCEPT_UUID).digest()
    412         expected_accept = base64.b64encode(original_expected_accept)
    413 
    414         self._logger.debug('Expected Sec-WebSocket-Accept: %r (%s)',
    415                            expected_accept,
    416                            util.hexify(original_expected_accept))
    417 
    418         if accept != expected_accept:
    419             raise Exception(
    420                 'Invalid Sec-WebSocket-Accept header: %r (expected) != %r '
    421                 '(actual)' % (accept, expected_accept))
    422 
    423         server_extensions_header = fields.get('sec-websocket-extensions')
    424         if (server_extensions_header is None or
    425             len(server_extensions_header) != 1):
    426             accepted_extensions = []
    427         else:
    428             accepted_extensions = server_extensions_header[0].split(',')
    429             # TODO(tyoshino): Follow the ABNF in the spec.
    430             accepted_extensions = [s.strip() for s in accepted_extensions]
    431 
    432         # Scan accepted extension list to check if there is any unrecognized
    433         # extensions or extensions we didn't request in it. Then, for
    434         # extensions we request, parse them and store parameters. They will be
    435         # used later by each extension.
    436         deflate_stream_accepted = False
    437         deflate_frame_accepted = False
    438         mux_accepted = False
    439         for extension in accepted_extensions:
    440             if extension == '':
    441                 continue
    442             if extension == _DEFLATE_STREAM_EXTENSION:
    443                 if self._options.use_deflate_stream:
    444                     deflate_stream_accepted = True
    445                     continue
    446             if extension == _DEFLATE_FRAME_EXTENSION:
    447                 if self._options.use_deflate_frame:
    448                     deflate_frame_accepted = True
    449                     continue
    450             if extension == _MUX_EXTENSION:
    451                 if self._options.use_mux:
    452                     mux_accepted = True
    453                     continue
    454 
    455             raise Exception(
    456                 'Received unrecognized extension: %s' % extension)
    457 
    458         # Let all extensions check the response for extension request.
    459 
    460         if self._options.use_deflate_stream and not deflate_stream_accepted:
    461             raise Exception('%s extension not accepted' %
    462                             _DEFLATE_STREAM_EXTENSION)
    463 
    464         if (self._options.use_deflate_frame and
    465             not deflate_frame_accepted):
    466             raise Exception('%s extension not accepted' %
    467                             _DEFLATE_FRAME_EXTENSION)
    468 
    469         if self._options.use_mux and not mux_accepted:
    470             raise Exception('%s extension not accepted' % _MUX_EXTENSION)
    471 
    472 
    473 class WebSocketHybi00Handshake(object):
    474     """Opening handshake processor for the WebSocket protocol version HyBi 00.
    475     """
    476 
    477     def __init__(self, options, draft_field):
    478         self._logger = util.get_class_logger(self)
    479 
    480         self._options = options
    481         self._draft_field = draft_field
    482 
    483     def handshake(self, socket):
    484         """Handshake WebSocket.
    485 
    486         Raises:
    487             Exception: handshake failed.
    488         """
    489 
    490         self._socket = socket
    491 
    492         # 4.1 5. send request line.
    493         request_line = _method_line(self._options.resource)
    494         self._logger.debug('Opening handshake Request-Line: %r', request_line)
    495         self._socket.sendall(request_line)
    496         # 4.1 6. Let /fields/ be an empty list of strings.
    497         fields = []
    498         # 4.1 7. Add the string "Upgrade: WebSocket" to /fields/.
    499         fields.append(_UPGRADE_HEADER_HIXIE75)
    500         # 4.1 8. Add the string "Connection: Upgrade" to /fields/.
    501         fields.append(_CONNECTION_HEADER)
    502         # 4.1 9-12. Add Host: field to /fields/.
    503         fields.append(_format_host_header(
    504             self._options.server_host,
    505             self._options.server_port,
    506             self._options.use_tls))
    507         # 4.1 13. Add Origin: field to /fields/.
    508         fields.append(_origin_header(self._options.origin))
    509         # TODO: 4.1 14 Add Sec-WebSocket-Protocol: field to /fields/.
    510         # TODO: 4.1 15 Add cookie headers to /fields/.
    511 
    512         # 4.1 16-23. Add Sec-WebSocket-Key<n> to /fields/.
    513         self._number1, key1 = self._generate_sec_websocket_key()
    514         self._logger.debug('Number1: %d', self._number1)
    515         fields.append('Sec-WebSocket-Key1: %s\r\n' % key1)
    516         self._number2, key2 = self._generate_sec_websocket_key()
    517         self._logger.debug('Number2: %d', self._number1)
    518         fields.append('Sec-WebSocket-Key2: %s\r\n' % key2)
    519 
    520         fields.append('Sec-WebSocket-Draft: %s\r\n' % self._draft_field)
    521 
    522         # 4.1 24. For each string in /fields/, in a random order: send the
    523         # string, encoded as UTF-8, followed by a UTF-8 encoded U+000D CARRIAGE
    524         # RETURN U+000A LINE FEED character pair (CRLF).
    525         random.shuffle(fields)
    526 
    527         self._logger.debug('Opening handshake request headers: %r', fields)
    528         for field in fields:
    529             self._socket.sendall(field)
    530 
    531         # 4.1 25. send a UTF-8-encoded U+000D CARRIAGE RETURN U+000A LINE FEED
    532         # character pair (CRLF).
    533         self._socket.sendall('\r\n')
    534         # 4.1 26. let /key3/ be a string consisting of eight random bytes (or
    535         # equivalently, a random 64 bit integer encoded in a big-endian order).
    536         self._key3 = self._generate_key3()
    537         # 4.1 27. send /key3/ to the server.
    538         self._socket.sendall(self._key3)
    539         self._logger.debug(
    540             'Key3: %r (%s)', self._key3, util.hexify(self._key3))
    541 
    542         self._logger.info('Sent opening handshake request')
    543 
    544         # 4.1 28. Read bytes from the server until either the connection
    545         # closes, or a 0x0A byte is read. let /field/ be these bytes, including
    546         # the 0x0A bytes.
    547         field = ''
    548         while True:
    549             ch = receive_bytes(self._socket, 1)
    550             field += ch
    551             if ch == '\n':
    552                 break
    553 
    554         self._logger.debug('Opening handshake Response-Line: %r', field)
    555 
    556         # if /field/ is not at least seven bytes long, or if the last
    557         # two bytes aren't 0x0D and 0x0A respectively, or if it does not
    558         # contain at least two 0x20 bytes, then fail the WebSocket connection
    559         # and abort these steps.
    560         if len(field) < 7 or not field.endswith('\r\n'):
    561             raise Exception('Wrong status line: %r' % field)
    562         m = re.match('[^ ]* ([^ ]*) .*', field)
    563         if m is None:
    564             raise Exception('No code found in status line: %r' % field)
    565         # 4.1 29. let /code/ be the substring of /field/ that starts from the
    566         # byte after the first 0x20 byte, and ends with the byte before the
    567         # second 0x20 byte.
    568         code = m.group(1)
    569         # 4.1 30. if /code/ is not three bytes long, or if any of the bytes in
    570         # /code/ are not in the range 0x30 to 0x90, then fail the WebSocket
    571         # connection and abort these steps.
    572         if not re.match('[0-9][0-9][0-9]', code):
    573             raise Exception(
    574                 'HTTP status code %r is not three digit in status line: %r' %
    575                 (code, field))
    576         # 4.1 31. if /code/, interpreted as UTF-8, is "101", then move to the
    577         # next step.
    578         if code != '101':
    579             raise HttpStatusException(
    580                 'Expected HTTP status code 101 but found %r in status line: '
    581                 '%r' % (code, field), int(code))
    582         # 4.1 32-39. read fields into /fields/
    583         fields = _read_fields(self._socket)
    584 
    585         self._logger.debug('Opening handshake response headers: %r', fields)
    586 
    587         # 4.1 40. _Fields processing_
    588         # read a byte from server
    589         ch = receive_bytes(self._socket, 1)
    590         if ch != '\n':  # 0x0A
    591             raise Exception('Expected LF but found %r' % ch)
    592         # 4.1 41. check /fields/
    593         if len(fields['upgrade']) != 1:
    594             raise Exception(
    595                 'Multiple Upgrade headers found: %s' % fields['upgrade'])
    596         if len(fields['connection']) != 1:
    597             raise Exception(
    598                 'Multiple Connection headers found: %s' % fields['connection'])
    599         if len(fields['sec-websocket-origin']) != 1:
    600             raise Exception(
    601                 'Multiple Sec-WebSocket-Origin headers found: %s' %
    602                 fields['sec-sebsocket-origin'])
    603         if len(fields['sec-websocket-location']) != 1:
    604             raise Exception(
    605                 'Multiple Sec-WebSocket-Location headers found: %s' %
    606                 fields['sec-sebsocket-location'])
    607         # TODO(ukai): protocol
    608         # if the entry's name is "upgrade"
    609         #  if the value is not exactly equal to the string "WebSocket",
    610         #  then fail the WebSocket connection and abort these steps.
    611         if fields['upgrade'][0] != 'WebSocket':
    612             raise Exception(
    613                 'Unexpected Upgrade header value: %s' % fields['upgrade'][0])
    614         # if the entry's name is "connection"
    615         #  if the value, converted to ASCII lowercase, is not exactly equal
    616         #  to the string "upgrade", then fail the WebSocket connection and
    617         #  abort these steps.
    618         if fields['connection'][0].lower() != 'upgrade':
    619             raise Exception(
    620                 'Unexpected Connection header value: %s' %
    621                 fields['connection'][0])
    622         # TODO(ukai): check origin, location, cookie, ..
    623 
    624         # 4.1 42. let /challenge/ be the concatenation of /number_1/,
    625         # expressed as a big endian 32 bit integer, /number_2/, expressed
    626         # as big endian 32 bit integer, and the eight bytes of /key_3/ in the
    627         # order they were sent on the wire.
    628         challenge = struct.pack('!I', self._number1)
    629         challenge += struct.pack('!I', self._number2)
    630         challenge += self._key3
    631 
    632         self._logger.debug(
    633             'Challenge: %r (%s)', challenge, util.hexify(challenge))
    634 
    635         # 4.1 43. let /expected/ be the MD5 fingerprint of /challenge/ as a
    636         # big-endian 128 bit string.
    637         expected = util.md5_hash(challenge).digest()
    638         self._logger.debug(
    639             'Expected challenge response: %r (%s)',
    640             expected, util.hexify(expected))
    641 
    642         # 4.1 44. read sixteen bytes from the server.
    643         # let /reply/ be those bytes.
    644         reply = receive_bytes(self._socket, 16)
    645         self._logger.debug(
    646             'Actual challenge response: %r (%s)', reply, util.hexify(reply))
    647 
    648         # 4.1 45. if /reply/ does not exactly equal /expected/, then fail
    649         # the WebSocket connection and abort these steps.
    650         if expected != reply:
    651             raise Exception(
    652                 'Bad challenge response: %r (expected) != %r (actual)' %
    653                 (expected, reply))
    654         # 4.1 46. The *WebSocket connection is established*.
    655 
    656     def _generate_sec_websocket_key(self):
    657         # 4.1 16. let /spaces_n/ be a random integer from 1 to 12 inclusive.
    658         spaces = random.randint(1, 12)
    659         # 4.1 17. let /max_n/ be the largest integer not greater than
    660         #  4,294,967,295 divided by /spaces_n/.
    661         maxnum = 4294967295 / spaces
    662         # 4.1 18. let /number_n/ be a random integer from 0 to /max_n/
    663         # inclusive.
    664         number = random.randint(0, maxnum)
    665         # 4.1 19. let /product_n/ be the result of multiplying /number_n/ and
    666         # /spaces_n/ together.
    667         product = number * spaces
    668         # 4.1 20. let /key_n/ be a string consisting of /product_n/, expressed
    669         # in base ten using the numerals in the range U+0030 DIGIT ZERO (0) to
    670         # U+0039 DIGIT NINE (9).
    671         key = str(product)
    672         # 4.1 21. insert between one and twelve random characters from the
    673         # range U+0021 to U+002F and U+003A to U+007E into /key_n/ at random
    674         # positions.
    675         available_chars = range(0x21, 0x2f + 1) + range(0x3a, 0x7e + 1)
    676         n = random.randint(1, 12)
    677         for _ in xrange(n):
    678             ch = random.choice(available_chars)
    679             pos = random.randint(0, len(key))
    680             key = key[0:pos] + chr(ch) + key[pos:]
    681         # 4.1 22. insert /spaces_n/ U+0020 SPACE characters into /key_n/ at
    682         # random positions other than start or end of the string.
    683         for _ in xrange(spaces):
    684             pos = random.randint(1, len(key) - 1)
    685             key = key[0:pos] + ' ' + key[pos:]
    686         return number, key
    687 
    688     def _generate_key3(self):
    689         # 4.1 26. let /key3/ be a string consisting of eight random bytes (or
    690         # equivalently, a random 64 bit integer encoded in a big-endian order).
    691         return ''.join([chr(random.randint(0, 255)) for _ in xrange(8)])
    692 
    693 
    694 class WebSocketHixie75Handshake(object):
    695     """WebSocket handshake processor for IETF Hixie 75."""
    696 
    697     _EXPECTED_RESPONSE = (
    698         'HTTP/1.1 101 Web Socket Protocol Handshake\r\n' +
    699         _UPGRADE_HEADER_HIXIE75 +
    700         _CONNECTION_HEADER)
    701 
    702     def __init__(self, options):
    703         self._logger = util.get_class_logger(self)
    704 
    705         self._options = options
    706 
    707     def _skip_headers(self):
    708         terminator = '\r\n\r\n'
    709         pos = 0
    710         while pos < len(terminator):
    711             received = receive_bytes(self._socket, 1)
    712             if received == terminator[pos]:
    713                 pos += 1
    714             elif received == terminator[0]:
    715                 pos = 1
    716             else:
    717                 pos = 0
    718 
    719     def handshake(self, socket):
    720         self._socket = socket
    721 
    722         request_line = _method_line(self._options.resource)
    723         self._logger.debug('Opening handshake Request-Line: %r', request_line)
    724         self._socket.sendall(request_line)
    725 
    726         headers = _UPGRADE_HEADER_HIXIE75 + _CONNECTION_HEADER
    727         headers += _format_host_header(
    728             self._options.server_host,
    729             self._options.server_port,
    730             self._options.use_tls)
    731         headers += _origin_header(self._options.origin)
    732         self._logger.debug('Opening handshake request headers: %r', headers)
    733         self._socket.sendall(headers)
    734 
    735         self._socket.sendall('\r\n')
    736 
    737         self._logger.info('Sent opening handshake request')
    738 
    739         for expected_char in WebSocketHixie75Handshake._EXPECTED_RESPONSE:
    740             received = receive_bytes(self._socket, 1)
    741             if expected_char != received:
    742                 raise Exception('Handshake failure')
    743         # We cut corners and skip other headers.
    744         self._skip_headers()
    745 
    746 
    747 class WebSocketStream(object):
    748     """Frame processor for the WebSocket protocol (RFC 6455)."""
    749 
    750     def __init__(self, socket, handshake):
    751         self._handshake = handshake
    752         if self._handshake._options.use_deflate_stream:
    753             self._socket = util.DeflateSocket(socket)
    754         else:
    755             self._socket = socket
    756 
    757         # Filters applied to application data part of data frames.
    758         self._outgoing_frame_filter = None
    759         self._incoming_frame_filter = None
    760 
    761         if self._handshake._options.use_deflate_frame:
    762             self._outgoing_frame_filter = (
    763                 util._RFC1979Deflater(None, False))
    764             self._incoming_frame_filter = util._RFC1979Inflater()
    765 
    766         self._fragmented = False
    767 
    768     def _mask_hybi(self, s):
    769         # TODO(tyoshino): os.urandom does open/read/close for every call. If
    770         # performance matters, change this to some library call that generates
    771         # cryptographically secure pseudo random number sequence.
    772         masking_nonce = os.urandom(4)
    773         result = [masking_nonce]
    774         count = 0
    775         for c in s:
    776             result.append(chr(ord(c) ^ ord(masking_nonce[count])))
    777             count = (count + 1) % len(masking_nonce)
    778         return ''.join(result)
    779 
    780     def send_frame_of_arbitrary_bytes(self, header, body):
    781         self._socket.sendall(header + self._mask_hybi(body))
    782 
    783     def send_data(self, payload, frame_type, end=True, mask=True):
    784         if self._outgoing_frame_filter is not None:
    785             payload = self._outgoing_frame_filter.filter(payload)
    786 
    787         if self._fragmented:
    788             opcode = OPCODE_CONTINUATION
    789         else:
    790             opcode = frame_type
    791 
    792         if end:
    793             self._fragmented = False
    794             fin = 1
    795         else:
    796             self._fragmented = True
    797             fin = 0
    798 
    799         rsv1 = 0
    800         if self._handshake._options.use_deflate_frame:
    801             rsv1 = 1
    802 
    803         if mask:
    804             mask_bit = 1 << 7
    805         else:
    806             mask_bit = 0
    807 
    808         header = chr(fin << 7 | rsv1 << 6 | opcode)
    809         payload_length = len(payload)
    810         if payload_length <= 125:
    811             header += chr(mask_bit | payload_length)
    812         elif payload_length < 1 << 16:
    813             header += chr(mask_bit | 126) + struct.pack('!H', payload_length)
    814         elif payload_length < 1 << 63:
    815             header += chr(mask_bit | 127) + struct.pack('!Q', payload_length)
    816         else:
    817             raise Exception('Too long payload (%d byte)' % payload_length)
    818         if mask:
    819             payload = self._mask_hybi(payload)
    820         self._socket.sendall(header + payload)
    821 
    822     def send_binary(self, payload, end=True, mask=True):
    823         self.send_data(payload, OPCODE_BINARY, end, mask)
    824 
    825     def send_text(self, payload, end=True, mask=True):
    826         self.send_data(payload.encode('utf-8'), OPCODE_TEXT, end, mask)
    827 
    828     def _assert_receive_data(self, payload, opcode, fin, rsv1, rsv2, rsv3):
    829         (actual_fin, actual_rsv1, actual_rsv2, actual_rsv3, actual_opcode,
    830          payload_length) = read_frame_header(self._socket)
    831 
    832         if actual_opcode != opcode:
    833             raise Exception(
    834                 'Unexpected opcode: %d (expected) vs %d (actual)' %
    835                 (opcode, actual_opcode))
    836 
    837         if actual_fin != fin:
    838             raise Exception(
    839                 'Unexpected fin: %d (expected) vs %d (actual)' %
    840                 (fin, actual_fin))
    841 
    842         if rsv1 is None:
    843             rsv1 = 0
    844             if self._handshake._options.use_deflate_frame:
    845                 rsv1 = 1
    846 
    847         if rsv2 is None:
    848             rsv2 = 0
    849 
    850         if rsv3 is None:
    851             rsv3 = 0
    852 
    853         if actual_rsv1 != rsv1:
    854             raise Exception(
    855                 'Unexpected rsv1: %r (expected) vs %r (actual)' %
    856                 (rsv1, actual_rsv1))
    857 
    858         if actual_rsv2 != rsv2:
    859             raise Exception(
    860                 'Unexpected rsv2: %r (expected) vs %r (actual)' %
    861                 (rsv2, actual_rsv2))
    862 
    863         if actual_rsv3 != rsv3:
    864             raise Exception(
    865                 'Unexpected rsv3: %r (expected) vs %r (actual)' %
    866                 (rsv3, actual_rsv3))
    867 
    868         received = receive_bytes(self._socket, payload_length)
    869 
    870         if self._incoming_frame_filter is not None:
    871             received = self._incoming_frame_filter.filter(received)
    872 
    873         if len(received) != len(payload):
    874             raise Exception(
    875                 'Unexpected payload length: %d (expected) vs %d (actual)' %
    876                 (len(payload), len(received)))
    877 
    878         if payload != received:
    879             raise Exception(
    880                 'Unexpected payload: %r (expected) vs %r (actual)' %
    881                 (payload, received))
    882 
    883     def assert_receive_binary(self, payload, opcode=OPCODE_BINARY, fin=1,
    884                               rsv1=None, rsv2=None, rsv3=None):
    885         self._assert_receive_data(payload, opcode, fin, rsv1, rsv2, rsv3)
    886 
    887     def assert_receive_text(self, payload, opcode=OPCODE_TEXT, fin=1,
    888                             rsv1=None, rsv2=None, rsv3=None):
    889         self._assert_receive_data(payload.encode('utf-8'), opcode, fin, rsv1,
    890                                   rsv2, rsv3)
    891 
    892     def _build_close_frame(self, code, reason, mask):
    893         frame = chr(1 << 7 | OPCODE_CLOSE)
    894 
    895         if code is not None:
    896             body = struct.pack('!H', code) + reason.encode('utf-8')
    897         else:
    898             body = ''
    899         if mask:
    900             frame += chr(1 << 7 | len(body)) + self._mask_hybi(body)
    901         else:
    902             frame += chr(len(body)) + body
    903         return frame
    904 
    905     def send_close(self, code, reason):
    906         self._socket.sendall(
    907             self._build_close_frame(code, reason, True))
    908 
    909     def assert_receive_close(self, code, reason):
    910         expected_frame = self._build_close_frame(code, reason, False)
    911         actual_frame = receive_bytes(self._socket, len(expected_frame))
    912         if actual_frame != expected_frame:
    913             raise Exception(
    914                 'Unexpected close frame: %r (expected) vs %r (actual)' %
    915                 (expected_frame, actual_frame))
    916 
    917 
    918 class WebSocketStreamHixie75(object):
    919     """Frame processor for the WebSocket protocol version Hixie 75 and HyBi 00.
    920     """
    921 
    922     _CLOSE_FRAME = '\xff\x00'
    923 
    924     def __init__(self, socket, unused_handshake):
    925         self._socket = socket
    926 
    927     def send_frame_of_arbitrary_bytes(self, header, body):
    928         self._socket.sendall(header + body)
    929 
    930     def send_data(self, payload, unused_frame_typem, unused_end, unused_mask):
    931         frame = ''.join(['\x00', payload, '\xff'])
    932         self._socket.sendall(frame)
    933 
    934     def send_binary(self, unused_payload, unused_end, unused_mask):
    935         pass
    936 
    937     def send_text(self, payload, unused_end, unused_mask):
    938         encoded_payload = payload.encode('utf-8')
    939         frame = ''.join(['\x00', encoded_payload, '\xff'])
    940         self._socket.sendall(frame)
    941 
    942     def assert_receive_binary(self, payload, opcode=OPCODE_BINARY, fin=1,
    943                               rsv1=0, rsv2=0, rsv3=0):
    944         raise Exception('Binary frame is not supported in hixie75')
    945 
    946     def assert_receive_text(self, payload):
    947         received = receive_bytes(self._socket, 1)
    948 
    949         if received != '\x00':
    950             raise Exception(
    951                 'Unexpected frame type: %d (expected) vs %d (actual)' %
    952                 (0, ord(received)))
    953 
    954         received = receive_bytes(self._socket, len(payload) + 1)
    955         if received[-1] != '\xff':
    956             raise Exception(
    957                 'Termination expected: 0xff (expected) vs %r (actual)' %
    958                 received)
    959 
    960         if received[0:-1] != payload:
    961             raise Exception(
    962                 'Unexpected payload: %r (expected) vs %r (actual)' %
    963                 (payload, received[0:-1]))
    964 
    965     def send_close(self, code, reason):
    966         self._socket.sendall(self._CLOSE_FRAME)
    967 
    968     def assert_receive_close(self, unused_code, unused_reason):
    969         closing = receive_bytes(self._socket, len(self._CLOSE_FRAME))
    970         if closing != self._CLOSE_FRAME:
    971             raise Exception('Didn\'t receive closing handshake')
    972 
    973 
    974 class ClientOptions(object):
    975     """Holds option values to configure the Client object."""
    976 
    977     def __init__(self):
    978         self.version = 13
    979         self.server_host = ''
    980         self.origin = ''
    981         self.resource = ''
    982         self.server_port = -1
    983         self.socket_timeout = 1000
    984         self.use_tls = False
    985         self.extensions = []
    986         # Enable deflate-stream.
    987         self.use_deflate_stream = False
    988         # Enable deflate-application-data.
    989         self.use_deflate_frame = False
    990         # Enable mux
    991         self.use_mux = False
    992 
    993     def enable_deflate_stream(self):
    994         self.use_deflate_stream = True
    995         self.extensions.append(_DEFLATE_STREAM_EXTENSION)
    996 
    997     def enable_deflate_frame(self):
    998         self.use_deflate_frame = True
    999         self.extensions.append(_DEFLATE_FRAME_EXTENSION)
   1000 
   1001     def enable_mux(self):
   1002         self.use_mux = True
   1003         self.extensions.append(_MUX_EXTENSION)
   1004 
   1005 
   1006 class Client(object):
   1007     """WebSocket client."""
   1008 
   1009     def __init__(self, options, handshake, stream_class):
   1010         self._logger = util.get_class_logger(self)
   1011 
   1012         self._options = options
   1013         self._socket = None
   1014 
   1015         self._handshake = handshake
   1016         self._stream_class = stream_class
   1017 
   1018     def connect(self):
   1019         self._socket = socket.socket()
   1020         self._socket.settimeout(self._options.socket_timeout)
   1021 
   1022         self._socket.connect((self._options.server_host,
   1023                               self._options.server_port))
   1024         if self._options.use_tls:
   1025             self._socket = _TLSSocket(self._socket)
   1026 
   1027         self._handshake.handshake(self._socket)
   1028 
   1029         self._stream = self._stream_class(self._socket, self._handshake)
   1030 
   1031         self._logger.info('Connection established')
   1032 
   1033     def send_frame_of_arbitrary_bytes(self, header, body):
   1034         self._stream.send_frame_of_arbitrary_bytes(header, body)
   1035 
   1036     def send_message(self, message, end=True, binary=False, raw=False,
   1037                      mask=True):
   1038         if binary:
   1039             self._stream.send_binary(message, end, mask)
   1040         elif raw:
   1041             self._stream.send_data(message, OPCODE_TEXT, end, mask)
   1042         else:
   1043             self._stream.send_text(message, end, mask)
   1044 
   1045     def assert_receive(self, payload, binary=False):
   1046         if binary:
   1047             self._stream.assert_receive_binary(payload)
   1048         else:
   1049             self._stream.assert_receive_text(payload)
   1050 
   1051     def send_close(self, code=STATUS_NORMAL_CLOSURE, reason=''):
   1052         self._stream.send_close(code, reason)
   1053 
   1054     def assert_receive_close(self, code=STATUS_NORMAL_CLOSURE, reason=''):
   1055         self._stream.assert_receive_close(code, reason)
   1056 
   1057     def close_socket(self):
   1058         self._socket.close()
   1059 
   1060     def assert_connection_closed(self):
   1061         try:
   1062             read_data = receive_bytes(self._socket, 1)
   1063         except Exception, e:
   1064             if str(e).find(
   1065                 'Connection closed before receiving requested length ') == 0:
   1066                 return
   1067             try:
   1068                 error_number, message = e
   1069                 for error_name in ['ECONNRESET', 'WSAECONNRESET']:
   1070                     if (error_name in dir(errno) and
   1071                         error_number == getattr(errno, error_name)):
   1072                         return
   1073             except:
   1074                 raise e
   1075             raise e
   1076 
   1077         raise Exception('Connection is not closed (Read: %r)' % read_data)
   1078 
   1079 
   1080 def create_client(options):
   1081     return Client(
   1082         options, WebSocketHandshake(options), WebSocketStream)
   1083 
   1084 
   1085 def create_client_hybi00(options):
   1086     return Client(
   1087         options,
   1088         WebSocketHybi00Handshake(options, '0'),
   1089         WebSocketStreamHixie75)
   1090 
   1091 
   1092 def create_client_hixie75(options):
   1093     return Client(
   1094         options, WebSocketHixie75Handshake(options), WebSocketStreamHixie75)
   1095 
   1096 
   1097 # vi:sts=4 sw=4 et
   1098