1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "net/socket/socket_test_util.h" 6 7 #include <algorithm> 8 #include <vector> 9 10 #include "base/basictypes.h" 11 #include "base/bind.h" 12 #include "base/bind_helpers.h" 13 #include "base/compiler_specific.h" 14 #include "base/message_loop/message_loop.h" 15 #include "base/run_loop.h" 16 #include "base/time/time.h" 17 #include "net/base/address_family.h" 18 #include "net/base/address_list.h" 19 #include "net/base/auth.h" 20 #include "net/base/load_timing_info.h" 21 #include "net/http/http_network_session.h" 22 #include "net/http/http_request_headers.h" 23 #include "net/http/http_response_headers.h" 24 #include "net/socket/client_socket_pool_histograms.h" 25 #include "net/socket/socket.h" 26 #include "net/ssl/ssl_cert_request_info.h" 27 #include "net/ssl/ssl_info.h" 28 #include "testing/gtest/include/gtest/gtest.h" 29 30 // Socket events are easier to debug if you log individual reads and writes. 31 // Enable these if locally debugging, but they are too noisy for the waterfall. 32 #if 0 33 #define NET_TRACE(level, s) DLOG(level) << s << __FUNCTION__ << "() " 34 #else 35 #define NET_TRACE(level, s) EAT_STREAM_PARAMETERS 36 #endif 37 38 namespace net { 39 40 namespace { 41 42 inline char AsciifyHigh(char x) { 43 char nybble = static_cast<char>((x >> 4) & 0x0F); 44 return nybble + ((nybble < 0x0A) ? '0' : 'A' - 10); 45 } 46 47 inline char AsciifyLow(char x) { 48 char nybble = static_cast<char>((x >> 0) & 0x0F); 49 return nybble + ((nybble < 0x0A) ? '0' : 'A' - 10); 50 } 51 52 inline char Asciify(char x) { 53 if ((x < 0) || !isprint(x)) 54 return '.'; 55 return x; 56 } 57 58 void DumpData(const char* data, int data_len) { 59 if (logging::LOG_INFO < logging::GetMinLogLevel()) 60 return; 61 DVLOG(1) << "Length: " << data_len; 62 const char* pfx = "Data: "; 63 if (!data || (data_len <= 0)) { 64 DVLOG(1) << pfx << "<None>"; 65 } else { 66 int i; 67 for (i = 0; i <= (data_len - 4); i += 4) { 68 DVLOG(1) << pfx 69 << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0]) 70 << AsciifyHigh(data[i + 1]) << AsciifyLow(data[i + 1]) 71 << AsciifyHigh(data[i + 2]) << AsciifyLow(data[i + 2]) 72 << AsciifyHigh(data[i + 3]) << AsciifyLow(data[i + 3]) 73 << " '" 74 << Asciify(data[i + 0]) 75 << Asciify(data[i + 1]) 76 << Asciify(data[i + 2]) 77 << Asciify(data[i + 3]) 78 << "'"; 79 pfx = " "; 80 } 81 // Take care of any 'trailing' bytes, if data_len was not a multiple of 4. 82 switch (data_len - i) { 83 case 3: 84 DVLOG(1) << pfx 85 << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0]) 86 << AsciifyHigh(data[i + 1]) << AsciifyLow(data[i + 1]) 87 << AsciifyHigh(data[i + 2]) << AsciifyLow(data[i + 2]) 88 << " '" 89 << Asciify(data[i + 0]) 90 << Asciify(data[i + 1]) 91 << Asciify(data[i + 2]) 92 << " '"; 93 break; 94 case 2: 95 DVLOG(1) << pfx 96 << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0]) 97 << AsciifyHigh(data[i + 1]) << AsciifyLow(data[i + 1]) 98 << " '" 99 << Asciify(data[i + 0]) 100 << Asciify(data[i + 1]) 101 << " '"; 102 break; 103 case 1: 104 DVLOG(1) << pfx 105 << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0]) 106 << " '" 107 << Asciify(data[i + 0]) 108 << " '"; 109 break; 110 } 111 } 112 } 113 114 template <MockReadWriteType type> 115 void DumpMockReadWrite(const MockReadWrite<type>& r) { 116 if (logging::LOG_INFO < logging::GetMinLogLevel()) 117 return; 118 DVLOG(1) << "Async: " << (r.mode == ASYNC) 119 << "\nResult: " << r.result; 120 DumpData(r.data, r.data_len); 121 const char* stop = (r.sequence_number & MockRead::STOPLOOP) ? " (STOP)" : ""; 122 DVLOG(1) << "Stage: " << (r.sequence_number & ~MockRead::STOPLOOP) << stop 123 << "\nTime: " << r.time_stamp.ToInternalValue(); 124 } 125 126 } // namespace 127 128 MockConnect::MockConnect() : mode(ASYNC), result(OK) { 129 IPAddressNumber ip; 130 CHECK(ParseIPLiteralToNumber("192.0.2.33", &ip)); 131 peer_addr = IPEndPoint(ip, 0); 132 } 133 134 MockConnect::MockConnect(IoMode io_mode, int r) : mode(io_mode), result(r) { 135 IPAddressNumber ip; 136 CHECK(ParseIPLiteralToNumber("192.0.2.33", &ip)); 137 peer_addr = IPEndPoint(ip, 0); 138 } 139 140 MockConnect::MockConnect(IoMode io_mode, int r, IPEndPoint addr) : 141 mode(io_mode), 142 result(r), 143 peer_addr(addr) { 144 } 145 146 MockConnect::~MockConnect() {} 147 148 StaticSocketDataProvider::StaticSocketDataProvider() 149 : reads_(NULL), 150 read_index_(0), 151 read_count_(0), 152 writes_(NULL), 153 write_index_(0), 154 write_count_(0) { 155 } 156 157 StaticSocketDataProvider::StaticSocketDataProvider(MockRead* reads, 158 size_t reads_count, 159 MockWrite* writes, 160 size_t writes_count) 161 : reads_(reads), 162 read_index_(0), 163 read_count_(reads_count), 164 writes_(writes), 165 write_index_(0), 166 write_count_(writes_count) { 167 } 168 169 StaticSocketDataProvider::~StaticSocketDataProvider() {} 170 171 const MockRead& StaticSocketDataProvider::PeekRead() const { 172 DCHECK(!at_read_eof()); 173 return reads_[read_index_]; 174 } 175 176 const MockWrite& StaticSocketDataProvider::PeekWrite() const { 177 DCHECK(!at_write_eof()); 178 return writes_[write_index_]; 179 } 180 181 const MockRead& StaticSocketDataProvider::PeekRead(size_t index) const { 182 DCHECK_LT(index, read_count_); 183 return reads_[index]; 184 } 185 186 const MockWrite& StaticSocketDataProvider::PeekWrite(size_t index) const { 187 DCHECK_LT(index, write_count_); 188 return writes_[index]; 189 } 190 191 MockRead StaticSocketDataProvider::GetNextRead() { 192 DCHECK(!at_read_eof()); 193 reads_[read_index_].time_stamp = base::Time::Now(); 194 return reads_[read_index_++]; 195 } 196 197 MockWriteResult StaticSocketDataProvider::OnWrite(const std::string& data) { 198 if (!writes_) { 199 // Not using mock writes; succeed synchronously. 200 return MockWriteResult(SYNCHRONOUS, data.length()); 201 } 202 DCHECK(!at_write_eof()); 203 204 // Check that what we are writing matches the expectation. 205 // Then give the mocked return value. 206 MockWrite* w = &writes_[write_index_++]; 207 w->time_stamp = base::Time::Now(); 208 int result = w->result; 209 if (w->data) { 210 // Note - we can simulate a partial write here. If the expected data 211 // is a match, but shorter than the write actually written, that is legal. 212 // Example: 213 // Application writes "foobarbaz" (9 bytes) 214 // Expected write was "foo" (3 bytes) 215 // This is a success, and we return 3 to the application. 216 std::string expected_data(w->data, w->data_len); 217 EXPECT_GE(data.length(), expected_data.length()); 218 std::string actual_data(data.substr(0, w->data_len)); 219 EXPECT_EQ(expected_data, actual_data); 220 if (expected_data != actual_data) 221 return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED); 222 if (result == OK) 223 result = w->data_len; 224 } 225 return MockWriteResult(w->mode, result); 226 } 227 228 void StaticSocketDataProvider::Reset() { 229 read_index_ = 0; 230 write_index_ = 0; 231 } 232 233 DynamicSocketDataProvider::DynamicSocketDataProvider() 234 : short_read_limit_(0), 235 allow_unconsumed_reads_(false) { 236 } 237 238 DynamicSocketDataProvider::~DynamicSocketDataProvider() {} 239 240 MockRead DynamicSocketDataProvider::GetNextRead() { 241 if (reads_.empty()) 242 return MockRead(SYNCHRONOUS, ERR_UNEXPECTED); 243 MockRead result = reads_.front(); 244 if (short_read_limit_ == 0 || result.data_len <= short_read_limit_) { 245 reads_.pop_front(); 246 } else { 247 result.data_len = short_read_limit_; 248 reads_.front().data += result.data_len; 249 reads_.front().data_len -= result.data_len; 250 } 251 return result; 252 } 253 254 void DynamicSocketDataProvider::Reset() { 255 reads_.clear(); 256 } 257 258 void DynamicSocketDataProvider::SimulateRead(const char* data, 259 const size_t length) { 260 if (!allow_unconsumed_reads_) { 261 EXPECT_TRUE(reads_.empty()) << "Unconsumed read: " << reads_.front().data; 262 } 263 reads_.push_back(MockRead(ASYNC, data, length)); 264 } 265 266 SSLSocketDataProvider::SSLSocketDataProvider(IoMode mode, int result) 267 : connect(mode, result), 268 next_proto_status(SSLClientSocket::kNextProtoUnsupported), 269 was_npn_negotiated(false), 270 protocol_negotiated(kProtoUnknown), 271 client_cert_sent(false), 272 cert_request_info(NULL), 273 channel_id_sent(false) { 274 } 275 276 SSLSocketDataProvider::~SSLSocketDataProvider() { 277 } 278 279 void SSLSocketDataProvider::SetNextProto(NextProto proto) { 280 was_npn_negotiated = true; 281 next_proto_status = SSLClientSocket::kNextProtoNegotiated; 282 protocol_negotiated = proto; 283 next_proto = SSLClientSocket::NextProtoToString(proto); 284 } 285 286 DelayedSocketData::DelayedSocketData( 287 int write_delay, MockRead* reads, size_t reads_count, 288 MockWrite* writes, size_t writes_count) 289 : StaticSocketDataProvider(reads, reads_count, writes, writes_count), 290 write_delay_(write_delay), 291 read_in_progress_(false), 292 weak_factory_(this) { 293 DCHECK_GE(write_delay_, 0); 294 } 295 296 DelayedSocketData::DelayedSocketData( 297 const MockConnect& connect, int write_delay, MockRead* reads, 298 size_t reads_count, MockWrite* writes, size_t writes_count) 299 : StaticSocketDataProvider(reads, reads_count, writes, writes_count), 300 write_delay_(write_delay), 301 read_in_progress_(false), 302 weak_factory_(this) { 303 DCHECK_GE(write_delay_, 0); 304 set_connect_data(connect); 305 } 306 307 DelayedSocketData::~DelayedSocketData() { 308 } 309 310 void DelayedSocketData::ForceNextRead() { 311 DCHECK(read_in_progress_); 312 write_delay_ = 0; 313 CompleteRead(); 314 } 315 316 MockRead DelayedSocketData::GetNextRead() { 317 MockRead out = MockRead(ASYNC, ERR_IO_PENDING); 318 if (write_delay_ <= 0) 319 out = StaticSocketDataProvider::GetNextRead(); 320 read_in_progress_ = (out.result == ERR_IO_PENDING); 321 return out; 322 } 323 324 MockWriteResult DelayedSocketData::OnWrite(const std::string& data) { 325 MockWriteResult rv = StaticSocketDataProvider::OnWrite(data); 326 // Now that our write has completed, we can allow reads to continue. 327 if (!--write_delay_ && read_in_progress_) 328 base::MessageLoop::current()->PostDelayedTask( 329 FROM_HERE, 330 base::Bind(&DelayedSocketData::CompleteRead, 331 weak_factory_.GetWeakPtr()), 332 base::TimeDelta::FromMilliseconds(100)); 333 return rv; 334 } 335 336 void DelayedSocketData::Reset() { 337 set_socket(NULL); 338 read_in_progress_ = false; 339 weak_factory_.InvalidateWeakPtrs(); 340 StaticSocketDataProvider::Reset(); 341 } 342 343 void DelayedSocketData::CompleteRead() { 344 if (socket() && read_in_progress_) 345 socket()->OnReadComplete(GetNextRead()); 346 } 347 348 OrderedSocketData::OrderedSocketData( 349 MockRead* reads, size_t reads_count, MockWrite* writes, size_t writes_count) 350 : StaticSocketDataProvider(reads, reads_count, writes, writes_count), 351 sequence_number_(0), loop_stop_stage_(0), 352 blocked_(false), weak_factory_(this) { 353 } 354 355 OrderedSocketData::OrderedSocketData( 356 const MockConnect& connect, 357 MockRead* reads, size_t reads_count, 358 MockWrite* writes, size_t writes_count) 359 : StaticSocketDataProvider(reads, reads_count, writes, writes_count), 360 sequence_number_(0), loop_stop_stage_(0), 361 blocked_(false), weak_factory_(this) { 362 set_connect_data(connect); 363 } 364 365 void OrderedSocketData::EndLoop() { 366 // If we've already stopped the loop, don't do it again until we've advanced 367 // to the next sequence_number. 368 NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ << ": EndLoop()"; 369 if (loop_stop_stage_ > 0) { 370 const MockRead& next_read = StaticSocketDataProvider::PeekRead(); 371 if ((next_read.sequence_number & ~MockRead::STOPLOOP) > 372 loop_stop_stage_) { 373 NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ 374 << ": Clearing stop index"; 375 loop_stop_stage_ = 0; 376 } else { 377 return; 378 } 379 } 380 // Record the sequence_number at which we stopped the loop. 381 NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ 382 << ": Posting Quit at read " << read_index(); 383 loop_stop_stage_ = sequence_number_; 384 } 385 386 MockRead OrderedSocketData::GetNextRead() { 387 weak_factory_.InvalidateWeakPtrs(); 388 blocked_ = false; 389 const MockRead& next_read = StaticSocketDataProvider::PeekRead(); 390 if (next_read.sequence_number & MockRead::STOPLOOP) 391 EndLoop(); 392 if ((next_read.sequence_number & ~MockRead::STOPLOOP) <= 393 sequence_number_++) { 394 NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ - 1 395 << ": Read " << read_index(); 396 DumpMockReadWrite(next_read); 397 blocked_ = (next_read.result == ERR_IO_PENDING); 398 return StaticSocketDataProvider::GetNextRead(); 399 } 400 NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ - 1 401 << ": I/O Pending"; 402 MockRead result = MockRead(ASYNC, ERR_IO_PENDING); 403 DumpMockReadWrite(result); 404 blocked_ = true; 405 return result; 406 } 407 408 MockWriteResult OrderedSocketData::OnWrite(const std::string& data) { 409 NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ 410 << ": Write " << write_index(); 411 DumpMockReadWrite(PeekWrite()); 412 ++sequence_number_; 413 if (blocked_) { 414 // TODO(willchan): This 100ms delay seems to work around some weirdness. We 415 // should probably fix the weirdness. One example is in SpdyStream, 416 // DoSendRequest() will return ERR_IO_PENDING, and there's a race. If the 417 // SYN_REPLY causes OnResponseReceived() to get called before 418 // SpdyStream::ReadResponseHeaders() is called, we hit a NOTREACHED(). 419 base::MessageLoop::current()->PostDelayedTask( 420 FROM_HERE, 421 base::Bind(&OrderedSocketData::CompleteRead, 422 weak_factory_.GetWeakPtr()), 423 base::TimeDelta::FromMilliseconds(100)); 424 } 425 return StaticSocketDataProvider::OnWrite(data); 426 } 427 428 void OrderedSocketData::Reset() { 429 NET_TRACE(INFO, " *** ") << "Stage " 430 << sequence_number_ << ": Reset()"; 431 sequence_number_ = 0; 432 loop_stop_stage_ = 0; 433 set_socket(NULL); 434 weak_factory_.InvalidateWeakPtrs(); 435 StaticSocketDataProvider::Reset(); 436 } 437 438 void OrderedSocketData::CompleteRead() { 439 if (socket() && blocked_) { 440 NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_; 441 socket()->OnReadComplete(GetNextRead()); 442 } 443 } 444 445 OrderedSocketData::~OrderedSocketData() {} 446 447 DeterministicSocketData::DeterministicSocketData(MockRead* reads, 448 size_t reads_count, MockWrite* writes, size_t writes_count) 449 : StaticSocketDataProvider(reads, reads_count, writes, writes_count), 450 sequence_number_(0), 451 current_read_(), 452 current_write_(), 453 stopping_sequence_number_(0), 454 stopped_(false), 455 print_debug_(false), 456 is_running_(false) { 457 VerifyCorrectSequenceNumbers(reads, reads_count, writes, writes_count); 458 } 459 460 DeterministicSocketData::~DeterministicSocketData() {} 461 462 void DeterministicSocketData::Run() { 463 DCHECK(!is_running_); 464 is_running_ = true; 465 466 SetStopped(false); 467 int counter = 0; 468 // Continue to consume data until all data has run out, or the stopped_ flag 469 // has been set. Consuming data requires two separate operations -- running 470 // the tasks in the message loop, and explicitly invoking the read/write 471 // callbacks (simulating network I/O). We check our conditions between each, 472 // since they can change in either. 473 while ((!at_write_eof() || !at_read_eof()) && !stopped()) { 474 if (counter % 2 == 0) 475 base::RunLoop().RunUntilIdle(); 476 if (counter % 2 == 1) { 477 InvokeCallbacks(); 478 } 479 counter++; 480 } 481 // We're done consuming new data, but it is possible there are still some 482 // pending callbacks which we expect to complete before returning. 483 while (delegate_.get() && 484 (delegate_->WritePending() || delegate_->ReadPending()) && 485 !stopped()) { 486 InvokeCallbacks(); 487 base::RunLoop().RunUntilIdle(); 488 } 489 SetStopped(false); 490 is_running_ = false; 491 } 492 493 void DeterministicSocketData::RunFor(int steps) { 494 StopAfter(steps); 495 Run(); 496 } 497 498 void DeterministicSocketData::SetStop(int seq) { 499 DCHECK_LT(sequence_number_, seq); 500 stopping_sequence_number_ = seq; 501 stopped_ = false; 502 } 503 504 void DeterministicSocketData::StopAfter(int seq) { 505 SetStop(sequence_number_ + seq); 506 } 507 508 MockRead DeterministicSocketData::GetNextRead() { 509 current_read_ = StaticSocketDataProvider::PeekRead(); 510 511 // Synchronous read while stopped is an error 512 if (stopped() && current_read_.mode == SYNCHRONOUS) { 513 LOG(ERROR) << "Unable to perform synchronous IO while stopped"; 514 return MockRead(SYNCHRONOUS, ERR_UNEXPECTED); 515 } 516 517 // Async read which will be called back in a future step. 518 if (sequence_number_ < current_read_.sequence_number) { 519 NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ 520 << ": I/O Pending"; 521 MockRead result = MockRead(SYNCHRONOUS, ERR_IO_PENDING); 522 if (current_read_.mode == SYNCHRONOUS) { 523 LOG(ERROR) << "Unable to perform synchronous read: " 524 << current_read_.sequence_number 525 << " at stage: " << sequence_number_; 526 result = MockRead(SYNCHRONOUS, ERR_UNEXPECTED); 527 } 528 if (print_debug_) 529 DumpMockReadWrite(result); 530 return result; 531 } 532 533 NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ 534 << ": Read " << read_index(); 535 if (print_debug_) 536 DumpMockReadWrite(current_read_); 537 538 // Increment the sequence number if IO is complete 539 if (current_read_.mode == SYNCHRONOUS) 540 NextStep(); 541 542 DCHECK_NE(ERR_IO_PENDING, current_read_.result); 543 StaticSocketDataProvider::GetNextRead(); 544 545 return current_read_; 546 } 547 548 MockWriteResult DeterministicSocketData::OnWrite(const std::string& data) { 549 const MockWrite& next_write = StaticSocketDataProvider::PeekWrite(); 550 current_write_ = next_write; 551 552 // Synchronous write while stopped is an error 553 if (stopped() && next_write.mode == SYNCHRONOUS) { 554 LOG(ERROR) << "Unable to perform synchronous IO while stopped"; 555 return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED); 556 } 557 558 // Async write which will be called back in a future step. 559 if (sequence_number_ < next_write.sequence_number) { 560 NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ 561 << ": I/O Pending"; 562 if (next_write.mode == SYNCHRONOUS) { 563 LOG(ERROR) << "Unable to perform synchronous write: " 564 << next_write.sequence_number << " at stage: " << sequence_number_; 565 return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED); 566 } 567 } else { 568 NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ 569 << ": Write " << write_index(); 570 } 571 572 if (print_debug_) 573 DumpMockReadWrite(next_write); 574 575 // Move to the next step if I/O is synchronous, since the operation will 576 // complete when this method returns. 577 if (next_write.mode == SYNCHRONOUS) 578 NextStep(); 579 580 // This is either a sync write for this step, or an async write. 581 return StaticSocketDataProvider::OnWrite(data); 582 } 583 584 void DeterministicSocketData::Reset() { 585 NET_TRACE(INFO, " *** ") << "Stage " 586 << sequence_number_ << ": Reset()"; 587 sequence_number_ = 0; 588 StaticSocketDataProvider::Reset(); 589 NOTREACHED(); 590 } 591 592 void DeterministicSocketData::InvokeCallbacks() { 593 if (delegate_.get() && delegate_->WritePending() && 594 (current_write().sequence_number == sequence_number())) { 595 NextStep(); 596 delegate_->CompleteWrite(); 597 return; 598 } 599 if (delegate_.get() && delegate_->ReadPending() && 600 (current_read().sequence_number == sequence_number())) { 601 NextStep(); 602 delegate_->CompleteRead(); 603 return; 604 } 605 } 606 607 void DeterministicSocketData::NextStep() { 608 // Invariant: Can never move *past* the stopping step. 609 DCHECK_LT(sequence_number_, stopping_sequence_number_); 610 sequence_number_++; 611 if (sequence_number_ == stopping_sequence_number_) 612 SetStopped(true); 613 } 614 615 void DeterministicSocketData::VerifyCorrectSequenceNumbers( 616 MockRead* reads, size_t reads_count, 617 MockWrite* writes, size_t writes_count) { 618 size_t read = 0; 619 size_t write = 0; 620 int expected = 0; 621 while (read < reads_count || write < writes_count) { 622 // Check to see that we have a read or write at the expected 623 // state. 624 if (read < reads_count && reads[read].sequence_number == expected) { 625 ++read; 626 ++expected; 627 continue; 628 } 629 if (write < writes_count && writes[write].sequence_number == expected) { 630 ++write; 631 ++expected; 632 continue; 633 } 634 NOTREACHED() << "Missing sequence number: " << expected; 635 return; 636 } 637 DCHECK_EQ(read, reads_count); 638 DCHECK_EQ(write, writes_count); 639 } 640 641 MockClientSocketFactory::MockClientSocketFactory() {} 642 643 MockClientSocketFactory::~MockClientSocketFactory() {} 644 645 void MockClientSocketFactory::AddSocketDataProvider( 646 SocketDataProvider* data) { 647 mock_data_.Add(data); 648 } 649 650 void MockClientSocketFactory::AddSSLSocketDataProvider( 651 SSLSocketDataProvider* data) { 652 mock_ssl_data_.Add(data); 653 } 654 655 void MockClientSocketFactory::ResetNextMockIndexes() { 656 mock_data_.ResetNextIndex(); 657 mock_ssl_data_.ResetNextIndex(); 658 } 659 660 scoped_ptr<DatagramClientSocket> 661 MockClientSocketFactory::CreateDatagramClientSocket( 662 DatagramSocket::BindType bind_type, 663 const RandIntCallback& rand_int_cb, 664 net::NetLog* net_log, 665 const net::NetLog::Source& source) { 666 SocketDataProvider* data_provider = mock_data_.GetNext(); 667 scoped_ptr<MockUDPClientSocket> socket( 668 new MockUDPClientSocket(data_provider, net_log)); 669 data_provider->set_socket(socket.get()); 670 return socket.PassAs<DatagramClientSocket>(); 671 } 672 673 scoped_ptr<StreamSocket> MockClientSocketFactory::CreateTransportClientSocket( 674 const AddressList& addresses, 675 net::NetLog* net_log, 676 const net::NetLog::Source& source) { 677 SocketDataProvider* data_provider = mock_data_.GetNext(); 678 scoped_ptr<MockTCPClientSocket> socket( 679 new MockTCPClientSocket(addresses, net_log, data_provider)); 680 data_provider->set_socket(socket.get()); 681 return socket.PassAs<StreamSocket>(); 682 } 683 684 scoped_ptr<SSLClientSocket> MockClientSocketFactory::CreateSSLClientSocket( 685 scoped_ptr<ClientSocketHandle> transport_socket, 686 const HostPortPair& host_and_port, 687 const SSLConfig& ssl_config, 688 const SSLClientSocketContext& context) { 689 return scoped_ptr<SSLClientSocket>( 690 new MockSSLClientSocket(transport_socket.Pass(), 691 host_and_port, ssl_config, 692 mock_ssl_data_.GetNext())); 693 } 694 695 void MockClientSocketFactory::ClearSSLSessionCache() { 696 } 697 698 const char MockClientSocket::kTlsUnique[] = "MOCK_TLSUNIQ"; 699 700 MockClientSocket::MockClientSocket(const BoundNetLog& net_log) 701 : connected_(false), 702 net_log_(net_log), 703 weak_factory_(this) { 704 IPAddressNumber ip; 705 CHECK(ParseIPLiteralToNumber("192.0.2.33", &ip)); 706 peer_addr_ = IPEndPoint(ip, 0); 707 } 708 709 bool MockClientSocket::SetReceiveBufferSize(int32 size) { 710 return true; 711 } 712 713 bool MockClientSocket::SetSendBufferSize(int32 size) { 714 return true; 715 } 716 717 void MockClientSocket::Disconnect() { 718 connected_ = false; 719 } 720 721 bool MockClientSocket::IsConnected() const { 722 return connected_; 723 } 724 725 bool MockClientSocket::IsConnectedAndIdle() const { 726 return connected_; 727 } 728 729 int MockClientSocket::GetPeerAddress(IPEndPoint* address) const { 730 if (!IsConnected()) 731 return ERR_SOCKET_NOT_CONNECTED; 732 *address = peer_addr_; 733 return OK; 734 } 735 736 int MockClientSocket::GetLocalAddress(IPEndPoint* address) const { 737 IPAddressNumber ip; 738 bool rv = ParseIPLiteralToNumber("192.0.2.33", &ip); 739 CHECK(rv); 740 *address = IPEndPoint(ip, 123); 741 return OK; 742 } 743 744 const BoundNetLog& MockClientSocket::NetLog() const { 745 return net_log_; 746 } 747 748 void MockClientSocket::GetSSLCertRequestInfo( 749 SSLCertRequestInfo* cert_request_info) { 750 } 751 752 int MockClientSocket::ExportKeyingMaterial(const base::StringPiece& label, 753 bool has_context, 754 const base::StringPiece& context, 755 unsigned char* out, 756 unsigned int outlen) { 757 memset(out, 'A', outlen); 758 return OK; 759 } 760 761 int MockClientSocket::GetTLSUniqueChannelBinding(std::string* out) { 762 out->assign(MockClientSocket::kTlsUnique); 763 return OK; 764 } 765 766 ServerBoundCertService* MockClientSocket::GetServerBoundCertService() const { 767 NOTREACHED(); 768 return NULL; 769 } 770 771 SSLClientSocket::NextProtoStatus 772 MockClientSocket::GetNextProto(std::string* proto, std::string* server_protos) { 773 proto->clear(); 774 server_protos->clear(); 775 return SSLClientSocket::kNextProtoUnsupported; 776 } 777 778 MockClientSocket::~MockClientSocket() {} 779 780 void MockClientSocket::RunCallbackAsync(const CompletionCallback& callback, 781 int result) { 782 base::MessageLoop::current()->PostTask( 783 FROM_HERE, 784 base::Bind(&MockClientSocket::RunCallback, 785 weak_factory_.GetWeakPtr(), 786 callback, 787 result)); 788 } 789 790 void MockClientSocket::RunCallback(const net::CompletionCallback& callback, 791 int result) { 792 if (!callback.is_null()) 793 callback.Run(result); 794 } 795 796 MockTCPClientSocket::MockTCPClientSocket(const AddressList& addresses, 797 net::NetLog* net_log, 798 SocketDataProvider* data) 799 : MockClientSocket(BoundNetLog::Make(net_log, net::NetLog::SOURCE_NONE)), 800 addresses_(addresses), 801 data_(data), 802 read_offset_(0), 803 read_data_(SYNCHRONOUS, ERR_UNEXPECTED), 804 need_read_data_(true), 805 peer_closed_connection_(false), 806 pending_buf_(NULL), 807 pending_buf_len_(0), 808 was_used_to_convey_data_(false) { 809 DCHECK(data_); 810 peer_addr_ = data->connect_data().peer_addr; 811 data_->Reset(); 812 } 813 814 MockTCPClientSocket::~MockTCPClientSocket() {} 815 816 int MockTCPClientSocket::Read(IOBuffer* buf, int buf_len, 817 const CompletionCallback& callback) { 818 if (!connected_) 819 return ERR_UNEXPECTED; 820 821 // If the buffer is already in use, a read is already in progress! 822 DCHECK(pending_buf_ == NULL); 823 824 // Store our async IO data. 825 pending_buf_ = buf; 826 pending_buf_len_ = buf_len; 827 pending_callback_ = callback; 828 829 if (need_read_data_) { 830 read_data_ = data_->GetNextRead(); 831 if (read_data_.result == ERR_CONNECTION_CLOSED) { 832 // This MockRead is just a marker to instruct us to set 833 // peer_closed_connection_. 834 peer_closed_connection_ = true; 835 } 836 if (read_data_.result == ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ) { 837 // This MockRead is just a marker to instruct us to set 838 // peer_closed_connection_. Skip it and get the next one. 839 read_data_ = data_->GetNextRead(); 840 peer_closed_connection_ = true; 841 } 842 // ERR_IO_PENDING means that the SocketDataProvider is taking responsibility 843 // to complete the async IO manually later (via OnReadComplete). 844 if (read_data_.result == ERR_IO_PENDING) { 845 // We need to be using async IO in this case. 846 DCHECK(!callback.is_null()); 847 return ERR_IO_PENDING; 848 } 849 need_read_data_ = false; 850 } 851 852 return CompleteRead(); 853 } 854 855 int MockTCPClientSocket::Write(IOBuffer* buf, int buf_len, 856 const CompletionCallback& callback) { 857 DCHECK(buf); 858 DCHECK_GT(buf_len, 0); 859 860 if (!connected_) 861 return ERR_UNEXPECTED; 862 863 std::string data(buf->data(), buf_len); 864 MockWriteResult write_result = data_->OnWrite(data); 865 866 was_used_to_convey_data_ = true; 867 868 if (write_result.mode == ASYNC) { 869 RunCallbackAsync(callback, write_result.result); 870 return ERR_IO_PENDING; 871 } 872 873 return write_result.result; 874 } 875 876 int MockTCPClientSocket::Connect(const CompletionCallback& callback) { 877 if (connected_) 878 return OK; 879 connected_ = true; 880 peer_closed_connection_ = false; 881 if (data_->connect_data().mode == ASYNC) { 882 if (data_->connect_data().result == ERR_IO_PENDING) 883 pending_callback_ = callback; 884 else 885 RunCallbackAsync(callback, data_->connect_data().result); 886 return ERR_IO_PENDING; 887 } 888 return data_->connect_data().result; 889 } 890 891 void MockTCPClientSocket::Disconnect() { 892 MockClientSocket::Disconnect(); 893 pending_callback_.Reset(); 894 } 895 896 bool MockTCPClientSocket::IsConnected() const { 897 return connected_ && !peer_closed_connection_; 898 } 899 900 bool MockTCPClientSocket::IsConnectedAndIdle() const { 901 return IsConnected(); 902 } 903 904 int MockTCPClientSocket::GetPeerAddress(IPEndPoint* address) const { 905 if (addresses_.empty()) 906 return MockClientSocket::GetPeerAddress(address); 907 908 *address = addresses_[0]; 909 return OK; 910 } 911 912 bool MockTCPClientSocket::WasEverUsed() const { 913 return was_used_to_convey_data_; 914 } 915 916 bool MockTCPClientSocket::UsingTCPFastOpen() const { 917 return false; 918 } 919 920 bool MockTCPClientSocket::WasNpnNegotiated() const { 921 return false; 922 } 923 924 bool MockTCPClientSocket::GetSSLInfo(SSLInfo* ssl_info) { 925 return false; 926 } 927 928 void MockTCPClientSocket::OnReadComplete(const MockRead& data) { 929 // There must be a read pending. 930 DCHECK(pending_buf_); 931 // You can't complete a read with another ERR_IO_PENDING status code. 932 DCHECK_NE(ERR_IO_PENDING, data.result); 933 // Since we've been waiting for data, need_read_data_ should be true. 934 DCHECK(need_read_data_); 935 936 read_data_ = data; 937 need_read_data_ = false; 938 939 // The caller is simulating that this IO completes right now. Don't 940 // let CompleteRead() schedule a callback. 941 read_data_.mode = SYNCHRONOUS; 942 943 CompletionCallback callback = pending_callback_; 944 int rv = CompleteRead(); 945 RunCallback(callback, rv); 946 } 947 948 void MockTCPClientSocket::OnConnectComplete(const MockConnect& data) { 949 CompletionCallback callback = pending_callback_; 950 RunCallback(callback, data.result); 951 } 952 953 int MockTCPClientSocket::CompleteRead() { 954 DCHECK(pending_buf_); 955 DCHECK(pending_buf_len_ > 0); 956 957 was_used_to_convey_data_ = true; 958 959 // Save the pending async IO data and reset our |pending_| state. 960 scoped_refptr<IOBuffer> buf = pending_buf_; 961 int buf_len = pending_buf_len_; 962 CompletionCallback callback = pending_callback_; 963 pending_buf_ = NULL; 964 pending_buf_len_ = 0; 965 pending_callback_.Reset(); 966 967 int result = read_data_.result; 968 DCHECK(result != ERR_IO_PENDING); 969 970 if (read_data_.data) { 971 if (read_data_.data_len - read_offset_ > 0) { 972 result = std::min(buf_len, read_data_.data_len - read_offset_); 973 memcpy(buf->data(), read_data_.data + read_offset_, result); 974 read_offset_ += result; 975 if (read_offset_ == read_data_.data_len) { 976 need_read_data_ = true; 977 read_offset_ = 0; 978 } 979 } else { 980 result = 0; // EOF 981 } 982 } 983 984 if (read_data_.mode == ASYNC) { 985 DCHECK(!callback.is_null()); 986 RunCallbackAsync(callback, result); 987 return ERR_IO_PENDING; 988 } 989 return result; 990 } 991 992 DeterministicSocketHelper::DeterministicSocketHelper( 993 net::NetLog* net_log, 994 DeterministicSocketData* data) 995 : write_pending_(false), 996 write_result_(0), 997 read_data_(), 998 read_buf_(NULL), 999 read_buf_len_(0), 1000 read_pending_(false), 1001 data_(data), 1002 was_used_to_convey_data_(false), 1003 peer_closed_connection_(false), 1004 net_log_(BoundNetLog::Make(net_log, net::NetLog::SOURCE_NONE)) { 1005 } 1006 1007 DeterministicSocketHelper::~DeterministicSocketHelper() {} 1008 1009 void DeterministicSocketHelper::CompleteWrite() { 1010 was_used_to_convey_data_ = true; 1011 write_pending_ = false; 1012 write_callback_.Run(write_result_); 1013 } 1014 1015 int DeterministicSocketHelper::CompleteRead() { 1016 DCHECK_GT(read_buf_len_, 0); 1017 DCHECK_LE(read_data_.data_len, read_buf_len_); 1018 DCHECK(read_buf_); 1019 1020 was_used_to_convey_data_ = true; 1021 1022 if (read_data_.result == ERR_IO_PENDING) 1023 read_data_ = data_->GetNextRead(); 1024 DCHECK_NE(ERR_IO_PENDING, read_data_.result); 1025 // If read_data_.mode is ASYNC, we do not need to wait, since this is already 1026 // the callback. Therefore we don't even bother to check it. 1027 int result = read_data_.result; 1028 1029 if (read_data_.data_len > 0) { 1030 DCHECK(read_data_.data); 1031 result = std::min(read_buf_len_, read_data_.data_len); 1032 memcpy(read_buf_->data(), read_data_.data, result); 1033 } 1034 1035 if (read_pending_) { 1036 read_pending_ = false; 1037 read_callback_.Run(result); 1038 } 1039 1040 return result; 1041 } 1042 1043 int DeterministicSocketHelper::Write( 1044 IOBuffer* buf, int buf_len, const CompletionCallback& callback) { 1045 DCHECK(buf); 1046 DCHECK_GT(buf_len, 0); 1047 1048 std::string data(buf->data(), buf_len); 1049 MockWriteResult write_result = data_->OnWrite(data); 1050 1051 if (write_result.mode == ASYNC) { 1052 write_callback_ = callback; 1053 write_result_ = write_result.result; 1054 DCHECK(!write_callback_.is_null()); 1055 write_pending_ = true; 1056 return ERR_IO_PENDING; 1057 } 1058 1059 was_used_to_convey_data_ = true; 1060 write_pending_ = false; 1061 return write_result.result; 1062 } 1063 1064 int DeterministicSocketHelper::Read( 1065 IOBuffer* buf, int buf_len, const CompletionCallback& callback) { 1066 1067 read_data_ = data_->GetNextRead(); 1068 // The buffer should always be big enough to contain all the MockRead data. To 1069 // use small buffers, split the data into multiple MockReads. 1070 DCHECK_LE(read_data_.data_len, buf_len); 1071 1072 if (read_data_.result == ERR_CONNECTION_CLOSED) { 1073 // This MockRead is just a marker to instruct us to set 1074 // peer_closed_connection_. 1075 peer_closed_connection_ = true; 1076 } 1077 if (read_data_.result == ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ) { 1078 // This MockRead is just a marker to instruct us to set 1079 // peer_closed_connection_. Skip it and get the next one. 1080 read_data_ = data_->GetNextRead(); 1081 peer_closed_connection_ = true; 1082 } 1083 1084 read_buf_ = buf; 1085 read_buf_len_ = buf_len; 1086 read_callback_ = callback; 1087 1088 if (read_data_.mode == ASYNC || (read_data_.result == ERR_IO_PENDING)) { 1089 read_pending_ = true; 1090 DCHECK(!read_callback_.is_null()); 1091 return ERR_IO_PENDING; 1092 } 1093 1094 was_used_to_convey_data_ = true; 1095 return CompleteRead(); 1096 } 1097 1098 DeterministicMockUDPClientSocket::DeterministicMockUDPClientSocket( 1099 net::NetLog* net_log, 1100 DeterministicSocketData* data) 1101 : connected_(false), 1102 helper_(net_log, data) { 1103 } 1104 1105 DeterministicMockUDPClientSocket::~DeterministicMockUDPClientSocket() {} 1106 1107 bool DeterministicMockUDPClientSocket::WritePending() const { 1108 return helper_.write_pending(); 1109 } 1110 1111 bool DeterministicMockUDPClientSocket::ReadPending() const { 1112 return helper_.read_pending(); 1113 } 1114 1115 void DeterministicMockUDPClientSocket::CompleteWrite() { 1116 helper_.CompleteWrite(); 1117 } 1118 1119 int DeterministicMockUDPClientSocket::CompleteRead() { 1120 return helper_.CompleteRead(); 1121 } 1122 1123 int DeterministicMockUDPClientSocket::Connect(const IPEndPoint& address) { 1124 if (connected_) 1125 return OK; 1126 connected_ = true; 1127 peer_address_ = address; 1128 return helper_.data()->connect_data().result; 1129 }; 1130 1131 int DeterministicMockUDPClientSocket::Write( 1132 IOBuffer* buf, 1133 int buf_len, 1134 const CompletionCallback& callback) { 1135 if (!connected_) 1136 return ERR_UNEXPECTED; 1137 1138 return helper_.Write(buf, buf_len, callback); 1139 } 1140 1141 int DeterministicMockUDPClientSocket::Read( 1142 IOBuffer* buf, 1143 int buf_len, 1144 const CompletionCallback& callback) { 1145 if (!connected_) 1146 return ERR_UNEXPECTED; 1147 1148 return helper_.Read(buf, buf_len, callback); 1149 } 1150 1151 bool DeterministicMockUDPClientSocket::SetReceiveBufferSize(int32 size) { 1152 return true; 1153 } 1154 1155 bool DeterministicMockUDPClientSocket::SetSendBufferSize(int32 size) { 1156 return true; 1157 } 1158 1159 void DeterministicMockUDPClientSocket::Close() { 1160 connected_ = false; 1161 } 1162 1163 int DeterministicMockUDPClientSocket::GetPeerAddress( 1164 IPEndPoint* address) const { 1165 *address = peer_address_; 1166 return OK; 1167 } 1168 1169 int DeterministicMockUDPClientSocket::GetLocalAddress( 1170 IPEndPoint* address) const { 1171 IPAddressNumber ip; 1172 bool rv = ParseIPLiteralToNumber("192.0.2.33", &ip); 1173 CHECK(rv); 1174 *address = IPEndPoint(ip, 123); 1175 return OK; 1176 } 1177 1178 const BoundNetLog& DeterministicMockUDPClientSocket::NetLog() const { 1179 return helper_.net_log(); 1180 } 1181 1182 void DeterministicMockUDPClientSocket::OnReadComplete(const MockRead& data) {} 1183 1184 void DeterministicMockUDPClientSocket::OnConnectComplete( 1185 const MockConnect& data) { 1186 NOTIMPLEMENTED(); 1187 } 1188 1189 DeterministicMockTCPClientSocket::DeterministicMockTCPClientSocket( 1190 net::NetLog* net_log, 1191 DeterministicSocketData* data) 1192 : MockClientSocket(BoundNetLog::Make(net_log, net::NetLog::SOURCE_NONE)), 1193 helper_(net_log, data) { 1194 peer_addr_ = data->connect_data().peer_addr; 1195 } 1196 1197 DeterministicMockTCPClientSocket::~DeterministicMockTCPClientSocket() {} 1198 1199 bool DeterministicMockTCPClientSocket::WritePending() const { 1200 return helper_.write_pending(); 1201 } 1202 1203 bool DeterministicMockTCPClientSocket::ReadPending() const { 1204 return helper_.read_pending(); 1205 } 1206 1207 void DeterministicMockTCPClientSocket::CompleteWrite() { 1208 helper_.CompleteWrite(); 1209 } 1210 1211 int DeterministicMockTCPClientSocket::CompleteRead() { 1212 return helper_.CompleteRead(); 1213 } 1214 1215 int DeterministicMockTCPClientSocket::Write( 1216 IOBuffer* buf, 1217 int buf_len, 1218 const CompletionCallback& callback) { 1219 if (!connected_) 1220 return ERR_UNEXPECTED; 1221 1222 return helper_.Write(buf, buf_len, callback); 1223 } 1224 1225 int DeterministicMockTCPClientSocket::Read( 1226 IOBuffer* buf, 1227 int buf_len, 1228 const CompletionCallback& callback) { 1229 if (!connected_) 1230 return ERR_UNEXPECTED; 1231 1232 return helper_.Read(buf, buf_len, callback); 1233 } 1234 1235 // TODO(erikchen): Support connect sequencing. 1236 int DeterministicMockTCPClientSocket::Connect( 1237 const CompletionCallback& callback) { 1238 if (connected_) 1239 return OK; 1240 connected_ = true; 1241 if (helper_.data()->connect_data().mode == ASYNC) { 1242 RunCallbackAsync(callback, helper_.data()->connect_data().result); 1243 return ERR_IO_PENDING; 1244 } 1245 return helper_.data()->connect_data().result; 1246 } 1247 1248 void DeterministicMockTCPClientSocket::Disconnect() { 1249 MockClientSocket::Disconnect(); 1250 } 1251 1252 bool DeterministicMockTCPClientSocket::IsConnected() const { 1253 return connected_ && !helper_.peer_closed_connection(); 1254 } 1255 1256 bool DeterministicMockTCPClientSocket::IsConnectedAndIdle() const { 1257 return IsConnected(); 1258 } 1259 1260 bool DeterministicMockTCPClientSocket::WasEverUsed() const { 1261 return helper_.was_used_to_convey_data(); 1262 } 1263 1264 bool DeterministicMockTCPClientSocket::UsingTCPFastOpen() const { 1265 return false; 1266 } 1267 1268 bool DeterministicMockTCPClientSocket::WasNpnNegotiated() const { 1269 return false; 1270 } 1271 1272 bool DeterministicMockTCPClientSocket::GetSSLInfo(SSLInfo* ssl_info) { 1273 return false; 1274 } 1275 1276 void DeterministicMockTCPClientSocket::OnReadComplete(const MockRead& data) {} 1277 1278 void DeterministicMockTCPClientSocket::OnConnectComplete( 1279 const MockConnect& data) {} 1280 1281 // static 1282 void MockSSLClientSocket::ConnectCallback( 1283 MockSSLClientSocket* ssl_client_socket, 1284 const CompletionCallback& callback, 1285 int rv) { 1286 if (rv == OK) 1287 ssl_client_socket->connected_ = true; 1288 callback.Run(rv); 1289 } 1290 1291 MockSSLClientSocket::MockSSLClientSocket( 1292 scoped_ptr<ClientSocketHandle> transport_socket, 1293 const HostPortPair& host_port_pair, 1294 const SSLConfig& ssl_config, 1295 SSLSocketDataProvider* data) 1296 : MockClientSocket( 1297 // Have to use the right BoundNetLog for LoadTimingInfo regression 1298 // tests. 1299 transport_socket->socket()->NetLog()), 1300 transport_(transport_socket.Pass()), 1301 data_(data), 1302 is_npn_state_set_(false), 1303 new_npn_value_(false), 1304 is_protocol_negotiated_set_(false), 1305 protocol_negotiated_(kProtoUnknown) { 1306 DCHECK(data_); 1307 peer_addr_ = data->connect.peer_addr; 1308 } 1309 1310 MockSSLClientSocket::~MockSSLClientSocket() { 1311 Disconnect(); 1312 } 1313 1314 int MockSSLClientSocket::Read(IOBuffer* buf, int buf_len, 1315 const CompletionCallback& callback) { 1316 return transport_->socket()->Read(buf, buf_len, callback); 1317 } 1318 1319 int MockSSLClientSocket::Write(IOBuffer* buf, int buf_len, 1320 const CompletionCallback& callback) { 1321 return transport_->socket()->Write(buf, buf_len, callback); 1322 } 1323 1324 int MockSSLClientSocket::Connect(const CompletionCallback& callback) { 1325 int rv = transport_->socket()->Connect( 1326 base::Bind(&ConnectCallback, base::Unretained(this), callback)); 1327 if (rv == OK) { 1328 if (data_->connect.result == OK) 1329 connected_ = true; 1330 if (data_->connect.mode == ASYNC) { 1331 RunCallbackAsync(callback, data_->connect.result); 1332 return ERR_IO_PENDING; 1333 } 1334 return data_->connect.result; 1335 } 1336 return rv; 1337 } 1338 1339 void MockSSLClientSocket::Disconnect() { 1340 MockClientSocket::Disconnect(); 1341 if (transport_->socket() != NULL) 1342 transport_->socket()->Disconnect(); 1343 } 1344 1345 bool MockSSLClientSocket::IsConnected() const { 1346 return transport_->socket()->IsConnected(); 1347 } 1348 1349 bool MockSSLClientSocket::WasEverUsed() const { 1350 return transport_->socket()->WasEverUsed(); 1351 } 1352 1353 bool MockSSLClientSocket::UsingTCPFastOpen() const { 1354 return transport_->socket()->UsingTCPFastOpen(); 1355 } 1356 1357 int MockSSLClientSocket::GetPeerAddress(IPEndPoint* address) const { 1358 return transport_->socket()->GetPeerAddress(address); 1359 } 1360 1361 bool MockSSLClientSocket::GetSSLInfo(SSLInfo* ssl_info) { 1362 ssl_info->Reset(); 1363 ssl_info->cert = data_->cert; 1364 ssl_info->client_cert_sent = data_->client_cert_sent; 1365 ssl_info->channel_id_sent = data_->channel_id_sent; 1366 return true; 1367 } 1368 1369 void MockSSLClientSocket::GetSSLCertRequestInfo( 1370 SSLCertRequestInfo* cert_request_info) { 1371 DCHECK(cert_request_info); 1372 if (data_->cert_request_info) { 1373 cert_request_info->host_and_port = 1374 data_->cert_request_info->host_and_port; 1375 cert_request_info->client_certs = data_->cert_request_info->client_certs; 1376 } else { 1377 cert_request_info->Reset(); 1378 } 1379 } 1380 1381 SSLClientSocket::NextProtoStatus MockSSLClientSocket::GetNextProto( 1382 std::string* proto, std::string* server_protos) { 1383 *proto = data_->next_proto; 1384 *server_protos = data_->server_protos; 1385 return data_->next_proto_status; 1386 } 1387 1388 bool MockSSLClientSocket::set_was_npn_negotiated(bool negotiated) { 1389 is_npn_state_set_ = true; 1390 return new_npn_value_ = negotiated; 1391 } 1392 1393 bool MockSSLClientSocket::WasNpnNegotiated() const { 1394 if (is_npn_state_set_) 1395 return new_npn_value_; 1396 return data_->was_npn_negotiated; 1397 } 1398 1399 NextProto MockSSLClientSocket::GetNegotiatedProtocol() const { 1400 if (is_protocol_negotiated_set_) 1401 return protocol_negotiated_; 1402 return data_->protocol_negotiated; 1403 } 1404 1405 void MockSSLClientSocket::set_protocol_negotiated( 1406 NextProto protocol_negotiated) { 1407 is_protocol_negotiated_set_ = true; 1408 protocol_negotiated_ = protocol_negotiated; 1409 } 1410 1411 bool MockSSLClientSocket::WasChannelIDSent() const { 1412 return data_->channel_id_sent; 1413 } 1414 1415 void MockSSLClientSocket::set_channel_id_sent(bool channel_id_sent) { 1416 data_->channel_id_sent = channel_id_sent; 1417 } 1418 1419 ServerBoundCertService* MockSSLClientSocket::GetServerBoundCertService() const { 1420 return data_->server_bound_cert_service; 1421 } 1422 1423 void MockSSLClientSocket::OnReadComplete(const MockRead& data) { 1424 NOTIMPLEMENTED(); 1425 } 1426 1427 void MockSSLClientSocket::OnConnectComplete(const MockConnect& data) { 1428 NOTIMPLEMENTED(); 1429 } 1430 1431 MockUDPClientSocket::MockUDPClientSocket(SocketDataProvider* data, 1432 net::NetLog* net_log) 1433 : connected_(false), 1434 data_(data), 1435 read_offset_(0), 1436 read_data_(SYNCHRONOUS, ERR_UNEXPECTED), 1437 need_read_data_(true), 1438 pending_buf_(NULL), 1439 pending_buf_len_(0), 1440 net_log_(BoundNetLog::Make(net_log, net::NetLog::SOURCE_NONE)), 1441 weak_factory_(this) { 1442 DCHECK(data_); 1443 data_->Reset(); 1444 peer_addr_ = data->connect_data().peer_addr; 1445 } 1446 1447 MockUDPClientSocket::~MockUDPClientSocket() {} 1448 1449 int MockUDPClientSocket::Read(IOBuffer* buf, int buf_len, 1450 const CompletionCallback& callback) { 1451 if (!connected_) 1452 return ERR_UNEXPECTED; 1453 1454 // If the buffer is already in use, a read is already in progress! 1455 DCHECK(pending_buf_ == NULL); 1456 1457 // Store our async IO data. 1458 pending_buf_ = buf; 1459 pending_buf_len_ = buf_len; 1460 pending_callback_ = callback; 1461 1462 if (need_read_data_) { 1463 read_data_ = data_->GetNextRead(); 1464 // ERR_IO_PENDING means that the SocketDataProvider is taking responsibility 1465 // to complete the async IO manually later (via OnReadComplete). 1466 if (read_data_.result == ERR_IO_PENDING) { 1467 // We need to be using async IO in this case. 1468 DCHECK(!callback.is_null()); 1469 return ERR_IO_PENDING; 1470 } 1471 need_read_data_ = false; 1472 } 1473 1474 return CompleteRead(); 1475 } 1476 1477 int MockUDPClientSocket::Write(IOBuffer* buf, int buf_len, 1478 const CompletionCallback& callback) { 1479 DCHECK(buf); 1480 DCHECK_GT(buf_len, 0); 1481 1482 if (!connected_) 1483 return ERR_UNEXPECTED; 1484 1485 std::string data(buf->data(), buf_len); 1486 MockWriteResult write_result = data_->OnWrite(data); 1487 1488 if (write_result.mode == ASYNC) { 1489 RunCallbackAsync(callback, write_result.result); 1490 return ERR_IO_PENDING; 1491 } 1492 return write_result.result; 1493 } 1494 1495 bool MockUDPClientSocket::SetReceiveBufferSize(int32 size) { 1496 return true; 1497 } 1498 1499 bool MockUDPClientSocket::SetSendBufferSize(int32 size) { 1500 return true; 1501 } 1502 1503 void MockUDPClientSocket::Close() { 1504 connected_ = false; 1505 } 1506 1507 int MockUDPClientSocket::GetPeerAddress(IPEndPoint* address) const { 1508 *address = peer_addr_; 1509 return OK; 1510 } 1511 1512 int MockUDPClientSocket::GetLocalAddress(IPEndPoint* address) const { 1513 IPAddressNumber ip; 1514 bool rv = ParseIPLiteralToNumber("192.0.2.33", &ip); 1515 CHECK(rv); 1516 *address = IPEndPoint(ip, 123); 1517 return OK; 1518 } 1519 1520 const BoundNetLog& MockUDPClientSocket::NetLog() const { 1521 return net_log_; 1522 } 1523 1524 int MockUDPClientSocket::Connect(const IPEndPoint& address) { 1525 connected_ = true; 1526 peer_addr_ = address; 1527 return OK; 1528 } 1529 1530 void MockUDPClientSocket::OnReadComplete(const MockRead& data) { 1531 // There must be a read pending. 1532 DCHECK(pending_buf_); 1533 // You can't complete a read with another ERR_IO_PENDING status code. 1534 DCHECK_NE(ERR_IO_PENDING, data.result); 1535 // Since we've been waiting for data, need_read_data_ should be true. 1536 DCHECK(need_read_data_); 1537 1538 read_data_ = data; 1539 need_read_data_ = false; 1540 1541 // The caller is simulating that this IO completes right now. Don't 1542 // let CompleteRead() schedule a callback. 1543 read_data_.mode = SYNCHRONOUS; 1544 1545 net::CompletionCallback callback = pending_callback_; 1546 int rv = CompleteRead(); 1547 RunCallback(callback, rv); 1548 } 1549 1550 void MockUDPClientSocket::OnConnectComplete(const MockConnect& data) { 1551 NOTIMPLEMENTED(); 1552 } 1553 1554 int MockUDPClientSocket::CompleteRead() { 1555 DCHECK(pending_buf_); 1556 DCHECK(pending_buf_len_ > 0); 1557 1558 // Save the pending async IO data and reset our |pending_| state. 1559 scoped_refptr<IOBuffer> buf = pending_buf_; 1560 int buf_len = pending_buf_len_; 1561 CompletionCallback callback = pending_callback_; 1562 pending_buf_ = NULL; 1563 pending_buf_len_ = 0; 1564 pending_callback_.Reset(); 1565 1566 int result = read_data_.result; 1567 DCHECK(result != ERR_IO_PENDING); 1568 1569 if (read_data_.data) { 1570 if (read_data_.data_len - read_offset_ > 0) { 1571 result = std::min(buf_len, read_data_.data_len - read_offset_); 1572 memcpy(buf->data(), read_data_.data + read_offset_, result); 1573 read_offset_ += result; 1574 if (read_offset_ == read_data_.data_len) { 1575 need_read_data_ = true; 1576 read_offset_ = 0; 1577 } 1578 } else { 1579 result = 0; // EOF 1580 } 1581 } 1582 1583 if (read_data_.mode == ASYNC) { 1584 DCHECK(!callback.is_null()); 1585 RunCallbackAsync(callback, result); 1586 return ERR_IO_PENDING; 1587 } 1588 return result; 1589 } 1590 1591 void MockUDPClientSocket::RunCallbackAsync(const CompletionCallback& callback, 1592 int result) { 1593 base::MessageLoop::current()->PostTask( 1594 FROM_HERE, 1595 base::Bind(&MockUDPClientSocket::RunCallback, 1596 weak_factory_.GetWeakPtr(), 1597 callback, 1598 result)); 1599 } 1600 1601 void MockUDPClientSocket::RunCallback(const CompletionCallback& callback, 1602 int result) { 1603 if (!callback.is_null()) 1604 callback.Run(result); 1605 } 1606 1607 TestSocketRequest::TestSocketRequest( 1608 std::vector<TestSocketRequest*>* request_order, size_t* completion_count) 1609 : request_order_(request_order), 1610 completion_count_(completion_count), 1611 callback_(base::Bind(&TestSocketRequest::OnComplete, 1612 base::Unretained(this))) { 1613 DCHECK(request_order); 1614 DCHECK(completion_count); 1615 } 1616 1617 TestSocketRequest::~TestSocketRequest() { 1618 } 1619 1620 void TestSocketRequest::OnComplete(int result) { 1621 SetResult(result); 1622 (*completion_count_)++; 1623 request_order_->push_back(this); 1624 } 1625 1626 // static 1627 const int ClientSocketPoolTest::kIndexOutOfBounds = -1; 1628 1629 // static 1630 const int ClientSocketPoolTest::kRequestNotFound = -2; 1631 1632 ClientSocketPoolTest::ClientSocketPoolTest() : completion_count_(0) {} 1633 ClientSocketPoolTest::~ClientSocketPoolTest() {} 1634 1635 int ClientSocketPoolTest::GetOrderOfRequest(size_t index) const { 1636 index--; 1637 if (index >= requests_.size()) 1638 return kIndexOutOfBounds; 1639 1640 for (size_t i = 0; i < request_order_.size(); i++) 1641 if (requests_[index] == request_order_[i]) 1642 return i + 1; 1643 1644 return kRequestNotFound; 1645 } 1646 1647 bool ClientSocketPoolTest::ReleaseOneConnection(KeepAlive keep_alive) { 1648 ScopedVector<TestSocketRequest>::iterator i; 1649 for (i = requests_.begin(); i != requests_.end(); ++i) { 1650 if ((*i)->handle()->is_initialized()) { 1651 if (keep_alive == NO_KEEP_ALIVE) 1652 (*i)->handle()->socket()->Disconnect(); 1653 (*i)->handle()->Reset(); 1654 base::RunLoop().RunUntilIdle(); 1655 return true; 1656 } 1657 } 1658 return false; 1659 } 1660 1661 void ClientSocketPoolTest::ReleaseAllConnections(KeepAlive keep_alive) { 1662 bool released_one; 1663 do { 1664 released_one = ReleaseOneConnection(keep_alive); 1665 } while (released_one); 1666 } 1667 1668 MockTransportClientSocketPool::MockConnectJob::MockConnectJob( 1669 scoped_ptr<StreamSocket> socket, 1670 ClientSocketHandle* handle, 1671 const CompletionCallback& callback) 1672 : socket_(socket.Pass()), 1673 handle_(handle), 1674 user_callback_(callback) { 1675 } 1676 1677 MockTransportClientSocketPool::MockConnectJob::~MockConnectJob() {} 1678 1679 int MockTransportClientSocketPool::MockConnectJob::Connect() { 1680 int rv = socket_->Connect(base::Bind(&MockConnectJob::OnConnect, 1681 base::Unretained(this))); 1682 if (rv == OK) { 1683 user_callback_.Reset(); 1684 OnConnect(OK); 1685 } 1686 return rv; 1687 } 1688 1689 bool MockTransportClientSocketPool::MockConnectJob::CancelHandle( 1690 const ClientSocketHandle* handle) { 1691 if (handle != handle_) 1692 return false; 1693 socket_.reset(); 1694 handle_ = NULL; 1695 user_callback_.Reset(); 1696 return true; 1697 } 1698 1699 void MockTransportClientSocketPool::MockConnectJob::OnConnect(int rv) { 1700 if (!socket_.get()) 1701 return; 1702 if (rv == OK) { 1703 handle_->SetSocket(socket_.Pass()); 1704 1705 // Needed for socket pool tests that layer other sockets on top of mock 1706 // sockets. 1707 LoadTimingInfo::ConnectTiming connect_timing; 1708 base::TimeTicks now = base::TimeTicks::Now(); 1709 connect_timing.dns_start = now; 1710 connect_timing.dns_end = now; 1711 connect_timing.connect_start = now; 1712 connect_timing.connect_end = now; 1713 handle_->set_connect_timing(connect_timing); 1714 } else { 1715 socket_.reset(); 1716 } 1717 1718 handle_ = NULL; 1719 1720 if (!user_callback_.is_null()) { 1721 CompletionCallback callback = user_callback_; 1722 user_callback_.Reset(); 1723 callback.Run(rv); 1724 } 1725 } 1726 1727 MockTransportClientSocketPool::MockTransportClientSocketPool( 1728 int max_sockets, 1729 int max_sockets_per_group, 1730 ClientSocketPoolHistograms* histograms, 1731 ClientSocketFactory* socket_factory) 1732 : TransportClientSocketPool(max_sockets, max_sockets_per_group, histograms, 1733 NULL, NULL, NULL), 1734 client_socket_factory_(socket_factory), 1735 last_request_priority_(DEFAULT_PRIORITY), 1736 release_count_(0), 1737 cancel_count_(0) { 1738 } 1739 1740 MockTransportClientSocketPool::~MockTransportClientSocketPool() {} 1741 1742 int MockTransportClientSocketPool::RequestSocket( 1743 const std::string& group_name, const void* socket_params, 1744 RequestPriority priority, ClientSocketHandle* handle, 1745 const CompletionCallback& callback, const BoundNetLog& net_log) { 1746 last_request_priority_ = priority; 1747 scoped_ptr<StreamSocket> socket = 1748 client_socket_factory_->CreateTransportClientSocket( 1749 AddressList(), net_log.net_log(), net::NetLog::Source()); 1750 MockConnectJob* job = new MockConnectJob(socket.Pass(), handle, callback); 1751 job_list_.push_back(job); 1752 handle->set_pool_id(1); 1753 return job->Connect(); 1754 } 1755 1756 void MockTransportClientSocketPool::CancelRequest(const std::string& group_name, 1757 ClientSocketHandle* handle) { 1758 std::vector<MockConnectJob*>::iterator i; 1759 for (i = job_list_.begin(); i != job_list_.end(); ++i) { 1760 if ((*i)->CancelHandle(handle)) { 1761 cancel_count_++; 1762 break; 1763 } 1764 } 1765 } 1766 1767 void MockTransportClientSocketPool::ReleaseSocket( 1768 const std::string& group_name, 1769 scoped_ptr<StreamSocket> socket, 1770 int id) { 1771 EXPECT_EQ(1, id); 1772 release_count_++; 1773 } 1774 1775 DeterministicMockClientSocketFactory::DeterministicMockClientSocketFactory() {} 1776 1777 DeterministicMockClientSocketFactory::~DeterministicMockClientSocketFactory() {} 1778 1779 void DeterministicMockClientSocketFactory::AddSocketDataProvider( 1780 DeterministicSocketData* data) { 1781 mock_data_.Add(data); 1782 } 1783 1784 void DeterministicMockClientSocketFactory::AddSSLSocketDataProvider( 1785 SSLSocketDataProvider* data) { 1786 mock_ssl_data_.Add(data); 1787 } 1788 1789 void DeterministicMockClientSocketFactory::ResetNextMockIndexes() { 1790 mock_data_.ResetNextIndex(); 1791 mock_ssl_data_.ResetNextIndex(); 1792 } 1793 1794 MockSSLClientSocket* DeterministicMockClientSocketFactory:: 1795 GetMockSSLClientSocket(size_t index) const { 1796 DCHECK_LT(index, ssl_client_sockets_.size()); 1797 return ssl_client_sockets_[index]; 1798 } 1799 1800 scoped_ptr<DatagramClientSocket> 1801 DeterministicMockClientSocketFactory::CreateDatagramClientSocket( 1802 DatagramSocket::BindType bind_type, 1803 const RandIntCallback& rand_int_cb, 1804 net::NetLog* net_log, 1805 const NetLog::Source& source) { 1806 DeterministicSocketData* data_provider = mock_data().GetNext(); 1807 scoped_ptr<DeterministicMockUDPClientSocket> socket( 1808 new DeterministicMockUDPClientSocket(net_log, data_provider)); 1809 data_provider->set_delegate(socket->AsWeakPtr()); 1810 udp_client_sockets().push_back(socket.get()); 1811 return socket.PassAs<DatagramClientSocket>(); 1812 } 1813 1814 scoped_ptr<StreamSocket> 1815 DeterministicMockClientSocketFactory::CreateTransportClientSocket( 1816 const AddressList& addresses, 1817 net::NetLog* net_log, 1818 const net::NetLog::Source& source) { 1819 DeterministicSocketData* data_provider = mock_data().GetNext(); 1820 scoped_ptr<DeterministicMockTCPClientSocket> socket( 1821 new DeterministicMockTCPClientSocket(net_log, data_provider)); 1822 data_provider->set_delegate(socket->AsWeakPtr()); 1823 tcp_client_sockets().push_back(socket.get()); 1824 return socket.PassAs<StreamSocket>(); 1825 } 1826 1827 scoped_ptr<SSLClientSocket> 1828 DeterministicMockClientSocketFactory::CreateSSLClientSocket( 1829 scoped_ptr<ClientSocketHandle> transport_socket, 1830 const HostPortPair& host_and_port, 1831 const SSLConfig& ssl_config, 1832 const SSLClientSocketContext& context) { 1833 scoped_ptr<MockSSLClientSocket> socket( 1834 new MockSSLClientSocket(transport_socket.Pass(), 1835 host_and_port, ssl_config, 1836 mock_ssl_data_.GetNext())); 1837 ssl_client_sockets_.push_back(socket.get()); 1838 return socket.PassAs<SSLClientSocket>(); 1839 } 1840 1841 void DeterministicMockClientSocketFactory::ClearSSLSessionCache() { 1842 } 1843 1844 MockSOCKSClientSocketPool::MockSOCKSClientSocketPool( 1845 int max_sockets, 1846 int max_sockets_per_group, 1847 ClientSocketPoolHistograms* histograms, 1848 TransportClientSocketPool* transport_pool) 1849 : SOCKSClientSocketPool(max_sockets, max_sockets_per_group, histograms, 1850 NULL, transport_pool, NULL), 1851 transport_pool_(transport_pool) { 1852 } 1853 1854 MockSOCKSClientSocketPool::~MockSOCKSClientSocketPool() {} 1855 1856 int MockSOCKSClientSocketPool::RequestSocket( 1857 const std::string& group_name, const void* socket_params, 1858 RequestPriority priority, ClientSocketHandle* handle, 1859 const CompletionCallback& callback, const BoundNetLog& net_log) { 1860 return transport_pool_->RequestSocket( 1861 group_name, socket_params, priority, handle, callback, net_log); 1862 } 1863 1864 void MockSOCKSClientSocketPool::CancelRequest( 1865 const std::string& group_name, 1866 ClientSocketHandle* handle) { 1867 return transport_pool_->CancelRequest(group_name, handle); 1868 } 1869 1870 void MockSOCKSClientSocketPool::ReleaseSocket(const std::string& group_name, 1871 scoped_ptr<StreamSocket> socket, 1872 int id) { 1873 return transport_pool_->ReleaseSocket(group_name, socket.Pass(), id); 1874 } 1875 1876 const char kSOCKS5GreetRequest[] = { 0x05, 0x01, 0x00 }; 1877 const int kSOCKS5GreetRequestLength = arraysize(kSOCKS5GreetRequest); 1878 1879 const char kSOCKS5GreetResponse[] = { 0x05, 0x00 }; 1880 const int kSOCKS5GreetResponseLength = arraysize(kSOCKS5GreetResponse); 1881 1882 const char kSOCKS5OkRequest[] = 1883 { 0x05, 0x01, 0x00, 0x03, 0x04, 'h', 'o', 's', 't', 0x00, 0x50 }; 1884 const int kSOCKS5OkRequestLength = arraysize(kSOCKS5OkRequest); 1885 1886 const char kSOCKS5OkResponse[] = 1887 { 0x05, 0x00, 0x00, 0x01, 127, 0, 0, 1, 0x00, 0x50 }; 1888 const int kSOCKS5OkResponseLength = arraysize(kSOCKS5OkResponse); 1889 1890 } // namespace net 1891