Home | History | Annotate | Download | only in Lib
      1 # Wrapper module for _ssl, providing some additional facilities

      2 # implemented in Python.  Written by Bill Janssen.

      3 
      4 """\
      5 This module provides some more Pythonic support for SSL.
      6 
      7 Object types:
      8 
      9   SSLSocket -- subtype of socket.socket which does SSL over the socket
     10 
     11 Exceptions:
     12 
     13   SSLError -- exception raised for I/O errors
     14 
     15 Functions:
     16 
     17   cert_time_to_seconds -- convert time string used for certificate
     18                           notBefore and notAfter functions to integer
     19                           seconds past the Epoch (the time values
     20                           returned from time.time())
     21 
     22   fetch_server_certificate (HOST, PORT) -- fetch the certificate provided
     23                           by the server running on HOST at port PORT.  No
     24                           validation of the certificate is performed.
     25 
     26 Integer constants:
     27 
     28 SSL_ERROR_ZERO_RETURN
     29 SSL_ERROR_WANT_READ
     30 SSL_ERROR_WANT_WRITE
     31 SSL_ERROR_WANT_X509_LOOKUP
     32 SSL_ERROR_SYSCALL
     33 SSL_ERROR_SSL
     34 SSL_ERROR_WANT_CONNECT
     35 
     36 SSL_ERROR_EOF
     37 SSL_ERROR_INVALID_ERROR_CODE
     38 
     39 The following group define certificate requirements that one side is
     40 allowing/requiring from the other side:
     41 
     42 CERT_NONE - no certificates from the other side are required (or will
     43             be looked at if provided)
     44 CERT_OPTIONAL - certificates are not required, but if provided will be
     45                 validated, and if validation fails, the connection will
     46                 also fail
     47 CERT_REQUIRED - certificates are required, and will be validated, and
     48                 if validation fails, the connection will also fail
     49 
     50 The following constants identify various SSL protocol variants:
     51 
     52 PROTOCOL_SSLv2
     53 PROTOCOL_SSLv3
     54 PROTOCOL_SSLv23
     55 PROTOCOL_TLSv1
     56 """
     57 
     58 import textwrap
     59 
     60 import _ssl             # if we can't import it, let the error propagate

     61 
     62 from _ssl import OPENSSL_VERSION_NUMBER, OPENSSL_VERSION_INFO, OPENSSL_VERSION
     63 from _ssl import SSLError
     64 from _ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED
     65 from _ssl import RAND_status, RAND_egd, RAND_add
     66 from _ssl import \
     67      SSL_ERROR_ZERO_RETURN, \
     68      SSL_ERROR_WANT_READ, \
     69      SSL_ERROR_WANT_WRITE, \
     70      SSL_ERROR_WANT_X509_LOOKUP, \
     71      SSL_ERROR_SYSCALL, \
     72      SSL_ERROR_SSL, \
     73      SSL_ERROR_WANT_CONNECT, \
     74      SSL_ERROR_EOF, \
     75      SSL_ERROR_INVALID_ERROR_CODE
     76 from _ssl import PROTOCOL_SSLv3, PROTOCOL_SSLv23, PROTOCOL_TLSv1
     77 _PROTOCOL_NAMES = {
     78     PROTOCOL_TLSv1: "TLSv1",
     79     PROTOCOL_SSLv23: "SSLv23",
     80     PROTOCOL_SSLv3: "SSLv3",
     81 }
     82 try:
     83     from _ssl import PROTOCOL_SSLv2
     84 except ImportError:
     85     pass
     86 else:
     87     _PROTOCOL_NAMES[PROTOCOL_SSLv2] = "SSLv2"
     88 
     89 from socket import socket, _fileobject, _delegate_methods, error as socket_error
     90 from socket import getnameinfo as _getnameinfo
     91 import base64        # for DER-to-PEM translation

     92 import errno
     93 
     94 class SSLSocket(socket):
     95 
     96     """This class implements a subtype of socket.socket that wraps
     97     the underlying OS socket in an SSL context when necessary, and
     98     provides read and write methods over that channel."""
     99 
    100     def __init__(self, sock, keyfile=None, certfile=None,
    101                  server_side=False, cert_reqs=CERT_NONE,
    102                  ssl_version=PROTOCOL_SSLv23, ca_certs=None,
    103                  do_handshake_on_connect=True,
    104                  suppress_ragged_eofs=True, ciphers=None):
    105         socket.__init__(self, _sock=sock._sock)
    106         # The initializer for socket overrides the methods send(), recv(), etc.

    107         # in the instancce, which we don't need -- but we want to provide the

    108         # methods defined in SSLSocket.

    109         for attr in _delegate_methods:
    110             try:
    111                 delattr(self, attr)
    112             except AttributeError:
    113                 pass
    114 
    115         if certfile and not keyfile:
    116             keyfile = certfile
    117         # see if it's connected

    118         try:
    119             socket.getpeername(self)
    120         except socket_error, e:
    121             if e.errno != errno.ENOTCONN:
    122                 raise
    123             # no, no connection yet

    124             self._connected = False
    125             self._sslobj = None
    126         else:
    127             # yes, create the SSL object

    128             self._connected = True
    129             self._sslobj = _ssl.sslwrap(self._sock, server_side,
    130                                         keyfile, certfile,
    131                                         cert_reqs, ssl_version, ca_certs,
    132                                         ciphers)
    133             if do_handshake_on_connect:
    134                 self.do_handshake()
    135         self.keyfile = keyfile
    136         self.certfile = certfile
    137         self.cert_reqs = cert_reqs
    138         self.ssl_version = ssl_version
    139         self.ca_certs = ca_certs
    140         self.ciphers = ciphers
    141         self.do_handshake_on_connect = do_handshake_on_connect
    142         self.suppress_ragged_eofs = suppress_ragged_eofs
    143         self._makefile_refs = 0
    144 
    145     def read(self, len=1024):
    146 
    147         """Read up to LEN bytes and return them.
    148         Return zero-length string on EOF."""
    149 
    150         try:
    151             return self._sslobj.read(len)
    152         except SSLError, x:
    153             if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs:
    154                 return ''
    155             else:
    156                 raise
    157 
    158     def write(self, data):
    159 
    160         """Write DATA to the underlying SSL channel.  Returns
    161         number of bytes of DATA actually transmitted."""
    162 
    163         return self._sslobj.write(data)
    164 
    165     def getpeercert(self, binary_form=False):
    166 
    167         """Returns a formatted version of the data in the
    168         certificate provided by the other end of the SSL channel.
    169         Return None if no certificate was provided, {} if a
    170         certificate was provided, but not validated."""
    171 
    172         return self._sslobj.peer_certificate(binary_form)
    173 
    174     def cipher(self):
    175 
    176         if not self._sslobj:
    177             return None
    178         else:
    179             return self._sslobj.cipher()
    180 
    181     def send(self, data, flags=0):
    182         if self._sslobj:
    183             if flags != 0:
    184                 raise ValueError(
    185                     "non-zero flags not allowed in calls to send() on %s" %
    186                     self.__class__)
    187             while True:
    188                 try:
    189                     v = self._sslobj.write(data)
    190                 except SSLError, x:
    191                     if x.args[0] == SSL_ERROR_WANT_READ:
    192                         return 0
    193                     elif x.args[0] == SSL_ERROR_WANT_WRITE:
    194                         return 0
    195                     else:
    196                         raise
    197                 else:
    198                     return v
    199         else:
    200             return self._sock.send(data, flags)
    201 
    202     def sendto(self, data, flags_or_addr, addr=None):
    203         if self._sslobj:
    204             raise ValueError("sendto not allowed on instances of %s" %
    205                              self.__class__)
    206         elif addr is None:
    207             return self._sock.sendto(data, flags_or_addr)
    208         else:
    209             return self._sock.sendto(data, flags_or_addr, addr)
    210 
    211     def sendall(self, data, flags=0):
    212         if self._sslobj:
    213             if flags != 0:
    214                 raise ValueError(
    215                     "non-zero flags not allowed in calls to sendall() on %s" %
    216                     self.__class__)
    217             amount = len(data)
    218             count = 0
    219             while (count < amount):
    220                 v = self.send(data[count:])
    221                 count += v
    222             return amount
    223         else:
    224             return socket.sendall(self, data, flags)
    225 
    226     def recv(self, buflen=1024, flags=0):
    227         if self._sslobj:
    228             if flags != 0:
    229                 raise ValueError(
    230                     "non-zero flags not allowed in calls to recv() on %s" %
    231                     self.__class__)
    232             return self.read(buflen)
    233         else:
    234             return self._sock.recv(buflen, flags)
    235 
    236     def recv_into(self, buffer, nbytes=None, flags=0):
    237         if buffer and (nbytes is None):
    238             nbytes = len(buffer)
    239         elif nbytes is None:
    240             nbytes = 1024
    241         if self._sslobj:
    242             if flags != 0:
    243                 raise ValueError(
    244                   "non-zero flags not allowed in calls to recv_into() on %s" %
    245                   self.__class__)
    246             tmp_buffer = self.read(nbytes)
    247             v = len(tmp_buffer)
    248             buffer[:v] = tmp_buffer
    249             return v
    250         else:
    251             return self._sock.recv_into(buffer, nbytes, flags)
    252 
    253     def recvfrom(self, buflen=1024, flags=0):
    254         if self._sslobj:
    255             raise ValueError("recvfrom not allowed on instances of %s" %
    256                              self.__class__)
    257         else:
    258             return self._sock.recvfrom(buflen, flags)
    259 
    260     def recvfrom_into(self, buffer, nbytes=None, flags=0):
    261         if self._sslobj:
    262             raise ValueError("recvfrom_into not allowed on instances of %s" %
    263                              self.__class__)
    264         else:
    265             return self._sock.recvfrom_into(buffer, nbytes, flags)
    266 
    267     def pending(self):
    268         if self._sslobj:
    269             return self._sslobj.pending()
    270         else:
    271             return 0
    272 
    273     def unwrap(self):
    274         if self._sslobj:
    275             s = self._sslobj.shutdown()
    276             self._sslobj = None
    277             return s
    278         else:
    279             raise ValueError("No SSL wrapper around " + str(self))
    280 
    281     def shutdown(self, how):
    282         self._sslobj = None
    283         socket.shutdown(self, how)
    284 
    285     def close(self):
    286         if self._makefile_refs < 1:
    287             self._sslobj = None
    288             socket.close(self)
    289         else:
    290             self._makefile_refs -= 1
    291 
    292     def do_handshake(self):
    293 
    294         """Perform a TLS/SSL handshake."""
    295 
    296         self._sslobj.do_handshake()
    297 
    298     def _real_connect(self, addr, return_errno):
    299         # Here we assume that the socket is client-side, and not

    300         # connected at the time of the call.  We connect it, then wrap it.

    301         if self._connected:
    302             raise ValueError("attempt to connect already-connected SSLSocket!")
    303         self._sslobj = _ssl.sslwrap(self._sock, False, self.keyfile, self.certfile,
    304                                     self.cert_reqs, self.ssl_version,
    305                                     self.ca_certs, self.ciphers)
    306         try:
    307             socket.connect(self, addr)
    308             if self.do_handshake_on_connect:
    309                 self.do_handshake()
    310         except socket_error as e:
    311             if return_errno:
    312                 return e.errno
    313             else:
    314                 self._sslobj = None
    315                 raise e
    316         self._connected = True
    317         return 0
    318 
    319     def connect(self, addr):
    320         """Connects to remote ADDR, and then wraps the connection in
    321         an SSL channel."""
    322         self._real_connect(addr, False)
    323 
    324     def connect_ex(self, addr):
    325         """Connects to remote ADDR, and then wraps the connection in
    326         an SSL channel."""
    327         return self._real_connect(addr, True)
    328 
    329     def accept(self):
    330 
    331         """Accepts a new connection from a remote client, and returns
    332         a tuple containing that new connection wrapped with a server-side
    333         SSL channel, and the address of the remote client."""
    334 
    335         newsock, addr = socket.accept(self)
    336         return (SSLSocket(newsock,
    337                           keyfile=self.keyfile,
    338                           certfile=self.certfile,
    339                           server_side=True,
    340                           cert_reqs=self.cert_reqs,
    341                           ssl_version=self.ssl_version,
    342                           ca_certs=self.ca_certs,
    343                           ciphers=self.ciphers,
    344                           do_handshake_on_connect=self.do_handshake_on_connect,
    345                           suppress_ragged_eofs=self.suppress_ragged_eofs),
    346                 addr)
    347 
    348     def makefile(self, mode='r', bufsize=-1):
    349 
    350         """Make and return a file-like object that
    351         works with the SSL connection.  Just use the code
    352         from the socket module."""
    353 
    354         self._makefile_refs += 1
    355         # close=True so as to decrement the reference count when done with

    356         # the file-like object.

    357         return _fileobject(self, mode, bufsize, close=True)
    358 
    359 
    360 
    361 def wrap_socket(sock, keyfile=None, certfile=None,
    362                 server_side=False, cert_reqs=CERT_NONE,
    363                 ssl_version=PROTOCOL_SSLv23, ca_certs=None,
    364                 do_handshake_on_connect=True,
    365                 suppress_ragged_eofs=True, ciphers=None):
    366 
    367     return SSLSocket(sock, keyfile=keyfile, certfile=certfile,
    368                      server_side=server_side, cert_reqs=cert_reqs,
    369                      ssl_version=ssl_version, ca_certs=ca_certs,
    370                      do_handshake_on_connect=do_handshake_on_connect,
    371                      suppress_ragged_eofs=suppress_ragged_eofs,
    372                      ciphers=ciphers)
    373 
    374 
    375 # some utility functions

    376 
    377 def cert_time_to_seconds(cert_time):
    378 
    379     """Takes a date-time string in standard ASN1_print form
    380     ("MON DAY 24HOUR:MINUTE:SEC YEAR TIMEZONE") and return
    381     a Python time value in seconds past the epoch."""
    382 
    383     import time
    384     return time.mktime(time.strptime(cert_time, "%b %d %H:%M:%S %Y GMT"))
    385 
    386 PEM_HEADER = "-----BEGIN CERTIFICATE-----"
    387 PEM_FOOTER = "-----END CERTIFICATE-----"
    388 
    389 def DER_cert_to_PEM_cert(der_cert_bytes):
    390 
    391     """Takes a certificate in binary DER format and returns the
    392     PEM version of it as a string."""
    393 
    394     if hasattr(base64, 'standard_b64encode'):
    395         # preferred because older API gets line-length wrong

    396         f = base64.standard_b64encode(der_cert_bytes)
    397         return (PEM_HEADER + '\n' +
    398                 textwrap.fill(f, 64) + '\n' +
    399                 PEM_FOOTER + '\n')
    400     else:
    401         return (PEM_HEADER + '\n' +
    402                 base64.encodestring(der_cert_bytes) +
    403                 PEM_FOOTER + '\n')
    404 
    405 def PEM_cert_to_DER_cert(pem_cert_string):
    406 
    407     """Takes a certificate in ASCII PEM format and returns the
    408     DER-encoded version of it as a byte sequence"""
    409 
    410     if not pem_cert_string.startswith(PEM_HEADER):
    411         raise ValueError("Invalid PEM encoding; must start with %s"
    412                          % PEM_HEADER)
    413     if not pem_cert_string.strip().endswith(PEM_FOOTER):
    414         raise ValueError("Invalid PEM encoding; must end with %s"
    415                          % PEM_FOOTER)
    416     d = pem_cert_string.strip()[len(PEM_HEADER):-len(PEM_FOOTER)]
    417     return base64.decodestring(d)
    418 
    419 def get_server_certificate(addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None):
    420 
    421     """Retrieve the certificate from the server at the specified address,
    422     and return it as a PEM-encoded string.
    423     If 'ca_certs' is specified, validate the server cert against it.
    424     If 'ssl_version' is specified, use it in the connection attempt."""
    425 
    426     host, port = addr
    427     if (ca_certs is not None):
    428         cert_reqs = CERT_REQUIRED
    429     else:
    430         cert_reqs = CERT_NONE
    431     s = wrap_socket(socket(), ssl_version=ssl_version,
    432                     cert_reqs=cert_reqs, ca_certs=ca_certs)
    433     s.connect(addr)
    434     dercert = s.getpeercert(True)
    435     s.close()
    436     return DER_cert_to_PEM_cert(dercert)
    437 
    438 def get_protocol_name(protocol_code):
    439     return _PROTOCOL_NAMES.get(protocol_code, '<unknown>')
    440 
    441 
    442 # a replacement for the old socket.ssl function

    443 
    444 def sslwrap_simple(sock, keyfile=None, certfile=None):
    445 
    446     """A replacement for the old socket.ssl function.  Designed
    447     for compability with Python 2.5 and earlier.  Will disappear in
    448     Python 3.0."""
    449 
    450     if hasattr(sock, "_sock"):
    451         sock = sock._sock
    452 
    453     ssl_sock = _ssl.sslwrap(sock, 0, keyfile, certfile, CERT_NONE,
    454                             PROTOCOL_SSLv23, None)
    455     try:
    456         sock.getpeername()
    457     except socket_error:
    458         # no, no connection yet

    459         pass
    460     else:
    461         # yes, do the handshake

    462         ssl_sock.do_handshake()
    463 
    464     return ssl_sock
    465