Home | History | Annotate | Download | only in socket
      1 // Copyright (c) 2009 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/socket/socket_test_util.h"
      6 
      7 #include <algorithm>
      8 
      9 #include "base/basictypes.h"
     10 #include "base/compiler_specific.h"
     11 #include "base/message_loop.h"
     12 #include "net/base/ssl_info.h"
     13 #include "net/socket/socket.h"
     14 #include "testing/gtest/include/gtest/gtest.h"
     15 
     16 namespace net {
     17 
     18 MockClientSocket::MockClientSocket()
     19     : ALLOW_THIS_IN_INITIALIZER_LIST(method_factory_(this)),
     20       connected_(false) {
     21 }
     22 
     23 void MockClientSocket::GetSSLInfo(net::SSLInfo* ssl_info) {
     24   NOTREACHED();
     25 }
     26 
     27 void MockClientSocket::GetSSLCertRequestInfo(
     28     net::SSLCertRequestInfo* cert_request_info) {
     29   NOTREACHED();
     30 }
     31 
     32 SSLClientSocket::NextProtoStatus
     33 MockClientSocket::GetNextProto(std::string* proto) {
     34   proto->clear();
     35   return SSLClientSocket::kNextProtoUnsupported;
     36 }
     37 
     38 void MockClientSocket::Disconnect() {
     39   connected_ = false;
     40 }
     41 
     42 bool MockClientSocket::IsConnected() const {
     43   return connected_;
     44 }
     45 
     46 bool MockClientSocket::IsConnectedAndIdle() const {
     47   return connected_;
     48 }
     49 
     50 int MockClientSocket::GetPeerName(struct sockaddr* name, socklen_t* namelen) {
     51   memset(reinterpret_cast<char *>(name), 0, *namelen);
     52   return net::OK;
     53 }
     54 
     55 void MockClientSocket::RunCallbackAsync(net::CompletionCallback* callback,
     56                                         int result) {
     57   MessageLoop::current()->PostTask(FROM_HERE,
     58       method_factory_.NewRunnableMethod(
     59           &MockClientSocket::RunCallback, callback, result));
     60 }
     61 
     62 void MockClientSocket::RunCallback(net::CompletionCallback* callback,
     63                                    int result) {
     64   if (callback)
     65     callback->Run(result);
     66 }
     67 
     68 MockTCPClientSocket::MockTCPClientSocket(const net::AddressList& addresses,
     69                                          net::SocketDataProvider* data)
     70     : addresses_(addresses),
     71       data_(data),
     72       read_offset_(0),
     73       read_data_(false, net::ERR_UNEXPECTED),
     74       need_read_data_(true),
     75       peer_closed_connection_(false),
     76       pending_buf_(NULL),
     77       pending_buf_len_(0),
     78       pending_callback_(NULL) {
     79   DCHECK(data_);
     80   data_->Reset();
     81 }
     82 
     83 int MockTCPClientSocket::Connect(net::CompletionCallback* callback,
     84                                  LoadLog* load_log) {
     85   if (connected_)
     86     return net::OK;
     87   connected_ = true;
     88   if (data_->connect_data().async) {
     89     RunCallbackAsync(callback, data_->connect_data().result);
     90     return net::ERR_IO_PENDING;
     91   }
     92   return data_->connect_data().result;
     93 }
     94 
     95 bool MockTCPClientSocket::IsConnected() const {
     96   return connected_ && !peer_closed_connection_;
     97 }
     98 
     99 int MockTCPClientSocket::Read(net::IOBuffer* buf, int buf_len,
    100                               net::CompletionCallback* callback) {
    101   if (!connected_)
    102     return net::ERR_UNEXPECTED;
    103 
    104   // If the buffer is already in use, a read is already in progress!
    105   DCHECK(pending_buf_ == NULL);
    106 
    107   // Store our async IO data.
    108   pending_buf_ = buf;
    109   pending_buf_len_ = buf_len;
    110   pending_callback_ = callback;
    111 
    112   if (need_read_data_) {
    113     read_data_ = data_->GetNextRead();
    114     if (read_data_.result == ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ) {
    115       // This MockRead is just a marker to instruct us to set
    116       // peer_closed_connection_.  Skip it and get the next one.
    117       read_data_ = data_->GetNextRead();
    118       peer_closed_connection_ = true;
    119     }
    120     // ERR_IO_PENDING means that the SocketDataProvider is taking responsibility
    121     // to complete the async IO manually later (via OnReadComplete).
    122     if (read_data_.result == ERR_IO_PENDING) {
    123       DCHECK(callback);  // We need to be using async IO in this case.
    124       return ERR_IO_PENDING;
    125     }
    126     need_read_data_ = false;
    127   }
    128 
    129   return CompleteRead();
    130 }
    131 
    132 int MockTCPClientSocket::Write(net::IOBuffer* buf, int buf_len,
    133                                net::CompletionCallback* callback) {
    134   DCHECK(buf);
    135   DCHECK_GT(buf_len, 0);
    136 
    137   if (!connected_)
    138     return net::ERR_UNEXPECTED;
    139 
    140   std::string data(buf->data(), buf_len);
    141   net::MockWriteResult write_result = data_->OnWrite(data);
    142 
    143   if (write_result.async) {
    144     RunCallbackAsync(callback, write_result.result);
    145     return net::ERR_IO_PENDING;
    146   }
    147   return write_result.result;
    148 }
    149 
    150 void MockTCPClientSocket::OnReadComplete(const MockRead& data) {
    151   // There must be a read pending.
    152   DCHECK(pending_buf_);
    153   // You can't complete a read with another ERR_IO_PENDING status code.
    154   DCHECK_NE(ERR_IO_PENDING, data.result);
    155   // Since we've been waiting for data, need_read_data_ should be true.
    156   DCHECK(need_read_data_);
    157 
    158   read_data_ = data;
    159   need_read_data_ = false;
    160 
    161   // The caller is simulating that this IO completes right now.  Don't
    162   // let CompleteRead() schedule a callback.
    163   read_data_.async = false;
    164 
    165   net::CompletionCallback* callback = pending_callback_;
    166   int rv = CompleteRead();
    167   RunCallback(callback, rv);
    168 }
    169 
    170 int MockTCPClientSocket::CompleteRead() {
    171   DCHECK(pending_buf_);
    172   DCHECK(pending_buf_len_ > 0);
    173 
    174   // Save the pending async IO data and reset our |pending_| state.
    175   net::IOBuffer* buf = pending_buf_;
    176   int buf_len = pending_buf_len_;
    177   net::CompletionCallback* callback = pending_callback_;
    178   pending_buf_ = NULL;
    179   pending_buf_len_ = 0;
    180   pending_callback_ = NULL;
    181 
    182   int result = read_data_.result;
    183   DCHECK(result != ERR_IO_PENDING);
    184 
    185   if (read_data_.data) {
    186     if (read_data_.data_len - read_offset_ > 0) {
    187       result = std::min(buf_len, read_data_.data_len - read_offset_);
    188       memcpy(buf->data(), read_data_.data + read_offset_, result);
    189       read_offset_ += result;
    190       if (read_offset_ == read_data_.data_len) {
    191         need_read_data_ = true;
    192         read_offset_ = 0;
    193       }
    194     } else {
    195       result = 0;  // EOF
    196     }
    197   }
    198 
    199   if (read_data_.async) {
    200     DCHECK(callback);
    201     RunCallbackAsync(callback, result);
    202     return net::ERR_IO_PENDING;
    203   }
    204   return result;
    205 }
    206 
    207 class MockSSLClientSocket::ConnectCallback :
    208     public net::CompletionCallbackImpl<MockSSLClientSocket::ConnectCallback> {
    209  public:
    210   ConnectCallback(MockSSLClientSocket *ssl_client_socket,
    211                   net::CompletionCallback* user_callback,
    212                   int rv)
    213       : ALLOW_THIS_IN_INITIALIZER_LIST(
    214           net::CompletionCallbackImpl<MockSSLClientSocket::ConnectCallback>(
    215                 this, &ConnectCallback::Wrapper)),
    216         ssl_client_socket_(ssl_client_socket),
    217         user_callback_(user_callback),
    218         rv_(rv) {
    219   }
    220 
    221  private:
    222   void Wrapper(int rv) {
    223     if (rv_ == net::OK)
    224       ssl_client_socket_->connected_ = true;
    225     user_callback_->Run(rv_);
    226     delete this;
    227   }
    228 
    229   MockSSLClientSocket* ssl_client_socket_;
    230   net::CompletionCallback* user_callback_;
    231   int rv_;
    232 };
    233 
    234 MockSSLClientSocket::MockSSLClientSocket(
    235     net::ClientSocket* transport_socket,
    236     const std::string& hostname,
    237     const net::SSLConfig& ssl_config,
    238     net::SSLSocketDataProvider* data)
    239     : transport_(transport_socket),
    240       data_(data) {
    241   DCHECK(data_);
    242 }
    243 
    244 MockSSLClientSocket::~MockSSLClientSocket() {
    245   Disconnect();
    246 }
    247 
    248 void MockSSLClientSocket::GetSSLInfo(net::SSLInfo* ssl_info) {
    249   ssl_info->Reset();
    250 }
    251 
    252 int MockSSLClientSocket::Connect(net::CompletionCallback* callback,
    253                                  LoadLog* load_log) {
    254   ConnectCallback* connect_callback = new ConnectCallback(
    255       this, callback, data_->connect.result);
    256   int rv = transport_->Connect(connect_callback, load_log);
    257   if (rv == net::OK) {
    258     delete connect_callback;
    259     if (data_->connect.async) {
    260       RunCallbackAsync(callback, data_->connect.result);
    261       return net::ERR_IO_PENDING;
    262     }
    263     if (data_->connect.result == net::OK)
    264       connected_ = true;
    265     return data_->connect.result;
    266   }
    267   return rv;
    268 }
    269 
    270 void MockSSLClientSocket::Disconnect() {
    271   MockClientSocket::Disconnect();
    272   if (transport_ != NULL)
    273     transport_->Disconnect();
    274 }
    275 
    276 int MockSSLClientSocket::Read(net::IOBuffer* buf, int buf_len,
    277                               net::CompletionCallback* callback) {
    278   return transport_->Read(buf, buf_len, callback);
    279 }
    280 
    281 int MockSSLClientSocket::Write(net::IOBuffer* buf, int buf_len,
    282                                net::CompletionCallback* callback) {
    283   return transport_->Write(buf, buf_len, callback);
    284 }
    285 
    286 MockRead StaticSocketDataProvider::GetNextRead() {
    287   MockRead rv = reads_[read_index_];
    288   if (reads_[read_index_].result != OK ||
    289       reads_[read_index_].data_len != 0)
    290     read_index_++;  // Don't advance past an EOF.
    291   return rv;
    292 }
    293 
    294 MockWriteResult StaticSocketDataProvider::OnWrite(const std::string& data) {
    295   if (!writes_) {
    296     // Not using mock writes; succeed synchronously.
    297     return MockWriteResult(false, data.length());
    298   }
    299 
    300   // Check that what we are writing matches the expectation.
    301   // Then give the mocked return value.
    302   net::MockWrite* w = &writes_[write_index_++];
    303   int result = w->result;
    304   if (w->data) {
    305     // Note - we can simulate a partial write here.  If the expected data
    306     // is a match, but shorter than the write actually written, that is legal.
    307     // Example:
    308     //   Application writes "foobarbaz" (9 bytes)
    309     //   Expected write was "foo" (3 bytes)
    310     //   This is a success, and we return 3 to the application.
    311     std::string expected_data(w->data, w->data_len);
    312     EXPECT_GE(data.length(), expected_data.length());
    313     std::string actual_data(data.substr(0, w->data_len));
    314     EXPECT_EQ(expected_data, actual_data);
    315     if (expected_data != actual_data)
    316       return MockWriteResult(false, net::ERR_UNEXPECTED);
    317     if (result == net::OK)
    318       result = w->data_len;
    319   }
    320   return MockWriteResult(w->async, result);
    321 }
    322 
    323 void StaticSocketDataProvider::Reset() {
    324   read_index_ = 0;
    325   write_index_ = 0;
    326 }
    327 
    328 DynamicSocketDataProvider::DynamicSocketDataProvider()
    329     : short_read_limit_(0),
    330       allow_unconsumed_reads_(false) {
    331 }
    332 
    333 MockRead DynamicSocketDataProvider::GetNextRead() {
    334   if (reads_.empty())
    335     return MockRead(false, ERR_UNEXPECTED);
    336   MockRead result = reads_.front();
    337   if (short_read_limit_ == 0 || result.data_len <= short_read_limit_) {
    338     reads_.pop_front();
    339   } else {
    340     result.data_len = short_read_limit_;
    341     reads_.front().data += result.data_len;
    342     reads_.front().data_len -= result.data_len;
    343   }
    344   return result;
    345 }
    346 
    347 void DynamicSocketDataProvider::Reset() {
    348   reads_.clear();
    349 }
    350 
    351 void DynamicSocketDataProvider::SimulateRead(const char* data) {
    352   if (!allow_unconsumed_reads_) {
    353     EXPECT_TRUE(reads_.empty()) << "Unconsumed read: " << reads_.front().data;
    354   }
    355   reads_.push_back(MockRead(data));
    356 }
    357 
    358 void MockClientSocketFactory::AddSocketDataProvider(
    359     SocketDataProvider* data) {
    360   mock_data_.Add(data);
    361 }
    362 
    363 void MockClientSocketFactory::AddSSLSocketDataProvider(
    364     SSLSocketDataProvider* data) {
    365   mock_ssl_data_.Add(data);
    366 }
    367 
    368 void MockClientSocketFactory::ResetNextMockIndexes() {
    369   mock_data_.ResetNextIndex();
    370   mock_ssl_data_.ResetNextIndex();
    371 }
    372 
    373 MockTCPClientSocket* MockClientSocketFactory::GetMockTCPClientSocket(
    374     int index) const {
    375   return tcp_client_sockets_[index];
    376 }
    377 
    378 MockSSLClientSocket* MockClientSocketFactory::GetMockSSLClientSocket(
    379     int index) const {
    380   return ssl_client_sockets_[index];
    381 }
    382 
    383 ClientSocket* MockClientSocketFactory::CreateTCPClientSocket(
    384     const AddressList& addresses) {
    385   SocketDataProvider* data_provider = mock_data_.GetNext();
    386   MockTCPClientSocket* socket =
    387       new MockTCPClientSocket(addresses, data_provider);
    388   data_provider->set_socket(socket);
    389   tcp_client_sockets_.push_back(socket);
    390   return socket;
    391 }
    392 
    393 SSLClientSocket* MockClientSocketFactory::CreateSSLClientSocket(
    394     ClientSocket* transport_socket,
    395     const std::string& hostname,
    396     const SSLConfig& ssl_config) {
    397   MockSSLClientSocket* socket =
    398       new MockSSLClientSocket(transport_socket, hostname, ssl_config,
    399                               mock_ssl_data_.GetNext());
    400   ssl_client_sockets_.push_back(socket);
    401   return socket;
    402 }
    403 
    404 int TestSocketRequest::WaitForResult() {
    405   return callback_.WaitForResult();
    406 }
    407 
    408 void TestSocketRequest::RunWithParams(const Tuple1<int>& params) {
    409   callback_.RunWithParams(params);
    410   (*completion_count_)++;
    411   request_order_->push_back(this);
    412 }
    413 
    414 // static
    415 const int ClientSocketPoolTest::kIndexOutOfBounds = -1;
    416 
    417 // static
    418 const int ClientSocketPoolTest::kRequestNotFound = -2;
    419 
    420 void ClientSocketPoolTest::SetUp() {
    421   completion_count_ = 0;
    422 }
    423 
    424 void ClientSocketPoolTest::TearDown() {
    425   // The tests often call Reset() on handles at the end which may post
    426   // DoReleaseSocket() tasks.
    427   // Pending tasks created by client_socket_pool_base_unittest.cc are
    428   // posted two milliseconds into the future and thus won't become
    429   // scheduled until that time.
    430   // We wait a few milliseconds to make sure that all such future tasks
    431   // are ready to run, before calling RunAllPending(). This will work
    432   // correctly even if Sleep() finishes late (and it should never finish
    433   // early), as all we have to ensure is that actual wall-time has progressed
    434   // past the scheduled starting time of the pending task.
    435   PlatformThread::Sleep(10);
    436   MessageLoop::current()->RunAllPending();
    437 }
    438 
    439 int ClientSocketPoolTest::GetOrderOfRequest(size_t index) {
    440   index--;
    441   if (index >= requests_.size())
    442     return kIndexOutOfBounds;
    443 
    444   for (size_t i = 0; i < request_order_.size(); i++)
    445     if (requests_[index] == request_order_[i])
    446       return i + 1;
    447 
    448   return kRequestNotFound;
    449 }
    450 
    451 bool ClientSocketPoolTest::ReleaseOneConnection(KeepAlive keep_alive) {
    452   ScopedVector<TestSocketRequest>::iterator i;
    453   for (i = requests_.begin(); i != requests_.end(); ++i) {
    454     if ((*i)->handle()->is_initialized()) {
    455       if (keep_alive == NO_KEEP_ALIVE)
    456         (*i)->handle()->socket()->Disconnect();
    457       (*i)->handle()->Reset();
    458       MessageLoop::current()->RunAllPending();
    459       return true;
    460     }
    461   }
    462   return false;
    463 }
    464 
    465 void ClientSocketPoolTest::ReleaseAllConnections(KeepAlive keep_alive) {
    466   bool released_one;
    467   do {
    468     released_one = ReleaseOneConnection(keep_alive);
    469   } while (released_one);
    470 }
    471 
    472 }  // namespace net
    473