Home | History | Annotate | Download | only in websockets
      1 // Copyright (c) 2011 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include <string>
      6 #include <vector>
      7 
      8 #include "base/memory/ref_counted.h"
      9 #include "base/string_split.h"
     10 #include "base/string_util.h"
     11 #include "googleurl/src/gurl.h"
     12 #include "net/base/cookie_policy.h"
     13 #include "net/base/cookie_store.h"
     14 #include "net/base/net_errors.h"
     15 #include "net/base/sys_addrinfo.h"
     16 #include "net/base/transport_security_state.h"
     17 #include "net/socket_stream/socket_stream.h"
     18 #include "net/url_request/url_request_context.h"
     19 #include "net/websockets/websocket_job.h"
     20 #include "net/websockets/websocket_throttle.h"
     21 #include "testing/gtest/include/gtest/gtest.h"
     22 #include "testing/gmock/include/gmock/gmock.h"
     23 #include "testing/platform_test.h"
     24 
     25 namespace net {
     26 
     27 class MockSocketStream : public SocketStream {
     28  public:
     29   MockSocketStream(const GURL& url, SocketStream::Delegate* delegate)
     30       : SocketStream(url, delegate) {}
     31   virtual ~MockSocketStream() {}
     32 
     33   virtual void Connect() {}
     34   virtual bool SendData(const char* data, int len) {
     35     sent_data_ += std::string(data, len);
     36     return true;
     37   }
     38 
     39   virtual void Close() {}
     40   virtual void RestartWithAuth(
     41       const string16& username, const string16& password) {}
     42   virtual void DetachDelegate() {
     43     delegate_ = NULL;
     44   }
     45 
     46   const std::string& sent_data() const {
     47     return sent_data_;
     48   }
     49 
     50  private:
     51   std::string sent_data_;
     52 };
     53 
     54 class MockSocketStreamDelegate : public SocketStream::Delegate {
     55  public:
     56   MockSocketStreamDelegate()
     57       : amount_sent_(0) {}
     58   virtual ~MockSocketStreamDelegate() {}
     59 
     60   virtual void OnConnected(SocketStream* socket, int max_pending_send_allowed) {
     61   }
     62   virtual void OnSentData(SocketStream* socket, int amount_sent) {
     63     amount_sent_ += amount_sent;
     64   }
     65   virtual void OnReceivedData(SocketStream* socket,
     66                               const char* data, int len) {
     67     received_data_ += std::string(data, len);
     68   }
     69   virtual void OnClose(SocketStream* socket) {
     70   }
     71 
     72   size_t amount_sent() const { return amount_sent_; }
     73   const std::string& received_data() const { return received_data_; }
     74 
     75  private:
     76   int amount_sent_;
     77   std::string received_data_;
     78 };
     79 
     80 class MockCookieStore : public CookieStore {
     81  public:
     82   struct Entry {
     83     GURL url;
     84     std::string cookie_line;
     85     CookieOptions options;
     86   };
     87   MockCookieStore() {}
     88 
     89   virtual bool SetCookieWithOptions(const GURL& url,
     90                                     const std::string& cookie_line,
     91                                     const CookieOptions& options) {
     92     Entry entry;
     93     entry.url = url;
     94     entry.cookie_line = cookie_line;
     95     entry.options = options;
     96     entries_.push_back(entry);
     97     return true;
     98   }
     99   virtual std::string GetCookiesWithOptions(const GURL& url,
    100                                             const CookieOptions& options) {
    101     std::string result;
    102     for (size_t i = 0; i < entries_.size(); i++) {
    103       Entry &entry = entries_[i];
    104       if (url == entry.url) {
    105         if (!result.empty()) {
    106           result += "; ";
    107         }
    108         result += entry.cookie_line;
    109       }
    110     }
    111     return result;
    112   }
    113   virtual void DeleteCookie(const GURL& url,
    114                             const std::string& cookie_name) {}
    115   virtual CookieMonster* GetCookieMonster() { return NULL; }
    116 
    117   const std::vector<Entry>& entries() const { return entries_; }
    118 
    119  private:
    120   friend class base::RefCountedThreadSafe<MockCookieStore>;
    121   virtual ~MockCookieStore() {}
    122 
    123   std::vector<Entry> entries_;
    124 };
    125 
    126 class MockCookiePolicy : public CookiePolicy {
    127  public:
    128   MockCookiePolicy() : allow_all_cookies_(true) {}
    129   virtual ~MockCookiePolicy() {}
    130 
    131   void set_allow_all_cookies(bool allow_all_cookies) {
    132     allow_all_cookies_ = allow_all_cookies;
    133   }
    134 
    135   virtual int CanGetCookies(const GURL& url,
    136                             const GURL& first_party_for_cookies) const {
    137     if (allow_all_cookies_)
    138       return OK;
    139     return ERR_ACCESS_DENIED;
    140   }
    141 
    142   virtual int CanSetCookie(const GURL& url,
    143                            const GURL& first_party_for_cookies,
    144                            const std::string& cookie_line) const {
    145     if (allow_all_cookies_)
    146       return OK;
    147     return ERR_ACCESS_DENIED;
    148   }
    149 
    150  private:
    151   bool allow_all_cookies_;
    152 };
    153 
    154 class MockURLRequestContext : public URLRequestContext {
    155  public:
    156   MockURLRequestContext(CookieStore* cookie_store,
    157                         CookiePolicy* cookie_policy) {
    158     set_cookie_store(cookie_store);
    159     set_cookie_policy(cookie_policy);
    160     transport_security_state_ = new TransportSecurityState();
    161     set_transport_security_state(transport_security_state_.get());
    162     TransportSecurityState::DomainState state;
    163     state.expiry = base::Time::Now() + base::TimeDelta::FromSeconds(1000);
    164     transport_security_state_->EnableHost("upgrademe.com", state);
    165   }
    166 
    167  private:
    168   friend class base::RefCountedThreadSafe<MockURLRequestContext>;
    169   virtual ~MockURLRequestContext() {}
    170 
    171   scoped_refptr<TransportSecurityState> transport_security_state_;
    172 };
    173 
    174 class WebSocketJobTest : public PlatformTest {
    175  public:
    176   virtual void SetUp() {
    177     cookie_store_ = new MockCookieStore;
    178     cookie_policy_.reset(new MockCookiePolicy);
    179     context_ = new MockURLRequestContext(
    180         cookie_store_.get(), cookie_policy_.get());
    181   }
    182   virtual void TearDown() {
    183     cookie_store_ = NULL;
    184     cookie_policy_.reset();
    185     context_ = NULL;
    186     websocket_ = NULL;
    187     socket_ = NULL;
    188   }
    189  protected:
    190   void InitWebSocketJob(const GURL& url, MockSocketStreamDelegate* delegate) {
    191     websocket_ = new WebSocketJob(delegate);
    192     socket_ = new MockSocketStream(url, websocket_.get());
    193     websocket_->InitSocketStream(socket_.get());
    194     websocket_->set_context(context_.get());
    195     websocket_->state_ = WebSocketJob::CONNECTING;
    196     struct addrinfo addr;
    197     memset(&addr, 0, sizeof(struct addrinfo));
    198     addr.ai_family = AF_INET;
    199     addr.ai_addrlen = sizeof(struct sockaddr_in);
    200     struct sockaddr_in sa_in;
    201     memset(&sa_in, 0, sizeof(struct sockaddr_in));
    202     memcpy(&sa_in.sin_addr, "\x7f\0\0\1", 4);
    203     addr.ai_addr = reinterpret_cast<sockaddr*>(&sa_in);
    204     addr.ai_next = NULL;
    205     websocket_->addresses_.Copy(&addr, true);
    206     WebSocketThrottle::GetInstance()->PutInQueue(websocket_);
    207   }
    208   WebSocketJob::State GetWebSocketJobState() {
    209     return websocket_->state_;
    210   }
    211   void CloseWebSocketJob() {
    212     if (websocket_->socket_) {
    213       websocket_->socket_->DetachDelegate();
    214       WebSocketThrottle::GetInstance()->RemoveFromQueue(websocket_);
    215     }
    216     websocket_->state_ = WebSocketJob::CLOSED;
    217     websocket_->delegate_ = NULL;
    218     websocket_->socket_ = NULL;
    219   }
    220   SocketStream* GetSocket(SocketStreamJob* job) {
    221     return job->socket_.get();
    222   }
    223 
    224   scoped_refptr<MockCookieStore> cookie_store_;
    225   scoped_ptr<MockCookiePolicy> cookie_policy_;
    226   scoped_refptr<MockURLRequestContext> context_;
    227   scoped_refptr<WebSocketJob> websocket_;
    228   scoped_refptr<MockSocketStream> socket_;
    229 };
    230 
    231 TEST_F(WebSocketJobTest, SimpleHandshake) {
    232   GURL url("ws://example.com/demo");
    233   MockSocketStreamDelegate delegate;
    234   InitWebSocketJob(url, &delegate);
    235 
    236   static const char* kHandshakeRequestMessage =
    237       "GET /demo HTTP/1.1\r\n"
    238       "Host: example.com\r\n"
    239       "Connection: Upgrade\r\n"
    240       "Sec-WebSocket-Key2: 12998 5 Y3 1  .P00\r\n"
    241       "Sec-WebSocket-Protocol: sample\r\n"
    242       "Upgrade: WebSocket\r\n"
    243       "Sec-WebSocket-Key1: 4 @1  46546xW%0l 1 5\r\n"
    244       "Origin: http://example.com\r\n"
    245       "\r\n"
    246       "^n:ds[4U";
    247 
    248   bool sent = websocket_->SendData(kHandshakeRequestMessage,
    249                                    strlen(kHandshakeRequestMessage));
    250   EXPECT_TRUE(sent);
    251   MessageLoop::current()->RunAllPending();
    252   EXPECT_EQ(kHandshakeRequestMessage, socket_->sent_data());
    253   EXPECT_EQ(WebSocketJob::CONNECTING, GetWebSocketJobState());
    254   websocket_->OnSentData(socket_.get(), strlen(kHandshakeRequestMessage));
    255   EXPECT_EQ(strlen(kHandshakeRequestMessage), delegate.amount_sent());
    256 
    257   const char kHandshakeResponseMessage[] =
    258       "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
    259       "Upgrade: WebSocket\r\n"
    260       "Connection: Upgrade\r\n"
    261       "Sec-WebSocket-Origin: http://example.com\r\n"
    262       "Sec-WebSocket-Location: ws://example.com/demo\r\n"
    263       "Sec-WebSocket-Protocol: sample\r\n"
    264       "\r\n"
    265       "8jKS'y:G*Co,Wxa-";
    266 
    267   websocket_->OnReceivedData(socket_.get(),
    268                              kHandshakeResponseMessage,
    269                              strlen(kHandshakeResponseMessage));
    270   MessageLoop::current()->RunAllPending();
    271   EXPECT_EQ(kHandshakeResponseMessage, delegate.received_data());
    272   EXPECT_EQ(WebSocketJob::OPEN, GetWebSocketJobState());
    273   CloseWebSocketJob();
    274 }
    275 
    276 TEST_F(WebSocketJobTest, SlowHandshake) {
    277   GURL url("ws://example.com/demo");
    278   MockSocketStreamDelegate delegate;
    279   InitWebSocketJob(url, &delegate);
    280 
    281   static const char* kHandshakeRequestMessage =
    282       "GET /demo HTTP/1.1\r\n"
    283       "Host: example.com\r\n"
    284       "Connection: Upgrade\r\n"
    285       "Sec-WebSocket-Key2: 12998 5 Y3 1  .P00\r\n"
    286       "Sec-WebSocket-Protocol: sample\r\n"
    287       "Upgrade: WebSocket\r\n"
    288       "Sec-WebSocket-Key1: 4 @1  46546xW%0l 1 5\r\n"
    289       "Origin: http://example.com\r\n"
    290       "\r\n"
    291       "^n:ds[4U";
    292 
    293   bool sent = websocket_->SendData(kHandshakeRequestMessage,
    294                                    strlen(kHandshakeRequestMessage));
    295   EXPECT_TRUE(sent);
    296   // We assume request is sent in one data chunk (from WebKit)
    297   // We don't support streaming request.
    298   MessageLoop::current()->RunAllPending();
    299   EXPECT_EQ(kHandshakeRequestMessage, socket_->sent_data());
    300   EXPECT_EQ(WebSocketJob::CONNECTING, GetWebSocketJobState());
    301   websocket_->OnSentData(socket_.get(), strlen(kHandshakeRequestMessage));
    302   EXPECT_EQ(strlen(kHandshakeRequestMessage), delegate.amount_sent());
    303 
    304   const char kHandshakeResponseMessage[] =
    305       "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
    306       "Upgrade: WebSocket\r\n"
    307       "Connection: Upgrade\r\n"
    308       "Sec-WebSocket-Origin: http://example.com\r\n"
    309       "Sec-WebSocket-Location: ws://example.com/demo\r\n"
    310       "Sec-WebSocket-Protocol: sample\r\n"
    311       "\r\n"
    312       "8jKS'y:G*Co,Wxa-";
    313 
    314   std::vector<std::string> lines;
    315   base::SplitString(kHandshakeResponseMessage, '\n', &lines);
    316   for (size_t i = 0; i < lines.size() - 2; i++) {
    317     std::string line = lines[i] + "\r\n";
    318     SCOPED_TRACE("Line: " + line);
    319     websocket_->OnReceivedData(socket_,
    320                                line.c_str(),
    321                                line.size());
    322     MessageLoop::current()->RunAllPending();
    323     EXPECT_TRUE(delegate.received_data().empty());
    324     EXPECT_EQ(WebSocketJob::CONNECTING, GetWebSocketJobState());
    325   }
    326   websocket_->OnReceivedData(socket_.get(), "\r\n", 2);
    327   MessageLoop::current()->RunAllPending();
    328   EXPECT_TRUE(delegate.received_data().empty());
    329   EXPECT_EQ(WebSocketJob::CONNECTING, GetWebSocketJobState());
    330   websocket_->OnReceivedData(socket_.get(), "8jKS'y:G*Co,Wxa-", 16);
    331   EXPECT_EQ(kHandshakeResponseMessage, delegate.received_data());
    332   EXPECT_EQ(WebSocketJob::OPEN, GetWebSocketJobState());
    333   CloseWebSocketJob();
    334 }
    335 
    336 TEST_F(WebSocketJobTest, HandshakeWithCookie) {
    337   GURL url("ws://example.com/demo");
    338   GURL cookieUrl("http://example.com/demo");
    339   CookieOptions cookie_options;
    340   cookie_store_->SetCookieWithOptions(
    341       cookieUrl, "CR-test=1", cookie_options);
    342   cookie_options.set_include_httponly();
    343   cookie_store_->SetCookieWithOptions(
    344       cookieUrl, "CR-test-httponly=1", cookie_options);
    345 
    346   MockSocketStreamDelegate delegate;
    347   InitWebSocketJob(url, &delegate);
    348 
    349   static const char* kHandshakeRequestMessage =
    350       "GET /demo HTTP/1.1\r\n"
    351       "Host: example.com\r\n"
    352       "Connection: Upgrade\r\n"
    353       "Sec-WebSocket-Key2: 12998 5 Y3 1  .P00\r\n"
    354       "Sec-WebSocket-Protocol: sample\r\n"
    355       "Upgrade: WebSocket\r\n"
    356       "Sec-WebSocket-Key1: 4 @1  46546xW%0l 1 5\r\n"
    357       "Origin: http://example.com\r\n"
    358       "Cookie: WK-test=1\r\n"
    359       "\r\n"
    360       "^n:ds[4U";
    361 
    362   static const char* kHandshakeRequestExpected =
    363       "GET /demo HTTP/1.1\r\n"
    364       "Host: example.com\r\n"
    365       "Connection: Upgrade\r\n"
    366       "Sec-WebSocket-Key2: 12998 5 Y3 1  .P00\r\n"
    367       "Sec-WebSocket-Protocol: sample\r\n"
    368       "Upgrade: WebSocket\r\n"
    369       "Sec-WebSocket-Key1: 4 @1  46546xW%0l 1 5\r\n"
    370       "Origin: http://example.com\r\n"
    371       "Cookie: CR-test=1; CR-test-httponly=1\r\n"
    372       "\r\n"
    373       "^n:ds[4U";
    374 
    375   bool sent = websocket_->SendData(kHandshakeRequestMessage,
    376                                    strlen(kHandshakeRequestMessage));
    377   EXPECT_TRUE(sent);
    378   MessageLoop::current()->RunAllPending();
    379   EXPECT_EQ(kHandshakeRequestExpected, socket_->sent_data());
    380   EXPECT_EQ(WebSocketJob::CONNECTING, GetWebSocketJobState());
    381   websocket_->OnSentData(socket_, strlen(kHandshakeRequestExpected));
    382   EXPECT_EQ(strlen(kHandshakeRequestMessage), delegate.amount_sent());
    383 
    384   const char kHandshakeResponseMessage[] =
    385       "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
    386       "Upgrade: WebSocket\r\n"
    387       "Connection: Upgrade\r\n"
    388       "Sec-WebSocket-Origin: http://example.com\r\n"
    389       "Sec-WebSocket-Location: ws://example.com/demo\r\n"
    390       "Sec-WebSocket-Protocol: sample\r\n"
    391       "Set-Cookie: CR-set-test=1\r\n"
    392       "\r\n"
    393       "8jKS'y:G*Co,Wxa-";
    394 
    395   static const char* kHandshakeResponseExpected =
    396       "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
    397       "Upgrade: WebSocket\r\n"
    398       "Connection: Upgrade\r\n"
    399       "Sec-WebSocket-Origin: http://example.com\r\n"
    400       "Sec-WebSocket-Location: ws://example.com/demo\r\n"
    401       "Sec-WebSocket-Protocol: sample\r\n"
    402       "\r\n"
    403       "8jKS'y:G*Co,Wxa-";
    404 
    405   websocket_->OnReceivedData(socket_.get(),
    406                              kHandshakeResponseMessage,
    407                              strlen(kHandshakeResponseMessage));
    408   MessageLoop::current()->RunAllPending();
    409   EXPECT_EQ(kHandshakeResponseExpected, delegate.received_data());
    410   EXPECT_EQ(WebSocketJob::OPEN, GetWebSocketJobState());
    411 
    412   EXPECT_EQ(3U, cookie_store_->entries().size());
    413   EXPECT_EQ(cookieUrl, cookie_store_->entries()[0].url);
    414   EXPECT_EQ("CR-test=1", cookie_store_->entries()[0].cookie_line);
    415   EXPECT_EQ(cookieUrl, cookie_store_->entries()[1].url);
    416   EXPECT_EQ("CR-test-httponly=1", cookie_store_->entries()[1].cookie_line);
    417   EXPECT_EQ(cookieUrl, cookie_store_->entries()[2].url);
    418   EXPECT_EQ("CR-set-test=1", cookie_store_->entries()[2].cookie_line);
    419 
    420   CloseWebSocketJob();
    421 }
    422 
    423 TEST_F(WebSocketJobTest, HandshakeWithCookieButNotAllowed) {
    424   GURL url("ws://example.com/demo");
    425   GURL cookieUrl("http://example.com/demo");
    426   CookieOptions cookie_options;
    427   cookie_store_->SetCookieWithOptions(
    428       cookieUrl, "CR-test=1", cookie_options);
    429   cookie_options.set_include_httponly();
    430   cookie_store_->SetCookieWithOptions(
    431       cookieUrl, "CR-test-httponly=1", cookie_options);
    432   cookie_policy_->set_allow_all_cookies(false);
    433 
    434   MockSocketStreamDelegate delegate;
    435   InitWebSocketJob(url, &delegate);
    436 
    437   static const char* kHandshakeRequestMessage =
    438       "GET /demo HTTP/1.1\r\n"
    439       "Host: example.com\r\n"
    440       "Connection: Upgrade\r\n"
    441       "Sec-WebSocket-Key2: 12998 5 Y3 1  .P00\r\n"
    442       "Sec-WebSocket-Protocol: sample\r\n"
    443       "Upgrade: WebSocket\r\n"
    444       "Sec-WebSocket-Key1: 4 @1  46546xW%0l 1 5\r\n"
    445       "Origin: http://example.com\r\n"
    446       "Cookie: WK-test=1\r\n"
    447       "\r\n"
    448       "^n:ds[4U";
    449 
    450   static const char* kHandshakeRequestExpected =
    451       "GET /demo HTTP/1.1\r\n"
    452       "Host: example.com\r\n"
    453       "Connection: Upgrade\r\n"
    454       "Sec-WebSocket-Key2: 12998 5 Y3 1  .P00\r\n"
    455       "Sec-WebSocket-Protocol: sample\r\n"
    456       "Upgrade: WebSocket\r\n"
    457       "Sec-WebSocket-Key1: 4 @1  46546xW%0l 1 5\r\n"
    458       "Origin: http://example.com\r\n"
    459       "\r\n"
    460       "^n:ds[4U";
    461 
    462   bool sent = websocket_->SendData(kHandshakeRequestMessage,
    463                                    strlen(kHandshakeRequestMessage));
    464   EXPECT_TRUE(sent);
    465   MessageLoop::current()->RunAllPending();
    466   EXPECT_EQ(kHandshakeRequestExpected, socket_->sent_data());
    467   EXPECT_EQ(WebSocketJob::CONNECTING, GetWebSocketJobState());
    468   websocket_->OnSentData(socket_, strlen(kHandshakeRequestExpected));
    469   EXPECT_EQ(strlen(kHandshakeRequestMessage), delegate.amount_sent());
    470 
    471   const char kHandshakeResponseMessage[] =
    472       "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
    473       "Upgrade: WebSocket\r\n"
    474       "Connection: Upgrade\r\n"
    475       "Sec-WebSocket-Origin: http://example.com\r\n"
    476       "Sec-WebSocket-Location: ws://example.com/demo\r\n"
    477       "Sec-WebSocket-Protocol: sample\r\n"
    478       "Set-Cookie: CR-set-test=1\r\n"
    479       "\r\n"
    480       "8jKS'y:G*Co,Wxa-";
    481 
    482   static const char* kHandshakeResponseExpected =
    483       "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
    484       "Upgrade: WebSocket\r\n"
    485       "Connection: Upgrade\r\n"
    486       "Sec-WebSocket-Origin: http://example.com\r\n"
    487       "Sec-WebSocket-Location: ws://example.com/demo\r\n"
    488       "Sec-WebSocket-Protocol: sample\r\n"
    489       "\r\n"
    490       "8jKS'y:G*Co,Wxa-";
    491 
    492   websocket_->OnReceivedData(socket_.get(),
    493                              kHandshakeResponseMessage,
    494                              strlen(kHandshakeResponseMessage));
    495   MessageLoop::current()->RunAllPending();
    496   EXPECT_EQ(kHandshakeResponseExpected, delegate.received_data());
    497   EXPECT_EQ(WebSocketJob::OPEN, GetWebSocketJobState());
    498 
    499   EXPECT_EQ(2U, cookie_store_->entries().size());
    500   EXPECT_EQ(cookieUrl, cookie_store_->entries()[0].url);
    501   EXPECT_EQ("CR-test=1", cookie_store_->entries()[0].cookie_line);
    502   EXPECT_EQ(cookieUrl, cookie_store_->entries()[1].url);
    503   EXPECT_EQ("CR-test-httponly=1", cookie_store_->entries()[1].cookie_line);
    504 
    505   CloseWebSocketJob();
    506 }
    507 
    508 TEST_F(WebSocketJobTest, HSTSUpgrade) {
    509   GURL url("ws://upgrademe.com/");
    510   MockSocketStreamDelegate delegate;
    511   scoped_refptr<SocketStreamJob> job = SocketStreamJob::CreateSocketStreamJob(
    512       url, &delegate, *context_.get());
    513   EXPECT_TRUE(GetSocket(job.get())->is_secure());
    514   job->DetachDelegate();
    515 
    516   url = GURL("ws://donotupgrademe.com/");
    517   job = SocketStreamJob::CreateSocketStreamJob(
    518       url, &delegate, *context_.get());
    519   EXPECT_FALSE(GetSocket(job.get())->is_secure());
    520   job->DetachDelegate();
    521 }
    522 
    523 }  // namespace net
    524