1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "chrome/test/chromedriver/net/websocket.h" 6 7 #include <string.h> 8 9 #include "base/base64.h" 10 #include "base/bind.h" 11 #include "base/bind_helpers.h" 12 #include "base/memory/scoped_vector.h" 13 #include "base/rand_util.h" 14 #include "base/sha1.h" 15 #include "base/strings/string_number_conversions.h" 16 #include "base/strings/stringprintf.h" 17 #include "net/base/address_list.h" 18 #include "net/base/io_buffer.h" 19 #include "net/base/ip_endpoint.h" 20 #include "net/base/net_errors.h" 21 #include "net/base/net_util.h" 22 #include "net/base/sys_addrinfo.h" 23 #include "net/http/http_response_headers.h" 24 #include "net/http/http_util.h" 25 #include "net/websockets/websocket_frame.h" 26 27 #if defined(OS_WIN) 28 #include <Winsock2.h> 29 #endif 30 31 namespace { 32 33 bool ResolveHost(const std::string& host, net::IPAddressNumber* address) { 34 struct addrinfo hints; 35 memset(&hints, 0, sizeof(hints)); 36 hints.ai_family = AF_UNSPEC; 37 hints.ai_socktype = SOCK_STREAM; 38 39 struct addrinfo* result; 40 if (getaddrinfo(host.c_str(), NULL, &hints, &result)) 41 return false; 42 43 for (struct addrinfo* addr = result; addr; addr = addr->ai_next) { 44 if (addr->ai_family == AF_INET || addr->ai_family == AF_INET6) { 45 net::IPEndPoint end_point; 46 if (!end_point.FromSockAddr(addr->ai_addr, addr->ai_addrlen)) { 47 freeaddrinfo(result); 48 return false; 49 } 50 *address = end_point.address(); 51 } 52 } 53 freeaddrinfo(result); 54 return true; 55 } 56 57 } // namespace 58 59 WebSocket::WebSocket(const GURL& url, WebSocketListener* listener) 60 : url_(url), 61 listener_(listener), 62 state_(INITIALIZED), 63 write_buffer_(new net::DrainableIOBuffer(new net::IOBuffer(0), 0)), 64 read_buffer_(new net::IOBufferWithSize(4096)) {} 65 66 WebSocket::~WebSocket() { 67 CHECK(thread_checker_.CalledOnValidThread()); 68 } 69 70 void WebSocket::Connect(const net::CompletionCallback& callback) { 71 CHECK(thread_checker_.CalledOnValidThread()); 72 CHECK_EQ(INITIALIZED, state_); 73 74 net::IPAddressNumber address; 75 if (!net::ParseIPLiteralToNumber(url_.HostNoBrackets(), &address)) { 76 if (!ResolveHost(url_.HostNoBrackets(), &address)) { 77 callback.Run(net::ERR_ADDRESS_UNREACHABLE); 78 return; 79 } 80 } 81 int port = 80; 82 base::StringToInt(url_.port(), &port); 83 net::AddressList addresses(net::IPEndPoint(address, port)); 84 net::NetLog::Source source; 85 socket_.reset(new net::TCPClientSocket(addresses, NULL, source)); 86 87 state_ = CONNECTING; 88 connect_callback_ = callback; 89 int code = socket_->Connect(base::Bind( 90 &WebSocket::OnSocketConnect, base::Unretained(this))); 91 if (code != net::ERR_IO_PENDING) 92 OnSocketConnect(code); 93 } 94 95 bool WebSocket::Send(const std::string& message) { 96 CHECK(thread_checker_.CalledOnValidThread()); 97 if (state_ != OPEN) 98 return false; 99 100 net::WebSocketFrameHeader header(net::WebSocketFrameHeader::kOpCodeText); 101 header.final = true; 102 header.masked = true; 103 header.payload_length = message.length(); 104 int header_size = net::GetWebSocketFrameHeaderSize(header); 105 net::WebSocketMaskingKey masking_key = net::GenerateWebSocketMaskingKey(); 106 std::string header_str; 107 header_str.resize(header_size); 108 CHECK_EQ(header_size, net::WriteWebSocketFrameHeader( 109 header, &masking_key, &header_str[0], header_str.length())); 110 111 std::string masked_message = message; 112 net::MaskWebSocketFramePayload( 113 masking_key, 0, &masked_message[0], masked_message.length()); 114 Write(header_str + masked_message); 115 return true; 116 } 117 118 void WebSocket::OnSocketConnect(int code) { 119 if (code != net::OK) { 120 Close(code); 121 return; 122 } 123 124 base::Base64Encode(base::RandBytesAsString(16), &sec_key_); 125 std::string handshake = base::StringPrintf( 126 "GET %s HTTP/1.1\r\n" 127 "Host: %s\r\n" 128 "Upgrade: websocket\r\n" 129 "Connection: Upgrade\r\n" 130 "Sec-WebSocket-Key: %s\r\n" 131 "Sec-WebSocket-Version: 13\r\n" 132 "Pragma: no-cache\r\n" 133 "Cache-Control: no-cache\r\n" 134 "\r\n", 135 url_.path().c_str(), 136 url_.host().c_str(), 137 sec_key_.c_str()); 138 Write(handshake); 139 Read(); 140 } 141 142 void WebSocket::Write(const std::string& data) { 143 pending_write_ += data; 144 if (!write_buffer_->BytesRemaining()) 145 ContinueWritingIfNecessary(); 146 } 147 148 void WebSocket::OnWrite(int code) { 149 if (!socket_->IsConnected()) { 150 // Supposedly if |StreamSocket| is closed, the error code may be undefined. 151 Close(net::ERR_FAILED); 152 return; 153 } 154 if (code < 0) { 155 Close(code); 156 return; 157 } 158 159 write_buffer_->DidConsume(code); 160 ContinueWritingIfNecessary(); 161 } 162 163 void WebSocket::ContinueWritingIfNecessary() { 164 if (!write_buffer_->BytesRemaining()) { 165 if (pending_write_.empty()) 166 return; 167 write_buffer_ = new net::DrainableIOBuffer( 168 new net::StringIOBuffer(pending_write_), 169 pending_write_.length()); 170 pending_write_.clear(); 171 } 172 int code = 173 socket_->Write(write_buffer_.get(), 174 write_buffer_->BytesRemaining(), 175 base::Bind(&WebSocket::OnWrite, base::Unretained(this))); 176 if (code != net::ERR_IO_PENDING) 177 OnWrite(code); 178 } 179 180 void WebSocket::Read() { 181 int code = 182 socket_->Read(read_buffer_.get(), 183 read_buffer_->size(), 184 base::Bind(&WebSocket::OnRead, base::Unretained(this))); 185 if (code != net::ERR_IO_PENDING) 186 OnRead(code); 187 } 188 189 void WebSocket::OnRead(int code) { 190 if (code <= 0) { 191 Close(code ? code : net::ERR_FAILED); 192 return; 193 } 194 195 if (state_ == CONNECTING) 196 OnReadDuringHandshake(read_buffer_->data(), code); 197 else if (state_ == OPEN) 198 OnReadDuringOpen(read_buffer_->data(), code); 199 200 if (state_ != CLOSED) 201 Read(); 202 } 203 204 void WebSocket::OnReadDuringHandshake(const char* data, int len) { 205 handshake_response_ += std::string(data, len); 206 int headers_end = net::HttpUtil::LocateEndOfHeaders( 207 handshake_response_.data(), handshake_response_.size(), 0); 208 if (headers_end == -1) 209 return; 210 211 const char kMagicKey[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; 212 std::string websocket_accept; 213 base::Base64Encode(base::SHA1HashString(sec_key_ + kMagicKey), 214 &websocket_accept); 215 scoped_refptr<net::HttpResponseHeaders> headers( 216 new net::HttpResponseHeaders( 217 net::HttpUtil::AssembleRawHeaders( 218 handshake_response_.data(), headers_end))); 219 if (headers->response_code() != 101 || 220 !headers->HasHeaderValue("Upgrade", "WebSocket") || 221 !headers->HasHeaderValue("Connection", "Upgrade") || 222 !headers->HasHeaderValue("Sec-WebSocket-Accept", websocket_accept)) { 223 Close(net::ERR_FAILED); 224 return; 225 } 226 std::string leftover_message = handshake_response_.substr(headers_end); 227 handshake_response_.clear(); 228 sec_key_.clear(); 229 state_ = OPEN; 230 InvokeConnectCallback(net::OK); 231 if (!leftover_message.empty()) 232 OnReadDuringOpen(leftover_message.c_str(), leftover_message.length()); 233 } 234 235 void WebSocket::OnReadDuringOpen(const char* data, int len) { 236 ScopedVector<net::WebSocketFrameChunk> frame_chunks; 237 CHECK(parser_.Decode(data, len, &frame_chunks)); 238 for (size_t i = 0; i < frame_chunks.size(); ++i) { 239 scoped_refptr<net::IOBufferWithSize> buffer = frame_chunks[i]->data; 240 if (buffer.get()) 241 next_message_ += std::string(buffer->data(), buffer->size()); 242 if (frame_chunks[i]->final_chunk) { 243 listener_->OnMessageReceived(next_message_); 244 next_message_.clear(); 245 } 246 } 247 } 248 249 void WebSocket::InvokeConnectCallback(int code) { 250 net::CompletionCallback temp = connect_callback_; 251 connect_callback_.Reset(); 252 CHECK(!temp.is_null()); 253 temp.Run(code); 254 } 255 256 void WebSocket::Close(int code) { 257 socket_->Disconnect(); 258 if (!connect_callback_.is_null()) 259 InvokeConnectCallback(code); 260 if (state_ == OPEN) 261 listener_->OnClose(); 262 263 state_ = CLOSED; 264 } 265