1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #define LOG_TAG "dns_tls_test" 18 19 #include <gtest/gtest.h> 20 21 #include "dns/DnsTlsDispatcher.h" 22 #include "dns/DnsTlsQueryMap.h" 23 #include "dns/DnsTlsServer.h" 24 #include "dns/DnsTlsSessionCache.h" 25 #include "dns/DnsTlsSocket.h" 26 #include "dns/DnsTlsTransport.h" 27 #include "dns/IDnsTlsSocket.h" 28 #include "dns/IDnsTlsSocketFactory.h" 29 #include "dns/IDnsTlsSocketObserver.h" 30 31 #include <chrono> 32 #include <arpa/inet.h> 33 #include <android-base/macros.h> 34 #include <netdutils/Slice.h> 35 36 #include "log/log.h" 37 38 namespace android { 39 namespace net { 40 41 using netdutils::Slice; 42 using netdutils::makeSlice; 43 44 typedef std::vector<uint8_t> bytevec; 45 46 static void parseServer(const char* server, in_port_t port, sockaddr_storage* parsed) { 47 sockaddr_in* sin = reinterpret_cast<sockaddr_in*>(parsed); 48 if (inet_pton(AF_INET, server, &(sin->sin_addr)) == 1) { 49 // IPv4 parse succeeded, so it's IPv4 50 sin->sin_family = AF_INET; 51 sin->sin_port = htons(port); 52 return; 53 } 54 sockaddr_in6* sin6 = reinterpret_cast<sockaddr_in6*>(parsed); 55 if (inet_pton(AF_INET6, server, &(sin6->sin6_addr)) == 1){ 56 // IPv6 parse succeeded, so it's IPv6. 57 sin6->sin6_family = AF_INET6; 58 sin6->sin6_port = htons(port); 59 return; 60 } 61 ALOGE("Failed to parse server address: %s", server); 62 } 63 64 bytevec FINGERPRINT1 = { 1 }; 65 bytevec FINGERPRINT2 = { 2 }; 66 67 std::string SERVERNAME1 = "dns.example.com"; 68 std::string SERVERNAME2 = "dns.example.org"; 69 70 // BaseTest just provides constants that are useful for the tests. 71 class BaseTest : public ::testing::Test { 72 protected: 73 BaseTest() { 74 parseServer("192.0.2.1", 853, &V4ADDR1); 75 parseServer("192.0.2.2", 853, &V4ADDR2); 76 parseServer("2001:db8::1", 853, &V6ADDR1); 77 parseServer("2001:db8::2", 853, &V6ADDR2); 78 79 SERVER1 = DnsTlsServer(V4ADDR1); 80 SERVER1.fingerprints.insert(FINGERPRINT1); 81 SERVER1.name = SERVERNAME1; 82 } 83 84 sockaddr_storage V4ADDR1; 85 sockaddr_storage V4ADDR2; 86 sockaddr_storage V6ADDR1; 87 sockaddr_storage V6ADDR2; 88 89 DnsTlsServer SERVER1; 90 }; 91 92 bytevec make_query(uint16_t id, size_t size) { 93 bytevec vec(size); 94 vec[0] = id >> 8; 95 vec[1] = id; 96 // Arbitrarily fill the query body with unique data. 97 for (size_t i = 2; i < size; ++i) { 98 vec[i] = id + i; 99 } 100 return vec; 101 } 102 103 // Query constants 104 const unsigned MARK = 123; 105 const uint16_t ID = 52; 106 const uint16_t SIZE = 22; 107 const bytevec QUERY = make_query(ID, SIZE); 108 109 template <class T> 110 class FakeSocketFactory : public IDnsTlsSocketFactory { 111 public: 112 FakeSocketFactory() {} 113 std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket( 114 const DnsTlsServer& server ATTRIBUTE_UNUSED, 115 unsigned mark ATTRIBUTE_UNUSED, 116 IDnsTlsSocketObserver* observer, 117 DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override { 118 return std::make_unique<T>(observer); 119 } 120 }; 121 122 bytevec make_echo(uint16_t id, const Slice query) { 123 bytevec response(query.size() + 2); 124 response[0] = id >> 8; 125 response[1] = id; 126 // Echo the query as the fake response. 127 memcpy(response.data() + 2, query.base(), query.size()); 128 return response; 129 } 130 131 // Simplest possible fake server. This just echoes the query as the response. 132 class FakeSocketEcho : public IDnsTlsSocket { 133 public: 134 FakeSocketEcho(IDnsTlsSocketObserver* observer) : mObserver(observer) {} 135 bool query(uint16_t id, const Slice query) override { 136 // Return the response immediately (asynchronously). 137 std::thread(&IDnsTlsSocketObserver::onResponse, mObserver, make_echo(id, query)).detach(); 138 return true; 139 } 140 private: 141 IDnsTlsSocketObserver* const mObserver; 142 }; 143 144 class TransportTest : public BaseTest {}; 145 146 TEST_F(TransportTest, Query) { 147 FakeSocketFactory<FakeSocketEcho> factory; 148 DnsTlsTransport transport(SERVER1, MARK, &factory); 149 auto r = transport.query(makeSlice(QUERY)).get(); 150 151 EXPECT_EQ(DnsTlsTransport::Response::success, r.code); 152 EXPECT_EQ(QUERY, r.response); 153 } 154 155 TEST_F(TransportTest, SerialQueries_100000) { 156 FakeSocketFactory<FakeSocketEcho> factory; 157 DnsTlsTransport transport(SERVER1, MARK, &factory); 158 // Send more than 65536 queries serially. 159 for (int i = 0; i < 100000; ++i) { 160 auto r = transport.query(makeSlice(QUERY)).get(); 161 162 EXPECT_EQ(DnsTlsTransport::Response::success, r.code); 163 EXPECT_EQ(QUERY, r.response); 164 } 165 } 166 167 // These queries might be handled in serial or parallel as they race the 168 // responses. 169 TEST_F(TransportTest, RacingQueries_10000) { 170 FakeSocketFactory<FakeSocketEcho> factory; 171 DnsTlsTransport transport(SERVER1, MARK, &factory); 172 std::vector<std::future<DnsTlsTransport::Result>> results; 173 // Fewer than 65536 queries to avoid ID exhaustion. 174 for (int i = 0; i < 10000; ++i) { 175 results.push_back(transport.query(makeSlice(QUERY))); 176 } 177 for (auto& result : results) { 178 auto r = result.get(); 179 EXPECT_EQ(DnsTlsTransport::Response::success, r.code); 180 EXPECT_EQ(QUERY, r.response); 181 } 182 } 183 184 // A server that waits until sDelay queries are queued before responding. 185 class FakeSocketDelay : public IDnsTlsSocket { 186 public: 187 FakeSocketDelay(IDnsTlsSocketObserver* observer) : mObserver(observer) {} 188 ~FakeSocketDelay() { std::lock_guard<std::mutex> guard(mLock); } 189 static size_t sDelay; 190 static bool sReverse; 191 192 bool query(uint16_t id, const Slice query) override { 193 ALOGD("FakeSocketDelay got query with ID %d", int(id)); 194 std::lock_guard<std::mutex> guard(mLock); 195 // Check for duplicate IDs. 196 EXPECT_EQ(0U, mIds.count(id)); 197 mIds.insert(id); 198 199 // Store response. 200 mResponses.push_back(make_echo(id, query)); 201 202 ALOGD("Up to %zu out of %zu queries", mResponses.size(), sDelay); 203 if (mResponses.size() == sDelay) { 204 std::thread(&FakeSocketDelay::sendResponses, this).detach(); 205 } 206 return true; 207 } 208 private: 209 void sendResponses() { 210 std::lock_guard<std::mutex> guard(mLock); 211 if (sReverse) { 212 std::reverse(std::begin(mResponses), std::end(mResponses)); 213 } 214 for (auto& response : mResponses) { 215 mObserver->onResponse(response); 216 } 217 mIds.clear(); 218 mResponses.clear(); 219 } 220 221 std::mutex mLock; 222 IDnsTlsSocketObserver* const mObserver; 223 std::set<uint16_t> mIds GUARDED_BY(mLock); 224 std::vector<bytevec> mResponses GUARDED_BY(mLock); 225 }; 226 227 size_t FakeSocketDelay::sDelay; 228 bool FakeSocketDelay::sReverse; 229 230 TEST_F(TransportTest, ParallelColliding) { 231 FakeSocketDelay::sDelay = 10; 232 FakeSocketDelay::sReverse = false; 233 FakeSocketFactory<FakeSocketDelay> factory; 234 DnsTlsTransport transport(SERVER1, MARK, &factory); 235 std::vector<std::future<DnsTlsTransport::Result>> results; 236 // Fewer than 65536 queries to avoid ID exhaustion. 237 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) { 238 results.push_back(transport.query(makeSlice(QUERY))); 239 } 240 for (auto& result : results) { 241 auto r = result.get(); 242 EXPECT_EQ(DnsTlsTransport::Response::success, r.code); 243 EXPECT_EQ(QUERY, r.response); 244 } 245 } 246 247 TEST_F(TransportTest, ParallelColliding_Max) { 248 FakeSocketDelay::sDelay = 65536; 249 FakeSocketDelay::sReverse = false; 250 FakeSocketFactory<FakeSocketDelay> factory; 251 DnsTlsTransport transport(SERVER1, MARK, &factory); 252 std::vector<std::future<DnsTlsTransport::Result>> results; 253 // Exactly 65536 queries should still be possible in parallel, 254 // even if they all have the same original ID. 255 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) { 256 results.push_back(transport.query(makeSlice(QUERY))); 257 } 258 for (auto& result : results) { 259 auto r = result.get(); 260 EXPECT_EQ(DnsTlsTransport::Response::success, r.code); 261 EXPECT_EQ(QUERY, r.response); 262 } 263 } 264 265 TEST_F(TransportTest, ParallelUnique) { 266 FakeSocketDelay::sDelay = 10; 267 FakeSocketDelay::sReverse = false; 268 FakeSocketFactory<FakeSocketDelay> factory; 269 DnsTlsTransport transport(SERVER1, MARK, &factory); 270 std::vector<bytevec> queries(FakeSocketDelay::sDelay); 271 std::vector<std::future<DnsTlsTransport::Result>> results; 272 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) { 273 queries[i] = make_query(i, SIZE); 274 results.push_back(transport.query(makeSlice(queries[i]))); 275 } 276 for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) { 277 auto r = results[i].get(); 278 EXPECT_EQ(DnsTlsTransport::Response::success, r.code); 279 EXPECT_EQ(queries[i], r.response); 280 } 281 } 282 283 TEST_F(TransportTest, ParallelUnique_Max) { 284 FakeSocketDelay::sDelay = 65536; 285 FakeSocketDelay::sReverse = false; 286 FakeSocketFactory<FakeSocketDelay> factory; 287 DnsTlsTransport transport(SERVER1, MARK, &factory); 288 std::vector<bytevec> queries(FakeSocketDelay::sDelay); 289 std::vector<std::future<DnsTlsTransport::Result>> results; 290 // Exactly 65536 queries should still be possible in parallel, 291 // and they should all be mapped correctly back to the original ID. 292 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) { 293 queries[i] = make_query(i, SIZE); 294 results.push_back(transport.query(makeSlice(queries[i]))); 295 } 296 for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) { 297 auto r = results[i].get(); 298 EXPECT_EQ(DnsTlsTransport::Response::success, r.code); 299 EXPECT_EQ(queries[i], r.response); 300 } 301 } 302 303 TEST_F(TransportTest, IdExhaustion) { 304 // A delay of 65537 is unreachable, because the maximum number 305 // of outstanding queries is 65536. 306 FakeSocketDelay::sDelay = 65537; 307 FakeSocketDelay::sReverse = false; 308 FakeSocketFactory<FakeSocketDelay> factory; 309 DnsTlsTransport transport(SERVER1, MARK, &factory); 310 std::vector<std::future<DnsTlsTransport::Result>> results; 311 // Issue the maximum number of queries. 312 for (int i = 0; i < 65536; ++i) { 313 results.push_back(transport.query(makeSlice(QUERY))); 314 } 315 316 // The ID space is now full, so subsequent queries should fail immediately. 317 auto r = transport.query(makeSlice(QUERY)).get(); 318 EXPECT_EQ(DnsTlsTransport::Response::internal_error, r.code); 319 EXPECT_TRUE(r.response.empty()); 320 321 for (auto& result : results) { 322 // All other queries should remain outstanding. 323 EXPECT_EQ(std::future_status::timeout, 324 result.wait_for(std::chrono::duration<int>::zero())); 325 } 326 } 327 328 // Responses can come back from the server in any order. This should have no 329 // effect on Transport's observed behavior. 330 TEST_F(TransportTest, ReverseOrder) { 331 FakeSocketDelay::sDelay = 10; 332 FakeSocketDelay::sReverse = true; 333 FakeSocketFactory<FakeSocketDelay> factory; 334 DnsTlsTransport transport(SERVER1, MARK, &factory); 335 std::vector<bytevec> queries(FakeSocketDelay::sDelay); 336 std::vector<std::future<DnsTlsTransport::Result>> results; 337 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) { 338 queries[i] = make_query(i, SIZE); 339 results.push_back(transport.query(makeSlice(queries[i]))); 340 } 341 for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) { 342 auto r = results[i].get(); 343 EXPECT_EQ(DnsTlsTransport::Response::success, r.code); 344 EXPECT_EQ(queries[i], r.response); 345 } 346 } 347 348 TEST_F(TransportTest, ReverseOrder_Max) { 349 FakeSocketDelay::sDelay = 65536; 350 FakeSocketDelay::sReverse = true; 351 FakeSocketFactory<FakeSocketDelay> factory; 352 DnsTlsTransport transport(SERVER1, MARK, &factory); 353 std::vector<bytevec> queries(FakeSocketDelay::sDelay); 354 std::vector<std::future<DnsTlsTransport::Result>> results; 355 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) { 356 queries[i] = make_query(i, SIZE); 357 results.push_back(transport.query(makeSlice(queries[i]))); 358 } 359 for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) { 360 auto r = results[i].get(); 361 EXPECT_EQ(DnsTlsTransport::Response::success, r.code); 362 EXPECT_EQ(queries[i], r.response); 363 } 364 } 365 366 // Returning null from the factory indicates a connection failure. 367 class NullSocketFactory : public IDnsTlsSocketFactory { 368 public: 369 NullSocketFactory() {} 370 std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket( 371 const DnsTlsServer& server ATTRIBUTE_UNUSED, 372 unsigned mark ATTRIBUTE_UNUSED, 373 IDnsTlsSocketObserver* observer ATTRIBUTE_UNUSED, 374 DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override { 375 return nullptr; 376 } 377 }; 378 379 TEST_F(TransportTest, ConnectFail) { 380 NullSocketFactory factory; 381 DnsTlsTransport transport(SERVER1, MARK, &factory); 382 auto r = transport.query(makeSlice(QUERY)).get(); 383 384 EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code); 385 EXPECT_TRUE(r.response.empty()); 386 } 387 388 // Simulate a socket that connects but then immediately receives a server 389 // close notification. 390 class FakeSocketClose : public IDnsTlsSocket { 391 public: 392 FakeSocketClose(IDnsTlsSocketObserver* observer) : 393 mCloser(&IDnsTlsSocketObserver::onClosed, observer) {} 394 ~FakeSocketClose() { mCloser.join(); } 395 bool query(uint16_t id ATTRIBUTE_UNUSED, 396 const Slice query ATTRIBUTE_UNUSED) override { 397 return true; 398 } 399 private: 400 std::thread mCloser; 401 }; 402 403 TEST_F(TransportTest, CloseRetryFail) { 404 FakeSocketFactory<FakeSocketClose> factory; 405 DnsTlsTransport transport(SERVER1, MARK, &factory); 406 auto r = transport.query(makeSlice(QUERY)).get(); 407 408 EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code); 409 EXPECT_TRUE(r.response.empty()); 410 } 411 412 // Simulate a server that occasionally closes the connection and silently 413 // drops some queries. 414 class FakeSocketLimited : public IDnsTlsSocket { 415 public: 416 static int sLimit; // Number of queries to answer per socket. 417 static size_t sMaxSize; // Silently discard queries greater than this size. 418 FakeSocketLimited(IDnsTlsSocketObserver* observer) : 419 mObserver(observer), mQueries(0) {} 420 ~FakeSocketLimited() { 421 { 422 ALOGD("~FakeSocketLimited acquiring mLock"); 423 std::lock_guard<std::mutex> guard(mLock); 424 ALOGD("~FakeSocketLimited acquired mLock"); 425 for (auto& thread : mThreads) { 426 ALOGD("~FakeSocketLimited joining response thread"); 427 thread.join(); 428 ALOGD("~FakeSocketLimited joined response thread"); 429 } 430 mThreads.clear(); 431 } 432 433 if (mCloser) { 434 ALOGD("~FakeSocketLimited joining closer thread"); 435 mCloser->join(); 436 ALOGD("~FakeSocketLimited joined closer thread"); 437 } 438 } 439 bool query(uint16_t id, const Slice query) override { 440 ALOGD("FakeSocketLimited::query acquiring mLock"); 441 std::lock_guard<std::mutex> guard(mLock); 442 ALOGD("FakeSocketLimited::query acquired mLock"); 443 ++mQueries; 444 445 if (mQueries <= sLimit) { 446 ALOGD("size %zu vs. limit of %zu", query.size(), sMaxSize); 447 if (query.size() <= sMaxSize) { 448 // Return the response immediately (asynchronously). 449 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, make_echo(id, query)); 450 } 451 } 452 if (mQueries == sLimit) { 453 mCloser = std::make_unique<std::thread>(&FakeSocketLimited::sendClose, this); 454 } 455 return mQueries <= sLimit; 456 } 457 private: 458 void sendClose() { 459 { 460 ALOGD("FakeSocketLimited::sendClose acquiring mLock"); 461 std::lock_guard<std::mutex> guard(mLock); 462 ALOGD("FakeSocketLimited::sendClose acquired mLock"); 463 for (auto& thread : mThreads) { 464 ALOGD("FakeSocketLimited::sendClose joining response thread"); 465 thread.join(); 466 ALOGD("FakeSocketLimited::sendClose joined response thread"); 467 } 468 mThreads.clear(); 469 } 470 mObserver->onClosed(); 471 } 472 std::mutex mLock; 473 IDnsTlsSocketObserver* const mObserver; 474 int mQueries GUARDED_BY(mLock); 475 std::vector<std::thread> mThreads GUARDED_BY(mLock); 476 std::unique_ptr<std::thread> mCloser GUARDED_BY(mLock); 477 }; 478 479 int FakeSocketLimited::sLimit; 480 size_t FakeSocketLimited::sMaxSize; 481 482 TEST_F(TransportTest, SilentDrop) { 483 FakeSocketLimited::sLimit = 10; // Close the socket after 10 queries. 484 FakeSocketLimited::sMaxSize = 0; // Silently drop all queries 485 FakeSocketFactory<FakeSocketLimited> factory; 486 DnsTlsTransport transport(SERVER1, MARK, &factory); 487 488 // Queue up 10 queries. They will all be ignored, and after the 10th, 489 // the socket will close. Transport will retry them all, until they 490 // all hit the retry limit and expire. 491 std::vector<std::future<DnsTlsTransport::Result>> results; 492 for (int i = 0; i < FakeSocketLimited::sLimit; ++i) { 493 results.push_back(transport.query(makeSlice(QUERY))); 494 } 495 for (auto& result : results) { 496 auto r = result.get(); 497 EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code); 498 EXPECT_TRUE(r.response.empty()); 499 } 500 } 501 502 TEST_F(TransportTest, PartialDrop) { 503 FakeSocketLimited::sLimit = 10; // Close the socket after 10 queries. 504 FakeSocketLimited::sMaxSize = SIZE - 2; // Silently drop "long" queries 505 FakeSocketFactory<FakeSocketLimited> factory; 506 DnsTlsTransport transport(SERVER1, MARK, &factory); 507 508 // Queue up 100 queries, alternating "short" which will be served and "long" 509 // which will be dropped. 510 int num_queries = 10 * FakeSocketLimited::sLimit; 511 std::vector<bytevec> queries(num_queries); 512 std::vector<std::future<DnsTlsTransport::Result>> results; 513 for (int i = 0; i < num_queries; ++i) { 514 queries[i] = make_query(i, SIZE + (i % 2)); 515 results.push_back(transport.query(makeSlice(queries[i]))); 516 } 517 // Just check the short queries, which are at the even indices. 518 for (int i = 0; i < num_queries; i += 2) { 519 auto r = results[i].get(); 520 EXPECT_EQ(DnsTlsTransport::Response::success, r.code); 521 EXPECT_EQ(queries[i], r.response); 522 } 523 } 524 525 // Simulate a malfunctioning server that injects extra miscellaneous 526 // responses to queries that were not asked. This will cause wrong answers but 527 // must not crash the Transport. 528 class FakeSocketGarbage : public IDnsTlsSocket { 529 public: 530 FakeSocketGarbage(IDnsTlsSocketObserver* observer) : mObserver(observer) { 531 // Inject a garbage event. 532 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, make_query(ID + 1, SIZE)); 533 } 534 ~FakeSocketGarbage() { 535 std::lock_guard<std::mutex> guard(mLock); 536 for (auto& thread : mThreads) { 537 thread.join(); 538 } 539 } 540 bool query(uint16_t id, const Slice query) override { 541 std::lock_guard<std::mutex> guard(mLock); 542 // Return the response twice. 543 auto echo = make_echo(id, query); 544 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, echo); 545 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, echo); 546 // Also return some other garbage 547 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, make_query(id + 1, query.size() + 2)); 548 return true; 549 } 550 private: 551 std::mutex mLock; 552 std::vector<std::thread> mThreads GUARDED_BY(mLock); 553 IDnsTlsSocketObserver* const mObserver; 554 }; 555 556 TEST_F(TransportTest, IgnoringGarbage) { 557 FakeSocketFactory<FakeSocketGarbage> factory; 558 DnsTlsTransport transport(SERVER1, MARK, &factory); 559 for (int i = 0; i < 10; ++i) { 560 auto r = transport.query(makeSlice(QUERY)).get(); 561 562 EXPECT_EQ(DnsTlsTransport::Response::success, r.code); 563 // Don't check the response because this server is malfunctioning. 564 } 565 } 566 567 // Dispatcher tests 568 class DispatcherTest : public BaseTest {}; 569 570 TEST_F(DispatcherTest, Query) { 571 bytevec ans(4096); 572 int resplen = 0; 573 574 auto factory = std::make_unique<FakeSocketFactory<FakeSocketEcho>>(); 575 DnsTlsDispatcher dispatcher(std::move(factory)); 576 auto r = dispatcher.query(SERVER1, MARK, makeSlice(QUERY), 577 makeSlice(ans), &resplen); 578 579 EXPECT_EQ(DnsTlsTransport::Response::success, r); 580 EXPECT_EQ(int(QUERY.size()), resplen); 581 ans.resize(resplen); 582 EXPECT_EQ(QUERY, ans); 583 } 584 585 TEST_F(DispatcherTest, AnswerTooLarge) { 586 bytevec ans(SIZE - 1); // Too small to hold the answer 587 int resplen = 0; 588 589 auto factory = std::make_unique<FakeSocketFactory<FakeSocketEcho>>(); 590 DnsTlsDispatcher dispatcher(std::move(factory)); 591 auto r = dispatcher.query(SERVER1, MARK, makeSlice(QUERY), 592 makeSlice(ans), &resplen); 593 594 EXPECT_EQ(DnsTlsTransport::Response::limit_error, r); 595 } 596 597 template<class T> 598 class TrackingFakeSocketFactory : public IDnsTlsSocketFactory { 599 public: 600 TrackingFakeSocketFactory() {} 601 std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket( 602 const DnsTlsServer& server, 603 unsigned mark, 604 IDnsTlsSocketObserver* observer, 605 DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override { 606 std::lock_guard<std::mutex> guard(mLock); 607 keys.emplace(mark, server); 608 return std::make_unique<T>(observer); 609 } 610 std::multiset<std::pair<unsigned, DnsTlsServer>> keys; 611 private: 612 std::mutex mLock; 613 }; 614 615 TEST_F(DispatcherTest, Dispatching) { 616 FakeSocketDelay::sDelay = 5; 617 FakeSocketDelay::sReverse = true; 618 auto factory = std::make_unique<TrackingFakeSocketFactory<FakeSocketDelay>>(); 619 auto* weak_factory = factory.get(); // Valid as long as dispatcher is in scope. 620 DnsTlsDispatcher dispatcher(std::move(factory)); 621 622 // Populate a vector of two servers and two socket marks, four combinations 623 // in total. 624 std::vector<std::pair<unsigned, DnsTlsServer>> keys; 625 keys.emplace_back(MARK, SERVER1); 626 keys.emplace_back(MARK + 1, SERVER1); 627 keys.emplace_back(MARK, V4ADDR2); 628 keys.emplace_back(MARK + 1, V4ADDR2); 629 630 // Do several queries on each server. They should all succeed. 631 std::vector<std::thread> threads; 632 for (size_t i = 0; i < FakeSocketDelay::sDelay * keys.size(); ++i) { 633 auto key = keys[i % keys.size()]; 634 threads.emplace_back([key, i] (DnsTlsDispatcher* dispatcher) { 635 auto q = make_query(i, SIZE); 636 bytevec ans(4096); 637 int resplen = 0; 638 unsigned mark = key.first; 639 const DnsTlsServer& server = key.second; 640 auto r = dispatcher->query(server, mark, makeSlice(q), 641 makeSlice(ans), &resplen); 642 EXPECT_EQ(DnsTlsTransport::Response::success, r); 643 EXPECT_EQ(int(q.size()), resplen); 644 ans.resize(resplen); 645 EXPECT_EQ(q, ans); 646 }, &dispatcher); 647 } 648 for (auto& thread : threads) { 649 thread.join(); 650 } 651 // We expect that the factory created one socket for each key. 652 EXPECT_EQ(keys.size(), weak_factory->keys.size()); 653 for (auto& key : keys) { 654 EXPECT_EQ(1U, weak_factory->keys.count(key)); 655 } 656 } 657 658 // Check DnsTlsServer's comparison logic. 659 AddressComparator ADDRESS_COMPARATOR; 660 bool isAddressEqual(const DnsTlsServer& s1, const DnsTlsServer& s2) { 661 bool cmp1 = ADDRESS_COMPARATOR(s1, s2); 662 bool cmp2 = ADDRESS_COMPARATOR(s2, s1); 663 EXPECT_FALSE(cmp1 && cmp2); 664 return !cmp1 && !cmp2; 665 } 666 667 void checkUnequal(const DnsTlsServer& s1, const DnsTlsServer& s2) { 668 EXPECT_TRUE(s1 == s1); 669 EXPECT_TRUE(s2 == s2); 670 EXPECT_TRUE(isAddressEqual(s1, s1)); 671 EXPECT_TRUE(isAddressEqual(s2, s2)); 672 673 EXPECT_TRUE(s1 < s2 ^ s2 < s1); 674 EXPECT_FALSE(s1 == s2); 675 EXPECT_FALSE(s2 == s1); 676 } 677 678 class ServerTest : public BaseTest {}; 679 680 TEST_F(ServerTest, IPv4) { 681 checkUnequal(V4ADDR1, V4ADDR2); 682 EXPECT_FALSE(isAddressEqual(V4ADDR1, V4ADDR2)); 683 } 684 685 TEST_F(ServerTest, IPv6) { 686 checkUnequal(V6ADDR1, V6ADDR2); 687 EXPECT_FALSE(isAddressEqual(V6ADDR1, V6ADDR2)); 688 } 689 690 TEST_F(ServerTest, MixedAddressFamily) { 691 checkUnequal(V6ADDR1, V4ADDR1); 692 EXPECT_FALSE(isAddressEqual(V6ADDR1, V4ADDR1)); 693 } 694 695 TEST_F(ServerTest, IPv6ScopeId) { 696 DnsTlsServer s1(V6ADDR1), s2(V6ADDR1); 697 sockaddr_in6* addr1 = reinterpret_cast<sockaddr_in6*>(&s1.ss); 698 addr1->sin6_scope_id = 1; 699 sockaddr_in6* addr2 = reinterpret_cast<sockaddr_in6*>(&s2.ss); 700 addr2->sin6_scope_id = 2; 701 checkUnequal(s1, s2); 702 EXPECT_FALSE(isAddressEqual(s1, s2)); 703 704 EXPECT_FALSE(s1.wasExplicitlyConfigured()); 705 EXPECT_FALSE(s2.wasExplicitlyConfigured()); 706 } 707 708 TEST_F(ServerTest, IPv6FlowInfo) { 709 DnsTlsServer s1(V6ADDR1), s2(V6ADDR1); 710 sockaddr_in6* addr1 = reinterpret_cast<sockaddr_in6*>(&s1.ss); 711 addr1->sin6_flowinfo = 1; 712 sockaddr_in6* addr2 = reinterpret_cast<sockaddr_in6*>(&s2.ss); 713 addr2->sin6_flowinfo = 2; 714 // All comparisons ignore flowinfo. 715 EXPECT_EQ(s1, s2); 716 EXPECT_TRUE(isAddressEqual(s1, s2)); 717 718 EXPECT_FALSE(s1.wasExplicitlyConfigured()); 719 EXPECT_FALSE(s2.wasExplicitlyConfigured()); 720 } 721 722 TEST_F(ServerTest, Port) { 723 DnsTlsServer s1, s2; 724 parseServer("192.0.2.1", 853, &s1.ss); 725 parseServer("192.0.2.1", 854, &s2.ss); 726 checkUnequal(s1, s2); 727 EXPECT_TRUE(isAddressEqual(s1, s2)); 728 729 DnsTlsServer s3, s4; 730 parseServer("2001:db8::1", 853, &s3.ss); 731 parseServer("2001:db8::1", 852, &s4.ss); 732 checkUnequal(s3, s4); 733 EXPECT_TRUE(isAddressEqual(s3, s4)); 734 735 EXPECT_FALSE(s1.wasExplicitlyConfigured()); 736 EXPECT_FALSE(s2.wasExplicitlyConfigured()); 737 } 738 739 TEST_F(ServerTest, Name) { 740 DnsTlsServer s1(V4ADDR1), s2(V4ADDR1); 741 s1.name = SERVERNAME1; 742 checkUnequal(s1, s2); 743 s2.name = SERVERNAME2; 744 checkUnequal(s1, s2); 745 EXPECT_TRUE(isAddressEqual(s1, s2)); 746 747 EXPECT_TRUE(s1.wasExplicitlyConfigured()); 748 EXPECT_TRUE(s2.wasExplicitlyConfigured()); 749 } 750 751 TEST_F(ServerTest, Fingerprint) { 752 DnsTlsServer s1(V4ADDR1), s2(V4ADDR1); 753 754 s1.fingerprints.insert(FINGERPRINT1); 755 checkUnequal(s1, s2); 756 EXPECT_TRUE(isAddressEqual(s1, s2)); 757 758 s2.fingerprints.insert(FINGERPRINT2); 759 checkUnequal(s1, s2); 760 EXPECT_TRUE(isAddressEqual(s1, s2)); 761 762 s2.fingerprints.insert(FINGERPRINT1); 763 checkUnequal(s1, s2); 764 EXPECT_TRUE(isAddressEqual(s1, s2)); 765 766 s1.fingerprints.insert(FINGERPRINT2); 767 EXPECT_EQ(s1, s2); 768 EXPECT_TRUE(isAddressEqual(s1, s2)); 769 770 EXPECT_TRUE(s1.wasExplicitlyConfigured()); 771 EXPECT_TRUE(s2.wasExplicitlyConfigured()); 772 } 773 774 TEST(QueryMapTest, Basic) { 775 DnsTlsQueryMap map; 776 777 EXPECT_TRUE(map.empty()); 778 779 bytevec q0 = make_query(999, SIZE); 780 bytevec q1 = make_query(888, SIZE); 781 bytevec q2 = make_query(777, SIZE); 782 783 auto f0 = map.recordQuery(makeSlice(q0)); 784 auto f1 = map.recordQuery(makeSlice(q1)); 785 auto f2 = map.recordQuery(makeSlice(q2)); 786 787 // Check return values of recordQuery 788 EXPECT_EQ(0, f0->query.newId); 789 EXPECT_EQ(1, f1->query.newId); 790 EXPECT_EQ(2, f2->query.newId); 791 792 // Check side effects of recordQuery 793 EXPECT_FALSE(map.empty()); 794 795 auto all = map.getAll(); 796 EXPECT_EQ(3U, all.size()); 797 798 EXPECT_EQ(0, all[0].newId); 799 EXPECT_EQ(1, all[1].newId); 800 EXPECT_EQ(2, all[2].newId); 801 802 EXPECT_EQ(makeSlice(q0), all[0].query); 803 EXPECT_EQ(makeSlice(q1), all[1].query); 804 EXPECT_EQ(makeSlice(q2), all[2].query); 805 806 bytevec a0 = make_query(0, SIZE); 807 bytevec a1 = make_query(1, SIZE); 808 bytevec a2 = make_query(2, SIZE); 809 810 // Return responses out of order 811 map.onResponse(a2); 812 map.onResponse(a0); 813 map.onResponse(a1); 814 815 EXPECT_TRUE(map.empty()); 816 817 auto r0 = f0->result.get(); 818 auto r1 = f1->result.get(); 819 auto r2 = f2->result.get(); 820 821 EXPECT_EQ(DnsTlsQueryMap::Response::success, r0.code); 822 EXPECT_EQ(DnsTlsQueryMap::Response::success, r1.code); 823 EXPECT_EQ(DnsTlsQueryMap::Response::success, r2.code); 824 825 const bytevec& d0 = r0.response; 826 const bytevec& d1 = r1.response; 827 const bytevec& d2 = r2.response; 828 829 // The ID should match the query 830 EXPECT_EQ(999, d0[0] << 8 | d0[1]); 831 EXPECT_EQ(888, d1[0] << 8 | d1[1]); 832 EXPECT_EQ(777, d2[0] << 8 | d2[1]); 833 // The body should match the answer 834 EXPECT_EQ(bytevec(a0.begin() + 2, a0.end()), bytevec(d0.begin() + 2, d0.end())); 835 EXPECT_EQ(bytevec(a1.begin() + 2, a1.end()), bytevec(d1.begin() + 2, d1.end())); 836 EXPECT_EQ(bytevec(a2.begin() + 2, a2.end()), bytevec(d2.begin() + 2, d2.end())); 837 } 838 839 TEST(QueryMapTest, FillHole) { 840 DnsTlsQueryMap map; 841 std::vector<std::unique_ptr<DnsTlsQueryMap::QueryFuture>> futures(UINT16_MAX + 1); 842 for (uint32_t i = 0; i <= UINT16_MAX; ++i) { 843 futures[i] = map.recordQuery(makeSlice(QUERY)); 844 ASSERT_TRUE(futures[i]); // answers[i] should be nonnull. 845 EXPECT_EQ(i, futures[i]->query.newId); 846 } 847 848 // The map should now be full. 849 EXPECT_EQ(size_t(UINT16_MAX + 1), map.getAll().size()); 850 851 // Trying to add another query should fail because the map is full. 852 EXPECT_FALSE(map.recordQuery(makeSlice(QUERY))); 853 854 // Send an answer to query 40000 855 auto answer = make_query(40000, SIZE); 856 map.onResponse(answer); 857 auto result = futures[40000]->result.get(); 858 EXPECT_EQ(DnsTlsQueryMap::Response::success, result.code); 859 EXPECT_EQ(ID, result.response[0] << 8 | result.response[1]); 860 EXPECT_EQ(bytevec(answer.begin() + 2, answer.end()), 861 bytevec(result.response.begin() + 2, result.response.end())); 862 863 // There should now be room in the map. 864 EXPECT_EQ(size_t(UINT16_MAX), map.getAll().size()); 865 auto f = map.recordQuery(makeSlice(QUERY)); 866 ASSERT_TRUE(f); 867 EXPECT_EQ(40000, f->query.newId); 868 869 // The map should now be full again. 870 EXPECT_EQ(size_t(UINT16_MAX + 1), map.getAll().size()); 871 EXPECT_FALSE(map.recordQuery(makeSlice(QUERY))); 872 } 873 874 } // end of namespace net 875 } // end of namespace android 876