Home | History | Annotate | Download | only in conscrypt
      1 /*
      2  * Copyright (C) 2009 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 package org.conscrypt;
     18 
     19 import java.io.ByteArrayInputStream;
     20 import java.io.ByteArrayOutputStream;
     21 import java.io.DataInputStream;
     22 import java.io.DataOutputStream;
     23 import java.io.IOException;
     24 import java.security.cert.Certificate;
     25 import java.security.cert.CertificateEncodingException;
     26 import java.security.cert.X509Certificate;
     27 import java.util.Arrays;
     28 import java.util.Enumeration;
     29 import java.util.Iterator;
     30 import java.util.LinkedHashMap;
     31 import java.util.Map;
     32 import java.util.NoSuchElementException;
     33 import javax.net.ssl.SSLSession;
     34 import javax.net.ssl.SSLSessionContext;
     35 
     36 /**
     37  * Supports SSL session caches.
     38  */
     39 abstract class AbstractSessionContext implements SSLSessionContext {
     40 
     41     /**
     42      * Maximum lifetime of a session (in seconds) after which it's considered invalid and should not
     43      * be used to for new connections.
     44      */
     45     private static final int DEFAULT_SESSION_TIMEOUT_SECONDS = 8 * 60 * 60;
     46 
     47     volatile int maximumSize;
     48     volatile int timeout = DEFAULT_SESSION_TIMEOUT_SECONDS;
     49 
     50     final long sslCtxNativePointer = NativeCrypto.SSL_CTX_new();
     51 
     52     /** Identifies OpenSSL sessions. */
     53     static final int OPEN_SSL = 1;
     54 
     55     private final Map<ByteArray, SSLSession> sessions
     56             = new LinkedHashMap<ByteArray, SSLSession>() {
     57         @Override
     58         protected boolean removeEldestEntry(
     59                 Map.Entry<ByteArray, SSLSession> eldest) {
     60             boolean remove = maximumSize > 0 && size() > maximumSize;
     61             if (remove) {
     62                 remove(eldest.getKey());
     63                 sessionRemoved(eldest.getValue());
     64             }
     65             return false;
     66         }
     67     };
     68 
     69     /**
     70      * Constructs a new session context.
     71      *
     72      * @param maximumSize of cache
     73      */
     74     AbstractSessionContext(int maximumSize) {
     75         this.maximumSize = maximumSize;
     76     }
     77 
     78     /**
     79      * Returns the collection of sessions ordered from oldest to newest
     80      */
     81     private Iterator<SSLSession> sessionIterator() {
     82         synchronized (sessions) {
     83             SSLSession[] array = sessions.values().toArray(
     84                     new SSLSession[sessions.size()]);
     85             return Arrays.asList(array).iterator();
     86         }
     87     }
     88 
     89     @Override
     90     public final Enumeration<byte[]> getIds() {
     91         final Iterator<SSLSession> i = sessionIterator();
     92         return new Enumeration<byte[]>() {
     93             private SSLSession next;
     94 
     95             @Override
     96             public boolean hasMoreElements() {
     97                 if (next != null) {
     98                     return true;
     99                 }
    100                 while (i.hasNext()) {
    101                     SSLSession session = i.next();
    102                     if (session.isValid()) {
    103                         next = session;
    104                         return true;
    105                     }
    106                 }
    107                 next = null;
    108                 return false;
    109             }
    110 
    111             @Override
    112             public byte[] nextElement() {
    113                 if (hasMoreElements()) {
    114                     byte[] id = next.getId();
    115                     next = null;
    116                     return id;
    117                 }
    118                 throw new NoSuchElementException();
    119             }
    120         };
    121     }
    122 
    123     @Override
    124     public final int getSessionCacheSize() {
    125         return maximumSize;
    126     }
    127 
    128     @Override
    129     public final int getSessionTimeout() {
    130         return timeout;
    131     }
    132 
    133     /**
    134      * Makes sure cache size is < maximumSize.
    135      */
    136     protected void trimToSize() {
    137         synchronized (sessions) {
    138             int size = sessions.size();
    139             if (size > maximumSize) {
    140                 int removals = size - maximumSize;
    141                 Iterator<SSLSession> i = sessions.values().iterator();
    142                 do {
    143                     SSLSession session = i.next();
    144                     i.remove();
    145                     sessionRemoved(session);
    146                 } while (--removals > 0);
    147             }
    148         }
    149     }
    150 
    151     @Override
    152     public void setSessionTimeout(int seconds)
    153             throws IllegalArgumentException {
    154         if (seconds < 0) {
    155             throw new IllegalArgumentException("seconds < 0");
    156         }
    157         timeout = seconds;
    158 
    159         synchronized (sessions) {
    160             Iterator<SSLSession> i = sessions.values().iterator();
    161             while (i.hasNext()) {
    162                 SSLSession session = i.next();
    163                 // SSLSession's know their context and consult the
    164                 // timeout as part of their validity condition.
    165                 if (!session.isValid()) {
    166                     i.remove();
    167                     sessionRemoved(session);
    168                 }
    169             }
    170         }
    171     }
    172 
    173     /**
    174      * Called when a session is removed. Used by ClientSessionContext
    175      * to update its host-and-port based cache.
    176      */
    177     protected abstract void sessionRemoved(SSLSession session);
    178 
    179     @Override
    180     public final void setSessionCacheSize(int size)
    181             throws IllegalArgumentException {
    182         if (size < 0) {
    183             throw new IllegalArgumentException("size < 0");
    184         }
    185 
    186         int oldMaximum = maximumSize;
    187         maximumSize = size;
    188 
    189         // Trim cache to size if necessary.
    190         if (size < oldMaximum) {
    191             trimToSize();
    192         }
    193     }
    194 
    195     /**
    196      * Converts the given session to bytes.
    197      *
    198      * @return session data as bytes or null if the session can't be converted
    199      */
    200     byte[] toBytes(SSLSession session) {
    201         // TODO: Support SSLSessionImpl, too.
    202         if (!(session instanceof OpenSSLSessionImpl)) {
    203             return null;
    204         }
    205 
    206         OpenSSLSessionImpl sslSession = (OpenSSLSessionImpl) session;
    207         try {
    208             ByteArrayOutputStream baos = new ByteArrayOutputStream();
    209             DataOutputStream daos = new DataOutputStream(baos);
    210 
    211             daos.writeInt(OPEN_SSL); // session type ID
    212 
    213             // Session data.
    214             byte[] data = sslSession.getEncoded();
    215             daos.writeInt(data.length);
    216             daos.write(data);
    217 
    218             // Certificates.
    219             Certificate[] certs = session.getPeerCertificates();
    220             daos.writeInt(certs.length);
    221 
    222             for (Certificate cert : certs) {
    223                 data = cert.getEncoded();
    224                 daos.writeInt(data.length);
    225                 daos.write(data);
    226             }
    227             // TODO: local certificates?
    228 
    229             return baos.toByteArray();
    230         } catch (IOException e) {
    231             log(e);
    232             return null;
    233         } catch (CertificateEncodingException e) {
    234             log(e);
    235             return null;
    236         }
    237     }
    238 
    239     /**
    240      * Creates a session from the given bytes.
    241      *
    242      * @return a session or null if the session can't be converted
    243      */
    244     SSLSession toSession(byte[] data, String host, int port) {
    245         ByteArrayInputStream bais = new ByteArrayInputStream(data);
    246         DataInputStream dais = new DataInputStream(bais);
    247         try {
    248             int type = dais.readInt();
    249             if (type != OPEN_SSL) {
    250                 log(new AssertionError("Unexpected type ID: " + type));
    251                 return null;
    252             }
    253 
    254             int length = dais.readInt();
    255             byte[] sessionData = new byte[length];
    256             dais.readFully(sessionData);
    257 
    258             int count = dais.readInt();
    259             X509Certificate[] certs = new X509Certificate[count];
    260             for (int i = 0; i < count; i++) {
    261                 length = dais.readInt();
    262                 byte[] certData = new byte[length];
    263                 dais.readFully(certData);
    264                 certs[i] = OpenSSLX509Certificate.fromX509Der(certData);
    265             }
    266 
    267             return new OpenSSLSessionImpl(sessionData, host, port, certs, this);
    268         } catch (IOException e) {
    269             log(e);
    270             return null;
    271         }
    272     }
    273 
    274     @Override
    275     public SSLSession getSession(byte[] sessionId) {
    276         if (sessionId == null) {
    277             throw new NullPointerException("sessionId == null");
    278         }
    279         ByteArray key = new ByteArray(sessionId);
    280         SSLSession session;
    281         synchronized (sessions) {
    282             session = sessions.get(key);
    283         }
    284         if (session != null && session.isValid()) {
    285             return session;
    286         }
    287         return null;
    288     }
    289 
    290     void putSession(SSLSession session) {
    291         byte[] id = session.getId();
    292         if (id.length == 0) {
    293             return;
    294         }
    295         ByteArray key = new ByteArray(id);
    296         synchronized (sessions) {
    297             sessions.put(key, session);
    298         }
    299     }
    300 
    301     static void log(Throwable t) {
    302         new Exception("Error converting session", t).printStackTrace();
    303     }
    304 
    305     @Override
    306     protected void finalize() throws Throwable {
    307         try {
    308             NativeCrypto.SSL_CTX_free(sslCtxNativePointer);
    309         } finally {
    310             super.finalize();
    311         }
    312     }
    313 }
    314