Home | History | Annotate | Download | only in legacynet
      1 #include <stdio.h>
      2 #include <string.h>
      3 #include <core.h>
      4 #include "pxe.h"
      5 
      6 /* DNS CLASS values we care about */
      7 #define CLASS_IN	1
      8 
      9 /* DNS TYPE values we care about */
     10 #define TYPE_A		1
     11 #define TYPE_CNAME	5
     12 
     13 /*
     14  * The DNS header structure
     15  */
     16 struct dnshdr {
     17     uint16_t id;
     18     uint16_t flags;
     19     /* number of entries in the question section */
     20     uint16_t qdcount;
     21     /* number of resource records in the answer section */
     22     uint16_t ancount;
     23     /* number of name server resource records in the authority records section*/
     24     uint16_t nscount;
     25     /* number of resource records in the additional records section */
     26     uint16_t arcount;
     27 } __attribute__ ((packed));
     28 
     29 /*
     30  * The DNS query structure
     31  */
     32 struct dnsquery {
     33     uint16_t qtype;
     34     uint16_t qclass;
     35 } __attribute__ ((packed));
     36 
     37 /*
     38  * The DNS Resource recodes structure
     39  */
     40 struct dnsrr {
     41     uint16_t type;
     42     uint16_t class;
     43     uint32_t ttl;
     44     uint16_t rdlength;   /* The lenght of this rr data */
     45     char     rdata[];
     46 } __attribute__ ((packed));
     47 
     48 
     49 #define DNS_PORT	htons(53)               /* Default DNS port */
     50 #define DNS_MAX_SERVERS 4		/* Max no of DNS servers */
     51 
     52 uint32_t dns_server[DNS_MAX_SERVERS] = {0, };
     53 
     54 
     55 /*
     56  * Turn a string in _src_ into a DNS "label set" in _dst_; returns the
     57  * number of dots encountered. On return, *dst is updated.
     58  */
     59 int dns_mangle(char **dst, const char *p)
     60 {
     61     char *q = *dst;
     62     char *count_ptr;
     63     char c;
     64     int dots = 0;
     65 
     66     count_ptr = q;
     67     *q++ = 0;
     68 
     69     while (1) {
     70         c = *p++;
     71         if (c == 0 || c == ':' || c == '/')
     72             break;
     73         if (c == '.') {
     74             dots++;
     75             count_ptr = q;
     76             *q++ = 0;
     77             continue;
     78         }
     79 
     80         *count_ptr += 1;
     81         *q++ = c;
     82     }
     83 
     84     if (*count_ptr)
     85         *q++ = 0;
     86 
     87     /* update the strings */
     88     *dst = q;
     89     return dots;
     90 }
     91 
     92 
     93 /*
     94  * Compare two sets of DNS labels, in _s1_ and _s2_; the one in _s2_
     95  * is allowed pointers relative to a packet in buf.
     96  *
     97  */
     98 static bool dns_compare(const void *s1, const void *s2, const void *buf)
     99 {
    100     const uint8_t *q = s1;
    101     const uint8_t *p = s2;
    102     unsigned int c0, c1;
    103 
    104     while (1) {
    105 	c0 = p[0];
    106         if (c0 >= 0xc0) {
    107 	    /* Follow pointer */
    108 	    c1 = p[1];
    109 	    p = (const uint8_t *)buf + ((c0 - 0xc0) << 8) + c1;
    110 	} else if (c0) {
    111 	    c0++;		/* Include the length byte */
    112 	    if (memcmp(q, p, c0))
    113 		return false;
    114 	    q += c0;
    115 	    p += c0;
    116 	} else {
    117 	    return *q == 0;
    118 	}
    119     }
    120 }
    121 
    122 /*
    123  * Copy a DNS label into a buffer, considering the possibility that we might
    124  * have to follow pointers relative to "buf".
    125  * Returns a pointer to the first free byte *after* the terminal null.
    126  */
    127 static void *dns_copylabel(void *dst, const void *src, const void *buf)
    128 {
    129     uint8_t *q = dst;
    130     const uint8_t *p = src;
    131     unsigned int c0, c1;
    132 
    133     while (1) {
    134 	c0 = p[0];
    135         if (c0 >= 0xc0) {
    136 	    /* Follow pointer */
    137 	    c1 = p[1];
    138 	    p = (const uint8_t *)buf + ((c0 - 0xc0) << 8) + c1;
    139 	} else if (c0) {
    140 	    c0++;		/* Include the length byte */
    141 	    memcpy(q, p, c0);
    142 	    p += c0;
    143 	    q += c0;
    144 	} else {
    145 	    *q++ = 0;
    146 	    return q;
    147 	}
    148     }
    149 }
    150 
    151 /*
    152  * Skip past a DNS label set in DS:SI
    153  */
    154 static char *dns_skiplabel(char *label)
    155 {
    156     uint8_t c;
    157 
    158     while (1) {
    159         c = *label++;
    160         if (c >= 0xc0)
    161             return ++label; /* pointer is two bytes */
    162         if (c == 0)
    163             return label;
    164         label += c;
    165     }
    166 }
    167 
    168 extern const uint8_t TimeoutTable[];
    169 extern uint16_t get_port(void);
    170 extern void free_port(uint16_t port);
    171 
    172 /*
    173  * parse the ip_str and return the ip address with *res.
    174  * return true if the whole string was consumed and the result
    175  * was valid.
    176  *
    177  */
    178 static bool parse_dotquad(const char *ip_str, uint32_t *res)
    179 {
    180     const char *p = ip_str;
    181     uint8_t part = 0;
    182     uint32_t ip = 0;
    183     int i;
    184 
    185     for (i = 0; i < 4; i++) {
    186         while (is_digit(*p)) {
    187             part = part * 10 + *p - '0';
    188             p++;
    189         }
    190         if (i != 3 && *p != '.')
    191             return false;
    192 
    193         ip = (ip << 8) | part;
    194         part = 0;
    195         p++;
    196     }
    197     p--;
    198 
    199     *res = htonl(ip);
    200     return *p == '\0';
    201 }
    202 
    203 /*
    204  * Actual resolver function
    205  * Points to a null-terminated or :-terminated string in _name_
    206  * and returns the ip addr in _ip_ if it exists and can be found.
    207  * If _ip_ = 0 on exit, the lookup failed. _name_ will be updated
    208  *
    209  * XXX: probably need some caching here.
    210  */
    211 __export uint32_t dns_resolv(const char *name)
    212 {
    213     static char __lowmem DNSSendBuf[PKTBUF_SIZE];
    214     static char __lowmem DNSRecvBuf[PKTBUF_SIZE];
    215     char *p;
    216     int err;
    217     int dots;
    218     int same;
    219     int rd_len;
    220     int ques, reps;    /* number of questions and replies */
    221     uint8_t timeout;
    222     const uint8_t *timeout_ptr = TimeoutTable;
    223     uint32_t oldtime;
    224     uint32_t srv;
    225     uint32_t *srv_ptr;
    226     struct dnshdr *hd1 = (struct dnshdr *)DNSSendBuf;
    227     struct dnshdr *hd2 = (struct dnshdr *)DNSRecvBuf;
    228     struct dnsquery *query;
    229     struct dnsrr *rr;
    230     static __lowmem struct s_PXENV_UDP_WRITE udp_write;
    231     static __lowmem struct s_PXENV_UDP_READ  udp_read;
    232     uint16_t local_port;
    233     uint32_t result = 0;
    234 
    235     /*
    236      * Return failure on an empty input... this can happen during
    237      * some types of URL parsing, and this is the easiest place to
    238      * check for it.
    239      */
    240     if (!name || !*name)
    241 	return 0;
    242 
    243     /* If it is a valid dot quad, just return that value */
    244     if (parse_dotquad(name, &result))
    245 	return result;
    246 
    247     /* Make sure we have at least one valid DNS server */
    248     if (!dns_server[0])
    249 	return 0;
    250 
    251     /* Get a local port number */
    252     local_port = get_port();
    253 
    254     /* First, fill the DNS header struct */
    255     hd1->id++;                      /* New query ID */
    256     hd1->flags   = htons(0x0100);   /* Recursion requested */
    257     hd1->qdcount = htons(1);        /* One question */
    258     hd1->ancount = 0;               /* No answers */
    259     hd1->nscount = 0;               /* No NS */
    260     hd1->arcount = 0;               /* No AR */
    261 
    262     p = DNSSendBuf + sizeof(struct dnshdr);
    263     dots = dns_mangle(&p, name);   /* store the CNAME */
    264 
    265     if (!dots) {
    266         p--; /* Remove final null */
    267         /* Uncompressed DNS label set so it ends in null */
    268         p = stpcpy(p, LocalDomain);
    269     }
    270 
    271     /* Fill the DNS query packet */
    272     query = (struct dnsquery *)p;
    273     query->qtype  = htons(TYPE_A);
    274     query->qclass = htons(CLASS_IN);
    275     p += sizeof(struct dnsquery);
    276 
    277     /* Now send it to name server */
    278     timeout_ptr = TimeoutTable;
    279     timeout = *timeout_ptr++;
    280     srv_ptr = dns_server;
    281     while (timeout) {
    282 	srv = *srv_ptr++;
    283 	if (!srv) {
    284 	    srv_ptr = dns_server;
    285 	    srv = *srv_ptr++;
    286 	}
    287 
    288         udp_write.status      = 0;
    289         udp_write.ip          = srv;
    290         udp_write.gw          = gateway(srv);
    291         udp_write.src_port    = local_port;
    292         udp_write.dst_port    = DNS_PORT;
    293         udp_write.buffer_size = p - DNSSendBuf;
    294         udp_write.buffer      = FAR_PTR(DNSSendBuf);
    295         err = pxe_call(PXENV_UDP_WRITE, &udp_write);
    296         if (err || udp_write.status)
    297             continue;
    298 
    299         oldtime = jiffies();
    300 	do {
    301 	    if (jiffies() - oldtime >= timeout)
    302 		goto again;
    303 
    304             udp_read.status      = 0;
    305             udp_read.src_ip      = srv;
    306             udp_read.dest_ip     = IPInfo.myip;
    307             udp_read.s_port      = DNS_PORT;
    308             udp_read.d_port      = local_port;
    309             udp_read.buffer_size = PKTBUF_SIZE;
    310             udp_read.buffer      = FAR_PTR(DNSRecvBuf);
    311             err = pxe_call(PXENV_UDP_READ, &udp_read);
    312 	} while (err || udp_read.status || hd2->id != hd1->id);
    313 
    314         if ((hd2->flags ^ 0x80) & htons(0xf80f))
    315             goto badness;
    316 
    317         ques = htons(hd2->qdcount);   /* Questions */
    318         reps = htons(hd2->ancount);   /* Replies   */
    319         p = DNSRecvBuf + sizeof(struct dnshdr);
    320         while (ques--) {
    321             p = dns_skiplabel(p); /* Skip name */
    322             p += 4;               /* Skip question trailer */
    323         }
    324 
    325         /* Parse the replies */
    326         while (reps--) {
    327             same = dns_compare(DNSSendBuf + sizeof(struct dnshdr),
    328 			       p, DNSRecvBuf);
    329             p = dns_skiplabel(p);
    330             rr = (struct dnsrr *)p;
    331             rd_len = ntohs(rr->rdlength);
    332             if (same && ntohs(rr->class) == CLASS_IN) {
    333 		switch (ntohs(rr->type)) {
    334 		case TYPE_A:
    335 		    if (rd_len == 4) {
    336 			result = *(uint32_t *)rr->rdata;
    337 			goto done;
    338 		    }
    339 		    break;
    340 		case TYPE_CNAME:
    341 		    dns_copylabel(DNSSendBuf + sizeof(struct dnshdr),
    342 				  rr->rdata, DNSRecvBuf);
    343 		    /*
    344 		     * We should probably rescan the packet from the top
    345 		     * here, and technically we might have to send a whole
    346 		     * new request here...
    347 		     */
    348 		    break;
    349 		default:
    350 		    break;
    351 		}
    352 	    }
    353 
    354             /* not the one we want, try next */
    355             p += sizeof(struct dnsrr) + rd_len;
    356         }
    357 
    358     badness:
    359         /*
    360          *
    361          ; We got back no data from this server.
    362          ; Unfortunately, for a recursive, non-authoritative
    363          ; query there is no such thing as an NXDOMAIN reply,
    364          ; which technically means we can't draw any
    365          ; conclusions.  However, in practice that means the
    366          ; domain doesn't exist.  If this turns out to be a
    367          ; problem, we may want to add code to go through all
    368          ; the servers before giving up.
    369 
    370          ; If the DNS server wasn't capable of recursion, and
    371          ; isn't capable of giving us an authoritative reply
    372          ; (i.e. neither AA or RA set), then at least try a
    373          ; different setver...
    374         */
    375         if (hd2->flags == htons(0x480))
    376             continue;
    377 
    378         break; /* failed */
    379 
    380     again:
    381 	continue;
    382     }
    383 
    384 done:
    385     free_port(local_port);	/* Return port number to the free pool */
    386 
    387     return result;
    388 }
    389