1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "net/socket/socket_test_util.h" 6 7 #include <algorithm> 8 #include <vector> 9 10 #include "base/basictypes.h" 11 #include "base/bind.h" 12 #include "base/bind_helpers.h" 13 #include "base/compiler_specific.h" 14 #include "base/message_loop/message_loop.h" 15 #include "base/run_loop.h" 16 #include "base/time/time.h" 17 #include "net/base/address_family.h" 18 #include "net/base/address_list.h" 19 #include "net/base/auth.h" 20 #include "net/base/load_timing_info.h" 21 #include "net/http/http_network_session.h" 22 #include "net/http/http_request_headers.h" 23 #include "net/http/http_response_headers.h" 24 #include "net/socket/client_socket_pool_histograms.h" 25 #include "net/socket/socket.h" 26 #include "net/ssl/ssl_cert_request_info.h" 27 #include "net/ssl/ssl_info.h" 28 #include "testing/gtest/include/gtest/gtest.h" 29 30 // Socket events are easier to debug if you log individual reads and writes. 31 // Enable these if locally debugging, but they are too noisy for the waterfall. 32 #if 0 33 #define NET_TRACE(level, s) DLOG(level) << s << __FUNCTION__ << "() " 34 #else 35 #define NET_TRACE(level, s) EAT_STREAM_PARAMETERS 36 #endif 37 38 namespace net { 39 40 namespace { 41 42 inline char AsciifyHigh(char x) { 43 char nybble = static_cast<char>((x >> 4) & 0x0F); 44 return nybble + ((nybble < 0x0A) ? '0' : 'A' - 10); 45 } 46 47 inline char AsciifyLow(char x) { 48 char nybble = static_cast<char>((x >> 0) & 0x0F); 49 return nybble + ((nybble < 0x0A) ? '0' : 'A' - 10); 50 } 51 52 inline char Asciify(char x) { 53 if ((x < 0) || !isprint(x)) 54 return '.'; 55 return x; 56 } 57 58 void DumpData(const char* data, int data_len) { 59 if (logging::LOG_INFO < logging::GetMinLogLevel()) 60 return; 61 DVLOG(1) << "Length: " << data_len; 62 const char* pfx = "Data: "; 63 if (!data || (data_len <= 0)) { 64 DVLOG(1) << pfx << "<None>"; 65 } else { 66 int i; 67 for (i = 0; i <= (data_len - 4); i += 4) { 68 DVLOG(1) << pfx 69 << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0]) 70 << AsciifyHigh(data[i + 1]) << AsciifyLow(data[i + 1]) 71 << AsciifyHigh(data[i + 2]) << AsciifyLow(data[i + 2]) 72 << AsciifyHigh(data[i + 3]) << AsciifyLow(data[i + 3]) 73 << " '" 74 << Asciify(data[i + 0]) 75 << Asciify(data[i + 1]) 76 << Asciify(data[i + 2]) 77 << Asciify(data[i + 3]) 78 << "'"; 79 pfx = " "; 80 } 81 // Take care of any 'trailing' bytes, if data_len was not a multiple of 4. 82 switch (data_len - i) { 83 case 3: 84 DVLOG(1) << pfx 85 << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0]) 86 << AsciifyHigh(data[i + 1]) << AsciifyLow(data[i + 1]) 87 << AsciifyHigh(data[i + 2]) << AsciifyLow(data[i + 2]) 88 << " '" 89 << Asciify(data[i + 0]) 90 << Asciify(data[i + 1]) 91 << Asciify(data[i + 2]) 92 << " '"; 93 break; 94 case 2: 95 DVLOG(1) << pfx 96 << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0]) 97 << AsciifyHigh(data[i + 1]) << AsciifyLow(data[i + 1]) 98 << " '" 99 << Asciify(data[i + 0]) 100 << Asciify(data[i + 1]) 101 << " '"; 102 break; 103 case 1: 104 DVLOG(1) << pfx 105 << AsciifyHigh(data[i + 0]) << AsciifyLow(data[i + 0]) 106 << " '" 107 << Asciify(data[i + 0]) 108 << " '"; 109 break; 110 } 111 } 112 } 113 114 template <MockReadWriteType type> 115 void DumpMockReadWrite(const MockReadWrite<type>& r) { 116 if (logging::LOG_INFO < logging::GetMinLogLevel()) 117 return; 118 DVLOG(1) << "Async: " << (r.mode == ASYNC) 119 << "\nResult: " << r.result; 120 DumpData(r.data, r.data_len); 121 const char* stop = (r.sequence_number & MockRead::STOPLOOP) ? " (STOP)" : ""; 122 DVLOG(1) << "Stage: " << (r.sequence_number & ~MockRead::STOPLOOP) << stop 123 << "\nTime: " << r.time_stamp.ToInternalValue(); 124 } 125 126 } // namespace 127 128 MockConnect::MockConnect() : mode(ASYNC), result(OK) { 129 IPAddressNumber ip; 130 CHECK(ParseIPLiteralToNumber("192.0.2.33", &ip)); 131 peer_addr = IPEndPoint(ip, 0); 132 } 133 134 MockConnect::MockConnect(IoMode io_mode, int r) : mode(io_mode), result(r) { 135 IPAddressNumber ip; 136 CHECK(ParseIPLiteralToNumber("192.0.2.33", &ip)); 137 peer_addr = IPEndPoint(ip, 0); 138 } 139 140 MockConnect::MockConnect(IoMode io_mode, int r, IPEndPoint addr) : 141 mode(io_mode), 142 result(r), 143 peer_addr(addr) { 144 } 145 146 MockConnect::~MockConnect() {} 147 148 StaticSocketDataProvider::StaticSocketDataProvider() 149 : reads_(NULL), 150 read_index_(0), 151 read_count_(0), 152 writes_(NULL), 153 write_index_(0), 154 write_count_(0) { 155 } 156 157 StaticSocketDataProvider::StaticSocketDataProvider(MockRead* reads, 158 size_t reads_count, 159 MockWrite* writes, 160 size_t writes_count) 161 : reads_(reads), 162 read_index_(0), 163 read_count_(reads_count), 164 writes_(writes), 165 write_index_(0), 166 write_count_(writes_count) { 167 } 168 169 StaticSocketDataProvider::~StaticSocketDataProvider() {} 170 171 const MockRead& StaticSocketDataProvider::PeekRead() const { 172 DCHECK(!at_read_eof()); 173 return reads_[read_index_]; 174 } 175 176 const MockWrite& StaticSocketDataProvider::PeekWrite() const { 177 DCHECK(!at_write_eof()); 178 return writes_[write_index_]; 179 } 180 181 const MockRead& StaticSocketDataProvider::PeekRead(size_t index) const { 182 DCHECK_LT(index, read_count_); 183 return reads_[index]; 184 } 185 186 const MockWrite& StaticSocketDataProvider::PeekWrite(size_t index) const { 187 DCHECK_LT(index, write_count_); 188 return writes_[index]; 189 } 190 191 MockRead StaticSocketDataProvider::GetNextRead() { 192 DCHECK(!at_read_eof()); 193 reads_[read_index_].time_stamp = base::Time::Now(); 194 return reads_[read_index_++]; 195 } 196 197 MockWriteResult StaticSocketDataProvider::OnWrite(const std::string& data) { 198 if (!writes_) { 199 // Not using mock writes; succeed synchronously. 200 return MockWriteResult(SYNCHRONOUS, data.length()); 201 } 202 DCHECK(!at_write_eof()); 203 204 // Check that what we are writing matches the expectation. 205 // Then give the mocked return value. 206 MockWrite* w = &writes_[write_index_++]; 207 w->time_stamp = base::Time::Now(); 208 int result = w->result; 209 if (w->data) { 210 // Note - we can simulate a partial write here. If the expected data 211 // is a match, but shorter than the write actually written, that is legal. 212 // Example: 213 // Application writes "foobarbaz" (9 bytes) 214 // Expected write was "foo" (3 bytes) 215 // This is a success, and we return 3 to the application. 216 std::string expected_data(w->data, w->data_len); 217 EXPECT_GE(data.length(), expected_data.length()); 218 std::string actual_data(data.substr(0, w->data_len)); 219 EXPECT_EQ(expected_data, actual_data); 220 if (expected_data != actual_data) 221 return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED); 222 if (result == OK) 223 result = w->data_len; 224 } 225 return MockWriteResult(w->mode, result); 226 } 227 228 void StaticSocketDataProvider::Reset() { 229 read_index_ = 0; 230 write_index_ = 0; 231 } 232 233 DynamicSocketDataProvider::DynamicSocketDataProvider() 234 : short_read_limit_(0), 235 allow_unconsumed_reads_(false) { 236 } 237 238 DynamicSocketDataProvider::~DynamicSocketDataProvider() {} 239 240 MockRead DynamicSocketDataProvider::GetNextRead() { 241 if (reads_.empty()) 242 return MockRead(SYNCHRONOUS, ERR_UNEXPECTED); 243 MockRead result = reads_.front(); 244 if (short_read_limit_ == 0 || result.data_len <= short_read_limit_) { 245 reads_.pop_front(); 246 } else { 247 result.data_len = short_read_limit_; 248 reads_.front().data += result.data_len; 249 reads_.front().data_len -= result.data_len; 250 } 251 return result; 252 } 253 254 void DynamicSocketDataProvider::Reset() { 255 reads_.clear(); 256 } 257 258 void DynamicSocketDataProvider::SimulateRead(const char* data, 259 const size_t length) { 260 if (!allow_unconsumed_reads_) { 261 EXPECT_TRUE(reads_.empty()) << "Unconsumed read: " << reads_.front().data; 262 } 263 reads_.push_back(MockRead(ASYNC, data, length)); 264 } 265 266 SSLSocketDataProvider::SSLSocketDataProvider(IoMode mode, int result) 267 : connect(mode, result), 268 next_proto_status(SSLClientSocket::kNextProtoUnsupported), 269 was_npn_negotiated(false), 270 protocol_negotiated(kProtoUnknown), 271 client_cert_sent(false), 272 cert_request_info(NULL), 273 channel_id_sent(false) { 274 } 275 276 SSLSocketDataProvider::~SSLSocketDataProvider() { 277 } 278 279 void SSLSocketDataProvider::SetNextProto(NextProto proto) { 280 was_npn_negotiated = true; 281 next_proto_status = SSLClientSocket::kNextProtoNegotiated; 282 protocol_negotiated = proto; 283 next_proto = SSLClientSocket::NextProtoToString(proto); 284 } 285 286 DelayedSocketData::DelayedSocketData( 287 int write_delay, MockRead* reads, size_t reads_count, 288 MockWrite* writes, size_t writes_count) 289 : StaticSocketDataProvider(reads, reads_count, writes, writes_count), 290 write_delay_(write_delay), 291 read_in_progress_(false), 292 weak_factory_(this) { 293 DCHECK_GE(write_delay_, 0); 294 } 295 296 DelayedSocketData::DelayedSocketData( 297 const MockConnect& connect, int write_delay, MockRead* reads, 298 size_t reads_count, MockWrite* writes, size_t writes_count) 299 : StaticSocketDataProvider(reads, reads_count, writes, writes_count), 300 write_delay_(write_delay), 301 read_in_progress_(false), 302 weak_factory_(this) { 303 DCHECK_GE(write_delay_, 0); 304 set_connect_data(connect); 305 } 306 307 DelayedSocketData::~DelayedSocketData() { 308 } 309 310 void DelayedSocketData::ForceNextRead() { 311 DCHECK(read_in_progress_); 312 write_delay_ = 0; 313 CompleteRead(); 314 } 315 316 MockRead DelayedSocketData::GetNextRead() { 317 MockRead out = MockRead(ASYNC, ERR_IO_PENDING); 318 if (write_delay_ <= 0) 319 out = StaticSocketDataProvider::GetNextRead(); 320 read_in_progress_ = (out.result == ERR_IO_PENDING); 321 return out; 322 } 323 324 MockWriteResult DelayedSocketData::OnWrite(const std::string& data) { 325 MockWriteResult rv = StaticSocketDataProvider::OnWrite(data); 326 // Now that our write has completed, we can allow reads to continue. 327 if (!--write_delay_ && read_in_progress_) 328 base::MessageLoop::current()->PostDelayedTask( 329 FROM_HERE, 330 base::Bind(&DelayedSocketData::CompleteRead, 331 weak_factory_.GetWeakPtr()), 332 base::TimeDelta::FromMilliseconds(100)); 333 return rv; 334 } 335 336 void DelayedSocketData::Reset() { 337 set_socket(NULL); 338 read_in_progress_ = false; 339 weak_factory_.InvalidateWeakPtrs(); 340 StaticSocketDataProvider::Reset(); 341 } 342 343 void DelayedSocketData::CompleteRead() { 344 if (socket() && read_in_progress_) 345 socket()->OnReadComplete(GetNextRead()); 346 } 347 348 OrderedSocketData::OrderedSocketData( 349 MockRead* reads, size_t reads_count, MockWrite* writes, size_t writes_count) 350 : StaticSocketDataProvider(reads, reads_count, writes, writes_count), 351 sequence_number_(0), loop_stop_stage_(0), 352 blocked_(false), weak_factory_(this) { 353 } 354 355 OrderedSocketData::OrderedSocketData( 356 const MockConnect& connect, 357 MockRead* reads, size_t reads_count, 358 MockWrite* writes, size_t writes_count) 359 : StaticSocketDataProvider(reads, reads_count, writes, writes_count), 360 sequence_number_(0), loop_stop_stage_(0), 361 blocked_(false), weak_factory_(this) { 362 set_connect_data(connect); 363 } 364 365 void OrderedSocketData::EndLoop() { 366 // If we've already stopped the loop, don't do it again until we've advanced 367 // to the next sequence_number. 368 NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ << ": EndLoop()"; 369 if (loop_stop_stage_ > 0) { 370 const MockRead& next_read = StaticSocketDataProvider::PeekRead(); 371 if ((next_read.sequence_number & ~MockRead::STOPLOOP) > 372 loop_stop_stage_) { 373 NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ 374 << ": Clearing stop index"; 375 loop_stop_stage_ = 0; 376 } else { 377 return; 378 } 379 } 380 // Record the sequence_number at which we stopped the loop. 381 NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ 382 << ": Posting Quit at read " << read_index(); 383 loop_stop_stage_ = sequence_number_; 384 } 385 386 MockRead OrderedSocketData::GetNextRead() { 387 weak_factory_.InvalidateWeakPtrs(); 388 blocked_ = false; 389 const MockRead& next_read = StaticSocketDataProvider::PeekRead(); 390 if (next_read.sequence_number & MockRead::STOPLOOP) 391 EndLoop(); 392 if ((next_read.sequence_number & ~MockRead::STOPLOOP) <= 393 sequence_number_++) { 394 NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ - 1 395 << ": Read " << read_index(); 396 DumpMockReadWrite(next_read); 397 blocked_ = (next_read.result == ERR_IO_PENDING); 398 return StaticSocketDataProvider::GetNextRead(); 399 } 400 NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ - 1 401 << ": I/O Pending"; 402 MockRead result = MockRead(ASYNC, ERR_IO_PENDING); 403 DumpMockReadWrite(result); 404 blocked_ = true; 405 return result; 406 } 407 408 MockWriteResult OrderedSocketData::OnWrite(const std::string& data) { 409 NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ 410 << ": Write " << write_index(); 411 DumpMockReadWrite(PeekWrite()); 412 ++sequence_number_; 413 if (blocked_) { 414 // TODO(willchan): This 100ms delay seems to work around some weirdness. We 415 // should probably fix the weirdness. One example is in SpdyStream, 416 // DoSendRequest() will return ERR_IO_PENDING, and there's a race. If the 417 // SYN_REPLY causes OnResponseReceived() to get called before 418 // SpdyStream::ReadResponseHeaders() is called, we hit a NOTREACHED(). 419 base::MessageLoop::current()->PostDelayedTask( 420 FROM_HERE, 421 base::Bind(&OrderedSocketData::CompleteRead, 422 weak_factory_.GetWeakPtr()), 423 base::TimeDelta::FromMilliseconds(100)); 424 } 425 return StaticSocketDataProvider::OnWrite(data); 426 } 427 428 void OrderedSocketData::Reset() { 429 NET_TRACE(INFO, " *** ") << "Stage " 430 << sequence_number_ << ": Reset()"; 431 sequence_number_ = 0; 432 loop_stop_stage_ = 0; 433 set_socket(NULL); 434 weak_factory_.InvalidateWeakPtrs(); 435 StaticSocketDataProvider::Reset(); 436 } 437 438 void OrderedSocketData::CompleteRead() { 439 if (socket() && blocked_) { 440 NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_; 441 socket()->OnReadComplete(GetNextRead()); 442 } 443 } 444 445 OrderedSocketData::~OrderedSocketData() {} 446 447 DeterministicSocketData::DeterministicSocketData(MockRead* reads, 448 size_t reads_count, MockWrite* writes, size_t writes_count) 449 : StaticSocketDataProvider(reads, reads_count, writes, writes_count), 450 sequence_number_(0), 451 current_read_(), 452 current_write_(), 453 stopping_sequence_number_(0), 454 stopped_(false), 455 print_debug_(false), 456 is_running_(false) { 457 VerifyCorrectSequenceNumbers(reads, reads_count, writes, writes_count); 458 } 459 460 DeterministicSocketData::~DeterministicSocketData() {} 461 462 void DeterministicSocketData::Run() { 463 DCHECK(!is_running_); 464 is_running_ = true; 465 466 SetStopped(false); 467 int counter = 0; 468 // Continue to consume data until all data has run out, or the stopped_ flag 469 // has been set. Consuming data requires two separate operations -- running 470 // the tasks in the message loop, and explicitly invoking the read/write 471 // callbacks (simulating network I/O). We check our conditions between each, 472 // since they can change in either. 473 while ((!at_write_eof() || !at_read_eof()) && !stopped()) { 474 if (counter % 2 == 0) 475 base::RunLoop().RunUntilIdle(); 476 if (counter % 2 == 1) { 477 InvokeCallbacks(); 478 } 479 counter++; 480 } 481 // We're done consuming new data, but it is possible there are still some 482 // pending callbacks which we expect to complete before returning. 483 while (delegate_.get() && 484 (delegate_->WritePending() || delegate_->ReadPending()) && 485 !stopped()) { 486 InvokeCallbacks(); 487 base::RunLoop().RunUntilIdle(); 488 } 489 SetStopped(false); 490 is_running_ = false; 491 } 492 493 void DeterministicSocketData::RunFor(int steps) { 494 StopAfter(steps); 495 Run(); 496 } 497 498 void DeterministicSocketData::SetStop(int seq) { 499 DCHECK_LT(sequence_number_, seq); 500 stopping_sequence_number_ = seq; 501 stopped_ = false; 502 } 503 504 void DeterministicSocketData::StopAfter(int seq) { 505 SetStop(sequence_number_ + seq); 506 } 507 508 MockRead DeterministicSocketData::GetNextRead() { 509 current_read_ = StaticSocketDataProvider::PeekRead(); 510 511 // Synchronous read while stopped is an error 512 if (stopped() && current_read_.mode == SYNCHRONOUS) { 513 LOG(ERROR) << "Unable to perform synchronous IO while stopped"; 514 return MockRead(SYNCHRONOUS, ERR_UNEXPECTED); 515 } 516 517 // Async read which will be called back in a future step. 518 if (sequence_number_ < current_read_.sequence_number) { 519 NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ 520 << ": I/O Pending"; 521 MockRead result = MockRead(SYNCHRONOUS, ERR_IO_PENDING); 522 if (current_read_.mode == SYNCHRONOUS) { 523 LOG(ERROR) << "Unable to perform synchronous read: " 524 << current_read_.sequence_number 525 << " at stage: " << sequence_number_; 526 result = MockRead(SYNCHRONOUS, ERR_UNEXPECTED); 527 } 528 if (print_debug_) 529 DumpMockReadWrite(result); 530 return result; 531 } 532 533 NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ 534 << ": Read " << read_index(); 535 if (print_debug_) 536 DumpMockReadWrite(current_read_); 537 538 // Increment the sequence number if IO is complete 539 if (current_read_.mode == SYNCHRONOUS) 540 NextStep(); 541 542 DCHECK_NE(ERR_IO_PENDING, current_read_.result); 543 StaticSocketDataProvider::GetNextRead(); 544 545 return current_read_; 546 } 547 548 MockWriteResult DeterministicSocketData::OnWrite(const std::string& data) { 549 const MockWrite& next_write = StaticSocketDataProvider::PeekWrite(); 550 current_write_ = next_write; 551 552 // Synchronous write while stopped is an error 553 if (stopped() && next_write.mode == SYNCHRONOUS) { 554 LOG(ERROR) << "Unable to perform synchronous IO while stopped"; 555 return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED); 556 } 557 558 // Async write which will be called back in a future step. 559 if (sequence_number_ < next_write.sequence_number) { 560 NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ 561 << ": I/O Pending"; 562 if (next_write.mode == SYNCHRONOUS) { 563 LOG(ERROR) << "Unable to perform synchronous write: " 564 << next_write.sequence_number << " at stage: " << sequence_number_; 565 return MockWriteResult(SYNCHRONOUS, ERR_UNEXPECTED); 566 } 567 } else { 568 NET_TRACE(INFO, " *** ") << "Stage " << sequence_number_ 569 << ": Write " << write_index(); 570 } 571 572 if (print_debug_) 573 DumpMockReadWrite(next_write); 574 575 // Move to the next step if I/O is synchronous, since the operation will 576 // complete when this method returns. 577 if (next_write.mode == SYNCHRONOUS) 578 NextStep(); 579 580 // This is either a sync write for this step, or an async write. 581 return StaticSocketDataProvider::OnWrite(data); 582 } 583 584 void DeterministicSocketData::Reset() { 585 NET_TRACE(INFO, " *** ") << "Stage " 586 << sequence_number_ << ": Reset()"; 587 sequence_number_ = 0; 588 StaticSocketDataProvider::Reset(); 589 NOTREACHED(); 590 } 591 592 void DeterministicSocketData::InvokeCallbacks() { 593 if (delegate_.get() && delegate_->WritePending() && 594 (current_write().sequence_number == sequence_number())) { 595 NextStep(); 596 delegate_->CompleteWrite(); 597 return; 598 } 599 if (delegate_.get() && delegate_->ReadPending() && 600 (current_read().sequence_number == sequence_number())) { 601 NextStep(); 602 delegate_->CompleteRead(); 603 return; 604 } 605 } 606 607 void DeterministicSocketData::NextStep() { 608 // Invariant: Can never move *past* the stopping step. 609 DCHECK_LT(sequence_number_, stopping_sequence_number_); 610 sequence_number_++; 611 if (sequence_number_ == stopping_sequence_number_) 612 SetStopped(true); 613 } 614 615 void DeterministicSocketData::VerifyCorrectSequenceNumbers( 616 MockRead* reads, size_t reads_count, 617 MockWrite* writes, size_t writes_count) { 618 size_t read = 0; 619 size_t write = 0; 620 int expected = 0; 621 while (read < reads_count || write < writes_count) { 622 // Check to see that we have a read or write at the expected 623 // state. 624 if (read < reads_count && reads[read].sequence_number == expected) { 625 ++read; 626 ++expected; 627 continue; 628 } 629 if (write < writes_count && writes[write].sequence_number == expected) { 630 ++write; 631 ++expected; 632 continue; 633 } 634 NOTREACHED() << "Missing sequence number: " << expected; 635 return; 636 } 637 DCHECK_EQ(read, reads_count); 638 DCHECK_EQ(write, writes_count); 639 } 640 641 MockClientSocketFactory::MockClientSocketFactory() {} 642 643 MockClientSocketFactory::~MockClientSocketFactory() {} 644 645 void MockClientSocketFactory::AddSocketDataProvider( 646 SocketDataProvider* data) { 647 mock_data_.Add(data); 648 } 649 650 void MockClientSocketFactory::AddSSLSocketDataProvider( 651 SSLSocketDataProvider* data) { 652 mock_ssl_data_.Add(data); 653 } 654 655 void MockClientSocketFactory::ResetNextMockIndexes() { 656 mock_data_.ResetNextIndex(); 657 mock_ssl_data_.ResetNextIndex(); 658 } 659 660 DatagramClientSocket* MockClientSocketFactory::CreateDatagramClientSocket( 661 DatagramSocket::BindType bind_type, 662 const RandIntCallback& rand_int_cb, 663 net::NetLog* net_log, 664 const net::NetLog::Source& source) { 665 SocketDataProvider* data_provider = mock_data_.GetNext(); 666 MockUDPClientSocket* socket = new MockUDPClientSocket(data_provider, net_log); 667 data_provider->set_socket(socket); 668 return socket; 669 } 670 671 StreamSocket* MockClientSocketFactory::CreateTransportClientSocket( 672 const AddressList& addresses, 673 net::NetLog* net_log, 674 const net::NetLog::Source& source) { 675 SocketDataProvider* data_provider = mock_data_.GetNext(); 676 MockTCPClientSocket* socket = 677 new MockTCPClientSocket(addresses, net_log, data_provider); 678 data_provider->set_socket(socket); 679 return socket; 680 } 681 682 SSLClientSocket* MockClientSocketFactory::CreateSSLClientSocket( 683 ClientSocketHandle* transport_socket, 684 const HostPortPair& host_and_port, 685 const SSLConfig& ssl_config, 686 const SSLClientSocketContext& context) { 687 MockSSLClientSocket* socket = 688 new MockSSLClientSocket(transport_socket, host_and_port, ssl_config, 689 mock_ssl_data_.GetNext()); 690 return socket; 691 } 692 693 void MockClientSocketFactory::ClearSSLSessionCache() { 694 } 695 696 const char MockClientSocket::kTlsUnique[] = "MOCK_TLSUNIQ"; 697 698 MockClientSocket::MockClientSocket(const BoundNetLog& net_log) 699 : weak_factory_(this), 700 connected_(false), 701 net_log_(net_log) { 702 IPAddressNumber ip; 703 CHECK(ParseIPLiteralToNumber("192.0.2.33", &ip)); 704 peer_addr_ = IPEndPoint(ip, 0); 705 } 706 707 bool MockClientSocket::SetReceiveBufferSize(int32 size) { 708 return true; 709 } 710 711 bool MockClientSocket::SetSendBufferSize(int32 size) { 712 return true; 713 } 714 715 void MockClientSocket::Disconnect() { 716 connected_ = false; 717 } 718 719 bool MockClientSocket::IsConnected() const { 720 return connected_; 721 } 722 723 bool MockClientSocket::IsConnectedAndIdle() const { 724 return connected_; 725 } 726 727 int MockClientSocket::GetPeerAddress(IPEndPoint* address) const { 728 if (!IsConnected()) 729 return ERR_SOCKET_NOT_CONNECTED; 730 *address = peer_addr_; 731 return OK; 732 } 733 734 int MockClientSocket::GetLocalAddress(IPEndPoint* address) const { 735 IPAddressNumber ip; 736 bool rv = ParseIPLiteralToNumber("192.0.2.33", &ip); 737 CHECK(rv); 738 *address = IPEndPoint(ip, 123); 739 return OK; 740 } 741 742 const BoundNetLog& MockClientSocket::NetLog() const { 743 return net_log_; 744 } 745 746 void MockClientSocket::GetSSLCertRequestInfo( 747 SSLCertRequestInfo* cert_request_info) { 748 } 749 750 int MockClientSocket::ExportKeyingMaterial(const base::StringPiece& label, 751 bool has_context, 752 const base::StringPiece& context, 753 unsigned char* out, 754 unsigned int outlen) { 755 memset(out, 'A', outlen); 756 return OK; 757 } 758 759 int MockClientSocket::GetTLSUniqueChannelBinding(std::string* out) { 760 out->assign(MockClientSocket::kTlsUnique); 761 return OK; 762 } 763 764 ServerBoundCertService* MockClientSocket::GetServerBoundCertService() const { 765 NOTREACHED(); 766 return NULL; 767 } 768 769 SSLClientSocket::NextProtoStatus 770 MockClientSocket::GetNextProto(std::string* proto, std::string* server_protos) { 771 proto->clear(); 772 server_protos->clear(); 773 return SSLClientSocket::kNextProtoUnsupported; 774 } 775 776 MockClientSocket::~MockClientSocket() {} 777 778 void MockClientSocket::RunCallbackAsync(const CompletionCallback& callback, 779 int result) { 780 base::MessageLoop::current()->PostTask( 781 FROM_HERE, 782 base::Bind(&MockClientSocket::RunCallback, 783 weak_factory_.GetWeakPtr(), 784 callback, 785 result)); 786 } 787 788 void MockClientSocket::RunCallback(const net::CompletionCallback& callback, 789 int result) { 790 if (!callback.is_null()) 791 callback.Run(result); 792 } 793 794 MockTCPClientSocket::MockTCPClientSocket(const AddressList& addresses, 795 net::NetLog* net_log, 796 SocketDataProvider* data) 797 : MockClientSocket(BoundNetLog::Make(net_log, net::NetLog::SOURCE_NONE)), 798 addresses_(addresses), 799 data_(data), 800 read_offset_(0), 801 read_data_(SYNCHRONOUS, ERR_UNEXPECTED), 802 need_read_data_(true), 803 peer_closed_connection_(false), 804 pending_buf_(NULL), 805 pending_buf_len_(0), 806 was_used_to_convey_data_(false) { 807 DCHECK(data_); 808 peer_addr_ = data->connect_data().peer_addr; 809 data_->Reset(); 810 } 811 812 MockTCPClientSocket::~MockTCPClientSocket() {} 813 814 int MockTCPClientSocket::Read(IOBuffer* buf, int buf_len, 815 const CompletionCallback& callback) { 816 if (!connected_) 817 return ERR_UNEXPECTED; 818 819 // If the buffer is already in use, a read is already in progress! 820 DCHECK(pending_buf_ == NULL); 821 822 // Store our async IO data. 823 pending_buf_ = buf; 824 pending_buf_len_ = buf_len; 825 pending_callback_ = callback; 826 827 if (need_read_data_) { 828 read_data_ = data_->GetNextRead(); 829 if (read_data_.result == ERR_CONNECTION_CLOSED) { 830 // This MockRead is just a marker to instruct us to set 831 // peer_closed_connection_. 832 peer_closed_connection_ = true; 833 } 834 if (read_data_.result == ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ) { 835 // This MockRead is just a marker to instruct us to set 836 // peer_closed_connection_. Skip it and get the next one. 837 read_data_ = data_->GetNextRead(); 838 peer_closed_connection_ = true; 839 } 840 // ERR_IO_PENDING means that the SocketDataProvider is taking responsibility 841 // to complete the async IO manually later (via OnReadComplete). 842 if (read_data_.result == ERR_IO_PENDING) { 843 // We need to be using async IO in this case. 844 DCHECK(!callback.is_null()); 845 return ERR_IO_PENDING; 846 } 847 need_read_data_ = false; 848 } 849 850 return CompleteRead(); 851 } 852 853 int MockTCPClientSocket::Write(IOBuffer* buf, int buf_len, 854 const CompletionCallback& callback) { 855 DCHECK(buf); 856 DCHECK_GT(buf_len, 0); 857 858 if (!connected_) 859 return ERR_UNEXPECTED; 860 861 std::string data(buf->data(), buf_len); 862 MockWriteResult write_result = data_->OnWrite(data); 863 864 was_used_to_convey_data_ = true; 865 866 if (write_result.mode == ASYNC) { 867 RunCallbackAsync(callback, write_result.result); 868 return ERR_IO_PENDING; 869 } 870 871 return write_result.result; 872 } 873 874 int MockTCPClientSocket::Connect(const CompletionCallback& callback) { 875 if (connected_) 876 return OK; 877 connected_ = true; 878 peer_closed_connection_ = false; 879 if (data_->connect_data().mode == ASYNC) { 880 if (data_->connect_data().result == ERR_IO_PENDING) 881 pending_callback_ = callback; 882 else 883 RunCallbackAsync(callback, data_->connect_data().result); 884 return ERR_IO_PENDING; 885 } 886 return data_->connect_data().result; 887 } 888 889 void MockTCPClientSocket::Disconnect() { 890 MockClientSocket::Disconnect(); 891 pending_callback_.Reset(); 892 } 893 894 bool MockTCPClientSocket::IsConnected() const { 895 return connected_ && !peer_closed_connection_; 896 } 897 898 bool MockTCPClientSocket::IsConnectedAndIdle() const { 899 return IsConnected(); 900 } 901 902 int MockTCPClientSocket::GetPeerAddress(IPEndPoint* address) const { 903 if (addresses_.empty()) 904 return MockClientSocket::GetPeerAddress(address); 905 906 *address = addresses_[0]; 907 return OK; 908 } 909 910 bool MockTCPClientSocket::WasEverUsed() const { 911 return was_used_to_convey_data_; 912 } 913 914 bool MockTCPClientSocket::UsingTCPFastOpen() const { 915 return false; 916 } 917 918 bool MockTCPClientSocket::WasNpnNegotiated() const { 919 return false; 920 } 921 922 bool MockTCPClientSocket::GetSSLInfo(SSLInfo* ssl_info) { 923 return false; 924 } 925 926 void MockTCPClientSocket::OnReadComplete(const MockRead& data) { 927 // There must be a read pending. 928 DCHECK(pending_buf_); 929 // You can't complete a read with another ERR_IO_PENDING status code. 930 DCHECK_NE(ERR_IO_PENDING, data.result); 931 // Since we've been waiting for data, need_read_data_ should be true. 932 DCHECK(need_read_data_); 933 934 read_data_ = data; 935 need_read_data_ = false; 936 937 // The caller is simulating that this IO completes right now. Don't 938 // let CompleteRead() schedule a callback. 939 read_data_.mode = SYNCHRONOUS; 940 941 CompletionCallback callback = pending_callback_; 942 int rv = CompleteRead(); 943 RunCallback(callback, rv); 944 } 945 946 void MockTCPClientSocket::OnConnectComplete(const MockConnect& data) { 947 CompletionCallback callback = pending_callback_; 948 RunCallback(callback, data.result); 949 } 950 951 int MockTCPClientSocket::CompleteRead() { 952 DCHECK(pending_buf_); 953 DCHECK(pending_buf_len_ > 0); 954 955 was_used_to_convey_data_ = true; 956 957 // Save the pending async IO data and reset our |pending_| state. 958 IOBuffer* buf = pending_buf_; 959 int buf_len = pending_buf_len_; 960 CompletionCallback callback = pending_callback_; 961 pending_buf_ = NULL; 962 pending_buf_len_ = 0; 963 pending_callback_.Reset(); 964 965 int result = read_data_.result; 966 DCHECK(result != ERR_IO_PENDING); 967 968 if (read_data_.data) { 969 if (read_data_.data_len - read_offset_ > 0) { 970 result = std::min(buf_len, read_data_.data_len - read_offset_); 971 memcpy(buf->data(), read_data_.data + read_offset_, result); 972 read_offset_ += result; 973 if (read_offset_ == read_data_.data_len) { 974 need_read_data_ = true; 975 read_offset_ = 0; 976 } 977 } else { 978 result = 0; // EOF 979 } 980 } 981 982 if (read_data_.mode == ASYNC) { 983 DCHECK(!callback.is_null()); 984 RunCallbackAsync(callback, result); 985 return ERR_IO_PENDING; 986 } 987 return result; 988 } 989 990 DeterministicSocketHelper::DeterministicSocketHelper( 991 net::NetLog* net_log, 992 DeterministicSocketData* data) 993 : write_pending_(false), 994 write_result_(0), 995 read_data_(), 996 read_buf_(NULL), 997 read_buf_len_(0), 998 read_pending_(false), 999 data_(data), 1000 was_used_to_convey_data_(false), 1001 peer_closed_connection_(false), 1002 net_log_(BoundNetLog::Make(net_log, net::NetLog::SOURCE_NONE)) { 1003 } 1004 1005 DeterministicSocketHelper::~DeterministicSocketHelper() {} 1006 1007 void DeterministicSocketHelper::CompleteWrite() { 1008 was_used_to_convey_data_ = true; 1009 write_pending_ = false; 1010 write_callback_.Run(write_result_); 1011 } 1012 1013 int DeterministicSocketHelper::CompleteRead() { 1014 DCHECK_GT(read_buf_len_, 0); 1015 DCHECK_LE(read_data_.data_len, read_buf_len_); 1016 DCHECK(read_buf_); 1017 1018 was_used_to_convey_data_ = true; 1019 1020 if (read_data_.result == ERR_IO_PENDING) 1021 read_data_ = data_->GetNextRead(); 1022 DCHECK_NE(ERR_IO_PENDING, read_data_.result); 1023 // If read_data_.mode is ASYNC, we do not need to wait, since this is already 1024 // the callback. Therefore we don't even bother to check it. 1025 int result = read_data_.result; 1026 1027 if (read_data_.data_len > 0) { 1028 DCHECK(read_data_.data); 1029 result = std::min(read_buf_len_, read_data_.data_len); 1030 memcpy(read_buf_->data(), read_data_.data, result); 1031 } 1032 1033 if (read_pending_) { 1034 read_pending_ = false; 1035 read_callback_.Run(result); 1036 } 1037 1038 return result; 1039 } 1040 1041 int DeterministicSocketHelper::Write( 1042 IOBuffer* buf, int buf_len, const CompletionCallback& callback) { 1043 DCHECK(buf); 1044 DCHECK_GT(buf_len, 0); 1045 1046 std::string data(buf->data(), buf_len); 1047 MockWriteResult write_result = data_->OnWrite(data); 1048 1049 if (write_result.mode == ASYNC) { 1050 write_callback_ = callback; 1051 write_result_ = write_result.result; 1052 DCHECK(!write_callback_.is_null()); 1053 write_pending_ = true; 1054 return ERR_IO_PENDING; 1055 } 1056 1057 was_used_to_convey_data_ = true; 1058 write_pending_ = false; 1059 return write_result.result; 1060 } 1061 1062 int DeterministicSocketHelper::Read( 1063 IOBuffer* buf, int buf_len, const CompletionCallback& callback) { 1064 1065 read_data_ = data_->GetNextRead(); 1066 // The buffer should always be big enough to contain all the MockRead data. To 1067 // use small buffers, split the data into multiple MockReads. 1068 DCHECK_LE(read_data_.data_len, buf_len); 1069 1070 if (read_data_.result == ERR_CONNECTION_CLOSED) { 1071 // This MockRead is just a marker to instruct us to set 1072 // peer_closed_connection_. 1073 peer_closed_connection_ = true; 1074 } 1075 if (read_data_.result == ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ) { 1076 // This MockRead is just a marker to instruct us to set 1077 // peer_closed_connection_. Skip it and get the next one. 1078 read_data_ = data_->GetNextRead(); 1079 peer_closed_connection_ = true; 1080 } 1081 1082 read_buf_ = buf; 1083 read_buf_len_ = buf_len; 1084 read_callback_ = callback; 1085 1086 if (read_data_.mode == ASYNC || (read_data_.result == ERR_IO_PENDING)) { 1087 read_pending_ = true; 1088 DCHECK(!read_callback_.is_null()); 1089 return ERR_IO_PENDING; 1090 } 1091 1092 was_used_to_convey_data_ = true; 1093 return CompleteRead(); 1094 } 1095 1096 DeterministicMockUDPClientSocket::DeterministicMockUDPClientSocket( 1097 net::NetLog* net_log, 1098 DeterministicSocketData* data) 1099 : connected_(false), 1100 helper_(net_log, data) { 1101 } 1102 1103 DeterministicMockUDPClientSocket::~DeterministicMockUDPClientSocket() {} 1104 1105 bool DeterministicMockUDPClientSocket::WritePending() const { 1106 return helper_.write_pending(); 1107 } 1108 1109 bool DeterministicMockUDPClientSocket::ReadPending() const { 1110 return helper_.read_pending(); 1111 } 1112 1113 void DeterministicMockUDPClientSocket::CompleteWrite() { 1114 helper_.CompleteWrite(); 1115 } 1116 1117 int DeterministicMockUDPClientSocket::CompleteRead() { 1118 return helper_.CompleteRead(); 1119 } 1120 1121 int DeterministicMockUDPClientSocket::Connect(const IPEndPoint& address) { 1122 if (connected_) 1123 return OK; 1124 connected_ = true; 1125 peer_address_ = address; 1126 return helper_.data()->connect_data().result; 1127 }; 1128 1129 int DeterministicMockUDPClientSocket::Write( 1130 IOBuffer* buf, 1131 int buf_len, 1132 const CompletionCallback& callback) { 1133 if (!connected_) 1134 return ERR_UNEXPECTED; 1135 1136 return helper_.Write(buf, buf_len, callback); 1137 } 1138 1139 int DeterministicMockUDPClientSocket::Read( 1140 IOBuffer* buf, 1141 int buf_len, 1142 const CompletionCallback& callback) { 1143 if (!connected_) 1144 return ERR_UNEXPECTED; 1145 1146 return helper_.Read(buf, buf_len, callback); 1147 } 1148 1149 bool DeterministicMockUDPClientSocket::SetReceiveBufferSize(int32 size) { 1150 return true; 1151 } 1152 1153 bool DeterministicMockUDPClientSocket::SetSendBufferSize(int32 size) { 1154 return true; 1155 } 1156 1157 void DeterministicMockUDPClientSocket::Close() { 1158 connected_ = false; 1159 } 1160 1161 int DeterministicMockUDPClientSocket::GetPeerAddress( 1162 IPEndPoint* address) const { 1163 *address = peer_address_; 1164 return OK; 1165 } 1166 1167 int DeterministicMockUDPClientSocket::GetLocalAddress( 1168 IPEndPoint* address) const { 1169 IPAddressNumber ip; 1170 bool rv = ParseIPLiteralToNumber("192.0.2.33", &ip); 1171 CHECK(rv); 1172 *address = IPEndPoint(ip, 123); 1173 return OK; 1174 } 1175 1176 const BoundNetLog& DeterministicMockUDPClientSocket::NetLog() const { 1177 return helper_.net_log(); 1178 } 1179 1180 void DeterministicMockUDPClientSocket::OnReadComplete(const MockRead& data) {} 1181 1182 void DeterministicMockUDPClientSocket::OnConnectComplete( 1183 const MockConnect& data) { 1184 NOTIMPLEMENTED(); 1185 } 1186 1187 DeterministicMockTCPClientSocket::DeterministicMockTCPClientSocket( 1188 net::NetLog* net_log, 1189 DeterministicSocketData* data) 1190 : MockClientSocket(BoundNetLog::Make(net_log, net::NetLog::SOURCE_NONE)), 1191 helper_(net_log, data) { 1192 peer_addr_ = data->connect_data().peer_addr; 1193 } 1194 1195 DeterministicMockTCPClientSocket::~DeterministicMockTCPClientSocket() {} 1196 1197 bool DeterministicMockTCPClientSocket::WritePending() const { 1198 return helper_.write_pending(); 1199 } 1200 1201 bool DeterministicMockTCPClientSocket::ReadPending() const { 1202 return helper_.read_pending(); 1203 } 1204 1205 void DeterministicMockTCPClientSocket::CompleteWrite() { 1206 helper_.CompleteWrite(); 1207 } 1208 1209 int DeterministicMockTCPClientSocket::CompleteRead() { 1210 return helper_.CompleteRead(); 1211 } 1212 1213 int DeterministicMockTCPClientSocket::Write( 1214 IOBuffer* buf, 1215 int buf_len, 1216 const CompletionCallback& callback) { 1217 if (!connected_) 1218 return ERR_UNEXPECTED; 1219 1220 return helper_.Write(buf, buf_len, callback); 1221 } 1222 1223 int DeterministicMockTCPClientSocket::Read( 1224 IOBuffer* buf, 1225 int buf_len, 1226 const CompletionCallback& callback) { 1227 if (!connected_) 1228 return ERR_UNEXPECTED; 1229 1230 return helper_.Read(buf, buf_len, callback); 1231 } 1232 1233 // TODO(erikchen): Support connect sequencing. 1234 int DeterministicMockTCPClientSocket::Connect( 1235 const CompletionCallback& callback) { 1236 if (connected_) 1237 return OK; 1238 connected_ = true; 1239 if (helper_.data()->connect_data().mode == ASYNC) { 1240 RunCallbackAsync(callback, helper_.data()->connect_data().result); 1241 return ERR_IO_PENDING; 1242 } 1243 return helper_.data()->connect_data().result; 1244 } 1245 1246 void DeterministicMockTCPClientSocket::Disconnect() { 1247 MockClientSocket::Disconnect(); 1248 } 1249 1250 bool DeterministicMockTCPClientSocket::IsConnected() const { 1251 return connected_ && !helper_.peer_closed_connection(); 1252 } 1253 1254 bool DeterministicMockTCPClientSocket::IsConnectedAndIdle() const { 1255 return IsConnected(); 1256 } 1257 1258 bool DeterministicMockTCPClientSocket::WasEverUsed() const { 1259 return helper_.was_used_to_convey_data(); 1260 } 1261 1262 bool DeterministicMockTCPClientSocket::UsingTCPFastOpen() const { 1263 return false; 1264 } 1265 1266 bool DeterministicMockTCPClientSocket::WasNpnNegotiated() const { 1267 return false; 1268 } 1269 1270 bool DeterministicMockTCPClientSocket::GetSSLInfo(SSLInfo* ssl_info) { 1271 return false; 1272 } 1273 1274 void DeterministicMockTCPClientSocket::OnReadComplete(const MockRead& data) {} 1275 1276 void DeterministicMockTCPClientSocket::OnConnectComplete( 1277 const MockConnect& data) {} 1278 1279 // static 1280 void MockSSLClientSocket::ConnectCallback( 1281 MockSSLClientSocket *ssl_client_socket, 1282 const CompletionCallback& callback, 1283 int rv) { 1284 if (rv == OK) 1285 ssl_client_socket->connected_ = true; 1286 callback.Run(rv); 1287 } 1288 1289 MockSSLClientSocket::MockSSLClientSocket( 1290 ClientSocketHandle* transport_socket, 1291 const HostPortPair& host_port_pair, 1292 const SSLConfig& ssl_config, 1293 SSLSocketDataProvider* data) 1294 : MockClientSocket( 1295 // Have to use the right BoundNetLog for LoadTimingInfo regression 1296 // tests. 1297 transport_socket->socket()->NetLog()), 1298 transport_(transport_socket), 1299 data_(data), 1300 is_npn_state_set_(false), 1301 new_npn_value_(false), 1302 is_protocol_negotiated_set_(false), 1303 protocol_negotiated_(kProtoUnknown) { 1304 DCHECK(data_); 1305 peer_addr_ = data->connect.peer_addr; 1306 } 1307 1308 MockSSLClientSocket::~MockSSLClientSocket() { 1309 Disconnect(); 1310 } 1311 1312 int MockSSLClientSocket::Read(IOBuffer* buf, int buf_len, 1313 const CompletionCallback& callback) { 1314 return transport_->socket()->Read(buf, buf_len, callback); 1315 } 1316 1317 int MockSSLClientSocket::Write(IOBuffer* buf, int buf_len, 1318 const CompletionCallback& callback) { 1319 return transport_->socket()->Write(buf, buf_len, callback); 1320 } 1321 1322 int MockSSLClientSocket::Connect(const CompletionCallback& callback) { 1323 int rv = transport_->socket()->Connect( 1324 base::Bind(&ConnectCallback, base::Unretained(this), callback)); 1325 if (rv == OK) { 1326 if (data_->connect.result == OK) 1327 connected_ = true; 1328 if (data_->connect.mode == ASYNC) { 1329 RunCallbackAsync(callback, data_->connect.result); 1330 return ERR_IO_PENDING; 1331 } 1332 return data_->connect.result; 1333 } 1334 return rv; 1335 } 1336 1337 void MockSSLClientSocket::Disconnect() { 1338 MockClientSocket::Disconnect(); 1339 if (transport_->socket() != NULL) 1340 transport_->socket()->Disconnect(); 1341 } 1342 1343 bool MockSSLClientSocket::IsConnected() const { 1344 return transport_->socket()->IsConnected(); 1345 } 1346 1347 bool MockSSLClientSocket::WasEverUsed() const { 1348 return transport_->socket()->WasEverUsed(); 1349 } 1350 1351 bool MockSSLClientSocket::UsingTCPFastOpen() const { 1352 return transport_->socket()->UsingTCPFastOpen(); 1353 } 1354 1355 int MockSSLClientSocket::GetPeerAddress(IPEndPoint* address) const { 1356 return transport_->socket()->GetPeerAddress(address); 1357 } 1358 1359 bool MockSSLClientSocket::GetSSLInfo(SSLInfo* ssl_info) { 1360 ssl_info->Reset(); 1361 ssl_info->cert = data_->cert; 1362 ssl_info->client_cert_sent = data_->client_cert_sent; 1363 ssl_info->channel_id_sent = data_->channel_id_sent; 1364 return true; 1365 } 1366 1367 void MockSSLClientSocket::GetSSLCertRequestInfo( 1368 SSLCertRequestInfo* cert_request_info) { 1369 DCHECK(cert_request_info); 1370 if (data_->cert_request_info) { 1371 cert_request_info->host_and_port = 1372 data_->cert_request_info->host_and_port; 1373 cert_request_info->client_certs = data_->cert_request_info->client_certs; 1374 } else { 1375 cert_request_info->Reset(); 1376 } 1377 } 1378 1379 SSLClientSocket::NextProtoStatus MockSSLClientSocket::GetNextProto( 1380 std::string* proto, std::string* server_protos) { 1381 *proto = data_->next_proto; 1382 *server_protos = data_->server_protos; 1383 return data_->next_proto_status; 1384 } 1385 1386 bool MockSSLClientSocket::set_was_npn_negotiated(bool negotiated) { 1387 is_npn_state_set_ = true; 1388 return new_npn_value_ = negotiated; 1389 } 1390 1391 bool MockSSLClientSocket::WasNpnNegotiated() const { 1392 if (is_npn_state_set_) 1393 return new_npn_value_; 1394 return data_->was_npn_negotiated; 1395 } 1396 1397 NextProto MockSSLClientSocket::GetNegotiatedProtocol() const { 1398 if (is_protocol_negotiated_set_) 1399 return protocol_negotiated_; 1400 return data_->protocol_negotiated; 1401 } 1402 1403 void MockSSLClientSocket::set_protocol_negotiated( 1404 NextProto protocol_negotiated) { 1405 is_protocol_negotiated_set_ = true; 1406 protocol_negotiated_ = protocol_negotiated; 1407 } 1408 1409 bool MockSSLClientSocket::WasChannelIDSent() const { 1410 return data_->channel_id_sent; 1411 } 1412 1413 void MockSSLClientSocket::set_channel_id_sent(bool channel_id_sent) { 1414 data_->channel_id_sent = channel_id_sent; 1415 } 1416 1417 ServerBoundCertService* MockSSLClientSocket::GetServerBoundCertService() const { 1418 return data_->server_bound_cert_service; 1419 } 1420 1421 void MockSSLClientSocket::OnReadComplete(const MockRead& data) { 1422 NOTIMPLEMENTED(); 1423 } 1424 1425 void MockSSLClientSocket::OnConnectComplete(const MockConnect& data) { 1426 NOTIMPLEMENTED(); 1427 } 1428 1429 MockUDPClientSocket::MockUDPClientSocket(SocketDataProvider* data, 1430 net::NetLog* net_log) 1431 : connected_(false), 1432 data_(data), 1433 read_offset_(0), 1434 read_data_(SYNCHRONOUS, ERR_UNEXPECTED), 1435 need_read_data_(true), 1436 pending_buf_(NULL), 1437 pending_buf_len_(0), 1438 net_log_(BoundNetLog::Make(net_log, net::NetLog::SOURCE_NONE)), 1439 weak_factory_(this) { 1440 DCHECK(data_); 1441 data_->Reset(); 1442 peer_addr_ = data->connect_data().peer_addr; 1443 } 1444 1445 MockUDPClientSocket::~MockUDPClientSocket() {} 1446 1447 int MockUDPClientSocket::Read(IOBuffer* buf, int buf_len, 1448 const CompletionCallback& callback) { 1449 if (!connected_) 1450 return ERR_UNEXPECTED; 1451 1452 // If the buffer is already in use, a read is already in progress! 1453 DCHECK(pending_buf_ == NULL); 1454 1455 // Store our async IO data. 1456 pending_buf_ = buf; 1457 pending_buf_len_ = buf_len; 1458 pending_callback_ = callback; 1459 1460 if (need_read_data_) { 1461 read_data_ = data_->GetNextRead(); 1462 // ERR_IO_PENDING means that the SocketDataProvider is taking responsibility 1463 // to complete the async IO manually later (via OnReadComplete). 1464 if (read_data_.result == ERR_IO_PENDING) { 1465 // We need to be using async IO in this case. 1466 DCHECK(!callback.is_null()); 1467 return ERR_IO_PENDING; 1468 } 1469 need_read_data_ = false; 1470 } 1471 1472 return CompleteRead(); 1473 } 1474 1475 int MockUDPClientSocket::Write(IOBuffer* buf, int buf_len, 1476 const CompletionCallback& callback) { 1477 DCHECK(buf); 1478 DCHECK_GT(buf_len, 0); 1479 1480 if (!connected_) 1481 return ERR_UNEXPECTED; 1482 1483 std::string data(buf->data(), buf_len); 1484 MockWriteResult write_result = data_->OnWrite(data); 1485 1486 if (write_result.mode == ASYNC) { 1487 RunCallbackAsync(callback, write_result.result); 1488 return ERR_IO_PENDING; 1489 } 1490 return write_result.result; 1491 } 1492 1493 bool MockUDPClientSocket::SetReceiveBufferSize(int32 size) { 1494 return true; 1495 } 1496 1497 bool MockUDPClientSocket::SetSendBufferSize(int32 size) { 1498 return true; 1499 } 1500 1501 void MockUDPClientSocket::Close() { 1502 connected_ = false; 1503 } 1504 1505 int MockUDPClientSocket::GetPeerAddress(IPEndPoint* address) const { 1506 *address = peer_addr_; 1507 return OK; 1508 } 1509 1510 int MockUDPClientSocket::GetLocalAddress(IPEndPoint* address) const { 1511 IPAddressNumber ip; 1512 bool rv = ParseIPLiteralToNumber("192.0.2.33", &ip); 1513 CHECK(rv); 1514 *address = IPEndPoint(ip, 123); 1515 return OK; 1516 } 1517 1518 const BoundNetLog& MockUDPClientSocket::NetLog() const { 1519 return net_log_; 1520 } 1521 1522 int MockUDPClientSocket::Connect(const IPEndPoint& address) { 1523 connected_ = true; 1524 peer_addr_ = address; 1525 return OK; 1526 } 1527 1528 void MockUDPClientSocket::OnReadComplete(const MockRead& data) { 1529 // There must be a read pending. 1530 DCHECK(pending_buf_); 1531 // You can't complete a read with another ERR_IO_PENDING status code. 1532 DCHECK_NE(ERR_IO_PENDING, data.result); 1533 // Since we've been waiting for data, need_read_data_ should be true. 1534 DCHECK(need_read_data_); 1535 1536 read_data_ = data; 1537 need_read_data_ = false; 1538 1539 // The caller is simulating that this IO completes right now. Don't 1540 // let CompleteRead() schedule a callback. 1541 read_data_.mode = SYNCHRONOUS; 1542 1543 net::CompletionCallback callback = pending_callback_; 1544 int rv = CompleteRead(); 1545 RunCallback(callback, rv); 1546 } 1547 1548 void MockUDPClientSocket::OnConnectComplete(const MockConnect& data) { 1549 NOTIMPLEMENTED(); 1550 } 1551 1552 int MockUDPClientSocket::CompleteRead() { 1553 DCHECK(pending_buf_); 1554 DCHECK(pending_buf_len_ > 0); 1555 1556 // Save the pending async IO data and reset our |pending_| state. 1557 IOBuffer* buf = pending_buf_; 1558 int buf_len = pending_buf_len_; 1559 CompletionCallback callback = pending_callback_; 1560 pending_buf_ = NULL; 1561 pending_buf_len_ = 0; 1562 pending_callback_.Reset(); 1563 1564 int result = read_data_.result; 1565 DCHECK(result != ERR_IO_PENDING); 1566 1567 if (read_data_.data) { 1568 if (read_data_.data_len - read_offset_ > 0) { 1569 result = std::min(buf_len, read_data_.data_len - read_offset_); 1570 memcpy(buf->data(), read_data_.data + read_offset_, result); 1571 read_offset_ += result; 1572 if (read_offset_ == read_data_.data_len) { 1573 need_read_data_ = true; 1574 read_offset_ = 0; 1575 } 1576 } else { 1577 result = 0; // EOF 1578 } 1579 } 1580 1581 if (read_data_.mode == ASYNC) { 1582 DCHECK(!callback.is_null()); 1583 RunCallbackAsync(callback, result); 1584 return ERR_IO_PENDING; 1585 } 1586 return result; 1587 } 1588 1589 void MockUDPClientSocket::RunCallbackAsync(const CompletionCallback& callback, 1590 int result) { 1591 base::MessageLoop::current()->PostTask( 1592 FROM_HERE, 1593 base::Bind(&MockUDPClientSocket::RunCallback, 1594 weak_factory_.GetWeakPtr(), 1595 callback, 1596 result)); 1597 } 1598 1599 void MockUDPClientSocket::RunCallback(const CompletionCallback& callback, 1600 int result) { 1601 if (!callback.is_null()) 1602 callback.Run(result); 1603 } 1604 1605 TestSocketRequest::TestSocketRequest( 1606 std::vector<TestSocketRequest*>* request_order, size_t* completion_count) 1607 : request_order_(request_order), 1608 completion_count_(completion_count), 1609 callback_(base::Bind(&TestSocketRequest::OnComplete, 1610 base::Unretained(this))) { 1611 DCHECK(request_order); 1612 DCHECK(completion_count); 1613 } 1614 1615 TestSocketRequest::~TestSocketRequest() { 1616 } 1617 1618 void TestSocketRequest::OnComplete(int result) { 1619 SetResult(result); 1620 (*completion_count_)++; 1621 request_order_->push_back(this); 1622 } 1623 1624 // static 1625 const int ClientSocketPoolTest::kIndexOutOfBounds = -1; 1626 1627 // static 1628 const int ClientSocketPoolTest::kRequestNotFound = -2; 1629 1630 ClientSocketPoolTest::ClientSocketPoolTest() : completion_count_(0) {} 1631 ClientSocketPoolTest::~ClientSocketPoolTest() {} 1632 1633 int ClientSocketPoolTest::GetOrderOfRequest(size_t index) const { 1634 index--; 1635 if (index >= requests_.size()) 1636 return kIndexOutOfBounds; 1637 1638 for (size_t i = 0; i < request_order_.size(); i++) 1639 if (requests_[index] == request_order_[i]) 1640 return i + 1; 1641 1642 return kRequestNotFound; 1643 } 1644 1645 bool ClientSocketPoolTest::ReleaseOneConnection(KeepAlive keep_alive) { 1646 ScopedVector<TestSocketRequest>::iterator i; 1647 for (i = requests_.begin(); i != requests_.end(); ++i) { 1648 if ((*i)->handle()->is_initialized()) { 1649 if (keep_alive == NO_KEEP_ALIVE) 1650 (*i)->handle()->socket()->Disconnect(); 1651 (*i)->handle()->Reset(); 1652 base::RunLoop().RunUntilIdle(); 1653 return true; 1654 } 1655 } 1656 return false; 1657 } 1658 1659 void ClientSocketPoolTest::ReleaseAllConnections(KeepAlive keep_alive) { 1660 bool released_one; 1661 do { 1662 released_one = ReleaseOneConnection(keep_alive); 1663 } while (released_one); 1664 } 1665 1666 MockTransportClientSocketPool::MockConnectJob::MockConnectJob( 1667 StreamSocket* socket, 1668 ClientSocketHandle* handle, 1669 const CompletionCallback& callback) 1670 : socket_(socket), 1671 handle_(handle), 1672 user_callback_(callback) { 1673 } 1674 1675 MockTransportClientSocketPool::MockConnectJob::~MockConnectJob() {} 1676 1677 int MockTransportClientSocketPool::MockConnectJob::Connect() { 1678 int rv = socket_->Connect(base::Bind(&MockConnectJob::OnConnect, 1679 base::Unretained(this))); 1680 if (rv == OK) { 1681 user_callback_.Reset(); 1682 OnConnect(OK); 1683 } 1684 return rv; 1685 } 1686 1687 bool MockTransportClientSocketPool::MockConnectJob::CancelHandle( 1688 const ClientSocketHandle* handle) { 1689 if (handle != handle_) 1690 return false; 1691 socket_.reset(); 1692 handle_ = NULL; 1693 user_callback_.Reset(); 1694 return true; 1695 } 1696 1697 void MockTransportClientSocketPool::MockConnectJob::OnConnect(int rv) { 1698 if (!socket_.get()) 1699 return; 1700 if (rv == OK) { 1701 handle_->set_socket(socket_.release()); 1702 1703 // Needed for socket pool tests that layer other sockets on top of mock 1704 // sockets. 1705 LoadTimingInfo::ConnectTiming connect_timing; 1706 base::TimeTicks now = base::TimeTicks::Now(); 1707 connect_timing.dns_start = now; 1708 connect_timing.dns_end = now; 1709 connect_timing.connect_start = now; 1710 connect_timing.connect_end = now; 1711 handle_->set_connect_timing(connect_timing); 1712 } else { 1713 socket_.reset(); 1714 } 1715 1716 handle_ = NULL; 1717 1718 if (!user_callback_.is_null()) { 1719 CompletionCallback callback = user_callback_; 1720 user_callback_.Reset(); 1721 callback.Run(rv); 1722 } 1723 } 1724 1725 MockTransportClientSocketPool::MockTransportClientSocketPool( 1726 int max_sockets, 1727 int max_sockets_per_group, 1728 ClientSocketPoolHistograms* histograms, 1729 ClientSocketFactory* socket_factory) 1730 : TransportClientSocketPool(max_sockets, max_sockets_per_group, histograms, 1731 NULL, NULL, NULL), 1732 client_socket_factory_(socket_factory), 1733 release_count_(0), 1734 cancel_count_(0) { 1735 } 1736 1737 MockTransportClientSocketPool::~MockTransportClientSocketPool() {} 1738 1739 int MockTransportClientSocketPool::RequestSocket( 1740 const std::string& group_name, const void* socket_params, 1741 RequestPriority priority, ClientSocketHandle* handle, 1742 const CompletionCallback& callback, const BoundNetLog& net_log) { 1743 StreamSocket* socket = client_socket_factory_->CreateTransportClientSocket( 1744 AddressList(), net_log.net_log(), net::NetLog::Source()); 1745 MockConnectJob* job = new MockConnectJob(socket, handle, callback); 1746 job_list_.push_back(job); 1747 handle->set_pool_id(1); 1748 return job->Connect(); 1749 } 1750 1751 void MockTransportClientSocketPool::CancelRequest(const std::string& group_name, 1752 ClientSocketHandle* handle) { 1753 std::vector<MockConnectJob*>::iterator i; 1754 for (i = job_list_.begin(); i != job_list_.end(); ++i) { 1755 if ((*i)->CancelHandle(handle)) { 1756 cancel_count_++; 1757 break; 1758 } 1759 } 1760 } 1761 1762 void MockTransportClientSocketPool::ReleaseSocket(const std::string& group_name, 1763 StreamSocket* socket, int id) { 1764 EXPECT_EQ(1, id); 1765 release_count_++; 1766 delete socket; 1767 } 1768 1769 DeterministicMockClientSocketFactory::DeterministicMockClientSocketFactory() {} 1770 1771 DeterministicMockClientSocketFactory::~DeterministicMockClientSocketFactory() {} 1772 1773 void DeterministicMockClientSocketFactory::AddSocketDataProvider( 1774 DeterministicSocketData* data) { 1775 mock_data_.Add(data); 1776 } 1777 1778 void DeterministicMockClientSocketFactory::AddSSLSocketDataProvider( 1779 SSLSocketDataProvider* data) { 1780 mock_ssl_data_.Add(data); 1781 } 1782 1783 void DeterministicMockClientSocketFactory::ResetNextMockIndexes() { 1784 mock_data_.ResetNextIndex(); 1785 mock_ssl_data_.ResetNextIndex(); 1786 } 1787 1788 MockSSLClientSocket* DeterministicMockClientSocketFactory:: 1789 GetMockSSLClientSocket(size_t index) const { 1790 DCHECK_LT(index, ssl_client_sockets_.size()); 1791 return ssl_client_sockets_[index]; 1792 } 1793 1794 DatagramClientSocket* 1795 DeterministicMockClientSocketFactory::CreateDatagramClientSocket( 1796 DatagramSocket::BindType bind_type, 1797 const RandIntCallback& rand_int_cb, 1798 net::NetLog* net_log, 1799 const NetLog::Source& source) { 1800 DeterministicSocketData* data_provider = mock_data().GetNext(); 1801 DeterministicMockUDPClientSocket* socket = 1802 new DeterministicMockUDPClientSocket(net_log, data_provider); 1803 data_provider->set_delegate(socket->AsWeakPtr()); 1804 udp_client_sockets().push_back(socket); 1805 return socket; 1806 } 1807 1808 StreamSocket* DeterministicMockClientSocketFactory::CreateTransportClientSocket( 1809 const AddressList& addresses, 1810 net::NetLog* net_log, 1811 const net::NetLog::Source& source) { 1812 DeterministicSocketData* data_provider = mock_data().GetNext(); 1813 DeterministicMockTCPClientSocket* socket = 1814 new DeterministicMockTCPClientSocket(net_log, data_provider); 1815 data_provider->set_delegate(socket->AsWeakPtr()); 1816 tcp_client_sockets().push_back(socket); 1817 return socket; 1818 } 1819 1820 SSLClientSocket* DeterministicMockClientSocketFactory::CreateSSLClientSocket( 1821 ClientSocketHandle* transport_socket, 1822 const HostPortPair& host_and_port, 1823 const SSLConfig& ssl_config, 1824 const SSLClientSocketContext& context) { 1825 MockSSLClientSocket* socket = 1826 new MockSSLClientSocket(transport_socket, host_and_port, ssl_config, 1827 mock_ssl_data_.GetNext()); 1828 ssl_client_sockets_.push_back(socket); 1829 return socket; 1830 } 1831 1832 void DeterministicMockClientSocketFactory::ClearSSLSessionCache() { 1833 } 1834 1835 MockSOCKSClientSocketPool::MockSOCKSClientSocketPool( 1836 int max_sockets, 1837 int max_sockets_per_group, 1838 ClientSocketPoolHistograms* histograms, 1839 TransportClientSocketPool* transport_pool) 1840 : SOCKSClientSocketPool(max_sockets, max_sockets_per_group, histograms, 1841 NULL, transport_pool, NULL), 1842 transport_pool_(transport_pool) { 1843 } 1844 1845 MockSOCKSClientSocketPool::~MockSOCKSClientSocketPool() {} 1846 1847 int MockSOCKSClientSocketPool::RequestSocket( 1848 const std::string& group_name, const void* socket_params, 1849 RequestPriority priority, ClientSocketHandle* handle, 1850 const CompletionCallback& callback, const BoundNetLog& net_log) { 1851 return transport_pool_->RequestSocket( 1852 group_name, socket_params, priority, handle, callback, net_log); 1853 } 1854 1855 void MockSOCKSClientSocketPool::CancelRequest( 1856 const std::string& group_name, 1857 ClientSocketHandle* handle) { 1858 return transport_pool_->CancelRequest(group_name, handle); 1859 } 1860 1861 void MockSOCKSClientSocketPool::ReleaseSocket(const std::string& group_name, 1862 StreamSocket* socket, int id) { 1863 return transport_pool_->ReleaseSocket(group_name, socket, id); 1864 } 1865 1866 const char kSOCKS5GreetRequest[] = { 0x05, 0x01, 0x00 }; 1867 const int kSOCKS5GreetRequestLength = arraysize(kSOCKS5GreetRequest); 1868 1869 const char kSOCKS5GreetResponse[] = { 0x05, 0x00 }; 1870 const int kSOCKS5GreetResponseLength = arraysize(kSOCKS5GreetResponse); 1871 1872 const char kSOCKS5OkRequest[] = 1873 { 0x05, 0x01, 0x00, 0x03, 0x04, 'h', 'o', 's', 't', 0x00, 0x50 }; 1874 const int kSOCKS5OkRequestLength = arraysize(kSOCKS5OkRequest); 1875 1876 const char kSOCKS5OkResponse[] = 1877 { 0x05, 0x00, 0x00, 0x01, 127, 0, 0, 1, 0x00, 0x50 }; 1878 const int kSOCKS5OkResponseLength = arraysize(kSOCKS5OkResponse); 1879 1880 } // namespace net 1881