Home | History | Annotate | Download | only in net
      1 package libcore.java.net;
      2 
      3 import junit.framework.TestCase;
      4 
      5 import java.io.IOException;
      6 import java.net.Socket;
      7 import java.net.SocketImpl;
      8 import java.net.SocketException;
      9 import java.net.SocketAddress;
     10 import java.net.ServerSocket;
     11 import java.util.BitSet;
     12 import java.util.Locale;
     13 import java.util.Set;
     14 import java.util.concurrent.CountDownLatch;
     15 import java.util.concurrent.TimeUnit;
     16 import java.util.concurrent.atomic.AtomicInteger;
     17 import java.util.concurrent.atomic.AtomicReference;
     18 
     19 /**
     20  * Tests for race conditions between {@link ServerSocket#close()} and
     21  * {@link ServerSocket#accept()}.
     22  */
     23 public class ServerSocketConcurrentCloseTest extends TestCase {
     24     private static final String TAG = ServerSocketConcurrentCloseTest.class.getSimpleName();
     25 
     26     /**
     27      * The implementation of {@link ServerSocket#accept()} checks closed state before
     28      * delegating to the {@link ServerSocket#implAccept(Socket)}, however this is not
     29      * sufficient for correctness because the socket might be closed after the check.
     30      * This checks that implAccept() itself also detects closed sockets and throws
     31      * SocketException.
     32      */
     33     public void testImplAccept_detectsClosedState() throws Exception {
     34         /** A ServerSocket that exposes implAccept() */
     35         class ExposedServerSocket extends ServerSocket {
     36             public ExposedServerSocket() throws IOException {
     37                 super(0 /* allocate port number automatically */);
     38             }
     39 
     40             public void implAcceptExposedForTest(Socket socket) throws IOException {
     41                 implAccept(socket);
     42             }
     43         }
     44         final ExposedServerSocket serverSocket = new ExposedServerSocket();
     45         serverSocket.close();
     46         // implAccept() on background thread to prevent this test hanging
     47         final AtomicReference<Exception> failure = new AtomicReference<>();
     48         final CountDownLatch threadFinishedLatch = new CountDownLatch(1);
     49         Thread thread = new Thread("implAccept() closed ServerSocket") {
     50             public void run() {
     51                 try {
     52                     // Hack: Need to subclass to access the protected constructor without reflection
     53                     Socket socket = new Socket((SocketImpl) null) { };
     54                     serverSocket.implAcceptExposedForTest(socket);
     55                 } catch (SocketException expected) {
     56                     // pass
     57                 } catch (IOException|RuntimeException e) {
     58                     failure.set(e);
     59                 } finally {
     60                     threadFinishedLatch.countDown();
     61                 }
     62             }
     63         };
     64         thread.start();
     65 
     66         boolean completed = threadFinishedLatch.await(5, TimeUnit.SECONDS);
     67         assertTrue("implAccept didn't throw or return within time limit", completed);
     68         Exception e = failure.get();
     69         if (e != null) {
     70             throw new AssertionError("Unexpected exception", e);
     71         }
     72         thread.join();
     73     }
     74 
     75     /**
     76      * Test for b/27763633.
     77      */
     78     public void testConcurrentServerSocketCloseReliablyThrows() {
     79         int numIterations = 100;
     80         int minNumIterationsWithConnections = 5;
     81         int msecPerIteration = 50;
     82         BitSet iterationsWithConnections = new BitSet(numIterations);
     83         for (int i = 0; i < numIterations; i++) {
     84             int numConnectionsMade = checkConnectIterationAndCloseSocket(
     85                     "Iteration " + (i+1) + " of " + numIterations, msecPerIteration);
     86             if (numConnectionsMade > 0) {
     87                 iterationsWithConnections.set(i);
     88             }
     89         }
     90 
     91         // Guard against the test passing as a false positive if no connections were actually
     92         // established. If the test was running for much longer then this would fail during
     93         // later iterations because TCP connections cannot be closed immediately (they stay
     94         // in TIME_WAIT state for a few minutes) and only some number (tens of thousands?)
     95         // can be open at a time. If this assertion turns out flaky in future, consider
     96         // reducing msecPerIteration or numIterations.
     97         int numIterationsWithConnections = iterationsWithConnections.cardinality();
     98         String msg = String.format(Locale.US,
     99                 "Connections only made on these %d/%d iterations of %d msec: %s",
    100                 numIterationsWithConnections, numIterations, msecPerIteration,
    101                 iterationsWithConnections);
    102         assertTrue(msg, numIterationsWithConnections >= minNumIterationsWithConnections);
    103     }
    104 
    105     /**
    106      * Checks that a concurrent {@link ServerSocket#close()} reliably causes
    107      * {@link ServerSocket#accept()} to throw {@link SocketException}.
    108      *
    109      * <p>Spawns a server and client thread that continuously connect to each
    110      * other for up to {@code maxSleepsPerIteration * sleepMsec} msec.
    111      * Then, closes the {@link ServerSocket} and verifies that the server
    112      * quickly shuts down.
    113      *
    114      * @return number of connections made between server and client threads
    115      */
    116     private int checkConnectIterationAndCloseSocket(String iterationName,
    117             int msecPerIteration) {
    118         ServerSocket serverSocket;
    119         try {
    120             serverSocket = new ServerSocket(0 /* allocate port number automatically */);
    121         } catch (IOException e) {
    122             fail("Abort: " + e);
    123             throw new AssertionError("unreachable");
    124         }
    125         ServerRunnable serverRunnable = new ServerRunnable(serverSocket);
    126         Thread serverThread = new Thread(serverRunnable, TAG + " (server)");
    127         ClientRunnable clientRunnable = new ClientRunnable(
    128                 serverSocket.getLocalSocketAddress(), serverRunnable);
    129         Thread clientThread = new Thread(clientRunnable, TAG + " (client)");
    130         serverThread.start();
    131         clientThread.start();
    132         try {
    133             assertTrue("Slow server startup", serverRunnable.awaitStart(1, TimeUnit.SECONDS));
    134             assertTrue("Slow client startup", clientRunnable.awaitStart(1, TimeUnit.SECONDS));
    135             if (serverRunnable.isShutdown()) {
    136                 fail("Server prematurely shut down");
    137             }
    138             // Let server and client keep connecting for some time, then close the socket.
    139             Thread.sleep(msecPerIteration);
    140             try {
    141                 serverSocket.close();
    142             } catch (IOException e) {
    143                 throw new AssertionError("serverSocket.close() failed: ", e);
    144             }
    145             // Check that the server shut down quickly in response to the socket closing.
    146             long hardLimitSeconds = 5;
    147             boolean serverShutdownReached = serverRunnable.awaitShutdown(hardLimitSeconds, TimeUnit.SECONDS);
    148             if (!serverShutdownReached) { // b/27763633
    149                 String serverStackTrace = stackTraceAsString(serverThread.getStackTrace());
    150                 fail("Server took > " + hardLimitSeconds + "sec to react to serverSocket.close(). "
    151                         + "Server thread's stackTrace: " + serverStackTrace);
    152             }
    153             assertTrue(serverRunnable.isShutdown());
    154             // Sanity check to ensure the threads don't live into the next iteration. This should
    155             // be quick because we only get here if shutdownLatch reached 0 within the time limit.
    156             serverThread.join();
    157             clientThread.join();
    158             return serverRunnable.numSuccessfulConnections.get();
    159         } catch (InterruptedException e) {
    160             throw new AssertionError("Unexpected interruption", e);
    161         }
    162     }
    163 
    164     /**
    165      * Repeatedly tries to connect to and disconnect from a SocketAddress until
    166      * it observes {@code shutdownLatch} reaching 0. Does not read/write any
    167      * data from/to the socket.
    168      */
    169     static class ClientRunnable implements Runnable {
    170         private final SocketAddress socketAddress;
    171 
    172         private final ServerRunnable serverRunnable;
    173         private final CountDownLatch startLatch = new CountDownLatch(1);
    174 
    175         public ClientRunnable(
    176                 SocketAddress socketAddress, ServerRunnable serverRunnable) {
    177             this.socketAddress = socketAddress;
    178             this.serverRunnable = serverRunnable;
    179         }
    180 
    181         @Override
    182         public void run() {
    183             startLatch.countDown();
    184             while (!serverRunnable.isShutdown()) {
    185                 try {
    186                     Socket socket = new Socket();
    187                     socket.connect(socketAddress, /* timeout (msec) */ 10);
    188                     socket.close();
    189                 } catch (IOException e) {
    190                     // harmless, as long as enough connections are successful
    191                 }
    192             }
    193         }
    194 
    195         public boolean awaitStart(long timeout, TimeUnit timeUnit) throws InterruptedException {
    196             return startLatch.await(timeout, timeUnit);
    197         }
    198 
    199     }
    200 
    201     /**
    202      * Repeatedly accepts connections from a ServerSocket and immediately closes them.
    203      * When it encounters a SocketException, it counts down the CountDownLatch and exits.
    204      */
    205     static class ServerRunnable implements Runnable {
    206         private final ServerSocket serverSocket;
    207         final AtomicInteger numSuccessfulConnections = new AtomicInteger();
    208         private final CountDownLatch startLatch = new CountDownLatch(1);
    209         private final CountDownLatch shutdownLatch = new CountDownLatch(1);
    210 
    211         ServerRunnable(ServerSocket serverSocket) {
    212             this.serverSocket = serverSocket;
    213         }
    214 
    215         @Override
    216         public void run() {
    217             startLatch.countDown();
    218             while (true) {
    219                 try {
    220                     Socket socket = serverSocket.accept();
    221                     numSuccessfulConnections.incrementAndGet();
    222                     socket.close();
    223                 } catch (SocketException e) {
    224                     shutdownLatch.countDown();
    225                     return;
    226                 } catch (IOException e) {
    227                     // harmless, as long as enough connections are successful
    228                 }
    229             }
    230         }
    231 
    232         public boolean awaitStart(long timeout, TimeUnit timeUnit) throws InterruptedException {
    233             return startLatch.await(timeout, timeUnit);
    234         }
    235 
    236         public boolean awaitShutdown(long timeout, TimeUnit timeUnit) throws InterruptedException {
    237             return shutdownLatch.await(timeout, timeUnit);
    238         }
    239 
    240         public boolean isShutdown() {
    241             return shutdownLatch.getCount() == 0;
    242         }
    243     }
    244 
    245     private static String stackTraceAsString(StackTraceElement[] stackTraceElements) {
    246         StringBuilder sb = new StringBuilder();
    247         for (StackTraceElement stackTraceElement : stackTraceElements) {
    248             sb.append("\n\t at ").append(stackTraceElement);
    249         }
    250         return sb.toString();
    251     }
    252 
    253 }
    254