Home | History | Annotate | Download | only in glue
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "jingle/glue/pseudotcp_adapter.h"
      6 
      7 #include <vector>
      8 
      9 #include "base/bind.h"
     10 #include "base/bind_helpers.h"
     11 #include "base/compiler_specific.h"
     12 #include "jingle/glue/thread_wrapper.h"
     13 #include "net/base/io_buffer.h"
     14 #include "net/base/net_errors.h"
     15 #include "net/base/test_completion_callback.h"
     16 #include "net/udp/udp_socket.h"
     17 #include "testing/gmock/include/gmock/gmock.h"
     18 #include "testing/gtest/include/gtest/gtest.h"
     19 
     20 
     21 namespace jingle_glue {
     22 namespace {
     23 class FakeSocket;
     24 }  // namespace
     25 }  // namespace jingle_glue
     26 
     27 namespace jingle_glue {
     28 
     29 namespace {
     30 
     31 const int kMessageSize = 1024;
     32 const int kMessages = 100;
     33 const int kTestDataSize = kMessages * kMessageSize;
     34 
     35 class RateLimiter {
     36  public:
     37   virtual ~RateLimiter() { };
     38   // Returns true if the new packet needs to be dropped, false otherwise.
     39   virtual bool DropNextPacket() = 0;
     40 };
     41 
     42 class LeakyBucket : public RateLimiter {
     43  public:
     44   // |rate| is in drops per second.
     45   LeakyBucket(double volume, double rate)
     46       : volume_(volume),
     47         rate_(rate),
     48         level_(0.0),
     49         last_update_(base::TimeTicks::HighResNow()) {
     50   }
     51 
     52   virtual ~LeakyBucket() { }
     53 
     54   virtual bool DropNextPacket() OVERRIDE {
     55     base::TimeTicks now = base::TimeTicks::HighResNow();
     56     double interval = (now - last_update_).InSecondsF();
     57     last_update_ = now;
     58     level_ = level_ + 1.0 - interval * rate_;
     59     if (level_ > volume_) {
     60       level_ = volume_;
     61       return true;
     62     } else if (level_ < 0.0) {
     63       level_ = 0.0;
     64     }
     65     return false;
     66   }
     67 
     68  private:
     69   double volume_;
     70   double rate_;
     71   double level_;
     72   base::TimeTicks last_update_;
     73 };
     74 
     75 class FakeSocket : public net::Socket {
     76  public:
     77   FakeSocket()
     78       : rate_limiter_(NULL),
     79         latency_ms_(0) {
     80   }
     81   virtual ~FakeSocket() { }
     82 
     83   void AppendInputPacket(const std::vector<char>& data) {
     84     if (rate_limiter_ && rate_limiter_->DropNextPacket())
     85       return;  // Lose the packet.
     86 
     87     if (!read_callback_.is_null()) {
     88       int size = std::min(read_buffer_size_, static_cast<int>(data.size()));
     89       memcpy(read_buffer_->data(), &data[0], data.size());
     90       net::CompletionCallback cb = read_callback_;
     91       read_callback_.Reset();
     92       read_buffer_ = NULL;
     93       cb.Run(size);
     94     } else {
     95       incoming_packets_.push_back(data);
     96     }
     97   }
     98 
     99   void Connect(FakeSocket* peer_socket) {
    100     peer_socket_ = peer_socket;
    101   }
    102 
    103   void set_rate_limiter(RateLimiter* rate_limiter) {
    104     rate_limiter_ = rate_limiter;
    105   };
    106 
    107   void set_latency(int latency_ms) { latency_ms_ = latency_ms; };
    108 
    109   // net::Socket interface.
    110   virtual int Read(net::IOBuffer* buf, int buf_len,
    111                    const net::CompletionCallback& callback) OVERRIDE {
    112     CHECK(read_callback_.is_null());
    113     CHECK(buf);
    114 
    115     if (incoming_packets_.size() > 0) {
    116       scoped_refptr<net::IOBuffer> buffer(buf);
    117       int size = std::min(
    118           static_cast<int>(incoming_packets_.front().size()), buf_len);
    119       memcpy(buffer->data(), &*incoming_packets_.front().begin(), size);
    120       incoming_packets_.pop_front();
    121       return size;
    122     } else {
    123       read_callback_ = callback;
    124       read_buffer_ = buf;
    125       read_buffer_size_ = buf_len;
    126       return net::ERR_IO_PENDING;
    127     }
    128   }
    129 
    130   virtual int Write(net::IOBuffer* buf, int buf_len,
    131                     const net::CompletionCallback& callback) OVERRIDE {
    132     DCHECK(buf);
    133     if (peer_socket_) {
    134       base::MessageLoop::current()->PostDelayedTask(
    135           FROM_HERE,
    136           base::Bind(&FakeSocket::AppendInputPacket,
    137                      base::Unretained(peer_socket_),
    138                      std::vector<char>(buf->data(), buf->data() + buf_len)),
    139           base::TimeDelta::FromMilliseconds(latency_ms_));
    140     }
    141 
    142     return buf_len;
    143   }
    144 
    145   virtual int SetReceiveBufferSize(int32 size) OVERRIDE {
    146     NOTIMPLEMENTED();
    147     return net::ERR_NOT_IMPLEMENTED;
    148   }
    149   virtual int SetSendBufferSize(int32 size) OVERRIDE {
    150     NOTIMPLEMENTED();
    151     return net::ERR_NOT_IMPLEMENTED;
    152   }
    153 
    154  private:
    155   scoped_refptr<net::IOBuffer> read_buffer_;
    156   int read_buffer_size_;
    157   net::CompletionCallback read_callback_;
    158 
    159   std::deque<std::vector<char> > incoming_packets_;
    160 
    161   FakeSocket* peer_socket_;
    162   RateLimiter* rate_limiter_;
    163   int latency_ms_;
    164 };
    165 
    166 class TCPChannelTester : public base::RefCountedThreadSafe<TCPChannelTester> {
    167  public:
    168   TCPChannelTester(base::MessageLoop* message_loop,
    169                    net::Socket* client_socket,
    170                    net::Socket* host_socket)
    171       : message_loop_(message_loop),
    172         host_socket_(host_socket),
    173         client_socket_(client_socket),
    174         done_(false),
    175         write_errors_(0),
    176         read_errors_(0) {}
    177 
    178   void Start() {
    179     message_loop_->PostTask(
    180         FROM_HERE, base::Bind(&TCPChannelTester::DoStart, this));
    181   }
    182 
    183   void CheckResults() {
    184     EXPECT_EQ(0, write_errors_);
    185     EXPECT_EQ(0, read_errors_);
    186 
    187     ASSERT_EQ(kTestDataSize + kMessageSize, input_buffer_->capacity());
    188 
    189     output_buffer_->SetOffset(0);
    190     ASSERT_EQ(kTestDataSize, output_buffer_->size());
    191 
    192     EXPECT_EQ(0, memcmp(output_buffer_->data(),
    193                         input_buffer_->StartOfBuffer(), kTestDataSize));
    194   }
    195 
    196  protected:
    197   virtual ~TCPChannelTester() {}
    198 
    199   void Done() {
    200     done_ = true;
    201     message_loop_->PostTask(FROM_HERE, base::MessageLoop::QuitClosure());
    202   }
    203 
    204   void DoStart() {
    205     InitBuffers();
    206     DoRead();
    207     DoWrite();
    208   }
    209 
    210   void InitBuffers() {
    211     output_buffer_ = new net::DrainableIOBuffer(
    212         new net::IOBuffer(kTestDataSize), kTestDataSize);
    213     memset(output_buffer_->data(), 123, kTestDataSize);
    214 
    215     input_buffer_ = new net::GrowableIOBuffer();
    216     // Always keep kMessageSize bytes available at the end of the input buffer.
    217     input_buffer_->SetCapacity(kMessageSize);
    218   }
    219 
    220   void DoWrite() {
    221     int result = 1;
    222     while (result > 0) {
    223       if (output_buffer_->BytesRemaining() == 0)
    224         break;
    225 
    226       int bytes_to_write = std::min(output_buffer_->BytesRemaining(),
    227                                     kMessageSize);
    228       result = client_socket_->Write(
    229           output_buffer_.get(),
    230           bytes_to_write,
    231           base::Bind(&TCPChannelTester::OnWritten, base::Unretained(this)));
    232       HandleWriteResult(result);
    233     }
    234   }
    235 
    236   void OnWritten(int result) {
    237     HandleWriteResult(result);
    238     DoWrite();
    239   }
    240 
    241   void HandleWriteResult(int result) {
    242     if (result <= 0 && result != net::ERR_IO_PENDING) {
    243       LOG(ERROR) << "Received error " << result << " when trying to write";
    244       write_errors_++;
    245       Done();
    246     } else if (result > 0) {
    247       output_buffer_->DidConsume(result);
    248     }
    249   }
    250 
    251   void DoRead() {
    252     int result = 1;
    253     while (result > 0) {
    254       input_buffer_->set_offset(input_buffer_->capacity() - kMessageSize);
    255 
    256       result = host_socket_->Read(
    257           input_buffer_.get(),
    258           kMessageSize,
    259           base::Bind(&TCPChannelTester::OnRead, base::Unretained(this)));
    260       HandleReadResult(result);
    261     };
    262   }
    263 
    264   void OnRead(int result) {
    265     HandleReadResult(result);
    266     DoRead();
    267   }
    268 
    269   void HandleReadResult(int result) {
    270     if (result <= 0 && result != net::ERR_IO_PENDING) {
    271       if (!done_) {
    272         LOG(ERROR) << "Received error " << result << " when trying to read";
    273         read_errors_++;
    274         Done();
    275       }
    276     } else if (result > 0) {
    277       // Allocate memory for the next read.
    278       input_buffer_->SetCapacity(input_buffer_->capacity() + result);
    279       if (input_buffer_->capacity() == kTestDataSize + kMessageSize)
    280         Done();
    281     }
    282   }
    283 
    284  private:
    285   friend class base::RefCountedThreadSafe<TCPChannelTester>;
    286 
    287   base::MessageLoop* message_loop_;
    288   net::Socket* host_socket_;
    289   net::Socket* client_socket_;
    290   bool done_;
    291 
    292   scoped_refptr<net::DrainableIOBuffer> output_buffer_;
    293   scoped_refptr<net::GrowableIOBuffer> input_buffer_;
    294 
    295   int write_errors_;
    296   int read_errors_;
    297 };
    298 
    299 class PseudoTcpAdapterTest : public testing::Test {
    300  protected:
    301   virtual void SetUp() OVERRIDE {
    302     JingleThreadWrapper::EnsureForCurrentMessageLoop();
    303 
    304     host_socket_ = new FakeSocket();
    305     client_socket_ = new FakeSocket();
    306 
    307     host_socket_->Connect(client_socket_);
    308     client_socket_->Connect(host_socket_);
    309 
    310     host_pseudotcp_.reset(new PseudoTcpAdapter(host_socket_));
    311     client_pseudotcp_.reset(new PseudoTcpAdapter(client_socket_));
    312   }
    313 
    314   FakeSocket* host_socket_;
    315   FakeSocket* client_socket_;
    316 
    317   scoped_ptr<PseudoTcpAdapter> host_pseudotcp_;
    318   scoped_ptr<PseudoTcpAdapter> client_pseudotcp_;
    319   base::MessageLoop message_loop_;
    320 };
    321 
    322 TEST_F(PseudoTcpAdapterTest, DataTransfer) {
    323   net::TestCompletionCallback host_connect_cb;
    324   net::TestCompletionCallback client_connect_cb;
    325 
    326   int rv1 = host_pseudotcp_->Connect(host_connect_cb.callback());
    327   int rv2 = client_pseudotcp_->Connect(client_connect_cb.callback());
    328 
    329   if (rv1 == net::ERR_IO_PENDING)
    330     rv1 = host_connect_cb.WaitForResult();
    331   if (rv2 == net::ERR_IO_PENDING)
    332     rv2 = client_connect_cb.WaitForResult();
    333   ASSERT_EQ(net::OK, rv1);
    334   ASSERT_EQ(net::OK, rv2);
    335 
    336   scoped_refptr<TCPChannelTester> tester =
    337       new TCPChannelTester(&message_loop_, host_pseudotcp_.get(),
    338                            client_pseudotcp_.get());
    339 
    340   tester->Start();
    341   message_loop_.Run();
    342   tester->CheckResults();
    343 }
    344 
    345 TEST_F(PseudoTcpAdapterTest, LimitedChannel) {
    346   const int kLatencyMs = 20;
    347   const int kPacketsPerSecond = 400;
    348   const int kBurstPackets = 10;
    349 
    350   LeakyBucket host_limiter(kBurstPackets, kPacketsPerSecond);
    351   host_socket_->set_latency(kLatencyMs);
    352   host_socket_->set_rate_limiter(&host_limiter);
    353 
    354   LeakyBucket client_limiter(kBurstPackets, kPacketsPerSecond);
    355   host_socket_->set_latency(kLatencyMs);
    356   client_socket_->set_rate_limiter(&client_limiter);
    357 
    358   net::TestCompletionCallback host_connect_cb;
    359   net::TestCompletionCallback client_connect_cb;
    360 
    361   int rv1 = host_pseudotcp_->Connect(host_connect_cb.callback());
    362   int rv2 = client_pseudotcp_->Connect(client_connect_cb.callback());
    363 
    364   if (rv1 == net::ERR_IO_PENDING)
    365     rv1 = host_connect_cb.WaitForResult();
    366   if (rv2 == net::ERR_IO_PENDING)
    367     rv2 = client_connect_cb.WaitForResult();
    368   ASSERT_EQ(net::OK, rv1);
    369   ASSERT_EQ(net::OK, rv2);
    370 
    371   scoped_refptr<TCPChannelTester> tester =
    372       new TCPChannelTester(&message_loop_, host_pseudotcp_.get(),
    373                            client_pseudotcp_.get());
    374 
    375   tester->Start();
    376   message_loop_.Run();
    377   tester->CheckResults();
    378 }
    379 
    380 class DeleteOnConnected {
    381  public:
    382   DeleteOnConnected(base::MessageLoop* message_loop,
    383                     scoped_ptr<PseudoTcpAdapter>* adapter)
    384       : message_loop_(message_loop), adapter_(adapter) {}
    385   void OnConnected(int error) {
    386     adapter_->reset();
    387     message_loop_->PostTask(FROM_HERE, base::MessageLoop::QuitClosure());
    388   }
    389   base::MessageLoop* message_loop_;
    390   scoped_ptr<PseudoTcpAdapter>* adapter_;
    391 };
    392 
    393 TEST_F(PseudoTcpAdapterTest, DeleteOnConnected) {
    394   // This test verifies that deleting the adapter mid-callback doesn't lead
    395   // to deleted structures being touched as the stack unrolls, so the failure
    396   // mode is a crash rather than a normal test failure.
    397   net::TestCompletionCallback client_connect_cb;
    398   DeleteOnConnected host_delete(&message_loop_, &host_pseudotcp_);
    399 
    400   host_pseudotcp_->Connect(base::Bind(&DeleteOnConnected::OnConnected,
    401                                       base::Unretained(&host_delete)));
    402   client_pseudotcp_->Connect(client_connect_cb.callback());
    403   message_loop_.Run();
    404 
    405   ASSERT_EQ(NULL, host_pseudotcp_.get());
    406 }
    407 
    408 // Verify that we can send/receive data with the write-waits-for-send
    409 // flag set.
    410 TEST_F(PseudoTcpAdapterTest, WriteWaitsForSendLetsDataThrough) {
    411   net::TestCompletionCallback host_connect_cb;
    412   net::TestCompletionCallback client_connect_cb;
    413 
    414   host_pseudotcp_->SetWriteWaitsForSend(true);
    415   client_pseudotcp_->SetWriteWaitsForSend(true);
    416 
    417   // Disable Nagle's algorithm because the test is slow when it is
    418   // enabled.
    419   host_pseudotcp_->SetNoDelay(true);
    420 
    421   int rv1 = host_pseudotcp_->Connect(host_connect_cb.callback());
    422   int rv2 = client_pseudotcp_->Connect(client_connect_cb.callback());
    423 
    424   if (rv1 == net::ERR_IO_PENDING)
    425     rv1 = host_connect_cb.WaitForResult();
    426   if (rv2 == net::ERR_IO_PENDING)
    427     rv2 = client_connect_cb.WaitForResult();
    428   ASSERT_EQ(net::OK, rv1);
    429   ASSERT_EQ(net::OK, rv2);
    430 
    431   scoped_refptr<TCPChannelTester> tester =
    432       new TCPChannelTester(&message_loop_, host_pseudotcp_.get(),
    433                            client_pseudotcp_.get());
    434 
    435   tester->Start();
    436   message_loop_.Run();
    437   tester->CheckResults();
    438 }
    439 
    440 }  // namespace
    441 
    442 }  // namespace jingle_glue
    443