Home | History | Annotate | Download | only in mockwebserver
      1 /*
      2  * Copyright (C) 2011 Google Inc.
      3  * Copyright (C) 2013 Square, Inc.
      4  *
      5  * Licensed under the Apache License, Version 2.0 (the "License");
      6  * you may not use this file except in compliance with the License.
      7  * You may obtain a copy of the License at
      8  *
      9  *      http://www.apache.org/licenses/LICENSE-2.0
     10  *
     11  * Unless required by applicable law or agreed to in writing, software
     12  * distributed under the License is distributed on an "AS IS" BASIS,
     13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     14  * See the License for the specific language governing permissions and
     15  * limitations under the License.
     16  */
     17 
     18 package com.squareup.okhttp.mockwebserver;
     19 
     20 import com.squareup.okhttp.Headers;
     21 import com.squareup.okhttp.HttpUrl;
     22 import com.squareup.okhttp.Protocol;
     23 import com.squareup.okhttp.Request;
     24 import com.squareup.okhttp.Response;
     25 import com.squareup.okhttp.internal.NamedRunnable;
     26 import com.squareup.okhttp.internal.Platform;
     27 import com.squareup.okhttp.internal.Util;
     28 import com.squareup.okhttp.internal.framed.ErrorCode;
     29 import com.squareup.okhttp.internal.framed.FramedConnection;
     30 import com.squareup.okhttp.internal.framed.FramedStream;
     31 import com.squareup.okhttp.internal.framed.Header;
     32 import com.squareup.okhttp.internal.framed.Settings;
     33 import com.squareup.okhttp.internal.http.HttpMethod;
     34 import com.squareup.okhttp.internal.ws.RealWebSocket;
     35 import com.squareup.okhttp.internal.ws.WebSocketProtocol;
     36 import com.squareup.okhttp.ws.WebSocketListener;
     37 import java.io.IOException;
     38 import java.net.InetAddress;
     39 import java.net.InetSocketAddress;
     40 import java.net.ProtocolException;
     41 import java.net.Proxy;
     42 import java.net.ServerSocket;
     43 import java.net.Socket;
     44 import java.net.SocketException;
     45 import java.net.URL;
     46 import java.security.SecureRandom;
     47 import java.security.cert.CertificateException;
     48 import java.security.cert.X509Certificate;
     49 import java.util.ArrayList;
     50 import java.util.Collections;
     51 import java.util.Iterator;
     52 import java.util.List;
     53 import java.util.Locale;
     54 import java.util.Set;
     55 import java.util.concurrent.BlockingQueue;
     56 import java.util.concurrent.ConcurrentHashMap;
     57 import java.util.concurrent.CountDownLatch;
     58 import java.util.concurrent.ExecutorService;
     59 import java.util.concurrent.Executors;
     60 import java.util.concurrent.LinkedBlockingDeque;
     61 import java.util.concurrent.LinkedBlockingQueue;
     62 import java.util.concurrent.ThreadPoolExecutor;
     63 import java.util.concurrent.TimeUnit;
     64 import java.util.concurrent.atomic.AtomicInteger;
     65 import java.util.logging.Level;
     66 import java.util.logging.Logger;
     67 import javax.net.ServerSocketFactory;
     68 import javax.net.ssl.SSLContext;
     69 import javax.net.ssl.SSLSocket;
     70 import javax.net.ssl.SSLSocketFactory;
     71 import javax.net.ssl.TrustManager;
     72 import javax.net.ssl.X509TrustManager;
     73 import okio.Buffer;
     74 import okio.BufferedSink;
     75 import okio.BufferedSource;
     76 import okio.ByteString;
     77 import okio.Okio;
     78 import okio.Sink;
     79 import okio.Timeout;
     80 import org.junit.rules.TestRule;
     81 import org.junit.runner.Description;
     82 import org.junit.runners.model.Statement;
     83 
     84 import static com.squareup.okhttp.mockwebserver.SocketPolicy.DISCONNECT_AT_START;
     85 import static com.squareup.okhttp.mockwebserver.SocketPolicy.DISCONNECT_DURING_REQUEST_BODY;
     86 import static com.squareup.okhttp.mockwebserver.SocketPolicy.DISCONNECT_DURING_RESPONSE_BODY;
     87 import static com.squareup.okhttp.mockwebserver.SocketPolicy.FAIL_HANDSHAKE;
     88 import static java.util.concurrent.TimeUnit.SECONDS;
     89 
     90 /**
     91  * A scriptable web server. Callers supply canned responses and the server
     92  * replays them upon request in sequence.
     93  */
     94 public final class MockWebServer implements TestRule {
     95   private static final X509TrustManager UNTRUSTED_TRUST_MANAGER = new X509TrustManager() {
     96     @Override public void checkClientTrusted(X509Certificate[] chain, String authType)
     97         throws CertificateException {
     98       throw new CertificateException();
     99     }
    100 
    101     @Override public void checkServerTrusted(X509Certificate[] chain, String authType) {
    102       throw new AssertionError();
    103     }
    104 
    105     @Override public X509Certificate[] getAcceptedIssuers() {
    106       throw new AssertionError();
    107     }
    108   };
    109 
    110   private static final Logger logger = Logger.getLogger(MockWebServer.class.getName());
    111 
    112   private final BlockingQueue<RecordedRequest> requestQueue = new LinkedBlockingQueue<>();
    113 
    114   private final Set<Socket> openClientSockets =
    115       Collections.newSetFromMap(new ConcurrentHashMap<Socket, Boolean>());
    116   private final Set<FramedConnection> openFramedConnections =
    117       Collections.newSetFromMap(new ConcurrentHashMap<FramedConnection, Boolean>());
    118   private final AtomicInteger requestCount = new AtomicInteger();
    119   private long bodyLimit = Long.MAX_VALUE;
    120   private ServerSocketFactory serverSocketFactory = ServerSocketFactory.getDefault();
    121   private ServerSocket serverSocket;
    122   private SSLSocketFactory sslSocketFactory;
    123   private ExecutorService executor;
    124   private boolean tunnelProxy;
    125   private Dispatcher dispatcher = new QueueDispatcher();
    126 
    127   private int port = -1;
    128   private InetSocketAddress inetSocketAddress;
    129   private boolean protocolNegotiationEnabled = true;
    130   private List<Protocol> protocols
    131       = Util.immutableList(Protocol.HTTP_2, Protocol.SPDY_3, Protocol.HTTP_1_1);
    132 
    133   private boolean started;
    134 
    135   private synchronized void maybeStart() {
    136     if (started) return;
    137     try {
    138       start();
    139     } catch (IOException e) {
    140       throw new RuntimeException(e);
    141     }
    142   }
    143 
    144   @Override public Statement apply(final Statement base, Description description) {
    145     return new Statement() {
    146       @Override public void evaluate() throws Throwable {
    147         maybeStart();
    148         try {
    149           base.evaluate();
    150         } finally {
    151           try {
    152             shutdown();
    153           } catch (IOException e) {
    154             logger.log(Level.WARNING, "MockWebServer shutdown failed", e);
    155           }
    156         }
    157       }
    158     };
    159   }
    160 
    161   public int getPort() {
    162     maybeStart();
    163     return port;
    164   }
    165 
    166   public String getHostName() {
    167     maybeStart();
    168     return inetSocketAddress.getHostName();
    169   }
    170 
    171   public Proxy toProxyAddress() {
    172     maybeStart();
    173     InetSocketAddress address = new InetSocketAddress(inetSocketAddress.getAddress(), getPort());
    174     return new Proxy(Proxy.Type.HTTP, address);
    175   }
    176 
    177   public void setServerSocketFactory(ServerSocketFactory serverSocketFactory) {
    178     if (executor != null) {
    179       throw new IllegalStateException(
    180           "setServerSocketFactory() must be called before start()");
    181     }
    182     this.serverSocketFactory = serverSocketFactory;
    183   }
    184 
    185   /**
    186    * Returns a URL for connecting to this server.
    187    * @param path the request path, such as "/".
    188    */
    189   @Deprecated
    190   public URL getUrl(String path) {
    191     return url(path).url();
    192   }
    193 
    194   /**
    195    * Returns a URL for connecting to this server.
    196    *
    197    * @param path the request path, such as "/".
    198    */
    199   public HttpUrl url(String path) {
    200     return new HttpUrl.Builder()
    201         .scheme(sslSocketFactory != null ? "https" : "http")
    202         .host(getHostName())
    203         .port(getPort())
    204         .build()
    205         .resolve(path);
    206   }
    207 
    208   /**
    209    * Returns a cookie domain for this server. This returns the server's
    210    * non-loopback host name if it is known. Otherwise this returns ".local" for
    211    * this server's loopback name.
    212    */
    213   public String getCookieDomain() {
    214     String hostName = getHostName();
    215     return hostName.contains(".") ? hostName : ".local";
    216   }
    217 
    218   /**
    219    * Sets the number of bytes of the POST body to keep in memory to the given
    220    * limit.
    221    */
    222   public void setBodyLimit(long maxBodyLength) {
    223     this.bodyLimit = maxBodyLength;
    224   }
    225 
    226   /**
    227    * Sets whether ALPN is used on incoming HTTPS connections to
    228    * negotiate a protocol like HTTP/1.1 or HTTP/2. Call this method to disable
    229    * negotiation and restrict connections to HTTP/1.1.
    230    */
    231   public void setProtocolNegotiationEnabled(boolean protocolNegotiationEnabled) {
    232     this.protocolNegotiationEnabled = protocolNegotiationEnabled;
    233   }
    234 
    235   /**
    236    * Indicates the protocols supported by ALPN on incoming HTTPS
    237    * connections. This list is ignored when
    238    * {@link #setProtocolNegotiationEnabled negotiation is disabled}.
    239    *
    240    * @param protocols the protocols to use, in order of preference. The list
    241    *     must contain {@linkplain Protocol#HTTP_1_1}. It must not contain null.
    242    */
    243   public void setProtocols(List<Protocol> protocols) {
    244     protocols = Util.immutableList(protocols);
    245     if (!protocols.contains(Protocol.HTTP_1_1)) {
    246       throw new IllegalArgumentException("protocols doesn't contain http/1.1: " + protocols);
    247     }
    248     if (protocols.contains(null)) {
    249       throw new IllegalArgumentException("protocols must not contain null");
    250     }
    251     this.protocols = protocols;
    252   }
    253 
    254   /**
    255    * Serve requests with HTTPS rather than otherwise.
    256    * @param tunnelProxy true to expect the HTTP CONNECT method before
    257    *     negotiating TLS.
    258    */
    259   public void useHttps(SSLSocketFactory sslSocketFactory, boolean tunnelProxy) {
    260     this.sslSocketFactory = sslSocketFactory;
    261     this.tunnelProxy = tunnelProxy;
    262   }
    263 
    264   /**
    265    * Awaits the next HTTP request, removes it, and returns it. Callers should
    266    * use this to verify the request was sent as intended. This method will block until the
    267    * request is available, possibly forever.
    268    *
    269    * @return the head of the request queue
    270    */
    271   public RecordedRequest takeRequest() throws InterruptedException {
    272     return requestQueue.take();
    273   }
    274 
    275   /**
    276    * Awaits the next HTTP request (waiting up to the
    277    * specified wait time if necessary), removes it, and returns it. Callers should
    278    * use this to verify the request was sent as intended within the given time.
    279    *
    280    * @param timeout how long to wait before giving up, in units of
    281   *        {@code unit}
    282    * @param unit a {@code TimeUnit} determining how to interpret the
    283    *        {@code timeout} parameter
    284    * @return the head of the request queue
    285    */
    286   public RecordedRequest takeRequest(long timeout, TimeUnit unit) throws InterruptedException {
    287     return requestQueue.poll(timeout, unit);
    288   }
    289 
    290   /**
    291    * Returns the number of HTTP requests received thus far by this server. This
    292    * may exceed the number of HTTP connections when connection reuse is in
    293    * practice.
    294    */
    295   public int getRequestCount() {
    296     return requestCount.get();
    297   }
    298 
    299   /**
    300    * Scripts {@code response} to be returned to a request made in sequence. The
    301    * first request is served by the first enqueued response; the second request
    302    * by the second enqueued response; and so on.
    303    *
    304    * @throws ClassCastException if the default dispatcher has been replaced
    305    *     with {@link #setDispatcher(Dispatcher)}.
    306    */
    307   public void enqueue(MockResponse response) {
    308     ((QueueDispatcher) dispatcher).enqueueResponse(response.clone());
    309   }
    310 
    311   /** Equivalent to {@code start(0)}. */
    312   public void start() throws IOException {
    313     start(0);
    314   }
    315 
    316   /**
    317    * Starts the server on the loopback interface for the given port.
    318    *
    319    * @param port the port to listen to, or 0 for any available port. Automated
    320    *     tests should always use port 0 to avoid flakiness when a specific port
    321    *     is unavailable.
    322    */
    323   public void start(int port) throws IOException {
    324     start(InetAddress.getByName("localhost"), port);
    325   }
    326 
    327   /**
    328    * Starts the server on the given address and port.
    329    *
    330    * @param inetAddress the address to create the server socket on
    331    *
    332    * @param port the port to listen to, or 0 for any available port. Automated
    333    *     tests should always use port 0 to avoid flakiness when a specific port
    334    *     is unavailable.
    335    */
    336   public void start(InetAddress inetAddress, int port) throws IOException {
    337     start(new InetSocketAddress(inetAddress, port));
    338   }
    339 
    340   /**
    341    * Starts the server and binds to the given socket address.
    342    *
    343    * @param inetSocketAddress the socket address to bind the server on
    344    */
    345   private synchronized void start(InetSocketAddress inetSocketAddress) throws IOException {
    346     if (started) throw new IllegalStateException("start() already called");
    347     started = true;
    348 
    349     executor = Executors.newCachedThreadPool(Util.threadFactory("MockWebServer", false));
    350     this.inetSocketAddress = inetSocketAddress;
    351     serverSocket = serverSocketFactory.createServerSocket();
    352     // Reuse if the user specified a port
    353     serverSocket.setReuseAddress(inetSocketAddress.getPort() != 0);
    354     serverSocket.bind(inetSocketAddress, 50);
    355 
    356     port = serverSocket.getLocalPort();
    357     executor.execute(new NamedRunnable("MockWebServer %s", port) {
    358       @Override protected void execute() {
    359         try {
    360           logger.info(MockWebServer.this + " starting to accept connections");
    361           acceptConnections();
    362         } catch (Throwable e) {
    363           logger.log(Level.WARNING, MockWebServer.this + " failed unexpectedly", e);
    364         }
    365 
    366         // Release all sockets and all threads, even if any close fails.
    367         Util.closeQuietly(serverSocket);
    368         for (Iterator<Socket> s = openClientSockets.iterator(); s.hasNext(); ) {
    369           Util.closeQuietly(s.next());
    370           s.remove();
    371         }
    372         for (Iterator<FramedConnection> s = openFramedConnections.iterator(); s.hasNext(); ) {
    373           Util.closeQuietly(s.next());
    374           s.remove();
    375         }
    376         executor.shutdown();
    377       }
    378 
    379       private void acceptConnections() throws Exception {
    380         while (true) {
    381           Socket socket;
    382           try {
    383             socket = serverSocket.accept();
    384           } catch (SocketException e) {
    385             logger.info(MockWebServer.this + " done accepting connections: " + e.getMessage());
    386             return;
    387           }
    388           SocketPolicy socketPolicy = dispatcher.peek().getSocketPolicy();
    389           if (socketPolicy == DISCONNECT_AT_START) {
    390             dispatchBookkeepingRequest(0, socket);
    391             socket.close();
    392           } else {
    393             openClientSockets.add(socket);
    394             serveConnection(socket);
    395           }
    396         }
    397       }
    398     });
    399   }
    400 
    401   public synchronized void shutdown() throws IOException {
    402     if (!started) return;
    403     if (serverSocket == null) throw new IllegalStateException("shutdown() before start()");
    404 
    405     // Cause acceptConnections() to break out.
    406     serverSocket.close();
    407 
    408     // Await shutdown.
    409     try {
    410       if (!executor.awaitTermination(5, TimeUnit.SECONDS)) {
    411         throw new IOException("Gave up waiting for executor to shut down");
    412       }
    413     } catch (InterruptedException e) {
    414       throw new AssertionError();
    415     }
    416   }
    417 
    418   private void serveConnection(final Socket raw) {
    419     executor.execute(new NamedRunnable("MockWebServer %s", raw.getRemoteSocketAddress()) {
    420       int sequenceNumber = 0;
    421 
    422       @Override protected void execute() {
    423         try {
    424           processConnection();
    425         } catch (IOException e) {
    426           logger.info(
    427               MockWebServer.this + " connection from " + raw.getInetAddress() + " failed: " + e);
    428         } catch (Exception e) {
    429           logger.log(Level.SEVERE,
    430               MockWebServer.this + " connection from " + raw.getInetAddress() + " crashed", e);
    431         }
    432       }
    433 
    434       public void processConnection() throws Exception {
    435         Protocol protocol = Protocol.HTTP_1_1;
    436         Socket socket;
    437         if (sslSocketFactory != null) {
    438           if (tunnelProxy) {
    439             createTunnel();
    440           }
    441           SocketPolicy socketPolicy = dispatcher.peek().getSocketPolicy();
    442           if (socketPolicy == FAIL_HANDSHAKE) {
    443             dispatchBookkeepingRequest(sequenceNumber, raw);
    444             processHandshakeFailure(raw);
    445             return;
    446           }
    447           socket = sslSocketFactory.createSocket(raw, raw.getInetAddress().getHostAddress(),
    448               raw.getPort(), true);
    449           SSLSocket sslSocket = (SSLSocket) socket;
    450           sslSocket.setUseClientMode(false);
    451           openClientSockets.add(socket);
    452 
    453           if (protocolNegotiationEnabled) {
    454             Platform.get().configureTlsExtensions(sslSocket, null, protocols);
    455           }
    456 
    457           sslSocket.startHandshake();
    458 
    459           if (protocolNegotiationEnabled) {
    460             String protocolString = Platform.get().getSelectedProtocol(sslSocket);
    461             protocol = protocolString != null ? Protocol.get(protocolString) : Protocol.HTTP_1_1;
    462           }
    463           openClientSockets.remove(raw);
    464         } else {
    465           socket = raw;
    466         }
    467 
    468         if (protocol != Protocol.HTTP_1_1) {
    469           FramedSocketHandler framedSocketListener = new FramedSocketHandler(socket, protocol);
    470           FramedConnection framedConnection = new FramedConnection.Builder(false)
    471               .socket(socket)
    472               .protocol(protocol)
    473               .listener(framedSocketListener)
    474               .build();
    475           openFramedConnections.add(framedConnection);
    476           openClientSockets.remove(socket);
    477           return;
    478         }
    479 
    480         BufferedSource source = Okio.buffer(Okio.source(socket));
    481         BufferedSink sink = Okio.buffer(Okio.sink(socket));
    482 
    483         while (processOneRequest(socket, source, sink)) {
    484         }
    485 
    486         if (sequenceNumber == 0) {
    487           logger.warning(MockWebServer.this
    488               + " connection from "
    489               + raw.getInetAddress()
    490               + " didn't make a request");
    491         }
    492 
    493         source.close();
    494         sink.close();
    495         socket.close();
    496         openClientSockets.remove(socket);
    497       }
    498 
    499       /**
    500        * Respond to CONNECT requests until a SWITCH_TO_SSL_AT_END response is
    501        * dispatched.
    502        */
    503       private void createTunnel() throws IOException, InterruptedException {
    504         BufferedSource source = Okio.buffer(Okio.source(raw));
    505         BufferedSink sink = Okio.buffer(Okio.sink(raw));
    506         while (true) {
    507           SocketPolicy socketPolicy = dispatcher.peek().getSocketPolicy();
    508           if (!processOneRequest(raw, source, sink)) {
    509             throw new IllegalStateException("Tunnel without any CONNECT!");
    510           }
    511           if (socketPolicy == SocketPolicy.UPGRADE_TO_SSL_AT_END) return;
    512         }
    513       }
    514 
    515       /**
    516        * Reads a request and writes its response. Returns true if further calls should be attempted
    517        * on the socket.
    518        */
    519       private boolean processOneRequest(Socket socket, BufferedSource source, BufferedSink sink)
    520           throws IOException, InterruptedException {
    521         RecordedRequest request = readRequest(socket, source, sink, sequenceNumber);
    522         if (request == null) return false;
    523 
    524         requestCount.incrementAndGet();
    525         requestQueue.add(request);
    526 
    527         MockResponse response = dispatcher.dispatch(request);
    528         if (response.getSocketPolicy() == SocketPolicy.DISCONNECT_AFTER_REQUEST) {
    529           socket.close();
    530           return false;
    531         }
    532         if (response.getSocketPolicy() == SocketPolicy.NO_RESPONSE) {
    533           // This read should block until the socket is closed. (Because nobody is writing.)
    534           if (source.exhausted()) return false;
    535           throw new ProtocolException("unexpected data");
    536         }
    537 
    538         boolean reuseSocket = true;
    539         boolean requestWantsWebSockets = "Upgrade".equalsIgnoreCase(request.getHeader("Connection"))
    540             && "websocket".equalsIgnoreCase(request.getHeader("Upgrade"));
    541         boolean responseWantsWebSockets = response.getWebSocketListener() != null;
    542         if (requestWantsWebSockets && responseWantsWebSockets) {
    543           handleWebSocketUpgrade(socket, source, sink, request, response);
    544           reuseSocket = false;
    545         } else {
    546           writeHttpResponse(socket, sink, response);
    547         }
    548 
    549         if (logger.isLoggable(Level.INFO)) {
    550           logger.info(MockWebServer.this + " received request: " + request
    551               + " and responded: " + response);
    552         }
    553 
    554         if (response.getSocketPolicy() == SocketPolicy.DISCONNECT_AT_END) {
    555           socket.close();
    556           return false;
    557         } else if (response.getSocketPolicy() == SocketPolicy.SHUTDOWN_INPUT_AT_END) {
    558           socket.shutdownInput();
    559         } else if (response.getSocketPolicy() == SocketPolicy.SHUTDOWN_OUTPUT_AT_END) {
    560           socket.shutdownOutput();
    561         }
    562 
    563         sequenceNumber++;
    564         return reuseSocket;
    565       }
    566     });
    567   }
    568 
    569   private void processHandshakeFailure(Socket raw) throws Exception {
    570     SSLContext context = SSLContext.getInstance("TLS");
    571     context.init(null, new TrustManager[] { UNTRUSTED_TRUST_MANAGER }, new SecureRandom());
    572     SSLSocketFactory sslSocketFactory = context.getSocketFactory();
    573     SSLSocket socket = (SSLSocket) sslSocketFactory.createSocket(
    574         raw, raw.getInetAddress().getHostAddress(), raw.getPort(), true);
    575     try {
    576       socket.startHandshake(); // we're testing a handshake failure
    577       throw new AssertionError();
    578     } catch (IOException expected) {
    579     }
    580     socket.close();
    581   }
    582 
    583   private void dispatchBookkeepingRequest(int sequenceNumber, Socket socket)
    584       throws InterruptedException {
    585     requestCount.incrementAndGet();
    586     dispatcher.dispatch(new RecordedRequest(null, null, null, -1, null, sequenceNumber, socket));
    587   }
    588 
    589   /** @param sequenceNumber the index of this request on this connection. */
    590   private RecordedRequest readRequest(Socket socket, BufferedSource source, BufferedSink sink,
    591       int sequenceNumber) throws IOException {
    592     String request;
    593     try {
    594       request = source.readUtf8LineStrict();
    595     } catch (IOException streamIsClosed) {
    596       return null; // no request because we closed the stream
    597     }
    598     if (request.length() == 0) {
    599       return null; // no request because the stream is exhausted
    600     }
    601 
    602     Headers.Builder headers = new Headers.Builder();
    603     long contentLength = -1;
    604     boolean chunked = false;
    605     boolean expectContinue = false;
    606     String header;
    607     while ((header = source.readUtf8LineStrict()).length() != 0) {
    608       headers.add(header);
    609       String lowercaseHeader = header.toLowerCase(Locale.US);
    610       if (contentLength == -1 && lowercaseHeader.startsWith("content-length:")) {
    611         contentLength = Long.parseLong(header.substring(15).trim());
    612       }
    613       if (lowercaseHeader.startsWith("transfer-encoding:")
    614           && lowercaseHeader.substring(18).trim().equals("chunked")) {
    615         chunked = true;
    616       }
    617       if (lowercaseHeader.startsWith("expect:")
    618           && lowercaseHeader.substring(7).trim().equals("100-continue")) {
    619         expectContinue = true;
    620       }
    621     }
    622 
    623     if (expectContinue) {
    624       sink.writeUtf8("HTTP/1.1 100 Continue\r\n");
    625       sink.writeUtf8("Content-Length: 0\r\n");
    626       sink.writeUtf8("\r\n");
    627       sink.flush();
    628     }
    629 
    630     boolean hasBody = false;
    631     TruncatingBuffer requestBody = new TruncatingBuffer(bodyLimit);
    632     List<Integer> chunkSizes = new ArrayList<>();
    633     MockResponse policy = dispatcher.peek();
    634     if (contentLength != -1) {
    635       hasBody = contentLength > 0;
    636       throttledTransfer(policy, socket, source, Okio.buffer(requestBody), contentLength, true);
    637     } else if (chunked) {
    638       hasBody = true;
    639       while (true) {
    640         int chunkSize = Integer.parseInt(source.readUtf8LineStrict().trim(), 16);
    641         if (chunkSize == 0) {
    642           readEmptyLine(source);
    643           break;
    644         }
    645         chunkSizes.add(chunkSize);
    646         throttledTransfer(policy, socket, source, Okio.buffer(requestBody), chunkSize, true);
    647         readEmptyLine(source);
    648       }
    649     }
    650 
    651     String method = request.substring(0, request.indexOf(' '));
    652     if (hasBody && !HttpMethod.permitsRequestBody(method)) {
    653       throw new IllegalArgumentException("Request must not have a body: " + request);
    654     }
    655 
    656     return new RecordedRequest(request, headers.build(), chunkSizes, requestBody.receivedByteCount,
    657         requestBody.buffer, sequenceNumber, socket);
    658   }
    659 
    660   private void handleWebSocketUpgrade(Socket socket, BufferedSource source, BufferedSink sink,
    661       RecordedRequest request, MockResponse response) throws IOException {
    662     String key = request.getHeader("Sec-WebSocket-Key");
    663     String acceptKey = Util.shaBase64(key + WebSocketProtocol.ACCEPT_MAGIC);
    664     response.setHeader("Sec-WebSocket-Accept", acceptKey);
    665 
    666     writeHttpResponse(socket, sink, response);
    667 
    668     final WebSocketListener listener = response.getWebSocketListener();
    669     final CountDownLatch connectionClose = new CountDownLatch(1);
    670 
    671     ThreadPoolExecutor replyExecutor =
    672         new ThreadPoolExecutor(1, 1, 1, SECONDS, new LinkedBlockingDeque<Runnable>(),
    673             Util.threadFactory(String.format("MockWebServer %s WebSocket", request.getPath()),
    674                 true));
    675     replyExecutor.allowCoreThreadTimeOut(true);
    676     final RealWebSocket webSocket =
    677         new RealWebSocket(false /* is server */, source, sink, new SecureRandom(), replyExecutor,
    678             listener, request.getPath()) {
    679           @Override protected void close() throws IOException {
    680             connectionClose.countDown();
    681           }
    682         };
    683 
    684     // Adapt the request and response into our Request and Response domain model.
    685     String scheme = request.getTlsVersion() != null ? "https" : "http";
    686     String authority = request.getHeader("Host"); // Has host and port.
    687     final Request fancyRequest = new Request.Builder()
    688         .url(scheme + "://" + authority + "/")
    689         .headers(request.getHeaders())
    690         .build();
    691     final Response fancyResponse = new Response.Builder()
    692         .code(Integer.parseInt(response.getStatus().split(" ")[1]))
    693         .message(response.getStatus().split(" ", 3)[2])
    694         .headers(response.getHeaders())
    695         .request(fancyRequest)
    696         .protocol(Protocol.HTTP_1_1)
    697         .build();
    698 
    699     listener.onOpen(webSocket, fancyResponse);
    700 
    701     while (webSocket.readMessage()) {
    702     }
    703 
    704     // Even if messages are no longer being read we need to wait for the connection close signal.
    705     try {
    706       connectionClose.await();
    707     } catch (InterruptedException e) {
    708       throw new RuntimeException(e);
    709     }
    710 
    711     replyExecutor.shutdown();
    712     Util.closeQuietly(sink);
    713     Util.closeQuietly(source);
    714   }
    715 
    716   private void writeHttpResponse(Socket socket, BufferedSink sink, MockResponse response)
    717       throws IOException {
    718     sink.writeUtf8(response.getStatus());
    719     sink.writeUtf8("\r\n");
    720 
    721     Headers headers = response.getHeaders();
    722     for (int i = 0, size = headers.size(); i < size; i++) {
    723       sink.writeUtf8(headers.name(i));
    724       sink.writeUtf8(": ");
    725       sink.writeUtf8(headers.value(i));
    726       sink.writeUtf8("\r\n");
    727     }
    728     sink.writeUtf8("\r\n");
    729     sink.flush();
    730 
    731     Buffer body = response.getBody();
    732     if (body == null) return;
    733     sleepIfDelayed(response);
    734     throttledTransfer(response, socket, body, sink, body.size(), false);
    735   }
    736 
    737   private void sleepIfDelayed(MockResponse response) {
    738     long delayMs = response.getBodyDelay(TimeUnit.MILLISECONDS);
    739     if (delayMs != 0) {
    740       try {
    741         Thread.sleep(delayMs);
    742       } catch (InterruptedException e) {
    743         throw new AssertionError(e);
    744       }
    745     }
    746   }
    747 
    748   /**
    749    * Transfer bytes from {@code source} to {@code sink} until either {@code byteCount}
    750    * bytes have been transferred or {@code source} is exhausted. The transfer is
    751    * throttled according to {@code policy}.
    752    */
    753   private void throttledTransfer(MockResponse policy, Socket socket, BufferedSource source,
    754       BufferedSink sink, long byteCount, boolean isRequest) throws IOException {
    755     if (byteCount == 0) return;
    756 
    757     Buffer buffer = new Buffer();
    758     long bytesPerPeriod = policy.getThrottleBytesPerPeriod();
    759     long periodDelayMs = policy.getThrottlePeriod(TimeUnit.MILLISECONDS);
    760 
    761     long halfByteCount = byteCount / 2;
    762     boolean disconnectHalfway = isRequest
    763         ? policy.getSocketPolicy() == DISCONNECT_DURING_REQUEST_BODY
    764         : policy.getSocketPolicy() == DISCONNECT_DURING_RESPONSE_BODY;
    765 
    766     while (!socket.isClosed()) {
    767       for (int b = 0; b < bytesPerPeriod; ) {
    768         // Ensure we do not read past the allotted bytes in this period.
    769         long toRead = Math.min(byteCount, bytesPerPeriod - b);
    770         // Ensure we do not read past halfway if the policy will kill the connection.
    771         if (disconnectHalfway) {
    772           toRead = Math.min(toRead, byteCount - halfByteCount);
    773         }
    774 
    775         long read = source.read(buffer, toRead);
    776         if (read == -1) return;
    777 
    778         sink.write(buffer, read);
    779         sink.flush();
    780         b += read;
    781         byteCount -= read;
    782 
    783         if (disconnectHalfway && byteCount == halfByteCount) {
    784           socket.close();
    785           return;
    786         }
    787 
    788         if (byteCount == 0) return;
    789       }
    790 
    791       if (periodDelayMs != 0) {
    792         try {
    793           Thread.sleep(periodDelayMs);
    794         } catch (InterruptedException e) {
    795           throw new AssertionError();
    796         }
    797       }
    798     }
    799   }
    800 
    801   private void readEmptyLine(BufferedSource source) throws IOException {
    802     String line = source.readUtf8LineStrict();
    803     if (line.length() != 0) throw new IllegalStateException("Expected empty but was: " + line);
    804   }
    805 
    806   /**
    807    * Sets the dispatcher used to match incoming requests to mock responses.
    808    * The default dispatcher simply serves a fixed sequence of responses from
    809    * a {@link #enqueue(MockResponse) queue}; custom dispatchers can vary the
    810    * response based on timing or the content of the request.
    811    */
    812   public void setDispatcher(Dispatcher dispatcher) {
    813     if (dispatcher == null) throw new NullPointerException();
    814     this.dispatcher = dispatcher;
    815   }
    816 
    817   @Override public String toString() {
    818     return "MockWebServer[" + port + "]";
    819   }
    820 
    821   /** A buffer wrapper that drops data after {@code bodyLimit} bytes. */
    822   private static class TruncatingBuffer implements Sink {
    823     private final Buffer buffer = new Buffer();
    824     private long remainingByteCount;
    825     private long receivedByteCount;
    826 
    827     TruncatingBuffer(long bodyLimit) {
    828       remainingByteCount = bodyLimit;
    829     }
    830 
    831     @Override public void write(Buffer source, long byteCount) throws IOException {
    832       long toRead = Math.min(remainingByteCount, byteCount);
    833       if (toRead > 0) {
    834         source.read(buffer, toRead);
    835       }
    836       long toSkip = byteCount - toRead;
    837       if (toSkip > 0) {
    838         source.skip(toSkip);
    839       }
    840       remainingByteCount -= toRead;
    841       receivedByteCount += byteCount;
    842     }
    843 
    844     @Override public void flush() throws IOException {
    845     }
    846 
    847     @Override public Timeout timeout() {
    848       return Timeout.NONE;
    849     }
    850 
    851     @Override public void close() throws IOException {
    852     }
    853   }
    854 
    855   /** Processes HTTP requests layered over framed protocols. */
    856   private class FramedSocketHandler extends FramedConnection.Listener {
    857     private final Socket socket;
    858     private final Protocol protocol;
    859     private final AtomicInteger sequenceNumber = new AtomicInteger();
    860 
    861     private FramedSocketHandler(Socket socket, Protocol protocol) {
    862       this.socket = socket;
    863       this.protocol = protocol;
    864     }
    865 
    866     @Override public void onStream(FramedStream stream) throws IOException {
    867       RecordedRequest request = readRequest(stream);
    868       requestQueue.add(request);
    869       MockResponse response;
    870       try {
    871         response = dispatcher.dispatch(request);
    872       } catch (InterruptedException e) {
    873         throw new AssertionError(e);
    874       }
    875       writeResponse(stream, response);
    876       if (logger.isLoggable(Level.INFO)) {
    877         logger.info(MockWebServer.this + " received request: " + request
    878             + " and responded: " + response + " protocol is " + protocol.toString());
    879       }
    880     }
    881 
    882     private RecordedRequest readRequest(FramedStream stream) throws IOException {
    883       List<Header> streamHeaders = stream.getRequestHeaders();
    884       Headers.Builder httpHeaders = new Headers.Builder();
    885       String method = "<:method omitted>";
    886       String path = "<:path omitted>";
    887       String version = protocol == Protocol.SPDY_3 ? "<:version omitted>" : "HTTP/1.1";
    888       for (int i = 0, size = streamHeaders.size(); i < size; i++) {
    889         ByteString name = streamHeaders.get(i).name;
    890         String value = streamHeaders.get(i).value.utf8();
    891         if (name.equals(Header.TARGET_METHOD)) {
    892           method = value;
    893         } else if (name.equals(Header.TARGET_PATH)) {
    894           path = value;
    895         } else if (name.equals(Header.VERSION)) {
    896           version = value;
    897         } else if (protocol == Protocol.SPDY_3) {
    898           for (String s : value.split("\u0000", -1)) {
    899             httpHeaders.add(name.utf8(), s);
    900           }
    901         } else if (protocol == Protocol.HTTP_2) {
    902           httpHeaders.add(name.utf8(), value);
    903         } else {
    904           throw new IllegalStateException();
    905         }
    906       }
    907 
    908       Buffer body = new Buffer();
    909       body.writeAll(stream.getSource());
    910       body.close();
    911 
    912       String requestLine = method + ' ' + path + ' ' + version;
    913       List<Integer> chunkSizes = Collections.emptyList(); // No chunked encoding for SPDY.
    914       return new RecordedRequest(requestLine, httpHeaders.build(), chunkSizes, body.size(), body,
    915           sequenceNumber.getAndIncrement(), socket);
    916     }
    917 
    918     private void writeResponse(FramedStream stream, MockResponse response) throws IOException {
    919       Settings settings = response.getSettings();
    920       if (settings != null) {
    921         stream.getConnection().setSettings(settings);
    922       }
    923 
    924       if (response.getSocketPolicy() == SocketPolicy.NO_RESPONSE) {
    925         return;
    926       }
    927       List<Header> spdyHeaders = new ArrayList<>();
    928       String[] statusParts = response.getStatus().split(" ", 2);
    929       if (statusParts.length != 2) {
    930         throw new AssertionError("Unexpected status: " + response.getStatus());
    931       }
    932       // TODO: constants for well-known header names.
    933       spdyHeaders.add(new Header(Header.RESPONSE_STATUS, statusParts[1]));
    934       if (protocol == Protocol.SPDY_3) {
    935         spdyHeaders.add(new Header(Header.VERSION, statusParts[0]));
    936       }
    937       Headers headers = response.getHeaders();
    938       for (int i = 0, size = headers.size(); i < size; i++) {
    939         spdyHeaders.add(new Header(headers.name(i), headers.value(i)));
    940       }
    941 
    942       Buffer body = response.getBody();
    943       boolean closeStreamAfterHeaders = body != null || !response.getPushPromises().isEmpty();
    944       stream.reply(spdyHeaders, closeStreamAfterHeaders);
    945       pushPromises(stream, response.getPushPromises());
    946       if (body != null) {
    947         BufferedSink sink = Okio.buffer(stream.getSink());
    948         sleepIfDelayed(response);
    949         throttledTransfer(response, socket, body, sink, bodyLimit, false);
    950         sink.close();
    951       } else if (closeStreamAfterHeaders) {
    952         stream.close(ErrorCode.NO_ERROR);
    953       }
    954     }
    955 
    956     private void pushPromises(FramedStream stream, List<PushPromise> promises) throws IOException {
    957       for (PushPromise pushPromise : promises) {
    958         List<Header> pushedHeaders = new ArrayList<>();
    959         pushedHeaders.add(new Header(stream.getConnection().getProtocol() == Protocol.SPDY_3
    960             ? Header.TARGET_HOST
    961             : Header.TARGET_AUTHORITY, url(pushPromise.getPath()).host()));
    962         pushedHeaders.add(new Header(Header.TARGET_METHOD, pushPromise.getMethod()));
    963         pushedHeaders.add(new Header(Header.TARGET_PATH, pushPromise.getPath()));
    964         Headers pushPromiseHeaders = pushPromise.getHeaders();
    965         for (int i = 0, size = pushPromiseHeaders.size(); i < size; i++) {
    966           pushedHeaders.add(new Header(pushPromiseHeaders.name(i), pushPromiseHeaders.value(i)));
    967         }
    968         String requestLine = pushPromise.getMethod() + ' ' + pushPromise.getPath() + " HTTP/1.1";
    969         List<Integer> chunkSizes = Collections.emptyList(); // No chunked encoding for SPDY.
    970         requestQueue.add(new RecordedRequest(requestLine, pushPromise.getHeaders(), chunkSizes, 0,
    971             new Buffer(), sequenceNumber.getAndIncrement(), socket));
    972         boolean hasBody = pushPromise.getResponse().getBody() != null;
    973         FramedStream pushedStream =
    974             stream.getConnection().pushStream(stream.getId(), pushedHeaders, hasBody);
    975         writeResponse(pushedStream, pushPromise.getResponse());
    976       }
    977     }
    978   }
    979 }
    980