1 package java.net; 2 3 import junit.framework.TestCase; 4 5 import java.io.IOException; 6 import java.util.Locale; 7 import java.util.concurrent.CountDownLatch; 8 import java.util.concurrent.TimeUnit; 9 10 /** 11 * Tests for race conditions between {@link ServerSocket#close()} and 12 * {@link ServerSocket#accept()}. 13 */ 14 public class ServerSocketConcurrentCloseTest extends TestCase { 15 private static final String TAG = ServerSocketConcurrentCloseTest.class.getSimpleName(); 16 17 /** 18 * The implementation of {@link ServerSocket#accept()} checks closed state before 19 * delegating to the {@link ServerSocket#implAccept(Socket)}, however this is not 20 * sufficient for correctness because the socket might be closed after the check. 21 * This checks that implAccept() itself also detects closed sockets and throws 22 * SocketException. 23 */ 24 public void testImplAccept_detectsClosedState() throws Exception { 25 /** A ServerSocket that exposes implAccept() */ 26 class ExposedServerSocket extends ServerSocket { 27 public ExposedServerSocket() throws IOException { 28 super(0 /* allocate port number automatically */); 29 } 30 31 public void implAcceptExposedForTest(Socket socket) throws IOException { 32 implAccept(socket); 33 } 34 } 35 final ExposedServerSocket serverSocket = new ExposedServerSocket(); 36 serverSocket.close(); 37 try { 38 // Hack: Need to subclass to access the protected constructor without reflection 39 Socket socket = new Socket((SocketImpl) null) { }; 40 serverSocket.implAcceptExposedForTest(socket); 41 fail("accepting on a closed socket should throw"); 42 } catch (SocketException expected) { 43 // expected 44 } catch (IOException e) { 45 throw new AssertionError(e); 46 } 47 } 48 49 /** 50 * Test for b/27763633. 51 */ 52 public void testConcurrentServerSocketCloseReliablyThrows() { 53 int numIterations = 200; 54 for (int i = 0; i < numIterations; i++) { 55 checkConnectIterationAndCloseSocket("Iteration " + (i+1) + " of " + numIterations, 56 /* msecPerIteration */ 50); 57 } 58 } 59 60 /** 61 * Checks that a concurrent {@link ServerSocket#close()} reliably causes 62 * {@link ServerSocket#accept()} to throw {@link SocketException}. 63 * 64 * <p>Spawns a server and client thread that continuously connect to each other 65 * for {@code msecPerIteration} msec. Then, closes the {@link ServerSocket} and 66 * verifies that the server quickly shuts down. 67 */ 68 private void checkConnectIterationAndCloseSocket(String iterationName, int msecPerIteration) { 69 ServerSocket serverSocket; 70 try { 71 serverSocket = new ServerSocket(0 /* allocate port number automatically */); 72 } catch (IOException e) { 73 fail("Abort: " + e); 74 throw new AssertionError("unreachable"); 75 } 76 final CountDownLatch shutdownLatch = new CountDownLatch(1); 77 ServerRunnable serverRunnable = new ServerRunnable(serverSocket, shutdownLatch); 78 Thread serverThread = new Thread(serverRunnable, TAG + " (server)"); 79 ClientRunnable clientRunnable = new ClientRunnable( 80 serverSocket.getLocalSocketAddress(), shutdownLatch); 81 Thread clientThread = new Thread(clientRunnable, TAG + " (client)"); 82 serverThread.start(); 83 clientThread.start(); 84 try { 85 if (shutdownLatch.getCount() == 0) { 86 fail("Server prematurely shut down"); 87 } 88 // Let server and client keep connecting for some time, then close the socket. 89 Thread.sleep(msecPerIteration); 90 try { 91 serverSocket.close(); 92 } catch (IOException e) { 93 throw new AssertionError("serverSocket.close() failed: ", e); 94 } 95 // Check that the server shut down quickly in response to the socket closing. 96 long hardLimitSeconds = 5; 97 boolean serverShutdownReached = shutdownLatch.await(hardLimitSeconds, TimeUnit.SECONDS); 98 if (!serverShutdownReached) { // b/27763633 99 shutdownLatch.countDown(); 100 String serverStackTrace = stackTraceAsString(serverThread.getStackTrace()); 101 fail("Server took > " + hardLimitSeconds + "sec to react to serverSocket.close(). " 102 + "Server thread's stackTrace: " + serverStackTrace); 103 } 104 // Guard against the test passing as a false positive if no connections were actually 105 // established. If the test was running for much longer then this would fail during 106 // later iterations because TCP connections cannot be closed immediately (they stay 107 // in TIME_WAIT state for a few minutes) and only some number (tens of thousands?) 108 // can be open at a time. If this assertion turns out flaky in future, consider 109 // reducing msecPerIteration or number of iterations. 110 assertTrue(String.format(Locale.US, "%s: No connections in %d msec.", 111 iterationName, msecPerIteration), 112 serverRunnable.numSuccessfulConnections > 0); 113 114 assertEquals(0, shutdownLatch.getCount()); 115 // Sanity check to ensure the threads don't live into the next iteration. This should 116 // be quick because we only get here if shutdownLatch reached 0 within the time limit. 117 serverThread.join(); 118 clientThread.join(); 119 } catch (InterruptedException e) { 120 throw new AssertionError("Unexpected interruption", e); 121 } 122 } 123 124 /** 125 * Repeatedly tries to connect to and disconnect from a SocketAddress until 126 * it observes {@code shutdownLatch} reaching 0. Does not read/write any 127 * data from/to the socket. 128 */ 129 static class ClientRunnable implements Runnable { 130 private final SocketAddress socketAddress; 131 private final CountDownLatch shutdownLatch; 132 133 public ClientRunnable( 134 SocketAddress socketAddress, CountDownLatch shutdownLatch) { 135 this.socketAddress = socketAddress; 136 this.shutdownLatch = shutdownLatch; 137 } 138 139 @Override 140 public void run() { 141 while (shutdownLatch.getCount() != 0) { // check if server is shutting down 142 try { 143 Socket socket = new Socket(); 144 socket.connect(socketAddress, /* timeout (msec) */ 10); 145 socket.close(); 146 } catch (IOException e) { 147 // harmless, as long as enough connections are successful 148 } 149 } 150 } 151 } 152 153 /** 154 * Repeatedly accepts connections from a ServerSocket and immediately closes them. 155 * When it encounters a SocketException, it counts down the CountDownLatch and exits. 156 */ 157 static class ServerRunnable implements Runnable { 158 private final ServerSocket serverSocket; 159 volatile int numSuccessfulConnections; 160 private final CountDownLatch shutdownLatch; 161 162 ServerRunnable(ServerSocket serverSocket, CountDownLatch shutdownLatch) { 163 this.serverSocket = serverSocket; 164 this.shutdownLatch = shutdownLatch; 165 } 166 167 @Override 168 public void run() { 169 int numSuccessfulConnections = 0; 170 while (true) { 171 try { 172 Socket socket = serverSocket.accept(); 173 numSuccessfulConnections++; 174 socket.close(); 175 } catch (SocketException e) { 176 this.numSuccessfulConnections = numSuccessfulConnections; 177 shutdownLatch.countDown(); 178 return; 179 } catch (IOException e) { 180 // harmless, as long as enough connections are successful 181 } 182 } 183 } 184 } 185 186 private static String stackTraceAsString(StackTraceElement[] stackTraceElements) { 187 StringBuilder sb = new StringBuilder(); 188 for (StackTraceElement stackTraceElement : stackTraceElements) { 189 sb.append("\n\t at ").append(stackTraceElement); 190 } 191 return sb.toString(); 192 } 193 194 } 195