1 // Copyright (c) 2009 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "net/websockets/websocket_throttle.h" 6 7 #include <string> 8 9 #include "base/message_loop.h" 10 #include "base/ref_counted.h" 11 #include "base/singleton.h" 12 #include "base/string_util.h" 13 #include "net/base/io_buffer.h" 14 #include "net/base/sys_addrinfo.h" 15 #include "net/socket_stream/socket_stream.h" 16 17 namespace net { 18 19 static std::string AddrinfoToHashkey(const struct addrinfo* addrinfo) { 20 switch (addrinfo->ai_family) { 21 case AF_INET: { 22 const struct sockaddr_in* const addr = 23 reinterpret_cast<const sockaddr_in*>(addrinfo->ai_addr); 24 return StringPrintf("%d:%s", 25 addrinfo->ai_family, 26 HexEncode(&addr->sin_addr, 4).c_str()); 27 } 28 case AF_INET6: { 29 const struct sockaddr_in6* const addr6 = 30 reinterpret_cast<const sockaddr_in6*>(addrinfo->ai_addr); 31 return StringPrintf("%d:%s", 32 addrinfo->ai_family, 33 HexEncode(&addr6->sin6_addr, 34 sizeof(addr6->sin6_addr)).c_str()); 35 } 36 default: 37 return StringPrintf("%d:%s", 38 addrinfo->ai_family, 39 HexEncode(addrinfo->ai_addr, 40 addrinfo->ai_addrlen).c_str()); 41 } 42 } 43 44 // State for WebSocket protocol on each SocketStream. 45 // This is owned in SocketStream as UserData keyed by WebSocketState::kKeyName. 46 // This is alive between connection starts and handshake is finished. 47 // In this class, it doesn't check actual handshake finishes, but only checks 48 // end of header is found in read data. 49 class WebSocketThrottle::WebSocketState : public SocketStream::UserData { 50 public: 51 explicit WebSocketState(const AddressList& addrs) 52 : address_list_(addrs), 53 callback_(NULL), 54 waiting_(false), 55 handshake_finished_(false), 56 buffer_(NULL) { 57 } 58 ~WebSocketState() {} 59 60 int OnStartOpenConnection(CompletionCallback* callback) { 61 DCHECK(!callback_); 62 if (!waiting_) 63 return OK; 64 callback_ = callback; 65 return ERR_IO_PENDING; 66 } 67 68 int OnRead(const char* data, int len, CompletionCallback* callback) { 69 DCHECK(!waiting_); 70 DCHECK(!callback_); 71 DCHECK(!handshake_finished_); 72 static const int kBufferSize = 8129; 73 74 if (!buffer_) { 75 // Fast path. 76 int eoh = HttpUtil::LocateEndOfHeaders(data, len, 0); 77 if (eoh > 0) { 78 handshake_finished_ = true; 79 return OK; 80 } 81 buffer_ = new GrowableIOBuffer(); 82 buffer_->SetCapacity(kBufferSize); 83 } else if (buffer_->RemainingCapacity() < len) { 84 buffer_->SetCapacity(buffer_->capacity() + kBufferSize); 85 } 86 memcpy(buffer_->data(), data, len); 87 buffer_->set_offset(buffer_->offset() + len); 88 89 int eoh = HttpUtil::LocateEndOfHeaders(buffer_->StartOfBuffer(), 90 buffer_->offset(), 0); 91 handshake_finished_ = (eoh > 0); 92 return OK; 93 } 94 95 const AddressList& address_list() const { return address_list_; } 96 void SetWaiting() { waiting_ = true; } 97 bool IsWaiting() const { return waiting_; } 98 bool HandshakeFinished() const { return handshake_finished_; } 99 void Wakeup() { 100 waiting_ = false; 101 // We wrap |callback_| to keep this alive while this is released. 102 scoped_refptr<CompletionCallbackRunner> runner = 103 new CompletionCallbackRunner(callback_); 104 callback_ = NULL; 105 MessageLoopForIO::current()->PostTask( 106 FROM_HERE, 107 NewRunnableMethod(runner.get(), 108 &CompletionCallbackRunner::Run)); 109 } 110 111 static const char* kKeyName; 112 113 private: 114 class CompletionCallbackRunner 115 : public base::RefCountedThreadSafe<CompletionCallbackRunner> { 116 public: 117 explicit CompletionCallbackRunner(CompletionCallback* callback) 118 : callback_(callback) { 119 DCHECK(callback_); 120 } 121 void Run() { 122 callback_->Run(OK); 123 } 124 private: 125 friend class base::RefCountedThreadSafe<CompletionCallbackRunner>; 126 127 virtual ~CompletionCallbackRunner() {} 128 129 CompletionCallback* callback_; 130 131 DISALLOW_COPY_AND_ASSIGN(CompletionCallbackRunner); 132 }; 133 134 const AddressList& address_list_; 135 136 CompletionCallback* callback_; 137 // True if waiting another websocket connection is established. 138 // False if the websocket is performing handshaking. 139 bool waiting_; 140 141 // True if the websocket handshake is completed. 142 // If true, it will be removed from queue and deleted from the SocketStream 143 // UserData soon. 144 bool handshake_finished_; 145 146 // Buffer for read data to check handshake response message. 147 scoped_refptr<GrowableIOBuffer> buffer_; 148 149 DISALLOW_COPY_AND_ASSIGN(WebSocketState); 150 }; 151 152 const char* WebSocketThrottle::WebSocketState::kKeyName = "WebSocketState"; 153 154 WebSocketThrottle::WebSocketThrottle() { 155 SocketStreamThrottle::RegisterSocketStreamThrottle("ws", this); 156 SocketStreamThrottle::RegisterSocketStreamThrottle("wss", this); 157 } 158 159 WebSocketThrottle::~WebSocketThrottle() { 160 DCHECK(queue_.empty()); 161 DCHECK(addr_map_.empty()); 162 } 163 164 int WebSocketThrottle::OnStartOpenConnection( 165 SocketStream* socket, CompletionCallback* callback) { 166 WebSocketState* state = new WebSocketState(socket->address_list()); 167 PutInQueue(socket, state); 168 return state->OnStartOpenConnection(callback); 169 } 170 171 int WebSocketThrottle::OnRead(SocketStream* socket, 172 const char* data, int len, 173 CompletionCallback* callback) { 174 WebSocketState* state = static_cast<WebSocketState*>( 175 socket->GetUserData(WebSocketState::kKeyName)); 176 // If no state, handshake was already completed. Do nothing. 177 if (!state) 178 return OK; 179 180 int result = state->OnRead(data, len, callback); 181 if (state->HandshakeFinished()) { 182 RemoveFromQueue(socket, state); 183 WakeupSocketIfNecessary(); 184 } 185 return result; 186 } 187 188 int WebSocketThrottle::OnWrite(SocketStream* socket, 189 const char* data, int len, 190 CompletionCallback* callback) { 191 // Do nothing. 192 return OK; 193 } 194 195 void WebSocketThrottle::OnClose(SocketStream* socket) { 196 WebSocketState* state = static_cast<WebSocketState*>( 197 socket->GetUserData(WebSocketState::kKeyName)); 198 if (!state) 199 return; 200 RemoveFromQueue(socket, state); 201 WakeupSocketIfNecessary(); 202 } 203 204 void WebSocketThrottle::PutInQueue(SocketStream* socket, 205 WebSocketState* state) { 206 socket->SetUserData(WebSocketState::kKeyName, state); 207 queue_.push_back(state); 208 const AddressList& address_list = socket->address_list(); 209 for (const struct addrinfo* addrinfo = address_list.head(); 210 addrinfo != NULL; 211 addrinfo = addrinfo->ai_next) { 212 std::string addrkey = AddrinfoToHashkey(addrinfo); 213 ConnectingAddressMap::iterator iter = addr_map_.find(addrkey); 214 if (iter == addr_map_.end()) { 215 ConnectingQueue* queue = new ConnectingQueue(); 216 queue->push_back(state); 217 addr_map_[addrkey] = queue; 218 } else { 219 iter->second->push_back(state); 220 state->SetWaiting(); 221 } 222 } 223 } 224 225 void WebSocketThrottle::RemoveFromQueue(SocketStream* socket, 226 WebSocketState* state) { 227 const AddressList& address_list = socket->address_list(); 228 for (const struct addrinfo* addrinfo = address_list.head(); 229 addrinfo != NULL; 230 addrinfo = addrinfo->ai_next) { 231 std::string addrkey = AddrinfoToHashkey(addrinfo); 232 ConnectingAddressMap::iterator iter = addr_map_.find(addrkey); 233 DCHECK(iter != addr_map_.end()); 234 ConnectingQueue* queue = iter->second; 235 DCHECK(state == queue->front()); 236 queue->pop_front(); 237 if (queue->empty()) { 238 delete queue; 239 addr_map_.erase(iter); 240 } 241 } 242 for (ConnectingQueue::iterator iter = queue_.begin(); 243 iter != queue_.end(); 244 ++iter) { 245 if (*iter == state) { 246 queue_.erase(iter); 247 break; 248 } 249 } 250 socket->SetUserData(WebSocketState::kKeyName, NULL); 251 } 252 253 void WebSocketThrottle::WakeupSocketIfNecessary() { 254 for (ConnectingQueue::iterator iter = queue_.begin(); 255 iter != queue_.end(); 256 ++iter) { 257 WebSocketState* state = *iter; 258 if (!state->IsWaiting()) 259 continue; 260 261 bool should_wakeup = true; 262 const AddressList& address_list = state->address_list(); 263 for (const struct addrinfo* addrinfo = address_list.head(); 264 addrinfo != NULL; 265 addrinfo = addrinfo->ai_next) { 266 std::string addrkey = AddrinfoToHashkey(addrinfo); 267 ConnectingAddressMap::iterator iter = addr_map_.find(addrkey); 268 DCHECK(iter != addr_map_.end()); 269 ConnectingQueue* queue = iter->second; 270 if (state != queue->front()) { 271 should_wakeup = false; 272 break; 273 } 274 } 275 if (should_wakeup) 276 state->Wakeup(); 277 } 278 } 279 280 /* static */ 281 void WebSocketThrottle::Init() { 282 Singleton<WebSocketThrottle>::get(); 283 } 284 285 } // namespace net 286