Home | History | Annotate | Download | only in conscrypt
      1 /*
      2  * Copyright 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 package org.conscrypt;
     18 
     19 import java.io.EOFException;
     20 import java.io.IOException;
     21 import java.io.InputStream;
     22 import java.io.OutputStream;
     23 import java.net.ServerSocket;
     24 import java.net.SocketException;
     25 import java.nio.channels.ClosedChannelException;
     26 import java.util.concurrent.ExecutionException;
     27 import java.util.concurrent.ExecutorService;
     28 import java.util.concurrent.Executors;
     29 import java.util.concurrent.Future;
     30 import java.util.concurrent.TimeUnit;
     31 import java.util.concurrent.TimeoutException;
     32 import javax.net.ssl.SSLException;
     33 import javax.net.ssl.SSLServerSocketFactory;
     34 import javax.net.ssl.SSLSocket;
     35 import javax.net.ssl.SSLSocketFactory;
     36 
     37 /**
     38  * A simple socket-based test server.
     39  */
     40 final class ServerEndpoint {
     41     /**
     42      * A processor for receipt of a single message.
     43      */
     44     public interface MessageProcessor {
     45         void processMessage(byte[] message, int numBytes, OutputStream os);
     46     }
     47 
     48     /**
     49      * A {@link MessageProcessor} that simply echos back the received message to the client.
     50      */
     51     public static final class EchoProcessor implements MessageProcessor {
     52         @Override
     53         public void processMessage(byte[] message, int numBytes, OutputStream os) {
     54             try {
     55                 os.write(message, 0, numBytes);
     56                 os.flush();
     57             } catch (IOException e) {
     58                 throw new RuntimeException(e);
     59             }
     60         }
     61     }
     62 
     63     private final ServerSocket serverSocket;
     64     private final ChannelType channelType;
     65     private final SSLSocketFactory socketFactory;
     66     private final int messageSize;
     67     private final String[] protocols;
     68     private final String[] cipherSuites;
     69     private final byte[] buffer;
     70     private SSLSocket socket;
     71     private ExecutorService executor;
     72     private InputStream inputStream;
     73     private OutputStream outputStream;
     74     private volatile boolean stopping;
     75     private volatile MessageProcessor messageProcessor = new EchoProcessor();
     76     private volatile Future<?> processFuture;
     77 
     78     ServerEndpoint(SSLSocketFactory socketFactory, SSLServerSocketFactory serverSocketFactory,
     79             ChannelType channelType, int messageSize, String[] protocols,
     80             String[] cipherSuites) throws IOException {
     81         this.serverSocket = channelType.newServerSocket(serverSocketFactory);
     82         this.socketFactory = socketFactory;
     83         this.channelType = channelType;
     84         this.messageSize = messageSize;
     85         this.protocols = protocols;
     86         this.cipherSuites = cipherSuites;
     87         buffer = new byte[messageSize];
     88     }
     89 
     90     void setMessageProcessor(MessageProcessor messageProcessor) {
     91         this.messageProcessor = messageProcessor;
     92     }
     93 
     94     Future<?> start() throws IOException {
     95         executor = Executors.newSingleThreadExecutor();
     96         return executor.submit(new AcceptTask());
     97     }
     98 
     99     void stop() {
    100         try {
    101             stopping = true;
    102 
    103             if (socket != null) {
    104                 socket.close();
    105                 socket = null;
    106             }
    107 
    108             if (processFuture != null) {
    109                 processFuture.get(5, TimeUnit.SECONDS);
    110             }
    111 
    112             serverSocket.close();
    113 
    114             if (executor != null) {
    115                 executor.shutdown();
    116                 executor.awaitTermination(5, TimeUnit.SECONDS);
    117                 executor = null;
    118             }
    119         } catch (IOException | InterruptedException | ExecutionException | TimeoutException e) {
    120             throw new RuntimeException(e);
    121         }
    122     }
    123 
    124     public int port() {
    125         return serverSocket.getLocalPort();
    126     }
    127 
    128     private final class AcceptTask implements Runnable {
    129         @Override
    130         public void run() {
    131             try {
    132                 if (stopping) {
    133                     return;
    134                 }
    135                 socket = channelType.accept(serverSocket, socketFactory);
    136                 socket.setEnabledProtocols(protocols);
    137                 socket.setEnabledCipherSuites(cipherSuites);
    138 
    139                 socket.startHandshake();
    140 
    141                 inputStream = socket.getInputStream();
    142                 outputStream = socket.getOutputStream();
    143 
    144                 if (stopping) {
    145                     return;
    146                 }
    147                 processFuture = executor.submit(new ProcessTask());
    148             } catch (IOException e) {
    149                 e.printStackTrace();
    150                 throw new RuntimeException(e);
    151             }
    152         }
    153     }
    154 
    155     private final class ProcessTask implements Runnable {
    156         @Override
    157         public void run() {
    158             try {
    159                 Thread thread = Thread.currentThread();
    160                 while (!stopping && !thread.isInterrupted()) {
    161                     int bytesRead = readMessage();
    162                     if (!stopping && !thread.isInterrupted()) {
    163                         messageProcessor.processMessage(buffer, bytesRead, outputStream);
    164                     }
    165                 }
    166             } catch (Throwable e) {
    167                 throw new RuntimeException(e);
    168             }
    169         }
    170 
    171         private int readMessage() throws IOException {
    172             int totalBytesRead = 0;
    173             while (!stopping && totalBytesRead < messageSize) {
    174                 try {
    175                     int remaining = messageSize - totalBytesRead;
    176                     int bytesRead = inputStream.read(buffer, totalBytesRead, remaining);
    177                     if (bytesRead == -1) {
    178                         break;
    179                     }
    180                     totalBytesRead += bytesRead;
    181                 } catch (SSLException e) {
    182                     if (e.getCause() instanceof EOFException) {
    183                         break;
    184                     }
    185                     throw e;
    186                 } catch (ClosedChannelException e) {
    187                     // Thrown for channel-based sockets. Just treat like EOF.
    188                     break;
    189                 } catch (SocketException e) {
    190                     // The socket was broken. Just treat like EOF.
    191                     break;
    192                 }
    193             }
    194             return totalBytesRead;
    195         }
    196     }
    197 }
    198