Home | History | Annotate | Download | only in ws
      1 /*
      2  * Copyright (C) 2014 Square, Inc.
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 package com.squareup.okhttp.ws;
     17 
     18 import com.squareup.okhttp.OkHttpClient;
     19 import com.squareup.okhttp.Request;
     20 import com.squareup.okhttp.Response;
     21 import com.squareup.okhttp.internal.SslContextBuilder;
     22 import com.squareup.okhttp.mockwebserver.MockResponse;
     23 import com.squareup.okhttp.mockwebserver.MockWebServer;
     24 import com.squareup.okhttp.testing.RecordingHostnameVerifier;
     25 import java.io.IOException;
     26 import java.net.ProtocolException;
     27 import java.util.Random;
     28 import java.util.concurrent.CountDownLatch;
     29 import java.util.concurrent.TimeUnit;
     30 import java.util.concurrent.atomic.AtomicReference;
     31 import javax.net.ssl.SSLContext;
     32 import okio.Buffer;
     33 import okio.BufferedSink;
     34 import okio.BufferedSource;
     35 import org.junit.After;
     36 import org.junit.Rule;
     37 import org.junit.Test;
     38 
     39 import static com.squareup.okhttp.ws.WebSocket.PayloadType.TEXT;
     40 
     41 public final class WebSocketCallTest {
     42   @Rule public final MockWebServer server = new MockWebServer();
     43 
     44   private final SSLContext sslContext = SslContextBuilder.localhost();
     45   private final WebSocketRecorder listener = new WebSocketRecorder();
     46   private final OkHttpClient client = new OkHttpClient();
     47   private final Random random = new Random(0);
     48 
     49   @After public void tearDown() {
     50     listener.assertExhausted();
     51   }
     52 
     53   @Test public void clientPingPong() throws IOException {
     54     WebSocketListener serverListener = new EmptyWebSocketListener();
     55     server.enqueue(new MockResponse().withWebSocketUpgrade(serverListener));
     56 
     57     WebSocket webSocket = awaitWebSocket();
     58     webSocket.sendPing(new Buffer().writeUtf8("Hello, WebSockets!"));
     59     listener.assertPong(new Buffer().writeUtf8("Hello, WebSockets!"));
     60   }
     61 
     62   @Test public void clientMessage() throws IOException {
     63     WebSocketRecorder serverListener = new WebSocketRecorder();
     64     server.enqueue(new MockResponse().withWebSocketUpgrade(serverListener));
     65 
     66     WebSocket webSocket = awaitWebSocket();
     67     webSocket.sendMessage(TEXT, new Buffer().writeUtf8("Hello, WebSockets!"));
     68     serverListener.assertTextMessage("Hello, WebSockets!");
     69   }
     70 
     71   @Test public void serverMessage() throws IOException {
     72     WebSocketListener serverListener = new EmptyWebSocketListener() {
     73       @Override public void onOpen(final WebSocket webSocket, Response response) {
     74         new Thread() {
     75           @Override public void run() {
     76             try {
     77               webSocket.sendMessage(TEXT, new Buffer().writeUtf8("Hello, WebSockets!"));
     78             } catch (IOException e) {
     79               throw new AssertionError(e);
     80             }
     81           }
     82         }.start();
     83       }
     84     };
     85     server.enqueue(new MockResponse().withWebSocketUpgrade(serverListener));
     86 
     87     awaitWebSocket();
     88     listener.assertTextMessage("Hello, WebSockets!");
     89   }
     90 
     91   @Test public void clientStreamingMessage() throws IOException {
     92     WebSocketRecorder serverListener = new WebSocketRecorder();
     93     server.enqueue(new MockResponse().withWebSocketUpgrade(serverListener));
     94 
     95     WebSocket webSocket = awaitWebSocket();
     96     BufferedSink sink = webSocket.newMessageSink(TEXT);
     97     sink.writeUtf8("Hello, ").flush();
     98     sink.writeUtf8("WebSockets!").flush();
     99     sink.close();
    100 
    101     serverListener.assertTextMessage("Hello, WebSockets!");
    102   }
    103 
    104   @Test public void serverStreamingMessage() throws IOException {
    105     WebSocketListener serverListener = new EmptyWebSocketListener() {
    106       @Override public void onOpen(final WebSocket webSocket, Response response) {
    107         new Thread() {
    108           @Override public void run() {
    109             try {
    110               BufferedSink sink = webSocket.newMessageSink(TEXT);
    111               sink.writeUtf8("Hello, ").flush();
    112               sink.writeUtf8("WebSockets!").flush();
    113               sink.close();
    114             } catch (IOException e) {
    115               throw new AssertionError(e);
    116             }
    117           }
    118         }.start();
    119       }
    120     };
    121     server.enqueue(new MockResponse().withWebSocketUpgrade(serverListener));
    122 
    123     awaitWebSocket();
    124     listener.assertTextMessage("Hello, WebSockets!");
    125   }
    126 
    127   @Test public void okButNotOk() {
    128     server.enqueue(new MockResponse());
    129     awaitWebSocket();
    130     listener.assertFailure(ProtocolException.class, "Expected HTTP 101 response but was '200 OK'");
    131   }
    132 
    133   @Test public void notFound() {
    134     server.enqueue(new MockResponse().setStatus("HTTP/1.1 404 Not Found"));
    135     awaitWebSocket();
    136     listener.assertFailure(ProtocolException.class,
    137         "Expected HTTP 101 response but was '404 Not Found'");
    138   }
    139 
    140   @Test public void missingConnectionHeader() {
    141     server.enqueue(new MockResponse()
    142         .setResponseCode(101)
    143         .setHeader("Upgrade", "websocket")
    144         .setHeader("Sec-WebSocket-Accept", "ujmZX4KXZqjwy6vi1aQFH5p4Ygk="));
    145     awaitWebSocket();
    146     listener.assertFailure(ProtocolException.class,
    147         "Expected 'Connection' header value 'Upgrade' but was 'null'");
    148   }
    149 
    150   @Test public void wrongConnectionHeader() {
    151     server.enqueue(new MockResponse()
    152         .setResponseCode(101)
    153         .setHeader("Upgrade", "websocket")
    154         .setHeader("Connection", "Downgrade")
    155         .setHeader("Sec-WebSocket-Accept", "ujmZX4KXZqjwy6vi1aQFH5p4Ygk="));
    156     awaitWebSocket();
    157     listener.assertFailure(ProtocolException.class,
    158         "Expected 'Connection' header value 'Upgrade' but was 'Downgrade'");
    159   }
    160 
    161   @Test public void missingUpgradeHeader() {
    162     server.enqueue(new MockResponse()
    163         .setResponseCode(101)
    164         .setHeader("Connection", "Upgrade")
    165         .setHeader("Sec-WebSocket-Accept", "ujmZX4KXZqjwy6vi1aQFH5p4Ygk="));
    166     awaitWebSocket();
    167     listener.assertFailure(ProtocolException.class,
    168         "Expected 'Upgrade' header value 'websocket' but was 'null'");
    169   }
    170 
    171   @Test public void wrongUpgradeHeader() {
    172     server.enqueue(new MockResponse()
    173         .setResponseCode(101)
    174         .setHeader("Connection", "Upgrade")
    175         .setHeader("Upgrade", "Pepsi")
    176         .setHeader("Sec-WebSocket-Accept", "ujmZX4KXZqjwy6vi1aQFH5p4Ygk="));
    177     awaitWebSocket();
    178     listener.assertFailure(ProtocolException.class,
    179         "Expected 'Upgrade' header value 'websocket' but was 'Pepsi'");
    180   }
    181 
    182   @Test public void missingMagicHeader() {
    183     server.enqueue(new MockResponse()
    184         .setResponseCode(101)
    185         .setHeader("Connection", "Upgrade")
    186         .setHeader("Upgrade", "websocket"));
    187     awaitWebSocket();
    188     listener.assertFailure(ProtocolException.class,
    189         "Expected 'Sec-WebSocket-Accept' header value 'ujmZX4KXZqjwy6vi1aQFH5p4Ygk=' but was 'null'");
    190   }
    191 
    192   @Test public void wrongMagicHeader() {
    193     server.enqueue(new MockResponse()
    194         .setResponseCode(101)
    195         .setHeader("Connection", "Upgrade")
    196         .setHeader("Upgrade", "websocket")
    197         .setHeader("Sec-WebSocket-Accept", "magic"));
    198     awaitWebSocket();
    199     listener.assertFailure(ProtocolException.class,
    200         "Expected 'Sec-WebSocket-Accept' header value 'ujmZX4KXZqjwy6vi1aQFH5p4Ygk=' but was 'magic'");
    201   }
    202 
    203   @Test public void wsScheme() throws IOException {
    204     websocketScheme("ws");
    205   }
    206 
    207   @Test public void wsUppercaseScheme() throws IOException {
    208     websocketScheme("WS");
    209   }
    210 
    211   @Test public void wssScheme() throws IOException {
    212     server.useHttps(sslContext.getSocketFactory(), false);
    213     client.setSslSocketFactory(sslContext.getSocketFactory());
    214     client.setHostnameVerifier(new RecordingHostnameVerifier());
    215 
    216     websocketScheme("wss");
    217   }
    218 
    219   @Test public void httpsScheme() throws IOException {
    220     server.useHttps(sslContext.getSocketFactory(), false);
    221     client.setSslSocketFactory(sslContext.getSocketFactory());
    222     client.setHostnameVerifier(new RecordingHostnameVerifier());
    223 
    224     websocketScheme("https");
    225   }
    226 
    227   private void websocketScheme(String scheme) throws IOException {
    228     WebSocketRecorder serverListener = new WebSocketRecorder();
    229     server.enqueue(new MockResponse().withWebSocketUpgrade(serverListener));
    230 
    231     Request request1 = new Request.Builder()
    232         .url(scheme + "://" + server.getHostName() + ":" + server.getPort() + "/")
    233         .build();
    234 
    235     WebSocket webSocket = awaitWebSocket(request1);
    236     webSocket.sendMessage(TEXT, new Buffer().writeUtf8("abc"));
    237     serverListener.assertTextMessage("abc");
    238   }
    239 
    240   private WebSocket awaitWebSocket() {
    241     return awaitWebSocket(new Request.Builder().get().url(server.url("/")).build());
    242   }
    243 
    244   private WebSocket awaitWebSocket(Request request) {
    245     WebSocketCall call = new WebSocketCall(client, request, random);
    246 
    247     final AtomicReference<Response> responseRef = new AtomicReference<>();
    248     final AtomicReference<WebSocket> webSocketRef = new AtomicReference<>();
    249     final AtomicReference<IOException> failureRef = new AtomicReference<>();
    250     final CountDownLatch latch = new CountDownLatch(1);
    251     call.enqueue(new WebSocketListener() {
    252       @Override public void onOpen(WebSocket webSocket, Response response) {
    253         webSocketRef.set(webSocket);
    254         responseRef.set(response);
    255         latch.countDown();
    256       }
    257 
    258       @Override public void onMessage(BufferedSource payload, WebSocket.PayloadType type)
    259           throws IOException {
    260         listener.onMessage(payload, type);
    261       }
    262 
    263       @Override public void onPong(Buffer payload) {
    264         listener.onPong(payload);
    265       }
    266 
    267       @Override public void onClose(int code, String reason) {
    268         listener.onClose(code, reason);
    269       }
    270 
    271       @Override public void onFailure(IOException e, Response response) {
    272         listener.onFailure(e, null);
    273         failureRef.set(e);
    274         latch.countDown();
    275       }
    276     });
    277 
    278     try {
    279       if (!latch.await(10, TimeUnit.SECONDS)) {
    280         throw new AssertionError("Timed out.");
    281       }
    282     } catch (InterruptedException e) {
    283       throw new AssertionError(e);
    284     }
    285 
    286     return webSocketRef.get();
    287   }
    288 
    289   private static class EmptyWebSocketListener implements WebSocketListener {
    290     @Override public void onOpen(WebSocket webSocket, Response response) {
    291     }
    292 
    293     @Override public void onMessage(BufferedSource payload, WebSocket.PayloadType type)
    294         throws IOException {
    295     }
    296 
    297     @Override public void onPong(Buffer payload) {
    298     }
    299 
    300     @Override public void onClose(int code, String reason) {
    301     }
    302 
    303     @Override public void onFailure(IOException e, Response response) {
    304     }
    305   }
    306 }
    307