Home | History | Annotate | Download | only in net
      1 package java.net;
      2 
      3 import junit.framework.TestCase;
      4 
      5 import java.io.IOException;
      6 import java.util.Locale;
      7 import java.util.concurrent.CountDownLatch;
      8 import java.util.concurrent.TimeUnit;
      9 
     10 /**
     11  * Tests for race conditions between {@link ServerSocket#close()} and
     12  * {@link ServerSocket#accept()}.
     13  */
     14 public class ServerSocketConcurrentCloseTest extends TestCase {
     15     private static final String TAG = ServerSocketConcurrentCloseTest.class.getSimpleName();
     16 
     17     /**
     18      * The implementation of {@link ServerSocket#accept()} checks closed state before
     19      * delegating to the {@link ServerSocket#implAccept(Socket)}, however this is not
     20      * sufficient for correctness because the socket might be closed after the check.
     21      * This checks that implAccept() itself also detects closed sockets and throws
     22      * SocketException.
     23      */
     24     public void testImplAccept_detectsClosedState() throws Exception {
     25         /** A ServerSocket that exposes implAccept() */
     26         class ExposedServerSocket extends ServerSocket {
     27             public ExposedServerSocket() throws IOException {
     28                 super(0 /* allocate port number automatically */);
     29             }
     30 
     31             public void implAcceptExposedForTest(Socket socket) throws IOException {
     32                 implAccept(socket);
     33             }
     34         }
     35         final ExposedServerSocket serverSocket = new ExposedServerSocket();
     36         serverSocket.close();
     37         try {
     38             // Hack: Need to subclass to access the protected constructor without reflection
     39             Socket socket = new Socket((SocketImpl) null) { };
     40             serverSocket.implAcceptExposedForTest(socket);
     41             fail("accepting on a closed socket should throw");
     42         } catch (SocketException expected) {
     43             // expected
     44         } catch (IOException e) {
     45             throw new AssertionError(e);
     46         }
     47     }
     48 
     49     /**
     50      * Test for b/27763633.
     51      */
     52     public void testConcurrentServerSocketCloseReliablyThrows() {
     53         int numIterations = 200;
     54         for (int i = 0; i < numIterations; i++) {
     55             checkConnectIterationAndCloseSocket("Iteration " + (i+1) + " of " + numIterations,
     56                     /* msecPerIteration */ 50);
     57         }
     58     }
     59 
     60     /**
     61      * Checks that a concurrent {@link ServerSocket#close()} reliably causes
     62      * {@link ServerSocket#accept()} to throw {@link SocketException}.
     63      *
     64      * <p>Spawns a server and client thread that continuously connect to each other
     65      * for {@code msecPerIteration} msec. Then, closes the {@link ServerSocket} and
     66      * verifies that the server quickly shuts down.
     67      */
     68     private void checkConnectIterationAndCloseSocket(String iterationName, int msecPerIteration) {
     69         ServerSocket serverSocket;
     70         try {
     71             serverSocket = new ServerSocket(0 /* allocate port number automatically */);
     72         } catch (IOException e) {
     73             fail("Abort: " + e);
     74             throw new AssertionError("unreachable");
     75         }
     76         final CountDownLatch shutdownLatch = new CountDownLatch(1);
     77         ServerRunnable serverRunnable = new ServerRunnable(serverSocket, shutdownLatch);
     78         Thread serverThread = new Thread(serverRunnable, TAG + " (server)");
     79         ClientRunnable clientRunnable = new ClientRunnable(
     80                 serverSocket.getLocalSocketAddress(), shutdownLatch);
     81         Thread clientThread = new Thread(clientRunnable, TAG + " (client)");
     82         serverThread.start();
     83         clientThread.start();
     84         try {
     85             if (shutdownLatch.getCount() == 0) {
     86                 fail("Server prematurely shut down");
     87             }
     88             // Let server and client keep connecting for some time, then close the socket.
     89             Thread.sleep(msecPerIteration);
     90             try {
     91                 serverSocket.close();
     92             } catch (IOException e) {
     93                 throw new AssertionError("serverSocket.close() failed: ", e);
     94             }
     95             // Check that the server shut down quickly in response to the socket closing.
     96             long hardLimitSeconds = 5;
     97             boolean serverShutdownReached = shutdownLatch.await(hardLimitSeconds, TimeUnit.SECONDS);
     98             if (!serverShutdownReached) { // b/27763633
     99                 shutdownLatch.countDown();
    100                 String serverStackTrace = stackTraceAsString(serverThread.getStackTrace());
    101                 fail("Server took > " + hardLimitSeconds + "sec to react to serverSocket.close(). "
    102                         + "Server thread's stackTrace: " + serverStackTrace);
    103             }
    104             // Guard against the test passing as a false positive if no connections were actually
    105             // established. If the test was running for much longer then this would fail during
    106             // later iterations because TCP connections cannot be closed immediately (they stay
    107             // in TIME_WAIT state for a few minutes) and only some number (tens of thousands?)
    108             // can be open at a time. If this assertion turns out flaky in future, consider
    109             // reducing msecPerIteration or number of iterations.
    110             assertTrue(String.format(Locale.US, "%s: No connections in %d msec.",
    111                     iterationName, msecPerIteration),
    112                     serverRunnable.numSuccessfulConnections > 0);
    113 
    114             assertEquals(0, shutdownLatch.getCount());
    115             // Sanity check to ensure the threads don't live into the next iteration. This should
    116             // be quick because we only get here if shutdownLatch reached 0 within the time limit.
    117             serverThread.join();
    118             clientThread.join();
    119         } catch (InterruptedException e) {
    120             throw new AssertionError("Unexpected interruption", e);
    121         }
    122     }
    123 
    124     /**
    125      * Repeatedly tries to connect to and disconnect from a SocketAddress until
    126      * it observes {@code shutdownLatch} reaching 0. Does not read/write any
    127      * data from/to the socket.
    128      */
    129     static class ClientRunnable implements Runnable {
    130         private final SocketAddress socketAddress;
    131         private final CountDownLatch shutdownLatch;
    132 
    133         public ClientRunnable(
    134                 SocketAddress socketAddress, CountDownLatch shutdownLatch) {
    135             this.socketAddress = socketAddress;
    136             this.shutdownLatch = shutdownLatch;
    137         }
    138 
    139         @Override
    140         public void run() {
    141             while (shutdownLatch.getCount() != 0) { // check if server is shutting down
    142                 try {
    143                     Socket socket = new Socket();
    144                     socket.connect(socketAddress, /* timeout (msec) */ 10);
    145                     socket.close();
    146                 } catch (IOException e) {
    147                     // harmless, as long as enough connections are successful
    148                 }
    149             }
    150         }
    151     }
    152 
    153     /**
    154      * Repeatedly accepts connections from a ServerSocket and immediately closes them.
    155      * When it encounters a SocketException, it counts down the CountDownLatch and exits.
    156      */
    157     static class ServerRunnable implements Runnable {
    158         private final ServerSocket serverSocket;
    159         volatile int numSuccessfulConnections;
    160         private final CountDownLatch shutdownLatch;
    161 
    162         ServerRunnable(ServerSocket serverSocket, CountDownLatch shutdownLatch) {
    163             this.serverSocket = serverSocket;
    164             this.shutdownLatch = shutdownLatch;
    165         }
    166 
    167         @Override
    168         public void run() {
    169             int numSuccessfulConnections = 0;
    170             while (true) {
    171                 try {
    172                     Socket socket = serverSocket.accept();
    173                     numSuccessfulConnections++;
    174                     socket.close();
    175                 } catch (SocketException e) {
    176                     this.numSuccessfulConnections = numSuccessfulConnections;
    177                     shutdownLatch.countDown();
    178                     return;
    179                 } catch (IOException e) {
    180                     // harmless, as long as enough connections are successful
    181                 }
    182             }
    183         }
    184     }
    185 
    186     private static String stackTraceAsString(StackTraceElement[] stackTraceElements) {
    187         StringBuilder sb = new StringBuilder();
    188         for (StackTraceElement stackTraceElement : stackTraceElements) {
    189             sb.append("\n\t at ").append(stackTraceElement);
    190         }
    191         return sb.toString();
    192     }
    193 
    194 }
    195