1 /* 2 * Copyright 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package org.conscrypt; 18 19 import java.io.EOFException; 20 import java.io.IOException; 21 import java.io.InputStream; 22 import java.io.OutputStream; 23 import java.net.ServerSocket; 24 import java.net.SocketException; 25 import java.nio.channels.ClosedChannelException; 26 import java.util.concurrent.ExecutionException; 27 import java.util.concurrent.ExecutorService; 28 import java.util.concurrent.Executors; 29 import java.util.concurrent.Future; 30 import java.util.concurrent.TimeUnit; 31 import java.util.concurrent.TimeoutException; 32 import javax.net.ssl.SSLException; 33 import javax.net.ssl.SSLServerSocketFactory; 34 import javax.net.ssl.SSLSocket; 35 import javax.net.ssl.SSLSocketFactory; 36 37 /** 38 * A simple socket-based test server. 39 */ 40 final class ServerEndpoint { 41 /** 42 * A processor for receipt of a single message. 43 */ 44 public interface MessageProcessor { 45 void processMessage(byte[] message, int numBytes, OutputStream os); 46 } 47 48 /** 49 * A {@link MessageProcessor} that simply echos back the received message to the client. 50 */ 51 public static final class EchoProcessor implements MessageProcessor { 52 @Override 53 public void processMessage(byte[] message, int numBytes, OutputStream os) { 54 try { 55 os.write(message, 0, numBytes); 56 os.flush(); 57 } catch (IOException e) { 58 throw new RuntimeException(e); 59 } 60 } 61 } 62 63 private final ServerSocket serverSocket; 64 private final ChannelType channelType; 65 private final SSLSocketFactory socketFactory; 66 private final int messageSize; 67 private final String[] protocols; 68 private final String[] cipherSuites; 69 private final byte[] buffer; 70 private SSLSocket socket; 71 private ExecutorService executor; 72 private InputStream inputStream; 73 private OutputStream outputStream; 74 private volatile boolean stopping; 75 private volatile MessageProcessor messageProcessor = new EchoProcessor(); 76 private volatile Future<?> processFuture; 77 78 ServerEndpoint(SSLSocketFactory socketFactory, SSLServerSocketFactory serverSocketFactory, 79 ChannelType channelType, int messageSize, String[] protocols, 80 String[] cipherSuites) throws IOException { 81 this.serverSocket = channelType.newServerSocket(serverSocketFactory); 82 this.socketFactory = socketFactory; 83 this.channelType = channelType; 84 this.messageSize = messageSize; 85 this.protocols = protocols; 86 this.cipherSuites = cipherSuites; 87 buffer = new byte[messageSize]; 88 } 89 90 void setMessageProcessor(MessageProcessor messageProcessor) { 91 this.messageProcessor = messageProcessor; 92 } 93 94 Future<?> start() throws IOException { 95 executor = Executors.newSingleThreadExecutor(); 96 return executor.submit(new AcceptTask()); 97 } 98 99 void stop() { 100 try { 101 stopping = true; 102 103 if (socket != null) { 104 socket.close(); 105 socket = null; 106 } 107 108 if (processFuture != null) { 109 processFuture.get(5, TimeUnit.SECONDS); 110 } 111 112 serverSocket.close(); 113 114 if (executor != null) { 115 executor.shutdown(); 116 executor.awaitTermination(5, TimeUnit.SECONDS); 117 executor = null; 118 } 119 } catch (IOException | InterruptedException | ExecutionException | TimeoutException e) { 120 throw new RuntimeException(e); 121 } 122 } 123 124 public int port() { 125 return serverSocket.getLocalPort(); 126 } 127 128 private final class AcceptTask implements Runnable { 129 @Override 130 public void run() { 131 try { 132 if (stopping) { 133 return; 134 } 135 socket = channelType.accept(serverSocket, socketFactory); 136 socket.setEnabledProtocols(protocols); 137 socket.setEnabledCipherSuites(cipherSuites); 138 139 socket.startHandshake(); 140 141 inputStream = socket.getInputStream(); 142 outputStream = socket.getOutputStream(); 143 144 if (stopping) { 145 return; 146 } 147 processFuture = executor.submit(new ProcessTask()); 148 } catch (IOException e) { 149 e.printStackTrace(); 150 throw new RuntimeException(e); 151 } 152 } 153 } 154 155 private final class ProcessTask implements Runnable { 156 @Override 157 public void run() { 158 try { 159 Thread thread = Thread.currentThread(); 160 while (!stopping && !thread.isInterrupted()) { 161 int bytesRead = readMessage(); 162 if (!stopping && !thread.isInterrupted()) { 163 messageProcessor.processMessage(buffer, bytesRead, outputStream); 164 } 165 } 166 } catch (Throwable e) { 167 throw new RuntimeException(e); 168 } 169 } 170 171 private int readMessage() throws IOException { 172 int totalBytesRead = 0; 173 while (!stopping && totalBytesRead < messageSize) { 174 try { 175 int remaining = messageSize - totalBytesRead; 176 int bytesRead = inputStream.read(buffer, totalBytesRead, remaining); 177 if (bytesRead == -1) { 178 break; 179 } 180 totalBytesRead += bytesRead; 181 } catch (SSLException e) { 182 if (e.getCause() instanceof EOFException) { 183 break; 184 } 185 throw e; 186 } catch (ClosedChannelException e) { 187 // Thrown for channel-based sockets. Just treat like EOF. 188 break; 189 } catch (SocketException e) { 190 // The socket was broken. Just treat like EOF. 191 break; 192 } 193 } 194 return totalBytesRead; 195 } 196 } 197 } 198