Home | History | Annotate | Download | only in socket
      1 // Copyright (c) 2011 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #ifndef NET_SOCKET_SOCKET_TEST_UTIL_H_
      6 #define NET_SOCKET_SOCKET_TEST_UTIL_H_
      7 #pragma once
      8 
      9 #include <cstring>
     10 #include <deque>
     11 #include <string>
     12 #include <vector>
     13 
     14 #include "base/basictypes.h"
     15 #include "base/callback.h"
     16 #include "base/logging.h"
     17 #include "base/memory/scoped_ptr.h"
     18 #include "base/memory/scoped_vector.h"
     19 #include "base/memory/weak_ptr.h"
     20 #include "base/string16.h"
     21 #include "net/base/address_list.h"
     22 #include "net/base/io_buffer.h"
     23 #include "net/base/net_errors.h"
     24 #include "net/base/net_log.h"
     25 #include "net/base/ssl_config_service.h"
     26 #include "net/base/test_completion_callback.h"
     27 #include "net/http/http_auth_controller.h"
     28 #include "net/http/http_proxy_client_socket_pool.h"
     29 #include "net/socket/client_socket_factory.h"
     30 #include "net/socket/client_socket_handle.h"
     31 #include "net/socket/socks_client_socket_pool.h"
     32 #include "net/socket/ssl_client_socket.h"
     33 #include "net/socket/ssl_client_socket_pool.h"
     34 #include "net/socket/transport_client_socket_pool.h"
     35 #include "testing/gtest/include/gtest/gtest.h"
     36 
     37 namespace net {
     38 
     39 enum {
     40   // A private network error code used by the socket test utility classes.
     41   // If the |result| member of a MockRead is
     42   // ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ, that MockRead is just a
     43   // marker that indicates the peer will close the connection after the next
     44   // MockRead.  The other members of that MockRead are ignored.
     45   ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ = -10000,
     46 };
     47 
     48 class ClientSocket;
     49 class MockClientSocket;
     50 class SSLClientSocket;
     51 class SSLHostInfo;
     52 
     53 struct MockConnect {
     54   // Asynchronous connection success.
     55   MockConnect() : async(true), result(OK) { }
     56   MockConnect(bool a, int r) : async(a), result(r) { }
     57 
     58   bool async;
     59   int result;
     60 };
     61 
     62 struct MockRead {
     63   // Flag to indicate that the message loop should be terminated.
     64   enum {
     65     STOPLOOP = 1 << 31
     66   };
     67 
     68   // Default
     69   MockRead() : async(false), result(0), data(NULL), data_len(0),
     70       sequence_number(0), time_stamp(base::Time::Now()) {}
     71 
     72   // Read failure (no data).
     73   MockRead(bool async, int result) : async(async) , result(result), data(NULL),
     74       data_len(0), sequence_number(0), time_stamp(base::Time::Now()) { }
     75 
     76   // Read failure (no data), with sequence information.
     77   MockRead(bool async, int result, int seq) : async(async) , result(result),
     78       data(NULL), data_len(0), sequence_number(seq),
     79       time_stamp(base::Time::Now()) { }
     80 
     81   // Asynchronous read success (inferred data length).
     82   explicit MockRead(const char* data) : async(true),  result(0), data(data),
     83       data_len(strlen(data)), sequence_number(0),
     84       time_stamp(base::Time::Now()) { }
     85 
     86   // Read success (inferred data length).
     87   MockRead(bool async, const char* data) : async(async), result(0), data(data),
     88       data_len(strlen(data)), sequence_number(0),
     89       time_stamp(base::Time::Now()) { }
     90 
     91   // Read success.
     92   MockRead(bool async, const char* data, int data_len) : async(async),
     93       result(0), data(data), data_len(data_len), sequence_number(0),
     94       time_stamp(base::Time::Now()) { }
     95 
     96   // Read success (inferred data length) with sequence information.
     97   MockRead(bool async, int seq, const char* data) : async(async),
     98       result(0), data(data), data_len(strlen(data)), sequence_number(seq),
     99       time_stamp(base::Time::Now()) { }
    100 
    101   // Read success with sequence information.
    102   MockRead(bool async, const char* data, int data_len, int seq) : async(async),
    103       result(0), data(data), data_len(data_len), sequence_number(seq),
    104       time_stamp(base::Time::Now()) { }
    105 
    106   bool async;
    107   int result;
    108   const char* data;
    109   int data_len;
    110 
    111   // For OrderedSocketData, which only allows reads to occur in a particular
    112   // sequence.  If a read occurs before the given |sequence_number| is reached,
    113   // an ERR_IO_PENDING is returned.
    114   int sequence_number;      // The sequence number at which a read is allowed
    115                             // to occur.
    116   base::Time time_stamp;    // The time stamp at which the operation occurred.
    117 };
    118 
    119 // MockWrite uses the same member fields as MockRead, but with different
    120 // meanings. The expected input to MockTCPClientSocket::Write() is given
    121 // by {data, data_len}, and the return value of Write() is controlled by
    122 // {async, result}.
    123 typedef MockRead MockWrite;
    124 
    125 struct MockWriteResult {
    126   MockWriteResult(bool async, int result) : async(async), result(result) {}
    127 
    128   bool async;
    129   int result;
    130 };
    131 
    132 // The SocketDataProvider is an interface used by the MockClientSocket
    133 // for getting data about individual reads and writes on the socket.
    134 class SocketDataProvider {
    135  public:
    136   SocketDataProvider() : socket_(NULL) {}
    137 
    138   virtual ~SocketDataProvider() {}
    139 
    140   // Returns the buffer and result code for the next simulated read.
    141   // If the |MockRead.result| is ERR_IO_PENDING, it informs the caller
    142   // that it will be called via the MockClientSocket::OnReadComplete()
    143   // function at a later time.
    144   virtual MockRead GetNextRead() = 0;
    145   virtual MockWriteResult OnWrite(const std::string& data) = 0;
    146   virtual void Reset() = 0;
    147 
    148   // Accessor for the socket which is using the SocketDataProvider.
    149   MockClientSocket* socket() { return socket_; }
    150   void set_socket(MockClientSocket* socket) { socket_ = socket; }
    151 
    152   MockConnect connect_data() const { return connect_; }
    153   void set_connect_data(const MockConnect& connect) { connect_ = connect; }
    154 
    155  private:
    156   MockConnect connect_;
    157   MockClientSocket* socket_;
    158 
    159   DISALLOW_COPY_AND_ASSIGN(SocketDataProvider);
    160 };
    161 
    162 // SocketDataProvider which responds based on static tables of mock reads and
    163 // writes.
    164 class StaticSocketDataProvider : public SocketDataProvider {
    165  public:
    166   StaticSocketDataProvider();
    167   StaticSocketDataProvider(MockRead* reads, size_t reads_count,
    168                            MockWrite* writes, size_t writes_count);
    169   virtual ~StaticSocketDataProvider();
    170 
    171   // These functions get access to the next available read and write data.
    172   const MockRead& PeekRead() const;
    173   const MockWrite& PeekWrite() const;
    174   // These functions get random access to the read and write data, for timing.
    175   const MockRead& PeekRead(size_t index) const;
    176   const MockWrite& PeekWrite(size_t index) const;
    177   size_t read_index() const { return read_index_; }
    178   size_t write_index() const { return write_index_; }
    179   size_t read_count() const { return read_count_; }
    180   size_t write_count() const { return write_count_; }
    181 
    182   bool at_read_eof() const { return read_index_ >= read_count_; }
    183   bool at_write_eof() const { return write_index_ >= write_count_; }
    184 
    185   virtual void CompleteRead() {}
    186 
    187   // SocketDataProvider methods:
    188   virtual MockRead GetNextRead();
    189   virtual MockWriteResult OnWrite(const std::string& data);
    190   virtual void Reset();
    191 
    192  private:
    193   MockRead* reads_;
    194   size_t read_index_;
    195   size_t read_count_;
    196   MockWrite* writes_;
    197   size_t write_index_;
    198   size_t write_count_;
    199 
    200   DISALLOW_COPY_AND_ASSIGN(StaticSocketDataProvider);
    201 };
    202 
    203 // SocketDataProvider which can make decisions about next mock reads based on
    204 // received writes. It can also be used to enforce order of operations, for
    205 // example that tested code must send the "Hello!" message before receiving
    206 // response. This is useful for testing conversation-like protocols like FTP.
    207 class DynamicSocketDataProvider : public SocketDataProvider {
    208  public:
    209   DynamicSocketDataProvider();
    210   virtual ~DynamicSocketDataProvider();
    211 
    212   int short_read_limit() const { return short_read_limit_; }
    213   void set_short_read_limit(int limit) { short_read_limit_ = limit; }
    214 
    215   void allow_unconsumed_reads(bool allow) { allow_unconsumed_reads_ = allow; }
    216 
    217   // SocketDataProvider methods:
    218   virtual MockRead GetNextRead();
    219   virtual MockWriteResult OnWrite(const std::string& data) = 0;
    220   virtual void Reset();
    221 
    222  protected:
    223   // The next time there is a read from this socket, it will return |data|.
    224   // Before calling SimulateRead next time, the previous data must be consumed.
    225   void SimulateRead(const char* data, size_t length);
    226   void SimulateRead(const char* data) {
    227     SimulateRead(data, std::strlen(data));
    228   }
    229 
    230  private:
    231   std::deque<MockRead> reads_;
    232 
    233   // Max number of bytes we will read at a time. 0 means no limit.
    234   int short_read_limit_;
    235 
    236   // If true, we'll not require the client to consume all data before we
    237   // mock the next read.
    238   bool allow_unconsumed_reads_;
    239 
    240   DISALLOW_COPY_AND_ASSIGN(DynamicSocketDataProvider);
    241 };
    242 
    243 // SSLSocketDataProviders only need to keep track of the return code from calls
    244 // to Connect().
    245 struct SSLSocketDataProvider {
    246   SSLSocketDataProvider(bool async, int result);
    247   ~SSLSocketDataProvider();
    248 
    249   MockConnect connect;
    250   SSLClientSocket::NextProtoStatus next_proto_status;
    251   std::string next_proto;
    252   bool was_npn_negotiated;
    253   net::SSLCertRequestInfo* cert_request_info;
    254   scoped_refptr<X509Certificate> cert_;
    255 };
    256 
    257 // A DataProvider where the client must write a request before the reads (e.g.
    258 // the response) will complete.
    259 class DelayedSocketData : public StaticSocketDataProvider,
    260                           public base::RefCounted<DelayedSocketData> {
    261  public:
    262   // |write_delay| the number of MockWrites to complete before allowing
    263   //               a MockRead to complete.
    264   // |reads| the list of MockRead completions.
    265   // |writes| the list of MockWrite completions.
    266   // Note: All MockReads and MockWrites must be async.
    267   // Note: The MockRead and MockWrite lists musts end with a EOF
    268   //       e.g. a MockRead(true, 0, 0);
    269   DelayedSocketData(int write_delay,
    270                     MockRead* reads, size_t reads_count,
    271                     MockWrite* writes, size_t writes_count);
    272 
    273   // |connect| the result for the connect phase.
    274   // |reads| the list of MockRead completions.
    275   // |write_delay| the number of MockWrites to complete before allowing
    276   //               a MockRead to complete.
    277   // |writes| the list of MockWrite completions.
    278   // Note: All MockReads and MockWrites must be async.
    279   // Note: The MockRead and MockWrite lists musts end with a EOF
    280   //       e.g. a MockRead(true, 0, 0);
    281   DelayedSocketData(const MockConnect& connect, int write_delay,
    282                     MockRead* reads, size_t reads_count,
    283                     MockWrite* writes, size_t writes_count);
    284   ~DelayedSocketData();
    285 
    286   void ForceNextRead();
    287 
    288   // StaticSocketDataProvider:
    289   virtual MockRead GetNextRead();
    290   virtual MockWriteResult OnWrite(const std::string& data);
    291   virtual void Reset();
    292   virtual void CompleteRead();
    293 
    294  private:
    295   int write_delay_;
    296   ScopedRunnableMethodFactory<DelayedSocketData> factory_;
    297 };
    298 
    299 // A DataProvider where the reads are ordered.
    300 // If a read is requested before its sequence number is reached, we return an
    301 // ERR_IO_PENDING (that way we don't have to explicitly add a MockRead just to
    302 // wait).
    303 // The sequence number is incremented on every read and write operation.
    304 // The message loop may be interrupted by setting the high bit of the sequence
    305 // number in the MockRead's sequence number.  When that MockRead is reached,
    306 // we post a Quit message to the loop.  This allows us to interrupt the reading
    307 // of data before a complete message has arrived, and provides support for
    308 // testing server push when the request is issued while the response is in the
    309 // middle of being received.
    310 class OrderedSocketData : public StaticSocketDataProvider,
    311                           public base::RefCounted<OrderedSocketData> {
    312  public:
    313   // |reads| the list of MockRead completions.
    314   // |writes| the list of MockWrite completions.
    315   // Note: All MockReads and MockWrites must be async.
    316   // Note: The MockRead and MockWrite lists musts end with a EOF
    317   //       e.g. a MockRead(true, 0, 0);
    318   OrderedSocketData(MockRead* reads, size_t reads_count,
    319                     MockWrite* writes, size_t writes_count);
    320 
    321   // |connect| the result for the connect phase.
    322   // |reads| the list of MockRead completions.
    323   // |writes| the list of MockWrite completions.
    324   // Note: All MockReads and MockWrites must be async.
    325   // Note: The MockRead and MockWrite lists musts end with a EOF
    326   //       e.g. a MockRead(true, 0, 0);
    327   OrderedSocketData(const MockConnect& connect,
    328                     MockRead* reads, size_t reads_count,
    329                     MockWrite* writes, size_t writes_count);
    330 
    331   void SetCompletionCallback(CompletionCallback* callback) {
    332     callback_ = callback;
    333   }
    334 
    335   // Posts a quit message to the current message loop, if one is running.
    336   void EndLoop();
    337 
    338   // StaticSocketDataProvider:
    339   virtual MockRead GetNextRead();
    340   virtual MockWriteResult OnWrite(const std::string& data);
    341   virtual void Reset();
    342   virtual void CompleteRead();
    343 
    344  private:
    345   friend class base::RefCounted<OrderedSocketData>;
    346   virtual ~OrderedSocketData();
    347 
    348   int sequence_number_;
    349   int loop_stop_stage_;
    350   CompletionCallback* callback_;
    351   bool blocked_;
    352   ScopedRunnableMethodFactory<OrderedSocketData> factory_;
    353 };
    354 
    355 class DeterministicMockTCPClientSocket;
    356 
    357 // This class gives the user full control over the network activity,
    358 // specifically the timing of the COMPLETION of I/O operations.  Regardless of
    359 // the order in which I/O operations are initiated, this class ensures that they
    360 // complete in the correct order.
    361 //
    362 // Network activity is modeled as a sequence of numbered steps which is
    363 // incremented whenever an I/O operation completes.  This can happen under two
    364 // different circumstances:
    365 //
    366 // 1) Performing a synchronous I/O operation.  (Invoking Read() or Write()
    367 //    when the corresponding MockRead or MockWrite is marked !async).
    368 // 2) Running the Run() method of this class.  The run method will invoke
    369 //    the current MessageLoop, running all pending events, and will then
    370 //    invoke any pending IO callbacks.
    371 //
    372 // In addition, this class allows for I/O processing to "stop" at a specified
    373 // step, by calling SetStop(int) or StopAfter(int).  Initiating an I/O operation
    374 // by calling Read() or Write() while stopped is permitted if the operation is
    375 // asynchronous.  It is an error to perform synchronous I/O while stopped.
    376 //
    377 // When creating the MockReads and MockWrites, note that the sequence number
    378 // refers to the number of the step in which the I/O will complete.  In the
    379 // case of synchronous I/O, this will be the same step as the I/O is initiated.
    380 // However, in the case of asynchronous I/O, this I/O may be initiated in
    381 // a much earlier step. Furthermore, when the a Read() or Write() is separated
    382 // from its completion by other Read() or Writes()'s, it can not be marked
    383 // synchronous.  If it is, ERR_UNUEXPECTED will be returned indicating that a
    384 // synchronous Read() or Write() could not be completed synchronously because of
    385 // the specific ordering constraints.
    386 //
    387 // Sequence numbers are preserved across both reads and writes. There should be
    388 // no gaps in sequence numbers, and no repeated sequence numbers. i.e.
    389 //  MockRead reads[] = {
    390 //    MockRead(false, "first read", length, 0)   // sync
    391 //    MockRead(true, "second read", length, 2)   // async
    392 //  };
    393 //  MockWrite writes[] = {
    394 //    MockWrite(true, "first write", length, 1),    // async
    395 //    MockWrite(false, "second write", length, 3),  // sync
    396 //  };
    397 //
    398 // Example control flow:
    399 // Read() is called.  The current step is 0.  The first available read is
    400 // synchronous, so the call to Read() returns length.  The current step is
    401 // now 1.  Next, Read() is called again.  The next available read can
    402 // not be completed until step 2, so Read() returns ERR_IO_PENDING.  The current
    403 // step is still 1.  Write is called().  The first available write is able to
    404 // complete in this step, but is marked asynchronous.  Write() returns
    405 // ERR_IO_PENDING.  The current step is still 1.  At this point RunFor(1) is
    406 // called which will cause the write callback to be invoked, and will then
    407 // stop.  The current state is now 2.  RunFor(1) is called again, which
    408 // causes the read callback to be invoked, and will then stop.  Then current
    409 // step is 2.  Write() is called again.  Then next available write is
    410 // synchronous so the call to Write() returns length.
    411 //
    412 // For examples of how to use this class, see:
    413 //   deterministic_socket_data_unittests.cc
    414 class DeterministicSocketData : public StaticSocketDataProvider,
    415     public base::RefCounted<DeterministicSocketData> {
    416  public:
    417   // |reads| the list of MockRead completions.
    418   // |writes| the list of MockWrite completions.
    419   DeterministicSocketData(MockRead* reads, size_t reads_count,
    420                           MockWrite* writes, size_t writes_count);
    421   virtual ~DeterministicSocketData();
    422 
    423   // Consume all the data up to the give stop point (via SetStop()).
    424   void Run();
    425 
    426   // Set the stop point to be |steps| from now, and then invoke Run().
    427   void RunFor(int steps);
    428 
    429   // Stop at step |seq|, which must be in the future.
    430   virtual void SetStop(int seq);
    431 
    432   // Stop |seq| steps after the current step.
    433   virtual void StopAfter(int seq);
    434   bool stopped() const { return stopped_; }
    435   void SetStopped(bool val) { stopped_ = val; }
    436   MockRead& current_read() { return current_read_; }
    437   MockRead& current_write() { return current_write_; }
    438   int sequence_number() const { return sequence_number_; }
    439   void set_socket(base::WeakPtr<DeterministicMockTCPClientSocket> socket) {
    440     socket_ = socket;
    441   }
    442 
    443   // StaticSocketDataProvider:
    444 
    445   // When the socket calls Read(), that calls GetNextRead(), and expects either
    446   // ERR_IO_PENDING or data.
    447   virtual MockRead GetNextRead();
    448 
    449   // When the socket calls Write(), it always completes synchronously. OnWrite()
    450   // checks to make sure the written data matches the expected data. The
    451   // callback will not be invoked until its sequence number is reached.
    452   virtual MockWriteResult OnWrite(const std::string& data);
    453   virtual void Reset();
    454   virtual void CompleteRead() {}
    455 
    456  private:
    457   // Invoke the read and write callbacks, if the timing is appropriate.
    458   void InvokeCallbacks();
    459 
    460   void NextStep();
    461 
    462   int sequence_number_;
    463   MockRead current_read_;
    464   MockWrite current_write_;
    465   int stopping_sequence_number_;
    466   bool stopped_;
    467   base::WeakPtr<DeterministicMockTCPClientSocket> socket_;
    468   bool print_debug_;
    469 };
    470 
    471 // Holds an array of SocketDataProvider elements.  As Mock{TCP,SSL}ClientSocket
    472 // objects get instantiated, they take their data from the i'th element of this
    473 // array.
    474 template<typename T>
    475 class SocketDataProviderArray {
    476  public:
    477   SocketDataProviderArray() : next_index_(0) {
    478   }
    479 
    480   T* GetNext() {
    481     DCHECK_LT(next_index_, data_providers_.size());
    482     return data_providers_[next_index_++];
    483   }
    484 
    485   void Add(T* data_provider) {
    486     DCHECK(data_provider);
    487     data_providers_.push_back(data_provider);
    488   }
    489 
    490   void ResetNextIndex() {
    491     next_index_ = 0;
    492   }
    493 
    494  private:
    495   // Index of the next |data_providers_| element to use. Not an iterator
    496   // because those are invalidated on vector reallocation.
    497   size_t next_index_;
    498 
    499   // SocketDataProviders to be returned.
    500   std::vector<T*> data_providers_;
    501 };
    502 
    503 class MockTCPClientSocket;
    504 class MockSSLClientSocket;
    505 
    506 // ClientSocketFactory which contains arrays of sockets of each type.
    507 // You should first fill the arrays using AddMock{SSL,}Socket. When the factory
    508 // is asked to create a socket, it takes next entry from appropriate array.
    509 // You can use ResetNextMockIndexes to reset that next entry index for all mock
    510 // socket types.
    511 class MockClientSocketFactory : public ClientSocketFactory {
    512  public:
    513   MockClientSocketFactory();
    514   virtual ~MockClientSocketFactory();
    515 
    516   void AddSocketDataProvider(SocketDataProvider* socket);
    517   void AddSSLSocketDataProvider(SSLSocketDataProvider* socket);
    518   void ResetNextMockIndexes();
    519 
    520   // Return |index|-th MockTCPClientSocket (starting from 0) that the factory
    521   // created.
    522   MockTCPClientSocket* GetMockTCPClientSocket(size_t index) const;
    523 
    524   // Return |index|-th MockSSLClientSocket (starting from 0) that the factory
    525   // created.
    526   MockSSLClientSocket* GetMockSSLClientSocket(size_t index) const;
    527 
    528   SocketDataProviderArray<SocketDataProvider>& mock_data() {
    529     return mock_data_;
    530   }
    531   std::vector<MockTCPClientSocket*>& tcp_client_sockets() {
    532     return tcp_client_sockets_;
    533   }
    534 
    535   // ClientSocketFactory
    536   virtual ClientSocket* CreateTransportClientSocket(
    537       const AddressList& addresses,
    538       NetLog* net_log,
    539       const NetLog::Source& source);
    540   virtual SSLClientSocket* CreateSSLClientSocket(
    541       ClientSocketHandle* transport_socket,
    542       const HostPortPair& host_and_port,
    543       const SSLConfig& ssl_config,
    544       SSLHostInfo* ssl_host_info,
    545       CertVerifier* cert_verifier,
    546       DnsCertProvenanceChecker* dns_cert_checker);
    547   virtual void ClearSSLSessionCache();
    548 
    549  private:
    550   SocketDataProviderArray<SocketDataProvider> mock_data_;
    551   SocketDataProviderArray<SSLSocketDataProvider> mock_ssl_data_;
    552 
    553   // Store pointers to handed out sockets in case the test wants to get them.
    554   std::vector<MockTCPClientSocket*> tcp_client_sockets_;
    555   std::vector<MockSSLClientSocket*> ssl_client_sockets_;
    556 };
    557 
    558 class MockClientSocket : public net::SSLClientSocket {
    559  public:
    560   explicit MockClientSocket(net::NetLog* net_log);
    561 
    562   // If an async IO is pending because the SocketDataProvider returned
    563   // ERR_IO_PENDING, then the MockClientSocket waits until this OnReadComplete
    564   // is called to complete the asynchronous read operation.
    565   // data.async is ignored, and this read is completed synchronously as
    566   // part of this call.
    567   virtual void OnReadComplete(const MockRead& data) = 0;
    568 
    569   // Socket methods:
    570   virtual int Read(net::IOBuffer* buf, int buf_len,
    571                    net::CompletionCallback* callback) = 0;
    572   virtual int Write(net::IOBuffer* buf, int buf_len,
    573                     net::CompletionCallback* callback) = 0;
    574   virtual bool SetReceiveBufferSize(int32 size);
    575   virtual bool SetSendBufferSize(int32 size);
    576 
    577   // ClientSocket methods:
    578   virtual int Connect(net::CompletionCallback* callback) = 0;
    579   virtual void Disconnect();
    580   virtual bool IsConnected() const;
    581   virtual bool IsConnectedAndIdle() const;
    582   virtual int GetPeerAddress(AddressList* address) const;
    583   virtual int GetLocalAddress(IPEndPoint* address) const;
    584   virtual const BoundNetLog& NetLog() const;
    585   virtual void SetSubresourceSpeculation() {}
    586   virtual void SetOmniboxSpeculation() {}
    587 
    588   // SSLClientSocket methods:
    589   virtual void GetSSLInfo(net::SSLInfo* ssl_info);
    590   virtual void GetSSLCertRequestInfo(
    591       net::SSLCertRequestInfo* cert_request_info);
    592   virtual NextProtoStatus GetNextProto(std::string* proto);
    593 
    594  protected:
    595   virtual ~MockClientSocket();
    596   void RunCallbackAsync(net::CompletionCallback* callback, int result);
    597   void RunCallback(net::CompletionCallback*, int result);
    598 
    599   ScopedRunnableMethodFactory<MockClientSocket> method_factory_;
    600 
    601   // True if Connect completed successfully and Disconnect hasn't been called.
    602   bool connected_;
    603 
    604   net::BoundNetLog net_log_;
    605 };
    606 
    607 class MockTCPClientSocket : public MockClientSocket {
    608  public:
    609   MockTCPClientSocket(const net::AddressList& addresses, net::NetLog* net_log,
    610                       net::SocketDataProvider* socket);
    611 
    612   net::AddressList addresses() const { return addresses_; }
    613 
    614   // Socket methods:
    615   virtual int Read(net::IOBuffer* buf, int buf_len,
    616                    net::CompletionCallback* callback);
    617   virtual int Write(net::IOBuffer* buf, int buf_len,
    618                     net::CompletionCallback* callback);
    619 
    620   // ClientSocket methods:
    621   virtual int Connect(net::CompletionCallback* callback);
    622   virtual void Disconnect();
    623   virtual bool IsConnected() const;
    624   virtual bool IsConnectedAndIdle() const;
    625   virtual int GetPeerAddress(AddressList* address) const;
    626   virtual bool WasEverUsed() const;
    627   virtual bool UsingTCPFastOpen() const;
    628 
    629   // MockClientSocket:
    630   virtual void OnReadComplete(const MockRead& data);
    631 
    632  private:
    633   int CompleteRead();
    634 
    635   net::AddressList addresses_;
    636 
    637   net::SocketDataProvider* data_;
    638   int read_offset_;
    639   net::MockRead read_data_;
    640   bool need_read_data_;
    641 
    642   // True if the peer has closed the connection.  This allows us to simulate
    643   // the recv(..., MSG_PEEK) call in the IsConnectedAndIdle method of the real
    644   // TCPClientSocket.
    645   bool peer_closed_connection_;
    646 
    647   // While an asynchronous IO is pending, we save our user-buffer state.
    648   net::IOBuffer* pending_buf_;
    649   int pending_buf_len_;
    650   net::CompletionCallback* pending_callback_;
    651   bool was_used_to_convey_data_;
    652 };
    653 
    654 class DeterministicMockTCPClientSocket : public MockClientSocket,
    655     public base::SupportsWeakPtr<DeterministicMockTCPClientSocket> {
    656  public:
    657   DeterministicMockTCPClientSocket(net::NetLog* net_log,
    658       net::DeterministicSocketData* data);
    659   virtual ~DeterministicMockTCPClientSocket();
    660 
    661   bool write_pending() const { return write_pending_; }
    662   bool read_pending() const { return read_pending_; }
    663 
    664   void CompleteWrite();
    665   int CompleteRead();
    666 
    667   // Socket:
    668   virtual int Write(net::IOBuffer* buf, int buf_len,
    669                     net::CompletionCallback* callback);
    670   virtual int Read(net::IOBuffer* buf, int buf_len,
    671                    net::CompletionCallback* callback);
    672 
    673   // ClientSocket:
    674   virtual int Connect(net::CompletionCallback* callback);
    675   virtual void Disconnect();
    676   virtual bool IsConnected() const;
    677   virtual bool IsConnectedAndIdle() const;
    678   virtual bool WasEverUsed() const;
    679   virtual bool UsingTCPFastOpen() const;
    680 
    681   // MockClientSocket:
    682   virtual void OnReadComplete(const MockRead& data);
    683 
    684  private:
    685   bool write_pending_;
    686   net::CompletionCallback* write_callback_;
    687   int write_result_;
    688 
    689   net::MockRead read_data_;
    690 
    691   net::IOBuffer* read_buf_;
    692   int read_buf_len_;
    693   bool read_pending_;
    694   net::CompletionCallback* read_callback_;
    695   net::DeterministicSocketData* data_;
    696   bool was_used_to_convey_data_;
    697 };
    698 
    699 class MockSSLClientSocket : public MockClientSocket {
    700  public:
    701   MockSSLClientSocket(
    702       net::ClientSocketHandle* transport_socket,
    703       const HostPortPair& host_and_port,
    704       const net::SSLConfig& ssl_config,
    705       SSLHostInfo* ssl_host_info,
    706       net::SSLSocketDataProvider* socket);
    707   virtual ~MockSSLClientSocket();
    708 
    709   // Socket methods:
    710   virtual int Read(net::IOBuffer* buf, int buf_len,
    711                    net::CompletionCallback* callback);
    712   virtual int Write(net::IOBuffer* buf, int buf_len,
    713                     net::CompletionCallback* callback);
    714 
    715   // ClientSocket methods:
    716   virtual int Connect(net::CompletionCallback* callback);
    717   virtual void Disconnect();
    718   virtual bool IsConnected() const;
    719   virtual bool WasEverUsed() const;
    720   virtual bool UsingTCPFastOpen() const;
    721 
    722   // SSLClientSocket methods:
    723   virtual void GetSSLInfo(net::SSLInfo* ssl_info);
    724   virtual void GetSSLCertRequestInfo(
    725       net::SSLCertRequestInfo* cert_request_info);
    726   virtual NextProtoStatus GetNextProto(std::string* proto);
    727   virtual bool was_npn_negotiated() const;
    728   virtual bool set_was_npn_negotiated(bool negotiated);
    729 
    730   // This MockSocket does not implement the manual async IO feature.
    731   virtual void OnReadComplete(const MockRead& data);
    732 
    733  private:
    734   class ConnectCallback;
    735 
    736   scoped_ptr<ClientSocketHandle> transport_;
    737   net::SSLSocketDataProvider* data_;
    738   bool is_npn_state_set_;
    739   bool new_npn_value_;
    740   bool was_used_to_convey_data_;
    741 };
    742 
    743 class TestSocketRequest : public CallbackRunner< Tuple1<int> > {
    744  public:
    745   TestSocketRequest(
    746       std::vector<TestSocketRequest*>* request_order,
    747       size_t* completion_count);
    748   virtual ~TestSocketRequest();
    749 
    750   ClientSocketHandle* handle() { return &handle_; }
    751 
    752   int WaitForResult();
    753   virtual void RunWithParams(const Tuple1<int>& params);
    754 
    755  private:
    756   ClientSocketHandle handle_;
    757   std::vector<TestSocketRequest*>* request_order_;
    758   size_t* completion_count_;
    759   TestCompletionCallback callback_;
    760 };
    761 
    762 class ClientSocketPoolTest {
    763  public:
    764   enum KeepAlive {
    765     KEEP_ALIVE,
    766 
    767     // A socket will be disconnected in addition to handle being reset.
    768     NO_KEEP_ALIVE,
    769   };
    770 
    771   static const int kIndexOutOfBounds;
    772   static const int kRequestNotFound;
    773 
    774   ClientSocketPoolTest();
    775   ~ClientSocketPoolTest();
    776 
    777   template <typename PoolType, typename SocketParams>
    778   int StartRequestUsingPool(PoolType* socket_pool,
    779                             const std::string& group_name,
    780                             RequestPriority priority,
    781                             const scoped_refptr<SocketParams>& socket_params) {
    782     DCHECK(socket_pool);
    783     TestSocketRequest* request = new TestSocketRequest(&request_order_,
    784                                                        &completion_count_);
    785     requests_.push_back(request);
    786     int rv = request->handle()->Init(
    787         group_name, socket_params, priority, request,
    788         socket_pool, BoundNetLog());
    789     if (rv != ERR_IO_PENDING)
    790       request_order_.push_back(request);
    791     return rv;
    792   }
    793 
    794   // Provided there were n requests started, takes |index| in range 1..n
    795   // and returns order in which that request completed, in range 1..n,
    796   // or kIndexOutOfBounds if |index| is out of bounds, or kRequestNotFound
    797   // if that request did not complete (for example was canceled).
    798   int GetOrderOfRequest(size_t index) const;
    799 
    800   // Resets first initialized socket handle from |requests_|. If found such
    801   // a handle, returns true.
    802   bool ReleaseOneConnection(KeepAlive keep_alive);
    803 
    804   // Releases connections until there is nothing to release.
    805   void ReleaseAllConnections(KeepAlive keep_alive);
    806 
    807   TestSocketRequest* request(int i) { return requests_[i]; }
    808   size_t requests_size() const { return requests_.size(); }
    809   ScopedVector<TestSocketRequest>* requests() { return &requests_; }
    810   size_t completion_count() const { return completion_count_; }
    811 
    812  private:
    813   ScopedVector<TestSocketRequest> requests_;
    814   std::vector<TestSocketRequest*> request_order_;
    815   size_t completion_count_;
    816 };
    817 
    818 class MockTransportClientSocketPool : public TransportClientSocketPool {
    819  public:
    820   class MockConnectJob {
    821    public:
    822     MockConnectJob(ClientSocket* socket, ClientSocketHandle* handle,
    823                    CompletionCallback* callback);
    824     ~MockConnectJob();
    825 
    826     int Connect();
    827     bool CancelHandle(const ClientSocketHandle* handle);
    828 
    829    private:
    830     void OnConnect(int rv);
    831 
    832     scoped_ptr<ClientSocket> socket_;
    833     ClientSocketHandle* handle_;
    834     CompletionCallback* user_callback_;
    835     CompletionCallbackImpl<MockConnectJob> connect_callback_;
    836 
    837     DISALLOW_COPY_AND_ASSIGN(MockConnectJob);
    838   };
    839 
    840   MockTransportClientSocketPool(
    841       int max_sockets,
    842       int max_sockets_per_group,
    843       ClientSocketPoolHistograms* histograms,
    844       ClientSocketFactory* socket_factory);
    845 
    846   virtual ~MockTransportClientSocketPool();
    847 
    848   int release_count() const { return release_count_; }
    849   int cancel_count() const { return cancel_count_; }
    850 
    851   // TransportClientSocketPool methods.
    852   virtual int RequestSocket(const std::string& group_name,
    853                             const void* socket_params,
    854                             RequestPriority priority,
    855                             ClientSocketHandle* handle,
    856                             CompletionCallback* callback,
    857                             const BoundNetLog& net_log);
    858 
    859   virtual void CancelRequest(const std::string& group_name,
    860                              ClientSocketHandle* handle);
    861   virtual void ReleaseSocket(const std::string& group_name,
    862                              ClientSocket* socket, int id);
    863 
    864  private:
    865   ClientSocketFactory* client_socket_factory_;
    866   ScopedVector<MockConnectJob> job_list_;
    867   int release_count_;
    868   int cancel_count_;
    869 
    870   DISALLOW_COPY_AND_ASSIGN(MockTransportClientSocketPool);
    871 };
    872 
    873 class DeterministicMockClientSocketFactory : public ClientSocketFactory {
    874  public:
    875   DeterministicMockClientSocketFactory();
    876   virtual ~DeterministicMockClientSocketFactory();
    877 
    878   void AddSocketDataProvider(DeterministicSocketData* socket);
    879   void AddSSLSocketDataProvider(SSLSocketDataProvider* socket);
    880   void ResetNextMockIndexes();
    881 
    882   // Return |index|-th MockSSLClientSocket (starting from 0) that the factory
    883   // created.
    884   MockSSLClientSocket* GetMockSSLClientSocket(size_t index) const;
    885 
    886   SocketDataProviderArray<DeterministicSocketData>& mock_data() {
    887     return mock_data_;
    888   }
    889   std::vector<DeterministicMockTCPClientSocket*>& tcp_client_sockets() {
    890     return tcp_client_sockets_;
    891   }
    892 
    893   // ClientSocketFactory
    894   virtual ClientSocket* CreateTransportClientSocket(
    895       const AddressList& addresses,
    896       NetLog* net_log,
    897       const NetLog::Source& source);
    898   virtual SSLClientSocket* CreateSSLClientSocket(
    899       ClientSocketHandle* transport_socket,
    900       const HostPortPair& host_and_port,
    901       const SSLConfig& ssl_config,
    902       SSLHostInfo* ssl_host_info,
    903       CertVerifier* cert_verifier,
    904       DnsCertProvenanceChecker* dns_cert_checker);
    905   virtual void ClearSSLSessionCache();
    906 
    907  private:
    908   SocketDataProviderArray<DeterministicSocketData> mock_data_;
    909   SocketDataProviderArray<SSLSocketDataProvider> mock_ssl_data_;
    910 
    911   // Store pointers to handed out sockets in case the test wants to get them.
    912   std::vector<DeterministicMockTCPClientSocket*> tcp_client_sockets_;
    913   std::vector<MockSSLClientSocket*> ssl_client_sockets_;
    914 };
    915 
    916 class MockSOCKSClientSocketPool : public SOCKSClientSocketPool {
    917  public:
    918   MockSOCKSClientSocketPool(
    919       int max_sockets,
    920       int max_sockets_per_group,
    921       ClientSocketPoolHistograms* histograms,
    922       TransportClientSocketPool* transport_pool);
    923 
    924   virtual ~MockSOCKSClientSocketPool();
    925 
    926   // SOCKSClientSocketPool methods.
    927   virtual int RequestSocket(const std::string& group_name,
    928                             const void* socket_params,
    929                             RequestPriority priority,
    930                             ClientSocketHandle* handle,
    931                             CompletionCallback* callback,
    932                             const BoundNetLog& net_log);
    933 
    934   virtual void CancelRequest(const std::string& group_name,
    935                              ClientSocketHandle* handle);
    936   virtual void ReleaseSocket(const std::string& group_name,
    937                              ClientSocket* socket, int id);
    938 
    939  private:
    940   TransportClientSocketPool* const transport_pool_;
    941 
    942   DISALLOW_COPY_AND_ASSIGN(MockSOCKSClientSocketPool);
    943 };
    944 
    945 // Constants for a successful SOCKS v5 handshake.
    946 extern const char kSOCKS5GreetRequest[];
    947 extern const int kSOCKS5GreetRequestLength;
    948 
    949 extern const char kSOCKS5GreetResponse[];
    950 extern const int kSOCKS5GreetResponseLength;
    951 
    952 extern const char kSOCKS5OkRequest[];
    953 extern const int kSOCKS5OkRequestLength;
    954 
    955 extern const char kSOCKS5OkResponse[];
    956 extern const int kSOCKS5OkResponseLength;
    957 
    958 }  // namespace net
    959 
    960 #endif  // NET_SOCKET_SOCKET_TEST_UTIL_H_
    961