1 // Copyright (c) 2009 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #ifndef NET_SOCKET_SOCKET_TEST_UTIL_H_ 6 #define NET_SOCKET_SOCKET_TEST_UTIL_H_ 7 8 #include <deque> 9 #include <string> 10 #include <vector> 11 12 #include "base/basictypes.h" 13 #include "base/logging.h" 14 #include "base/scoped_ptr.h" 15 #include "base/scoped_vector.h" 16 #include "net/base/address_list.h" 17 #include "net/base/io_buffer.h" 18 #include "net/base/net_errors.h" 19 #include "net/base/ssl_config_service.h" 20 #include "net/base/test_completion_callback.h" 21 #include "net/socket/client_socket_factory.h" 22 #include "net/socket/client_socket_handle.h" 23 #include "net/socket/ssl_client_socket.h" 24 #include "testing/gtest/include/gtest/gtest.h" 25 26 namespace net { 27 28 enum { 29 // A private network error code used by the socket test utility classes. 30 // If the |result| member of a MockRead is 31 // ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ, that MockRead is just a 32 // marker that indicates the peer will close the connection after the next 33 // MockRead. The other members of that MockRead are ignored. 34 ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ = -10000, 35 }; 36 37 class ClientSocket; 38 class LoadLog; 39 class MockClientSocket; 40 class SSLClientSocket; 41 42 struct MockConnect { 43 // Asynchronous connection success. 44 MockConnect() : async(true), result(OK) { } 45 MockConnect(bool a, int r) : async(a), result(r) { } 46 47 bool async; 48 int result; 49 }; 50 51 struct MockRead { 52 // Default 53 MockRead() : async(false), result(0), data(NULL), data_len(0) {} 54 55 // Read failure (no data). 56 MockRead(bool async, int result) : async(async) , result(result), data(NULL), 57 data_len(0) { } 58 59 // Asynchronous read success (inferred data length). 60 explicit MockRead(const char* data) : async(true), result(0), data(data), 61 data_len(strlen(data)) { } 62 63 // Read success (inferred data length). 64 MockRead(bool async, const char* data) : async(async), result(0), data(data), 65 data_len(strlen(data)) { } 66 67 // Read success. 68 MockRead(bool async, const char* data, int data_len) : async(async), 69 result(0), data(data), data_len(data_len) { } 70 71 bool async; 72 int result; 73 const char* data; 74 int data_len; 75 }; 76 77 // MockWrite uses the same member fields as MockRead, but with different 78 // meanings. The expected input to MockTCPClientSocket::Write() is given 79 // by {data, data_len}, and the return value of Write() is controlled by 80 // {async, result}. 81 typedef MockRead MockWrite; 82 83 struct MockWriteResult { 84 MockWriteResult(bool async, int result) : async(async), result(result) {} 85 86 bool async; 87 int result; 88 }; 89 90 // The SocketDataProvider is an interface used by the MockClientSocket 91 // for getting data about individual reads and writes on the socket. 92 class SocketDataProvider { 93 public: 94 SocketDataProvider() : socket_(NULL) {} 95 96 virtual ~SocketDataProvider() {} 97 98 // Returns the buffer and result code for the next simulated read. 99 // If the |MockRead.result| is ERR_IO_PENDING, it informs the caller 100 // that it will be called via the MockClientSocket::OnReadComplete() 101 // function at a later time. 102 virtual MockRead GetNextRead() = 0; 103 virtual MockWriteResult OnWrite(const std::string& data) = 0; 104 virtual void Reset() = 0; 105 106 // Accessor for the socket which is using the SocketDataProvider. 107 MockClientSocket* socket() { return socket_; } 108 void set_socket(MockClientSocket* socket) { socket_ = socket; } 109 110 MockConnect connect_data() const { return connect_; } 111 void set_connect_data(const MockConnect& connect) { connect_ = connect; } 112 113 private: 114 MockConnect connect_; 115 MockClientSocket* socket_; 116 117 DISALLOW_COPY_AND_ASSIGN(SocketDataProvider); 118 }; 119 120 // SocketDataProvider which responds based on static tables of mock reads and 121 // writes. 122 class StaticSocketDataProvider : public SocketDataProvider { 123 public: 124 StaticSocketDataProvider() : reads_(NULL), read_index_(0), 125 writes_(NULL), write_index_(0) {} 126 StaticSocketDataProvider(MockRead* r, MockWrite* w) : reads_(r), 127 read_index_(0), writes_(w), write_index_(0) {} 128 129 // SocketDataProvider methods: 130 virtual MockRead GetNextRead(); 131 virtual MockWriteResult OnWrite(const std::string& data); 132 virtual void Reset(); 133 134 // If the test wishes to verify that all data is consumed, it can include 135 // a EOF MockRead or MockWrite, which is a zero-length Read or Write. 136 // The test can then call at_read_eof() or at_write_eof() to verify that 137 // all data has been consumed. 138 bool at_read_eof() const { return reads_[read_index_].data_len == 0; } 139 bool at_write_eof() const { return writes_[write_index_].data_len == 0; } 140 141 private: 142 MockRead* reads_; 143 int read_index_; 144 MockWrite* writes_; 145 int write_index_; 146 147 DISALLOW_COPY_AND_ASSIGN(StaticSocketDataProvider); 148 }; 149 150 // SocketDataProvider which can make decisions about next mock reads based on 151 // received writes. It can also be used to enforce order of operations, for 152 // example that tested code must send the "Hello!" message before receiving 153 // response. This is useful for testing conversation-like protocols like FTP. 154 class DynamicSocketDataProvider : public SocketDataProvider { 155 public: 156 DynamicSocketDataProvider(); 157 158 // SocketDataProvider methods: 159 virtual MockRead GetNextRead(); 160 virtual MockWriteResult OnWrite(const std::string& data) = 0; 161 virtual void Reset(); 162 163 int short_read_limit() const { return short_read_limit_; } 164 void set_short_read_limit(int limit) { short_read_limit_ = limit; } 165 166 void allow_unconsumed_reads(bool allow) { allow_unconsumed_reads_ = allow; } 167 168 protected: 169 // The next time there is a read from this socket, it will return |data|. 170 // Before calling SimulateRead next time, the previous data must be consumed. 171 void SimulateRead(const char* data); 172 173 private: 174 std::deque<MockRead> reads_; 175 176 // Max number of bytes we will read at a time. 0 means no limit. 177 int short_read_limit_; 178 179 // If true, we'll not require the client to consume all data before we 180 // mock the next read. 181 bool allow_unconsumed_reads_; 182 183 DISALLOW_COPY_AND_ASSIGN(DynamicSocketDataProvider); 184 }; 185 186 // SSLSocketDataProviders only need to keep track of the return code from calls 187 // to Connect(). 188 struct SSLSocketDataProvider { 189 SSLSocketDataProvider(bool async, int result) : connect(async, result) { } 190 191 MockConnect connect; 192 }; 193 194 // Holds an array of SocketDataProvider elements. As Mock{TCP,SSL}ClientSocket 195 // objects get instantiated, they take their data from the i'th element of this 196 // array. 197 template<typename T> 198 class SocketDataProviderArray { 199 public: 200 SocketDataProviderArray() : next_index_(0) { 201 } 202 203 T* GetNext() { 204 DCHECK(next_index_ < data_providers_.size()); 205 return data_providers_[next_index_++]; 206 } 207 208 void Add(T* data_provider) { 209 DCHECK(data_provider); 210 data_providers_.push_back(data_provider); 211 } 212 213 void ResetNextIndex() { 214 next_index_ = 0; 215 } 216 217 private: 218 // Index of the next |data_providers_| element to use. Not an iterator 219 // because those are invalidated on vector reallocation. 220 size_t next_index_; 221 222 // SocketDataProviders to be returned. 223 std::vector<T*> data_providers_; 224 }; 225 226 class MockTCPClientSocket; 227 class MockSSLClientSocket; 228 229 // ClientSocketFactory which contains arrays of sockets of each type. 230 // You should first fill the arrays using AddMock{SSL,}Socket. When the factory 231 // is asked to create a socket, it takes next entry from appropriate array. 232 // You can use ResetNextMockIndexes to reset that next entry index for all mock 233 // socket types. 234 class MockClientSocketFactory : public ClientSocketFactory { 235 public: 236 void AddSocketDataProvider(SocketDataProvider* socket); 237 void AddSSLSocketDataProvider(SSLSocketDataProvider* socket); 238 void ResetNextMockIndexes(); 239 240 // Return |index|-th MockTCPClientSocket (starting from 0) that the factory 241 // created. 242 MockTCPClientSocket* GetMockTCPClientSocket(int index) const; 243 244 // Return |index|-th MockSSLClientSocket (starting from 0) that the factory 245 // created. 246 MockSSLClientSocket* GetMockSSLClientSocket(int index) const; 247 248 // ClientSocketFactory 249 virtual ClientSocket* CreateTCPClientSocket(const AddressList& addresses); 250 virtual SSLClientSocket* CreateSSLClientSocket( 251 ClientSocket* transport_socket, 252 const std::string& hostname, 253 const SSLConfig& ssl_config); 254 255 private: 256 SocketDataProviderArray<SocketDataProvider> mock_data_; 257 SocketDataProviderArray<SSLSocketDataProvider> mock_ssl_data_; 258 259 // Store pointers to handed out sockets in case the test wants to get them. 260 std::vector<MockTCPClientSocket*> tcp_client_sockets_; 261 std::vector<MockSSLClientSocket*> ssl_client_sockets_; 262 }; 263 264 class MockClientSocket : public net::SSLClientSocket { 265 public: 266 MockClientSocket(); 267 268 // ClientSocket methods: 269 virtual int Connect(net::CompletionCallback* callback, LoadLog* load_log) = 0; 270 virtual void Disconnect(); 271 virtual bool IsConnected() const; 272 virtual bool IsConnectedAndIdle() const; 273 virtual int GetPeerName(struct sockaddr* name, socklen_t* namelen); 274 275 // SSLClientSocket methods: 276 virtual void GetSSLInfo(net::SSLInfo* ssl_info); 277 virtual void GetSSLCertRequestInfo( 278 net::SSLCertRequestInfo* cert_request_info); 279 virtual NextProtoStatus GetNextProto(std::string* proto); 280 281 // Socket methods: 282 virtual int Read(net::IOBuffer* buf, int buf_len, 283 net::CompletionCallback* callback) = 0; 284 virtual int Write(net::IOBuffer* buf, int buf_len, 285 net::CompletionCallback* callback) = 0; 286 virtual bool SetReceiveBufferSize(int32 size) { return true; } 287 virtual bool SetSendBufferSize(int32 size) { return true; } 288 289 // If an async IO is pending because the SocketDataProvider returned 290 // ERR_IO_PENDING, then the MockClientSocket waits until this OnReadComplete 291 // is called to complete the asynchronous read operation. 292 // data.async is ignored, and this read is completed synchronously as 293 // part of this call. 294 virtual void OnReadComplete(const MockRead& data) = 0; 295 296 protected: 297 void RunCallbackAsync(net::CompletionCallback* callback, int result); 298 void RunCallback(net::CompletionCallback*, int result); 299 300 ScopedRunnableMethodFactory<MockClientSocket> method_factory_; 301 302 // True if Connect completed successfully and Disconnect hasn't been called. 303 bool connected_; 304 }; 305 306 class MockTCPClientSocket : public MockClientSocket { 307 public: 308 MockTCPClientSocket(const net::AddressList& addresses, 309 net::SocketDataProvider* socket); 310 311 // ClientSocket methods: 312 virtual int Connect(net::CompletionCallback* callback, 313 LoadLog* load_log); 314 virtual bool IsConnected() const; 315 virtual bool IsConnectedAndIdle() const { return IsConnected(); } 316 317 // Socket methods: 318 virtual int Read(net::IOBuffer* buf, int buf_len, 319 net::CompletionCallback* callback); 320 virtual int Write(net::IOBuffer* buf, int buf_len, 321 net::CompletionCallback* callback); 322 323 virtual void OnReadComplete(const MockRead& data); 324 325 net::AddressList addresses() const { return addresses_; } 326 327 private: 328 int CompleteRead(); 329 330 net::AddressList addresses_; 331 332 net::SocketDataProvider* data_; 333 int read_offset_; 334 net::MockRead read_data_; 335 bool need_read_data_; 336 337 // True if the peer has closed the connection. This allows us to simulate 338 // the recv(..., MSG_PEEK) call in the IsConnectedAndIdle method of the real 339 // TCPClientSocket. 340 bool peer_closed_connection_; 341 342 // While an asynchronous IO is pending, we save our user-buffer state. 343 net::IOBuffer* pending_buf_; 344 int pending_buf_len_; 345 net::CompletionCallback* pending_callback_; 346 }; 347 348 class MockSSLClientSocket : public MockClientSocket { 349 public: 350 MockSSLClientSocket( 351 net::ClientSocket* transport_socket, 352 const std::string& hostname, 353 const net::SSLConfig& ssl_config, 354 net::SSLSocketDataProvider* socket); 355 ~MockSSLClientSocket(); 356 357 virtual void GetSSLInfo(net::SSLInfo* ssl_info); 358 359 virtual int Connect(net::CompletionCallback* callback, LoadLog* load_log); 360 virtual void Disconnect(); 361 362 // Socket methods: 363 virtual int Read(net::IOBuffer* buf, int buf_len, 364 net::CompletionCallback* callback); 365 virtual int Write(net::IOBuffer* buf, int buf_len, 366 net::CompletionCallback* callback); 367 368 // This MockSocket does not implement the manual async IO feature. 369 virtual void OnReadComplete(const MockRead& data) { NOTIMPLEMENTED(); } 370 371 private: 372 class ConnectCallback; 373 374 scoped_ptr<ClientSocket> transport_; 375 net::SSLSocketDataProvider* data_; 376 }; 377 378 class TestSocketRequest : public CallbackRunner< Tuple1<int> > { 379 public: 380 TestSocketRequest( 381 std::vector<TestSocketRequest*>* request_order, 382 size_t* completion_count) 383 : request_order_(request_order), 384 completion_count_(completion_count) { 385 DCHECK(request_order); 386 DCHECK(completion_count); 387 } 388 389 ClientSocketHandle* handle() { return &handle_; } 390 391 int WaitForResult(); 392 virtual void RunWithParams(const Tuple1<int>& params); 393 394 private: 395 ClientSocketHandle handle_; 396 std::vector<TestSocketRequest*>* request_order_; 397 size_t* completion_count_; 398 TestCompletionCallback callback_; 399 }; 400 401 class ClientSocketPoolTest : public testing::Test { 402 protected: 403 enum KeepAlive { 404 KEEP_ALIVE, 405 406 // A socket will be disconnected in addition to handle being reset. 407 NO_KEEP_ALIVE, 408 }; 409 410 static const int kIndexOutOfBounds; 411 static const int kRequestNotFound; 412 413 virtual void SetUp(); 414 virtual void TearDown(); 415 416 template <typename PoolType, typename SocketParams> 417 int StartRequestUsingPool(PoolType* socket_pool, 418 const std::string& group_name, 419 RequestPriority priority, 420 const SocketParams& socket_params) { 421 DCHECK(socket_pool); 422 TestSocketRequest* request = new TestSocketRequest(&request_order_, 423 &completion_count_); 424 requests_.push_back(request); 425 int rv = request->handle()->Init( 426 group_name, socket_params, priority, request, 427 socket_pool, NULL); 428 if (rv != ERR_IO_PENDING) 429 request_order_.push_back(request); 430 return rv; 431 } 432 433 // Provided there were n requests started, takes |index| in range 1..n 434 // and returns order in which that request completed, in range 1..n, 435 // or kIndexOutOfBounds if |index| is out of bounds, or kRequestNotFound 436 // if that request did not complete (for example was canceled). 437 int GetOrderOfRequest(size_t index); 438 439 // Resets first initialized socket handle from |requests_|. If found such 440 // a handle, returns true. 441 bool ReleaseOneConnection(KeepAlive keep_alive); 442 443 // Releases connections until there is nothing to release. 444 void ReleaseAllConnections(KeepAlive keep_alive); 445 446 ScopedVector<TestSocketRequest> requests_; 447 std::vector<TestSocketRequest*> request_order_; 448 size_t completion_count_; 449 }; 450 451 } // namespace net 452 453 #endif // NET_SOCKET_SOCKET_TEST_UTIL_H_ 454