Home | History | Annotate | Download | only in tests
      1 /*
      2  * Copyright (C) 2016 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "dns_responder.h"
     18 
     19 #include <arpa/inet.h>
     20 #include <fcntl.h>
     21 #include <netdb.h>
     22 #include <stdarg.h>
     23 #include <stdio.h>
     24 #include <stdlib.h>
     25 #include <string.h>
     26 #include <sys/epoll.h>
     27 #include <sys/socket.h>
     28 #include <sys/types.h>
     29 #include <unistd.h>
     30 
     31 #include <iostream>
     32 #include <vector>
     33 
     34 #include <log/log.h>
     35 
     36 namespace test {
     37 
     38 std::string errno2str() {
     39     char error_msg[512] = { 0 };
     40     if (strerror_r(errno, error_msg, sizeof(error_msg)))
     41         return std::string();
     42     return std::string(error_msg);
     43 }
     44 
     45 #define APLOGI(fmt, ...) ALOGI(fmt ": [%d] %s", __VA_ARGS__, errno, errno2str().c_str())
     46 
     47 std::string str2hex(const char* buffer, size_t len) {
     48     std::string str(len*2, '\0');
     49     for (size_t i = 0 ; i < len ; ++i) {
     50         static const char* hex = "0123456789ABCDEF";
     51         uint8_t c = buffer[i];
     52         str[i*2] = hex[c >> 4];
     53         str[i*2 + 1] = hex[c & 0x0F];
     54     }
     55     return str;
     56 }
     57 
     58 std::string addr2str(const sockaddr* sa, socklen_t sa_len) {
     59     char host_str[NI_MAXHOST] = { 0 };
     60     int rv = getnameinfo(sa, sa_len, host_str, sizeof(host_str), nullptr, 0,
     61                          NI_NUMERICHOST);
     62     if (rv == 0) return std::string(host_str);
     63     return std::string();
     64 }
     65 
     66 /* DNS struct helpers */
     67 
     68 const char* dnstype2str(unsigned dnstype) {
     69     static std::unordered_map<unsigned, const char*> kTypeStrs = {
     70         { ns_type::ns_t_a, "A" },
     71         { ns_type::ns_t_ns, "NS" },
     72         { ns_type::ns_t_md, "MD" },
     73         { ns_type::ns_t_mf, "MF" },
     74         { ns_type::ns_t_cname, "CNAME" },
     75         { ns_type::ns_t_soa, "SOA" },
     76         { ns_type::ns_t_mb, "MB" },
     77         { ns_type::ns_t_mb, "MG" },
     78         { ns_type::ns_t_mr, "MR" },
     79         { ns_type::ns_t_null, "NULL" },
     80         { ns_type::ns_t_wks, "WKS" },
     81         { ns_type::ns_t_ptr, "PTR" },
     82         { ns_type::ns_t_hinfo, "HINFO" },
     83         { ns_type::ns_t_minfo, "MINFO" },
     84         { ns_type::ns_t_mx, "MX" },
     85         { ns_type::ns_t_txt, "TXT" },
     86         { ns_type::ns_t_rp, "RP" },
     87         { ns_type::ns_t_afsdb, "AFSDB" },
     88         { ns_type::ns_t_x25, "X25" },
     89         { ns_type::ns_t_isdn, "ISDN" },
     90         { ns_type::ns_t_rt, "RT" },
     91         { ns_type::ns_t_nsap, "NSAP" },
     92         { ns_type::ns_t_nsap_ptr, "NSAP-PTR" },
     93         { ns_type::ns_t_sig, "SIG" },
     94         { ns_type::ns_t_key, "KEY" },
     95         { ns_type::ns_t_px, "PX" },
     96         { ns_type::ns_t_gpos, "GPOS" },
     97         { ns_type::ns_t_aaaa, "AAAA" },
     98         { ns_type::ns_t_loc, "LOC" },
     99         { ns_type::ns_t_nxt, "NXT" },
    100         { ns_type::ns_t_eid, "EID" },
    101         { ns_type::ns_t_nimloc, "NIMLOC" },
    102         { ns_type::ns_t_srv, "SRV" },
    103         { ns_type::ns_t_naptr, "NAPTR" },
    104         { ns_type::ns_t_kx, "KX" },
    105         { ns_type::ns_t_cert, "CERT" },
    106         { ns_type::ns_t_a6, "A6" },
    107         { ns_type::ns_t_dname, "DNAME" },
    108         { ns_type::ns_t_sink, "SINK" },
    109         { ns_type::ns_t_opt, "OPT" },
    110         { ns_type::ns_t_apl, "APL" },
    111         { ns_type::ns_t_tkey, "TKEY" },
    112         { ns_type::ns_t_tsig, "TSIG" },
    113         { ns_type::ns_t_ixfr, "IXFR" },
    114         { ns_type::ns_t_axfr, "AXFR" },
    115         { ns_type::ns_t_mailb, "MAILB" },
    116         { ns_type::ns_t_maila, "MAILA" },
    117         { ns_type::ns_t_any, "ANY" },
    118         { ns_type::ns_t_zxfr, "ZXFR" },
    119     };
    120     auto it = kTypeStrs.find(dnstype);
    121     static const char* kUnknownStr{ "UNKNOWN" };
    122     if (it == kTypeStrs.end()) return kUnknownStr;
    123     return it->second;
    124 }
    125 
    126 const char* dnsclass2str(unsigned dnsclass) {
    127     static std::unordered_map<unsigned, const char*> kClassStrs = {
    128         { ns_class::ns_c_in , "Internet" },
    129         { 2, "CSNet" },
    130         { ns_class::ns_c_chaos, "ChaosNet" },
    131         { ns_class::ns_c_hs, "Hesiod" },
    132         { ns_class::ns_c_none, "none" },
    133         { ns_class::ns_c_any, "any" }
    134     };
    135     auto it = kClassStrs.find(dnsclass);
    136     static const char* kUnknownStr{ "UNKNOWN" };
    137     if (it == kClassStrs.end()) return kUnknownStr;
    138     return it->second;
    139     return "unknown";
    140 }
    141 
    142 struct DNSName {
    143     std::string name;
    144     const char* read(const char* buffer, const char* buffer_end);
    145     char* write(char* buffer, const char* buffer_end) const;
    146     const char* toString() const;
    147 private:
    148     const char* parseField(const char* buffer, const char* buffer_end,
    149                            bool* last);
    150 };
    151 
    152 const char* DNSName::toString() const {
    153     return name.c_str();
    154 }
    155 
    156 const char* DNSName::read(const char* buffer, const char* buffer_end) {
    157     const char* cur = buffer;
    158     bool last = false;
    159     do {
    160         cur = parseField(cur, buffer_end, &last);
    161         if (cur == nullptr) {
    162             ALOGI("parsing failed at line %d", __LINE__);
    163             return nullptr;
    164         }
    165     } while (!last);
    166     return cur;
    167 }
    168 
    169 char* DNSName::write(char* buffer, const char* buffer_end) const {
    170     char* buffer_cur = buffer;
    171     for (size_t pos = 0 ; pos < name.size() ; ) {
    172         size_t dot_pos = name.find('.', pos);
    173         if (dot_pos == std::string::npos) {
    174             // Sanity check, should never happen unless parseField is broken.
    175             ALOGI("logic error: all names are expected to end with a '.'");
    176             return nullptr;
    177         }
    178         size_t len = dot_pos - pos;
    179         if (len >= 256) {
    180             ALOGI("name component '%s' is %zu long, but max is 255",
    181                     name.substr(pos, dot_pos - pos).c_str(), len);
    182             return nullptr;
    183         }
    184         if (buffer_cur + sizeof(uint8_t) + len > buffer_end) {
    185             ALOGI("buffer overflow at line %d", __LINE__);
    186             return nullptr;
    187         }
    188         *buffer_cur++ = len;
    189         buffer_cur = std::copy(std::next(name.begin(), pos),
    190                                std::next(name.begin(), dot_pos),
    191                                buffer_cur);
    192         pos = dot_pos + 1;
    193     }
    194     // Write final zero.
    195     *buffer_cur++ = 0;
    196     return buffer_cur;
    197 }
    198 
    199 const char* DNSName::parseField(const char* buffer, const char* buffer_end,
    200                                 bool* last) {
    201     if (buffer + sizeof(uint8_t) > buffer_end) {
    202         ALOGI("parsing failed at line %d", __LINE__);
    203         return nullptr;
    204     }
    205     unsigned field_type = *buffer >> 6;
    206     unsigned ofs = *buffer & 0x3F;
    207     const char* cur = buffer + sizeof(uint8_t);
    208     if (field_type == 0) {
    209         // length + name component
    210         if (ofs == 0) {
    211             *last = true;
    212             return cur;
    213         }
    214         if (cur + ofs > buffer_end) {
    215             ALOGI("parsing failed at line %d", __LINE__);
    216             return nullptr;
    217         }
    218         name.append(cur, ofs);
    219         name.push_back('.');
    220         return cur + ofs;
    221     } else if (field_type == 3) {
    222         ALOGI("name compression not implemented");
    223         return nullptr;
    224     }
    225     ALOGI("invalid name field type");
    226     return nullptr;
    227 }
    228 
    229 struct DNSQuestion {
    230     DNSName qname;
    231     unsigned qtype;
    232     unsigned qclass;
    233     const char* read(const char* buffer, const char* buffer_end);
    234     char* write(char* buffer, const char* buffer_end) const;
    235     std::string toString() const;
    236 };
    237 
    238 const char* DNSQuestion::read(const char* buffer, const char* buffer_end) {
    239     const char* cur = qname.read(buffer, buffer_end);
    240     if (cur == nullptr) {
    241         ALOGI("parsing failed at line %d", __LINE__);
    242         return nullptr;
    243     }
    244     if (cur + 2*sizeof(uint16_t) > buffer_end) {
    245         ALOGI("parsing failed at line %d", __LINE__);
    246         return nullptr;
    247     }
    248     qtype = ntohs(*reinterpret_cast<const uint16_t*>(cur));
    249     qclass = ntohs(*reinterpret_cast<const uint16_t*>(cur + sizeof(uint16_t)));
    250     return cur + 2*sizeof(uint16_t);
    251 }
    252 
    253 char* DNSQuestion::write(char* buffer, const char* buffer_end) const {
    254     char* buffer_cur = qname.write(buffer, buffer_end);
    255     if (buffer_cur == nullptr) return nullptr;
    256     if (buffer_cur + 2*sizeof(uint16_t) > buffer_end) {
    257         ALOGI("buffer overflow on line %d", __LINE__);
    258         return nullptr;
    259     }
    260     *reinterpret_cast<uint16_t*>(buffer_cur) = htons(qtype);
    261     *reinterpret_cast<uint16_t*>(buffer_cur + sizeof(uint16_t)) =
    262             htons(qclass);
    263     return buffer_cur + 2*sizeof(uint16_t);
    264 }
    265 
    266 std::string DNSQuestion::toString() const {
    267     char buffer[4096];
    268     int len = snprintf(buffer, sizeof(buffer), "Q<%s,%s,%s>", qname.toString(),
    269                        dnstype2str(qtype), dnsclass2str(qclass));
    270     return std::string(buffer, len);
    271 }
    272 
    273 struct DNSRecord {
    274     DNSName name;
    275     unsigned rtype;
    276     unsigned rclass;
    277     unsigned ttl;
    278     std::vector<char> rdata;
    279     const char* read(const char* buffer, const char* buffer_end);
    280     char* write(char* buffer, const char* buffer_end) const;
    281     std::string toString() const;
    282 private:
    283     struct IntFields {
    284         uint16_t rtype;
    285         uint16_t rclass;
    286         uint32_t ttl;
    287         uint16_t rdlen;
    288     } __attribute__((__packed__));
    289 
    290     const char* readIntFields(const char* buffer, const char* buffer_end,
    291             unsigned* rdlen);
    292     char* writeIntFields(unsigned rdlen, char* buffer,
    293                          const char* buffer_end) const;
    294 };
    295 
    296 const char* DNSRecord::read(const char* buffer, const char* buffer_end) {
    297     const char* cur = name.read(buffer, buffer_end);
    298     if (cur == nullptr) {
    299         ALOGI("parsing failed at line %d", __LINE__);
    300         return nullptr;
    301     }
    302     unsigned rdlen = 0;
    303     cur = readIntFields(cur, buffer_end, &rdlen);
    304     if (cur == nullptr) {
    305         ALOGI("parsing failed at line %d", __LINE__);
    306         return nullptr;
    307     }
    308     if (cur + rdlen > buffer_end) {
    309         ALOGI("parsing failed at line %d", __LINE__);
    310         return nullptr;
    311     }
    312     rdata.assign(cur, cur + rdlen);
    313     return cur + rdlen;
    314 }
    315 
    316 char* DNSRecord::write(char* buffer, const char* buffer_end) const {
    317     char* buffer_cur = name.write(buffer, buffer_end);
    318     if (buffer_cur == nullptr) return nullptr;
    319     buffer_cur = writeIntFields(rdata.size(), buffer_cur, buffer_end);
    320     if (buffer_cur == nullptr) return nullptr;
    321     if (buffer_cur + rdata.size() > buffer_end) {
    322         ALOGI("buffer overflow on line %d", __LINE__);
    323         return nullptr;
    324     }
    325     return std::copy(rdata.begin(), rdata.end(), buffer_cur);
    326 }
    327 
    328 std::string DNSRecord::toString() const {
    329     char buffer[4096];
    330     int len = snprintf(buffer, sizeof(buffer), "R<%s,%s,%s>", name.toString(),
    331                        dnstype2str(rtype), dnsclass2str(rclass));
    332     return std::string(buffer, len);
    333 }
    334 
    335 const char* DNSRecord::readIntFields(const char* buffer, const char* buffer_end,
    336                                      unsigned* rdlen) {
    337     if (buffer + sizeof(IntFields) > buffer_end ) {
    338         ALOGI("parsing failed at line %d", __LINE__);
    339         return nullptr;
    340     }
    341     const auto& intfields = *reinterpret_cast<const IntFields*>(buffer);
    342     rtype = ntohs(intfields.rtype);
    343     rclass = ntohs(intfields.rclass);
    344     ttl = ntohl(intfields.ttl);
    345     *rdlen = ntohs(intfields.rdlen);
    346     return buffer + sizeof(IntFields);
    347 }
    348 
    349 char* DNSRecord::writeIntFields(unsigned rdlen, char* buffer,
    350                                 const char* buffer_end) const {
    351     if (buffer + sizeof(IntFields) > buffer_end ) {
    352         ALOGI("buffer overflow on line %d", __LINE__);
    353         return nullptr;
    354     }
    355     auto& intfields = *reinterpret_cast<IntFields*>(buffer);
    356     intfields.rtype = htons(rtype);
    357     intfields.rclass = htons(rclass);
    358     intfields.ttl = htonl(ttl);
    359     intfields.rdlen = htons(rdlen);
    360     return buffer + sizeof(IntFields);
    361 }
    362 
    363 struct DNSHeader {
    364     unsigned id;
    365     bool ra;
    366     uint8_t rcode;
    367     bool qr;
    368     uint8_t opcode;
    369     bool aa;
    370     bool tr;
    371     bool rd;
    372     std::vector<DNSQuestion> questions;
    373     std::vector<DNSRecord> answers;
    374     std::vector<DNSRecord> authorities;
    375     std::vector<DNSRecord> additionals;
    376     const char* read(const char* buffer, const char* buffer_end);
    377     char* write(char* buffer, const char* buffer_end) const;
    378     std::string toString() const;
    379 
    380 private:
    381     struct Header {
    382         uint16_t id;
    383         uint8_t flags0;
    384         uint8_t flags1;
    385         uint16_t qdcount;
    386         uint16_t ancount;
    387         uint16_t nscount;
    388         uint16_t arcount;
    389     } __attribute__((__packed__));
    390 
    391     const char* readHeader(const char* buffer, const char* buffer_end,
    392                            unsigned* qdcount, unsigned* ancount,
    393                            unsigned* nscount, unsigned* arcount);
    394 };
    395 
    396 const char* DNSHeader::read(const char* buffer, const char* buffer_end) {
    397     unsigned qdcount;
    398     unsigned ancount;
    399     unsigned nscount;
    400     unsigned arcount;
    401     const char* cur = readHeader(buffer, buffer_end, &qdcount, &ancount,
    402                                  &nscount, &arcount);
    403     if (cur == nullptr) {
    404         ALOGI("parsing failed at line %d", __LINE__);
    405         return nullptr;
    406     }
    407     if (qdcount) {
    408         questions.resize(qdcount);
    409         for (unsigned i = 0 ; i < qdcount ; ++i) {
    410             cur = questions[i].read(cur, buffer_end);
    411             if (cur == nullptr) {
    412                 ALOGI("parsing failed at line %d", __LINE__);
    413                 return nullptr;
    414             }
    415         }
    416     }
    417     if (ancount) {
    418         answers.resize(ancount);
    419         for (unsigned i = 0 ; i < ancount ; ++i) {
    420             cur = answers[i].read(cur, buffer_end);
    421             if (cur == nullptr) {
    422                 ALOGI("parsing failed at line %d", __LINE__);
    423                 return nullptr;
    424             }
    425         }
    426     }
    427     if (nscount) {
    428         authorities.resize(nscount);
    429         for (unsigned i = 0 ; i < nscount ; ++i) {
    430             cur = authorities[i].read(cur, buffer_end);
    431             if (cur == nullptr) {
    432                 ALOGI("parsing failed at line %d", __LINE__);
    433                 return nullptr;
    434             }
    435         }
    436     }
    437     if (arcount) {
    438         additionals.resize(arcount);
    439         for (unsigned i = 0 ; i < arcount ; ++i) {
    440             cur = additionals[i].read(cur, buffer_end);
    441             if (cur == nullptr) {
    442                 ALOGI("parsing failed at line %d", __LINE__);
    443                 return nullptr;
    444             }
    445         }
    446     }
    447     return cur;
    448 }
    449 
    450 char* DNSHeader::write(char* buffer, const char* buffer_end) const {
    451     if (buffer + sizeof(Header) > buffer_end) {
    452         ALOGI("buffer overflow on line %d", __LINE__);
    453         return nullptr;
    454     }
    455     Header& header = *reinterpret_cast<Header*>(buffer);
    456     // bytes 0-1
    457     header.id = htons(id);
    458     // byte 2: 7:qr, 3-6:opcode, 2:aa, 1:tr, 0:rd
    459     header.flags0 = (qr << 7) | (opcode << 3) | (aa << 2) | (tr << 1) | rd;
    460     // byte 3: 7:ra, 6:zero, 5:ad, 4:cd, 0-3:rcode
    461     header.flags1 = rcode;
    462     // rest of header
    463     header.qdcount = htons(questions.size());
    464     header.ancount = htons(answers.size());
    465     header.nscount = htons(authorities.size());
    466     header.arcount = htons(additionals.size());
    467     char* buffer_cur = buffer + sizeof(Header);
    468     for (const DNSQuestion& question : questions) {
    469         buffer_cur = question.write(buffer_cur, buffer_end);
    470         if (buffer_cur == nullptr) return nullptr;
    471     }
    472     for (const DNSRecord& answer : answers) {
    473         buffer_cur = answer.write(buffer_cur, buffer_end);
    474         if (buffer_cur == nullptr) return nullptr;
    475     }
    476     for (const DNSRecord& authority : authorities) {
    477         buffer_cur = authority.write(buffer_cur, buffer_end);
    478         if (buffer_cur == nullptr) return nullptr;
    479     }
    480     for (const DNSRecord& additional : additionals) {
    481         buffer_cur = additional.write(buffer_cur, buffer_end);
    482         if (buffer_cur == nullptr) return nullptr;
    483     }
    484     return buffer_cur;
    485 }
    486 
    487 std::string DNSHeader::toString() const {
    488     // TODO
    489     return std::string();
    490 }
    491 
    492 const char* DNSHeader::readHeader(const char* buffer, const char* buffer_end,
    493                                   unsigned* qdcount, unsigned* ancount,
    494                                   unsigned* nscount, unsigned* arcount) {
    495     if (buffer + sizeof(Header) > buffer_end)
    496         return 0;
    497     const auto& header = *reinterpret_cast<const Header*>(buffer);
    498     // bytes 0-1
    499     id = ntohs(header.id);
    500     // byte 2: 7:qr, 3-6:opcode, 2:aa, 1:tr, 0:rd
    501     qr = header.flags0 >> 7;
    502     opcode = (header.flags0 >> 3) & 0x0F;
    503     aa = (header.flags0 >> 2) & 1;
    504     tr = (header.flags0 >> 1) & 1;
    505     rd = header.flags0 & 1;
    506     // byte 3: 7:ra, 6:zero, 5:ad, 4:cd, 0-3:rcode
    507     ra = header.flags1 >> 7;
    508     rcode = header.flags1 & 0xF;
    509     // rest of header
    510     *qdcount = ntohs(header.qdcount);
    511     *ancount = ntohs(header.ancount);
    512     *nscount = ntohs(header.nscount);
    513     *arcount = ntohs(header.arcount);
    514     return buffer + sizeof(Header);
    515 }
    516 
    517 /* DNS responder */
    518 
    519 DNSResponder::DNSResponder(std::string listen_address,
    520                            std::string listen_service, int poll_timeout_ms,
    521                            uint16_t error_rcode, double response_probability) :
    522     listen_address_(std::move(listen_address)), listen_service_(std::move(listen_service)),
    523     poll_timeout_ms_(poll_timeout_ms), error_rcode_(error_rcode),
    524     response_probability_(response_probability),
    525     socket_(-1), epoll_fd_(-1), terminate_(false) { }
    526 
    527 DNSResponder::~DNSResponder() {
    528     stopServer();
    529 }
    530 
    531 void DNSResponder::addMapping(const char* name, ns_type type,
    532         const char* addr) {
    533     std::lock_guard<std::mutex> lock(mappings_mutex_);
    534     auto it = mappings_.find(QueryKey(name, type));
    535     if (it != mappings_.end()) {
    536         ALOGI("Overwriting mapping for (%s, %s), previous address %s, new "
    537             "address %s", name, dnstype2str(type), it->second.c_str(),
    538             addr);
    539         it->second = addr;
    540         return;
    541     }
    542     mappings_.emplace(std::piecewise_construct,
    543                       std::forward_as_tuple(name, type),
    544                       std::forward_as_tuple(addr));
    545 }
    546 
    547 void DNSResponder::removeMapping(const char* name, ns_type type) {
    548     std::lock_guard<std::mutex> lock(mappings_mutex_);
    549     auto it = mappings_.find(QueryKey(name, type));
    550     if (it != mappings_.end()) {
    551         ALOGI("Cannot remove mapping mapping from (%s, %s), not present", name,
    552             dnstype2str(type));
    553         return;
    554     }
    555     mappings_.erase(it);
    556 }
    557 
    558 void DNSResponder::setResponseProbability(double response_probability) {
    559     response_probability_ = response_probability;
    560 }
    561 
    562 bool DNSResponder::running() const {
    563     return socket_ != -1;
    564 }
    565 
    566 bool DNSResponder::startServer() {
    567     if (running()) {
    568         ALOGI("server already running");
    569         return false;
    570     }
    571     addrinfo ai_hints{
    572         .ai_family = AF_UNSPEC,
    573         .ai_socktype = SOCK_DGRAM,
    574         .ai_flags = AI_PASSIVE
    575     };
    576     addrinfo* ai_res;
    577     int rv = getaddrinfo(listen_address_.c_str(), listen_service_.c_str(),
    578                          &ai_hints, &ai_res);
    579     if (rv) {
    580         ALOGI("getaddrinfo(%s, %s) failed: %s", listen_address_.c_str(),
    581             listen_service_.c_str(), gai_strerror(rv));
    582         return false;
    583     }
    584     int s = -1;
    585     for (const addrinfo* ai = ai_res ; ai ; ai = ai->ai_next) {
    586         s = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol);
    587         if (s < 0) continue;
    588         const int one = 1;
    589         setsockopt(s, SOL_SOCKET, SO_REUSEPORT, &one, sizeof(one));
    590         if (bind(s, ai->ai_addr, ai->ai_addrlen)) {
    591             APLOGI("bind failed for socket %d", s);
    592             close(s);
    593             s = -1;
    594             continue;
    595         }
    596         std::string host_str = addr2str(ai->ai_addr, ai->ai_addrlen);
    597         ALOGI("bound to UDP %s:%s", host_str.c_str(), listen_service_.c_str());
    598         break;
    599     }
    600     freeaddrinfo(ai_res);
    601     if (s < 0) {
    602         ALOGI("bind() failed");
    603         return false;
    604     }
    605 
    606     int flags = fcntl(s, F_GETFL, 0);
    607     if (flags < 0) flags = 0;
    608     if (fcntl(s, F_SETFL, flags | O_NONBLOCK) < 0) {
    609         APLOGI("fcntl(F_SETFL) failed for socket %d", s);
    610         close(s);
    611         return false;
    612     }
    613 
    614     int ep_fd = epoll_create(1);
    615     if (ep_fd < 0) {
    616         char error_msg[512] = { 0 };
    617         if (strerror_r(errno, error_msg, sizeof(error_msg)))
    618             strncpy(error_msg, "UNKNOWN", sizeof(error_msg));
    619         APLOGI("epoll_create() failed: %s", error_msg);
    620         close(s);
    621         return false;
    622     }
    623     epoll_event ev;
    624     ev.events = EPOLLIN;
    625     ev.data.fd = s;
    626     if (epoll_ctl(ep_fd, EPOLL_CTL_ADD, s, &ev) < 0) {
    627         APLOGI("epoll_ctl() failed for socket %d", s);
    628         close(ep_fd);
    629         close(s);
    630         return false;
    631     }
    632 
    633     epoll_fd_ = ep_fd;
    634     socket_ = s;
    635     {
    636         std::lock_guard<std::mutex> lock(update_mutex_);
    637         handler_thread_ = std::thread(&DNSResponder::requestHandler, this);
    638     }
    639     ALOGI("server started successfully");
    640     return true;
    641 }
    642 
    643 bool DNSResponder::stopServer() {
    644     std::lock_guard<std::mutex> lock(update_mutex_);
    645     if (!running()) {
    646         ALOGI("server not running");
    647         return false;
    648     }
    649     if (terminate_) {
    650         ALOGI("LOGIC ERROR");
    651         return false;
    652     }
    653     ALOGI("stopping server");
    654     terminate_ = true;
    655     handler_thread_.join();
    656     close(epoll_fd_);
    657     close(socket_);
    658     terminate_ = false;
    659     socket_ = -1;
    660     ALOGI("server stopped successfully");
    661     return true;
    662 }
    663 
    664 std::vector<std::pair<std::string, ns_type >> DNSResponder::queries() const {
    665     std::lock_guard<std::mutex> lock(queries_mutex_);
    666     return queries_;
    667 }
    668 
    669 void DNSResponder::clearQueries() {
    670     std::lock_guard<std::mutex> lock(queries_mutex_);
    671     queries_.clear();
    672 }
    673 
    674 void DNSResponder::requestHandler() {
    675     epoll_event evs[1];
    676     while (!terminate_) {
    677         int n = epoll_wait(epoll_fd_, evs, 1, poll_timeout_ms_);
    678         if (n == 0) continue;
    679         if (n < 0) {
    680             ALOGI("epoll_wait() failed");
    681             // TODO(imaipi): terminate on error.
    682             return;
    683         }
    684         char buffer[4096];
    685         sockaddr_storage sa;
    686         socklen_t sa_len = sizeof(sa);
    687         ssize_t len;
    688         do {
    689             len = recvfrom(socket_, buffer, sizeof(buffer), 0,
    690                            (sockaddr*) &sa, &sa_len);
    691         } while (len < 0 && (errno == EAGAIN || errno == EINTR));
    692         if (len <= 0) {
    693             ALOGI("recvfrom() failed");
    694             continue;
    695         }
    696         ALOGI("read %zd bytes", len);
    697         char response[4096];
    698         size_t response_len = sizeof(response);
    699         if (handleDNSRequest(buffer, len, response, &response_len) &&
    700             response_len > 0) {
    701             len = sendto(socket_, response, response_len, 0,
    702                          reinterpret_cast<const sockaddr*>(&sa), sa_len);
    703             std::string host_str =
    704                 addr2str(reinterpret_cast<const sockaddr*>(&sa), sa_len);
    705             if (len > 0) {
    706                 ALOGI("sent %zu bytes to %s", len, host_str.c_str());
    707             } else {
    708                 APLOGI("sendto() failed for %s", host_str.c_str());
    709             }
    710             // Test that the response is actually a correct DNS message.
    711             const char* response_end = response + len;
    712             DNSHeader header;
    713             const char* cur = header.read(response, response_end);
    714             if (cur == nullptr) ALOGI("response is flawed");
    715 
    716         } else {
    717             ALOGI("not responding");
    718         }
    719     }
    720 }
    721 
    722 bool DNSResponder::handleDNSRequest(const char* buffer, ssize_t len,
    723                                     char* response, size_t* response_len)
    724                                     const {
    725     ALOGI("request: '%s'", str2hex(buffer, len).c_str());
    726     const char* buffer_end = buffer + len;
    727     DNSHeader header;
    728     const char* cur = header.read(buffer, buffer_end);
    729     // TODO(imaipi): for now, unparsable messages are silently dropped, fix.
    730     if (cur == nullptr) {
    731         ALOGI("failed to parse query");
    732         return false;
    733     }
    734     if (header.qr) {
    735         ALOGI("response received instead of a query");
    736         return false;
    737     }
    738     if (header.opcode != ns_opcode::ns_o_query) {
    739         ALOGI("unsupported request opcode received");
    740         return makeErrorResponse(&header, ns_rcode::ns_r_notimpl, response,
    741                                  response_len);
    742     }
    743     if (header.questions.empty()) {
    744         ALOGI("no questions present");
    745         return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response,
    746                                  response_len);
    747     }
    748     if (!header.answers.empty()) {
    749         ALOGI("already %zu answers present in query", header.answers.size());
    750         return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response,
    751                                  response_len);
    752     }
    753     {
    754         std::lock_guard<std::mutex> lock(queries_mutex_);
    755         for (const DNSQuestion& question : header.questions) {
    756             queries_.push_back(make_pair(question.qname.name,
    757                                          ns_type(question.qtype)));
    758         }
    759     }
    760 
    761     // Ignore requests with the preset probability.
    762     auto constexpr bound = std::numeric_limits<unsigned>::max();
    763     if (arc4random_uniform(bound) > bound*response_probability_) {
    764         ALOGI("returning SRVFAIL in accordance with probability distribution");
    765         return makeErrorResponse(&header, ns_rcode::ns_r_servfail, response,
    766                                  response_len);
    767     }
    768 
    769     for (const DNSQuestion& question : header.questions) {
    770         if (question.qclass != ns_class::ns_c_in &&
    771             question.qclass != ns_class::ns_c_any) {
    772             ALOGI("unsupported question class %u", question.qclass);
    773             return makeErrorResponse(&header, ns_rcode::ns_r_notimpl, response,
    774                                      response_len);
    775         }
    776         if (!addAnswerRecords(question, &header.answers)) {
    777             return makeErrorResponse(&header, ns_rcode::ns_r_servfail, response,
    778                                      response_len);
    779         }
    780     }
    781     header.qr = true;
    782     char* response_cur = header.write(response, response + *response_len);
    783     if (response_cur == nullptr) {
    784         return false;
    785     }
    786     *response_len = response_cur - response;
    787     return true;
    788 }
    789 
    790 bool DNSResponder::addAnswerRecords(const DNSQuestion& question,
    791                                     std::vector<DNSRecord>* answers) const {
    792     auto it = mappings_.find(QueryKey(question.qname.name, question.qtype));
    793     if (it == mappings_.end()) {
    794         // TODO(imaipi): handle correctly
    795         ALOGI("no mapping found for %s %s, lazily refusing to add an answer",
    796             question.qname.name.c_str(), dnstype2str(question.qtype));
    797         return true;
    798     }
    799     ALOGI("mapping found for %s %s: %s", question.qname.name.c_str(),
    800         dnstype2str(question.qtype), it->second.c_str());
    801     DNSRecord record;
    802     record.name = question.qname;
    803     record.rtype = question.qtype;
    804     record.rclass = ns_class::ns_c_in;
    805     record.ttl = 5;  // seconds
    806     if (question.qtype == ns_type::ns_t_a) {
    807         record.rdata.resize(4);
    808         if (inet_pton(AF_INET, it->second.c_str(), record.rdata.data()) != 1) {
    809             ALOGI("inet_pton(AF_INET, %s) failed", it->second.c_str());
    810             return false;
    811         }
    812     } else if (question.qtype == ns_type::ns_t_aaaa) {
    813         record.rdata.resize(16);
    814         if (inet_pton(AF_INET6, it->second.c_str(), record.rdata.data()) != 1) {
    815             ALOGI("inet_pton(AF_INET6, %s) failed", it->second.c_str());
    816             return false;
    817         }
    818     } else {
    819         ALOGI("unhandled qtype %s", dnstype2str(question.qtype));
    820         return false;
    821     }
    822     answers->push_back(std::move(record));
    823     return true;
    824 }
    825 
    826 bool DNSResponder::makeErrorResponse(DNSHeader* header, ns_rcode rcode,
    827                                      char* response, size_t* response_len)
    828                                      const {
    829     header->answers.clear();
    830     header->authorities.clear();
    831     header->additionals.clear();
    832     header->rcode = rcode;
    833     header->qr = true;
    834     char* response_cur = header->write(response, response + *response_len);
    835     if (response_cur == nullptr) return false;
    836     *response_len = response_cur - response;
    837     return true;
    838 }
    839 
    840 }  // namespace test
    841 
    842