Home | History | Annotate | Download | only in net
      1 // Copyright 2009 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 package net
      6 
      7 import (
      8 	"os"
      9 	"runtime"
     10 	"syscall"
     11 	"unsafe"
     12 )
     13 
     14 var (
     15 	lookupPort = oldLookupPort
     16 	lookupIP   = oldLookupIP
     17 )
     18 
     19 func getprotobyname(name string) (proto int, err error) {
     20 	p, err := syscall.GetProtoByName(name)
     21 	if err != nil {
     22 		return 0, os.NewSyscallError("getorotobyname", err)
     23 	}
     24 	return int(p.Proto), nil
     25 }
     26 
     27 // lookupProtocol looks up IP protocol name and returns correspondent protocol number.
     28 func lookupProtocol(name string) (int, error) {
     29 	// GetProtoByName return value is stored in thread local storage.
     30 	// Start new os thread before the call to prevent races.
     31 	type result struct {
     32 		proto int
     33 		err   error
     34 	}
     35 	ch := make(chan result)
     36 	go func() {
     37 		acquireThread()
     38 		defer releaseThread()
     39 		runtime.LockOSThread()
     40 		defer runtime.UnlockOSThread()
     41 		proto, err := getprotobyname(name)
     42 		ch <- result{proto: proto, err: err}
     43 	}()
     44 	r := <-ch
     45 	if r.err != nil {
     46 		if proto, ok := protocols[name]; ok {
     47 			return proto, nil
     48 		}
     49 		r.err = &DNSError{Err: r.err.Error(), Name: name}
     50 	}
     51 	return r.proto, r.err
     52 }
     53 
     54 func lookupHost(name string) ([]string, error) {
     55 	ips, err := LookupIP(name)
     56 	if err != nil {
     57 		return nil, err
     58 	}
     59 	addrs := make([]string, 0, len(ips))
     60 	for _, ip := range ips {
     61 		addrs = append(addrs, ip.String())
     62 	}
     63 	return addrs, nil
     64 }
     65 
     66 func gethostbyname(name string) (addrs []IPAddr, err error) {
     67 	// caller already acquired thread
     68 	h, err := syscall.GetHostByName(name)
     69 	if err != nil {
     70 		return nil, os.NewSyscallError("gethostbyname", err)
     71 	}
     72 	switch h.AddrType {
     73 	case syscall.AF_INET:
     74 		i := 0
     75 		addrs = make([]IPAddr, 100) // plenty of room to grow
     76 		for p := (*[100](*[4]byte))(unsafe.Pointer(h.AddrList)); i < cap(addrs) && p[i] != nil; i++ {
     77 			addrs[i] = IPAddr{IP: IPv4(p[i][0], p[i][1], p[i][2], p[i][3])}
     78 		}
     79 		addrs = addrs[0:i]
     80 	default: // TODO(vcc): Implement non IPv4 address lookups.
     81 		return nil, syscall.EWINDOWS
     82 	}
     83 	return addrs, nil
     84 }
     85 
     86 func oldLookupIP(name string) ([]IPAddr, error) {
     87 	// GetHostByName return value is stored in thread local storage.
     88 	// Start new os thread before the call to prevent races.
     89 	type result struct {
     90 		addrs []IPAddr
     91 		err   error
     92 	}
     93 	ch := make(chan result)
     94 	go func() {
     95 		acquireThread()
     96 		defer releaseThread()
     97 		runtime.LockOSThread()
     98 		defer runtime.UnlockOSThread()
     99 		addrs, err := gethostbyname(name)
    100 		ch <- result{addrs: addrs, err: err}
    101 	}()
    102 	r := <-ch
    103 	if r.err != nil {
    104 		r.err = &DNSError{Err: r.err.Error(), Name: name}
    105 	}
    106 	return r.addrs, r.err
    107 }
    108 
    109 func newLookupIP(name string) ([]IPAddr, error) {
    110 	acquireThread()
    111 	defer releaseThread()
    112 	hints := syscall.AddrinfoW{
    113 		Family:   syscall.AF_UNSPEC,
    114 		Socktype: syscall.SOCK_STREAM,
    115 		Protocol: syscall.IPPROTO_IP,
    116 	}
    117 	var result *syscall.AddrinfoW
    118 	e := syscall.GetAddrInfoW(syscall.StringToUTF16Ptr(name), nil, &hints, &result)
    119 	if e != nil {
    120 		return nil, &DNSError{Err: os.NewSyscallError("getaddrinfow", e).Error(), Name: name}
    121 	}
    122 	defer syscall.FreeAddrInfoW(result)
    123 	addrs := make([]IPAddr, 0, 5)
    124 	for ; result != nil; result = result.Next {
    125 		addr := unsafe.Pointer(result.Addr)
    126 		switch result.Family {
    127 		case syscall.AF_INET:
    128 			a := (*syscall.RawSockaddrInet4)(addr).Addr
    129 			addrs = append(addrs, IPAddr{IP: IPv4(a[0], a[1], a[2], a[3])})
    130 		case syscall.AF_INET6:
    131 			a := (*syscall.RawSockaddrInet6)(addr).Addr
    132 			zone := zoneToString(int((*syscall.RawSockaddrInet6)(addr).Scope_id))
    133 			addrs = append(addrs, IPAddr{IP: IP{a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8], a[9], a[10], a[11], a[12], a[13], a[14], a[15]}, Zone: zone})
    134 		default:
    135 			return nil, &DNSError{Err: syscall.EWINDOWS.Error(), Name: name}
    136 		}
    137 	}
    138 	return addrs, nil
    139 }
    140 
    141 func getservbyname(network, service string) (int, error) {
    142 	acquireThread()
    143 	defer releaseThread()
    144 	switch network {
    145 	case "tcp4", "tcp6":
    146 		network = "tcp"
    147 	case "udp4", "udp6":
    148 		network = "udp"
    149 	}
    150 	s, err := syscall.GetServByName(service, network)
    151 	if err != nil {
    152 		return 0, os.NewSyscallError("getservbyname", err)
    153 	}
    154 	return int(syscall.Ntohs(s.Port)), nil
    155 }
    156 
    157 func oldLookupPort(network, service string) (int, error) {
    158 	// GetServByName return value is stored in thread local storage.
    159 	// Start new os thread before the call to prevent races.
    160 	type result struct {
    161 		port int
    162 		err  error
    163 	}
    164 	ch := make(chan result)
    165 	go func() {
    166 		acquireThread()
    167 		defer releaseThread()
    168 		runtime.LockOSThread()
    169 		defer runtime.UnlockOSThread()
    170 		port, err := getservbyname(network, service)
    171 		ch <- result{port: port, err: err}
    172 	}()
    173 	r := <-ch
    174 	if r.err != nil {
    175 		r.err = &DNSError{Err: r.err.Error(), Name: network + "/" + service}
    176 	}
    177 	return r.port, r.err
    178 }
    179 
    180 func newLookupPort(network, service string) (int, error) {
    181 	acquireThread()
    182 	defer releaseThread()
    183 	var stype int32
    184 	switch network {
    185 	case "tcp4", "tcp6":
    186 		stype = syscall.SOCK_STREAM
    187 	case "udp4", "udp6":
    188 		stype = syscall.SOCK_DGRAM
    189 	}
    190 	hints := syscall.AddrinfoW{
    191 		Family:   syscall.AF_UNSPEC,
    192 		Socktype: stype,
    193 		Protocol: syscall.IPPROTO_IP,
    194 	}
    195 	var result *syscall.AddrinfoW
    196 	e := syscall.GetAddrInfoW(nil, syscall.StringToUTF16Ptr(service), &hints, &result)
    197 	if e != nil {
    198 		return 0, &DNSError{Err: os.NewSyscallError("getaddrinfow", e).Error(), Name: network + "/" + service}
    199 	}
    200 	defer syscall.FreeAddrInfoW(result)
    201 	if result == nil {
    202 		return 0, &DNSError{Err: syscall.EINVAL.Error(), Name: network + "/" + service}
    203 	}
    204 	addr := unsafe.Pointer(result.Addr)
    205 	switch result.Family {
    206 	case syscall.AF_INET:
    207 		a := (*syscall.RawSockaddrInet4)(addr)
    208 		return int(syscall.Ntohs(a.Port)), nil
    209 	case syscall.AF_INET6:
    210 		a := (*syscall.RawSockaddrInet6)(addr)
    211 		return int(syscall.Ntohs(a.Port)), nil
    212 	}
    213 	return 0, &DNSError{Err: syscall.EINVAL.Error(), Name: network + "/" + service}
    214 }
    215 
    216 func lookupCNAME(name string) (string, error) {
    217 	acquireThread()
    218 	defer releaseThread()
    219 	var r *syscall.DNSRecord
    220 	e := syscall.DnsQuery(name, syscall.DNS_TYPE_CNAME, 0, nil, &r, nil)
    221 	// windows returns DNS_INFO_NO_RECORDS if there are no CNAME-s
    222 	if errno, ok := e.(syscall.Errno); ok && errno == syscall.DNS_INFO_NO_RECORDS {
    223 		// if there are no aliases, the canonical name is the input name
    224 		if name == "" || name[len(name)-1] != '.' {
    225 			return name + ".", nil
    226 		}
    227 		return name, nil
    228 	}
    229 	if e != nil {
    230 		return "", &DNSError{Err: os.NewSyscallError("dnsquery", e).Error(), Name: name}
    231 	}
    232 	defer syscall.DnsRecordListFree(r, 1)
    233 
    234 	resolved := resolveCNAME(syscall.StringToUTF16Ptr(name), r)
    235 	cname := syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(resolved))[:]) + "."
    236 	return cname, nil
    237 }
    238 
    239 func lookupSRV(service, proto, name string) (string, []*SRV, error) {
    240 	acquireThread()
    241 	defer releaseThread()
    242 	var target string
    243 	if service == "" && proto == "" {
    244 		target = name
    245 	} else {
    246 		target = "_" + service + "._" + proto + "." + name
    247 	}
    248 	var r *syscall.DNSRecord
    249 	e := syscall.DnsQuery(target, syscall.DNS_TYPE_SRV, 0, nil, &r, nil)
    250 	if e != nil {
    251 		return "", nil, &DNSError{Err: os.NewSyscallError("dnsquery", e).Error(), Name: target}
    252 	}
    253 	defer syscall.DnsRecordListFree(r, 1)
    254 
    255 	srvs := make([]*SRV, 0, 10)
    256 	for _, p := range validRecs(r, syscall.DNS_TYPE_SRV, target) {
    257 		v := (*syscall.DNSSRVData)(unsafe.Pointer(&p.Data[0]))
    258 		srvs = append(srvs, &SRV{syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Target))[:]), v.Port, v.Priority, v.Weight})
    259 	}
    260 	byPriorityWeight(srvs).sort()
    261 	return name, srvs, nil
    262 }
    263 
    264 func lookupMX(name string) ([]*MX, error) {
    265 	acquireThread()
    266 	defer releaseThread()
    267 	var r *syscall.DNSRecord
    268 	e := syscall.DnsQuery(name, syscall.DNS_TYPE_MX, 0, nil, &r, nil)
    269 	if e != nil {
    270 		return nil, &DNSError{Err: os.NewSyscallError("dnsquery", e).Error(), Name: name}
    271 	}
    272 	defer syscall.DnsRecordListFree(r, 1)
    273 
    274 	mxs := make([]*MX, 0, 10)
    275 	for _, p := range validRecs(r, syscall.DNS_TYPE_MX, name) {
    276 		v := (*syscall.DNSMXData)(unsafe.Pointer(&p.Data[0]))
    277 		mxs = append(mxs, &MX{syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.NameExchange))[:]) + ".", v.Preference})
    278 	}
    279 	byPref(mxs).sort()
    280 	return mxs, nil
    281 }
    282 
    283 func lookupNS(name string) ([]*NS, error) {
    284 	acquireThread()
    285 	defer releaseThread()
    286 	var r *syscall.DNSRecord
    287 	e := syscall.DnsQuery(name, syscall.DNS_TYPE_NS, 0, nil, &r, nil)
    288 	if e != nil {
    289 		return nil, &DNSError{Err: os.NewSyscallError("dnsquery", e).Error(), Name: name}
    290 	}
    291 	defer syscall.DnsRecordListFree(r, 1)
    292 
    293 	nss := make([]*NS, 0, 10)
    294 	for _, p := range validRecs(r, syscall.DNS_TYPE_NS, name) {
    295 		v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0]))
    296 		nss = append(nss, &NS{syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:]) + "."})
    297 	}
    298 	return nss, nil
    299 }
    300 
    301 func lookupTXT(name string) ([]string, error) {
    302 	acquireThread()
    303 	defer releaseThread()
    304 	var r *syscall.DNSRecord
    305 	e := syscall.DnsQuery(name, syscall.DNS_TYPE_TEXT, 0, nil, &r, nil)
    306 	if e != nil {
    307 		return nil, &DNSError{Err: os.NewSyscallError("dnsquery", e).Error(), Name: name}
    308 	}
    309 	defer syscall.DnsRecordListFree(r, 1)
    310 
    311 	txts := make([]string, 0, 10)
    312 	for _, p := range validRecs(r, syscall.DNS_TYPE_TEXT, name) {
    313 		d := (*syscall.DNSTXTData)(unsafe.Pointer(&p.Data[0]))
    314 		for _, v := range (*[1 << 10]*uint16)(unsafe.Pointer(&(d.StringArray[0])))[:d.StringCount] {
    315 			s := syscall.UTF16ToString((*[1 << 20]uint16)(unsafe.Pointer(v))[:])
    316 			txts = append(txts, s)
    317 		}
    318 	}
    319 	return txts, nil
    320 }
    321 
    322 func lookupAddr(addr string) ([]string, error) {
    323 	acquireThread()
    324 	defer releaseThread()
    325 	arpa, err := reverseaddr(addr)
    326 	if err != nil {
    327 		return nil, err
    328 	}
    329 	var r *syscall.DNSRecord
    330 	e := syscall.DnsQuery(arpa, syscall.DNS_TYPE_PTR, 0, nil, &r, nil)
    331 	if e != nil {
    332 		return nil, &DNSError{Err: os.NewSyscallError("dnsquery", e).Error(), Name: addr}
    333 	}
    334 	defer syscall.DnsRecordListFree(r, 1)
    335 
    336 	ptrs := make([]string, 0, 10)
    337 	for _, p := range validRecs(r, syscall.DNS_TYPE_PTR, arpa) {
    338 		v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0]))
    339 		ptrs = append(ptrs, syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:]))
    340 	}
    341 	return ptrs, nil
    342 }
    343 
    344 const dnsSectionMask = 0x0003
    345 
    346 // returns only results applicable to name and resolves CNAME entries
    347 func validRecs(r *syscall.DNSRecord, dnstype uint16, name string) []*syscall.DNSRecord {
    348 	cname := syscall.StringToUTF16Ptr(name)
    349 	if dnstype != syscall.DNS_TYPE_CNAME {
    350 		cname = resolveCNAME(cname, r)
    351 	}
    352 	rec := make([]*syscall.DNSRecord, 0, 10)
    353 	for p := r; p != nil; p = p.Next {
    354 		if p.Dw&dnsSectionMask != syscall.DnsSectionAnswer {
    355 			continue
    356 		}
    357 		if p.Type != dnstype {
    358 			continue
    359 		}
    360 		if !syscall.DnsNameCompare(cname, p.Name) {
    361 			continue
    362 		}
    363 		rec = append(rec, p)
    364 	}
    365 	return rec
    366 }
    367 
    368 // returns the last CNAME in chain
    369 func resolveCNAME(name *uint16, r *syscall.DNSRecord) *uint16 {
    370 	// limit cname resolving to 10 in case of a infinite CNAME loop
    371 Cname:
    372 	for cnameloop := 0; cnameloop < 10; cnameloop++ {
    373 		for p := r; p != nil; p = p.Next {
    374 			if p.Dw&dnsSectionMask != syscall.DnsSectionAnswer {
    375 				continue
    376 			}
    377 			if p.Type != syscall.DNS_TYPE_CNAME {
    378 				continue
    379 			}
    380 			if !syscall.DnsNameCompare(name, p.Name) {
    381 				continue
    382 			}
    383 			name = (*syscall.DNSPTRData)(unsafe.Pointer(&r.Data[0])).Host
    384 			continue Cname
    385 		}
    386 		break
    387 	}
    388 	return name
    389 }
    390