Home | History | Annotate | Download | only in fake_dns
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include <arpa/inet.h>
      6 #include <errno.h>
      7 #include <netinet/in.h>
      8 #include <signal.h>
      9 #include <stdio.h>
     10 #include <stdlib.h>
     11 #include <string.h>
     12 #include <sys/socket.h>
     13 #include <sys/types.h>
     14 #include <unistd.h>
     15 
     16 #include <string>
     17 
     18 #include "base/basictypes.h"
     19 #include "base/command_line.h"
     20 #include "base/logging.h"
     21 #include "base/posix/eintr_wrapper.h"
     22 #include "base/safe_strerror_posix.h"
     23 #include "net/base/big_endian.h"
     24 #include "net/base/net_util.h"
     25 #include "net/dns/dns_protocol.h"
     26 #include "tools/android/common/daemon.h"
     27 #include "tools/android/common/net.h"
     28 
     29 namespace {
     30 
     31 // Mininum request size: 1 question containing 1 QNAME, 1 TYPE and 1 CLASS.
     32 const size_t kMinRequestSize = sizeof(net::dns_protocol::Header) + 6;
     33 
     34 // The name reference in the answer pointing to the name in the query.
     35 // Its format is: highest two bits set to 1, then the offset of the name
     36 // which just follows the header.
     37 const uint16 kPointerToQueryName =
     38     static_cast<uint16>(0xc000 | sizeof(net::dns_protocol::Header));
     39 
     40 const uint32 kTTL = 86400;  // One day.
     41 
     42 void PError(const char* msg) {
     43   int current_errno = errno;
     44   LOG(ERROR) << "ERROR: " << msg << ": " << safe_strerror(current_errno);
     45 }
     46 
     47 void SendTo(int sockfd, const void* buf, size_t len, int flags,
     48             const sockaddr* dest_addr, socklen_t addrlen) {
     49   if (HANDLE_EINTR(sendto(sockfd, buf, len, flags, dest_addr, addrlen)) == -1)
     50     PError("sendto()");
     51 }
     52 
     53 void CloseFileDescriptor(int fd) {
     54   int old_errno = errno;
     55   (void) HANDLE_EINTR(close(fd));
     56   errno = old_errno;
     57 }
     58 
     59 void SendRefusedResponse(int sock, const sockaddr_in& client_addr, uint16 id) {
     60   net::dns_protocol::Header response;
     61   response.id = htons(id);
     62   response.flags = htons(net::dns_protocol::kFlagResponse |
     63                          net::dns_protocol::kFlagAA |
     64                          net::dns_protocol::kFlagRD |
     65                          net::dns_protocol::kFlagRA |
     66                          net::dns_protocol::kRcodeREFUSED);
     67   response.qdcount = 0;
     68   response.ancount = 0;
     69   response.nscount = 0;
     70   response.arcount = 0;
     71   SendTo(sock, &response, sizeof(response), 0,
     72          reinterpret_cast<const sockaddr*>(&client_addr), sizeof(client_addr));
     73 }
     74 
     75 void SendResponse(int sock, const sockaddr_in& client_addr, uint16 id,
     76                   uint16 qtype, const char* question, size_t question_length) {
     77   net::dns_protocol::Header header;
     78   header.id = htons(id);
     79   header.flags = htons(net::dns_protocol::kFlagResponse |
     80                        net::dns_protocol::kFlagAA |
     81                        net::dns_protocol::kFlagRD |
     82                        net::dns_protocol::kFlagRA |
     83                        net::dns_protocol::kRcodeNOERROR);
     84   header.qdcount = htons(1);
     85   header.ancount = htons(1);
     86   header.nscount = 0;
     87   header.arcount = 0;
     88 
     89   // Size of RDATA which is a IPv4 or IPv6 address.
     90   size_t rdata_size = qtype == net::dns_protocol::kTypeA ?
     91                       net::kIPv4AddressSize : net::kIPv6AddressSize;
     92 
     93   // Size of the whole response which contains the header, the question and
     94   // the answer. 12 is the sum of sizes of the compressed name reference, TYPE,
     95   // CLASS, TTL and RDLENGTH.
     96   size_t response_size = sizeof(header) + question_length + 12 + rdata_size;
     97 
     98   if (response_size > net::dns_protocol::kMaxUDPSize) {
     99     LOG(ERROR) << "Response is too large: " << response_size;
    100     SendRefusedResponse(sock, client_addr, id);
    101     return;
    102   }
    103 
    104   char response[net::dns_protocol::kMaxUDPSize];
    105   net::BigEndianWriter writer(response, arraysize(response));
    106   writer.WriteBytes(&header, sizeof(header));
    107 
    108   // Repeat the question in the response. Some clients (e.g. ping) needs this.
    109   writer.WriteBytes(question, question_length);
    110 
    111   // Construct the answer.
    112   writer.WriteU16(kPointerToQueryName);
    113   writer.WriteU16(qtype);
    114   writer.WriteU16(net::dns_protocol::kClassIN);
    115   writer.WriteU32(kTTL);
    116   writer.WriteU16(rdata_size);
    117   if (qtype == net::dns_protocol::kTypeA)
    118     writer.WriteU32(INADDR_LOOPBACK);
    119   else
    120     writer.WriteBytes(&in6addr_loopback, sizeof(in6_addr));
    121   DCHECK(writer.ptr() - response == response_size);
    122 
    123   SendTo(sock, response, response_size, 0,
    124          reinterpret_cast<const sockaddr*>(&client_addr), sizeof(client_addr));
    125 }
    126 
    127 void HandleRequest(int sock, const char* request, size_t size,
    128                    const sockaddr_in& client_addr) {
    129   if (size < kMinRequestSize) {
    130     LOG(ERROR) << "Request is too small " << size
    131                << "\n" << tools::DumpBinary(request, size);
    132     return;
    133   }
    134 
    135   net::BigEndianReader reader(request, size);
    136   net::dns_protocol::Header header;
    137   reader.ReadBytes(&header, sizeof(header));
    138   uint16 id = ntohs(header.id);
    139   uint16 flags = ntohs(header.flags);
    140   uint16 qdcount = ntohs(header.qdcount);
    141   uint16 ancount = ntohs(header.ancount);
    142   uint16 nscount = ntohs(header.nscount);
    143   uint16 arcount = ntohs(header.arcount);
    144 
    145   const uint16 kAllowedFlags = 0x07ff;
    146   if ((flags & ~kAllowedFlags) ||
    147       qdcount != 1 || ancount || nscount || arcount) {
    148     LOG(ERROR) << "Unsupported request: FLAGS=" << flags
    149                << " QDCOUNT=" << qdcount
    150                << " ANCOUNT=" << ancount
    151                << " NSCOUNT=" << nscount
    152                << " ARCOUNT=" << arcount
    153                << "\n" << tools::DumpBinary(request, size);
    154     SendRefusedResponse(sock, client_addr, id);
    155     return;
    156   }
    157 
    158   // request[size - 5] should be the end of the QNAME (a zero byte).
    159   // We don't care about the validity of QNAME because we don't parse it.
    160   const char* qname_end = &request[size - 5];
    161   if (*qname_end) {
    162     LOG(ERROR) << "Error parsing QNAME\n" << tools::DumpBinary(request, size);
    163     SendRefusedResponse(sock, client_addr, id);
    164     return;
    165   }
    166 
    167   reader.Skip(qname_end - reader.ptr() + 1);
    168 
    169   uint16 qtype;
    170   uint16 qclass;
    171   reader.ReadU16(&qtype);
    172   reader.ReadU16(&qclass);
    173   if ((qtype != net::dns_protocol::kTypeA &&
    174        qtype != net::dns_protocol::kTypeAAAA) ||
    175       qclass != net::dns_protocol::kClassIN) {
    176     LOG(ERROR) << "Unsupported query: QTYPE=" << qtype << " QCLASS=" << qclass
    177                << "\n" << tools::DumpBinary(request, size);
    178     SendRefusedResponse(sock, client_addr, id);
    179     return;
    180   }
    181 
    182   SendResponse(sock, client_addr, id, qtype,
    183                request + sizeof(header), size - sizeof(header));
    184 }
    185 
    186 }  // namespace
    187 
    188 int main(int argc, char** argv) {
    189   printf("Fake DNS server\n");
    190 
    191   CommandLine command_line(argc, argv);
    192   if (tools::HasHelpSwitch(command_line) || command_line.GetArgs().size()) {
    193     tools::ShowHelp(argv[0], "", "");
    194     return 0;
    195   }
    196 
    197   int sock = socket(AF_INET, SOCK_DGRAM, 0);
    198   if (sock < 0) {
    199     PError("create socket");
    200     return 1;
    201   }
    202 
    203   sockaddr_in addr;
    204   memset(&addr, 0, sizeof(addr));
    205   addr.sin_family = AF_INET;
    206   addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
    207   addr.sin_port = htons(53);
    208   int reuse_addr = 1;
    209   if (HANDLE_EINTR(bind(sock, reinterpret_cast<sockaddr*>(&addr),
    210                         sizeof(addr))) < 0) {
    211     PError("server bind");
    212     CloseFileDescriptor(sock);
    213     return 1;
    214   }
    215 
    216   if (!tools::HasNoSpawnDaemonSwitch(command_line))
    217     tools::SpawnDaemon(0);
    218 
    219   while (true) {
    220     sockaddr_in client_addr;
    221     socklen_t client_addr_len = sizeof(client_addr);
    222     char request[net::dns_protocol::kMaxUDPSize];
    223     int size = HANDLE_EINTR(recvfrom(sock, request, sizeof(request),
    224                                      MSG_WAITALL,
    225                                      reinterpret_cast<sockaddr*>(&client_addr),
    226                                      &client_addr_len));
    227     if (size < 0) {
    228       // Unrecoverable error, can only exit.
    229       LOG(ERROR) << "Failed to receive a request: " << strerror(errno);
    230       CloseFileDescriptor(sock);
    231       return 1;
    232     }
    233 
    234     if (size > 0)
    235       HandleRequest(sock, request, size, client_addr);
    236   }
    237 }
    238 
    239