Home | History | Annotate | Download | only in conscrypt
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License
     15  */
     16 
     17 package org.conscrypt;
     18 
     19 import static org.junit.Assert.assertArrayEquals;
     20 import static org.junit.Assert.assertEquals;
     21 import static org.junit.Assert.assertNotEquals;
     22 
     23 import java.io.EOFException;
     24 import java.io.IOException;
     25 import java.net.InetSocketAddress;
     26 import java.nio.ByteBuffer;
     27 import java.nio.channels.ServerSocketChannel;
     28 import java.nio.channels.SocketChannel;
     29 import java.util.Arrays;
     30 import java.util.LinkedHashSet;
     31 import java.util.Set;
     32 import java.util.concurrent.ExecutionException;
     33 import java.util.concurrent.ExecutorService;
     34 import java.util.concurrent.Executors;
     35 import java.util.concurrent.Future;
     36 import java.util.concurrent.TimeUnit;
     37 import java.util.concurrent.TimeoutException;
     38 import javax.net.ssl.SSLContext;
     39 import javax.net.ssl.SSLEngine;
     40 import javax.net.ssl.SSLEngineResult;
     41 import javax.net.ssl.SSLEngineResult.Status;
     42 import javax.net.ssl.SSLSocket;
     43 import javax.net.ssl.SSLSocketFactory;
     44 import org.conscrypt.java.security.TestKeyStore;
     45 import org.junit.After;
     46 import org.junit.Before;
     47 import org.junit.Test;
     48 import org.junit.runner.RunWith;
     49 import org.junit.runners.Parameterized;
     50 import org.junit.runners.Parameterized.Parameter;
     51 import org.junit.runners.Parameterized.Parameters;
     52 
     53 /**
     54  * This tests that server-initiated cipher renegotiation works properly with a Conscrypt client.
     55  * BoringSSL does not support user-initiated renegotiation, so we use the JDK implementation for
     56  * the server.
     57  */
     58 @RunWith(Parameterized.class)
     59 public class RenegotiationTest {
     60     private static final ByteBuffer EMPTY_BUFFER = ByteBuffer.allocateDirect(0);
     61     private static final String[] CIPHERS = TestUtils.getCommonCipherSuites();
     62     private static final byte[] MESSAGE_BYTES = "Hello".getBytes(TestUtils.UTF_8);
     63     private static final ByteBuffer MESSAGE_BUFFER =
     64             ByteBuffer.wrap(MESSAGE_BYTES).asReadOnlyBuffer();
     65     private static final int MESSAGE_LENGTH = MESSAGE_BYTES.length;
     66 
     67     public enum SocketType {
     68         FILE_DESCRIPTOR {
     69             @Override
     70             Client newClient(int port) {
     71                 return new Client(false, port);
     72             }
     73         },
     74         ENGINE {
     75             @Override
     76             Client newClient(int port) {
     77                 return new Client(true, port);
     78             }
     79         };
     80 
     81         abstract Client newClient(int port);
     82     }
     83 
     84     @Parameters(name = "{0}")
     85     public static Object[] data() {
     86         return new Object[] {SocketType.FILE_DESCRIPTOR, SocketType.ENGINE};
     87     }
     88 
     89     @Parameter
     90     public SocketType socketType;
     91 
     92     private Client client;
     93     private Server server;
     94 
     95     @Before
     96     public void setup() throws Exception {
     97         server = new Server();
     98         Future<?> connectedFuture = server.start();
     99 
    100         client = socketType.newClient(server.port());
    101         client.start();
    102 
    103         // Wait for the initial connection to complete.
    104         connectedFuture.get(5, TimeUnit.SECONDS);
    105     }
    106 
    107     @After
    108     public void teardown() {
    109         client.stop();
    110         server.stop();
    111     }
    112 
    113     @Test
    114     public void test() throws Exception {
    115         client.socket.startHandshake();
    116         String initialCipher = client.socket.getSession().getCipherSuite();
    117 
    118         client.sendMessage();
    119 
    120         Future<?> repliesFuture = client.readReplies();
    121         server.await(5, TimeUnit.SECONDS);
    122         repliesFuture.get(5, TimeUnit.SECONDS);
    123 
    124         // Verify that the cipher has changed.
    125         assertNotEquals(initialCipher, client.socket.getSession().getCipherSuite());
    126     }
    127 
    128     private static SSLContext newConscryptClientContext() {
    129         SSLContext context = TestUtils.newContext(TestUtils.getConscryptProvider());
    130         return TestUtils.initSslContext(context, TestKeyStore.getClient());
    131     }
    132 
    133     private static SSLContext newJdkServerContext() {
    134         SSLContext context = TestUtils.newContext(TestUtils.getJdkProvider());
    135         return TestUtils.initSslContext(context, TestKeyStore.getServer());
    136     }
    137 
    138     private static final class Client {
    139         private final SSLSocket socket;
    140         private ExecutorService executor;
    141         private volatile boolean stopping;
    142 
    143         Client(boolean useEngineSocket, int port) {
    144             try {
    145                 SSLSocketFactory socketFactory = newConscryptClientContext().getSocketFactory();
    146                 Conscrypt.setUseEngineSocket(socketFactory, useEngineSocket);
    147                 socket = (SSLSocket) socketFactory.createSocket(
    148                         TestUtils.getLoopbackAddress(), port);
    149                 socket.setEnabledCipherSuites(CIPHERS);
    150             } catch (IOException e) {
    151                 throw new RuntimeException(e);
    152             }
    153         }
    154 
    155         void start() {
    156             try {
    157                 executor = Executors.newSingleThreadExecutor();
    158                 socket.startHandshake();
    159             } catch (IOException e) {
    160                 e.printStackTrace();
    161                 throw new RuntimeException(e);
    162             }
    163         }
    164 
    165         void stop() {
    166             try {
    167                 stopping = true;
    168                 socket.close();
    169 
    170                 if (executor != null) {
    171                     executor.shutdown();
    172                     executor.awaitTermination(5, TimeUnit.SECONDS);
    173                     executor = null;
    174                 }
    175             } catch (RuntimeException e) {
    176                 throw e;
    177             } catch (Exception e) {
    178                 throw new RuntimeException(e);
    179             }
    180         }
    181 
    182         Future<?> readReplies() {
    183             return executor.submit(new Runnable() {
    184                 @Override
    185                 public void run() {
    186                     readReply();
    187                 }
    188             });
    189         }
    190 
    191         private void readReply() {
    192             try {
    193                 byte[] buffer = new byte[MESSAGE_LENGTH];
    194                 int totalBytesRead = 0;
    195                 while (totalBytesRead < MESSAGE_LENGTH) {
    196                     int remaining = MESSAGE_LENGTH - totalBytesRead;
    197                     int bytesRead = socket.getInputStream().read(buffer, totalBytesRead, remaining);
    198                     if (bytesRead == -1) {
    199                         throw new EOFException();
    200                     }
    201                     totalBytesRead += bytesRead;
    202                 }
    203 
    204                 // Verify the reply is correct.
    205                 assertEquals(MESSAGE_LENGTH, totalBytesRead);
    206                 assertArrayEquals(MESSAGE_BYTES, buffer);
    207             } catch (IOException e) {
    208                 throw new RuntimeException(e);
    209             }
    210         }
    211 
    212         void sendMessage() throws IOException {
    213             try {
    214                 socket.getOutputStream().write(MESSAGE_BYTES);
    215                 socket.getOutputStream().flush();
    216             } catch (IOException e) {
    217                 throw new RuntimeException(e);
    218             }
    219         }
    220     }
    221 
    222     private static final class Server {
    223         private final ServerSocketChannel serverChannel;
    224         private final SSLEngine engine;
    225         private final ByteBuffer inboundPacketBuffer;
    226         private final ByteBuffer inboundAppBuffer;
    227         private final ByteBuffer outboundPacketBuffer;
    228         private final Set<String> ciphers = new LinkedHashSet<String>(Arrays.asList(CIPHERS));
    229         private SocketChannel channel;
    230         private ExecutorService executor;
    231         private volatile boolean stopping;
    232         private volatile Future<?> echoFuture;
    233 
    234         Server() throws IOException {
    235             serverChannel = ServerSocketChannel.open();
    236             serverChannel.socket().bind(new InetSocketAddress(TestUtils.getLoopbackAddress(), 0));
    237             engine = newJdkServerContext().createSSLEngine();
    238             engine.setEnabledCipherSuites(CIPHERS);
    239             engine.setUseClientMode(false);
    240 
    241             inboundPacketBuffer =
    242                     ByteBuffer.allocateDirect(engine.getSession().getPacketBufferSize());
    243             inboundAppBuffer =
    244                     ByteBuffer.allocateDirect(engine.getSession().getApplicationBufferSize());
    245             outboundPacketBuffer =
    246                     ByteBuffer.allocateDirect(engine.getSession().getPacketBufferSize());
    247         }
    248 
    249         Future<?> start() throws IOException {
    250             executor = Executors.newSingleThreadExecutor();
    251             return executor.submit(new AcceptTask());
    252         }
    253 
    254         void await(long timeout, TimeUnit unit)
    255                 throws InterruptedException, ExecutionException, TimeoutException {
    256             echoFuture.get(timeout, unit);
    257         }
    258 
    259         void stop() {
    260             try {
    261                 stopping = true;
    262 
    263                 if (channel != null) {
    264                     channel.close();
    265                     channel = null;
    266                 }
    267 
    268                 serverChannel.close();
    269 
    270                 if (executor != null) {
    271                     executor.shutdown();
    272                     executor.awaitTermination(5, TimeUnit.SECONDS);
    273                     executor = null;
    274                 }
    275             } catch (IOException e) {
    276                 throw new RuntimeException(e);
    277             } catch (InterruptedException e) {
    278                 throw new RuntimeException(e);
    279             }
    280         }
    281 
    282         int port() {
    283             return serverChannel.socket().getLocalPort();
    284         }
    285 
    286         private final class AcceptTask implements Runnable {
    287             @Override
    288             public void run() {
    289                 try {
    290                     if (stopping) {
    291                         return;
    292                     }
    293                     channel = serverChannel.accept();
    294                     channel.configureBlocking(false);
    295 
    296                     doHandshake();
    297 
    298                     if (stopping) {
    299                         return;
    300                     }
    301                     echoFuture = executor.submit(new EchoTask());
    302                 } catch (Throwable e) {
    303                     e.printStackTrace();
    304                     throw new RuntimeException(e);
    305                 }
    306             }
    307         }
    308 
    309         private final class EchoTask implements Runnable {
    310             @Override
    311             public void run() {
    312                 try {
    313                     readMessage();
    314                     renegotiate();
    315                     reply();
    316                 } catch (Throwable e) {
    317                     e.printStackTrace();
    318                     throw new RuntimeException(e);
    319                 }
    320             }
    321 
    322             private void renegotiate() throws Exception {
    323                 // Remove the current cipher from the set and renegotiate to force a new
    324                 // cipher to be selected.
    325                 String currentCipher = engine.getSession().getCipherSuite();
    326                 ciphers.remove(currentCipher);
    327                 engine.setEnabledCipherSuites(ciphers.toArray(new String[ciphers.size()]));
    328                 doHandshake();
    329             }
    330 
    331             private void reply() throws IOException {
    332                 SSLEngineResult result = wrap(newMessage());
    333                 if (result.getStatus() != Status.OK) {
    334                     throw new RuntimeException("Wrap failed. Status: " + result.getStatus());
    335                 }
    336             }
    337 
    338             private ByteBuffer newMessage() {
    339                 return MESSAGE_BUFFER.duplicate();
    340             }
    341 
    342             private void readMessage() throws IOException {
    343                 int totalProduced = 0;
    344                 while (!stopping) {
    345                     SSLEngineResult result = unwrap();
    346                     if (result.getStatus() != Status.OK) {
    347                         throw new RuntimeException("Failed reading message: " + result);
    348                     }
    349                     totalProduced += result.bytesProduced();
    350                     if (totalProduced == MESSAGE_LENGTH) {
    351                         return;
    352                     }
    353                 }
    354             }
    355         }
    356 
    357         private SSLEngineResult wrap(ByteBuffer src) throws IOException {
    358             outboundPacketBuffer.clear();
    359 
    360             // Check if the engine has bytes to wrap.
    361             SSLEngineResult result = engine.wrap(src, outboundPacketBuffer);
    362 
    363             // Write any wrapped bytes to the socket.
    364             outboundPacketBuffer.flip();
    365 
    366             do {
    367                 channel.write(outboundPacketBuffer);
    368             } while (outboundPacketBuffer.hasRemaining());
    369 
    370             return result;
    371         }
    372 
    373         private SSLEngineResult unwrap() throws IOException {
    374             // Unwrap any available bytes from the socket.
    375             SSLEngineResult result = null;
    376             boolean done = false;
    377             while (!done) {
    378                 if (channel.read(inboundPacketBuffer) == -1) {
    379                     throw new EOFException();
    380                 }
    381                 // Just clear the app buffer - we don't really use it.
    382                 inboundAppBuffer.clear();
    383                 inboundPacketBuffer.flip();
    384                 result = engine.unwrap(inboundPacketBuffer, inboundAppBuffer);
    385                 switch (result.getStatus()) {
    386                     case BUFFER_UNDERFLOW:
    387                         // Continue reading from the socket in a moment.
    388                         try {
    389                             Thread.sleep(10);
    390                         } catch (InterruptedException e) {
    391                             throw new RuntimeException(e);
    392                         }
    393                         break;
    394                     case OK:
    395                         done = true;
    396                         break;
    397                     default: { throw new RuntimeException("Unexpected unwrap result: " + result); }
    398                 }
    399 
    400                 // Compact for the next socket read.
    401                 inboundPacketBuffer.compact();
    402             }
    403             return result;
    404         }
    405 
    406         private void doHandshake() throws IOException {
    407             engine.beginHandshake();
    408 
    409             boolean done = false;
    410             while (!done) {
    411                 switch (engine.getHandshakeStatus()) {
    412                     case NEED_WRAP: {
    413                         wrap(EMPTY_BUFFER);
    414                         break;
    415                     }
    416                     case NEED_UNWRAP: {
    417                         unwrap();
    418                         break;
    419                     }
    420                     case NEED_TASK: {
    421                         runDelegatedTasks();
    422                         break;
    423                     }
    424                     default: {
    425                         done = true;
    426                         break;
    427                     }
    428                 }
    429             }
    430         }
    431 
    432         private void runDelegatedTasks() {
    433             for (;;) {
    434                 Runnable task = engine.getDelegatedTask();
    435                 if (task == null) {
    436                     break;
    437                 }
    438                 task.run();
    439             }
    440         }
    441     }
    442 }
    443