Home | History | Annotate | Download | only in mod_pywebsocket
      1 # Copyright 2011, Google Inc.
      2 # All rights reserved.
      3 #
      4 # Redistribution and use in source and binary forms, with or without
      5 # modification, are permitted provided that the following conditions are
      6 # met:
      7 #
      8 #     * Redistributions of source code must retain the above copyright
      9 # notice, this list of conditions and the following disclaimer.
     10 #     * Redistributions in binary form must reproduce the above
     11 # copyright notice, this list of conditions and the following disclaimer
     12 # in the documentation and/or other materials provided with the
     13 # distribution.
     14 #     * Neither the name of Google Inc. nor the names of its
     15 # contributors may be used to endorse or promote products derived from
     16 # this software without specific prior written permission.
     17 #
     18 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
     19 # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
     20 # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
     21 # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
     22 # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
     23 # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
     24 # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
     25 # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
     26 # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
     27 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
     28 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
     29 
     30 
     31 """WebSocket utilities.
     32 """
     33 
     34 
     35 import array
     36 import errno
     37 
     38 # Import hash classes from a module available and recommended for each Python
     39 # version and re-export those symbol. Use sha and md5 module in Python 2.4, and
     40 # hashlib module in Python 2.6.
     41 try:
     42     import hashlib
     43     md5_hash = hashlib.md5
     44     sha1_hash = hashlib.sha1
     45 except ImportError:
     46     import md5
     47     import sha
     48     md5_hash = md5.md5
     49     sha1_hash = sha.sha
     50 
     51 import StringIO
     52 import logging
     53 import os
     54 import re
     55 import socket
     56 import traceback
     57 import zlib
     58 
     59 try:
     60     from mod_pywebsocket import fast_masking
     61 except ImportError:
     62     pass
     63 
     64 
     65 def get_stack_trace():
     66     """Get the current stack trace as string.
     67 
     68     This is needed to support Python 2.3.
     69     TODO: Remove this when we only support Python 2.4 and above.
     70           Use traceback.format_exc instead.
     71     """
     72 
     73     out = StringIO.StringIO()
     74     traceback.print_exc(file=out)
     75     return out.getvalue()
     76 
     77 
     78 def prepend_message_to_exception(message, exc):
     79     """Prepend message to the exception."""
     80 
     81     exc.args = (message + str(exc),)
     82     return
     83 
     84 
     85 def __translate_interp(interp, cygwin_path):
     86     """Translate interp program path for Win32 python to run cygwin program
     87     (e.g. perl).  Note that it doesn't support path that contains space,
     88     which is typically true for Unix, where #!-script is written.
     89     For Win32 python, cygwin_path is a directory of cygwin binaries.
     90 
     91     Args:
     92       interp: interp command line
     93       cygwin_path: directory name of cygwin binary, or None
     94     Returns:
     95       translated interp command line.
     96     """
     97     if not cygwin_path:
     98         return interp
     99     m = re.match('^[^ ]*/([^ ]+)( .*)?', interp)
    100     if m:
    101         cmd = os.path.join(cygwin_path, m.group(1))
    102         return cmd + m.group(2)
    103     return interp
    104 
    105 
    106 def get_script_interp(script_path, cygwin_path=None):
    107     """Gets #!-interpreter command line from the script.
    108 
    109     It also fixes command path.  When Cygwin Python is used, e.g. in WebKit,
    110     it could run "/usr/bin/perl -wT hello.pl".
    111     When Win32 Python is used, e.g. in Chromium, it couldn't.  So, fix
    112     "/usr/bin/perl" to "<cygwin_path>\perl.exe".
    113 
    114     Args:
    115       script_path: pathname of the script
    116       cygwin_path: directory name of cygwin binary, or None
    117     Returns:
    118       #!-interpreter command line, or None if it is not #!-script.
    119     """
    120     fp = open(script_path)
    121     line = fp.readline()
    122     fp.close()
    123     m = re.match('^#!(.*)', line)
    124     if m:
    125         return __translate_interp(m.group(1), cygwin_path)
    126     return None
    127 
    128 
    129 def wrap_popen3_for_win(cygwin_path):
    130     """Wrap popen3 to support #!-script on Windows.
    131 
    132     Args:
    133       cygwin_path:  path for cygwin binary if command path is needed to be
    134                     translated.  None if no translation required.
    135     """
    136 
    137     __orig_popen3 = os.popen3
    138 
    139     def __wrap_popen3(cmd, mode='t', bufsize=-1):
    140         cmdline = cmd.split(' ')
    141         interp = get_script_interp(cmdline[0], cygwin_path)
    142         if interp:
    143             cmd = interp + ' ' + cmd
    144         return __orig_popen3(cmd, mode, bufsize)
    145 
    146     os.popen3 = __wrap_popen3
    147 
    148 
    149 def hexify(s):
    150     return ' '.join(map(lambda x: '%02x' % ord(x), s))
    151 
    152 
    153 def get_class_logger(o):
    154     return logging.getLogger(
    155         '%s.%s' % (o.__class__.__module__, o.__class__.__name__))
    156 
    157 
    158 class NoopMasker(object):
    159     """A masking object that has the same interface as RepeatedXorMasker but
    160     just returns the string passed in without making any change.
    161     """
    162 
    163     def __init__(self):
    164         pass
    165 
    166     def mask(self, s):
    167         return s
    168 
    169 
    170 class RepeatedXorMasker(object):
    171     """A masking object that applies XOR on the string given to mask method
    172     with the masking bytes given to the constructor repeatedly. This object
    173     remembers the position in the masking bytes the last mask method call
    174     ended and resumes from that point on the next mask method call.
    175     """
    176 
    177     def __init__(self, masking_key):
    178         self._masking_key = masking_key
    179         self._masking_key_index = 0
    180 
    181     def _mask_using_swig(self, s):
    182         masked_data = fast_masking.mask(
    183                 s, self._masking_key, self._masking_key_index)
    184         self._masking_key_index = (
    185                 (self._masking_key_index + len(s)) % len(self._masking_key))
    186         return masked_data
    187 
    188     def _mask_using_array(self, s):
    189         result = array.array('B')
    190         result.fromstring(s)
    191 
    192         # Use temporary local variables to eliminate the cost to access
    193         # attributes
    194         masking_key = map(ord, self._masking_key)
    195         masking_key_size = len(masking_key)
    196         masking_key_index = self._masking_key_index
    197 
    198         for i in xrange(len(result)):
    199             result[i] ^= masking_key[masking_key_index]
    200             masking_key_index = (masking_key_index + 1) % masking_key_size
    201 
    202         self._masking_key_index = masking_key_index
    203 
    204         return result.tostring()
    205 
    206     if 'fast_masking' in globals():
    207         mask = _mask_using_swig
    208     else:
    209         mask = _mask_using_array
    210 
    211 
    212 # By making wbits option negative, we can suppress CMF/FLG (2 octet) and
    213 # ADLER32 (4 octet) fields of zlib so that we can use zlib module just as
    214 # deflate library. DICTID won't be added as far as we don't set dictionary.
    215 # LZ77 window of 32K will be used for both compression and decompression.
    216 # For decompression, we can just use 32K to cover any windows size. For
    217 # compression, we use 32K so receivers must use 32K.
    218 #
    219 # Compression level is Z_DEFAULT_COMPRESSION. We don't have to match level
    220 # to decode.
    221 #
    222 # See zconf.h, deflate.cc, inflate.cc of zlib library, and zlibmodule.c of
    223 # Python. See also RFC1950 (ZLIB 3.3).
    224 
    225 
    226 class _Deflater(object):
    227 
    228     def __init__(self, window_bits):
    229         self._logger = get_class_logger(self)
    230 
    231         self._compress = zlib.compressobj(
    232             zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -window_bits)
    233 
    234     def compress(self, bytes):
    235         compressed_bytes = self._compress.compress(bytes)
    236         self._logger.debug('Compress input %r', bytes)
    237         self._logger.debug('Compress result %r', compressed_bytes)
    238         return compressed_bytes
    239 
    240     def compress_and_flush(self, bytes):
    241         compressed_bytes = self._compress.compress(bytes)
    242         compressed_bytes += self._compress.flush(zlib.Z_SYNC_FLUSH)
    243         self._logger.debug('Compress input %r', bytes)
    244         self._logger.debug('Compress result %r', compressed_bytes)
    245         return compressed_bytes
    246 
    247     def compress_and_finish(self, bytes):
    248         compressed_bytes = self._compress.compress(bytes)
    249         compressed_bytes += self._compress.flush(zlib.Z_FINISH)
    250         self._logger.debug('Compress input %r', bytes)
    251         self._logger.debug('Compress result %r', compressed_bytes)
    252         return compressed_bytes
    253 
    254 
    255 class _Inflater(object):
    256 
    257     def __init__(self, window_bits):
    258         self._logger = get_class_logger(self)
    259         self._window_bits = window_bits
    260 
    261         self._unconsumed = ''
    262 
    263         self.reset()
    264 
    265     def decompress(self, size):
    266         if not (size == -1 or size > 0):
    267             raise Exception('size must be -1 or positive')
    268 
    269         data = ''
    270 
    271         while True:
    272             if size == -1:
    273                 data += self._decompress.decompress(self._unconsumed)
    274                 # See Python bug http://bugs.python.org/issue12050 to
    275                 # understand why the same code cannot be used for updating
    276                 # self._unconsumed for here and else block.
    277                 self._unconsumed = ''
    278             else:
    279                 data += self._decompress.decompress(
    280                     self._unconsumed, size - len(data))
    281                 self._unconsumed = self._decompress.unconsumed_tail
    282             if self._decompress.unused_data:
    283                 # Encountered a last block (i.e. a block with BFINAL = 1) and
    284                 # found a new stream (unused_data). We cannot use the same
    285                 # zlib.Decompress object for the new stream. Create a new
    286                 # Decompress object to decompress the new one.
    287                 #
    288                 # It's fine to ignore unconsumed_tail if unused_data is not
    289                 # empty.
    290                 self._unconsumed = self._decompress.unused_data
    291                 self.reset()
    292                 if size >= 0 and len(data) == size:
    293                     # data is filled. Don't call decompress again.
    294                     break
    295                 else:
    296                     # Re-invoke Decompress.decompress to try to decompress all
    297                     # available bytes before invoking read which blocks until
    298                     # any new byte is available.
    299                     continue
    300             else:
    301                 # Here, since unused_data is empty, even if unconsumed_tail is
    302                 # not empty, bytes of requested length are already in data. We
    303                 # don't have to "continue" here.
    304                 break
    305 
    306         if data:
    307             self._logger.debug('Decompressed %r', data)
    308         return data
    309 
    310     def append(self, data):
    311         self._logger.debug('Appended %r', data)
    312         self._unconsumed += data
    313 
    314     def reset(self):
    315         self._logger.debug('Reset')
    316         self._decompress = zlib.decompressobj(-self._window_bits)
    317 
    318 
    319 # Compresses/decompresses given octets using the method introduced in RFC1979.
    320 
    321 
    322 class _RFC1979Deflater(object):
    323     """A compressor class that applies DEFLATE to given byte sequence and
    324     flushes using the algorithm described in the RFC1979 section 2.1.
    325     """
    326 
    327     def __init__(self, window_bits, no_context_takeover):
    328         self._deflater = None
    329         if window_bits is None:
    330             window_bits = zlib.MAX_WBITS
    331         self._window_bits = window_bits
    332         self._no_context_takeover = no_context_takeover
    333 
    334     def filter(self, bytes, end=True, bfinal=False):
    335         if self._deflater is None:
    336             self._deflater = _Deflater(self._window_bits)
    337 
    338         if bfinal:
    339             result = self._deflater.compress_and_finish(bytes)
    340             # Add a padding block with BFINAL = 0 and BTYPE = 0.
    341             result = result + chr(0)
    342             self._deflater = None
    343             return result
    344 
    345         result = self._deflater.compress_and_flush(bytes)
    346         if end:
    347             # Strip last 4 octets which is LEN and NLEN field of a
    348             # non-compressed block added for Z_SYNC_FLUSH.
    349             result = result[:-4]
    350 
    351         if self._no_context_takeover and end:
    352             self._deflater = None
    353 
    354         return result
    355 
    356 
    357 class _RFC1979Inflater(object):
    358     """A decompressor class for byte sequence compressed and flushed following
    359     the algorithm described in the RFC1979 section 2.1.
    360     """
    361 
    362     def __init__(self, window_bits=zlib.MAX_WBITS):
    363         self._inflater = _Inflater(window_bits)
    364 
    365     def filter(self, bytes):
    366         # Restore stripped LEN and NLEN field of a non-compressed block added
    367         # for Z_SYNC_FLUSH.
    368         self._inflater.append(bytes + '\x00\x00\xff\xff')
    369         return self._inflater.decompress(-1)
    370 
    371 
    372 class DeflateSocket(object):
    373     """A wrapper class for socket object to intercept send and recv to perform
    374     deflate compression and decompression transparently.
    375     """
    376 
    377     # Size of the buffer passed to recv to receive compressed data.
    378     _RECV_SIZE = 4096
    379 
    380     def __init__(self, socket):
    381         self._socket = socket
    382 
    383         self._logger = get_class_logger(self)
    384 
    385         self._deflater = _Deflater(zlib.MAX_WBITS)
    386         self._inflater = _Inflater(zlib.MAX_WBITS)
    387 
    388     def recv(self, size):
    389         """Receives data from the socket specified on the construction up
    390         to the specified size. Once any data is available, returns it even
    391         if it's smaller than the specified size.
    392         """
    393 
    394         # TODO(tyoshino): Allow call with size=0. It should block until any
    395         # decompressed data is available.
    396         if size <= 0:
    397             raise Exception('Non-positive size passed')
    398         while True:
    399             data = self._inflater.decompress(size)
    400             if len(data) != 0:
    401                 return data
    402 
    403             read_data = self._socket.recv(DeflateSocket._RECV_SIZE)
    404             if not read_data:
    405                 return ''
    406             self._inflater.append(read_data)
    407 
    408     def sendall(self, bytes):
    409         self.send(bytes)
    410 
    411     def send(self, bytes):
    412         self._socket.sendall(self._deflater.compress_and_flush(bytes))
    413         return len(bytes)
    414 
    415 
    416 # vi:sts=4 sw=4 et
    417