Home | History | Annotate | Download | only in glue
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "jingle/glue/pseudotcp_adapter.h"
      6 
      7 #include <vector>
      8 
      9 #include "base/bind.h"
     10 #include "base/bind_helpers.h"
     11 #include "base/compiler_specific.h"
     12 #include "jingle/glue/thread_wrapper.h"
     13 #include "net/base/io_buffer.h"
     14 #include "net/base/net_errors.h"
     15 #include "net/base/test_completion_callback.h"
     16 #include "net/udp/udp_socket.h"
     17 #include "testing/gmock/include/gmock/gmock.h"
     18 #include "testing/gtest/include/gtest/gtest.h"
     19 
     20 
     21 namespace jingle_glue {
     22 namespace {
     23 class FakeSocket;
     24 }  // namespace
     25 }  // namespace jingle_glue
     26 
     27 namespace jingle_glue {
     28 
     29 namespace {
     30 
     31 // The range is chosen arbitrarily. It must be big enough so that we
     32 // always have at least two UDP ports available.
     33 const int kMinPort = 32000;
     34 const int kMaxPort = 33000;
     35 
     36 const int kMessageSize = 1024;
     37 const int kMessages = 100;
     38 const int kTestDataSize = kMessages * kMessageSize;
     39 
     40 class RateLimiter {
     41  public:
     42   virtual ~RateLimiter() { };
     43   // Returns true if the new packet needs to be dropped, false otherwise.
     44   virtual bool DropNextPacket() = 0;
     45 };
     46 
     47 class LeakyBucket : public RateLimiter {
     48  public:
     49   // |rate| is in drops per second.
     50   LeakyBucket(double volume, double rate)
     51       : volume_(volume),
     52         rate_(rate),
     53         level_(0.0),
     54         last_update_(base::TimeTicks::HighResNow()) {
     55   }
     56 
     57   virtual ~LeakyBucket() { }
     58 
     59   virtual bool DropNextPacket() OVERRIDE {
     60     base::TimeTicks now = base::TimeTicks::HighResNow();
     61     double interval = (now - last_update_).InSecondsF();
     62     last_update_ = now;
     63     level_ = level_ + 1.0 - interval * rate_;
     64     if (level_ > volume_) {
     65       level_ = volume_;
     66       return true;
     67     } else if (level_ < 0.0) {
     68       level_ = 0.0;
     69     }
     70     return false;
     71   }
     72 
     73  private:
     74   double volume_;
     75   double rate_;
     76   double level_;
     77   base::TimeTicks last_update_;
     78 };
     79 
     80 class FakeSocket : public net::Socket {
     81  public:
     82   FakeSocket()
     83       : rate_limiter_(NULL),
     84         latency_ms_(0) {
     85   }
     86   virtual ~FakeSocket() { }
     87 
     88   void AppendInputPacket(const std::vector<char>& data) {
     89     if (rate_limiter_ && rate_limiter_->DropNextPacket())
     90       return;  // Lose the packet.
     91 
     92     if (!read_callback_.is_null()) {
     93       int size = std::min(read_buffer_size_, static_cast<int>(data.size()));
     94       memcpy(read_buffer_->data(), &data[0], data.size());
     95       net::CompletionCallback cb = read_callback_;
     96       read_callback_.Reset();
     97       read_buffer_ = NULL;
     98       cb.Run(size);
     99     } else {
    100       incoming_packets_.push_back(data);
    101     }
    102   }
    103 
    104   void Connect(FakeSocket* peer_socket) {
    105     peer_socket_ = peer_socket;
    106   }
    107 
    108   void set_rate_limiter(RateLimiter* rate_limiter) {
    109     rate_limiter_ = rate_limiter;
    110   };
    111 
    112   void set_latency(int latency_ms) { latency_ms_ = latency_ms; };
    113 
    114   // net::Socket interface.
    115   virtual int Read(net::IOBuffer* buf, int buf_len,
    116                    const net::CompletionCallback& callback) OVERRIDE {
    117     CHECK(read_callback_.is_null());
    118     CHECK(buf);
    119 
    120     if (incoming_packets_.size() > 0) {
    121       scoped_refptr<net::IOBuffer> buffer(buf);
    122       int size = std::min(
    123           static_cast<int>(incoming_packets_.front().size()), buf_len);
    124       memcpy(buffer->data(), &*incoming_packets_.front().begin(), size);
    125       incoming_packets_.pop_front();
    126       return size;
    127     } else {
    128       read_callback_ = callback;
    129       read_buffer_ = buf;
    130       read_buffer_size_ = buf_len;
    131       return net::ERR_IO_PENDING;
    132     }
    133   }
    134 
    135   virtual int Write(net::IOBuffer* buf, int buf_len,
    136                     const net::CompletionCallback& callback) OVERRIDE {
    137     DCHECK(buf);
    138     if (peer_socket_) {
    139       base::MessageLoop::current()->PostDelayedTask(
    140           FROM_HERE,
    141           base::Bind(&FakeSocket::AppendInputPacket,
    142                      base::Unretained(peer_socket_),
    143                      std::vector<char>(buf->data(), buf->data() + buf_len)),
    144           base::TimeDelta::FromMilliseconds(latency_ms_));
    145     }
    146 
    147     return buf_len;
    148   }
    149 
    150   virtual bool SetReceiveBufferSize(int32 size) OVERRIDE {
    151     NOTIMPLEMENTED();
    152     return false;
    153   }
    154   virtual bool SetSendBufferSize(int32 size) OVERRIDE {
    155     NOTIMPLEMENTED();
    156     return false;
    157   }
    158 
    159  private:
    160   scoped_refptr<net::IOBuffer> read_buffer_;
    161   int read_buffer_size_;
    162   net::CompletionCallback read_callback_;
    163 
    164   std::deque<std::vector<char> > incoming_packets_;
    165 
    166   FakeSocket* peer_socket_;
    167   RateLimiter* rate_limiter_;
    168   int latency_ms_;
    169 };
    170 
    171 class TCPChannelTester : public base::RefCountedThreadSafe<TCPChannelTester> {
    172  public:
    173   TCPChannelTester(base::MessageLoop* message_loop,
    174                    net::Socket* client_socket,
    175                    net::Socket* host_socket)
    176       : message_loop_(message_loop),
    177         host_socket_(host_socket),
    178         client_socket_(client_socket),
    179         done_(false),
    180         write_errors_(0),
    181         read_errors_(0) {}
    182 
    183   void Start() {
    184     message_loop_->PostTask(
    185         FROM_HERE, base::Bind(&TCPChannelTester::DoStart, this));
    186   }
    187 
    188   void CheckResults() {
    189     EXPECT_EQ(0, write_errors_);
    190     EXPECT_EQ(0, read_errors_);
    191 
    192     ASSERT_EQ(kTestDataSize + kMessageSize, input_buffer_->capacity());
    193 
    194     output_buffer_->SetOffset(0);
    195     ASSERT_EQ(kTestDataSize, output_buffer_->size());
    196 
    197     EXPECT_EQ(0, memcmp(output_buffer_->data(),
    198                         input_buffer_->StartOfBuffer(), kTestDataSize));
    199   }
    200 
    201  protected:
    202   virtual ~TCPChannelTester() {}
    203 
    204   void Done() {
    205     done_ = true;
    206     message_loop_->PostTask(FROM_HERE, base::MessageLoop::QuitClosure());
    207   }
    208 
    209   void DoStart() {
    210     InitBuffers();
    211     DoRead();
    212     DoWrite();
    213   }
    214 
    215   void InitBuffers() {
    216     output_buffer_ = new net::DrainableIOBuffer(
    217         new net::IOBuffer(kTestDataSize), kTestDataSize);
    218     memset(output_buffer_->data(), 123, kTestDataSize);
    219 
    220     input_buffer_ = new net::GrowableIOBuffer();
    221     // Always keep kMessageSize bytes available at the end of the input buffer.
    222     input_buffer_->SetCapacity(kMessageSize);
    223   }
    224 
    225   void DoWrite() {
    226     int result = 1;
    227     while (result > 0) {
    228       if (output_buffer_->BytesRemaining() == 0)
    229         break;
    230 
    231       int bytes_to_write = std::min(output_buffer_->BytesRemaining(),
    232                                     kMessageSize);
    233       result = client_socket_->Write(
    234           output_buffer_.get(),
    235           bytes_to_write,
    236           base::Bind(&TCPChannelTester::OnWritten, base::Unretained(this)));
    237       HandleWriteResult(result);
    238     }
    239   }
    240 
    241   void OnWritten(int result) {
    242     HandleWriteResult(result);
    243     DoWrite();
    244   }
    245 
    246   void HandleWriteResult(int result) {
    247     if (result <= 0 && result != net::ERR_IO_PENDING) {
    248       LOG(ERROR) << "Received error " << result << " when trying to write";
    249       write_errors_++;
    250       Done();
    251     } else if (result > 0) {
    252       output_buffer_->DidConsume(result);
    253     }
    254   }
    255 
    256   void DoRead() {
    257     int result = 1;
    258     while (result > 0) {
    259       input_buffer_->set_offset(input_buffer_->capacity() - kMessageSize);
    260 
    261       result = host_socket_->Read(
    262           input_buffer_.get(),
    263           kMessageSize,
    264           base::Bind(&TCPChannelTester::OnRead, base::Unretained(this)));
    265       HandleReadResult(result);
    266     };
    267   }
    268 
    269   void OnRead(int result) {
    270     HandleReadResult(result);
    271     DoRead();
    272   }
    273 
    274   void HandleReadResult(int result) {
    275     if (result <= 0 && result != net::ERR_IO_PENDING) {
    276       if (!done_) {
    277         LOG(ERROR) << "Received error " << result << " when trying to read";
    278         read_errors_++;
    279         Done();
    280       }
    281     } else if (result > 0) {
    282       // Allocate memory for the next read.
    283       input_buffer_->SetCapacity(input_buffer_->capacity() + result);
    284       if (input_buffer_->capacity() == kTestDataSize + kMessageSize)
    285         Done();
    286     }
    287   }
    288 
    289  private:
    290   friend class base::RefCountedThreadSafe<TCPChannelTester>;
    291 
    292   base::MessageLoop* message_loop_;
    293   net::Socket* host_socket_;
    294   net::Socket* client_socket_;
    295   bool done_;
    296 
    297   scoped_refptr<net::DrainableIOBuffer> output_buffer_;
    298   scoped_refptr<net::GrowableIOBuffer> input_buffer_;
    299 
    300   int write_errors_;
    301   int read_errors_;
    302 };
    303 
    304 class PseudoTcpAdapterTest : public testing::Test {
    305  protected:
    306   virtual void SetUp() OVERRIDE {
    307     JingleThreadWrapper::EnsureForCurrentMessageLoop();
    308 
    309     host_socket_ = new FakeSocket();
    310     client_socket_ = new FakeSocket();
    311 
    312     host_socket_->Connect(client_socket_);
    313     client_socket_->Connect(host_socket_);
    314 
    315     host_pseudotcp_.reset(new PseudoTcpAdapter(host_socket_));
    316     client_pseudotcp_.reset(new PseudoTcpAdapter(client_socket_));
    317   }
    318 
    319   FakeSocket* host_socket_;
    320   FakeSocket* client_socket_;
    321 
    322   scoped_ptr<PseudoTcpAdapter> host_pseudotcp_;
    323   scoped_ptr<PseudoTcpAdapter> client_pseudotcp_;
    324   base::MessageLoop message_loop_;
    325 };
    326 
    327 TEST_F(PseudoTcpAdapterTest, DataTransfer) {
    328   net::TestCompletionCallback host_connect_cb;
    329   net::TestCompletionCallback client_connect_cb;
    330 
    331   int rv1 = host_pseudotcp_->Connect(host_connect_cb.callback());
    332   int rv2 = client_pseudotcp_->Connect(client_connect_cb.callback());
    333 
    334   if (rv1 == net::ERR_IO_PENDING)
    335     rv1 = host_connect_cb.WaitForResult();
    336   if (rv2 == net::ERR_IO_PENDING)
    337     rv2 = client_connect_cb.WaitForResult();
    338   ASSERT_EQ(net::OK, rv1);
    339   ASSERT_EQ(net::OK, rv2);
    340 
    341   scoped_refptr<TCPChannelTester> tester =
    342       new TCPChannelTester(&message_loop_, host_pseudotcp_.get(),
    343                            client_pseudotcp_.get());
    344 
    345   tester->Start();
    346   message_loop_.Run();
    347   tester->CheckResults();
    348 }
    349 
    350 TEST_F(PseudoTcpAdapterTest, LimitedChannel) {
    351   const int kLatencyMs = 20;
    352   const int kPacketsPerSecond = 400;
    353   const int kBurstPackets = 10;
    354 
    355   LeakyBucket host_limiter(kBurstPackets, kPacketsPerSecond);
    356   host_socket_->set_latency(kLatencyMs);
    357   host_socket_->set_rate_limiter(&host_limiter);
    358 
    359   LeakyBucket client_limiter(kBurstPackets, kPacketsPerSecond);
    360   host_socket_->set_latency(kLatencyMs);
    361   client_socket_->set_rate_limiter(&client_limiter);
    362 
    363   net::TestCompletionCallback host_connect_cb;
    364   net::TestCompletionCallback client_connect_cb;
    365 
    366   int rv1 = host_pseudotcp_->Connect(host_connect_cb.callback());
    367   int rv2 = client_pseudotcp_->Connect(client_connect_cb.callback());
    368 
    369   if (rv1 == net::ERR_IO_PENDING)
    370     rv1 = host_connect_cb.WaitForResult();
    371   if (rv2 == net::ERR_IO_PENDING)
    372     rv2 = client_connect_cb.WaitForResult();
    373   ASSERT_EQ(net::OK, rv1);
    374   ASSERT_EQ(net::OK, rv2);
    375 
    376   scoped_refptr<TCPChannelTester> tester =
    377       new TCPChannelTester(&message_loop_, host_pseudotcp_.get(),
    378                            client_pseudotcp_.get());
    379 
    380   tester->Start();
    381   message_loop_.Run();
    382   tester->CheckResults();
    383 }
    384 
    385 class DeleteOnConnected {
    386  public:
    387   DeleteOnConnected(base::MessageLoop* message_loop,
    388                     scoped_ptr<PseudoTcpAdapter>* adapter)
    389       : message_loop_(message_loop), adapter_(adapter) {}
    390   void OnConnected(int error) {
    391     adapter_->reset();
    392     message_loop_->PostTask(FROM_HERE, base::MessageLoop::QuitClosure());
    393   }
    394   base::MessageLoop* message_loop_;
    395   scoped_ptr<PseudoTcpAdapter>* adapter_;
    396 };
    397 
    398 TEST_F(PseudoTcpAdapterTest, DeleteOnConnected) {
    399   // This test verifies that deleting the adapter mid-callback doesn't lead
    400   // to deleted structures being touched as the stack unrolls, so the failure
    401   // mode is a crash rather than a normal test failure.
    402   net::TestCompletionCallback client_connect_cb;
    403   DeleteOnConnected host_delete(&message_loop_, &host_pseudotcp_);
    404 
    405   host_pseudotcp_->Connect(base::Bind(&DeleteOnConnected::OnConnected,
    406                                       base::Unretained(&host_delete)));
    407   client_pseudotcp_->Connect(client_connect_cb.callback());
    408   message_loop_.Run();
    409 
    410   ASSERT_EQ(NULL, host_pseudotcp_.get());
    411 }
    412 
    413 // Verify that we can send/receive data with the write-waits-for-send
    414 // flag set.
    415 TEST_F(PseudoTcpAdapterTest, WriteWaitsForSendLetsDataThrough) {
    416   net::TestCompletionCallback host_connect_cb;
    417   net::TestCompletionCallback client_connect_cb;
    418 
    419   host_pseudotcp_->SetWriteWaitsForSend(true);
    420   client_pseudotcp_->SetWriteWaitsForSend(true);
    421 
    422   // Disable Nagle's algorithm because the test is slow when it is
    423   // enabled.
    424   host_pseudotcp_->SetNoDelay(true);
    425 
    426   int rv1 = host_pseudotcp_->Connect(host_connect_cb.callback());
    427   int rv2 = client_pseudotcp_->Connect(client_connect_cb.callback());
    428 
    429   if (rv1 == net::ERR_IO_PENDING)
    430     rv1 = host_connect_cb.WaitForResult();
    431   if (rv2 == net::ERR_IO_PENDING)
    432     rv2 = client_connect_cb.WaitForResult();
    433   ASSERT_EQ(net::OK, rv1);
    434   ASSERT_EQ(net::OK, rv2);
    435 
    436   scoped_refptr<TCPChannelTester> tester =
    437       new TCPChannelTester(&message_loop_, host_pseudotcp_.get(),
    438                            client_pseudotcp_.get());
    439 
    440   tester->Start();
    441   message_loop_.Run();
    442   tester->CheckResults();
    443 }
    444 
    445 }  // namespace
    446 
    447 }  // namespace jingle_glue
    448