Home | History | Annotate | Download | only in socket_stream
      1 // Copyright (c) 2010 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include <string>
      6 #include <vector>
      7 
      8 #include "base/callback.h"
      9 #include "base/utf_string_conversions.h"
     10 #include "net/base/auth.h"
     11 #include "net/base/mock_host_resolver.h"
     12 #include "net/base/net_log.h"
     13 #include "net/base/net_log_unittest.h"
     14 #include "net/base/test_completion_callback.h"
     15 #include "net/socket/socket_test_util.h"
     16 #include "net/socket_stream/socket_stream.h"
     17 #include "net/url_request/url_request_test_util.h"
     18 #include "testing/gtest/include/gtest/gtest.h"
     19 #include "testing/platform_test.h"
     20 
     21 struct SocketStreamEvent {
     22   enum EventType {
     23     EVENT_CONNECTED, EVENT_SENT_DATA, EVENT_RECEIVED_DATA, EVENT_CLOSE,
     24     EVENT_AUTH_REQUIRED,
     25   };
     26 
     27   SocketStreamEvent(EventType type, net::SocketStream* socket_stream,
     28                     int num, const std::string& str,
     29                     net::AuthChallengeInfo* auth_challenge_info)
     30       : event_type(type), socket(socket_stream), number(num), data(str),
     31         auth_info(auth_challenge_info) {}
     32 
     33   EventType event_type;
     34   net::SocketStream* socket;
     35   int number;
     36   std::string data;
     37   scoped_refptr<net::AuthChallengeInfo> auth_info;
     38 };
     39 
     40 class SocketStreamEventRecorder : public net::SocketStream::Delegate {
     41  public:
     42   explicit SocketStreamEventRecorder(net::CompletionCallback* callback)
     43       : on_connected_(NULL),
     44         on_sent_data_(NULL),
     45         on_received_data_(NULL),
     46         on_close_(NULL),
     47         on_auth_required_(NULL),
     48         callback_(callback) {}
     49   virtual ~SocketStreamEventRecorder() {
     50     delete on_connected_;
     51     delete on_sent_data_;
     52     delete on_received_data_;
     53     delete on_close_;
     54     delete on_auth_required_;
     55   }
     56 
     57   void SetOnConnected(Callback1<SocketStreamEvent*>::Type* callback) {
     58     on_connected_ = callback;
     59   }
     60   void SetOnSentData(Callback1<SocketStreamEvent*>::Type* callback) {
     61     on_sent_data_ = callback;
     62   }
     63   void SetOnReceivedData(Callback1<SocketStreamEvent*>::Type* callback) {
     64     on_received_data_ = callback;
     65   }
     66   void SetOnClose(Callback1<SocketStreamEvent*>::Type* callback) {
     67     on_close_ = callback;
     68   }
     69   void SetOnAuthRequired(Callback1<SocketStreamEvent*>::Type* callback) {
     70     on_auth_required_ = callback;
     71   }
     72 
     73   virtual void OnConnected(net::SocketStream* socket,
     74                            int num_pending_send_allowed) {
     75     events_.push_back(
     76         SocketStreamEvent(SocketStreamEvent::EVENT_CONNECTED,
     77                           socket, num_pending_send_allowed, std::string(),
     78                           NULL));
     79     if (on_connected_)
     80       on_connected_->Run(&events_.back());
     81   }
     82   virtual void OnSentData(net::SocketStream* socket,
     83                           int amount_sent) {
     84     events_.push_back(
     85         SocketStreamEvent(SocketStreamEvent::EVENT_SENT_DATA,
     86                           socket, amount_sent, std::string(), NULL));
     87     if (on_sent_data_)
     88       on_sent_data_->Run(&events_.back());
     89   }
     90   virtual void OnReceivedData(net::SocketStream* socket,
     91                               const char* data, int len) {
     92     events_.push_back(
     93         SocketStreamEvent(SocketStreamEvent::EVENT_RECEIVED_DATA,
     94                           socket, len, std::string(data, len), NULL));
     95     if (on_received_data_)
     96       on_received_data_->Run(&events_.back());
     97   }
     98   virtual void OnClose(net::SocketStream* socket) {
     99     events_.push_back(
    100         SocketStreamEvent(SocketStreamEvent::EVENT_CLOSE,
    101                           socket, 0, std::string(), NULL));
    102     if (on_close_)
    103       on_close_->Run(&events_.back());
    104     if (callback_)
    105       callback_->Run(net::OK);
    106   }
    107   virtual void OnAuthRequired(net::SocketStream* socket,
    108                               net::AuthChallengeInfo* auth_info) {
    109     events_.push_back(
    110         SocketStreamEvent(SocketStreamEvent::EVENT_AUTH_REQUIRED,
    111                           socket, 0, std::string(), auth_info));
    112     if (on_auth_required_)
    113       on_auth_required_->Run(&events_.back());
    114   }
    115 
    116   void DoClose(SocketStreamEvent* event) {
    117     event->socket->Close();
    118   }
    119   void DoRestartWithAuth(SocketStreamEvent* event) {
    120     VLOG(1) << "RestartWithAuth username=" << username_
    121             << " password=" << password_;
    122     event->socket->RestartWithAuth(username_, password_);
    123   }
    124   void SetAuthInfo(const string16& username,
    125                    const string16& password) {
    126     username_ = username;
    127     password_ = password;
    128   }
    129 
    130   const std::vector<SocketStreamEvent>& GetSeenEvents() const {
    131     return events_;
    132   }
    133 
    134  private:
    135   std::vector<SocketStreamEvent> events_;
    136   Callback1<SocketStreamEvent*>::Type* on_connected_;
    137   Callback1<SocketStreamEvent*>::Type* on_sent_data_;
    138   Callback1<SocketStreamEvent*>::Type* on_received_data_;
    139   Callback1<SocketStreamEvent*>::Type* on_close_;
    140   Callback1<SocketStreamEvent*>::Type* on_auth_required_;
    141   net::CompletionCallback* callback_;
    142 
    143   string16 username_;
    144   string16 password_;
    145 
    146   DISALLOW_COPY_AND_ASSIGN(SocketStreamEventRecorder);
    147 };
    148 
    149 namespace net {
    150 
    151 class SocketStreamTest : public PlatformTest {
    152  public:
    153   virtual ~SocketStreamTest() {}
    154   virtual void SetUp() {
    155     mock_socket_factory_.reset();
    156     handshake_request_ = kWebSocketHandshakeRequest;
    157     handshake_response_ = kWebSocketHandshakeResponse;
    158   }
    159   virtual void TearDown() {
    160     mock_socket_factory_.reset();
    161   }
    162 
    163   virtual void SetWebSocketHandshakeMessage(
    164       const char* request, const char* response) {
    165     handshake_request_ = request;
    166     handshake_response_ = response;
    167   }
    168   virtual void AddWebSocketMessage(const std::string& message) {
    169     messages_.push_back(message);
    170   }
    171 
    172   virtual MockClientSocketFactory* GetMockClientSocketFactory() {
    173     mock_socket_factory_.reset(new MockClientSocketFactory);
    174     return mock_socket_factory_.get();
    175   }
    176 
    177   virtual void DoSendWebSocketHandshake(SocketStreamEvent* event) {
    178     event->socket->SendData(
    179         handshake_request_.data(), handshake_request_.size());
    180   }
    181 
    182   virtual void DoCloseFlushPendingWriteTest(SocketStreamEvent* event) {
    183     // handshake response received.
    184     for (size_t i = 0; i < messages_.size(); i++) {
    185       std::vector<char> frame;
    186       frame.push_back('\0');
    187       frame.insert(frame.end(), messages_[i].begin(), messages_[i].end());
    188       frame.push_back('\xff');
    189       EXPECT_TRUE(event->socket->SendData(&frame[0], frame.size()));
    190     }
    191     // Actual ClientSocket close must happen after all frames queued by
    192     // SendData above are sent out.
    193     event->socket->Close();
    194   }
    195 
    196   static const char* kWebSocketHandshakeRequest;
    197   static const char* kWebSocketHandshakeResponse;
    198 
    199  private:
    200   std::string handshake_request_;
    201   std::string handshake_response_;
    202   std::vector<std::string> messages_;
    203 
    204   scoped_ptr<MockClientSocketFactory> mock_socket_factory_;
    205 };
    206 
    207 const char* SocketStreamTest::kWebSocketHandshakeRequest =
    208     "GET /demo HTTP/1.1\r\n"
    209     "Host: example.com\r\n"
    210     "Connection: Upgrade\r\n"
    211     "Sec-WebSocket-Key2: 12998 5 Y3 1  .P00\r\n"
    212     "Sec-WebSocket-Protocol: sample\r\n"
    213     "Upgrade: WebSocket\r\n"
    214     "Sec-WebSocket-Key1: 4 @1  46546xW%0l 1 5\r\n"
    215     "Origin: http://example.com\r\n"
    216     "\r\n"
    217     "^n:ds[4U";
    218 
    219 const char* SocketStreamTest::kWebSocketHandshakeResponse =
    220     "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
    221     "Upgrade: WebSocket\r\n"
    222     "Connection: Upgrade\r\n"
    223     "Sec-WebSocket-Origin: http://example.com\r\n"
    224     "Sec-WebSocket-Location: ws://example.com/demo\r\n"
    225     "Sec-WebSocket-Protocol: sample\r\n"
    226     "\r\n"
    227     "8jKS'y:G*Co,Wxa-";
    228 
    229 TEST_F(SocketStreamTest, CloseFlushPendingWrite) {
    230   TestCompletionCallback callback;
    231 
    232   scoped_ptr<SocketStreamEventRecorder> delegate(
    233       new SocketStreamEventRecorder(&callback));
    234   // Necessary for NewCallback.
    235   SocketStreamTest* test = this;
    236   delegate->SetOnConnected(NewCallback(
    237       test, &SocketStreamTest::DoSendWebSocketHandshake));
    238   delegate->SetOnReceivedData(NewCallback(
    239       test, &SocketStreamTest::DoCloseFlushPendingWriteTest));
    240 
    241   MockHostResolver host_resolver;
    242 
    243   scoped_refptr<SocketStream> socket_stream(
    244       new SocketStream(GURL("ws://example.com/demo"), delegate.get()));
    245 
    246   socket_stream->set_context(new TestURLRequestContext());
    247   socket_stream->SetHostResolver(&host_resolver);
    248 
    249   MockWrite data_writes[] = {
    250     MockWrite(SocketStreamTest::kWebSocketHandshakeRequest),
    251     MockWrite(true, "\0message1\xff", 10),
    252     MockWrite(true, "\0message2\xff", 10)
    253   };
    254   MockRead data_reads[] = {
    255     MockRead(SocketStreamTest::kWebSocketHandshakeResponse),
    256     // Server doesn't close the connection after handshake.
    257     MockRead(true, ERR_IO_PENDING)
    258   };
    259   AddWebSocketMessage("message1");
    260   AddWebSocketMessage("message2");
    261 
    262   scoped_refptr<DelayedSocketData> data_provider(
    263       new DelayedSocketData(1,
    264                             data_reads, arraysize(data_reads),
    265                             data_writes, arraysize(data_writes)));
    266 
    267   MockClientSocketFactory* mock_socket_factory =
    268       GetMockClientSocketFactory();
    269   mock_socket_factory->AddSocketDataProvider(data_provider.get());
    270 
    271   socket_stream->SetClientSocketFactory(mock_socket_factory);
    272 
    273   socket_stream->Connect();
    274 
    275   callback.WaitForResult();
    276 
    277   const std::vector<SocketStreamEvent>& events = delegate->GetSeenEvents();
    278   EXPECT_EQ(6U, events.size());
    279 
    280   EXPECT_EQ(SocketStreamEvent::EVENT_CONNECTED, events[0].event_type);
    281   EXPECT_EQ(SocketStreamEvent::EVENT_SENT_DATA, events[1].event_type);
    282   EXPECT_EQ(SocketStreamEvent::EVENT_RECEIVED_DATA, events[2].event_type);
    283   EXPECT_EQ(SocketStreamEvent::EVENT_SENT_DATA, events[3].event_type);
    284   EXPECT_EQ(SocketStreamEvent::EVENT_SENT_DATA, events[4].event_type);
    285   EXPECT_EQ(SocketStreamEvent::EVENT_CLOSE, events[5].event_type);
    286 }
    287 
    288 TEST_F(SocketStreamTest, BasicAuthProxy) {
    289   MockClientSocketFactory mock_socket_factory;
    290   MockWrite data_writes1[] = {
    291     MockWrite("CONNECT example.com:80 HTTP/1.1\r\n"
    292               "Host: example.com\r\n"
    293               "Proxy-Connection: keep-alive\r\n\r\n"),
    294   };
    295   MockRead data_reads1[] = {
    296     MockRead("HTTP/1.1 407 Proxy Authentication Required\r\n"),
    297     MockRead("Proxy-Authenticate: Basic realm=\"MyRealm1\"\r\n"),
    298     MockRead("\r\n"),
    299   };
    300   StaticSocketDataProvider data1(data_reads1, arraysize(data_reads1),
    301                                  data_writes1, arraysize(data_writes1));
    302   mock_socket_factory.AddSocketDataProvider(&data1);
    303 
    304   MockWrite data_writes2[] = {
    305     MockWrite("CONNECT example.com:80 HTTP/1.1\r\n"
    306               "Host: example.com\r\n"
    307               "Proxy-Connection: keep-alive\r\n"
    308               "Proxy-Authorization: Basic Zm9vOmJhcg==\r\n\r\n"),
    309   };
    310   MockRead data_reads2[] = {
    311     MockRead("HTTP/1.1 200 Connection Established\r\n"),
    312     MockRead("Proxy-agent: Apache/2.2.8\r\n"),
    313     MockRead("\r\n"),
    314     // SocketStream::DoClose is run asynchronously.  Socket can be read after
    315     // "\r\n".  We have to give ERR_IO_PENDING to SocketStream then to indicate
    316     // server doesn't close the connection.
    317     MockRead(true, ERR_IO_PENDING)
    318   };
    319   StaticSocketDataProvider data2(data_reads2, arraysize(data_reads2),
    320                                  data_writes2, arraysize(data_writes2));
    321   mock_socket_factory.AddSocketDataProvider(&data2);
    322 
    323   TestCompletionCallback callback;
    324 
    325   scoped_ptr<SocketStreamEventRecorder> delegate(
    326       new SocketStreamEventRecorder(&callback));
    327   delegate->SetOnConnected(NewCallback(delegate.get(),
    328                                        &SocketStreamEventRecorder::DoClose));
    329   delegate->SetAuthInfo(ASCIIToUTF16("foo"), ASCIIToUTF16("bar"));
    330   delegate->SetOnAuthRequired(
    331       NewCallback(delegate.get(),
    332                   &SocketStreamEventRecorder::DoRestartWithAuth));
    333 
    334   scoped_refptr<SocketStream> socket_stream(
    335       new SocketStream(GURL("ws://example.com/demo"), delegate.get()));
    336 
    337   socket_stream->set_context(new TestURLRequestContext("myproxy:70"));
    338   MockHostResolver host_resolver;
    339   socket_stream->SetHostResolver(&host_resolver);
    340   socket_stream->SetClientSocketFactory(&mock_socket_factory);
    341 
    342   socket_stream->Connect();
    343 
    344   callback.WaitForResult();
    345 
    346   const std::vector<SocketStreamEvent>& events = delegate->GetSeenEvents();
    347   EXPECT_EQ(3U, events.size());
    348 
    349   EXPECT_EQ(SocketStreamEvent::EVENT_AUTH_REQUIRED, events[0].event_type);
    350   EXPECT_EQ(SocketStreamEvent::EVENT_CONNECTED, events[1].event_type);
    351   EXPECT_EQ(SocketStreamEvent::EVENT_CLOSE, events[2].event_type);
    352 
    353   // TODO(eroman): Add back NetLogTest here...
    354 }
    355 
    356 }  // namespace net
    357