1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "jingle/glue/pseudotcp_adapter.h" 6 7 #include <vector> 8 9 #include "base/bind.h" 10 #include "base/bind_helpers.h" 11 #include "base/compiler_specific.h" 12 #include "jingle/glue/thread_wrapper.h" 13 #include "net/base/io_buffer.h" 14 #include "net/base/net_errors.h" 15 #include "net/base/test_completion_callback.h" 16 #include "net/udp/udp_socket.h" 17 #include "testing/gmock/include/gmock/gmock.h" 18 #include "testing/gtest/include/gtest/gtest.h" 19 20 21 namespace jingle_glue { 22 namespace { 23 class FakeSocket; 24 } // namespace 25 } // namespace jingle_glue 26 27 namespace jingle_glue { 28 29 namespace { 30 31 const int kMessageSize = 1024; 32 const int kMessages = 100; 33 const int kTestDataSize = kMessages * kMessageSize; 34 35 class RateLimiter { 36 public: 37 virtual ~RateLimiter() { }; 38 // Returns true if the new packet needs to be dropped, false otherwise. 39 virtual bool DropNextPacket() = 0; 40 }; 41 42 class LeakyBucket : public RateLimiter { 43 public: 44 // |rate| is in drops per second. 45 LeakyBucket(double volume, double rate) 46 : volume_(volume), 47 rate_(rate), 48 level_(0.0), 49 last_update_(base::TimeTicks::HighResNow()) { 50 } 51 52 virtual ~LeakyBucket() { } 53 54 virtual bool DropNextPacket() OVERRIDE { 55 base::TimeTicks now = base::TimeTicks::HighResNow(); 56 double interval = (now - last_update_).InSecondsF(); 57 last_update_ = now; 58 level_ = level_ + 1.0 - interval * rate_; 59 if (level_ > volume_) { 60 level_ = volume_; 61 return true; 62 } else if (level_ < 0.0) { 63 level_ = 0.0; 64 } 65 return false; 66 } 67 68 private: 69 double volume_; 70 double rate_; 71 double level_; 72 base::TimeTicks last_update_; 73 }; 74 75 class FakeSocket : public net::Socket { 76 public: 77 FakeSocket() 78 : rate_limiter_(NULL), 79 latency_ms_(0) { 80 } 81 virtual ~FakeSocket() { } 82 83 void AppendInputPacket(const std::vector<char>& data) { 84 if (rate_limiter_ && rate_limiter_->DropNextPacket()) 85 return; // Lose the packet. 86 87 if (!read_callback_.is_null()) { 88 int size = std::min(read_buffer_size_, static_cast<int>(data.size())); 89 memcpy(read_buffer_->data(), &data[0], data.size()); 90 net::CompletionCallback cb = read_callback_; 91 read_callback_.Reset(); 92 read_buffer_ = NULL; 93 cb.Run(size); 94 } else { 95 incoming_packets_.push_back(data); 96 } 97 } 98 99 void Connect(FakeSocket* peer_socket) { 100 peer_socket_ = peer_socket; 101 } 102 103 void set_rate_limiter(RateLimiter* rate_limiter) { 104 rate_limiter_ = rate_limiter; 105 }; 106 107 void set_latency(int latency_ms) { latency_ms_ = latency_ms; }; 108 109 // net::Socket interface. 110 virtual int Read(net::IOBuffer* buf, int buf_len, 111 const net::CompletionCallback& callback) OVERRIDE { 112 CHECK(read_callback_.is_null()); 113 CHECK(buf); 114 115 if (incoming_packets_.size() > 0) { 116 scoped_refptr<net::IOBuffer> buffer(buf); 117 int size = std::min( 118 static_cast<int>(incoming_packets_.front().size()), buf_len); 119 memcpy(buffer->data(), &*incoming_packets_.front().begin(), size); 120 incoming_packets_.pop_front(); 121 return size; 122 } else { 123 read_callback_ = callback; 124 read_buffer_ = buf; 125 read_buffer_size_ = buf_len; 126 return net::ERR_IO_PENDING; 127 } 128 } 129 130 virtual int Write(net::IOBuffer* buf, int buf_len, 131 const net::CompletionCallback& callback) OVERRIDE { 132 DCHECK(buf); 133 if (peer_socket_) { 134 base::MessageLoop::current()->PostDelayedTask( 135 FROM_HERE, 136 base::Bind(&FakeSocket::AppendInputPacket, 137 base::Unretained(peer_socket_), 138 std::vector<char>(buf->data(), buf->data() + buf_len)), 139 base::TimeDelta::FromMilliseconds(latency_ms_)); 140 } 141 142 return buf_len; 143 } 144 145 virtual int SetReceiveBufferSize(int32 size) OVERRIDE { 146 NOTIMPLEMENTED(); 147 return net::ERR_NOT_IMPLEMENTED; 148 } 149 virtual int SetSendBufferSize(int32 size) OVERRIDE { 150 NOTIMPLEMENTED(); 151 return net::ERR_NOT_IMPLEMENTED; 152 } 153 154 private: 155 scoped_refptr<net::IOBuffer> read_buffer_; 156 int read_buffer_size_; 157 net::CompletionCallback read_callback_; 158 159 std::deque<std::vector<char> > incoming_packets_; 160 161 FakeSocket* peer_socket_; 162 RateLimiter* rate_limiter_; 163 int latency_ms_; 164 }; 165 166 class TCPChannelTester : public base::RefCountedThreadSafe<TCPChannelTester> { 167 public: 168 TCPChannelTester(base::MessageLoop* message_loop, 169 net::Socket* client_socket, 170 net::Socket* host_socket) 171 : message_loop_(message_loop), 172 host_socket_(host_socket), 173 client_socket_(client_socket), 174 done_(false), 175 write_errors_(0), 176 read_errors_(0) {} 177 178 void Start() { 179 message_loop_->PostTask( 180 FROM_HERE, base::Bind(&TCPChannelTester::DoStart, this)); 181 } 182 183 void CheckResults() { 184 EXPECT_EQ(0, write_errors_); 185 EXPECT_EQ(0, read_errors_); 186 187 ASSERT_EQ(kTestDataSize + kMessageSize, input_buffer_->capacity()); 188 189 output_buffer_->SetOffset(0); 190 ASSERT_EQ(kTestDataSize, output_buffer_->size()); 191 192 EXPECT_EQ(0, memcmp(output_buffer_->data(), 193 input_buffer_->StartOfBuffer(), kTestDataSize)); 194 } 195 196 protected: 197 virtual ~TCPChannelTester() {} 198 199 void Done() { 200 done_ = true; 201 message_loop_->PostTask(FROM_HERE, base::MessageLoop::QuitClosure()); 202 } 203 204 void DoStart() { 205 InitBuffers(); 206 DoRead(); 207 DoWrite(); 208 } 209 210 void InitBuffers() { 211 output_buffer_ = new net::DrainableIOBuffer( 212 new net::IOBuffer(kTestDataSize), kTestDataSize); 213 memset(output_buffer_->data(), 123, kTestDataSize); 214 215 input_buffer_ = new net::GrowableIOBuffer(); 216 // Always keep kMessageSize bytes available at the end of the input buffer. 217 input_buffer_->SetCapacity(kMessageSize); 218 } 219 220 void DoWrite() { 221 int result = 1; 222 while (result > 0) { 223 if (output_buffer_->BytesRemaining() == 0) 224 break; 225 226 int bytes_to_write = std::min(output_buffer_->BytesRemaining(), 227 kMessageSize); 228 result = client_socket_->Write( 229 output_buffer_.get(), 230 bytes_to_write, 231 base::Bind(&TCPChannelTester::OnWritten, base::Unretained(this))); 232 HandleWriteResult(result); 233 } 234 } 235 236 void OnWritten(int result) { 237 HandleWriteResult(result); 238 DoWrite(); 239 } 240 241 void HandleWriteResult(int result) { 242 if (result <= 0 && result != net::ERR_IO_PENDING) { 243 LOG(ERROR) << "Received error " << result << " when trying to write"; 244 write_errors_++; 245 Done(); 246 } else if (result > 0) { 247 output_buffer_->DidConsume(result); 248 } 249 } 250 251 void DoRead() { 252 int result = 1; 253 while (result > 0) { 254 input_buffer_->set_offset(input_buffer_->capacity() - kMessageSize); 255 256 result = host_socket_->Read( 257 input_buffer_.get(), 258 kMessageSize, 259 base::Bind(&TCPChannelTester::OnRead, base::Unretained(this))); 260 HandleReadResult(result); 261 }; 262 } 263 264 void OnRead(int result) { 265 HandleReadResult(result); 266 DoRead(); 267 } 268 269 void HandleReadResult(int result) { 270 if (result <= 0 && result != net::ERR_IO_PENDING) { 271 if (!done_) { 272 LOG(ERROR) << "Received error " << result << " when trying to read"; 273 read_errors_++; 274 Done(); 275 } 276 } else if (result > 0) { 277 // Allocate memory for the next read. 278 input_buffer_->SetCapacity(input_buffer_->capacity() + result); 279 if (input_buffer_->capacity() == kTestDataSize + kMessageSize) 280 Done(); 281 } 282 } 283 284 private: 285 friend class base::RefCountedThreadSafe<TCPChannelTester>; 286 287 base::MessageLoop* message_loop_; 288 net::Socket* host_socket_; 289 net::Socket* client_socket_; 290 bool done_; 291 292 scoped_refptr<net::DrainableIOBuffer> output_buffer_; 293 scoped_refptr<net::GrowableIOBuffer> input_buffer_; 294 295 int write_errors_; 296 int read_errors_; 297 }; 298 299 class PseudoTcpAdapterTest : public testing::Test { 300 protected: 301 virtual void SetUp() OVERRIDE { 302 JingleThreadWrapper::EnsureForCurrentMessageLoop(); 303 304 host_socket_ = new FakeSocket(); 305 client_socket_ = new FakeSocket(); 306 307 host_socket_->Connect(client_socket_); 308 client_socket_->Connect(host_socket_); 309 310 host_pseudotcp_.reset(new PseudoTcpAdapter(host_socket_)); 311 client_pseudotcp_.reset(new PseudoTcpAdapter(client_socket_)); 312 } 313 314 FakeSocket* host_socket_; 315 FakeSocket* client_socket_; 316 317 scoped_ptr<PseudoTcpAdapter> host_pseudotcp_; 318 scoped_ptr<PseudoTcpAdapter> client_pseudotcp_; 319 base::MessageLoop message_loop_; 320 }; 321 322 TEST_F(PseudoTcpAdapterTest, DataTransfer) { 323 net::TestCompletionCallback host_connect_cb; 324 net::TestCompletionCallback client_connect_cb; 325 326 int rv1 = host_pseudotcp_->Connect(host_connect_cb.callback()); 327 int rv2 = client_pseudotcp_->Connect(client_connect_cb.callback()); 328 329 if (rv1 == net::ERR_IO_PENDING) 330 rv1 = host_connect_cb.WaitForResult(); 331 if (rv2 == net::ERR_IO_PENDING) 332 rv2 = client_connect_cb.WaitForResult(); 333 ASSERT_EQ(net::OK, rv1); 334 ASSERT_EQ(net::OK, rv2); 335 336 scoped_refptr<TCPChannelTester> tester = 337 new TCPChannelTester(&message_loop_, host_pseudotcp_.get(), 338 client_pseudotcp_.get()); 339 340 tester->Start(); 341 message_loop_.Run(); 342 tester->CheckResults(); 343 } 344 345 TEST_F(PseudoTcpAdapterTest, LimitedChannel) { 346 const int kLatencyMs = 20; 347 const int kPacketsPerSecond = 400; 348 const int kBurstPackets = 10; 349 350 LeakyBucket host_limiter(kBurstPackets, kPacketsPerSecond); 351 host_socket_->set_latency(kLatencyMs); 352 host_socket_->set_rate_limiter(&host_limiter); 353 354 LeakyBucket client_limiter(kBurstPackets, kPacketsPerSecond); 355 host_socket_->set_latency(kLatencyMs); 356 client_socket_->set_rate_limiter(&client_limiter); 357 358 net::TestCompletionCallback host_connect_cb; 359 net::TestCompletionCallback client_connect_cb; 360 361 int rv1 = host_pseudotcp_->Connect(host_connect_cb.callback()); 362 int rv2 = client_pseudotcp_->Connect(client_connect_cb.callback()); 363 364 if (rv1 == net::ERR_IO_PENDING) 365 rv1 = host_connect_cb.WaitForResult(); 366 if (rv2 == net::ERR_IO_PENDING) 367 rv2 = client_connect_cb.WaitForResult(); 368 ASSERT_EQ(net::OK, rv1); 369 ASSERT_EQ(net::OK, rv2); 370 371 scoped_refptr<TCPChannelTester> tester = 372 new TCPChannelTester(&message_loop_, host_pseudotcp_.get(), 373 client_pseudotcp_.get()); 374 375 tester->Start(); 376 message_loop_.Run(); 377 tester->CheckResults(); 378 } 379 380 class DeleteOnConnected { 381 public: 382 DeleteOnConnected(base::MessageLoop* message_loop, 383 scoped_ptr<PseudoTcpAdapter>* adapter) 384 : message_loop_(message_loop), adapter_(adapter) {} 385 void OnConnected(int error) { 386 adapter_->reset(); 387 message_loop_->PostTask(FROM_HERE, base::MessageLoop::QuitClosure()); 388 } 389 base::MessageLoop* message_loop_; 390 scoped_ptr<PseudoTcpAdapter>* adapter_; 391 }; 392 393 TEST_F(PseudoTcpAdapterTest, DeleteOnConnected) { 394 // This test verifies that deleting the adapter mid-callback doesn't lead 395 // to deleted structures being touched as the stack unrolls, so the failure 396 // mode is a crash rather than a normal test failure. 397 net::TestCompletionCallback client_connect_cb; 398 DeleteOnConnected host_delete(&message_loop_, &host_pseudotcp_); 399 400 host_pseudotcp_->Connect(base::Bind(&DeleteOnConnected::OnConnected, 401 base::Unretained(&host_delete))); 402 client_pseudotcp_->Connect(client_connect_cb.callback()); 403 message_loop_.Run(); 404 405 ASSERT_EQ(NULL, host_pseudotcp_.get()); 406 } 407 408 // Verify that we can send/receive data with the write-waits-for-send 409 // flag set. 410 TEST_F(PseudoTcpAdapterTest, WriteWaitsForSendLetsDataThrough) { 411 net::TestCompletionCallback host_connect_cb; 412 net::TestCompletionCallback client_connect_cb; 413 414 host_pseudotcp_->SetWriteWaitsForSend(true); 415 client_pseudotcp_->SetWriteWaitsForSend(true); 416 417 // Disable Nagle's algorithm because the test is slow when it is 418 // enabled. 419 host_pseudotcp_->SetNoDelay(true); 420 421 int rv1 = host_pseudotcp_->Connect(host_connect_cb.callback()); 422 int rv2 = client_pseudotcp_->Connect(client_connect_cb.callback()); 423 424 if (rv1 == net::ERR_IO_PENDING) 425 rv1 = host_connect_cb.WaitForResult(); 426 if (rv2 == net::ERR_IO_PENDING) 427 rv2 = client_connect_cb.WaitForResult(); 428 ASSERT_EQ(net::OK, rv1); 429 ASSERT_EQ(net::OK, rv2); 430 431 scoped_refptr<TCPChannelTester> tester = 432 new TCPChannelTester(&message_loop_, host_pseudotcp_.get(), 433 client_pseudotcp_.get()); 434 435 tester->Start(); 436 message_loop_.Run(); 437 tester->CheckResults(); 438 } 439 440 } // namespace 441 442 } // namespace jingle_glue 443