Home | History | Annotate | Download | only in mockwebserver
      1 /*
      2  * Copyright (C) 2011 Google Inc.
      3  * Copyright (C) 2013 Square, Inc.
      4  *
      5  * Licensed under the Apache License, Version 2.0 (the "License");
      6  * you may not use this file except in compliance with the License.
      7  * You may obtain a copy of the License at
      8  *
      9  *      http://www.apache.org/licenses/LICENSE-2.0
     10  *
     11  * Unless required by applicable law or agreed to in writing, software
     12  * distributed under the License is distributed on an "AS IS" BASIS,
     13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     14  * See the License for the specific language governing permissions and
     15  * limitations under the License.
     16  */
     17 
     18 package com.squareup.okhttp.mockwebserver;
     19 
     20 import com.squareup.okhttp.Headers;
     21 import com.squareup.okhttp.Protocol;
     22 import com.squareup.okhttp.Request;
     23 import com.squareup.okhttp.Response;
     24 import com.squareup.okhttp.internal.NamedRunnable;
     25 import com.squareup.okhttp.internal.Platform;
     26 import com.squareup.okhttp.internal.Util;
     27 import com.squareup.okhttp.internal.spdy.ErrorCode;
     28 import com.squareup.okhttp.internal.spdy.Header;
     29 import com.squareup.okhttp.internal.spdy.IncomingStreamHandler;
     30 import com.squareup.okhttp.internal.spdy.SpdyConnection;
     31 import com.squareup.okhttp.internal.spdy.SpdyStream;
     32 import com.squareup.okhttp.internal.ws.RealWebSocket;
     33 import com.squareup.okhttp.internal.ws.WebSocketProtocol;
     34 import com.squareup.okhttp.ws.WebSocketListener;
     35 import java.io.IOException;
     36 import java.net.InetAddress;
     37 import java.net.InetSocketAddress;
     38 import java.net.MalformedURLException;
     39 import java.net.ProtocolException;
     40 import java.net.Proxy;
     41 import java.net.ServerSocket;
     42 import java.net.Socket;
     43 import java.net.SocketException;
     44 import java.net.URL;
     45 import java.security.SecureRandom;
     46 import java.security.cert.CertificateException;
     47 import java.security.cert.X509Certificate;
     48 import java.util.ArrayList;
     49 import java.util.Collections;
     50 import java.util.Iterator;
     51 import java.util.List;
     52 import java.util.Locale;
     53 import java.util.Set;
     54 import java.util.concurrent.BlockingQueue;
     55 import java.util.concurrent.ConcurrentHashMap;
     56 import java.util.concurrent.CountDownLatch;
     57 import java.util.concurrent.ExecutorService;
     58 import java.util.concurrent.Executors;
     59 import java.util.concurrent.LinkedBlockingDeque;
     60 import java.util.concurrent.LinkedBlockingQueue;
     61 import java.util.concurrent.ThreadPoolExecutor;
     62 import java.util.concurrent.TimeUnit;
     63 import java.util.concurrent.atomic.AtomicInteger;
     64 import java.util.logging.Level;
     65 import java.util.logging.Logger;
     66 import javax.net.ServerSocketFactory;
     67 import javax.net.ssl.SSLContext;
     68 import javax.net.ssl.SSLSocket;
     69 import javax.net.ssl.SSLSocketFactory;
     70 import javax.net.ssl.TrustManager;
     71 import javax.net.ssl.X509TrustManager;
     72 import okio.Buffer;
     73 import okio.BufferedSink;
     74 import okio.BufferedSource;
     75 import okio.ByteString;
     76 import okio.Okio;
     77 import okio.Sink;
     78 import okio.Timeout;
     79 
     80 import static com.squareup.okhttp.mockwebserver.SocketPolicy.DISCONNECT_AT_START;
     81 import static com.squareup.okhttp.mockwebserver.SocketPolicy.FAIL_HANDSHAKE;
     82 import static java.util.concurrent.TimeUnit.SECONDS;
     83 
     84 /**
     85  * A scriptable web server. Callers supply canned responses and the server
     86  * replays them upon request in sequence.
     87  */
     88 public final class MockWebServer {
     89   private static final X509TrustManager UNTRUSTED_TRUST_MANAGER = new X509TrustManager() {
     90     @Override public void checkClientTrusted(X509Certificate[] chain, String authType)
     91         throws CertificateException {
     92       throw new CertificateException();
     93     }
     94 
     95     @Override public void checkServerTrusted(X509Certificate[] chain, String authType) {
     96       throw new AssertionError();
     97     }
     98 
     99     @Override public X509Certificate[] getAcceptedIssuers() {
    100       throw new AssertionError();
    101     }
    102   };
    103 
    104   private static final Logger logger = Logger.getLogger(MockWebServer.class.getName());
    105 
    106   private final BlockingQueue<RecordedRequest> requestQueue = new LinkedBlockingQueue<>();
    107 
    108   private final Set<Socket> openClientSockets =
    109       Collections.newSetFromMap(new ConcurrentHashMap<Socket, Boolean>());
    110   private final Set<SpdyConnection> openSpdyConnections =
    111       Collections.newSetFromMap(new ConcurrentHashMap<SpdyConnection, Boolean>());
    112   private final AtomicInteger requestCount = new AtomicInteger();
    113   private long bodyLimit = Long.MAX_VALUE;
    114   private ServerSocketFactory serverSocketFactory = ServerSocketFactory.getDefault();
    115   private ServerSocket serverSocket;
    116   private SSLSocketFactory sslSocketFactory;
    117   private ExecutorService executor;
    118   private boolean tunnelProxy;
    119   private Dispatcher dispatcher = new QueueDispatcher();
    120 
    121   private int port = -1;
    122   private InetSocketAddress inetSocketAddress;
    123   private boolean protocolNegotiationEnabled = true;
    124   private List<Protocol> protocols
    125       = Util.immutableList(Protocol.HTTP_2, Protocol.SPDY_3, Protocol.HTTP_1_1);
    126 
    127   public void setServerSocketFactory(ServerSocketFactory serverSocketFactory) {
    128     if (serverSocketFactory == null) throw new IllegalArgumentException("null serverSocketFactory");
    129     this.serverSocketFactory = serverSocketFactory;
    130   }
    131 
    132   public int getPort() {
    133     if (port == -1) throw new IllegalStateException("Call start() before getPort()");
    134     return port;
    135   }
    136 
    137   public String getHostName() {
    138     if (inetSocketAddress == null) {
    139       throw new IllegalStateException("Call start() before getHostName()");
    140     }
    141     return inetSocketAddress.getHostName();
    142   }
    143 
    144   public Proxy toProxyAddress() {
    145     if (inetSocketAddress == null) {
    146       throw new IllegalStateException("Call start() before toProxyAddress()");
    147     }
    148     InetSocketAddress address = new InetSocketAddress(inetSocketAddress.getAddress(), getPort());
    149     return new Proxy(Proxy.Type.HTTP, address);
    150   }
    151 
    152   /**
    153    * Returns a URL for connecting to this server.
    154    * @param path the request path, such as "/".
    155    */
    156   public URL getUrl(String path) {
    157     try {
    158       return sslSocketFactory != null
    159           ? new URL("https://" + getHostName() + ":" + getPort() + path)
    160           : new URL("http://" + getHostName() + ":" + getPort() + path);
    161     } catch (MalformedURLException e) {
    162       throw new AssertionError(e);
    163     }
    164   }
    165 
    166   /**
    167    * Returns a cookie domain for this server. This returns the server's
    168    * non-loopback host name if it is known. Otherwise this returns ".local" for
    169    * this server's loopback name.
    170    */
    171   public String getCookieDomain() {
    172     String hostName = getHostName();
    173     return hostName.contains(".") ? hostName : ".local";
    174   }
    175 
    176   /**
    177    * Sets the number of bytes of the POST body to keep in memory to the given
    178    * limit.
    179    */
    180   public void setBodyLimit(long maxBodyLength) {
    181     this.bodyLimit = maxBodyLength;
    182   }
    183 
    184   /**
    185    * Sets whether ALPN is used on incoming HTTPS connections to
    186    * negotiate a protocol like HTTP/1.1 or HTTP/2. Call this method to disable
    187    * negotiation and restrict connections to HTTP/1.1.
    188    */
    189   public void setProtocolNegotiationEnabled(boolean protocolNegotiationEnabled) {
    190     this.protocolNegotiationEnabled = protocolNegotiationEnabled;
    191   }
    192 
    193   /**
    194    * Indicates the protocols supported by ALPN on incoming HTTPS
    195    * connections. This list is ignored when
    196    * {@link #setProtocolNegotiationEnabled negotiation is disabled}.
    197    *
    198    * @param protocols the protocols to use, in order of preference. The list
    199    *     must contain {@linkplain Protocol#HTTP_1_1}. It must not contain null.
    200    */
    201   public void setProtocols(List<Protocol> protocols) {
    202     protocols = Util.immutableList(protocols);
    203     if (!protocols.contains(Protocol.HTTP_1_1)) {
    204       throw new IllegalArgumentException("protocols doesn't contain http/1.1: " + protocols);
    205     }
    206     if (protocols.contains(null)) {
    207       throw new IllegalArgumentException("protocols must not contain null");
    208     }
    209     this.protocols = protocols;
    210   }
    211 
    212   /**
    213    * Serve requests with HTTPS rather than otherwise.
    214    * @param tunnelProxy true to expect the HTTP CONNECT method before
    215    *     negotiating TLS.
    216    */
    217   public void useHttps(SSLSocketFactory sslSocketFactory, boolean tunnelProxy) {
    218     this.sslSocketFactory = sslSocketFactory;
    219     this.tunnelProxy = tunnelProxy;
    220   }
    221 
    222   /**
    223    * Awaits the next HTTP request, removes it, and returns it. Callers should
    224    * use this to verify the request was sent as intended. This method will block until the
    225    * request is available, possibly forever.
    226    *
    227    * @return the head of the request queue
    228    */
    229   public RecordedRequest takeRequest() throws InterruptedException {
    230     return requestQueue.take();
    231   }
    232 
    233   /**
    234    * Awaits the next HTTP request (waiting up to the
    235    * specified wait time if necessary), removes it, and returns it. Callers should
    236    * use this to verify the request was sent as intended within the given time.
    237    *
    238    * @param timeout how long to wait before giving up, in units of
    239   *        {@code unit}
    240    * @param unit a {@code TimeUnit} determining how to interpret the
    241    *        {@code timeout} parameter
    242    * @return the head of the request queue
    243    */
    244   public RecordedRequest takeRequest(long timeout, TimeUnit unit) throws InterruptedException {
    245     return requestQueue.poll(timeout, unit);
    246   }
    247 
    248   /**
    249    * Returns the number of HTTP requests received thus far by this server. This
    250    * may exceed the number of HTTP connections when connection reuse is in
    251    * practice.
    252    */
    253   public int getRequestCount() {
    254     return requestCount.get();
    255   }
    256 
    257   /**
    258    * Scripts {@code response} to be returned to a request made in sequence. The
    259    * first request is served by the first enqueued response; the second request
    260    * by the second enqueued response; and so on.
    261    *
    262    * @throws ClassCastException if the default dispatcher has been replaced
    263    *     with {@link #setDispatcher(Dispatcher)}.
    264    */
    265   public void enqueue(MockResponse response) {
    266     ((QueueDispatcher) dispatcher).enqueueResponse(response.clone());
    267   }
    268 
    269   /** @deprecated Use {@link #start()}. */
    270   public void play() throws IOException {
    271     start();
    272   }
    273 
    274   /** @deprecated Use {@link #start(int)}. */
    275   public void play(int port) throws IOException {
    276     start(port);
    277   }
    278 
    279   /** Equivalent to {@code start(0)}. */
    280   public void start() throws IOException {
    281     start(0);
    282   }
    283 
    284   /**
    285    * Starts the server on the loopback interface for the given port.
    286    *
    287    * @param port the port to listen to, or 0 for any available port. Automated
    288    *     tests should always use port 0 to avoid flakiness when a specific port
    289    *     is unavailable.
    290    */
    291   public void start(int port) throws IOException {
    292     start(InetAddress.getByName("localhost"), port);
    293   }
    294 
    295   /**
    296    * Starts the server on the given address and port.
    297    *
    298    * @param inetAddress the address to create the server socket on
    299    *
    300    * @param port the port to listen to, or 0 for any available port. Automated
    301    *     tests should always use port 0 to avoid flakiness when a specific port
    302    *     is unavailable.
    303    */
    304   public void start(InetAddress inetAddress, int port) throws IOException {
    305     start(new InetSocketAddress(inetAddress, port));
    306   }
    307 
    308   /**
    309    * Starts the server and binds to the given socket address.
    310    *
    311    * @param inetSocketAddress the socket address to bind the server on
    312    */
    313   private void start(InetSocketAddress inetSocketAddress) throws IOException {
    314     if (executor != null) throw new IllegalStateException("start() already called");
    315     executor = Executors.newCachedThreadPool(Util.threadFactory("MockWebServer", false));
    316     this.inetSocketAddress = inetSocketAddress;
    317     serverSocket = serverSocketFactory.createServerSocket();
    318     // Reuse if the user specified a port
    319     serverSocket.setReuseAddress(inetSocketAddress.getPort() != 0);
    320     serverSocket.bind(inetSocketAddress, 50);
    321 
    322     port = serverSocket.getLocalPort();
    323     executor.execute(new NamedRunnable("MockWebServer %s", port) {
    324       @Override protected void execute() {
    325         try {
    326           logger.info(MockWebServer.this + " starting to accept connections");
    327           acceptConnections();
    328         } catch (Throwable e) {
    329           logger.log(Level.WARNING, MockWebServer.this + " failed unexpectedly", e);
    330         }
    331 
    332         // Release all sockets and all threads, even if any close fails.
    333         Util.closeQuietly(serverSocket);
    334         for (Iterator<Socket> s = openClientSockets.iterator(); s.hasNext(); ) {
    335           Util.closeQuietly(s.next());
    336           s.remove();
    337         }
    338         for (Iterator<SpdyConnection> s = openSpdyConnections.iterator(); s.hasNext(); ) {
    339           Util.closeQuietly(s.next());
    340           s.remove();
    341         }
    342         executor.shutdown();
    343       }
    344 
    345       private void acceptConnections() throws Exception {
    346         while (true) {
    347           Socket socket;
    348           try {
    349             socket = serverSocket.accept();
    350           } catch (SocketException e) {
    351             logger.info(MockWebServer.this + " done accepting connections: " + e.getMessage());
    352             return;
    353           }
    354           SocketPolicy socketPolicy = dispatcher.peek().getSocketPolicy();
    355           if (socketPolicy == DISCONNECT_AT_START) {
    356             dispatchBookkeepingRequest(0, socket);
    357             socket.close();
    358           } else {
    359             openClientSockets.add(socket);
    360             serveConnection(socket);
    361           }
    362         }
    363       }
    364     });
    365   }
    366 
    367   public void shutdown() throws IOException {
    368     if (serverSocket == null) throw new IllegalStateException("shutdown() before start()");
    369 
    370     // Cause acceptConnections() to break out.
    371     serverSocket.close();
    372 
    373     // Await shutdown.
    374     try {
    375       if (!executor.awaitTermination(5, TimeUnit.SECONDS)) {
    376         throw new IOException("Gave up waiting for executor to shut down");
    377       }
    378     } catch (InterruptedException e) {
    379       throw new AssertionError();
    380     }
    381   }
    382 
    383   private void serveConnection(final Socket raw) {
    384     executor.execute(new NamedRunnable("MockWebServer %s", raw.getRemoteSocketAddress()) {
    385       int sequenceNumber = 0;
    386 
    387       @Override protected void execute() {
    388         try {
    389           processConnection();
    390         } catch (IOException e) {
    391           logger.info(
    392               MockWebServer.this + " connection from " + raw.getInetAddress() + " failed: " + e);
    393         } catch (Exception e) {
    394           logger.log(Level.SEVERE,
    395               MockWebServer.this + " connection from " + raw.getInetAddress() + " crashed", e);
    396         }
    397       }
    398 
    399       public void processConnection() throws Exception {
    400         Protocol protocol = Protocol.HTTP_1_1;
    401         Socket socket;
    402         if (sslSocketFactory != null) {
    403           if (tunnelProxy) {
    404             createTunnel();
    405           }
    406           SocketPolicy socketPolicy = dispatcher.peek().getSocketPolicy();
    407           if (socketPolicy == FAIL_HANDSHAKE) {
    408             dispatchBookkeepingRequest(sequenceNumber, raw);
    409             processHandshakeFailure(raw);
    410             return;
    411           }
    412           socket = sslSocketFactory.createSocket(raw, raw.getInetAddress().getHostAddress(),
    413               raw.getPort(), true);
    414           SSLSocket sslSocket = (SSLSocket) socket;
    415           sslSocket.setUseClientMode(false);
    416           openClientSockets.add(socket);
    417 
    418           if (protocolNegotiationEnabled) {
    419             Platform.get().configureTlsExtensions(sslSocket, null, protocols);
    420           }
    421 
    422           sslSocket.startHandshake();
    423 
    424           if (protocolNegotiationEnabled) {
    425             String protocolString = Platform.get().getSelectedProtocol(sslSocket);
    426             protocol = protocolString != null ? Protocol.get(protocolString) : Protocol.HTTP_1_1;
    427           }
    428           openClientSockets.remove(raw);
    429         } else {
    430           socket = raw;
    431         }
    432 
    433         if (protocol != Protocol.HTTP_1_1) {
    434           SpdySocketHandler spdySocketHandler = new SpdySocketHandler(socket, protocol);
    435           SpdyConnection spdyConnection =
    436               new SpdyConnection.Builder(false, socket).protocol(protocol)
    437                   .handler(spdySocketHandler)
    438                   .build();
    439           openSpdyConnections.add(spdyConnection);
    440           openClientSockets.remove(socket);
    441           return;
    442         }
    443 
    444         BufferedSource source = Okio.buffer(Okio.source(socket));
    445         BufferedSink sink = Okio.buffer(Okio.sink(socket));
    446 
    447         while (processOneRequest(socket, source, sink)) {
    448         }
    449 
    450         if (sequenceNumber == 0) {
    451           logger.warning(MockWebServer.this
    452               + " connection from "
    453               + raw.getInetAddress()
    454               + " didn't make a request");
    455         }
    456 
    457         source.close();
    458         sink.close();
    459         socket.close();
    460         openClientSockets.remove(socket);
    461       }
    462 
    463       /**
    464        * Respond to CONNECT requests until a SWITCH_TO_SSL_AT_END response is
    465        * dispatched.
    466        */
    467       private void createTunnel() throws IOException, InterruptedException {
    468         BufferedSource source = Okio.buffer(Okio.source(raw));
    469         BufferedSink sink = Okio.buffer(Okio.sink(raw));
    470         while (true) {
    471           SocketPolicy socketPolicy = dispatcher.peek().getSocketPolicy();
    472           if (!processOneRequest(raw, source, sink)) {
    473             throw new IllegalStateException("Tunnel without any CONNECT!");
    474           }
    475           if (socketPolicy == SocketPolicy.UPGRADE_TO_SSL_AT_END) return;
    476         }
    477       }
    478 
    479       /**
    480        * Reads a request and writes its response. Returns true if further calls should be attempted
    481        * on the socket.
    482        */
    483       private boolean processOneRequest(Socket socket, BufferedSource source, BufferedSink sink)
    484           throws IOException, InterruptedException {
    485         RecordedRequest request = readRequest(socket, source, sink, sequenceNumber);
    486         if (request == null) return false;
    487 
    488         requestCount.incrementAndGet();
    489         requestQueue.add(request);
    490 
    491         MockResponse response = dispatcher.dispatch(request);
    492         if (response.getSocketPolicy() == SocketPolicy.DISCONNECT_AFTER_REQUEST) {
    493           socket.close();
    494           return false;
    495         }
    496         if (response.getSocketPolicy() == SocketPolicy.NO_RESPONSE) {
    497           // This read should block until the socket is closed. (Because nobody is writing.)
    498           if (source.exhausted()) return false;
    499           throw new ProtocolException("unexpected data");
    500         }
    501 
    502         boolean requestWantsWebSockets = "Upgrade".equalsIgnoreCase(request.getHeader("Connection"))
    503             && "websocket".equalsIgnoreCase(request.getHeader("Upgrade"));
    504         boolean responseWantsWebSockets = response.getWebSocketListener() != null;
    505         if (requestWantsWebSockets && responseWantsWebSockets) {
    506           handleWebSocketUpgrade(socket, source, sink, request, response);
    507         } else {
    508           writeHttpResponse(socket, sink, response);
    509         }
    510 
    511         if (logger.isLoggable(Level.INFO)) {
    512           logger.info(MockWebServer.this + " received request: " + request
    513               + " and responded: " + response);
    514         }
    515 
    516         if (response.getSocketPolicy() == SocketPolicy.DISCONNECT_AT_END) {
    517           socket.close();
    518           return false;
    519         } else if (response.getSocketPolicy() == SocketPolicy.SHUTDOWN_INPUT_AT_END) {
    520           socket.shutdownInput();
    521         } else if (response.getSocketPolicy() == SocketPolicy.SHUTDOWN_OUTPUT_AT_END) {
    522           socket.shutdownOutput();
    523         }
    524 
    525         sequenceNumber++;
    526         return true;
    527       }
    528     });
    529   }
    530 
    531   private void processHandshakeFailure(Socket raw) throws Exception {
    532     SSLContext context = SSLContext.getInstance("TLS");
    533     context.init(null, new TrustManager[] { UNTRUSTED_TRUST_MANAGER }, new SecureRandom());
    534     SSLSocketFactory sslSocketFactory = context.getSocketFactory();
    535     SSLSocket socket = (SSLSocket) sslSocketFactory.createSocket(
    536         raw, raw.getInetAddress().getHostAddress(), raw.getPort(), true);
    537     try {
    538       socket.startHandshake(); // we're testing a handshake failure
    539       throw new AssertionError();
    540     } catch (IOException expected) {
    541     }
    542     socket.close();
    543   }
    544 
    545   private void dispatchBookkeepingRequest(int sequenceNumber, Socket socket)
    546       throws InterruptedException {
    547     requestCount.incrementAndGet();
    548     dispatcher.dispatch(new RecordedRequest(null, null, null, -1, null, sequenceNumber, socket));
    549   }
    550 
    551   /** @param sequenceNumber the index of this request on this connection. */
    552   private RecordedRequest readRequest(Socket socket, BufferedSource source, BufferedSink sink,
    553       int sequenceNumber) throws IOException {
    554     String request;
    555     try {
    556       request = source.readUtf8LineStrict();
    557     } catch (IOException streamIsClosed) {
    558       return null; // no request because we closed the stream
    559     }
    560     if (request.length() == 0) {
    561       return null; // no request because the stream is exhausted
    562     }
    563 
    564     Headers.Builder headers = new Headers.Builder();
    565     long contentLength = -1;
    566     boolean chunked = false;
    567     boolean expectContinue = false;
    568     String header;
    569     while ((header = source.readUtf8LineStrict()).length() != 0) {
    570       headers.add(header);
    571       String lowercaseHeader = header.toLowerCase(Locale.US);
    572       if (contentLength == -1 && lowercaseHeader.startsWith("content-length:")) {
    573         contentLength = Long.parseLong(header.substring(15).trim());
    574       }
    575       if (lowercaseHeader.startsWith("transfer-encoding:")
    576           && lowercaseHeader.substring(18).trim().equals("chunked")) {
    577         chunked = true;
    578       }
    579       if (lowercaseHeader.startsWith("expect:")
    580           && lowercaseHeader.substring(7).trim().equals("100-continue")) {
    581         expectContinue = true;
    582       }
    583     }
    584 
    585     if (expectContinue) {
    586       sink.writeUtf8("HTTP/1.1 100 Continue\r\n");
    587       sink.writeUtf8("Content-Length: 0\r\n");
    588       sink.writeUtf8("\r\n");
    589       sink.flush();
    590     }
    591 
    592     boolean hasBody = false;
    593     TruncatingBuffer requestBody = new TruncatingBuffer(bodyLimit);
    594     List<Integer> chunkSizes = new ArrayList<>();
    595     MockResponse throttlePolicy = dispatcher.peek();
    596     if (contentLength != -1) {
    597       hasBody = contentLength > 0;
    598       throttledTransfer(throttlePolicy, socket, source, Okio.buffer(requestBody), contentLength);
    599     } else if (chunked) {
    600       hasBody = true;
    601       while (true) {
    602         int chunkSize = Integer.parseInt(source.readUtf8LineStrict().trim(), 16);
    603         if (chunkSize == 0) {
    604           readEmptyLine(source);
    605           break;
    606         }
    607         chunkSizes.add(chunkSize);
    608         throttledTransfer(throttlePolicy, socket, source, Okio.buffer(requestBody), chunkSize);
    609         readEmptyLine(source);
    610       }
    611     }
    612 
    613     if (request.startsWith("OPTIONS ")
    614         || request.startsWith("GET ")
    615         || request.startsWith("HEAD ")
    616         || request.startsWith("TRACE ")
    617         || request.startsWith("CONNECT ")) {
    618       if (hasBody) {
    619         throw new IllegalArgumentException("Request must not have a body: " + request);
    620       }
    621     } else if (!request.startsWith("POST ")
    622         && !request.startsWith("PUT ")
    623         && !request.startsWith("PATCH ")
    624         && !request.startsWith("DELETE ")) { // Permitted as spec is ambiguous.
    625       throw new UnsupportedOperationException("Unexpected method: " + request);
    626     }
    627 
    628     return new RecordedRequest(request, headers.build(), chunkSizes, requestBody.receivedByteCount,
    629         requestBody.buffer, sequenceNumber, socket);
    630   }
    631 
    632   private void handleWebSocketUpgrade(Socket socket, BufferedSource source, BufferedSink sink,
    633       RecordedRequest request, MockResponse response) throws IOException {
    634     String key = request.getHeader("Sec-WebSocket-Key");
    635     String acceptKey = Util.shaBase64(key + WebSocketProtocol.ACCEPT_MAGIC);
    636     response.setHeader("Sec-WebSocket-Accept", acceptKey);
    637 
    638     writeHttpResponse(socket, sink, response);
    639 
    640     final WebSocketListener listener = response.getWebSocketListener();
    641     final CountDownLatch connectionClose = new CountDownLatch(1);
    642 
    643     ThreadPoolExecutor replyExecutor =
    644         new ThreadPoolExecutor(1, 1, 1, SECONDS, new LinkedBlockingDeque<Runnable>(),
    645             Util.threadFactory(String.format("MockWebServer %s WebSocket", request.getPath()),
    646                 true));
    647     replyExecutor.allowCoreThreadTimeOut(true);
    648     final RealWebSocket webSocket =
    649         new RealWebSocket(false /* is server */, source, sink, new SecureRandom(), replyExecutor,
    650             listener, request.getPath()) {
    651           @Override protected void closeConnection() throws IOException {
    652             connectionClose.countDown();
    653           }
    654         };
    655 
    656     // Adapt the request and response into our Request and Response domain model.
    657     final Request fancyRequest = new Request.Builder()
    658         .get().url(request.getPath())
    659         .headers(request.getHeaders())
    660         .build();
    661     final Response fancyResponse = new Response.Builder()
    662         .code(Integer.parseInt(response.getStatus().split(" ")[1]))
    663         .message(response.getStatus().split(" ", 3)[2])
    664         .headers(response.getHeaders())
    665         .request(fancyRequest)
    666         .protocol(Protocol.HTTP_1_1)
    667         .build();
    668 
    669     // The callback might act synchronously. Give it its own thread.
    670     new Thread(new Runnable() {
    671       @Override public void run() {
    672         try {
    673           listener.onOpen(webSocket, fancyRequest, fancyResponse);
    674         } catch (IOException e) {
    675           // TODO try to write close frame?
    676           connectionClose.countDown();
    677         }
    678       }
    679     }, "MockWebServer WebSocket Writer " + request.getPath()).start();
    680 
    681     // Use this thread to continuously read messages.
    682     while (webSocket.readMessage()) {
    683     }
    684 
    685     // Even if messages are no longer being read we need to wait for the connection close signal.
    686     try {
    687       connectionClose.await();
    688     } catch (InterruptedException e) {
    689       throw new RuntimeException(e);
    690     }
    691 
    692     Util.closeQuietly(sink);
    693     Util.closeQuietly(source);
    694   }
    695 
    696   private void writeHttpResponse(Socket socket, BufferedSink sink, MockResponse response)
    697       throws IOException {
    698     sink.writeUtf8(response.getStatus());
    699     sink.writeUtf8("\r\n");
    700 
    701     Headers headers = response.getHeaders();
    702     for (int i = 0, size = headers.size(); i < size; i++) {
    703       sink.writeUtf8(headers.name(i));
    704       sink.writeUtf8(": ");
    705       sink.writeUtf8(headers.value(i));
    706       sink.writeUtf8("\r\n");
    707     }
    708     sink.writeUtf8("\r\n");
    709     sink.flush();
    710 
    711     Buffer body = response.getBody();
    712     if (body == null) return;
    713     sleepIfDelayed(response);
    714     throttledTransfer(response, socket, body, sink, Long.MAX_VALUE);
    715   }
    716 
    717   private void sleepIfDelayed(MockResponse response) {
    718     long delayMs = response.getBodyDelay(TimeUnit.MILLISECONDS);
    719     if (delayMs != 0) {
    720       try {
    721         Thread.sleep(delayMs);
    722       } catch (InterruptedException e) {
    723         throw new AssertionError(e);
    724       }
    725     }
    726   }
    727 
    728   /**
    729    * Transfer bytes from {@code source} to {@code sink} until either {@code byteCount}
    730    * bytes have been transferred or {@code source} is exhausted. The transfer is
    731    * throttled according to {@code throttlePolicy}.
    732    */
    733   private void throttledTransfer(MockResponse throttlePolicy, Socket socket, BufferedSource source,
    734       BufferedSink sink, long byteCount) throws IOException {
    735     if (byteCount == 0) return;
    736 
    737     Buffer buffer = new Buffer();
    738     long bytesPerPeriod = throttlePolicy.getThrottleBytesPerPeriod();
    739     long periodDelayMs = throttlePolicy.getThrottlePeriod(TimeUnit.MILLISECONDS);
    740 
    741     while (!socket.isClosed()) {
    742       for (int b = 0; b < bytesPerPeriod; ) {
    743         long toRead = Math.min(Math.min(2048, byteCount), bytesPerPeriod - b);
    744         long read = source.read(buffer, toRead);
    745         if (read == -1) return;
    746 
    747         sink.write(buffer, read);
    748         sink.flush();
    749         b += read;
    750         byteCount -= read;
    751 
    752         if (byteCount == 0) return;
    753       }
    754 
    755       if (periodDelayMs != 0) {
    756         try {
    757           Thread.sleep(periodDelayMs);
    758         } catch (InterruptedException e) {
    759           throw new AssertionError();
    760         }
    761       }
    762     }
    763   }
    764 
    765   private void readEmptyLine(BufferedSource source) throws IOException {
    766     String line = source.readUtf8LineStrict();
    767     if (line.length() != 0) throw new IllegalStateException("Expected empty but was: " + line);
    768   }
    769 
    770   /**
    771    * Sets the dispatcher used to match incoming requests to mock responses.
    772    * The default dispatcher simply serves a fixed sequence of responses from
    773    * a {@link #enqueue(MockResponse) queue}; custom dispatchers can vary the
    774    * response based on timing or the content of the request.
    775    */
    776   public void setDispatcher(Dispatcher dispatcher) {
    777     if (dispatcher == null) throw new NullPointerException();
    778     this.dispatcher = dispatcher;
    779   }
    780 
    781   @Override public String toString() {
    782     return "MockWebServer[" + port + "]";
    783   }
    784 
    785   /** A buffer wrapper that drops data after {@code bodyLimit} bytes. */
    786   private static class TruncatingBuffer implements Sink {
    787     private final Buffer buffer = new Buffer();
    788     private long remainingByteCount;
    789     private long receivedByteCount;
    790 
    791     TruncatingBuffer(long bodyLimit) {
    792       remainingByteCount = bodyLimit;
    793     }
    794 
    795     @Override public void write(Buffer source, long byteCount) throws IOException {
    796       long toRead = Math.min(remainingByteCount, byteCount);
    797       if (toRead > 0) {
    798         source.read(buffer, toRead);
    799       }
    800       long toSkip = byteCount - toRead;
    801       if (toSkip > 0) {
    802         source.skip(toSkip);
    803       }
    804       remainingByteCount -= toRead;
    805       receivedByteCount += byteCount;
    806     }
    807 
    808     @Override public void flush() throws IOException {
    809     }
    810 
    811     @Override public Timeout timeout() {
    812       return Timeout.NONE;
    813     }
    814 
    815     @Override public void close() throws IOException {
    816     }
    817   }
    818 
    819   /** Processes HTTP requests layered over SPDY/3. */
    820   private class SpdySocketHandler implements IncomingStreamHandler {
    821     private final Socket socket;
    822     private final Protocol protocol;
    823     private final AtomicInteger sequenceNumber = new AtomicInteger();
    824 
    825     private SpdySocketHandler(Socket socket, Protocol protocol) {
    826       this.socket = socket;
    827       this.protocol = protocol;
    828     }
    829 
    830     @Override public void receive(SpdyStream stream) throws IOException {
    831       RecordedRequest request = readRequest(stream);
    832       requestQueue.add(request);
    833       MockResponse response;
    834       try {
    835         response = dispatcher.dispatch(request);
    836       } catch (InterruptedException e) {
    837         throw new AssertionError(e);
    838       }
    839       writeResponse(stream, response);
    840       if (logger.isLoggable(Level.INFO)) {
    841         logger.info(MockWebServer.this + " received request: " + request
    842             + " and responded: " + response + " protocol is " + protocol.toString());
    843       }
    844     }
    845 
    846     private RecordedRequest readRequest(SpdyStream stream) throws IOException {
    847       List<Header> spdyHeaders = stream.getRequestHeaders();
    848       Headers.Builder httpHeaders = new Headers.Builder();
    849       String method = "<:method omitted>";
    850       String path = "<:path omitted>";
    851       String version = protocol == Protocol.SPDY_3 ? "<:version omitted>" : "HTTP/1.1";
    852       for (int i = 0, size = spdyHeaders.size(); i < size; i++) {
    853         ByteString name = spdyHeaders.get(i).name;
    854         String value = spdyHeaders.get(i).value.utf8();
    855         if (name.equals(Header.TARGET_METHOD)) {
    856           method = value;
    857         } else if (name.equals(Header.TARGET_PATH)) {
    858           path = value;
    859         } else if (name.equals(Header.VERSION)) {
    860           version = value;
    861         } else {
    862           httpHeaders.add(name.utf8(), value);
    863         }
    864       }
    865 
    866       Buffer body = new Buffer();
    867       body.writeAll(stream.getSource());
    868       body.close();
    869 
    870       String requestLine = method + ' ' + path + ' ' + version;
    871       List<Integer> chunkSizes = Collections.emptyList(); // No chunked encoding for SPDY.
    872       return new RecordedRequest(requestLine, httpHeaders.build(), chunkSizes, body.size(), body,
    873           sequenceNumber.getAndIncrement(), socket);
    874     }
    875 
    876     private void writeResponse(SpdyStream stream, MockResponse response) throws IOException {
    877       if (response.getSocketPolicy() == SocketPolicy.NO_RESPONSE) {
    878         return;
    879       }
    880       List<Header> spdyHeaders = new ArrayList<>();
    881       String[] statusParts = response.getStatus().split(" ", 2);
    882       if (statusParts.length != 2) {
    883         throw new AssertionError("Unexpected status: " + response.getStatus());
    884       }
    885       // TODO: constants for well-known header names.
    886       spdyHeaders.add(new Header(Header.RESPONSE_STATUS, statusParts[1]));
    887       if (protocol == Protocol.SPDY_3) {
    888         spdyHeaders.add(new Header(Header.VERSION, statusParts[0]));
    889       }
    890       Headers headers = response.getHeaders();
    891       for (int i = 0, size = headers.size(); i < size; i++) {
    892         spdyHeaders.add(new Header(headers.name(i), headers.value(i)));
    893       }
    894 
    895       Buffer body = response.getBody();
    896       boolean closeStreamAfterHeaders = body != null || !response.getPushPromises().isEmpty();
    897       stream.reply(spdyHeaders, closeStreamAfterHeaders);
    898       pushPromises(stream, response.getPushPromises());
    899       if (body != null) {
    900         BufferedSink sink = Okio.buffer(stream.getSink());
    901         sleepIfDelayed(response);
    902         throttledTransfer(response, socket, body, sink, bodyLimit);
    903         sink.close();
    904       } else if (closeStreamAfterHeaders) {
    905         stream.close(ErrorCode.NO_ERROR);
    906       }
    907     }
    908 
    909     private void pushPromises(SpdyStream stream, List<PushPromise> promises) throws IOException {
    910       for (PushPromise pushPromise : promises) {
    911         List<Header> pushedHeaders = new ArrayList<>();
    912         pushedHeaders.add(new Header(stream.getConnection().getProtocol() == Protocol.SPDY_3
    913             ? Header.TARGET_HOST
    914             : Header.TARGET_AUTHORITY, getUrl(pushPromise.getPath()).getHost()));
    915         pushedHeaders.add(new Header(Header.TARGET_METHOD, pushPromise.getMethod()));
    916         pushedHeaders.add(new Header(Header.TARGET_PATH, pushPromise.getPath()));
    917         Headers pushPromiseHeaders = pushPromise.getHeaders();
    918         for (int i = 0, size = pushPromiseHeaders.size(); i < size; i++) {
    919           pushedHeaders.add(new Header(pushPromiseHeaders.name(i), pushPromiseHeaders.value(i)));
    920         }
    921         String requestLine = pushPromise.getMethod() + ' ' + pushPromise.getPath() + " HTTP/1.1";
    922         List<Integer> chunkSizes = Collections.emptyList(); // No chunked encoding for SPDY.
    923         requestQueue.add(new RecordedRequest(requestLine, pushPromise.getHeaders(), chunkSizes, 0,
    924             new Buffer(), sequenceNumber.getAndIncrement(), socket));
    925         boolean hasBody = pushPromise.getResponse().getBody() != null;
    926         SpdyStream pushedStream =
    927             stream.getConnection().pushStream(stream.getId(), pushedHeaders, hasBody);
    928         writeResponse(pushedStream, pushPromise.getResponse());
    929       }
    930     }
    931   }
    932 }
    933