1 // Copyright (c) 2010 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include <string> 6 #include <vector> 7 8 #include "base/callback.h" 9 #include "base/utf_string_conversions.h" 10 #include "net/base/auth.h" 11 #include "net/base/mock_host_resolver.h" 12 #include "net/base/net_log.h" 13 #include "net/base/net_log_unittest.h" 14 #include "net/base/test_completion_callback.h" 15 #include "net/socket/socket_test_util.h" 16 #include "net/socket_stream/socket_stream.h" 17 #include "net/url_request/url_request_test_util.h" 18 #include "testing/gtest/include/gtest/gtest.h" 19 #include "testing/platform_test.h" 20 21 struct SocketStreamEvent { 22 enum EventType { 23 EVENT_CONNECTED, EVENT_SENT_DATA, EVENT_RECEIVED_DATA, EVENT_CLOSE, 24 EVENT_AUTH_REQUIRED, 25 }; 26 27 SocketStreamEvent(EventType type, net::SocketStream* socket_stream, 28 int num, const std::string& str, 29 net::AuthChallengeInfo* auth_challenge_info) 30 : event_type(type), socket(socket_stream), number(num), data(str), 31 auth_info(auth_challenge_info) {} 32 33 EventType event_type; 34 net::SocketStream* socket; 35 int number; 36 std::string data; 37 scoped_refptr<net::AuthChallengeInfo> auth_info; 38 }; 39 40 class SocketStreamEventRecorder : public net::SocketStream::Delegate { 41 public: 42 explicit SocketStreamEventRecorder(net::CompletionCallback* callback) 43 : on_connected_(NULL), 44 on_sent_data_(NULL), 45 on_received_data_(NULL), 46 on_close_(NULL), 47 on_auth_required_(NULL), 48 callback_(callback) {} 49 virtual ~SocketStreamEventRecorder() { 50 delete on_connected_; 51 delete on_sent_data_; 52 delete on_received_data_; 53 delete on_close_; 54 delete on_auth_required_; 55 } 56 57 void SetOnConnected(Callback1<SocketStreamEvent*>::Type* callback) { 58 on_connected_ = callback; 59 } 60 void SetOnSentData(Callback1<SocketStreamEvent*>::Type* callback) { 61 on_sent_data_ = callback; 62 } 63 void SetOnReceivedData(Callback1<SocketStreamEvent*>::Type* callback) { 64 on_received_data_ = callback; 65 } 66 void SetOnClose(Callback1<SocketStreamEvent*>::Type* callback) { 67 on_close_ = callback; 68 } 69 void SetOnAuthRequired(Callback1<SocketStreamEvent*>::Type* callback) { 70 on_auth_required_ = callback; 71 } 72 73 virtual void OnConnected(net::SocketStream* socket, 74 int num_pending_send_allowed) { 75 events_.push_back( 76 SocketStreamEvent(SocketStreamEvent::EVENT_CONNECTED, 77 socket, num_pending_send_allowed, std::string(), 78 NULL)); 79 if (on_connected_) 80 on_connected_->Run(&events_.back()); 81 } 82 virtual void OnSentData(net::SocketStream* socket, 83 int amount_sent) { 84 events_.push_back( 85 SocketStreamEvent(SocketStreamEvent::EVENT_SENT_DATA, 86 socket, amount_sent, std::string(), NULL)); 87 if (on_sent_data_) 88 on_sent_data_->Run(&events_.back()); 89 } 90 virtual void OnReceivedData(net::SocketStream* socket, 91 const char* data, int len) { 92 events_.push_back( 93 SocketStreamEvent(SocketStreamEvent::EVENT_RECEIVED_DATA, 94 socket, len, std::string(data, len), NULL)); 95 if (on_received_data_) 96 on_received_data_->Run(&events_.back()); 97 } 98 virtual void OnClose(net::SocketStream* socket) { 99 events_.push_back( 100 SocketStreamEvent(SocketStreamEvent::EVENT_CLOSE, 101 socket, 0, std::string(), NULL)); 102 if (on_close_) 103 on_close_->Run(&events_.back()); 104 if (callback_) 105 callback_->Run(net::OK); 106 } 107 virtual void OnAuthRequired(net::SocketStream* socket, 108 net::AuthChallengeInfo* auth_info) { 109 events_.push_back( 110 SocketStreamEvent(SocketStreamEvent::EVENT_AUTH_REQUIRED, 111 socket, 0, std::string(), auth_info)); 112 if (on_auth_required_) 113 on_auth_required_->Run(&events_.back()); 114 } 115 116 void DoClose(SocketStreamEvent* event) { 117 event->socket->Close(); 118 } 119 void DoRestartWithAuth(SocketStreamEvent* event) { 120 VLOG(1) << "RestartWithAuth username=" << username_ 121 << " password=" << password_; 122 event->socket->RestartWithAuth(username_, password_); 123 } 124 void SetAuthInfo(const string16& username, 125 const string16& password) { 126 username_ = username; 127 password_ = password; 128 } 129 130 const std::vector<SocketStreamEvent>& GetSeenEvents() const { 131 return events_; 132 } 133 134 private: 135 std::vector<SocketStreamEvent> events_; 136 Callback1<SocketStreamEvent*>::Type* on_connected_; 137 Callback1<SocketStreamEvent*>::Type* on_sent_data_; 138 Callback1<SocketStreamEvent*>::Type* on_received_data_; 139 Callback1<SocketStreamEvent*>::Type* on_close_; 140 Callback1<SocketStreamEvent*>::Type* on_auth_required_; 141 net::CompletionCallback* callback_; 142 143 string16 username_; 144 string16 password_; 145 146 DISALLOW_COPY_AND_ASSIGN(SocketStreamEventRecorder); 147 }; 148 149 namespace net { 150 151 class SocketStreamTest : public PlatformTest { 152 public: 153 virtual ~SocketStreamTest() {} 154 virtual void SetUp() { 155 mock_socket_factory_.reset(); 156 handshake_request_ = kWebSocketHandshakeRequest; 157 handshake_response_ = kWebSocketHandshakeResponse; 158 } 159 virtual void TearDown() { 160 mock_socket_factory_.reset(); 161 } 162 163 virtual void SetWebSocketHandshakeMessage( 164 const char* request, const char* response) { 165 handshake_request_ = request; 166 handshake_response_ = response; 167 } 168 virtual void AddWebSocketMessage(const std::string& message) { 169 messages_.push_back(message); 170 } 171 172 virtual MockClientSocketFactory* GetMockClientSocketFactory() { 173 mock_socket_factory_.reset(new MockClientSocketFactory); 174 return mock_socket_factory_.get(); 175 } 176 177 virtual void DoSendWebSocketHandshake(SocketStreamEvent* event) { 178 event->socket->SendData( 179 handshake_request_.data(), handshake_request_.size()); 180 } 181 182 virtual void DoCloseFlushPendingWriteTest(SocketStreamEvent* event) { 183 // handshake response received. 184 for (size_t i = 0; i < messages_.size(); i++) { 185 std::vector<char> frame; 186 frame.push_back('\0'); 187 frame.insert(frame.end(), messages_[i].begin(), messages_[i].end()); 188 frame.push_back('\xff'); 189 EXPECT_TRUE(event->socket->SendData(&frame[0], frame.size())); 190 } 191 // Actual ClientSocket close must happen after all frames queued by 192 // SendData above are sent out. 193 event->socket->Close(); 194 } 195 196 static const char* kWebSocketHandshakeRequest; 197 static const char* kWebSocketHandshakeResponse; 198 199 private: 200 std::string handshake_request_; 201 std::string handshake_response_; 202 std::vector<std::string> messages_; 203 204 scoped_ptr<MockClientSocketFactory> mock_socket_factory_; 205 }; 206 207 const char* SocketStreamTest::kWebSocketHandshakeRequest = 208 "GET /demo HTTP/1.1\r\n" 209 "Host: example.com\r\n" 210 "Connection: Upgrade\r\n" 211 "Sec-WebSocket-Key2: 12998 5 Y3 1 .P00\r\n" 212 "Sec-WebSocket-Protocol: sample\r\n" 213 "Upgrade: WebSocket\r\n" 214 "Sec-WebSocket-Key1: 4 @1 46546xW%0l 1 5\r\n" 215 "Origin: http://example.com\r\n" 216 "\r\n" 217 "^n:ds[4U"; 218 219 const char* SocketStreamTest::kWebSocketHandshakeResponse = 220 "HTTP/1.1 101 WebSocket Protocol Handshake\r\n" 221 "Upgrade: WebSocket\r\n" 222 "Connection: Upgrade\r\n" 223 "Sec-WebSocket-Origin: http://example.com\r\n" 224 "Sec-WebSocket-Location: ws://example.com/demo\r\n" 225 "Sec-WebSocket-Protocol: sample\r\n" 226 "\r\n" 227 "8jKS'y:G*Co,Wxa-"; 228 229 TEST_F(SocketStreamTest, CloseFlushPendingWrite) { 230 TestCompletionCallback callback; 231 232 scoped_ptr<SocketStreamEventRecorder> delegate( 233 new SocketStreamEventRecorder(&callback)); 234 // Necessary for NewCallback. 235 SocketStreamTest* test = this; 236 delegate->SetOnConnected(NewCallback( 237 test, &SocketStreamTest::DoSendWebSocketHandshake)); 238 delegate->SetOnReceivedData(NewCallback( 239 test, &SocketStreamTest::DoCloseFlushPendingWriteTest)); 240 241 MockHostResolver host_resolver; 242 243 scoped_refptr<SocketStream> socket_stream( 244 new SocketStream(GURL("ws://example.com/demo"), delegate.get())); 245 246 socket_stream->set_context(new TestURLRequestContext()); 247 socket_stream->SetHostResolver(&host_resolver); 248 249 MockWrite data_writes[] = { 250 MockWrite(SocketStreamTest::kWebSocketHandshakeRequest), 251 MockWrite(true, "\0message1\xff", 10), 252 MockWrite(true, "\0message2\xff", 10) 253 }; 254 MockRead data_reads[] = { 255 MockRead(SocketStreamTest::kWebSocketHandshakeResponse), 256 // Server doesn't close the connection after handshake. 257 MockRead(true, ERR_IO_PENDING) 258 }; 259 AddWebSocketMessage("message1"); 260 AddWebSocketMessage("message2"); 261 262 scoped_refptr<DelayedSocketData> data_provider( 263 new DelayedSocketData(1, 264 data_reads, arraysize(data_reads), 265 data_writes, arraysize(data_writes))); 266 267 MockClientSocketFactory* mock_socket_factory = 268 GetMockClientSocketFactory(); 269 mock_socket_factory->AddSocketDataProvider(data_provider.get()); 270 271 socket_stream->SetClientSocketFactory(mock_socket_factory); 272 273 socket_stream->Connect(); 274 275 callback.WaitForResult(); 276 277 const std::vector<SocketStreamEvent>& events = delegate->GetSeenEvents(); 278 EXPECT_EQ(6U, events.size()); 279 280 EXPECT_EQ(SocketStreamEvent::EVENT_CONNECTED, events[0].event_type); 281 EXPECT_EQ(SocketStreamEvent::EVENT_SENT_DATA, events[1].event_type); 282 EXPECT_EQ(SocketStreamEvent::EVENT_RECEIVED_DATA, events[2].event_type); 283 EXPECT_EQ(SocketStreamEvent::EVENT_SENT_DATA, events[3].event_type); 284 EXPECT_EQ(SocketStreamEvent::EVENT_SENT_DATA, events[4].event_type); 285 EXPECT_EQ(SocketStreamEvent::EVENT_CLOSE, events[5].event_type); 286 } 287 288 TEST_F(SocketStreamTest, BasicAuthProxy) { 289 MockClientSocketFactory mock_socket_factory; 290 MockWrite data_writes1[] = { 291 MockWrite("CONNECT example.com:80 HTTP/1.1\r\n" 292 "Host: example.com\r\n" 293 "Proxy-Connection: keep-alive\r\n\r\n"), 294 }; 295 MockRead data_reads1[] = { 296 MockRead("HTTP/1.1 407 Proxy Authentication Required\r\n"), 297 MockRead("Proxy-Authenticate: Basic realm=\"MyRealm1\"\r\n"), 298 MockRead("\r\n"), 299 }; 300 StaticSocketDataProvider data1(data_reads1, arraysize(data_reads1), 301 data_writes1, arraysize(data_writes1)); 302 mock_socket_factory.AddSocketDataProvider(&data1); 303 304 MockWrite data_writes2[] = { 305 MockWrite("CONNECT example.com:80 HTTP/1.1\r\n" 306 "Host: example.com\r\n" 307 "Proxy-Connection: keep-alive\r\n" 308 "Proxy-Authorization: Basic Zm9vOmJhcg==\r\n\r\n"), 309 }; 310 MockRead data_reads2[] = { 311 MockRead("HTTP/1.1 200 Connection Established\r\n"), 312 MockRead("Proxy-agent: Apache/2.2.8\r\n"), 313 MockRead("\r\n"), 314 // SocketStream::DoClose is run asynchronously. Socket can be read after 315 // "\r\n". We have to give ERR_IO_PENDING to SocketStream then to indicate 316 // server doesn't close the connection. 317 MockRead(true, ERR_IO_PENDING) 318 }; 319 StaticSocketDataProvider data2(data_reads2, arraysize(data_reads2), 320 data_writes2, arraysize(data_writes2)); 321 mock_socket_factory.AddSocketDataProvider(&data2); 322 323 TestCompletionCallback callback; 324 325 scoped_ptr<SocketStreamEventRecorder> delegate( 326 new SocketStreamEventRecorder(&callback)); 327 delegate->SetOnConnected(NewCallback(delegate.get(), 328 &SocketStreamEventRecorder::DoClose)); 329 delegate->SetAuthInfo(ASCIIToUTF16("foo"), ASCIIToUTF16("bar")); 330 delegate->SetOnAuthRequired( 331 NewCallback(delegate.get(), 332 &SocketStreamEventRecorder::DoRestartWithAuth)); 333 334 scoped_refptr<SocketStream> socket_stream( 335 new SocketStream(GURL("ws://example.com/demo"), delegate.get())); 336 337 socket_stream->set_context(new TestURLRequestContext("myproxy:70")); 338 MockHostResolver host_resolver; 339 socket_stream->SetHostResolver(&host_resolver); 340 socket_stream->SetClientSocketFactory(&mock_socket_factory); 341 342 socket_stream->Connect(); 343 344 callback.WaitForResult(); 345 346 const std::vector<SocketStreamEvent>& events = delegate->GetSeenEvents(); 347 EXPECT_EQ(3U, events.size()); 348 349 EXPECT_EQ(SocketStreamEvent::EVENT_AUTH_REQUIRED, events[0].event_type); 350 EXPECT_EQ(SocketStreamEvent::EVENT_CONNECTED, events[1].event_type); 351 EXPECT_EQ(SocketStreamEvent::EVENT_CLOSE, events[2].event_type); 352 353 // TODO(eroman): Add back NetLogTest here... 354 } 355 356 } // namespace net 357