Home | History | Annotate | Download | only in websockets
      1 // Copyright (c) 2009 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/websockets/websocket_throttle.h"
      6 
      7 #include <string>
      8 
      9 #include "base/message_loop.h"
     10 #include "base/ref_counted.h"
     11 #include "base/singleton.h"
     12 #include "base/string_util.h"
     13 #include "net/base/io_buffer.h"
     14 #include "net/base/sys_addrinfo.h"
     15 #include "net/socket_stream/socket_stream.h"
     16 
     17 namespace net {
     18 
     19 static std::string AddrinfoToHashkey(const struct addrinfo* addrinfo) {
     20   switch (addrinfo->ai_family) {
     21     case AF_INET: {
     22       const struct sockaddr_in* const addr =
     23           reinterpret_cast<const sockaddr_in*>(addrinfo->ai_addr);
     24       return StringPrintf("%d:%s",
     25                           addrinfo->ai_family,
     26                           HexEncode(&addr->sin_addr, 4).c_str());
     27       }
     28     case AF_INET6: {
     29       const struct sockaddr_in6* const addr6 =
     30           reinterpret_cast<const sockaddr_in6*>(addrinfo->ai_addr);
     31       return StringPrintf("%d:%s",
     32                           addrinfo->ai_family,
     33                           HexEncode(&addr6->sin6_addr,
     34                                     sizeof(addr6->sin6_addr)).c_str());
     35       }
     36     default:
     37       return StringPrintf("%d:%s",
     38                           addrinfo->ai_family,
     39                           HexEncode(addrinfo->ai_addr,
     40                                     addrinfo->ai_addrlen).c_str());
     41   }
     42 }
     43 
     44 // State for WebSocket protocol on each SocketStream.
     45 // This is owned in SocketStream as UserData keyed by WebSocketState::kKeyName.
     46 // This is alive between connection starts and handshake is finished.
     47 // In this class, it doesn't check actual handshake finishes, but only checks
     48 // end of header is found in read data.
     49 class WebSocketThrottle::WebSocketState : public SocketStream::UserData {
     50  public:
     51   explicit WebSocketState(const AddressList& addrs)
     52       : address_list_(addrs),
     53         callback_(NULL),
     54         waiting_(false),
     55         handshake_finished_(false),
     56         buffer_(NULL) {
     57   }
     58   ~WebSocketState() {}
     59 
     60   int OnStartOpenConnection(CompletionCallback* callback) {
     61     DCHECK(!callback_);
     62     if (!waiting_)
     63       return OK;
     64     callback_ = callback;
     65     return ERR_IO_PENDING;
     66   }
     67 
     68   int OnRead(const char* data, int len, CompletionCallback* callback) {
     69     DCHECK(!waiting_);
     70     DCHECK(!callback_);
     71     DCHECK(!handshake_finished_);
     72     static const int kBufferSize = 8129;
     73 
     74     if (!buffer_) {
     75       // Fast path.
     76       int eoh = HttpUtil::LocateEndOfHeaders(data, len, 0);
     77       if (eoh > 0) {
     78         handshake_finished_ = true;
     79         return OK;
     80       }
     81       buffer_ = new GrowableIOBuffer();
     82       buffer_->SetCapacity(kBufferSize);
     83     } else if (buffer_->RemainingCapacity() < len) {
     84       buffer_->SetCapacity(buffer_->capacity() + kBufferSize);
     85     }
     86     memcpy(buffer_->data(), data, len);
     87     buffer_->set_offset(buffer_->offset() + len);
     88 
     89     int eoh = HttpUtil::LocateEndOfHeaders(buffer_->StartOfBuffer(),
     90                                            buffer_->offset(), 0);
     91     handshake_finished_ = (eoh > 0);
     92     return OK;
     93   }
     94 
     95   const AddressList& address_list() const { return address_list_; }
     96   void SetWaiting() { waiting_ = true; }
     97   bool IsWaiting() const { return waiting_; }
     98   bool HandshakeFinished() const { return handshake_finished_; }
     99   void Wakeup() {
    100     waiting_ = false;
    101     // We wrap |callback_| to keep this alive while this is released.
    102     scoped_refptr<CompletionCallbackRunner> runner =
    103         new CompletionCallbackRunner(callback_);
    104     callback_ = NULL;
    105     MessageLoopForIO::current()->PostTask(
    106         FROM_HERE,
    107         NewRunnableMethod(runner.get(),
    108                           &CompletionCallbackRunner::Run));
    109   }
    110 
    111   static const char* kKeyName;
    112 
    113  private:
    114   class CompletionCallbackRunner
    115       : public base::RefCountedThreadSafe<CompletionCallbackRunner> {
    116    public:
    117     explicit CompletionCallbackRunner(CompletionCallback* callback)
    118         : callback_(callback) {
    119       DCHECK(callback_);
    120     }
    121     void Run() {
    122       callback_->Run(OK);
    123     }
    124    private:
    125     friend class base::RefCountedThreadSafe<CompletionCallbackRunner>;
    126 
    127     virtual ~CompletionCallbackRunner() {}
    128 
    129     CompletionCallback* callback_;
    130 
    131     DISALLOW_COPY_AND_ASSIGN(CompletionCallbackRunner);
    132   };
    133 
    134   const AddressList& address_list_;
    135 
    136   CompletionCallback* callback_;
    137   // True if waiting another websocket connection is established.
    138   // False if the websocket is performing handshaking.
    139   bool waiting_;
    140 
    141   // True if the websocket handshake is completed.
    142   // If true, it will be removed from queue and deleted from the SocketStream
    143   // UserData soon.
    144   bool handshake_finished_;
    145 
    146   // Buffer for read data to check handshake response message.
    147   scoped_refptr<GrowableIOBuffer> buffer_;
    148 
    149   DISALLOW_COPY_AND_ASSIGN(WebSocketState);
    150 };
    151 
    152 const char* WebSocketThrottle::WebSocketState::kKeyName = "WebSocketState";
    153 
    154 WebSocketThrottle::WebSocketThrottle() {
    155   SocketStreamThrottle::RegisterSocketStreamThrottle("ws", this);
    156   SocketStreamThrottle::RegisterSocketStreamThrottle("wss", this);
    157 }
    158 
    159 WebSocketThrottle::~WebSocketThrottle() {
    160   DCHECK(queue_.empty());
    161   DCHECK(addr_map_.empty());
    162 }
    163 
    164 int WebSocketThrottle::OnStartOpenConnection(
    165     SocketStream* socket, CompletionCallback* callback) {
    166   WebSocketState* state = new WebSocketState(socket->address_list());
    167   PutInQueue(socket, state);
    168   return state->OnStartOpenConnection(callback);
    169 }
    170 
    171 int WebSocketThrottle::OnRead(SocketStream* socket,
    172                               const char* data, int len,
    173                               CompletionCallback* callback) {
    174   WebSocketState* state = static_cast<WebSocketState*>(
    175       socket->GetUserData(WebSocketState::kKeyName));
    176   // If no state, handshake was already completed. Do nothing.
    177   if (!state)
    178     return OK;
    179 
    180   int result = state->OnRead(data, len, callback);
    181   if (state->HandshakeFinished()) {
    182     RemoveFromQueue(socket, state);
    183     WakeupSocketIfNecessary();
    184   }
    185   return result;
    186 }
    187 
    188 int WebSocketThrottle::OnWrite(SocketStream* socket,
    189                                const char* data, int len,
    190                                CompletionCallback* callback) {
    191   // Do nothing.
    192   return OK;
    193 }
    194 
    195 void WebSocketThrottle::OnClose(SocketStream* socket) {
    196   WebSocketState* state = static_cast<WebSocketState*>(
    197       socket->GetUserData(WebSocketState::kKeyName));
    198   if (!state)
    199     return;
    200   RemoveFromQueue(socket, state);
    201   WakeupSocketIfNecessary();
    202 }
    203 
    204 void WebSocketThrottle::PutInQueue(SocketStream* socket,
    205                                    WebSocketState* state) {
    206   socket->SetUserData(WebSocketState::kKeyName, state);
    207   queue_.push_back(state);
    208   const AddressList& address_list = socket->address_list();
    209   for (const struct addrinfo* addrinfo = address_list.head();
    210        addrinfo != NULL;
    211        addrinfo = addrinfo->ai_next) {
    212     std::string addrkey = AddrinfoToHashkey(addrinfo);
    213     ConnectingAddressMap::iterator iter = addr_map_.find(addrkey);
    214     if (iter == addr_map_.end()) {
    215       ConnectingQueue* queue = new ConnectingQueue();
    216       queue->push_back(state);
    217       addr_map_[addrkey] = queue;
    218     } else {
    219       iter->second->push_back(state);
    220       state->SetWaiting();
    221     }
    222   }
    223 }
    224 
    225 void WebSocketThrottle::RemoveFromQueue(SocketStream* socket,
    226                                         WebSocketState* state) {
    227   const AddressList& address_list = socket->address_list();
    228   for (const struct addrinfo* addrinfo = address_list.head();
    229        addrinfo != NULL;
    230        addrinfo = addrinfo->ai_next) {
    231     std::string addrkey = AddrinfoToHashkey(addrinfo);
    232     ConnectingAddressMap::iterator iter = addr_map_.find(addrkey);
    233     DCHECK(iter != addr_map_.end());
    234     ConnectingQueue* queue = iter->second;
    235     DCHECK(state == queue->front());
    236     queue->pop_front();
    237     if (queue->empty()) {
    238       delete queue;
    239       addr_map_.erase(iter);
    240     }
    241   }
    242   for (ConnectingQueue::iterator iter = queue_.begin();
    243        iter != queue_.end();
    244        ++iter) {
    245     if (*iter == state) {
    246       queue_.erase(iter);
    247       break;
    248     }
    249   }
    250   socket->SetUserData(WebSocketState::kKeyName, NULL);
    251 }
    252 
    253 void WebSocketThrottle::WakeupSocketIfNecessary() {
    254   for (ConnectingQueue::iterator iter = queue_.begin();
    255        iter != queue_.end();
    256        ++iter) {
    257     WebSocketState* state = *iter;
    258     if (!state->IsWaiting())
    259       continue;
    260 
    261     bool should_wakeup = true;
    262     const AddressList& address_list = state->address_list();
    263     for (const struct addrinfo* addrinfo = address_list.head();
    264          addrinfo != NULL;
    265          addrinfo = addrinfo->ai_next) {
    266       std::string addrkey = AddrinfoToHashkey(addrinfo);
    267       ConnectingAddressMap::iterator iter = addr_map_.find(addrkey);
    268       DCHECK(iter != addr_map_.end());
    269       ConnectingQueue* queue = iter->second;
    270       if (state != queue->front()) {
    271         should_wakeup = false;
    272         break;
    273       }
    274     }
    275     if (should_wakeup)
    276       state->Wakeup();
    277   }
    278 }
    279 
    280 /* static */
    281 void WebSocketThrottle::Init() {
    282   Singleton<WebSocketThrottle>::get();
    283 }
    284 
    285 }  // namespace net
    286