Home | History | Annotate | Download | only in integration
      1 """TLS Lite + Twisted."""
      2 
      3 from twisted.protocols.policies import ProtocolWrapper, WrappingFactory
      4 from twisted.python.failure import Failure
      5 
      6 from AsyncStateMachine import AsyncStateMachine
      7 from tlslite.TLSConnection import TLSConnection
      8 from tlslite.errors import *
      9 
     10 import socket
     11 import errno
     12 
     13 
     14 #The TLSConnection is created around a "fake socket" that
     15 #plugs it into the underlying Twisted transport
     16 class _FakeSocket:
     17     def __init__(self, wrapper):
     18         self.wrapper = wrapper
     19         self.data = ""
     20 
     21     def send(self, data):
     22         ProtocolWrapper.write(self.wrapper, data)
     23         return len(data)
     24 
     25     def recv(self, numBytes):
     26         if self.data == "":
     27             raise socket.error, (errno.EWOULDBLOCK, "")
     28         returnData = self.data[:numBytes]
     29         self.data = self.data[numBytes:]
     30         return returnData
     31 
     32 class TLSTwistedProtocolWrapper(ProtocolWrapper, AsyncStateMachine):
     33     """This class can wrap Twisted protocols to add TLS support.
     34 
     35     Below is a complete example of using TLS Lite with a Twisted echo
     36     server.
     37 
     38     There are two server implementations below.  Echo is the original
     39     protocol, which is oblivious to TLS.  Echo1 subclasses Echo and
     40     negotiates TLS when the client connects.  Echo2 subclasses Echo and
     41     negotiates TLS when the client sends "STARTTLS"::
     42 
     43         from twisted.internet.protocol import Protocol, Factory
     44         from twisted.internet import reactor
     45         from twisted.protocols.policies import WrappingFactory
     46         from twisted.protocols.basic import LineReceiver
     47         from twisted.python import log
     48         from twisted.python.failure import Failure
     49         import sys
     50         from tlslite.api import *
     51 
     52         s = open("./serverX509Cert.pem").read()
     53         x509 = X509()
     54         x509.parse(s)
     55         certChain = X509CertChain([x509])
     56 
     57         s = open("./serverX509Key.pem").read()
     58         privateKey = parsePEMKey(s, private=True)
     59 
     60         verifierDB = VerifierDB("verifierDB")
     61         verifierDB.open()
     62 
     63         class Echo(LineReceiver):
     64             def connectionMade(self):
     65                 self.transport.write("Welcome to the echo server!\\r\\n")
     66 
     67             def lineReceived(self, line):
     68                 self.transport.write(line + "\\r\\n")
     69 
     70         class Echo1(Echo):
     71             def connectionMade(self):
     72                 if not self.transport.tlsStarted:
     73                     self.transport.setServerHandshakeOp(certChain=certChain,
     74                                                         privateKey=privateKey,
     75                                                         verifierDB=verifierDB)
     76                 else:
     77                     Echo.connectionMade(self)
     78 
     79             def connectionLost(self, reason):
     80                 pass #Handle any TLS exceptions here
     81 
     82         class Echo2(Echo):
     83             def lineReceived(self, data):
     84                 if data == "STARTTLS":
     85                     self.transport.setServerHandshakeOp(certChain=certChain,
     86                                                         privateKey=privateKey,
     87                                                         verifierDB=verifierDB)
     88                 else:
     89                     Echo.lineReceived(self, data)
     90 
     91             def connectionLost(self, reason):
     92                 pass #Handle any TLS exceptions here
     93 
     94         factory = Factory()
     95         factory.protocol = Echo1
     96         #factory.protocol = Echo2
     97 
     98         wrappingFactory = WrappingFactory(factory)
     99         wrappingFactory.protocol = TLSTwistedProtocolWrapper
    100 
    101         log.startLogging(sys.stdout)
    102         reactor.listenTCP(1079, wrappingFactory)
    103         reactor.run()
    104 
    105     This class works as follows:
    106 
    107     Data comes in and is given to the AsyncStateMachine for handling.
    108     AsyncStateMachine will forward events to this class, and we'll
    109     pass them on to the ProtocolHandler, which will proxy them to the
    110     wrapped protocol.  The wrapped protocol may then call back into
    111     this class, and these calls will be proxied into the
    112     AsyncStateMachine.
    113 
    114     The call graph looks like this:
    115      - self.dataReceived
    116        - AsyncStateMachine.inReadEvent
    117          - self.out(Connect|Close|Read)Event
    118            - ProtocolWrapper.(connectionMade|loseConnection|dataReceived)
    119              - self.(loseConnection|write|writeSequence)
    120                - AsyncStateMachine.(setCloseOp|setWriteOp)
    121     """
    122 
    123     #WARNING: IF YOU COPY-AND-PASTE THE ABOVE CODE, BE SURE TO REMOVE
    124     #THE EXTRA ESCAPING AROUND "\\r\\n"
    125 
    126     def __init__(self, factory, wrappedProtocol):
    127         ProtocolWrapper.__init__(self, factory, wrappedProtocol)
    128         AsyncStateMachine.__init__(self)
    129         self.fakeSocket = _FakeSocket(self)
    130         self.tlsConnection = TLSConnection(self.fakeSocket)
    131         self.tlsStarted = False
    132         self.connectionLostCalled = False
    133 
    134     def connectionMade(self):
    135         try:
    136             ProtocolWrapper.connectionMade(self)
    137         except TLSError, e:
    138             self.connectionLost(Failure(e))
    139             ProtocolWrapper.loseConnection(self)
    140 
    141     def dataReceived(self, data):
    142         try:
    143             if not self.tlsStarted:
    144                 ProtocolWrapper.dataReceived(self, data)
    145             else:
    146                 self.fakeSocket.data += data
    147                 while self.fakeSocket.data:
    148                     AsyncStateMachine.inReadEvent(self)
    149         except TLSError, e:
    150             self.connectionLost(Failure(e))
    151             ProtocolWrapper.loseConnection(self)
    152 
    153     def connectionLost(self, reason):
    154         if not self.connectionLostCalled:
    155             ProtocolWrapper.connectionLost(self, reason)
    156             self.connectionLostCalled = True
    157 
    158 
    159     def outConnectEvent(self):
    160         ProtocolWrapper.connectionMade(self)
    161 
    162     def outCloseEvent(self):
    163         ProtocolWrapper.loseConnection(self)
    164 
    165     def outReadEvent(self, data):
    166         if data == "":
    167             ProtocolWrapper.loseConnection(self)
    168         else:
    169             ProtocolWrapper.dataReceived(self, data)
    170 
    171 
    172     def setServerHandshakeOp(self, **args):
    173         self.tlsStarted = True
    174         AsyncStateMachine.setServerHandshakeOp(self, **args)
    175 
    176     def loseConnection(self):
    177         if not self.tlsStarted:
    178             ProtocolWrapper.loseConnection(self)
    179         else:
    180             AsyncStateMachine.setCloseOp(self)
    181 
    182     def write(self, data):
    183         if not self.tlsStarted:
    184             ProtocolWrapper.write(self, data)
    185         else:
    186             #Because of the FakeSocket, write operations are guaranteed to
    187             #terminate immediately.
    188             AsyncStateMachine.setWriteOp(self, data)
    189 
    190     def writeSequence(self, seq):
    191         if not self.tlsStarted:
    192             ProtocolWrapper.writeSequence(self, seq)
    193         else:
    194             #Because of the FakeSocket, write operations are guaranteed to
    195             #terminate immediately.
    196             AsyncStateMachine.setWriteOp(self, "".join(seq))