1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "jingle/glue/pseudotcp_adapter.h" 6 7 #include <vector> 8 9 #include "base/bind.h" 10 #include "base/bind_helpers.h" 11 #include "base/compiler_specific.h" 12 #include "jingle/glue/thread_wrapper.h" 13 #include "net/base/io_buffer.h" 14 #include "net/base/net_errors.h" 15 #include "net/base/test_completion_callback.h" 16 #include "net/udp/udp_socket.h" 17 #include "testing/gmock/include/gmock/gmock.h" 18 #include "testing/gtest/include/gtest/gtest.h" 19 20 21 namespace jingle_glue { 22 namespace { 23 class FakeSocket; 24 } // namespace 25 } // namespace jingle_glue 26 27 namespace jingle_glue { 28 29 namespace { 30 31 // The range is chosen arbitrarily. It must be big enough so that we 32 // always have at least two UDP ports available. 33 const int kMinPort = 32000; 34 const int kMaxPort = 33000; 35 36 const int kMessageSize = 1024; 37 const int kMessages = 100; 38 const int kTestDataSize = kMessages * kMessageSize; 39 40 class RateLimiter { 41 public: 42 virtual ~RateLimiter() { }; 43 // Returns true if the new packet needs to be dropped, false otherwise. 44 virtual bool DropNextPacket() = 0; 45 }; 46 47 class LeakyBucket : public RateLimiter { 48 public: 49 // |rate| is in drops per second. 50 LeakyBucket(double volume, double rate) 51 : volume_(volume), 52 rate_(rate), 53 level_(0.0), 54 last_update_(base::TimeTicks::HighResNow()) { 55 } 56 57 virtual ~LeakyBucket() { } 58 59 virtual bool DropNextPacket() OVERRIDE { 60 base::TimeTicks now = base::TimeTicks::HighResNow(); 61 double interval = (now - last_update_).InSecondsF(); 62 last_update_ = now; 63 level_ = level_ + 1.0 - interval * rate_; 64 if (level_ > volume_) { 65 level_ = volume_; 66 return true; 67 } else if (level_ < 0.0) { 68 level_ = 0.0; 69 } 70 return false; 71 } 72 73 private: 74 double volume_; 75 double rate_; 76 double level_; 77 base::TimeTicks last_update_; 78 }; 79 80 class FakeSocket : public net::Socket { 81 public: 82 FakeSocket() 83 : rate_limiter_(NULL), 84 latency_ms_(0) { 85 } 86 virtual ~FakeSocket() { } 87 88 void AppendInputPacket(const std::vector<char>& data) { 89 if (rate_limiter_ && rate_limiter_->DropNextPacket()) 90 return; // Lose the packet. 91 92 if (!read_callback_.is_null()) { 93 int size = std::min(read_buffer_size_, static_cast<int>(data.size())); 94 memcpy(read_buffer_->data(), &data[0], data.size()); 95 net::CompletionCallback cb = read_callback_; 96 read_callback_.Reset(); 97 read_buffer_ = NULL; 98 cb.Run(size); 99 } else { 100 incoming_packets_.push_back(data); 101 } 102 } 103 104 void Connect(FakeSocket* peer_socket) { 105 peer_socket_ = peer_socket; 106 } 107 108 void set_rate_limiter(RateLimiter* rate_limiter) { 109 rate_limiter_ = rate_limiter; 110 }; 111 112 void set_latency(int latency_ms) { latency_ms_ = latency_ms; }; 113 114 // net::Socket interface. 115 virtual int Read(net::IOBuffer* buf, int buf_len, 116 const net::CompletionCallback& callback) OVERRIDE { 117 CHECK(read_callback_.is_null()); 118 CHECK(buf); 119 120 if (incoming_packets_.size() > 0) { 121 scoped_refptr<net::IOBuffer> buffer(buf); 122 int size = std::min( 123 static_cast<int>(incoming_packets_.front().size()), buf_len); 124 memcpy(buffer->data(), &*incoming_packets_.front().begin(), size); 125 incoming_packets_.pop_front(); 126 return size; 127 } else { 128 read_callback_ = callback; 129 read_buffer_ = buf; 130 read_buffer_size_ = buf_len; 131 return net::ERR_IO_PENDING; 132 } 133 } 134 135 virtual int Write(net::IOBuffer* buf, int buf_len, 136 const net::CompletionCallback& callback) OVERRIDE { 137 DCHECK(buf); 138 if (peer_socket_) { 139 base::MessageLoop::current()->PostDelayedTask( 140 FROM_HERE, 141 base::Bind(&FakeSocket::AppendInputPacket, 142 base::Unretained(peer_socket_), 143 std::vector<char>(buf->data(), buf->data() + buf_len)), 144 base::TimeDelta::FromMilliseconds(latency_ms_)); 145 } 146 147 return buf_len; 148 } 149 150 virtual bool SetReceiveBufferSize(int32 size) OVERRIDE { 151 NOTIMPLEMENTED(); 152 return false; 153 } 154 virtual bool SetSendBufferSize(int32 size) OVERRIDE { 155 NOTIMPLEMENTED(); 156 return false; 157 } 158 159 private: 160 scoped_refptr<net::IOBuffer> read_buffer_; 161 int read_buffer_size_; 162 net::CompletionCallback read_callback_; 163 164 std::deque<std::vector<char> > incoming_packets_; 165 166 FakeSocket* peer_socket_; 167 RateLimiter* rate_limiter_; 168 int latency_ms_; 169 }; 170 171 class TCPChannelTester : public base::RefCountedThreadSafe<TCPChannelTester> { 172 public: 173 TCPChannelTester(base::MessageLoop* message_loop, 174 net::Socket* client_socket, 175 net::Socket* host_socket) 176 : message_loop_(message_loop), 177 host_socket_(host_socket), 178 client_socket_(client_socket), 179 done_(false), 180 write_errors_(0), 181 read_errors_(0) {} 182 183 void Start() { 184 message_loop_->PostTask( 185 FROM_HERE, base::Bind(&TCPChannelTester::DoStart, this)); 186 } 187 188 void CheckResults() { 189 EXPECT_EQ(0, write_errors_); 190 EXPECT_EQ(0, read_errors_); 191 192 ASSERT_EQ(kTestDataSize + kMessageSize, input_buffer_->capacity()); 193 194 output_buffer_->SetOffset(0); 195 ASSERT_EQ(kTestDataSize, output_buffer_->size()); 196 197 EXPECT_EQ(0, memcmp(output_buffer_->data(), 198 input_buffer_->StartOfBuffer(), kTestDataSize)); 199 } 200 201 protected: 202 virtual ~TCPChannelTester() {} 203 204 void Done() { 205 done_ = true; 206 message_loop_->PostTask(FROM_HERE, base::MessageLoop::QuitClosure()); 207 } 208 209 void DoStart() { 210 InitBuffers(); 211 DoRead(); 212 DoWrite(); 213 } 214 215 void InitBuffers() { 216 output_buffer_ = new net::DrainableIOBuffer( 217 new net::IOBuffer(kTestDataSize), kTestDataSize); 218 memset(output_buffer_->data(), 123, kTestDataSize); 219 220 input_buffer_ = new net::GrowableIOBuffer(); 221 // Always keep kMessageSize bytes available at the end of the input buffer. 222 input_buffer_->SetCapacity(kMessageSize); 223 } 224 225 void DoWrite() { 226 int result = 1; 227 while (result > 0) { 228 if (output_buffer_->BytesRemaining() == 0) 229 break; 230 231 int bytes_to_write = std::min(output_buffer_->BytesRemaining(), 232 kMessageSize); 233 result = client_socket_->Write( 234 output_buffer_.get(), 235 bytes_to_write, 236 base::Bind(&TCPChannelTester::OnWritten, base::Unretained(this))); 237 HandleWriteResult(result); 238 } 239 } 240 241 void OnWritten(int result) { 242 HandleWriteResult(result); 243 DoWrite(); 244 } 245 246 void HandleWriteResult(int result) { 247 if (result <= 0 && result != net::ERR_IO_PENDING) { 248 LOG(ERROR) << "Received error " << result << " when trying to write"; 249 write_errors_++; 250 Done(); 251 } else if (result > 0) { 252 output_buffer_->DidConsume(result); 253 } 254 } 255 256 void DoRead() { 257 int result = 1; 258 while (result > 0) { 259 input_buffer_->set_offset(input_buffer_->capacity() - kMessageSize); 260 261 result = host_socket_->Read( 262 input_buffer_.get(), 263 kMessageSize, 264 base::Bind(&TCPChannelTester::OnRead, base::Unretained(this))); 265 HandleReadResult(result); 266 }; 267 } 268 269 void OnRead(int result) { 270 HandleReadResult(result); 271 DoRead(); 272 } 273 274 void HandleReadResult(int result) { 275 if (result <= 0 && result != net::ERR_IO_PENDING) { 276 if (!done_) { 277 LOG(ERROR) << "Received error " << result << " when trying to read"; 278 read_errors_++; 279 Done(); 280 } 281 } else if (result > 0) { 282 // Allocate memory for the next read. 283 input_buffer_->SetCapacity(input_buffer_->capacity() + result); 284 if (input_buffer_->capacity() == kTestDataSize + kMessageSize) 285 Done(); 286 } 287 } 288 289 private: 290 friend class base::RefCountedThreadSafe<TCPChannelTester>; 291 292 base::MessageLoop* message_loop_; 293 net::Socket* host_socket_; 294 net::Socket* client_socket_; 295 bool done_; 296 297 scoped_refptr<net::DrainableIOBuffer> output_buffer_; 298 scoped_refptr<net::GrowableIOBuffer> input_buffer_; 299 300 int write_errors_; 301 int read_errors_; 302 }; 303 304 class PseudoTcpAdapterTest : public testing::Test { 305 protected: 306 virtual void SetUp() OVERRIDE { 307 JingleThreadWrapper::EnsureForCurrentMessageLoop(); 308 309 host_socket_ = new FakeSocket(); 310 client_socket_ = new FakeSocket(); 311 312 host_socket_->Connect(client_socket_); 313 client_socket_->Connect(host_socket_); 314 315 host_pseudotcp_.reset(new PseudoTcpAdapter(host_socket_)); 316 client_pseudotcp_.reset(new PseudoTcpAdapter(client_socket_)); 317 } 318 319 FakeSocket* host_socket_; 320 FakeSocket* client_socket_; 321 322 scoped_ptr<PseudoTcpAdapter> host_pseudotcp_; 323 scoped_ptr<PseudoTcpAdapter> client_pseudotcp_; 324 base::MessageLoop message_loop_; 325 }; 326 327 TEST_F(PseudoTcpAdapterTest, DataTransfer) { 328 net::TestCompletionCallback host_connect_cb; 329 net::TestCompletionCallback client_connect_cb; 330 331 int rv1 = host_pseudotcp_->Connect(host_connect_cb.callback()); 332 int rv2 = client_pseudotcp_->Connect(client_connect_cb.callback()); 333 334 if (rv1 == net::ERR_IO_PENDING) 335 rv1 = host_connect_cb.WaitForResult(); 336 if (rv2 == net::ERR_IO_PENDING) 337 rv2 = client_connect_cb.WaitForResult(); 338 ASSERT_EQ(net::OK, rv1); 339 ASSERT_EQ(net::OK, rv2); 340 341 scoped_refptr<TCPChannelTester> tester = 342 new TCPChannelTester(&message_loop_, host_pseudotcp_.get(), 343 client_pseudotcp_.get()); 344 345 tester->Start(); 346 message_loop_.Run(); 347 tester->CheckResults(); 348 } 349 350 TEST_F(PseudoTcpAdapterTest, LimitedChannel) { 351 const int kLatencyMs = 20; 352 const int kPacketsPerSecond = 400; 353 const int kBurstPackets = 10; 354 355 LeakyBucket host_limiter(kBurstPackets, kPacketsPerSecond); 356 host_socket_->set_latency(kLatencyMs); 357 host_socket_->set_rate_limiter(&host_limiter); 358 359 LeakyBucket client_limiter(kBurstPackets, kPacketsPerSecond); 360 host_socket_->set_latency(kLatencyMs); 361 client_socket_->set_rate_limiter(&client_limiter); 362 363 net::TestCompletionCallback host_connect_cb; 364 net::TestCompletionCallback client_connect_cb; 365 366 int rv1 = host_pseudotcp_->Connect(host_connect_cb.callback()); 367 int rv2 = client_pseudotcp_->Connect(client_connect_cb.callback()); 368 369 if (rv1 == net::ERR_IO_PENDING) 370 rv1 = host_connect_cb.WaitForResult(); 371 if (rv2 == net::ERR_IO_PENDING) 372 rv2 = client_connect_cb.WaitForResult(); 373 ASSERT_EQ(net::OK, rv1); 374 ASSERT_EQ(net::OK, rv2); 375 376 scoped_refptr<TCPChannelTester> tester = 377 new TCPChannelTester(&message_loop_, host_pseudotcp_.get(), 378 client_pseudotcp_.get()); 379 380 tester->Start(); 381 message_loop_.Run(); 382 tester->CheckResults(); 383 } 384 385 class DeleteOnConnected { 386 public: 387 DeleteOnConnected(base::MessageLoop* message_loop, 388 scoped_ptr<PseudoTcpAdapter>* adapter) 389 : message_loop_(message_loop), adapter_(adapter) {} 390 void OnConnected(int error) { 391 adapter_->reset(); 392 message_loop_->PostTask(FROM_HERE, base::MessageLoop::QuitClosure()); 393 } 394 base::MessageLoop* message_loop_; 395 scoped_ptr<PseudoTcpAdapter>* adapter_; 396 }; 397 398 TEST_F(PseudoTcpAdapterTest, DeleteOnConnected) { 399 // This test verifies that deleting the adapter mid-callback doesn't lead 400 // to deleted structures being touched as the stack unrolls, so the failure 401 // mode is a crash rather than a normal test failure. 402 net::TestCompletionCallback client_connect_cb; 403 DeleteOnConnected host_delete(&message_loop_, &host_pseudotcp_); 404 405 host_pseudotcp_->Connect(base::Bind(&DeleteOnConnected::OnConnected, 406 base::Unretained(&host_delete))); 407 client_pseudotcp_->Connect(client_connect_cb.callback()); 408 message_loop_.Run(); 409 410 ASSERT_EQ(NULL, host_pseudotcp_.get()); 411 } 412 413 // Verify that we can send/receive data with the write-waits-for-send 414 // flag set. 415 TEST_F(PseudoTcpAdapterTest, WriteWaitsForSendLetsDataThrough) { 416 net::TestCompletionCallback host_connect_cb; 417 net::TestCompletionCallback client_connect_cb; 418 419 host_pseudotcp_->SetWriteWaitsForSend(true); 420 client_pseudotcp_->SetWriteWaitsForSend(true); 421 422 // Disable Nagle's algorithm because the test is slow when it is 423 // enabled. 424 host_pseudotcp_->SetNoDelay(true); 425 426 int rv1 = host_pseudotcp_->Connect(host_connect_cb.callback()); 427 int rv2 = client_pseudotcp_->Connect(client_connect_cb.callback()); 428 429 if (rv1 == net::ERR_IO_PENDING) 430 rv1 = host_connect_cb.WaitForResult(); 431 if (rv2 == net::ERR_IO_PENDING) 432 rv2 = client_connect_cb.WaitForResult(); 433 ASSERT_EQ(net::OK, rv1); 434 ASSERT_EQ(net::OK, rv2); 435 436 scoped_refptr<TCPChannelTester> tester = 437 new TCPChannelTester(&message_loop_, host_pseudotcp_.get(), 438 client_pseudotcp_.get()); 439 440 tester->Start(); 441 message_loop_.Run(); 442 tester->CheckResults(); 443 } 444 445 } // namespace 446 447 } // namespace jingle_glue 448