1 /* 2 * Copyright (C) 2016 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "dns_responder.h" 18 19 #include <arpa/inet.h> 20 #include <fcntl.h> 21 #include <netdb.h> 22 #include <stdarg.h> 23 #include <stdio.h> 24 #include <stdlib.h> 25 #include <string.h> 26 #include <sys/epoll.h> 27 #include <sys/socket.h> 28 #include <sys/types.h> 29 #include <unistd.h> 30 31 #include <iostream> 32 #include <vector> 33 34 #include <log/log.h> 35 36 namespace test { 37 38 std::string errno2str() { 39 char error_msg[512] = { 0 }; 40 if (strerror_r(errno, error_msg, sizeof(error_msg))) 41 return std::string(); 42 return std::string(error_msg); 43 } 44 45 #define APLOGI(fmt, ...) ALOGI(fmt ": [%d] %s", __VA_ARGS__, errno, errno2str().c_str()) 46 47 std::string str2hex(const char* buffer, size_t len) { 48 std::string str(len*2, '\0'); 49 for (size_t i = 0 ; i < len ; ++i) { 50 static const char* hex = "0123456789ABCDEF"; 51 uint8_t c = buffer[i]; 52 str[i*2] = hex[c >> 4]; 53 str[i*2 + 1] = hex[c & 0x0F]; 54 } 55 return str; 56 } 57 58 std::string addr2str(const sockaddr* sa, socklen_t sa_len) { 59 char host_str[NI_MAXHOST] = { 0 }; 60 int rv = getnameinfo(sa, sa_len, host_str, sizeof(host_str), nullptr, 0, 61 NI_NUMERICHOST); 62 if (rv == 0) return std::string(host_str); 63 return std::string(); 64 } 65 66 /* DNS struct helpers */ 67 68 const char* dnstype2str(unsigned dnstype) { 69 static std::unordered_map<unsigned, const char*> kTypeStrs = { 70 { ns_type::ns_t_a, "A" }, 71 { ns_type::ns_t_ns, "NS" }, 72 { ns_type::ns_t_md, "MD" }, 73 { ns_type::ns_t_mf, "MF" }, 74 { ns_type::ns_t_cname, "CNAME" }, 75 { ns_type::ns_t_soa, "SOA" }, 76 { ns_type::ns_t_mb, "MB" }, 77 { ns_type::ns_t_mb, "MG" }, 78 { ns_type::ns_t_mr, "MR" }, 79 { ns_type::ns_t_null, "NULL" }, 80 { ns_type::ns_t_wks, "WKS" }, 81 { ns_type::ns_t_ptr, "PTR" }, 82 { ns_type::ns_t_hinfo, "HINFO" }, 83 { ns_type::ns_t_minfo, "MINFO" }, 84 { ns_type::ns_t_mx, "MX" }, 85 { ns_type::ns_t_txt, "TXT" }, 86 { ns_type::ns_t_rp, "RP" }, 87 { ns_type::ns_t_afsdb, "AFSDB" }, 88 { ns_type::ns_t_x25, "X25" }, 89 { ns_type::ns_t_isdn, "ISDN" }, 90 { ns_type::ns_t_rt, "RT" }, 91 { ns_type::ns_t_nsap, "NSAP" }, 92 { ns_type::ns_t_nsap_ptr, "NSAP-PTR" }, 93 { ns_type::ns_t_sig, "SIG" }, 94 { ns_type::ns_t_key, "KEY" }, 95 { ns_type::ns_t_px, "PX" }, 96 { ns_type::ns_t_gpos, "GPOS" }, 97 { ns_type::ns_t_aaaa, "AAAA" }, 98 { ns_type::ns_t_loc, "LOC" }, 99 { ns_type::ns_t_nxt, "NXT" }, 100 { ns_type::ns_t_eid, "EID" }, 101 { ns_type::ns_t_nimloc, "NIMLOC" }, 102 { ns_type::ns_t_srv, "SRV" }, 103 { ns_type::ns_t_naptr, "NAPTR" }, 104 { ns_type::ns_t_kx, "KX" }, 105 { ns_type::ns_t_cert, "CERT" }, 106 { ns_type::ns_t_a6, "A6" }, 107 { ns_type::ns_t_dname, "DNAME" }, 108 { ns_type::ns_t_sink, "SINK" }, 109 { ns_type::ns_t_opt, "OPT" }, 110 { ns_type::ns_t_apl, "APL" }, 111 { ns_type::ns_t_tkey, "TKEY" }, 112 { ns_type::ns_t_tsig, "TSIG" }, 113 { ns_type::ns_t_ixfr, "IXFR" }, 114 { ns_type::ns_t_axfr, "AXFR" }, 115 { ns_type::ns_t_mailb, "MAILB" }, 116 { ns_type::ns_t_maila, "MAILA" }, 117 { ns_type::ns_t_any, "ANY" }, 118 { ns_type::ns_t_zxfr, "ZXFR" }, 119 }; 120 auto it = kTypeStrs.find(dnstype); 121 static const char* kUnknownStr{ "UNKNOWN" }; 122 if (it == kTypeStrs.end()) return kUnknownStr; 123 return it->second; 124 } 125 126 const char* dnsclass2str(unsigned dnsclass) { 127 static std::unordered_map<unsigned, const char*> kClassStrs = { 128 { ns_class::ns_c_in , "Internet" }, 129 { 2, "CSNet" }, 130 { ns_class::ns_c_chaos, "ChaosNet" }, 131 { ns_class::ns_c_hs, "Hesiod" }, 132 { ns_class::ns_c_none, "none" }, 133 { ns_class::ns_c_any, "any" } 134 }; 135 auto it = kClassStrs.find(dnsclass); 136 static const char* kUnknownStr{ "UNKNOWN" }; 137 if (it == kClassStrs.end()) return kUnknownStr; 138 return it->second; 139 return "unknown"; 140 } 141 142 struct DNSName { 143 std::string name; 144 const char* read(const char* buffer, const char* buffer_end); 145 char* write(char* buffer, const char* buffer_end) const; 146 const char* toString() const; 147 private: 148 const char* parseField(const char* buffer, const char* buffer_end, 149 bool* last); 150 }; 151 152 const char* DNSName::toString() const { 153 return name.c_str(); 154 } 155 156 const char* DNSName::read(const char* buffer, const char* buffer_end) { 157 const char* cur = buffer; 158 bool last = false; 159 do { 160 cur = parseField(cur, buffer_end, &last); 161 if (cur == nullptr) { 162 ALOGI("parsing failed at line %d", __LINE__); 163 return nullptr; 164 } 165 } while (!last); 166 return cur; 167 } 168 169 char* DNSName::write(char* buffer, const char* buffer_end) const { 170 char* buffer_cur = buffer; 171 for (size_t pos = 0 ; pos < name.size() ; ) { 172 size_t dot_pos = name.find('.', pos); 173 if (dot_pos == std::string::npos) { 174 // Sanity check, should never happen unless parseField is broken. 175 ALOGI("logic error: all names are expected to end with a '.'"); 176 return nullptr; 177 } 178 size_t len = dot_pos - pos; 179 if (len >= 256) { 180 ALOGI("name component '%s' is %zu long, but max is 255", 181 name.substr(pos, dot_pos - pos).c_str(), len); 182 return nullptr; 183 } 184 if (buffer_cur + sizeof(uint8_t) + len > buffer_end) { 185 ALOGI("buffer overflow at line %d", __LINE__); 186 return nullptr; 187 } 188 *buffer_cur++ = len; 189 buffer_cur = std::copy(std::next(name.begin(), pos), 190 std::next(name.begin(), dot_pos), 191 buffer_cur); 192 pos = dot_pos + 1; 193 } 194 // Write final zero. 195 *buffer_cur++ = 0; 196 return buffer_cur; 197 } 198 199 const char* DNSName::parseField(const char* buffer, const char* buffer_end, 200 bool* last) { 201 if (buffer + sizeof(uint8_t) > buffer_end) { 202 ALOGI("parsing failed at line %d", __LINE__); 203 return nullptr; 204 } 205 unsigned field_type = *buffer >> 6; 206 unsigned ofs = *buffer & 0x3F; 207 const char* cur = buffer + sizeof(uint8_t); 208 if (field_type == 0) { 209 // length + name component 210 if (ofs == 0) { 211 *last = true; 212 return cur; 213 } 214 if (cur + ofs > buffer_end) { 215 ALOGI("parsing failed at line %d", __LINE__); 216 return nullptr; 217 } 218 name.append(cur, ofs); 219 name.push_back('.'); 220 return cur + ofs; 221 } else if (field_type == 3) { 222 ALOGI("name compression not implemented"); 223 return nullptr; 224 } 225 ALOGI("invalid name field type"); 226 return nullptr; 227 } 228 229 struct DNSQuestion { 230 DNSName qname; 231 unsigned qtype; 232 unsigned qclass; 233 const char* read(const char* buffer, const char* buffer_end); 234 char* write(char* buffer, const char* buffer_end) const; 235 std::string toString() const; 236 }; 237 238 const char* DNSQuestion::read(const char* buffer, const char* buffer_end) { 239 const char* cur = qname.read(buffer, buffer_end); 240 if (cur == nullptr) { 241 ALOGI("parsing failed at line %d", __LINE__); 242 return nullptr; 243 } 244 if (cur + 2*sizeof(uint16_t) > buffer_end) { 245 ALOGI("parsing failed at line %d", __LINE__); 246 return nullptr; 247 } 248 qtype = ntohs(*reinterpret_cast<const uint16_t*>(cur)); 249 qclass = ntohs(*reinterpret_cast<const uint16_t*>(cur + sizeof(uint16_t))); 250 return cur + 2*sizeof(uint16_t); 251 } 252 253 char* DNSQuestion::write(char* buffer, const char* buffer_end) const { 254 char* buffer_cur = qname.write(buffer, buffer_end); 255 if (buffer_cur == nullptr) return nullptr; 256 if (buffer_cur + 2*sizeof(uint16_t) > buffer_end) { 257 ALOGI("buffer overflow on line %d", __LINE__); 258 return nullptr; 259 } 260 *reinterpret_cast<uint16_t*>(buffer_cur) = htons(qtype); 261 *reinterpret_cast<uint16_t*>(buffer_cur + sizeof(uint16_t)) = 262 htons(qclass); 263 return buffer_cur + 2*sizeof(uint16_t); 264 } 265 266 std::string DNSQuestion::toString() const { 267 char buffer[4096]; 268 int len = snprintf(buffer, sizeof(buffer), "Q<%s,%s,%s>", qname.toString(), 269 dnstype2str(qtype), dnsclass2str(qclass)); 270 return std::string(buffer, len); 271 } 272 273 struct DNSRecord { 274 DNSName name; 275 unsigned rtype; 276 unsigned rclass; 277 unsigned ttl; 278 std::vector<char> rdata; 279 const char* read(const char* buffer, const char* buffer_end); 280 char* write(char* buffer, const char* buffer_end) const; 281 std::string toString() const; 282 private: 283 struct IntFields { 284 uint16_t rtype; 285 uint16_t rclass; 286 uint32_t ttl; 287 uint16_t rdlen; 288 } __attribute__((__packed__)); 289 290 const char* readIntFields(const char* buffer, const char* buffer_end, 291 unsigned* rdlen); 292 char* writeIntFields(unsigned rdlen, char* buffer, 293 const char* buffer_end) const; 294 }; 295 296 const char* DNSRecord::read(const char* buffer, const char* buffer_end) { 297 const char* cur = name.read(buffer, buffer_end); 298 if (cur == nullptr) { 299 ALOGI("parsing failed at line %d", __LINE__); 300 return nullptr; 301 } 302 unsigned rdlen = 0; 303 cur = readIntFields(cur, buffer_end, &rdlen); 304 if (cur == nullptr) { 305 ALOGI("parsing failed at line %d", __LINE__); 306 return nullptr; 307 } 308 if (cur + rdlen > buffer_end) { 309 ALOGI("parsing failed at line %d", __LINE__); 310 return nullptr; 311 } 312 rdata.assign(cur, cur + rdlen); 313 return cur + rdlen; 314 } 315 316 char* DNSRecord::write(char* buffer, const char* buffer_end) const { 317 char* buffer_cur = name.write(buffer, buffer_end); 318 if (buffer_cur == nullptr) return nullptr; 319 buffer_cur = writeIntFields(rdata.size(), buffer_cur, buffer_end); 320 if (buffer_cur == nullptr) return nullptr; 321 if (buffer_cur + rdata.size() > buffer_end) { 322 ALOGI("buffer overflow on line %d", __LINE__); 323 return nullptr; 324 } 325 return std::copy(rdata.begin(), rdata.end(), buffer_cur); 326 } 327 328 std::string DNSRecord::toString() const { 329 char buffer[4096]; 330 int len = snprintf(buffer, sizeof(buffer), "R<%s,%s,%s>", name.toString(), 331 dnstype2str(rtype), dnsclass2str(rclass)); 332 return std::string(buffer, len); 333 } 334 335 const char* DNSRecord::readIntFields(const char* buffer, const char* buffer_end, 336 unsigned* rdlen) { 337 if (buffer + sizeof(IntFields) > buffer_end ) { 338 ALOGI("parsing failed at line %d", __LINE__); 339 return nullptr; 340 } 341 const auto& intfields = *reinterpret_cast<const IntFields*>(buffer); 342 rtype = ntohs(intfields.rtype); 343 rclass = ntohs(intfields.rclass); 344 ttl = ntohl(intfields.ttl); 345 *rdlen = ntohs(intfields.rdlen); 346 return buffer + sizeof(IntFields); 347 } 348 349 char* DNSRecord::writeIntFields(unsigned rdlen, char* buffer, 350 const char* buffer_end) const { 351 if (buffer + sizeof(IntFields) > buffer_end ) { 352 ALOGI("buffer overflow on line %d", __LINE__); 353 return nullptr; 354 } 355 auto& intfields = *reinterpret_cast<IntFields*>(buffer); 356 intfields.rtype = htons(rtype); 357 intfields.rclass = htons(rclass); 358 intfields.ttl = htonl(ttl); 359 intfields.rdlen = htons(rdlen); 360 return buffer + sizeof(IntFields); 361 } 362 363 struct DNSHeader { 364 unsigned id; 365 bool ra; 366 uint8_t rcode; 367 bool qr; 368 uint8_t opcode; 369 bool aa; 370 bool tr; 371 bool rd; 372 std::vector<DNSQuestion> questions; 373 std::vector<DNSRecord> answers; 374 std::vector<DNSRecord> authorities; 375 std::vector<DNSRecord> additionals; 376 const char* read(const char* buffer, const char* buffer_end); 377 char* write(char* buffer, const char* buffer_end) const; 378 std::string toString() const; 379 380 private: 381 struct Header { 382 uint16_t id; 383 uint8_t flags0; 384 uint8_t flags1; 385 uint16_t qdcount; 386 uint16_t ancount; 387 uint16_t nscount; 388 uint16_t arcount; 389 } __attribute__((__packed__)); 390 391 const char* readHeader(const char* buffer, const char* buffer_end, 392 unsigned* qdcount, unsigned* ancount, 393 unsigned* nscount, unsigned* arcount); 394 }; 395 396 const char* DNSHeader::read(const char* buffer, const char* buffer_end) { 397 unsigned qdcount; 398 unsigned ancount; 399 unsigned nscount; 400 unsigned arcount; 401 const char* cur = readHeader(buffer, buffer_end, &qdcount, &ancount, 402 &nscount, &arcount); 403 if (cur == nullptr) { 404 ALOGI("parsing failed at line %d", __LINE__); 405 return nullptr; 406 } 407 if (qdcount) { 408 questions.resize(qdcount); 409 for (unsigned i = 0 ; i < qdcount ; ++i) { 410 cur = questions[i].read(cur, buffer_end); 411 if (cur == nullptr) { 412 ALOGI("parsing failed at line %d", __LINE__); 413 return nullptr; 414 } 415 } 416 } 417 if (ancount) { 418 answers.resize(ancount); 419 for (unsigned i = 0 ; i < ancount ; ++i) { 420 cur = answers[i].read(cur, buffer_end); 421 if (cur == nullptr) { 422 ALOGI("parsing failed at line %d", __LINE__); 423 return nullptr; 424 } 425 } 426 } 427 if (nscount) { 428 authorities.resize(nscount); 429 for (unsigned i = 0 ; i < nscount ; ++i) { 430 cur = authorities[i].read(cur, buffer_end); 431 if (cur == nullptr) { 432 ALOGI("parsing failed at line %d", __LINE__); 433 return nullptr; 434 } 435 } 436 } 437 if (arcount) { 438 additionals.resize(arcount); 439 for (unsigned i = 0 ; i < arcount ; ++i) { 440 cur = additionals[i].read(cur, buffer_end); 441 if (cur == nullptr) { 442 ALOGI("parsing failed at line %d", __LINE__); 443 return nullptr; 444 } 445 } 446 } 447 return cur; 448 } 449 450 char* DNSHeader::write(char* buffer, const char* buffer_end) const { 451 if (buffer + sizeof(Header) > buffer_end) { 452 ALOGI("buffer overflow on line %d", __LINE__); 453 return nullptr; 454 } 455 Header& header = *reinterpret_cast<Header*>(buffer); 456 // bytes 0-1 457 header.id = htons(id); 458 // byte 2: 7:qr, 3-6:opcode, 2:aa, 1:tr, 0:rd 459 header.flags0 = (qr << 7) | (opcode << 3) | (aa << 2) | (tr << 1) | rd; 460 // byte 3: 7:ra, 6:zero, 5:ad, 4:cd, 0-3:rcode 461 header.flags1 = rcode; 462 // rest of header 463 header.qdcount = htons(questions.size()); 464 header.ancount = htons(answers.size()); 465 header.nscount = htons(authorities.size()); 466 header.arcount = htons(additionals.size()); 467 char* buffer_cur = buffer + sizeof(Header); 468 for (const DNSQuestion& question : questions) { 469 buffer_cur = question.write(buffer_cur, buffer_end); 470 if (buffer_cur == nullptr) return nullptr; 471 } 472 for (const DNSRecord& answer : answers) { 473 buffer_cur = answer.write(buffer_cur, buffer_end); 474 if (buffer_cur == nullptr) return nullptr; 475 } 476 for (const DNSRecord& authority : authorities) { 477 buffer_cur = authority.write(buffer_cur, buffer_end); 478 if (buffer_cur == nullptr) return nullptr; 479 } 480 for (const DNSRecord& additional : additionals) { 481 buffer_cur = additional.write(buffer_cur, buffer_end); 482 if (buffer_cur == nullptr) return nullptr; 483 } 484 return buffer_cur; 485 } 486 487 std::string DNSHeader::toString() const { 488 // TODO 489 return std::string(); 490 } 491 492 const char* DNSHeader::readHeader(const char* buffer, const char* buffer_end, 493 unsigned* qdcount, unsigned* ancount, 494 unsigned* nscount, unsigned* arcount) { 495 if (buffer + sizeof(Header) > buffer_end) 496 return 0; 497 const auto& header = *reinterpret_cast<const Header*>(buffer); 498 // bytes 0-1 499 id = ntohs(header.id); 500 // byte 2: 7:qr, 3-6:opcode, 2:aa, 1:tr, 0:rd 501 qr = header.flags0 >> 7; 502 opcode = (header.flags0 >> 3) & 0x0F; 503 aa = (header.flags0 >> 2) & 1; 504 tr = (header.flags0 >> 1) & 1; 505 rd = header.flags0 & 1; 506 // byte 3: 7:ra, 6:zero, 5:ad, 4:cd, 0-3:rcode 507 ra = header.flags1 >> 7; 508 rcode = header.flags1 & 0xF; 509 // rest of header 510 *qdcount = ntohs(header.qdcount); 511 *ancount = ntohs(header.ancount); 512 *nscount = ntohs(header.nscount); 513 *arcount = ntohs(header.arcount); 514 return buffer + sizeof(Header); 515 } 516 517 /* DNS responder */ 518 519 DNSResponder::DNSResponder(std::string listen_address, 520 std::string listen_service, int poll_timeout_ms, 521 uint16_t error_rcode, double response_probability) : 522 listen_address_(std::move(listen_address)), listen_service_(std::move(listen_service)), 523 poll_timeout_ms_(poll_timeout_ms), error_rcode_(error_rcode), 524 response_probability_(response_probability), 525 socket_(-1), epoll_fd_(-1), terminate_(false) { } 526 527 DNSResponder::~DNSResponder() { 528 stopServer(); 529 } 530 531 void DNSResponder::addMapping(const char* name, ns_type type, 532 const char* addr) { 533 std::lock_guard<std::mutex> lock(mappings_mutex_); 534 auto it = mappings_.find(QueryKey(name, type)); 535 if (it != mappings_.end()) { 536 ALOGI("Overwriting mapping for (%s, %s), previous address %s, new " 537 "address %s", name, dnstype2str(type), it->second.c_str(), 538 addr); 539 it->second = addr; 540 return; 541 } 542 mappings_.emplace(std::piecewise_construct, 543 std::forward_as_tuple(name, type), 544 std::forward_as_tuple(addr)); 545 } 546 547 void DNSResponder::removeMapping(const char* name, ns_type type) { 548 std::lock_guard<std::mutex> lock(mappings_mutex_); 549 auto it = mappings_.find(QueryKey(name, type)); 550 if (it != mappings_.end()) { 551 ALOGI("Cannot remove mapping mapping from (%s, %s), not present", name, 552 dnstype2str(type)); 553 return; 554 } 555 mappings_.erase(it); 556 } 557 558 void DNSResponder::setResponseProbability(double response_probability) { 559 response_probability_ = response_probability; 560 } 561 562 bool DNSResponder::running() const { 563 return socket_ != -1; 564 } 565 566 bool DNSResponder::startServer() { 567 if (running()) { 568 ALOGI("server already running"); 569 return false; 570 } 571 addrinfo ai_hints{ 572 .ai_family = AF_UNSPEC, 573 .ai_socktype = SOCK_DGRAM, 574 .ai_flags = AI_PASSIVE 575 }; 576 addrinfo* ai_res; 577 int rv = getaddrinfo(listen_address_.c_str(), listen_service_.c_str(), 578 &ai_hints, &ai_res); 579 if (rv) { 580 ALOGI("getaddrinfo(%s, %s) failed: %s", listen_address_.c_str(), 581 listen_service_.c_str(), gai_strerror(rv)); 582 return false; 583 } 584 int s = -1; 585 for (const addrinfo* ai = ai_res ; ai ; ai = ai->ai_next) { 586 s = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol); 587 if (s < 0) continue; 588 const int one = 1; 589 setsockopt(s, SOL_SOCKET, SO_REUSEPORT, &one, sizeof(one)); 590 if (bind(s, ai->ai_addr, ai->ai_addrlen)) { 591 APLOGI("bind failed for socket %d", s); 592 close(s); 593 s = -1; 594 continue; 595 } 596 std::string host_str = addr2str(ai->ai_addr, ai->ai_addrlen); 597 ALOGI("bound to UDP %s:%s", host_str.c_str(), listen_service_.c_str()); 598 break; 599 } 600 freeaddrinfo(ai_res); 601 if (s < 0) { 602 ALOGI("bind() failed"); 603 return false; 604 } 605 606 int flags = fcntl(s, F_GETFL, 0); 607 if (flags < 0) flags = 0; 608 if (fcntl(s, F_SETFL, flags | O_NONBLOCK) < 0) { 609 APLOGI("fcntl(F_SETFL) failed for socket %d", s); 610 close(s); 611 return false; 612 } 613 614 int ep_fd = epoll_create(1); 615 if (ep_fd < 0) { 616 char error_msg[512] = { 0 }; 617 if (strerror_r(errno, error_msg, sizeof(error_msg))) 618 strncpy(error_msg, "UNKNOWN", sizeof(error_msg)); 619 APLOGI("epoll_create() failed: %s", error_msg); 620 close(s); 621 return false; 622 } 623 epoll_event ev; 624 ev.events = EPOLLIN; 625 ev.data.fd = s; 626 if (epoll_ctl(ep_fd, EPOLL_CTL_ADD, s, &ev) < 0) { 627 APLOGI("epoll_ctl() failed for socket %d", s); 628 close(ep_fd); 629 close(s); 630 return false; 631 } 632 633 epoll_fd_ = ep_fd; 634 socket_ = s; 635 { 636 std::lock_guard<std::mutex> lock(update_mutex_); 637 handler_thread_ = std::thread(&DNSResponder::requestHandler, this); 638 } 639 ALOGI("server started successfully"); 640 return true; 641 } 642 643 bool DNSResponder::stopServer() { 644 std::lock_guard<std::mutex> lock(update_mutex_); 645 if (!running()) { 646 ALOGI("server not running"); 647 return false; 648 } 649 if (terminate_) { 650 ALOGI("LOGIC ERROR"); 651 return false; 652 } 653 ALOGI("stopping server"); 654 terminate_ = true; 655 handler_thread_.join(); 656 close(epoll_fd_); 657 close(socket_); 658 terminate_ = false; 659 socket_ = -1; 660 ALOGI("server stopped successfully"); 661 return true; 662 } 663 664 std::vector<std::pair<std::string, ns_type >> DNSResponder::queries() const { 665 std::lock_guard<std::mutex> lock(queries_mutex_); 666 return queries_; 667 } 668 669 void DNSResponder::clearQueries() { 670 std::lock_guard<std::mutex> lock(queries_mutex_); 671 queries_.clear(); 672 } 673 674 void DNSResponder::requestHandler() { 675 epoll_event evs[1]; 676 while (!terminate_) { 677 int n = epoll_wait(epoll_fd_, evs, 1, poll_timeout_ms_); 678 if (n == 0) continue; 679 if (n < 0) { 680 ALOGI("epoll_wait() failed"); 681 // TODO(imaipi): terminate on error. 682 return; 683 } 684 char buffer[4096]; 685 sockaddr_storage sa; 686 socklen_t sa_len = sizeof(sa); 687 ssize_t len; 688 do { 689 len = recvfrom(socket_, buffer, sizeof(buffer), 0, 690 (sockaddr*) &sa, &sa_len); 691 } while (len < 0 && (errno == EAGAIN || errno == EINTR)); 692 if (len <= 0) { 693 ALOGI("recvfrom() failed"); 694 continue; 695 } 696 ALOGI("read %zd bytes", len); 697 char response[4096]; 698 size_t response_len = sizeof(response); 699 if (handleDNSRequest(buffer, len, response, &response_len) && 700 response_len > 0) { 701 len = sendto(socket_, response, response_len, 0, 702 reinterpret_cast<const sockaddr*>(&sa), sa_len); 703 std::string host_str = 704 addr2str(reinterpret_cast<const sockaddr*>(&sa), sa_len); 705 if (len > 0) { 706 ALOGI("sent %zu bytes to %s", len, host_str.c_str()); 707 } else { 708 APLOGI("sendto() failed for %s", host_str.c_str()); 709 } 710 // Test that the response is actually a correct DNS message. 711 const char* response_end = response + len; 712 DNSHeader header; 713 const char* cur = header.read(response, response_end); 714 if (cur == nullptr) ALOGI("response is flawed"); 715 716 } else { 717 ALOGI("not responding"); 718 } 719 } 720 } 721 722 bool DNSResponder::handleDNSRequest(const char* buffer, ssize_t len, 723 char* response, size_t* response_len) 724 const { 725 ALOGI("request: '%s'", str2hex(buffer, len).c_str()); 726 const char* buffer_end = buffer + len; 727 DNSHeader header; 728 const char* cur = header.read(buffer, buffer_end); 729 // TODO(imaipi): for now, unparsable messages are silently dropped, fix. 730 if (cur == nullptr) { 731 ALOGI("failed to parse query"); 732 return false; 733 } 734 if (header.qr) { 735 ALOGI("response received instead of a query"); 736 return false; 737 } 738 if (header.opcode != ns_opcode::ns_o_query) { 739 ALOGI("unsupported request opcode received"); 740 return makeErrorResponse(&header, ns_rcode::ns_r_notimpl, response, 741 response_len); 742 } 743 if (header.questions.empty()) { 744 ALOGI("no questions present"); 745 return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response, 746 response_len); 747 } 748 if (!header.answers.empty()) { 749 ALOGI("already %zu answers present in query", header.answers.size()); 750 return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response, 751 response_len); 752 } 753 { 754 std::lock_guard<std::mutex> lock(queries_mutex_); 755 for (const DNSQuestion& question : header.questions) { 756 queries_.push_back(make_pair(question.qname.name, 757 ns_type(question.qtype))); 758 } 759 } 760 761 // Ignore requests with the preset probability. 762 auto constexpr bound = std::numeric_limits<unsigned>::max(); 763 if (arc4random_uniform(bound) > bound*response_probability_) { 764 ALOGI("returning SRVFAIL in accordance with probability distribution"); 765 return makeErrorResponse(&header, ns_rcode::ns_r_servfail, response, 766 response_len); 767 } 768 769 for (const DNSQuestion& question : header.questions) { 770 if (question.qclass != ns_class::ns_c_in && 771 question.qclass != ns_class::ns_c_any) { 772 ALOGI("unsupported question class %u", question.qclass); 773 return makeErrorResponse(&header, ns_rcode::ns_r_notimpl, response, 774 response_len); 775 } 776 if (!addAnswerRecords(question, &header.answers)) { 777 return makeErrorResponse(&header, ns_rcode::ns_r_servfail, response, 778 response_len); 779 } 780 } 781 header.qr = true; 782 char* response_cur = header.write(response, response + *response_len); 783 if (response_cur == nullptr) { 784 return false; 785 } 786 *response_len = response_cur - response; 787 return true; 788 } 789 790 bool DNSResponder::addAnswerRecords(const DNSQuestion& question, 791 std::vector<DNSRecord>* answers) const { 792 auto it = mappings_.find(QueryKey(question.qname.name, question.qtype)); 793 if (it == mappings_.end()) { 794 // TODO(imaipi): handle correctly 795 ALOGI("no mapping found for %s %s, lazily refusing to add an answer", 796 question.qname.name.c_str(), dnstype2str(question.qtype)); 797 return true; 798 } 799 ALOGI("mapping found for %s %s: %s", question.qname.name.c_str(), 800 dnstype2str(question.qtype), it->second.c_str()); 801 DNSRecord record; 802 record.name = question.qname; 803 record.rtype = question.qtype; 804 record.rclass = ns_class::ns_c_in; 805 record.ttl = 5; // seconds 806 if (question.qtype == ns_type::ns_t_a) { 807 record.rdata.resize(4); 808 if (inet_pton(AF_INET, it->second.c_str(), record.rdata.data()) != 1) { 809 ALOGI("inet_pton(AF_INET, %s) failed", it->second.c_str()); 810 return false; 811 } 812 } else if (question.qtype == ns_type::ns_t_aaaa) { 813 record.rdata.resize(16); 814 if (inet_pton(AF_INET6, it->second.c_str(), record.rdata.data()) != 1) { 815 ALOGI("inet_pton(AF_INET6, %s) failed", it->second.c_str()); 816 return false; 817 } 818 } else { 819 ALOGI("unhandled qtype %s", dnstype2str(question.qtype)); 820 return false; 821 } 822 answers->push_back(std::move(record)); 823 return true; 824 } 825 826 bool DNSResponder::makeErrorResponse(DNSHeader* header, ns_rcode rcode, 827 char* response, size_t* response_len) 828 const { 829 header->answers.clear(); 830 header->authorities.clear(); 831 header->additionals.clear(); 832 header->rcode = rcode; 833 header->qr = true; 834 char* response_cur = header->write(response, response + *response_len); 835 if (response_cur == nullptr) return false; 836 *response_len = response_cur - response; 837 return true; 838 } 839 840 } // namespace test 841 842