Home | History | Annotate | Download | only in http
      1 /*
      2  * Copyright (C) 2010 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 package tests.http;
     18 
     19 import java.io.BufferedInputStream;
     20 import java.io.BufferedOutputStream;
     21 import java.io.ByteArrayOutputStream;
     22 import java.io.IOException;
     23 import java.io.InputStream;
     24 import java.io.OutputStream;
     25 import java.net.InetAddress;
     26 import java.net.InetSocketAddress;
     27 import java.net.MalformedURLException;
     28 import java.net.Proxy;
     29 import java.net.ServerSocket;
     30 import java.net.Socket;
     31 import java.net.URL;
     32 import java.net.UnknownHostException;
     33 import java.util.ArrayList;
     34 import java.util.Collections;
     35 import java.util.HashSet;
     36 import java.util.Iterator;
     37 import java.util.List;
     38 import java.util.Set;
     39 import java.util.concurrent.BlockingQueue;
     40 import java.util.concurrent.Callable;
     41 import java.util.concurrent.ExecutorService;
     42 import java.util.concurrent.Executors;
     43 import java.util.concurrent.LinkedBlockingDeque;
     44 import java.util.concurrent.LinkedBlockingQueue;
     45 import java.util.concurrent.atomic.AtomicInteger;
     46 import javax.net.ssl.SSLSocket;
     47 import javax.net.ssl.SSLSocketFactory;
     48 
     49 /**
     50  * A scriptable web server. Callers supply canned responses and the server
     51  * replays them upon request in sequence.
     52  */
     53 public final class MockWebServer {
     54 
     55     static final String ASCII = "US-ASCII";
     56 
     57     private final BlockingQueue<RecordedRequest> requestQueue
     58             = new LinkedBlockingQueue<RecordedRequest>();
     59     private final BlockingQueue<MockResponse> responseQueue
     60             = new LinkedBlockingDeque<MockResponse>();
     61     private final Set<Socket> openClientSockets
     62             = Collections.synchronizedSet(new HashSet<Socket>());
     63     private boolean singleResponse;
     64     private final AtomicInteger requestCount = new AtomicInteger();
     65     private int bodyLimit = Integer.MAX_VALUE;
     66     private ServerSocket serverSocket;
     67     private SSLSocketFactory sslSocketFactory;
     68     private ExecutorService executor;
     69     private boolean tunnelProxy;
     70 
     71     private int port = -1;
     72 
     73     public int getPort() {
     74         if (port == -1) {
     75             throw new IllegalStateException("Cannot retrieve port before calling play()");
     76         }
     77         return port;
     78     }
     79 
     80     public Proxy toProxyAddress() {
     81         return new Proxy(Proxy.Type.HTTP, new InetSocketAddress("localhost", getPort()));
     82     }
     83 
     84     /**
     85      * Returns a URL for connecting to this server.
     86      *
     87      * @param path the request path, such as "/".
     88      */
     89     public URL getUrl(String path) throws MalformedURLException, UnknownHostException {
     90         String host = InetAddress.getLocalHost().getHostName();
     91         return sslSocketFactory != null
     92                 ? new URL("https://" + host + ":" + getPort() + path)
     93                 : new URL("http://" + host + ":" + getPort() + path);
     94     }
     95 
     96     /**
     97      * Sets the number of bytes of the POST body to keep in memory to the given
     98      * limit.
     99      */
    100     public void setBodyLimit(int maxBodyLength) {
    101         this.bodyLimit = maxBodyLength;
    102     }
    103 
    104     /**
    105      * Serve requests with HTTPS rather than otherwise.
    106      *
    107      * @param tunnelProxy whether to expect the HTTP CONNECT method before
    108      *     negotiating TLS.
    109      */
    110     public void useHttps(SSLSocketFactory sslSocketFactory, boolean tunnelProxy) {
    111         this.sslSocketFactory = sslSocketFactory;
    112         this.tunnelProxy = tunnelProxy;
    113     }
    114 
    115     /**
    116      * Awaits the next HTTP request, removes it, and returns it. Callers should
    117      * use this to verify the request sent was as intended.
    118      */
    119     public RecordedRequest takeRequest() throws InterruptedException {
    120         return requestQueue.take();
    121     }
    122 
    123     /**
    124      * Returns the number of HTTP requests received thus far by this server.
    125      * This may exceed the number of HTTP connections when connection reuse is
    126      * in practice.
    127      */
    128     public int getRequestCount() {
    129         return requestCount.get();
    130     }
    131 
    132     public void enqueue(MockResponse response) {
    133         responseQueue.add(response);
    134     }
    135 
    136     /**
    137      * By default, this class processes requests coming in by adding them to a
    138      * queue and serves responses by removing them from another queue. This mode
    139      * is appropriate for correctness testing.
    140      *
    141      * <p>Serving a single response causes the server to be stateless: requests
    142      * are not enqueued, and responses are not dequeued. This mode is appropriate
    143      * for benchmarking.
    144      */
    145     public void setSingleResponse(boolean singleResponse) {
    146         this.singleResponse = singleResponse;
    147     }
    148 
    149     /**
    150      * Starts the server, serves all enqueued requests, and shuts the server
    151      * down.
    152      */
    153     public void play() throws IOException {
    154         executor = Executors.newCachedThreadPool();
    155         serverSocket = new ServerSocket(0);
    156         serverSocket.setReuseAddress(true);
    157 
    158         port = serverSocket.getLocalPort();
    159         executor.submit(namedCallable("MockWebServer-accept-" + port, new Callable<Void>() {
    160             public Void call() throws Exception {
    161                 List<Throwable> failures = new ArrayList<Throwable>();
    162                 try {
    163                     acceptConnections();
    164                 } catch (Throwable e) {
    165                     failures.add(e);
    166                 }
    167 
    168                 /*
    169                  * This gnarly block of code will release all sockets and
    170                  * all thread, even if any close fails.
    171                  */
    172                 try {
    173                     serverSocket.close();
    174                 } catch (Throwable e) {
    175                     failures.add(e);
    176                 }
    177                 for (Iterator<Socket> s = openClientSockets.iterator(); s.hasNext(); ) {
    178                     try {
    179                         s.next().close();
    180                         s.remove();
    181                     } catch (Throwable e) {
    182                         failures.add(e);
    183                     }
    184                 }
    185                 try {
    186                     executor.shutdown();
    187                 } catch (Throwable e) {
    188                     failures.add(e);
    189                 }
    190                 if (!failures.isEmpty()) {
    191                     Throwable thrown = failures.get(0);
    192                     if (thrown instanceof Exception) {
    193                         throw (Exception) thrown;
    194                     } else {
    195                         throw (Error) thrown;
    196                     }
    197                 } else {
    198                     return null;
    199                 }
    200             }
    201 
    202             public void acceptConnections() throws Exception {
    203                 int count = 0;
    204                 while (true) {
    205                     if (count > 0 && responseQueue.isEmpty()) {
    206                         return;
    207                     }
    208 
    209                     Socket socket = serverSocket.accept();
    210                     if (responseQueue.peek().getDisconnectAtStart()) {
    211                         responseQueue.take();
    212                         socket.close();
    213                         continue;
    214                     }
    215                     openClientSockets.add(socket);
    216                     serveConnection(socket);
    217                     count++;
    218                 }
    219             }
    220         }));
    221     }
    222 
    223     public void shutdown() throws IOException {
    224         if (serverSocket != null) {
    225             serverSocket.close(); // should cause acceptConnections() to break out
    226         }
    227     }
    228 
    229     private void serveConnection(final Socket raw) {
    230         String name = "MockWebServer-" + raw.getRemoteSocketAddress();
    231         executor.submit(namedCallable(name, new Callable<Void>() {
    232             int sequenceNumber = 0;
    233 
    234             public Void call() throws Exception {
    235                 Socket socket;
    236                 if (sslSocketFactory != null) {
    237                     if (tunnelProxy) {
    238                         if (!processOneRequest(raw.getInputStream(), raw.getOutputStream())) {
    239                             throw new IllegalStateException("Tunnel without any CONNECT!");
    240                         }
    241                     }
    242                     socket = sslSocketFactory.createSocket(
    243                             raw, raw.getInetAddress().getHostAddress(), raw.getPort(), true);
    244                     ((SSLSocket) socket).setUseClientMode(false);
    245                 } else {
    246                     socket = raw;
    247                 }
    248 
    249                 InputStream in = new BufferedInputStream(socket.getInputStream());
    250                 OutputStream out = new BufferedOutputStream(socket.getOutputStream());
    251 
    252                 if (!processOneRequest(in, out)) {
    253                     throw new IllegalStateException("Connection without any request!");
    254                 }
    255                 while (processOneRequest(in, out)) {}
    256 
    257                 in.close();
    258                 out.close();
    259                 raw.close();
    260                 openClientSockets.remove(raw);
    261                 return null;
    262             }
    263 
    264             /**
    265              * Reads a request and writes its response. Returns true if a request
    266              * was processed.
    267              */
    268             private boolean processOneRequest(InputStream in, OutputStream out)
    269                     throws IOException, InterruptedException {
    270                 RecordedRequest request = readRequest(in, sequenceNumber);
    271                 if (request == null) {
    272                     return false;
    273                 }
    274                 MockResponse response = dispatch(request);
    275                 writeResponse(out, response);
    276                 if (response.getDisconnectAtEnd()) {
    277                     in.close();
    278                     out.close();
    279                 }
    280                 sequenceNumber++;
    281                 return true;
    282             }
    283         }));
    284     }
    285 
    286     /**
    287      * @param sequenceNumber the index of this request on this connection.
    288      */
    289     private RecordedRequest readRequest(InputStream in, int sequenceNumber) throws IOException {
    290         String request = readAsciiUntilCrlf(in);
    291         if (request.isEmpty()) {
    292             return null; // end of data; no more requests
    293         }
    294 
    295         List<String> headers = new ArrayList<String>();
    296         int contentLength = -1;
    297         boolean chunked = false;
    298         String header;
    299         while (!(header = readAsciiUntilCrlf(in)).isEmpty()) {
    300             headers.add(header);
    301             String lowercaseHeader = header.toLowerCase();
    302             if (contentLength == -1 && lowercaseHeader.startsWith("content-length:")) {
    303                 contentLength = Integer.parseInt(header.substring(15).trim());
    304             }
    305             if (lowercaseHeader.startsWith("transfer-encoding:") &&
    306                     lowercaseHeader.substring(18).trim().equals("chunked")) {
    307                 chunked = true;
    308             }
    309         }
    310 
    311         boolean hasBody = false;
    312         TruncatingOutputStream requestBody = new TruncatingOutputStream();
    313         List<Integer> chunkSizes = new ArrayList<Integer>();
    314         if (contentLength != -1) {
    315             hasBody = true;
    316             transfer(contentLength, in, requestBody);
    317         } else if (chunked) {
    318             hasBody = true;
    319             while (true) {
    320                 int chunkSize = Integer.parseInt(readAsciiUntilCrlf(in).trim(), 16);
    321                 if (chunkSize == 0) {
    322                     readEmptyLine(in);
    323                     break;
    324                 }
    325                 chunkSizes.add(chunkSize);
    326                 transfer(chunkSize, in, requestBody);
    327                 readEmptyLine(in);
    328             }
    329         }
    330 
    331         if (request.startsWith("GET ") || request.startsWith("CONNECT ")) {
    332             if (hasBody) {
    333                 throw new IllegalArgumentException("GET requests should not have a body!");
    334             }
    335         } else if (request.startsWith("POST ")) {
    336             if (!hasBody) {
    337                 throw new IllegalArgumentException("POST requests must have a body!");
    338             }
    339         } else {
    340             throw new UnsupportedOperationException("Unexpected method: " + request);
    341         }
    342 
    343         return new RecordedRequest(request, headers, chunkSizes,
    344                 requestBody.numBytesReceived, requestBody.toByteArray(), sequenceNumber);
    345     }
    346 
    347     /**
    348      * Returns a response to satisfy {@code request}.
    349      */
    350     private MockResponse dispatch(RecordedRequest request) throws InterruptedException {
    351         if (responseQueue.isEmpty()) {
    352             throw new IllegalStateException("Unexpected request: " + request);
    353         }
    354 
    355         if (singleResponse) {
    356             return responseQueue.peek();
    357         } else {
    358             requestCount.incrementAndGet();
    359             requestQueue.add(request);
    360             return responseQueue.take();
    361         }
    362     }
    363 
    364     private void writeResponse(OutputStream out, MockResponse response) throws IOException {
    365         out.write((response.getStatus() + "\r\n").getBytes(ASCII));
    366         for (String header : response.getHeaders()) {
    367             out.write((header + "\r\n").getBytes(ASCII));
    368         }
    369         out.write(("\r\n").getBytes(ASCII));
    370         out.write(response.getBody());
    371         out.flush();
    372     }
    373 
    374     /**
    375      * Transfer bytes from {@code in} to {@code out} until either {@code length}
    376      * bytes have been transferred or {@code in} is exhausted.
    377      */
    378     private void transfer(int length, InputStream in, OutputStream out) throws IOException {
    379         byte[] buffer = new byte[1024];
    380         while (length > 0) {
    381             int count = in.read(buffer, 0, Math.min(buffer.length, length));
    382             if (count == -1) {
    383                 return;
    384             }
    385             out.write(buffer, 0, count);
    386             length -= count;
    387         }
    388     }
    389 
    390     /**
    391      * Returns the text from {@code in} until the next "\r\n", or null if
    392      * {@code in} is exhausted.
    393      */
    394     private String readAsciiUntilCrlf(InputStream in) throws IOException {
    395         StringBuilder builder = new StringBuilder();
    396         while (true) {
    397             int c = in.read();
    398             if (c == '\n' && builder.length() > 0 && builder.charAt(builder.length() - 1) == '\r') {
    399                 builder.deleteCharAt(builder.length() - 1);
    400                 return builder.toString();
    401             } else if (c == -1) {
    402                 return builder.toString();
    403             } else {
    404                 builder.append((char) c);
    405             }
    406         }
    407     }
    408 
    409     private void readEmptyLine(InputStream in) throws IOException {
    410         String line = readAsciiUntilCrlf(in);
    411         if (!line.isEmpty()) {
    412             throw new IllegalStateException("Expected empty but was: " + line);
    413         }
    414     }
    415 
    416     /**
    417      * An output stream that drops data after bodyLimit bytes.
    418      */
    419     private class TruncatingOutputStream extends ByteArrayOutputStream {
    420         private int numBytesReceived = 0;
    421         @Override public void write(byte[] buffer, int offset, int len) {
    422             numBytesReceived += len;
    423             super.write(buffer, offset, Math.min(len, bodyLimit - count));
    424         }
    425         @Override public void write(int oneByte) {
    426             numBytesReceived++;
    427             if (count < bodyLimit) {
    428                 super.write(oneByte);
    429             }
    430         }
    431     }
    432 
    433     private static <T> Callable<T> namedCallable(final String name, final Callable<T> callable) {
    434         return new Callable<T>() {
    435             public T call() throws Exception {
    436                 String originalName = Thread.currentThread().getName();
    437                 Thread.currentThread().setName(name);
    438                 try {
    439                     return callable.call();
    440                 } finally {
    441                     Thread.currentThread().setName(originalName);
    442                 }
    443             }
    444         };
    445     }
    446 }
    447