Home | History | Annotate | Download | only in dns_responder
      1 /*
      2  * Copyright (C) 2016 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "dns_responder.h"
     18 
     19 #include <arpa/inet.h>
     20 #include <fcntl.h>
     21 #include <netdb.h>
     22 #include <stdarg.h>
     23 #include <stdio.h>
     24 #include <stdlib.h>
     25 #include <string.h>
     26 #include <sys/epoll.h>
     27 #include <sys/socket.h>
     28 #include <sys/types.h>
     29 #include <unistd.h>
     30 
     31 #include <iostream>
     32 #include <vector>
     33 
     34 #define LOG_TAG "DNSResponder"
     35 #include <log/log.h>
     36 
     37 namespace test {
     38 
     39 std::string errno2str() {
     40     char error_msg[512] = { 0 };
     41     if (strerror_r(errno, error_msg, sizeof(error_msg)))
     42         return std::string();
     43     return std::string(error_msg);
     44 }
     45 
     46 #define APLOGI(fmt, ...) ALOGI(fmt ": [%d] %s", __VA_ARGS__, errno, errno2str().c_str())
     47 
     48 std::string str2hex(const char* buffer, size_t len) {
     49     std::string str(len*2, '\0');
     50     for (size_t i = 0 ; i < len ; ++i) {
     51         static const char* hex = "0123456789ABCDEF";
     52         uint8_t c = buffer[i];
     53         str[i*2] = hex[c >> 4];
     54         str[i*2 + 1] = hex[c & 0x0F];
     55     }
     56     return str;
     57 }
     58 
     59 std::string addr2str(const sockaddr* sa, socklen_t sa_len) {
     60     char host_str[NI_MAXHOST] = { 0 };
     61     int rv = getnameinfo(sa, sa_len, host_str, sizeof(host_str), nullptr, 0,
     62                          NI_NUMERICHOST);
     63     if (rv == 0) return std::string(host_str);
     64     return std::string();
     65 }
     66 
     67 /* DNS struct helpers */
     68 
     69 const char* dnstype2str(unsigned dnstype) {
     70     static std::unordered_map<unsigned, const char*> kTypeStrs = {
     71         { ns_type::ns_t_a, "A" },
     72         { ns_type::ns_t_ns, "NS" },
     73         { ns_type::ns_t_md, "MD" },
     74         { ns_type::ns_t_mf, "MF" },
     75         { ns_type::ns_t_cname, "CNAME" },
     76         { ns_type::ns_t_soa, "SOA" },
     77         { ns_type::ns_t_mb, "MB" },
     78         { ns_type::ns_t_mb, "MG" },
     79         { ns_type::ns_t_mr, "MR" },
     80         { ns_type::ns_t_null, "NULL" },
     81         { ns_type::ns_t_wks, "WKS" },
     82         { ns_type::ns_t_ptr, "PTR" },
     83         { ns_type::ns_t_hinfo, "HINFO" },
     84         { ns_type::ns_t_minfo, "MINFO" },
     85         { ns_type::ns_t_mx, "MX" },
     86         { ns_type::ns_t_txt, "TXT" },
     87         { ns_type::ns_t_rp, "RP" },
     88         { ns_type::ns_t_afsdb, "AFSDB" },
     89         { ns_type::ns_t_x25, "X25" },
     90         { ns_type::ns_t_isdn, "ISDN" },
     91         { ns_type::ns_t_rt, "RT" },
     92         { ns_type::ns_t_nsap, "NSAP" },
     93         { ns_type::ns_t_nsap_ptr, "NSAP-PTR" },
     94         { ns_type::ns_t_sig, "SIG" },
     95         { ns_type::ns_t_key, "KEY" },
     96         { ns_type::ns_t_px, "PX" },
     97         { ns_type::ns_t_gpos, "GPOS" },
     98         { ns_type::ns_t_aaaa, "AAAA" },
     99         { ns_type::ns_t_loc, "LOC" },
    100         { ns_type::ns_t_nxt, "NXT" },
    101         { ns_type::ns_t_eid, "EID" },
    102         { ns_type::ns_t_nimloc, "NIMLOC" },
    103         { ns_type::ns_t_srv, "SRV" },
    104         { ns_type::ns_t_naptr, "NAPTR" },
    105         { ns_type::ns_t_kx, "KX" },
    106         { ns_type::ns_t_cert, "CERT" },
    107         { ns_type::ns_t_a6, "A6" },
    108         { ns_type::ns_t_dname, "DNAME" },
    109         { ns_type::ns_t_sink, "SINK" },
    110         { ns_type::ns_t_opt, "OPT" },
    111         { ns_type::ns_t_apl, "APL" },
    112         { ns_type::ns_t_tkey, "TKEY" },
    113         { ns_type::ns_t_tsig, "TSIG" },
    114         { ns_type::ns_t_ixfr, "IXFR" },
    115         { ns_type::ns_t_axfr, "AXFR" },
    116         { ns_type::ns_t_mailb, "MAILB" },
    117         { ns_type::ns_t_maila, "MAILA" },
    118         { ns_type::ns_t_any, "ANY" },
    119         { ns_type::ns_t_zxfr, "ZXFR" },
    120     };
    121     auto it = kTypeStrs.find(dnstype);
    122     static const char* kUnknownStr{ "UNKNOWN" };
    123     if (it == kTypeStrs.end()) return kUnknownStr;
    124     return it->second;
    125 }
    126 
    127 const char* dnsclass2str(unsigned dnsclass) {
    128     static std::unordered_map<unsigned, const char*> kClassStrs = {
    129         { ns_class::ns_c_in , "Internet" },
    130         { 2, "CSNet" },
    131         { ns_class::ns_c_chaos, "ChaosNet" },
    132         { ns_class::ns_c_hs, "Hesiod" },
    133         { ns_class::ns_c_none, "none" },
    134         { ns_class::ns_c_any, "any" }
    135     };
    136     auto it = kClassStrs.find(dnsclass);
    137     static const char* kUnknownStr{ "UNKNOWN" };
    138     if (it == kClassStrs.end()) return kUnknownStr;
    139     return it->second;
    140     return "unknown";
    141 }
    142 
    143 struct DNSName {
    144     std::string name;
    145     const char* read(const char* buffer, const char* buffer_end);
    146     char* write(char* buffer, const char* buffer_end) const;
    147     const char* toString() const;
    148 private:
    149     const char* parseField(const char* buffer, const char* buffer_end,
    150                            bool* last);
    151 };
    152 
    153 const char* DNSName::toString() const {
    154     return name.c_str();
    155 }
    156 
    157 const char* DNSName::read(const char* buffer, const char* buffer_end) {
    158     const char* cur = buffer;
    159     bool last = false;
    160     do {
    161         cur = parseField(cur, buffer_end, &last);
    162         if (cur == nullptr) {
    163             ALOGI("parsing failed at line %d", __LINE__);
    164             return nullptr;
    165         }
    166     } while (!last);
    167     return cur;
    168 }
    169 
    170 char* DNSName::write(char* buffer, const char* buffer_end) const {
    171     char* buffer_cur = buffer;
    172     for (size_t pos = 0 ; pos < name.size() ; ) {
    173         size_t dot_pos = name.find('.', pos);
    174         if (dot_pos == std::string::npos) {
    175             // Sanity check, should never happen unless parseField is broken.
    176             ALOGI("logic error: all names are expected to end with a '.'");
    177             return nullptr;
    178         }
    179         size_t len = dot_pos - pos;
    180         if (len >= 256) {
    181             ALOGI("name component '%s' is %zu long, but max is 255",
    182                     name.substr(pos, dot_pos - pos).c_str(), len);
    183             return nullptr;
    184         }
    185         if (buffer_cur + sizeof(uint8_t) + len > buffer_end) {
    186             ALOGI("buffer overflow at line %d", __LINE__);
    187             return nullptr;
    188         }
    189         *buffer_cur++ = len;
    190         buffer_cur = std::copy(std::next(name.begin(), pos),
    191                                std::next(name.begin(), dot_pos),
    192                                buffer_cur);
    193         pos = dot_pos + 1;
    194     }
    195     // Write final zero.
    196     *buffer_cur++ = 0;
    197     return buffer_cur;
    198 }
    199 
    200 const char* DNSName::parseField(const char* buffer, const char* buffer_end,
    201                                 bool* last) {
    202     if (buffer + sizeof(uint8_t) > buffer_end) {
    203         ALOGI("parsing failed at line %d", __LINE__);
    204         return nullptr;
    205     }
    206     unsigned field_type = *buffer >> 6;
    207     unsigned ofs = *buffer & 0x3F;
    208     const char* cur = buffer + sizeof(uint8_t);
    209     if (field_type == 0) {
    210         // length + name component
    211         if (ofs == 0) {
    212             *last = true;
    213             return cur;
    214         }
    215         if (cur + ofs > buffer_end) {
    216             ALOGI("parsing failed at line %d", __LINE__);
    217             return nullptr;
    218         }
    219         name.append(cur, ofs);
    220         name.push_back('.');
    221         return cur + ofs;
    222     } else if (field_type == 3) {
    223         ALOGI("name compression not implemented");
    224         return nullptr;
    225     }
    226     ALOGI("invalid name field type");
    227     return nullptr;
    228 }
    229 
    230 struct DNSQuestion {
    231     DNSName qname;
    232     unsigned qtype;
    233     unsigned qclass;
    234     const char* read(const char* buffer, const char* buffer_end);
    235     char* write(char* buffer, const char* buffer_end) const;
    236     std::string toString() const;
    237 };
    238 
    239 const char* DNSQuestion::read(const char* buffer, const char* buffer_end) {
    240     const char* cur = qname.read(buffer, buffer_end);
    241     if (cur == nullptr) {
    242         ALOGI("parsing failed at line %d", __LINE__);
    243         return nullptr;
    244     }
    245     if (cur + 2*sizeof(uint16_t) > buffer_end) {
    246         ALOGI("parsing failed at line %d", __LINE__);
    247         return nullptr;
    248     }
    249     qtype = ntohs(*reinterpret_cast<const uint16_t*>(cur));
    250     qclass = ntohs(*reinterpret_cast<const uint16_t*>(cur + sizeof(uint16_t)));
    251     return cur + 2*sizeof(uint16_t);
    252 }
    253 
    254 char* DNSQuestion::write(char* buffer, const char* buffer_end) const {
    255     char* buffer_cur = qname.write(buffer, buffer_end);
    256     if (buffer_cur == nullptr) return nullptr;
    257     if (buffer_cur + 2*sizeof(uint16_t) > buffer_end) {
    258         ALOGI("buffer overflow on line %d", __LINE__);
    259         return nullptr;
    260     }
    261     *reinterpret_cast<uint16_t*>(buffer_cur) = htons(qtype);
    262     *reinterpret_cast<uint16_t*>(buffer_cur + sizeof(uint16_t)) =
    263             htons(qclass);
    264     return buffer_cur + 2*sizeof(uint16_t);
    265 }
    266 
    267 std::string DNSQuestion::toString() const {
    268     char buffer[4096];
    269     int len = snprintf(buffer, sizeof(buffer), "Q<%s,%s,%s>", qname.toString(),
    270                        dnstype2str(qtype), dnsclass2str(qclass));
    271     return std::string(buffer, len);
    272 }
    273 
    274 struct DNSRecord {
    275     DNSName name;
    276     unsigned rtype;
    277     unsigned rclass;
    278     unsigned ttl;
    279     std::vector<char> rdata;
    280     const char* read(const char* buffer, const char* buffer_end);
    281     char* write(char* buffer, const char* buffer_end) const;
    282     std::string toString() const;
    283 private:
    284     struct IntFields {
    285         uint16_t rtype;
    286         uint16_t rclass;
    287         uint32_t ttl;
    288         uint16_t rdlen;
    289     } __attribute__((__packed__));
    290 
    291     const char* readIntFields(const char* buffer, const char* buffer_end,
    292             unsigned* rdlen);
    293     char* writeIntFields(unsigned rdlen, char* buffer,
    294                          const char* buffer_end) const;
    295 };
    296 
    297 const char* DNSRecord::read(const char* buffer, const char* buffer_end) {
    298     const char* cur = name.read(buffer, buffer_end);
    299     if (cur == nullptr) {
    300         ALOGI("parsing failed at line %d", __LINE__);
    301         return nullptr;
    302     }
    303     unsigned rdlen = 0;
    304     cur = readIntFields(cur, buffer_end, &rdlen);
    305     if (cur == nullptr) {
    306         ALOGI("parsing failed at line %d", __LINE__);
    307         return nullptr;
    308     }
    309     if (cur + rdlen > buffer_end) {
    310         ALOGI("parsing failed at line %d", __LINE__);
    311         return nullptr;
    312     }
    313     rdata.assign(cur, cur + rdlen);
    314     return cur + rdlen;
    315 }
    316 
    317 char* DNSRecord::write(char* buffer, const char* buffer_end) const {
    318     char* buffer_cur = name.write(buffer, buffer_end);
    319     if (buffer_cur == nullptr) return nullptr;
    320     buffer_cur = writeIntFields(rdata.size(), buffer_cur, buffer_end);
    321     if (buffer_cur == nullptr) return nullptr;
    322     if (buffer_cur + rdata.size() > buffer_end) {
    323         ALOGI("buffer overflow on line %d", __LINE__);
    324         return nullptr;
    325     }
    326     return std::copy(rdata.begin(), rdata.end(), buffer_cur);
    327 }
    328 
    329 std::string DNSRecord::toString() const {
    330     char buffer[4096];
    331     int len = snprintf(buffer, sizeof(buffer), "R<%s,%s,%s>", name.toString(),
    332                        dnstype2str(rtype), dnsclass2str(rclass));
    333     return std::string(buffer, len);
    334 }
    335 
    336 const char* DNSRecord::readIntFields(const char* buffer, const char* buffer_end,
    337                                      unsigned* rdlen) {
    338     if (buffer + sizeof(IntFields) > buffer_end ) {
    339         ALOGI("parsing failed at line %d", __LINE__);
    340         return nullptr;
    341     }
    342     const auto& intfields = *reinterpret_cast<const IntFields*>(buffer);
    343     rtype = ntohs(intfields.rtype);
    344     rclass = ntohs(intfields.rclass);
    345     ttl = ntohl(intfields.ttl);
    346     *rdlen = ntohs(intfields.rdlen);
    347     return buffer + sizeof(IntFields);
    348 }
    349 
    350 char* DNSRecord::writeIntFields(unsigned rdlen, char* buffer,
    351                                 const char* buffer_end) const {
    352     if (buffer + sizeof(IntFields) > buffer_end ) {
    353         ALOGI("buffer overflow on line %d", __LINE__);
    354         return nullptr;
    355     }
    356     auto& intfields = *reinterpret_cast<IntFields*>(buffer);
    357     intfields.rtype = htons(rtype);
    358     intfields.rclass = htons(rclass);
    359     intfields.ttl = htonl(ttl);
    360     intfields.rdlen = htons(rdlen);
    361     return buffer + sizeof(IntFields);
    362 }
    363 
    364 struct DNSHeader {
    365     unsigned id;
    366     bool ra;
    367     uint8_t rcode;
    368     bool qr;
    369     uint8_t opcode;
    370     bool aa;
    371     bool tr;
    372     bool rd;
    373     std::vector<DNSQuestion> questions;
    374     std::vector<DNSRecord> answers;
    375     std::vector<DNSRecord> authorities;
    376     std::vector<DNSRecord> additionals;
    377     const char* read(const char* buffer, const char* buffer_end);
    378     char* write(char* buffer, const char* buffer_end) const;
    379     std::string toString() const;
    380 
    381 private:
    382     struct Header {
    383         uint16_t id;
    384         uint8_t flags0;
    385         uint8_t flags1;
    386         uint16_t qdcount;
    387         uint16_t ancount;
    388         uint16_t nscount;
    389         uint16_t arcount;
    390     } __attribute__((__packed__));
    391 
    392     const char* readHeader(const char* buffer, const char* buffer_end,
    393                            unsigned* qdcount, unsigned* ancount,
    394                            unsigned* nscount, unsigned* arcount);
    395 };
    396 
    397 const char* DNSHeader::read(const char* buffer, const char* buffer_end) {
    398     unsigned qdcount;
    399     unsigned ancount;
    400     unsigned nscount;
    401     unsigned arcount;
    402     const char* cur = readHeader(buffer, buffer_end, &qdcount, &ancount,
    403                                  &nscount, &arcount);
    404     if (cur == nullptr) {
    405         ALOGI("parsing failed at line %d", __LINE__);
    406         return nullptr;
    407     }
    408     if (qdcount) {
    409         questions.resize(qdcount);
    410         for (unsigned i = 0 ; i < qdcount ; ++i) {
    411             cur = questions[i].read(cur, buffer_end);
    412             if (cur == nullptr) {
    413                 ALOGI("parsing failed at line %d", __LINE__);
    414                 return nullptr;
    415             }
    416         }
    417     }
    418     if (ancount) {
    419         answers.resize(ancount);
    420         for (unsigned i = 0 ; i < ancount ; ++i) {
    421             cur = answers[i].read(cur, buffer_end);
    422             if (cur == nullptr) {
    423                 ALOGI("parsing failed at line %d", __LINE__);
    424                 return nullptr;
    425             }
    426         }
    427     }
    428     if (nscount) {
    429         authorities.resize(nscount);
    430         for (unsigned i = 0 ; i < nscount ; ++i) {
    431             cur = authorities[i].read(cur, buffer_end);
    432             if (cur == nullptr) {
    433                 ALOGI("parsing failed at line %d", __LINE__);
    434                 return nullptr;
    435             }
    436         }
    437     }
    438     if (arcount) {
    439         additionals.resize(arcount);
    440         for (unsigned i = 0 ; i < arcount ; ++i) {
    441             cur = additionals[i].read(cur, buffer_end);
    442             if (cur == nullptr) {
    443                 ALOGI("parsing failed at line %d", __LINE__);
    444                 return nullptr;
    445             }
    446         }
    447     }
    448     return cur;
    449 }
    450 
    451 char* DNSHeader::write(char* buffer, const char* buffer_end) const {
    452     if (buffer + sizeof(Header) > buffer_end) {
    453         ALOGI("buffer overflow on line %d", __LINE__);
    454         return nullptr;
    455     }
    456     Header& header = *reinterpret_cast<Header*>(buffer);
    457     // bytes 0-1
    458     header.id = htons(id);
    459     // byte 2: 7:qr, 3-6:opcode, 2:aa, 1:tr, 0:rd
    460     header.flags0 = (qr << 7) | (opcode << 3) | (aa << 2) | (tr << 1) | rd;
    461     // byte 3: 7:ra, 6:zero, 5:ad, 4:cd, 0-3:rcode
    462     header.flags1 = rcode;
    463     // rest of header
    464     header.qdcount = htons(questions.size());
    465     header.ancount = htons(answers.size());
    466     header.nscount = htons(authorities.size());
    467     header.arcount = htons(additionals.size());
    468     char* buffer_cur = buffer + sizeof(Header);
    469     for (const DNSQuestion& question : questions) {
    470         buffer_cur = question.write(buffer_cur, buffer_end);
    471         if (buffer_cur == nullptr) return nullptr;
    472     }
    473     for (const DNSRecord& answer : answers) {
    474         buffer_cur = answer.write(buffer_cur, buffer_end);
    475         if (buffer_cur == nullptr) return nullptr;
    476     }
    477     for (const DNSRecord& authority : authorities) {
    478         buffer_cur = authority.write(buffer_cur, buffer_end);
    479         if (buffer_cur == nullptr) return nullptr;
    480     }
    481     for (const DNSRecord& additional : additionals) {
    482         buffer_cur = additional.write(buffer_cur, buffer_end);
    483         if (buffer_cur == nullptr) return nullptr;
    484     }
    485     return buffer_cur;
    486 }
    487 
    488 std::string DNSHeader::toString() const {
    489     // TODO
    490     return std::string();
    491 }
    492 
    493 const char* DNSHeader::readHeader(const char* buffer, const char* buffer_end,
    494                                   unsigned* qdcount, unsigned* ancount,
    495                                   unsigned* nscount, unsigned* arcount) {
    496     if (buffer + sizeof(Header) > buffer_end)
    497         return 0;
    498     const auto& header = *reinterpret_cast<const Header*>(buffer);
    499     // bytes 0-1
    500     id = ntohs(header.id);
    501     // byte 2: 7:qr, 3-6:opcode, 2:aa, 1:tr, 0:rd
    502     qr = header.flags0 >> 7;
    503     opcode = (header.flags0 >> 3) & 0x0F;
    504     aa = (header.flags0 >> 2) & 1;
    505     tr = (header.flags0 >> 1) & 1;
    506     rd = header.flags0 & 1;
    507     // byte 3: 7:ra, 6:zero, 5:ad, 4:cd, 0-3:rcode
    508     ra = header.flags1 >> 7;
    509     rcode = header.flags1 & 0xF;
    510     // rest of header
    511     *qdcount = ntohs(header.qdcount);
    512     *ancount = ntohs(header.ancount);
    513     *nscount = ntohs(header.nscount);
    514     *arcount = ntohs(header.arcount);
    515     return buffer + sizeof(Header);
    516 }
    517 
    518 /* DNS responder */
    519 
    520 DNSResponder::DNSResponder(std::string listen_address,
    521                            std::string listen_service, int poll_timeout_ms,
    522                            uint16_t error_rcode, double response_probability) :
    523     listen_address_(std::move(listen_address)), listen_service_(std::move(listen_service)),
    524     poll_timeout_ms_(poll_timeout_ms), error_rcode_(error_rcode),
    525     response_probability_(response_probability),
    526     socket_(-1), epoll_fd_(-1), terminate_(false) { }
    527 
    528 DNSResponder::~DNSResponder() {
    529     stopServer();
    530 }
    531 
    532 void DNSResponder::addMapping(const char* name, ns_type type,
    533         const char* addr) {
    534     std::lock_guard<std::mutex> lock(mappings_mutex_);
    535     auto it = mappings_.find(QueryKey(name, type));
    536     if (it != mappings_.end()) {
    537         ALOGI("Overwriting mapping for (%s, %s), previous address %s, new "
    538             "address %s", name, dnstype2str(type), it->second.c_str(),
    539             addr);
    540         it->second = addr;
    541         return;
    542     }
    543     mappings_.emplace(std::piecewise_construct,
    544                       std::forward_as_tuple(name, type),
    545                       std::forward_as_tuple(addr));
    546 }
    547 
    548 void DNSResponder::removeMapping(const char* name, ns_type type) {
    549     std::lock_guard<std::mutex> lock(mappings_mutex_);
    550     auto it = mappings_.find(QueryKey(name, type));
    551     if (it != mappings_.end()) {
    552         ALOGI("Cannot remove mapping mapping from (%s, %s), not present", name,
    553             dnstype2str(type));
    554         return;
    555     }
    556     mappings_.erase(it);
    557 }
    558 
    559 void DNSResponder::setResponseProbability(double response_probability) {
    560     response_probability_ = response_probability;
    561 }
    562 
    563 bool DNSResponder::running() const {
    564     return socket_ != -1;
    565 }
    566 
    567 bool DNSResponder::startServer() {
    568     if (running()) {
    569         ALOGI("server already running");
    570         return false;
    571     }
    572     addrinfo ai_hints{
    573         .ai_family = AF_UNSPEC,
    574         .ai_socktype = SOCK_DGRAM,
    575         .ai_flags = AI_PASSIVE
    576     };
    577     addrinfo* ai_res;
    578     int rv = getaddrinfo(listen_address_.c_str(), listen_service_.c_str(),
    579                          &ai_hints, &ai_res);
    580     if (rv) {
    581         ALOGI("getaddrinfo(%s, %s) failed: %s", listen_address_.c_str(),
    582             listen_service_.c_str(), gai_strerror(rv));
    583         return false;
    584     }
    585     int s = -1;
    586     for (const addrinfo* ai = ai_res ; ai ; ai = ai->ai_next) {
    587         s = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol);
    588         if (s < 0) continue;
    589         const int one = 1;
    590         setsockopt(s, SOL_SOCKET, SO_REUSEPORT, &one, sizeof(one));
    591         if (bind(s, ai->ai_addr, ai->ai_addrlen)) {
    592             APLOGI("bind failed for socket %d", s);
    593             close(s);
    594             s = -1;
    595             continue;
    596         }
    597         std::string host_str = addr2str(ai->ai_addr, ai->ai_addrlen);
    598         ALOGI("bound to UDP %s:%s", host_str.c_str(), listen_service_.c_str());
    599         break;
    600     }
    601     freeaddrinfo(ai_res);
    602     if (s < 0) {
    603         ALOGI("bind() failed");
    604         return false;
    605     }
    606 
    607     int flags = fcntl(s, F_GETFL, 0);
    608     if (flags < 0) flags = 0;
    609     if (fcntl(s, F_SETFL, flags | O_NONBLOCK) < 0) {
    610         APLOGI("fcntl(F_SETFL) failed for socket %d", s);
    611         close(s);
    612         return false;
    613     }
    614 
    615     int ep_fd = epoll_create(1);
    616     if (ep_fd < 0) {
    617         char error_msg[512] = { 0 };
    618         if (strerror_r(errno, error_msg, sizeof(error_msg)))
    619             strncpy(error_msg, "UNKNOWN", sizeof(error_msg));
    620         APLOGI("epoll_create() failed: %s", error_msg);
    621         close(s);
    622         return false;
    623     }
    624     epoll_event ev;
    625     ev.events = EPOLLIN;
    626     ev.data.fd = s;
    627     if (epoll_ctl(ep_fd, EPOLL_CTL_ADD, s, &ev) < 0) {
    628         APLOGI("epoll_ctl() failed for socket %d", s);
    629         close(ep_fd);
    630         close(s);
    631         return false;
    632     }
    633 
    634     epoll_fd_ = ep_fd;
    635     socket_ = s;
    636     {
    637         std::lock_guard<std::mutex> lock(update_mutex_);
    638         handler_thread_ = std::thread(&DNSResponder::requestHandler, this);
    639     }
    640     ALOGI("server started successfully");
    641     return true;
    642 }
    643 
    644 bool DNSResponder::stopServer() {
    645     std::lock_guard<std::mutex> lock(update_mutex_);
    646     if (!running()) {
    647         ALOGI("server not running");
    648         return false;
    649     }
    650     if (terminate_) {
    651         ALOGI("LOGIC ERROR");
    652         return false;
    653     }
    654     ALOGI("stopping server");
    655     terminate_ = true;
    656     handler_thread_.join();
    657     close(epoll_fd_);
    658     close(socket_);
    659     terminate_ = false;
    660     socket_ = -1;
    661     ALOGI("server stopped successfully");
    662     return true;
    663 }
    664 
    665 std::vector<std::pair<std::string, ns_type >> DNSResponder::queries() const {
    666     std::lock_guard<std::mutex> lock(queries_mutex_);
    667     return queries_;
    668 }
    669 
    670 void DNSResponder::clearQueries() {
    671     std::lock_guard<std::mutex> lock(queries_mutex_);
    672     queries_.clear();
    673 }
    674 
    675 void DNSResponder::requestHandler() {
    676     epoll_event evs[1];
    677     while (!terminate_) {
    678         int n = epoll_wait(epoll_fd_, evs, 1, poll_timeout_ms_);
    679         if (n == 0) continue;
    680         if (n < 0) {
    681             ALOGI("epoll_wait() failed");
    682             // TODO(imaipi): terminate on error.
    683             return;
    684         }
    685         char buffer[4096];
    686         sockaddr_storage sa;
    687         socklen_t sa_len = sizeof(sa);
    688         ssize_t len;
    689         do {
    690             len = recvfrom(socket_, buffer, sizeof(buffer), 0,
    691                            (sockaddr*) &sa, &sa_len);
    692         } while (len < 0 && (errno == EAGAIN || errno == EINTR));
    693         if (len <= 0) {
    694             ALOGI("recvfrom() failed");
    695             continue;
    696         }
    697         ALOGI("read %zd bytes", len);
    698         char response[4096];
    699         size_t response_len = sizeof(response);
    700         if (handleDNSRequest(buffer, len, response, &response_len) &&
    701             response_len > 0) {
    702             len = sendto(socket_, response, response_len, 0,
    703                          reinterpret_cast<const sockaddr*>(&sa), sa_len);
    704             std::string host_str =
    705                 addr2str(reinterpret_cast<const sockaddr*>(&sa), sa_len);
    706             if (len > 0) {
    707                 ALOGI("sent %zu bytes to %s", len, host_str.c_str());
    708             } else {
    709                 APLOGI("sendto() failed for %s", host_str.c_str());
    710             }
    711             // Test that the response is actually a correct DNS message.
    712             const char* response_end = response + len;
    713             DNSHeader header;
    714             const char* cur = header.read(response, response_end);
    715             if (cur == nullptr) ALOGI("response is flawed");
    716 
    717         } else {
    718             ALOGI("not responding");
    719         }
    720     }
    721 }
    722 
    723 bool DNSResponder::handleDNSRequest(const char* buffer, ssize_t len,
    724                                     char* response, size_t* response_len)
    725                                     const {
    726     ALOGI("request: '%s'", str2hex(buffer, len).c_str());
    727     const char* buffer_end = buffer + len;
    728     DNSHeader header;
    729     const char* cur = header.read(buffer, buffer_end);
    730     // TODO(imaipi): for now, unparsable messages are silently dropped, fix.
    731     if (cur == nullptr) {
    732         ALOGI("failed to parse query");
    733         return false;
    734     }
    735     if (header.qr) {
    736         ALOGI("response received instead of a query");
    737         return false;
    738     }
    739     if (header.opcode != ns_opcode::ns_o_query) {
    740         ALOGI("unsupported request opcode received");
    741         return makeErrorResponse(&header, ns_rcode::ns_r_notimpl, response,
    742                                  response_len);
    743     }
    744     if (header.questions.empty()) {
    745         ALOGI("no questions present");
    746         return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response,
    747                                  response_len);
    748     }
    749     if (!header.answers.empty()) {
    750         ALOGI("already %zu answers present in query", header.answers.size());
    751         return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response,
    752                                  response_len);
    753     }
    754     {
    755         std::lock_guard<std::mutex> lock(queries_mutex_);
    756         for (const DNSQuestion& question : header.questions) {
    757             queries_.push_back(make_pair(question.qname.name,
    758                                          ns_type(question.qtype)));
    759         }
    760     }
    761 
    762     // Ignore requests with the preset probability.
    763     auto constexpr bound = std::numeric_limits<unsigned>::max();
    764     if (arc4random_uniform(bound) > bound*response_probability_) {
    765         ALOGI("returning SRVFAIL in accordance with probability distribution");
    766         return makeErrorResponse(&header, ns_rcode::ns_r_servfail, response,
    767                                  response_len);
    768     }
    769 
    770     for (const DNSQuestion& question : header.questions) {
    771         if (question.qclass != ns_class::ns_c_in &&
    772             question.qclass != ns_class::ns_c_any) {
    773             ALOGI("unsupported question class %u", question.qclass);
    774             return makeErrorResponse(&header, ns_rcode::ns_r_notimpl, response,
    775                                      response_len);
    776         }
    777         if (!addAnswerRecords(question, &header.answers)) {
    778             return makeErrorResponse(&header, ns_rcode::ns_r_servfail, response,
    779                                      response_len);
    780         }
    781     }
    782     header.qr = true;
    783     char* response_cur = header.write(response, response + *response_len);
    784     if (response_cur == nullptr) {
    785         return false;
    786     }
    787     *response_len = response_cur - response;
    788     return true;
    789 }
    790 
    791 bool DNSResponder::addAnswerRecords(const DNSQuestion& question,
    792                                     std::vector<DNSRecord>* answers) const {
    793     auto it = mappings_.find(QueryKey(question.qname.name, question.qtype));
    794     if (it == mappings_.end()) {
    795         // TODO(imaipi): handle correctly
    796         ALOGI("no mapping found for %s %s, lazily refusing to add an answer",
    797             question.qname.name.c_str(), dnstype2str(question.qtype));
    798         return true;
    799     }
    800     ALOGI("mapping found for %s %s: %s", question.qname.name.c_str(),
    801         dnstype2str(question.qtype), it->second.c_str());
    802     DNSRecord record;
    803     record.name = question.qname;
    804     record.rtype = question.qtype;
    805     record.rclass = ns_class::ns_c_in;
    806     record.ttl = 5;  // seconds
    807     if (question.qtype == ns_type::ns_t_a) {
    808         record.rdata.resize(4);
    809         if (inet_pton(AF_INET, it->second.c_str(), record.rdata.data()) != 1) {
    810             ALOGI("inet_pton(AF_INET, %s) failed", it->second.c_str());
    811             return false;
    812         }
    813     } else if (question.qtype == ns_type::ns_t_aaaa) {
    814         record.rdata.resize(16);
    815         if (inet_pton(AF_INET6, it->second.c_str(), record.rdata.data()) != 1) {
    816             ALOGI("inet_pton(AF_INET6, %s) failed", it->second.c_str());
    817             return false;
    818         }
    819     } else {
    820         ALOGI("unhandled qtype %s", dnstype2str(question.qtype));
    821         return false;
    822     }
    823     answers->push_back(std::move(record));
    824     return true;
    825 }
    826 
    827 bool DNSResponder::makeErrorResponse(DNSHeader* header, ns_rcode rcode,
    828                                      char* response, size_t* response_len)
    829                                      const {
    830     header->answers.clear();
    831     header->authorities.clear();
    832     header->additionals.clear();
    833     header->rcode = rcode;
    834     header->qr = true;
    835     char* response_cur = header->write(response, response + *response_len);
    836     if (response_cur == nullptr) return false;
    837     *response_len = response_cur - response;
    838     return true;
    839 }
    840 
    841 }  // namespace test
    842 
    843