1 package libcore.java.net; 2 3 import junit.framework.TestCase; 4 5 import java.io.IOException; 6 import java.net.Socket; 7 import java.net.SocketImpl; 8 import java.net.SocketException; 9 import java.net.SocketAddress; 10 import java.net.ServerSocket; 11 import java.util.BitSet; 12 import java.util.Locale; 13 import java.util.Set; 14 import java.util.concurrent.CountDownLatch; 15 import java.util.concurrent.TimeUnit; 16 import java.util.concurrent.atomic.AtomicInteger; 17 import java.util.concurrent.atomic.AtomicReference; 18 19 /** 20 * Tests for race conditions between {@link ServerSocket#close()} and 21 * {@link ServerSocket#accept()}. 22 */ 23 public class ServerSocketConcurrentCloseTest extends TestCase { 24 private static final String TAG = ServerSocketConcurrentCloseTest.class.getSimpleName(); 25 26 /** 27 * The implementation of {@link ServerSocket#accept()} checks closed state before 28 * delegating to the {@link ServerSocket#implAccept(Socket)}, however this is not 29 * sufficient for correctness because the socket might be closed after the check. 30 * This checks that implAccept() itself also detects closed sockets and throws 31 * SocketException. 32 */ 33 public void testImplAccept_detectsClosedState() throws Exception { 34 /** A ServerSocket that exposes implAccept() */ 35 class ExposedServerSocket extends ServerSocket { 36 public ExposedServerSocket() throws IOException { 37 super(0 /* allocate port number automatically */); 38 } 39 40 public void implAcceptExposedForTest(Socket socket) throws IOException { 41 implAccept(socket); 42 } 43 } 44 final ExposedServerSocket serverSocket = new ExposedServerSocket(); 45 serverSocket.close(); 46 // implAccept() on background thread to prevent this test hanging 47 final AtomicReference<Exception> failure = new AtomicReference<>(); 48 final CountDownLatch threadFinishedLatch = new CountDownLatch(1); 49 Thread thread = new Thread("implAccept() closed ServerSocket") { 50 public void run() { 51 try { 52 // Hack: Need to subclass to access the protected constructor without reflection 53 Socket socket = new Socket((SocketImpl) null) { }; 54 serverSocket.implAcceptExposedForTest(socket); 55 } catch (SocketException expected) { 56 // pass 57 } catch (IOException|RuntimeException e) { 58 failure.set(e); 59 } finally { 60 threadFinishedLatch.countDown(); 61 } 62 } 63 }; 64 thread.start(); 65 66 boolean completed = threadFinishedLatch.await(5, TimeUnit.SECONDS); 67 assertTrue("implAccept didn't throw or return within time limit", completed); 68 Exception e = failure.get(); 69 if (e != null) { 70 throw new AssertionError("Unexpected exception", e); 71 } 72 thread.join(); 73 } 74 75 /** 76 * Test for b/27763633. 77 */ 78 public void testConcurrentServerSocketCloseReliablyThrows() { 79 int numIterations = 100; 80 int minNumIterationsWithConnections = 5; 81 int msecPerIteration = 50; 82 BitSet iterationsWithConnections = new BitSet(numIterations); 83 for (int i = 0; i < numIterations; i++) { 84 int numConnectionsMade = checkConnectIterationAndCloseSocket( 85 "Iteration " + (i+1) + " of " + numIterations, msecPerIteration); 86 if (numConnectionsMade > 0) { 87 iterationsWithConnections.set(i); 88 } 89 } 90 91 // Guard against the test passing as a false positive if no connections were actually 92 // established. If the test was running for much longer then this would fail during 93 // later iterations because TCP connections cannot be closed immediately (they stay 94 // in TIME_WAIT state for a few minutes) and only some number (tens of thousands?) 95 // can be open at a time. If this assertion turns out flaky in future, consider 96 // reducing msecPerIteration or numIterations. 97 int numIterationsWithConnections = iterationsWithConnections.cardinality(); 98 String msg = String.format(Locale.US, 99 "Connections only made on these %d/%d iterations of %d msec: %s", 100 numIterationsWithConnections, numIterations, msecPerIteration, 101 iterationsWithConnections); 102 assertTrue(msg, numIterationsWithConnections >= minNumIterationsWithConnections); 103 } 104 105 /** 106 * Checks that a concurrent {@link ServerSocket#close()} reliably causes 107 * {@link ServerSocket#accept()} to throw {@link SocketException}. 108 * 109 * <p>Spawns a server and client thread that continuously connect to each 110 * other for up to {@code maxSleepsPerIteration * sleepMsec} msec. 111 * Then, closes the {@link ServerSocket} and verifies that the server 112 * quickly shuts down. 113 * 114 * @return number of connections made between server and client threads 115 */ 116 private int checkConnectIterationAndCloseSocket(String iterationName, 117 int msecPerIteration) { 118 ServerSocket serverSocket; 119 try { 120 serverSocket = new ServerSocket(0 /* allocate port number automatically */); 121 } catch (IOException e) { 122 fail("Abort: " + e); 123 throw new AssertionError("unreachable"); 124 } 125 ServerRunnable serverRunnable = new ServerRunnable(serverSocket); 126 Thread serverThread = new Thread(serverRunnable, TAG + " (server)"); 127 ClientRunnable clientRunnable = new ClientRunnable( 128 serverSocket.getLocalSocketAddress(), serverRunnable); 129 Thread clientThread = new Thread(clientRunnable, TAG + " (client)"); 130 serverThread.start(); 131 clientThread.start(); 132 try { 133 assertTrue("Slow server startup", serverRunnable.awaitStart(1, TimeUnit.SECONDS)); 134 assertTrue("Slow client startup", clientRunnable.awaitStart(1, TimeUnit.SECONDS)); 135 if (serverRunnable.isShutdown()) { 136 fail("Server prematurely shut down"); 137 } 138 // Let server and client keep connecting for some time, then close the socket. 139 Thread.sleep(msecPerIteration); 140 try { 141 serverSocket.close(); 142 } catch (IOException e) { 143 throw new AssertionError("serverSocket.close() failed: ", e); 144 } 145 // Check that the server shut down quickly in response to the socket closing. 146 long hardLimitSeconds = 5; 147 boolean serverShutdownReached = serverRunnable.awaitShutdown(hardLimitSeconds, TimeUnit.SECONDS); 148 if (!serverShutdownReached) { // b/27763633 149 String serverStackTrace = stackTraceAsString(serverThread.getStackTrace()); 150 fail("Server took > " + hardLimitSeconds + "sec to react to serverSocket.close(). " 151 + "Server thread's stackTrace: " + serverStackTrace); 152 } 153 assertTrue(serverRunnable.isShutdown()); 154 // Sanity check to ensure the threads don't live into the next iteration. This should 155 // be quick because we only get here if shutdownLatch reached 0 within the time limit. 156 serverThread.join(); 157 clientThread.join(); 158 return serverRunnable.numSuccessfulConnections.get(); 159 } catch (InterruptedException e) { 160 throw new AssertionError("Unexpected interruption", e); 161 } 162 } 163 164 /** 165 * Repeatedly tries to connect to and disconnect from a SocketAddress until 166 * it observes {@code shutdownLatch} reaching 0. Does not read/write any 167 * data from/to the socket. 168 */ 169 static class ClientRunnable implements Runnable { 170 private final SocketAddress socketAddress; 171 172 private final ServerRunnable serverRunnable; 173 private final CountDownLatch startLatch = new CountDownLatch(1); 174 175 public ClientRunnable( 176 SocketAddress socketAddress, ServerRunnable serverRunnable) { 177 this.socketAddress = socketAddress; 178 this.serverRunnable = serverRunnable; 179 } 180 181 @Override 182 public void run() { 183 startLatch.countDown(); 184 while (!serverRunnable.isShutdown()) { 185 try { 186 Socket socket = new Socket(); 187 socket.connect(socketAddress, /* timeout (msec) */ 10); 188 socket.close(); 189 } catch (IOException e) { 190 // harmless, as long as enough connections are successful 191 } 192 } 193 } 194 195 public boolean awaitStart(long timeout, TimeUnit timeUnit) throws InterruptedException { 196 return startLatch.await(timeout, timeUnit); 197 } 198 199 } 200 201 /** 202 * Repeatedly accepts connections from a ServerSocket and immediately closes them. 203 * When it encounters a SocketException, it counts down the CountDownLatch and exits. 204 */ 205 static class ServerRunnable implements Runnable { 206 private final ServerSocket serverSocket; 207 final AtomicInteger numSuccessfulConnections = new AtomicInteger(); 208 private final CountDownLatch startLatch = new CountDownLatch(1); 209 private final CountDownLatch shutdownLatch = new CountDownLatch(1); 210 211 ServerRunnable(ServerSocket serverSocket) { 212 this.serverSocket = serverSocket; 213 } 214 215 @Override 216 public void run() { 217 startLatch.countDown(); 218 while (true) { 219 try { 220 Socket socket = serverSocket.accept(); 221 numSuccessfulConnections.incrementAndGet(); 222 socket.close(); 223 } catch (SocketException e) { 224 shutdownLatch.countDown(); 225 return; 226 } catch (IOException e) { 227 // harmless, as long as enough connections are successful 228 } 229 } 230 } 231 232 public boolean awaitStart(long timeout, TimeUnit timeUnit) throws InterruptedException { 233 return startLatch.await(timeout, timeUnit); 234 } 235 236 public boolean awaitShutdown(long timeout, TimeUnit timeUnit) throws InterruptedException { 237 return shutdownLatch.await(timeout, timeUnit); 238 } 239 240 public boolean isShutdown() { 241 return shutdownLatch.getCount() == 0; 242 } 243 } 244 245 private static String stackTraceAsString(StackTraceElement[] stackTraceElements) { 246 StringBuilder sb = new StringBuilder(); 247 for (StackTraceElement stackTraceElement : stackTraceElements) { 248 sb.append("\n\t at ").append(stackTraceElement); 249 } 250 return sb.toString(); 251 } 252 253 } 254