1 # Copyright 2012, Google Inc. 2 # All rights reserved. 3 # 4 # Redistribution and use in source and binary forms, with or without 5 # modification, are permitted provided that the following conditions are 6 # met: 7 # 8 # * Redistributions of source code must retain the above copyright 9 # notice, this list of conditions and the following disclaimer. 10 # * Redistributions in binary form must reproduce the above 11 # copyright notice, this list of conditions and the following disclaimer 12 # in the documentation and/or other materials provided with the 13 # distribution. 14 # * Neither the name of Google Inc. nor the names of its 15 # contributors may be used to endorse or promote products derived from 16 # this software without specific prior written permission. 17 # 18 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 30 31 """This file provides the opening handshake processor for the WebSocket 32 protocol (RFC 6455). 33 34 Specification: 35 http://tools.ietf.org/html/rfc6455 36 """ 37 38 39 # Note: request.connection.write is used in this module, even though mod_python 40 # document says that it should be used only in connection handlers. 41 # Unfortunately, we have no other options. For example, request.write is not 42 # suitable because it doesn't allow direct raw bytes writing. 43 44 45 import base64 46 import logging 47 import os 48 import re 49 50 from mod_pywebsocket import common 51 from mod_pywebsocket.extensions import get_extension_processor 52 from mod_pywebsocket.handshake._base import check_request_line 53 from mod_pywebsocket.handshake._base import format_header 54 from mod_pywebsocket.handshake._base import get_mandatory_header 55 from mod_pywebsocket.handshake._base import HandshakeException 56 from mod_pywebsocket.handshake._base import parse_token_list 57 from mod_pywebsocket.handshake._base import validate_mandatory_header 58 from mod_pywebsocket.handshake._base import validate_subprotocol 59 from mod_pywebsocket.handshake._base import VersionException 60 from mod_pywebsocket.stream import Stream 61 from mod_pywebsocket.stream import StreamOptions 62 from mod_pywebsocket import util 63 64 65 # Used to validate the value in the Sec-WebSocket-Key header strictly. RFC 4648 66 # disallows non-zero padding, so the character right before == must be any of 67 # A, Q, g and w. 68 _SEC_WEBSOCKET_KEY_REGEX = re.compile('^[+/0-9A-Za-z]{21}[AQgw]==$') 69 70 # Defining aliases for values used frequently. 71 _VERSION_HYBI08 = common.VERSION_HYBI08 72 _VERSION_HYBI08_STRING = str(_VERSION_HYBI08) 73 _VERSION_LATEST = common.VERSION_HYBI_LATEST 74 _VERSION_LATEST_STRING = str(_VERSION_LATEST) 75 _SUPPORTED_VERSIONS = [ 76 _VERSION_LATEST, 77 _VERSION_HYBI08, 78 ] 79 80 81 def compute_accept(key): 82 """Computes value for the Sec-WebSocket-Accept header from value of the 83 Sec-WebSocket-Key header. 84 """ 85 86 accept_binary = util.sha1_hash( 87 key + common.WEBSOCKET_ACCEPT_UUID).digest() 88 accept = base64.b64encode(accept_binary) 89 90 return (accept, accept_binary) 91 92 93 class Handshaker(object): 94 """Opening handshake processor for the WebSocket protocol (RFC 6455).""" 95 96 def __init__(self, request, dispatcher): 97 """Construct an instance. 98 99 Args: 100 request: mod_python request. 101 dispatcher: Dispatcher (dispatch.Dispatcher). 102 103 Handshaker will add attributes such as ws_resource during handshake. 104 """ 105 106 self._logger = util.get_class_logger(self) 107 108 self._request = request 109 self._dispatcher = dispatcher 110 111 def _validate_connection_header(self): 112 connection = get_mandatory_header( 113 self._request, common.CONNECTION_HEADER) 114 115 try: 116 connection_tokens = parse_token_list(connection) 117 except HandshakeException, e: 118 raise HandshakeException( 119 'Failed to parse %s: %s' % (common.CONNECTION_HEADER, e)) 120 121 connection_is_valid = False 122 for token in connection_tokens: 123 if token.lower() == common.UPGRADE_CONNECTION_TYPE.lower(): 124 connection_is_valid = True 125 break 126 if not connection_is_valid: 127 raise HandshakeException( 128 '%s header doesn\'t contain "%s"' % 129 (common.CONNECTION_HEADER, common.UPGRADE_CONNECTION_TYPE)) 130 131 def do_handshake(self): 132 self._request.ws_close_code = None 133 self._request.ws_close_reason = None 134 135 # Parsing. 136 137 check_request_line(self._request) 138 139 validate_mandatory_header( 140 self._request, 141 common.UPGRADE_HEADER, 142 common.WEBSOCKET_UPGRADE_TYPE) 143 144 self._validate_connection_header() 145 146 self._request.ws_resource = self._request.uri 147 148 unused_host = get_mandatory_header(self._request, common.HOST_HEADER) 149 150 self._request.ws_version = self._check_version() 151 152 # This handshake must be based on latest hybi. We are responsible to 153 # fallback to HTTP on handshake failure as latest hybi handshake 154 # specifies. 155 try: 156 self._get_origin() 157 self._set_protocol() 158 self._parse_extensions() 159 160 # Key validation, response generation. 161 162 key = self._get_key() 163 (accept, accept_binary) = compute_accept(key) 164 self._logger.debug( 165 '%s: %r (%s)', 166 common.SEC_WEBSOCKET_ACCEPT_HEADER, 167 accept, 168 util.hexify(accept_binary)) 169 170 self._logger.debug('Protocol version is RFC 6455') 171 172 # Setup extension processors. 173 174 processors = [] 175 if self._request.ws_requested_extensions is not None: 176 for extension_request in self._request.ws_requested_extensions: 177 processor = get_extension_processor(extension_request) 178 # Unknown extension requests are just ignored. 179 if processor is not None: 180 processors.append(processor) 181 self._request.ws_extension_processors = processors 182 183 # Extra handshake handler may modify/remove processors. 184 self._dispatcher.do_extra_handshake(self._request) 185 processors = filter(lambda processor: processor is not None, 186 self._request.ws_extension_processors) 187 188 accepted_extensions = [] 189 190 # We need to take care of mux extension here. Extensions that 191 # are placed before mux should be applied to logical channels. 192 mux_index = -1 193 for i, processor in enumerate(processors): 194 if processor.name() == common.MUX_EXTENSION: 195 mux_index = i 196 break 197 if mux_index >= 0: 198 mux_processor = processors[mux_index] 199 logical_channel_processors = processors[:mux_index] 200 processors = processors[mux_index+1:] 201 202 for processor in logical_channel_processors: 203 extension_response = processor.get_extension_response() 204 if extension_response is None: 205 # Rejected. 206 continue 207 accepted_extensions.append(extension_response) 208 # Pass a shallow copy of accepted_extensions as extensions for 209 # logical channels. 210 mux_response = mux_processor.get_extension_response( 211 self._request, accepted_extensions[:]) 212 if mux_response is not None: 213 accepted_extensions.append(mux_response) 214 215 stream_options = StreamOptions() 216 217 # When there is mux extension, here, |processors| contain only 218 # prosessors for extensions placed after mux. 219 for processor in processors: 220 221 extension_response = processor.get_extension_response() 222 if extension_response is None: 223 # Rejected. 224 continue 225 226 accepted_extensions.append(extension_response) 227 228 processor.setup_stream_options(stream_options) 229 230 if len(accepted_extensions) > 0: 231 self._request.ws_extensions = accepted_extensions 232 self._logger.debug( 233 'Extensions accepted: %r', 234 map(common.ExtensionParameter.name, accepted_extensions)) 235 else: 236 self._request.ws_extensions = None 237 238 self._request.ws_stream = self._create_stream(stream_options) 239 240 if self._request.ws_requested_protocols is not None: 241 if self._request.ws_protocol is None: 242 raise HandshakeException( 243 'do_extra_handshake must choose one subprotocol from ' 244 'ws_requested_protocols and set it to ws_protocol') 245 validate_subprotocol(self._request.ws_protocol, hixie=False) 246 247 self._logger.debug( 248 'Subprotocol accepted: %r', 249 self._request.ws_protocol) 250 else: 251 if self._request.ws_protocol is not None: 252 raise HandshakeException( 253 'ws_protocol must be None when the client didn\'t ' 254 'request any subprotocol') 255 256 self._send_handshake(accept) 257 except HandshakeException, e: 258 if not e.status: 259 # Fallback to 400 bad request by default. 260 e.status = common.HTTP_STATUS_BAD_REQUEST 261 raise e 262 263 def _get_origin(self): 264 if self._request.ws_version is _VERSION_HYBI08: 265 origin_header = common.SEC_WEBSOCKET_ORIGIN_HEADER 266 else: 267 origin_header = common.ORIGIN_HEADER 268 origin = self._request.headers_in.get(origin_header) 269 if origin is None: 270 self._logger.debug('Client request does not have origin header') 271 self._request.ws_origin = origin 272 273 def _check_version(self): 274 version = get_mandatory_header(self._request, 275 common.SEC_WEBSOCKET_VERSION_HEADER) 276 if version == _VERSION_HYBI08_STRING: 277 return _VERSION_HYBI08 278 if version == _VERSION_LATEST_STRING: 279 return _VERSION_LATEST 280 281 if version.find(',') >= 0: 282 raise HandshakeException( 283 'Multiple versions (%r) are not allowed for header %s' % 284 (version, common.SEC_WEBSOCKET_VERSION_HEADER), 285 status=common.HTTP_STATUS_BAD_REQUEST) 286 raise VersionException( 287 'Unsupported version %r for header %s' % 288 (version, common.SEC_WEBSOCKET_VERSION_HEADER), 289 supported_versions=', '.join(map(str, _SUPPORTED_VERSIONS))) 290 291 def _set_protocol(self): 292 self._request.ws_protocol = None 293 294 protocol_header = self._request.headers_in.get( 295 common.SEC_WEBSOCKET_PROTOCOL_HEADER) 296 297 if not protocol_header: 298 self._request.ws_requested_protocols = None 299 return 300 301 self._request.ws_requested_protocols = parse_token_list( 302 protocol_header) 303 self._logger.debug('Subprotocols requested: %r', 304 self._request.ws_requested_protocols) 305 306 def _parse_extensions(self): 307 extensions_header = self._request.headers_in.get( 308 common.SEC_WEBSOCKET_EXTENSIONS_HEADER) 309 if not extensions_header: 310 self._request.ws_requested_extensions = None 311 return 312 313 if self._request.ws_version is common.VERSION_HYBI08: 314 allow_quoted_string=False 315 else: 316 allow_quoted_string=True 317 try: 318 self._request.ws_requested_extensions = common.parse_extensions( 319 extensions_header, allow_quoted_string=allow_quoted_string) 320 except common.ExtensionParsingException, e: 321 raise HandshakeException( 322 'Failed to parse Sec-WebSocket-Extensions header: %r' % e) 323 324 self._logger.debug( 325 'Extensions requested: %r', 326 map(common.ExtensionParameter.name, 327 self._request.ws_requested_extensions)) 328 329 def _validate_key(self, key): 330 if key.find(',') >= 0: 331 raise HandshakeException('Request has multiple %s header lines or ' 332 'contains illegal character \',\': %r' % 333 (common.SEC_WEBSOCKET_KEY_HEADER, key)) 334 335 # Validate 336 key_is_valid = False 337 try: 338 # Validate key by quick regex match before parsing by base64 339 # module. Because base64 module skips invalid characters, we have 340 # to do this in advance to make this server strictly reject illegal 341 # keys. 342 if _SEC_WEBSOCKET_KEY_REGEX.match(key): 343 decoded_key = base64.b64decode(key) 344 if len(decoded_key) == 16: 345 key_is_valid = True 346 except TypeError, e: 347 pass 348 349 if not key_is_valid: 350 raise HandshakeException( 351 'Illegal value for header %s: %r' % 352 (common.SEC_WEBSOCKET_KEY_HEADER, key)) 353 354 return decoded_key 355 356 def _get_key(self): 357 key = get_mandatory_header( 358 self._request, common.SEC_WEBSOCKET_KEY_HEADER) 359 360 decoded_key = self._validate_key(key) 361 362 self._logger.debug( 363 '%s: %r (%s)', 364 common.SEC_WEBSOCKET_KEY_HEADER, 365 key, 366 util.hexify(decoded_key)) 367 368 return key 369 370 def _create_stream(self, stream_options): 371 return Stream(self._request, stream_options) 372 373 def _create_handshake_response(self, accept): 374 response = [] 375 376 response.append('HTTP/1.1 101 Switching Protocols\r\n') 377 378 response.append(format_header( 379 common.UPGRADE_HEADER, common.WEBSOCKET_UPGRADE_TYPE)) 380 response.append(format_header( 381 common.CONNECTION_HEADER, common.UPGRADE_CONNECTION_TYPE)) 382 response.append(format_header( 383 common.SEC_WEBSOCKET_ACCEPT_HEADER, accept)) 384 if self._request.ws_protocol is not None: 385 response.append(format_header( 386 common.SEC_WEBSOCKET_PROTOCOL_HEADER, 387 self._request.ws_protocol)) 388 if (self._request.ws_extensions is not None and 389 len(self._request.ws_extensions) != 0): 390 response.append(format_header( 391 common.SEC_WEBSOCKET_EXTENSIONS_HEADER, 392 common.format_extensions(self._request.ws_extensions))) 393 response.append('\r\n') 394 395 return ''.join(response) 396 397 def _send_handshake(self, accept): 398 raw_response = self._create_handshake_response(accept) 399 self._request.connection.write(raw_response) 400 self._logger.debug('Sent server\'s opening handshake: %r', 401 raw_response) 402 403 404 # vi:sts=4 sw=4 et 405