Home | History | Annotate | Download | only in net
      1 // Copyright 2009 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 // +build darwin dragonfly freebsd linux netbsd openbsd solaris
      6 
      7 // DNS client: see RFC 1035.
      8 // Has to be linked into package net for Dial.
      9 
     10 // TODO(rsc):
     11 //	Could potentially handle many outstanding lookups faster.
     12 //	Could have a small cache.
     13 //	Random UDP source port (net.Dial should do that for us).
     14 //	Random request IDs.
     15 
     16 package net
     17 
     18 import (
     19 	"context"
     20 	"errors"
     21 	"io"
     22 	"math/rand"
     23 	"os"
     24 	"sync"
     25 	"time"
     26 )
     27 
     28 // A dnsDialer provides dialing suitable for DNS queries.
     29 type dnsDialer interface {
     30 	dialDNS(ctx context.Context, network, addr string) (dnsConn, error)
     31 }
     32 
     33 var testHookDNSDialer = func() dnsDialer { return &Dialer{} }
     34 
     35 // A dnsConn represents a DNS transport endpoint.
     36 type dnsConn interface {
     37 	io.Closer
     38 
     39 	SetDeadline(time.Time) error
     40 
     41 	// dnsRoundTrip executes a single DNS transaction, returning a
     42 	// DNS response message for the provided DNS query message.
     43 	dnsRoundTrip(query *dnsMsg) (*dnsMsg, error)
     44 }
     45 
     46 func (c *UDPConn) dnsRoundTrip(query *dnsMsg) (*dnsMsg, error) {
     47 	return dnsRoundTripUDP(c, query)
     48 }
     49 
     50 // dnsRoundTripUDP implements the dnsRoundTrip interface for RFC 1035's
     51 // "UDP usage" transport mechanism. c should be a packet-oriented connection,
     52 // such as a *UDPConn.
     53 func dnsRoundTripUDP(c io.ReadWriter, query *dnsMsg) (*dnsMsg, error) {
     54 	b, ok := query.Pack()
     55 	if !ok {
     56 		return nil, errors.New("cannot marshal DNS message")
     57 	}
     58 	if _, err := c.Write(b); err != nil {
     59 		return nil, err
     60 	}
     61 
     62 	b = make([]byte, 512) // see RFC 1035
     63 	for {
     64 		n, err := c.Read(b)
     65 		if err != nil {
     66 			return nil, err
     67 		}
     68 		resp := &dnsMsg{}
     69 		if !resp.Unpack(b[:n]) || !resp.IsResponseTo(query) {
     70 			// Ignore invalid responses as they may be malicious
     71 			// forgery attempts. Instead continue waiting until
     72 			// timeout. See golang.org/issue/13281.
     73 			continue
     74 		}
     75 		return resp, nil
     76 	}
     77 }
     78 
     79 func (c *TCPConn) dnsRoundTrip(out *dnsMsg) (*dnsMsg, error) {
     80 	return dnsRoundTripTCP(c, out)
     81 }
     82 
     83 // dnsRoundTripTCP implements the dnsRoundTrip interface for RFC 1035's
     84 // "TCP usage" transport mechanism. c should be a stream-oriented connection,
     85 // such as a *TCPConn.
     86 func dnsRoundTripTCP(c io.ReadWriter, query *dnsMsg) (*dnsMsg, error) {
     87 	b, ok := query.Pack()
     88 	if !ok {
     89 		return nil, errors.New("cannot marshal DNS message")
     90 	}
     91 	l := len(b)
     92 	b = append([]byte{byte(l >> 8), byte(l)}, b...)
     93 	if _, err := c.Write(b); err != nil {
     94 		return nil, err
     95 	}
     96 
     97 	b = make([]byte, 1280) // 1280 is a reasonable initial size for IP over Ethernet, see RFC 4035
     98 	if _, err := io.ReadFull(c, b[:2]); err != nil {
     99 		return nil, err
    100 	}
    101 	l = int(b[0])<<8 | int(b[1])
    102 	if l > len(b) {
    103 		b = make([]byte, l)
    104 	}
    105 	n, err := io.ReadFull(c, b[:l])
    106 	if err != nil {
    107 		return nil, err
    108 	}
    109 	resp := &dnsMsg{}
    110 	if !resp.Unpack(b[:n]) {
    111 		return nil, errors.New("cannot unmarshal DNS message")
    112 	}
    113 	if !resp.IsResponseTo(query) {
    114 		return nil, errors.New("invalid DNS response")
    115 	}
    116 	return resp, nil
    117 }
    118 
    119 func (d *Dialer) dialDNS(ctx context.Context, network, server string) (dnsConn, error) {
    120 	switch network {
    121 	case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6":
    122 	default:
    123 		return nil, UnknownNetworkError(network)
    124 	}
    125 	// Calling Dial here is scary -- we have to be sure not to
    126 	// dial a name that will require a DNS lookup, or Dial will
    127 	// call back here to translate it. The DNS config parser has
    128 	// already checked that all the cfg.servers are IP
    129 	// addresses, which Dial will use without a DNS lookup.
    130 	c, err := d.DialContext(ctx, network, server)
    131 	if err != nil {
    132 		return nil, mapErr(err)
    133 	}
    134 	switch network {
    135 	case "tcp", "tcp4", "tcp6":
    136 		return c.(*TCPConn), nil
    137 	case "udp", "udp4", "udp6":
    138 		return c.(*UDPConn), nil
    139 	}
    140 	panic("unreachable")
    141 }
    142 
    143 // exchange sends a query on the connection and hopes for a response.
    144 func exchange(ctx context.Context, server, name string, qtype uint16, timeout time.Duration) (*dnsMsg, error) {
    145 	d := testHookDNSDialer()
    146 	out := dnsMsg{
    147 		dnsMsgHdr: dnsMsgHdr{
    148 			recursion_desired: true,
    149 		},
    150 		question: []dnsQuestion{
    151 			{name, qtype, dnsClassINET},
    152 		},
    153 	}
    154 	for _, network := range []string{"udp", "tcp"} {
    155 		// TODO(mdempsky): Refactor so defers from UDP-based
    156 		// exchanges happen before TCP-based exchange.
    157 
    158 		ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout))
    159 		defer cancel()
    160 
    161 		c, err := d.dialDNS(ctx, network, server)
    162 		if err != nil {
    163 			return nil, err
    164 		}
    165 		defer c.Close()
    166 		if d, ok := ctx.Deadline(); ok && !d.IsZero() {
    167 			c.SetDeadline(d)
    168 		}
    169 		out.id = uint16(rand.Int()) ^ uint16(time.Now().UnixNano())
    170 		in, err := c.dnsRoundTrip(&out)
    171 		if err != nil {
    172 			return nil, mapErr(err)
    173 		}
    174 		if in.truncated { // see RFC 5966
    175 			continue
    176 		}
    177 		return in, nil
    178 	}
    179 	return nil, errors.New("no answer from DNS server")
    180 }
    181 
    182 // Do a lookup for a single name, which must be rooted
    183 // (otherwise answer will not find the answers).
    184 func tryOneName(ctx context.Context, cfg *dnsConfig, name string, qtype uint16) (string, []dnsRR, error) {
    185 	var lastErr error
    186 	serverOffset := cfg.serverOffset()
    187 	sLen := uint32(len(cfg.servers))
    188 
    189 	for i := 0; i < cfg.attempts; i++ {
    190 		for j := uint32(0); j < sLen; j++ {
    191 			server := cfg.servers[(serverOffset+j)%sLen]
    192 
    193 			msg, err := exchange(ctx, server, name, qtype, cfg.timeout)
    194 			if err != nil {
    195 				lastErr = &DNSError{
    196 					Err:    err.Error(),
    197 					Name:   name,
    198 					Server: server,
    199 				}
    200 				if nerr, ok := err.(Error); ok && nerr.Timeout() {
    201 					lastErr.(*DNSError).IsTimeout = true
    202 				}
    203 				continue
    204 			}
    205 			// libresolv continues to the next server when it receives
    206 			// an invalid referral response. See golang.org/issue/15434.
    207 			if msg.rcode == dnsRcodeSuccess && !msg.authoritative && !msg.recursion_available && len(msg.answer) == 0 && len(msg.extra) == 0 {
    208 				lastErr = &DNSError{Err: "lame referral", Name: name, Server: server}
    209 				continue
    210 			}
    211 			cname, rrs, err := answer(name, server, msg, qtype)
    212 			// If answer errored for rcodes dnsRcodeSuccess or dnsRcodeNameError,
    213 			// it means the response in msg was not useful and trying another
    214 			// server probably won't help. Return now in those cases.
    215 			// TODO: indicate this in a more obvious way, such as a field on DNSError?
    216 			if err == nil || msg.rcode == dnsRcodeSuccess || msg.rcode == dnsRcodeNameError {
    217 				return cname, rrs, err
    218 			}
    219 			lastErr = err
    220 		}
    221 	}
    222 	return "", nil, lastErr
    223 }
    224 
    225 // addrRecordList converts and returns a list of IP addresses from DNS
    226 // address records (both A and AAAA). Other record types are ignored.
    227 func addrRecordList(rrs []dnsRR) []IPAddr {
    228 	addrs := make([]IPAddr, 0, 4)
    229 	for _, rr := range rrs {
    230 		switch rr := rr.(type) {
    231 		case *dnsRR_A:
    232 			addrs = append(addrs, IPAddr{IP: IPv4(byte(rr.A>>24), byte(rr.A>>16), byte(rr.A>>8), byte(rr.A))})
    233 		case *dnsRR_AAAA:
    234 			ip := make(IP, IPv6len)
    235 			copy(ip, rr.AAAA[:])
    236 			addrs = append(addrs, IPAddr{IP: ip})
    237 		}
    238 	}
    239 	return addrs
    240 }
    241 
    242 // A resolverConfig represents a DNS stub resolver configuration.
    243 type resolverConfig struct {
    244 	initOnce sync.Once // guards init of resolverConfig
    245 
    246 	// ch is used as a semaphore that only allows one lookup at a
    247 	// time to recheck resolv.conf.
    248 	ch          chan struct{} // guards lastChecked and modTime
    249 	lastChecked time.Time     // last time resolv.conf was checked
    250 
    251 	mu        sync.RWMutex // protects dnsConfig
    252 	dnsConfig *dnsConfig   // parsed resolv.conf structure used in lookups
    253 }
    254 
    255 var resolvConf resolverConfig
    256 
    257 // init initializes conf and is only called via conf.initOnce.
    258 func (conf *resolverConfig) init() {
    259 	// Set dnsConfig and lastChecked so we don't parse
    260 	// resolv.conf twice the first time.
    261 	conf.dnsConfig = systemConf().resolv
    262 	if conf.dnsConfig == nil {
    263 		conf.dnsConfig = dnsReadConfig("/etc/resolv.conf")
    264 	}
    265 	conf.lastChecked = time.Now()
    266 
    267 	// Prepare ch so that only one update of resolverConfig may
    268 	// run at once.
    269 	conf.ch = make(chan struct{}, 1)
    270 }
    271 
    272 // tryUpdate tries to update conf with the named resolv.conf file.
    273 // The name variable only exists for testing. It is otherwise always
    274 // "/etc/resolv.conf".
    275 func (conf *resolverConfig) tryUpdate(name string) {
    276 	conf.initOnce.Do(conf.init)
    277 
    278 	// Ensure only one update at a time checks resolv.conf.
    279 	if !conf.tryAcquireSema() {
    280 		return
    281 	}
    282 	defer conf.releaseSema()
    283 
    284 	now := time.Now()
    285 	if conf.lastChecked.After(now.Add(-5 * time.Second)) {
    286 		return
    287 	}
    288 	conf.lastChecked = now
    289 
    290 	var mtime time.Time
    291 	if fi, err := os.Stat(name); err == nil {
    292 		mtime = fi.ModTime()
    293 	}
    294 	if mtime.Equal(conf.dnsConfig.mtime) {
    295 		return
    296 	}
    297 
    298 	dnsConf := dnsReadConfig(name)
    299 	conf.mu.Lock()
    300 	conf.dnsConfig = dnsConf
    301 	conf.mu.Unlock()
    302 }
    303 
    304 func (conf *resolverConfig) tryAcquireSema() bool {
    305 	select {
    306 	case conf.ch <- struct{}{}:
    307 		return true
    308 	default:
    309 		return false
    310 	}
    311 }
    312 
    313 func (conf *resolverConfig) releaseSema() {
    314 	<-conf.ch
    315 }
    316 
    317 func lookup(ctx context.Context, name string, qtype uint16) (cname string, rrs []dnsRR, err error) {
    318 	if !isDomainName(name) {
    319 		// We used to use "invalid domain name" as the error,
    320 		// but that is a detail of the specific lookup mechanism.
    321 		// Other lookups might allow broader name syntax
    322 		// (for example Multicast DNS allows UTF-8; see RFC 6762).
    323 		// For consistency with libc resolvers, report no such host.
    324 		return "", nil, &DNSError{Err: errNoSuchHost.Error(), Name: name}
    325 	}
    326 	resolvConf.tryUpdate("/etc/resolv.conf")
    327 	resolvConf.mu.RLock()
    328 	conf := resolvConf.dnsConfig
    329 	resolvConf.mu.RUnlock()
    330 	for _, fqdn := range conf.nameList(name) {
    331 		cname, rrs, err = tryOneName(ctx, conf, fqdn, qtype)
    332 		if err == nil {
    333 			break
    334 		}
    335 	}
    336 	if err, ok := err.(*DNSError); ok {
    337 		// Show original name passed to lookup, not suffixed one.
    338 		// In general we might have tried many suffixes; showing
    339 		// just one is misleading. See also golang.org/issue/6324.
    340 		err.Name = name
    341 	}
    342 	return
    343 }
    344 
    345 // avoidDNS reports whether this is a hostname for which we should not
    346 // use DNS. Currently this includes only .onion, per RFC 7686. See
    347 // golang.org/issue/13705. Does not cover .local names (RFC 6762),
    348 // see golang.org/issue/16739.
    349 func avoidDNS(name string) bool {
    350 	if name == "" {
    351 		return true
    352 	}
    353 	if name[len(name)-1] == '.' {
    354 		name = name[:len(name)-1]
    355 	}
    356 	return stringsHasSuffixFold(name, ".onion")
    357 }
    358 
    359 // nameList returns a list of names for sequential DNS queries.
    360 func (conf *dnsConfig) nameList(name string) []string {
    361 	if avoidDNS(name) {
    362 		return nil
    363 	}
    364 
    365 	// Check name length (see isDomainName).
    366 	l := len(name)
    367 	rooted := l > 0 && name[l-1] == '.'
    368 	if l > 254 || l == 254 && rooted {
    369 		return nil
    370 	}
    371 
    372 	// If name is rooted (trailing dot), try only that name.
    373 	if rooted {
    374 		return []string{name}
    375 	}
    376 
    377 	hasNdots := count(name, '.') >= conf.ndots
    378 	name += "."
    379 	l++
    380 
    381 	// Build list of search choices.
    382 	names := make([]string, 0, 1+len(conf.search))
    383 	// If name has enough dots, try unsuffixed first.
    384 	if hasNdots {
    385 		names = append(names, name)
    386 	}
    387 	// Try suffixes that are not too long (see isDomainName).
    388 	for _, suffix := range conf.search {
    389 		if l+len(suffix) <= 254 {
    390 			names = append(names, name+suffix)
    391 		}
    392 	}
    393 	// Try unsuffixed, if not tried first above.
    394 	if !hasNdots {
    395 		names = append(names, name)
    396 	}
    397 	return names
    398 }
    399 
    400 // hostLookupOrder specifies the order of LookupHost lookup strategies.
    401 // It is basically a simplified representation of nsswitch.conf.
    402 // "files" means /etc/hosts.
    403 type hostLookupOrder int
    404 
    405 const (
    406 	// hostLookupCgo means defer to cgo.
    407 	hostLookupCgo      hostLookupOrder = iota
    408 	hostLookupFilesDNS                 // files first
    409 	hostLookupDNSFiles                 // dns first
    410 	hostLookupFiles                    // only files
    411 	hostLookupDNS                      // only DNS
    412 )
    413 
    414 var lookupOrderName = map[hostLookupOrder]string{
    415 	hostLookupCgo:      "cgo",
    416 	hostLookupFilesDNS: "files,dns",
    417 	hostLookupDNSFiles: "dns,files",
    418 	hostLookupFiles:    "files",
    419 	hostLookupDNS:      "dns",
    420 }
    421 
    422 func (o hostLookupOrder) String() string {
    423 	if s, ok := lookupOrderName[o]; ok {
    424 		return s
    425 	}
    426 	return "hostLookupOrder=" + itoa(int(o)) + "??"
    427 }
    428 
    429 // goLookupHost is the native Go implementation of LookupHost.
    430 // Used only if cgoLookupHost refuses to handle the request
    431 // (that is, only if cgoLookupHost is the stub in cgo_stub.go).
    432 // Normally we let cgo use the C library resolver instead of
    433 // depending on our lookup code, so that Go and C get the same
    434 // answers.
    435 func goLookupHost(ctx context.Context, name string) (addrs []string, err error) {
    436 	return goLookupHostOrder(ctx, name, hostLookupFilesDNS)
    437 }
    438 
    439 func goLookupHostOrder(ctx context.Context, name string, order hostLookupOrder) (addrs []string, err error) {
    440 	if order == hostLookupFilesDNS || order == hostLookupFiles {
    441 		// Use entries from /etc/hosts if they match.
    442 		addrs = lookupStaticHost(name)
    443 		if len(addrs) > 0 || order == hostLookupFiles {
    444 			return
    445 		}
    446 	}
    447 	ips, _, err := goLookupIPCNAMEOrder(ctx, name, order)
    448 	if err != nil {
    449 		return
    450 	}
    451 	addrs = make([]string, 0, len(ips))
    452 	for _, ip := range ips {
    453 		addrs = append(addrs, ip.String())
    454 	}
    455 	return
    456 }
    457 
    458 // lookup entries from /etc/hosts
    459 func goLookupIPFiles(name string) (addrs []IPAddr) {
    460 	for _, haddr := range lookupStaticHost(name) {
    461 		haddr, zone := splitHostZone(haddr)
    462 		if ip := ParseIP(haddr); ip != nil {
    463 			addr := IPAddr{IP: ip, Zone: zone}
    464 			addrs = append(addrs, addr)
    465 		}
    466 	}
    467 	sortByRFC6724(addrs)
    468 	return
    469 }
    470 
    471 // goLookupIP is the native Go implementation of LookupIP.
    472 // The libc versions are in cgo_*.go.
    473 func goLookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) {
    474 	order := systemConf().hostLookupOrder(host)
    475 	addrs, _, err = goLookupIPCNAMEOrder(ctx, host, order)
    476 	return
    477 }
    478 
    479 func goLookupIPCNAMEOrder(ctx context.Context, name string, order hostLookupOrder) (addrs []IPAddr, cname string, err error) {
    480 	if order == hostLookupFilesDNS || order == hostLookupFiles {
    481 		addrs = goLookupIPFiles(name)
    482 		if len(addrs) > 0 || order == hostLookupFiles {
    483 			return addrs, name, nil
    484 		}
    485 	}
    486 	if !isDomainName(name) {
    487 		// See comment in func lookup above about use of errNoSuchHost.
    488 		return nil, "", &DNSError{Err: errNoSuchHost.Error(), Name: name}
    489 	}
    490 	resolvConf.tryUpdate("/etc/resolv.conf")
    491 	resolvConf.mu.RLock()
    492 	conf := resolvConf.dnsConfig
    493 	resolvConf.mu.RUnlock()
    494 	type racer struct {
    495 		cname string
    496 		rrs   []dnsRR
    497 		error
    498 	}
    499 	lane := make(chan racer, 1)
    500 	qtypes := [...]uint16{dnsTypeA, dnsTypeAAAA}
    501 	var lastErr error
    502 	for _, fqdn := range conf.nameList(name) {
    503 		for _, qtype := range qtypes {
    504 			go func(qtype uint16) {
    505 				cname, rrs, err := tryOneName(ctx, conf, fqdn, qtype)
    506 				lane <- racer{cname, rrs, err}
    507 			}(qtype)
    508 		}
    509 		for range qtypes {
    510 			racer := <-lane
    511 			if racer.error != nil {
    512 				// Prefer error for original name.
    513 				if lastErr == nil || fqdn == name+"." {
    514 					lastErr = racer.error
    515 				}
    516 				continue
    517 			}
    518 			addrs = append(addrs, addrRecordList(racer.rrs)...)
    519 			if cname == "" {
    520 				cname = racer.cname
    521 			}
    522 		}
    523 		if len(addrs) > 0 {
    524 			break
    525 		}
    526 	}
    527 	if lastErr, ok := lastErr.(*DNSError); ok {
    528 		// Show original name passed to lookup, not suffixed one.
    529 		// In general we might have tried many suffixes; showing
    530 		// just one is misleading. See also golang.org/issue/6324.
    531 		lastErr.Name = name
    532 	}
    533 	sortByRFC6724(addrs)
    534 	if len(addrs) == 0 {
    535 		if order == hostLookupDNSFiles {
    536 			addrs = goLookupIPFiles(name)
    537 		}
    538 		if len(addrs) == 0 && lastErr != nil {
    539 			return nil, "", lastErr
    540 		}
    541 	}
    542 	return addrs, cname, nil
    543 }
    544 
    545 // goLookupCNAME is the native Go (non-cgo) implementation of LookupCNAME.
    546 func goLookupCNAME(ctx context.Context, host string) (cname string, err error) {
    547 	order := systemConf().hostLookupOrder(host)
    548 	_, cname, err = goLookupIPCNAMEOrder(ctx, host, order)
    549 	return
    550 }
    551 
    552 // goLookupPTR is the native Go implementation of LookupAddr.
    553 // Used only if cgoLookupPTR refuses to handle the request (that is,
    554 // only if cgoLookupPTR is the stub in cgo_stub.go).
    555 // Normally we let cgo use the C library resolver instead of depending
    556 // on our lookup code, so that Go and C get the same answers.
    557 func goLookupPTR(ctx context.Context, addr string) ([]string, error) {
    558 	names := lookupStaticAddr(addr)
    559 	if len(names) > 0 {
    560 		return names, nil
    561 	}
    562 	arpa, err := reverseaddr(addr)
    563 	if err != nil {
    564 		return nil, err
    565 	}
    566 	_, rrs, err := lookup(ctx, arpa, dnsTypePTR)
    567 	if err != nil {
    568 		return nil, err
    569 	}
    570 	ptrs := make([]string, len(rrs))
    571 	for i, rr := range rrs {
    572 		ptrs[i] = rr.(*dnsRR_PTR).Ptr
    573 	}
    574 	return ptrs, nil
    575 }
    576