Home | History | Annotate | Download | only in tlslite
      1 """Helper class for TLSConnection."""
      2 from __future__ import generators
      3 
      4 from utils.compat import *
      5 from utils.cryptomath import *
      6 from utils.cipherfactory import createAES, createRC4, createTripleDES
      7 from utils.codec import *
      8 from errors import *
      9 from messages import *
     10 from mathtls import *
     11 from constants import *
     12 from utils.cryptomath import getRandomBytes
     13 from utils import hmac
     14 from FileObject import FileObject
     15 
     16 # The sha module is deprecated in Python 2.6 
     17 try:
     18     import sha
     19 except ImportError:
     20     from hashlib import sha1 as sha
     21 
     22 # The md5 module is deprecated in Python 2.6
     23 try:
     24     import md5
     25 except ImportError:
     26     from hashlib import md5
     27 
     28 import socket
     29 import errno
     30 import traceback
     31 
     32 class _ConnectionState:
     33     def __init__(self):
     34         self.macContext = None
     35         self.encContext = None
     36         self.seqnum = 0
     37 
     38     def getSeqNumStr(self):
     39         w = Writer(8)
     40         w.add(self.seqnum, 8)
     41         seqnumStr = bytesToString(w.bytes)
     42         self.seqnum += 1
     43         return seqnumStr
     44 
     45 
     46 class TLSRecordLayer:
     47     """
     48     This class handles data transmission for a TLS connection.
     49 
     50     Its only subclass is L{tlslite.TLSConnection.TLSConnection}.  We've
     51     separated the code in this class from TLSConnection to make things
     52     more readable.
     53 
     54 
     55     @type sock: socket.socket
     56     @ivar sock: The underlying socket object.
     57 
     58     @type session: L{tlslite.Session.Session}
     59     @ivar session: The session corresponding to this connection.
     60 
     61     Due to TLS session resumption, multiple connections can correspond
     62     to the same underlying session.
     63 
     64     @type version: tuple
     65     @ivar version: The TLS version being used for this connection.
     66 
     67     (3,0) means SSL 3.0, and (3,1) means TLS 1.0.
     68 
     69     @type closed: bool
     70     @ivar closed: If this connection is closed.
     71 
     72     @type resumed: bool
     73     @ivar resumed: If this connection is based on a resumed session.
     74 
     75     @type allegedSharedKeyUsername: str or None
     76     @ivar allegedSharedKeyUsername:  This is set to the shared-key
     77     username asserted by the client, whether the handshake succeeded or
     78     not.  If the handshake fails, this can be inspected to
     79     determine if a guessing attack is in progress against a particular
     80     user account.
     81 
     82     @type allegedSrpUsername: str or None
     83     @ivar allegedSrpUsername:  This is set to the SRP username
     84     asserted by the client, whether the handshake succeeded or not.
     85     If the handshake fails, this can be inspected to determine
     86     if a guessing attack is in progress against a particular user
     87     account.
     88 
     89     @type closeSocket: bool
     90     @ivar closeSocket: If the socket should be closed when the
     91     connection is closed (writable).
     92 
     93     If you set this to True, TLS Lite will assume the responsibility of
     94     closing the socket when the TLS Connection is shutdown (either
     95     through an error or through the user calling close()).  The default
     96     is False.
     97 
     98     @type ignoreAbruptClose: bool
     99     @ivar ignoreAbruptClose: If an abrupt close of the socket should
    100     raise an error (writable).
    101 
    102     If you set this to True, TLS Lite will not raise a
    103     L{tlslite.errors.TLSAbruptCloseError} exception if the underlying
    104     socket is unexpectedly closed.  Such an unexpected closure could be
    105     caused by an attacker.  However, it also occurs with some incorrect
    106     TLS implementations.
    107 
    108     You should set this to True only if you're not worried about an
    109     attacker truncating the connection, and only if necessary to avoid
    110     spurious errors.  The default is False.
    111 
    112     @sort: __init__, read, readAsync, write, writeAsync, close, closeAsync,
    113     getCipherImplementation, getCipherName
    114     """
    115 
    116     def __init__(self, sock):
    117         self.sock = sock
    118 
    119         #My session object (Session instance; read-only)
    120         self.session = None
    121 
    122         #Am I a client or server?
    123         self._client = None
    124 
    125         #Buffers for processing messages
    126         self._handshakeBuffer = []
    127         self._readBuffer = ""
    128 
    129         #Handshake digests
    130         self._handshake_md5 = md5.md5()
    131         self._handshake_sha = sha.sha()
    132 
    133         #TLS Protocol Version
    134         self.version = (0,0) #read-only
    135         self._versionCheck = False #Once we choose a version, this is True
    136 
    137         #Current and Pending connection states
    138         self._writeState = _ConnectionState()
    139         self._readState = _ConnectionState()
    140         self._pendingWriteState = _ConnectionState()
    141         self._pendingReadState = _ConnectionState()
    142 
    143         #Is the connection open?
    144         self.closed = True #read-only
    145         self._refCount = 0 #Used to trigger closure
    146 
    147         #Is this a resumed (or shared-key) session?
    148         self.resumed = False #read-only
    149 
    150         #What username did the client claim in his handshake?
    151         self.allegedSharedKeyUsername = None
    152         self.allegedSrpUsername = None
    153 
    154         #On a call to close(), do we close the socket? (writeable)
    155         self.closeSocket = False
    156 
    157         #If the socket is abruptly closed, do we ignore it
    158         #and pretend the connection was shut down properly? (writeable)
    159         self.ignoreAbruptClose = False
    160 
    161         #Fault we will induce, for testing purposes
    162         self.fault = None
    163 
    164     #*********************************************************
    165     # Public Functions START
    166     #*********************************************************
    167 
    168     def read(self, max=None, min=1):
    169         """Read some data from the TLS connection.
    170 
    171         This function will block until at least 'min' bytes are
    172         available (or the connection is closed).
    173 
    174         If an exception is raised, the connection will have been
    175         automatically closed.
    176 
    177         @type max: int
    178         @param max: The maximum number of bytes to return.
    179 
    180         @type min: int
    181         @param min: The minimum number of bytes to return
    182 
    183         @rtype: str
    184         @return: A string of no more than 'max' bytes, and no fewer
    185         than 'min' (unless the connection has been closed, in which
    186         case fewer than 'min' bytes may be returned).
    187 
    188         @raise socket.error: If a socket error occurs.
    189         @raise tlslite.errors.TLSAbruptCloseError: If the socket is closed
    190         without a preceding alert.
    191         @raise tlslite.errors.TLSAlert: If a TLS alert is signalled.
    192         """
    193         for result in self.readAsync(max, min):
    194             pass
    195         return result
    196 
    197     def readAsync(self, max=None, min=1):
    198         """Start a read operation on the TLS connection.
    199 
    200         This function returns a generator which behaves similarly to
    201         read().  Successive invocations of the generator will return 0
    202         if it is waiting to read from the socket, 1 if it is waiting
    203         to write to the socket, or a string if the read operation has
    204         completed.
    205 
    206         @rtype: iterable
    207         @return: A generator; see above for details.
    208         """
    209         try:
    210             while len(self._readBuffer)<min and not self.closed:
    211                 try:
    212                     for result in self._getMsg(ContentType.application_data):
    213                         if result in (0,1):
    214                             yield result
    215                     applicationData = result
    216                     self._readBuffer += bytesToString(applicationData.write())
    217                 except TLSRemoteAlert, alert:
    218                     if alert.description != AlertDescription.close_notify:
    219                         raise
    220                 except TLSAbruptCloseError:
    221                     if not self.ignoreAbruptClose:
    222                         raise
    223                     else:
    224                         self._shutdown(True)
    225 
    226             if max == None:
    227                 max = len(self._readBuffer)
    228 
    229             returnStr = self._readBuffer[:max]
    230             self._readBuffer = self._readBuffer[max:]
    231             yield returnStr
    232         except:
    233             self._shutdown(False)
    234             raise
    235 
    236     def write(self, s):
    237         """Write some data to the TLS connection.
    238 
    239         This function will block until all the data has been sent.
    240 
    241         If an exception is raised, the connection will have been
    242         automatically closed.
    243 
    244         @type s: str
    245         @param s: The data to transmit to the other party.
    246 
    247         @raise socket.error: If a socket error occurs.
    248         """
    249         for result in self.writeAsync(s):
    250             pass
    251 
    252     def writeAsync(self, s):
    253         """Start a write operation on the TLS connection.
    254 
    255         This function returns a generator which behaves similarly to
    256         write().  Successive invocations of the generator will return
    257         1 if it is waiting to write to the socket, or will raise
    258         StopIteration if the write operation has completed.
    259 
    260         @rtype: iterable
    261         @return: A generator; see above for details.
    262         """
    263         try:
    264             if self.closed:
    265                 raise ValueError()
    266 
    267             index = 0
    268             blockSize = 16384
    269             skipEmptyFrag = False
    270             while 1:
    271                 startIndex = index * blockSize
    272                 endIndex = startIndex + blockSize
    273                 if startIndex >= len(s):
    274                     break
    275                 if endIndex > len(s):
    276                     endIndex = len(s)
    277                 block = stringToBytes(s[startIndex : endIndex])
    278                 applicationData = ApplicationData().create(block)
    279                 for result in self._sendMsg(applicationData, skipEmptyFrag):
    280                     yield result
    281                 skipEmptyFrag = True #only send an empy fragment on 1st message
    282                 index += 1
    283         except:
    284             self._shutdown(False)
    285             raise
    286 
    287     def close(self):
    288         """Close the TLS connection.
    289 
    290         This function will block until it has exchanged close_notify
    291         alerts with the other party.  After doing so, it will shut down the
    292         TLS connection.  Further attempts to read through this connection
    293         will return "".  Further attempts to write through this connection
    294         will raise ValueError.
    295 
    296         If makefile() has been called on this connection, the connection
    297         will be not be closed until the connection object and all file
    298         objects have been closed.
    299 
    300         Even if an exception is raised, the connection will have been
    301         closed.
    302 
    303         @raise socket.error: If a socket error occurs.
    304         @raise tlslite.errors.TLSAbruptCloseError: If the socket is closed
    305         without a preceding alert.
    306         @raise tlslite.errors.TLSAlert: If a TLS alert is signalled.
    307         """
    308         if not self.closed:
    309             for result in self._decrefAsync():
    310                 pass
    311 
    312     def closeAsync(self):
    313         """Start a close operation on the TLS connection.
    314 
    315         This function returns a generator which behaves similarly to
    316         close().  Successive invocations of the generator will return 0
    317         if it is waiting to read from the socket, 1 if it is waiting
    318         to write to the socket, or will raise StopIteration if the
    319         close operation has completed.
    320 
    321         @rtype: iterable
    322         @return: A generator; see above for details.
    323         """
    324         if not self.closed:
    325             for result in self._decrefAsync():
    326                 yield result
    327 
    328     def _decrefAsync(self):
    329         self._refCount -= 1
    330         if self._refCount == 0 and not self.closed:
    331             try:
    332                 for result in self._sendMsg(Alert().create(\
    333                         AlertDescription.close_notify, AlertLevel.warning)):
    334                     yield result
    335                 alert = None
    336                 # Forcing a shutdown as WinHTTP does not seem to be
    337                 # responsive to the close notify.
    338                 prevCloseSocket = self.closeSocket
    339                 self.closeSocket = True
    340                 self._shutdown(True)
    341                 self.closeSocket = prevCloseSocket
    342                 while not alert:
    343                     for result in self._getMsg((ContentType.alert, \
    344                                               ContentType.application_data)):
    345                         if result in (0,1):
    346                             yield result
    347                     if result.contentType == ContentType.alert:
    348                         alert = result
    349                 if alert.description == AlertDescription.close_notify:
    350                     self._shutdown(True)
    351                 else:
    352                     raise TLSRemoteAlert(alert)
    353             except (socket.error, TLSAbruptCloseError):
    354                 #If the other side closes the socket, that's okay
    355                 self._shutdown(True)
    356             except:
    357                 self._shutdown(False)
    358                 raise
    359 
    360     def getCipherName(self):
    361         """Get the name of the cipher used with this connection.
    362 
    363         @rtype: str
    364         @return: The name of the cipher used with this connection.
    365         Either 'aes128', 'aes256', 'rc4', or '3des'.
    366         """
    367         if not self._writeState.encContext:
    368             return None
    369         return self._writeState.encContext.name
    370 
    371     def getCipherImplementation(self):
    372         """Get the name of the cipher implementation used with
    373         this connection.
    374 
    375         @rtype: str
    376         @return: The name of the cipher implementation used with
    377         this connection.  Either 'python', 'cryptlib', 'openssl',
    378         or 'pycrypto'.
    379         """
    380         if not self._writeState.encContext:
    381             return None
    382         return self._writeState.encContext.implementation
    383 
    384 
    385 
    386     #Emulate a socket, somewhat -
    387     def send(self, s):
    388         """Send data to the TLS connection (socket emulation).
    389 
    390         @raise socket.error: If a socket error occurs.
    391         """
    392         self.write(s)
    393         return len(s)
    394 
    395     def sendall(self, s):
    396         """Send data to the TLS connection (socket emulation).
    397 
    398         @raise socket.error: If a socket error occurs.
    399         """
    400         self.write(s)
    401 
    402     def recv(self, bufsize):
    403         """Get some data from the TLS connection (socket emulation).
    404 
    405         @raise socket.error: If a socket error occurs.
    406         @raise tlslite.errors.TLSAbruptCloseError: If the socket is closed
    407         without a preceding alert.
    408         @raise tlslite.errors.TLSAlert: If a TLS alert is signalled.
    409         """
    410         return self.read(bufsize)
    411 
    412     def makefile(self, mode='r', bufsize=-1):
    413         """Create a file object for the TLS connection (socket emulation).
    414 
    415         @rtype: L{tlslite.FileObject.FileObject}
    416         """
    417         self._refCount += 1
    418         return FileObject(self, mode, bufsize)
    419 
    420     def getsockname(self):
    421         """Return the socket's own address (socket emulation)."""
    422         return self.sock.getsockname()
    423 
    424     def getpeername(self):
    425         """Return the remote address to which the socket is connected
    426         (socket emulation)."""
    427         return self.sock.getpeername()
    428 
    429     def settimeout(self, value):
    430         """Set a timeout on blocking socket operations (socket emulation)."""
    431         return self.sock.settimeout(value)
    432 
    433     def gettimeout(self):
    434         """Return the timeout associated with socket operations (socket
    435         emulation)."""
    436         return self.sock.gettimeout()
    437 
    438     def setsockopt(self, level, optname, value):
    439         """Set the value of the given socket option (socket emulation)."""
    440         return self.sock.setsockopt(level, optname, value)
    441 
    442 
    443      #*********************************************************
    444      # Public Functions END
    445      #*********************************************************
    446 
    447     def _shutdown(self, resumable):
    448         self._writeState = _ConnectionState()
    449         self._readState = _ConnectionState()
    450         #Don't do this: self._readBuffer = ""
    451         self.version = (0,0)
    452         self._versionCheck = False
    453         self.closed = True
    454         if self.closeSocket:
    455             self.sock.close()
    456 
    457         #Even if resumable is False, we'll never toggle this on
    458         if not resumable and self.session:
    459             self.session.resumable = False
    460 
    461 
    462     def _sendError(self, alertDescription, errorStr=None):
    463         alert = Alert().create(alertDescription, AlertLevel.fatal)
    464         for result in self._sendMsg(alert):
    465             yield result
    466         self._shutdown(False)
    467         raise TLSLocalAlert(alert, errorStr)
    468 
    469     def _sendMsgs(self, msgs):
    470         skipEmptyFrag = False
    471         for msg in msgs:
    472             for result in self._sendMsg(msg, skipEmptyFrag):
    473                 yield result
    474             skipEmptyFrag = True
    475 
    476     def _sendMsg(self, msg, skipEmptyFrag=False):
    477         bytes = msg.write()
    478         contentType = msg.contentType
    479 
    480         #Whenever we're connected and asked to send a message,
    481         #we first send an empty Application Data message.  This prevents
    482         #an attacker from launching a chosen-plaintext attack based on
    483         #knowing the next IV.
    484         if not self.closed and not skipEmptyFrag and self.version == (3,1):
    485             if self._writeState.encContext:
    486                 if self._writeState.encContext.isBlockCipher:
    487                     for result in self._sendMsg(ApplicationData(),
    488                                                skipEmptyFrag=True):
    489                         yield result
    490 
    491         #Update handshake hashes
    492         if contentType == ContentType.handshake:
    493             bytesStr = bytesToString(bytes)
    494             self._handshake_md5.update(bytesStr)
    495             self._handshake_sha.update(bytesStr)
    496 
    497         #Calculate MAC
    498         if self._writeState.macContext:
    499             seqnumStr = self._writeState.getSeqNumStr()
    500             bytesStr = bytesToString(bytes)
    501             mac = self._writeState.macContext.copy()
    502             mac.update(seqnumStr)
    503             mac.update(chr(contentType))
    504             if self.version == (3,0):
    505                 mac.update( chr( int(len(bytes)/256) ) )
    506                 mac.update( chr( int(len(bytes)%256) ) )
    507             elif self.version in ((3,1), (3,2)):
    508                 mac.update(chr(self.version[0]))
    509                 mac.update(chr(self.version[1]))
    510                 mac.update( chr( int(len(bytes)/256) ) )
    511                 mac.update( chr( int(len(bytes)%256) ) )
    512             else:
    513                 raise AssertionError()
    514             mac.update(bytesStr)
    515             macString = mac.digest()
    516             macBytes = stringToBytes(macString)
    517             if self.fault == Fault.badMAC:
    518                 macBytes[0] = (macBytes[0]+1) % 256
    519 
    520         #Encrypt for Block or Stream Cipher
    521         if self._writeState.encContext:
    522             #Add padding and encrypt (for Block Cipher):
    523             if self._writeState.encContext.isBlockCipher:
    524 
    525                 #Add TLS 1.1 fixed block
    526                 if self.version == (3,2):
    527                     bytes = self.fixedIVBlock + bytes
    528 
    529                 #Add padding: bytes = bytes + (macBytes + paddingBytes)
    530                 currentLength = len(bytes) + len(macBytes) + 1
    531                 blockLength = self._writeState.encContext.block_size
    532                 paddingLength = blockLength-(currentLength % blockLength)
    533 
    534                 paddingBytes = createByteArraySequence([paddingLength] * \
    535                                                        (paddingLength+1))
    536                 if self.fault == Fault.badPadding:
    537                     paddingBytes[0] = (paddingBytes[0]+1) % 256
    538                 endBytes = concatArrays(macBytes, paddingBytes)
    539                 bytes = concatArrays(bytes, endBytes)
    540                 #Encrypt
    541                 plaintext = stringToBytes(bytes)
    542                 ciphertext = self._writeState.encContext.encrypt(plaintext)
    543                 bytes = stringToBytes(ciphertext)
    544 
    545             #Encrypt (for Stream Cipher)
    546             else:
    547                 bytes = concatArrays(bytes, macBytes)
    548                 plaintext = bytesToString(bytes)
    549                 ciphertext = self._writeState.encContext.encrypt(plaintext)
    550                 bytes = stringToBytes(ciphertext)
    551 
    552         #Add record header and send
    553         r = RecordHeader3().create(self.version, contentType, len(bytes))
    554         s = bytesToString(concatArrays(r.write(), bytes))
    555         while 1:
    556             try:
    557                 bytesSent = self.sock.send(s) #Might raise socket.error
    558             except socket.error, why:
    559                 if why[0] == errno.EWOULDBLOCK:
    560                     yield 1
    561                     continue
    562                 else:
    563                     raise
    564             if bytesSent == len(s):
    565                 return
    566             s = s[bytesSent:]
    567             yield 1
    568 
    569 
    570     def _getMsg(self, expectedType, secondaryType=None, constructorType=None):
    571         try:
    572             if not isinstance(expectedType, tuple):
    573                 expectedType = (expectedType,)
    574 
    575             #Spin in a loop, until we've got a non-empty record of a type we
    576             #expect.  The loop will be repeated if:
    577             #  - we receive a renegotiation attempt; we send no_renegotiation,
    578             #    then try again
    579             #  - we receive an empty application-data fragment; we try again
    580             while 1:
    581                 for result in self._getNextRecord():
    582                     if result in (0,1):
    583                         yield result
    584                 recordHeader, p = result
    585 
    586                 #If this is an empty application-data fragment, try again
    587                 if recordHeader.type == ContentType.application_data:
    588                     if p.index == len(p.bytes):
    589                         continue
    590 
    591                 #If we received an unexpected record type...
    592                 if recordHeader.type not in expectedType:
    593 
    594                     #If we received an alert...
    595                     if recordHeader.type == ContentType.alert:
    596                         alert = Alert().parse(p)
    597 
    598                         #We either received a fatal error, a warning, or a
    599                         #close_notify.  In any case, we're going to close the
    600                         #connection.  In the latter two cases we respond with
    601                         #a close_notify, but ignore any socket errors, since
    602                         #the other side might have already closed the socket.
    603                         if alert.level == AlertLevel.warning or \
    604                            alert.description == AlertDescription.close_notify:
    605 
    606                             #If the sendMsg() call fails because the socket has
    607                             #already been closed, we will be forgiving and not
    608                             #report the error nor invalidate the "resumability"
    609                             #of the session.
    610                             try:
    611                                 alertMsg = Alert()
    612                                 alertMsg.create(AlertDescription.close_notify,
    613                                                 AlertLevel.warning)
    614                                 for result in self._sendMsg(alertMsg):
    615                                     yield result
    616                             except socket.error:
    617                                 pass
    618 
    619                             if alert.description == \
    620                                    AlertDescription.close_notify:
    621                                 self._shutdown(True)
    622                             elif alert.level == AlertLevel.warning:
    623                                 self._shutdown(False)
    624 
    625                         else: #Fatal alert:
    626                             self._shutdown(False)
    627 
    628                         #Raise the alert as an exception
    629                         raise TLSRemoteAlert(alert)
    630 
    631                     #If we received a renegotiation attempt...
    632                     if recordHeader.type == ContentType.handshake:
    633                         subType = p.get(1)
    634                         reneg = False
    635                         if self._client:
    636                             if subType == HandshakeType.hello_request:
    637                                 reneg = True
    638                         else:
    639                             if subType == HandshakeType.client_hello:
    640                                 reneg = True
    641                         #Send no_renegotiation, then try again
    642                         if reneg:
    643                             alertMsg = Alert()
    644                             alertMsg.create(AlertDescription.no_renegotiation,
    645                                             AlertLevel.warning)
    646                             for result in self._sendMsg(alertMsg):
    647                                 yield result
    648                             continue
    649 
    650                     #Otherwise: this is an unexpected record, but neither an
    651                     #alert nor renegotiation
    652                     for result in self._sendError(\
    653                             AlertDescription.unexpected_message,
    654                             "received type=%d" % recordHeader.type):
    655                         yield result
    656 
    657                 break
    658 
    659             #Parse based on content_type
    660             if recordHeader.type == ContentType.change_cipher_spec:
    661                 yield ChangeCipherSpec().parse(p)
    662             elif recordHeader.type == ContentType.alert:
    663                 yield Alert().parse(p)
    664             elif recordHeader.type == ContentType.application_data:
    665                 yield ApplicationData().parse(p)
    666             elif recordHeader.type == ContentType.handshake:
    667                 #Convert secondaryType to tuple, if it isn't already
    668                 if not isinstance(secondaryType, tuple):
    669                     secondaryType = (secondaryType,)
    670 
    671                 #If it's a handshake message, check handshake header
    672                 if recordHeader.ssl2:
    673                     subType = p.get(1)
    674                     if subType != HandshakeType.client_hello:
    675                         for result in self._sendError(\
    676                                 AlertDescription.unexpected_message,
    677                                 "Can only handle SSLv2 ClientHello messages"):
    678                             yield result
    679                     if HandshakeType.client_hello not in secondaryType:
    680                         for result in self._sendError(\
    681                                 AlertDescription.unexpected_message):
    682                             yield result
    683                     subType = HandshakeType.client_hello
    684                 else:
    685                     subType = p.get(1)
    686                     if subType not in secondaryType:
    687                         for result in self._sendError(\
    688                                 AlertDescription.unexpected_message,
    689                                 "Expecting %s, got %s" % (str(secondaryType), subType)):
    690                             yield result
    691 
    692                 #Update handshake hashes
    693                 sToHash = bytesToString(p.bytes)
    694                 self._handshake_md5.update(sToHash)
    695                 self._handshake_sha.update(sToHash)
    696 
    697                 #Parse based on handshake type
    698                 if subType == HandshakeType.client_hello:
    699                     yield ClientHello(recordHeader.ssl2).parse(p)
    700                 elif subType == HandshakeType.server_hello:
    701                     yield ServerHello().parse(p)
    702                 elif subType == HandshakeType.certificate:
    703                     yield Certificate(constructorType).parse(p)
    704                 elif subType == HandshakeType.certificate_request:
    705                     yield CertificateRequest().parse(p)
    706                 elif subType == HandshakeType.certificate_verify:
    707                     yield CertificateVerify().parse(p)
    708                 elif subType == HandshakeType.server_key_exchange:
    709                     yield ServerKeyExchange(constructorType).parse(p)
    710                 elif subType == HandshakeType.server_hello_done:
    711                     yield ServerHelloDone().parse(p)
    712                 elif subType == HandshakeType.client_key_exchange:
    713                     yield ClientKeyExchange(constructorType, \
    714                                             self.version).parse(p)
    715                 elif subType == HandshakeType.finished:
    716                     yield Finished(self.version).parse(p)
    717                 elif subType == HandshakeType.encrypted_extensions:
    718                     yield EncryptedExtensions().parse(p)
    719                 else:
    720                     raise AssertionError()
    721 
    722         #If an exception was raised by a Parser or Message instance:
    723         except SyntaxError, e:
    724             for result in self._sendError(AlertDescription.decode_error,
    725                                          formatExceptionTrace(e)):
    726                 yield result
    727 
    728 
    729     #Returns next record or next handshake message
    730     def _getNextRecord(self):
    731 
    732         #If there's a handshake message waiting, return it
    733         if self._handshakeBuffer:
    734             recordHeader, bytes = self._handshakeBuffer[0]
    735             self._handshakeBuffer = self._handshakeBuffer[1:]
    736             yield (recordHeader, Parser(bytes))
    737             return
    738 
    739         #Otherwise...
    740         #Read the next record header
    741         bytes = createByteArraySequence([])
    742         recordHeaderLength = 1
    743         ssl2 = False
    744         while 1:
    745             try:
    746                 s = self.sock.recv(recordHeaderLength-len(bytes))
    747             except socket.error, why:
    748                 if why[0] == errno.EWOULDBLOCK:
    749                     yield 0
    750                     continue
    751                 else:
    752                     raise
    753 
    754             #If the connection was abruptly closed, raise an error
    755             if len(s)==0:
    756                 raise TLSAbruptCloseError()
    757 
    758             bytes += stringToBytes(s)
    759             if len(bytes)==1:
    760                 if bytes[0] in ContentType.all:
    761                     ssl2 = False
    762                     recordHeaderLength = 5
    763                 elif bytes[0] == 128:
    764                     ssl2 = True
    765                     recordHeaderLength = 2
    766                 else:
    767                     raise SyntaxError()
    768             if len(bytes) == recordHeaderLength:
    769                 break
    770 
    771         #Parse the record header
    772         if ssl2:
    773             r = RecordHeader2().parse(Parser(bytes))
    774         else:
    775             r = RecordHeader3().parse(Parser(bytes))
    776 
    777         #Check the record header fields
    778         if r.length > 18432:
    779             for result in self._sendError(AlertDescription.record_overflow):
    780                 yield result
    781 
    782         #Read the record contents
    783         bytes = createByteArraySequence([])
    784         while 1:
    785             try:
    786                 s = self.sock.recv(r.length - len(bytes))
    787             except socket.error, why:
    788                 if why[0] == errno.EWOULDBLOCK:
    789                     yield 0
    790                     continue
    791                 else:
    792                     raise
    793 
    794             #If the connection is closed, raise a socket error
    795             if len(s)==0:
    796                     raise TLSAbruptCloseError()
    797 
    798             bytes += stringToBytes(s)
    799             if len(bytes) == r.length:
    800                 break
    801 
    802         #Check the record header fields (2)
    803         #We do this after reading the contents from the socket, so that
    804         #if there's an error, we at least don't leave extra bytes in the
    805         #socket..
    806         #
    807         # THIS CHECK HAS NO SECURITY RELEVANCE (?), BUT COULD HURT INTEROP.
    808         # SO WE LEAVE IT OUT FOR NOW.
    809         #
    810         #if self._versionCheck and r.version != self.version:
    811         #    for result in self._sendError(AlertDescription.protocol_version,
    812         #            "Version in header field: %s, should be %s" % (str(r.version),
    813         #                                                       str(self.version))):
    814         #        yield result
    815 
    816         #Decrypt the record
    817         for result in self._decryptRecord(r.type, bytes):
    818             if result in (0,1):
    819                 yield result
    820             else:
    821                 break
    822         bytes = result
    823         p = Parser(bytes)
    824 
    825         #If it doesn't contain handshake messages, we can just return it
    826         if r.type != ContentType.handshake:
    827             yield (r, p)
    828         #If it's an SSLv2 ClientHello, we can return it as well
    829         elif r.ssl2:
    830             yield (r, p)
    831         else:
    832             #Otherwise, we loop through and add the handshake messages to the
    833             #handshake buffer
    834             while 1:
    835                 if p.index == len(bytes): #If we're at the end
    836                     if not self._handshakeBuffer:
    837                         for result in self._sendError(\
    838                                 AlertDescription.decode_error, \
    839                                 "Received empty handshake record"):
    840                             yield result
    841                     break
    842                 #There needs to be at least 4 bytes to get a header
    843                 if p.index+4 > len(bytes):
    844                     for result in self._sendError(\
    845                             AlertDescription.decode_error,
    846                             "A record has a partial handshake message (1)"):
    847                         yield result
    848                 p.get(1) # skip handshake type
    849                 msgLength = p.get(3)
    850                 if p.index+msgLength > len(bytes):
    851                     for result in self._sendError(\
    852                             AlertDescription.decode_error,
    853                             "A record has a partial handshake message (2)"):
    854                         yield result
    855 
    856                 handshakePair = (r, bytes[p.index-4 : p.index+msgLength])
    857                 self._handshakeBuffer.append(handshakePair)
    858                 p.index += msgLength
    859 
    860             #We've moved at least one handshake message into the
    861             #handshakeBuffer, return the first one
    862             recordHeader, bytes = self._handshakeBuffer[0]
    863             self._handshakeBuffer = self._handshakeBuffer[1:]
    864             yield (recordHeader, Parser(bytes))
    865 
    866 
    867     def _decryptRecord(self, recordType, bytes):
    868         if self._readState.encContext:
    869 
    870             #Decrypt if it's a block cipher
    871             if self._readState.encContext.isBlockCipher:
    872                 blockLength = self._readState.encContext.block_size
    873                 if len(bytes) % blockLength != 0:
    874                     for result in self._sendError(\
    875                             AlertDescription.decryption_failed,
    876                             "Encrypted data not a multiple of blocksize"):
    877                         yield result
    878                 ciphertext = bytesToString(bytes)
    879                 plaintext = self._readState.encContext.decrypt(ciphertext)
    880                 if self.version == (3,2): #For TLS 1.1, remove explicit IV
    881                     plaintext = plaintext[self._readState.encContext.block_size : ]
    882                 bytes = stringToBytes(plaintext)
    883 
    884                 #Check padding
    885                 paddingGood = True
    886                 paddingLength = bytes[-1]
    887                 if (paddingLength+1) > len(bytes):
    888                     paddingGood=False
    889                     totalPaddingLength = 0
    890                 else:
    891                     if self.version == (3,0):
    892                         totalPaddingLength = paddingLength+1
    893                     elif self.version in ((3,1), (3,2)):
    894                         totalPaddingLength = paddingLength+1
    895                         paddingBytes = bytes[-totalPaddingLength:-1]
    896                         for byte in paddingBytes:
    897                             if byte != paddingLength:
    898                                 paddingGood = False
    899                                 totalPaddingLength = 0
    900                     else:
    901                         raise AssertionError()
    902 
    903             #Decrypt if it's a stream cipher
    904             else:
    905                 paddingGood = True
    906                 ciphertext = bytesToString(bytes)
    907                 plaintext = self._readState.encContext.decrypt(ciphertext)
    908                 bytes = stringToBytes(plaintext)
    909                 totalPaddingLength = 0
    910 
    911             #Check MAC
    912             macGood = True
    913             macLength = self._readState.macContext.digest_size
    914             endLength = macLength + totalPaddingLength
    915             if endLength > len(bytes):
    916                 macGood = False
    917             else:
    918                 #Read MAC
    919                 startIndex = len(bytes) - endLength
    920                 endIndex = startIndex + macLength
    921                 checkBytes = bytes[startIndex : endIndex]
    922 
    923                 #Calculate MAC
    924                 seqnumStr = self._readState.getSeqNumStr()
    925                 bytes = bytes[:-endLength]
    926                 bytesStr = bytesToString(bytes)
    927                 mac = self._readState.macContext.copy()
    928                 mac.update(seqnumStr)
    929                 mac.update(chr(recordType))
    930                 if self.version == (3,0):
    931                     mac.update( chr( int(len(bytes)/256) ) )
    932                     mac.update( chr( int(len(bytes)%256) ) )
    933                 elif self.version in ((3,1), (3,2)):
    934                     mac.update(chr(self.version[0]))
    935                     mac.update(chr(self.version[1]))
    936                     mac.update( chr( int(len(bytes)/256) ) )
    937                     mac.update( chr( int(len(bytes)%256) ) )
    938                 else:
    939                     raise AssertionError()
    940                 mac.update(bytesStr)
    941                 macString = mac.digest()
    942                 macBytes = stringToBytes(macString)
    943 
    944                 #Compare MACs
    945                 if macBytes != checkBytes:
    946                     macGood = False
    947 
    948             if not (paddingGood and macGood):
    949                 for result in self._sendError(AlertDescription.bad_record_mac,
    950                                           "MAC failure (or padding failure)"):
    951                     yield result
    952 
    953         yield bytes
    954 
    955     def _handshakeStart(self, client):
    956         self._client = client
    957         self._handshake_md5 = md5.md5()
    958         self._handshake_sha = sha.sha()
    959         self._handshakeBuffer = []
    960         self.allegedSharedKeyUsername = None
    961         self.allegedSrpUsername = None
    962         self._refCount = 1
    963 
    964     def _handshakeDone(self, resumed):
    965         self.resumed = resumed
    966         self.closed = False
    967 
    968     def _calcPendingStates(self, clientRandom, serverRandom, implementations):
    969         if self.session.cipherSuite in CipherSuite.aes128Suites:
    970             macLength = 20
    971             keyLength = 16
    972             ivLength = 16
    973             createCipherFunc = createAES
    974         elif self.session.cipherSuite in CipherSuite.aes256Suites:
    975             macLength = 20
    976             keyLength = 32
    977             ivLength = 16
    978             createCipherFunc = createAES
    979         elif self.session.cipherSuite in CipherSuite.rc4Suites:
    980             macLength = 20
    981             keyLength = 16
    982             ivLength = 0
    983             createCipherFunc = createRC4
    984         elif self.session.cipherSuite in CipherSuite.tripleDESSuites:
    985             macLength = 20
    986             keyLength = 24
    987             ivLength = 8
    988             createCipherFunc = createTripleDES
    989         else:
    990             raise AssertionError()
    991 
    992         if self.version == (3,0):
    993             createMACFunc = MAC_SSL
    994         elif self.version in ((3,1), (3,2)):
    995             createMACFunc = hmac.HMAC
    996 
    997         outputLength = (macLength*2) + (keyLength*2) + (ivLength*2)
    998 
    999         #Calculate Keying Material from Master Secret
   1000         if self.version == (3,0):
   1001             keyBlock = PRF_SSL(self.session.masterSecret,
   1002                                concatArrays(serverRandom, clientRandom),
   1003                                outputLength)
   1004         elif self.version in ((3,1), (3,2)):
   1005             keyBlock = PRF(self.session.masterSecret,
   1006                            "key expansion",
   1007                            concatArrays(serverRandom,clientRandom),
   1008                            outputLength)
   1009         else:
   1010             raise AssertionError()
   1011 
   1012         #Slice up Keying Material
   1013         clientPendingState = _ConnectionState()
   1014         serverPendingState = _ConnectionState()
   1015         p = Parser(keyBlock)
   1016         clientMACBlock = bytesToString(p.getFixBytes(macLength))
   1017         serverMACBlock = bytesToString(p.getFixBytes(macLength))
   1018         clientKeyBlock = bytesToString(p.getFixBytes(keyLength))
   1019         serverKeyBlock = bytesToString(p.getFixBytes(keyLength))
   1020         clientIVBlock  = bytesToString(p.getFixBytes(ivLength))
   1021         serverIVBlock  = bytesToString(p.getFixBytes(ivLength))
   1022         clientPendingState.macContext = createMACFunc(clientMACBlock,
   1023                                                       digestmod=sha)
   1024         serverPendingState.macContext = createMACFunc(serverMACBlock,
   1025                                                       digestmod=sha)
   1026         clientPendingState.encContext = createCipherFunc(clientKeyBlock,
   1027                                                          clientIVBlock,
   1028                                                          implementations)
   1029         serverPendingState.encContext = createCipherFunc(serverKeyBlock,
   1030                                                          serverIVBlock,
   1031                                                          implementations)
   1032 
   1033         #Assign new connection states to pending states
   1034         if self._client:
   1035             self._pendingWriteState = clientPendingState
   1036             self._pendingReadState = serverPendingState
   1037         else:
   1038             self._pendingWriteState = serverPendingState
   1039             self._pendingReadState = clientPendingState
   1040 
   1041         if self.version == (3,2) and ivLength:
   1042             #Choose fixedIVBlock for TLS 1.1 (this is encrypted with the CBC
   1043             #residue to create the IV for each sent block)
   1044             self.fixedIVBlock = getRandomBytes(ivLength)
   1045 
   1046     def _changeWriteState(self):
   1047         self._writeState = self._pendingWriteState
   1048         self._pendingWriteState = _ConnectionState()
   1049 
   1050     def _changeReadState(self):
   1051         self._readState = self._pendingReadState
   1052         self._pendingReadState = _ConnectionState()
   1053 
   1054     def _sendFinished(self):
   1055         #Send ChangeCipherSpec
   1056         for result in self._sendMsg(ChangeCipherSpec()):
   1057             yield result
   1058 
   1059         #Switch to pending write state
   1060         self._changeWriteState()
   1061 
   1062         #Calculate verification data
   1063         verifyData = self._calcFinished(True)
   1064         if self.fault == Fault.badFinished:
   1065             verifyData[0] = (verifyData[0]+1)%256
   1066 
   1067         #Send Finished message under new state
   1068         finished = Finished(self.version).create(verifyData)
   1069         for result in self._sendMsg(finished):
   1070             yield result
   1071 
   1072     def _getChangeCipherSpec(self):
   1073         #Get and check ChangeCipherSpec
   1074         for result in self._getMsg(ContentType.change_cipher_spec):
   1075             if result in (0,1):
   1076                 yield result
   1077         changeCipherSpec = result
   1078 
   1079         if changeCipherSpec.type != 1:
   1080             for result in self._sendError(AlertDescription.illegal_parameter,
   1081                                          "ChangeCipherSpec type incorrect"):
   1082                 yield result
   1083 
   1084         #Switch to pending read state
   1085         self._changeReadState()
   1086 
   1087     def _getEncryptedExtensions(self):
   1088         for result in self._getMsg(ContentType.handshake,
   1089                                    HandshakeType.encrypted_extensions):
   1090             if result in (0,1):
   1091                 yield result
   1092         encrypted_extensions = result
   1093         self.channel_id = encrypted_extensions.channel_id_key
   1094 
   1095     def _getFinished(self):
   1096         #Calculate verification data
   1097         verifyData = self._calcFinished(False)
   1098 
   1099         #Get and check Finished message under new state
   1100         for result in self._getMsg(ContentType.handshake,
   1101                                   HandshakeType.finished):
   1102             if result in (0,1):
   1103                 yield result
   1104         finished = result
   1105         if finished.verify_data != verifyData:
   1106             for result in self._sendError(AlertDescription.decrypt_error,
   1107                                          "Finished message is incorrect"):
   1108                 yield result
   1109 
   1110     def _calcFinished(self, send=True):
   1111         if self.version == (3,0):
   1112             if (self._client and send) or (not self._client and not send):
   1113                 senderStr = "\x43\x4C\x4E\x54"
   1114             else:
   1115                 senderStr = "\x53\x52\x56\x52"
   1116 
   1117             verifyData = self._calcSSLHandshakeHash(self.session.masterSecret,
   1118                                                    senderStr)
   1119             return verifyData
   1120 
   1121         elif self.version in ((3,1), (3,2)):
   1122             if (self._client and send) or (not self._client and not send):
   1123                 label = "client finished"
   1124             else:
   1125                 label = "server finished"
   1126 
   1127             handshakeHashes = stringToBytes(self._handshake_md5.digest() + \
   1128                                             self._handshake_sha.digest())
   1129             verifyData = PRF(self.session.masterSecret, label, handshakeHashes,
   1130                              12)
   1131             return verifyData
   1132         else:
   1133             raise AssertionError()
   1134 
   1135     #Used for Finished messages and CertificateVerify messages in SSL v3
   1136     def _calcSSLHandshakeHash(self, masterSecret, label):
   1137         masterSecretStr = bytesToString(masterSecret)
   1138 
   1139         imac_md5 = self._handshake_md5.copy()
   1140         imac_sha = self._handshake_sha.copy()
   1141 
   1142         imac_md5.update(label + masterSecretStr + '\x36'*48)
   1143         imac_sha.update(label + masterSecretStr + '\x36'*40)
   1144 
   1145         md5Str = md5.md5(masterSecretStr + ('\x5c'*48) + \
   1146                          imac_md5.digest()).digest()
   1147         shaStr = sha.sha(masterSecretStr + ('\x5c'*40) + \
   1148                          imac_sha.digest()).digest()
   1149 
   1150         return stringToBytes(md5Str + shaStr)
   1151