Home | History | Annotate | Download | only in handshake
      1 # Copyright 2012, Google Inc.
      2 # All rights reserved.
      3 #
      4 # Redistribution and use in source and binary forms, with or without
      5 # modification, are permitted provided that the following conditions are
      6 # met:
      7 #
      8 #     * Redistributions of source code must retain the above copyright
      9 # notice, this list of conditions and the following disclaimer.
     10 #     * Redistributions in binary form must reproduce the above
     11 # copyright notice, this list of conditions and the following disclaimer
     12 # in the documentation and/or other materials provided with the
     13 # distribution.
     14 #     * Neither the name of Google Inc. nor the names of its
     15 # contributors may be used to endorse or promote products derived from
     16 # this software without specific prior written permission.
     17 #
     18 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
     19 # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
     20 # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
     21 # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
     22 # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
     23 # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
     24 # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
     25 # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
     26 # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
     27 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
     28 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
     29 
     30 
     31 """This file provides the opening handshake processor for the WebSocket
     32 protocol (RFC 6455).
     33 
     34 Specification:
     35 http://tools.ietf.org/html/rfc6455
     36 """
     37 
     38 
     39 # Note: request.connection.write is used in this module, even though mod_python
     40 # document says that it should be used only in connection handlers.
     41 # Unfortunately, we have no other options. For example, request.write is not
     42 # suitable because it doesn't allow direct raw bytes writing.
     43 
     44 
     45 import base64
     46 import logging
     47 import os
     48 import re
     49 
     50 from mod_pywebsocket import common
     51 from mod_pywebsocket.extensions import get_extension_processor
     52 from mod_pywebsocket.extensions import is_compression_extension
     53 from mod_pywebsocket.handshake._base import check_request_line
     54 from mod_pywebsocket.handshake._base import format_header
     55 from mod_pywebsocket.handshake._base import get_mandatory_header
     56 from mod_pywebsocket.handshake._base import HandshakeException
     57 from mod_pywebsocket.handshake._base import parse_token_list
     58 from mod_pywebsocket.handshake._base import validate_mandatory_header
     59 from mod_pywebsocket.handshake._base import validate_subprotocol
     60 from mod_pywebsocket.handshake._base import VersionException
     61 from mod_pywebsocket.stream import Stream
     62 from mod_pywebsocket.stream import StreamOptions
     63 from mod_pywebsocket import util
     64 
     65 
     66 # Used to validate the value in the Sec-WebSocket-Key header strictly. RFC 4648
     67 # disallows non-zero padding, so the character right before == must be any of
     68 # A, Q, g and w.
     69 _SEC_WEBSOCKET_KEY_REGEX = re.compile('^[+/0-9A-Za-z]{21}[AQgw]==$')
     70 
     71 # Defining aliases for values used frequently.
     72 _VERSION_LATEST = common.VERSION_HYBI_LATEST
     73 _VERSION_LATEST_STRING = str(_VERSION_LATEST)
     74 _SUPPORTED_VERSIONS = [
     75     _VERSION_LATEST,
     76 ]
     77 
     78 
     79 def compute_accept(key):
     80     """Computes value for the Sec-WebSocket-Accept header from value of the
     81     Sec-WebSocket-Key header.
     82     """
     83 
     84     accept_binary = util.sha1_hash(
     85         key + common.WEBSOCKET_ACCEPT_UUID).digest()
     86     accept = base64.b64encode(accept_binary)
     87 
     88     return (accept, accept_binary)
     89 
     90 
     91 class Handshaker(object):
     92     """Opening handshake processor for the WebSocket protocol (RFC 6455)."""
     93 
     94     def __init__(self, request, dispatcher):
     95         """Construct an instance.
     96 
     97         Args:
     98             request: mod_python request.
     99             dispatcher: Dispatcher (dispatch.Dispatcher).
    100 
    101         Handshaker will add attributes such as ws_resource during handshake.
    102         """
    103 
    104         self._logger = util.get_class_logger(self)
    105 
    106         self._request = request
    107         self._dispatcher = dispatcher
    108 
    109     def _validate_connection_header(self):
    110         connection = get_mandatory_header(
    111             self._request, common.CONNECTION_HEADER)
    112 
    113         try:
    114             connection_tokens = parse_token_list(connection)
    115         except HandshakeException, e:
    116             raise HandshakeException(
    117                 'Failed to parse %s: %s' % (common.CONNECTION_HEADER, e))
    118 
    119         connection_is_valid = False
    120         for token in connection_tokens:
    121             if token.lower() == common.UPGRADE_CONNECTION_TYPE.lower():
    122                 connection_is_valid = True
    123                 break
    124         if not connection_is_valid:
    125             raise HandshakeException(
    126                 '%s header doesn\'t contain "%s"' %
    127                 (common.CONNECTION_HEADER, common.UPGRADE_CONNECTION_TYPE))
    128 
    129     def do_handshake(self):
    130         self._request.ws_close_code = None
    131         self._request.ws_close_reason = None
    132 
    133         # Parsing.
    134 
    135         check_request_line(self._request)
    136 
    137         validate_mandatory_header(
    138             self._request,
    139             common.UPGRADE_HEADER,
    140             common.WEBSOCKET_UPGRADE_TYPE)
    141 
    142         self._validate_connection_header()
    143 
    144         self._request.ws_resource = self._request.uri
    145 
    146         unused_host = get_mandatory_header(self._request, common.HOST_HEADER)
    147 
    148         self._request.ws_version = self._check_version()
    149 
    150         try:
    151             self._get_origin()
    152             self._set_protocol()
    153             self._parse_extensions()
    154 
    155             # Key validation, response generation.
    156 
    157             key = self._get_key()
    158             (accept, accept_binary) = compute_accept(key)
    159             self._logger.debug(
    160                 '%s: %r (%s)',
    161                 common.SEC_WEBSOCKET_ACCEPT_HEADER,
    162                 accept,
    163                 util.hexify(accept_binary))
    164 
    165             self._logger.debug('Protocol version is RFC 6455')
    166 
    167             # Setup extension processors.
    168 
    169             processors = []
    170             if self._request.ws_requested_extensions is not None:
    171                 for extension_request in self._request.ws_requested_extensions:
    172                     processor = get_extension_processor(extension_request)
    173                     # Unknown extension requests are just ignored.
    174                     if processor is not None:
    175                         processors.append(processor)
    176             self._request.ws_extension_processors = processors
    177 
    178             # List of extra headers. The extra handshake handler may add header
    179             # data as name/value pairs to this list and pywebsocket appends
    180             # them to the WebSocket handshake.
    181             self._request.extra_headers = []
    182 
    183             # Extra handshake handler may modify/remove processors.
    184             self._dispatcher.do_extra_handshake(self._request)
    185             processors = filter(lambda processor: processor is not None,
    186                                 self._request.ws_extension_processors)
    187 
    188             # Ask each processor if there are extensions on the request which
    189             # cannot co-exist. When processor decided other processors cannot
    190             # co-exist with it, the processor marks them (or itself) as
    191             # "inactive". The first extension processor has the right to
    192             # make the final call.
    193             for processor in reversed(processors):
    194                 if processor.is_active():
    195                     processor.check_consistency_with_other_processors(
    196                         processors)
    197             processors = filter(lambda processor: processor.is_active(),
    198                                 processors)
    199 
    200             accepted_extensions = []
    201 
    202             # We need to take into account of mux extension here.
    203             # If mux extension exists:
    204             # - Remove processors of extensions for logical channel,
    205             #   which are processors located before the mux processor
    206             # - Pass extension requests for logical channel to mux processor
    207             # - Attach the mux processor to the request. It will be referred
    208             #   by dispatcher to see whether the dispatcher should use mux
    209             #   handler or not.
    210             mux_index = -1
    211             for i, processor in enumerate(processors):
    212                 if processor.name() == common.MUX_EXTENSION:
    213                     mux_index = i
    214                     break
    215             if mux_index >= 0:
    216                 logical_channel_extensions = []
    217                 for processor in processors[:mux_index]:
    218                     logical_channel_extensions.append(processor.request())
    219                     processor.set_active(False)
    220                 self._request.mux_processor = processors[mux_index]
    221                 self._request.mux_processor.set_extensions(
    222                     logical_channel_extensions)
    223                 processors = filter(lambda processor: processor.is_active(),
    224                                     processors)
    225 
    226             stream_options = StreamOptions()
    227 
    228             for index, processor in enumerate(processors):
    229                 if not processor.is_active():
    230                     continue
    231 
    232                 extension_response = processor.get_extension_response()
    233                 if extension_response is None:
    234                     # Rejected.
    235                     continue
    236 
    237                 accepted_extensions.append(extension_response)
    238 
    239                 processor.setup_stream_options(stream_options)
    240 
    241                 if not is_compression_extension(processor.name()):
    242                     continue
    243 
    244                 # Inactivate all of the following compression extensions.
    245                 for j in xrange(index + 1, len(processors)):
    246                     if is_compression_extension(processors[j].name()):
    247                         processors[j].set_active(False)
    248 
    249             if len(accepted_extensions) > 0:
    250                 self._request.ws_extensions = accepted_extensions
    251                 self._logger.debug(
    252                     'Extensions accepted: %r',
    253                     map(common.ExtensionParameter.name, accepted_extensions))
    254             else:
    255                 self._request.ws_extensions = None
    256 
    257             self._request.ws_stream = self._create_stream(stream_options)
    258 
    259             if self._request.ws_requested_protocols is not None:
    260                 if self._request.ws_protocol is None:
    261                     raise HandshakeException(
    262                         'do_extra_handshake must choose one subprotocol from '
    263                         'ws_requested_protocols and set it to ws_protocol')
    264                 validate_subprotocol(self._request.ws_protocol)
    265 
    266                 self._logger.debug(
    267                     'Subprotocol accepted: %r',
    268                     self._request.ws_protocol)
    269             else:
    270                 if self._request.ws_protocol is not None:
    271                     raise HandshakeException(
    272                         'ws_protocol must be None when the client didn\'t '
    273                         'request any subprotocol')
    274 
    275             self._send_handshake(accept)
    276         except HandshakeException, e:
    277             if not e.status:
    278                 # Fallback to 400 bad request by default.
    279                 e.status = common.HTTP_STATUS_BAD_REQUEST
    280             raise e
    281 
    282     def _get_origin(self):
    283         origin_header = common.ORIGIN_HEADER
    284         origin = self._request.headers_in.get(origin_header)
    285         if origin is None:
    286             self._logger.debug('Client request does not have origin header')
    287         self._request.ws_origin = origin
    288 
    289     def _check_version(self):
    290         version = get_mandatory_header(self._request,
    291                                        common.SEC_WEBSOCKET_VERSION_HEADER)
    292         if version == _VERSION_LATEST_STRING:
    293             return _VERSION_LATEST
    294 
    295         if version.find(',') >= 0:
    296             raise HandshakeException(
    297                 'Multiple versions (%r) are not allowed for header %s' %
    298                 (version, common.SEC_WEBSOCKET_VERSION_HEADER),
    299                 status=common.HTTP_STATUS_BAD_REQUEST)
    300         raise VersionException(
    301             'Unsupported version %r for header %s' %
    302             (version, common.SEC_WEBSOCKET_VERSION_HEADER),
    303             supported_versions=', '.join(map(str, _SUPPORTED_VERSIONS)))
    304 
    305     def _set_protocol(self):
    306         self._request.ws_protocol = None
    307 
    308         protocol_header = self._request.headers_in.get(
    309             common.SEC_WEBSOCKET_PROTOCOL_HEADER)
    310 
    311         if protocol_header is None:
    312             self._request.ws_requested_protocols = None
    313             return
    314 
    315         self._request.ws_requested_protocols = parse_token_list(
    316             protocol_header)
    317         self._logger.debug('Subprotocols requested: %r',
    318                            self._request.ws_requested_protocols)
    319 
    320     def _parse_extensions(self):
    321         extensions_header = self._request.headers_in.get(
    322             common.SEC_WEBSOCKET_EXTENSIONS_HEADER)
    323         if not extensions_header:
    324             self._request.ws_requested_extensions = None
    325             return
    326 
    327         try:
    328             self._request.ws_requested_extensions = common.parse_extensions(
    329                 extensions_header)
    330         except common.ExtensionParsingException, e:
    331             raise HandshakeException(
    332                 'Failed to parse Sec-WebSocket-Extensions header: %r' % e)
    333 
    334         self._logger.debug(
    335             'Extensions requested: %r',
    336             map(common.ExtensionParameter.name,
    337                 self._request.ws_requested_extensions))
    338 
    339     def _validate_key(self, key):
    340         if key.find(',') >= 0:
    341             raise HandshakeException('Request has multiple %s header lines or '
    342                                      'contains illegal character \',\': %r' %
    343                                      (common.SEC_WEBSOCKET_KEY_HEADER, key))
    344 
    345         # Validate
    346         key_is_valid = False
    347         try:
    348             # Validate key by quick regex match before parsing by base64
    349             # module. Because base64 module skips invalid characters, we have
    350             # to do this in advance to make this server strictly reject illegal
    351             # keys.
    352             if _SEC_WEBSOCKET_KEY_REGEX.match(key):
    353                 decoded_key = base64.b64decode(key)
    354                 if len(decoded_key) == 16:
    355                     key_is_valid = True
    356         except TypeError, e:
    357             pass
    358 
    359         if not key_is_valid:
    360             raise HandshakeException(
    361                 'Illegal value for header %s: %r' %
    362                 (common.SEC_WEBSOCKET_KEY_HEADER, key))
    363 
    364         return decoded_key
    365 
    366     def _get_key(self):
    367         key = get_mandatory_header(
    368             self._request, common.SEC_WEBSOCKET_KEY_HEADER)
    369 
    370         decoded_key = self._validate_key(key)
    371 
    372         self._logger.debug(
    373             '%s: %r (%s)',
    374             common.SEC_WEBSOCKET_KEY_HEADER,
    375             key,
    376             util.hexify(decoded_key))
    377 
    378         return key
    379 
    380     def _create_stream(self, stream_options):
    381         return Stream(self._request, stream_options)
    382 
    383     def _create_handshake_response(self, accept):
    384         response = []
    385 
    386         response.append('HTTP/1.1 101 Switching Protocols\r\n')
    387 
    388         # WebSocket headers
    389         response.append(format_header(
    390             common.UPGRADE_HEADER, common.WEBSOCKET_UPGRADE_TYPE))
    391         response.append(format_header(
    392             common.CONNECTION_HEADER, common.UPGRADE_CONNECTION_TYPE))
    393         response.append(format_header(
    394             common.SEC_WEBSOCKET_ACCEPT_HEADER, accept))
    395         if self._request.ws_protocol is not None:
    396             response.append(format_header(
    397                 common.SEC_WEBSOCKET_PROTOCOL_HEADER,
    398                 self._request.ws_protocol))
    399         if (self._request.ws_extensions is not None and
    400             len(self._request.ws_extensions) != 0):
    401             response.append(format_header(
    402                 common.SEC_WEBSOCKET_EXTENSIONS_HEADER,
    403                 common.format_extensions(self._request.ws_extensions)))
    404 
    405         # Headers not specific for WebSocket
    406         for name, value in self._request.extra_headers:
    407             response.append(format_header(name, value))
    408 
    409         response.append('\r\n')
    410 
    411         return ''.join(response)
    412 
    413     def _send_handshake(self, accept):
    414         raw_response = self._create_handshake_response(accept)
    415         self._request.connection.write(raw_response)
    416         self._logger.debug('Sent server\'s opening handshake: %r',
    417                            raw_response)
    418 
    419 
    420 # vi:sts=4 sw=4 et
    421