1 # Copyright 2012, Google Inc. 2 # All rights reserved. 3 # 4 # Redistribution and use in source and binary forms, with or without 5 # modification, are permitted provided that the following conditions are 6 # met: 7 # 8 # * Redistributions of source code must retain the above copyright 9 # notice, this list of conditions and the following disclaimer. 10 # * Redistributions in binary form must reproduce the above 11 # copyright notice, this list of conditions and the following disclaimer 12 # in the documentation and/or other materials provided with the 13 # distribution. 14 # * Neither the name of Google Inc. nor the names of its 15 # contributors may be used to endorse or promote products derived from 16 # this software without specific prior written permission. 17 # 18 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19 # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20 # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21 # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22 # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24 # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 30 31 """This file provides the opening handshake processor for the WebSocket 32 protocol (RFC 6455). 33 34 Specification: 35 http://tools.ietf.org/html/rfc6455 36 """ 37 38 39 # Note: request.connection.write is used in this module, even though mod_python 40 # document says that it should be used only in connection handlers. 41 # Unfortunately, we have no other options. For example, request.write is not 42 # suitable because it doesn't allow direct raw bytes writing. 43 44 45 import base64 46 import logging 47 import os 48 import re 49 50 from mod_pywebsocket import common 51 from mod_pywebsocket.extensions import get_extension_processor 52 from mod_pywebsocket.extensions import is_compression_extension 53 from mod_pywebsocket.handshake._base import check_request_line 54 from mod_pywebsocket.handshake._base import format_header 55 from mod_pywebsocket.handshake._base import get_mandatory_header 56 from mod_pywebsocket.handshake._base import HandshakeException 57 from mod_pywebsocket.handshake._base import parse_token_list 58 from mod_pywebsocket.handshake._base import validate_mandatory_header 59 from mod_pywebsocket.handshake._base import validate_subprotocol 60 from mod_pywebsocket.handshake._base import VersionException 61 from mod_pywebsocket.stream import Stream 62 from mod_pywebsocket.stream import StreamOptions 63 from mod_pywebsocket import util 64 65 66 # Used to validate the value in the Sec-WebSocket-Key header strictly. RFC 4648 67 # disallows non-zero padding, so the character right before == must be any of 68 # A, Q, g and w. 69 _SEC_WEBSOCKET_KEY_REGEX = re.compile('^[+/0-9A-Za-z]{21}[AQgw]==$') 70 71 # Defining aliases for values used frequently. 72 _VERSION_LATEST = common.VERSION_HYBI_LATEST 73 _VERSION_LATEST_STRING = str(_VERSION_LATEST) 74 _SUPPORTED_VERSIONS = [ 75 _VERSION_LATEST, 76 ] 77 78 79 def compute_accept(key): 80 """Computes value for the Sec-WebSocket-Accept header from value of the 81 Sec-WebSocket-Key header. 82 """ 83 84 accept_binary = util.sha1_hash( 85 key + common.WEBSOCKET_ACCEPT_UUID).digest() 86 accept = base64.b64encode(accept_binary) 87 88 return (accept, accept_binary) 89 90 91 class Handshaker(object): 92 """Opening handshake processor for the WebSocket protocol (RFC 6455).""" 93 94 def __init__(self, request, dispatcher): 95 """Construct an instance. 96 97 Args: 98 request: mod_python request. 99 dispatcher: Dispatcher (dispatch.Dispatcher). 100 101 Handshaker will add attributes such as ws_resource during handshake. 102 """ 103 104 self._logger = util.get_class_logger(self) 105 106 self._request = request 107 self._dispatcher = dispatcher 108 109 def _validate_connection_header(self): 110 connection = get_mandatory_header( 111 self._request, common.CONNECTION_HEADER) 112 113 try: 114 connection_tokens = parse_token_list(connection) 115 except HandshakeException, e: 116 raise HandshakeException( 117 'Failed to parse %s: %s' % (common.CONNECTION_HEADER, e)) 118 119 connection_is_valid = False 120 for token in connection_tokens: 121 if token.lower() == common.UPGRADE_CONNECTION_TYPE.lower(): 122 connection_is_valid = True 123 break 124 if not connection_is_valid: 125 raise HandshakeException( 126 '%s header doesn\'t contain "%s"' % 127 (common.CONNECTION_HEADER, common.UPGRADE_CONNECTION_TYPE)) 128 129 def do_handshake(self): 130 self._request.ws_close_code = None 131 self._request.ws_close_reason = None 132 133 # Parsing. 134 135 check_request_line(self._request) 136 137 validate_mandatory_header( 138 self._request, 139 common.UPGRADE_HEADER, 140 common.WEBSOCKET_UPGRADE_TYPE) 141 142 self._validate_connection_header() 143 144 self._request.ws_resource = self._request.uri 145 146 unused_host = get_mandatory_header(self._request, common.HOST_HEADER) 147 148 self._request.ws_version = self._check_version() 149 150 try: 151 self._get_origin() 152 self._set_protocol() 153 self._parse_extensions() 154 155 # Key validation, response generation. 156 157 key = self._get_key() 158 (accept, accept_binary) = compute_accept(key) 159 self._logger.debug( 160 '%s: %r (%s)', 161 common.SEC_WEBSOCKET_ACCEPT_HEADER, 162 accept, 163 util.hexify(accept_binary)) 164 165 self._logger.debug('Protocol version is RFC 6455') 166 167 # Setup extension processors. 168 169 processors = [] 170 if self._request.ws_requested_extensions is not None: 171 for extension_request in self._request.ws_requested_extensions: 172 processor = get_extension_processor(extension_request) 173 # Unknown extension requests are just ignored. 174 if processor is not None: 175 processors.append(processor) 176 self._request.ws_extension_processors = processors 177 178 # List of extra headers. The extra handshake handler may add header 179 # data as name/value pairs to this list and pywebsocket appends 180 # them to the WebSocket handshake. 181 self._request.extra_headers = [] 182 183 # Extra handshake handler may modify/remove processors. 184 self._dispatcher.do_extra_handshake(self._request) 185 processors = filter(lambda processor: processor is not None, 186 self._request.ws_extension_processors) 187 188 # Ask each processor if there are extensions on the request which 189 # cannot co-exist. When processor decided other processors cannot 190 # co-exist with it, the processor marks them (or itself) as 191 # "inactive". The first extension processor has the right to 192 # make the final call. 193 for processor in reversed(processors): 194 if processor.is_active(): 195 processor.check_consistency_with_other_processors( 196 processors) 197 processors = filter(lambda processor: processor.is_active(), 198 processors) 199 200 accepted_extensions = [] 201 202 # We need to take into account of mux extension here. 203 # If mux extension exists: 204 # - Remove processors of extensions for logical channel, 205 # which are processors located before the mux processor 206 # - Pass extension requests for logical channel to mux processor 207 # - Attach the mux processor to the request. It will be referred 208 # by dispatcher to see whether the dispatcher should use mux 209 # handler or not. 210 mux_index = -1 211 for i, processor in enumerate(processors): 212 if processor.name() == common.MUX_EXTENSION: 213 mux_index = i 214 break 215 if mux_index >= 0: 216 logical_channel_extensions = [] 217 for processor in processors[:mux_index]: 218 logical_channel_extensions.append(processor.request()) 219 processor.set_active(False) 220 self._request.mux_processor = processors[mux_index] 221 self._request.mux_processor.set_extensions( 222 logical_channel_extensions) 223 processors = filter(lambda processor: processor.is_active(), 224 processors) 225 226 stream_options = StreamOptions() 227 228 for index, processor in enumerate(processors): 229 if not processor.is_active(): 230 continue 231 232 extension_response = processor.get_extension_response() 233 if extension_response is None: 234 # Rejected. 235 continue 236 237 accepted_extensions.append(extension_response) 238 239 processor.setup_stream_options(stream_options) 240 241 if not is_compression_extension(processor.name()): 242 continue 243 244 # Inactivate all of the following compression extensions. 245 for j in xrange(index + 1, len(processors)): 246 if is_compression_extension(processors[j].name()): 247 processors[j].set_active(False) 248 249 if len(accepted_extensions) > 0: 250 self._request.ws_extensions = accepted_extensions 251 self._logger.debug( 252 'Extensions accepted: %r', 253 map(common.ExtensionParameter.name, accepted_extensions)) 254 else: 255 self._request.ws_extensions = None 256 257 self._request.ws_stream = self._create_stream(stream_options) 258 259 if self._request.ws_requested_protocols is not None: 260 if self._request.ws_protocol is None: 261 raise HandshakeException( 262 'do_extra_handshake must choose one subprotocol from ' 263 'ws_requested_protocols and set it to ws_protocol') 264 validate_subprotocol(self._request.ws_protocol) 265 266 self._logger.debug( 267 'Subprotocol accepted: %r', 268 self._request.ws_protocol) 269 else: 270 if self._request.ws_protocol is not None: 271 raise HandshakeException( 272 'ws_protocol must be None when the client didn\'t ' 273 'request any subprotocol') 274 275 self._send_handshake(accept) 276 except HandshakeException, e: 277 if not e.status: 278 # Fallback to 400 bad request by default. 279 e.status = common.HTTP_STATUS_BAD_REQUEST 280 raise e 281 282 def _get_origin(self): 283 origin_header = common.ORIGIN_HEADER 284 origin = self._request.headers_in.get(origin_header) 285 if origin is None: 286 self._logger.debug('Client request does not have origin header') 287 self._request.ws_origin = origin 288 289 def _check_version(self): 290 version = get_mandatory_header(self._request, 291 common.SEC_WEBSOCKET_VERSION_HEADER) 292 if version == _VERSION_LATEST_STRING: 293 return _VERSION_LATEST 294 295 if version.find(',') >= 0: 296 raise HandshakeException( 297 'Multiple versions (%r) are not allowed for header %s' % 298 (version, common.SEC_WEBSOCKET_VERSION_HEADER), 299 status=common.HTTP_STATUS_BAD_REQUEST) 300 raise VersionException( 301 'Unsupported version %r for header %s' % 302 (version, common.SEC_WEBSOCKET_VERSION_HEADER), 303 supported_versions=', '.join(map(str, _SUPPORTED_VERSIONS))) 304 305 def _set_protocol(self): 306 self._request.ws_protocol = None 307 308 protocol_header = self._request.headers_in.get( 309 common.SEC_WEBSOCKET_PROTOCOL_HEADER) 310 311 if protocol_header is None: 312 self._request.ws_requested_protocols = None 313 return 314 315 self._request.ws_requested_protocols = parse_token_list( 316 protocol_header) 317 self._logger.debug('Subprotocols requested: %r', 318 self._request.ws_requested_protocols) 319 320 def _parse_extensions(self): 321 extensions_header = self._request.headers_in.get( 322 common.SEC_WEBSOCKET_EXTENSIONS_HEADER) 323 if not extensions_header: 324 self._request.ws_requested_extensions = None 325 return 326 327 try: 328 self._request.ws_requested_extensions = common.parse_extensions( 329 extensions_header) 330 except common.ExtensionParsingException, e: 331 raise HandshakeException( 332 'Failed to parse Sec-WebSocket-Extensions header: %r' % e) 333 334 self._logger.debug( 335 'Extensions requested: %r', 336 map(common.ExtensionParameter.name, 337 self._request.ws_requested_extensions)) 338 339 def _validate_key(self, key): 340 if key.find(',') >= 0: 341 raise HandshakeException('Request has multiple %s header lines or ' 342 'contains illegal character \',\': %r' % 343 (common.SEC_WEBSOCKET_KEY_HEADER, key)) 344 345 # Validate 346 key_is_valid = False 347 try: 348 # Validate key by quick regex match before parsing by base64 349 # module. Because base64 module skips invalid characters, we have 350 # to do this in advance to make this server strictly reject illegal 351 # keys. 352 if _SEC_WEBSOCKET_KEY_REGEX.match(key): 353 decoded_key = base64.b64decode(key) 354 if len(decoded_key) == 16: 355 key_is_valid = True 356 except TypeError, e: 357 pass 358 359 if not key_is_valid: 360 raise HandshakeException( 361 'Illegal value for header %s: %r' % 362 (common.SEC_WEBSOCKET_KEY_HEADER, key)) 363 364 return decoded_key 365 366 def _get_key(self): 367 key = get_mandatory_header( 368 self._request, common.SEC_WEBSOCKET_KEY_HEADER) 369 370 decoded_key = self._validate_key(key) 371 372 self._logger.debug( 373 '%s: %r (%s)', 374 common.SEC_WEBSOCKET_KEY_HEADER, 375 key, 376 util.hexify(decoded_key)) 377 378 return key 379 380 def _create_stream(self, stream_options): 381 return Stream(self._request, stream_options) 382 383 def _create_handshake_response(self, accept): 384 response = [] 385 386 response.append('HTTP/1.1 101 Switching Protocols\r\n') 387 388 # WebSocket headers 389 response.append(format_header( 390 common.UPGRADE_HEADER, common.WEBSOCKET_UPGRADE_TYPE)) 391 response.append(format_header( 392 common.CONNECTION_HEADER, common.UPGRADE_CONNECTION_TYPE)) 393 response.append(format_header( 394 common.SEC_WEBSOCKET_ACCEPT_HEADER, accept)) 395 if self._request.ws_protocol is not None: 396 response.append(format_header( 397 common.SEC_WEBSOCKET_PROTOCOL_HEADER, 398 self._request.ws_protocol)) 399 if (self._request.ws_extensions is not None and 400 len(self._request.ws_extensions) != 0): 401 response.append(format_header( 402 common.SEC_WEBSOCKET_EXTENSIONS_HEADER, 403 common.format_extensions(self._request.ws_extensions))) 404 405 # Headers not specific for WebSocket 406 for name, value in self._request.extra_headers: 407 response.append(format_header(name, value)) 408 409 response.append('\r\n') 410 411 return ''.join(response) 412 413 def _send_handshake(self, accept): 414 raw_response = self._create_handshake_response(accept) 415 self._request.connection.write(raw_response) 416 self._logger.debug('Sent server\'s opening handshake: %r', 417 raw_response) 418 419 420 # vi:sts=4 sw=4 et 421