Home | History | Annotate | Download | only in conscrypt
      1 /*
      2  * Copyright (C) 2016 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 /*
     18  * Copyright 2016 The Netty Project
     19  *
     20  * The Netty Project licenses this file to you under the Apache License,
     21  * version 2.0 (the "License"); you may not use this file except in compliance
     22  * with the License. You may obtain a copy of the License at:
     23  *
     24  *   http://www.apache.org/licenses/LICENSE-2.0
     25  *
     26  * Unless required by applicable law or agreed to in writing, software
     27  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
     28  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
     29  * License for the specific language governing permissions and limitations
     30  * under the License.
     31  */
     32 
     33 package org.conscrypt;
     34 
     35 import static java.lang.Math.min;
     36 import static org.conscrypt.NativeConstants.SSL3_RT_ALERT;
     37 import static org.conscrypt.NativeConstants.SSL3_RT_APPLICATION_DATA;
     38 import static org.conscrypt.NativeConstants.SSL3_RT_CHANGE_CIPHER_SPEC;
     39 import static org.conscrypt.NativeConstants.SSL3_RT_HANDSHAKE;
     40 import static org.conscrypt.NativeConstants.SSL3_RT_HEADER_LENGTH;
     41 import static org.conscrypt.NativeConstants.SSL3_RT_MAX_PACKET_SIZE;
     42 
     43 import java.io.ByteArrayInputStream;
     44 import java.nio.ByteBuffer;
     45 import java.nio.charset.Charset;
     46 import java.security.cert.CertificateEncodingException;
     47 import java.security.cert.CertificateFactory;
     48 import java.security.cert.X509Certificate;
     49 import java.util.Arrays;
     50 import java.util.HashSet;
     51 import java.util.Set;
     52 import javax.net.ssl.SSLException;
     53 import javax.net.ssl.SSLHandshakeException;
     54 import javax.net.ssl.SSLPeerUnverifiedException;
     55 import javax.net.ssl.SSLSession;
     56 import javax.security.cert.CertificateException;
     57 
     58 /**
     59  * Utility methods for SSL packet processing. Copied from the Netty project.
     60  * <p>
     61  * This is a public class to allow testing to occur on Android via CTS.
     62  */
     63 final class SSLUtils {
     64     static final boolean USE_ENGINE_SOCKET_BY_DEFAULT = Boolean.parseBoolean(
     65             System.getProperty("org.conscrypt.useEngineSocketByDefault", "false"));
     66     private static final int MAX_PROTOCOL_LENGTH = 255;
     67 
     68     private static final Charset US_ASCII = Charset.forName("US-ASCII");
     69 
     70     // TODO(nathanmittler): Should these be in NativeConstants?
     71     enum SessionType {
     72         /**
     73          * Identifies OpenSSL sessions.
     74          */
     75         OPEN_SSL(1),
     76 
     77         /**
     78          * Identifies OpenSSL sessions with OCSP stapled data.
     79          */
     80         OPEN_SSL_WITH_OCSP(2),
     81 
     82         /**
     83          * Identifies OpenSSL sessions with TLS SCT data.
     84          */
     85         OPEN_SSL_WITH_TLS_SCT(3);
     86 
     87         SessionType(int value) {
     88             this.value = value;
     89         }
     90 
     91         static boolean isSupportedType(int type) {
     92             return type == OPEN_SSL.value || type == OPEN_SSL_WITH_OCSP.value
     93                     || type == OPEN_SSL_WITH_TLS_SCT.value;
     94         }
     95 
     96         final int value;
     97     }
     98 
     99     /**
    100      * States for SSL engines.
    101      */
    102     static final class EngineStates {
    103         private EngineStates() {}
    104 
    105         /**
    106          * The engine is constructed, but the initial handshake hasn't been started
    107          */
    108         static final int STATE_NEW = 0;
    109 
    110         /**
    111          * The client/server mode of the engine has been set.
    112          */
    113         static final int STATE_MODE_SET = 1;
    114 
    115         /**
    116          * The handshake has been started
    117          */
    118         static final int STATE_HANDSHAKE_STARTED = 2;
    119 
    120         /**
    121          * Listeners of the handshake have been notified of completion but the handshake call
    122          * hasn't returned.
    123          */
    124         static final int STATE_HANDSHAKE_COMPLETED = 3;
    125 
    126         /**
    127          * The handshake call returned but the listeners have not yet been notified. This is expected
    128          * behaviour in cut-through mode, where SSL_do_handshake returns before the handshake is
    129          * complete. We can now start writing data to the socket.
    130          */
    131         static final int STATE_READY_HANDSHAKE_CUT_THROUGH = 4;
    132 
    133         /**
    134          * The handshake call has returned and the listeners have been notified. Ready to begin
    135          * writing data.
    136          */
    137         static final int STATE_READY = 5;
    138 
    139         /**
    140          * The inbound direction of the engine has been closed.
    141          */
    142         static final int STATE_CLOSED_INBOUND = 6;
    143 
    144         /**
    145          * The outbound direction of the engine has been closed.
    146          */
    147         static final int STATE_CLOSED_OUTBOUND = 7;
    148 
    149         /**
    150          * The engine has been closed.
    151          */
    152         static final int STATE_CLOSED = 8;
    153     }
    154 
    155     /**
    156      * This is the maximum overhead when encrypting plaintext as defined by
    157      * <a href="https://www.ietf.org/rfc/rfc5246.txt">rfc5264</a>,
    158      * <a href="https://www.ietf.org/rfc/rfc5289.txt">rfc5289</a> and openssl implementation itself.
    159      *
    160      * Please note that we use a padding of 16 here as openssl uses PKC#5 which uses 16 bytes
    161      * whilethe spec itself allow up to 255 bytes. 16 bytes is the max for PKC#5 (which handles it
    162      * the same way as PKC#7) as we use a block size of 16. See <a
    163      * href="https://tools.ietf.org/html/rfc5652#section-6.3">rfc5652#section-6.3</a>.
    164      *
    165      * 16 (IV) + 48 (MAC) + 1 (Padding_length field) + 15 (Padding) + 1 (ContentType) + 2
    166      * (ProtocolVersion) + 2 (Length)
    167      *
    168      * TODO: We may need to review this calculation once TLS 1.3 becomes available.
    169      */
    170     private static final int MAX_ENCRYPTION_OVERHEAD_LENGTH = 15 + 48 + 1 + 16 + 1 + 2 + 2;
    171 
    172     private static final int MAX_ENCRYPTION_OVERHEAD_DIFF =
    173             Integer.MAX_VALUE - MAX_ENCRYPTION_OVERHEAD_LENGTH;
    174 
    175     /** Key type: RSA certificate. */
    176     private static final String KEY_TYPE_RSA = "RSA";
    177 
    178     /** Key type: Elliptic Curve certificate. */
    179     private static final String KEY_TYPE_EC = "EC";
    180 
    181     /**
    182      * If the given session is a {@link SessionDecorator}, unwraps the session and returns the
    183      * underlying (non-decorated) session. Otherwise, returns the provided session.
    184      */
    185     static SSLSession unwrapSession(SSLSession session) {
    186         while (session instanceof SessionDecorator) {
    187             session = ((SessionDecorator) session).getDelegate();
    188         }
    189         return session;
    190     }
    191 
    192     static X509Certificate[] decodeX509CertificateChain(byte[][] certChain)
    193             throws java.security.cert.CertificateException {
    194         CertificateFactory certificateFactory = getCertificateFactory();
    195         int numCerts = certChain.length;
    196         X509Certificate[] decodedCerts = new X509Certificate[numCerts];
    197         for (int i = 0; i < numCerts; i++) {
    198             decodedCerts[i] = decodeX509Certificate(certificateFactory, certChain[i]);
    199         }
    200         return decodedCerts;
    201     }
    202 
    203     private static CertificateFactory getCertificateFactory() {
    204         try {
    205             return CertificateFactory.getInstance("X.509");
    206         } catch (java.security.cert.CertificateException e) {
    207             return null;
    208         }
    209     }
    210 
    211     private static X509Certificate decodeX509Certificate(CertificateFactory certificateFactory,
    212             byte[] bytes) throws java.security.cert.CertificateException {
    213         if (certificateFactory != null) {
    214             return (X509Certificate) certificateFactory.generateCertificate(
    215                     new ByteArrayInputStream(bytes));
    216         }
    217         return OpenSSLX509Certificate.fromX509Der(bytes);
    218     }
    219 
    220     /**
    221      * Returns key type constant suitable for calling X509KeyManager.chooseServerAlias or
    222      * X509ExtendedKeyManager.chooseEngineServerAlias. Returns {@code null} for key exchanges that
    223      * do not use X.509 for server authentication.
    224      */
    225     static String getServerX509KeyType(long sslCipherNative) throws SSLException {
    226         String kx_name = NativeCrypto.SSL_CIPHER_get_kx_name(sslCipherNative);
    227         if (kx_name.equals("RSA") || kx_name.equals("DHE_RSA") || kx_name.equals("ECDHE_RSA")) {
    228             return KEY_TYPE_RSA;
    229         } else if (kx_name.equals("ECDHE_ECDSA")) {
    230             return KEY_TYPE_EC;
    231         } else {
    232             return null;
    233         }
    234     }
    235 
    236     /**
    237      * Similar to getServerKeyType, but returns value given TLS
    238      * ClientCertificateType byte values from a CertificateRequest
    239      * message for use with X509KeyManager.chooseClientAlias or
    240      * X509ExtendedKeyManager.chooseEngineClientAlias.
    241      * <p>
    242      * Visible for testing.
    243      */
    244     static String getClientKeyType(byte clientCertificateType) {
    245         // See also http://www.ietf.org/assignments/tls-parameters/tls-parameters.xml
    246         switch (clientCertificateType) {
    247             case NativeConstants.TLS_CT_RSA_SIGN:
    248                 return KEY_TYPE_RSA; // RFC rsa_sign
    249             case NativeConstants.TLS_CT_ECDSA_SIGN:
    250                 return KEY_TYPE_EC; // RFC ecdsa_sign
    251             default:
    252                 return null;
    253         }
    254     }
    255 
    256     /**
    257      * Gets the supported key types for client certificates based on the
    258      * {@code ClientCertificateType} values provided by the server.
    259      *
    260      * @param clientCertificateTypes {@code ClientCertificateType} values provided by the server.
    261      *        See https://www.ietf.org/assignments/tls-parameters/tls-parameters.xml.
    262      * @return supported key types that can be used in {@code X509KeyManager.chooseClientAlias} and
    263      *         {@code X509ExtendedKeyManager.chooseEngineClientAlias}.
    264      *
    265      * Visible for testing.
    266      */
    267     static Set<String> getSupportedClientKeyTypes(byte[] clientCertificateTypes) {
    268         Set<String> result = new HashSet<String>(clientCertificateTypes.length);
    269         for (byte keyTypeCode : clientCertificateTypes) {
    270             String keyType = SSLUtils.getClientKeyType(keyTypeCode);
    271             if (keyType == null) {
    272                 // Unsupported client key type -- ignore
    273                 continue;
    274             }
    275             result.add(keyType);
    276         }
    277         return result;
    278     }
    279 
    280     static byte[][] encodeIssuerX509Principals(X509Certificate[] certificates)
    281             throws CertificateEncodingException {
    282         byte[][] principalBytes = new byte[certificates.length][];
    283         for (int i = 0; i < certificates.length; i++) {
    284             principalBytes[i] = certificates[i].getIssuerX500Principal().getEncoded();
    285         }
    286         return principalBytes;
    287     }
    288 
    289     /**
    290      * Converts the peer certificates into a cert chain.
    291      */
    292     static javax.security.cert.X509Certificate[] toCertificateChain(X509Certificate[] certificates)
    293             throws SSLPeerUnverifiedException {
    294         try {
    295             javax.security.cert.X509Certificate[] chain =
    296                     new javax.security.cert.X509Certificate[certificates.length];
    297 
    298             for (int i = 0; i < certificates.length; i++) {
    299                 byte[] encoded = certificates[i].getEncoded();
    300                 chain[i] = javax.security.cert.X509Certificate.getInstance(encoded);
    301             }
    302             return chain;
    303         } catch (CertificateEncodingException e) {
    304             SSLPeerUnverifiedException exception = new SSLPeerUnverifiedException(e.getMessage());
    305             exception.initCause(exception);
    306             throw exception;
    307         } catch (CertificateException e) {
    308             SSLPeerUnverifiedException exception = new SSLPeerUnverifiedException(e.getMessage());
    309             exception.initCause(exception);
    310             throw exception;
    311         }
    312     }
    313 
    314     /**
    315      * Calculates the minimum bytes required in the encrypted output buffer for the given number of
    316      * plaintext source bytes.
    317      */
    318     static int calculateOutNetBufSize(int pendingBytes) {
    319         return min(SSL3_RT_MAX_PACKET_SIZE,
    320                 MAX_ENCRYPTION_OVERHEAD_LENGTH + min(MAX_ENCRYPTION_OVERHEAD_DIFF, pendingBytes));
    321     }
    322 
    323     /**
    324      * Wraps the given exception if it's not already a {@link SSLHandshakeException}.
    325      */
    326     static SSLHandshakeException toSSLHandshakeException(Throwable e) {
    327         if (e instanceof SSLHandshakeException) {
    328             return (SSLHandshakeException) e;
    329         }
    330 
    331         return (SSLHandshakeException) new SSLHandshakeException(e.getMessage()).initCause(e);
    332     }
    333 
    334     /**
    335      * Wraps the given exception if it's not already a {@link SSLException}.
    336      */
    337     static SSLException toSSLException(Throwable e) {
    338         if (e instanceof SSLException) {
    339             return (SSLException) e;
    340         }
    341         return new SSLException(e);
    342     }
    343 
    344     static String toProtocolString(byte[] bytes) {
    345         if (bytes == null) {
    346             return null;
    347         }
    348         return new String(bytes, US_ASCII);
    349     }
    350 
    351     static byte[] toProtocolBytes(String protocol) {
    352         if (protocol == null) {
    353             return null;
    354         }
    355         return protocol.getBytes(US_ASCII);
    356     }
    357 
    358     /**
    359      * Decodes the given list of protocols into {@link String}s.
    360      * @param protocols the encoded protocol list
    361      * @return the decoded protocols or {@link EmptyArray#BYTE} if {@code protocols} is
    362      * empty.
    363      * @throws NullPointerException if protocols is {@code null}.
    364      */
    365     static String[] decodeProtocols(byte[] protocols) {
    366         if (protocols.length == 0) {
    367             return EmptyArray.STRING;
    368         }
    369 
    370         int numProtocols = 0;
    371         for (int i = 0; i < protocols.length;) {
    372             int protocolLength = protocols[i];
    373             if (protocolLength < 0 || protocolLength > protocols.length - i) {
    374                 throw new IllegalArgumentException(
    375                     "Protocol has invalid length (" + protocolLength + " at position " + i
    376                         + "): " + (protocols.length < 50
    377                         ? Arrays.toString(protocols) : protocols.length + " byte array"));
    378             }
    379 
    380             numProtocols++;
    381             i += 1 + protocolLength;
    382         }
    383 
    384         String[] decoded = new String[numProtocols];
    385         for (int i = 0, d = 0; i < protocols.length;) {
    386             int protocolLength = protocols[i];
    387             decoded[d++] = protocolLength > 0
    388                     ? new String(protocols, i + 1, protocolLength, US_ASCII)
    389                     : "";
    390             i += 1 + protocolLength;
    391         }
    392 
    393         return decoded;
    394     }
    395 
    396     /**
    397      * Encodes a list of protocols into the wire-format (length-prefixed 8-bit strings).
    398      * Requires that all strings be encoded with US-ASCII.
    399      *
    400      * @param protocols the list of protocols to be encoded
    401      * @return the encoded form of the protocol list.
    402      * @throws IllegalArgumentException if protocols is {@code null}, or if any element is
    403      * {@code null} or an empty string.
    404      */
    405     static byte[] encodeProtocols(String[] protocols) {
    406         if (protocols == null) {
    407             throw new IllegalArgumentException("protocols array must be non-null");
    408         }
    409 
    410         if (protocols.length == 0) {
    411             return EmptyArray.BYTE;
    412         }
    413 
    414         // Calculate the encoded length.
    415         int length = 0;
    416         for (int i = 0; i < protocols.length; ++i) {
    417             String protocol = protocols[i];
    418             if (protocol == null) {
    419                 throw new IllegalArgumentException("protocol[" + i + "] is null");
    420             }
    421             int protocolLength = protocols[i].length();
    422 
    423             // Verify that the length is valid here, so that we don't attempt to allocate an array
    424             // below if the threshold is violated.
    425             if (protocolLength == 0 || protocolLength > MAX_PROTOCOL_LENGTH) {
    426                 throw new IllegalArgumentException(
    427                     "protocol[" + i + "] has invalid length: " + protocolLength);
    428             }
    429 
    430             // Include a 1-byte prefix for each protocol.
    431             length += 1 + protocolLength;
    432         }
    433 
    434         byte[] data = new byte[length];
    435         for (int dataIndex = 0, i = 0; i < protocols.length; ++i) {
    436             String protocol = protocols[i];
    437             int protocolLength = protocol.length();
    438 
    439             // Add the length prefix.
    440             data[dataIndex++] = (byte) protocolLength;
    441             for (int ci = 0; ci < protocolLength; ++ci) {
    442                 char c = protocol.charAt(ci);
    443                 if (c > Byte.MAX_VALUE) {
    444                     // Enforce US-ASCII
    445                     throw new IllegalArgumentException("Protocol contains invalid character: "
    446                         + c + "(protocol=" + protocol + ")");
    447                 }
    448                 data[dataIndex++] = (byte) c;
    449             }
    450         }
    451         return data;
    452     }
    453 
    454     /**
    455      * Return how much bytes can be read out of the encrypted data. Be aware that this method will
    456      * not increase the readerIndex of the given {@link ByteBuffer}.
    457      *
    458      * @param buffers The {@link ByteBuffer}s to read from. Be aware that they must have at least
    459      * {@link org.conscrypt.NativeConstants#SSL3_RT_HEADER_LENGTH} bytes to read, otherwise it will
    460      * throw an {@link IllegalArgumentException}.
    461      * @return length The length of the encrypted packet that is included in the buffer. This will
    462      * return {@code -1} if the given {@link ByteBuffer} is not encrypted at all.
    463      * @throws IllegalArgumentException Is thrown if the given {@link ByteBuffer} has not at least
    464      * {@link org.conscrypt.NativeConstants#SSL3_RT_HEADER_LENGTH} bytes to read.
    465      */
    466     static int getEncryptedPacketLength(ByteBuffer[] buffers, int offset) {
    467         ByteBuffer buffer = buffers[offset];
    468 
    469         // Check if everything we need is in one ByteBuffer. If so we can make use of the fast-path.
    470         if (buffer.remaining() >= SSL3_RT_HEADER_LENGTH) {
    471             return getEncryptedPacketLength(buffer);
    472         }
    473 
    474         // We need to copy 5 bytes into a temporary buffer so we can parse out the packet length
    475         // easily.
    476         ByteBuffer tmp = ByteBuffer.allocate(SSL3_RT_HEADER_LENGTH);
    477         do {
    478             buffer = buffers[offset++];
    479             int pos = buffer.position();
    480             int limit = buffer.limit();
    481             if (buffer.remaining() > tmp.remaining()) {
    482                 buffer.limit(pos + tmp.remaining());
    483             }
    484             try {
    485                 tmp.put(buffer);
    486             } finally {
    487                 // Restore the original indices.
    488                 buffer.limit(limit);
    489                 buffer.position(pos);
    490             }
    491         } while (tmp.hasRemaining());
    492 
    493         // Done, flip the buffer so we can read from it.
    494         tmp.flip();
    495         return getEncryptedPacketLength(tmp);
    496     }
    497 
    498     private static int getEncryptedPacketLength(ByteBuffer buffer) {
    499         int pos = buffer.position();
    500         // SSLv3 or TLS - Check ContentType
    501         switch (unsignedByte(buffer.get(pos))) {
    502             case SSL3_RT_CHANGE_CIPHER_SPEC:
    503             case SSL3_RT_ALERT:
    504             case SSL3_RT_HANDSHAKE:
    505             case SSL3_RT_APPLICATION_DATA:
    506                 break;
    507             default:
    508                 // SSLv2 or bad data
    509                 return -1;
    510         }
    511 
    512         // SSLv3 or TLS - Check ProtocolVersion
    513         int majorVersion = unsignedByte(buffer.get(pos + 1));
    514         if (majorVersion != 3) {
    515             // Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data)
    516             return -1;
    517         }
    518 
    519         // SSLv3 or TLS
    520         int packetLength = unsignedShort(buffer.getShort(pos + 3)) + SSL3_RT_HEADER_LENGTH;
    521         if (packetLength <= SSL3_RT_HEADER_LENGTH) {
    522             // Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data)
    523             return -1;
    524         }
    525         return packetLength;
    526     }
    527 
    528     private static short unsignedByte(byte b) {
    529         return (short) (b & 0xFF);
    530     }
    531 
    532     private static int unsignedShort(short s) {
    533         return s & 0xFFFF;
    534     }
    535 
    536     private SSLUtils() {}
    537 }
    538