Home | History | Annotate | Download | only in mod_pywebsocket
      1 # Copyright 2011, Google Inc.
      2 # All rights reserved.
      3 #
      4 # Redistribution and use in source and binary forms, with or without
      5 # modification, are permitted provided that the following conditions are
      6 # met:
      7 #
      8 #     * Redistributions of source code must retain the above copyright
      9 # notice, this list of conditions and the following disclaimer.
     10 #     * Redistributions in binary form must reproduce the above
     11 # copyright notice, this list of conditions and the following disclaimer
     12 # in the documentation and/or other materials provided with the
     13 # distribution.
     14 #     * Neither the name of Google Inc. nor the names of its
     15 # contributors may be used to endorse or promote products derived from
     16 # this software without specific prior written permission.
     17 #
     18 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
     19 # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
     20 # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
     21 # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
     22 # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
     23 # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
     24 # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
     25 # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
     26 # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
     27 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
     28 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
     29 
     30 
     31 """WebSocket utilities.
     32 """
     33 
     34 
     35 import array
     36 import errno
     37 
     38 # Import hash classes from a module available and recommended for each Python
     39 # version and re-export those symbol. Use sha and md5 module in Python 2.4, and
     40 # hashlib module in Python 2.6.
     41 try:
     42     import hashlib
     43     md5_hash = hashlib.md5
     44     sha1_hash = hashlib.sha1
     45 except ImportError:
     46     import md5
     47     import sha
     48     md5_hash = md5.md5
     49     sha1_hash = sha.sha
     50 
     51 import StringIO
     52 import logging
     53 import os
     54 import re
     55 import socket
     56 import traceback
     57 import zlib
     58 
     59 
     60 def get_stack_trace():
     61     """Get the current stack trace as string.
     62 
     63     This is needed to support Python 2.3.
     64     TODO: Remove this when we only support Python 2.4 and above.
     65           Use traceback.format_exc instead.
     66     """
     67 
     68     out = StringIO.StringIO()
     69     traceback.print_exc(file=out)
     70     return out.getvalue()
     71 
     72 
     73 def prepend_message_to_exception(message, exc):
     74     """Prepend message to the exception."""
     75 
     76     exc.args = (message + str(exc),)
     77     return
     78 
     79 
     80 def __translate_interp(interp, cygwin_path):
     81     """Translate interp program path for Win32 python to run cygwin program
     82     (e.g. perl).  Note that it doesn't support path that contains space,
     83     which is typically true for Unix, where #!-script is written.
     84     For Win32 python, cygwin_path is a directory of cygwin binaries.
     85 
     86     Args:
     87       interp: interp command line
     88       cygwin_path: directory name of cygwin binary, or None
     89     Returns:
     90       translated interp command line.
     91     """
     92     if not cygwin_path:
     93         return interp
     94     m = re.match('^[^ ]*/([^ ]+)( .*)?', interp)
     95     if m:
     96         cmd = os.path.join(cygwin_path, m.group(1))
     97         return cmd + m.group(2)
     98     return interp
     99 
    100 
    101 def get_script_interp(script_path, cygwin_path=None):
    102     """Gets #!-interpreter command line from the script.
    103 
    104     It also fixes command path.  When Cygwin Python is used, e.g. in WebKit,
    105     it could run "/usr/bin/perl -wT hello.pl".
    106     When Win32 Python is used, e.g. in Chromium, it couldn't.  So, fix
    107     "/usr/bin/perl" to "<cygwin_path>\perl.exe".
    108 
    109     Args:
    110       script_path: pathname of the script
    111       cygwin_path: directory name of cygwin binary, or None
    112     Returns:
    113       #!-interpreter command line, or None if it is not #!-script.
    114     """
    115     fp = open(script_path)
    116     line = fp.readline()
    117     fp.close()
    118     m = re.match('^#!(.*)', line)
    119     if m:
    120         return __translate_interp(m.group(1), cygwin_path)
    121     return None
    122 
    123 
    124 def wrap_popen3_for_win(cygwin_path):
    125     """Wrap popen3 to support #!-script on Windows.
    126 
    127     Args:
    128       cygwin_path:  path for cygwin binary if command path is needed to be
    129                     translated.  None if no translation required.
    130     """
    131 
    132     __orig_popen3 = os.popen3
    133 
    134     def __wrap_popen3(cmd, mode='t', bufsize=-1):
    135         cmdline = cmd.split(' ')
    136         interp = get_script_interp(cmdline[0], cygwin_path)
    137         if interp:
    138             cmd = interp + ' ' + cmd
    139         return __orig_popen3(cmd, mode, bufsize)
    140 
    141     os.popen3 = __wrap_popen3
    142 
    143 
    144 def hexify(s):
    145     return ' '.join(map(lambda x: '%02x' % ord(x), s))
    146 
    147 
    148 def get_class_logger(o):
    149     return logging.getLogger(
    150         '%s.%s' % (o.__class__.__module__, o.__class__.__name__))
    151 
    152 
    153 class NoopMasker(object):
    154     """A masking object that has the same interface as RepeatedXorMasker but
    155     just returns the string passed in without making any change.
    156     """
    157 
    158     def __init__(self):
    159         pass
    160 
    161     def mask(self, s):
    162         return s
    163 
    164 
    165 class RepeatedXorMasker(object):
    166     """A masking object that applies XOR on the string given to mask method
    167     with the masking bytes given to the constructor repeatedly. This object
    168     remembers the position in the masking bytes the last mask method call
    169     ended and resumes from that point on the next mask method call.
    170     """
    171 
    172     def __init__(self, mask):
    173         self._mask = map(ord, mask)
    174         self._mask_size = len(self._mask)
    175         self._count = 0
    176 
    177     def mask(self, s):
    178         result = array.array('B')
    179         result.fromstring(s)
    180         # Use temporary local variables to eliminate the cost to access
    181         # attributes
    182         count = self._count
    183         mask = self._mask
    184         mask_size = self._mask_size
    185         for i in xrange(len(result)):
    186             result[i] ^= mask[count]
    187             count = (count + 1) % mask_size
    188         self._count = count
    189 
    190         return result.tostring()
    191 
    192 
    193 class DeflateRequest(object):
    194     """A wrapper class for request object to intercept send and recv to perform
    195     deflate compression and decompression transparently.
    196     """
    197 
    198     def __init__(self, request):
    199         self._request = request
    200         self.connection = DeflateConnection(request.connection)
    201 
    202     def __getattribute__(self, name):
    203         if name in ('_request', 'connection'):
    204             return object.__getattribute__(self, name)
    205         return self._request.__getattribute__(name)
    206 
    207     def __setattr__(self, name, value):
    208         if name in ('_request', 'connection'):
    209             return object.__setattr__(self, name, value)
    210         return self._request.__setattr__(name, value)
    211 
    212 
    213 # By making wbits option negative, we can suppress CMF/FLG (2 octet) and
    214 # ADLER32 (4 octet) fields of zlib so that we can use zlib module just as
    215 # deflate library. DICTID won't be added as far as we don't set dictionary.
    216 # LZ77 window of 32K will be used for both compression and decompression.
    217 # For decompression, we can just use 32K to cover any windows size. For
    218 # compression, we use 32K so receivers must use 32K.
    219 #
    220 # Compression level is Z_DEFAULT_COMPRESSION. We don't have to match level
    221 # to decode.
    222 #
    223 # See zconf.h, deflate.cc, inflate.cc of zlib library, and zlibmodule.c of
    224 # Python. See also RFC1950 (ZLIB 3.3).
    225 
    226 
    227 class _Deflater(object):
    228 
    229     def __init__(self, window_bits):
    230         self._logger = get_class_logger(self)
    231 
    232         self._compress = zlib.compressobj(
    233             zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -window_bits)
    234 
    235     def compress_and_flush(self, bytes):
    236         compressed_bytes = self._compress.compress(bytes)
    237         compressed_bytes += self._compress.flush(zlib.Z_SYNC_FLUSH)
    238         self._logger.debug('Compress input %r', bytes)
    239         self._logger.debug('Compress result %r', compressed_bytes)
    240         return compressed_bytes
    241 
    242 
    243 class _Inflater(object):
    244 
    245     def __init__(self):
    246         self._logger = get_class_logger(self)
    247 
    248         self._unconsumed = ''
    249 
    250         self.reset()
    251 
    252     def decompress(self, size):
    253         if not (size == -1 or size > 0):
    254             raise Exception('size must be -1 or positive')
    255 
    256         data = ''
    257 
    258         while True:
    259             if size == -1:
    260                 data += self._decompress.decompress(self._unconsumed)
    261                 # See Python bug http://bugs.python.org/issue12050 to
    262                 # understand why the same code cannot be used for updating
    263                 # self._unconsumed for here and else block.
    264                 self._unconsumed = ''
    265             else:
    266                 data += self._decompress.decompress(
    267                     self._unconsumed, size - len(data))
    268                 self._unconsumed = self._decompress.unconsumed_tail
    269             if self._decompress.unused_data:
    270                 # Encountered a last block (i.e. a block with BFINAL = 1) and
    271                 # found a new stream (unused_data). We cannot use the same
    272                 # zlib.Decompress object for the new stream. Create a new
    273                 # Decompress object to decompress the new one.
    274                 #
    275                 # It's fine to ignore unconsumed_tail if unused_data is not
    276                 # empty.
    277                 self._unconsumed = self._decompress.unused_data
    278                 self.reset()
    279                 if size >= 0 and len(data) == size:
    280                     # data is filled. Don't call decompress again.
    281                     break
    282                 else:
    283                     # Re-invoke Decompress.decompress to try to decompress all
    284                     # available bytes before invoking read which blocks until
    285                     # any new byte is available.
    286                     continue
    287             else:
    288                 # Here, since unused_data is empty, even if unconsumed_tail is
    289                 # not empty, bytes of requested length are already in data. We
    290                 # don't have to "continue" here.
    291                 break
    292 
    293         if data:
    294             self._logger.debug('Decompressed %r', data)
    295         return data
    296 
    297     def append(self, data):
    298         self._logger.debug('Appended %r', data)
    299         self._unconsumed += data
    300 
    301     def reset(self):
    302         self._logger.debug('Reset')
    303         self._decompress = zlib.decompressobj(-zlib.MAX_WBITS)
    304 
    305 
    306 # Compresses/decompresses given octets using the method introduced in RFC1979.
    307 
    308 
    309 class _RFC1979Deflater(object):
    310     """A compressor class that applies DEFLATE to given byte sequence and
    311     flushes using the algorithm described in the RFC1979 section 2.1.
    312     """
    313 
    314     def __init__(self, window_bits, no_context_takeover):
    315         self._deflater = None
    316         if window_bits is None:
    317             window_bits = zlib.MAX_WBITS
    318         self._window_bits = window_bits
    319         self._no_context_takeover = no_context_takeover
    320 
    321     def filter(self, bytes):
    322         if self._deflater is None or self._no_context_takeover:
    323             self._deflater = _Deflater(self._window_bits)
    324 
    325         # Strip last 4 octets which is LEN and NLEN field of a non-compressed
    326         # block added for Z_SYNC_FLUSH.
    327         return self._deflater.compress_and_flush(bytes)[:-4]
    328 
    329 
    330 class _RFC1979Inflater(object):
    331     """A decompressor class for byte sequence compressed and flushed following
    332     the algorithm described in the RFC1979 section 2.1.
    333     """
    334 
    335     def __init__(self):
    336         self._inflater = _Inflater()
    337 
    338     def filter(self, bytes):
    339         # Restore stripped LEN and NLEN field of a non-compressed block added
    340         # for Z_SYNC_FLUSH.
    341         self._inflater.append(bytes + '\x00\x00\xff\xff')
    342         return self._inflater.decompress(-1)
    343 
    344 
    345 class DeflateSocket(object):
    346     """A wrapper class for socket object to intercept send and recv to perform
    347     deflate compression and decompression transparently.
    348     """
    349 
    350     # Size of the buffer passed to recv to receive compressed data.
    351     _RECV_SIZE = 4096
    352 
    353     def __init__(self, socket):
    354         self._socket = socket
    355 
    356         self._logger = get_class_logger(self)
    357 
    358         self._deflater = _Deflater(zlib.MAX_WBITS)
    359         self._inflater = _Inflater()
    360 
    361     def recv(self, size):
    362         """Receives data from the socket specified on the construction up
    363         to the specified size. Once any data is available, returns it even
    364         if it's smaller than the specified size.
    365         """
    366 
    367         # TODO(tyoshino): Allow call with size=0. It should block until any
    368         # decompressed data is available.
    369         if size <= 0:
    370             raise Exception('Non-positive size passed')
    371         while True:
    372             data = self._inflater.decompress(size)
    373             if len(data) != 0:
    374                 return data
    375 
    376             read_data = self._socket.recv(DeflateSocket._RECV_SIZE)
    377             if not read_data:
    378                 return ''
    379             self._inflater.append(read_data)
    380 
    381     def sendall(self, bytes):
    382         self.send(bytes)
    383 
    384     def send(self, bytes):
    385         self._socket.sendall(self._deflater.compress_and_flush(bytes))
    386         return len(bytes)
    387 
    388 
    389 class DeflateConnection(object):
    390     """A wrapper class for request object to intercept write and read to
    391     perform deflate compression and decompression transparently.
    392     """
    393 
    394     def __init__(self, connection):
    395         self._connection = connection
    396 
    397         self._logger = get_class_logger(self)
    398 
    399         self._deflater = _Deflater(zlib.MAX_WBITS)
    400         self._inflater = _Inflater()
    401 
    402     def get_remote_addr(self):
    403         return self._connection.remote_addr
    404     remote_addr = property(get_remote_addr)
    405 
    406     def put_bytes(self, bytes):
    407         self.write(bytes)
    408 
    409     def read(self, size=-1):
    410         """Reads at most size bytes. Blocks until there's at least one byte
    411         available.
    412         """
    413 
    414         # TODO(tyoshino): Allow call with size=0.
    415         if not (size == -1 or size > 0):
    416             raise Exception('size must be -1 or positive')
    417 
    418         data = ''
    419         while True:
    420             if size == -1:
    421                 data += self._inflater.decompress(-1)
    422             else:
    423                 data += self._inflater.decompress(size - len(data))
    424 
    425             if size >= 0 and len(data) != 0:
    426                 break
    427 
    428             # TODO(tyoshino): Make this read efficient by some workaround.
    429             #
    430             # In 3.0.3 and prior of mod_python, read blocks until length bytes
    431             # was read. We don't know the exact size to read while using
    432             # deflate, so read byte-by-byte.
    433             #
    434             # _StandaloneRequest.read that ultimately performs
    435             # socket._fileobject.read also blocks until length bytes was read
    436             read_data = self._connection.read(1)
    437             if not read_data:
    438                 break
    439             self._inflater.append(read_data)
    440         return data
    441 
    442     def write(self, bytes):
    443         self._connection.write(self._deflater.compress_and_flush(bytes))
    444 
    445 
    446 def _is_ewouldblock_errno(error_number):
    447     """Returns True iff error_number indicates that receive operation would
    448     block. To make this portable, we check availability of errno and then
    449     compare them.
    450     """
    451 
    452     for error_name in ['WSAEWOULDBLOCK', 'EWOULDBLOCK', 'EAGAIN']:
    453         if (error_name in dir(errno) and
    454             error_number == getattr(errno, error_name)):
    455             return True
    456     return False
    457 
    458 
    459 def drain_received_data(raw_socket):
    460     # Set the socket non-blocking.
    461     original_timeout = raw_socket.gettimeout()
    462     raw_socket.settimeout(0.0)
    463 
    464     drained_data = []
    465 
    466     # Drain until the socket is closed or no data is immediately
    467     # available for read.
    468     while True:
    469         try:
    470             data = raw_socket.recv(1)
    471             if not data:
    472                 break
    473             drained_data.append(data)
    474         except socket.error, e:
    475             # e can be either a pair (errno, string) or just a string (or
    476             # something else) telling what went wrong. We suppress only
    477             # the errors that indicates that the socket blocks. Those
    478             # exceptions can be parsed as a pair (errno, string).
    479             try:
    480                 error_number, message = e
    481             except:
    482                 # Failed to parse socket.error.
    483                 raise e
    484 
    485             if _is_ewouldblock_errno(error_number):
    486                 break
    487             else:
    488                 raise e
    489 
    490     # Rollback timeout value.
    491     raw_socket.settimeout(original_timeout)
    492 
    493     return ''.join(drained_data)
    494 
    495 
    496 # vi:sts=4 sw=4 et
    497