Home | History | Annotate | Download | only in handshake
      1 # Copyright 2012, Google Inc.
      2 # All rights reserved.
      3 #
      4 # Redistribution and use in source and binary forms, with or without
      5 # modification, are permitted provided that the following conditions are
      6 # met:
      7 #
      8 #     * Redistributions of source code must retain the above copyright
      9 # notice, this list of conditions and the following disclaimer.
     10 #     * Redistributions in binary form must reproduce the above
     11 # copyright notice, this list of conditions and the following disclaimer
     12 # in the documentation and/or other materials provided with the
     13 # distribution.
     14 #     * Neither the name of Google Inc. nor the names of its
     15 # contributors may be used to endorse or promote products derived from
     16 # this software without specific prior written permission.
     17 #
     18 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
     19 # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
     20 # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
     21 # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
     22 # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
     23 # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
     24 # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
     25 # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
     26 # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
     27 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
     28 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
     29 
     30 
     31 """This file provides the opening handshake processor for the WebSocket
     32 protocol (RFC 6455).
     33 
     34 Specification:
     35 http://tools.ietf.org/html/rfc6455
     36 """
     37 
     38 
     39 # Note: request.connection.write is used in this module, even though mod_python
     40 # document says that it should be used only in connection handlers.
     41 # Unfortunately, we have no other options. For example, request.write is not
     42 # suitable because it doesn't allow direct raw bytes writing.
     43 
     44 
     45 import base64
     46 import logging
     47 import os
     48 import re
     49 
     50 from mod_pywebsocket import common
     51 from mod_pywebsocket.extensions import get_extension_processor
     52 from mod_pywebsocket.handshake._base import check_request_line
     53 from mod_pywebsocket.handshake._base import format_header
     54 from mod_pywebsocket.handshake._base import get_mandatory_header
     55 from mod_pywebsocket.handshake._base import HandshakeException
     56 from mod_pywebsocket.handshake._base import parse_token_list
     57 from mod_pywebsocket.handshake._base import validate_mandatory_header
     58 from mod_pywebsocket.handshake._base import validate_subprotocol
     59 from mod_pywebsocket.handshake._base import VersionException
     60 from mod_pywebsocket.stream import Stream
     61 from mod_pywebsocket.stream import StreamOptions
     62 from mod_pywebsocket import util
     63 
     64 
     65 # Used to validate the value in the Sec-WebSocket-Key header strictly. RFC 4648
     66 # disallows non-zero padding, so the character right before == must be any of
     67 # A, Q, g and w.
     68 _SEC_WEBSOCKET_KEY_REGEX = re.compile('^[+/0-9A-Za-z]{21}[AQgw]==$')
     69 
     70 # Defining aliases for values used frequently.
     71 _VERSION_HYBI08 = common.VERSION_HYBI08
     72 _VERSION_HYBI08_STRING = str(_VERSION_HYBI08)
     73 _VERSION_LATEST = common.VERSION_HYBI_LATEST
     74 _VERSION_LATEST_STRING = str(_VERSION_LATEST)
     75 _SUPPORTED_VERSIONS = [
     76     _VERSION_LATEST,
     77     _VERSION_HYBI08,
     78 ]
     79 
     80 
     81 def compute_accept(key):
     82     """Computes value for the Sec-WebSocket-Accept header from value of the
     83     Sec-WebSocket-Key header.
     84     """
     85 
     86     accept_binary = util.sha1_hash(
     87         key + common.WEBSOCKET_ACCEPT_UUID).digest()
     88     accept = base64.b64encode(accept_binary)
     89 
     90     return (accept, accept_binary)
     91 
     92 
     93 class Handshaker(object):
     94     """Opening handshake processor for the WebSocket protocol (RFC 6455)."""
     95 
     96     def __init__(self, request, dispatcher):
     97         """Construct an instance.
     98 
     99         Args:
    100             request: mod_python request.
    101             dispatcher: Dispatcher (dispatch.Dispatcher).
    102 
    103         Handshaker will add attributes such as ws_resource during handshake.
    104         """
    105 
    106         self._logger = util.get_class_logger(self)
    107 
    108         self._request = request
    109         self._dispatcher = dispatcher
    110 
    111     def _validate_connection_header(self):
    112         connection = get_mandatory_header(
    113             self._request, common.CONNECTION_HEADER)
    114 
    115         try:
    116             connection_tokens = parse_token_list(connection)
    117         except HandshakeException, e:
    118             raise HandshakeException(
    119                 'Failed to parse %s: %s' % (common.CONNECTION_HEADER, e))
    120 
    121         connection_is_valid = False
    122         for token in connection_tokens:
    123             if token.lower() == common.UPGRADE_CONNECTION_TYPE.lower():
    124                 connection_is_valid = True
    125                 break
    126         if not connection_is_valid:
    127             raise HandshakeException(
    128                 '%s header doesn\'t contain "%s"' %
    129                 (common.CONNECTION_HEADER, common.UPGRADE_CONNECTION_TYPE))
    130 
    131     def do_handshake(self):
    132         self._request.ws_close_code = None
    133         self._request.ws_close_reason = None
    134 
    135         # Parsing.
    136 
    137         check_request_line(self._request)
    138 
    139         validate_mandatory_header(
    140             self._request,
    141             common.UPGRADE_HEADER,
    142             common.WEBSOCKET_UPGRADE_TYPE)
    143 
    144         self._validate_connection_header()
    145 
    146         self._request.ws_resource = self._request.uri
    147 
    148         unused_host = get_mandatory_header(self._request, common.HOST_HEADER)
    149 
    150         self._request.ws_version = self._check_version()
    151 
    152         # This handshake must be based on latest hybi. We are responsible to
    153         # fallback to HTTP on handshake failure as latest hybi handshake
    154         # specifies.
    155         try:
    156             self._get_origin()
    157             self._set_protocol()
    158             self._parse_extensions()
    159 
    160             # Key validation, response generation.
    161 
    162             key = self._get_key()
    163             (accept, accept_binary) = compute_accept(key)
    164             self._logger.debug(
    165                 '%s: %r (%s)',
    166                 common.SEC_WEBSOCKET_ACCEPT_HEADER,
    167                 accept,
    168                 util.hexify(accept_binary))
    169 
    170             self._logger.debug('Protocol version is RFC 6455')
    171 
    172             # Setup extension processors.
    173 
    174             processors = []
    175             if self._request.ws_requested_extensions is not None:
    176                 for extension_request in self._request.ws_requested_extensions:
    177                     processor = get_extension_processor(extension_request)
    178                     # Unknown extension requests are just ignored.
    179                     if processor is not None:
    180                         processors.append(processor)
    181             self._request.ws_extension_processors = processors
    182 
    183             # Extra handshake handler may modify/remove processors.
    184             self._dispatcher.do_extra_handshake(self._request)
    185             processors = filter(lambda processor: processor is not None,
    186                                 self._request.ws_extension_processors)
    187 
    188             accepted_extensions = []
    189 
    190             # We need to take care of mux extension here. Extensions that
    191             # are placed before mux should be applied to logical channels.
    192             mux_index = -1
    193             for i, processor in enumerate(processors):
    194                 if processor.name() == common.MUX_EXTENSION:
    195                     mux_index = i
    196                     break
    197             if mux_index >= 0:
    198                 mux_processor = processors[mux_index]
    199                 logical_channel_processors = processors[:mux_index]
    200                 processors = processors[mux_index+1:]
    201 
    202                 for processor in logical_channel_processors:
    203                     extension_response = processor.get_extension_response()
    204                     if extension_response is None:
    205                         # Rejected.
    206                         continue
    207                     accepted_extensions.append(extension_response)
    208                 # Pass a shallow copy of accepted_extensions as extensions for
    209                 # logical channels.
    210                 mux_response = mux_processor.get_extension_response(
    211                     self._request, accepted_extensions[:])
    212                 if mux_response is not None:
    213                     accepted_extensions.append(mux_response)
    214 
    215             stream_options = StreamOptions()
    216 
    217             # When there is mux extension, here, |processors| contain only
    218             # prosessors for extensions placed after mux.
    219             for processor in processors:
    220 
    221                 extension_response = processor.get_extension_response()
    222                 if extension_response is None:
    223                     # Rejected.
    224                     continue
    225 
    226                 accepted_extensions.append(extension_response)
    227 
    228                 processor.setup_stream_options(stream_options)
    229 
    230             if len(accepted_extensions) > 0:
    231                 self._request.ws_extensions = accepted_extensions
    232                 self._logger.debug(
    233                     'Extensions accepted: %r',
    234                     map(common.ExtensionParameter.name, accepted_extensions))
    235             else:
    236                 self._request.ws_extensions = None
    237 
    238             self._request.ws_stream = self._create_stream(stream_options)
    239 
    240             if self._request.ws_requested_protocols is not None:
    241                 if self._request.ws_protocol is None:
    242                     raise HandshakeException(
    243                         'do_extra_handshake must choose one subprotocol from '
    244                         'ws_requested_protocols and set it to ws_protocol')
    245                 validate_subprotocol(self._request.ws_protocol, hixie=False)
    246 
    247                 self._logger.debug(
    248                     'Subprotocol accepted: %r',
    249                     self._request.ws_protocol)
    250             else:
    251                 if self._request.ws_protocol is not None:
    252                     raise HandshakeException(
    253                         'ws_protocol must be None when the client didn\'t '
    254                         'request any subprotocol')
    255 
    256             self._send_handshake(accept)
    257         except HandshakeException, e:
    258             if not e.status:
    259                 # Fallback to 400 bad request by default.
    260                 e.status = common.HTTP_STATUS_BAD_REQUEST
    261             raise e
    262 
    263     def _get_origin(self):
    264         if self._request.ws_version is _VERSION_HYBI08:
    265             origin_header = common.SEC_WEBSOCKET_ORIGIN_HEADER
    266         else:
    267             origin_header = common.ORIGIN_HEADER
    268         origin = self._request.headers_in.get(origin_header)
    269         if origin is None:
    270             self._logger.debug('Client request does not have origin header')
    271         self._request.ws_origin = origin
    272 
    273     def _check_version(self):
    274         version = get_mandatory_header(self._request,
    275                                        common.SEC_WEBSOCKET_VERSION_HEADER)
    276         if version == _VERSION_HYBI08_STRING:
    277             return _VERSION_HYBI08
    278         if version == _VERSION_LATEST_STRING:
    279             return _VERSION_LATEST
    280 
    281         if version.find(',') >= 0:
    282             raise HandshakeException(
    283                 'Multiple versions (%r) are not allowed for header %s' %
    284                 (version, common.SEC_WEBSOCKET_VERSION_HEADER),
    285                 status=common.HTTP_STATUS_BAD_REQUEST)
    286         raise VersionException(
    287             'Unsupported version %r for header %s' %
    288             (version, common.SEC_WEBSOCKET_VERSION_HEADER),
    289             supported_versions=', '.join(map(str, _SUPPORTED_VERSIONS)))
    290 
    291     def _set_protocol(self):
    292         self._request.ws_protocol = None
    293 
    294         protocol_header = self._request.headers_in.get(
    295             common.SEC_WEBSOCKET_PROTOCOL_HEADER)
    296 
    297         if not protocol_header:
    298             self._request.ws_requested_protocols = None
    299             return
    300 
    301         self._request.ws_requested_protocols = parse_token_list(
    302             protocol_header)
    303         self._logger.debug('Subprotocols requested: %r',
    304                            self._request.ws_requested_protocols)
    305 
    306     def _parse_extensions(self):
    307         extensions_header = self._request.headers_in.get(
    308             common.SEC_WEBSOCKET_EXTENSIONS_HEADER)
    309         if not extensions_header:
    310             self._request.ws_requested_extensions = None
    311             return
    312 
    313         if self._request.ws_version is common.VERSION_HYBI08:
    314             allow_quoted_string=False
    315         else:
    316             allow_quoted_string=True
    317         try:
    318             self._request.ws_requested_extensions = common.parse_extensions(
    319                 extensions_header, allow_quoted_string=allow_quoted_string)
    320         except common.ExtensionParsingException, e:
    321             raise HandshakeException(
    322                 'Failed to parse Sec-WebSocket-Extensions header: %r' % e)
    323 
    324         self._logger.debug(
    325             'Extensions requested: %r',
    326             map(common.ExtensionParameter.name,
    327                 self._request.ws_requested_extensions))
    328 
    329     def _validate_key(self, key):
    330         if key.find(',') >= 0:
    331             raise HandshakeException('Request has multiple %s header lines or '
    332                                      'contains illegal character \',\': %r' %
    333                                      (common.SEC_WEBSOCKET_KEY_HEADER, key))
    334 
    335         # Validate
    336         key_is_valid = False
    337         try:
    338             # Validate key by quick regex match before parsing by base64
    339             # module. Because base64 module skips invalid characters, we have
    340             # to do this in advance to make this server strictly reject illegal
    341             # keys.
    342             if _SEC_WEBSOCKET_KEY_REGEX.match(key):
    343                 decoded_key = base64.b64decode(key)
    344                 if len(decoded_key) == 16:
    345                     key_is_valid = True
    346         except TypeError, e:
    347             pass
    348 
    349         if not key_is_valid:
    350             raise HandshakeException(
    351                 'Illegal value for header %s: %r' %
    352                 (common.SEC_WEBSOCKET_KEY_HEADER, key))
    353 
    354         return decoded_key
    355 
    356     def _get_key(self):
    357         key = get_mandatory_header(
    358             self._request, common.SEC_WEBSOCKET_KEY_HEADER)
    359 
    360         decoded_key = self._validate_key(key)
    361 
    362         self._logger.debug(
    363             '%s: %r (%s)',
    364             common.SEC_WEBSOCKET_KEY_HEADER,
    365             key,
    366             util.hexify(decoded_key))
    367 
    368         return key
    369 
    370     def _create_stream(self, stream_options):
    371         return Stream(self._request, stream_options)
    372 
    373     def _create_handshake_response(self, accept):
    374         response = []
    375 
    376         response.append('HTTP/1.1 101 Switching Protocols\r\n')
    377 
    378         response.append(format_header(
    379             common.UPGRADE_HEADER, common.WEBSOCKET_UPGRADE_TYPE))
    380         response.append(format_header(
    381             common.CONNECTION_HEADER, common.UPGRADE_CONNECTION_TYPE))
    382         response.append(format_header(
    383             common.SEC_WEBSOCKET_ACCEPT_HEADER, accept))
    384         if self._request.ws_protocol is not None:
    385             response.append(format_header(
    386                 common.SEC_WEBSOCKET_PROTOCOL_HEADER,
    387                 self._request.ws_protocol))
    388         if (self._request.ws_extensions is not None and
    389             len(self._request.ws_extensions) != 0):
    390             response.append(format_header(
    391                 common.SEC_WEBSOCKET_EXTENSIONS_HEADER,
    392                 common.format_extensions(self._request.ws_extensions)))
    393         response.append('\r\n')
    394 
    395         return ''.join(response)
    396 
    397     def _send_handshake(self, accept):
    398         raw_response = self._create_handshake_response(accept)
    399         self._request.connection.write(raw_response)
    400         self._logger.debug('Sent server\'s opening handshake: %r',
    401                            raw_response)
    402 
    403 
    404 # vi:sts=4 sw=4 et
    405