1 #include <stdio.h> 2 #include <string.h> 3 #include <core.h> 4 #include "pxe.h" 5 6 /* DNS CLASS values we care about */ 7 #define CLASS_IN 1 8 9 /* DNS TYPE values we care about */ 10 #define TYPE_A 1 11 #define TYPE_CNAME 5 12 13 /* 14 * The DNS header structure 15 */ 16 struct dnshdr { 17 uint16_t id; 18 uint16_t flags; 19 /* number of entries in the question section */ 20 uint16_t qdcount; 21 /* number of resource records in the answer section */ 22 uint16_t ancount; 23 /* number of name server resource records in the authority records section*/ 24 uint16_t nscount; 25 /* number of resource records in the additional records section */ 26 uint16_t arcount; 27 } __attribute__ ((packed)); 28 29 /* 30 * The DNS query structure 31 */ 32 struct dnsquery { 33 uint16_t qtype; 34 uint16_t qclass; 35 } __attribute__ ((packed)); 36 37 /* 38 * The DNS Resource recodes structure 39 */ 40 struct dnsrr { 41 uint16_t type; 42 uint16_t class; 43 uint32_t ttl; 44 uint16_t rdlength; /* The lenght of this rr data */ 45 char rdata[]; 46 } __attribute__ ((packed)); 47 48 49 #define DNS_PORT htons(53) /* Default DNS port */ 50 #define DNS_MAX_SERVERS 4 /* Max no of DNS servers */ 51 52 uint32_t dns_server[DNS_MAX_SERVERS] = {0, }; 53 54 55 /* 56 * Turn a string in _src_ into a DNS "label set" in _dst_; returns the 57 * number of dots encountered. On return, *dst is updated. 58 */ 59 int dns_mangle(char **dst, const char *p) 60 { 61 char *q = *dst; 62 char *count_ptr; 63 char c; 64 int dots = 0; 65 66 count_ptr = q; 67 *q++ = 0; 68 69 while (1) { 70 c = *p++; 71 if (c == 0 || c == ':' || c == '/') 72 break; 73 if (c == '.') { 74 dots++; 75 count_ptr = q; 76 *q++ = 0; 77 continue; 78 } 79 80 *count_ptr += 1; 81 *q++ = c; 82 } 83 84 if (*count_ptr) 85 *q++ = 0; 86 87 /* update the strings */ 88 *dst = q; 89 return dots; 90 } 91 92 93 /* 94 * Compare two sets of DNS labels, in _s1_ and _s2_; the one in _s2_ 95 * is allowed pointers relative to a packet in buf. 96 * 97 */ 98 static bool dns_compare(const void *s1, const void *s2, const void *buf) 99 { 100 const uint8_t *q = s1; 101 const uint8_t *p = s2; 102 unsigned int c0, c1; 103 104 while (1) { 105 c0 = p[0]; 106 if (c0 >= 0xc0) { 107 /* Follow pointer */ 108 c1 = p[1]; 109 p = (const uint8_t *)buf + ((c0 - 0xc0) << 8) + c1; 110 } else if (c0) { 111 c0++; /* Include the length byte */ 112 if (memcmp(q, p, c0)) 113 return false; 114 q += c0; 115 p += c0; 116 } else { 117 return *q == 0; 118 } 119 } 120 } 121 122 /* 123 * Copy a DNS label into a buffer, considering the possibility that we might 124 * have to follow pointers relative to "buf". 125 * Returns a pointer to the first free byte *after* the terminal null. 126 */ 127 static void *dns_copylabel(void *dst, const void *src, const void *buf) 128 { 129 uint8_t *q = dst; 130 const uint8_t *p = src; 131 unsigned int c0, c1; 132 133 while (1) { 134 c0 = p[0]; 135 if (c0 >= 0xc0) { 136 /* Follow pointer */ 137 c1 = p[1]; 138 p = (const uint8_t *)buf + ((c0 - 0xc0) << 8) + c1; 139 } else if (c0) { 140 c0++; /* Include the length byte */ 141 memcpy(q, p, c0); 142 p += c0; 143 q += c0; 144 } else { 145 *q++ = 0; 146 return q; 147 } 148 } 149 } 150 151 /* 152 * Skip past a DNS label set in DS:SI 153 */ 154 static char *dns_skiplabel(char *label) 155 { 156 uint8_t c; 157 158 while (1) { 159 c = *label++; 160 if (c >= 0xc0) 161 return ++label; /* pointer is two bytes */ 162 if (c == 0) 163 return label; 164 label += c; 165 } 166 } 167 168 extern const uint8_t TimeoutTable[]; 169 extern uint16_t get_port(void); 170 extern void free_port(uint16_t port); 171 172 /* 173 * parse the ip_str and return the ip address with *res. 174 * return true if the whole string was consumed and the result 175 * was valid. 176 * 177 */ 178 static bool parse_dotquad(const char *ip_str, uint32_t *res) 179 { 180 const char *p = ip_str; 181 uint8_t part = 0; 182 uint32_t ip = 0; 183 int i; 184 185 for (i = 0; i < 4; i++) { 186 while (is_digit(*p)) { 187 part = part * 10 + *p - '0'; 188 p++; 189 } 190 if (i != 3 && *p != '.') 191 return false; 192 193 ip = (ip << 8) | part; 194 part = 0; 195 p++; 196 } 197 p--; 198 199 *res = htonl(ip); 200 return *p == '\0'; 201 } 202 203 /* 204 * Actual resolver function 205 * Points to a null-terminated or :-terminated string in _name_ 206 * and returns the ip addr in _ip_ if it exists and can be found. 207 * If _ip_ = 0 on exit, the lookup failed. _name_ will be updated 208 * 209 * XXX: probably need some caching here. 210 */ 211 __export uint32_t dns_resolv(const char *name) 212 { 213 static char __lowmem DNSSendBuf[PKTBUF_SIZE]; 214 static char __lowmem DNSRecvBuf[PKTBUF_SIZE]; 215 char *p; 216 int err; 217 int dots; 218 int same; 219 int rd_len; 220 int ques, reps; /* number of questions and replies */ 221 uint8_t timeout; 222 const uint8_t *timeout_ptr = TimeoutTable; 223 uint32_t oldtime; 224 uint32_t srv; 225 uint32_t *srv_ptr; 226 struct dnshdr *hd1 = (struct dnshdr *)DNSSendBuf; 227 struct dnshdr *hd2 = (struct dnshdr *)DNSRecvBuf; 228 struct dnsquery *query; 229 struct dnsrr *rr; 230 static __lowmem struct s_PXENV_UDP_WRITE udp_write; 231 static __lowmem struct s_PXENV_UDP_READ udp_read; 232 uint16_t local_port; 233 uint32_t result = 0; 234 235 /* 236 * Return failure on an empty input... this can happen during 237 * some types of URL parsing, and this is the easiest place to 238 * check for it. 239 */ 240 if (!name || !*name) 241 return 0; 242 243 /* If it is a valid dot quad, just return that value */ 244 if (parse_dotquad(name, &result)) 245 return result; 246 247 /* Make sure we have at least one valid DNS server */ 248 if (!dns_server[0]) 249 return 0; 250 251 /* Get a local port number */ 252 local_port = get_port(); 253 254 /* First, fill the DNS header struct */ 255 hd1->id++; /* New query ID */ 256 hd1->flags = htons(0x0100); /* Recursion requested */ 257 hd1->qdcount = htons(1); /* One question */ 258 hd1->ancount = 0; /* No answers */ 259 hd1->nscount = 0; /* No NS */ 260 hd1->arcount = 0; /* No AR */ 261 262 p = DNSSendBuf + sizeof(struct dnshdr); 263 dots = dns_mangle(&p, name); /* store the CNAME */ 264 265 if (!dots) { 266 p--; /* Remove final null */ 267 /* Uncompressed DNS label set so it ends in null */ 268 p = stpcpy(p, LocalDomain); 269 } 270 271 /* Fill the DNS query packet */ 272 query = (struct dnsquery *)p; 273 query->qtype = htons(TYPE_A); 274 query->qclass = htons(CLASS_IN); 275 p += sizeof(struct dnsquery); 276 277 /* Now send it to name server */ 278 timeout_ptr = TimeoutTable; 279 timeout = *timeout_ptr++; 280 srv_ptr = dns_server; 281 while (timeout) { 282 srv = *srv_ptr++; 283 if (!srv) { 284 srv_ptr = dns_server; 285 srv = *srv_ptr++; 286 } 287 288 udp_write.status = 0; 289 udp_write.ip = srv; 290 udp_write.gw = gateway(srv); 291 udp_write.src_port = local_port; 292 udp_write.dst_port = DNS_PORT; 293 udp_write.buffer_size = p - DNSSendBuf; 294 udp_write.buffer = FAR_PTR(DNSSendBuf); 295 err = pxe_call(PXENV_UDP_WRITE, &udp_write); 296 if (err || udp_write.status) 297 continue; 298 299 oldtime = jiffies(); 300 do { 301 if (jiffies() - oldtime >= timeout) 302 goto again; 303 304 udp_read.status = 0; 305 udp_read.src_ip = srv; 306 udp_read.dest_ip = IPInfo.myip; 307 udp_read.s_port = DNS_PORT; 308 udp_read.d_port = local_port; 309 udp_read.buffer_size = PKTBUF_SIZE; 310 udp_read.buffer = FAR_PTR(DNSRecvBuf); 311 err = pxe_call(PXENV_UDP_READ, &udp_read); 312 } while (err || udp_read.status || hd2->id != hd1->id); 313 314 if ((hd2->flags ^ 0x80) & htons(0xf80f)) 315 goto badness; 316 317 ques = htons(hd2->qdcount); /* Questions */ 318 reps = htons(hd2->ancount); /* Replies */ 319 p = DNSRecvBuf + sizeof(struct dnshdr); 320 while (ques--) { 321 p = dns_skiplabel(p); /* Skip name */ 322 p += 4; /* Skip question trailer */ 323 } 324 325 /* Parse the replies */ 326 while (reps--) { 327 same = dns_compare(DNSSendBuf + sizeof(struct dnshdr), 328 p, DNSRecvBuf); 329 p = dns_skiplabel(p); 330 rr = (struct dnsrr *)p; 331 rd_len = ntohs(rr->rdlength); 332 if (same && ntohs(rr->class) == CLASS_IN) { 333 switch (ntohs(rr->type)) { 334 case TYPE_A: 335 if (rd_len == 4) { 336 result = *(uint32_t *)rr->rdata; 337 goto done; 338 } 339 break; 340 case TYPE_CNAME: 341 dns_copylabel(DNSSendBuf + sizeof(struct dnshdr), 342 rr->rdata, DNSRecvBuf); 343 /* 344 * We should probably rescan the packet from the top 345 * here, and technically we might have to send a whole 346 * new request here... 347 */ 348 break; 349 default: 350 break; 351 } 352 } 353 354 /* not the one we want, try next */ 355 p += sizeof(struct dnsrr) + rd_len; 356 } 357 358 badness: 359 /* 360 * 361 ; We got back no data from this server. 362 ; Unfortunately, for a recursive, non-authoritative 363 ; query there is no such thing as an NXDOMAIN reply, 364 ; which technically means we can't draw any 365 ; conclusions. However, in practice that means the 366 ; domain doesn't exist. If this turns out to be a 367 ; problem, we may want to add code to go through all 368 ; the servers before giving up. 369 370 ; If the DNS server wasn't capable of recursion, and 371 ; isn't capable of giving us an authoritative reply 372 ; (i.e. neither AA or RA set), then at least try a 373 ; different setver... 374 */ 375 if (hd2->flags == htons(0x480)) 376 continue; 377 378 break; /* failed */ 379 380 again: 381 continue; 382 } 383 384 done: 385 free_port(local_port); /* Return port number to the free pool */ 386 387 return result; 388 } 389