1 // Copyright 2013 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "net/tools/quic/quic_time_wait_list_manager.h" 6 7 #include <errno.h> 8 9 #include "net/quic/crypto/crypto_protocol.h" 10 #include "net/quic/crypto/null_encrypter.h" 11 #include "net/quic/crypto/quic_decrypter.h" 12 #include "net/quic/crypto/quic_encrypter.h" 13 #include "net/quic/quic_data_reader.h" 14 #include "net/quic/quic_framer.h" 15 #include "net/quic/quic_packet_writer.h" 16 #include "net/quic/quic_protocol.h" 17 #include "net/quic/quic_utils.h" 18 #include "net/quic/test_tools/quic_test_utils.h" 19 #include "net/tools/quic/test_tools/mock_epoll_server.h" 20 #include "net/tools/quic/test_tools/quic_test_utils.h" 21 #include "testing/gmock/include/gmock/gmock.h" 22 #include "testing/gtest/include/gtest/gtest.h" 23 24 using net::test::BuildUnsizedDataPacket; 25 using net::test::NoOpFramerVisitor; 26 using net::test::QuicVersionMax; 27 using net::test::QuicVersionMin; 28 using testing::Args; 29 using testing::Assign; 30 using testing::DoAll; 31 using testing::Matcher; 32 using testing::MatcherInterface; 33 using testing::NiceMock; 34 using testing::Return; 35 using testing::ReturnPointee; 36 using testing::SetArgPointee; 37 using testing::StrictMock; 38 using testing::Truly; 39 using testing::_; 40 41 namespace net { 42 namespace tools { 43 namespace test { 44 45 class FramerVisitorCapturingPublicReset : public NoOpFramerVisitor { 46 public: 47 FramerVisitorCapturingPublicReset() {} 48 virtual ~FramerVisitorCapturingPublicReset() OVERRIDE {} 49 50 virtual void OnPublicResetPacket( 51 const QuicPublicResetPacket& public_reset) OVERRIDE { 52 public_reset_packet_ = public_reset; 53 } 54 55 const QuicPublicResetPacket public_reset_packet() { 56 return public_reset_packet_; 57 } 58 59 private: 60 QuicPublicResetPacket public_reset_packet_; 61 }; 62 63 class QuicTimeWaitListManagerPeer { 64 public: 65 static bool ShouldSendResponse(QuicTimeWaitListManager* manager, 66 int received_packet_count) { 67 return manager->ShouldSendResponse(received_packet_count); 68 } 69 70 static QuicTime::Delta time_wait_period(QuicTimeWaitListManager* manager) { 71 return manager->kTimeWaitPeriod_; 72 } 73 74 static QuicVersion GetQuicVersionFromConnectionId( 75 QuicTimeWaitListManager* manager, 76 QuicConnectionId connection_id) { 77 return manager->GetQuicVersionFromConnectionId(connection_id); 78 } 79 }; 80 81 namespace { 82 83 class MockFakeTimeEpollServer : public FakeTimeEpollServer { 84 public: 85 MOCK_METHOD2(RegisterAlarm, void(int64 timeout_in_us, 86 EpollAlarmCallbackInterface* alarm)); 87 }; 88 89 class QuicTimeWaitListManagerTest : public ::testing::Test { 90 protected: 91 QuicTimeWaitListManagerTest() 92 : time_wait_list_manager_(&writer_, &visitor_, 93 &epoll_server_, QuicSupportedVersions()), 94 framer_(QuicSupportedVersions(), QuicTime::Zero(), true), 95 connection_id_(45), 96 client_address_(net::test::TestPeerIPAddress(), kTestPort), 97 writer_is_blocked_(false) {} 98 99 virtual ~QuicTimeWaitListManagerTest() OVERRIDE {} 100 101 virtual void SetUp() OVERRIDE { 102 EXPECT_CALL(writer_, IsWriteBlocked()) 103 .WillRepeatedly(ReturnPointee(&writer_is_blocked_)); 104 EXPECT_CALL(writer_, IsWriteBlockedDataBuffered()) 105 .WillRepeatedly(Return(false)); 106 } 107 108 void AddConnectionId(QuicConnectionId connection_id) { 109 AddConnectionId(connection_id, QuicVersionMax(), NULL); 110 } 111 112 void AddConnectionId(QuicConnectionId connection_id, 113 QuicVersion version, 114 QuicEncryptedPacket* packet) { 115 time_wait_list_manager_.AddConnectionIdToTimeWait( 116 connection_id, version, packet); 117 } 118 119 bool IsConnectionIdInTimeWait(QuicConnectionId connection_id) { 120 return time_wait_list_manager_.IsConnectionIdInTimeWait(connection_id); 121 } 122 123 void ProcessPacket(QuicConnectionId connection_id, 124 QuicPacketSequenceNumber sequence_number) { 125 QuicEncryptedPacket packet(NULL, 0); 126 time_wait_list_manager_.ProcessPacket(server_address_, 127 client_address_, 128 connection_id, 129 sequence_number, 130 packet); 131 } 132 133 QuicEncryptedPacket* ConstructEncryptedPacket( 134 EncryptionLevel level, 135 QuicConnectionId connection_id, 136 QuicPacketSequenceNumber sequence_number) { 137 QuicPacketHeader header; 138 header.public_header.connection_id = connection_id; 139 header.public_header.connection_id_length = PACKET_8BYTE_CONNECTION_ID; 140 header.public_header.version_flag = false; 141 header.public_header.reset_flag = false; 142 header.public_header.sequence_number_length = PACKET_6BYTE_SEQUENCE_NUMBER; 143 header.packet_sequence_number = sequence_number; 144 header.entropy_flag = false; 145 header.entropy_hash = 0; 146 header.fec_flag = false; 147 header.is_in_fec_group = NOT_IN_FEC_GROUP; 148 header.fec_group = 0; 149 QuicStreamFrame stream_frame(1, false, 0, MakeIOVector("data")); 150 QuicFrame frame(&stream_frame); 151 QuicFrames frames; 152 frames.push_back(frame); 153 scoped_ptr<QuicPacket> packet( 154 BuildUnsizedDataPacket(&framer_, header, frames).packet); 155 EXPECT_TRUE(packet != NULL); 156 QuicEncryptedPacket* encrypted = framer_.EncryptPacket(ENCRYPTION_NONE, 157 sequence_number, 158 *packet); 159 EXPECT_TRUE(encrypted != NULL); 160 return encrypted; 161 } 162 163 NiceMock<MockFakeTimeEpollServer> epoll_server_; 164 StrictMock<MockPacketWriter> writer_; 165 StrictMock<MockQuicServerSessionVisitor> visitor_; 166 QuicTimeWaitListManager time_wait_list_manager_; 167 QuicFramer framer_; 168 QuicConnectionId connection_id_; 169 IPEndPoint server_address_; 170 IPEndPoint client_address_; 171 bool writer_is_blocked_; 172 }; 173 174 class ValidatePublicResetPacketPredicate 175 : public MatcherInterface<const std::tr1::tuple<const char*, int> > { 176 public: 177 explicit ValidatePublicResetPacketPredicate(QuicConnectionId connection_id, 178 QuicPacketSequenceNumber number) 179 : connection_id_(connection_id), sequence_number_(number) { 180 } 181 182 virtual bool MatchAndExplain( 183 const std::tr1::tuple<const char*, int> packet_buffer, 184 testing::MatchResultListener* /* listener */) const OVERRIDE { 185 FramerVisitorCapturingPublicReset visitor; 186 QuicFramer framer(QuicSupportedVersions(), 187 QuicTime::Zero(), 188 false); 189 framer.set_visitor(&visitor); 190 QuicEncryptedPacket encrypted(std::tr1::get<0>(packet_buffer), 191 std::tr1::get<1>(packet_buffer)); 192 framer.ProcessPacket(encrypted); 193 QuicPublicResetPacket packet = visitor.public_reset_packet(); 194 return connection_id_ == packet.public_header.connection_id && 195 packet.public_header.reset_flag && !packet.public_header.version_flag && 196 sequence_number_ == packet.rejected_sequence_number && 197 net::test::TestPeerIPAddress() == packet.client_address.address() && 198 kTestPort == packet.client_address.port(); 199 } 200 201 virtual void DescribeTo(::std::ostream* os) const OVERRIDE {} 202 203 virtual void DescribeNegationTo(::std::ostream* os) const OVERRIDE {} 204 205 private: 206 QuicConnectionId connection_id_; 207 QuicPacketSequenceNumber sequence_number_; 208 }; 209 210 211 Matcher<const std::tr1::tuple<const char*, int> > PublicResetPacketEq( 212 QuicConnectionId connection_id, 213 QuicPacketSequenceNumber sequence_number) { 214 return MakeMatcher(new ValidatePublicResetPacketPredicate(connection_id, 215 sequence_number)); 216 } 217 218 TEST_F(QuicTimeWaitListManagerTest, CheckConnectionIdInTimeWait) { 219 EXPECT_FALSE(IsConnectionIdInTimeWait(connection_id_)); 220 AddConnectionId(connection_id_); 221 EXPECT_TRUE(IsConnectionIdInTimeWait(connection_id_)); 222 } 223 224 TEST_F(QuicTimeWaitListManagerTest, SendConnectionClose) { 225 size_t kConnectionCloseLength = 100; 226 AddConnectionId( 227 connection_id_, 228 QuicVersionMax(), 229 new QuicEncryptedPacket( 230 new char[kConnectionCloseLength], kConnectionCloseLength, true)); 231 const int kRandomSequenceNumber = 1; 232 EXPECT_CALL(writer_, WritePacket(_, kConnectionCloseLength, 233 server_address_.address(), 234 client_address_)) 235 .WillOnce(Return(WriteResult(WRITE_STATUS_OK, 1))); 236 237 ProcessPacket(connection_id_, kRandomSequenceNumber); 238 } 239 240 TEST_F(QuicTimeWaitListManagerTest, SendPublicReset) { 241 AddConnectionId(connection_id_); 242 const int kRandomSequenceNumber = 1; 243 EXPECT_CALL(writer_, WritePacket(_, _, 244 server_address_.address(), 245 client_address_)) 246 .With(Args<0, 1>(PublicResetPacketEq(connection_id_, 247 kRandomSequenceNumber))) 248 .WillOnce(Return(WriteResult(WRITE_STATUS_OK, 0))); 249 250 ProcessPacket(connection_id_, kRandomSequenceNumber); 251 } 252 253 TEST_F(QuicTimeWaitListManagerTest, SendPublicResetWithExponentialBackOff) { 254 AddConnectionId(connection_id_); 255 for (int sequence_number = 1; sequence_number < 101; ++sequence_number) { 256 if ((sequence_number & (sequence_number - 1)) == 0) { 257 EXPECT_CALL(writer_, WritePacket(_, _, _, _)) 258 .WillOnce(Return(WriteResult(WRITE_STATUS_OK, 1))); 259 } 260 ProcessPacket(connection_id_, sequence_number); 261 // Send public reset with exponential back off. 262 if ((sequence_number & (sequence_number - 1)) == 0) { 263 EXPECT_TRUE(QuicTimeWaitListManagerPeer::ShouldSendResponse( 264 &time_wait_list_manager_, sequence_number)); 265 } else { 266 EXPECT_FALSE(QuicTimeWaitListManagerPeer::ShouldSendResponse( 267 &time_wait_list_manager_, sequence_number)); 268 } 269 } 270 } 271 272 TEST_F(QuicTimeWaitListManagerTest, CleanUpOldConnectionIds) { 273 const int kConnectionIdCount = 100; 274 const int kOldConnectionIdCount = 31; 275 276 // Add connection_ids such that their expiry time is kTimeWaitPeriod_. 277 epoll_server_.set_now_in_usec(0); 278 for (int connection_id = 1; 279 connection_id <= kOldConnectionIdCount; 280 ++connection_id) { 281 AddConnectionId(connection_id); 282 } 283 284 // Add remaining connection_ids such that their add time is 285 // 2 * kTimeWaitPeriod. 286 const QuicTime::Delta time_wait_period = 287 QuicTimeWaitListManagerPeer::time_wait_period(&time_wait_list_manager_); 288 epoll_server_.set_now_in_usec(time_wait_period.ToMicroseconds()); 289 for (int connection_id = kOldConnectionIdCount + 1; 290 connection_id <= kConnectionIdCount; 291 ++connection_id) { 292 AddConnectionId(connection_id); 293 } 294 295 QuicTime::Delta offset = QuicTime::Delta::FromMicroseconds(39); 296 // Now set the current time as time_wait_period + offset usecs. 297 epoll_server_.set_now_in_usec(time_wait_period.Add(offset).ToMicroseconds()); 298 // After all the old connection_ids are cleaned up, check the next alarm 299 // interval. 300 int64 next_alarm_time = epoll_server_.ApproximateNowInUsec() + 301 time_wait_period.Subtract(offset).ToMicroseconds(); 302 EXPECT_CALL(epoll_server_, RegisterAlarm(next_alarm_time, _)); 303 304 time_wait_list_manager_.CleanUpOldConnectionIds(); 305 for (int connection_id = 1; 306 connection_id <= kConnectionIdCount; 307 ++connection_id) { 308 EXPECT_EQ(connection_id > kOldConnectionIdCount, 309 IsConnectionIdInTimeWait(connection_id)) 310 << "kOldConnectionIdCount: " << kOldConnectionIdCount 311 << " connection_id: " << connection_id; 312 } 313 } 314 315 TEST_F(QuicTimeWaitListManagerTest, SendQueuedPackets) { 316 QuicConnectionId connection_id = 1; 317 AddConnectionId(connection_id); 318 QuicPacketSequenceNumber sequence_number = 234; 319 scoped_ptr<QuicEncryptedPacket> packet(ConstructEncryptedPacket( 320 ENCRYPTION_NONE, connection_id, sequence_number)); 321 // Let first write through. 322 EXPECT_CALL(writer_, WritePacket(_, _, 323 server_address_.address(), 324 client_address_)) 325 .With(Args<0, 1>(PublicResetPacketEq(connection_id, 326 sequence_number))) 327 .WillOnce(Return(WriteResult(WRITE_STATUS_OK, packet->length()))); 328 ProcessPacket(connection_id, sequence_number); 329 330 // write block for the next packet. 331 EXPECT_CALL(writer_, WritePacket(_, _, 332 server_address_.address(), 333 client_address_)) 334 .With(Args<0, 1>(PublicResetPacketEq(connection_id, 335 sequence_number))) 336 .WillOnce(DoAll( 337 Assign(&writer_is_blocked_, true), 338 Return(WriteResult(WRITE_STATUS_BLOCKED, EAGAIN)))); 339 EXPECT_CALL(visitor_, OnWriteBlocked(&time_wait_list_manager_)); 340 ProcessPacket(connection_id, sequence_number); 341 // 3rd packet. No public reset should be sent; 342 ProcessPacket(connection_id, sequence_number); 343 344 // write packet should not be called since we are write blocked but the 345 // should be queued. 346 QuicConnectionId other_connection_id = 2; 347 AddConnectionId(other_connection_id); 348 QuicPacketSequenceNumber other_sequence_number = 23423; 349 scoped_ptr<QuicEncryptedPacket> other_packet( 350 ConstructEncryptedPacket( 351 ENCRYPTION_NONE, other_connection_id, other_sequence_number)); 352 EXPECT_CALL(writer_, WritePacket(_, _, _, _)) 353 .Times(0); 354 EXPECT_CALL(visitor_, OnWriteBlocked(&time_wait_list_manager_)); 355 ProcessPacket(other_connection_id, other_sequence_number); 356 357 // Now expect all the write blocked public reset packets to be sent again. 358 writer_is_blocked_ = false; 359 EXPECT_CALL(writer_, WritePacket(_, _, 360 server_address_.address(), 361 client_address_)) 362 .With(Args<0, 1>(PublicResetPacketEq(connection_id, 363 sequence_number))) 364 .WillOnce(Return(WriteResult(WRITE_STATUS_OK, packet->length()))); 365 EXPECT_CALL(writer_, WritePacket(_, _, 366 server_address_.address(), 367 client_address_)) 368 .With(Args<0, 1>(PublicResetPacketEq(other_connection_id, 369 other_sequence_number))) 370 .WillOnce(Return(WriteResult(WRITE_STATUS_OK, 371 other_packet->length()))); 372 time_wait_list_manager_.OnCanWrite(); 373 } 374 375 TEST_F(QuicTimeWaitListManagerTest, GetQuicVersionFromMap) { 376 const int kConnectionId1 = 123; 377 const int kConnectionId2 = 456; 378 const int kConnectionId3 = 789; 379 380 AddConnectionId(kConnectionId1, QuicVersionMin(), NULL); 381 AddConnectionId(kConnectionId2, QuicVersionMax(), NULL); 382 AddConnectionId(kConnectionId3, QuicVersionMax(), NULL); 383 384 EXPECT_EQ(QuicVersionMin(), 385 QuicTimeWaitListManagerPeer::GetQuicVersionFromConnectionId( 386 &time_wait_list_manager_, kConnectionId1)); 387 EXPECT_EQ(QuicVersionMax(), 388 QuicTimeWaitListManagerPeer::GetQuicVersionFromConnectionId( 389 &time_wait_list_manager_, kConnectionId2)); 390 EXPECT_EQ(QuicVersionMax(), 391 QuicTimeWaitListManagerPeer::GetQuicVersionFromConnectionId( 392 &time_wait_list_manager_, kConnectionId3)); 393 } 394 395 TEST_F(QuicTimeWaitListManagerTest, AddConnectionIdTwice) { 396 // Add connection_ids such that their expiry time is kTimeWaitPeriod_. 397 epoll_server_.set_now_in_usec(0); 398 AddConnectionId(connection_id_); 399 EXPECT_TRUE(IsConnectionIdInTimeWait(connection_id_)); 400 size_t kConnectionCloseLength = 100; 401 AddConnectionId( 402 connection_id_, 403 QuicVersionMax(), 404 new QuicEncryptedPacket( 405 new char[kConnectionCloseLength], kConnectionCloseLength, true)); 406 EXPECT_TRUE(IsConnectionIdInTimeWait(connection_id_)); 407 408 EXPECT_CALL(writer_, WritePacket(_, 409 kConnectionCloseLength, 410 server_address_.address(), 411 client_address_)) 412 .WillOnce(Return(WriteResult(WRITE_STATUS_OK, 1))); 413 414 const int kRandomSequenceNumber = 1; 415 ProcessPacket(connection_id_, kRandomSequenceNumber); 416 417 const QuicTime::Delta time_wait_period = 418 QuicTimeWaitListManagerPeer::time_wait_period(&time_wait_list_manager_); 419 420 QuicTime::Delta offset = QuicTime::Delta::FromMicroseconds(39); 421 // Now set the current time as time_wait_period + offset usecs. 422 epoll_server_.set_now_in_usec(time_wait_period.Add(offset).ToMicroseconds()); 423 // After the connection_ids are cleaned up, check the next alarm interval. 424 int64 next_alarm_time = epoll_server_.ApproximateNowInUsec() + 425 time_wait_period.ToMicroseconds(); 426 427 EXPECT_CALL(epoll_server_, RegisterAlarm(next_alarm_time, _)); 428 time_wait_list_manager_.CleanUpOldConnectionIds(); 429 EXPECT_FALSE(IsConnectionIdInTimeWait(connection_id_)); 430 } 431 432 TEST_F(QuicTimeWaitListManagerTest, ConnectionIdsOrderedByTime) { 433 // Simple randomization: the values of connection_ids are swapped based on the 434 // current seconds on the clock. If the container is broken, the test will be 435 // 50% flaky. 436 int odd_second = static_cast<int>(epoll_server_.ApproximateNowInUsec()) % 2; 437 EXPECT_TRUE(odd_second == 0 || odd_second == 1); 438 const QuicConnectionId kConnectionId1 = odd_second; 439 const QuicConnectionId kConnectionId2 = 1 - odd_second; 440 441 // 1 will hash lower than 2, but we add it later. They should come out in the 442 // add order, not hash order. 443 epoll_server_.set_now_in_usec(0); 444 AddConnectionId(kConnectionId1); 445 epoll_server_.set_now_in_usec(10); 446 AddConnectionId(kConnectionId2); 447 448 const QuicTime::Delta time_wait_period = 449 QuicTimeWaitListManagerPeer::time_wait_period(&time_wait_list_manager_); 450 epoll_server_.set_now_in_usec(time_wait_period.ToMicroseconds() + 1); 451 452 EXPECT_CALL(epoll_server_, RegisterAlarm(_, _)); 453 454 time_wait_list_manager_.CleanUpOldConnectionIds(); 455 EXPECT_FALSE(IsConnectionIdInTimeWait(kConnectionId1)); 456 EXPECT_TRUE(IsConnectionIdInTimeWait(kConnectionId2)); 457 } 458 } // namespace 459 } // namespace test 460 } // namespace tools 461 } // namespace net 462