1 # Wrapper module for _ssl, providing some additional facilities 2 # implemented in Python. Written by Bill Janssen. 3 4 """\ 5 This module provides some more Pythonic support for SSL. 6 7 Object types: 8 9 SSLSocket -- subtype of socket.socket which does SSL over the socket 10 11 Exceptions: 12 13 SSLError -- exception raised for I/O errors 14 15 Functions: 16 17 cert_time_to_seconds -- convert time string used for certificate 18 notBefore and notAfter functions to integer 19 seconds past the Epoch (the time values 20 returned from time.time()) 21 22 fetch_server_certificate (HOST, PORT) -- fetch the certificate provided 23 by the server running on HOST at port PORT. No 24 validation of the certificate is performed. 25 26 Integer constants: 27 28 SSL_ERROR_ZERO_RETURN 29 SSL_ERROR_WANT_READ 30 SSL_ERROR_WANT_WRITE 31 SSL_ERROR_WANT_X509_LOOKUP 32 SSL_ERROR_SYSCALL 33 SSL_ERROR_SSL 34 SSL_ERROR_WANT_CONNECT 35 36 SSL_ERROR_EOF 37 SSL_ERROR_INVALID_ERROR_CODE 38 39 The following group define certificate requirements that one side is 40 allowing/requiring from the other side: 41 42 CERT_NONE - no certificates from the other side are required (or will 43 be looked at if provided) 44 CERT_OPTIONAL - certificates are not required, but if provided will be 45 validated, and if validation fails, the connection will 46 also fail 47 CERT_REQUIRED - certificates are required, and will be validated, and 48 if validation fails, the connection will also fail 49 50 The following constants identify various SSL protocol variants: 51 52 PROTOCOL_SSLv2 53 PROTOCOL_SSLv3 54 PROTOCOL_SSLv23 55 PROTOCOL_TLSv1 56 """ 57 58 import textwrap 59 60 import _ssl # if we can't import it, let the error propagate 61 62 from _ssl import OPENSSL_VERSION_NUMBER, OPENSSL_VERSION_INFO, OPENSSL_VERSION 63 from _ssl import SSLError 64 from _ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED 65 from _ssl import RAND_status, RAND_egd, RAND_add 66 from _ssl import \ 67 SSL_ERROR_ZERO_RETURN, \ 68 SSL_ERROR_WANT_READ, \ 69 SSL_ERROR_WANT_WRITE, \ 70 SSL_ERROR_WANT_X509_LOOKUP, \ 71 SSL_ERROR_SYSCALL, \ 72 SSL_ERROR_SSL, \ 73 SSL_ERROR_WANT_CONNECT, \ 74 SSL_ERROR_EOF, \ 75 SSL_ERROR_INVALID_ERROR_CODE 76 from _ssl import PROTOCOL_SSLv3, PROTOCOL_SSLv23, PROTOCOL_TLSv1 77 _PROTOCOL_NAMES = { 78 PROTOCOL_TLSv1: "TLSv1", 79 PROTOCOL_SSLv23: "SSLv23", 80 PROTOCOL_SSLv3: "SSLv3", 81 } 82 try: 83 from _ssl import PROTOCOL_SSLv2 84 except ImportError: 85 pass 86 else: 87 _PROTOCOL_NAMES[PROTOCOL_SSLv2] = "SSLv2" 88 89 from socket import socket, _fileobject, _delegate_methods, error as socket_error 90 from socket import getnameinfo as _getnameinfo 91 import base64 # for DER-to-PEM translation 92 import errno 93 94 class SSLSocket(socket): 95 96 """This class implements a subtype of socket.socket that wraps 97 the underlying OS socket in an SSL context when necessary, and 98 provides read and write methods over that channel.""" 99 100 def __init__(self, sock, keyfile=None, certfile=None, 101 server_side=False, cert_reqs=CERT_NONE, 102 ssl_version=PROTOCOL_SSLv23, ca_certs=None, 103 do_handshake_on_connect=True, 104 suppress_ragged_eofs=True, ciphers=None): 105 socket.__init__(self, _sock=sock._sock) 106 # The initializer for socket overrides the methods send(), recv(), etc. 107 # in the instancce, which we don't need -- but we want to provide the 108 # methods defined in SSLSocket. 109 for attr in _delegate_methods: 110 try: 111 delattr(self, attr) 112 except AttributeError: 113 pass 114 115 if certfile and not keyfile: 116 keyfile = certfile 117 # see if it's connected 118 try: 119 socket.getpeername(self) 120 except socket_error, e: 121 if e.errno != errno.ENOTCONN: 122 raise 123 # no, no connection yet 124 self._connected = False 125 self._sslobj = None 126 else: 127 # yes, create the SSL object 128 self._connected = True 129 self._sslobj = _ssl.sslwrap(self._sock, server_side, 130 keyfile, certfile, 131 cert_reqs, ssl_version, ca_certs, 132 ciphers) 133 if do_handshake_on_connect: 134 self.do_handshake() 135 self.keyfile = keyfile 136 self.certfile = certfile 137 self.cert_reqs = cert_reqs 138 self.ssl_version = ssl_version 139 self.ca_certs = ca_certs 140 self.ciphers = ciphers 141 self.do_handshake_on_connect = do_handshake_on_connect 142 self.suppress_ragged_eofs = suppress_ragged_eofs 143 self._makefile_refs = 0 144 145 def read(self, len=1024): 146 147 """Read up to LEN bytes and return them. 148 Return zero-length string on EOF.""" 149 150 try: 151 return self._sslobj.read(len) 152 except SSLError, x: 153 if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs: 154 return '' 155 else: 156 raise 157 158 def write(self, data): 159 160 """Write DATA to the underlying SSL channel. Returns 161 number of bytes of DATA actually transmitted.""" 162 163 return self._sslobj.write(data) 164 165 def getpeercert(self, binary_form=False): 166 167 """Returns a formatted version of the data in the 168 certificate provided by the other end of the SSL channel. 169 Return None if no certificate was provided, {} if a 170 certificate was provided, but not validated.""" 171 172 return self._sslobj.peer_certificate(binary_form) 173 174 def cipher(self): 175 176 if not self._sslobj: 177 return None 178 else: 179 return self._sslobj.cipher() 180 181 def send(self, data, flags=0): 182 if self._sslobj: 183 if flags != 0: 184 raise ValueError( 185 "non-zero flags not allowed in calls to send() on %s" % 186 self.__class__) 187 while True: 188 try: 189 v = self._sslobj.write(data) 190 except SSLError, x: 191 if x.args[0] == SSL_ERROR_WANT_READ: 192 return 0 193 elif x.args[0] == SSL_ERROR_WANT_WRITE: 194 return 0 195 else: 196 raise 197 else: 198 return v 199 else: 200 return self._sock.send(data, flags) 201 202 def sendto(self, data, flags_or_addr, addr=None): 203 if self._sslobj: 204 raise ValueError("sendto not allowed on instances of %s" % 205 self.__class__) 206 elif addr is None: 207 return self._sock.sendto(data, flags_or_addr) 208 else: 209 return self._sock.sendto(data, flags_or_addr, addr) 210 211 def sendall(self, data, flags=0): 212 if self._sslobj: 213 if flags != 0: 214 raise ValueError( 215 "non-zero flags not allowed in calls to sendall() on %s" % 216 self.__class__) 217 amount = len(data) 218 count = 0 219 while (count < amount): 220 v = self.send(data[count:]) 221 count += v 222 return amount 223 else: 224 return socket.sendall(self, data, flags) 225 226 def recv(self, buflen=1024, flags=0): 227 if self._sslobj: 228 if flags != 0: 229 raise ValueError( 230 "non-zero flags not allowed in calls to recv() on %s" % 231 self.__class__) 232 return self.read(buflen) 233 else: 234 return self._sock.recv(buflen, flags) 235 236 def recv_into(self, buffer, nbytes=None, flags=0): 237 if buffer and (nbytes is None): 238 nbytes = len(buffer) 239 elif nbytes is None: 240 nbytes = 1024 241 if self._sslobj: 242 if flags != 0: 243 raise ValueError( 244 "non-zero flags not allowed in calls to recv_into() on %s" % 245 self.__class__) 246 tmp_buffer = self.read(nbytes) 247 v = len(tmp_buffer) 248 buffer[:v] = tmp_buffer 249 return v 250 else: 251 return self._sock.recv_into(buffer, nbytes, flags) 252 253 def recvfrom(self, buflen=1024, flags=0): 254 if self._sslobj: 255 raise ValueError("recvfrom not allowed on instances of %s" % 256 self.__class__) 257 else: 258 return self._sock.recvfrom(buflen, flags) 259 260 def recvfrom_into(self, buffer, nbytes=None, flags=0): 261 if self._sslobj: 262 raise ValueError("recvfrom_into not allowed on instances of %s" % 263 self.__class__) 264 else: 265 return self._sock.recvfrom_into(buffer, nbytes, flags) 266 267 def pending(self): 268 if self._sslobj: 269 return self._sslobj.pending() 270 else: 271 return 0 272 273 def unwrap(self): 274 if self._sslobj: 275 s = self._sslobj.shutdown() 276 self._sslobj = None 277 return s 278 else: 279 raise ValueError("No SSL wrapper around " + str(self)) 280 281 def shutdown(self, how): 282 self._sslobj = None 283 socket.shutdown(self, how) 284 285 def close(self): 286 if self._makefile_refs < 1: 287 self._sslobj = None 288 socket.close(self) 289 else: 290 self._makefile_refs -= 1 291 292 def do_handshake(self): 293 294 """Perform a TLS/SSL handshake.""" 295 296 self._sslobj.do_handshake() 297 298 def _real_connect(self, addr, return_errno): 299 # Here we assume that the socket is client-side, and not 300 # connected at the time of the call. We connect it, then wrap it. 301 if self._connected: 302 raise ValueError("attempt to connect already-connected SSLSocket!") 303 self._sslobj = _ssl.sslwrap(self._sock, False, self.keyfile, self.certfile, 304 self.cert_reqs, self.ssl_version, 305 self.ca_certs, self.ciphers) 306 try: 307 socket.connect(self, addr) 308 if self.do_handshake_on_connect: 309 self.do_handshake() 310 except socket_error as e: 311 if return_errno: 312 return e.errno 313 else: 314 self._sslobj = None 315 raise e 316 self._connected = True 317 return 0 318 319 def connect(self, addr): 320 """Connects to remote ADDR, and then wraps the connection in 321 an SSL channel.""" 322 self._real_connect(addr, False) 323 324 def connect_ex(self, addr): 325 """Connects to remote ADDR, and then wraps the connection in 326 an SSL channel.""" 327 return self._real_connect(addr, True) 328 329 def accept(self): 330 331 """Accepts a new connection from a remote client, and returns 332 a tuple containing that new connection wrapped with a server-side 333 SSL channel, and the address of the remote client.""" 334 335 newsock, addr = socket.accept(self) 336 return (SSLSocket(newsock, 337 keyfile=self.keyfile, 338 certfile=self.certfile, 339 server_side=True, 340 cert_reqs=self.cert_reqs, 341 ssl_version=self.ssl_version, 342 ca_certs=self.ca_certs, 343 ciphers=self.ciphers, 344 do_handshake_on_connect=self.do_handshake_on_connect, 345 suppress_ragged_eofs=self.suppress_ragged_eofs), 346 addr) 347 348 def makefile(self, mode='r', bufsize=-1): 349 350 """Make and return a file-like object that 351 works with the SSL connection. Just use the code 352 from the socket module.""" 353 354 self._makefile_refs += 1 355 # close=True so as to decrement the reference count when done with 356 # the file-like object. 357 return _fileobject(self, mode, bufsize, close=True) 358 359 360 361 def wrap_socket(sock, keyfile=None, certfile=None, 362 server_side=False, cert_reqs=CERT_NONE, 363 ssl_version=PROTOCOL_SSLv23, ca_certs=None, 364 do_handshake_on_connect=True, 365 suppress_ragged_eofs=True, ciphers=None): 366 367 return SSLSocket(sock, keyfile=keyfile, certfile=certfile, 368 server_side=server_side, cert_reqs=cert_reqs, 369 ssl_version=ssl_version, ca_certs=ca_certs, 370 do_handshake_on_connect=do_handshake_on_connect, 371 suppress_ragged_eofs=suppress_ragged_eofs, 372 ciphers=ciphers) 373 374 375 # some utility functions 376 377 def cert_time_to_seconds(cert_time): 378 379 """Takes a date-time string in standard ASN1_print form 380 ("MON DAY 24HOUR:MINUTE:SEC YEAR TIMEZONE") and return 381 a Python time value in seconds past the epoch.""" 382 383 import time 384 return time.mktime(time.strptime(cert_time, "%b %d %H:%M:%S %Y GMT")) 385 386 PEM_HEADER = "-----BEGIN CERTIFICATE-----" 387 PEM_FOOTER = "-----END CERTIFICATE-----" 388 389 def DER_cert_to_PEM_cert(der_cert_bytes): 390 391 """Takes a certificate in binary DER format and returns the 392 PEM version of it as a string.""" 393 394 if hasattr(base64, 'standard_b64encode'): 395 # preferred because older API gets line-length wrong 396 f = base64.standard_b64encode(der_cert_bytes) 397 return (PEM_HEADER + '\n' + 398 textwrap.fill(f, 64) + '\n' + 399 PEM_FOOTER + '\n') 400 else: 401 return (PEM_HEADER + '\n' + 402 base64.encodestring(der_cert_bytes) + 403 PEM_FOOTER + '\n') 404 405 def PEM_cert_to_DER_cert(pem_cert_string): 406 407 """Takes a certificate in ASCII PEM format and returns the 408 DER-encoded version of it as a byte sequence""" 409 410 if not pem_cert_string.startswith(PEM_HEADER): 411 raise ValueError("Invalid PEM encoding; must start with %s" 412 % PEM_HEADER) 413 if not pem_cert_string.strip().endswith(PEM_FOOTER): 414 raise ValueError("Invalid PEM encoding; must end with %s" 415 % PEM_FOOTER) 416 d = pem_cert_string.strip()[len(PEM_HEADER):-len(PEM_FOOTER)] 417 return base64.decodestring(d) 418 419 def get_server_certificate(addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None): 420 421 """Retrieve the certificate from the server at the specified address, 422 and return it as a PEM-encoded string. 423 If 'ca_certs' is specified, validate the server cert against it. 424 If 'ssl_version' is specified, use it in the connection attempt.""" 425 426 host, port = addr 427 if (ca_certs is not None): 428 cert_reqs = CERT_REQUIRED 429 else: 430 cert_reqs = CERT_NONE 431 s = wrap_socket(socket(), ssl_version=ssl_version, 432 cert_reqs=cert_reqs, ca_certs=ca_certs) 433 s.connect(addr) 434 dercert = s.getpeercert(True) 435 s.close() 436 return DER_cert_to_PEM_cert(dercert) 437 438 def get_protocol_name(protocol_code): 439 return _PROTOCOL_NAMES.get(protocol_code, '<unknown>') 440 441 442 # a replacement for the old socket.ssl function 443 444 def sslwrap_simple(sock, keyfile=None, certfile=None): 445 446 """A replacement for the old socket.ssl function. Designed 447 for compability with Python 2.5 and earlier. Will disappear in 448 Python 3.0.""" 449 450 if hasattr(sock, "_sock"): 451 sock = sock._sock 452 453 ssl_sock = _ssl.sslwrap(sock, 0, keyfile, certfile, CERT_NONE, 454 PROTOCOL_SSLv23, None) 455 try: 456 sock.getpeername() 457 except socket_error: 458 # no, no connection yet 459 pass 460 else: 461 # yes, do the handshake 462 ssl_sock.do_handshake() 463 464 return ssl_sock 465