Home | History | Annotate | Download | only in quic
      1 // Copyright 2013 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/tools/quic/quic_time_wait_list_manager.h"
      6 
      7 #include <errno.h>
      8 
      9 #include "net/quic/crypto/crypto_protocol.h"
     10 #include "net/quic/crypto/null_encrypter.h"
     11 #include "net/quic/crypto/quic_decrypter.h"
     12 #include "net/quic/crypto/quic_encrypter.h"
     13 #include "net/quic/quic_data_reader.h"
     14 #include "net/quic/quic_framer.h"
     15 #include "net/quic/quic_packet_writer.h"
     16 #include "net/quic/quic_protocol.h"
     17 #include "net/quic/quic_utils.h"
     18 #include "net/quic/test_tools/quic_test_utils.h"
     19 #include "net/tools/quic/test_tools/mock_epoll_server.h"
     20 #include "net/tools/quic/test_tools/quic_test_utils.h"
     21 #include "testing/gmock/include/gmock/gmock.h"
     22 #include "testing/gtest/include/gtest/gtest.h"
     23 
     24 using net::test::BuildUnsizedDataPacket;
     25 using net::test::NoOpFramerVisitor;
     26 using net::test::QuicVersionMax;
     27 using net::test::QuicVersionMin;
     28 using testing::Args;
     29 using testing::Assign;
     30 using testing::DoAll;
     31 using testing::Matcher;
     32 using testing::MatcherInterface;
     33 using testing::NiceMock;
     34 using testing::Return;
     35 using testing::ReturnPointee;
     36 using testing::SetArgPointee;
     37 using testing::StrictMock;
     38 using testing::Truly;
     39 using testing::_;
     40 
     41 namespace net {
     42 namespace tools {
     43 namespace test {
     44 
     45 class FramerVisitorCapturingPublicReset : public NoOpFramerVisitor {
     46  public:
     47   FramerVisitorCapturingPublicReset() {}
     48   virtual ~FramerVisitorCapturingPublicReset() OVERRIDE {}
     49 
     50   virtual void OnPublicResetPacket(
     51       const QuicPublicResetPacket& public_reset) OVERRIDE {
     52     public_reset_packet_ = public_reset;
     53   }
     54 
     55   const QuicPublicResetPacket public_reset_packet() {
     56     return public_reset_packet_;
     57   }
     58 
     59  private:
     60   QuicPublicResetPacket public_reset_packet_;
     61 };
     62 
     63 class QuicTimeWaitListManagerPeer {
     64  public:
     65   static bool ShouldSendResponse(QuicTimeWaitListManager* manager,
     66                                  int received_packet_count) {
     67     return manager->ShouldSendResponse(received_packet_count);
     68   }
     69 
     70   static QuicTime::Delta time_wait_period(QuicTimeWaitListManager* manager) {
     71     return manager->kTimeWaitPeriod_;
     72   }
     73 
     74   static QuicVersion GetQuicVersionFromConnectionId(
     75       QuicTimeWaitListManager* manager,
     76       QuicConnectionId connection_id) {
     77     return manager->GetQuicVersionFromConnectionId(connection_id);
     78   }
     79 };
     80 
     81 namespace {
     82 
     83 class MockFakeTimeEpollServer : public FakeTimeEpollServer {
     84  public:
     85   MOCK_METHOD2(RegisterAlarm, void(int64 timeout_in_us,
     86                                    EpollAlarmCallbackInterface* alarm));
     87 };
     88 
     89 class QuicTimeWaitListManagerTest : public ::testing::Test {
     90  protected:
     91   QuicTimeWaitListManagerTest()
     92       : time_wait_list_manager_(&writer_, &visitor_,
     93                                 &epoll_server_, QuicSupportedVersions()),
     94         framer_(QuicSupportedVersions(), QuicTime::Zero(), true),
     95         connection_id_(45),
     96         client_address_(net::test::TestPeerIPAddress(), kTestPort),
     97         writer_is_blocked_(false) {}
     98 
     99   virtual ~QuicTimeWaitListManagerTest() OVERRIDE {}
    100 
    101   virtual void SetUp() OVERRIDE {
    102     EXPECT_CALL(writer_, IsWriteBlocked())
    103         .WillRepeatedly(ReturnPointee(&writer_is_blocked_));
    104     EXPECT_CALL(writer_, IsWriteBlockedDataBuffered())
    105         .WillRepeatedly(Return(false));
    106   }
    107 
    108   void AddConnectionId(QuicConnectionId connection_id) {
    109     AddConnectionId(connection_id, QuicVersionMax(), NULL);
    110   }
    111 
    112   void AddConnectionId(QuicConnectionId connection_id,
    113                        QuicVersion version,
    114                        QuicEncryptedPacket* packet) {
    115     time_wait_list_manager_.AddConnectionIdToTimeWait(
    116         connection_id, version, packet);
    117   }
    118 
    119   bool IsConnectionIdInTimeWait(QuicConnectionId connection_id) {
    120     return time_wait_list_manager_.IsConnectionIdInTimeWait(connection_id);
    121   }
    122 
    123   void ProcessPacket(QuicConnectionId connection_id,
    124                      QuicPacketSequenceNumber sequence_number) {
    125     QuicEncryptedPacket packet(NULL, 0);
    126     time_wait_list_manager_.ProcessPacket(server_address_,
    127                                           client_address_,
    128                                           connection_id,
    129                                           sequence_number,
    130                                           packet);
    131   }
    132 
    133   QuicEncryptedPacket* ConstructEncryptedPacket(
    134       EncryptionLevel level,
    135       QuicConnectionId connection_id,
    136       QuicPacketSequenceNumber sequence_number) {
    137     QuicPacketHeader header;
    138     header.public_header.connection_id = connection_id;
    139     header.public_header.connection_id_length = PACKET_8BYTE_CONNECTION_ID;
    140     header.public_header.version_flag = false;
    141     header.public_header.reset_flag = false;
    142     header.public_header.sequence_number_length = PACKET_6BYTE_SEQUENCE_NUMBER;
    143     header.packet_sequence_number = sequence_number;
    144     header.entropy_flag = false;
    145     header.entropy_hash = 0;
    146     header.fec_flag = false;
    147     header.is_in_fec_group = NOT_IN_FEC_GROUP;
    148     header.fec_group = 0;
    149     QuicStreamFrame stream_frame(1, false, 0, MakeIOVector("data"));
    150     QuicFrame frame(&stream_frame);
    151     QuicFrames frames;
    152     frames.push_back(frame);
    153     scoped_ptr<QuicPacket> packet(
    154         BuildUnsizedDataPacket(&framer_, header, frames).packet);
    155     EXPECT_TRUE(packet != NULL);
    156     QuicEncryptedPacket* encrypted = framer_.EncryptPacket(ENCRYPTION_NONE,
    157                                                            sequence_number,
    158                                                            *packet);
    159     EXPECT_TRUE(encrypted != NULL);
    160     return encrypted;
    161   }
    162 
    163   NiceMock<MockFakeTimeEpollServer> epoll_server_;
    164   StrictMock<MockPacketWriter> writer_;
    165   StrictMock<MockQuicServerSessionVisitor> visitor_;
    166   QuicTimeWaitListManager time_wait_list_manager_;
    167   QuicFramer framer_;
    168   QuicConnectionId connection_id_;
    169   IPEndPoint server_address_;
    170   IPEndPoint client_address_;
    171   bool writer_is_blocked_;
    172 };
    173 
    174 class ValidatePublicResetPacketPredicate
    175     : public MatcherInterface<const std::tr1::tuple<const char*, int> > {
    176  public:
    177   explicit ValidatePublicResetPacketPredicate(QuicConnectionId connection_id,
    178                                               QuicPacketSequenceNumber number)
    179       : connection_id_(connection_id), sequence_number_(number) {
    180   }
    181 
    182   virtual bool MatchAndExplain(
    183       const std::tr1::tuple<const char*, int> packet_buffer,
    184       testing::MatchResultListener* /* listener */) const OVERRIDE {
    185     FramerVisitorCapturingPublicReset visitor;
    186     QuicFramer framer(QuicSupportedVersions(),
    187                       QuicTime::Zero(),
    188                       false);
    189     framer.set_visitor(&visitor);
    190     QuicEncryptedPacket encrypted(std::tr1::get<0>(packet_buffer),
    191                                   std::tr1::get<1>(packet_buffer));
    192     framer.ProcessPacket(encrypted);
    193     QuicPublicResetPacket packet = visitor.public_reset_packet();
    194     return connection_id_ == packet.public_header.connection_id &&
    195         packet.public_header.reset_flag && !packet.public_header.version_flag &&
    196         sequence_number_ == packet.rejected_sequence_number &&
    197         net::test::TestPeerIPAddress() == packet.client_address.address() &&
    198         kTestPort == packet.client_address.port();
    199   }
    200 
    201   virtual void DescribeTo(::std::ostream* os) const OVERRIDE {}
    202 
    203   virtual void DescribeNegationTo(::std::ostream* os) const OVERRIDE {}
    204 
    205  private:
    206   QuicConnectionId connection_id_;
    207   QuicPacketSequenceNumber sequence_number_;
    208 };
    209 
    210 
    211 Matcher<const std::tr1::tuple<const char*, int> > PublicResetPacketEq(
    212     QuicConnectionId connection_id,
    213     QuicPacketSequenceNumber sequence_number) {
    214   return MakeMatcher(new ValidatePublicResetPacketPredicate(connection_id,
    215                                                             sequence_number));
    216 }
    217 
    218 TEST_F(QuicTimeWaitListManagerTest, CheckConnectionIdInTimeWait) {
    219   EXPECT_FALSE(IsConnectionIdInTimeWait(connection_id_));
    220   AddConnectionId(connection_id_);
    221   EXPECT_TRUE(IsConnectionIdInTimeWait(connection_id_));
    222 }
    223 
    224 TEST_F(QuicTimeWaitListManagerTest, SendConnectionClose) {
    225   size_t kConnectionCloseLength = 100;
    226   AddConnectionId(
    227       connection_id_,
    228       QuicVersionMax(),
    229       new QuicEncryptedPacket(
    230           new char[kConnectionCloseLength], kConnectionCloseLength, true));
    231   const int kRandomSequenceNumber = 1;
    232   EXPECT_CALL(writer_, WritePacket(_, kConnectionCloseLength,
    233                                    server_address_.address(),
    234                                    client_address_))
    235       .WillOnce(Return(WriteResult(WRITE_STATUS_OK, 1)));
    236 
    237   ProcessPacket(connection_id_, kRandomSequenceNumber);
    238 }
    239 
    240 TEST_F(QuicTimeWaitListManagerTest, SendPublicReset) {
    241   AddConnectionId(connection_id_);
    242   const int kRandomSequenceNumber = 1;
    243   EXPECT_CALL(writer_, WritePacket(_, _,
    244                                    server_address_.address(),
    245                                    client_address_))
    246       .With(Args<0, 1>(PublicResetPacketEq(connection_id_,
    247                                            kRandomSequenceNumber)))
    248       .WillOnce(Return(WriteResult(WRITE_STATUS_OK, 0)));
    249 
    250   ProcessPacket(connection_id_, kRandomSequenceNumber);
    251 }
    252 
    253 TEST_F(QuicTimeWaitListManagerTest, SendPublicResetWithExponentialBackOff) {
    254   AddConnectionId(connection_id_);
    255   for (int sequence_number = 1; sequence_number < 101; ++sequence_number) {
    256     if ((sequence_number & (sequence_number - 1)) == 0) {
    257       EXPECT_CALL(writer_, WritePacket(_, _, _, _))
    258           .WillOnce(Return(WriteResult(WRITE_STATUS_OK, 1)));
    259     }
    260     ProcessPacket(connection_id_, sequence_number);
    261     // Send public reset with exponential back off.
    262     if ((sequence_number & (sequence_number - 1)) == 0) {
    263       EXPECT_TRUE(QuicTimeWaitListManagerPeer::ShouldSendResponse(
    264                       &time_wait_list_manager_, sequence_number));
    265     } else {
    266       EXPECT_FALSE(QuicTimeWaitListManagerPeer::ShouldSendResponse(
    267                        &time_wait_list_manager_, sequence_number));
    268     }
    269   }
    270 }
    271 
    272 TEST_F(QuicTimeWaitListManagerTest, CleanUpOldConnectionIds) {
    273   const int kConnectionIdCount = 100;
    274   const int kOldConnectionIdCount = 31;
    275 
    276   // Add connection_ids such that their expiry time is kTimeWaitPeriod_.
    277   epoll_server_.set_now_in_usec(0);
    278   for (int connection_id = 1;
    279        connection_id <= kOldConnectionIdCount;
    280        ++connection_id) {
    281     AddConnectionId(connection_id);
    282   }
    283 
    284   // Add remaining connection_ids such that their add time is
    285   // 2 * kTimeWaitPeriod.
    286   const QuicTime::Delta time_wait_period =
    287       QuicTimeWaitListManagerPeer::time_wait_period(&time_wait_list_manager_);
    288   epoll_server_.set_now_in_usec(time_wait_period.ToMicroseconds());
    289   for (int connection_id = kOldConnectionIdCount + 1;
    290        connection_id <= kConnectionIdCount;
    291        ++connection_id) {
    292     AddConnectionId(connection_id);
    293   }
    294 
    295   QuicTime::Delta offset = QuicTime::Delta::FromMicroseconds(39);
    296   // Now set the current time as time_wait_period + offset usecs.
    297   epoll_server_.set_now_in_usec(time_wait_period.Add(offset).ToMicroseconds());
    298   // After all the old connection_ids are cleaned up, check the next alarm
    299   // interval.
    300   int64 next_alarm_time = epoll_server_.ApproximateNowInUsec() +
    301       time_wait_period.Subtract(offset).ToMicroseconds();
    302   EXPECT_CALL(epoll_server_, RegisterAlarm(next_alarm_time, _));
    303 
    304   time_wait_list_manager_.CleanUpOldConnectionIds();
    305   for (int connection_id = 1;
    306        connection_id <= kConnectionIdCount;
    307        ++connection_id) {
    308     EXPECT_EQ(connection_id > kOldConnectionIdCount,
    309               IsConnectionIdInTimeWait(connection_id))
    310         << "kOldConnectionIdCount: " << kOldConnectionIdCount
    311         << " connection_id: " <<  connection_id;
    312   }
    313 }
    314 
    315 TEST_F(QuicTimeWaitListManagerTest, SendQueuedPackets) {
    316   QuicConnectionId connection_id = 1;
    317   AddConnectionId(connection_id);
    318   QuicPacketSequenceNumber sequence_number = 234;
    319   scoped_ptr<QuicEncryptedPacket> packet(ConstructEncryptedPacket(
    320       ENCRYPTION_NONE, connection_id, sequence_number));
    321   // Let first write through.
    322   EXPECT_CALL(writer_, WritePacket(_, _,
    323                                    server_address_.address(),
    324                                    client_address_))
    325       .With(Args<0, 1>(PublicResetPacketEq(connection_id,
    326                                            sequence_number)))
    327       .WillOnce(Return(WriteResult(WRITE_STATUS_OK, packet->length())));
    328   ProcessPacket(connection_id, sequence_number);
    329 
    330   // write block for the next packet.
    331   EXPECT_CALL(writer_, WritePacket(_, _,
    332                                    server_address_.address(),
    333                                    client_address_))
    334       .With(Args<0, 1>(PublicResetPacketEq(connection_id,
    335                                            sequence_number)))
    336       .WillOnce(DoAll(
    337           Assign(&writer_is_blocked_, true),
    338           Return(WriteResult(WRITE_STATUS_BLOCKED, EAGAIN))));
    339   EXPECT_CALL(visitor_, OnWriteBlocked(&time_wait_list_manager_));
    340   ProcessPacket(connection_id, sequence_number);
    341   // 3rd packet. No public reset should be sent;
    342   ProcessPacket(connection_id, sequence_number);
    343 
    344   // write packet should not be called since we are write blocked but the
    345   // should be queued.
    346   QuicConnectionId other_connection_id = 2;
    347   AddConnectionId(other_connection_id);
    348   QuicPacketSequenceNumber other_sequence_number = 23423;
    349   scoped_ptr<QuicEncryptedPacket> other_packet(
    350       ConstructEncryptedPacket(
    351           ENCRYPTION_NONE, other_connection_id, other_sequence_number));
    352   EXPECT_CALL(writer_, WritePacket(_, _, _, _))
    353       .Times(0);
    354   EXPECT_CALL(visitor_, OnWriteBlocked(&time_wait_list_manager_));
    355   ProcessPacket(other_connection_id, other_sequence_number);
    356 
    357   // Now expect all the write blocked public reset packets to be sent again.
    358   writer_is_blocked_ = false;
    359   EXPECT_CALL(writer_, WritePacket(_, _,
    360                                    server_address_.address(),
    361                                    client_address_))
    362       .With(Args<0, 1>(PublicResetPacketEq(connection_id,
    363                                            sequence_number)))
    364       .WillOnce(Return(WriteResult(WRITE_STATUS_OK, packet->length())));
    365   EXPECT_CALL(writer_, WritePacket(_, _,
    366                                    server_address_.address(),
    367                                    client_address_))
    368       .With(Args<0, 1>(PublicResetPacketEq(other_connection_id,
    369                                            other_sequence_number)))
    370       .WillOnce(Return(WriteResult(WRITE_STATUS_OK,
    371                                    other_packet->length())));
    372   time_wait_list_manager_.OnCanWrite();
    373 }
    374 
    375 TEST_F(QuicTimeWaitListManagerTest, GetQuicVersionFromMap) {
    376   const int kConnectionId1 = 123;
    377   const int kConnectionId2 = 456;
    378   const int kConnectionId3 = 789;
    379 
    380   AddConnectionId(kConnectionId1, QuicVersionMin(), NULL);
    381   AddConnectionId(kConnectionId2, QuicVersionMax(), NULL);
    382   AddConnectionId(kConnectionId3, QuicVersionMax(), NULL);
    383 
    384   EXPECT_EQ(QuicVersionMin(),
    385             QuicTimeWaitListManagerPeer::GetQuicVersionFromConnectionId(
    386                 &time_wait_list_manager_, kConnectionId1));
    387   EXPECT_EQ(QuicVersionMax(),
    388             QuicTimeWaitListManagerPeer::GetQuicVersionFromConnectionId(
    389                 &time_wait_list_manager_, kConnectionId2));
    390   EXPECT_EQ(QuicVersionMax(),
    391             QuicTimeWaitListManagerPeer::GetQuicVersionFromConnectionId(
    392                 &time_wait_list_manager_, kConnectionId3));
    393 }
    394 
    395 TEST_F(QuicTimeWaitListManagerTest, AddConnectionIdTwice) {
    396   // Add connection_ids such that their expiry time is kTimeWaitPeriod_.
    397   epoll_server_.set_now_in_usec(0);
    398   AddConnectionId(connection_id_);
    399   EXPECT_TRUE(IsConnectionIdInTimeWait(connection_id_));
    400   size_t kConnectionCloseLength = 100;
    401   AddConnectionId(
    402       connection_id_,
    403       QuicVersionMax(),
    404       new QuicEncryptedPacket(
    405           new char[kConnectionCloseLength], kConnectionCloseLength, true));
    406   EXPECT_TRUE(IsConnectionIdInTimeWait(connection_id_));
    407 
    408   EXPECT_CALL(writer_, WritePacket(_,
    409                                    kConnectionCloseLength,
    410                                    server_address_.address(),
    411                                    client_address_))
    412       .WillOnce(Return(WriteResult(WRITE_STATUS_OK, 1)));
    413 
    414   const int kRandomSequenceNumber = 1;
    415   ProcessPacket(connection_id_, kRandomSequenceNumber);
    416 
    417   const QuicTime::Delta time_wait_period =
    418       QuicTimeWaitListManagerPeer::time_wait_period(&time_wait_list_manager_);
    419 
    420   QuicTime::Delta offset = QuicTime::Delta::FromMicroseconds(39);
    421   // Now set the current time as time_wait_period + offset usecs.
    422   epoll_server_.set_now_in_usec(time_wait_period.Add(offset).ToMicroseconds());
    423   // After the connection_ids are cleaned up, check the next alarm interval.
    424   int64 next_alarm_time = epoll_server_.ApproximateNowInUsec() +
    425       time_wait_period.ToMicroseconds();
    426 
    427   EXPECT_CALL(epoll_server_, RegisterAlarm(next_alarm_time, _));
    428   time_wait_list_manager_.CleanUpOldConnectionIds();
    429   EXPECT_FALSE(IsConnectionIdInTimeWait(connection_id_));
    430 }
    431 
    432 TEST_F(QuicTimeWaitListManagerTest, ConnectionIdsOrderedByTime) {
    433   // Simple randomization: the values of connection_ids are swapped based on the
    434   // current seconds on the clock. If the container is broken, the test will be
    435   // 50% flaky.
    436   int odd_second = static_cast<int>(epoll_server_.ApproximateNowInUsec()) % 2;
    437   EXPECT_TRUE(odd_second == 0 || odd_second == 1);
    438   const QuicConnectionId kConnectionId1 = odd_second;
    439   const QuicConnectionId kConnectionId2 = 1 - odd_second;
    440 
    441   // 1 will hash lower than 2, but we add it later. They should come out in the
    442   // add order, not hash order.
    443   epoll_server_.set_now_in_usec(0);
    444   AddConnectionId(kConnectionId1);
    445   epoll_server_.set_now_in_usec(10);
    446   AddConnectionId(kConnectionId2);
    447 
    448   const QuicTime::Delta time_wait_period =
    449       QuicTimeWaitListManagerPeer::time_wait_period(&time_wait_list_manager_);
    450   epoll_server_.set_now_in_usec(time_wait_period.ToMicroseconds() + 1);
    451 
    452   EXPECT_CALL(epoll_server_, RegisterAlarm(_, _));
    453 
    454   time_wait_list_manager_.CleanUpOldConnectionIds();
    455   EXPECT_FALSE(IsConnectionIdInTimeWait(kConnectionId1));
    456   EXPECT_TRUE(IsConnectionIdInTimeWait(kConnectionId2));
    457 }
    458 }  // namespace
    459 }  // namespace test
    460 }  // namespace tools
    461 }  // namespace net
    462