1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "net/quic/test_tools/crypto_test_utils.h" 6 7 #include "net/quic/crypto/channel_id.h" 8 #include "net/quic/crypto/common_cert_set.h" 9 #include "net/quic/crypto/crypto_handshake.h" 10 #include "net/quic/crypto/crypto_server_config.h" 11 #include "net/quic/crypto/quic_decrypter.h" 12 #include "net/quic/crypto/quic_encrypter.h" 13 #include "net/quic/crypto/quic_random.h" 14 #include "net/quic/quic_clock.h" 15 #include "net/quic/quic_crypto_client_stream.h" 16 #include "net/quic/quic_crypto_server_stream.h" 17 #include "net/quic/quic_crypto_stream.h" 18 #include "net/quic/test_tools/quic_connection_peer.h" 19 #include "net/quic/test_tools/quic_test_utils.h" 20 #include "net/quic/test_tools/simple_quic_framer.h" 21 22 using base::StringPiece; 23 using std::string; 24 using std::vector; 25 26 namespace net { 27 namespace test { 28 29 namespace { 30 31 // CryptoFramerVisitor is a framer visitor that records handshake messages. 32 class CryptoFramerVisitor : public CryptoFramerVisitorInterface { 33 public: 34 CryptoFramerVisitor() 35 : error_(false) { 36 } 37 38 virtual void OnError(CryptoFramer* framer) OVERRIDE { 39 error_ = true; 40 } 41 42 virtual void OnHandshakeMessage( 43 const CryptoHandshakeMessage& message) OVERRIDE { 44 messages_.push_back(message); 45 } 46 47 bool error() const { 48 return error_; 49 } 50 51 const vector<CryptoHandshakeMessage>& messages() const { 52 return messages_; 53 } 54 55 private: 56 bool error_; 57 vector<CryptoHandshakeMessage> messages_; 58 }; 59 60 // MovePackets parses crypto handshake messages from packet number 61 // |*inout_packet_index| through to the last packet and has |dest_stream| 62 // process them. |*inout_packet_index| is updated with an index one greater 63 // than the last packet processed. 64 void MovePackets(PacketSavingConnection* source_conn, 65 size_t *inout_packet_index, 66 QuicCryptoStream* dest_stream, 67 PacketSavingConnection* dest_conn) { 68 SimpleQuicFramer framer; 69 CryptoFramer crypto_framer; 70 CryptoFramerVisitor crypto_visitor; 71 72 // In order to properly test the code we need to perform encryption and 73 // decryption so that the crypters latch when expected. The crypters are in 74 // |dest_conn|, but we don't want to try and use them there. Instead we swap 75 // them into |framer|, perform the decryption with them, and then swap them 76 // back. 77 QuicConnectionPeer::SwapCrypters(dest_conn, framer.framer()); 78 79 crypto_framer.set_visitor(&crypto_visitor); 80 81 size_t index = *inout_packet_index; 82 for (; index < source_conn->encrypted_packets_.size(); index++) { 83 ASSERT_TRUE(framer.ProcessPacket(*source_conn->encrypted_packets_[index])); 84 for (vector<QuicStreamFrame>::const_iterator 85 i = framer.stream_frames().begin(); 86 i != framer.stream_frames().end(); ++i) { 87 ASSERT_TRUE(crypto_framer.ProcessInput(i->data)); 88 ASSERT_FALSE(crypto_visitor.error()); 89 } 90 } 91 *inout_packet_index = index; 92 93 QuicConnectionPeer::SwapCrypters(dest_conn, framer.framer()); 94 95 ASSERT_EQ(0u, crypto_framer.InputBytesRemaining()); 96 97 for (vector<CryptoHandshakeMessage>::const_iterator 98 i = crypto_visitor.messages().begin(); 99 i != crypto_visitor.messages().end(); ++i) { 100 dest_stream->OnHandshakeMessage(*i); 101 } 102 } 103 104 // HexChar parses |c| as a hex character. If valid, it sets |*value| to the 105 // value of the hex character and returns true. Otherwise it returns false. 106 bool HexChar(char c, uint8* value) { 107 if (c >= '0' && c <= '9') { 108 *value = c - '0'; 109 return true; 110 } 111 if (c >= 'a' && c <= 'f') { 112 *value = c - 'a' + 10; 113 return true; 114 } 115 if (c >= 'A' && c <= 'F') { 116 *value = c - 'A' + 10; 117 return true; 118 } 119 return false; 120 } 121 122 } // anonymous namespace 123 124 CryptoTestUtils::FakeClientOptions::FakeClientOptions() 125 : dont_verify_certs(false), 126 channel_id_enabled(false) { 127 } 128 129 // static 130 int CryptoTestUtils::HandshakeWithFakeServer( 131 PacketSavingConnection* client_conn, 132 QuicCryptoClientStream* client) { 133 QuicGuid guid(1); 134 IPAddressNumber ip; 135 CHECK(ParseIPLiteralToNumber("192.0.2.33", &ip)); 136 IPEndPoint addr = IPEndPoint(ip, 1); 137 PacketSavingConnection* server_conn = 138 new PacketSavingConnection(guid, addr, true); 139 TestSession server_session(server_conn, QuicConfig(), true); 140 141 QuicCryptoServerConfig crypto_config(QuicCryptoServerConfig::TESTING, 142 QuicRandom::GetInstance()); 143 SetupCryptoServerConfigForTest( 144 server_session.connection()->clock(), 145 server_session.connection()->random_generator(), 146 server_session.config(), &crypto_config); 147 148 QuicCryptoServerStream server(crypto_config, &server_session); 149 server_session.SetCryptoStream(&server); 150 151 // The client's handshake must have been started already. 152 CHECK_NE(0u, client_conn->packets_.size()); 153 154 CommunicateHandshakeMessages(client_conn, client, server_conn, &server); 155 156 CompareClientAndServerKeys(client, &server); 157 158 return client->num_sent_client_hellos(); 159 } 160 161 // static 162 int CryptoTestUtils::HandshakeWithFakeClient( 163 PacketSavingConnection* server_conn, 164 QuicCryptoServerStream* server, 165 const FakeClientOptions& options) { 166 QuicGuid guid(1); 167 IPAddressNumber ip; 168 CHECK(ParseIPLiteralToNumber("192.0.2.33", &ip)); 169 IPEndPoint addr = IPEndPoint(ip, 1); 170 PacketSavingConnection* client_conn = 171 new PacketSavingConnection(guid, addr, false); 172 TestSession client_session(client_conn, QuicConfig(), false); 173 QuicCryptoClientConfig crypto_config; 174 175 client_session.config()->SetDefaults(); 176 crypto_config.SetDefaults(); 177 // TODO(rtenneti): Enable testing of ProofVerifier. 178 // if (!options.dont_verify_certs) { 179 // crypto_config.SetProofVerifier(ProofVerifierForTesting()); 180 // } 181 if (options.channel_id_enabled) { 182 crypto_config.SetChannelIDSigner(ChannelIDSignerForTesting()); 183 } 184 QuicCryptoClientStream client("test.example.com", &client_session, 185 &crypto_config); 186 client_session.SetCryptoStream(&client); 187 188 CHECK(client.CryptoConnect()); 189 CHECK_EQ(1u, client_conn->packets_.size()); 190 191 CommunicateHandshakeMessages(client_conn, &client, server_conn, server); 192 193 CompareClientAndServerKeys(&client, server); 194 195 if (options.channel_id_enabled) { 196 EXPECT_EQ(crypto_config.channel_id_signer()->GetKeyForHostname( 197 "test.example.com"), 198 server->crypto_negotiated_params().channel_id); 199 } 200 201 return client.num_sent_client_hellos(); 202 } 203 204 // static 205 void CryptoTestUtils::SetupCryptoServerConfigForTest( 206 const QuicClock* clock, 207 QuicRandom* rand, 208 QuicConfig* config, 209 QuicCryptoServerConfig* crypto_config) { 210 config->SetDefaults(); 211 QuicCryptoServerConfig::ConfigOptions options; 212 options.channel_id_enabled = true; 213 scoped_ptr<CryptoHandshakeMessage> scfg( 214 crypto_config->AddDefaultConfig(rand, clock, options)); 215 } 216 217 // static 218 void CryptoTestUtils::CommunicateHandshakeMessages( 219 PacketSavingConnection* a_conn, 220 QuicCryptoStream* a, 221 PacketSavingConnection* b_conn, 222 QuicCryptoStream* b) { 223 size_t a_i = 0, b_i = 0; 224 while (!a->handshake_confirmed()) { 225 ASSERT_GT(a_conn->packets_.size(), a_i); 226 LOG(INFO) << "Processing " << a_conn->packets_.size() - a_i 227 << " packets a->b"; 228 MovePackets(a_conn, &a_i, b, b_conn); 229 230 ASSERT_GT(b_conn->packets_.size(), b_i); 231 LOG(INFO) << "Processing " << b_conn->packets_.size() - b_i 232 << " packets b->a"; 233 if (b_conn->packets_.size() - b_i == 2) { 234 LOG(INFO) << "here"; 235 } 236 MovePackets(b_conn, &b_i, a, a_conn); 237 } 238 } 239 240 // static 241 string CryptoTestUtils::GetValueForTag(const CryptoHandshakeMessage& message, 242 QuicTag tag) { 243 QuicTagValueMap::const_iterator it = message.tag_value_map().find(tag); 244 if (it == message.tag_value_map().end()) { 245 return string(); 246 } 247 return it->second; 248 } 249 250 class MockCommonCertSets : public CommonCertSets { 251 public: 252 MockCommonCertSets(StringPiece cert, uint64 hash, uint32 index) 253 : cert_(cert.as_string()), 254 hash_(hash), 255 index_(index) { 256 } 257 258 virtual StringPiece GetCommonHashes() const OVERRIDE { 259 CHECK(false) << "not implemented"; 260 return StringPiece(); 261 } 262 263 virtual StringPiece GetCert(uint64 hash, uint32 index) const OVERRIDE { 264 if (hash == hash_ && index == index_) { 265 return cert_; 266 } 267 return StringPiece(); 268 } 269 270 virtual bool MatchCert(StringPiece cert, 271 StringPiece common_set_hashes, 272 uint64* out_hash, 273 uint32* out_index) const OVERRIDE { 274 if (cert != cert_) { 275 return false; 276 } 277 278 if (common_set_hashes.size() % sizeof(uint64) != 0) { 279 return false; 280 } 281 bool client_has_set = false; 282 for (size_t i = 0; i < common_set_hashes.size(); i += sizeof(uint64)) { 283 uint64 hash; 284 memcpy(&hash, common_set_hashes.data() + i, sizeof(hash)); 285 if (hash == hash_) { 286 client_has_set = true; 287 break; 288 } 289 } 290 291 if (!client_has_set) { 292 return false; 293 } 294 295 *out_hash = hash_; 296 *out_index = index_; 297 return true; 298 } 299 300 private: 301 const string cert_; 302 const uint64 hash_; 303 const uint32 index_; 304 }; 305 306 CommonCertSets* CryptoTestUtils::MockCommonCertSets(StringPiece cert, 307 uint64 hash, 308 uint32 index) { 309 return new class MockCommonCertSets(cert, hash, index); 310 } 311 312 void CryptoTestUtils::CompareClientAndServerKeys( 313 QuicCryptoClientStream* client, 314 QuicCryptoServerStream* server) { 315 const QuicEncrypter* client_encrypter( 316 client->session()->connection()->encrypter(ENCRYPTION_INITIAL)); 317 const QuicDecrypter* client_decrypter( 318 client->session()->connection()->decrypter()); 319 const QuicEncrypter* client_forward_secure_encrypter( 320 client->session()->connection()->encrypter(ENCRYPTION_FORWARD_SECURE)); 321 const QuicDecrypter* client_forward_secure_decrypter( 322 client->session()->connection()->alternative_decrypter()); 323 const QuicEncrypter* server_encrypter( 324 server->session()->connection()->encrypter(ENCRYPTION_INITIAL)); 325 const QuicDecrypter* server_decrypter( 326 server->session()->connection()->decrypter()); 327 const QuicEncrypter* server_forward_secure_encrypter( 328 server->session()->connection()->encrypter(ENCRYPTION_FORWARD_SECURE)); 329 const QuicDecrypter* server_forward_secure_decrypter( 330 server->session()->connection()->alternative_decrypter()); 331 332 StringPiece client_encrypter_key = client_encrypter->GetKey(); 333 StringPiece client_encrypter_iv = client_encrypter->GetNoncePrefix(); 334 StringPiece client_decrypter_key = client_decrypter->GetKey(); 335 StringPiece client_decrypter_iv = client_decrypter->GetNoncePrefix(); 336 StringPiece client_forward_secure_encrypter_key = 337 client_forward_secure_encrypter->GetKey(); 338 StringPiece client_forward_secure_encrypter_iv = 339 client_forward_secure_encrypter->GetNoncePrefix(); 340 StringPiece client_forward_secure_decrypter_key = 341 client_forward_secure_decrypter->GetKey(); 342 StringPiece client_forward_secure_decrypter_iv = 343 client_forward_secure_decrypter->GetNoncePrefix(); 344 StringPiece server_encrypter_key = server_encrypter->GetKey(); 345 StringPiece server_encrypter_iv = server_encrypter->GetNoncePrefix(); 346 StringPiece server_decrypter_key = server_decrypter->GetKey(); 347 StringPiece server_decrypter_iv = server_decrypter->GetNoncePrefix(); 348 StringPiece server_forward_secure_encrypter_key = 349 server_forward_secure_encrypter->GetKey(); 350 StringPiece server_forward_secure_encrypter_iv = 351 server_forward_secure_encrypter->GetNoncePrefix(); 352 StringPiece server_forward_secure_decrypter_key = 353 server_forward_secure_decrypter->GetKey(); 354 StringPiece server_forward_secure_decrypter_iv = 355 server_forward_secure_decrypter->GetNoncePrefix(); 356 357 CompareCharArraysWithHexError("client write key", 358 client_encrypter_key.data(), 359 client_encrypter_key.length(), 360 server_decrypter_key.data(), 361 server_decrypter_key.length()); 362 CompareCharArraysWithHexError("client write IV", 363 client_encrypter_iv.data(), 364 client_encrypter_iv.length(), 365 server_decrypter_iv.data(), 366 server_decrypter_iv.length()); 367 CompareCharArraysWithHexError("server write key", 368 server_encrypter_key.data(), 369 server_encrypter_key.length(), 370 client_decrypter_key.data(), 371 client_decrypter_key.length()); 372 CompareCharArraysWithHexError("server write IV", 373 server_encrypter_iv.data(), 374 server_encrypter_iv.length(), 375 client_decrypter_iv.data(), 376 client_decrypter_iv.length()); 377 CompareCharArraysWithHexError("client forward secure write key", 378 client_forward_secure_encrypter_key.data(), 379 client_forward_secure_encrypter_key.length(), 380 server_forward_secure_decrypter_key.data(), 381 server_forward_secure_decrypter_key.length()); 382 CompareCharArraysWithHexError("client forward secure write IV", 383 client_forward_secure_encrypter_iv.data(), 384 client_forward_secure_encrypter_iv.length(), 385 server_forward_secure_decrypter_iv.data(), 386 server_forward_secure_decrypter_iv.length()); 387 CompareCharArraysWithHexError("server forward secure write key", 388 server_forward_secure_encrypter_key.data(), 389 server_forward_secure_encrypter_key.length(), 390 client_forward_secure_decrypter_key.data(), 391 client_forward_secure_decrypter_key.length()); 392 CompareCharArraysWithHexError("server forward secure write IV", 393 server_forward_secure_encrypter_iv.data(), 394 server_forward_secure_encrypter_iv.length(), 395 client_forward_secure_decrypter_iv.data(), 396 client_forward_secure_decrypter_iv.length()); 397 } 398 399 // static 400 QuicTag CryptoTestUtils::ParseTag(const char* tagstr) { 401 const size_t len = strlen(tagstr); 402 CHECK_NE(0u, len); 403 404 QuicTag tag = 0; 405 406 if (tagstr[0] == '#') { 407 CHECK_EQ(static_cast<size_t>(1 + 2*4), len); 408 tagstr++; 409 410 for (size_t i = 0; i < 8; i++) { 411 tag <<= 4; 412 413 uint8 v = 0; 414 CHECK(HexChar(tagstr[i], &v)); 415 tag |= v; 416 } 417 418 return tag; 419 } 420 421 CHECK_LE(len, 4u); 422 for (size_t i = 0; i < 4; i++) { 423 tag >>= 8; 424 if (i < len) { 425 tag |= static_cast<uint32>(tagstr[i]) << 24; 426 } 427 } 428 429 return tag; 430 } 431 432 // static 433 CryptoHandshakeMessage CryptoTestUtils::Message(const char* message_tag, ...) { 434 va_list ap; 435 va_start(ap, message_tag); 436 437 CryptoHandshakeMessage message = BuildMessage(message_tag, ap); 438 va_end(ap); 439 return message; 440 } 441 442 // static 443 CryptoHandshakeMessage CryptoTestUtils::BuildMessage(const char* message_tag, 444 va_list ap) { 445 CryptoHandshakeMessage msg; 446 msg.set_tag(ParseTag(message_tag)); 447 448 for (;;) { 449 const char* tagstr = va_arg(ap, const char*); 450 if (tagstr == NULL) { 451 break; 452 } 453 454 if (tagstr[0] == '$') { 455 // Special value. 456 const char* const special = tagstr + 1; 457 if (strcmp(special, "padding") == 0) { 458 const int min_bytes = va_arg(ap, int); 459 msg.set_minimum_size(min_bytes); 460 } else { 461 CHECK(false) << "Unknown special value: " << special; 462 } 463 464 continue; 465 } 466 467 const QuicTag tag = ParseTag(tagstr); 468 const char* valuestr = va_arg(ap, const char*); 469 470 size_t len = strlen(valuestr); 471 if (len > 0 && valuestr[0] == '#') { 472 valuestr++; 473 len--; 474 475 CHECK(len % 2 == 0); 476 scoped_ptr<uint8[]> buf(new uint8[len/2]); 477 478 for (size_t i = 0; i < len/2; i++) { 479 uint8 v = 0; 480 CHECK(HexChar(valuestr[i*2], &v)); 481 buf[i] = v << 4; 482 CHECK(HexChar(valuestr[i*2 + 1], &v)); 483 buf[i] |= v; 484 } 485 486 msg.SetStringPiece( 487 tag, StringPiece(reinterpret_cast<char*>(buf.get()), len/2)); 488 continue; 489 } 490 491 msg.SetStringPiece(tag, valuestr); 492 } 493 494 // The CryptoHandshakeMessage needs to be serialized and parsed to ensure 495 // that any padding is included. 496 scoped_ptr<QuicData> bytes(CryptoFramer::ConstructHandshakeMessage(msg)); 497 scoped_ptr<CryptoHandshakeMessage> parsed( 498 CryptoFramer::ParseMessage(bytes->AsStringPiece())); 499 CHECK(parsed.get()); 500 501 return *parsed; 502 } 503 504 } // namespace test 505 } // namespace net 506