1 /* 2 * Copyright (C) 2010 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package tests.http; 18 19 import java.io.BufferedInputStream; 20 import java.io.BufferedOutputStream; 21 import java.io.ByteArrayOutputStream; 22 import java.io.IOException; 23 import java.io.InputStream; 24 import java.io.OutputStream; 25 import java.net.InetAddress; 26 import java.net.InetSocketAddress; 27 import java.net.MalformedURLException; 28 import java.net.Proxy; 29 import java.net.ServerSocket; 30 import java.net.Socket; 31 import java.net.URL; 32 import java.net.UnknownHostException; 33 import java.util.ArrayList; 34 import java.util.Collections; 35 import java.util.HashSet; 36 import java.util.Iterator; 37 import java.util.List; 38 import java.util.Set; 39 import java.util.concurrent.BlockingQueue; 40 import java.util.concurrent.Callable; 41 import java.util.concurrent.ExecutorService; 42 import java.util.concurrent.Executors; 43 import java.util.concurrent.LinkedBlockingDeque; 44 import java.util.concurrent.LinkedBlockingQueue; 45 import java.util.concurrent.atomic.AtomicInteger; 46 import javax.net.ssl.SSLSocket; 47 import javax.net.ssl.SSLSocketFactory; 48 49 /** 50 * A scriptable web server. Callers supply canned responses and the server 51 * replays them upon request in sequence. 52 */ 53 public final class MockWebServer { 54 55 static final String ASCII = "US-ASCII"; 56 57 private final BlockingQueue<RecordedRequest> requestQueue 58 = new LinkedBlockingQueue<RecordedRequest>(); 59 private final BlockingQueue<MockResponse> responseQueue 60 = new LinkedBlockingDeque<MockResponse>(); 61 private final Set<Socket> openClientSockets 62 = Collections.synchronizedSet(new HashSet<Socket>()); 63 private boolean singleResponse; 64 private final AtomicInteger requestCount = new AtomicInteger(); 65 private int bodyLimit = Integer.MAX_VALUE; 66 private ServerSocket serverSocket; 67 private SSLSocketFactory sslSocketFactory; 68 private ExecutorService executor; 69 private boolean tunnelProxy; 70 71 private int port = -1; 72 73 public int getPort() { 74 if (port == -1) { 75 throw new IllegalStateException("Cannot retrieve port before calling play()"); 76 } 77 return port; 78 } 79 80 public Proxy toProxyAddress() { 81 return new Proxy(Proxy.Type.HTTP, new InetSocketAddress("localhost", getPort())); 82 } 83 84 /** 85 * Returns a URL for connecting to this server. 86 * 87 * @param path the request path, such as "/". 88 */ 89 public URL getUrl(String path) throws MalformedURLException, UnknownHostException { 90 String host = InetAddress.getLocalHost().getHostName(); 91 return sslSocketFactory != null 92 ? new URL("https://" + host + ":" + getPort() + path) 93 : new URL("http://" + host + ":" + getPort() + path); 94 } 95 96 /** 97 * Sets the number of bytes of the POST body to keep in memory to the given 98 * limit. 99 */ 100 public void setBodyLimit(int maxBodyLength) { 101 this.bodyLimit = maxBodyLength; 102 } 103 104 /** 105 * Serve requests with HTTPS rather than otherwise. 106 * 107 * @param tunnelProxy whether to expect the HTTP CONNECT method before 108 * negotiating TLS. 109 */ 110 public void useHttps(SSLSocketFactory sslSocketFactory, boolean tunnelProxy) { 111 this.sslSocketFactory = sslSocketFactory; 112 this.tunnelProxy = tunnelProxy; 113 } 114 115 /** 116 * Awaits the next HTTP request, removes it, and returns it. Callers should 117 * use this to verify the request sent was as intended. 118 */ 119 public RecordedRequest takeRequest() throws InterruptedException { 120 return requestQueue.take(); 121 } 122 123 /** 124 * Returns the number of HTTP requests received thus far by this server. 125 * This may exceed the number of HTTP connections when connection reuse is 126 * in practice. 127 */ 128 public int getRequestCount() { 129 return requestCount.get(); 130 } 131 132 public void enqueue(MockResponse response) { 133 responseQueue.add(response); 134 } 135 136 /** 137 * By default, this class processes requests coming in by adding them to a 138 * queue and serves responses by removing them from another queue. This mode 139 * is appropriate for correctness testing. 140 * 141 * <p>Serving a single response causes the server to be stateless: requests 142 * are not enqueued, and responses are not dequeued. This mode is appropriate 143 * for benchmarking. 144 */ 145 public void setSingleResponse(boolean singleResponse) { 146 this.singleResponse = singleResponse; 147 } 148 149 /** 150 * Starts the server, serves all enqueued requests, and shuts the server 151 * down. 152 */ 153 public void play() throws IOException { 154 executor = Executors.newCachedThreadPool(); 155 serverSocket = new ServerSocket(0); 156 serverSocket.setReuseAddress(true); 157 158 port = serverSocket.getLocalPort(); 159 executor.submit(namedCallable("MockWebServer-accept-" + port, new Callable<Void>() { 160 public Void call() throws Exception { 161 List<Throwable> failures = new ArrayList<Throwable>(); 162 try { 163 acceptConnections(); 164 } catch (Throwable e) { 165 failures.add(e); 166 } 167 168 /* 169 * This gnarly block of code will release all sockets and 170 * all thread, even if any close fails. 171 */ 172 try { 173 serverSocket.close(); 174 } catch (Throwable e) { 175 failures.add(e); 176 } 177 for (Iterator<Socket> s = openClientSockets.iterator(); s.hasNext(); ) { 178 try { 179 s.next().close(); 180 s.remove(); 181 } catch (Throwable e) { 182 failures.add(e); 183 } 184 } 185 try { 186 executor.shutdown(); 187 } catch (Throwable e) { 188 failures.add(e); 189 } 190 if (!failures.isEmpty()) { 191 Throwable thrown = failures.get(0); 192 if (thrown instanceof Exception) { 193 throw (Exception) thrown; 194 } else { 195 throw (Error) thrown; 196 } 197 } else { 198 return null; 199 } 200 } 201 202 public void acceptConnections() throws Exception { 203 int count = 0; 204 while (true) { 205 if (count > 0 && responseQueue.isEmpty()) { 206 return; 207 } 208 209 Socket socket = serverSocket.accept(); 210 if (responseQueue.peek().getDisconnectAtStart()) { 211 responseQueue.take(); 212 socket.close(); 213 continue; 214 } 215 openClientSockets.add(socket); 216 serveConnection(socket); 217 count++; 218 } 219 } 220 })); 221 } 222 223 public void shutdown() throws IOException { 224 if (serverSocket != null) { 225 serverSocket.close(); // should cause acceptConnections() to break out 226 } 227 } 228 229 private void serveConnection(final Socket raw) { 230 String name = "MockWebServer-" + raw.getRemoteSocketAddress(); 231 executor.submit(namedCallable(name, new Callable<Void>() { 232 int sequenceNumber = 0; 233 234 public Void call() throws Exception { 235 Socket socket; 236 if (sslSocketFactory != null) { 237 if (tunnelProxy) { 238 if (!processOneRequest(raw.getInputStream(), raw.getOutputStream())) { 239 throw new IllegalStateException("Tunnel without any CONNECT!"); 240 } 241 } 242 socket = sslSocketFactory.createSocket( 243 raw, raw.getInetAddress().getHostAddress(), raw.getPort(), true); 244 ((SSLSocket) socket).setUseClientMode(false); 245 } else { 246 socket = raw; 247 } 248 249 InputStream in = new BufferedInputStream(socket.getInputStream()); 250 OutputStream out = new BufferedOutputStream(socket.getOutputStream()); 251 252 if (!processOneRequest(in, out)) { 253 throw new IllegalStateException("Connection without any request!"); 254 } 255 while (processOneRequest(in, out)) {} 256 257 in.close(); 258 out.close(); 259 raw.close(); 260 openClientSockets.remove(raw); 261 return null; 262 } 263 264 /** 265 * Reads a request and writes its response. Returns true if a request 266 * was processed. 267 */ 268 private boolean processOneRequest(InputStream in, OutputStream out) 269 throws IOException, InterruptedException { 270 RecordedRequest request = readRequest(in, sequenceNumber); 271 if (request == null) { 272 return false; 273 } 274 MockResponse response = dispatch(request); 275 writeResponse(out, response); 276 if (response.getDisconnectAtEnd()) { 277 in.close(); 278 out.close(); 279 } 280 sequenceNumber++; 281 return true; 282 } 283 })); 284 } 285 286 /** 287 * @param sequenceNumber the index of this request on this connection. 288 */ 289 private RecordedRequest readRequest(InputStream in, int sequenceNumber) throws IOException { 290 String request = readAsciiUntilCrlf(in); 291 if (request.isEmpty()) { 292 return null; // end of data; no more requests 293 } 294 295 List<String> headers = new ArrayList<String>(); 296 int contentLength = -1; 297 boolean chunked = false; 298 String header; 299 while (!(header = readAsciiUntilCrlf(in)).isEmpty()) { 300 headers.add(header); 301 String lowercaseHeader = header.toLowerCase(); 302 if (contentLength == -1 && lowercaseHeader.startsWith("content-length:")) { 303 contentLength = Integer.parseInt(header.substring(15).trim()); 304 } 305 if (lowercaseHeader.startsWith("transfer-encoding:") && 306 lowercaseHeader.substring(18).trim().equals("chunked")) { 307 chunked = true; 308 } 309 } 310 311 boolean hasBody = false; 312 TruncatingOutputStream requestBody = new TruncatingOutputStream(); 313 List<Integer> chunkSizes = new ArrayList<Integer>(); 314 if (contentLength != -1) { 315 hasBody = true; 316 transfer(contentLength, in, requestBody); 317 } else if (chunked) { 318 hasBody = true; 319 while (true) { 320 int chunkSize = Integer.parseInt(readAsciiUntilCrlf(in).trim(), 16); 321 if (chunkSize == 0) { 322 readEmptyLine(in); 323 break; 324 } 325 chunkSizes.add(chunkSize); 326 transfer(chunkSize, in, requestBody); 327 readEmptyLine(in); 328 } 329 } 330 331 if (request.startsWith("GET ") || request.startsWith("CONNECT ")) { 332 if (hasBody) { 333 throw new IllegalArgumentException("GET requests should not have a body!"); 334 } 335 } else if (request.startsWith("POST ")) { 336 if (!hasBody) { 337 throw new IllegalArgumentException("POST requests must have a body!"); 338 } 339 } else { 340 throw new UnsupportedOperationException("Unexpected method: " + request); 341 } 342 343 return new RecordedRequest(request, headers, chunkSizes, 344 requestBody.numBytesReceived, requestBody.toByteArray(), sequenceNumber); 345 } 346 347 /** 348 * Returns a response to satisfy {@code request}. 349 */ 350 private MockResponse dispatch(RecordedRequest request) throws InterruptedException { 351 if (responseQueue.isEmpty()) { 352 throw new IllegalStateException("Unexpected request: " + request); 353 } 354 355 if (singleResponse) { 356 return responseQueue.peek(); 357 } else { 358 requestCount.incrementAndGet(); 359 requestQueue.add(request); 360 return responseQueue.take(); 361 } 362 } 363 364 private void writeResponse(OutputStream out, MockResponse response) throws IOException { 365 out.write((response.getStatus() + "\r\n").getBytes(ASCII)); 366 for (String header : response.getHeaders()) { 367 out.write((header + "\r\n").getBytes(ASCII)); 368 } 369 out.write(("\r\n").getBytes(ASCII)); 370 out.write(response.getBody()); 371 out.flush(); 372 } 373 374 /** 375 * Transfer bytes from {@code in} to {@code out} until either {@code length} 376 * bytes have been transferred or {@code in} is exhausted. 377 */ 378 private void transfer(int length, InputStream in, OutputStream out) throws IOException { 379 byte[] buffer = new byte[1024]; 380 while (length > 0) { 381 int count = in.read(buffer, 0, Math.min(buffer.length, length)); 382 if (count == -1) { 383 return; 384 } 385 out.write(buffer, 0, count); 386 length -= count; 387 } 388 } 389 390 /** 391 * Returns the text from {@code in} until the next "\r\n", or null if 392 * {@code in} is exhausted. 393 */ 394 private String readAsciiUntilCrlf(InputStream in) throws IOException { 395 StringBuilder builder = new StringBuilder(); 396 while (true) { 397 int c = in.read(); 398 if (c == '\n' && builder.length() > 0 && builder.charAt(builder.length() - 1) == '\r') { 399 builder.deleteCharAt(builder.length() - 1); 400 return builder.toString(); 401 } else if (c == -1) { 402 return builder.toString(); 403 } else { 404 builder.append((char) c); 405 } 406 } 407 } 408 409 private void readEmptyLine(InputStream in) throws IOException { 410 String line = readAsciiUntilCrlf(in); 411 if (!line.isEmpty()) { 412 throw new IllegalStateException("Expected empty but was: " + line); 413 } 414 } 415 416 /** 417 * An output stream that drops data after bodyLimit bytes. 418 */ 419 private class TruncatingOutputStream extends ByteArrayOutputStream { 420 private int numBytesReceived = 0; 421 @Override public void write(byte[] buffer, int offset, int len) { 422 numBytesReceived += len; 423 super.write(buffer, offset, Math.min(len, bodyLimit - count)); 424 } 425 @Override public void write(int oneByte) { 426 numBytesReceived++; 427 if (count < bodyLimit) { 428 super.write(oneByte); 429 } 430 } 431 } 432 433 private static <T> Callable<T> namedCallable(final String name, final Callable<T> callable) { 434 return new Callable<T>() { 435 public T call() throws Exception { 436 String originalName = Thread.currentThread().getName(); 437 Thread.currentThread().setName(name); 438 try { 439 return callable.call(); 440 } finally { 441 Thread.currentThread().setName(originalName); 442 } 443 } 444 }; 445 } 446 } 447