Home | History | Annotate | Download | only in socket
      1 // Copyright (c) 2009 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #ifndef NET_SOCKET_SOCKET_TEST_UTIL_H_
      6 #define NET_SOCKET_SOCKET_TEST_UTIL_H_
      7 
      8 #include <deque>
      9 #include <string>
     10 #include <vector>
     11 
     12 #include "base/basictypes.h"
     13 #include "base/logging.h"
     14 #include "base/scoped_ptr.h"
     15 #include "base/scoped_vector.h"
     16 #include "net/base/address_list.h"
     17 #include "net/base/io_buffer.h"
     18 #include "net/base/net_errors.h"
     19 #include "net/base/ssl_config_service.h"
     20 #include "net/base/test_completion_callback.h"
     21 #include "net/socket/client_socket_factory.h"
     22 #include "net/socket/client_socket_handle.h"
     23 #include "net/socket/ssl_client_socket.h"
     24 #include "testing/gtest/include/gtest/gtest.h"
     25 
     26 namespace net {
     27 
     28 enum {
     29   // A private network error code used by the socket test utility classes.
     30   // If the |result| member of a MockRead is
     31   // ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ, that MockRead is just a
     32   // marker that indicates the peer will close the connection after the next
     33   // MockRead.  The other members of that MockRead are ignored.
     34   ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ = -10000,
     35 };
     36 
     37 class ClientSocket;
     38 class LoadLog;
     39 class MockClientSocket;
     40 class SSLClientSocket;
     41 
     42 struct MockConnect {
     43   // Asynchronous connection success.
     44   MockConnect() : async(true), result(OK) { }
     45   MockConnect(bool a, int r) : async(a), result(r) { }
     46 
     47   bool async;
     48   int result;
     49 };
     50 
     51 struct MockRead {
     52   // Default
     53   MockRead() : async(false), result(0), data(NULL), data_len(0) {}
     54 
     55   // Read failure (no data).
     56   MockRead(bool async, int result) : async(async) , result(result), data(NULL),
     57       data_len(0) { }
     58 
     59   // Asynchronous read success (inferred data length).
     60   explicit MockRead(const char* data) : async(true),  result(0), data(data),
     61       data_len(strlen(data)) { }
     62 
     63   // Read success (inferred data length).
     64   MockRead(bool async, const char* data) : async(async), result(0), data(data),
     65       data_len(strlen(data)) { }
     66 
     67   // Read success.
     68   MockRead(bool async, const char* data, int data_len) : async(async),
     69       result(0), data(data), data_len(data_len) { }
     70 
     71   bool async;
     72   int result;
     73   const char* data;
     74   int data_len;
     75 };
     76 
     77 // MockWrite uses the same member fields as MockRead, but with different
     78 // meanings. The expected input to MockTCPClientSocket::Write() is given
     79 // by {data, data_len}, and the return value of Write() is controlled by
     80 // {async, result}.
     81 typedef MockRead MockWrite;
     82 
     83 struct MockWriteResult {
     84   MockWriteResult(bool async, int result) : async(async), result(result) {}
     85 
     86   bool async;
     87   int result;
     88 };
     89 
     90 // The SocketDataProvider is an interface used by the MockClientSocket
     91 // for getting data about individual reads and writes on the socket.
     92 class SocketDataProvider {
     93  public:
     94   SocketDataProvider() : socket_(NULL) {}
     95 
     96   virtual ~SocketDataProvider() {}
     97 
     98   // Returns the buffer and result code for the next simulated read.
     99   // If the |MockRead.result| is ERR_IO_PENDING, it informs the caller
    100   // that it will be called via the MockClientSocket::OnReadComplete()
    101   // function at a later time.
    102   virtual MockRead GetNextRead() = 0;
    103   virtual MockWriteResult OnWrite(const std::string& data) = 0;
    104   virtual void Reset() = 0;
    105 
    106   // Accessor for the socket which is using the SocketDataProvider.
    107   MockClientSocket* socket() { return socket_; }
    108   void set_socket(MockClientSocket* socket) { socket_ = socket; }
    109 
    110   MockConnect connect_data() const { return connect_; }
    111   void set_connect_data(const MockConnect& connect) { connect_ = connect; }
    112 
    113  private:
    114   MockConnect connect_;
    115   MockClientSocket* socket_;
    116 
    117   DISALLOW_COPY_AND_ASSIGN(SocketDataProvider);
    118 };
    119 
    120 // SocketDataProvider which responds based on static tables of mock reads and
    121 // writes.
    122 class StaticSocketDataProvider : public SocketDataProvider {
    123  public:
    124   StaticSocketDataProvider() : reads_(NULL), read_index_(0),
    125       writes_(NULL), write_index_(0) {}
    126   StaticSocketDataProvider(MockRead* r, MockWrite* w) : reads_(r),
    127       read_index_(0), writes_(w), write_index_(0) {}
    128 
    129   // SocketDataProvider methods:
    130   virtual MockRead GetNextRead();
    131   virtual MockWriteResult OnWrite(const std::string& data);
    132   virtual void Reset();
    133 
    134   // If the test wishes to verify that all data is consumed, it can include
    135   // a EOF MockRead or MockWrite, which is a zero-length Read or Write.
    136   // The test can then call at_read_eof() or at_write_eof() to verify that
    137   // all data has been consumed.
    138   bool at_read_eof() const { return reads_[read_index_].data_len == 0; }
    139   bool at_write_eof() const { return writes_[write_index_].data_len == 0; }
    140 
    141  private:
    142   MockRead* reads_;
    143   int read_index_;
    144   MockWrite* writes_;
    145   int write_index_;
    146 
    147   DISALLOW_COPY_AND_ASSIGN(StaticSocketDataProvider);
    148 };
    149 
    150 // SocketDataProvider which can make decisions about next mock reads based on
    151 // received writes. It can also be used to enforce order of operations, for
    152 // example that tested code must send the "Hello!" message before receiving
    153 // response. This is useful for testing conversation-like protocols like FTP.
    154 class DynamicSocketDataProvider : public SocketDataProvider {
    155  public:
    156   DynamicSocketDataProvider();
    157 
    158   // SocketDataProvider methods:
    159   virtual MockRead GetNextRead();
    160   virtual MockWriteResult OnWrite(const std::string& data) = 0;
    161   virtual void Reset();
    162 
    163   int short_read_limit() const { return short_read_limit_; }
    164   void set_short_read_limit(int limit) { short_read_limit_ = limit; }
    165 
    166   void allow_unconsumed_reads(bool allow) { allow_unconsumed_reads_ = allow; }
    167 
    168  protected:
    169   // The next time there is a read from this socket, it will return |data|.
    170   // Before calling SimulateRead next time, the previous data must be consumed.
    171   void SimulateRead(const char* data);
    172 
    173  private:
    174   std::deque<MockRead> reads_;
    175 
    176   // Max number of bytes we will read at a time. 0 means no limit.
    177   int short_read_limit_;
    178 
    179   // If true, we'll not require the client to consume all data before we
    180   // mock the next read.
    181   bool allow_unconsumed_reads_;
    182 
    183   DISALLOW_COPY_AND_ASSIGN(DynamicSocketDataProvider);
    184 };
    185 
    186 // SSLSocketDataProviders only need to keep track of the return code from calls
    187 // to Connect().
    188 struct SSLSocketDataProvider {
    189   SSLSocketDataProvider(bool async, int result) : connect(async, result) { }
    190 
    191   MockConnect connect;
    192 };
    193 
    194 // Holds an array of SocketDataProvider elements.  As Mock{TCP,SSL}ClientSocket
    195 // objects get instantiated, they take their data from the i'th element of this
    196 // array.
    197 template<typename T>
    198 class SocketDataProviderArray {
    199  public:
    200   SocketDataProviderArray() : next_index_(0) {
    201   }
    202 
    203   T* GetNext() {
    204     DCHECK(next_index_ < data_providers_.size());
    205     return data_providers_[next_index_++];
    206   }
    207 
    208   void Add(T* data_provider) {
    209     DCHECK(data_provider);
    210     data_providers_.push_back(data_provider);
    211   }
    212 
    213   void ResetNextIndex() {
    214     next_index_ = 0;
    215   }
    216 
    217  private:
    218   // Index of the next |data_providers_| element to use. Not an iterator
    219   // because those are invalidated on vector reallocation.
    220   size_t next_index_;
    221 
    222   // SocketDataProviders to be returned.
    223   std::vector<T*> data_providers_;
    224 };
    225 
    226 class MockTCPClientSocket;
    227 class MockSSLClientSocket;
    228 
    229 // ClientSocketFactory which contains arrays of sockets of each type.
    230 // You should first fill the arrays using AddMock{SSL,}Socket. When the factory
    231 // is asked to create a socket, it takes next entry from appropriate array.
    232 // You can use ResetNextMockIndexes to reset that next entry index for all mock
    233 // socket types.
    234 class MockClientSocketFactory : public ClientSocketFactory {
    235  public:
    236   void AddSocketDataProvider(SocketDataProvider* socket);
    237   void AddSSLSocketDataProvider(SSLSocketDataProvider* socket);
    238   void ResetNextMockIndexes();
    239 
    240   // Return |index|-th MockTCPClientSocket (starting from 0) that the factory
    241   // created.
    242   MockTCPClientSocket* GetMockTCPClientSocket(int index) const;
    243 
    244   // Return |index|-th MockSSLClientSocket (starting from 0) that the factory
    245   // created.
    246   MockSSLClientSocket* GetMockSSLClientSocket(int index) const;
    247 
    248   // ClientSocketFactory
    249   virtual ClientSocket* CreateTCPClientSocket(const AddressList& addresses);
    250   virtual SSLClientSocket* CreateSSLClientSocket(
    251       ClientSocket* transport_socket,
    252       const std::string& hostname,
    253       const SSLConfig& ssl_config);
    254 
    255  private:
    256   SocketDataProviderArray<SocketDataProvider> mock_data_;
    257   SocketDataProviderArray<SSLSocketDataProvider> mock_ssl_data_;
    258 
    259   // Store pointers to handed out sockets in case the test wants to get them.
    260   std::vector<MockTCPClientSocket*> tcp_client_sockets_;
    261   std::vector<MockSSLClientSocket*> ssl_client_sockets_;
    262 };
    263 
    264 class MockClientSocket : public net::SSLClientSocket {
    265  public:
    266   MockClientSocket();
    267 
    268   // ClientSocket methods:
    269   virtual int Connect(net::CompletionCallback* callback, LoadLog* load_log) = 0;
    270   virtual void Disconnect();
    271   virtual bool IsConnected() const;
    272   virtual bool IsConnectedAndIdle() const;
    273   virtual int GetPeerName(struct sockaddr* name, socklen_t* namelen);
    274 
    275   // SSLClientSocket methods:
    276   virtual void GetSSLInfo(net::SSLInfo* ssl_info);
    277   virtual void GetSSLCertRequestInfo(
    278       net::SSLCertRequestInfo* cert_request_info);
    279   virtual NextProtoStatus GetNextProto(std::string* proto);
    280 
    281   // Socket methods:
    282   virtual int Read(net::IOBuffer* buf, int buf_len,
    283                    net::CompletionCallback* callback) = 0;
    284   virtual int Write(net::IOBuffer* buf, int buf_len,
    285                     net::CompletionCallback* callback) = 0;
    286   virtual bool SetReceiveBufferSize(int32 size) { return true; }
    287   virtual bool SetSendBufferSize(int32 size) { return true; }
    288 
    289   // If an async IO is pending because the SocketDataProvider returned
    290   // ERR_IO_PENDING, then the MockClientSocket waits until this OnReadComplete
    291   // is called to complete the asynchronous read operation.
    292   // data.async is ignored, and this read is completed synchronously as
    293   // part of this call.
    294   virtual void OnReadComplete(const MockRead& data) = 0;
    295 
    296  protected:
    297   void RunCallbackAsync(net::CompletionCallback* callback, int result);
    298   void RunCallback(net::CompletionCallback*, int result);
    299 
    300   ScopedRunnableMethodFactory<MockClientSocket> method_factory_;
    301 
    302   // True if Connect completed successfully and Disconnect hasn't been called.
    303   bool connected_;
    304 };
    305 
    306 class MockTCPClientSocket : public MockClientSocket {
    307  public:
    308   MockTCPClientSocket(const net::AddressList& addresses,
    309                       net::SocketDataProvider* socket);
    310 
    311   // ClientSocket methods:
    312   virtual int Connect(net::CompletionCallback* callback,
    313                       LoadLog* load_log);
    314   virtual bool IsConnected() const;
    315   virtual bool IsConnectedAndIdle() const { return IsConnected(); }
    316 
    317   // Socket methods:
    318   virtual int Read(net::IOBuffer* buf, int buf_len,
    319                    net::CompletionCallback* callback);
    320   virtual int Write(net::IOBuffer* buf, int buf_len,
    321                     net::CompletionCallback* callback);
    322 
    323   virtual void OnReadComplete(const MockRead& data);
    324 
    325   net::AddressList addresses() const { return addresses_; }
    326 
    327  private:
    328   int CompleteRead();
    329 
    330   net::AddressList addresses_;
    331 
    332   net::SocketDataProvider* data_;
    333   int read_offset_;
    334   net::MockRead read_data_;
    335   bool need_read_data_;
    336 
    337   // True if the peer has closed the connection.  This allows us to simulate
    338   // the recv(..., MSG_PEEK) call in the IsConnectedAndIdle method of the real
    339   // TCPClientSocket.
    340   bool peer_closed_connection_;
    341 
    342   // While an asynchronous IO is pending, we save our user-buffer state.
    343   net::IOBuffer* pending_buf_;
    344   int pending_buf_len_;
    345   net::CompletionCallback* pending_callback_;
    346 };
    347 
    348 class MockSSLClientSocket : public MockClientSocket {
    349  public:
    350   MockSSLClientSocket(
    351       net::ClientSocket* transport_socket,
    352       const std::string& hostname,
    353       const net::SSLConfig& ssl_config,
    354       net::SSLSocketDataProvider* socket);
    355   ~MockSSLClientSocket();
    356 
    357   virtual void GetSSLInfo(net::SSLInfo* ssl_info);
    358 
    359   virtual int Connect(net::CompletionCallback* callback, LoadLog* load_log);
    360   virtual void Disconnect();
    361 
    362   // Socket methods:
    363   virtual int Read(net::IOBuffer* buf, int buf_len,
    364                    net::CompletionCallback* callback);
    365   virtual int Write(net::IOBuffer* buf, int buf_len,
    366                     net::CompletionCallback* callback);
    367 
    368   // This MockSocket does not implement the manual async IO feature.
    369   virtual void OnReadComplete(const MockRead& data) { NOTIMPLEMENTED(); }
    370 
    371  private:
    372   class ConnectCallback;
    373 
    374   scoped_ptr<ClientSocket> transport_;
    375   net::SSLSocketDataProvider* data_;
    376 };
    377 
    378 class TestSocketRequest : public CallbackRunner< Tuple1<int> > {
    379  public:
    380   TestSocketRequest(
    381       std::vector<TestSocketRequest*>* request_order,
    382       size_t* completion_count)
    383       : request_order_(request_order),
    384         completion_count_(completion_count) {
    385     DCHECK(request_order);
    386     DCHECK(completion_count);
    387   }
    388 
    389   ClientSocketHandle* handle() { return &handle_; }
    390 
    391   int WaitForResult();
    392   virtual void RunWithParams(const Tuple1<int>& params);
    393 
    394  private:
    395   ClientSocketHandle handle_;
    396   std::vector<TestSocketRequest*>* request_order_;
    397   size_t* completion_count_;
    398   TestCompletionCallback callback_;
    399 };
    400 
    401 class ClientSocketPoolTest : public testing::Test {
    402  protected:
    403   enum KeepAlive {
    404     KEEP_ALIVE,
    405 
    406     // A socket will be disconnected in addition to handle being reset.
    407     NO_KEEP_ALIVE,
    408   };
    409 
    410   static const int kIndexOutOfBounds;
    411   static const int kRequestNotFound;
    412 
    413   virtual void SetUp();
    414   virtual void TearDown();
    415 
    416   template <typename PoolType, typename SocketParams>
    417   int StartRequestUsingPool(PoolType* socket_pool,
    418                             const std::string& group_name,
    419                             RequestPriority priority,
    420                             const SocketParams& socket_params) {
    421     DCHECK(socket_pool);
    422     TestSocketRequest* request = new TestSocketRequest(&request_order_,
    423                                                        &completion_count_);
    424     requests_.push_back(request);
    425     int rv = request->handle()->Init(
    426         group_name, socket_params, priority, request,
    427         socket_pool, NULL);
    428     if (rv != ERR_IO_PENDING)
    429       request_order_.push_back(request);
    430     return rv;
    431   }
    432 
    433   // Provided there were n requests started, takes |index| in range 1..n
    434   // and returns order in which that request completed, in range 1..n,
    435   // or kIndexOutOfBounds if |index| is out of bounds, or kRequestNotFound
    436   // if that request did not complete (for example was canceled).
    437   int GetOrderOfRequest(size_t index);
    438 
    439   // Resets first initialized socket handle from |requests_|. If found such
    440   // a handle, returns true.
    441   bool ReleaseOneConnection(KeepAlive keep_alive);
    442 
    443   // Releases connections until there is nothing to release.
    444   void ReleaseAllConnections(KeepAlive keep_alive);
    445 
    446   ScopedVector<TestSocketRequest> requests_;
    447   std::vector<TestSocketRequest*> request_order_;
    448   size_t completion_count_;
    449 };
    450 
    451 }  // namespace net
    452 
    453 #endif  // NET_SOCKET_SOCKET_TEST_UTIL_H_
    454