1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "net/tools/quic/test_tools/quic_test_client.h" 6 7 #include "base/time/time.h" 8 #include "net/base/completion_callback.h" 9 #include "net/base/net_errors.h" 10 #include "net/cert/cert_verify_result.h" 11 #include "net/cert/x509_certificate.h" 12 #include "net/quic/crypto/proof_verifier.h" 13 #include "net/quic/test_tools/quic_connection_peer.h" 14 #include "net/tools/balsa/balsa_headers.h" 15 #include "net/tools/quic/quic_epoll_connection_helper.h" 16 #include "net/tools/quic/quic_spdy_client_stream.h" 17 #include "net/tools/quic/test_tools/http_message_test_utils.h" 18 #include "url/gurl.h" 19 20 using base::StringPiece; 21 using net::test::QuicConnectionPeer; 22 using net::test::QuicTestWriter; 23 using std::string; 24 using std::vector; 25 26 namespace { 27 28 // RecordingProofVerifier accepts any certificate chain and records the common 29 // name of the leaf. 30 class RecordingProofVerifier : public net::ProofVerifier { 31 public: 32 // ProofVerifier interface. 33 virtual net::ProofVerifier::Status VerifyProof( 34 const string& hostname, 35 const string& server_config, 36 const vector<string>& certs, 37 const string& signature, 38 string* error_details, 39 scoped_ptr<net::ProofVerifyDetails>* details, 40 net::ProofVerifierCallback* callback) OVERRIDE { 41 delete callback; 42 43 common_name_.clear(); 44 if (certs.empty()) { 45 return FAILURE; 46 } 47 48 // Convert certs to X509Certificate. 49 vector<StringPiece> cert_pieces(certs.size()); 50 for (unsigned i = 0; i < certs.size(); i++) { 51 cert_pieces[i] = StringPiece(certs[i]); 52 } 53 scoped_refptr<net::X509Certificate> cert = 54 net::X509Certificate::CreateFromDERCertChain(cert_pieces); 55 if (!cert.get()) { 56 return FAILURE; 57 } 58 59 common_name_ = cert->subject().GetDisplayName(); 60 return SUCCESS; 61 } 62 63 const string& common_name() const { return common_name_; } 64 65 private: 66 string common_name_; 67 }; 68 69 } // anonymous namespace 70 71 namespace net { 72 namespace tools { 73 namespace test { 74 75 BalsaHeaders* MungeHeaders(const BalsaHeaders* const_headers, 76 bool secure) { 77 StringPiece uri = const_headers->request_uri(); 78 if (uri.empty()) { 79 return NULL; 80 } 81 if (const_headers->request_method() == "CONNECT") { 82 return NULL; 83 } 84 BalsaHeaders* headers = new BalsaHeaders; 85 headers->CopyFrom(*const_headers); 86 if (!uri.starts_with("https://") && 87 !uri.starts_with("http://")) { 88 // If we have a relative URL, set some defaults. 89 string full_uri = secure ? "https://www.google.com" : 90 "http://www.google.com"; 91 full_uri.append(uri.as_string()); 92 headers->SetRequestUri(full_uri); 93 } 94 return headers; 95 } 96 97 // A quic client which allows mocking out writes. 98 class QuicEpollClient : public QuicClient { 99 public: 100 typedef QuicClient Super; 101 102 QuicEpollClient(IPEndPoint server_address, 103 const string& server_hostname, 104 const QuicVersionVector& supported_versions) 105 : Super(server_address, server_hostname, supported_versions, false), 106 override_guid_(0), test_writer_(NULL) { 107 } 108 109 QuicEpollClient(IPEndPoint server_address, 110 const string& server_hostname, 111 const QuicConfig& config, 112 const QuicVersionVector& supported_versions) 113 : Super(server_address, server_hostname, config, supported_versions), 114 override_guid_(0), test_writer_(NULL) { 115 } 116 117 virtual ~QuicEpollClient() { 118 if (connected()) { 119 Disconnect(); 120 } 121 } 122 123 virtual QuicPacketWriter* CreateQuicPacketWriter() OVERRIDE { 124 QuicPacketWriter* writer = Super::CreateQuicPacketWriter(); 125 if (!test_writer_) { 126 return writer; 127 } 128 test_writer_->set_writer(writer); 129 return test_writer_; 130 } 131 132 virtual QuicGuid GenerateGuid() OVERRIDE { 133 return override_guid_ ? override_guid_ : Super::GenerateGuid(); 134 } 135 136 // Takes ownership of writer. 137 void UseWriter(QuicTestWriter* writer) { test_writer_ = writer; } 138 139 void UseGuid(QuicGuid guid) { 140 override_guid_ = guid; 141 } 142 143 private: 144 QuicGuid override_guid_; // GUID to use, if nonzero 145 QuicTestWriter* test_writer_; 146 }; 147 148 QuicTestClient::QuicTestClient(IPEndPoint address, const string& hostname, 149 const QuicVersionVector& supported_versions) 150 : client_(new QuicEpollClient(address, hostname, supported_versions)) { 151 Initialize(address, hostname, true); 152 } 153 154 QuicTestClient::QuicTestClient(IPEndPoint address, 155 const string& hostname, 156 bool secure, 157 const QuicVersionVector& supported_versions) 158 : client_(new QuicEpollClient(address, hostname, supported_versions)) { 159 Initialize(address, hostname, secure); 160 } 161 162 QuicTestClient::QuicTestClient(IPEndPoint address, 163 const string& hostname, 164 bool secure, 165 const QuicConfig& config, 166 const QuicVersionVector& supported_versions) 167 : client_(new QuicEpollClient(address, hostname, config, 168 supported_versions)) { 169 Initialize(address, hostname, secure); 170 } 171 172 void QuicTestClient::Initialize(IPEndPoint address, 173 const string& hostname, 174 bool secure) { 175 server_address_ = address; 176 priority_ = 3; 177 connect_attempted_ = false; 178 secure_ = secure; 179 auto_reconnect_ = false; 180 buffer_body_ = true; 181 proof_verifier_ = NULL; 182 ClearPerRequestState(); 183 ExpectCertificates(secure_); 184 } 185 186 QuicTestClient::~QuicTestClient() { 187 if (stream_) { 188 stream_->set_visitor(NULL); 189 } 190 } 191 192 void QuicTestClient::ExpectCertificates(bool on) { 193 if (on) { 194 proof_verifier_ = new RecordingProofVerifier; 195 client_->SetProofVerifier(proof_verifier_); 196 } else { 197 proof_verifier_ = NULL; 198 client_->SetProofVerifier(NULL); 199 } 200 } 201 202 ssize_t QuicTestClient::SendRequest(const string& uri) { 203 HTTPMessage message(HttpConstants::HTTP_1_1, HttpConstants::GET, uri); 204 return SendMessage(message); 205 } 206 207 ssize_t QuicTestClient::SendMessage(const HTTPMessage& message) { 208 stream_ = NULL; // Always force creation of a stream for SendMessage. 209 210 // If we're not connected, try to find an sni hostname. 211 if (!connected()) { 212 GURL url(message.headers()->request_uri().as_string()); 213 if (!url.host().empty()) { 214 client_->set_server_hostname(url.host()); 215 } 216 } 217 218 QuicSpdyClientStream* stream = GetOrCreateStream(); 219 if (!stream) { return 0; } 220 221 scoped_ptr<BalsaHeaders> munged_headers(MungeHeaders(message.headers(), 222 secure_)); 223 ssize_t ret = GetOrCreateStream()->SendRequest( 224 munged_headers.get() ? *munged_headers.get() : *message.headers(), 225 message.body(), 226 message.has_complete_message()); 227 WaitForWriteToFlush(); 228 return ret; 229 } 230 231 ssize_t QuicTestClient::SendData(string data, bool last_data) { 232 QuicSpdyClientStream* stream = GetOrCreateStream(); 233 if (!stream) { return 0; } 234 GetOrCreateStream()->SendBody(data, last_data); 235 WaitForWriteToFlush(); 236 return data.length(); 237 } 238 239 string QuicTestClient::SendCustomSynchronousRequest( 240 const HTTPMessage& message) { 241 SendMessage(message); 242 WaitForResponse(); 243 return response_; 244 } 245 246 string QuicTestClient::SendSynchronousRequest(const string& uri) { 247 if (SendRequest(uri) == 0) { 248 DLOG(ERROR) << "Failed the request for uri:" << uri; 249 return ""; 250 } 251 WaitForResponse(); 252 return response_; 253 } 254 255 QuicSpdyClientStream* QuicTestClient::GetOrCreateStream() { 256 if (!connect_attempted_ || auto_reconnect_) { 257 if (!connected()) { 258 Connect(); 259 } 260 if (!connected()) { 261 return NULL; 262 } 263 } 264 if (!stream_) { 265 stream_ = client_->CreateReliableClientStream(); 266 if (stream_ == NULL) { 267 return NULL; 268 } 269 stream_->set_visitor(this); 270 reinterpret_cast<QuicSpdyClientStream*>(stream_)->set_priority(priority_); 271 } 272 273 return stream_; 274 } 275 276 const string& QuicTestClient::cert_common_name() const { 277 return reinterpret_cast<RecordingProofVerifier*>(proof_verifier_) 278 ->common_name(); 279 } 280 281 bool QuicTestClient::connected() const { 282 return client_->connected(); 283 } 284 285 void QuicTestClient::WaitForResponse() { 286 if (stream_ == NULL) { 287 // The client has likely disconnected. 288 return; 289 } 290 client_->WaitForStreamToClose(stream_->id()); 291 } 292 293 void QuicTestClient::Connect() { 294 DCHECK(!connected()); 295 if (!connect_attempted_) { 296 client_->Initialize(); 297 } 298 client_->Connect(); 299 connect_attempted_ = true; 300 } 301 302 void QuicTestClient::ResetConnection() { 303 Disconnect(); 304 Connect(); 305 } 306 307 void QuicTestClient::Disconnect() { 308 client_->Disconnect(); 309 connect_attempted_ = false; 310 } 311 312 IPEndPoint QuicTestClient::LocalSocketAddress() const { 313 return client_->client_address(); 314 } 315 316 void QuicTestClient::ClearPerRequestState() { 317 stream_error_ = QUIC_STREAM_NO_ERROR; 318 stream_ = NULL; 319 response_ = ""; 320 response_complete_ = false; 321 response_headers_complete_ = false; 322 headers_.Clear(); 323 bytes_read_ = 0; 324 bytes_written_ = 0; 325 response_header_size_ = 0; 326 response_body_size_ = 0; 327 } 328 329 void QuicTestClient::WaitForResponseForMs(int timeout_ms) { 330 int64 timeout_us = timeout_ms * base::Time::kMicrosecondsPerMillisecond; 331 int64 old_timeout_us = client()->epoll_server()->timeout_in_us(); 332 if (timeout_us > 0) { 333 client()->epoll_server()->set_timeout_in_us(timeout_us); 334 } 335 const QuicClock* clock = 336 QuicConnectionPeer::GetHelper(client()->session()->connection())-> 337 GetClock(); 338 QuicTime end_waiting_time = clock->Now().Add( 339 QuicTime::Delta::FromMicroseconds(timeout_us)); 340 while (stream_ != NULL && 341 !client_->session()->IsClosedStream(stream_->id()) && 342 (timeout_us < 0 || clock->Now() < end_waiting_time)) { 343 client_->WaitForEvents(); 344 } 345 if (timeout_us > 0) { 346 client()->epoll_server()->set_timeout_in_us(old_timeout_us); 347 } 348 } 349 350 void QuicTestClient::WaitForInitialResponseForMs(int timeout_ms) { 351 int64 timeout_us = timeout_ms * base::Time::kMicrosecondsPerMillisecond; 352 int64 old_timeout_us = client()->epoll_server()->timeout_in_us(); 353 if (timeout_us > 0) { 354 client()->epoll_server()->set_timeout_in_us(timeout_us); 355 } 356 const QuicClock* clock = 357 QuicConnectionPeer::GetHelper(client()->session()->connection())-> 358 GetClock(); 359 QuicTime end_waiting_time = clock->Now().Add( 360 QuicTime::Delta::FromMicroseconds(timeout_us)); 361 while (stream_ != NULL && 362 !client_->session()->IsClosedStream(stream_->id()) && 363 stream_->stream_bytes_read() == 0 && 364 (timeout_us < 0 || clock->Now() < end_waiting_time)) { 365 client_->WaitForEvents(); 366 } 367 if (timeout_us > 0) { 368 client()->epoll_server()->set_timeout_in_us(old_timeout_us); 369 } 370 } 371 372 ssize_t QuicTestClient::Send(const void *buffer, size_t size) { 373 return SendData(string(static_cast<const char*>(buffer), size), false); 374 } 375 376 bool QuicTestClient::response_headers_complete() const { 377 if (stream_ != NULL) { 378 return stream_->headers_decompressed(); 379 } else { 380 return response_headers_complete_; 381 } 382 } 383 384 const BalsaHeaders* QuicTestClient::response_headers() const { 385 if (stream_ != NULL) { 386 return &stream_->headers(); 387 } else { 388 return &headers_; 389 } 390 } 391 392 int QuicTestClient::response_size() const { 393 return bytes_read_; 394 } 395 396 size_t QuicTestClient::bytes_read() const { 397 return bytes_read_; 398 } 399 400 size_t QuicTestClient::bytes_written() const { 401 return bytes_written_; 402 } 403 404 void QuicTestClient::OnClose(QuicDataStream* stream) { 405 if (stream_ != stream) { 406 return; 407 } 408 if (buffer_body()) { 409 // TODO(fnk): The stream still buffers the whole thing. Fix that. 410 response_ = stream_->data(); 411 } 412 response_complete_ = true; 413 response_headers_complete_ = stream_->headers_decompressed(); 414 headers_.CopyFrom(stream_->headers()); 415 stream_error_ = stream_->stream_error(); 416 bytes_read_ = stream_->stream_bytes_read(); 417 bytes_written_ = stream_->stream_bytes_written(); 418 response_header_size_ = headers_.GetSizeForWriteBuffer(); 419 response_body_size_ = stream_->data().size(); 420 stream_ = NULL; 421 } 422 423 void QuicTestClient::UseWriter(QuicTestWriter* writer) { 424 reinterpret_cast<QuicEpollClient*>(client_.get())->UseWriter(writer); 425 } 426 427 void QuicTestClient::UseGuid(QuicGuid guid) { 428 DCHECK(!connected()); 429 reinterpret_cast<QuicEpollClient*>(client_.get())->UseGuid(guid); 430 } 431 432 void QuicTestClient::WaitForWriteToFlush() { 433 while (connected() && client()->session()->HasQueuedData()) { 434 client_->WaitForEvents(); 435 } 436 } 437 438 } // namespace test 439 } // namespace tools 440 } // namespace net 441