1 /* 2 * Copyright (C) 2011 Google Inc. 3 * Copyright (C) 2013 Square, Inc. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18 package com.squareup.okhttp.mockwebserver; 19 20 import com.squareup.okhttp.Headers; 21 import com.squareup.okhttp.Protocol; 22 import com.squareup.okhttp.Request; 23 import com.squareup.okhttp.Response; 24 import com.squareup.okhttp.internal.NamedRunnable; 25 import com.squareup.okhttp.internal.Platform; 26 import com.squareup.okhttp.internal.Util; 27 import com.squareup.okhttp.internal.spdy.ErrorCode; 28 import com.squareup.okhttp.internal.spdy.Header; 29 import com.squareup.okhttp.internal.spdy.IncomingStreamHandler; 30 import com.squareup.okhttp.internal.spdy.SpdyConnection; 31 import com.squareup.okhttp.internal.spdy.SpdyStream; 32 import com.squareup.okhttp.internal.ws.RealWebSocket; 33 import com.squareup.okhttp.internal.ws.WebSocketProtocol; 34 import com.squareup.okhttp.ws.WebSocketListener; 35 import java.io.IOException; 36 import java.net.InetAddress; 37 import java.net.InetSocketAddress; 38 import java.net.MalformedURLException; 39 import java.net.ProtocolException; 40 import java.net.Proxy; 41 import java.net.ServerSocket; 42 import java.net.Socket; 43 import java.net.SocketException; 44 import java.net.URL; 45 import java.security.SecureRandom; 46 import java.security.cert.CertificateException; 47 import java.security.cert.X509Certificate; 48 import java.util.ArrayList; 49 import java.util.Collections; 50 import java.util.Iterator; 51 import java.util.List; 52 import java.util.Locale; 53 import java.util.Set; 54 import java.util.concurrent.BlockingQueue; 55 import java.util.concurrent.ConcurrentHashMap; 56 import java.util.concurrent.CountDownLatch; 57 import java.util.concurrent.ExecutorService; 58 import java.util.concurrent.Executors; 59 import java.util.concurrent.LinkedBlockingDeque; 60 import java.util.concurrent.LinkedBlockingQueue; 61 import java.util.concurrent.ThreadPoolExecutor; 62 import java.util.concurrent.TimeUnit; 63 import java.util.concurrent.atomic.AtomicInteger; 64 import java.util.logging.Level; 65 import java.util.logging.Logger; 66 import javax.net.ServerSocketFactory; 67 import javax.net.ssl.SSLContext; 68 import javax.net.ssl.SSLSocket; 69 import javax.net.ssl.SSLSocketFactory; 70 import javax.net.ssl.TrustManager; 71 import javax.net.ssl.X509TrustManager; 72 import okio.Buffer; 73 import okio.BufferedSink; 74 import okio.BufferedSource; 75 import okio.ByteString; 76 import okio.Okio; 77 import okio.Sink; 78 import okio.Timeout; 79 80 import static com.squareup.okhttp.mockwebserver.SocketPolicy.DISCONNECT_AT_START; 81 import static com.squareup.okhttp.mockwebserver.SocketPolicy.FAIL_HANDSHAKE; 82 import static java.util.concurrent.TimeUnit.SECONDS; 83 84 /** 85 * A scriptable web server. Callers supply canned responses and the server 86 * replays them upon request in sequence. 87 */ 88 public final class MockWebServer { 89 private static final X509TrustManager UNTRUSTED_TRUST_MANAGER = new X509TrustManager() { 90 @Override public void checkClientTrusted(X509Certificate[] chain, String authType) 91 throws CertificateException { 92 throw new CertificateException(); 93 } 94 95 @Override public void checkServerTrusted(X509Certificate[] chain, String authType) { 96 throw new AssertionError(); 97 } 98 99 @Override public X509Certificate[] getAcceptedIssuers() { 100 throw new AssertionError(); 101 } 102 }; 103 104 private static final Logger logger = Logger.getLogger(MockWebServer.class.getName()); 105 106 private final BlockingQueue<RecordedRequest> requestQueue = new LinkedBlockingQueue<>(); 107 108 private final Set<Socket> openClientSockets = 109 Collections.newSetFromMap(new ConcurrentHashMap<Socket, Boolean>()); 110 private final Set<SpdyConnection> openSpdyConnections = 111 Collections.newSetFromMap(new ConcurrentHashMap<SpdyConnection, Boolean>()); 112 private final AtomicInteger requestCount = new AtomicInteger(); 113 private long bodyLimit = Long.MAX_VALUE; 114 private ServerSocketFactory serverSocketFactory = ServerSocketFactory.getDefault(); 115 private ServerSocket serverSocket; 116 private SSLSocketFactory sslSocketFactory; 117 private ExecutorService executor; 118 private boolean tunnelProxy; 119 private Dispatcher dispatcher = new QueueDispatcher(); 120 121 private int port = -1; 122 private InetSocketAddress inetSocketAddress; 123 private boolean protocolNegotiationEnabled = true; 124 private List<Protocol> protocols 125 = Util.immutableList(Protocol.HTTP_2, Protocol.SPDY_3, Protocol.HTTP_1_1); 126 127 public void setServerSocketFactory(ServerSocketFactory serverSocketFactory) { 128 if (serverSocketFactory == null) throw new IllegalArgumentException("null serverSocketFactory"); 129 this.serverSocketFactory = serverSocketFactory; 130 } 131 132 public int getPort() { 133 if (port == -1) throw new IllegalStateException("Call start() before getPort()"); 134 return port; 135 } 136 137 public String getHostName() { 138 if (inetSocketAddress == null) { 139 throw new IllegalStateException("Call start() before getHostName()"); 140 } 141 return inetSocketAddress.getHostName(); 142 } 143 144 public Proxy toProxyAddress() { 145 if (inetSocketAddress == null) { 146 throw new IllegalStateException("Call start() before toProxyAddress()"); 147 } 148 InetSocketAddress address = new InetSocketAddress(inetSocketAddress.getAddress(), getPort()); 149 return new Proxy(Proxy.Type.HTTP, address); 150 } 151 152 /** 153 * Returns a URL for connecting to this server. 154 * @param path the request path, such as "/". 155 */ 156 public URL getUrl(String path) { 157 try { 158 return sslSocketFactory != null 159 ? new URL("https://" + getHostName() + ":" + getPort() + path) 160 : new URL("http://" + getHostName() + ":" + getPort() + path); 161 } catch (MalformedURLException e) { 162 throw new AssertionError(e); 163 } 164 } 165 166 /** 167 * Returns a cookie domain for this server. This returns the server's 168 * non-loopback host name if it is known. Otherwise this returns ".local" for 169 * this server's loopback name. 170 */ 171 public String getCookieDomain() { 172 String hostName = getHostName(); 173 return hostName.contains(".") ? hostName : ".local"; 174 } 175 176 /** 177 * Sets the number of bytes of the POST body to keep in memory to the given 178 * limit. 179 */ 180 public void setBodyLimit(long maxBodyLength) { 181 this.bodyLimit = maxBodyLength; 182 } 183 184 /** 185 * Sets whether ALPN is used on incoming HTTPS connections to 186 * negotiate a protocol like HTTP/1.1 or HTTP/2. Call this method to disable 187 * negotiation and restrict connections to HTTP/1.1. 188 */ 189 public void setProtocolNegotiationEnabled(boolean protocolNegotiationEnabled) { 190 this.protocolNegotiationEnabled = protocolNegotiationEnabled; 191 } 192 193 /** 194 * Indicates the protocols supported by ALPN on incoming HTTPS 195 * connections. This list is ignored when 196 * {@link #setProtocolNegotiationEnabled negotiation is disabled}. 197 * 198 * @param protocols the protocols to use, in order of preference. The list 199 * must contain {@linkplain Protocol#HTTP_1_1}. It must not contain null. 200 */ 201 public void setProtocols(List<Protocol> protocols) { 202 protocols = Util.immutableList(protocols); 203 if (!protocols.contains(Protocol.HTTP_1_1)) { 204 throw new IllegalArgumentException("protocols doesn't contain http/1.1: " + protocols); 205 } 206 if (protocols.contains(null)) { 207 throw new IllegalArgumentException("protocols must not contain null"); 208 } 209 this.protocols = protocols; 210 } 211 212 /** 213 * Serve requests with HTTPS rather than otherwise. 214 * @param tunnelProxy true to expect the HTTP CONNECT method before 215 * negotiating TLS. 216 */ 217 public void useHttps(SSLSocketFactory sslSocketFactory, boolean tunnelProxy) { 218 this.sslSocketFactory = sslSocketFactory; 219 this.tunnelProxy = tunnelProxy; 220 } 221 222 /** 223 * Awaits the next HTTP request, removes it, and returns it. Callers should 224 * use this to verify the request was sent as intended. This method will block until the 225 * request is available, possibly forever. 226 * 227 * @return the head of the request queue 228 */ 229 public RecordedRequest takeRequest() throws InterruptedException { 230 return requestQueue.take(); 231 } 232 233 /** 234 * Awaits the next HTTP request (waiting up to the 235 * specified wait time if necessary), removes it, and returns it. Callers should 236 * use this to verify the request was sent as intended within the given time. 237 * 238 * @param timeout how long to wait before giving up, in units of 239 * {@code unit} 240 * @param unit a {@code TimeUnit} determining how to interpret the 241 * {@code timeout} parameter 242 * @return the head of the request queue 243 */ 244 public RecordedRequest takeRequest(long timeout, TimeUnit unit) throws InterruptedException { 245 return requestQueue.poll(timeout, unit); 246 } 247 248 /** 249 * Returns the number of HTTP requests received thus far by this server. This 250 * may exceed the number of HTTP connections when connection reuse is in 251 * practice. 252 */ 253 public int getRequestCount() { 254 return requestCount.get(); 255 } 256 257 /** 258 * Scripts {@code response} to be returned to a request made in sequence. The 259 * first request is served by the first enqueued response; the second request 260 * by the second enqueued response; and so on. 261 * 262 * @throws ClassCastException if the default dispatcher has been replaced 263 * with {@link #setDispatcher(Dispatcher)}. 264 */ 265 public void enqueue(MockResponse response) { 266 ((QueueDispatcher) dispatcher).enqueueResponse(response.clone()); 267 } 268 269 /** @deprecated Use {@link #start()}. */ 270 public void play() throws IOException { 271 start(); 272 } 273 274 /** @deprecated Use {@link #start(int)}. */ 275 public void play(int port) throws IOException { 276 start(port); 277 } 278 279 /** Equivalent to {@code start(0)}. */ 280 public void start() throws IOException { 281 start(0); 282 } 283 284 /** 285 * Starts the server on the loopback interface for the given port. 286 * 287 * @param port the port to listen to, or 0 for any available port. Automated 288 * tests should always use port 0 to avoid flakiness when a specific port 289 * is unavailable. 290 */ 291 public void start(int port) throws IOException { 292 start(InetAddress.getByName("localhost"), port); 293 } 294 295 /** 296 * Starts the server on the given address and port. 297 * 298 * @param inetAddress the address to create the server socket on 299 * 300 * @param port the port to listen to, or 0 for any available port. Automated 301 * tests should always use port 0 to avoid flakiness when a specific port 302 * is unavailable. 303 */ 304 public void start(InetAddress inetAddress, int port) throws IOException { 305 start(new InetSocketAddress(inetAddress, port)); 306 } 307 308 /** 309 * Starts the server and binds to the given socket address. 310 * 311 * @param inetSocketAddress the socket address to bind the server on 312 */ 313 private void start(InetSocketAddress inetSocketAddress) throws IOException { 314 if (executor != null) throw new IllegalStateException("start() already called"); 315 executor = Executors.newCachedThreadPool(Util.threadFactory("MockWebServer", false)); 316 this.inetSocketAddress = inetSocketAddress; 317 serverSocket = serverSocketFactory.createServerSocket(); 318 // Reuse if the user specified a port 319 serverSocket.setReuseAddress(inetSocketAddress.getPort() != 0); 320 serverSocket.bind(inetSocketAddress, 50); 321 322 port = serverSocket.getLocalPort(); 323 executor.execute(new NamedRunnable("MockWebServer %s", port) { 324 @Override protected void execute() { 325 try { 326 logger.info(MockWebServer.this + " starting to accept connections"); 327 acceptConnections(); 328 } catch (Throwable e) { 329 logger.log(Level.WARNING, MockWebServer.this + " failed unexpectedly", e); 330 } 331 332 // Release all sockets and all threads, even if any close fails. 333 Util.closeQuietly(serverSocket); 334 for (Iterator<Socket> s = openClientSockets.iterator(); s.hasNext(); ) { 335 Util.closeQuietly(s.next()); 336 s.remove(); 337 } 338 for (Iterator<SpdyConnection> s = openSpdyConnections.iterator(); s.hasNext(); ) { 339 Util.closeQuietly(s.next()); 340 s.remove(); 341 } 342 executor.shutdown(); 343 } 344 345 private void acceptConnections() throws Exception { 346 while (true) { 347 Socket socket; 348 try { 349 socket = serverSocket.accept(); 350 } catch (SocketException e) { 351 logger.info(MockWebServer.this + " done accepting connections: " + e.getMessage()); 352 return; 353 } 354 SocketPolicy socketPolicy = dispatcher.peek().getSocketPolicy(); 355 if (socketPolicy == DISCONNECT_AT_START) { 356 dispatchBookkeepingRequest(0, socket); 357 socket.close(); 358 } else { 359 openClientSockets.add(socket); 360 serveConnection(socket); 361 } 362 } 363 } 364 }); 365 } 366 367 public void shutdown() throws IOException { 368 if (serverSocket == null) throw new IllegalStateException("shutdown() before start()"); 369 370 // Cause acceptConnections() to break out. 371 serverSocket.close(); 372 373 // Await shutdown. 374 try { 375 if (!executor.awaitTermination(5, TimeUnit.SECONDS)) { 376 throw new IOException("Gave up waiting for executor to shut down"); 377 } 378 } catch (InterruptedException e) { 379 throw new AssertionError(); 380 } 381 } 382 383 private void serveConnection(final Socket raw) { 384 executor.execute(new NamedRunnable("MockWebServer %s", raw.getRemoteSocketAddress()) { 385 int sequenceNumber = 0; 386 387 @Override protected void execute() { 388 try { 389 processConnection(); 390 } catch (IOException e) { 391 logger.info( 392 MockWebServer.this + " connection from " + raw.getInetAddress() + " failed: " + e); 393 } catch (Exception e) { 394 logger.log(Level.SEVERE, 395 MockWebServer.this + " connection from " + raw.getInetAddress() + " crashed", e); 396 } 397 } 398 399 public void processConnection() throws Exception { 400 Protocol protocol = Protocol.HTTP_1_1; 401 Socket socket; 402 if (sslSocketFactory != null) { 403 if (tunnelProxy) { 404 createTunnel(); 405 } 406 SocketPolicy socketPolicy = dispatcher.peek().getSocketPolicy(); 407 if (socketPolicy == FAIL_HANDSHAKE) { 408 dispatchBookkeepingRequest(sequenceNumber, raw); 409 processHandshakeFailure(raw); 410 return; 411 } 412 socket = sslSocketFactory.createSocket(raw, raw.getInetAddress().getHostAddress(), 413 raw.getPort(), true); 414 SSLSocket sslSocket = (SSLSocket) socket; 415 sslSocket.setUseClientMode(false); 416 openClientSockets.add(socket); 417 418 if (protocolNegotiationEnabled) { 419 Platform.get().configureTlsExtensions(sslSocket, null, protocols); 420 } 421 422 sslSocket.startHandshake(); 423 424 if (protocolNegotiationEnabled) { 425 String protocolString = Platform.get().getSelectedProtocol(sslSocket); 426 protocol = protocolString != null ? Protocol.get(protocolString) : Protocol.HTTP_1_1; 427 } 428 openClientSockets.remove(raw); 429 } else { 430 socket = raw; 431 } 432 433 if (protocol != Protocol.HTTP_1_1) { 434 SpdySocketHandler spdySocketHandler = new SpdySocketHandler(socket, protocol); 435 SpdyConnection spdyConnection = 436 new SpdyConnection.Builder(false, socket).protocol(protocol) 437 .handler(spdySocketHandler) 438 .build(); 439 openSpdyConnections.add(spdyConnection); 440 openClientSockets.remove(socket); 441 return; 442 } 443 444 BufferedSource source = Okio.buffer(Okio.source(socket)); 445 BufferedSink sink = Okio.buffer(Okio.sink(socket)); 446 447 while (processOneRequest(socket, source, sink)) { 448 } 449 450 if (sequenceNumber == 0) { 451 logger.warning(MockWebServer.this 452 + " connection from " 453 + raw.getInetAddress() 454 + " didn't make a request"); 455 } 456 457 source.close(); 458 sink.close(); 459 socket.close(); 460 openClientSockets.remove(socket); 461 } 462 463 /** 464 * Respond to CONNECT requests until a SWITCH_TO_SSL_AT_END response is 465 * dispatched. 466 */ 467 private void createTunnel() throws IOException, InterruptedException { 468 BufferedSource source = Okio.buffer(Okio.source(raw)); 469 BufferedSink sink = Okio.buffer(Okio.sink(raw)); 470 while (true) { 471 SocketPolicy socketPolicy = dispatcher.peek().getSocketPolicy(); 472 if (!processOneRequest(raw, source, sink)) { 473 throw new IllegalStateException("Tunnel without any CONNECT!"); 474 } 475 if (socketPolicy == SocketPolicy.UPGRADE_TO_SSL_AT_END) return; 476 } 477 } 478 479 /** 480 * Reads a request and writes its response. Returns true if further calls should be attempted 481 * on the socket. 482 */ 483 private boolean processOneRequest(Socket socket, BufferedSource source, BufferedSink sink) 484 throws IOException, InterruptedException { 485 RecordedRequest request = readRequest(socket, source, sink, sequenceNumber); 486 if (request == null) return false; 487 488 requestCount.incrementAndGet(); 489 requestQueue.add(request); 490 491 MockResponse response = dispatcher.dispatch(request); 492 if (response.getSocketPolicy() == SocketPolicy.DISCONNECT_AFTER_REQUEST) { 493 socket.close(); 494 return false; 495 } 496 if (response.getSocketPolicy() == SocketPolicy.NO_RESPONSE) { 497 // This read should block until the socket is closed. (Because nobody is writing.) 498 if (source.exhausted()) return false; 499 throw new ProtocolException("unexpected data"); 500 } 501 502 boolean requestWantsWebSockets = "Upgrade".equalsIgnoreCase(request.getHeader("Connection")) 503 && "websocket".equalsIgnoreCase(request.getHeader("Upgrade")); 504 boolean responseWantsWebSockets = response.getWebSocketListener() != null; 505 if (requestWantsWebSockets && responseWantsWebSockets) { 506 handleWebSocketUpgrade(socket, source, sink, request, response); 507 } else { 508 writeHttpResponse(socket, sink, response); 509 } 510 511 if (logger.isLoggable(Level.INFO)) { 512 logger.info(MockWebServer.this + " received request: " + request 513 + " and responded: " + response); 514 } 515 516 if (response.getSocketPolicy() == SocketPolicy.DISCONNECT_AT_END) { 517 socket.close(); 518 return false; 519 } else if (response.getSocketPolicy() == SocketPolicy.SHUTDOWN_INPUT_AT_END) { 520 socket.shutdownInput(); 521 } else if (response.getSocketPolicy() == SocketPolicy.SHUTDOWN_OUTPUT_AT_END) { 522 socket.shutdownOutput(); 523 } 524 525 sequenceNumber++; 526 return true; 527 } 528 }); 529 } 530 531 private void processHandshakeFailure(Socket raw) throws Exception { 532 SSLContext context = SSLContext.getInstance("TLS"); 533 context.init(null, new TrustManager[] { UNTRUSTED_TRUST_MANAGER }, new SecureRandom()); 534 SSLSocketFactory sslSocketFactory = context.getSocketFactory(); 535 SSLSocket socket = (SSLSocket) sslSocketFactory.createSocket( 536 raw, raw.getInetAddress().getHostAddress(), raw.getPort(), true); 537 try { 538 socket.startHandshake(); // we're testing a handshake failure 539 throw new AssertionError(); 540 } catch (IOException expected) { 541 } 542 socket.close(); 543 } 544 545 private void dispatchBookkeepingRequest(int sequenceNumber, Socket socket) 546 throws InterruptedException { 547 requestCount.incrementAndGet(); 548 dispatcher.dispatch(new RecordedRequest(null, null, null, -1, null, sequenceNumber, socket)); 549 } 550 551 /** @param sequenceNumber the index of this request on this connection. */ 552 private RecordedRequest readRequest(Socket socket, BufferedSource source, BufferedSink sink, 553 int sequenceNumber) throws IOException { 554 String request; 555 try { 556 request = source.readUtf8LineStrict(); 557 } catch (IOException streamIsClosed) { 558 return null; // no request because we closed the stream 559 } 560 if (request.length() == 0) { 561 return null; // no request because the stream is exhausted 562 } 563 564 Headers.Builder headers = new Headers.Builder(); 565 long contentLength = -1; 566 boolean chunked = false; 567 boolean expectContinue = false; 568 String header; 569 while ((header = source.readUtf8LineStrict()).length() != 0) { 570 headers.add(header); 571 String lowercaseHeader = header.toLowerCase(Locale.US); 572 if (contentLength == -1 && lowercaseHeader.startsWith("content-length:")) { 573 contentLength = Long.parseLong(header.substring(15).trim()); 574 } 575 if (lowercaseHeader.startsWith("transfer-encoding:") 576 && lowercaseHeader.substring(18).trim().equals("chunked")) { 577 chunked = true; 578 } 579 if (lowercaseHeader.startsWith("expect:") 580 && lowercaseHeader.substring(7).trim().equals("100-continue")) { 581 expectContinue = true; 582 } 583 } 584 585 if (expectContinue) { 586 sink.writeUtf8("HTTP/1.1 100 Continue\r\n"); 587 sink.writeUtf8("Content-Length: 0\r\n"); 588 sink.writeUtf8("\r\n"); 589 sink.flush(); 590 } 591 592 boolean hasBody = false; 593 TruncatingBuffer requestBody = new TruncatingBuffer(bodyLimit); 594 List<Integer> chunkSizes = new ArrayList<>(); 595 MockResponse throttlePolicy = dispatcher.peek(); 596 if (contentLength != -1) { 597 hasBody = contentLength > 0; 598 throttledTransfer(throttlePolicy, socket, source, Okio.buffer(requestBody), contentLength); 599 } else if (chunked) { 600 hasBody = true; 601 while (true) { 602 int chunkSize = Integer.parseInt(source.readUtf8LineStrict().trim(), 16); 603 if (chunkSize == 0) { 604 readEmptyLine(source); 605 break; 606 } 607 chunkSizes.add(chunkSize); 608 throttledTransfer(throttlePolicy, socket, source, Okio.buffer(requestBody), chunkSize); 609 readEmptyLine(source); 610 } 611 } 612 613 if (request.startsWith("OPTIONS ") 614 || request.startsWith("GET ") 615 || request.startsWith("HEAD ") 616 || request.startsWith("TRACE ") 617 || request.startsWith("CONNECT ")) { 618 if (hasBody) { 619 throw new IllegalArgumentException("Request must not have a body: " + request); 620 } 621 } else if (!request.startsWith("POST ") 622 && !request.startsWith("PUT ") 623 && !request.startsWith("PATCH ") 624 && !request.startsWith("DELETE ")) { // Permitted as spec is ambiguous. 625 throw new UnsupportedOperationException("Unexpected method: " + request); 626 } 627 628 return new RecordedRequest(request, headers.build(), chunkSizes, requestBody.receivedByteCount, 629 requestBody.buffer, sequenceNumber, socket); 630 } 631 632 private void handleWebSocketUpgrade(Socket socket, BufferedSource source, BufferedSink sink, 633 RecordedRequest request, MockResponse response) throws IOException { 634 String key = request.getHeader("Sec-WebSocket-Key"); 635 String acceptKey = Util.shaBase64(key + WebSocketProtocol.ACCEPT_MAGIC); 636 response.setHeader("Sec-WebSocket-Accept", acceptKey); 637 638 writeHttpResponse(socket, sink, response); 639 640 final WebSocketListener listener = response.getWebSocketListener(); 641 final CountDownLatch connectionClose = new CountDownLatch(1); 642 643 ThreadPoolExecutor replyExecutor = 644 new ThreadPoolExecutor(1, 1, 1, SECONDS, new LinkedBlockingDeque<Runnable>(), 645 Util.threadFactory(String.format("MockWebServer %s WebSocket", request.getPath()), 646 true)); 647 replyExecutor.allowCoreThreadTimeOut(true); 648 final RealWebSocket webSocket = 649 new RealWebSocket(false /* is server */, source, sink, new SecureRandom(), replyExecutor, 650 listener, request.getPath()) { 651 @Override protected void closeConnection() throws IOException { 652 connectionClose.countDown(); 653 } 654 }; 655 656 // Adapt the request and response into our Request and Response domain model. 657 final Request fancyRequest = new Request.Builder() 658 .get().url(request.getPath()) 659 .headers(request.getHeaders()) 660 .build(); 661 final Response fancyResponse = new Response.Builder() 662 .code(Integer.parseInt(response.getStatus().split(" ")[1])) 663 .message(response.getStatus().split(" ", 3)[2]) 664 .headers(response.getHeaders()) 665 .request(fancyRequest) 666 .protocol(Protocol.HTTP_1_1) 667 .build(); 668 669 // The callback might act synchronously. Give it its own thread. 670 new Thread(new Runnable() { 671 @Override public void run() { 672 try { 673 listener.onOpen(webSocket, fancyRequest, fancyResponse); 674 } catch (IOException e) { 675 // TODO try to write close frame? 676 connectionClose.countDown(); 677 } 678 } 679 }, "MockWebServer WebSocket Writer " + request.getPath()).start(); 680 681 // Use this thread to continuously read messages. 682 while (webSocket.readMessage()) { 683 } 684 685 // Even if messages are no longer being read we need to wait for the connection close signal. 686 try { 687 connectionClose.await(); 688 } catch (InterruptedException e) { 689 throw new RuntimeException(e); 690 } 691 692 Util.closeQuietly(sink); 693 Util.closeQuietly(source); 694 } 695 696 private void writeHttpResponse(Socket socket, BufferedSink sink, MockResponse response) 697 throws IOException { 698 sink.writeUtf8(response.getStatus()); 699 sink.writeUtf8("\r\n"); 700 701 Headers headers = response.getHeaders(); 702 for (int i = 0, size = headers.size(); i < size; i++) { 703 sink.writeUtf8(headers.name(i)); 704 sink.writeUtf8(": "); 705 sink.writeUtf8(headers.value(i)); 706 sink.writeUtf8("\r\n"); 707 } 708 sink.writeUtf8("\r\n"); 709 sink.flush(); 710 711 Buffer body = response.getBody(); 712 if (body == null) return; 713 sleepIfDelayed(response); 714 throttledTransfer(response, socket, body, sink, Long.MAX_VALUE); 715 } 716 717 private void sleepIfDelayed(MockResponse response) { 718 long delayMs = response.getBodyDelay(TimeUnit.MILLISECONDS); 719 if (delayMs != 0) { 720 try { 721 Thread.sleep(delayMs); 722 } catch (InterruptedException e) { 723 throw new AssertionError(e); 724 } 725 } 726 } 727 728 /** 729 * Transfer bytes from {@code source} to {@code sink} until either {@code byteCount} 730 * bytes have been transferred or {@code source} is exhausted. The transfer is 731 * throttled according to {@code throttlePolicy}. 732 */ 733 private void throttledTransfer(MockResponse throttlePolicy, Socket socket, BufferedSource source, 734 BufferedSink sink, long byteCount) throws IOException { 735 if (byteCount == 0) return; 736 737 Buffer buffer = new Buffer(); 738 long bytesPerPeriod = throttlePolicy.getThrottleBytesPerPeriod(); 739 long periodDelayMs = throttlePolicy.getThrottlePeriod(TimeUnit.MILLISECONDS); 740 741 while (!socket.isClosed()) { 742 for (int b = 0; b < bytesPerPeriod; ) { 743 long toRead = Math.min(Math.min(2048, byteCount), bytesPerPeriod - b); 744 long read = source.read(buffer, toRead); 745 if (read == -1) return; 746 747 sink.write(buffer, read); 748 sink.flush(); 749 b += read; 750 byteCount -= read; 751 752 if (byteCount == 0) return; 753 } 754 755 if (periodDelayMs != 0) { 756 try { 757 Thread.sleep(periodDelayMs); 758 } catch (InterruptedException e) { 759 throw new AssertionError(); 760 } 761 } 762 } 763 } 764 765 private void readEmptyLine(BufferedSource source) throws IOException { 766 String line = source.readUtf8LineStrict(); 767 if (line.length() != 0) throw new IllegalStateException("Expected empty but was: " + line); 768 } 769 770 /** 771 * Sets the dispatcher used to match incoming requests to mock responses. 772 * The default dispatcher simply serves a fixed sequence of responses from 773 * a {@link #enqueue(MockResponse) queue}; custom dispatchers can vary the 774 * response based on timing or the content of the request. 775 */ 776 public void setDispatcher(Dispatcher dispatcher) { 777 if (dispatcher == null) throw new NullPointerException(); 778 this.dispatcher = dispatcher; 779 } 780 781 @Override public String toString() { 782 return "MockWebServer[" + port + "]"; 783 } 784 785 /** A buffer wrapper that drops data after {@code bodyLimit} bytes. */ 786 private static class TruncatingBuffer implements Sink { 787 private final Buffer buffer = new Buffer(); 788 private long remainingByteCount; 789 private long receivedByteCount; 790 791 TruncatingBuffer(long bodyLimit) { 792 remainingByteCount = bodyLimit; 793 } 794 795 @Override public void write(Buffer source, long byteCount) throws IOException { 796 long toRead = Math.min(remainingByteCount, byteCount); 797 if (toRead > 0) { 798 source.read(buffer, toRead); 799 } 800 long toSkip = byteCount - toRead; 801 if (toSkip > 0) { 802 source.skip(toSkip); 803 } 804 remainingByteCount -= toRead; 805 receivedByteCount += byteCount; 806 } 807 808 @Override public void flush() throws IOException { 809 } 810 811 @Override public Timeout timeout() { 812 return Timeout.NONE; 813 } 814 815 @Override public void close() throws IOException { 816 } 817 } 818 819 /** Processes HTTP requests layered over SPDY/3. */ 820 private class SpdySocketHandler implements IncomingStreamHandler { 821 private final Socket socket; 822 private final Protocol protocol; 823 private final AtomicInteger sequenceNumber = new AtomicInteger(); 824 825 private SpdySocketHandler(Socket socket, Protocol protocol) { 826 this.socket = socket; 827 this.protocol = protocol; 828 } 829 830 @Override public void receive(SpdyStream stream) throws IOException { 831 RecordedRequest request = readRequest(stream); 832 requestQueue.add(request); 833 MockResponse response; 834 try { 835 response = dispatcher.dispatch(request); 836 } catch (InterruptedException e) { 837 throw new AssertionError(e); 838 } 839 writeResponse(stream, response); 840 if (logger.isLoggable(Level.INFO)) { 841 logger.info(MockWebServer.this + " received request: " + request 842 + " and responded: " + response + " protocol is " + protocol.toString()); 843 } 844 } 845 846 private RecordedRequest readRequest(SpdyStream stream) throws IOException { 847 List<Header> spdyHeaders = stream.getRequestHeaders(); 848 Headers.Builder httpHeaders = new Headers.Builder(); 849 String method = "<:method omitted>"; 850 String path = "<:path omitted>"; 851 String version = protocol == Protocol.SPDY_3 ? "<:version omitted>" : "HTTP/1.1"; 852 for (int i = 0, size = spdyHeaders.size(); i < size; i++) { 853 ByteString name = spdyHeaders.get(i).name; 854 String value = spdyHeaders.get(i).value.utf8(); 855 if (name.equals(Header.TARGET_METHOD)) { 856 method = value; 857 } else if (name.equals(Header.TARGET_PATH)) { 858 path = value; 859 } else if (name.equals(Header.VERSION)) { 860 version = value; 861 } else { 862 httpHeaders.add(name.utf8(), value); 863 } 864 } 865 866 Buffer body = new Buffer(); 867 body.writeAll(stream.getSource()); 868 body.close(); 869 870 String requestLine = method + ' ' + path + ' ' + version; 871 List<Integer> chunkSizes = Collections.emptyList(); // No chunked encoding for SPDY. 872 return new RecordedRequest(requestLine, httpHeaders.build(), chunkSizes, body.size(), body, 873 sequenceNumber.getAndIncrement(), socket); 874 } 875 876 private void writeResponse(SpdyStream stream, MockResponse response) throws IOException { 877 if (response.getSocketPolicy() == SocketPolicy.NO_RESPONSE) { 878 return; 879 } 880 List<Header> spdyHeaders = new ArrayList<>(); 881 String[] statusParts = response.getStatus().split(" ", 2); 882 if (statusParts.length != 2) { 883 throw new AssertionError("Unexpected status: " + response.getStatus()); 884 } 885 // TODO: constants for well-known header names. 886 spdyHeaders.add(new Header(Header.RESPONSE_STATUS, statusParts[1])); 887 if (protocol == Protocol.SPDY_3) { 888 spdyHeaders.add(new Header(Header.VERSION, statusParts[0])); 889 } 890 Headers headers = response.getHeaders(); 891 for (int i = 0, size = headers.size(); i < size; i++) { 892 spdyHeaders.add(new Header(headers.name(i), headers.value(i))); 893 } 894 895 Buffer body = response.getBody(); 896 boolean closeStreamAfterHeaders = body != null || !response.getPushPromises().isEmpty(); 897 stream.reply(spdyHeaders, closeStreamAfterHeaders); 898 pushPromises(stream, response.getPushPromises()); 899 if (body != null) { 900 BufferedSink sink = Okio.buffer(stream.getSink()); 901 sleepIfDelayed(response); 902 throttledTransfer(response, socket, body, sink, bodyLimit); 903 sink.close(); 904 } else if (closeStreamAfterHeaders) { 905 stream.close(ErrorCode.NO_ERROR); 906 } 907 } 908 909 private void pushPromises(SpdyStream stream, List<PushPromise> promises) throws IOException { 910 for (PushPromise pushPromise : promises) { 911 List<Header> pushedHeaders = new ArrayList<>(); 912 pushedHeaders.add(new Header(stream.getConnection().getProtocol() == Protocol.SPDY_3 913 ? Header.TARGET_HOST 914 : Header.TARGET_AUTHORITY, getUrl(pushPromise.getPath()).getHost())); 915 pushedHeaders.add(new Header(Header.TARGET_METHOD, pushPromise.getMethod())); 916 pushedHeaders.add(new Header(Header.TARGET_PATH, pushPromise.getPath())); 917 Headers pushPromiseHeaders = pushPromise.getHeaders(); 918 for (int i = 0, size = pushPromiseHeaders.size(); i < size; i++) { 919 pushedHeaders.add(new Header(pushPromiseHeaders.name(i), pushPromiseHeaders.value(i))); 920 } 921 String requestLine = pushPromise.getMethod() + ' ' + pushPromise.getPath() + " HTTP/1.1"; 922 List<Integer> chunkSizes = Collections.emptyList(); // No chunked encoding for SPDY. 923 requestQueue.add(new RecordedRequest(requestLine, pushPromise.getHeaders(), chunkSizes, 0, 924 new Buffer(), sequenceNumber.getAndIncrement(), socket)); 925 boolean hasBody = pushPromise.getResponse().getBody() != null; 926 SpdyStream pushedStream = 927 stream.getConnection().pushStream(stream.getId(), pushedHeaders, hasBody); 928 writeResponse(pushedStream, pushPromise.getResponse()); 929 } 930 } 931 } 932 } 933