Home | History | Annotate | Download | only in conscrypt
      1 /*
      2  * Copyright (C) 2015 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 package org.conscrypt;
     18 
     19 import static org.junit.Assert.assertEquals;
     20 import static org.junit.Assert.assertFalse;
     21 
     22 import java.io.FileNotFoundException;
     23 import java.io.IOException;
     24 import java.io.InputStream;
     25 import java.lang.reflect.Method;
     26 import java.net.ServerSocket;
     27 import java.nio.ByteBuffer;
     28 import java.nio.charset.Charset;
     29 import java.security.NoSuchAlgorithmException;
     30 import java.security.Provider;
     31 import java.security.Security;
     32 import javax.net.ssl.SSLContext;
     33 import javax.net.ssl.SSLEngine;
     34 import javax.net.ssl.SSLEngineResult;
     35 import javax.net.ssl.SSLException;
     36 import javax.net.ssl.SSLServerSocketFactory;
     37 import javax.net.ssl.SSLSocketFactory;
     38 import libcore.io.Streams;
     39 import libcore.java.security.TestKeyStore;
     40 
     41 /**
     42  * Utility methods to support testing.
     43  */
     44 public final class TestUtils {
     45     static final Charset UTF_8 = Charset.forName("UTF-8");
     46 
     47     private static final Provider JDK_PROVIDER = getDefaultTlsProvider();
     48     private static final byte[] CHARS =
     49             "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789".getBytes(UTF_8);
     50     private static final ByteBuffer EMPTY_BUFFER = ByteBuffer.allocateDirect(0);
     51 
     52     public static final String PROTOCOL_TLS_V1_2 = "TLSv1.2";
     53     public static final String PROVIDER_PROPERTY = "SSLContext.TLSv1.2";
     54     public static final String LOCALHOST = "localhost";
     55 
     56     static final String TEST_CIPHER = "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256";
     57 
     58     private TestUtils() {}
     59 
     60     private static Provider getDefaultTlsProvider() {
     61         for (Provider p : Security.getProviders()) {
     62             if (p.get(PROVIDER_PROPERTY) != null) {
     63                 return p;
     64             }
     65         }
     66         throw new RuntimeException("Unable to find a default provider for " + PROVIDER_PROPERTY);
     67     }
     68 
     69     static Provider getJdkProvider() {
     70         return JDK_PROVIDER;
     71     }
     72 
     73     public static Provider getConscryptProvider() {
     74         try {
     75             return (Provider) conscryptClass("OpenSSLProvider")
     76                 .getConstructor()
     77                 .newInstance();
     78         } catch (Exception e) {
     79             throw new RuntimeException(e);
     80         }
     81     }
     82 
     83     public static void installConscryptAsDefaultProvider() {
     84         final Provider conscryptProvider = getConscryptProvider();
     85         synchronized (getConscryptProvider()) {
     86             Provider[] providers = Security.getProviders();
     87             if (providers.length == 0 || !providers[0].equals(conscryptProvider)) {
     88                 Security.insertProviderAt(conscryptProvider, 1);
     89                 return;
     90             }
     91         }
     92     }
     93 
     94     public static InputStream openTestFile(String name) throws FileNotFoundException {
     95         InputStream is = TestUtils.class.getResourceAsStream("/" + name);
     96         if (is == null) {
     97             throw new FileNotFoundException(name);
     98         }
     99         return is;
    100     }
    101 
    102     public static byte[] readTestFile(String name) throws IOException {
    103         return Streams.readFully(openTestFile(name));
    104     }
    105 
    106     /**
    107      * Looks up the conscrypt class for the given simple name (i.e. no package prefix).
    108      */
    109     public static Class<?> conscryptClass(String simpleName) throws ClassNotFoundException {
    110         ClassNotFoundException ex = null;
    111         for (String packageName : new String[]{"com.android.org.conscrypt", "org.conscrypt"}) {
    112             String name = packageName + "." + simpleName;
    113             try {
    114                 return Class.forName(name);
    115             } catch (ClassNotFoundException e) {
    116                 ex = e;
    117             }
    118         }
    119         throw ex;
    120     }
    121 
    122     /**
    123      * Returns an array containing only {@link #PROTOCOL_TLS_V1_2}.
    124      */
    125     public static String[] getProtocols() {
    126         return new String[] {PROTOCOL_TLS_V1_2};
    127     }
    128 
    129     public static SSLSocketFactory getJdkSocketFactory() {
    130         return getSocketFactory(JDK_PROVIDER);
    131     }
    132 
    133     public static SSLServerSocketFactory getJdkServerSocketFactory() {
    134         return getServerSocketFactory(JDK_PROVIDER);
    135     }
    136 
    137     static SSLSocketFactory setUseEngineSocket(SSLSocketFactory conscryptFactory, boolean useEngineSocket) {
    138         try {
    139             Class<?> clazz = conscryptClass("Conscrypt$SocketFactories");
    140             Method method = clazz.getMethod("setUseEngineSocket", SSLSocketFactory.class, boolean.class);
    141             method.invoke(null, conscryptFactory, useEngineSocket);
    142             return conscryptFactory;
    143         } catch (Exception e) {
    144             throw new RuntimeException(e);
    145         }
    146     }
    147 
    148     static SSLServerSocketFactory setUseEngineSocket(SSLServerSocketFactory conscryptFactory, boolean useEngineSocket) {
    149         try {
    150             Class<?> clazz = conscryptClass("Conscrypt$ServerSocketFactories");
    151             Method method = clazz.getMethod("setUseEngineSocket", SSLServerSocketFactory.class, boolean.class);
    152             method.invoke(null, conscryptFactory, useEngineSocket);
    153             return conscryptFactory;
    154         } catch (Exception e) {
    155             throw new RuntimeException(e);
    156         }
    157     }
    158 
    159     public static SSLSocketFactory getConscryptSocketFactory(boolean useEngineSocket) {
    160         return setUseEngineSocket(getSocketFactory(getConscryptProvider()), useEngineSocket);
    161     }
    162 
    163     public static SSLServerSocketFactory getConscryptServerSocketFactory(boolean useEngineSocket) {
    164         return setUseEngineSocket(getServerSocketFactory(getConscryptProvider()), useEngineSocket);
    165     }
    166 
    167     private static SSLSocketFactory getSocketFactory(Provider provider) {
    168         SSLContext clientContext = initClientSslContext(newContext(provider));
    169         return clientContext.getSocketFactory();
    170     }
    171 
    172     private static SSLServerSocketFactory getServerSocketFactory(Provider provider) {
    173         SSLContext serverContext = initServerSslContext(newContext(provider));
    174         return serverContext.getServerSocketFactory();
    175     }
    176 
    177     private static SSLContext newContext(Provider provider) {
    178         try {
    179             return SSLContext.getInstance("TLS", provider);
    180         } catch (NoSuchAlgorithmException e) {
    181             throw new RuntimeException(e);
    182         }
    183     }
    184 
    185     /**
    186      * Picks a port that is not used right at this moment.
    187      * Warning: Not thread safe. May see "BindException: Address already in use: bind" if using the
    188      * returned port to create a new server socket when other threads/processes are concurrently
    189      * creating new sockets without a specific port.
    190      */
    191     public static int pickUnusedPort() {
    192         try {
    193             ServerSocket serverSocket = new ServerSocket(0);
    194             int port = serverSocket.getLocalPort();
    195             serverSocket.close();
    196             return port;
    197         } catch (IOException e) {
    198             throw new RuntimeException(e);
    199         }
    200     }
    201 
    202     /**
    203      * Creates a text message of the given length.
    204      */
    205     public static byte[] newTextMessage(int length) {
    206         byte[] msg = new byte[length];
    207         for (int msgIndex = 0; msgIndex < length;) {
    208             int remaining = length - msgIndex;
    209             int numChars = Math.min(remaining, CHARS.length);
    210             System.arraycopy(CHARS, 0, msg, msgIndex, numChars);
    211             msgIndex += numChars;
    212         }
    213         return msg;
    214     }
    215 
    216     /**
    217      * Initializes the given engine with the cipher and client mode.
    218      */
    219     static SSLEngine initEngine(SSLEngine engine, String cipher, boolean client) {
    220         engine.setEnabledProtocols(getProtocols());
    221         engine.setEnabledCipherSuites(new String[] {cipher});
    222         engine.setUseClientMode(client);
    223         return engine;
    224     }
    225 
    226     static SSLContext newClientSslContext(Provider provider) {
    227         SSLContext context = newContext(provider);
    228         return initClientSslContext(context);
    229     }
    230 
    231     static SSLContext newServerSslContext(Provider provider) {
    232         SSLContext context = newContext(provider);
    233         return initServerSslContext(context);
    234     }
    235 
    236     /**
    237      * Initializes the given client-side {@code context} with a default cert.
    238      */
    239     public static SSLContext initClientSslContext(SSLContext context) {
    240         return initSslContext(context, TestKeyStore.getClient());
    241     }
    242 
    243     /**
    244      * Initializes the given server-side {@code context} with the given cert chain and private key.
    245      */
    246     public static SSLContext initServerSslContext(SSLContext context) {
    247         return initSslContext(context, TestKeyStore.getServer());
    248     }
    249 
    250     /**
    251      * Initializes the given {@code context} from the {@code keyStore}.
    252      */
    253     static SSLContext initSslContext(SSLContext context, TestKeyStore keyStore) {
    254         try {
    255             context.init(keyStore.keyManagers, keyStore.trustManagers, null);
    256             return context;
    257         } catch (Exception e) {
    258             throw new RuntimeException(e);
    259         }
    260     }
    261 
    262     /**
    263      * Performs the intial TLS handshake between the two {@link SSLEngine} instances.
    264      */
    265     public static void doEngineHandshake(SSLEngine clientEngine, SSLEngine serverEngine,
    266             ByteBuffer clientAppBuffer, ByteBuffer clientPacketBuffer, ByteBuffer serverAppBuffer,
    267             ByteBuffer serverPacketBuffer) throws SSLException {
    268         clientEngine.beginHandshake();
    269         serverEngine.beginHandshake();
    270 
    271         SSLEngineResult clientResult;
    272         SSLEngineResult serverResult;
    273 
    274         boolean clientHandshakeFinished = false;
    275         boolean serverHandshakeFinished = false;
    276 
    277         do {
    278             int cTOsPos = clientPacketBuffer.position();
    279             int sTOcPos = serverPacketBuffer.position();
    280 
    281             clientResult = clientEngine.wrap(EMPTY_BUFFER, clientPacketBuffer);
    282             runDelegatedTasks(clientResult, clientEngine);
    283             serverResult = serverEngine.wrap(EMPTY_BUFFER, serverPacketBuffer);
    284             runDelegatedTasks(serverResult, serverEngine);
    285 
    286             // Verify that the consumed and produced number match what is in the buffers now.
    287             assertEquals(0, clientResult.bytesConsumed());
    288             assertEquals(0, serverResult.bytesConsumed());
    289             assertEquals(clientPacketBuffer.position() - cTOsPos, clientResult.bytesProduced());
    290             assertEquals(serverPacketBuffer.position() - sTOcPos, serverResult.bytesProduced());
    291 
    292             clientPacketBuffer.flip();
    293             serverPacketBuffer.flip();
    294 
    295             // Verify that we only had one SSLEngineResult.HandshakeStatus.FINISHED
    296             if (isHandshakeFinished(clientResult)) {
    297                 assertFalse(clientHandshakeFinished);
    298                 clientHandshakeFinished = true;
    299             }
    300             if (isHandshakeFinished(serverResult)) {
    301                 assertFalse(serverHandshakeFinished);
    302                 serverHandshakeFinished = true;
    303             }
    304 
    305             cTOsPos = clientPacketBuffer.position();
    306             sTOcPos = serverPacketBuffer.position();
    307 
    308             int clientAppReadBufferPos = clientAppBuffer.position();
    309             int serverAppReadBufferPos = serverAppBuffer.position();
    310 
    311             clientResult = clientEngine.unwrap(serverPacketBuffer, clientAppBuffer);
    312             runDelegatedTasks(clientResult, clientEngine);
    313             serverResult = serverEngine.unwrap(clientPacketBuffer, serverAppBuffer);
    314             runDelegatedTasks(serverResult, serverEngine);
    315 
    316             // Verify that the consumed and produced number match what is in the buffers now.
    317             assertEquals(serverPacketBuffer.position() - sTOcPos, clientResult.bytesConsumed());
    318             assertEquals(clientPacketBuffer.position() - cTOsPos, serverResult.bytesConsumed());
    319             assertEquals(clientAppBuffer.position() - clientAppReadBufferPos,
    320                     clientResult.bytesProduced());
    321             assertEquals(serverAppBuffer.position() - serverAppReadBufferPos,
    322                     serverResult.bytesProduced());
    323 
    324             clientPacketBuffer.compact();
    325             serverPacketBuffer.compact();
    326 
    327             // Verify that we only had one SSLEngineResult.HandshakeStatus.FINISHED
    328             if (isHandshakeFinished(clientResult)) {
    329                 assertFalse(clientHandshakeFinished);
    330                 clientHandshakeFinished = true;
    331             }
    332             if (isHandshakeFinished(serverResult)) {
    333                 assertFalse(serverHandshakeFinished);
    334                 serverHandshakeFinished = true;
    335             }
    336         } while (!clientHandshakeFinished || !serverHandshakeFinished);
    337     }
    338 
    339     private static boolean isHandshakeFinished(SSLEngineResult result) {
    340         return result.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.FINISHED;
    341     }
    342 
    343     private static void runDelegatedTasks(SSLEngineResult result, SSLEngine engine) {
    344         if (result.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_TASK) {
    345             for (;;) {
    346                 Runnable task = engine.getDelegatedTask();
    347                 if (task == null) {
    348                     break;
    349                 }
    350                 task.run();
    351             }
    352         }
    353     }
    354 }
    355