Home | History | Annotate | Download | only in tests
      1 /*
      2  * Copyright (C) 2018 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #define LOG_TAG "dns_tls_test"
     18 
     19 #include <gtest/gtest.h>
     20 
     21 #include "dns/DnsTlsDispatcher.h"
     22 #include "dns/DnsTlsQueryMap.h"
     23 #include "dns/DnsTlsServer.h"
     24 #include "dns/DnsTlsSessionCache.h"
     25 #include "dns/DnsTlsSocket.h"
     26 #include "dns/DnsTlsTransport.h"
     27 #include "dns/IDnsTlsSocket.h"
     28 #include "dns/IDnsTlsSocketFactory.h"
     29 #include "dns/IDnsTlsSocketObserver.h"
     30 
     31 #include <chrono>
     32 #include <arpa/inet.h>
     33 #include <android-base/macros.h>
     34 #include <netdutils/Slice.h>
     35 
     36 #include "log/log.h"
     37 
     38 namespace android {
     39 namespace net {
     40 
     41 using netdutils::Slice;
     42 using netdutils::makeSlice;
     43 
     44 typedef std::vector<uint8_t> bytevec;
     45 
     46 static void parseServer(const char* server, in_port_t port, sockaddr_storage* parsed) {
     47     sockaddr_in* sin = reinterpret_cast<sockaddr_in*>(parsed);
     48     if (inet_pton(AF_INET, server, &(sin->sin_addr)) == 1) {
     49         // IPv4 parse succeeded, so it's IPv4
     50         sin->sin_family = AF_INET;
     51         sin->sin_port = htons(port);
     52         return;
     53     }
     54     sockaddr_in6* sin6 = reinterpret_cast<sockaddr_in6*>(parsed);
     55     if (inet_pton(AF_INET6, server, &(sin6->sin6_addr)) == 1){
     56         // IPv6 parse succeeded, so it's IPv6.
     57         sin6->sin6_family = AF_INET6;
     58         sin6->sin6_port = htons(port);
     59         return;
     60     }
     61     ALOGE("Failed to parse server address: %s", server);
     62 }
     63 
     64 bytevec FINGERPRINT1 = { 1 };
     65 bytevec FINGERPRINT2 = { 2 };
     66 
     67 std::string SERVERNAME1 = "dns.example.com";
     68 std::string SERVERNAME2 = "dns.example.org";
     69 
     70 // BaseTest just provides constants that are useful for the tests.
     71 class BaseTest : public ::testing::Test {
     72 protected:
     73     BaseTest() {
     74         parseServer("192.0.2.1", 853, &V4ADDR1);
     75         parseServer("192.0.2.2", 853, &V4ADDR2);
     76         parseServer("2001:db8::1", 853, &V6ADDR1);
     77         parseServer("2001:db8::2", 853, &V6ADDR2);
     78 
     79         SERVER1 = DnsTlsServer(V4ADDR1);
     80         SERVER1.fingerprints.insert(FINGERPRINT1);
     81         SERVER1.name = SERVERNAME1;
     82     }
     83 
     84     sockaddr_storage V4ADDR1;
     85     sockaddr_storage V4ADDR2;
     86     sockaddr_storage V6ADDR1;
     87     sockaddr_storage V6ADDR2;
     88 
     89     DnsTlsServer SERVER1;
     90 };
     91 
     92 bytevec make_query(uint16_t id, size_t size) {
     93     bytevec vec(size);
     94     vec[0] = id >> 8;
     95     vec[1] = id;
     96     // Arbitrarily fill the query body with unique data.
     97     for (size_t i = 2; i < size; ++i) {
     98         vec[i] = id + i;
     99     }
    100     return vec;
    101 }
    102 
    103 // Query constants
    104 const unsigned MARK = 123;
    105 const uint16_t ID = 52;
    106 const uint16_t SIZE = 22;
    107 const bytevec QUERY = make_query(ID, SIZE);
    108 
    109 template <class T>
    110 class FakeSocketFactory : public IDnsTlsSocketFactory {
    111 public:
    112     FakeSocketFactory() {}
    113     std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
    114             const DnsTlsServer& server ATTRIBUTE_UNUSED,
    115             unsigned mark ATTRIBUTE_UNUSED,
    116             IDnsTlsSocketObserver* observer,
    117             DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
    118         return std::make_unique<T>(observer);
    119     }
    120 };
    121 
    122 bytevec make_echo(uint16_t id, const Slice query) {
    123     bytevec response(query.size() + 2);
    124     response[0] = id >> 8;
    125     response[1] = id;
    126     // Echo the query as the fake response.
    127     memcpy(response.data() + 2, query.base(), query.size());
    128     return response;
    129 }
    130 
    131 // Simplest possible fake server.  This just echoes the query as the response.
    132 class FakeSocketEcho : public IDnsTlsSocket {
    133 public:
    134     FakeSocketEcho(IDnsTlsSocketObserver* observer) : mObserver(observer) {}
    135     bool query(uint16_t id, const Slice query) override {
    136         // Return the response immediately (asynchronously).
    137         std::thread(&IDnsTlsSocketObserver::onResponse, mObserver, make_echo(id, query)).detach();
    138         return true;
    139     }
    140 private:
    141     IDnsTlsSocketObserver* const mObserver;
    142 };
    143 
    144 class TransportTest : public BaseTest {};
    145 
    146 TEST_F(TransportTest, Query) {
    147     FakeSocketFactory<FakeSocketEcho> factory;
    148     DnsTlsTransport transport(SERVER1, MARK, &factory);
    149     auto r = transport.query(makeSlice(QUERY)).get();
    150 
    151     EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
    152     EXPECT_EQ(QUERY, r.response);
    153 }
    154 
    155 TEST_F(TransportTest, SerialQueries_100000) {
    156     FakeSocketFactory<FakeSocketEcho> factory;
    157     DnsTlsTransport transport(SERVER1, MARK, &factory);
    158     // Send more than 65536 queries serially.
    159     for (int i = 0; i < 100000; ++i) {
    160         auto r = transport.query(makeSlice(QUERY)).get();
    161 
    162         EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
    163         EXPECT_EQ(QUERY, r.response);
    164     }
    165 }
    166 
    167 // These queries might be handled in serial or parallel as they race the
    168 // responses.
    169 TEST_F(TransportTest, RacingQueries_10000) {
    170     FakeSocketFactory<FakeSocketEcho> factory;
    171     DnsTlsTransport transport(SERVER1, MARK, &factory);
    172     std::vector<std::future<DnsTlsTransport::Result>> results;
    173     // Fewer than 65536 queries to avoid ID exhaustion.
    174     for (int i = 0; i < 10000; ++i) {
    175         results.push_back(transport.query(makeSlice(QUERY)));
    176     }
    177     for (auto& result : results) {
    178         auto r = result.get();
    179         EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
    180         EXPECT_EQ(QUERY, r.response);
    181     }
    182 }
    183 
    184 // A server that waits until sDelay queries are queued before responding.
    185 class FakeSocketDelay : public IDnsTlsSocket {
    186 public:
    187     FakeSocketDelay(IDnsTlsSocketObserver* observer) : mObserver(observer) {}
    188     ~FakeSocketDelay() { std::lock_guard<std::mutex> guard(mLock); }
    189     static size_t sDelay;
    190     static bool sReverse;
    191 
    192     bool query(uint16_t id, const Slice query) override {
    193         ALOGD("FakeSocketDelay got query with ID %d", int(id));
    194         std::lock_guard<std::mutex> guard(mLock);
    195         // Check for duplicate IDs.
    196         EXPECT_EQ(0U, mIds.count(id));
    197         mIds.insert(id);
    198 
    199         // Store response.
    200         mResponses.push_back(make_echo(id, query));
    201 
    202         ALOGD("Up to %zu out of %zu queries", mResponses.size(), sDelay);
    203         if (mResponses.size() == sDelay) {
    204             std::thread(&FakeSocketDelay::sendResponses, this).detach();
    205         }
    206         return true;
    207     }
    208 private:
    209     void sendResponses() {
    210         std::lock_guard<std::mutex> guard(mLock);
    211         if (sReverse) {
    212             std::reverse(std::begin(mResponses), std::end(mResponses));
    213         }
    214         for (auto& response : mResponses) {
    215             mObserver->onResponse(response);
    216         }
    217         mIds.clear();
    218         mResponses.clear();
    219     }
    220 
    221     std::mutex mLock;
    222     IDnsTlsSocketObserver* const mObserver;
    223     std::set<uint16_t> mIds GUARDED_BY(mLock);
    224     std::vector<bytevec> mResponses GUARDED_BY(mLock);
    225 };
    226 
    227 size_t FakeSocketDelay::sDelay;
    228 bool FakeSocketDelay::sReverse;
    229 
    230 TEST_F(TransportTest, ParallelColliding) {
    231     FakeSocketDelay::sDelay = 10;
    232     FakeSocketDelay::sReverse = false;
    233     FakeSocketFactory<FakeSocketDelay> factory;
    234     DnsTlsTransport transport(SERVER1, MARK, &factory);
    235     std::vector<std::future<DnsTlsTransport::Result>> results;
    236     // Fewer than 65536 queries to avoid ID exhaustion.
    237     for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
    238         results.push_back(transport.query(makeSlice(QUERY)));
    239     }
    240     for (auto& result : results) {
    241         auto r = result.get();
    242         EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
    243         EXPECT_EQ(QUERY, r.response);
    244     }
    245 }
    246 
    247 TEST_F(TransportTest, ParallelColliding_Max) {
    248     FakeSocketDelay::sDelay = 65536;
    249     FakeSocketDelay::sReverse = false;
    250     FakeSocketFactory<FakeSocketDelay> factory;
    251     DnsTlsTransport transport(SERVER1, MARK, &factory);
    252     std::vector<std::future<DnsTlsTransport::Result>> results;
    253     // Exactly 65536 queries should still be possible in parallel,
    254     // even if they all have the same original ID.
    255     for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
    256         results.push_back(transport.query(makeSlice(QUERY)));
    257     }
    258     for (auto& result : results) {
    259         auto r = result.get();
    260         EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
    261         EXPECT_EQ(QUERY, r.response);
    262     }
    263 }
    264 
    265 TEST_F(TransportTest, ParallelUnique) {
    266     FakeSocketDelay::sDelay = 10;
    267     FakeSocketDelay::sReverse = false;
    268     FakeSocketFactory<FakeSocketDelay> factory;
    269     DnsTlsTransport transport(SERVER1, MARK, &factory);
    270     std::vector<bytevec> queries(FakeSocketDelay::sDelay);
    271     std::vector<std::future<DnsTlsTransport::Result>> results;
    272     for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
    273         queries[i] = make_query(i, SIZE);
    274         results.push_back(transport.query(makeSlice(queries[i])));
    275     }
    276     for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
    277         auto r = results[i].get();
    278         EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
    279         EXPECT_EQ(queries[i], r.response);
    280     }
    281 }
    282 
    283 TEST_F(TransportTest, ParallelUnique_Max) {
    284     FakeSocketDelay::sDelay = 65536;
    285     FakeSocketDelay::sReverse = false;
    286     FakeSocketFactory<FakeSocketDelay> factory;
    287     DnsTlsTransport transport(SERVER1, MARK, &factory);
    288     std::vector<bytevec> queries(FakeSocketDelay::sDelay);
    289     std::vector<std::future<DnsTlsTransport::Result>> results;
    290     // Exactly 65536 queries should still be possible in parallel,
    291     // and they should all be mapped correctly back to the original ID.
    292     for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
    293         queries[i] = make_query(i, SIZE);
    294         results.push_back(transport.query(makeSlice(queries[i])));
    295     }
    296     for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
    297         auto r = results[i].get();
    298         EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
    299         EXPECT_EQ(queries[i], r.response);
    300     }
    301 }
    302 
    303 TEST_F(TransportTest, IdExhaustion) {
    304     // A delay of 65537 is unreachable, because the maximum number
    305     // of outstanding queries is 65536.
    306     FakeSocketDelay::sDelay = 65537;
    307     FakeSocketDelay::sReverse = false;
    308     FakeSocketFactory<FakeSocketDelay> factory;
    309     DnsTlsTransport transport(SERVER1, MARK, &factory);
    310     std::vector<std::future<DnsTlsTransport::Result>> results;
    311     // Issue the maximum number of queries.
    312     for (int i = 0; i < 65536; ++i) {
    313         results.push_back(transport.query(makeSlice(QUERY)));
    314     }
    315 
    316     // The ID space is now full, so subsequent queries should fail immediately.
    317     auto r = transport.query(makeSlice(QUERY)).get();
    318     EXPECT_EQ(DnsTlsTransport::Response::internal_error, r.code);
    319     EXPECT_TRUE(r.response.empty());
    320 
    321     for (auto& result : results) {
    322         // All other queries should remain outstanding.
    323         EXPECT_EQ(std::future_status::timeout,
    324                 result.wait_for(std::chrono::duration<int>::zero()));
    325     }
    326 }
    327 
    328 // Responses can come back from the server in any order.  This should have no
    329 // effect on Transport's observed behavior.
    330 TEST_F(TransportTest, ReverseOrder) {
    331     FakeSocketDelay::sDelay = 10;
    332     FakeSocketDelay::sReverse = true;
    333     FakeSocketFactory<FakeSocketDelay> factory;
    334     DnsTlsTransport transport(SERVER1, MARK, &factory);
    335     std::vector<bytevec> queries(FakeSocketDelay::sDelay);
    336     std::vector<std::future<DnsTlsTransport::Result>> results;
    337     for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
    338         queries[i] = make_query(i, SIZE);
    339         results.push_back(transport.query(makeSlice(queries[i])));
    340     }
    341     for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
    342         auto r = results[i].get();
    343         EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
    344         EXPECT_EQ(queries[i], r.response);
    345     }
    346 }
    347 
    348 TEST_F(TransportTest, ReverseOrder_Max) {
    349     FakeSocketDelay::sDelay = 65536;
    350     FakeSocketDelay::sReverse = true;
    351     FakeSocketFactory<FakeSocketDelay> factory;
    352     DnsTlsTransport transport(SERVER1, MARK, &factory);
    353     std::vector<bytevec> queries(FakeSocketDelay::sDelay);
    354     std::vector<std::future<DnsTlsTransport::Result>> results;
    355     for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
    356         queries[i] = make_query(i, SIZE);
    357         results.push_back(transport.query(makeSlice(queries[i])));
    358     }
    359     for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
    360         auto r = results[i].get();
    361         EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
    362         EXPECT_EQ(queries[i], r.response);
    363     }
    364 }
    365 
    366 // Returning null from the factory indicates a connection failure.
    367 class NullSocketFactory : public IDnsTlsSocketFactory {
    368 public:
    369     NullSocketFactory() {}
    370     std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
    371             const DnsTlsServer& server ATTRIBUTE_UNUSED,
    372             unsigned mark ATTRIBUTE_UNUSED,
    373             IDnsTlsSocketObserver* observer ATTRIBUTE_UNUSED,
    374             DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
    375         return nullptr;
    376     }
    377 };
    378 
    379 TEST_F(TransportTest, ConnectFail) {
    380     NullSocketFactory factory;
    381     DnsTlsTransport transport(SERVER1, MARK, &factory);
    382     auto r = transport.query(makeSlice(QUERY)).get();
    383 
    384     EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
    385     EXPECT_TRUE(r.response.empty());
    386 }
    387 
    388 // Simulate a socket that connects but then immediately receives a server
    389 // close notification.
    390 class FakeSocketClose : public IDnsTlsSocket {
    391 public:
    392     FakeSocketClose(IDnsTlsSocketObserver* observer) :
    393             mCloser(&IDnsTlsSocketObserver::onClosed, observer) {}
    394     ~FakeSocketClose() { mCloser.join(); }
    395     bool query(uint16_t id ATTRIBUTE_UNUSED,
    396                const Slice query ATTRIBUTE_UNUSED) override {
    397         return true;
    398     }
    399 private:
    400     std::thread mCloser;
    401 };
    402 
    403 TEST_F(TransportTest, CloseRetryFail) {
    404     FakeSocketFactory<FakeSocketClose> factory;
    405     DnsTlsTransport transport(SERVER1, MARK, &factory);
    406     auto r = transport.query(makeSlice(QUERY)).get();
    407 
    408     EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
    409     EXPECT_TRUE(r.response.empty());
    410 }
    411 
    412 // Simulate a server that occasionally closes the connection and silently
    413 // drops some queries.
    414 class FakeSocketLimited : public IDnsTlsSocket {
    415 public:
    416     static int sLimit;  // Number of queries to answer per socket.
    417     static size_t sMaxSize;  // Silently discard queries greater than this size.
    418     FakeSocketLimited(IDnsTlsSocketObserver* observer) :
    419             mObserver(observer), mQueries(0) {}
    420     ~FakeSocketLimited() {
    421         {
    422             ALOGD("~FakeSocketLimited acquiring mLock");
    423             std::lock_guard<std::mutex> guard(mLock);
    424             ALOGD("~FakeSocketLimited acquired mLock");
    425             for (auto& thread : mThreads) {
    426                 ALOGD("~FakeSocketLimited joining response thread");
    427                 thread.join();
    428                 ALOGD("~FakeSocketLimited joined response thread");
    429             }
    430             mThreads.clear();
    431         }
    432 
    433         if (mCloser) {
    434             ALOGD("~FakeSocketLimited joining closer thread");
    435             mCloser->join();
    436             ALOGD("~FakeSocketLimited joined closer thread");
    437         }
    438     }
    439     bool query(uint16_t id, const Slice query) override {
    440         ALOGD("FakeSocketLimited::query acquiring mLock");
    441         std::lock_guard<std::mutex> guard(mLock);
    442         ALOGD("FakeSocketLimited::query acquired mLock");
    443         ++mQueries;
    444 
    445         if (mQueries <= sLimit) {
    446             ALOGD("size %zu vs. limit of %zu", query.size(), sMaxSize);
    447             if (query.size() <= sMaxSize) {
    448                 // Return the response immediately (asynchronously).
    449                 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, make_echo(id, query));
    450             }
    451         }
    452         if (mQueries == sLimit) {
    453             mCloser = std::make_unique<std::thread>(&FakeSocketLimited::sendClose, this);
    454         }
    455         return mQueries <= sLimit;
    456     }
    457 private:
    458     void sendClose() {
    459         {
    460             ALOGD("FakeSocketLimited::sendClose acquiring mLock");
    461             std::lock_guard<std::mutex> guard(mLock);
    462             ALOGD("FakeSocketLimited::sendClose acquired mLock");
    463             for (auto& thread : mThreads) {
    464                 ALOGD("FakeSocketLimited::sendClose joining response thread");
    465                 thread.join();
    466                 ALOGD("FakeSocketLimited::sendClose joined response thread");
    467             }
    468             mThreads.clear();
    469         }
    470         mObserver->onClosed();
    471     }
    472     std::mutex mLock;
    473     IDnsTlsSocketObserver* const mObserver;
    474     int mQueries GUARDED_BY(mLock);
    475     std::vector<std::thread> mThreads GUARDED_BY(mLock);
    476     std::unique_ptr<std::thread> mCloser GUARDED_BY(mLock);
    477 };
    478 
    479 int FakeSocketLimited::sLimit;
    480 size_t FakeSocketLimited::sMaxSize;
    481 
    482 TEST_F(TransportTest, SilentDrop) {
    483     FakeSocketLimited::sLimit = 10;  // Close the socket after 10 queries.
    484     FakeSocketLimited::sMaxSize = 0;  // Silently drop all queries
    485     FakeSocketFactory<FakeSocketLimited> factory;
    486     DnsTlsTransport transport(SERVER1, MARK, &factory);
    487 
    488     // Queue up 10 queries.  They will all be ignored, and after the 10th,
    489     // the socket will close.  Transport will retry them all, until they
    490     // all hit the retry limit and expire.
    491     std::vector<std::future<DnsTlsTransport::Result>> results;
    492     for (int i = 0; i < FakeSocketLimited::sLimit; ++i) {
    493         results.push_back(transport.query(makeSlice(QUERY)));
    494     }
    495     for (auto& result : results) {
    496         auto r = result.get();
    497         EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
    498         EXPECT_TRUE(r.response.empty());
    499     }
    500 }
    501 
    502 TEST_F(TransportTest, PartialDrop) {
    503     FakeSocketLimited::sLimit = 10;  // Close the socket after 10 queries.
    504     FakeSocketLimited::sMaxSize = SIZE - 2;  // Silently drop "long" queries
    505     FakeSocketFactory<FakeSocketLimited> factory;
    506     DnsTlsTransport transport(SERVER1, MARK, &factory);
    507 
    508     // Queue up 100 queries, alternating "short" which will be served and "long"
    509     // which will be dropped.
    510     int num_queries = 10 * FakeSocketLimited::sLimit;
    511     std::vector<bytevec> queries(num_queries);
    512     std::vector<std::future<DnsTlsTransport::Result>> results;
    513     for (int i = 0; i < num_queries; ++i) {
    514         queries[i] = make_query(i, SIZE + (i % 2));
    515         results.push_back(transport.query(makeSlice(queries[i])));
    516     }
    517     // Just check the short queries, which are at the even indices.
    518     for (int i = 0; i < num_queries; i += 2) {
    519         auto r = results[i].get();
    520         EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
    521         EXPECT_EQ(queries[i], r.response);
    522     }
    523 }
    524 
    525 // Simulate a malfunctioning server that injects extra miscellaneous
    526 // responses to queries that were not asked.  This will cause wrong answers but
    527 // must not crash the Transport.
    528 class FakeSocketGarbage : public IDnsTlsSocket {
    529 public:
    530     FakeSocketGarbage(IDnsTlsSocketObserver* observer) : mObserver(observer) {
    531         // Inject a garbage event.
    532         mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, make_query(ID + 1, SIZE));
    533     }
    534     ~FakeSocketGarbage() {
    535         std::lock_guard<std::mutex> guard(mLock);
    536         for (auto& thread : mThreads) {
    537             thread.join();
    538         }
    539     }
    540     bool query(uint16_t id, const Slice query) override {
    541         std::lock_guard<std::mutex> guard(mLock);
    542         // Return the response twice.
    543         auto echo = make_echo(id, query);
    544         mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, echo);
    545         mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, echo);
    546         // Also return some other garbage
    547         mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, make_query(id + 1, query.size() + 2));
    548         return true;
    549     }
    550 private:
    551     std::mutex mLock;
    552     std::vector<std::thread> mThreads GUARDED_BY(mLock);
    553     IDnsTlsSocketObserver* const mObserver;
    554 };
    555 
    556 TEST_F(TransportTest, IgnoringGarbage) {
    557     FakeSocketFactory<FakeSocketGarbage> factory;
    558     DnsTlsTransport transport(SERVER1, MARK, &factory);
    559     for (int i = 0; i < 10; ++i) {
    560         auto r = transport.query(makeSlice(QUERY)).get();
    561 
    562         EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
    563         // Don't check the response because this server is malfunctioning.
    564     }
    565 }
    566 
    567 // Dispatcher tests
    568 class DispatcherTest : public BaseTest {};
    569 
    570 TEST_F(DispatcherTest, Query) {
    571     bytevec ans(4096);
    572     int resplen = 0;
    573 
    574     auto factory = std::make_unique<FakeSocketFactory<FakeSocketEcho>>();
    575     DnsTlsDispatcher dispatcher(std::move(factory));
    576     auto r = dispatcher.query(SERVER1, MARK, makeSlice(QUERY),
    577                               makeSlice(ans), &resplen);
    578 
    579     EXPECT_EQ(DnsTlsTransport::Response::success, r);
    580     EXPECT_EQ(int(QUERY.size()), resplen);
    581     ans.resize(resplen);
    582     EXPECT_EQ(QUERY, ans);
    583 }
    584 
    585 TEST_F(DispatcherTest, AnswerTooLarge) {
    586     bytevec ans(SIZE - 1);  // Too small to hold the answer
    587     int resplen = 0;
    588 
    589     auto factory = std::make_unique<FakeSocketFactory<FakeSocketEcho>>();
    590     DnsTlsDispatcher dispatcher(std::move(factory));
    591     auto r = dispatcher.query(SERVER1, MARK, makeSlice(QUERY),
    592                               makeSlice(ans), &resplen);
    593 
    594     EXPECT_EQ(DnsTlsTransport::Response::limit_error, r);
    595 }
    596 
    597 template<class T>
    598 class TrackingFakeSocketFactory : public IDnsTlsSocketFactory {
    599 public:
    600     TrackingFakeSocketFactory() {}
    601     std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
    602             const DnsTlsServer& server,
    603             unsigned mark,
    604             IDnsTlsSocketObserver* observer,
    605             DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
    606         std::lock_guard<std::mutex> guard(mLock);
    607         keys.emplace(mark, server);
    608         return std::make_unique<T>(observer);
    609     }
    610     std::multiset<std::pair<unsigned, DnsTlsServer>> keys;
    611 private:
    612     std::mutex mLock;
    613 };
    614 
    615 TEST_F(DispatcherTest, Dispatching) {
    616     FakeSocketDelay::sDelay = 5;
    617     FakeSocketDelay::sReverse = true;
    618     auto factory = std::make_unique<TrackingFakeSocketFactory<FakeSocketDelay>>();
    619     auto* weak_factory = factory.get();  // Valid as long as dispatcher is in scope.
    620     DnsTlsDispatcher dispatcher(std::move(factory));
    621 
    622     // Populate a vector of two servers and two socket marks, four combinations
    623     // in total.
    624     std::vector<std::pair<unsigned, DnsTlsServer>> keys;
    625     keys.emplace_back(MARK, SERVER1);
    626     keys.emplace_back(MARK + 1, SERVER1);
    627     keys.emplace_back(MARK, V4ADDR2);
    628     keys.emplace_back(MARK + 1, V4ADDR2);
    629 
    630     // Do several queries on each server.  They should all succeed.
    631     std::vector<std::thread> threads;
    632     for (size_t i = 0; i < FakeSocketDelay::sDelay * keys.size(); ++i) {
    633         auto key = keys[i % keys.size()];
    634         threads.emplace_back([key, i] (DnsTlsDispatcher* dispatcher) {
    635             auto q = make_query(i, SIZE);
    636             bytevec ans(4096);
    637             int resplen = 0;
    638             unsigned mark = key.first;
    639             const DnsTlsServer& server = key.second;
    640             auto r = dispatcher->query(server, mark, makeSlice(q),
    641                                        makeSlice(ans), &resplen);
    642             EXPECT_EQ(DnsTlsTransport::Response::success, r);
    643             EXPECT_EQ(int(q.size()), resplen);
    644             ans.resize(resplen);
    645             EXPECT_EQ(q, ans);
    646         }, &dispatcher);
    647     }
    648     for (auto& thread : threads) {
    649         thread.join();
    650     }
    651     // We expect that the factory created one socket for each key.
    652     EXPECT_EQ(keys.size(), weak_factory->keys.size());
    653     for (auto& key : keys) {
    654         EXPECT_EQ(1U, weak_factory->keys.count(key));
    655     }
    656 }
    657 
    658 // Check DnsTlsServer's comparison logic.
    659 AddressComparator ADDRESS_COMPARATOR;
    660 bool isAddressEqual(const DnsTlsServer& s1, const DnsTlsServer& s2) {
    661     bool cmp1 = ADDRESS_COMPARATOR(s1, s2);
    662     bool cmp2 = ADDRESS_COMPARATOR(s2, s1);
    663     EXPECT_FALSE(cmp1 && cmp2);
    664     return !cmp1 && !cmp2;
    665 }
    666 
    667 void checkUnequal(const DnsTlsServer& s1, const DnsTlsServer& s2) {
    668     EXPECT_TRUE(s1 == s1);
    669     EXPECT_TRUE(s2 == s2);
    670     EXPECT_TRUE(isAddressEqual(s1, s1));
    671     EXPECT_TRUE(isAddressEqual(s2, s2));
    672 
    673     EXPECT_TRUE(s1 < s2 ^ s2 < s1);
    674     EXPECT_FALSE(s1 == s2);
    675     EXPECT_FALSE(s2 == s1);
    676 }
    677 
    678 class ServerTest : public BaseTest {};
    679 
    680 TEST_F(ServerTest, IPv4) {
    681     checkUnequal(V4ADDR1, V4ADDR2);
    682     EXPECT_FALSE(isAddressEqual(V4ADDR1, V4ADDR2));
    683 }
    684 
    685 TEST_F(ServerTest, IPv6) {
    686     checkUnequal(V6ADDR1, V6ADDR2);
    687     EXPECT_FALSE(isAddressEqual(V6ADDR1, V6ADDR2));
    688 }
    689 
    690 TEST_F(ServerTest, MixedAddressFamily) {
    691     checkUnequal(V6ADDR1, V4ADDR1);
    692     EXPECT_FALSE(isAddressEqual(V6ADDR1, V4ADDR1));
    693 }
    694 
    695 TEST_F(ServerTest, IPv6ScopeId) {
    696     DnsTlsServer s1(V6ADDR1), s2(V6ADDR1);
    697     sockaddr_in6* addr1 = reinterpret_cast<sockaddr_in6*>(&s1.ss);
    698     addr1->sin6_scope_id = 1;
    699     sockaddr_in6* addr2 = reinterpret_cast<sockaddr_in6*>(&s2.ss);
    700     addr2->sin6_scope_id = 2;
    701     checkUnequal(s1, s2);
    702     EXPECT_FALSE(isAddressEqual(s1, s2));
    703 
    704     EXPECT_FALSE(s1.wasExplicitlyConfigured());
    705     EXPECT_FALSE(s2.wasExplicitlyConfigured());
    706 }
    707 
    708 TEST_F(ServerTest, IPv6FlowInfo) {
    709     DnsTlsServer s1(V6ADDR1), s2(V6ADDR1);
    710     sockaddr_in6* addr1 = reinterpret_cast<sockaddr_in6*>(&s1.ss);
    711     addr1->sin6_flowinfo = 1;
    712     sockaddr_in6* addr2 = reinterpret_cast<sockaddr_in6*>(&s2.ss);
    713     addr2->sin6_flowinfo = 2;
    714     // All comparisons ignore flowinfo.
    715     EXPECT_EQ(s1, s2);
    716     EXPECT_TRUE(isAddressEqual(s1, s2));
    717 
    718     EXPECT_FALSE(s1.wasExplicitlyConfigured());
    719     EXPECT_FALSE(s2.wasExplicitlyConfigured());
    720 }
    721 
    722 TEST_F(ServerTest, Port) {
    723     DnsTlsServer s1, s2;
    724     parseServer("192.0.2.1", 853, &s1.ss);
    725     parseServer("192.0.2.1", 854, &s2.ss);
    726     checkUnequal(s1, s2);
    727     EXPECT_TRUE(isAddressEqual(s1, s2));
    728 
    729     DnsTlsServer s3, s4;
    730     parseServer("2001:db8::1", 853, &s3.ss);
    731     parseServer("2001:db8::1", 852, &s4.ss);
    732     checkUnequal(s3, s4);
    733     EXPECT_TRUE(isAddressEqual(s3, s4));
    734 
    735     EXPECT_FALSE(s1.wasExplicitlyConfigured());
    736     EXPECT_FALSE(s2.wasExplicitlyConfigured());
    737 }
    738 
    739 TEST_F(ServerTest, Name) {
    740     DnsTlsServer s1(V4ADDR1), s2(V4ADDR1);
    741     s1.name = SERVERNAME1;
    742     checkUnequal(s1, s2);
    743     s2.name = SERVERNAME2;
    744     checkUnequal(s1, s2);
    745     EXPECT_TRUE(isAddressEqual(s1, s2));
    746 
    747     EXPECT_TRUE(s1.wasExplicitlyConfigured());
    748     EXPECT_TRUE(s2.wasExplicitlyConfigured());
    749 }
    750 
    751 TEST_F(ServerTest, Fingerprint) {
    752     DnsTlsServer s1(V4ADDR1), s2(V4ADDR1);
    753 
    754     s1.fingerprints.insert(FINGERPRINT1);
    755     checkUnequal(s1, s2);
    756     EXPECT_TRUE(isAddressEqual(s1, s2));
    757 
    758     s2.fingerprints.insert(FINGERPRINT2);
    759     checkUnequal(s1, s2);
    760     EXPECT_TRUE(isAddressEqual(s1, s2));
    761 
    762     s2.fingerprints.insert(FINGERPRINT1);
    763     checkUnequal(s1, s2);
    764     EXPECT_TRUE(isAddressEqual(s1, s2));
    765 
    766     s1.fingerprints.insert(FINGERPRINT2);
    767     EXPECT_EQ(s1, s2);
    768     EXPECT_TRUE(isAddressEqual(s1, s2));
    769 
    770     EXPECT_TRUE(s1.wasExplicitlyConfigured());
    771     EXPECT_TRUE(s2.wasExplicitlyConfigured());
    772 }
    773 
    774 TEST(QueryMapTest, Basic) {
    775     DnsTlsQueryMap map;
    776 
    777     EXPECT_TRUE(map.empty());
    778 
    779     bytevec q0 = make_query(999, SIZE);
    780     bytevec q1 = make_query(888, SIZE);
    781     bytevec q2 = make_query(777, SIZE);
    782 
    783     auto f0 = map.recordQuery(makeSlice(q0));
    784     auto f1 = map.recordQuery(makeSlice(q1));
    785     auto f2 = map.recordQuery(makeSlice(q2));
    786 
    787     // Check return values of recordQuery
    788     EXPECT_EQ(0, f0->query.newId);
    789     EXPECT_EQ(1, f1->query.newId);
    790     EXPECT_EQ(2, f2->query.newId);
    791 
    792     // Check side effects of recordQuery
    793     EXPECT_FALSE(map.empty());
    794 
    795     auto all = map.getAll();
    796     EXPECT_EQ(3U, all.size());
    797 
    798     EXPECT_EQ(0, all[0].newId);
    799     EXPECT_EQ(1, all[1].newId);
    800     EXPECT_EQ(2, all[2].newId);
    801 
    802     EXPECT_EQ(makeSlice(q0), all[0].query);
    803     EXPECT_EQ(makeSlice(q1), all[1].query);
    804     EXPECT_EQ(makeSlice(q2), all[2].query);
    805 
    806     bytevec a0 = make_query(0, SIZE);
    807     bytevec a1 = make_query(1, SIZE);
    808     bytevec a2 = make_query(2, SIZE);
    809 
    810     // Return responses out of order
    811     map.onResponse(a2);
    812     map.onResponse(a0);
    813     map.onResponse(a1);
    814 
    815     EXPECT_TRUE(map.empty());
    816 
    817     auto r0 = f0->result.get();
    818     auto r1 = f1->result.get();
    819     auto r2 = f2->result.get();
    820 
    821     EXPECT_EQ(DnsTlsQueryMap::Response::success, r0.code);
    822     EXPECT_EQ(DnsTlsQueryMap::Response::success, r1.code);
    823     EXPECT_EQ(DnsTlsQueryMap::Response::success, r2.code);
    824 
    825     const bytevec& d0 = r0.response;
    826     const bytevec& d1 = r1.response;
    827     const bytevec& d2 = r2.response;
    828 
    829     // The ID should match the query
    830     EXPECT_EQ(999, d0[0] << 8 | d0[1]);
    831     EXPECT_EQ(888, d1[0] << 8 | d1[1]);
    832     EXPECT_EQ(777, d2[0] << 8 | d2[1]);
    833     // The body should match the answer
    834     EXPECT_EQ(bytevec(a0.begin() + 2, a0.end()), bytevec(d0.begin() + 2, d0.end()));
    835     EXPECT_EQ(bytevec(a1.begin() + 2, a1.end()), bytevec(d1.begin() + 2, d1.end()));
    836     EXPECT_EQ(bytevec(a2.begin() + 2, a2.end()), bytevec(d2.begin() + 2, d2.end()));
    837 }
    838 
    839 TEST(QueryMapTest, FillHole) {
    840     DnsTlsQueryMap map;
    841     std::vector<std::unique_ptr<DnsTlsQueryMap::QueryFuture>> futures(UINT16_MAX + 1);
    842     for (uint32_t i = 0; i <= UINT16_MAX; ++i) {
    843         futures[i] = map.recordQuery(makeSlice(QUERY));
    844         ASSERT_TRUE(futures[i]);  // answers[i] should be nonnull.
    845         EXPECT_EQ(i, futures[i]->query.newId);
    846     }
    847 
    848     // The map should now be full.
    849     EXPECT_EQ(size_t(UINT16_MAX + 1), map.getAll().size());
    850 
    851     // Trying to add another query should fail because the map is full.
    852     EXPECT_FALSE(map.recordQuery(makeSlice(QUERY)));
    853 
    854     // Send an answer to query 40000
    855     auto answer = make_query(40000, SIZE);
    856     map.onResponse(answer);
    857     auto result = futures[40000]->result.get();
    858     EXPECT_EQ(DnsTlsQueryMap::Response::success, result.code);
    859     EXPECT_EQ(ID, result.response[0] << 8 | result.response[1]);
    860     EXPECT_EQ(bytevec(answer.begin() + 2, answer.end()),
    861               bytevec(result.response.begin() + 2, result.response.end()));
    862 
    863     // There should now be room in the map.
    864     EXPECT_EQ(size_t(UINT16_MAX), map.getAll().size());
    865     auto f = map.recordQuery(makeSlice(QUERY));
    866     ASSERT_TRUE(f);
    867     EXPECT_EQ(40000, f->query.newId);
    868 
    869     // The map should now be full again.
    870     EXPECT_EQ(size_t(UINT16_MAX + 1), map.getAll().size());
    871     EXPECT_FALSE(map.recordQuery(makeSlice(QUERY)));
    872 }
    873 
    874 } // end of namespace net
    875 } // end of namespace android
    876