Home | History | Annotate | Download | only in net
      1 // Copyright 2009 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 package net
      6 
      7 import (
      8 	"context"
      9 	"os"
     10 	"runtime"
     11 	"syscall"
     12 	"unsafe"
     13 )
     14 
     15 const _WSAHOST_NOT_FOUND = syscall.Errno(11001)
     16 
     17 func winError(call string, err error) error {
     18 	switch err {
     19 	case _WSAHOST_NOT_FOUND:
     20 		return errNoSuchHost
     21 	}
     22 	return os.NewSyscallError(call, err)
     23 }
     24 
     25 func getprotobyname(name string) (proto int, err error) {
     26 	p, err := syscall.GetProtoByName(name)
     27 	if err != nil {
     28 		return 0, winError("getprotobyname", err)
     29 	}
     30 	return int(p.Proto), nil
     31 }
     32 
     33 // lookupProtocol looks up IP protocol name and returns correspondent protocol number.
     34 func lookupProtocol(ctx context.Context, name string) (int, error) {
     35 	// GetProtoByName return value is stored in thread local storage.
     36 	// Start new os thread before the call to prevent races.
     37 	type result struct {
     38 		proto int
     39 		err   error
     40 	}
     41 	ch := make(chan result) // unbuffered
     42 	go func() {
     43 		acquireThread()
     44 		defer releaseThread()
     45 		runtime.LockOSThread()
     46 		defer runtime.UnlockOSThread()
     47 		proto, err := getprotobyname(name)
     48 		select {
     49 		case ch <- result{proto: proto, err: err}:
     50 		case <-ctx.Done():
     51 		}
     52 	}()
     53 	select {
     54 	case r := <-ch:
     55 		if r.err != nil {
     56 			if proto, err := lookupProtocolMap(name); err == nil {
     57 				return proto, nil
     58 			}
     59 			r.err = &DNSError{Err: r.err.Error(), Name: name}
     60 		}
     61 		return r.proto, r.err
     62 	case <-ctx.Done():
     63 		return 0, mapErr(ctx.Err())
     64 	}
     65 }
     66 
     67 func (r *Resolver) lookupHost(ctx context.Context, name string) ([]string, error) {
     68 	ips, err := r.lookupIP(ctx, name)
     69 	if err != nil {
     70 		return nil, err
     71 	}
     72 	addrs := make([]string, 0, len(ips))
     73 	for _, ip := range ips {
     74 		addrs = append(addrs, ip.String())
     75 	}
     76 	return addrs, nil
     77 }
     78 
     79 func (r *Resolver) lookupIP(ctx context.Context, name string) ([]IPAddr, error) {
     80 	// TODO(bradfitz,brainman): use ctx more. See TODO below.
     81 
     82 	type ret struct {
     83 		addrs []IPAddr
     84 		err   error
     85 	}
     86 	ch := make(chan ret, 1)
     87 	go func() {
     88 		acquireThread()
     89 		defer releaseThread()
     90 		hints := syscall.AddrinfoW{
     91 			Family:   syscall.AF_UNSPEC,
     92 			Socktype: syscall.SOCK_STREAM,
     93 			Protocol: syscall.IPPROTO_IP,
     94 		}
     95 		var result *syscall.AddrinfoW
     96 		e := syscall.GetAddrInfoW(syscall.StringToUTF16Ptr(name), nil, &hints, &result)
     97 		if e != nil {
     98 			ch <- ret{err: &DNSError{Err: winError("getaddrinfow", e).Error(), Name: name}}
     99 		}
    100 		defer syscall.FreeAddrInfoW(result)
    101 		addrs := make([]IPAddr, 0, 5)
    102 		for ; result != nil; result = result.Next {
    103 			addr := unsafe.Pointer(result.Addr)
    104 			switch result.Family {
    105 			case syscall.AF_INET:
    106 				a := (*syscall.RawSockaddrInet4)(addr).Addr
    107 				addrs = append(addrs, IPAddr{IP: IPv4(a[0], a[1], a[2], a[3])})
    108 			case syscall.AF_INET6:
    109 				a := (*syscall.RawSockaddrInet6)(addr).Addr
    110 				zone := zoneCache.name(int((*syscall.RawSockaddrInet6)(addr).Scope_id))
    111 				addrs = append(addrs, IPAddr{IP: IP{a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8], a[9], a[10], a[11], a[12], a[13], a[14], a[15]}, Zone: zone})
    112 			default:
    113 				ch <- ret{err: &DNSError{Err: syscall.EWINDOWS.Error(), Name: name}}
    114 			}
    115 		}
    116 		ch <- ret{addrs: addrs}
    117 	}()
    118 	select {
    119 	case r := <-ch:
    120 		return r.addrs, r.err
    121 	case <-ctx.Done():
    122 		// TODO(bradfitz,brainman): cancel the ongoing
    123 		// GetAddrInfoW? It would require conditionally using
    124 		// GetAddrInfoEx with lpOverlapped, which requires
    125 		// Windows 8 or newer. I guess we'll need oldLookupIP,
    126 		// newLookupIP, and newerLookUP.
    127 		//
    128 		// For now we just let it finish and write to the
    129 		// buffered channel.
    130 		return nil, &DNSError{
    131 			Name:      name,
    132 			Err:       ctx.Err().Error(),
    133 			IsTimeout: ctx.Err() == context.DeadlineExceeded,
    134 		}
    135 	}
    136 }
    137 
    138 func (r *Resolver) lookupPort(ctx context.Context, network, service string) (int, error) {
    139 	if r.PreferGo {
    140 		return lookupPortMap(network, service)
    141 	}
    142 
    143 	// TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
    144 	acquireThread()
    145 	defer releaseThread()
    146 	var stype int32
    147 	switch network {
    148 	case "tcp4", "tcp6":
    149 		stype = syscall.SOCK_STREAM
    150 	case "udp4", "udp6":
    151 		stype = syscall.SOCK_DGRAM
    152 	}
    153 	hints := syscall.AddrinfoW{
    154 		Family:   syscall.AF_UNSPEC,
    155 		Socktype: stype,
    156 		Protocol: syscall.IPPROTO_IP,
    157 	}
    158 	var result *syscall.AddrinfoW
    159 	e := syscall.GetAddrInfoW(nil, syscall.StringToUTF16Ptr(service), &hints, &result)
    160 	if e != nil {
    161 		if port, err := lookupPortMap(network, service); err == nil {
    162 			return port, nil
    163 		}
    164 		return 0, &DNSError{Err: winError("getaddrinfow", e).Error(), Name: network + "/" + service}
    165 	}
    166 	defer syscall.FreeAddrInfoW(result)
    167 	if result == nil {
    168 		return 0, &DNSError{Err: syscall.EINVAL.Error(), Name: network + "/" + service}
    169 	}
    170 	addr := unsafe.Pointer(result.Addr)
    171 	switch result.Family {
    172 	case syscall.AF_INET:
    173 		a := (*syscall.RawSockaddrInet4)(addr)
    174 		return int(syscall.Ntohs(a.Port)), nil
    175 	case syscall.AF_INET6:
    176 		a := (*syscall.RawSockaddrInet6)(addr)
    177 		return int(syscall.Ntohs(a.Port)), nil
    178 	}
    179 	return 0, &DNSError{Err: syscall.EINVAL.Error(), Name: network + "/" + service}
    180 }
    181 
    182 func (*Resolver) lookupCNAME(ctx context.Context, name string) (string, error) {
    183 	// TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
    184 	acquireThread()
    185 	defer releaseThread()
    186 	var r *syscall.DNSRecord
    187 	e := syscall.DnsQuery(name, syscall.DNS_TYPE_CNAME, 0, nil, &r, nil)
    188 	// windows returns DNS_INFO_NO_RECORDS if there are no CNAME-s
    189 	if errno, ok := e.(syscall.Errno); ok && errno == syscall.DNS_INFO_NO_RECORDS {
    190 		// if there are no aliases, the canonical name is the input name
    191 		return absDomainName([]byte(name)), nil
    192 	}
    193 	if e != nil {
    194 		return "", &DNSError{Err: winError("dnsquery", e).Error(), Name: name}
    195 	}
    196 	defer syscall.DnsRecordListFree(r, 1)
    197 
    198 	resolved := resolveCNAME(syscall.StringToUTF16Ptr(name), r)
    199 	cname := syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(resolved))[:])
    200 	return absDomainName([]byte(cname)), nil
    201 }
    202 
    203 func (*Resolver) lookupSRV(ctx context.Context, service, proto, name string) (string, []*SRV, error) {
    204 	// TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
    205 	acquireThread()
    206 	defer releaseThread()
    207 	var target string
    208 	if service == "" && proto == "" {
    209 		target = name
    210 	} else {
    211 		target = "_" + service + "._" + proto + "." + name
    212 	}
    213 	var r *syscall.DNSRecord
    214 	e := syscall.DnsQuery(target, syscall.DNS_TYPE_SRV, 0, nil, &r, nil)
    215 	if e != nil {
    216 		return "", nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: target}
    217 	}
    218 	defer syscall.DnsRecordListFree(r, 1)
    219 
    220 	srvs := make([]*SRV, 0, 10)
    221 	for _, p := range validRecs(r, syscall.DNS_TYPE_SRV, target) {
    222 		v := (*syscall.DNSSRVData)(unsafe.Pointer(&p.Data[0]))
    223 		srvs = append(srvs, &SRV{absDomainName([]byte(syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Target))[:]))), v.Port, v.Priority, v.Weight})
    224 	}
    225 	byPriorityWeight(srvs).sort()
    226 	return absDomainName([]byte(target)), srvs, nil
    227 }
    228 
    229 func (*Resolver) lookupMX(ctx context.Context, name string) ([]*MX, error) {
    230 	// TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
    231 	acquireThread()
    232 	defer releaseThread()
    233 	var r *syscall.DNSRecord
    234 	e := syscall.DnsQuery(name, syscall.DNS_TYPE_MX, 0, nil, &r, nil)
    235 	if e != nil {
    236 		return nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: name}
    237 	}
    238 	defer syscall.DnsRecordListFree(r, 1)
    239 
    240 	mxs := make([]*MX, 0, 10)
    241 	for _, p := range validRecs(r, syscall.DNS_TYPE_MX, name) {
    242 		v := (*syscall.DNSMXData)(unsafe.Pointer(&p.Data[0]))
    243 		mxs = append(mxs, &MX{absDomainName([]byte(syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.NameExchange))[:]))), v.Preference})
    244 	}
    245 	byPref(mxs).sort()
    246 	return mxs, nil
    247 }
    248 
    249 func (*Resolver) lookupNS(ctx context.Context, name string) ([]*NS, error) {
    250 	// TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
    251 	acquireThread()
    252 	defer releaseThread()
    253 	var r *syscall.DNSRecord
    254 	e := syscall.DnsQuery(name, syscall.DNS_TYPE_NS, 0, nil, &r, nil)
    255 	if e != nil {
    256 		return nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: name}
    257 	}
    258 	defer syscall.DnsRecordListFree(r, 1)
    259 
    260 	nss := make([]*NS, 0, 10)
    261 	for _, p := range validRecs(r, syscall.DNS_TYPE_NS, name) {
    262 		v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0]))
    263 		nss = append(nss, &NS{absDomainName([]byte(syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:])))})
    264 	}
    265 	return nss, nil
    266 }
    267 
    268 func (*Resolver) lookupTXT(ctx context.Context, name string) ([]string, error) {
    269 	// TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
    270 	acquireThread()
    271 	defer releaseThread()
    272 	var r *syscall.DNSRecord
    273 	e := syscall.DnsQuery(name, syscall.DNS_TYPE_TEXT, 0, nil, &r, nil)
    274 	if e != nil {
    275 		return nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: name}
    276 	}
    277 	defer syscall.DnsRecordListFree(r, 1)
    278 
    279 	txts := make([]string, 0, 10)
    280 	for _, p := range validRecs(r, syscall.DNS_TYPE_TEXT, name) {
    281 		d := (*syscall.DNSTXTData)(unsafe.Pointer(&p.Data[0]))
    282 		s := ""
    283 		for _, v := range (*[1 << 10]*uint16)(unsafe.Pointer(&(d.StringArray[0])))[:d.StringCount] {
    284 			s += syscall.UTF16ToString((*[1 << 20]uint16)(unsafe.Pointer(v))[:])
    285 		}
    286 		txts = append(txts, s)
    287 	}
    288 	return txts, nil
    289 }
    290 
    291 func (*Resolver) lookupAddr(ctx context.Context, addr string) ([]string, error) {
    292 	// TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
    293 	acquireThread()
    294 	defer releaseThread()
    295 	arpa, err := reverseaddr(addr)
    296 	if err != nil {
    297 		return nil, err
    298 	}
    299 	var r *syscall.DNSRecord
    300 	e := syscall.DnsQuery(arpa, syscall.DNS_TYPE_PTR, 0, nil, &r, nil)
    301 	if e != nil {
    302 		return nil, &DNSError{Err: winError("dnsquery", e).Error(), Name: addr}
    303 	}
    304 	defer syscall.DnsRecordListFree(r, 1)
    305 
    306 	ptrs := make([]string, 0, 10)
    307 	for _, p := range validRecs(r, syscall.DNS_TYPE_PTR, arpa) {
    308 		v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0]))
    309 		ptrs = append(ptrs, absDomainName([]byte(syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:]))))
    310 	}
    311 	return ptrs, nil
    312 }
    313 
    314 const dnsSectionMask = 0x0003
    315 
    316 // returns only results applicable to name and resolves CNAME entries
    317 func validRecs(r *syscall.DNSRecord, dnstype uint16, name string) []*syscall.DNSRecord {
    318 	cname := syscall.StringToUTF16Ptr(name)
    319 	if dnstype != syscall.DNS_TYPE_CNAME {
    320 		cname = resolveCNAME(cname, r)
    321 	}
    322 	rec := make([]*syscall.DNSRecord, 0, 10)
    323 	for p := r; p != nil; p = p.Next {
    324 		if p.Dw&dnsSectionMask != syscall.DnsSectionAnswer {
    325 			continue
    326 		}
    327 		if p.Type != dnstype {
    328 			continue
    329 		}
    330 		if !syscall.DnsNameCompare(cname, p.Name) {
    331 			continue
    332 		}
    333 		rec = append(rec, p)
    334 	}
    335 	return rec
    336 }
    337 
    338 // returns the last CNAME in chain
    339 func resolveCNAME(name *uint16, r *syscall.DNSRecord) *uint16 {
    340 	// limit cname resolving to 10 in case of a infinite CNAME loop
    341 Cname:
    342 	for cnameloop := 0; cnameloop < 10; cnameloop++ {
    343 		for p := r; p != nil; p = p.Next {
    344 			if p.Dw&dnsSectionMask != syscall.DnsSectionAnswer {
    345 				continue
    346 			}
    347 			if p.Type != syscall.DNS_TYPE_CNAME {
    348 				continue
    349 			}
    350 			if !syscall.DnsNameCompare(name, p.Name) {
    351 				continue
    352 			}
    353 			name = (*syscall.DNSPTRData)(unsafe.Pointer(&r.Data[0])).Host
    354 			continue Cname
    355 		}
    356 		break
    357 	}
    358 	return name
    359 }
    360