Home | History | Annotate | Download | only in test_tools
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/quic/test_tools/quic_test_utils.h"
      6 
      7 #include "base/stl_util.h"
      8 #include "net/quic/crypto/crypto_framer.h"
      9 #include "net/quic/crypto/crypto_handshake.h"
     10 #include "net/quic/crypto/crypto_utils.h"
     11 #include "net/quic/crypto/null_encrypter.h"
     12 #include "net/quic/crypto/quic_decrypter.h"
     13 #include "net/quic/crypto/quic_encrypter.h"
     14 #include "net/quic/quic_framer.h"
     15 #include "net/quic/quic_packet_creator.h"
     16 #include "net/spdy/spdy_frame_builder.h"
     17 
     18 using base::StringPiece;
     19 using std::max;
     20 using std::min;
     21 using std::string;
     22 using testing::_;
     23 
     24 namespace net {
     25 namespace test {
     26 namespace {
     27 
     28 // No-op alarm implementation used by MockHelper.
     29 class TestAlarm : public QuicAlarm {
     30  public:
     31   explicit TestAlarm(QuicAlarm::Delegate* delegate)
     32       : QuicAlarm(delegate) {
     33   }
     34 
     35   virtual void SetImpl() OVERRIDE {}
     36   virtual void CancelImpl() OVERRIDE {}
     37 };
     38 
     39 }  // namespace
     40 
     41 MockFramerVisitor::MockFramerVisitor() {
     42   // By default, we want to accept packets.
     43   ON_CALL(*this, OnProtocolVersionMismatch(_))
     44       .WillByDefault(testing::Return(false));
     45 
     46   // By default, we want to accept packets.
     47   ON_CALL(*this, OnPacketHeader(_))
     48       .WillByDefault(testing::Return(true));
     49 
     50   ON_CALL(*this, OnStreamFrame(_))
     51       .WillByDefault(testing::Return(true));
     52 
     53   ON_CALL(*this, OnAckFrame(_))
     54       .WillByDefault(testing::Return(true));
     55 
     56   ON_CALL(*this, OnCongestionFeedbackFrame(_))
     57       .WillByDefault(testing::Return(true));
     58 
     59   ON_CALL(*this, OnRstStreamFrame(_))
     60       .WillByDefault(testing::Return(true));
     61 
     62   ON_CALL(*this, OnConnectionCloseFrame(_))
     63       .WillByDefault(testing::Return(true));
     64 
     65   ON_CALL(*this, OnGoAwayFrame(_))
     66       .WillByDefault(testing::Return(true));
     67 }
     68 
     69 MockFramerVisitor::~MockFramerVisitor() {
     70 }
     71 
     72 bool NoOpFramerVisitor::OnProtocolVersionMismatch(QuicVersion version) {
     73   return false;
     74 }
     75 
     76 bool NoOpFramerVisitor::OnPacketHeader(const QuicPacketHeader& header) {
     77   return true;
     78 }
     79 
     80 bool NoOpFramerVisitor::OnStreamFrame(const QuicStreamFrame& frame) {
     81   return true;
     82 }
     83 
     84 bool NoOpFramerVisitor::OnAckFrame(const QuicAckFrame& frame) {
     85   return true;
     86 }
     87 
     88 bool NoOpFramerVisitor::OnCongestionFeedbackFrame(
     89     const QuicCongestionFeedbackFrame& frame) {
     90   return true;
     91 }
     92 
     93 bool NoOpFramerVisitor::OnRstStreamFrame(
     94     const QuicRstStreamFrame& frame) {
     95   return true;
     96 }
     97 
     98 bool NoOpFramerVisitor::OnConnectionCloseFrame(
     99     const QuicConnectionCloseFrame& frame) {
    100   return true;
    101 }
    102 
    103 bool NoOpFramerVisitor::OnGoAwayFrame(const QuicGoAwayFrame& frame) {
    104   return true;
    105 }
    106 
    107 FramerVisitorCapturingFrames::FramerVisitorCapturingFrames() : frame_count_(0) {
    108 }
    109 
    110 FramerVisitorCapturingFrames::~FramerVisitorCapturingFrames() {
    111 }
    112 
    113 bool FramerVisitorCapturingFrames::OnPacketHeader(
    114     const QuicPacketHeader& header) {
    115   header_ = header;
    116   frame_count_ = 0;
    117   return true;
    118 }
    119 
    120 bool FramerVisitorCapturingFrames::OnStreamFrame(const QuicStreamFrame& frame) {
    121   // TODO(ianswett): Own the underlying string, so it will not exist outside
    122   // this callback.
    123   stream_frames_.push_back(frame);
    124   ++frame_count_;
    125   return true;
    126 }
    127 
    128 bool FramerVisitorCapturingFrames::OnAckFrame(const QuicAckFrame& frame) {
    129   ack_.reset(new QuicAckFrame(frame));
    130   ++frame_count_;
    131   return true;
    132 }
    133 
    134 bool FramerVisitorCapturingFrames::OnCongestionFeedbackFrame(
    135     const QuicCongestionFeedbackFrame& frame) {
    136   feedback_.reset(new QuicCongestionFeedbackFrame(frame));
    137   ++frame_count_;
    138   return true;
    139 }
    140 
    141 bool FramerVisitorCapturingFrames::OnRstStreamFrame(
    142     const QuicRstStreamFrame& frame) {
    143   rst_.reset(new QuicRstStreamFrame(frame));
    144   ++frame_count_;
    145   return true;
    146 }
    147 
    148 bool FramerVisitorCapturingFrames::OnConnectionCloseFrame(
    149     const QuicConnectionCloseFrame& frame) {
    150   close_.reset(new QuicConnectionCloseFrame(frame));
    151   ++frame_count_;
    152   return true;
    153 }
    154 
    155 bool FramerVisitorCapturingFrames::OnGoAwayFrame(const QuicGoAwayFrame& frame) {
    156   goaway_.reset(new QuicGoAwayFrame(frame));
    157   ++frame_count_;
    158   return true;
    159 }
    160 
    161 void FramerVisitorCapturingFrames::OnVersionNegotiationPacket(
    162     const QuicVersionNegotiationPacket& packet) {
    163   version_negotiation_packet_.reset(new QuicVersionNegotiationPacket(packet));
    164   frame_count_ = 0;
    165 }
    166 
    167 FramerVisitorCapturingPublicReset::FramerVisitorCapturingPublicReset() {
    168 }
    169 
    170 FramerVisitorCapturingPublicReset::~FramerVisitorCapturingPublicReset() {
    171 }
    172 
    173 void FramerVisitorCapturingPublicReset::OnPublicResetPacket(
    174     const QuicPublicResetPacket& public_reset) {
    175   public_reset_packet_ = public_reset;
    176 }
    177 
    178 MockConnectionVisitor::MockConnectionVisitor() {
    179 }
    180 
    181 MockConnectionVisitor::~MockConnectionVisitor() {
    182 }
    183 
    184 MockHelper::MockHelper() {
    185 }
    186 
    187 MockHelper::~MockHelper() {
    188 }
    189 
    190 const QuicClock* MockHelper::GetClock() const {
    191   return &clock_;
    192 }
    193 
    194 QuicRandom* MockHelper::GetRandomGenerator() {
    195   return &random_generator_;
    196 }
    197 
    198 QuicAlarm* MockHelper::CreateAlarm(QuicAlarm::Delegate* delegate) {
    199   return new TestAlarm(delegate);
    200 }
    201 
    202 void MockHelper::AdvanceTime(QuicTime::Delta delta) {
    203   clock_.AdvanceTime(delta);
    204 }
    205 
    206 MockConnection::MockConnection(QuicGuid guid,
    207                                IPEndPoint address,
    208                                bool is_server)
    209     : QuicConnection(guid, address, new testing::NiceMock<MockHelper>(),
    210                      is_server, QuicVersionMax()),
    211       has_mock_helper_(true) {
    212 }
    213 
    214 MockConnection::MockConnection(QuicGuid guid,
    215                                IPEndPoint address,
    216                                QuicConnectionHelperInterface* helper,
    217                                bool is_server)
    218     : QuicConnection(guid, address, helper, is_server, QuicVersionMax()),
    219       has_mock_helper_(false) {
    220 }
    221 
    222 MockConnection::~MockConnection() {
    223 }
    224 
    225 void MockConnection::AdvanceTime(QuicTime::Delta delta) {
    226   CHECK(has_mock_helper_) << "Cannot advance time unless a MockClock is being"
    227                              " used";
    228   static_cast<MockHelper*>(helper())->AdvanceTime(delta);
    229 }
    230 
    231 PacketSavingConnection::PacketSavingConnection(QuicGuid guid,
    232                                                IPEndPoint address,
    233                                                bool is_server)
    234     : MockConnection(guid, address, is_server) {
    235 }
    236 
    237 PacketSavingConnection::~PacketSavingConnection() {
    238   STLDeleteElements(&packets_);
    239   STLDeleteElements(&encrypted_packets_);
    240 }
    241 
    242 bool PacketSavingConnection::SendOrQueuePacket(
    243     EncryptionLevel level,
    244     QuicPacketSequenceNumber sequence_number,
    245     QuicPacket* packet,
    246     QuicPacketEntropyHash entropy_hash,
    247     HasRetransmittableData retransmittable) {
    248   packets_.push_back(packet);
    249   QuicEncryptedPacket* encrypted =
    250       framer_.EncryptPacket(level, sequence_number, *packet);
    251   encrypted_packets_.push_back(encrypted);
    252   return true;
    253 }
    254 
    255 MockSession::MockSession(QuicConnection* connection, bool is_server)
    256     : QuicSession(connection, DefaultQuicConfig(), is_server) {
    257   ON_CALL(*this, WriteData(_, _, _, _))
    258       .WillByDefault(testing::Return(QuicConsumedData(0, false)));
    259 }
    260 
    261 MockSession::~MockSession() {
    262 }
    263 
    264 TestSession::TestSession(QuicConnection* connection,
    265                          const QuicConfig& config,
    266                          bool is_server)
    267     : QuicSession(connection, config, is_server),
    268       crypto_stream_(NULL) {
    269 }
    270 
    271 TestSession::~TestSession() {}
    272 
    273 void TestSession::SetCryptoStream(QuicCryptoStream* stream) {
    274   crypto_stream_ = stream;
    275 }
    276 
    277 QuicCryptoStream* TestSession::GetCryptoStream() {
    278   return crypto_stream_;
    279 }
    280 
    281 MockSendAlgorithm::MockSendAlgorithm() {
    282 }
    283 
    284 MockSendAlgorithm::~MockSendAlgorithm() {
    285 }
    286 
    287 namespace {
    288 
    289 string HexDumpWithMarks(const char* data, int length,
    290                         const bool* marks, int mark_length) {
    291   static const char kHexChars[] = "0123456789abcdef";
    292   static const int kColumns = 4;
    293 
    294   const int kSizeLimit = 1024;
    295   if (length > kSizeLimit || mark_length > kSizeLimit) {
    296     LOG(ERROR) << "Only dumping first " << kSizeLimit << " bytes.";
    297     length = min(length, kSizeLimit);
    298     mark_length = min(mark_length, kSizeLimit);
    299   }
    300 
    301   string hex;
    302   for (const char* row = data; length > 0;
    303        row += kColumns, length -= kColumns) {
    304     for (const char *p = row; p < row + 4; ++p) {
    305       if (p < row + length) {
    306         const bool mark =
    307             (marks && (p - data) < mark_length && marks[p - data]);
    308         hex += mark ? '*' : ' ';
    309         hex += kHexChars[(*p & 0xf0) >> 4];
    310         hex += kHexChars[*p & 0x0f];
    311         hex += mark ? '*' : ' ';
    312       } else {
    313         hex += "    ";
    314       }
    315     }
    316     hex = hex + "  ";
    317 
    318     for (const char *p = row; p < row + 4 && p < row + length; ++p)
    319       hex += (*p >= 0x20 && *p <= 0x7f) ? (*p) : '.';
    320 
    321     hex = hex + '\n';
    322   }
    323   return hex;
    324 }
    325 
    326 }  // namespace
    327 
    328 void CompareCharArraysWithHexError(
    329     const string& description,
    330     const char* actual,
    331     const int actual_len,
    332     const char* expected,
    333     const int expected_len) {
    334   const int min_len = min(actual_len, expected_len);
    335   const int max_len = max(actual_len, expected_len);
    336   scoped_ptr<bool[]> marks(new bool[max_len]);
    337   bool identical = (actual_len == expected_len);
    338   for (int i = 0; i < min_len; ++i) {
    339     if (actual[i] != expected[i]) {
    340       marks[i] = true;
    341       identical = false;
    342     } else {
    343       marks[i] = false;
    344     }
    345   }
    346   for (int i = min_len; i < max_len; ++i) {
    347     marks[i] = true;
    348   }
    349   if (identical) return;
    350   ADD_FAILURE()
    351       << "Description:\n"
    352       << description
    353       << "\n\nExpected:\n"
    354       << HexDumpWithMarks(expected, expected_len, marks.get(), max_len)
    355       << "\nActual:\n"
    356       << HexDumpWithMarks(actual, actual_len, marks.get(), max_len);
    357 }
    358 
    359 void CompareQuicDataWithHexError(
    360     const string& description,
    361     QuicData* actual,
    362     QuicData* expected) {
    363   CompareCharArraysWithHexError(
    364       description,
    365       actual->data(), actual->length(),
    366       expected->data(), expected->length());
    367 }
    368 
    369 static QuicPacket* ConstructPacketFromHandshakeMessage(
    370     QuicGuid guid,
    371     const CryptoHandshakeMessage& message,
    372     bool should_include_version) {
    373   CryptoFramer crypto_framer;
    374   scoped_ptr<QuicData> data(crypto_framer.ConstructHandshakeMessage(message));
    375   QuicFramer quic_framer(QuicVersionMax(), QuicTime::Zero(), false);
    376 
    377   QuicPacketHeader header;
    378   header.public_header.guid = guid;
    379   header.public_header.reset_flag = false;
    380   header.public_header.version_flag = should_include_version;
    381   header.packet_sequence_number = 1;
    382   header.entropy_flag = false;
    383   header.entropy_hash = 0;
    384   header.fec_flag = false;
    385   header.fec_group = 0;
    386 
    387   QuicStreamFrame stream_frame(kCryptoStreamId, false, 0,
    388                                data->AsStringPiece());
    389 
    390   QuicFrame frame(&stream_frame);
    391   QuicFrames frames;
    392   frames.push_back(frame);
    393   return quic_framer.BuildUnsizedDataPacket(header, frames).packet;
    394 }
    395 
    396 QuicPacket* ConstructHandshakePacket(QuicGuid guid, QuicTag tag) {
    397   CryptoHandshakeMessage message;
    398   message.set_tag(tag);
    399   return ConstructPacketFromHandshakeMessage(guid, message, false);
    400 }
    401 
    402 size_t GetPacketLengthForOneStream(QuicVersion version,
    403                                    bool include_version,
    404                                    InFecGroup is_in_fec_group,
    405                                    size_t* payload_length) {
    406   *payload_length = 1;
    407   const size_t stream_length =
    408       NullEncrypter().GetCiphertextSize(*payload_length) +
    409       QuicPacketCreator::StreamFramePacketOverhead(
    410           version, PACKET_8BYTE_GUID, include_version,
    411           PACKET_6BYTE_SEQUENCE_NUMBER, is_in_fec_group);
    412   const size_t ack_length = NullEncrypter().GetCiphertextSize(
    413       QuicFramer::GetMinAckFrameSize()) +
    414       GetPacketHeaderSize(PACKET_8BYTE_GUID, include_version,
    415                           PACKET_6BYTE_SEQUENCE_NUMBER, is_in_fec_group);
    416   if (stream_length < ack_length) {
    417     *payload_length = 1 + ack_length - stream_length;
    418   }
    419 
    420   return NullEncrypter().GetCiphertextSize(*payload_length) +
    421       QuicPacketCreator::StreamFramePacketOverhead(
    422           version, PACKET_8BYTE_GUID, include_version,
    423           PACKET_6BYTE_SEQUENCE_NUMBER, is_in_fec_group);
    424 }
    425 
    426 // Size in bytes of the stream frame fields for an arbitrary StreamID and
    427 // offset and the last frame in a packet.
    428 size_t GetMinStreamFrameSize(QuicVersion version) {
    429   return kQuicFrameTypeSize + kQuicMaxStreamIdSize + kQuicMaxStreamOffsetSize;
    430 }
    431 
    432 QuicPacketEntropyHash TestEntropyCalculator::EntropyHash(
    433     QuicPacketSequenceNumber sequence_number) const {
    434   return 1u;
    435 }
    436 
    437 QuicConfig DefaultQuicConfig() {
    438   QuicConfig config;
    439   config.SetDefaults();
    440   return config;
    441 }
    442 
    443 bool TestDecompressorVisitor::OnDecompressedData(StringPiece data) {
    444   data.AppendToString(&data_);
    445   return true;
    446 }
    447 
    448 void TestDecompressorVisitor::OnDecompressionError() {
    449   error_ = true;
    450 }
    451 
    452 }  // namespace test
    453 }  // namespace net
    454