1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "chrome/test/chromedriver/net/websocket.h" 6 7 #include "base/base64.h" 8 #include "base/bind.h" 9 #include "base/bind_helpers.h" 10 #include "base/memory/scoped_vector.h" 11 #include "base/rand_util.h" 12 #include "base/sha1.h" 13 #include "base/strings/string_number_conversions.h" 14 #include "base/strings/stringprintf.h" 15 #include "net/base/address_list.h" 16 #include "net/base/io_buffer.h" 17 #include "net/base/ip_endpoint.h" 18 #include "net/base/net_errors.h" 19 #include "net/base/net_util.h" 20 #include "net/http/http_response_headers.h" 21 #include "net/http/http_util.h" 22 #include "net/websockets/websocket_frame.h" 23 24 WebSocket::WebSocket(const GURL& url, WebSocketListener* listener) 25 : url_(url), 26 listener_(listener), 27 state_(INITIALIZED), 28 write_buffer_(new net::DrainableIOBuffer(new net::IOBuffer(0), 0)), 29 read_buffer_(new net::IOBufferWithSize(4096)) { 30 net::IPAddressNumber address; 31 CHECK(net::ParseIPLiteralToNumber(url_.HostNoBrackets(), &address)); 32 int port = 80; 33 base::StringToInt(url_.port(), &port); 34 net::AddressList addresses(net::IPEndPoint(address, port)); 35 net::NetLog::Source source; 36 socket_.reset(new net::TCPClientSocket(addresses, NULL, source)); 37 } 38 39 WebSocket::~WebSocket() { 40 CHECK(thread_checker_.CalledOnValidThread()); 41 } 42 43 void WebSocket::Connect(const net::CompletionCallback& callback) { 44 CHECK(thread_checker_.CalledOnValidThread()); 45 CHECK_EQ(INITIALIZED, state_); 46 state_ = CONNECTING; 47 connect_callback_ = callback; 48 int code = socket_->Connect(base::Bind( 49 &WebSocket::OnSocketConnect, base::Unretained(this))); 50 if (code != net::ERR_IO_PENDING) 51 OnSocketConnect(code); 52 } 53 54 bool WebSocket::Send(const std::string& message) { 55 CHECK(thread_checker_.CalledOnValidThread()); 56 if (state_ != OPEN) 57 return false; 58 59 net::WebSocketFrameHeader header(net::WebSocketFrameHeader::kOpCodeText); 60 header.final = true; 61 header.masked = true; 62 header.payload_length = message.length(); 63 int header_size = net::GetWebSocketFrameHeaderSize(header); 64 net::WebSocketMaskingKey masking_key = net::GenerateWebSocketMaskingKey(); 65 std::string header_str; 66 header_str.resize(header_size); 67 CHECK_EQ(header_size, net::WriteWebSocketFrameHeader( 68 header, &masking_key, &header_str[0], header_str.length())); 69 70 std::string masked_message = message; 71 net::MaskWebSocketFramePayload( 72 masking_key, 0, &masked_message[0], masked_message.length()); 73 Write(header_str + masked_message); 74 return true; 75 } 76 77 void WebSocket::OnSocketConnect(int code) { 78 if (code != net::OK) { 79 Close(code); 80 return; 81 } 82 83 CHECK(base::Base64Encode(base::RandBytesAsString(16), &sec_key_)); 84 std::string handshake = base::StringPrintf( 85 "GET %s HTTP/1.1\r\n" 86 "Host: %s\r\n" 87 "Upgrade: websocket\r\n" 88 "Connection: Upgrade\r\n" 89 "Sec-WebSocket-Key: %s\r\n" 90 "Sec-WebSocket-Version: 13\r\n" 91 "Pragma: no-cache\r\n" 92 "Cache-Control: no-cache\r\n" 93 "\r\n", 94 url_.path().c_str(), 95 url_.host().c_str(), 96 sec_key_.c_str()); 97 Write(handshake); 98 Read(); 99 } 100 101 void WebSocket::Write(const std::string& data) { 102 pending_write_ += data; 103 if (!write_buffer_->BytesRemaining()) 104 ContinueWritingIfNecessary(); 105 } 106 107 void WebSocket::OnWrite(int code) { 108 if (!socket_->IsConnected()) { 109 // Supposedly if |StreamSocket| is closed, the error code may be undefined. 110 Close(net::ERR_FAILED); 111 return; 112 } 113 if (code < 0) { 114 Close(code); 115 return; 116 } 117 118 write_buffer_->DidConsume(code); 119 ContinueWritingIfNecessary(); 120 } 121 122 void WebSocket::ContinueWritingIfNecessary() { 123 if (!write_buffer_->BytesRemaining()) { 124 if (pending_write_.empty()) 125 return; 126 write_buffer_ = new net::DrainableIOBuffer( 127 new net::StringIOBuffer(pending_write_), 128 pending_write_.length()); 129 pending_write_.clear(); 130 } 131 int code = 132 socket_->Write(write_buffer_.get(), 133 write_buffer_->BytesRemaining(), 134 base::Bind(&WebSocket::OnWrite, base::Unretained(this))); 135 if (code != net::ERR_IO_PENDING) 136 OnWrite(code); 137 } 138 139 void WebSocket::Read() { 140 int code = 141 socket_->Read(read_buffer_.get(), 142 read_buffer_->size(), 143 base::Bind(&WebSocket::OnRead, base::Unretained(this))); 144 if (code != net::ERR_IO_PENDING) 145 OnRead(code); 146 } 147 148 void WebSocket::OnRead(int code) { 149 if (code <= 0) { 150 Close(code ? code : net::ERR_FAILED); 151 return; 152 } 153 154 if (state_ == CONNECTING) 155 OnReadDuringHandshake(read_buffer_->data(), code); 156 else if (state_ == OPEN) 157 OnReadDuringOpen(read_buffer_->data(), code); 158 159 if (state_ != CLOSED) 160 Read(); 161 } 162 163 void WebSocket::OnReadDuringHandshake(const char* data, int len) { 164 handshake_response_ += std::string(data, len); 165 int headers_end = net::HttpUtil::LocateEndOfHeaders( 166 handshake_response_.data(), handshake_response_.size(), 0); 167 if (headers_end == -1) 168 return; 169 170 const char kMagicKey[] = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; 171 std::string websocket_accept; 172 CHECK(base::Base64Encode(base::SHA1HashString(sec_key_ + kMagicKey), 173 &websocket_accept)); 174 scoped_refptr<net::HttpResponseHeaders> headers( 175 new net::HttpResponseHeaders( 176 net::HttpUtil::AssembleRawHeaders( 177 handshake_response_.data(), headers_end))); 178 if (headers->response_code() != 101 || 179 !headers->HasHeaderValue("Upgrade", "WebSocket") || 180 !headers->HasHeaderValue("Connection", "Upgrade") || 181 !headers->HasHeaderValue("Sec-WebSocket-Accept", websocket_accept)) { 182 Close(net::ERR_FAILED); 183 return; 184 } 185 std::string leftover_message = handshake_response_.substr(headers_end); 186 handshake_response_.clear(); 187 sec_key_.clear(); 188 state_ = OPEN; 189 InvokeConnectCallback(net::OK); 190 if (!leftover_message.empty()) 191 OnReadDuringOpen(leftover_message.c_str(), leftover_message.length()); 192 } 193 194 void WebSocket::OnReadDuringOpen(const char* data, int len) { 195 ScopedVector<net::WebSocketFrameChunk> frame_chunks; 196 CHECK(parser_.Decode(data, len, &frame_chunks)); 197 for (size_t i = 0; i < frame_chunks.size(); ++i) { 198 scoped_refptr<net::IOBufferWithSize> buffer = frame_chunks[i]->data; 199 if (buffer.get()) 200 next_message_ += std::string(buffer->data(), buffer->size()); 201 if (frame_chunks[i]->final_chunk) { 202 listener_->OnMessageReceived(next_message_); 203 next_message_.clear(); 204 } 205 } 206 } 207 208 void WebSocket::InvokeConnectCallback(int code) { 209 net::CompletionCallback temp = connect_callback_; 210 connect_callback_.Reset(); 211 CHECK(!temp.is_null()); 212 temp.Run(code); 213 } 214 215 void WebSocket::Close(int code) { 216 socket_->Disconnect(); 217 if (!connect_callback_.is_null()) 218 InvokeConnectCallback(code); 219 if (state_ == OPEN) 220 listener_->OnClose(); 221 222 state_ = CLOSED; 223 } 224 225 226