1 /* 2 * Copyright (C) 2014 Square, Inc. 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 package com.squareup.okhttp.ws; 17 18 import com.squareup.okhttp.OkHttpClient; 19 import com.squareup.okhttp.Request; 20 import com.squareup.okhttp.Response; 21 import com.squareup.okhttp.internal.SslContextBuilder; 22 import com.squareup.okhttp.mockwebserver.MockResponse; 23 import com.squareup.okhttp.mockwebserver.MockWebServer; 24 import com.squareup.okhttp.testing.RecordingHostnameVerifier; 25 import java.io.IOException; 26 import java.net.ProtocolException; 27 import java.util.Random; 28 import java.util.concurrent.CountDownLatch; 29 import java.util.concurrent.TimeUnit; 30 import java.util.concurrent.atomic.AtomicReference; 31 import javax.net.ssl.SSLContext; 32 import okio.Buffer; 33 import okio.BufferedSink; 34 import okio.BufferedSource; 35 import org.junit.After; 36 import org.junit.Rule; 37 import org.junit.Test; 38 39 import static com.squareup.okhttp.ws.WebSocket.PayloadType.TEXT; 40 41 public final class WebSocketCallTest { 42 @Rule public final MockWebServer server = new MockWebServer(); 43 44 private final SSLContext sslContext = SslContextBuilder.localhost(); 45 private final WebSocketRecorder listener = new WebSocketRecorder(); 46 private final OkHttpClient client = new OkHttpClient(); 47 private final Random random = new Random(0); 48 49 @After public void tearDown() { 50 listener.assertExhausted(); 51 } 52 53 @Test public void clientPingPong() throws IOException { 54 WebSocketListener serverListener = new EmptyWebSocketListener(); 55 server.enqueue(new MockResponse().withWebSocketUpgrade(serverListener)); 56 57 WebSocket webSocket = awaitWebSocket(); 58 webSocket.sendPing(new Buffer().writeUtf8("Hello, WebSockets!")); 59 listener.assertPong(new Buffer().writeUtf8("Hello, WebSockets!")); 60 } 61 62 @Test public void clientMessage() throws IOException { 63 WebSocketRecorder serverListener = new WebSocketRecorder(); 64 server.enqueue(new MockResponse().withWebSocketUpgrade(serverListener)); 65 66 WebSocket webSocket = awaitWebSocket(); 67 webSocket.sendMessage(TEXT, new Buffer().writeUtf8("Hello, WebSockets!")); 68 serverListener.assertTextMessage("Hello, WebSockets!"); 69 } 70 71 @Test public void serverMessage() throws IOException { 72 WebSocketListener serverListener = new EmptyWebSocketListener() { 73 @Override public void onOpen(final WebSocket webSocket, Response response) { 74 new Thread() { 75 @Override public void run() { 76 try { 77 webSocket.sendMessage(TEXT, new Buffer().writeUtf8("Hello, WebSockets!")); 78 } catch (IOException e) { 79 throw new AssertionError(e); 80 } 81 } 82 }.start(); 83 } 84 }; 85 server.enqueue(new MockResponse().withWebSocketUpgrade(serverListener)); 86 87 awaitWebSocket(); 88 listener.assertTextMessage("Hello, WebSockets!"); 89 } 90 91 @Test public void clientStreamingMessage() throws IOException { 92 WebSocketRecorder serverListener = new WebSocketRecorder(); 93 server.enqueue(new MockResponse().withWebSocketUpgrade(serverListener)); 94 95 WebSocket webSocket = awaitWebSocket(); 96 BufferedSink sink = webSocket.newMessageSink(TEXT); 97 sink.writeUtf8("Hello, ").flush(); 98 sink.writeUtf8("WebSockets!").flush(); 99 sink.close(); 100 101 serverListener.assertTextMessage("Hello, WebSockets!"); 102 } 103 104 @Test public void serverStreamingMessage() throws IOException { 105 WebSocketListener serverListener = new EmptyWebSocketListener() { 106 @Override public void onOpen(final WebSocket webSocket, Response response) { 107 new Thread() { 108 @Override public void run() { 109 try { 110 BufferedSink sink = webSocket.newMessageSink(TEXT); 111 sink.writeUtf8("Hello, ").flush(); 112 sink.writeUtf8("WebSockets!").flush(); 113 sink.close(); 114 } catch (IOException e) { 115 throw new AssertionError(e); 116 } 117 } 118 }.start(); 119 } 120 }; 121 server.enqueue(new MockResponse().withWebSocketUpgrade(serverListener)); 122 123 awaitWebSocket(); 124 listener.assertTextMessage("Hello, WebSockets!"); 125 } 126 127 @Test public void okButNotOk() { 128 server.enqueue(new MockResponse()); 129 awaitWebSocket(); 130 listener.assertFailure(ProtocolException.class, "Expected HTTP 101 response but was '200 OK'"); 131 } 132 133 @Test public void notFound() { 134 server.enqueue(new MockResponse().setStatus("HTTP/1.1 404 Not Found")); 135 awaitWebSocket(); 136 listener.assertFailure(ProtocolException.class, 137 "Expected HTTP 101 response but was '404 Not Found'"); 138 } 139 140 @Test public void missingConnectionHeader() { 141 server.enqueue(new MockResponse() 142 .setResponseCode(101) 143 .setHeader("Upgrade", "websocket") 144 .setHeader("Sec-WebSocket-Accept", "ujmZX4KXZqjwy6vi1aQFH5p4Ygk=")); 145 awaitWebSocket(); 146 listener.assertFailure(ProtocolException.class, 147 "Expected 'Connection' header value 'Upgrade' but was 'null'"); 148 } 149 150 @Test public void wrongConnectionHeader() { 151 server.enqueue(new MockResponse() 152 .setResponseCode(101) 153 .setHeader("Upgrade", "websocket") 154 .setHeader("Connection", "Downgrade") 155 .setHeader("Sec-WebSocket-Accept", "ujmZX4KXZqjwy6vi1aQFH5p4Ygk=")); 156 awaitWebSocket(); 157 listener.assertFailure(ProtocolException.class, 158 "Expected 'Connection' header value 'Upgrade' but was 'Downgrade'"); 159 } 160 161 @Test public void missingUpgradeHeader() { 162 server.enqueue(new MockResponse() 163 .setResponseCode(101) 164 .setHeader("Connection", "Upgrade") 165 .setHeader("Sec-WebSocket-Accept", "ujmZX4KXZqjwy6vi1aQFH5p4Ygk=")); 166 awaitWebSocket(); 167 listener.assertFailure(ProtocolException.class, 168 "Expected 'Upgrade' header value 'websocket' but was 'null'"); 169 } 170 171 @Test public void wrongUpgradeHeader() { 172 server.enqueue(new MockResponse() 173 .setResponseCode(101) 174 .setHeader("Connection", "Upgrade") 175 .setHeader("Upgrade", "Pepsi") 176 .setHeader("Sec-WebSocket-Accept", "ujmZX4KXZqjwy6vi1aQFH5p4Ygk=")); 177 awaitWebSocket(); 178 listener.assertFailure(ProtocolException.class, 179 "Expected 'Upgrade' header value 'websocket' but was 'Pepsi'"); 180 } 181 182 @Test public void missingMagicHeader() { 183 server.enqueue(new MockResponse() 184 .setResponseCode(101) 185 .setHeader("Connection", "Upgrade") 186 .setHeader("Upgrade", "websocket")); 187 awaitWebSocket(); 188 listener.assertFailure(ProtocolException.class, 189 "Expected 'Sec-WebSocket-Accept' header value 'ujmZX4KXZqjwy6vi1aQFH5p4Ygk=' but was 'null'"); 190 } 191 192 @Test public void wrongMagicHeader() { 193 server.enqueue(new MockResponse() 194 .setResponseCode(101) 195 .setHeader("Connection", "Upgrade") 196 .setHeader("Upgrade", "websocket") 197 .setHeader("Sec-WebSocket-Accept", "magic")); 198 awaitWebSocket(); 199 listener.assertFailure(ProtocolException.class, 200 "Expected 'Sec-WebSocket-Accept' header value 'ujmZX4KXZqjwy6vi1aQFH5p4Ygk=' but was 'magic'"); 201 } 202 203 @Test public void wsScheme() throws IOException { 204 websocketScheme("ws"); 205 } 206 207 @Test public void wsUppercaseScheme() throws IOException { 208 websocketScheme("WS"); 209 } 210 211 @Test public void wssScheme() throws IOException { 212 server.useHttps(sslContext.getSocketFactory(), false); 213 client.setSslSocketFactory(sslContext.getSocketFactory()); 214 client.setHostnameVerifier(new RecordingHostnameVerifier()); 215 216 websocketScheme("wss"); 217 } 218 219 @Test public void httpsScheme() throws IOException { 220 server.useHttps(sslContext.getSocketFactory(), false); 221 client.setSslSocketFactory(sslContext.getSocketFactory()); 222 client.setHostnameVerifier(new RecordingHostnameVerifier()); 223 224 websocketScheme("https"); 225 } 226 227 private void websocketScheme(String scheme) throws IOException { 228 WebSocketRecorder serverListener = new WebSocketRecorder(); 229 server.enqueue(new MockResponse().withWebSocketUpgrade(serverListener)); 230 231 Request request1 = new Request.Builder() 232 .url(scheme + "://" + server.getHostName() + ":" + server.getPort() + "/") 233 .build(); 234 235 WebSocket webSocket = awaitWebSocket(request1); 236 webSocket.sendMessage(TEXT, new Buffer().writeUtf8("abc")); 237 serverListener.assertTextMessage("abc"); 238 } 239 240 private WebSocket awaitWebSocket() { 241 return awaitWebSocket(new Request.Builder().get().url(server.url("/")).build()); 242 } 243 244 private WebSocket awaitWebSocket(Request request) { 245 WebSocketCall call = new WebSocketCall(client, request, random); 246 247 final AtomicReference<Response> responseRef = new AtomicReference<>(); 248 final AtomicReference<WebSocket> webSocketRef = new AtomicReference<>(); 249 final AtomicReference<IOException> failureRef = new AtomicReference<>(); 250 final CountDownLatch latch = new CountDownLatch(1); 251 call.enqueue(new WebSocketListener() { 252 @Override public void onOpen(WebSocket webSocket, Response response) { 253 webSocketRef.set(webSocket); 254 responseRef.set(response); 255 latch.countDown(); 256 } 257 258 @Override public void onMessage(BufferedSource payload, WebSocket.PayloadType type) 259 throws IOException { 260 listener.onMessage(payload, type); 261 } 262 263 @Override public void onPong(Buffer payload) { 264 listener.onPong(payload); 265 } 266 267 @Override public void onClose(int code, String reason) { 268 listener.onClose(code, reason); 269 } 270 271 @Override public void onFailure(IOException e, Response response) { 272 listener.onFailure(e, null); 273 failureRef.set(e); 274 latch.countDown(); 275 } 276 }); 277 278 try { 279 if (!latch.await(10, TimeUnit.SECONDS)) { 280 throw new AssertionError("Timed out."); 281 } 282 } catch (InterruptedException e) { 283 throw new AssertionError(e); 284 } 285 286 return webSocketRef.get(); 287 } 288 289 private static class EmptyWebSocketListener implements WebSocketListener { 290 @Override public void onOpen(WebSocket webSocket, Response response) { 291 } 292 293 @Override public void onMessage(BufferedSource payload, WebSocket.PayloadType type) 294 throws IOException { 295 } 296 297 @Override public void onPong(Buffer payload) { 298 } 299 300 @Override public void onClose(int code, String reason) { 301 } 302 303 @Override public void onFailure(IOException e, Response response) { 304 } 305 } 306 } 307