Home | History | Annotate | Download | only in handshake
      1 # Copyright 2012, Google Inc.
      2 # All rights reserved.
      3 #
      4 # Redistribution and use in source and binary forms, with or without
      5 # modification, are permitted provided that the following conditions are
      6 # met:
      7 #
      8 #     * Redistributions of source code must retain the above copyright
      9 # notice, this list of conditions and the following disclaimer.
     10 #     * Redistributions in binary form must reproduce the above
     11 # copyright notice, this list of conditions and the following disclaimer
     12 # in the documentation and/or other materials provided with the
     13 # distribution.
     14 #     * Neither the name of Google Inc. nor the names of its
     15 # contributors may be used to endorse or promote products derived from
     16 # this software without specific prior written permission.
     17 #
     18 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
     19 # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
     20 # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
     21 # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
     22 # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
     23 # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
     24 # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
     25 # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
     26 # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
     27 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
     28 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
     29 
     30 
     31 """This file provides the opening handshake processor for the WebSocket
     32 protocol (RFC 6455).
     33 
     34 Specification:
     35 http://tools.ietf.org/html/rfc6455
     36 """
     37 
     38 
     39 # Note: request.connection.write is used in this module, even though mod_python
     40 # document says that it should be used only in connection handlers.
     41 # Unfortunately, we have no other options. For example, request.write is not
     42 # suitable because it doesn't allow direct raw bytes writing.
     43 
     44 
     45 import base64
     46 import logging
     47 import os
     48 import re
     49 
     50 from mod_pywebsocket import common
     51 from mod_pywebsocket.extensions import get_extension_processor
     52 from mod_pywebsocket.extensions import is_compression_extension
     53 from mod_pywebsocket.handshake._base import check_request_line
     54 from mod_pywebsocket.handshake._base import format_header
     55 from mod_pywebsocket.handshake._base import get_mandatory_header
     56 from mod_pywebsocket.handshake._base import HandshakeException
     57 from mod_pywebsocket.handshake._base import parse_token_list
     58 from mod_pywebsocket.handshake._base import validate_mandatory_header
     59 from mod_pywebsocket.handshake._base import validate_subprotocol
     60 from mod_pywebsocket.handshake._base import VersionException
     61 from mod_pywebsocket.stream import Stream
     62 from mod_pywebsocket.stream import StreamOptions
     63 from mod_pywebsocket import util
     64 
     65 
     66 # Used to validate the value in the Sec-WebSocket-Key header strictly. RFC 4648
     67 # disallows non-zero padding, so the character right before == must be any of
     68 # A, Q, g and w.
     69 _SEC_WEBSOCKET_KEY_REGEX = re.compile('^[+/0-9A-Za-z]{21}[AQgw]==$')
     70 
     71 # Defining aliases for values used frequently.
     72 _VERSION_HYBI08 = common.VERSION_HYBI08
     73 _VERSION_HYBI08_STRING = str(_VERSION_HYBI08)
     74 _VERSION_LATEST = common.VERSION_HYBI_LATEST
     75 _VERSION_LATEST_STRING = str(_VERSION_LATEST)
     76 _SUPPORTED_VERSIONS = [
     77     _VERSION_LATEST,
     78     _VERSION_HYBI08,
     79 ]
     80 
     81 
     82 def compute_accept(key):
     83     """Computes value for the Sec-WebSocket-Accept header from value of the
     84     Sec-WebSocket-Key header.
     85     """
     86 
     87     accept_binary = util.sha1_hash(
     88         key + common.WEBSOCKET_ACCEPT_UUID).digest()
     89     accept = base64.b64encode(accept_binary)
     90 
     91     return (accept, accept_binary)
     92 
     93 
     94 class Handshaker(object):
     95     """Opening handshake processor for the WebSocket protocol (RFC 6455)."""
     96 
     97     def __init__(self, request, dispatcher):
     98         """Construct an instance.
     99 
    100         Args:
    101             request: mod_python request.
    102             dispatcher: Dispatcher (dispatch.Dispatcher).
    103 
    104         Handshaker will add attributes such as ws_resource during handshake.
    105         """
    106 
    107         self._logger = util.get_class_logger(self)
    108 
    109         self._request = request
    110         self._dispatcher = dispatcher
    111 
    112     def _validate_connection_header(self):
    113         connection = get_mandatory_header(
    114             self._request, common.CONNECTION_HEADER)
    115 
    116         try:
    117             connection_tokens = parse_token_list(connection)
    118         except HandshakeException, e:
    119             raise HandshakeException(
    120                 'Failed to parse %s: %s' % (common.CONNECTION_HEADER, e))
    121 
    122         connection_is_valid = False
    123         for token in connection_tokens:
    124             if token.lower() == common.UPGRADE_CONNECTION_TYPE.lower():
    125                 connection_is_valid = True
    126                 break
    127         if not connection_is_valid:
    128             raise HandshakeException(
    129                 '%s header doesn\'t contain "%s"' %
    130                 (common.CONNECTION_HEADER, common.UPGRADE_CONNECTION_TYPE))
    131 
    132     def do_handshake(self):
    133         self._request.ws_close_code = None
    134         self._request.ws_close_reason = None
    135 
    136         # Parsing.
    137 
    138         check_request_line(self._request)
    139 
    140         validate_mandatory_header(
    141             self._request,
    142             common.UPGRADE_HEADER,
    143             common.WEBSOCKET_UPGRADE_TYPE)
    144 
    145         self._validate_connection_header()
    146 
    147         self._request.ws_resource = self._request.uri
    148 
    149         unused_host = get_mandatory_header(self._request, common.HOST_HEADER)
    150 
    151         self._request.ws_version = self._check_version()
    152 
    153         # This handshake must be based on latest hybi. We are responsible to
    154         # fallback to HTTP on handshake failure as latest hybi handshake
    155         # specifies.
    156         try:
    157             self._get_origin()
    158             self._set_protocol()
    159             self._parse_extensions()
    160 
    161             # Key validation, response generation.
    162 
    163             key = self._get_key()
    164             (accept, accept_binary) = compute_accept(key)
    165             self._logger.debug(
    166                 '%s: %r (%s)',
    167                 common.SEC_WEBSOCKET_ACCEPT_HEADER,
    168                 accept,
    169                 util.hexify(accept_binary))
    170 
    171             self._logger.debug('Protocol version is RFC 6455')
    172 
    173             # Setup extension processors.
    174 
    175             processors = []
    176             if self._request.ws_requested_extensions is not None:
    177                 for extension_request in self._request.ws_requested_extensions:
    178                     processor = get_extension_processor(extension_request)
    179                     # Unknown extension requests are just ignored.
    180                     if processor is not None:
    181                         processors.append(processor)
    182             self._request.ws_extension_processors = processors
    183 
    184             # List of extra headers. The extra handshake handler may add header
    185             # data as name/value pairs to this list and pywebsocket appends
    186             # them to the WebSocket handshake.
    187             self._request.extra_headers = []
    188 
    189             # Extra handshake handler may modify/remove processors.
    190             self._dispatcher.do_extra_handshake(self._request)
    191             processors = filter(lambda processor: processor is not None,
    192                                 self._request.ws_extension_processors)
    193 
    194             # Ask each processor if there are extensions on the request which
    195             # cannot co-exist. When processor decided other processors cannot
    196             # co-exist with it, the processor marks them (or itself) as
    197             # "inactive". The first extension processor has the right to
    198             # make the final call.
    199             for processor in reversed(processors):
    200                 if processor.is_active():
    201                     processor.check_consistency_with_other_processors(
    202                         processors)
    203             processors = filter(lambda processor: processor.is_active(),
    204                                 processors)
    205 
    206             accepted_extensions = []
    207 
    208             # We need to take into account of mux extension here.
    209             # If mux extension exists:
    210             # - Remove processors of extensions for logical channel,
    211             #   which are processors located before the mux processor
    212             # - Pass extension requests for logical channel to mux processor
    213             # - Attach the mux processor to the request. It will be referred
    214             #   by dispatcher to see whether the dispatcher should use mux
    215             #   handler or not.
    216             mux_index = -1
    217             for i, processor in enumerate(processors):
    218                 if processor.name() == common.MUX_EXTENSION:
    219                     mux_index = i
    220                     break
    221             if mux_index >= 0:
    222                 logical_channel_extensions = []
    223                 for processor in processors[:mux_index]:
    224                     logical_channel_extensions.append(processor.request())
    225                     processor.set_active(False)
    226                 self._request.mux_processor = processors[mux_index]
    227                 self._request.mux_processor.set_extensions(
    228                     logical_channel_extensions)
    229                 processors = filter(lambda processor: processor.is_active(),
    230                                     processors)
    231 
    232             stream_options = StreamOptions()
    233 
    234             for index, processor in enumerate(processors):
    235                 if not processor.is_active():
    236                     continue
    237 
    238                 extension_response = processor.get_extension_response()
    239                 if extension_response is None:
    240                     # Rejected.
    241                     continue
    242 
    243                 accepted_extensions.append(extension_response)
    244 
    245                 processor.setup_stream_options(stream_options)
    246 
    247                 if not is_compression_extension(processor.name()):
    248                     continue
    249 
    250                 # Inactivate all of the following compression extensions.
    251                 for j in xrange(index + 1, len(processors)):
    252                     if is_compression_extension(processors[j].name()):
    253                         processors[j].set_active(False)
    254 
    255             if len(accepted_extensions) > 0:
    256                 self._request.ws_extensions = accepted_extensions
    257                 self._logger.debug(
    258                     'Extensions accepted: %r',
    259                     map(common.ExtensionParameter.name, accepted_extensions))
    260             else:
    261                 self._request.ws_extensions = None
    262 
    263             self._request.ws_stream = self._create_stream(stream_options)
    264 
    265             if self._request.ws_requested_protocols is not None:
    266                 if self._request.ws_protocol is None:
    267                     raise HandshakeException(
    268                         'do_extra_handshake must choose one subprotocol from '
    269                         'ws_requested_protocols and set it to ws_protocol')
    270                 validate_subprotocol(self._request.ws_protocol)
    271 
    272                 self._logger.debug(
    273                     'Subprotocol accepted: %r',
    274                     self._request.ws_protocol)
    275             else:
    276                 if self._request.ws_protocol is not None:
    277                     raise HandshakeException(
    278                         'ws_protocol must be None when the client didn\'t '
    279                         'request any subprotocol')
    280 
    281             self._send_handshake(accept)
    282         except HandshakeException, e:
    283             if not e.status:
    284                 # Fallback to 400 bad request by default.
    285                 e.status = common.HTTP_STATUS_BAD_REQUEST
    286             raise e
    287 
    288     def _get_origin(self):
    289         if self._request.ws_version is _VERSION_HYBI08:
    290             origin_header = common.SEC_WEBSOCKET_ORIGIN_HEADER
    291         else:
    292             origin_header = common.ORIGIN_HEADER
    293         origin = self._request.headers_in.get(origin_header)
    294         if origin is None:
    295             self._logger.debug('Client request does not have origin header')
    296         self._request.ws_origin = origin
    297 
    298     def _check_version(self):
    299         version = get_mandatory_header(self._request,
    300                                        common.SEC_WEBSOCKET_VERSION_HEADER)
    301         if version == _VERSION_HYBI08_STRING:
    302             return _VERSION_HYBI08
    303         if version == _VERSION_LATEST_STRING:
    304             return _VERSION_LATEST
    305 
    306         if version.find(',') >= 0:
    307             raise HandshakeException(
    308                 'Multiple versions (%r) are not allowed for header %s' %
    309                 (version, common.SEC_WEBSOCKET_VERSION_HEADER),
    310                 status=common.HTTP_STATUS_BAD_REQUEST)
    311         raise VersionException(
    312             'Unsupported version %r for header %s' %
    313             (version, common.SEC_WEBSOCKET_VERSION_HEADER),
    314             supported_versions=', '.join(map(str, _SUPPORTED_VERSIONS)))
    315 
    316     def _set_protocol(self):
    317         self._request.ws_protocol = None
    318 
    319         protocol_header = self._request.headers_in.get(
    320             common.SEC_WEBSOCKET_PROTOCOL_HEADER)
    321 
    322         if protocol_header is None:
    323             self._request.ws_requested_protocols = None
    324             return
    325 
    326         self._request.ws_requested_protocols = parse_token_list(
    327             protocol_header)
    328         self._logger.debug('Subprotocols requested: %r',
    329                            self._request.ws_requested_protocols)
    330 
    331     def _parse_extensions(self):
    332         extensions_header = self._request.headers_in.get(
    333             common.SEC_WEBSOCKET_EXTENSIONS_HEADER)
    334         if not extensions_header:
    335             self._request.ws_requested_extensions = None
    336             return
    337 
    338         if self._request.ws_version is common.VERSION_HYBI08:
    339             allow_quoted_string=False
    340         else:
    341             allow_quoted_string=True
    342         try:
    343             self._request.ws_requested_extensions = common.parse_extensions(
    344                 extensions_header, allow_quoted_string=allow_quoted_string)
    345         except common.ExtensionParsingException, e:
    346             raise HandshakeException(
    347                 'Failed to parse Sec-WebSocket-Extensions header: %r' % e)
    348 
    349         self._logger.debug(
    350             'Extensions requested: %r',
    351             map(common.ExtensionParameter.name,
    352                 self._request.ws_requested_extensions))
    353 
    354     def _validate_key(self, key):
    355         if key.find(',') >= 0:
    356             raise HandshakeException('Request has multiple %s header lines or '
    357                                      'contains illegal character \',\': %r' %
    358                                      (common.SEC_WEBSOCKET_KEY_HEADER, key))
    359 
    360         # Validate
    361         key_is_valid = False
    362         try:
    363             # Validate key by quick regex match before parsing by base64
    364             # module. Because base64 module skips invalid characters, we have
    365             # to do this in advance to make this server strictly reject illegal
    366             # keys.
    367             if _SEC_WEBSOCKET_KEY_REGEX.match(key):
    368                 decoded_key = base64.b64decode(key)
    369                 if len(decoded_key) == 16:
    370                     key_is_valid = True
    371         except TypeError, e:
    372             pass
    373 
    374         if not key_is_valid:
    375             raise HandshakeException(
    376                 'Illegal value for header %s: %r' %
    377                 (common.SEC_WEBSOCKET_KEY_HEADER, key))
    378 
    379         return decoded_key
    380 
    381     def _get_key(self):
    382         key = get_mandatory_header(
    383             self._request, common.SEC_WEBSOCKET_KEY_HEADER)
    384 
    385         decoded_key = self._validate_key(key)
    386 
    387         self._logger.debug(
    388             '%s: %r (%s)',
    389             common.SEC_WEBSOCKET_KEY_HEADER,
    390             key,
    391             util.hexify(decoded_key))
    392 
    393         return key
    394 
    395     def _create_stream(self, stream_options):
    396         return Stream(self._request, stream_options)
    397 
    398     def _create_handshake_response(self, accept):
    399         response = []
    400 
    401         response.append('HTTP/1.1 101 Switching Protocols\r\n')
    402 
    403         # WebSocket headers
    404         response.append(format_header(
    405             common.UPGRADE_HEADER, common.WEBSOCKET_UPGRADE_TYPE))
    406         response.append(format_header(
    407             common.CONNECTION_HEADER, common.UPGRADE_CONNECTION_TYPE))
    408         response.append(format_header(
    409             common.SEC_WEBSOCKET_ACCEPT_HEADER, accept))
    410         if self._request.ws_protocol is not None:
    411             response.append(format_header(
    412                 common.SEC_WEBSOCKET_PROTOCOL_HEADER,
    413                 self._request.ws_protocol))
    414         if (self._request.ws_extensions is not None and
    415             len(self._request.ws_extensions) != 0):
    416             response.append(format_header(
    417                 common.SEC_WEBSOCKET_EXTENSIONS_HEADER,
    418                 common.format_extensions(self._request.ws_extensions)))
    419 
    420         # Headers not specific for WebSocket
    421         for name, value in self._request.extra_headers:
    422             response.append(format_header(name, value))
    423 
    424         response.append('\r\n')
    425 
    426         return ''.join(response)
    427 
    428     def _send_handshake(self, accept):
    429         raw_response = self._create_handshake_response(accept)
    430         self._request.connection.write(raw_response)
    431         self._logger.debug('Sent server\'s opening handshake: %r',
    432                            raw_response)
    433 
    434 
    435 # vi:sts=4 sw=4 et
    436