1 /* 2 * Copyright (C) 2011 Google Inc. 3 * Copyright (C) 2013 Square, Inc. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18 package com.squareup.okhttp.mockwebserver; 19 20 import com.squareup.okhttp.Headers; 21 import com.squareup.okhttp.HttpUrl; 22 import com.squareup.okhttp.Protocol; 23 import com.squareup.okhttp.Request; 24 import com.squareup.okhttp.Response; 25 import com.squareup.okhttp.internal.NamedRunnable; 26 import com.squareup.okhttp.internal.Platform; 27 import com.squareup.okhttp.internal.Util; 28 import com.squareup.okhttp.internal.framed.ErrorCode; 29 import com.squareup.okhttp.internal.framed.FramedConnection; 30 import com.squareup.okhttp.internal.framed.FramedStream; 31 import com.squareup.okhttp.internal.framed.Header; 32 import com.squareup.okhttp.internal.framed.Settings; 33 import com.squareup.okhttp.internal.http.HttpMethod; 34 import com.squareup.okhttp.internal.ws.RealWebSocket; 35 import com.squareup.okhttp.internal.ws.WebSocketProtocol; 36 import com.squareup.okhttp.ws.WebSocketListener; 37 import java.io.IOException; 38 import java.net.InetAddress; 39 import java.net.InetSocketAddress; 40 import java.net.ProtocolException; 41 import java.net.Proxy; 42 import java.net.ServerSocket; 43 import java.net.Socket; 44 import java.net.SocketException; 45 import java.net.URL; 46 import java.security.SecureRandom; 47 import java.security.cert.CertificateException; 48 import java.security.cert.X509Certificate; 49 import java.util.ArrayList; 50 import java.util.Collections; 51 import java.util.Iterator; 52 import java.util.List; 53 import java.util.Locale; 54 import java.util.Set; 55 import java.util.concurrent.BlockingQueue; 56 import java.util.concurrent.ConcurrentHashMap; 57 import java.util.concurrent.CountDownLatch; 58 import java.util.concurrent.ExecutorService; 59 import java.util.concurrent.Executors; 60 import java.util.concurrent.LinkedBlockingDeque; 61 import java.util.concurrent.LinkedBlockingQueue; 62 import java.util.concurrent.ThreadPoolExecutor; 63 import java.util.concurrent.TimeUnit; 64 import java.util.concurrent.atomic.AtomicInteger; 65 import java.util.logging.Level; 66 import java.util.logging.Logger; 67 import javax.net.ServerSocketFactory; 68 import javax.net.ssl.SSLContext; 69 import javax.net.ssl.SSLSocket; 70 import javax.net.ssl.SSLSocketFactory; 71 import javax.net.ssl.TrustManager; 72 import javax.net.ssl.X509TrustManager; 73 import okio.Buffer; 74 import okio.BufferedSink; 75 import okio.BufferedSource; 76 import okio.ByteString; 77 import okio.Okio; 78 import okio.Sink; 79 import okio.Timeout; 80 import org.junit.rules.TestRule; 81 import org.junit.runner.Description; 82 import org.junit.runners.model.Statement; 83 84 import static com.squareup.okhttp.mockwebserver.SocketPolicy.DISCONNECT_AT_START; 85 import static com.squareup.okhttp.mockwebserver.SocketPolicy.DISCONNECT_DURING_REQUEST_BODY; 86 import static com.squareup.okhttp.mockwebserver.SocketPolicy.DISCONNECT_DURING_RESPONSE_BODY; 87 import static com.squareup.okhttp.mockwebserver.SocketPolicy.FAIL_HANDSHAKE; 88 import static java.util.concurrent.TimeUnit.SECONDS; 89 90 /** 91 * A scriptable web server. Callers supply canned responses and the server 92 * replays them upon request in sequence. 93 */ 94 public final class MockWebServer implements TestRule { 95 private static final X509TrustManager UNTRUSTED_TRUST_MANAGER = new X509TrustManager() { 96 @Override public void checkClientTrusted(X509Certificate[] chain, String authType) 97 throws CertificateException { 98 throw new CertificateException(); 99 } 100 101 @Override public void checkServerTrusted(X509Certificate[] chain, String authType) { 102 throw new AssertionError(); 103 } 104 105 @Override public X509Certificate[] getAcceptedIssuers() { 106 throw new AssertionError(); 107 } 108 }; 109 110 private static final Logger logger = Logger.getLogger(MockWebServer.class.getName()); 111 112 private final BlockingQueue<RecordedRequest> requestQueue = new LinkedBlockingQueue<>(); 113 114 private final Set<Socket> openClientSockets = 115 Collections.newSetFromMap(new ConcurrentHashMap<Socket, Boolean>()); 116 private final Set<FramedConnection> openFramedConnections = 117 Collections.newSetFromMap(new ConcurrentHashMap<FramedConnection, Boolean>()); 118 private final AtomicInteger requestCount = new AtomicInteger(); 119 private long bodyLimit = Long.MAX_VALUE; 120 private ServerSocketFactory serverSocketFactory = ServerSocketFactory.getDefault(); 121 private ServerSocket serverSocket; 122 private SSLSocketFactory sslSocketFactory; 123 private ExecutorService executor; 124 private boolean tunnelProxy; 125 private Dispatcher dispatcher = new QueueDispatcher(); 126 127 private int port = -1; 128 private InetSocketAddress inetSocketAddress; 129 private boolean protocolNegotiationEnabled = true; 130 private List<Protocol> protocols 131 = Util.immutableList(Protocol.HTTP_2, Protocol.SPDY_3, Protocol.HTTP_1_1); 132 133 private boolean started; 134 135 private synchronized void maybeStart() { 136 if (started) return; 137 try { 138 start(); 139 } catch (IOException e) { 140 throw new RuntimeException(e); 141 } 142 } 143 144 @Override public Statement apply(final Statement base, Description description) { 145 return new Statement() { 146 @Override public void evaluate() throws Throwable { 147 maybeStart(); 148 try { 149 base.evaluate(); 150 } finally { 151 try { 152 shutdown(); 153 } catch (IOException e) { 154 logger.log(Level.WARNING, "MockWebServer shutdown failed", e); 155 } 156 } 157 } 158 }; 159 } 160 161 public int getPort() { 162 maybeStart(); 163 return port; 164 } 165 166 public String getHostName() { 167 maybeStart(); 168 return inetSocketAddress.getHostName(); 169 } 170 171 public Proxy toProxyAddress() { 172 maybeStart(); 173 InetSocketAddress address = new InetSocketAddress(inetSocketAddress.getAddress(), getPort()); 174 return new Proxy(Proxy.Type.HTTP, address); 175 } 176 177 public void setServerSocketFactory(ServerSocketFactory serverSocketFactory) { 178 if (executor != null) { 179 throw new IllegalStateException( 180 "setServerSocketFactory() must be called before start()"); 181 } 182 this.serverSocketFactory = serverSocketFactory; 183 } 184 185 /** 186 * Returns a URL for connecting to this server. 187 * @param path the request path, such as "/". 188 */ 189 @Deprecated 190 public URL getUrl(String path) { 191 return url(path).url(); 192 } 193 194 /** 195 * Returns a URL for connecting to this server. 196 * 197 * @param path the request path, such as "/". 198 */ 199 public HttpUrl url(String path) { 200 return new HttpUrl.Builder() 201 .scheme(sslSocketFactory != null ? "https" : "http") 202 .host(getHostName()) 203 .port(getPort()) 204 .build() 205 .resolve(path); 206 } 207 208 /** 209 * Returns a cookie domain for this server. This returns the server's 210 * non-loopback host name if it is known. Otherwise this returns ".local" for 211 * this server's loopback name. 212 */ 213 public String getCookieDomain() { 214 String hostName = getHostName(); 215 return hostName.contains(".") ? hostName : ".local"; 216 } 217 218 /** 219 * Sets the number of bytes of the POST body to keep in memory to the given 220 * limit. 221 */ 222 public void setBodyLimit(long maxBodyLength) { 223 this.bodyLimit = maxBodyLength; 224 } 225 226 /** 227 * Sets whether ALPN is used on incoming HTTPS connections to 228 * negotiate a protocol like HTTP/1.1 or HTTP/2. Call this method to disable 229 * negotiation and restrict connections to HTTP/1.1. 230 */ 231 public void setProtocolNegotiationEnabled(boolean protocolNegotiationEnabled) { 232 this.protocolNegotiationEnabled = protocolNegotiationEnabled; 233 } 234 235 /** 236 * Indicates the protocols supported by ALPN on incoming HTTPS 237 * connections. This list is ignored when 238 * {@link #setProtocolNegotiationEnabled negotiation is disabled}. 239 * 240 * @param protocols the protocols to use, in order of preference. The list 241 * must contain {@linkplain Protocol#HTTP_1_1}. It must not contain null. 242 */ 243 public void setProtocols(List<Protocol> protocols) { 244 protocols = Util.immutableList(protocols); 245 if (!protocols.contains(Protocol.HTTP_1_1)) { 246 throw new IllegalArgumentException("protocols doesn't contain http/1.1: " + protocols); 247 } 248 if (protocols.contains(null)) { 249 throw new IllegalArgumentException("protocols must not contain null"); 250 } 251 this.protocols = protocols; 252 } 253 254 /** 255 * Serve requests with HTTPS rather than otherwise. 256 * @param tunnelProxy true to expect the HTTP CONNECT method before 257 * negotiating TLS. 258 */ 259 public void useHttps(SSLSocketFactory sslSocketFactory, boolean tunnelProxy) { 260 this.sslSocketFactory = sslSocketFactory; 261 this.tunnelProxy = tunnelProxy; 262 } 263 264 /** 265 * Awaits the next HTTP request, removes it, and returns it. Callers should 266 * use this to verify the request was sent as intended. This method will block until the 267 * request is available, possibly forever. 268 * 269 * @return the head of the request queue 270 */ 271 public RecordedRequest takeRequest() throws InterruptedException { 272 return requestQueue.take(); 273 } 274 275 /** 276 * Awaits the next HTTP request (waiting up to the 277 * specified wait time if necessary), removes it, and returns it. Callers should 278 * use this to verify the request was sent as intended within the given time. 279 * 280 * @param timeout how long to wait before giving up, in units of 281 * {@code unit} 282 * @param unit a {@code TimeUnit} determining how to interpret the 283 * {@code timeout} parameter 284 * @return the head of the request queue 285 */ 286 public RecordedRequest takeRequest(long timeout, TimeUnit unit) throws InterruptedException { 287 return requestQueue.poll(timeout, unit); 288 } 289 290 /** 291 * Returns the number of HTTP requests received thus far by this server. This 292 * may exceed the number of HTTP connections when connection reuse is in 293 * practice. 294 */ 295 public int getRequestCount() { 296 return requestCount.get(); 297 } 298 299 /** 300 * Scripts {@code response} to be returned to a request made in sequence. The 301 * first request is served by the first enqueued response; the second request 302 * by the second enqueued response; and so on. 303 * 304 * @throws ClassCastException if the default dispatcher has been replaced 305 * with {@link #setDispatcher(Dispatcher)}. 306 */ 307 public void enqueue(MockResponse response) { 308 ((QueueDispatcher) dispatcher).enqueueResponse(response.clone()); 309 } 310 311 /** Equivalent to {@code start(0)}. */ 312 public void start() throws IOException { 313 start(0); 314 } 315 316 /** 317 * Starts the server on the loopback interface for the given port. 318 * 319 * @param port the port to listen to, or 0 for any available port. Automated 320 * tests should always use port 0 to avoid flakiness when a specific port 321 * is unavailable. 322 */ 323 public void start(int port) throws IOException { 324 start(InetAddress.getByName("localhost"), port); 325 } 326 327 /** 328 * Starts the server on the given address and port. 329 * 330 * @param inetAddress the address to create the server socket on 331 * 332 * @param port the port to listen to, or 0 for any available port. Automated 333 * tests should always use port 0 to avoid flakiness when a specific port 334 * is unavailable. 335 */ 336 public void start(InetAddress inetAddress, int port) throws IOException { 337 start(new InetSocketAddress(inetAddress, port)); 338 } 339 340 /** 341 * Starts the server and binds to the given socket address. 342 * 343 * @param inetSocketAddress the socket address to bind the server on 344 */ 345 private synchronized void start(InetSocketAddress inetSocketAddress) throws IOException { 346 if (started) throw new IllegalStateException("start() already called"); 347 started = true; 348 349 executor = Executors.newCachedThreadPool(Util.threadFactory("MockWebServer", false)); 350 this.inetSocketAddress = inetSocketAddress; 351 serverSocket = serverSocketFactory.createServerSocket(); 352 // Reuse if the user specified a port 353 serverSocket.setReuseAddress(inetSocketAddress.getPort() != 0); 354 serverSocket.bind(inetSocketAddress, 50); 355 356 port = serverSocket.getLocalPort(); 357 executor.execute(new NamedRunnable("MockWebServer %s", port) { 358 @Override protected void execute() { 359 try { 360 logger.info(MockWebServer.this + " starting to accept connections"); 361 acceptConnections(); 362 } catch (Throwable e) { 363 logger.log(Level.WARNING, MockWebServer.this + " failed unexpectedly", e); 364 } 365 366 // Release all sockets and all threads, even if any close fails. 367 Util.closeQuietly(serverSocket); 368 for (Iterator<Socket> s = openClientSockets.iterator(); s.hasNext(); ) { 369 Util.closeQuietly(s.next()); 370 s.remove(); 371 } 372 for (Iterator<FramedConnection> s = openFramedConnections.iterator(); s.hasNext(); ) { 373 Util.closeQuietly(s.next()); 374 s.remove(); 375 } 376 executor.shutdown(); 377 } 378 379 private void acceptConnections() throws Exception { 380 while (true) { 381 Socket socket; 382 try { 383 socket = serverSocket.accept(); 384 } catch (SocketException e) { 385 logger.info(MockWebServer.this + " done accepting connections: " + e.getMessage()); 386 return; 387 } 388 SocketPolicy socketPolicy = dispatcher.peek().getSocketPolicy(); 389 if (socketPolicy == DISCONNECT_AT_START) { 390 dispatchBookkeepingRequest(0, socket); 391 socket.close(); 392 } else { 393 openClientSockets.add(socket); 394 serveConnection(socket); 395 } 396 } 397 } 398 }); 399 } 400 401 public synchronized void shutdown() throws IOException { 402 if (!started) return; 403 if (serverSocket == null) throw new IllegalStateException("shutdown() before start()"); 404 405 // Cause acceptConnections() to break out. 406 serverSocket.close(); 407 408 // Await shutdown. 409 try { 410 if (!executor.awaitTermination(5, TimeUnit.SECONDS)) { 411 throw new IOException("Gave up waiting for executor to shut down"); 412 } 413 } catch (InterruptedException e) { 414 throw new AssertionError(); 415 } 416 } 417 418 private void serveConnection(final Socket raw) { 419 executor.execute(new NamedRunnable("MockWebServer %s", raw.getRemoteSocketAddress()) { 420 int sequenceNumber = 0; 421 422 @Override protected void execute() { 423 try { 424 processConnection(); 425 } catch (IOException e) { 426 logger.info( 427 MockWebServer.this + " connection from " + raw.getInetAddress() + " failed: " + e); 428 } catch (Exception e) { 429 logger.log(Level.SEVERE, 430 MockWebServer.this + " connection from " + raw.getInetAddress() + " crashed", e); 431 } 432 } 433 434 public void processConnection() throws Exception { 435 Protocol protocol = Protocol.HTTP_1_1; 436 Socket socket; 437 if (sslSocketFactory != null) { 438 if (tunnelProxy) { 439 createTunnel(); 440 } 441 SocketPolicy socketPolicy = dispatcher.peek().getSocketPolicy(); 442 if (socketPolicy == FAIL_HANDSHAKE) { 443 dispatchBookkeepingRequest(sequenceNumber, raw); 444 processHandshakeFailure(raw); 445 return; 446 } 447 socket = sslSocketFactory.createSocket(raw, raw.getInetAddress().getHostAddress(), 448 raw.getPort(), true); 449 SSLSocket sslSocket = (SSLSocket) socket; 450 sslSocket.setUseClientMode(false); 451 openClientSockets.add(socket); 452 453 if (protocolNegotiationEnabled) { 454 Platform.get().configureTlsExtensions(sslSocket, null, protocols); 455 } 456 457 sslSocket.startHandshake(); 458 459 if (protocolNegotiationEnabled) { 460 String protocolString = Platform.get().getSelectedProtocol(sslSocket); 461 protocol = protocolString != null ? Protocol.get(protocolString) : Protocol.HTTP_1_1; 462 } 463 openClientSockets.remove(raw); 464 } else { 465 socket = raw; 466 } 467 468 if (protocol != Protocol.HTTP_1_1) { 469 FramedSocketHandler framedSocketListener = new FramedSocketHandler(socket, protocol); 470 FramedConnection framedConnection = new FramedConnection.Builder(false) 471 .socket(socket) 472 .protocol(protocol) 473 .listener(framedSocketListener) 474 .build(); 475 openFramedConnections.add(framedConnection); 476 openClientSockets.remove(socket); 477 return; 478 } 479 480 BufferedSource source = Okio.buffer(Okio.source(socket)); 481 BufferedSink sink = Okio.buffer(Okio.sink(socket)); 482 483 while (processOneRequest(socket, source, sink)) { 484 } 485 486 if (sequenceNumber == 0) { 487 logger.warning(MockWebServer.this 488 + " connection from " 489 + raw.getInetAddress() 490 + " didn't make a request"); 491 } 492 493 source.close(); 494 sink.close(); 495 socket.close(); 496 openClientSockets.remove(socket); 497 } 498 499 /** 500 * Respond to CONNECT requests until a SWITCH_TO_SSL_AT_END response is 501 * dispatched. 502 */ 503 private void createTunnel() throws IOException, InterruptedException { 504 BufferedSource source = Okio.buffer(Okio.source(raw)); 505 BufferedSink sink = Okio.buffer(Okio.sink(raw)); 506 while (true) { 507 SocketPolicy socketPolicy = dispatcher.peek().getSocketPolicy(); 508 if (!processOneRequest(raw, source, sink)) { 509 throw new IllegalStateException("Tunnel without any CONNECT!"); 510 } 511 if (socketPolicy == SocketPolicy.UPGRADE_TO_SSL_AT_END) return; 512 } 513 } 514 515 /** 516 * Reads a request and writes its response. Returns true if further calls should be attempted 517 * on the socket. 518 */ 519 private boolean processOneRequest(Socket socket, BufferedSource source, BufferedSink sink) 520 throws IOException, InterruptedException { 521 RecordedRequest request = readRequest(socket, source, sink, sequenceNumber); 522 if (request == null) return false; 523 524 requestCount.incrementAndGet(); 525 requestQueue.add(request); 526 527 MockResponse response = dispatcher.dispatch(request); 528 if (response.getSocketPolicy() == SocketPolicy.DISCONNECT_AFTER_REQUEST) { 529 socket.close(); 530 return false; 531 } 532 if (response.getSocketPolicy() == SocketPolicy.NO_RESPONSE) { 533 // This read should block until the socket is closed. (Because nobody is writing.) 534 if (source.exhausted()) return false; 535 throw new ProtocolException("unexpected data"); 536 } 537 538 boolean reuseSocket = true; 539 boolean requestWantsWebSockets = "Upgrade".equalsIgnoreCase(request.getHeader("Connection")) 540 && "websocket".equalsIgnoreCase(request.getHeader("Upgrade")); 541 boolean responseWantsWebSockets = response.getWebSocketListener() != null; 542 if (requestWantsWebSockets && responseWantsWebSockets) { 543 handleWebSocketUpgrade(socket, source, sink, request, response); 544 reuseSocket = false; 545 } else { 546 writeHttpResponse(socket, sink, response); 547 } 548 549 if (logger.isLoggable(Level.INFO)) { 550 logger.info(MockWebServer.this + " received request: " + request 551 + " and responded: " + response); 552 } 553 554 if (response.getSocketPolicy() == SocketPolicy.DISCONNECT_AT_END) { 555 socket.close(); 556 return false; 557 } else if (response.getSocketPolicy() == SocketPolicy.SHUTDOWN_INPUT_AT_END) { 558 socket.shutdownInput(); 559 } else if (response.getSocketPolicy() == SocketPolicy.SHUTDOWN_OUTPUT_AT_END) { 560 socket.shutdownOutput(); 561 } 562 563 sequenceNumber++; 564 return reuseSocket; 565 } 566 }); 567 } 568 569 private void processHandshakeFailure(Socket raw) throws Exception { 570 SSLContext context = SSLContext.getInstance("TLS"); 571 context.init(null, new TrustManager[] { UNTRUSTED_TRUST_MANAGER }, new SecureRandom()); 572 SSLSocketFactory sslSocketFactory = context.getSocketFactory(); 573 SSLSocket socket = (SSLSocket) sslSocketFactory.createSocket( 574 raw, raw.getInetAddress().getHostAddress(), raw.getPort(), true); 575 try { 576 socket.startHandshake(); // we're testing a handshake failure 577 throw new AssertionError(); 578 } catch (IOException expected) { 579 } 580 socket.close(); 581 } 582 583 private void dispatchBookkeepingRequest(int sequenceNumber, Socket socket) 584 throws InterruptedException { 585 requestCount.incrementAndGet(); 586 dispatcher.dispatch(new RecordedRequest(null, null, null, -1, null, sequenceNumber, socket)); 587 } 588 589 /** @param sequenceNumber the index of this request on this connection. */ 590 private RecordedRequest readRequest(Socket socket, BufferedSource source, BufferedSink sink, 591 int sequenceNumber) throws IOException { 592 String request; 593 try { 594 request = source.readUtf8LineStrict(); 595 } catch (IOException streamIsClosed) { 596 return null; // no request because we closed the stream 597 } 598 if (request.length() == 0) { 599 return null; // no request because the stream is exhausted 600 } 601 602 Headers.Builder headers = new Headers.Builder(); 603 long contentLength = -1; 604 boolean chunked = false; 605 boolean expectContinue = false; 606 String header; 607 while ((header = source.readUtf8LineStrict()).length() != 0) { 608 headers.add(header); 609 String lowercaseHeader = header.toLowerCase(Locale.US); 610 if (contentLength == -1 && lowercaseHeader.startsWith("content-length:")) { 611 contentLength = Long.parseLong(header.substring(15).trim()); 612 } 613 if (lowercaseHeader.startsWith("transfer-encoding:") 614 && lowercaseHeader.substring(18).trim().equals("chunked")) { 615 chunked = true; 616 } 617 if (lowercaseHeader.startsWith("expect:") 618 && lowercaseHeader.substring(7).trim().equals("100-continue")) { 619 expectContinue = true; 620 } 621 } 622 623 if (expectContinue) { 624 sink.writeUtf8("HTTP/1.1 100 Continue\r\n"); 625 sink.writeUtf8("Content-Length: 0\r\n"); 626 sink.writeUtf8("\r\n"); 627 sink.flush(); 628 } 629 630 boolean hasBody = false; 631 TruncatingBuffer requestBody = new TruncatingBuffer(bodyLimit); 632 List<Integer> chunkSizes = new ArrayList<>(); 633 MockResponse policy = dispatcher.peek(); 634 if (contentLength != -1) { 635 hasBody = contentLength > 0; 636 throttledTransfer(policy, socket, source, Okio.buffer(requestBody), contentLength, true); 637 } else if (chunked) { 638 hasBody = true; 639 while (true) { 640 int chunkSize = Integer.parseInt(source.readUtf8LineStrict().trim(), 16); 641 if (chunkSize == 0) { 642 readEmptyLine(source); 643 break; 644 } 645 chunkSizes.add(chunkSize); 646 throttledTransfer(policy, socket, source, Okio.buffer(requestBody), chunkSize, true); 647 readEmptyLine(source); 648 } 649 } 650 651 String method = request.substring(0, request.indexOf(' ')); 652 if (hasBody && !HttpMethod.permitsRequestBody(method)) { 653 throw new IllegalArgumentException("Request must not have a body: " + request); 654 } 655 656 return new RecordedRequest(request, headers.build(), chunkSizes, requestBody.receivedByteCount, 657 requestBody.buffer, sequenceNumber, socket); 658 } 659 660 private void handleWebSocketUpgrade(Socket socket, BufferedSource source, BufferedSink sink, 661 RecordedRequest request, MockResponse response) throws IOException { 662 String key = request.getHeader("Sec-WebSocket-Key"); 663 String acceptKey = Util.shaBase64(key + WebSocketProtocol.ACCEPT_MAGIC); 664 response.setHeader("Sec-WebSocket-Accept", acceptKey); 665 666 writeHttpResponse(socket, sink, response); 667 668 final WebSocketListener listener = response.getWebSocketListener(); 669 final CountDownLatch connectionClose = new CountDownLatch(1); 670 671 ThreadPoolExecutor replyExecutor = 672 new ThreadPoolExecutor(1, 1, 1, SECONDS, new LinkedBlockingDeque<Runnable>(), 673 Util.threadFactory(String.format("MockWebServer %s WebSocket", request.getPath()), 674 true)); 675 replyExecutor.allowCoreThreadTimeOut(true); 676 final RealWebSocket webSocket = 677 new RealWebSocket(false /* is server */, source, sink, new SecureRandom(), replyExecutor, 678 listener, request.getPath()) { 679 @Override protected void close() throws IOException { 680 connectionClose.countDown(); 681 } 682 }; 683 684 // Adapt the request and response into our Request and Response domain model. 685 String scheme = request.getTlsVersion() != null ? "https" : "http"; 686 String authority = request.getHeader("Host"); // Has host and port. 687 final Request fancyRequest = new Request.Builder() 688 .url(scheme + "://" + authority + "/") 689 .headers(request.getHeaders()) 690 .build(); 691 final Response fancyResponse = new Response.Builder() 692 .code(Integer.parseInt(response.getStatus().split(" ")[1])) 693 .message(response.getStatus().split(" ", 3)[2]) 694 .headers(response.getHeaders()) 695 .request(fancyRequest) 696 .protocol(Protocol.HTTP_1_1) 697 .build(); 698 699 listener.onOpen(webSocket, fancyResponse); 700 701 while (webSocket.readMessage()) { 702 } 703 704 // Even if messages are no longer being read we need to wait for the connection close signal. 705 try { 706 connectionClose.await(); 707 } catch (InterruptedException e) { 708 throw new RuntimeException(e); 709 } 710 711 replyExecutor.shutdown(); 712 Util.closeQuietly(sink); 713 Util.closeQuietly(source); 714 } 715 716 private void writeHttpResponse(Socket socket, BufferedSink sink, MockResponse response) 717 throws IOException { 718 sink.writeUtf8(response.getStatus()); 719 sink.writeUtf8("\r\n"); 720 721 Headers headers = response.getHeaders(); 722 for (int i = 0, size = headers.size(); i < size; i++) { 723 sink.writeUtf8(headers.name(i)); 724 sink.writeUtf8(": "); 725 sink.writeUtf8(headers.value(i)); 726 sink.writeUtf8("\r\n"); 727 } 728 sink.writeUtf8("\r\n"); 729 sink.flush(); 730 731 Buffer body = response.getBody(); 732 if (body == null) return; 733 sleepIfDelayed(response); 734 throttledTransfer(response, socket, body, sink, body.size(), false); 735 } 736 737 private void sleepIfDelayed(MockResponse response) { 738 long delayMs = response.getBodyDelay(TimeUnit.MILLISECONDS); 739 if (delayMs != 0) { 740 try { 741 Thread.sleep(delayMs); 742 } catch (InterruptedException e) { 743 throw new AssertionError(e); 744 } 745 } 746 } 747 748 /** 749 * Transfer bytes from {@code source} to {@code sink} until either {@code byteCount} 750 * bytes have been transferred or {@code source} is exhausted. The transfer is 751 * throttled according to {@code policy}. 752 */ 753 private void throttledTransfer(MockResponse policy, Socket socket, BufferedSource source, 754 BufferedSink sink, long byteCount, boolean isRequest) throws IOException { 755 if (byteCount == 0) return; 756 757 Buffer buffer = new Buffer(); 758 long bytesPerPeriod = policy.getThrottleBytesPerPeriod(); 759 long periodDelayMs = policy.getThrottlePeriod(TimeUnit.MILLISECONDS); 760 761 long halfByteCount = byteCount / 2; 762 boolean disconnectHalfway = isRequest 763 ? policy.getSocketPolicy() == DISCONNECT_DURING_REQUEST_BODY 764 : policy.getSocketPolicy() == DISCONNECT_DURING_RESPONSE_BODY; 765 766 while (!socket.isClosed()) { 767 for (int b = 0; b < bytesPerPeriod; ) { 768 // Ensure we do not read past the allotted bytes in this period. 769 long toRead = Math.min(byteCount, bytesPerPeriod - b); 770 // Ensure we do not read past halfway if the policy will kill the connection. 771 if (disconnectHalfway) { 772 toRead = Math.min(toRead, byteCount - halfByteCount); 773 } 774 775 long read = source.read(buffer, toRead); 776 if (read == -1) return; 777 778 sink.write(buffer, read); 779 sink.flush(); 780 b += read; 781 byteCount -= read; 782 783 if (disconnectHalfway && byteCount == halfByteCount) { 784 socket.close(); 785 return; 786 } 787 788 if (byteCount == 0) return; 789 } 790 791 if (periodDelayMs != 0) { 792 try { 793 Thread.sleep(periodDelayMs); 794 } catch (InterruptedException e) { 795 throw new AssertionError(); 796 } 797 } 798 } 799 } 800 801 private void readEmptyLine(BufferedSource source) throws IOException { 802 String line = source.readUtf8LineStrict(); 803 if (line.length() != 0) throw new IllegalStateException("Expected empty but was: " + line); 804 } 805 806 /** 807 * Sets the dispatcher used to match incoming requests to mock responses. 808 * The default dispatcher simply serves a fixed sequence of responses from 809 * a {@link #enqueue(MockResponse) queue}; custom dispatchers can vary the 810 * response based on timing or the content of the request. 811 */ 812 public void setDispatcher(Dispatcher dispatcher) { 813 if (dispatcher == null) throw new NullPointerException(); 814 this.dispatcher = dispatcher; 815 } 816 817 @Override public String toString() { 818 return "MockWebServer[" + port + "]"; 819 } 820 821 /** A buffer wrapper that drops data after {@code bodyLimit} bytes. */ 822 private static class TruncatingBuffer implements Sink { 823 private final Buffer buffer = new Buffer(); 824 private long remainingByteCount; 825 private long receivedByteCount; 826 827 TruncatingBuffer(long bodyLimit) { 828 remainingByteCount = bodyLimit; 829 } 830 831 @Override public void write(Buffer source, long byteCount) throws IOException { 832 long toRead = Math.min(remainingByteCount, byteCount); 833 if (toRead > 0) { 834 source.read(buffer, toRead); 835 } 836 long toSkip = byteCount - toRead; 837 if (toSkip > 0) { 838 source.skip(toSkip); 839 } 840 remainingByteCount -= toRead; 841 receivedByteCount += byteCount; 842 } 843 844 @Override public void flush() throws IOException { 845 } 846 847 @Override public Timeout timeout() { 848 return Timeout.NONE; 849 } 850 851 @Override public void close() throws IOException { 852 } 853 } 854 855 /** Processes HTTP requests layered over framed protocols. */ 856 private class FramedSocketHandler extends FramedConnection.Listener { 857 private final Socket socket; 858 private final Protocol protocol; 859 private final AtomicInteger sequenceNumber = new AtomicInteger(); 860 861 private FramedSocketHandler(Socket socket, Protocol protocol) { 862 this.socket = socket; 863 this.protocol = protocol; 864 } 865 866 @Override public void onStream(FramedStream stream) throws IOException { 867 RecordedRequest request = readRequest(stream); 868 requestQueue.add(request); 869 MockResponse response; 870 try { 871 response = dispatcher.dispatch(request); 872 } catch (InterruptedException e) { 873 throw new AssertionError(e); 874 } 875 writeResponse(stream, response); 876 if (logger.isLoggable(Level.INFO)) { 877 logger.info(MockWebServer.this + " received request: " + request 878 + " and responded: " + response + " protocol is " + protocol.toString()); 879 } 880 } 881 882 private RecordedRequest readRequest(FramedStream stream) throws IOException { 883 List<Header> streamHeaders = stream.getRequestHeaders(); 884 Headers.Builder httpHeaders = new Headers.Builder(); 885 String method = "<:method omitted>"; 886 String path = "<:path omitted>"; 887 String version = protocol == Protocol.SPDY_3 ? "<:version omitted>" : "HTTP/1.1"; 888 for (int i = 0, size = streamHeaders.size(); i < size; i++) { 889 ByteString name = streamHeaders.get(i).name; 890 String value = streamHeaders.get(i).value.utf8(); 891 if (name.equals(Header.TARGET_METHOD)) { 892 method = value; 893 } else if (name.equals(Header.TARGET_PATH)) { 894 path = value; 895 } else if (name.equals(Header.VERSION)) { 896 version = value; 897 } else if (protocol == Protocol.SPDY_3) { 898 for (String s : value.split("\u0000", -1)) { 899 httpHeaders.add(name.utf8(), s); 900 } 901 } else if (protocol == Protocol.HTTP_2) { 902 httpHeaders.add(name.utf8(), value); 903 } else { 904 throw new IllegalStateException(); 905 } 906 } 907 908 Buffer body = new Buffer(); 909 body.writeAll(stream.getSource()); 910 body.close(); 911 912 String requestLine = method + ' ' + path + ' ' + version; 913 List<Integer> chunkSizes = Collections.emptyList(); // No chunked encoding for SPDY. 914 return new RecordedRequest(requestLine, httpHeaders.build(), chunkSizes, body.size(), body, 915 sequenceNumber.getAndIncrement(), socket); 916 } 917 918 private void writeResponse(FramedStream stream, MockResponse response) throws IOException { 919 Settings settings = response.getSettings(); 920 if (settings != null) { 921 stream.getConnection().setSettings(settings); 922 } 923 924 if (response.getSocketPolicy() == SocketPolicy.NO_RESPONSE) { 925 return; 926 } 927 List<Header> spdyHeaders = new ArrayList<>(); 928 String[] statusParts = response.getStatus().split(" ", 2); 929 if (statusParts.length != 2) { 930 throw new AssertionError("Unexpected status: " + response.getStatus()); 931 } 932 // TODO: constants for well-known header names. 933 spdyHeaders.add(new Header(Header.RESPONSE_STATUS, statusParts[1])); 934 if (protocol == Protocol.SPDY_3) { 935 spdyHeaders.add(new Header(Header.VERSION, statusParts[0])); 936 } 937 Headers headers = response.getHeaders(); 938 for (int i = 0, size = headers.size(); i < size; i++) { 939 spdyHeaders.add(new Header(headers.name(i), headers.value(i))); 940 } 941 942 Buffer body = response.getBody(); 943 boolean closeStreamAfterHeaders = body != null || !response.getPushPromises().isEmpty(); 944 stream.reply(spdyHeaders, closeStreamAfterHeaders); 945 pushPromises(stream, response.getPushPromises()); 946 if (body != null) { 947 BufferedSink sink = Okio.buffer(stream.getSink()); 948 sleepIfDelayed(response); 949 throttledTransfer(response, socket, body, sink, bodyLimit, false); 950 sink.close(); 951 } else if (closeStreamAfterHeaders) { 952 stream.close(ErrorCode.NO_ERROR); 953 } 954 } 955 956 private void pushPromises(FramedStream stream, List<PushPromise> promises) throws IOException { 957 for (PushPromise pushPromise : promises) { 958 List<Header> pushedHeaders = new ArrayList<>(); 959 pushedHeaders.add(new Header(stream.getConnection().getProtocol() == Protocol.SPDY_3 960 ? Header.TARGET_HOST 961 : Header.TARGET_AUTHORITY, url(pushPromise.getPath()).host())); 962 pushedHeaders.add(new Header(Header.TARGET_METHOD, pushPromise.getMethod())); 963 pushedHeaders.add(new Header(Header.TARGET_PATH, pushPromise.getPath())); 964 Headers pushPromiseHeaders = pushPromise.getHeaders(); 965 for (int i = 0, size = pushPromiseHeaders.size(); i < size; i++) { 966 pushedHeaders.add(new Header(pushPromiseHeaders.name(i), pushPromiseHeaders.value(i))); 967 } 968 String requestLine = pushPromise.getMethod() + ' ' + pushPromise.getPath() + " HTTP/1.1"; 969 List<Integer> chunkSizes = Collections.emptyList(); // No chunked encoding for SPDY. 970 requestQueue.add(new RecordedRequest(requestLine, pushPromise.getHeaders(), chunkSizes, 0, 971 new Buffer(), sequenceNumber.getAndIncrement(), socket)); 972 boolean hasBody = pushPromise.getResponse().getBody() != null; 973 FramedStream pushedStream = 974 stream.getConnection().pushStream(stream.getId(), pushedHeaders, hasBody); 975 writeResponse(pushedStream, pushPromise.getResponse()); 976 } 977 } 978 } 979 } 980