Home | History | Annotate | Download | only in internal
      1 /*
      2  * Copyright (C) 2015 Square, Inc.
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 package com.squareup.okhttp.internal;
     17 
     18 import com.squareup.okhttp.ConnectionSpec;
     19 import com.squareup.okhttp.TlsVersion;
     20 import java.io.IOException;
     21 import java.security.cert.CertificateException;
     22 import java.util.Arrays;
     23 import java.util.LinkedHashSet;
     24 import java.util.Set;
     25 import javax.net.ssl.SSLContext;
     26 import javax.net.ssl.SSLHandshakeException;
     27 import javax.net.ssl.SSLSocket;
     28 import org.junit.Test;
     29 
     30 import static org.junit.Assert.assertEquals;
     31 import static org.junit.Assert.assertFalse;
     32 import static org.junit.Assert.assertTrue;
     33 
     34 public class ConnectionSpecSelectorTest {
     35   static {
     36     Internal.initializeInstanceForTests();
     37   }
     38 
     39   public static final SSLHandshakeException RETRYABLE_EXCEPTION = new SSLHandshakeException(
     40       "Simulated handshake exception");
     41 
     42   private SSLContext sslContext = SslContextBuilder.localhost();
     43 
     44   @Test
     45   public void nonRetryableIOException() throws Exception {
     46     ConnectionSpecSelector connectionSpecSelector =
     47         createConnectionSpecSelector(ConnectionSpec.MODERN_TLS, ConnectionSpec.COMPATIBLE_TLS);
     48     SSLSocket socket = createSocketWithEnabledProtocols(TlsVersion.TLS_1_1, TlsVersion.TLS_1_0);
     49     connectionSpecSelector.configureSecureSocket(socket);
     50 
     51     boolean retry = connectionSpecSelector.connectionFailed(
     52         new IOException("Non-handshake exception"));
     53     assertFalse(retry);
     54     socket.close();
     55   }
     56 
     57   @Test
     58   public void nonRetryableSSLHandshakeException() throws Exception {
     59     ConnectionSpecSelector connectionSpecSelector =
     60         createConnectionSpecSelector(ConnectionSpec.MODERN_TLS, ConnectionSpec.COMPATIBLE_TLS);
     61     SSLSocket socket = createSocketWithEnabledProtocols(TlsVersion.TLS_1_1, TlsVersion.TLS_1_0);
     62     connectionSpecSelector.configureSecureSocket(socket);
     63 
     64     SSLHandshakeException trustIssueException =
     65         new SSLHandshakeException("Certificate handshake exception");
     66     trustIssueException.initCause(new CertificateException());
     67     boolean retry = connectionSpecSelector.connectionFailed(trustIssueException);
     68     assertFalse(retry);
     69     socket.close();
     70   }
     71 
     72   @Test
     73   public void retryableSSLHandshakeException() throws Exception {
     74     ConnectionSpecSelector connectionSpecSelector =
     75         createConnectionSpecSelector(ConnectionSpec.MODERN_TLS, ConnectionSpec.COMPATIBLE_TLS);
     76     SSLSocket socket = createSocketWithEnabledProtocols(TlsVersion.TLS_1_1, TlsVersion.TLS_1_0);
     77     connectionSpecSelector.configureSecureSocket(socket);
     78 
     79     boolean retry = connectionSpecSelector.connectionFailed(RETRYABLE_EXCEPTION);
     80     assertTrue(retry);
     81     socket.close();
     82   }
     83 
     84   @Test
     85   public void someFallbacksSupported() throws Exception {
     86     ConnectionSpec sslV3 =
     87         new ConnectionSpec.Builder(ConnectionSpec.MODERN_TLS)
     88             .tlsVersions(TlsVersion.SSL_3_0)
     89             .build();
     90 
     91     ConnectionSpecSelector connectionSpecSelector = createConnectionSpecSelector(
     92         ConnectionSpec.MODERN_TLS, ConnectionSpec.COMPATIBLE_TLS, sslV3);
     93 
     94     TlsVersion[] enabledSocketTlsVersions = { TlsVersion.TLS_1_1, TlsVersion.TLS_1_0 };
     95     SSLSocket socket = createSocketWithEnabledProtocols(enabledSocketTlsVersions);
     96 
     97     // MODERN_TLS is used here.
     98     connectionSpecSelector.configureSecureSocket(socket);
     99     assertEnabledProtocols(socket, TlsVersion.TLS_1_1, TlsVersion.TLS_1_0);
    100 
    101     boolean retry = connectionSpecSelector.connectionFailed(RETRYABLE_EXCEPTION);
    102     assertTrue(retry);
    103     socket.close();
    104 
    105     // COMPATIBLE_TLS is used here.
    106     socket = createSocketWithEnabledProtocols(enabledSocketTlsVersions);
    107     connectionSpecSelector.configureSecureSocket(socket);
    108     assertEnabledProtocols(socket, TlsVersion.TLS_1_0);
    109 
    110     retry = connectionSpecSelector.connectionFailed(RETRYABLE_EXCEPTION);
    111     assertFalse(retry);
    112     socket.close();
    113 
    114     // sslV3 is not used because SSLv3 is not enabled on the socket.
    115   }
    116 
    117   private static ConnectionSpecSelector createConnectionSpecSelector(
    118       ConnectionSpec... connectionSpecs) {
    119     return new ConnectionSpecSelector(Arrays.asList(connectionSpecs));
    120   }
    121 
    122   private SSLSocket createSocketWithEnabledProtocols(TlsVersion... tlsVersions) throws IOException {
    123     SSLSocket socket = (SSLSocket) sslContext.getSocketFactory().createSocket();
    124     socket.setEnabledProtocols(javaNames(tlsVersions));
    125     return socket;
    126   }
    127 
    128   private static void assertEnabledProtocols(SSLSocket socket, TlsVersion... required) {
    129     Set<String> actual = new LinkedHashSet<>(Arrays.asList(socket.getEnabledProtocols()));
    130     Set<String> expected = new LinkedHashSet<>(Arrays.asList(javaNames(required)));
    131     assertEquals(expected, actual);
    132   }
    133 
    134   private static String[] javaNames(TlsVersion... tlsVersions) {
    135     String[] protocols = new String[tlsVersions.length];
    136     for (int i = 0; i < tlsVersions.length; i++) {
    137       protocols[i] = tlsVersions[i].javaName();
    138     }
    139     return protocols;
    140   }
    141 }
    142