Home | History | Annotate | Download | only in conscrypt
      1 /*
      2  * Copyright (C) 2015 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 package org.conscrypt;
     18 
     19 import static org.junit.Assert.assertEquals;
     20 import static org.junit.Assert.assertFalse;
     21 
     22 import java.io.FileNotFoundException;
     23 import java.io.IOException;
     24 import java.io.InputStream;
     25 import java.lang.reflect.Method;
     26 import java.net.InetAddress;
     27 import java.net.ServerSocket;
     28 import java.net.UnknownHostException;
     29 import java.nio.ByteBuffer;
     30 import java.nio.charset.Charset;
     31 import java.security.NoSuchAlgorithmException;
     32 import java.security.Provider;
     33 import java.security.Security;
     34 import java.util.ArrayList;
     35 import java.util.Arrays;
     36 import java.util.Iterator;
     37 import java.util.LinkedHashSet;
     38 import java.util.List;
     39 import java.util.Set;
     40 import javax.net.ssl.SSLContext;
     41 import javax.net.ssl.SSLEngine;
     42 import javax.net.ssl.SSLEngineResult;
     43 import javax.net.ssl.SSLException;
     44 import javax.net.ssl.SSLParameters;
     45 import javax.net.ssl.SSLServerSocketFactory;
     46 import javax.net.ssl.SSLSocketFactory;
     47 import libcore.io.Streams;
     48 import org.bouncycastle.jce.provider.BouncyCastleProvider;
     49 import org.conscrypt.java.security.TestKeyStore;
     50 import org.junit.Assume;
     51 
     52 /**
     53  * Utility methods to support testing.
     54  */
     55 public final class TestUtils {
     56     public static final Charset UTF_8 = Charset.forName("UTF-8");
     57     private static final String PROTOCOL_TLS_V1_2 = "TLSv1.2";
     58     private static final String PROTOCOL_TLS_V1_1 = "TLSv1.1";
     59     private static final String PROTOCOL_TLS_V1 = "TLSv1";
     60     private static final String[] DESIRED_PROTOCOLS =
     61         new String[] {PROTOCOL_TLS_V1_2, PROTOCOL_TLS_V1_1, /* For Java 6 */ PROTOCOL_TLS_V1};
     62     private static final Provider JDK_PROVIDER = getDefaultTlsProvider();
     63     private static final byte[] CHARS =
     64             "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789".getBytes(UTF_8);
     65     private static final ByteBuffer EMPTY_BUFFER = ByteBuffer.allocateDirect(0);
     66     private static final String[] PROTOCOLS = getProtocolsInternal();
     67 
     68     static final String TEST_CIPHER = "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256";
     69 
     70     private TestUtils() {}
     71 
     72     private static Provider getDefaultTlsProvider() {
     73         for (String protocol : DESIRED_PROTOCOLS) {
     74             for (Provider p : Security.getProviders()) {
     75                 if (hasProtocol(p, protocol)) {
     76                     return p;
     77                 }
     78             }
     79         }
     80         // For Java 1.6 testing
     81         return new BouncyCastleProvider();
     82     }
     83 
     84     private static boolean hasProtocol(Provider p, String protocol) {
     85         return p.get("SSLContext." + protocol) != null;
     86     }
     87 
     88     static Provider getJdkProvider() {
     89         return JDK_PROVIDER;
     90     }
     91 
     92     private static void assumeClassAvailable(String classname) {
     93         boolean available = false;
     94         try {
     95             Class.forName(classname);
     96             available = true;
     97         } catch (ClassNotFoundException ignore) {
     98             // Ignored
     99         }
    100         Assume.assumeTrue("Skipping test: " + classname + " unavailable", available);
    101     }
    102 
    103     public static void assumeSNIHostnameAvailable() {
    104         assumeClassAvailable("javax.net.ssl.SNIHostName");
    105     }
    106 
    107     public static void assumeSetEndpointIdentificationAlgorithmAvailable() {
    108         boolean supported = false;
    109         try {
    110             SSLParameters.class.getMethod("setEndpointIdentificationAlgorithm", String.class);
    111             supported = true;
    112         } catch (NoSuchMethodException ignore) {
    113             // Ignored
    114         }
    115         Assume.assumeTrue("Skipping test: "
    116                 + "SSLParameters.setEndpointIdentificationAlgorithm unavailable", supported);
    117     }
    118 
    119     public static void assumeAEADAvailable() {
    120         assumeClassAvailable("javax.crypto.AEADBadTagException");
    121     }
    122 
    123     private static boolean isAndroid() {
    124         try {
    125             Class.forName("android.app.Application", false, ClassLoader.getSystemClassLoader());
    126             return true;
    127         } catch (Throwable ignored) {
    128             // Failed to load the class uniquely available in Android.
    129             return false;
    130         }
    131     }
    132 
    133     public static void assumeAndroid() {
    134         Assume.assumeTrue(isAndroid());
    135     }
    136 
    137     public static void assumeAllowsUnsignedCrypto() {
    138         // The Oracle JRE disallows loading crypto providers from unsigned jars
    139         Assume.assumeTrue(isAndroid()
    140                 || !System.getProperty("java.vm.name").contains("HotSpot"));
    141     }
    142 
    143     public static InetAddress getLoopbackAddress() {
    144         try {
    145             Method method = InetAddress.class.getMethod("getLoopbackAddress");
    146             return (InetAddress) method.invoke(null);
    147         } catch (Exception ignore) {
    148             // Ignored.
    149         }
    150         try {
    151             return InetAddress.getLocalHost();
    152         } catch (UnknownHostException e) {
    153             throw new RuntimeException(e);
    154         }
    155     }
    156 
    157     public static Provider getConscryptProvider() {
    158         try {
    159             return (Provider) conscryptClass("OpenSSLProvider").getConstructor().newInstance();
    160         } catch (Exception e) {
    161             throw new RuntimeException(e);
    162         }
    163     }
    164 
    165     public static synchronized void installConscryptAsDefaultProvider() {
    166         final Provider conscryptProvider = getConscryptProvider();
    167         Provider[] providers = Security.getProviders();
    168         if (providers.length == 0 || !providers[0].equals(conscryptProvider)) {
    169             Security.insertProviderAt(conscryptProvider, 1);
    170         }
    171     }
    172 
    173     public static InputStream openTestFile(String name) throws FileNotFoundException {
    174         InputStream is = TestUtils.class.getResourceAsStream("/" + name);
    175         if (is == null) {
    176             throw new FileNotFoundException(name);
    177         }
    178         return is;
    179     }
    180 
    181     public static byte[] readTestFile(String name) throws IOException {
    182         return Streams.readFully(openTestFile(name));
    183     }
    184 
    185     /**
    186      * Looks up the conscrypt class for the given simple name (i.e. no package prefix).
    187      */
    188     public static Class<?> conscryptClass(String simpleName) throws ClassNotFoundException {
    189         ClassNotFoundException ex = null;
    190         for (String packageName : new String[] {"org.conscrypt", "com.android.org.conscrypt"}) {
    191             String name = packageName + "." + simpleName;
    192             try {
    193                 return Class.forName(name);
    194             } catch (ClassNotFoundException e) {
    195                 ex = e;
    196             }
    197         }
    198         throw ex;
    199     }
    200 
    201     /**
    202      * Returns an array containing only {@link #PROTOCOL_TLS_V1_2}.
    203      */
    204     public static String[] getProtocols() {
    205         return PROTOCOLS;
    206     }
    207 
    208     private static String[] getProtocolsInternal() {
    209         List<String> protocols = new ArrayList<String>();
    210         for (String protocol : DESIRED_PROTOCOLS) {
    211             if (hasProtocol(getJdkProvider(), protocol)) {
    212                 protocols.add(protocol);
    213             }
    214         }
    215         return protocols.toArray(new String[protocols.size()]);
    216     }
    217 
    218     public static SSLSocketFactory getJdkSocketFactory() {
    219         return getSocketFactory(JDK_PROVIDER);
    220     }
    221 
    222     public static SSLServerSocketFactory getJdkServerSocketFactory() {
    223         return getServerSocketFactory(JDK_PROVIDER);
    224     }
    225 
    226     static SSLSocketFactory setUseEngineSocket(
    227             SSLSocketFactory conscryptFactory, boolean useEngineSocket) {
    228         try {
    229             Class<?> clazz = conscryptClass("Conscrypt");
    230             Method method =
    231                     clazz.getMethod("setUseEngineSocket", SSLSocketFactory.class, boolean.class);
    232             method.invoke(null, conscryptFactory, useEngineSocket);
    233             return conscryptFactory;
    234         } catch (Exception e) {
    235             throw new RuntimeException(e);
    236         }
    237     }
    238 
    239     static SSLServerSocketFactory setUseEngineSocket(
    240             SSLServerSocketFactory conscryptFactory, boolean useEngineSocket) {
    241         try {
    242             Class<?> clazz = conscryptClass("Conscrypt");
    243             Method method = clazz.getMethod(
    244                     "setUseEngineSocket", SSLServerSocketFactory.class, boolean.class);
    245             method.invoke(null, conscryptFactory, useEngineSocket);
    246             return conscryptFactory;
    247         } catch (Exception e) {
    248             throw new RuntimeException(e);
    249         }
    250     }
    251 
    252     public static SSLSocketFactory getConscryptSocketFactory(boolean useEngineSocket) {
    253         return setUseEngineSocket(getSocketFactory(getConscryptProvider()), useEngineSocket);
    254     }
    255 
    256     public static SSLServerSocketFactory getConscryptServerSocketFactory(boolean useEngineSocket) {
    257         return setUseEngineSocket(getServerSocketFactory(getConscryptProvider()), useEngineSocket);
    258     }
    259 
    260     private static SSLSocketFactory getSocketFactory(Provider provider) {
    261         SSLContext clientContext = initClientSslContext(newContext(provider));
    262         return clientContext.getSocketFactory();
    263     }
    264 
    265     private static SSLServerSocketFactory getServerSocketFactory(Provider provider) {
    266         SSLContext serverContext = initServerSslContext(newContext(provider));
    267         return serverContext.getServerSocketFactory();
    268     }
    269 
    270     static SSLContext newContext(Provider provider) {
    271         try {
    272             return SSLContext.getInstance("TLS", provider);
    273         } catch (NoSuchAlgorithmException e) {
    274             throw new RuntimeException(e);
    275         }
    276     }
    277 
    278     static String[] getCommonCipherSuites() {
    279         SSLContext jdkContext =
    280                 TestUtils.initSslContext(newContext(getJdkProvider()), TestKeyStore.getClient());
    281         SSLContext conscryptContext = TestUtils.initSslContext(
    282                 newContext(getConscryptProvider()), TestKeyStore.getClient());
    283         Set<String> supported = new LinkedHashSet<String>();
    284         supported.addAll(supportedCiphers(jdkContext));
    285         supported.retainAll(supportedCiphers(conscryptContext));
    286         filterCiphers(supported);
    287 
    288         return supported.toArray(new String[supported.size()]);
    289     }
    290 
    291     private static List<String> supportedCiphers(SSLContext ctx) {
    292         return Arrays.asList(ctx.getDefaultSSLParameters().getCipherSuites());
    293     }
    294 
    295     private static void filterCiphers(Iterable<String> ciphers) {
    296         // Filter all non-TLS ciphers.
    297         Iterator<String> iter = ciphers.iterator();
    298         while (iter.hasNext()) {
    299             String cipher = iter.next();
    300             if (cipher.startsWith("SSL_") || cipher.startsWith("TLS_EMPTY")
    301                     || cipher.contains("_RC4_")) {
    302                 iter.remove();
    303             }
    304         }
    305     }
    306 
    307     /**
    308      * Picks a port that is not used right at this moment.
    309      * Warning: Not thread safe. May see "BindException: Address already in use: bind" if using the
    310      * returned port to create a new server socket when other threads/processes are concurrently
    311      * creating new sockets without a specific port.
    312      */
    313     public static int pickUnusedPort() {
    314         try {
    315             ServerSocket serverSocket = new ServerSocket(0);
    316             int port = serverSocket.getLocalPort();
    317             serverSocket.close();
    318             return port;
    319         } catch (IOException e) {
    320             throw new RuntimeException(e);
    321         }
    322     }
    323 
    324     /**
    325      * Creates a text message of the given length.
    326      */
    327     public static byte[] newTextMessage(int length) {
    328         byte[] msg = new byte[length];
    329         for (int msgIndex = 0; msgIndex < length;) {
    330             int remaining = length - msgIndex;
    331             int numChars = Math.min(remaining, CHARS.length);
    332             System.arraycopy(CHARS, 0, msg, msgIndex, numChars);
    333             msgIndex += numChars;
    334         }
    335         return msg;
    336     }
    337 
    338     static SSLContext newClientSslContext(Provider provider) {
    339         SSLContext context = newContext(provider);
    340         return initClientSslContext(context);
    341     }
    342 
    343     static SSLContext newServerSslContext(Provider provider) {
    344         SSLContext context = newContext(provider);
    345         return initServerSslContext(context);
    346     }
    347 
    348     /**
    349      * Initializes the given client-side {@code context} with a default cert.
    350      */
    351     public static SSLContext initClientSslContext(SSLContext context) {
    352         return initSslContext(context, TestKeyStore.getClient());
    353     }
    354 
    355     /**
    356      * Initializes the given server-side {@code context} with the given cert chain and private key.
    357      */
    358     public static SSLContext initServerSslContext(SSLContext context) {
    359         return initSslContext(context, TestKeyStore.getServer());
    360     }
    361 
    362     /**
    363      * Initializes the given {@code context} from the {@code keyStore}.
    364      */
    365     static SSLContext initSslContext(SSLContext context, TestKeyStore keyStore) {
    366         try {
    367             context.init(keyStore.keyManagers, keyStore.trustManagers, null);
    368             return context;
    369         } catch (Exception e) {
    370             throw new RuntimeException(e);
    371         }
    372     }
    373 
    374     /**
    375      * Performs the intial TLS handshake between the two {@link SSLEngine} instances.
    376      */
    377     public static void doEngineHandshake(SSLEngine clientEngine, SSLEngine serverEngine,
    378         ByteBuffer clientAppBuffer, ByteBuffer clientPacketBuffer, ByteBuffer serverAppBuffer,
    379         ByteBuffer serverPacketBuffer, boolean beginHandshake) throws SSLException {
    380         if (beginHandshake) {
    381             clientEngine.beginHandshake();
    382             serverEngine.beginHandshake();
    383         }
    384 
    385         SSLEngineResult clientResult;
    386         SSLEngineResult serverResult;
    387 
    388         boolean clientHandshakeFinished = false;
    389         boolean serverHandshakeFinished = false;
    390 
    391         do {
    392             int cTOsPos = clientPacketBuffer.position();
    393             int sTOcPos = serverPacketBuffer.position();
    394 
    395             clientResult = clientEngine.wrap(EMPTY_BUFFER, clientPacketBuffer);
    396             runDelegatedTasks(clientResult, clientEngine);
    397             serverResult = serverEngine.wrap(EMPTY_BUFFER, serverPacketBuffer);
    398             runDelegatedTasks(serverResult, serverEngine);
    399 
    400             // Verify that the consumed and produced number match what is in the buffers now.
    401             assertEquals(0, clientResult.bytesConsumed());
    402             assertEquals(0, serverResult.bytesConsumed());
    403             assertEquals(clientPacketBuffer.position() - cTOsPos, clientResult.bytesProduced());
    404             assertEquals(serverPacketBuffer.position() - sTOcPos, serverResult.bytesProduced());
    405 
    406             clientPacketBuffer.flip();
    407             serverPacketBuffer.flip();
    408 
    409             // Verify that we only had one SSLEngineResult.HandshakeStatus.FINISHED
    410             if (isHandshakeFinished(clientResult)) {
    411                 assertFalse(clientHandshakeFinished);
    412                 clientHandshakeFinished = true;
    413             }
    414             if (isHandshakeFinished(serverResult)) {
    415                 assertFalse(serverHandshakeFinished);
    416                 serverHandshakeFinished = true;
    417             }
    418 
    419             cTOsPos = clientPacketBuffer.position();
    420             sTOcPos = serverPacketBuffer.position();
    421 
    422             int clientAppReadBufferPos = clientAppBuffer.position();
    423             int serverAppReadBufferPos = serverAppBuffer.position();
    424 
    425             clientResult = clientEngine.unwrap(serverPacketBuffer, clientAppBuffer);
    426             runDelegatedTasks(clientResult, clientEngine);
    427             serverResult = serverEngine.unwrap(clientPacketBuffer, serverAppBuffer);
    428             runDelegatedTasks(serverResult, serverEngine);
    429 
    430             // Verify that the consumed and produced number match what is in the buffers now.
    431             assertEquals(serverPacketBuffer.position() - sTOcPos, clientResult.bytesConsumed());
    432             assertEquals(clientPacketBuffer.position() - cTOsPos, serverResult.bytesConsumed());
    433             assertEquals(clientAppBuffer.position() - clientAppReadBufferPos,
    434                 clientResult.bytesProduced());
    435             assertEquals(serverAppBuffer.position() - serverAppReadBufferPos,
    436                 serverResult.bytesProduced());
    437 
    438             clientPacketBuffer.compact();
    439             serverPacketBuffer.compact();
    440 
    441             // Verify that we only had one SSLEngineResult.HandshakeStatus.FINISHED
    442             if (isHandshakeFinished(clientResult)) {
    443                 assertFalse(clientHandshakeFinished);
    444                 clientHandshakeFinished = true;
    445             }
    446             if (isHandshakeFinished(serverResult)) {
    447                 assertFalse(serverHandshakeFinished);
    448                 serverHandshakeFinished = true;
    449             }
    450         } while (!clientHandshakeFinished || !serverHandshakeFinished);
    451     }
    452 
    453     private static boolean isHandshakeFinished(SSLEngineResult result) {
    454         return result.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.FINISHED;
    455     }
    456 
    457     private static void runDelegatedTasks(SSLEngineResult result, SSLEngine engine) {
    458         if (result.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_TASK) {
    459             for (;;) {
    460                 Runnable task = engine.getDelegatedTask();
    461                 if (task == null) {
    462                     break;
    463                 }
    464                 task.run();
    465             }
    466         }
    467     }
    468 
    469     /**
    470      * Decodes the provided hexadecimal string into a byte array.  Odd-length inputs
    471      * are not allowed.
    472      *
    473      * Throws an {@code IllegalArgumentException} if the input is malformed.
    474      */
    475     public static byte[] decodeHex(String encoded) throws IllegalArgumentException {
    476         return decodeHex(encoded.toCharArray());
    477     }
    478 
    479     /**
    480      * Decodes the provided hexadecimal string into a byte array. If {@code allowSingleChar}
    481      * is {@code true} odd-length inputs are allowed and the first character is interpreted
    482      * as the lower bits of the first result byte.
    483      *
    484      * Throws an {@code IllegalArgumentException} if the input is malformed.
    485      */
    486     public static byte[] decodeHex(String encoded, boolean allowSingleChar) throws IllegalArgumentException {
    487         return decodeHex(encoded.toCharArray(), allowSingleChar);
    488     }
    489 
    490     /**
    491      * Decodes the provided hexadecimal string into a byte array.  Odd-length inputs
    492      * are not allowed.
    493      *
    494      * Throws an {@code IllegalArgumentException} if the input is malformed.
    495      */
    496     public static byte[] decodeHex(char[] encoded) throws IllegalArgumentException {
    497         return decodeHex(encoded, false);
    498     }
    499 
    500     /**
    501      * Decodes the provided hexadecimal string into a byte array. If {@code allowSingleChar}
    502      * is {@code true} odd-length inputs are allowed and the first character is interpreted
    503      * as the lower bits of the first result byte.
    504      *
    505      * Throws an {@code IllegalArgumentException} if the input is malformed.
    506      */
    507     public static byte[] decodeHex(char[] encoded, boolean allowSingleChar) throws IllegalArgumentException {
    508         int resultLengthBytes = (encoded.length + 1) / 2;
    509         byte[] result = new byte[resultLengthBytes];
    510 
    511         int resultOffset = 0;
    512         int i = 0;
    513         if (allowSingleChar) {
    514             if ((encoded.length % 2) != 0) {
    515                 // Odd number of digits -- the first digit is the lower 4 bits of the first result byte.
    516                 result[resultOffset++] = (byte) toDigit(encoded, i);
    517                 i++;
    518             }
    519         } else {
    520             if ((encoded.length % 2) != 0) {
    521                 throw new IllegalArgumentException("Invalid input length: " + encoded.length);
    522             }
    523         }
    524 
    525         for (int len = encoded.length; i < len; i += 2) {
    526             result[resultOffset++] = (byte) ((toDigit(encoded, i) << 4) | toDigit(encoded, i + 1));
    527         }
    528 
    529         return result;
    530     }
    531 
    532 
    533     private static int toDigit(char[] str, int offset) throws IllegalArgumentException {
    534         // NOTE: that this isn't really a code point in the traditional sense, since we're
    535         // just rejecting surrogate pairs outright.
    536         int pseudoCodePoint = str[offset];
    537 
    538         if ('0' <= pseudoCodePoint && pseudoCodePoint <= '9') {
    539             return pseudoCodePoint - '0';
    540         } else if ('a' <= pseudoCodePoint && pseudoCodePoint <= 'f') {
    541             return 10 + (pseudoCodePoint - 'a');
    542         } else if ('A' <= pseudoCodePoint && pseudoCodePoint <= 'F') {
    543             return 10 + (pseudoCodePoint - 'A');
    544         }
    545 
    546         throw new IllegalArgumentException("Illegal char: " + str[offset] +
    547                 " at offset " + offset);
    548     }
    549 }
    550