Home | History | Annotate | Download | only in test_tools
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/tools/quic/test_tools/quic_test_client.h"
      6 
      7 #include "base/time/time.h"
      8 #include "net/base/completion_callback.h"
      9 #include "net/base/net_errors.h"
     10 #include "net/cert/cert_verify_result.h"
     11 #include "net/cert/x509_certificate.h"
     12 #include "net/quic/crypto/proof_verifier.h"
     13 #include "net/quic/test_tools/quic_connection_peer.h"
     14 #include "net/tools/balsa/balsa_headers.h"
     15 #include "net/tools/quic/quic_epoll_connection_helper.h"
     16 #include "net/tools/quic/quic_spdy_client_stream.h"
     17 #include "net/tools/quic/test_tools/http_message_test_utils.h"
     18 #include "url/gurl.h"
     19 
     20 using base::StringPiece;
     21 using net::test::QuicConnectionPeer;
     22 using net::test::QuicTestWriter;
     23 using std::string;
     24 using std::vector;
     25 
     26 namespace {
     27 
     28 // RecordingProofVerifier accepts any certificate chain and records the common
     29 // name of the leaf.
     30 class RecordingProofVerifier : public net::ProofVerifier {
     31  public:
     32   // ProofVerifier interface.
     33   virtual net::ProofVerifier::Status VerifyProof(
     34       const string& hostname,
     35       const string& server_config,
     36       const vector<string>& certs,
     37       const string& signature,
     38       string* error_details,
     39       scoped_ptr<net::ProofVerifyDetails>* details,
     40       net::ProofVerifierCallback* callback) OVERRIDE {
     41     delete callback;
     42 
     43     common_name_.clear();
     44     if (certs.empty()) {
     45       return FAILURE;
     46     }
     47 
     48     // Convert certs to X509Certificate.
     49     vector<StringPiece> cert_pieces(certs.size());
     50     for (unsigned i = 0; i < certs.size(); i++) {
     51       cert_pieces[i] = StringPiece(certs[i]);
     52     }
     53     scoped_refptr<net::X509Certificate> cert =
     54         net::X509Certificate::CreateFromDERCertChain(cert_pieces);
     55     if (!cert.get()) {
     56       return FAILURE;
     57     }
     58 
     59     common_name_ = cert->subject().GetDisplayName();
     60     return SUCCESS;
     61   }
     62 
     63   const string& common_name() const { return common_name_; }
     64 
     65  private:
     66   string common_name_;
     67 };
     68 
     69 }  // anonymous namespace
     70 
     71 namespace net {
     72 namespace tools {
     73 namespace test {
     74 
     75 BalsaHeaders* MungeHeaders(const BalsaHeaders* const_headers,
     76                            bool secure) {
     77   StringPiece uri = const_headers->request_uri();
     78   if (uri.empty()) {
     79     return NULL;
     80   }
     81   if (const_headers->request_method() == "CONNECT") {
     82     return NULL;
     83   }
     84   BalsaHeaders* headers = new BalsaHeaders;
     85   headers->CopyFrom(*const_headers);
     86   if (!uri.starts_with("https://") &&
     87       !uri.starts_with("http://")) {
     88     // If we have a relative URL, set some defaults.
     89     string full_uri = secure ? "https://www.google.com" :
     90                                "http://www.google.com";
     91     full_uri.append(uri.as_string());
     92     headers->SetRequestUri(full_uri);
     93   }
     94   return headers;
     95 }
     96 
     97 // A quic client which allows mocking out writes.
     98 class QuicEpollClient : public QuicClient {
     99  public:
    100   typedef QuicClient Super;
    101 
    102   QuicEpollClient(IPEndPoint server_address,
    103              const string& server_hostname,
    104              const QuicVersionVector& supported_versions)
    105       : Super(server_address, server_hostname, supported_versions, false),
    106         override_guid_(0), test_writer_(NULL) {
    107   }
    108 
    109   QuicEpollClient(IPEndPoint server_address,
    110              const string& server_hostname,
    111              const QuicConfig& config,
    112              const QuicVersionVector& supported_versions)
    113       : Super(server_address, server_hostname, config, supported_versions),
    114         override_guid_(0), test_writer_(NULL) {
    115   }
    116 
    117   virtual ~QuicEpollClient() {
    118     if (connected()) {
    119       Disconnect();
    120     }
    121   }
    122 
    123   virtual QuicPacketWriter* CreateQuicPacketWriter() OVERRIDE {
    124     QuicPacketWriter* writer = Super::CreateQuicPacketWriter();
    125     if (!test_writer_) {
    126       return writer;
    127     }
    128     test_writer_->set_writer(writer);
    129     return test_writer_;
    130   }
    131 
    132   virtual QuicGuid GenerateGuid() OVERRIDE {
    133     return override_guid_ ? override_guid_ : Super::GenerateGuid();
    134   }
    135 
    136   // Takes ownership of writer.
    137   void UseWriter(QuicTestWriter* writer) { test_writer_ = writer; }
    138 
    139   void UseGuid(QuicGuid guid) {
    140     override_guid_ = guid;
    141   }
    142 
    143  private:
    144   QuicGuid override_guid_;  // GUID to use, if nonzero
    145   QuicTestWriter* test_writer_;
    146 };
    147 
    148 QuicTestClient::QuicTestClient(IPEndPoint address, const string& hostname,
    149                                const QuicVersionVector& supported_versions)
    150     : client_(new QuicEpollClient(address, hostname, supported_versions)) {
    151   Initialize(address, hostname, true);
    152 }
    153 
    154 QuicTestClient::QuicTestClient(IPEndPoint address,
    155                                const string& hostname,
    156                                bool secure,
    157                                const QuicVersionVector& supported_versions)
    158     : client_(new QuicEpollClient(address, hostname, supported_versions)) {
    159   Initialize(address, hostname, secure);
    160 }
    161 
    162 QuicTestClient::QuicTestClient(IPEndPoint address,
    163                                const string& hostname,
    164                                bool secure,
    165                                const QuicConfig& config,
    166                                const QuicVersionVector& supported_versions)
    167     : client_(new QuicEpollClient(address, hostname, config,
    168                                   supported_versions)) {
    169   Initialize(address, hostname, secure);
    170 }
    171 
    172 void QuicTestClient::Initialize(IPEndPoint address,
    173                                 const string& hostname,
    174                                 bool secure) {
    175   server_address_ = address;
    176   priority_ = 3;
    177   connect_attempted_ = false;
    178   secure_ = secure;
    179   auto_reconnect_ = false;
    180   buffer_body_ = true;
    181   proof_verifier_ = NULL;
    182   ClearPerRequestState();
    183   ExpectCertificates(secure_);
    184 }
    185 
    186 QuicTestClient::~QuicTestClient() {
    187   if (stream_) {
    188     stream_->set_visitor(NULL);
    189   }
    190 }
    191 
    192 void QuicTestClient::ExpectCertificates(bool on) {
    193   if (on) {
    194     proof_verifier_ = new RecordingProofVerifier;
    195     client_->SetProofVerifier(proof_verifier_);
    196   } else {
    197     proof_verifier_ = NULL;
    198     client_->SetProofVerifier(NULL);
    199   }
    200 }
    201 
    202 ssize_t QuicTestClient::SendRequest(const string& uri) {
    203   HTTPMessage message(HttpConstants::HTTP_1_1, HttpConstants::GET, uri);
    204   return SendMessage(message);
    205 }
    206 
    207 ssize_t QuicTestClient::SendMessage(const HTTPMessage& message) {
    208   stream_ = NULL;  // Always force creation of a stream for SendMessage.
    209 
    210   // If we're not connected, try to find an sni hostname.
    211   if (!connected()) {
    212     GURL url(message.headers()->request_uri().as_string());
    213     if (!url.host().empty()) {
    214       client_->set_server_hostname(url.host());
    215     }
    216   }
    217 
    218   QuicSpdyClientStream* stream = GetOrCreateStream();
    219   if (!stream) { return 0; }
    220 
    221   scoped_ptr<BalsaHeaders> munged_headers(MungeHeaders(message.headers(),
    222                                           secure_));
    223   ssize_t ret = GetOrCreateStream()->SendRequest(
    224       munged_headers.get() ? *munged_headers.get() : *message.headers(),
    225       message.body(),
    226       message.has_complete_message());
    227   WaitForWriteToFlush();
    228   return ret;
    229 }
    230 
    231 ssize_t QuicTestClient::SendData(string data, bool last_data) {
    232   QuicSpdyClientStream* stream = GetOrCreateStream();
    233   if (!stream) { return 0; }
    234   GetOrCreateStream()->SendBody(data, last_data);
    235   WaitForWriteToFlush();
    236   return data.length();
    237 }
    238 
    239 string QuicTestClient::SendCustomSynchronousRequest(
    240     const HTTPMessage& message) {
    241   SendMessage(message);
    242   WaitForResponse();
    243   return response_;
    244 }
    245 
    246 string QuicTestClient::SendSynchronousRequest(const string& uri) {
    247   if (SendRequest(uri) == 0) {
    248     DLOG(ERROR) << "Failed the request for uri:" << uri;
    249     return "";
    250   }
    251   WaitForResponse();
    252   return response_;
    253 }
    254 
    255 QuicSpdyClientStream* QuicTestClient::GetOrCreateStream() {
    256   if (!connect_attempted_ || auto_reconnect_) {
    257     if (!connected()) {
    258       Connect();
    259     }
    260     if (!connected()) {
    261       return NULL;
    262     }
    263   }
    264   if (!stream_) {
    265     stream_ = client_->CreateReliableClientStream();
    266     if (stream_ == NULL) {
    267       return NULL;
    268     }
    269     stream_->set_visitor(this);
    270     reinterpret_cast<QuicSpdyClientStream*>(stream_)->set_priority(priority_);
    271   }
    272 
    273   return stream_;
    274 }
    275 
    276 const string& QuicTestClient::cert_common_name() const {
    277   return reinterpret_cast<RecordingProofVerifier*>(proof_verifier_)
    278       ->common_name();
    279 }
    280 
    281 bool QuicTestClient::connected() const {
    282   return client_->connected();
    283 }
    284 
    285 void QuicTestClient::WaitForResponse() {
    286   if (stream_ == NULL) {
    287     // The client has likely disconnected.
    288     return;
    289   }
    290   client_->WaitForStreamToClose(stream_->id());
    291 }
    292 
    293 void QuicTestClient::Connect() {
    294   DCHECK(!connected());
    295   if (!connect_attempted_) {
    296     client_->Initialize();
    297   }
    298   client_->Connect();
    299   connect_attempted_ = true;
    300 }
    301 
    302 void QuicTestClient::ResetConnection() {
    303   Disconnect();
    304   Connect();
    305 }
    306 
    307 void QuicTestClient::Disconnect() {
    308   client_->Disconnect();
    309   connect_attempted_ = false;
    310 }
    311 
    312 IPEndPoint QuicTestClient::LocalSocketAddress() const {
    313   return client_->client_address();
    314 }
    315 
    316 void QuicTestClient::ClearPerRequestState() {
    317   stream_error_ = QUIC_STREAM_NO_ERROR;
    318   stream_ = NULL;
    319   response_ = "";
    320   response_complete_ = false;
    321   response_headers_complete_ = false;
    322   headers_.Clear();
    323   bytes_read_ = 0;
    324   bytes_written_ = 0;
    325   response_header_size_ = 0;
    326   response_body_size_ = 0;
    327 }
    328 
    329 void QuicTestClient::WaitForResponseForMs(int timeout_ms) {
    330   int64 timeout_us = timeout_ms * base::Time::kMicrosecondsPerMillisecond;
    331   int64 old_timeout_us = client()->epoll_server()->timeout_in_us();
    332   if (timeout_us > 0) {
    333     client()->epoll_server()->set_timeout_in_us(timeout_us);
    334   }
    335   const QuicClock* clock =
    336       QuicConnectionPeer::GetHelper(client()->session()->connection())->
    337           GetClock();
    338   QuicTime end_waiting_time = clock->Now().Add(
    339       QuicTime::Delta::FromMicroseconds(timeout_us));
    340   while (stream_ != NULL &&
    341          !client_->session()->IsClosedStream(stream_->id()) &&
    342          (timeout_us < 0 || clock->Now() < end_waiting_time)) {
    343     client_->WaitForEvents();
    344   }
    345   if (timeout_us > 0) {
    346     client()->epoll_server()->set_timeout_in_us(old_timeout_us);
    347   }
    348 }
    349 
    350 void QuicTestClient::WaitForInitialResponseForMs(int timeout_ms) {
    351   int64 timeout_us = timeout_ms * base::Time::kMicrosecondsPerMillisecond;
    352   int64 old_timeout_us = client()->epoll_server()->timeout_in_us();
    353   if (timeout_us > 0) {
    354     client()->epoll_server()->set_timeout_in_us(timeout_us);
    355   }
    356   const QuicClock* clock =
    357       QuicConnectionPeer::GetHelper(client()->session()->connection())->
    358           GetClock();
    359   QuicTime end_waiting_time = clock->Now().Add(
    360       QuicTime::Delta::FromMicroseconds(timeout_us));
    361   while (stream_ != NULL &&
    362          !client_->session()->IsClosedStream(stream_->id()) &&
    363          stream_->stream_bytes_read() == 0 &&
    364          (timeout_us < 0 || clock->Now() < end_waiting_time)) {
    365     client_->WaitForEvents();
    366   }
    367   if (timeout_us > 0) {
    368     client()->epoll_server()->set_timeout_in_us(old_timeout_us);
    369   }
    370 }
    371 
    372 ssize_t QuicTestClient::Send(const void *buffer, size_t size) {
    373   return SendData(string(static_cast<const char*>(buffer), size), false);
    374 }
    375 
    376 bool QuicTestClient::response_headers_complete() const {
    377   if (stream_ != NULL) {
    378     return stream_->headers_decompressed();
    379   } else {
    380     return response_headers_complete_;
    381   }
    382 }
    383 
    384 const BalsaHeaders* QuicTestClient::response_headers() const {
    385   if (stream_ != NULL) {
    386     return &stream_->headers();
    387   } else {
    388     return &headers_;
    389   }
    390 }
    391 
    392 int QuicTestClient::response_size() const {
    393   return bytes_read_;
    394 }
    395 
    396 size_t QuicTestClient::bytes_read() const {
    397   return bytes_read_;
    398 }
    399 
    400 size_t QuicTestClient::bytes_written() const {
    401   return bytes_written_;
    402 }
    403 
    404 void QuicTestClient::OnClose(QuicDataStream* stream) {
    405   if (stream_ != stream) {
    406     return;
    407   }
    408   if (buffer_body()) {
    409     // TODO(fnk): The stream still buffers the whole thing. Fix that.
    410     response_ = stream_->data();
    411   }
    412   response_complete_ = true;
    413   response_headers_complete_ = stream_->headers_decompressed();
    414   headers_.CopyFrom(stream_->headers());
    415   stream_error_ = stream_->stream_error();
    416   bytes_read_ = stream_->stream_bytes_read();
    417   bytes_written_ = stream_->stream_bytes_written();
    418   response_header_size_ = headers_.GetSizeForWriteBuffer();
    419   response_body_size_ = stream_->data().size();
    420   stream_ = NULL;
    421 }
    422 
    423 void QuicTestClient::UseWriter(QuicTestWriter* writer) {
    424   reinterpret_cast<QuicEpollClient*>(client_.get())->UseWriter(writer);
    425 }
    426 
    427 void QuicTestClient::UseGuid(QuicGuid guid) {
    428   DCHECK(!connected());
    429   reinterpret_cast<QuicEpollClient*>(client_.get())->UseGuid(guid);
    430 }
    431 
    432 void QuicTestClient::WaitForWriteToFlush() {
    433   while (connected() && client()->session()->HasQueuedData()) {
    434     client_->WaitForEvents();
    435   }
    436 }
    437 
    438 }  // namespace test
    439 }  // namespace tools
    440 }  // namespace net
    441