Home | History | Annotate | Download | only in dns
      1 /*
      2  * Copyright (C) 2018 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #define LOG_TAG "DnsTlsQueryMap"
     18 //#define LOG_NDEBUG 0
     19 
     20 #include "dns/DnsTlsTransport.h"
     21 
     22 #include "log/log.h"
     23 
     24 namespace android {
     25 namespace net {
     26 
     27 std::unique_ptr<DnsTlsQueryMap::QueryFuture> DnsTlsQueryMap::recordQuery(const Slice query) {
     28     std::lock_guard<std::mutex> guard(mLock);
     29 
     30     // Store the query so it can be matched to the response or reissued.
     31     if (query.size() < 2) {
     32         ALOGW("Query is too short");
     33         return nullptr;
     34     }
     35     int32_t newId = getFreeId();
     36     if (newId < 0) {
     37         ALOGW("All query IDs are in use");
     38         return nullptr;
     39     }
     40     Query q = { .newId = static_cast<uint16_t>(newId), .query = query };
     41     std::map<uint16_t, QueryPromise>::iterator it;
     42     bool inserted;
     43     std::tie(it, inserted) = mQueries.emplace(newId, q);
     44     if (!inserted) {
     45         ALOGE("Failed to store pending query");
     46         return nullptr;
     47     }
     48     return std::make_unique<QueryFuture>(q, it->second.result.get_future());
     49 }
     50 
     51 void DnsTlsQueryMap::expire(QueryPromise* p) {
     52     Result r = { .code = Response::network_error };
     53     p->result.set_value(r);
     54 }
     55 
     56 void DnsTlsQueryMap::markTried(uint16_t newId) {
     57     std::lock_guard<std::mutex> guard(mLock);
     58     auto it = mQueries.find(newId);
     59     if (it != mQueries.end()) {
     60         it->second.tries++;
     61     }
     62 }
     63 
     64 void DnsTlsQueryMap::cleanup() {
     65     std::lock_guard<std::mutex> guard(mLock);
     66     for (auto it = mQueries.begin(); it != mQueries.end();) {
     67         auto& p = it->second;
     68         if (p.tries >= kMaxTries) {
     69             expire(&p);
     70             it = mQueries.erase(it);
     71         } else {
     72             ++it;
     73         }
     74     }
     75 }
     76 
     77 int32_t DnsTlsQueryMap::getFreeId() {
     78     if (mQueries.empty()) {
     79         return 0;
     80     }
     81     uint16_t maxId = mQueries.rbegin()->first;
     82     if (maxId < UINT16_MAX) {
     83         return maxId + 1;
     84     }
     85     if (mQueries.size() == UINT16_MAX + 1) {
     86         // Map is full.
     87         return -1;
     88     }
     89     // Linear scan.
     90     uint16_t nextId = 0;
     91     for (auto& pair : mQueries) {
     92         uint16_t id = pair.first;
     93         if (id != nextId) {
     94             // Found a gap.
     95             return nextId;
     96         }
     97         nextId = id + 1;
     98     }
     99     // Unreachable (but the compiler isn't smart enough to prove it).
    100     return -1;
    101 }
    102 
    103 std::vector<DnsTlsQueryMap::Query> DnsTlsQueryMap::getAll() {
    104     std::lock_guard<std::mutex> guard(mLock);
    105     std::vector<DnsTlsQueryMap::Query> queries;
    106     for (auto& q : mQueries) {
    107         queries.push_back(q.second.query);
    108     }
    109     return queries;
    110 }
    111 
    112 bool DnsTlsQueryMap::empty() {
    113     std::lock_guard<std::mutex> guard(mLock);
    114     return mQueries.empty();
    115 }
    116 
    117 void DnsTlsQueryMap::clear() {
    118     std::lock_guard<std::mutex> guard(mLock);
    119     for (auto& q : mQueries) {
    120         expire(&q.second);
    121     }
    122     mQueries.clear();
    123 }
    124 
    125 void DnsTlsQueryMap::onResponse(std::vector<uint8_t> response) {
    126     ALOGV("Got response of size %zu", response.size());
    127     if (response.size() < 2) {
    128         ALOGW("Response is too short");
    129         return;
    130     }
    131     uint16_t id = response[0] << 8 | response[1];
    132     std::lock_guard<std::mutex> guard(mLock);
    133     auto it = mQueries.find(id);
    134     if (it == mQueries.end()) {
    135         ALOGW("Discarding response: unknown ID %d", id);
    136         return;
    137     }
    138     Result r = { .code = Response::success, .response = std::move(response) };
    139     // Rewrite ID to match the query
    140     const uint8_t* data = it->second.query.query.base();
    141     r.response[0] = data[0];
    142     r.response[1] = data[1];
    143     ALOGV("Sending result to dispatcher");
    144     it->second.result.set_value(std::move(r));
    145     mQueries.erase(it);
    146 }
    147 
    148 }  // end of namespace net
    149 }  // end of namespace android
    150