1 /* 2 * Copyright (C) 2016 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "dns_responder.h" 18 19 #include <arpa/inet.h> 20 #include <fcntl.h> 21 #include <netdb.h> 22 #include <stdarg.h> 23 #include <stdio.h> 24 #include <stdlib.h> 25 #include <string.h> 26 #include <sys/epoll.h> 27 #include <sys/socket.h> 28 #include <sys/types.h> 29 #include <unistd.h> 30 31 #include <iostream> 32 #include <vector> 33 34 #define LOG_TAG "DNSResponder" 35 #include <log/log.h> 36 37 namespace test { 38 39 std::string errno2str() { 40 char error_msg[512] = { 0 }; 41 if (strerror_r(errno, error_msg, sizeof(error_msg))) 42 return std::string(); 43 return std::string(error_msg); 44 } 45 46 #define APLOGI(fmt, ...) ALOGI(fmt ": [%d] %s", __VA_ARGS__, errno, errno2str().c_str()) 47 48 std::string str2hex(const char* buffer, size_t len) { 49 std::string str(len*2, '\0'); 50 for (size_t i = 0 ; i < len ; ++i) { 51 static const char* hex = "0123456789ABCDEF"; 52 uint8_t c = buffer[i]; 53 str[i*2] = hex[c >> 4]; 54 str[i*2 + 1] = hex[c & 0x0F]; 55 } 56 return str; 57 } 58 59 std::string addr2str(const sockaddr* sa, socklen_t sa_len) { 60 char host_str[NI_MAXHOST] = { 0 }; 61 int rv = getnameinfo(sa, sa_len, host_str, sizeof(host_str), nullptr, 0, 62 NI_NUMERICHOST); 63 if (rv == 0) return std::string(host_str); 64 return std::string(); 65 } 66 67 /* DNS struct helpers */ 68 69 const char* dnstype2str(unsigned dnstype) { 70 static std::unordered_map<unsigned, const char*> kTypeStrs = { 71 { ns_type::ns_t_a, "A" }, 72 { ns_type::ns_t_ns, "NS" }, 73 { ns_type::ns_t_md, "MD" }, 74 { ns_type::ns_t_mf, "MF" }, 75 { ns_type::ns_t_cname, "CNAME" }, 76 { ns_type::ns_t_soa, "SOA" }, 77 { ns_type::ns_t_mb, "MB" }, 78 { ns_type::ns_t_mb, "MG" }, 79 { ns_type::ns_t_mr, "MR" }, 80 { ns_type::ns_t_null, "NULL" }, 81 { ns_type::ns_t_wks, "WKS" }, 82 { ns_type::ns_t_ptr, "PTR" }, 83 { ns_type::ns_t_hinfo, "HINFO" }, 84 { ns_type::ns_t_minfo, "MINFO" }, 85 { ns_type::ns_t_mx, "MX" }, 86 { ns_type::ns_t_txt, "TXT" }, 87 { ns_type::ns_t_rp, "RP" }, 88 { ns_type::ns_t_afsdb, "AFSDB" }, 89 { ns_type::ns_t_x25, "X25" }, 90 { ns_type::ns_t_isdn, "ISDN" }, 91 { ns_type::ns_t_rt, "RT" }, 92 { ns_type::ns_t_nsap, "NSAP" }, 93 { ns_type::ns_t_nsap_ptr, "NSAP-PTR" }, 94 { ns_type::ns_t_sig, "SIG" }, 95 { ns_type::ns_t_key, "KEY" }, 96 { ns_type::ns_t_px, "PX" }, 97 { ns_type::ns_t_gpos, "GPOS" }, 98 { ns_type::ns_t_aaaa, "AAAA" }, 99 { ns_type::ns_t_loc, "LOC" }, 100 { ns_type::ns_t_nxt, "NXT" }, 101 { ns_type::ns_t_eid, "EID" }, 102 { ns_type::ns_t_nimloc, "NIMLOC" }, 103 { ns_type::ns_t_srv, "SRV" }, 104 { ns_type::ns_t_naptr, "NAPTR" }, 105 { ns_type::ns_t_kx, "KX" }, 106 { ns_type::ns_t_cert, "CERT" }, 107 { ns_type::ns_t_a6, "A6" }, 108 { ns_type::ns_t_dname, "DNAME" }, 109 { ns_type::ns_t_sink, "SINK" }, 110 { ns_type::ns_t_opt, "OPT" }, 111 { ns_type::ns_t_apl, "APL" }, 112 { ns_type::ns_t_tkey, "TKEY" }, 113 { ns_type::ns_t_tsig, "TSIG" }, 114 { ns_type::ns_t_ixfr, "IXFR" }, 115 { ns_type::ns_t_axfr, "AXFR" }, 116 { ns_type::ns_t_mailb, "MAILB" }, 117 { ns_type::ns_t_maila, "MAILA" }, 118 { ns_type::ns_t_any, "ANY" }, 119 { ns_type::ns_t_zxfr, "ZXFR" }, 120 }; 121 auto it = kTypeStrs.find(dnstype); 122 static const char* kUnknownStr{ "UNKNOWN" }; 123 if (it == kTypeStrs.end()) return kUnknownStr; 124 return it->second; 125 } 126 127 const char* dnsclass2str(unsigned dnsclass) { 128 static std::unordered_map<unsigned, const char*> kClassStrs = { 129 { ns_class::ns_c_in , "Internet" }, 130 { 2, "CSNet" }, 131 { ns_class::ns_c_chaos, "ChaosNet" }, 132 { ns_class::ns_c_hs, "Hesiod" }, 133 { ns_class::ns_c_none, "none" }, 134 { ns_class::ns_c_any, "any" } 135 }; 136 auto it = kClassStrs.find(dnsclass); 137 static const char* kUnknownStr{ "UNKNOWN" }; 138 if (it == kClassStrs.end()) return kUnknownStr; 139 return it->second; 140 return "unknown"; 141 } 142 143 struct DNSName { 144 std::string name; 145 const char* read(const char* buffer, const char* buffer_end); 146 char* write(char* buffer, const char* buffer_end) const; 147 const char* toString() const; 148 private: 149 const char* parseField(const char* buffer, const char* buffer_end, 150 bool* last); 151 }; 152 153 const char* DNSName::toString() const { 154 return name.c_str(); 155 } 156 157 const char* DNSName::read(const char* buffer, const char* buffer_end) { 158 const char* cur = buffer; 159 bool last = false; 160 do { 161 cur = parseField(cur, buffer_end, &last); 162 if (cur == nullptr) { 163 ALOGI("parsing failed at line %d", __LINE__); 164 return nullptr; 165 } 166 } while (!last); 167 return cur; 168 } 169 170 char* DNSName::write(char* buffer, const char* buffer_end) const { 171 char* buffer_cur = buffer; 172 for (size_t pos = 0 ; pos < name.size() ; ) { 173 size_t dot_pos = name.find('.', pos); 174 if (dot_pos == std::string::npos) { 175 // Sanity check, should never happen unless parseField is broken. 176 ALOGI("logic error: all names are expected to end with a '.'"); 177 return nullptr; 178 } 179 size_t len = dot_pos - pos; 180 if (len >= 256) { 181 ALOGI("name component '%s' is %zu long, but max is 255", 182 name.substr(pos, dot_pos - pos).c_str(), len); 183 return nullptr; 184 } 185 if (buffer_cur + sizeof(uint8_t) + len > buffer_end) { 186 ALOGI("buffer overflow at line %d", __LINE__); 187 return nullptr; 188 } 189 *buffer_cur++ = len; 190 buffer_cur = std::copy(std::next(name.begin(), pos), 191 std::next(name.begin(), dot_pos), 192 buffer_cur); 193 pos = dot_pos + 1; 194 } 195 // Write final zero. 196 *buffer_cur++ = 0; 197 return buffer_cur; 198 } 199 200 const char* DNSName::parseField(const char* buffer, const char* buffer_end, 201 bool* last) { 202 if (buffer + sizeof(uint8_t) > buffer_end) { 203 ALOGI("parsing failed at line %d", __LINE__); 204 return nullptr; 205 } 206 unsigned field_type = *buffer >> 6; 207 unsigned ofs = *buffer & 0x3F; 208 const char* cur = buffer + sizeof(uint8_t); 209 if (field_type == 0) { 210 // length + name component 211 if (ofs == 0) { 212 *last = true; 213 return cur; 214 } 215 if (cur + ofs > buffer_end) { 216 ALOGI("parsing failed at line %d", __LINE__); 217 return nullptr; 218 } 219 name.append(cur, ofs); 220 name.push_back('.'); 221 return cur + ofs; 222 } else if (field_type == 3) { 223 ALOGI("name compression not implemented"); 224 return nullptr; 225 } 226 ALOGI("invalid name field type"); 227 return nullptr; 228 } 229 230 struct DNSQuestion { 231 DNSName qname; 232 unsigned qtype; 233 unsigned qclass; 234 const char* read(const char* buffer, const char* buffer_end); 235 char* write(char* buffer, const char* buffer_end) const; 236 std::string toString() const; 237 }; 238 239 const char* DNSQuestion::read(const char* buffer, const char* buffer_end) { 240 const char* cur = qname.read(buffer, buffer_end); 241 if (cur == nullptr) { 242 ALOGI("parsing failed at line %d", __LINE__); 243 return nullptr; 244 } 245 if (cur + 2*sizeof(uint16_t) > buffer_end) { 246 ALOGI("parsing failed at line %d", __LINE__); 247 return nullptr; 248 } 249 qtype = ntohs(*reinterpret_cast<const uint16_t*>(cur)); 250 qclass = ntohs(*reinterpret_cast<const uint16_t*>(cur + sizeof(uint16_t))); 251 return cur + 2*sizeof(uint16_t); 252 } 253 254 char* DNSQuestion::write(char* buffer, const char* buffer_end) const { 255 char* buffer_cur = qname.write(buffer, buffer_end); 256 if (buffer_cur == nullptr) return nullptr; 257 if (buffer_cur + 2*sizeof(uint16_t) > buffer_end) { 258 ALOGI("buffer overflow on line %d", __LINE__); 259 return nullptr; 260 } 261 *reinterpret_cast<uint16_t*>(buffer_cur) = htons(qtype); 262 *reinterpret_cast<uint16_t*>(buffer_cur + sizeof(uint16_t)) = 263 htons(qclass); 264 return buffer_cur + 2*sizeof(uint16_t); 265 } 266 267 std::string DNSQuestion::toString() const { 268 char buffer[4096]; 269 int len = snprintf(buffer, sizeof(buffer), "Q<%s,%s,%s>", qname.toString(), 270 dnstype2str(qtype), dnsclass2str(qclass)); 271 return std::string(buffer, len); 272 } 273 274 struct DNSRecord { 275 DNSName name; 276 unsigned rtype; 277 unsigned rclass; 278 unsigned ttl; 279 std::vector<char> rdata; 280 const char* read(const char* buffer, const char* buffer_end); 281 char* write(char* buffer, const char* buffer_end) const; 282 std::string toString() const; 283 private: 284 struct IntFields { 285 uint16_t rtype; 286 uint16_t rclass; 287 uint32_t ttl; 288 uint16_t rdlen; 289 } __attribute__((__packed__)); 290 291 const char* readIntFields(const char* buffer, const char* buffer_end, 292 unsigned* rdlen); 293 char* writeIntFields(unsigned rdlen, char* buffer, 294 const char* buffer_end) const; 295 }; 296 297 const char* DNSRecord::read(const char* buffer, const char* buffer_end) { 298 const char* cur = name.read(buffer, buffer_end); 299 if (cur == nullptr) { 300 ALOGI("parsing failed at line %d", __LINE__); 301 return nullptr; 302 } 303 unsigned rdlen = 0; 304 cur = readIntFields(cur, buffer_end, &rdlen); 305 if (cur == nullptr) { 306 ALOGI("parsing failed at line %d", __LINE__); 307 return nullptr; 308 } 309 if (cur + rdlen > buffer_end) { 310 ALOGI("parsing failed at line %d", __LINE__); 311 return nullptr; 312 } 313 rdata.assign(cur, cur + rdlen); 314 return cur + rdlen; 315 } 316 317 char* DNSRecord::write(char* buffer, const char* buffer_end) const { 318 char* buffer_cur = name.write(buffer, buffer_end); 319 if (buffer_cur == nullptr) return nullptr; 320 buffer_cur = writeIntFields(rdata.size(), buffer_cur, buffer_end); 321 if (buffer_cur == nullptr) return nullptr; 322 if (buffer_cur + rdata.size() > buffer_end) { 323 ALOGI("buffer overflow on line %d", __LINE__); 324 return nullptr; 325 } 326 return std::copy(rdata.begin(), rdata.end(), buffer_cur); 327 } 328 329 std::string DNSRecord::toString() const { 330 char buffer[4096]; 331 int len = snprintf(buffer, sizeof(buffer), "R<%s,%s,%s>", name.toString(), 332 dnstype2str(rtype), dnsclass2str(rclass)); 333 return std::string(buffer, len); 334 } 335 336 const char* DNSRecord::readIntFields(const char* buffer, const char* buffer_end, 337 unsigned* rdlen) { 338 if (buffer + sizeof(IntFields) > buffer_end ) { 339 ALOGI("parsing failed at line %d", __LINE__); 340 return nullptr; 341 } 342 const auto& intfields = *reinterpret_cast<const IntFields*>(buffer); 343 rtype = ntohs(intfields.rtype); 344 rclass = ntohs(intfields.rclass); 345 ttl = ntohl(intfields.ttl); 346 *rdlen = ntohs(intfields.rdlen); 347 return buffer + sizeof(IntFields); 348 } 349 350 char* DNSRecord::writeIntFields(unsigned rdlen, char* buffer, 351 const char* buffer_end) const { 352 if (buffer + sizeof(IntFields) > buffer_end ) { 353 ALOGI("buffer overflow on line %d", __LINE__); 354 return nullptr; 355 } 356 auto& intfields = *reinterpret_cast<IntFields*>(buffer); 357 intfields.rtype = htons(rtype); 358 intfields.rclass = htons(rclass); 359 intfields.ttl = htonl(ttl); 360 intfields.rdlen = htons(rdlen); 361 return buffer + sizeof(IntFields); 362 } 363 364 struct DNSHeader { 365 unsigned id; 366 bool ra; 367 uint8_t rcode; 368 bool qr; 369 uint8_t opcode; 370 bool aa; 371 bool tr; 372 bool rd; 373 std::vector<DNSQuestion> questions; 374 std::vector<DNSRecord> answers; 375 std::vector<DNSRecord> authorities; 376 std::vector<DNSRecord> additionals; 377 const char* read(const char* buffer, const char* buffer_end); 378 char* write(char* buffer, const char* buffer_end) const; 379 std::string toString() const; 380 381 private: 382 struct Header { 383 uint16_t id; 384 uint8_t flags0; 385 uint8_t flags1; 386 uint16_t qdcount; 387 uint16_t ancount; 388 uint16_t nscount; 389 uint16_t arcount; 390 } __attribute__((__packed__)); 391 392 const char* readHeader(const char* buffer, const char* buffer_end, 393 unsigned* qdcount, unsigned* ancount, 394 unsigned* nscount, unsigned* arcount); 395 }; 396 397 const char* DNSHeader::read(const char* buffer, const char* buffer_end) { 398 unsigned qdcount; 399 unsigned ancount; 400 unsigned nscount; 401 unsigned arcount; 402 const char* cur = readHeader(buffer, buffer_end, &qdcount, &ancount, 403 &nscount, &arcount); 404 if (cur == nullptr) { 405 ALOGI("parsing failed at line %d", __LINE__); 406 return nullptr; 407 } 408 if (qdcount) { 409 questions.resize(qdcount); 410 for (unsigned i = 0 ; i < qdcount ; ++i) { 411 cur = questions[i].read(cur, buffer_end); 412 if (cur == nullptr) { 413 ALOGI("parsing failed at line %d", __LINE__); 414 return nullptr; 415 } 416 } 417 } 418 if (ancount) { 419 answers.resize(ancount); 420 for (unsigned i = 0 ; i < ancount ; ++i) { 421 cur = answers[i].read(cur, buffer_end); 422 if (cur == nullptr) { 423 ALOGI("parsing failed at line %d", __LINE__); 424 return nullptr; 425 } 426 } 427 } 428 if (nscount) { 429 authorities.resize(nscount); 430 for (unsigned i = 0 ; i < nscount ; ++i) { 431 cur = authorities[i].read(cur, buffer_end); 432 if (cur == nullptr) { 433 ALOGI("parsing failed at line %d", __LINE__); 434 return nullptr; 435 } 436 } 437 } 438 if (arcount) { 439 additionals.resize(arcount); 440 for (unsigned i = 0 ; i < arcount ; ++i) { 441 cur = additionals[i].read(cur, buffer_end); 442 if (cur == nullptr) { 443 ALOGI("parsing failed at line %d", __LINE__); 444 return nullptr; 445 } 446 } 447 } 448 return cur; 449 } 450 451 char* DNSHeader::write(char* buffer, const char* buffer_end) const { 452 if (buffer + sizeof(Header) > buffer_end) { 453 ALOGI("buffer overflow on line %d", __LINE__); 454 return nullptr; 455 } 456 Header& header = *reinterpret_cast<Header*>(buffer); 457 // bytes 0-1 458 header.id = htons(id); 459 // byte 2: 7:qr, 3-6:opcode, 2:aa, 1:tr, 0:rd 460 header.flags0 = (qr << 7) | (opcode << 3) | (aa << 2) | (tr << 1) | rd; 461 // byte 3: 7:ra, 6:zero, 5:ad, 4:cd, 0-3:rcode 462 header.flags1 = rcode; 463 // rest of header 464 header.qdcount = htons(questions.size()); 465 header.ancount = htons(answers.size()); 466 header.nscount = htons(authorities.size()); 467 header.arcount = htons(additionals.size()); 468 char* buffer_cur = buffer + sizeof(Header); 469 for (const DNSQuestion& question : questions) { 470 buffer_cur = question.write(buffer_cur, buffer_end); 471 if (buffer_cur == nullptr) return nullptr; 472 } 473 for (const DNSRecord& answer : answers) { 474 buffer_cur = answer.write(buffer_cur, buffer_end); 475 if (buffer_cur == nullptr) return nullptr; 476 } 477 for (const DNSRecord& authority : authorities) { 478 buffer_cur = authority.write(buffer_cur, buffer_end); 479 if (buffer_cur == nullptr) return nullptr; 480 } 481 for (const DNSRecord& additional : additionals) { 482 buffer_cur = additional.write(buffer_cur, buffer_end); 483 if (buffer_cur == nullptr) return nullptr; 484 } 485 return buffer_cur; 486 } 487 488 std::string DNSHeader::toString() const { 489 // TODO 490 return std::string(); 491 } 492 493 const char* DNSHeader::readHeader(const char* buffer, const char* buffer_end, 494 unsigned* qdcount, unsigned* ancount, 495 unsigned* nscount, unsigned* arcount) { 496 if (buffer + sizeof(Header) > buffer_end) 497 return 0; 498 const auto& header = *reinterpret_cast<const Header*>(buffer); 499 // bytes 0-1 500 id = ntohs(header.id); 501 // byte 2: 7:qr, 3-6:opcode, 2:aa, 1:tr, 0:rd 502 qr = header.flags0 >> 7; 503 opcode = (header.flags0 >> 3) & 0x0F; 504 aa = (header.flags0 >> 2) & 1; 505 tr = (header.flags0 >> 1) & 1; 506 rd = header.flags0 & 1; 507 // byte 3: 7:ra, 6:zero, 5:ad, 4:cd, 0-3:rcode 508 ra = header.flags1 >> 7; 509 rcode = header.flags1 & 0xF; 510 // rest of header 511 *qdcount = ntohs(header.qdcount); 512 *ancount = ntohs(header.ancount); 513 *nscount = ntohs(header.nscount); 514 *arcount = ntohs(header.arcount); 515 return buffer + sizeof(Header); 516 } 517 518 /* DNS responder */ 519 520 DNSResponder::DNSResponder(std::string listen_address, 521 std::string listen_service, int poll_timeout_ms, 522 uint16_t error_rcode, double response_probability) : 523 listen_address_(std::move(listen_address)), listen_service_(std::move(listen_service)), 524 poll_timeout_ms_(poll_timeout_ms), error_rcode_(error_rcode), 525 response_probability_(response_probability), 526 socket_(-1), epoll_fd_(-1), terminate_(false) { } 527 528 DNSResponder::~DNSResponder() { 529 stopServer(); 530 } 531 532 void DNSResponder::addMapping(const char* name, ns_type type, 533 const char* addr) { 534 std::lock_guard<std::mutex> lock(mappings_mutex_); 535 auto it = mappings_.find(QueryKey(name, type)); 536 if (it != mappings_.end()) { 537 ALOGI("Overwriting mapping for (%s, %s), previous address %s, new " 538 "address %s", name, dnstype2str(type), it->second.c_str(), 539 addr); 540 it->second = addr; 541 return; 542 } 543 mappings_.emplace(std::piecewise_construct, 544 std::forward_as_tuple(name, type), 545 std::forward_as_tuple(addr)); 546 } 547 548 void DNSResponder::removeMapping(const char* name, ns_type type) { 549 std::lock_guard<std::mutex> lock(mappings_mutex_); 550 auto it = mappings_.find(QueryKey(name, type)); 551 if (it != mappings_.end()) { 552 ALOGI("Cannot remove mapping mapping from (%s, %s), not present", name, 553 dnstype2str(type)); 554 return; 555 } 556 mappings_.erase(it); 557 } 558 559 void DNSResponder::setResponseProbability(double response_probability) { 560 response_probability_ = response_probability; 561 } 562 563 bool DNSResponder::running() const { 564 return socket_ != -1; 565 } 566 567 bool DNSResponder::startServer() { 568 if (running()) { 569 ALOGI("server already running"); 570 return false; 571 } 572 addrinfo ai_hints{ 573 .ai_family = AF_UNSPEC, 574 .ai_socktype = SOCK_DGRAM, 575 .ai_flags = AI_PASSIVE 576 }; 577 addrinfo* ai_res; 578 int rv = getaddrinfo(listen_address_.c_str(), listen_service_.c_str(), 579 &ai_hints, &ai_res); 580 if (rv) { 581 ALOGI("getaddrinfo(%s, %s) failed: %s", listen_address_.c_str(), 582 listen_service_.c_str(), gai_strerror(rv)); 583 return false; 584 } 585 int s = -1; 586 for (const addrinfo* ai = ai_res ; ai ; ai = ai->ai_next) { 587 s = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol); 588 if (s < 0) continue; 589 const int one = 1; 590 setsockopt(s, SOL_SOCKET, SO_REUSEPORT, &one, sizeof(one)); 591 if (bind(s, ai->ai_addr, ai->ai_addrlen)) { 592 APLOGI("bind failed for socket %d", s); 593 close(s); 594 s = -1; 595 continue; 596 } 597 std::string host_str = addr2str(ai->ai_addr, ai->ai_addrlen); 598 ALOGI("bound to UDP %s:%s", host_str.c_str(), listen_service_.c_str()); 599 break; 600 } 601 freeaddrinfo(ai_res); 602 if (s < 0) { 603 ALOGI("bind() failed"); 604 return false; 605 } 606 607 int flags = fcntl(s, F_GETFL, 0); 608 if (flags < 0) flags = 0; 609 if (fcntl(s, F_SETFL, flags | O_NONBLOCK) < 0) { 610 APLOGI("fcntl(F_SETFL) failed for socket %d", s); 611 close(s); 612 return false; 613 } 614 615 int ep_fd = epoll_create(1); 616 if (ep_fd < 0) { 617 char error_msg[512] = { 0 }; 618 if (strerror_r(errno, error_msg, sizeof(error_msg))) 619 strncpy(error_msg, "UNKNOWN", sizeof(error_msg)); 620 APLOGI("epoll_create() failed: %s", error_msg); 621 close(s); 622 return false; 623 } 624 epoll_event ev; 625 ev.events = EPOLLIN; 626 ev.data.fd = s; 627 if (epoll_ctl(ep_fd, EPOLL_CTL_ADD, s, &ev) < 0) { 628 APLOGI("epoll_ctl() failed for socket %d", s); 629 close(ep_fd); 630 close(s); 631 return false; 632 } 633 634 epoll_fd_ = ep_fd; 635 socket_ = s; 636 { 637 std::lock_guard<std::mutex> lock(update_mutex_); 638 handler_thread_ = std::thread(&DNSResponder::requestHandler, this); 639 } 640 ALOGI("server started successfully"); 641 return true; 642 } 643 644 bool DNSResponder::stopServer() { 645 std::lock_guard<std::mutex> lock(update_mutex_); 646 if (!running()) { 647 ALOGI("server not running"); 648 return false; 649 } 650 if (terminate_) { 651 ALOGI("LOGIC ERROR"); 652 return false; 653 } 654 ALOGI("stopping server"); 655 terminate_ = true; 656 handler_thread_.join(); 657 close(epoll_fd_); 658 close(socket_); 659 terminate_ = false; 660 socket_ = -1; 661 ALOGI("server stopped successfully"); 662 return true; 663 } 664 665 std::vector<std::pair<std::string, ns_type >> DNSResponder::queries() const { 666 std::lock_guard<std::mutex> lock(queries_mutex_); 667 return queries_; 668 } 669 670 void DNSResponder::clearQueries() { 671 std::lock_guard<std::mutex> lock(queries_mutex_); 672 queries_.clear(); 673 } 674 675 void DNSResponder::requestHandler() { 676 epoll_event evs[1]; 677 while (!terminate_) { 678 int n = epoll_wait(epoll_fd_, evs, 1, poll_timeout_ms_); 679 if (n == 0) continue; 680 if (n < 0) { 681 ALOGI("epoll_wait() failed"); 682 // TODO(imaipi): terminate on error. 683 return; 684 } 685 char buffer[4096]; 686 sockaddr_storage sa; 687 socklen_t sa_len = sizeof(sa); 688 ssize_t len; 689 do { 690 len = recvfrom(socket_, buffer, sizeof(buffer), 0, 691 (sockaddr*) &sa, &sa_len); 692 } while (len < 0 && (errno == EAGAIN || errno == EINTR)); 693 if (len <= 0) { 694 ALOGI("recvfrom() failed"); 695 continue; 696 } 697 ALOGI("read %zd bytes", len); 698 char response[4096]; 699 size_t response_len = sizeof(response); 700 if (handleDNSRequest(buffer, len, response, &response_len) && 701 response_len > 0) { 702 len = sendto(socket_, response, response_len, 0, 703 reinterpret_cast<const sockaddr*>(&sa), sa_len); 704 std::string host_str = 705 addr2str(reinterpret_cast<const sockaddr*>(&sa), sa_len); 706 if (len > 0) { 707 ALOGI("sent %zu bytes to %s", len, host_str.c_str()); 708 } else { 709 APLOGI("sendto() failed for %s", host_str.c_str()); 710 } 711 // Test that the response is actually a correct DNS message. 712 const char* response_end = response + len; 713 DNSHeader header; 714 const char* cur = header.read(response, response_end); 715 if (cur == nullptr) ALOGI("response is flawed"); 716 717 } else { 718 ALOGI("not responding"); 719 } 720 } 721 } 722 723 bool DNSResponder::handleDNSRequest(const char* buffer, ssize_t len, 724 char* response, size_t* response_len) 725 const { 726 ALOGI("request: '%s'", str2hex(buffer, len).c_str()); 727 const char* buffer_end = buffer + len; 728 DNSHeader header; 729 const char* cur = header.read(buffer, buffer_end); 730 // TODO(imaipi): for now, unparsable messages are silently dropped, fix. 731 if (cur == nullptr) { 732 ALOGI("failed to parse query"); 733 return false; 734 } 735 if (header.qr) { 736 ALOGI("response received instead of a query"); 737 return false; 738 } 739 if (header.opcode != ns_opcode::ns_o_query) { 740 ALOGI("unsupported request opcode received"); 741 return makeErrorResponse(&header, ns_rcode::ns_r_notimpl, response, 742 response_len); 743 } 744 if (header.questions.empty()) { 745 ALOGI("no questions present"); 746 return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response, 747 response_len); 748 } 749 if (!header.answers.empty()) { 750 ALOGI("already %zu answers present in query", header.answers.size()); 751 return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response, 752 response_len); 753 } 754 { 755 std::lock_guard<std::mutex> lock(queries_mutex_); 756 for (const DNSQuestion& question : header.questions) { 757 queries_.push_back(make_pair(question.qname.name, 758 ns_type(question.qtype))); 759 } 760 } 761 762 // Ignore requests with the preset probability. 763 auto constexpr bound = std::numeric_limits<unsigned>::max(); 764 if (arc4random_uniform(bound) > bound*response_probability_) { 765 ALOGI("returning SRVFAIL in accordance with probability distribution"); 766 return makeErrorResponse(&header, ns_rcode::ns_r_servfail, response, 767 response_len); 768 } 769 770 for (const DNSQuestion& question : header.questions) { 771 if (question.qclass != ns_class::ns_c_in && 772 question.qclass != ns_class::ns_c_any) { 773 ALOGI("unsupported question class %u", question.qclass); 774 return makeErrorResponse(&header, ns_rcode::ns_r_notimpl, response, 775 response_len); 776 } 777 if (!addAnswerRecords(question, &header.answers)) { 778 return makeErrorResponse(&header, ns_rcode::ns_r_servfail, response, 779 response_len); 780 } 781 } 782 header.qr = true; 783 char* response_cur = header.write(response, response + *response_len); 784 if (response_cur == nullptr) { 785 return false; 786 } 787 *response_len = response_cur - response; 788 return true; 789 } 790 791 bool DNSResponder::addAnswerRecords(const DNSQuestion& question, 792 std::vector<DNSRecord>* answers) const { 793 auto it = mappings_.find(QueryKey(question.qname.name, question.qtype)); 794 if (it == mappings_.end()) { 795 // TODO(imaipi): handle correctly 796 ALOGI("no mapping found for %s %s, lazily refusing to add an answer", 797 question.qname.name.c_str(), dnstype2str(question.qtype)); 798 return true; 799 } 800 ALOGI("mapping found for %s %s: %s", question.qname.name.c_str(), 801 dnstype2str(question.qtype), it->second.c_str()); 802 DNSRecord record; 803 record.name = question.qname; 804 record.rtype = question.qtype; 805 record.rclass = ns_class::ns_c_in; 806 record.ttl = 5; // seconds 807 if (question.qtype == ns_type::ns_t_a) { 808 record.rdata.resize(4); 809 if (inet_pton(AF_INET, it->second.c_str(), record.rdata.data()) != 1) { 810 ALOGI("inet_pton(AF_INET, %s) failed", it->second.c_str()); 811 return false; 812 } 813 } else if (question.qtype == ns_type::ns_t_aaaa) { 814 record.rdata.resize(16); 815 if (inet_pton(AF_INET6, it->second.c_str(), record.rdata.data()) != 1) { 816 ALOGI("inet_pton(AF_INET6, %s) failed", it->second.c_str()); 817 return false; 818 } 819 } else { 820 ALOGI("unhandled qtype %s", dnstype2str(question.qtype)); 821 return false; 822 } 823 answers->push_back(std::move(record)); 824 return true; 825 } 826 827 bool DNSResponder::makeErrorResponse(DNSHeader* header, ns_rcode rcode, 828 char* response, size_t* response_len) 829 const { 830 header->answers.clear(); 831 header->authorities.clear(); 832 header->additionals.clear(); 833 header->rcode = rcode; 834 header->qr = true; 835 char* response_cur = header->write(response, response + *response_len); 836 if (response_cur == nullptr) return false; 837 *response_len = response_cur - response; 838 return true; 839 } 840 841 } // namespace test 842 843