1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "net/socket_stream/socket_stream.h" 6 7 #include <string> 8 #include <vector> 9 10 #include "base/bind.h" 11 #include "base/bind_helpers.h" 12 #include "base/callback.h" 13 #include "base/strings/utf_string_conversions.h" 14 #include "net/base/auth.h" 15 #include "net/base/net_log.h" 16 #include "net/base/net_log_unittest.h" 17 #include "net/base/test_completion_callback.h" 18 #include "net/dns/mock_host_resolver.h" 19 #include "net/http/http_network_session.h" 20 #include "net/proxy/proxy_service.h" 21 #include "net/socket/socket_test_util.h" 22 #include "net/url_request/url_request_test_util.h" 23 #include "testing/gtest/include/gtest/gtest.h" 24 #include "testing/platform_test.h" 25 26 namespace net { 27 28 namespace { 29 30 struct SocketStreamEvent { 31 enum EventType { 32 EVENT_START_OPEN_CONNECTION, EVENT_CONNECTED, EVENT_SENT_DATA, 33 EVENT_RECEIVED_DATA, EVENT_CLOSE, EVENT_AUTH_REQUIRED, EVENT_ERROR, 34 }; 35 36 SocketStreamEvent(EventType type, 37 SocketStream* socket_stream, 38 int num, 39 const std::string& str, 40 AuthChallengeInfo* auth_challenge_info, 41 int error) 42 : event_type(type), socket(socket_stream), number(num), data(str), 43 auth_info(auth_challenge_info), error_code(error) {} 44 45 EventType event_type; 46 SocketStream* socket; 47 int number; 48 std::string data; 49 scoped_refptr<AuthChallengeInfo> auth_info; 50 int error_code; 51 }; 52 53 class SocketStreamEventRecorder : public SocketStream::Delegate { 54 public: 55 // |callback| will be run when the OnClose() or OnError() method is called. 56 // For OnClose(), |callback| is called with OK. For OnError(), it's called 57 // with the error code. 58 explicit SocketStreamEventRecorder(const CompletionCallback& callback) 59 : callback_(callback) {} 60 virtual ~SocketStreamEventRecorder() {} 61 62 void SetOnStartOpenConnection( 63 const base::Callback<int(SocketStreamEvent*)>& callback) { 64 on_start_open_connection_ = callback; 65 } 66 void SetOnConnected( 67 const base::Callback<void(SocketStreamEvent*)>& callback) { 68 on_connected_ = callback; 69 } 70 void SetOnSentData( 71 const base::Callback<void(SocketStreamEvent*)>& callback) { 72 on_sent_data_ = callback; 73 } 74 void SetOnReceivedData( 75 const base::Callback<void(SocketStreamEvent*)>& callback) { 76 on_received_data_ = callback; 77 } 78 void SetOnClose(const base::Callback<void(SocketStreamEvent*)>& callback) { 79 on_close_ = callback; 80 } 81 void SetOnAuthRequired( 82 const base::Callback<void(SocketStreamEvent*)>& callback) { 83 on_auth_required_ = callback; 84 } 85 void SetOnError(const base::Callback<void(SocketStreamEvent*)>& callback) { 86 on_error_ = callback; 87 } 88 89 virtual int OnStartOpenConnection( 90 SocketStream* socket, 91 const CompletionCallback& callback) OVERRIDE { 92 connection_callback_ = callback; 93 events_.push_back( 94 SocketStreamEvent(SocketStreamEvent::EVENT_START_OPEN_CONNECTION, 95 socket, 0, std::string(), NULL, OK)); 96 if (!on_start_open_connection_.is_null()) 97 return on_start_open_connection_.Run(&events_.back()); 98 return OK; 99 } 100 virtual void OnConnected(SocketStream* socket, 101 int num_pending_send_allowed) OVERRIDE { 102 events_.push_back( 103 SocketStreamEvent(SocketStreamEvent::EVENT_CONNECTED, 104 socket, num_pending_send_allowed, std::string(), 105 NULL, OK)); 106 if (!on_connected_.is_null()) 107 on_connected_.Run(&events_.back()); 108 } 109 virtual void OnSentData(SocketStream* socket, 110 int amount_sent) OVERRIDE { 111 events_.push_back( 112 SocketStreamEvent(SocketStreamEvent::EVENT_SENT_DATA, socket, 113 amount_sent, std::string(), NULL, OK)); 114 if (!on_sent_data_.is_null()) 115 on_sent_data_.Run(&events_.back()); 116 } 117 virtual void OnReceivedData(SocketStream* socket, 118 const char* data, int len) OVERRIDE { 119 events_.push_back( 120 SocketStreamEvent(SocketStreamEvent::EVENT_RECEIVED_DATA, socket, len, 121 std::string(data, len), NULL, OK)); 122 if (!on_received_data_.is_null()) 123 on_received_data_.Run(&events_.back()); 124 } 125 virtual void OnClose(SocketStream* socket) OVERRIDE { 126 events_.push_back( 127 SocketStreamEvent(SocketStreamEvent::EVENT_CLOSE, socket, 0, 128 std::string(), NULL, OK)); 129 if (!on_close_.is_null()) 130 on_close_.Run(&events_.back()); 131 if (!callback_.is_null()) 132 callback_.Run(OK); 133 } 134 virtual void OnAuthRequired(SocketStream* socket, 135 AuthChallengeInfo* auth_info) OVERRIDE { 136 events_.push_back( 137 SocketStreamEvent(SocketStreamEvent::EVENT_AUTH_REQUIRED, socket, 0, 138 std::string(), auth_info, OK)); 139 if (!on_auth_required_.is_null()) 140 on_auth_required_.Run(&events_.back()); 141 } 142 virtual void OnError(const SocketStream* socket, int error) OVERRIDE { 143 events_.push_back( 144 SocketStreamEvent(SocketStreamEvent::EVENT_ERROR, NULL, 0, 145 std::string(), NULL, error)); 146 if (!on_error_.is_null()) 147 on_error_.Run(&events_.back()); 148 if (!callback_.is_null()) 149 callback_.Run(error); 150 } 151 152 void DoClose(SocketStreamEvent* event) { 153 event->socket->Close(); 154 } 155 void DoRestartWithAuth(SocketStreamEvent* event) { 156 VLOG(1) << "RestartWithAuth username=" << credentials_.username() 157 << " password=" << credentials_.password(); 158 event->socket->RestartWithAuth(credentials_); 159 } 160 void SetAuthInfo(const AuthCredentials& credentials) { 161 credentials_ = credentials; 162 } 163 // Wakes up the SocketStream waiting for completion of OnStartOpenConnection() 164 // of its delegate. 165 void CompleteConnection(int result) { 166 connection_callback_.Run(result); 167 } 168 169 const std::vector<SocketStreamEvent>& GetSeenEvents() const { 170 return events_; 171 } 172 173 private: 174 std::vector<SocketStreamEvent> events_; 175 base::Callback<int(SocketStreamEvent*)> on_start_open_connection_; 176 base::Callback<void(SocketStreamEvent*)> on_connected_; 177 base::Callback<void(SocketStreamEvent*)> on_sent_data_; 178 base::Callback<void(SocketStreamEvent*)> on_received_data_; 179 base::Callback<void(SocketStreamEvent*)> on_close_; 180 base::Callback<void(SocketStreamEvent*)> on_auth_required_; 181 base::Callback<void(SocketStreamEvent*)> on_error_; 182 const CompletionCallback callback_; 183 CompletionCallback connection_callback_; 184 AuthCredentials credentials_; 185 186 DISALLOW_COPY_AND_ASSIGN(SocketStreamEventRecorder); 187 }; 188 189 // This is used for the test OnErrorDetachDelegate. 190 class SelfDeletingDelegate : public SocketStream::Delegate { 191 public: 192 // |callback| must cause the test message loop to exit when called. 193 explicit SelfDeletingDelegate(const CompletionCallback& callback) 194 : socket_stream_(), callback_(callback) {} 195 196 virtual ~SelfDeletingDelegate() {} 197 198 // Call DetachDelegate(), delete |this|, then run the callback. 199 virtual void OnError(const SocketStream* socket, int error) OVERRIDE { 200 // callback_ will be deleted when we delete |this|, so copy it to call it 201 // afterwards. 202 CompletionCallback callback = callback_; 203 socket_stream_->DetachDelegate(); 204 delete this; 205 callback.Run(OK); 206 } 207 208 // This can't be passed in the constructor because this object needs to be 209 // created before SocketStream. 210 void set_socket_stream(const scoped_refptr<SocketStream>& socket_stream) { 211 socket_stream_ = socket_stream; 212 EXPECT_EQ(socket_stream_->delegate(), this); 213 } 214 215 virtual void OnConnected(SocketStream* socket, int max_pending_send_allowed) 216 OVERRIDE { 217 ADD_FAILURE() << "OnConnected() should not be called"; 218 } 219 virtual void OnSentData(SocketStream* socket, int amount_sent) OVERRIDE { 220 ADD_FAILURE() << "OnSentData() should not be called"; 221 } 222 virtual void OnReceivedData(SocketStream* socket, const char* data, int len) 223 OVERRIDE { 224 ADD_FAILURE() << "OnReceivedData() should not be called"; 225 } 226 virtual void OnClose(SocketStream* socket) OVERRIDE { 227 ADD_FAILURE() << "OnClose() should not be called"; 228 } 229 230 private: 231 scoped_refptr<SocketStream> socket_stream_; 232 const CompletionCallback callback_; 233 234 DISALLOW_COPY_AND_ASSIGN(SelfDeletingDelegate); 235 }; 236 237 class TestURLRequestContextWithProxy : public TestURLRequestContext { 238 public: 239 explicit TestURLRequestContextWithProxy(const std::string& proxy) 240 : TestURLRequestContext(true) { 241 context_storage_.set_proxy_service(ProxyService::CreateFixed(proxy)); 242 Init(); 243 } 244 virtual ~TestURLRequestContextWithProxy() {} 245 }; 246 247 class TestSocketStreamNetworkDelegate : public TestNetworkDelegate { 248 public: 249 TestSocketStreamNetworkDelegate() 250 : before_connect_result_(OK) {} 251 virtual ~TestSocketStreamNetworkDelegate() {} 252 253 virtual int OnBeforeSocketStreamConnect( 254 SocketStream* stream, 255 const CompletionCallback& callback) OVERRIDE { 256 return before_connect_result_; 257 } 258 259 void SetBeforeConnectResult(int result) { 260 before_connect_result_ = result; 261 } 262 263 private: 264 int before_connect_result_; 265 }; 266 267 } // namespace 268 269 class SocketStreamTest : public PlatformTest { 270 public: 271 virtual ~SocketStreamTest() {} 272 virtual void SetUp() { 273 mock_socket_factory_.reset(); 274 handshake_request_ = kWebSocketHandshakeRequest; 275 handshake_response_ = kWebSocketHandshakeResponse; 276 } 277 virtual void TearDown() { 278 mock_socket_factory_.reset(); 279 } 280 281 virtual void SetWebSocketHandshakeMessage( 282 const char* request, const char* response) { 283 handshake_request_ = request; 284 handshake_response_ = response; 285 } 286 virtual void AddWebSocketMessage(const std::string& message) { 287 messages_.push_back(message); 288 } 289 290 virtual MockClientSocketFactory* GetMockClientSocketFactory() { 291 mock_socket_factory_.reset(new MockClientSocketFactory); 292 return mock_socket_factory_.get(); 293 } 294 295 // Functions for SocketStreamEventRecorder to handle calls to the 296 // SocketStream::Delegate methods from the SocketStream. 297 298 virtual void DoSendWebSocketHandshake(SocketStreamEvent* event) { 299 event->socket->SendData( 300 handshake_request_.data(), handshake_request_.size()); 301 } 302 303 virtual void DoCloseFlushPendingWriteTest(SocketStreamEvent* event) { 304 // handshake response received. 305 for (size_t i = 0; i < messages_.size(); i++) { 306 std::vector<char> frame; 307 frame.push_back('\0'); 308 frame.insert(frame.end(), messages_[i].begin(), messages_[i].end()); 309 frame.push_back('\xff'); 310 EXPECT_TRUE(event->socket->SendData(&frame[0], frame.size())); 311 } 312 // Actual StreamSocket close must happen after all frames queued by 313 // SendData above are sent out. 314 event->socket->Close(); 315 } 316 317 virtual void DoCloseFlushPendingWriteTestWithSetContextNull( 318 SocketStreamEvent* event) { 319 event->socket->set_context(NULL); 320 // handshake response received. 321 for (size_t i = 0; i < messages_.size(); i++) { 322 std::vector<char> frame; 323 frame.push_back('\0'); 324 frame.insert(frame.end(), messages_[i].begin(), messages_[i].end()); 325 frame.push_back('\xff'); 326 EXPECT_TRUE(event->socket->SendData(&frame[0], frame.size())); 327 } 328 // Actual StreamSocket close must happen after all frames queued by 329 // SendData above are sent out. 330 event->socket->Close(); 331 } 332 333 virtual void DoFailByTooBigDataAndClose(SocketStreamEvent* event) { 334 std::string frame(event->number + 1, 0x00); 335 VLOG(1) << event->number; 336 EXPECT_FALSE(event->socket->SendData(&frame[0], frame.size())); 337 event->socket->Close(); 338 } 339 340 virtual int DoSwitchToSpdyTest(SocketStreamEvent* event) { 341 return ERR_PROTOCOL_SWITCHED; 342 } 343 344 // Notifies |io_test_callback_| of that this method is called, and keeps the 345 // SocketStream waiting. 346 virtual int DoIOPending(SocketStreamEvent* event) { 347 io_test_callback_.callback().Run(OK); 348 return ERR_IO_PENDING; 349 } 350 351 static const char kWebSocketHandshakeRequest[]; 352 static const char kWebSocketHandshakeResponse[]; 353 354 protected: 355 TestCompletionCallback io_test_callback_; 356 357 private: 358 std::string handshake_request_; 359 std::string handshake_response_; 360 std::vector<std::string> messages_; 361 362 scoped_ptr<MockClientSocketFactory> mock_socket_factory_; 363 }; 364 365 const char SocketStreamTest::kWebSocketHandshakeRequest[] = 366 "GET /demo HTTP/1.1\r\n" 367 "Host: example.com\r\n" 368 "Connection: Upgrade\r\n" 369 "Sec-WebSocket-Key2: 12998 5 Y3 1 .P00\r\n" 370 "Sec-WebSocket-Protocol: sample\r\n" 371 "Upgrade: WebSocket\r\n" 372 "Sec-WebSocket-Key1: 4 @1 46546xW%0l 1 5\r\n" 373 "Origin: http://example.com\r\n" 374 "\r\n" 375 "^n:ds[4U"; 376 377 const char SocketStreamTest::kWebSocketHandshakeResponse[] = 378 "HTTP/1.1 101 WebSocket Protocol Handshake\r\n" 379 "Upgrade: WebSocket\r\n" 380 "Connection: Upgrade\r\n" 381 "Sec-WebSocket-Origin: http://example.com\r\n" 382 "Sec-WebSocket-Location: ws://example.com/demo\r\n" 383 "Sec-WebSocket-Protocol: sample\r\n" 384 "\r\n" 385 "8jKS'y:G*Co,Wxa-"; 386 387 TEST_F(SocketStreamTest, CloseFlushPendingWrite) { 388 TestCompletionCallback test_callback; 389 390 scoped_ptr<SocketStreamEventRecorder> delegate( 391 new SocketStreamEventRecorder(test_callback.callback())); 392 delegate->SetOnConnected(base::Bind( 393 &SocketStreamTest::DoSendWebSocketHandshake, base::Unretained(this))); 394 delegate->SetOnReceivedData(base::Bind( 395 &SocketStreamTest::DoCloseFlushPendingWriteTest, 396 base::Unretained(this))); 397 398 TestURLRequestContext context; 399 400 scoped_refptr<SocketStream> socket_stream( 401 new SocketStream(GURL("ws://example.com/demo"), delegate.get())); 402 403 socket_stream->set_context(&context); 404 405 MockWrite data_writes[] = { 406 MockWrite(SocketStreamTest::kWebSocketHandshakeRequest), 407 MockWrite(ASYNC, "\0message1\xff", 10), 408 MockWrite(ASYNC, "\0message2\xff", 10) 409 }; 410 MockRead data_reads[] = { 411 MockRead(SocketStreamTest::kWebSocketHandshakeResponse), 412 // Server doesn't close the connection after handshake. 413 MockRead(ASYNC, ERR_IO_PENDING) 414 }; 415 AddWebSocketMessage("message1"); 416 AddWebSocketMessage("message2"); 417 418 DelayedSocketData data_provider( 419 1, data_reads, arraysize(data_reads), 420 data_writes, arraysize(data_writes)); 421 422 MockClientSocketFactory* mock_socket_factory = 423 GetMockClientSocketFactory(); 424 mock_socket_factory->AddSocketDataProvider(&data_provider); 425 426 socket_stream->SetClientSocketFactory(mock_socket_factory); 427 428 socket_stream->Connect(); 429 430 test_callback.WaitForResult(); 431 432 EXPECT_TRUE(data_provider.at_read_eof()); 433 EXPECT_TRUE(data_provider.at_write_eof()); 434 435 const std::vector<SocketStreamEvent>& events = delegate->GetSeenEvents(); 436 ASSERT_EQ(7U, events.size()); 437 438 EXPECT_EQ(SocketStreamEvent::EVENT_START_OPEN_CONNECTION, 439 events[0].event_type); 440 EXPECT_EQ(SocketStreamEvent::EVENT_CONNECTED, events[1].event_type); 441 EXPECT_EQ(SocketStreamEvent::EVENT_SENT_DATA, events[2].event_type); 442 EXPECT_EQ(SocketStreamEvent::EVENT_RECEIVED_DATA, events[3].event_type); 443 EXPECT_EQ(SocketStreamEvent::EVENT_SENT_DATA, events[4].event_type); 444 EXPECT_EQ(SocketStreamEvent::EVENT_SENT_DATA, events[5].event_type); 445 EXPECT_EQ(SocketStreamEvent::EVENT_CLOSE, events[6].event_type); 446 } 447 448 TEST_F(SocketStreamTest, ResolveFailure) { 449 TestCompletionCallback test_callback; 450 451 scoped_ptr<SocketStreamEventRecorder> delegate( 452 new SocketStreamEventRecorder(test_callback.callback())); 453 454 scoped_refptr<SocketStream> socket_stream( 455 new SocketStream(GURL("ws://example.com/demo"), delegate.get())); 456 457 // Make resolver fail. 458 TestURLRequestContext context; 459 scoped_ptr<MockHostResolver> mock_host_resolver( 460 new MockHostResolver()); 461 mock_host_resolver->rules()->AddSimulatedFailure("example.com"); 462 context.set_host_resolver(mock_host_resolver.get()); 463 socket_stream->set_context(&context); 464 465 // No read/write on socket is expected. 466 StaticSocketDataProvider data_provider(NULL, 0, NULL, 0); 467 MockClientSocketFactory* mock_socket_factory = 468 GetMockClientSocketFactory(); 469 mock_socket_factory->AddSocketDataProvider(&data_provider); 470 socket_stream->SetClientSocketFactory(mock_socket_factory); 471 472 socket_stream->Connect(); 473 474 test_callback.WaitForResult(); 475 476 const std::vector<SocketStreamEvent>& events = delegate->GetSeenEvents(); 477 ASSERT_EQ(2U, events.size()); 478 479 EXPECT_EQ(SocketStreamEvent::EVENT_ERROR, events[0].event_type); 480 EXPECT_EQ(SocketStreamEvent::EVENT_CLOSE, events[1].event_type); 481 } 482 483 TEST_F(SocketStreamTest, ExceedMaxPendingSendAllowed) { 484 TestCompletionCallback test_callback; 485 486 scoped_ptr<SocketStreamEventRecorder> delegate( 487 new SocketStreamEventRecorder(test_callback.callback())); 488 delegate->SetOnConnected(base::Bind( 489 &SocketStreamTest::DoFailByTooBigDataAndClose, base::Unretained(this))); 490 491 TestURLRequestContext context; 492 493 scoped_refptr<SocketStream> socket_stream( 494 new SocketStream(GURL("ws://example.com/demo"), delegate.get())); 495 496 socket_stream->set_context(&context); 497 498 DelayedSocketData data_provider(1, NULL, 0, NULL, 0); 499 500 MockClientSocketFactory* mock_socket_factory = 501 GetMockClientSocketFactory(); 502 mock_socket_factory->AddSocketDataProvider(&data_provider); 503 504 socket_stream->SetClientSocketFactory(mock_socket_factory); 505 506 socket_stream->Connect(); 507 508 test_callback.WaitForResult(); 509 510 const std::vector<SocketStreamEvent>& events = delegate->GetSeenEvents(); 511 ASSERT_EQ(4U, events.size()); 512 513 EXPECT_EQ(SocketStreamEvent::EVENT_START_OPEN_CONNECTION, 514 events[0].event_type); 515 EXPECT_EQ(SocketStreamEvent::EVENT_CONNECTED, events[1].event_type); 516 EXPECT_EQ(SocketStreamEvent::EVENT_ERROR, events[2].event_type); 517 EXPECT_EQ(SocketStreamEvent::EVENT_CLOSE, events[3].event_type); 518 } 519 520 TEST_F(SocketStreamTest, BasicAuthProxy) { 521 MockClientSocketFactory mock_socket_factory; 522 MockWrite data_writes1[] = { 523 MockWrite("CONNECT example.com:80 HTTP/1.1\r\n" 524 "Host: example.com\r\n" 525 "Proxy-Connection: keep-alive\r\n\r\n"), 526 }; 527 MockRead data_reads1[] = { 528 MockRead("HTTP/1.1 407 Proxy Authentication Required\r\n"), 529 MockRead("Proxy-Authenticate: Basic realm=\"MyRealm1\"\r\n"), 530 MockRead("\r\n"), 531 }; 532 StaticSocketDataProvider data1(data_reads1, arraysize(data_reads1), 533 data_writes1, arraysize(data_writes1)); 534 mock_socket_factory.AddSocketDataProvider(&data1); 535 536 MockWrite data_writes2[] = { 537 MockWrite("CONNECT example.com:80 HTTP/1.1\r\n" 538 "Host: example.com\r\n" 539 "Proxy-Connection: keep-alive\r\n" 540 "Proxy-Authorization: Basic Zm9vOmJhcg==\r\n\r\n"), 541 }; 542 MockRead data_reads2[] = { 543 MockRead("HTTP/1.1 200 Connection Established\r\n"), 544 MockRead("Proxy-agent: Apache/2.2.8\r\n"), 545 MockRead("\r\n"), 546 // SocketStream::DoClose is run asynchronously. Socket can be read after 547 // "\r\n". We have to give ERR_IO_PENDING to SocketStream then to indicate 548 // server doesn't close the connection. 549 MockRead(ASYNC, ERR_IO_PENDING) 550 }; 551 StaticSocketDataProvider data2(data_reads2, arraysize(data_reads2), 552 data_writes2, arraysize(data_writes2)); 553 mock_socket_factory.AddSocketDataProvider(&data2); 554 555 TestCompletionCallback test_callback; 556 557 scoped_ptr<SocketStreamEventRecorder> delegate( 558 new SocketStreamEventRecorder(test_callback.callback())); 559 delegate->SetOnConnected(base::Bind(&SocketStreamEventRecorder::DoClose, 560 base::Unretained(delegate.get()))); 561 delegate->SetAuthInfo(AuthCredentials(ASCIIToUTF16("foo"), 562 ASCIIToUTF16("bar"))); 563 delegate->SetOnAuthRequired(base::Bind( 564 &SocketStreamEventRecorder::DoRestartWithAuth, 565 base::Unretained(delegate.get()))); 566 567 scoped_refptr<SocketStream> socket_stream( 568 new SocketStream(GURL("ws://example.com/demo"), delegate.get())); 569 570 TestURLRequestContextWithProxy context("myproxy:70"); 571 572 socket_stream->set_context(&context); 573 socket_stream->SetClientSocketFactory(&mock_socket_factory); 574 575 socket_stream->Connect(); 576 577 test_callback.WaitForResult(); 578 579 const std::vector<SocketStreamEvent>& events = delegate->GetSeenEvents(); 580 ASSERT_EQ(5U, events.size()); 581 582 EXPECT_EQ(SocketStreamEvent::EVENT_START_OPEN_CONNECTION, 583 events[0].event_type); 584 EXPECT_EQ(SocketStreamEvent::EVENT_AUTH_REQUIRED, events[1].event_type); 585 EXPECT_EQ(SocketStreamEvent::EVENT_CONNECTED, events[2].event_type); 586 EXPECT_EQ(SocketStreamEvent::EVENT_ERROR, events[3].event_type); 587 EXPECT_EQ(ERR_ABORTED, events[3].error_code); 588 EXPECT_EQ(SocketStreamEvent::EVENT_CLOSE, events[4].event_type); 589 590 // TODO(eroman): Add back NetLogTest here... 591 } 592 593 TEST_F(SocketStreamTest, BasicAuthProxyWithAuthCache) { 594 MockClientSocketFactory mock_socket_factory; 595 MockWrite data_writes[] = { 596 // WebSocket(SocketStream) always uses CONNECT when it is configured to use 597 // proxy so the port may not be 443. 598 MockWrite("CONNECT example.com:80 HTTP/1.1\r\n" 599 "Host: example.com\r\n" 600 "Proxy-Connection: keep-alive\r\n" 601 "Proxy-Authorization: Basic Zm9vOmJhcg==\r\n\r\n"), 602 }; 603 MockRead data_reads[] = { 604 MockRead("HTTP/1.1 200 Connection Established\r\n"), 605 MockRead("Proxy-agent: Apache/2.2.8\r\n"), 606 MockRead("\r\n"), 607 MockRead(ASYNC, ERR_IO_PENDING) 608 }; 609 StaticSocketDataProvider data(data_reads, arraysize(data_reads), 610 data_writes, arraysize(data_writes)); 611 mock_socket_factory.AddSocketDataProvider(&data); 612 613 TestCompletionCallback test_callback; 614 scoped_ptr<SocketStreamEventRecorder> delegate( 615 new SocketStreamEventRecorder(test_callback.callback())); 616 delegate->SetOnConnected(base::Bind(&SocketStreamEventRecorder::DoClose, 617 base::Unretained(delegate.get()))); 618 619 scoped_refptr<SocketStream> socket_stream( 620 new SocketStream(GURL("ws://example.com/demo"), delegate.get())); 621 622 TestURLRequestContextWithProxy context("myproxy:70"); 623 HttpAuthCache* auth_cache = 624 context.http_transaction_factory()->GetSession()->http_auth_cache(); 625 auth_cache->Add(GURL("http://myproxy:70"), 626 "MyRealm1", 627 HttpAuth::AUTH_SCHEME_BASIC, 628 "Basic realm=MyRealm1", 629 AuthCredentials(ASCIIToUTF16("foo"), 630 ASCIIToUTF16("bar")), 631 "/"); 632 633 socket_stream->set_context(&context); 634 socket_stream->SetClientSocketFactory(&mock_socket_factory); 635 636 socket_stream->Connect(); 637 638 test_callback.WaitForResult(); 639 640 const std::vector<SocketStreamEvent>& events = delegate->GetSeenEvents(); 641 ASSERT_EQ(4U, events.size()); 642 EXPECT_EQ(SocketStreamEvent::EVENT_START_OPEN_CONNECTION, 643 events[0].event_type); 644 EXPECT_EQ(SocketStreamEvent::EVENT_CONNECTED, events[1].event_type); 645 EXPECT_EQ(ERR_ABORTED, events[2].error_code); 646 EXPECT_EQ(SocketStreamEvent::EVENT_CLOSE, events[3].event_type); 647 } 648 649 TEST_F(SocketStreamTest, WSSBasicAuthProxyWithAuthCache) { 650 MockClientSocketFactory mock_socket_factory; 651 MockWrite data_writes1[] = { 652 MockWrite("CONNECT example.com:443 HTTP/1.1\r\n" 653 "Host: example.com\r\n" 654 "Proxy-Connection: keep-alive\r\n" 655 "Proxy-Authorization: Basic Zm9vOmJhcg==\r\n\r\n"), 656 }; 657 MockRead data_reads1[] = { 658 MockRead("HTTP/1.1 200 Connection Established\r\n"), 659 MockRead("Proxy-agent: Apache/2.2.8\r\n"), 660 MockRead("\r\n"), 661 MockRead(ASYNC, ERR_IO_PENDING) 662 }; 663 StaticSocketDataProvider data1(data_reads1, arraysize(data_reads1), 664 data_writes1, arraysize(data_writes1)); 665 mock_socket_factory.AddSocketDataProvider(&data1); 666 667 SSLSocketDataProvider data2(ASYNC, OK); 668 mock_socket_factory.AddSSLSocketDataProvider(&data2); 669 670 TestCompletionCallback test_callback; 671 scoped_ptr<SocketStreamEventRecorder> delegate( 672 new SocketStreamEventRecorder(test_callback.callback())); 673 delegate->SetOnConnected(base::Bind(&SocketStreamEventRecorder::DoClose, 674 base::Unretained(delegate.get()))); 675 676 scoped_refptr<SocketStream> socket_stream( 677 new SocketStream(GURL("wss://example.com/demo"), delegate.get())); 678 679 TestURLRequestContextWithProxy context("myproxy:70"); 680 HttpAuthCache* auth_cache = 681 context.http_transaction_factory()->GetSession()->http_auth_cache(); 682 auth_cache->Add(GURL("http://myproxy:70"), 683 "MyRealm1", 684 HttpAuth::AUTH_SCHEME_BASIC, 685 "Basic realm=MyRealm1", 686 AuthCredentials(ASCIIToUTF16("foo"), 687 ASCIIToUTF16("bar")), 688 "/"); 689 690 socket_stream->set_context(&context); 691 socket_stream->SetClientSocketFactory(&mock_socket_factory); 692 693 socket_stream->Connect(); 694 695 test_callback.WaitForResult(); 696 697 const std::vector<SocketStreamEvent>& events = delegate->GetSeenEvents(); 698 ASSERT_EQ(4U, events.size()); 699 EXPECT_EQ(SocketStreamEvent::EVENT_START_OPEN_CONNECTION, 700 events[0].event_type); 701 EXPECT_EQ(SocketStreamEvent::EVENT_CONNECTED, events[1].event_type); 702 EXPECT_EQ(ERR_ABORTED, events[2].error_code); 703 EXPECT_EQ(SocketStreamEvent::EVENT_CLOSE, events[3].event_type); 704 } 705 706 TEST_F(SocketStreamTest, IOPending) { 707 TestCompletionCallback test_callback; 708 709 scoped_ptr<SocketStreamEventRecorder> delegate( 710 new SocketStreamEventRecorder(test_callback.callback())); 711 delegate->SetOnStartOpenConnection(base::Bind( 712 &SocketStreamTest::DoIOPending, base::Unretained(this))); 713 delegate->SetOnConnected(base::Bind( 714 &SocketStreamTest::DoSendWebSocketHandshake, base::Unretained(this))); 715 delegate->SetOnReceivedData(base::Bind( 716 &SocketStreamTest::DoCloseFlushPendingWriteTest, 717 base::Unretained(this))); 718 719 TestURLRequestContext context; 720 721 scoped_refptr<SocketStream> socket_stream( 722 new SocketStream(GURL("ws://example.com/demo"), delegate.get())); 723 724 socket_stream->set_context(&context); 725 726 MockWrite data_writes[] = { 727 MockWrite(SocketStreamTest::kWebSocketHandshakeRequest), 728 MockWrite(ASYNC, "\0message1\xff", 10), 729 MockWrite(ASYNC, "\0message2\xff", 10) 730 }; 731 MockRead data_reads[] = { 732 MockRead(SocketStreamTest::kWebSocketHandshakeResponse), 733 // Server doesn't close the connection after handshake. 734 MockRead(ASYNC, ERR_IO_PENDING) 735 }; 736 AddWebSocketMessage("message1"); 737 AddWebSocketMessage("message2"); 738 739 DelayedSocketData data_provider( 740 1, data_reads, arraysize(data_reads), 741 data_writes, arraysize(data_writes)); 742 743 MockClientSocketFactory* mock_socket_factory = 744 GetMockClientSocketFactory(); 745 mock_socket_factory->AddSocketDataProvider(&data_provider); 746 747 socket_stream->SetClientSocketFactory(mock_socket_factory); 748 749 socket_stream->Connect(); 750 io_test_callback_.WaitForResult(); 751 EXPECT_EQ(SocketStream::STATE_RESOLVE_PROTOCOL_COMPLETE, 752 socket_stream->next_state_); 753 delegate->CompleteConnection(OK); 754 755 EXPECT_EQ(OK, test_callback.WaitForResult()); 756 757 EXPECT_TRUE(data_provider.at_read_eof()); 758 EXPECT_TRUE(data_provider.at_write_eof()); 759 760 const std::vector<SocketStreamEvent>& events = delegate->GetSeenEvents(); 761 ASSERT_EQ(7U, events.size()); 762 763 EXPECT_EQ(SocketStreamEvent::EVENT_START_OPEN_CONNECTION, 764 events[0].event_type); 765 EXPECT_EQ(SocketStreamEvent::EVENT_CONNECTED, events[1].event_type); 766 EXPECT_EQ(SocketStreamEvent::EVENT_SENT_DATA, events[2].event_type); 767 EXPECT_EQ(SocketStreamEvent::EVENT_RECEIVED_DATA, events[3].event_type); 768 EXPECT_EQ(SocketStreamEvent::EVENT_SENT_DATA, events[4].event_type); 769 EXPECT_EQ(SocketStreamEvent::EVENT_SENT_DATA, events[5].event_type); 770 EXPECT_EQ(SocketStreamEvent::EVENT_CLOSE, events[6].event_type); 771 } 772 773 TEST_F(SocketStreamTest, SwitchToSpdy) { 774 TestCompletionCallback test_callback; 775 776 scoped_ptr<SocketStreamEventRecorder> delegate( 777 new SocketStreamEventRecorder(test_callback.callback())); 778 delegate->SetOnStartOpenConnection(base::Bind( 779 &SocketStreamTest::DoSwitchToSpdyTest, base::Unretained(this))); 780 781 TestURLRequestContext context; 782 783 scoped_refptr<SocketStream> socket_stream( 784 new SocketStream(GURL("ws://example.com/demo"), delegate.get())); 785 786 socket_stream->set_context(&context); 787 788 socket_stream->Connect(); 789 790 EXPECT_EQ(ERR_PROTOCOL_SWITCHED, test_callback.WaitForResult()); 791 792 const std::vector<SocketStreamEvent>& events = delegate->GetSeenEvents(); 793 ASSERT_EQ(2U, events.size()); 794 795 EXPECT_EQ(SocketStreamEvent::EVENT_START_OPEN_CONNECTION, 796 events[0].event_type); 797 EXPECT_EQ(SocketStreamEvent::EVENT_ERROR, events[1].event_type); 798 EXPECT_EQ(ERR_PROTOCOL_SWITCHED, events[1].error_code); 799 } 800 801 TEST_F(SocketStreamTest, SwitchAfterPending) { 802 TestCompletionCallback test_callback; 803 804 scoped_ptr<SocketStreamEventRecorder> delegate( 805 new SocketStreamEventRecorder(test_callback.callback())); 806 delegate->SetOnStartOpenConnection(base::Bind( 807 &SocketStreamTest::DoIOPending, base::Unretained(this))); 808 809 TestURLRequestContext context; 810 811 scoped_refptr<SocketStream> socket_stream( 812 new SocketStream(GURL("ws://example.com/demo"), delegate.get())); 813 814 socket_stream->set_context(&context); 815 816 socket_stream->Connect(); 817 io_test_callback_.WaitForResult(); 818 819 EXPECT_EQ(SocketStream::STATE_RESOLVE_PROTOCOL_COMPLETE, 820 socket_stream->next_state_); 821 delegate->CompleteConnection(ERR_PROTOCOL_SWITCHED); 822 823 EXPECT_EQ(ERR_PROTOCOL_SWITCHED, test_callback.WaitForResult()); 824 825 const std::vector<SocketStreamEvent>& events = delegate->GetSeenEvents(); 826 ASSERT_EQ(2U, events.size()); 827 828 EXPECT_EQ(SocketStreamEvent::EVENT_START_OPEN_CONNECTION, 829 events[0].event_type); 830 EXPECT_EQ(SocketStreamEvent::EVENT_ERROR, events[1].event_type); 831 EXPECT_EQ(ERR_PROTOCOL_SWITCHED, events[1].error_code); 832 } 833 834 // Test a connection though a secure proxy. 835 TEST_F(SocketStreamTest, SecureProxyConnectError) { 836 MockClientSocketFactory mock_socket_factory; 837 MockWrite data_writes[] = { 838 MockWrite("CONNECT example.com:80 HTTP/1.1\r\n" 839 "Host: example.com\r\n" 840 "Proxy-Connection: keep-alive\r\n\r\n") 841 }; 842 MockRead data_reads[] = { 843 MockRead("HTTP/1.1 200 Connection Established\r\n"), 844 MockRead("Proxy-agent: Apache/2.2.8\r\n"), 845 MockRead("\r\n"), 846 // SocketStream::DoClose is run asynchronously. Socket can be read after 847 // "\r\n". We have to give ERR_IO_PENDING to SocketStream then to indicate 848 // server doesn't close the connection. 849 MockRead(ASYNC, ERR_IO_PENDING) 850 }; 851 StaticSocketDataProvider data(data_reads, arraysize(data_reads), 852 data_writes, arraysize(data_writes)); 853 mock_socket_factory.AddSocketDataProvider(&data); 854 SSLSocketDataProvider ssl(SYNCHRONOUS, ERR_SSL_PROTOCOL_ERROR); 855 mock_socket_factory.AddSSLSocketDataProvider(&ssl); 856 857 TestCompletionCallback test_callback; 858 TestURLRequestContextWithProxy context("https://myproxy:70"); 859 860 scoped_ptr<SocketStreamEventRecorder> delegate( 861 new SocketStreamEventRecorder(test_callback.callback())); 862 delegate->SetOnConnected(base::Bind(&SocketStreamEventRecorder::DoClose, 863 base::Unretained(delegate.get()))); 864 865 scoped_refptr<SocketStream> socket_stream( 866 new SocketStream(GURL("ws://example.com/demo"), delegate.get())); 867 868 socket_stream->set_context(&context); 869 socket_stream->SetClientSocketFactory(&mock_socket_factory); 870 871 socket_stream->Connect(); 872 873 test_callback.WaitForResult(); 874 875 const std::vector<SocketStreamEvent>& events = delegate->GetSeenEvents(); 876 ASSERT_EQ(3U, events.size()); 877 878 EXPECT_EQ(SocketStreamEvent::EVENT_START_OPEN_CONNECTION, 879 events[0].event_type); 880 EXPECT_EQ(SocketStreamEvent::EVENT_ERROR, events[1].event_type); 881 EXPECT_EQ(ERR_SSL_PROTOCOL_ERROR, events[1].error_code); 882 EXPECT_EQ(SocketStreamEvent::EVENT_CLOSE, events[2].event_type); 883 } 884 885 // Test a connection though a secure proxy. 886 TEST_F(SocketStreamTest, SecureProxyConnect) { 887 MockClientSocketFactory mock_socket_factory; 888 MockWrite data_writes[] = { 889 MockWrite("CONNECT example.com:80 HTTP/1.1\r\n" 890 "Host: example.com\r\n" 891 "Proxy-Connection: keep-alive\r\n\r\n") 892 }; 893 MockRead data_reads[] = { 894 MockRead("HTTP/1.1 200 Connection Established\r\n"), 895 MockRead("Proxy-agent: Apache/2.2.8\r\n"), 896 MockRead("\r\n"), 897 // SocketStream::DoClose is run asynchronously. Socket can be read after 898 // "\r\n". We have to give ERR_IO_PENDING to SocketStream then to indicate 899 // server doesn't close the connection. 900 MockRead(ASYNC, ERR_IO_PENDING) 901 }; 902 StaticSocketDataProvider data(data_reads, arraysize(data_reads), 903 data_writes, arraysize(data_writes)); 904 mock_socket_factory.AddSocketDataProvider(&data); 905 SSLSocketDataProvider ssl(SYNCHRONOUS, OK); 906 mock_socket_factory.AddSSLSocketDataProvider(&ssl); 907 908 TestCompletionCallback test_callback; 909 TestURLRequestContextWithProxy context("https://myproxy:70"); 910 911 scoped_ptr<SocketStreamEventRecorder> delegate( 912 new SocketStreamEventRecorder(test_callback.callback())); 913 delegate->SetOnConnected(base::Bind(&SocketStreamEventRecorder::DoClose, 914 base::Unretained(delegate.get()))); 915 916 scoped_refptr<SocketStream> socket_stream( 917 new SocketStream(GURL("ws://example.com/demo"), delegate.get())); 918 919 socket_stream->set_context(&context); 920 socket_stream->SetClientSocketFactory(&mock_socket_factory); 921 922 socket_stream->Connect(); 923 924 test_callback.WaitForResult(); 925 926 const std::vector<SocketStreamEvent>& events = delegate->GetSeenEvents(); 927 ASSERT_EQ(4U, events.size()); 928 929 EXPECT_EQ(SocketStreamEvent::EVENT_START_OPEN_CONNECTION, 930 events[0].event_type); 931 EXPECT_EQ(SocketStreamEvent::EVENT_CONNECTED, events[1].event_type); 932 EXPECT_EQ(SocketStreamEvent::EVENT_ERROR, events[2].event_type); 933 EXPECT_EQ(ERR_ABORTED, events[2].error_code); 934 EXPECT_EQ(SocketStreamEvent::EVENT_CLOSE, events[3].event_type); 935 } 936 937 TEST_F(SocketStreamTest, BeforeConnectFailed) { 938 TestCompletionCallback test_callback; 939 940 scoped_ptr<SocketStreamEventRecorder> delegate( 941 new SocketStreamEventRecorder(test_callback.callback())); 942 943 TestURLRequestContext context; 944 TestSocketStreamNetworkDelegate network_delegate; 945 network_delegate.SetBeforeConnectResult(ERR_ACCESS_DENIED); 946 context.set_network_delegate(&network_delegate); 947 948 scoped_refptr<SocketStream> socket_stream( 949 new SocketStream(GURL("ws://example.com/demo"), delegate.get())); 950 951 socket_stream->set_context(&context); 952 953 socket_stream->Connect(); 954 955 test_callback.WaitForResult(); 956 957 const std::vector<SocketStreamEvent>& events = delegate->GetSeenEvents(); 958 ASSERT_EQ(2U, events.size()); 959 960 EXPECT_EQ(SocketStreamEvent::EVENT_ERROR, events[0].event_type); 961 EXPECT_EQ(ERR_ACCESS_DENIED, events[0].error_code); 962 EXPECT_EQ(SocketStreamEvent::EVENT_CLOSE, events[1].event_type); 963 } 964 965 // Check that a connect failure, followed by the delegate calling DetachDelegate 966 // and deleting itself in the OnError callback, is handled correctly. 967 TEST_F(SocketStreamTest, OnErrorDetachDelegate) { 968 MockClientSocketFactory mock_socket_factory; 969 TestCompletionCallback test_callback; 970 971 // SelfDeletingDelegate is self-owning; we just need a pointer to it to 972 // connect it and the SocketStream. 973 SelfDeletingDelegate* delegate = 974 new SelfDeletingDelegate(test_callback.callback()); 975 MockConnect mock_connect(ASYNC, ERR_CONNECTION_REFUSED); 976 StaticSocketDataProvider data; 977 data.set_connect_data(mock_connect); 978 mock_socket_factory.AddSocketDataProvider(&data); 979 980 TestURLRequestContext context; 981 scoped_refptr<SocketStream> socket_stream( 982 new SocketStream(GURL("ws://localhost:9998/echo"), delegate)); 983 socket_stream->set_context(&context); 984 socket_stream->SetClientSocketFactory(&mock_socket_factory); 985 delegate->set_socket_stream(socket_stream); 986 // The delegate pointer will become invalid during the test. Set it to NULL to 987 // avoid holding a dangling pointer. 988 delegate = NULL; 989 990 socket_stream->Connect(); 991 992 EXPECT_EQ(OK, test_callback.WaitForResult()); 993 } 994 995 TEST_F(SocketStreamTest, NullContextSocketStreamShouldNotCrash) { 996 TestCompletionCallback test_callback; 997 998 scoped_ptr<SocketStreamEventRecorder> delegate( 999 new SocketStreamEventRecorder(test_callback.callback())); 1000 TestURLRequestContext context; 1001 scoped_refptr<SocketStream> socket_stream( 1002 new SocketStream(GURL("ws://example.com/demo"), delegate.get())); 1003 delegate->SetOnStartOpenConnection(base::Bind( 1004 &SocketStreamTest::DoIOPending, base::Unretained(this))); 1005 delegate->SetOnConnected(base::Bind( 1006 &SocketStreamTest::DoSendWebSocketHandshake, base::Unretained(this))); 1007 delegate->SetOnReceivedData(base::Bind( 1008 &SocketStreamTest::DoCloseFlushPendingWriteTestWithSetContextNull, 1009 base::Unretained(this))); 1010 1011 socket_stream->set_context(&context); 1012 1013 MockWrite data_writes[] = { 1014 MockWrite(SocketStreamTest::kWebSocketHandshakeRequest), 1015 }; 1016 MockRead data_reads[] = { 1017 MockRead(SocketStreamTest::kWebSocketHandshakeResponse), 1018 }; 1019 AddWebSocketMessage("message1"); 1020 AddWebSocketMessage("message2"); 1021 1022 DelayedSocketData data_provider( 1023 1, data_reads, arraysize(data_reads), 1024 data_writes, arraysize(data_writes)); 1025 1026 MockClientSocketFactory* mock_socket_factory = GetMockClientSocketFactory(); 1027 mock_socket_factory->AddSocketDataProvider(&data_provider); 1028 socket_stream->SetClientSocketFactory(mock_socket_factory); 1029 1030 socket_stream->Connect(); 1031 io_test_callback_.WaitForResult(); 1032 delegate->CompleteConnection(OK); 1033 EXPECT_EQ(OK, test_callback.WaitForResult()); 1034 1035 EXPECT_TRUE(data_provider.at_read_eof()); 1036 EXPECT_TRUE(data_provider.at_write_eof()); 1037 1038 const std::vector<SocketStreamEvent>& events = delegate->GetSeenEvents(); 1039 ASSERT_EQ(5U, events.size()); 1040 1041 EXPECT_EQ(SocketStreamEvent::EVENT_START_OPEN_CONNECTION, 1042 events[0].event_type); 1043 EXPECT_EQ(SocketStreamEvent::EVENT_CONNECTED, events[1].event_type); 1044 EXPECT_EQ(SocketStreamEvent::EVENT_SENT_DATA, events[2].event_type); 1045 EXPECT_EQ(SocketStreamEvent::EVENT_RECEIVED_DATA, events[3].event_type); 1046 EXPECT_EQ(SocketStreamEvent::EVENT_CLOSE, events[4].event_type); 1047 } 1048 1049 } // namespace net 1050