1 // Copyright (c) 2009 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include <algorithm> 6 #include <limits> 7 8 #include "net/websockets/websocket.h" 9 10 #include "base/message_loop.h" 11 #include "net/http/http_response_headers.h" 12 #include "net/http/http_util.h" 13 14 namespace net { 15 16 static const int kWebSocketPort = 80; 17 static const int kSecureWebSocketPort = 443; 18 19 static const char kServerHandshakeHeader[] = 20 "HTTP/1.1 101 Web Socket Protocol Handshake\r\n"; 21 static const size_t kServerHandshakeHeaderLength = 22 sizeof(kServerHandshakeHeader) - 1; 23 24 static const char kUpgradeHeader[] = "Upgrade: WebSocket\r\n"; 25 static const size_t kUpgradeHeaderLength = sizeof(kUpgradeHeader) - 1; 26 27 static const char kConnectionHeader[] = "Connection: Upgrade\r\n"; 28 static const size_t kConnectionHeaderLength = sizeof(kConnectionHeader) - 1; 29 30 bool WebSocket::Request::is_secure() const { 31 return url_.SchemeIs("wss"); 32 } 33 34 WebSocket::WebSocket(Request* request, WebSocketDelegate* delegate) 35 : ready_state_(INITIALIZED), 36 mode_(MODE_INCOMPLETE), 37 request_(request), 38 delegate_(delegate), 39 origin_loop_(MessageLoop::current()), 40 socket_stream_(NULL), 41 max_pending_send_allowed_(0), 42 current_read_buf_(NULL), 43 read_consumed_len_(0), 44 current_write_buf_(NULL) { 45 DCHECK(request_.get()); 46 DCHECK(delegate_); 47 DCHECK(origin_loop_); 48 } 49 50 WebSocket::~WebSocket() { 51 DCHECK(ready_state_ == INITIALIZED || !delegate_); 52 DCHECK(!socket_stream_); 53 DCHECK(!delegate_); 54 } 55 56 void WebSocket::Connect() { 57 DCHECK(ready_state_ == INITIALIZED); 58 DCHECK(request_.get()); 59 DCHECK(delegate_); 60 DCHECK(!socket_stream_); 61 DCHECK(MessageLoop::current() == origin_loop_); 62 63 socket_stream_ = new SocketStream(request_->url(), this); 64 socket_stream_->set_context(request_->context()); 65 66 if (request_->host_resolver()) 67 socket_stream_->SetHostResolver(request_->host_resolver()); 68 if (request_->client_socket_factory()) 69 socket_stream_->SetClientSocketFactory(request_->client_socket_factory()); 70 71 AddRef(); // Release in DoClose(). 72 ready_state_ = CONNECTING; 73 socket_stream_->Connect(); 74 } 75 76 void WebSocket::Send(const std::string& msg) { 77 DCHECK(ready_state_ == OPEN); 78 DCHECK(MessageLoop::current() == origin_loop_); 79 80 IOBufferWithSize* buf = new IOBufferWithSize(msg.size() + 2); 81 char* p = buf->data(); 82 *p = '\0'; 83 memcpy(p + 1, msg.data(), msg.size()); 84 *(p + 1 + msg.size()) = '\xff'; 85 pending_write_bufs_.push_back(buf); 86 SendPending(); 87 } 88 89 void WebSocket::Close() { 90 DCHECK(MessageLoop::current() == origin_loop_); 91 92 if (ready_state_ == INITIALIZED) { 93 DCHECK(!socket_stream_); 94 ready_state_ = CLOSED; 95 return; 96 } 97 if (ready_state_ != CLOSED) { 98 DCHECK(socket_stream_); 99 socket_stream_->Close(); 100 return; 101 } 102 } 103 104 void WebSocket::DetachDelegate() { 105 if (!delegate_) 106 return; 107 delegate_ = NULL; 108 Close(); 109 } 110 111 void WebSocket::OnConnected(SocketStream* socket_stream, 112 int max_pending_send_allowed) { 113 DCHECK(socket_stream == socket_stream_); 114 max_pending_send_allowed_ = max_pending_send_allowed; 115 116 // Use |max_pending_send_allowed| as hint for initial size of read buffer. 117 current_read_buf_ = new GrowableIOBuffer(); 118 current_read_buf_->SetCapacity(max_pending_send_allowed_); 119 read_consumed_len_ = 0; 120 121 DCHECK(!current_write_buf_); 122 const std::string msg = request_->CreateClientHandshakeMessage(); 123 IOBufferWithSize* buf = new IOBufferWithSize(msg.size()); 124 memcpy(buf->data(), msg.data(), msg.size()); 125 pending_write_bufs_.push_back(buf); 126 origin_loop_->PostTask(FROM_HERE, 127 NewRunnableMethod(this, &WebSocket::SendPending)); 128 } 129 130 void WebSocket::OnSentData(SocketStream* socket_stream, int amount_sent) { 131 DCHECK(socket_stream == socket_stream_); 132 DCHECK(current_write_buf_); 133 current_write_buf_->DidConsume(amount_sent); 134 DCHECK_GE(current_write_buf_->BytesRemaining(), 0); 135 if (current_write_buf_->BytesRemaining() == 0) { 136 current_write_buf_ = NULL; 137 pending_write_bufs_.pop_front(); 138 } 139 origin_loop_->PostTask(FROM_HERE, 140 NewRunnableMethod(this, &WebSocket::SendPending)); 141 } 142 143 void WebSocket::OnReceivedData(SocketStream* socket_stream, 144 const char* data, int len) { 145 DCHECK(socket_stream == socket_stream_); 146 AddToReadBuffer(data, len); 147 origin_loop_->PostTask(FROM_HERE, 148 NewRunnableMethod(this, &WebSocket::DoReceivedData)); 149 } 150 151 void WebSocket::OnClose(SocketStream* socket_stream) { 152 origin_loop_->PostTask(FROM_HERE, 153 NewRunnableMethod(this, &WebSocket::DoClose)); 154 } 155 156 void WebSocket::OnError(const SocketStream* socket_stream, int error) { 157 origin_loop_->PostTask(FROM_HERE, 158 NewRunnableMethod(this, &WebSocket::DoError, error)); 159 } 160 161 std::string WebSocket::Request::CreateClientHandshakeMessage() const { 162 std::string msg; 163 msg = "GET "; 164 msg += url_.path(); 165 if (url_.has_query()) { 166 msg += "?"; 167 msg += url_.query(); 168 } 169 msg += " HTTP/1.1\r\n"; 170 msg += kUpgradeHeader; 171 msg += kConnectionHeader; 172 msg += "Host: "; 173 msg += StringToLowerASCII(url_.host()); 174 if (url_.has_port()) { 175 bool secure = is_secure(); 176 int port = url_.EffectiveIntPort(); 177 if ((!secure && 178 port != kWebSocketPort && port != url_parse::PORT_UNSPECIFIED) || 179 (secure && 180 port != kSecureWebSocketPort && port != url_parse::PORT_UNSPECIFIED)) { 181 msg += ":"; 182 msg += IntToString(port); 183 } 184 } 185 msg += "\r\n"; 186 msg += "Origin: "; 187 // It's OK to lowercase the origin as the Origin header does not contain 188 // the path or query portions, as per 189 // http://tools.ietf.org/html/draft-abarth-origin-00. 190 // 191 // TODO(satorux): Should we trim the port portion here if it's 80 for 192 // http:// or 443 for https:// ? Or can we assume it's done by the 193 // client of the library? 194 msg += StringToLowerASCII(origin_); 195 msg += "\r\n"; 196 if (!protocol_.empty()) { 197 msg += "WebSocket-Protocol: "; 198 msg += protocol_; 199 msg += "\r\n"; 200 } 201 // TODO(ukai): Add cookie if necessary. 202 msg += "\r\n"; 203 return msg; 204 } 205 206 int WebSocket::CheckHandshake() { 207 DCHECK(current_read_buf_); 208 DCHECK(ready_state_ == CONNECTING); 209 mode_ = MODE_INCOMPLETE; 210 const char *start = current_read_buf_->StartOfBuffer() + read_consumed_len_; 211 const char *p = start; 212 size_t len = current_read_buf_->offset() - read_consumed_len_; 213 if (len < kServerHandshakeHeaderLength) { 214 return -1; 215 } 216 if (!memcmp(p, kServerHandshakeHeader, kServerHandshakeHeaderLength)) { 217 mode_ = MODE_NORMAL; 218 } else { 219 int eoh = HttpUtil::LocateEndOfHeaders(p, len); 220 if (eoh < 0) 221 return -1; 222 scoped_refptr<HttpResponseHeaders> headers( 223 new HttpResponseHeaders(HttpUtil::AssembleRawHeaders(p, eoh))); 224 if (headers->response_code() == 407) { 225 mode_ = MODE_AUTHENTICATE; 226 // TODO(ukai): Implement authentication handlers. 227 } 228 DLOG(INFO) << "non-normal websocket connection. " 229 << "response_code=" << headers->response_code() 230 << " mode=" << mode_; 231 // Invalid response code. 232 ready_state_ = CLOSED; 233 return eoh; 234 } 235 const char* end = p + len + 1; 236 p += kServerHandshakeHeaderLength; 237 238 if (mode_ == MODE_NORMAL) { 239 size_t header_size = end - p; 240 if (header_size < kUpgradeHeaderLength) 241 return -1; 242 if (memcmp(p, kUpgradeHeader, kUpgradeHeaderLength)) { 243 DLOG(INFO) << "Bad Upgrade Header " 244 << std::string(p, kUpgradeHeaderLength); 245 ready_state_ = CLOSED; 246 return p - start; 247 } 248 p += kUpgradeHeaderLength; 249 250 header_size = end - p; 251 if (header_size < kConnectionHeaderLength) 252 return -1; 253 if (memcmp(p, kConnectionHeader, kConnectionHeaderLength)) { 254 DLOG(INFO) << "Bad Connection Header " 255 << std::string(p, kConnectionHeaderLength); 256 ready_state_ = CLOSED; 257 return p - start; 258 } 259 p += kConnectionHeaderLength; 260 } 261 int eoh = HttpUtil::LocateEndOfHeaders(start, len); 262 if (eoh == -1) 263 return eoh; 264 scoped_refptr<HttpResponseHeaders> headers( 265 new HttpResponseHeaders(HttpUtil::AssembleRawHeaders(start, eoh))); 266 if (!ProcessHeaders(*headers)) { 267 DLOG(INFO) << "Process Headers failed: " 268 << std::string(start, eoh); 269 ready_state_ = CLOSED; 270 return eoh; 271 } 272 switch (mode_) { 273 case MODE_NORMAL: 274 if (CheckResponseHeaders()) { 275 ready_state_ = OPEN; 276 } else { 277 ready_state_ = CLOSED; 278 } 279 break; 280 default: 281 ready_state_ = CLOSED; 282 break; 283 } 284 if (ready_state_ == CLOSED) 285 DLOG(INFO) << "CheckHandshake mode=" << mode_ 286 << " " << std::string(start, eoh); 287 return eoh; 288 } 289 290 // Gets the value of the specified header. 291 // It assures only one header of |name| in |headers|. 292 // Returns true iff single header of |name| is found in |headers| 293 // and |value| is filled with the value. 294 // Returns false otherwise. 295 static bool GetSingleHeader(const HttpResponseHeaders& headers, 296 const std::string& name, 297 std::string* value) { 298 std::string first_value; 299 void* iter = NULL; 300 if (!headers.EnumerateHeader(&iter, name, &first_value)) 301 return false; 302 303 // Checks no more |name| found in |headers|. 304 // Second call of EnumerateHeader() must return false. 305 std::string second_value; 306 if (headers.EnumerateHeader(&iter, name, &second_value)) 307 return false; 308 *value = first_value; 309 return true; 310 } 311 312 bool WebSocket::ProcessHeaders(const HttpResponseHeaders& headers) { 313 if (!GetSingleHeader(headers, "websocket-origin", &ws_origin_)) 314 return false; 315 316 if (!GetSingleHeader(headers, "websocket-location", &ws_location_)) 317 return false; 318 319 if (!request_->protocol().empty() 320 && !GetSingleHeader(headers, "websocket-protocol", &ws_protocol_)) 321 return false; 322 return true; 323 } 324 325 bool WebSocket::CheckResponseHeaders() const { 326 DCHECK(mode_ == MODE_NORMAL); 327 if (!LowerCaseEqualsASCII(request_->origin(), ws_origin_.c_str())) 328 return false; 329 if (request_->location() != ws_location_) 330 return false; 331 if (request_->protocol() != ws_protocol_) 332 return false; 333 return true; 334 } 335 336 void WebSocket::SendPending() { 337 DCHECK(MessageLoop::current() == origin_loop_); 338 DCHECK(socket_stream_); 339 if (!current_write_buf_) { 340 if (pending_write_bufs_.empty()) 341 return; 342 current_write_buf_ = new DrainableIOBuffer( 343 pending_write_bufs_.front(), pending_write_bufs_.front()->size()); 344 } 345 DCHECK_GT(current_write_buf_->BytesRemaining(), 0); 346 bool sent = socket_stream_->SendData( 347 current_write_buf_->data(), 348 std::min(current_write_buf_->BytesRemaining(), 349 max_pending_send_allowed_)); 350 DCHECK(sent); 351 } 352 353 void WebSocket::DoReceivedData() { 354 DCHECK(MessageLoop::current() == origin_loop_); 355 switch (ready_state_) { 356 case CONNECTING: 357 { 358 int eoh = CheckHandshake(); 359 if (eoh < 0) { 360 // Not enough data, Retry when more data is available. 361 return; 362 } 363 SkipReadBuffer(eoh); 364 } 365 if (ready_state_ != OPEN) { 366 // Handshake failed. 367 socket_stream_->Close(); 368 return; 369 } 370 if (delegate_) 371 delegate_->OnOpen(this); 372 if (current_read_buf_->offset() == read_consumed_len_) { 373 // No remaining data after handshake message. 374 break; 375 } 376 // FALL THROUGH 377 case OPEN: 378 ProcessFrameData(); 379 break; 380 381 case CLOSED: 382 // Closed just after DoReceivedData is queued on |origin_loop_|. 383 break; 384 default: 385 NOTREACHED(); 386 break; 387 } 388 } 389 390 void WebSocket::ProcessFrameData() { 391 DCHECK(current_read_buf_); 392 const char* start_frame = 393 current_read_buf_->StartOfBuffer() + read_consumed_len_; 394 const char* next_frame = start_frame; 395 const char* p = next_frame; 396 const char* end = 397 current_read_buf_->StartOfBuffer() + current_read_buf_->offset(); 398 while (p < end) { 399 unsigned char frame_byte = static_cast<unsigned char>(*p++); 400 if ((frame_byte & 0x80) == 0x80) { 401 int length = 0; 402 while (p < end) { 403 if (length > std::numeric_limits<int>::max() / 128) { 404 // frame length overflow. 405 socket_stream_->Close(); 406 return; 407 } 408 unsigned char c = static_cast<unsigned char>(*p); 409 length = length * 128 + (c & 0x7f); 410 ++p; 411 if ((c & 0x80) != 0x80) 412 break; 413 } 414 // Checks if the frame body hasn't been completely received yet. 415 // It also checks the case the frame length bytes haven't been completely 416 // received yet, because p == end and length > 0 in such case. 417 if (p + length < end) { 418 p += length; 419 next_frame = p; 420 } else { 421 break; 422 } 423 } else { 424 const char* msg_start = p; 425 while (p < end && *p != '\xff') 426 ++p; 427 if (p < end && *p == '\xff') { 428 if (frame_byte == 0x00 && delegate_) 429 delegate_->OnMessage(this, std::string(msg_start, p - msg_start)); 430 ++p; 431 next_frame = p; 432 } 433 } 434 } 435 SkipReadBuffer(next_frame - start_frame); 436 } 437 438 void WebSocket::AddToReadBuffer(const char* data, int len) { 439 DCHECK(current_read_buf_); 440 // Check if |current_read_buf_| has enough space to store |len| of |data|. 441 if (len >= current_read_buf_->RemainingCapacity()) { 442 current_read_buf_->SetCapacity( 443 current_read_buf_->offset() + len); 444 } 445 446 DCHECK(current_read_buf_->RemainingCapacity() >= len); 447 memcpy(current_read_buf_->data(), data, len); 448 current_read_buf_->set_offset(current_read_buf_->offset() + len); 449 } 450 451 void WebSocket::SkipReadBuffer(int len) { 452 if (len == 0) 453 return; 454 DCHECK_GT(len, 0); 455 read_consumed_len_ += len; 456 int remaining = current_read_buf_->offset() - read_consumed_len_; 457 DCHECK_GE(remaining, 0); 458 if (remaining < read_consumed_len_ && 459 current_read_buf_->RemainingCapacity() < read_consumed_len_) { 460 // Pre compaction: 461 // 0 v-read_consumed_len_ v-offset v- capacity 462 // |..processed..| .. remaining .. | .. RemainingCapacity | 463 // 464 memmove(current_read_buf_->StartOfBuffer(), 465 current_read_buf_->StartOfBuffer() + read_consumed_len_, 466 remaining); 467 read_consumed_len_ = 0; 468 current_read_buf_->set_offset(remaining); 469 // Post compaction: 470 // 0read_consumed_len_ v- offset v- capacity 471 // |.. remaining .. | .. RemainingCapacity ... | 472 // 473 } 474 } 475 476 void WebSocket::DoClose() { 477 DCHECK(MessageLoop::current() == origin_loop_); 478 WebSocketDelegate* delegate = delegate_; 479 delegate_ = NULL; 480 ready_state_ = CLOSED; 481 if (!socket_stream_) 482 return; 483 socket_stream_ = NULL; 484 if (delegate) 485 delegate->OnClose(this); 486 Release(); 487 } 488 489 void WebSocket::DoError(int error) { 490 DCHECK(MessageLoop::current() == origin_loop_); 491 if (delegate_) 492 delegate_->OnError(this, error); 493 } 494 495 } // namespace net 496