Home | History | Annotate | Download | only in test_tools
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/quic/test_tools/quic_test_utils.h"
      6 
      7 #include "base/sha1.h"
      8 #include "base/stl_util.h"
      9 #include "base/strings/string_number_conversions.h"
     10 #include "net/quic/crypto/crypto_framer.h"
     11 #include "net/quic/crypto/crypto_handshake.h"
     12 #include "net/quic/crypto/crypto_utils.h"
     13 #include "net/quic/crypto/null_encrypter.h"
     14 #include "net/quic/crypto/quic_decrypter.h"
     15 #include "net/quic/crypto/quic_encrypter.h"
     16 #include "net/quic/quic_framer.h"
     17 #include "net/quic/quic_packet_creator.h"
     18 #include "net/quic/quic_utils.h"
     19 #include "net/quic/test_tools/quic_connection_peer.h"
     20 #include "net/spdy/spdy_frame_builder.h"
     21 
     22 using base::StringPiece;
     23 using std::max;
     24 using std::min;
     25 using std::string;
     26 using testing::AnyNumber;
     27 using testing::_;
     28 
     29 namespace net {
     30 namespace test {
     31 namespace {
     32 
     33 // No-op alarm implementation used by MockHelper.
     34 class TestAlarm : public QuicAlarm {
     35  public:
     36   explicit TestAlarm(QuicAlarm::Delegate* delegate)
     37       : QuicAlarm(delegate) {
     38   }
     39 
     40   virtual void SetImpl() OVERRIDE {}
     41   virtual void CancelImpl() OVERRIDE {}
     42 };
     43 
     44 }  // namespace
     45 
     46 QuicAckFrame MakeAckFrame(QuicPacketSequenceNumber largest_observed) {
     47   QuicAckFrame ack;
     48   ack.largest_observed = largest_observed;
     49   ack.entropy_hash = 0;
     50   return ack;
     51 }
     52 
     53 QuicAckFrame MakeAckFrameWithNackRanges(
     54     size_t num_nack_ranges, QuicPacketSequenceNumber least_unacked) {
     55   QuicAckFrame ack = MakeAckFrame(2 * num_nack_ranges + least_unacked);
     56   // Add enough missing packets to get num_nack_ranges nack ranges.
     57   for (QuicPacketSequenceNumber i = 1; i < 2 * num_nack_ranges; i += 2) {
     58     ack.missing_packets.insert(least_unacked + i);
     59   }
     60   return ack;
     61 }
     62 
     63 SerializedPacket BuildUnsizedDataPacket(QuicFramer* framer,
     64                                         const QuicPacketHeader& header,
     65                                         const QuicFrames& frames) {
     66   const size_t max_plaintext_size = framer->GetMaxPlaintextSize(kMaxPacketSize);
     67   size_t packet_size = GetPacketHeaderSize(header);
     68   for (size_t i = 0; i < frames.size(); ++i) {
     69     DCHECK_LE(packet_size, max_plaintext_size);
     70     bool first_frame = i == 0;
     71     bool last_frame = i == frames.size() - 1;
     72     const size_t frame_size = framer->GetSerializedFrameLength(
     73         frames[i], max_plaintext_size - packet_size, first_frame, last_frame,
     74         header.is_in_fec_group,
     75         header.public_header.sequence_number_length);
     76     DCHECK(frame_size);
     77     packet_size += frame_size;
     78   }
     79   return framer->BuildDataPacket(header, frames, packet_size);
     80 }
     81 
     82 uint64 SimpleRandom::RandUint64() {
     83   unsigned char hash[base::kSHA1Length];
     84   base::SHA1HashBytes(reinterpret_cast<unsigned char*>(&seed_), sizeof(seed_),
     85                       hash);
     86   memcpy(&seed_, hash, sizeof(seed_));
     87   return seed_;
     88 }
     89 
     90 MockFramerVisitor::MockFramerVisitor() {
     91   // By default, we want to accept packets.
     92   ON_CALL(*this, OnProtocolVersionMismatch(_))
     93       .WillByDefault(testing::Return(false));
     94 
     95   // By default, we want to accept packets.
     96   ON_CALL(*this, OnUnauthenticatedHeader(_))
     97       .WillByDefault(testing::Return(true));
     98 
     99   ON_CALL(*this, OnUnauthenticatedPublicHeader(_))
    100       .WillByDefault(testing::Return(true));
    101 
    102   ON_CALL(*this, OnPacketHeader(_))
    103       .WillByDefault(testing::Return(true));
    104 
    105   ON_CALL(*this, OnStreamFrame(_))
    106       .WillByDefault(testing::Return(true));
    107 
    108   ON_CALL(*this, OnAckFrame(_))
    109       .WillByDefault(testing::Return(true));
    110 
    111   ON_CALL(*this, OnCongestionFeedbackFrame(_))
    112       .WillByDefault(testing::Return(true));
    113 
    114   ON_CALL(*this, OnStopWaitingFrame(_))
    115       .WillByDefault(testing::Return(true));
    116 
    117   ON_CALL(*this, OnPingFrame(_))
    118       .WillByDefault(testing::Return(true));
    119 
    120   ON_CALL(*this, OnRstStreamFrame(_))
    121       .WillByDefault(testing::Return(true));
    122 
    123   ON_CALL(*this, OnConnectionCloseFrame(_))
    124       .WillByDefault(testing::Return(true));
    125 
    126   ON_CALL(*this, OnGoAwayFrame(_))
    127       .WillByDefault(testing::Return(true));
    128 }
    129 
    130 MockFramerVisitor::~MockFramerVisitor() {
    131 }
    132 
    133 bool NoOpFramerVisitor::OnProtocolVersionMismatch(QuicVersion version) {
    134   return false;
    135 }
    136 
    137 bool NoOpFramerVisitor::OnUnauthenticatedPublicHeader(
    138     const QuicPacketPublicHeader& header) {
    139   return true;
    140 }
    141 
    142 bool NoOpFramerVisitor::OnUnauthenticatedHeader(
    143     const QuicPacketHeader& header) {
    144   return true;
    145 }
    146 
    147 bool NoOpFramerVisitor::OnPacketHeader(const QuicPacketHeader& header) {
    148   return true;
    149 }
    150 
    151 bool NoOpFramerVisitor::OnStreamFrame(const QuicStreamFrame& frame) {
    152   return true;
    153 }
    154 
    155 bool NoOpFramerVisitor::OnAckFrame(const QuicAckFrame& frame) {
    156   return true;
    157 }
    158 
    159 bool NoOpFramerVisitor::OnCongestionFeedbackFrame(
    160     const QuicCongestionFeedbackFrame& frame) {
    161   return true;
    162 }
    163 
    164 bool NoOpFramerVisitor::OnStopWaitingFrame(
    165     const QuicStopWaitingFrame& frame) {
    166   return true;
    167 }
    168 
    169 bool NoOpFramerVisitor::OnPingFrame(const QuicPingFrame& frame) {
    170   return true;
    171 }
    172 
    173 bool NoOpFramerVisitor::OnRstStreamFrame(
    174     const QuicRstStreamFrame& frame) {
    175   return true;
    176 }
    177 
    178 bool NoOpFramerVisitor::OnConnectionCloseFrame(
    179     const QuicConnectionCloseFrame& frame) {
    180   return true;
    181 }
    182 
    183 bool NoOpFramerVisitor::OnGoAwayFrame(const QuicGoAwayFrame& frame) {
    184   return true;
    185 }
    186 
    187 bool NoOpFramerVisitor::OnWindowUpdateFrame(
    188     const QuicWindowUpdateFrame& frame) {
    189   return true;
    190 }
    191 
    192 bool NoOpFramerVisitor::OnBlockedFrame(const QuicBlockedFrame& frame) {
    193   return true;
    194 }
    195 
    196 MockConnectionVisitor::MockConnectionVisitor() {
    197 }
    198 
    199 MockConnectionVisitor::~MockConnectionVisitor() {
    200 }
    201 
    202 MockHelper::MockHelper() {
    203 }
    204 
    205 MockHelper::~MockHelper() {
    206 }
    207 
    208 const QuicClock* MockHelper::GetClock() const {
    209   return &clock_;
    210 }
    211 
    212 QuicRandom* MockHelper::GetRandomGenerator() {
    213   return &random_generator_;
    214 }
    215 
    216 QuicAlarm* MockHelper::CreateAlarm(QuicAlarm::Delegate* delegate) {
    217   return new TestAlarm(delegate);
    218 }
    219 
    220 void MockHelper::AdvanceTime(QuicTime::Delta delta) {
    221   clock_.AdvanceTime(delta);
    222 }
    223 
    224 namespace {
    225 class NiceMockPacketWriterFactory
    226     : public QuicConnection::PacketWriterFactory {
    227  public:
    228   NiceMockPacketWriterFactory() {}
    229   virtual ~NiceMockPacketWriterFactory() {}
    230 
    231   virtual QuicPacketWriter* Create(
    232       QuicConnection* /*connection*/) const OVERRIDE {
    233     return new testing::NiceMock<MockPacketWriter>();
    234   }
    235 
    236  private:
    237   DISALLOW_COPY_AND_ASSIGN(NiceMockPacketWriterFactory);
    238 };
    239 }  // namespace
    240 
    241 MockConnection::MockConnection(bool is_server)
    242     : QuicConnection(kTestConnectionId,
    243                      IPEndPoint(TestPeerIPAddress(), kTestPort),
    244                      new testing::NiceMock<MockHelper>(),
    245                      NiceMockPacketWriterFactory(),
    246                      /* owns_writer= */ true,
    247                      is_server, QuicSupportedVersions()),
    248       helper_(helper()) {
    249 }
    250 
    251 MockConnection::MockConnection(IPEndPoint address,
    252                                bool is_server)
    253     : QuicConnection(kTestConnectionId, address,
    254                      new testing::NiceMock<MockHelper>(),
    255                      NiceMockPacketWriterFactory(),
    256                      /* owns_writer= */ true,
    257                      is_server, QuicSupportedVersions()),
    258       helper_(helper()) {
    259 }
    260 
    261 MockConnection::MockConnection(QuicConnectionId connection_id,
    262                                bool is_server)
    263     : QuicConnection(connection_id,
    264                      IPEndPoint(TestPeerIPAddress(), kTestPort),
    265                      new testing::NiceMock<MockHelper>(),
    266                      NiceMockPacketWriterFactory(),
    267                      /* owns_writer= */ true,
    268                      is_server, QuicSupportedVersions()),
    269       helper_(helper()) {
    270 }
    271 
    272 MockConnection::MockConnection(bool is_server,
    273                                const QuicVersionVector& supported_versions)
    274     : QuicConnection(kTestConnectionId,
    275                      IPEndPoint(TestPeerIPAddress(), kTestPort),
    276                      new testing::NiceMock<MockHelper>(),
    277                      NiceMockPacketWriterFactory(),
    278                      /* owns_writer= */ true,
    279                      is_server, supported_versions),
    280       helper_(helper()) {
    281 }
    282 
    283 MockConnection::~MockConnection() {
    284 }
    285 
    286 void MockConnection::AdvanceTime(QuicTime::Delta delta) {
    287   static_cast<MockHelper*>(helper())->AdvanceTime(delta);
    288 }
    289 
    290 PacketSavingConnection::PacketSavingConnection(bool is_server)
    291     : MockConnection(is_server) {
    292 }
    293 
    294 PacketSavingConnection::PacketSavingConnection(
    295     bool is_server,
    296     const QuicVersionVector& supported_versions)
    297     : MockConnection(is_server, supported_versions) {
    298 }
    299 
    300 PacketSavingConnection::~PacketSavingConnection() {
    301   STLDeleteElements(&packets_);
    302   STLDeleteElements(&encrypted_packets_);
    303 }
    304 
    305 void PacketSavingConnection::SendOrQueuePacket(QueuedPacket packet) {
    306   packets_.push_back(packet.serialized_packet.packet);
    307   QuicEncryptedPacket* encrypted = QuicConnectionPeer::GetFramer(this)->
    308       EncryptPacket(packet.encryption_level,
    309                     packet.serialized_packet.sequence_number,
    310                     *packet.serialized_packet.packet);
    311   encrypted_packets_.push_back(encrypted);
    312   // Transfer ownership of the packet to the SentPacketManager and the
    313   // ack notifier to the AckNotifierManager.
    314   sent_packet_manager_.OnSerializedPacket(packet.serialized_packet);
    315 }
    316 
    317 MockSession::MockSession(QuicConnection* connection)
    318     : QuicSession(connection, DefaultQuicConfig()) {
    319   InitializeSession();
    320   ON_CALL(*this, WritevData(_, _, _, _, _, _))
    321       .WillByDefault(testing::Return(QuicConsumedData(0, false)));
    322 }
    323 
    324 MockSession::~MockSession() {
    325 }
    326 
    327 TestSession::TestSession(QuicConnection* connection, const QuicConfig& config)
    328     : QuicSession(connection, config),
    329       crypto_stream_(NULL) {
    330   InitializeSession();
    331 }
    332 
    333 TestSession::~TestSession() {}
    334 
    335 void TestSession::SetCryptoStream(QuicCryptoStream* stream) {
    336   crypto_stream_ = stream;
    337 }
    338 
    339 QuicCryptoStream* TestSession::GetCryptoStream() {
    340   return crypto_stream_;
    341 }
    342 
    343 TestClientSession::TestClientSession(QuicConnection* connection,
    344                                      const QuicConfig& config)
    345     : QuicClientSessionBase(connection, config),
    346       crypto_stream_(NULL) {
    347   EXPECT_CALL(*this, OnProofValid(_)).Times(AnyNumber());
    348   InitializeSession();
    349 }
    350 
    351 TestClientSession::~TestClientSession() {}
    352 
    353 void TestClientSession::SetCryptoStream(QuicCryptoStream* stream) {
    354   crypto_stream_ = stream;
    355 }
    356 
    357 QuicCryptoStream* TestClientSession::GetCryptoStream() {
    358   return crypto_stream_;
    359 }
    360 
    361 MockPacketWriter::MockPacketWriter() {
    362 }
    363 
    364 MockPacketWriter::~MockPacketWriter() {
    365 }
    366 
    367 MockSendAlgorithm::MockSendAlgorithm() {
    368 }
    369 
    370 MockSendAlgorithm::~MockSendAlgorithm() {
    371 }
    372 
    373 MockLossAlgorithm::MockLossAlgorithm() {
    374 }
    375 
    376 MockLossAlgorithm::~MockLossAlgorithm() {
    377 }
    378 
    379 MockAckNotifierDelegate::MockAckNotifierDelegate() {
    380 }
    381 
    382 MockAckNotifierDelegate::~MockAckNotifierDelegate() {
    383 }
    384 
    385 MockNetworkChangeVisitor::MockNetworkChangeVisitor() {
    386 }
    387 
    388 MockNetworkChangeVisitor::~MockNetworkChangeVisitor() {
    389 }
    390 
    391 namespace {
    392 
    393 string HexDumpWithMarks(const char* data, int length,
    394                         const bool* marks, int mark_length) {
    395   static const char kHexChars[] = "0123456789abcdef";
    396   static const int kColumns = 4;
    397 
    398   const int kSizeLimit = 1024;
    399   if (length > kSizeLimit || mark_length > kSizeLimit) {
    400     LOG(ERROR) << "Only dumping first " << kSizeLimit << " bytes.";
    401     length = min(length, kSizeLimit);
    402     mark_length = min(mark_length, kSizeLimit);
    403   }
    404 
    405   string hex;
    406   for (const char* row = data; length > 0;
    407        row += kColumns, length -= kColumns) {
    408     for (const char *p = row; p < row + 4; ++p) {
    409       if (p < row + length) {
    410         const bool mark =
    411             (marks && (p - data) < mark_length && marks[p - data]);
    412         hex += mark ? '*' : ' ';
    413         hex += kHexChars[(*p & 0xf0) >> 4];
    414         hex += kHexChars[*p & 0x0f];
    415         hex += mark ? '*' : ' ';
    416       } else {
    417         hex += "    ";
    418       }
    419     }
    420     hex = hex + "  ";
    421 
    422     for (const char *p = row; p < row + 4 && p < row + length; ++p)
    423       hex += (*p >= 0x20 && *p <= 0x7f) ? (*p) : '.';
    424 
    425     hex = hex + '\n';
    426   }
    427   return hex;
    428 }
    429 
    430 }  // namespace
    431 
    432 IPAddressNumber TestPeerIPAddress() { return Loopback4(); }
    433 
    434 QuicVersion QuicVersionMax() { return QuicSupportedVersions().front(); }
    435 
    436 QuicVersion QuicVersionMin() { return QuicSupportedVersions().back(); }
    437 
    438 IPAddressNumber Loopback4() {
    439   IPAddressNumber addr;
    440   CHECK(ParseIPLiteralToNumber("127.0.0.1", &addr));
    441   return addr;
    442 }
    443 
    444 IPAddressNumber Loopback6() {
    445   IPAddressNumber addr;
    446   CHECK(ParseIPLiteralToNumber("::1", &addr));
    447   return addr;
    448 }
    449 
    450 void GenerateBody(string* body, int length) {
    451   body->clear();
    452   body->reserve(length);
    453   for (int i = 0; i < length; ++i) {
    454     body->append(1, static_cast<char>(32 + i % (126 - 32)));
    455   }
    456 }
    457 
    458 QuicEncryptedPacket* ConstructEncryptedPacket(
    459     QuicConnectionId connection_id,
    460     bool version_flag,
    461     bool reset_flag,
    462     QuicPacketSequenceNumber sequence_number,
    463     const string& data) {
    464   QuicPacketHeader header;
    465   header.public_header.connection_id = connection_id;
    466   header.public_header.connection_id_length = PACKET_8BYTE_CONNECTION_ID;
    467   header.public_header.version_flag = version_flag;
    468   header.public_header.reset_flag = reset_flag;
    469   header.public_header.sequence_number_length = PACKET_6BYTE_SEQUENCE_NUMBER;
    470   header.packet_sequence_number = sequence_number;
    471   header.entropy_flag = false;
    472   header.entropy_hash = 0;
    473   header.fec_flag = false;
    474   header.is_in_fec_group = NOT_IN_FEC_GROUP;
    475   header.fec_group = 0;
    476   QuicStreamFrame stream_frame(1, false, 0, MakeIOVector(data));
    477   QuicFrame frame(&stream_frame);
    478   QuicFrames frames;
    479   frames.push_back(frame);
    480   QuicFramer framer(QuicSupportedVersions(), QuicTime::Zero(), false);
    481   scoped_ptr<QuicPacket> packet(
    482       BuildUnsizedDataPacket(&framer, header, frames).packet);
    483   EXPECT_TRUE(packet != NULL);
    484   QuicEncryptedPacket* encrypted = framer.EncryptPacket(ENCRYPTION_NONE,
    485                                                         sequence_number,
    486                                                         *packet);
    487   EXPECT_TRUE(encrypted != NULL);
    488   return encrypted;
    489 }
    490 
    491 void CompareCharArraysWithHexError(
    492     const string& description,
    493     const char* actual,
    494     const int actual_len,
    495     const char* expected,
    496     const int expected_len) {
    497   EXPECT_EQ(actual_len, expected_len);
    498   const int min_len = min(actual_len, expected_len);
    499   const int max_len = max(actual_len, expected_len);
    500   scoped_ptr<bool[]> marks(new bool[max_len]);
    501   bool identical = (actual_len == expected_len);
    502   for (int i = 0; i < min_len; ++i) {
    503     if (actual[i] != expected[i]) {
    504       marks[i] = true;
    505       identical = false;
    506     } else {
    507       marks[i] = false;
    508     }
    509   }
    510   for (int i = min_len; i < max_len; ++i) {
    511     marks[i] = true;
    512   }
    513   if (identical) return;
    514   ADD_FAILURE()
    515       << "Description:\n"
    516       << description
    517       << "\n\nExpected:\n"
    518       << HexDumpWithMarks(expected, expected_len, marks.get(), max_len)
    519       << "\nActual:\n"
    520       << HexDumpWithMarks(actual, actual_len, marks.get(), max_len);
    521 }
    522 
    523 bool DecodeHexString(const base::StringPiece& hex, std::string* bytes) {
    524   bytes->clear();
    525   if (hex.empty())
    526     return true;
    527   std::vector<uint8> v;
    528   if (!base::HexStringToBytes(hex.as_string(), &v))
    529     return false;
    530   if (!v.empty())
    531     bytes->assign(reinterpret_cast<const char*>(&v[0]), v.size());
    532   return true;
    533 }
    534 
    535 static QuicPacket* ConstructPacketFromHandshakeMessage(
    536     QuicConnectionId connection_id,
    537     const CryptoHandshakeMessage& message,
    538     bool should_include_version) {
    539   CryptoFramer crypto_framer;
    540   scoped_ptr<QuicData> data(crypto_framer.ConstructHandshakeMessage(message));
    541   QuicFramer quic_framer(QuicSupportedVersions(), QuicTime::Zero(), false);
    542 
    543   QuicPacketHeader header;
    544   header.public_header.connection_id = connection_id;
    545   header.public_header.reset_flag = false;
    546   header.public_header.version_flag = should_include_version;
    547   header.packet_sequence_number = 1;
    548   header.entropy_flag = false;
    549   header.entropy_hash = 0;
    550   header.fec_flag = false;
    551   header.fec_group = 0;
    552 
    553   QuicStreamFrame stream_frame(kCryptoStreamId, false, 0,
    554                                MakeIOVector(data->AsStringPiece()));
    555 
    556   QuicFrame frame(&stream_frame);
    557   QuicFrames frames;
    558   frames.push_back(frame);
    559   return BuildUnsizedDataPacket(&quic_framer, header, frames).packet;
    560 }
    561 
    562 QuicPacket* ConstructHandshakePacket(QuicConnectionId connection_id,
    563                                      QuicTag tag) {
    564   CryptoHandshakeMessage message;
    565   message.set_tag(tag);
    566   return ConstructPacketFromHandshakeMessage(connection_id, message, false);
    567 }
    568 
    569 size_t GetPacketLengthForOneStream(
    570     QuicVersion version,
    571     bool include_version,
    572     QuicSequenceNumberLength sequence_number_length,
    573     InFecGroup is_in_fec_group,
    574     size_t* payload_length) {
    575   *payload_length = 1;
    576   const size_t stream_length =
    577       NullEncrypter().GetCiphertextSize(*payload_length) +
    578       QuicPacketCreator::StreamFramePacketOverhead(
    579           PACKET_8BYTE_CONNECTION_ID, include_version,
    580           sequence_number_length, 0u, is_in_fec_group);
    581   const size_t ack_length = NullEncrypter().GetCiphertextSize(
    582       QuicFramer::GetMinAckFrameSize(
    583           sequence_number_length, PACKET_1BYTE_SEQUENCE_NUMBER)) +
    584       GetPacketHeaderSize(PACKET_8BYTE_CONNECTION_ID, include_version,
    585                           sequence_number_length, is_in_fec_group);
    586   if (stream_length < ack_length) {
    587     *payload_length = 1 + ack_length - stream_length;
    588   }
    589 
    590   return NullEncrypter().GetCiphertextSize(*payload_length) +
    591       QuicPacketCreator::StreamFramePacketOverhead(
    592           PACKET_8BYTE_CONNECTION_ID, include_version,
    593           sequence_number_length, 0u, is_in_fec_group);
    594 }
    595 
    596 TestEntropyCalculator::TestEntropyCalculator() {}
    597 
    598 TestEntropyCalculator::~TestEntropyCalculator() {}
    599 
    600 QuicPacketEntropyHash TestEntropyCalculator::EntropyHash(
    601     QuicPacketSequenceNumber sequence_number) const {
    602   return 1u;
    603 }
    604 
    605 MockEntropyCalculator::MockEntropyCalculator() {}
    606 
    607 MockEntropyCalculator::~MockEntropyCalculator() {}
    608 
    609 QuicConfig DefaultQuicConfig() {
    610   QuicConfig config;
    611   config.SetDefaults();
    612   config.SetInitialFlowControlWindowToSend(
    613       kInitialSessionFlowControlWindowForTest);
    614   config.SetInitialStreamFlowControlWindowToSend(
    615       kInitialStreamFlowControlWindowForTest);
    616   config.SetInitialSessionFlowControlWindowToSend(
    617       kInitialSessionFlowControlWindowForTest);
    618   return config;
    619 }
    620 
    621 QuicVersionVector SupportedVersions(QuicVersion version) {
    622   QuicVersionVector versions;
    623   versions.push_back(version);
    624   return versions;
    625 }
    626 
    627 TestWriterFactory::TestWriterFactory() : current_writer_(NULL) {}
    628 TestWriterFactory::~TestWriterFactory() {}
    629 
    630 QuicPacketWriter* TestWriterFactory::Create(QuicServerPacketWriter* writer,
    631                                             QuicConnection* connection) {
    632   return new PerConnectionPacketWriter(this, writer, connection);
    633 }
    634 
    635 void TestWriterFactory::OnPacketSent(WriteResult result) {
    636   if (current_writer_ != NULL && result.status == WRITE_STATUS_ERROR) {
    637     current_writer_->connection()->OnWriteError(result.error_code);
    638     current_writer_ = NULL;
    639   }
    640 }
    641 
    642 void TestWriterFactory::Unregister(PerConnectionPacketWriter* writer) {
    643   if (current_writer_ == writer) {
    644     current_writer_ = NULL;
    645   }
    646 }
    647 
    648 TestWriterFactory::PerConnectionPacketWriter::PerConnectionPacketWriter(
    649     TestWriterFactory* factory,
    650     QuicServerPacketWriter* writer,
    651     QuicConnection* connection)
    652     : QuicPerConnectionPacketWriter(writer, connection),
    653       factory_(factory) {
    654 }
    655 
    656 TestWriterFactory::PerConnectionPacketWriter::~PerConnectionPacketWriter() {
    657   factory_->Unregister(this);
    658 }
    659 
    660 WriteResult TestWriterFactory::PerConnectionPacketWriter::WritePacket(
    661     const char* buffer,
    662     size_t buf_len,
    663     const IPAddressNumber& self_address,
    664     const IPEndPoint& peer_address) {
    665   // A DCHECK(factory_current_writer_ == NULL) would be wrong here -- this class
    666   // may be used in a setting where connection()->OnPacketSent() is called in a
    667   // different way, so TestWriterFactory::OnPacketSent might never be called.
    668   factory_->current_writer_ = this;
    669   return QuicPerConnectionPacketWriter::WritePacket(buffer,
    670                                                     buf_len,
    671                                                     self_address,
    672                                                     peer_address);
    673 }
    674 
    675 }  // namespace test
    676 }  // namespace net
    677