Home | History | Annotate | Download | only in ssl
      1 /*
      2  * Copyright (C) 2010 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 package org.conscrypt.javax.net.ssl;
     17 
     18 import static org.junit.Assert.assertTrue;
     19 
     20 import java.io.ByteArrayInputStream;
     21 import java.io.ByteArrayOutputStream;
     22 import java.io.IOException;
     23 import java.io.ObjectInputStream;
     24 import java.io.ObjectOutput;
     25 import java.io.ObjectOutputStream;
     26 import java.io.OutputStream;
     27 import java.net.InetAddress;
     28 import java.net.InetSocketAddress;
     29 import java.net.Socket;
     30 import java.security.KeyStore;
     31 import java.security.Principal;
     32 import java.security.SecureRandom;
     33 import java.security.cert.Certificate;
     34 import java.security.cert.CertificateException;
     35 import java.security.cert.X509Certificate;
     36 import java.util.Collections;
     37 import javax.net.ssl.KeyManager;
     38 import javax.net.ssl.SSLContext;
     39 import javax.net.ssl.SSLServerSocket;
     40 import javax.net.ssl.SSLSocket;
     41 import javax.net.ssl.SSLSocketFactory;
     42 import javax.net.ssl.TrustManager;
     43 import javax.net.ssl.X509TrustManager;
     44 import org.conscrypt.TestUtils;
     45 import org.conscrypt.java.security.TestKeyStore;
     46 
     47 /**
     48  * TestSSLContext is a convenience class for other tests that
     49  * want a canned SSLContext and related state for testing so they
     50  * don't have to duplicate the logic.
     51  */
     52 public final class TestSSLContext {
     53     /**
     54      * The Android SSLSocket and SSLServerSocket implementations are
     55      * based on a version of OpenSSL which includes support for RFC
     56      * 4507 session tickets. When using session tickets, the server
     57      * does not need to keep a cache mapping session IDs to SSL
     58      * sessions for reuse. Instead, the client presents the server
     59      * with a session ticket it received from the server earlier,
     60      * which is an SSL session encrypted by the server's secret
     61      * key. Since in this case the server does not need to keep a
     62      * cache, some tests may find different results depending on
     63      * whether or not the session tickets are in use. These tests can
     64      * use this function to determine if loopback SSL connections are
     65      * expected to use session tickets and conditionalize their
     66      * results appropriately.
     67      */
     68     public static boolean sslServerSocketSupportsSessionTickets() {
     69         // Disabled session tickets for better compatability b/2682876
     70         // return !IS_RI;
     71         return false;
     72     }
     73     public final KeyStore clientKeyStore;
     74     public final char[] clientStorePassword;
     75     public final KeyStore serverKeyStore;
     76     public final char[] serverStorePassword;
     77     public final KeyManager[] clientKeyManagers;
     78     public final KeyManager[] serverKeyManagers;
     79     public final X509TrustManager clientTrustManager;
     80     public final X509TrustManager serverTrustManager;
     81     public final SSLContext clientContext;
     82     public final SSLContext serverContext;
     83     public final SSLServerSocket serverSocket;
     84     public final InetAddress host;
     85     public final int port;
     86     /**
     87      * Used for replacing the hostname in an InetSocketAddress object during
     88      * serialization.
     89      */
     90     private static class HostnameRewritingObjectOutputStream extends ObjectOutputStream {
     91         private final String hostname;
     92         public HostnameRewritingObjectOutputStream(OutputStream out, String hostname)
     93                 throws IOException {
     94             super(out);
     95             this.hostname = hostname;
     96         }
     97         @Override
     98         public PutField putFields() throws IOException {
     99             return new PutFieldProxy(super.putFields(), hostname);
    100         }
    101         private static class PutFieldProxy extends ObjectOutputStream.PutField {
    102             private final PutField delegate;
    103             private final String hostname;
    104             public PutFieldProxy(ObjectOutputStream.PutField delegate, String hostname) {
    105                 this.delegate = delegate;
    106                 this.hostname = hostname;
    107             }
    108             @Override
    109             public void put(String name, boolean val) {
    110                 delegate.put(name, val);
    111             }
    112             @Override
    113             public void put(String name, byte val) {
    114                 delegate.put(name, val);
    115             }
    116             @Override
    117             public void put(String name, char val) {
    118                 delegate.put(name, val);
    119             }
    120             @Override
    121             public void put(String name, short val) {
    122                 delegate.put(name, val);
    123             }
    124             @Override
    125             public void put(String name, int val) {
    126                 delegate.put(name, val);
    127             }
    128             @Override
    129             public void put(String name, long val) {
    130                 delegate.put(name, val);
    131             }
    132             @Override
    133             public void put(String name, float val) {
    134                 delegate.put(name, val);
    135             }
    136             @Override
    137             public void put(String name, double val) {
    138                 delegate.put(name, val);
    139             }
    140             @Override
    141             public void put(String name, Object val) {
    142                 if ("hostname".equals(name)) {
    143                     delegate.put(name, hostname);
    144                 } else {
    145                     delegate.put(name, val);
    146                 }
    147             }
    148             @SuppressWarnings("deprecation")
    149             @Override
    150             public void write(ObjectOutput out) throws IOException {
    151                 delegate.write(out);
    152             }
    153         }
    154     }
    155     /**
    156      * Creates an InetSocketAddress where the hostname points to an arbitrary
    157      * hostname, but the address points to the loopback address. Useful for
    158      * testing SNI where both "localhost" and IP addresses are not allowed.
    159      */
    160     public InetSocketAddress getLoopbackAsHostname(String hostname, int port)
    161             throws IOException, ClassNotFoundException {
    162         InetSocketAddress addr = new InetSocketAddress(TestUtils.getLoopbackAddress(), port);
    163         ByteArrayOutputStream baos = new ByteArrayOutputStream();
    164         HostnameRewritingObjectOutputStream oos =
    165                 new HostnameRewritingObjectOutputStream(baos, hostname);
    166         oos.writeObject(addr);
    167         oos.close();
    168         ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(baos.toByteArray()));
    169         return (InetSocketAddress) ois.readObject();
    170     }
    171     private TestSSLContext(KeyStore clientKeyStore, char[] clientStorePassword,
    172             KeyStore serverKeyStore, char[] serverStorePassword, KeyManager[] clientKeyManagers,
    173             KeyManager[] serverKeyManagers, X509TrustManager clientTrustManager,
    174             X509TrustManager serverTrustManager, SSLContext clientContext,
    175             SSLContext serverContext, SSLServerSocket serverSocket, InetAddress host, int port) {
    176         this.clientKeyStore = clientKeyStore;
    177         this.clientStorePassword = clientStorePassword;
    178         this.serverKeyStore = serverKeyStore;
    179         this.serverStorePassword = serverStorePassword;
    180         this.clientKeyManagers = clientKeyManagers;
    181         this.serverKeyManagers = serverKeyManagers;
    182         this.clientTrustManager = clientTrustManager;
    183         this.serverTrustManager = serverTrustManager;
    184         this.clientContext = clientContext;
    185         this.serverContext = serverContext;
    186         this.serverSocket = serverSocket;
    187         this.host = host;
    188         this.port = port;
    189     }
    190     public void close() {
    191         try {
    192             serverSocket.close();
    193         } catch (Exception e) {
    194             throw new RuntimeException(e);
    195         }
    196     }
    197 
    198     public static Builder newBuilder() {
    199         return new Builder();
    200     }
    201 
    202     public static final class Builder {
    203         private TestKeyStore client;
    204         private char[] clientStorePassword;
    205         private TestKeyStore server;
    206         private char[] serverStorePassword;
    207         private KeyManager[] additionalClientKeyManagers;
    208         private KeyManager[] additionalServerKeyManagers;
    209         private TrustManager clientTrustManager;
    210         private TrustManager serverTrustManager;
    211         private SSLContext clientContext;
    212         private SSLContext serverContext;
    213         private int serverReceiveBufferSize;
    214         private boolean useDefaults = true;
    215 
    216         public Builder useDefaults(boolean useDefaults) {
    217             this.useDefaults = useDefaults;
    218             return this;
    219         }
    220 
    221         public Builder client(TestKeyStore client) {
    222             this.client = client;
    223             return this;
    224         }
    225 
    226         public Builder clientStorePassword(char[] clientStorePassword) {
    227             this.clientStorePassword = clientStorePassword;
    228             return this;
    229         }
    230 
    231         public Builder server(TestKeyStore server) {
    232             this.server = server;
    233             return this;
    234         }
    235 
    236         public Builder serverStorePassword(char[] serverStorePassword) {
    237             this.serverStorePassword = serverStorePassword;
    238             return this;
    239         }
    240 
    241         public Builder additionalClientKeyManagers(KeyManager[] additionalClientKeyManagers) {
    242             this.additionalClientKeyManagers = additionalClientKeyManagers;
    243             return this;
    244         }
    245 
    246         public Builder additionalServerKeyManagers(KeyManager[] additionalServerKeyManagers) {
    247             this.additionalServerKeyManagers = additionalServerKeyManagers;
    248             return this;
    249         }
    250 
    251         public Builder clientTrustManager(TrustManager clientTrustManager) {
    252             this.clientTrustManager = clientTrustManager;
    253             return this;
    254         }
    255 
    256         public Builder serverTrustManager(TrustManager serverTrustManager) {
    257             this.serverTrustManager = serverTrustManager;
    258             return this;
    259         }
    260 
    261         public Builder clientContext(SSLContext clientContext) {
    262             this.clientContext = clientContext;
    263             return this;
    264         }
    265 
    266         public Builder serverContext(SSLContext serverContext) {
    267             this.serverContext = serverContext;
    268             return this;
    269         }
    270 
    271         public Builder serverReceiveBufferSize(int serverReceiveBufferSize) {
    272             this.serverReceiveBufferSize = serverReceiveBufferSize;
    273             return this;
    274         }
    275 
    276         TestSSLContext build() {
    277             // Get the current values for all the things.
    278             TestKeyStore client = this.client;
    279             TestKeyStore server = this.server;
    280             char[] clientStorePassword = this.clientStorePassword;
    281             char[] serverStorePassword = this.serverStorePassword;
    282             KeyManager[] clientKeyManagers = client != null ? client.keyManagers : null;
    283             KeyManager[] serverKeyManagers = server != null ? server.keyManagers : null;
    284             TrustManager clientTrustManager = this.clientTrustManager;
    285             TrustManager serverTrustManager = this.serverTrustManager;
    286             SSLContext clientContext = this.clientContext;
    287             SSLContext serverContext = this.serverContext;
    288 
    289             // Apply default values if configured to do so.
    290             if (useDefaults) {
    291                 client = client != null ? client : TestKeyStore.getClient();
    292                 server = server != null ? server : TestKeyStore.getServer();
    293                 clientStorePassword =
    294                         clientStorePassword != null ? clientStorePassword : client.storePassword;
    295                 serverStorePassword =
    296                         serverStorePassword != null ? serverStorePassword : server.storePassword;
    297                 clientKeyManagers =
    298                         clientKeyManagers != null ? clientKeyManagers : client.keyManagers;
    299                 serverKeyManagers =
    300                         serverKeyManagers != null ? serverKeyManagers : server.keyManagers;
    301                 clientKeyManagers = concat(clientKeyManagers, additionalClientKeyManagers);
    302                 serverKeyManagers = concat(serverKeyManagers, additionalServerKeyManagers);
    303                 clientTrustManager =
    304                         clientTrustManager != null ? clientTrustManager : client.trustManagers[0];
    305                 serverTrustManager =
    306                         serverTrustManager != null ? serverTrustManager : server.trustManagers[0];
    307 
    308                 String protocol = "TLSv1.2";
    309                 clientContext = clientContext != null
    310                         ? clientContext
    311                         : createSSLContext(protocol, clientKeyManagers,
    312                                   new TrustManager[] {clientTrustManager});
    313                 serverContext = serverContext != null
    314                         ? serverContext
    315                         : createSSLContext(protocol, serverKeyManagers,
    316                                   new TrustManager[] {serverTrustManager});
    317             }
    318 
    319             // Create the context.
    320             try {
    321                 SSLServerSocket serverSocket =
    322                         (SSLServerSocket) serverContext.getServerSocketFactory()
    323                                 .createServerSocket();
    324                 if (serverReceiveBufferSize > 0) {
    325                     // The TCP spec says that this should occur before listen.
    326                     serverSocket.setReceiveBufferSize(serverReceiveBufferSize);
    327                 }
    328                 InetAddress host = TestUtils.getLoopbackAddress();
    329                 serverSocket.bind(new InetSocketAddress(host, 0));
    330                 int port = serverSocket.getLocalPort();
    331                 return new TestSSLContext(client != null ? client.keyStore : null,
    332                         clientStorePassword, server != null ? server.keyStore : null,
    333                         serverStorePassword, clientKeyManagers, serverKeyManagers,
    334                         (X509TrustManager) clientTrustManager,
    335                         (X509TrustManager) serverTrustManager, clientContext, serverContext,
    336                         serverSocket, host, port);
    337             } catch (RuntimeException e) {
    338                 throw e;
    339             } catch (Exception e) {
    340                 throw new RuntimeException(e);
    341             }
    342         }
    343     }
    344 
    345     /**
    346      * Usual TestSSLContext creation method, creates underlying
    347      * SSLContext with certificate and key as well as SSLServerSocket
    348      * listening provided host and port.
    349      */
    350     public static TestSSLContext create() {
    351         return new Builder().build();
    352     }
    353 
    354     /**
    355      * TestSSLContext creation method that allows separate creation of server key store
    356      */
    357     public static TestSSLContext create(TestKeyStore client, TestKeyStore server) {
    358         return new Builder().client(client).server(server).build();
    359     }
    360     /**
    361      * Create a SSLContext with a KeyManager using the private key and
    362      * certificate chain from the given KeyStore and a TrustManager
    363      * using the certificates authorities from the same KeyStore.
    364      */
    365     public static SSLContext createSSLContext(final String protocol, final KeyManager[] keyManagers,
    366             final TrustManager[] trustManagers) {
    367         try {
    368             SSLContext context = SSLContext.getInstance(protocol);
    369             context.init(keyManagers, trustManagers, new SecureRandom());
    370             return context;
    371         } catch (Exception e) {
    372             throw new RuntimeException(e);
    373         }
    374     }
    375     public static void assertCertificateInKeyStore(Principal principal, KeyStore keyStore)
    376             throws Exception {
    377         String subjectName = principal.getName();
    378         boolean found = false;
    379         for (String alias : Collections.list(keyStore.aliases())) {
    380             if (!keyStore.isCertificateEntry(alias)) {
    381                 continue;
    382             }
    383             X509Certificate keyStoreCertificate = (X509Certificate) keyStore.getCertificate(alias);
    384             if (subjectName.equals(keyStoreCertificate.getSubjectDN().getName())) {
    385                 found = true;
    386                 break;
    387             }
    388         }
    389         assertTrue(found);
    390     }
    391     public static void assertCertificateInKeyStore(Certificate certificate, KeyStore keyStore)
    392             throws Exception {
    393         boolean found = false;
    394         for (String alias : Collections.list(keyStore.aliases())) {
    395             if (!keyStore.isCertificateEntry(alias)) {
    396                 continue;
    397             }
    398             Certificate keyStoreCertificate = keyStore.getCertificate(alias);
    399             if (certificate.equals(keyStoreCertificate)) {
    400                 found = true;
    401                 break;
    402             }
    403         }
    404         assertTrue(found);
    405     }
    406     public static void assertServerCertificateChain(
    407             X509TrustManager trustManager, Certificate[] serverChain) throws CertificateException {
    408         X509Certificate[] chain = (X509Certificate[]) serverChain;
    409         trustManager.checkServerTrusted(chain, chain[0].getPublicKey().getAlgorithm());
    410     }
    411     public static void assertClientCertificateChain(
    412             X509TrustManager trustManager, Certificate[] clientChain) throws CertificateException {
    413         X509Certificate[] chain = (X509Certificate[]) clientChain;
    414         trustManager.checkClientTrusted(chain, chain[0].getPublicKey().getAlgorithm());
    415     }
    416     /**
    417      * Returns an SSLSocketFactory that calls setWantClientAuth and
    418      * setNeedClientAuth as specified on all returned sockets.
    419      */
    420     public static SSLSocketFactory clientAuth(
    421             final SSLSocketFactory sf, final boolean want, final boolean need) {
    422         return new SSLSocketFactory() {
    423             private SSLSocket set(Socket socket) {
    424                 SSLSocket s = (SSLSocket) socket;
    425                 s.setWantClientAuth(want);
    426                 s.setNeedClientAuth(need);
    427                 return s;
    428             }
    429             @Override
    430             public Socket createSocket(String host, int port) throws IOException {
    431                 return set(sf.createSocket(host, port));
    432             }
    433             @Override
    434             public Socket createSocket(String host, int port, InetAddress localHost, int localPort)
    435                     throws IOException {
    436                 return set(sf.createSocket(host, port, localHost, localPort));
    437             }
    438             @Override
    439             public Socket createSocket(InetAddress host, int port) throws IOException {
    440                 return set(sf.createSocket(host, port));
    441             }
    442             @Override
    443             public Socket createSocket(InetAddress address, int port, InetAddress localAddress,
    444                     int localPort) throws IOException {
    445                 return set(sf.createSocket(address, port));
    446             }
    447             @Override
    448             public String[] getDefaultCipherSuites() {
    449                 return sf.getDefaultCipherSuites();
    450             }
    451             @Override
    452             public String[] getSupportedCipherSuites() {
    453                 return sf.getSupportedCipherSuites();
    454             }
    455             @Override
    456             public Socket createSocket(Socket s, String host, int port, boolean autoClose)
    457                     throws IOException {
    458                 return set(sf.createSocket(s, host, port, autoClose));
    459             }
    460         };
    461     }
    462     private static KeyManager[] concat(KeyManager[] a, KeyManager[] b) {
    463         if ((a == null) || (a.length == 0)) {
    464             return b;
    465         }
    466         if ((b == null) || (b.length == 0)) {
    467             return a;
    468         }
    469         KeyManager[] result = new KeyManager[a.length + b.length];
    470         System.arraycopy(a, 0, result, 0, a.length);
    471         System.arraycopy(b, 0, result, a.length, b.length);
    472         return result;
    473     }
    474 }
    475