1 // Copyright (c) 2009 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "net/socket/socket_test_util.h" 6 7 #include <algorithm> 8 9 #include "base/basictypes.h" 10 #include "base/compiler_specific.h" 11 #include "base/message_loop.h" 12 #include "net/base/ssl_info.h" 13 #include "net/socket/socket.h" 14 #include "testing/gtest/include/gtest/gtest.h" 15 16 namespace net { 17 18 MockClientSocket::MockClientSocket() 19 : ALLOW_THIS_IN_INITIALIZER_LIST(method_factory_(this)), 20 connected_(false) { 21 } 22 23 void MockClientSocket::GetSSLInfo(net::SSLInfo* ssl_info) { 24 NOTREACHED(); 25 } 26 27 void MockClientSocket::GetSSLCertRequestInfo( 28 net::SSLCertRequestInfo* cert_request_info) { 29 NOTREACHED(); 30 } 31 32 SSLClientSocket::NextProtoStatus 33 MockClientSocket::GetNextProto(std::string* proto) { 34 proto->clear(); 35 return SSLClientSocket::kNextProtoUnsupported; 36 } 37 38 void MockClientSocket::Disconnect() { 39 connected_ = false; 40 } 41 42 bool MockClientSocket::IsConnected() const { 43 return connected_; 44 } 45 46 bool MockClientSocket::IsConnectedAndIdle() const { 47 return connected_; 48 } 49 50 int MockClientSocket::GetPeerName(struct sockaddr* name, socklen_t* namelen) { 51 memset(reinterpret_cast<char *>(name), 0, *namelen); 52 return net::OK; 53 } 54 55 void MockClientSocket::RunCallbackAsync(net::CompletionCallback* callback, 56 int result) { 57 MessageLoop::current()->PostTask(FROM_HERE, 58 method_factory_.NewRunnableMethod( 59 &MockClientSocket::RunCallback, callback, result)); 60 } 61 62 void MockClientSocket::RunCallback(net::CompletionCallback* callback, 63 int result) { 64 if (callback) 65 callback->Run(result); 66 } 67 68 MockTCPClientSocket::MockTCPClientSocket(const net::AddressList& addresses, 69 net::SocketDataProvider* data) 70 : addresses_(addresses), 71 data_(data), 72 read_offset_(0), 73 read_data_(false, net::ERR_UNEXPECTED), 74 need_read_data_(true), 75 peer_closed_connection_(false), 76 pending_buf_(NULL), 77 pending_buf_len_(0), 78 pending_callback_(NULL) { 79 DCHECK(data_); 80 data_->Reset(); 81 } 82 83 int MockTCPClientSocket::Connect(net::CompletionCallback* callback, 84 LoadLog* load_log) { 85 if (connected_) 86 return net::OK; 87 connected_ = true; 88 if (data_->connect_data().async) { 89 RunCallbackAsync(callback, data_->connect_data().result); 90 return net::ERR_IO_PENDING; 91 } 92 return data_->connect_data().result; 93 } 94 95 bool MockTCPClientSocket::IsConnected() const { 96 return connected_ && !peer_closed_connection_; 97 } 98 99 int MockTCPClientSocket::Read(net::IOBuffer* buf, int buf_len, 100 net::CompletionCallback* callback) { 101 if (!connected_) 102 return net::ERR_UNEXPECTED; 103 104 // If the buffer is already in use, a read is already in progress! 105 DCHECK(pending_buf_ == NULL); 106 107 // Store our async IO data. 108 pending_buf_ = buf; 109 pending_buf_len_ = buf_len; 110 pending_callback_ = callback; 111 112 if (need_read_data_) { 113 read_data_ = data_->GetNextRead(); 114 if (read_data_.result == ERR_TEST_PEER_CLOSE_AFTER_NEXT_MOCK_READ) { 115 // This MockRead is just a marker to instruct us to set 116 // peer_closed_connection_. Skip it and get the next one. 117 read_data_ = data_->GetNextRead(); 118 peer_closed_connection_ = true; 119 } 120 // ERR_IO_PENDING means that the SocketDataProvider is taking responsibility 121 // to complete the async IO manually later (via OnReadComplete). 122 if (read_data_.result == ERR_IO_PENDING) { 123 DCHECK(callback); // We need to be using async IO in this case. 124 return ERR_IO_PENDING; 125 } 126 need_read_data_ = false; 127 } 128 129 return CompleteRead(); 130 } 131 132 int MockTCPClientSocket::Write(net::IOBuffer* buf, int buf_len, 133 net::CompletionCallback* callback) { 134 DCHECK(buf); 135 DCHECK_GT(buf_len, 0); 136 137 if (!connected_) 138 return net::ERR_UNEXPECTED; 139 140 std::string data(buf->data(), buf_len); 141 net::MockWriteResult write_result = data_->OnWrite(data); 142 143 if (write_result.async) { 144 RunCallbackAsync(callback, write_result.result); 145 return net::ERR_IO_PENDING; 146 } 147 return write_result.result; 148 } 149 150 void MockTCPClientSocket::OnReadComplete(const MockRead& data) { 151 // There must be a read pending. 152 DCHECK(pending_buf_); 153 // You can't complete a read with another ERR_IO_PENDING status code. 154 DCHECK_NE(ERR_IO_PENDING, data.result); 155 // Since we've been waiting for data, need_read_data_ should be true. 156 DCHECK(need_read_data_); 157 158 read_data_ = data; 159 need_read_data_ = false; 160 161 // The caller is simulating that this IO completes right now. Don't 162 // let CompleteRead() schedule a callback. 163 read_data_.async = false; 164 165 net::CompletionCallback* callback = pending_callback_; 166 int rv = CompleteRead(); 167 RunCallback(callback, rv); 168 } 169 170 int MockTCPClientSocket::CompleteRead() { 171 DCHECK(pending_buf_); 172 DCHECK(pending_buf_len_ > 0); 173 174 // Save the pending async IO data and reset our |pending_| state. 175 net::IOBuffer* buf = pending_buf_; 176 int buf_len = pending_buf_len_; 177 net::CompletionCallback* callback = pending_callback_; 178 pending_buf_ = NULL; 179 pending_buf_len_ = 0; 180 pending_callback_ = NULL; 181 182 int result = read_data_.result; 183 DCHECK(result != ERR_IO_PENDING); 184 185 if (read_data_.data) { 186 if (read_data_.data_len - read_offset_ > 0) { 187 result = std::min(buf_len, read_data_.data_len - read_offset_); 188 memcpy(buf->data(), read_data_.data + read_offset_, result); 189 read_offset_ += result; 190 if (read_offset_ == read_data_.data_len) { 191 need_read_data_ = true; 192 read_offset_ = 0; 193 } 194 } else { 195 result = 0; // EOF 196 } 197 } 198 199 if (read_data_.async) { 200 DCHECK(callback); 201 RunCallbackAsync(callback, result); 202 return net::ERR_IO_PENDING; 203 } 204 return result; 205 } 206 207 class MockSSLClientSocket::ConnectCallback : 208 public net::CompletionCallbackImpl<MockSSLClientSocket::ConnectCallback> { 209 public: 210 ConnectCallback(MockSSLClientSocket *ssl_client_socket, 211 net::CompletionCallback* user_callback, 212 int rv) 213 : ALLOW_THIS_IN_INITIALIZER_LIST( 214 net::CompletionCallbackImpl<MockSSLClientSocket::ConnectCallback>( 215 this, &ConnectCallback::Wrapper)), 216 ssl_client_socket_(ssl_client_socket), 217 user_callback_(user_callback), 218 rv_(rv) { 219 } 220 221 private: 222 void Wrapper(int rv) { 223 if (rv_ == net::OK) 224 ssl_client_socket_->connected_ = true; 225 user_callback_->Run(rv_); 226 delete this; 227 } 228 229 MockSSLClientSocket* ssl_client_socket_; 230 net::CompletionCallback* user_callback_; 231 int rv_; 232 }; 233 234 MockSSLClientSocket::MockSSLClientSocket( 235 net::ClientSocket* transport_socket, 236 const std::string& hostname, 237 const net::SSLConfig& ssl_config, 238 net::SSLSocketDataProvider* data) 239 : transport_(transport_socket), 240 data_(data) { 241 DCHECK(data_); 242 } 243 244 MockSSLClientSocket::~MockSSLClientSocket() { 245 Disconnect(); 246 } 247 248 void MockSSLClientSocket::GetSSLInfo(net::SSLInfo* ssl_info) { 249 ssl_info->Reset(); 250 } 251 252 int MockSSLClientSocket::Connect(net::CompletionCallback* callback, 253 LoadLog* load_log) { 254 ConnectCallback* connect_callback = new ConnectCallback( 255 this, callback, data_->connect.result); 256 int rv = transport_->Connect(connect_callback, load_log); 257 if (rv == net::OK) { 258 delete connect_callback; 259 if (data_->connect.async) { 260 RunCallbackAsync(callback, data_->connect.result); 261 return net::ERR_IO_PENDING; 262 } 263 if (data_->connect.result == net::OK) 264 connected_ = true; 265 return data_->connect.result; 266 } 267 return rv; 268 } 269 270 void MockSSLClientSocket::Disconnect() { 271 MockClientSocket::Disconnect(); 272 if (transport_ != NULL) 273 transport_->Disconnect(); 274 } 275 276 int MockSSLClientSocket::Read(net::IOBuffer* buf, int buf_len, 277 net::CompletionCallback* callback) { 278 return transport_->Read(buf, buf_len, callback); 279 } 280 281 int MockSSLClientSocket::Write(net::IOBuffer* buf, int buf_len, 282 net::CompletionCallback* callback) { 283 return transport_->Write(buf, buf_len, callback); 284 } 285 286 MockRead StaticSocketDataProvider::GetNextRead() { 287 MockRead rv = reads_[read_index_]; 288 if (reads_[read_index_].result != OK || 289 reads_[read_index_].data_len != 0) 290 read_index_++; // Don't advance past an EOF. 291 return rv; 292 } 293 294 MockWriteResult StaticSocketDataProvider::OnWrite(const std::string& data) { 295 if (!writes_) { 296 // Not using mock writes; succeed synchronously. 297 return MockWriteResult(false, data.length()); 298 } 299 300 // Check that what we are writing matches the expectation. 301 // Then give the mocked return value. 302 net::MockWrite* w = &writes_[write_index_++]; 303 int result = w->result; 304 if (w->data) { 305 // Note - we can simulate a partial write here. If the expected data 306 // is a match, but shorter than the write actually written, that is legal. 307 // Example: 308 // Application writes "foobarbaz" (9 bytes) 309 // Expected write was "foo" (3 bytes) 310 // This is a success, and we return 3 to the application. 311 std::string expected_data(w->data, w->data_len); 312 EXPECT_GE(data.length(), expected_data.length()); 313 std::string actual_data(data.substr(0, w->data_len)); 314 EXPECT_EQ(expected_data, actual_data); 315 if (expected_data != actual_data) 316 return MockWriteResult(false, net::ERR_UNEXPECTED); 317 if (result == net::OK) 318 result = w->data_len; 319 } 320 return MockWriteResult(w->async, result); 321 } 322 323 void StaticSocketDataProvider::Reset() { 324 read_index_ = 0; 325 write_index_ = 0; 326 } 327 328 DynamicSocketDataProvider::DynamicSocketDataProvider() 329 : short_read_limit_(0), 330 allow_unconsumed_reads_(false) { 331 } 332 333 MockRead DynamicSocketDataProvider::GetNextRead() { 334 if (reads_.empty()) 335 return MockRead(false, ERR_UNEXPECTED); 336 MockRead result = reads_.front(); 337 if (short_read_limit_ == 0 || result.data_len <= short_read_limit_) { 338 reads_.pop_front(); 339 } else { 340 result.data_len = short_read_limit_; 341 reads_.front().data += result.data_len; 342 reads_.front().data_len -= result.data_len; 343 } 344 return result; 345 } 346 347 void DynamicSocketDataProvider::Reset() { 348 reads_.clear(); 349 } 350 351 void DynamicSocketDataProvider::SimulateRead(const char* data) { 352 if (!allow_unconsumed_reads_) { 353 EXPECT_TRUE(reads_.empty()) << "Unconsumed read: " << reads_.front().data; 354 } 355 reads_.push_back(MockRead(data)); 356 } 357 358 void MockClientSocketFactory::AddSocketDataProvider( 359 SocketDataProvider* data) { 360 mock_data_.Add(data); 361 } 362 363 void MockClientSocketFactory::AddSSLSocketDataProvider( 364 SSLSocketDataProvider* data) { 365 mock_ssl_data_.Add(data); 366 } 367 368 void MockClientSocketFactory::ResetNextMockIndexes() { 369 mock_data_.ResetNextIndex(); 370 mock_ssl_data_.ResetNextIndex(); 371 } 372 373 MockTCPClientSocket* MockClientSocketFactory::GetMockTCPClientSocket( 374 int index) const { 375 return tcp_client_sockets_[index]; 376 } 377 378 MockSSLClientSocket* MockClientSocketFactory::GetMockSSLClientSocket( 379 int index) const { 380 return ssl_client_sockets_[index]; 381 } 382 383 ClientSocket* MockClientSocketFactory::CreateTCPClientSocket( 384 const AddressList& addresses) { 385 SocketDataProvider* data_provider = mock_data_.GetNext(); 386 MockTCPClientSocket* socket = 387 new MockTCPClientSocket(addresses, data_provider); 388 data_provider->set_socket(socket); 389 tcp_client_sockets_.push_back(socket); 390 return socket; 391 } 392 393 SSLClientSocket* MockClientSocketFactory::CreateSSLClientSocket( 394 ClientSocket* transport_socket, 395 const std::string& hostname, 396 const SSLConfig& ssl_config) { 397 MockSSLClientSocket* socket = 398 new MockSSLClientSocket(transport_socket, hostname, ssl_config, 399 mock_ssl_data_.GetNext()); 400 ssl_client_sockets_.push_back(socket); 401 return socket; 402 } 403 404 int TestSocketRequest::WaitForResult() { 405 return callback_.WaitForResult(); 406 } 407 408 void TestSocketRequest::RunWithParams(const Tuple1<int>& params) { 409 callback_.RunWithParams(params); 410 (*completion_count_)++; 411 request_order_->push_back(this); 412 } 413 414 // static 415 const int ClientSocketPoolTest::kIndexOutOfBounds = -1; 416 417 // static 418 const int ClientSocketPoolTest::kRequestNotFound = -2; 419 420 void ClientSocketPoolTest::SetUp() { 421 completion_count_ = 0; 422 } 423 424 void ClientSocketPoolTest::TearDown() { 425 // The tests often call Reset() on handles at the end which may post 426 // DoReleaseSocket() tasks. 427 // Pending tasks created by client_socket_pool_base_unittest.cc are 428 // posted two milliseconds into the future and thus won't become 429 // scheduled until that time. 430 // We wait a few milliseconds to make sure that all such future tasks 431 // are ready to run, before calling RunAllPending(). This will work 432 // correctly even if Sleep() finishes late (and it should never finish 433 // early), as all we have to ensure is that actual wall-time has progressed 434 // past the scheduled starting time of the pending task. 435 PlatformThread::Sleep(10); 436 MessageLoop::current()->RunAllPending(); 437 } 438 439 int ClientSocketPoolTest::GetOrderOfRequest(size_t index) { 440 index--; 441 if (index >= requests_.size()) 442 return kIndexOutOfBounds; 443 444 for (size_t i = 0; i < request_order_.size(); i++) 445 if (requests_[index] == request_order_[i]) 446 return i + 1; 447 448 return kRequestNotFound; 449 } 450 451 bool ClientSocketPoolTest::ReleaseOneConnection(KeepAlive keep_alive) { 452 ScopedVector<TestSocketRequest>::iterator i; 453 for (i = requests_.begin(); i != requests_.end(); ++i) { 454 if ((*i)->handle()->is_initialized()) { 455 if (keep_alive == NO_KEEP_ALIVE) 456 (*i)->handle()->socket()->Disconnect(); 457 (*i)->handle()->Reset(); 458 MessageLoop::current()->RunAllPending(); 459 return true; 460 } 461 } 462 return false; 463 } 464 465 void ClientSocketPoolTest::ReleaseAllConnections(KeepAlive keep_alive) { 466 bool released_one; 467 do { 468 released_one = ReleaseOneConnection(keep_alive); 469 } while (released_one); 470 } 471 472 } // namespace net 473