Home | History | Annotate | Download | only in net
      1 // Copyright 2010 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 package net
      6 
      7 import (
      8 	"context"
      9 	"internal/nettrace"
     10 	"time"
     11 )
     12 
     13 // A Dialer contains options for connecting to an address.
     14 //
     15 // The zero value for each field is equivalent to dialing
     16 // without that option. Dialing with the zero value of Dialer
     17 // is therefore equivalent to just calling the Dial function.
     18 type Dialer struct {
     19 	// Timeout is the maximum amount of time a dial will wait for
     20 	// a connect to complete. If Deadline is also set, it may fail
     21 	// earlier.
     22 	//
     23 	// The default is no timeout.
     24 	//
     25 	// When dialing a name with multiple IP addresses, the timeout
     26 	// may be divided between them.
     27 	//
     28 	// With or without a timeout, the operating system may impose
     29 	// its own earlier timeout. For instance, TCP timeouts are
     30 	// often around 3 minutes.
     31 	Timeout time.Duration
     32 
     33 	// Deadline is the absolute point in time after which dials
     34 	// will fail. If Timeout is set, it may fail earlier.
     35 	// Zero means no deadline, or dependent on the operating system
     36 	// as with the Timeout option.
     37 	Deadline time.Time
     38 
     39 	// LocalAddr is the local address to use when dialing an
     40 	// address. The address must be of a compatible type for the
     41 	// network being dialed.
     42 	// If nil, a local address is automatically chosen.
     43 	LocalAddr Addr
     44 
     45 	// DualStack enables RFC 6555-compliant "Happy Eyeballs" dialing
     46 	// when the network is "tcp" and the destination is a host name
     47 	// with both IPv4 and IPv6 addresses. This allows a client to
     48 	// tolerate networks where one address family is silently broken.
     49 	DualStack bool
     50 
     51 	// FallbackDelay specifies the length of time to wait before
     52 	// spawning a fallback connection, when DualStack is enabled.
     53 	// If zero, a default delay of 300ms is used.
     54 	FallbackDelay time.Duration
     55 
     56 	// KeepAlive specifies the keep-alive period for an active
     57 	// network connection.
     58 	// If zero, keep-alives are not enabled. Network protocols
     59 	// that do not support keep-alives ignore this field.
     60 	KeepAlive time.Duration
     61 
     62 	// Resolver optionally specifies an alternate resolver to use.
     63 	Resolver *Resolver
     64 
     65 	// Cancel is an optional channel whose closure indicates that
     66 	// the dial should be canceled. Not all types of dials support
     67 	// cancelation.
     68 	//
     69 	// Deprecated: Use DialContext instead.
     70 	Cancel <-chan struct{}
     71 }
     72 
     73 func minNonzeroTime(a, b time.Time) time.Time {
     74 	if a.IsZero() {
     75 		return b
     76 	}
     77 	if b.IsZero() || a.Before(b) {
     78 		return a
     79 	}
     80 	return b
     81 }
     82 
     83 // deadline returns the earliest of:
     84 //   - now+Timeout
     85 //   - d.Deadline
     86 //   - the context's deadline
     87 // Or zero, if none of Timeout, Deadline, or context's deadline is set.
     88 func (d *Dialer) deadline(ctx context.Context, now time.Time) (earliest time.Time) {
     89 	if d.Timeout != 0 { // including negative, for historical reasons
     90 		earliest = now.Add(d.Timeout)
     91 	}
     92 	if d, ok := ctx.Deadline(); ok {
     93 		earliest = minNonzeroTime(earliest, d)
     94 	}
     95 	return minNonzeroTime(earliest, d.Deadline)
     96 }
     97 
     98 func (d *Dialer) resolver() *Resolver {
     99 	if d.Resolver != nil {
    100 		return d.Resolver
    101 	}
    102 	return DefaultResolver
    103 }
    104 
    105 // partialDeadline returns the deadline to use for a single address,
    106 // when multiple addresses are pending.
    107 func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, error) {
    108 	if deadline.IsZero() {
    109 		return deadline, nil
    110 	}
    111 	timeRemaining := deadline.Sub(now)
    112 	if timeRemaining <= 0 {
    113 		return time.Time{}, errTimeout
    114 	}
    115 	// Tentatively allocate equal time to each remaining address.
    116 	timeout := timeRemaining / time.Duration(addrsRemaining)
    117 	// If the time per address is too short, steal from the end of the list.
    118 	const saneMinimum = 2 * time.Second
    119 	if timeout < saneMinimum {
    120 		if timeRemaining < saneMinimum {
    121 			timeout = timeRemaining
    122 		} else {
    123 			timeout = saneMinimum
    124 		}
    125 	}
    126 	return now.Add(timeout), nil
    127 }
    128 
    129 func (d *Dialer) fallbackDelay() time.Duration {
    130 	if d.FallbackDelay > 0 {
    131 		return d.FallbackDelay
    132 	} else {
    133 		return 300 * time.Millisecond
    134 	}
    135 }
    136 
    137 func parseNetwork(ctx context.Context, net string) (afnet string, proto int, err error) {
    138 	i := last(net, ':')
    139 	if i < 0 { // no colon
    140 		switch net {
    141 		case "tcp", "tcp4", "tcp6":
    142 		case "udp", "udp4", "udp6":
    143 		case "ip", "ip4", "ip6":
    144 		case "unix", "unixgram", "unixpacket":
    145 		default:
    146 			return "", 0, UnknownNetworkError(net)
    147 		}
    148 		return net, 0, nil
    149 	}
    150 	afnet = net[:i]
    151 	switch afnet {
    152 	case "ip", "ip4", "ip6":
    153 		protostr := net[i+1:]
    154 		proto, i, ok := dtoi(protostr)
    155 		if !ok || i != len(protostr) {
    156 			proto, err = lookupProtocol(ctx, protostr)
    157 			if err != nil {
    158 				return "", 0, err
    159 			}
    160 		}
    161 		return afnet, proto, nil
    162 	}
    163 	return "", 0, UnknownNetworkError(net)
    164 }
    165 
    166 // resolveAddrList resolves addr using hint and returns a list of
    167 // addresses. The result contains at least one address when error is
    168 // nil.
    169 func (r *Resolver) resolveAddrList(ctx context.Context, op, network, addr string, hint Addr) (addrList, error) {
    170 	afnet, _, err := parseNetwork(ctx, network)
    171 	if err != nil {
    172 		return nil, err
    173 	}
    174 	if op == "dial" && addr == "" {
    175 		return nil, errMissingAddress
    176 	}
    177 	switch afnet {
    178 	case "unix", "unixgram", "unixpacket":
    179 		addr, err := ResolveUnixAddr(afnet, addr)
    180 		if err != nil {
    181 			return nil, err
    182 		}
    183 		if op == "dial" && hint != nil && addr.Network() != hint.Network() {
    184 			return nil, &AddrError{Err: "mismatched local address type", Addr: hint.String()}
    185 		}
    186 		return addrList{addr}, nil
    187 	}
    188 	addrs, err := r.internetAddrList(ctx, afnet, addr)
    189 	if err != nil || op != "dial" || hint == nil {
    190 		return addrs, err
    191 	}
    192 	var (
    193 		tcp      *TCPAddr
    194 		udp      *UDPAddr
    195 		ip       *IPAddr
    196 		wildcard bool
    197 	)
    198 	switch hint := hint.(type) {
    199 	case *TCPAddr:
    200 		tcp = hint
    201 		wildcard = tcp.isWildcard()
    202 	case *UDPAddr:
    203 		udp = hint
    204 		wildcard = udp.isWildcard()
    205 	case *IPAddr:
    206 		ip = hint
    207 		wildcard = ip.isWildcard()
    208 	}
    209 	naddrs := addrs[:0]
    210 	for _, addr := range addrs {
    211 		if addr.Network() != hint.Network() {
    212 			return nil, &AddrError{Err: "mismatched local address type", Addr: hint.String()}
    213 		}
    214 		switch addr := addr.(type) {
    215 		case *TCPAddr:
    216 			if !wildcard && !addr.isWildcard() && !addr.IP.matchAddrFamily(tcp.IP) {
    217 				continue
    218 			}
    219 			naddrs = append(naddrs, addr)
    220 		case *UDPAddr:
    221 			if !wildcard && !addr.isWildcard() && !addr.IP.matchAddrFamily(udp.IP) {
    222 				continue
    223 			}
    224 			naddrs = append(naddrs, addr)
    225 		case *IPAddr:
    226 			if !wildcard && !addr.isWildcard() && !addr.IP.matchAddrFamily(ip.IP) {
    227 				continue
    228 			}
    229 			naddrs = append(naddrs, addr)
    230 		}
    231 	}
    232 	if len(naddrs) == 0 {
    233 		return nil, &AddrError{Err: errNoSuitableAddress.Error(), Addr: hint.String()}
    234 	}
    235 	return naddrs, nil
    236 }
    237 
    238 // Dial connects to the address on the named network.
    239 //
    240 // Known networks are "tcp", "tcp4" (IPv4-only), "tcp6" (IPv6-only),
    241 // "udp", "udp4" (IPv4-only), "udp6" (IPv6-only), "ip", "ip4"
    242 // (IPv4-only), "ip6" (IPv6-only), "unix", "unixgram" and
    243 // "unixpacket".
    244 //
    245 // For TCP and UDP networks, addresses have the form host:port.
    246 // If host is a literal IPv6 address it must be enclosed
    247 // in square brackets as in "[::1]:80" or "[ipv6-host%zone]:80".
    248 // The functions JoinHostPort and SplitHostPort manipulate addresses
    249 // in this form.
    250 // If the host is empty, as in ":80", the local system is assumed.
    251 //
    252 // Examples:
    253 //	Dial("tcp", "192.0.2.1:80")
    254 //	Dial("tcp", "golang.org:http")
    255 //	Dial("tcp", "[2001:db8::1]:http")
    256 //	Dial("tcp", "[fe80::1%lo0]:80")
    257 //	Dial("tcp", ":80")
    258 //
    259 // For IP networks, the network must be "ip", "ip4" or "ip6" followed
    260 // by a colon and a protocol number or name and the addr must be a
    261 // literal IP address.
    262 //
    263 // Examples:
    264 //	Dial("ip4:1", "192.0.2.1")
    265 //	Dial("ip6:ipv6-icmp", "2001:db8::1")
    266 //
    267 // For Unix networks, the address must be a file system path.
    268 //
    269 // If the host is resolved to multiple addresses,
    270 // Dial will try each address in order until one succeeds.
    271 func Dial(network, address string) (Conn, error) {
    272 	var d Dialer
    273 	return d.Dial(network, address)
    274 }
    275 
    276 // DialTimeout acts like Dial but takes a timeout.
    277 // The timeout includes name resolution, if required.
    278 func DialTimeout(network, address string, timeout time.Duration) (Conn, error) {
    279 	d := Dialer{Timeout: timeout}
    280 	return d.Dial(network, address)
    281 }
    282 
    283 // dialParam contains a Dial's parameters and configuration.
    284 type dialParam struct {
    285 	Dialer
    286 	network, address string
    287 }
    288 
    289 // Dial connects to the address on the named network.
    290 //
    291 // See func Dial for a description of the network and address
    292 // parameters.
    293 func (d *Dialer) Dial(network, address string) (Conn, error) {
    294 	return d.DialContext(context.Background(), network, address)
    295 }
    296 
    297 // DialContext connects to the address on the named network using
    298 // the provided context.
    299 //
    300 // The provided Context must be non-nil. If the context expires before
    301 // the connection is complete, an error is returned. Once successfully
    302 // connected, any expiration of the context will not affect the
    303 // connection.
    304 //
    305 // When using TCP, and the host in the address parameter resolves to multiple
    306 // network addresses, any dial timeout (from d.Timeout or ctx) is spread
    307 // over each consecutive dial, such that each is given an appropriate
    308 // fraction of the time to connect.
    309 // For example, if a host has 4 IP addresses and the timeout is 1 minute,
    310 // the connect to each single address will be given 15 seconds to complete
    311 // before trying the next one.
    312 //
    313 // See func Dial for a description of the network and address
    314 // parameters.
    315 func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn, error) {
    316 	if ctx == nil {
    317 		panic("nil context")
    318 	}
    319 	deadline := d.deadline(ctx, time.Now())
    320 	if !deadline.IsZero() {
    321 		if d, ok := ctx.Deadline(); !ok || deadline.Before(d) {
    322 			subCtx, cancel := context.WithDeadline(ctx, deadline)
    323 			defer cancel()
    324 			ctx = subCtx
    325 		}
    326 	}
    327 	if oldCancel := d.Cancel; oldCancel != nil {
    328 		subCtx, cancel := context.WithCancel(ctx)
    329 		defer cancel()
    330 		go func() {
    331 			select {
    332 			case <-oldCancel:
    333 				cancel()
    334 			case <-subCtx.Done():
    335 			}
    336 		}()
    337 		ctx = subCtx
    338 	}
    339 
    340 	// Shadow the nettrace (if any) during resolve so Connect events don't fire for DNS lookups.
    341 	resolveCtx := ctx
    342 	if trace, _ := ctx.Value(nettrace.TraceKey{}).(*nettrace.Trace); trace != nil {
    343 		shadow := *trace
    344 		shadow.ConnectStart = nil
    345 		shadow.ConnectDone = nil
    346 		resolveCtx = context.WithValue(resolveCtx, nettrace.TraceKey{}, &shadow)
    347 	}
    348 
    349 	addrs, err := d.resolver().resolveAddrList(resolveCtx, "dial", network, address, d.LocalAddr)
    350 	if err != nil {
    351 		return nil, &OpError{Op: "dial", Net: network, Source: nil, Addr: nil, Err: err}
    352 	}
    353 
    354 	dp := &dialParam{
    355 		Dialer:  *d,
    356 		network: network,
    357 		address: address,
    358 	}
    359 
    360 	var primaries, fallbacks addrList
    361 	if d.DualStack && network == "tcp" {
    362 		primaries, fallbacks = addrs.partition(isIPv4)
    363 	} else {
    364 		primaries = addrs
    365 	}
    366 
    367 	var c Conn
    368 	if len(fallbacks) > 0 {
    369 		c, err = dialParallel(ctx, dp, primaries, fallbacks)
    370 	} else {
    371 		c, err = dialSerial(ctx, dp, primaries)
    372 	}
    373 	if err != nil {
    374 		return nil, err
    375 	}
    376 
    377 	if tc, ok := c.(*TCPConn); ok && d.KeepAlive > 0 {
    378 		setKeepAlive(tc.fd, true)
    379 		setKeepAlivePeriod(tc.fd, d.KeepAlive)
    380 		testHookSetKeepAlive()
    381 	}
    382 	return c, nil
    383 }
    384 
    385 // dialParallel races two copies of dialSerial, giving the first a
    386 // head start. It returns the first established connection and
    387 // closes the others. Otherwise it returns an error from the first
    388 // primary address.
    389 func dialParallel(ctx context.Context, dp *dialParam, primaries, fallbacks addrList) (Conn, error) {
    390 	if len(fallbacks) == 0 {
    391 		return dialSerial(ctx, dp, primaries)
    392 	}
    393 
    394 	returned := make(chan struct{})
    395 	defer close(returned)
    396 
    397 	type dialResult struct {
    398 		Conn
    399 		error
    400 		primary bool
    401 		done    bool
    402 	}
    403 	results := make(chan dialResult) // unbuffered
    404 
    405 	startRacer := func(ctx context.Context, primary bool) {
    406 		ras := primaries
    407 		if !primary {
    408 			ras = fallbacks
    409 		}
    410 		c, err := dialSerial(ctx, dp, ras)
    411 		select {
    412 		case results <- dialResult{Conn: c, error: err, primary: primary, done: true}:
    413 		case <-returned:
    414 			if c != nil {
    415 				c.Close()
    416 			}
    417 		}
    418 	}
    419 
    420 	var primary, fallback dialResult
    421 
    422 	// Start the main racer.
    423 	primaryCtx, primaryCancel := context.WithCancel(ctx)
    424 	defer primaryCancel()
    425 	go startRacer(primaryCtx, true)
    426 
    427 	// Start the timer for the fallback racer.
    428 	fallbackTimer := time.NewTimer(dp.fallbackDelay())
    429 	defer fallbackTimer.Stop()
    430 
    431 	for {
    432 		select {
    433 		case <-fallbackTimer.C:
    434 			fallbackCtx, fallbackCancel := context.WithCancel(ctx)
    435 			defer fallbackCancel()
    436 			go startRacer(fallbackCtx, false)
    437 
    438 		case res := <-results:
    439 			if res.error == nil {
    440 				return res.Conn, nil
    441 			}
    442 			if res.primary {
    443 				primary = res
    444 			} else {
    445 				fallback = res
    446 			}
    447 			if primary.done && fallback.done {
    448 				return nil, primary.error
    449 			}
    450 			if res.primary && fallbackTimer.Stop() {
    451 				// If we were able to stop the timer, that means it
    452 				// was running (hadn't yet started the fallback), but
    453 				// we just got an error on the primary path, so start
    454 				// the fallback immediately (in 0 nanoseconds).
    455 				fallbackTimer.Reset(0)
    456 			}
    457 		}
    458 	}
    459 }
    460 
    461 // dialSerial connects to a list of addresses in sequence, returning
    462 // either the first successful connection, or the first error.
    463 func dialSerial(ctx context.Context, dp *dialParam, ras addrList) (Conn, error) {
    464 	var firstErr error // The error from the first address is most relevant.
    465 
    466 	for i, ra := range ras {
    467 		select {
    468 		case <-ctx.Done():
    469 			return nil, &OpError{Op: "dial", Net: dp.network, Source: dp.LocalAddr, Addr: ra, Err: mapErr(ctx.Err())}
    470 		default:
    471 		}
    472 
    473 		deadline, _ := ctx.Deadline()
    474 		partialDeadline, err := partialDeadline(time.Now(), deadline, len(ras)-i)
    475 		if err != nil {
    476 			// Ran out of time.
    477 			if firstErr == nil {
    478 				firstErr = &OpError{Op: "dial", Net: dp.network, Source: dp.LocalAddr, Addr: ra, Err: err}
    479 			}
    480 			break
    481 		}
    482 		dialCtx := ctx
    483 		if partialDeadline.Before(deadline) {
    484 			var cancel context.CancelFunc
    485 			dialCtx, cancel = context.WithDeadline(ctx, partialDeadline)
    486 			defer cancel()
    487 		}
    488 
    489 		c, err := dialSingle(dialCtx, dp, ra)
    490 		if err == nil {
    491 			return c, nil
    492 		}
    493 		if firstErr == nil {
    494 			firstErr = err
    495 		}
    496 	}
    497 
    498 	if firstErr == nil {
    499 		firstErr = &OpError{Op: "dial", Net: dp.network, Source: nil, Addr: nil, Err: errMissingAddress}
    500 	}
    501 	return nil, firstErr
    502 }
    503 
    504 // dialSingle attempts to establish and returns a single connection to
    505 // the destination address.
    506 func dialSingle(ctx context.Context, dp *dialParam, ra Addr) (c Conn, err error) {
    507 	trace, _ := ctx.Value(nettrace.TraceKey{}).(*nettrace.Trace)
    508 	if trace != nil {
    509 		raStr := ra.String()
    510 		if trace.ConnectStart != nil {
    511 			trace.ConnectStart(dp.network, raStr)
    512 		}
    513 		if trace.ConnectDone != nil {
    514 			defer func() { trace.ConnectDone(dp.network, raStr, err) }()
    515 		}
    516 	}
    517 	la := dp.LocalAddr
    518 	switch ra := ra.(type) {
    519 	case *TCPAddr:
    520 		la, _ := la.(*TCPAddr)
    521 		c, err = dialTCP(ctx, dp.network, la, ra)
    522 	case *UDPAddr:
    523 		la, _ := la.(*UDPAddr)
    524 		c, err = dialUDP(ctx, dp.network, la, ra)
    525 	case *IPAddr:
    526 		la, _ := la.(*IPAddr)
    527 		c, err = dialIP(ctx, dp.network, la, ra)
    528 	case *UnixAddr:
    529 		la, _ := la.(*UnixAddr)
    530 		c, err = dialUnix(ctx, dp.network, la, ra)
    531 	default:
    532 		return nil, &OpError{Op: "dial", Net: dp.network, Source: la, Addr: ra, Err: &AddrError{Err: "unexpected address type", Addr: dp.address}}
    533 	}
    534 	if err != nil {
    535 		return nil, &OpError{Op: "dial", Net: dp.network, Source: la, Addr: ra, Err: err} // c is non-nil interface containing nil pointer
    536 	}
    537 	return c, nil
    538 }
    539 
    540 // Listen announces on the local network address laddr.
    541 // The network net must be a stream-oriented network: "tcp", "tcp4",
    542 // "tcp6", "unix" or "unixpacket".
    543 // For TCP and UDP, the syntax of laddr is "host:port", like "127.0.0.1:8080".
    544 // If host is omitted, as in ":8080", Listen listens on all available interfaces
    545 // instead of just the interface with the given host address.
    546 // See Dial for more details about address syntax.
    547 //
    548 // Listening on a hostname is not recommended because this creates a socket
    549 // for at most one of its IP addresses.
    550 func Listen(net, laddr string) (Listener, error) {
    551 	addrs, err := DefaultResolver.resolveAddrList(context.Background(), "listen", net, laddr, nil)
    552 	if err != nil {
    553 		return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: nil, Err: err}
    554 	}
    555 	var l Listener
    556 	switch la := addrs.first(isIPv4).(type) {
    557 	case *TCPAddr:
    558 		l, err = ListenTCP(net, la)
    559 	case *UnixAddr:
    560 		l, err = ListenUnix(net, la)
    561 	default:
    562 		return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: la, Err: &AddrError{Err: "unexpected address type", Addr: laddr}}
    563 	}
    564 	if err != nil {
    565 		return nil, err // l is non-nil interface containing nil pointer
    566 	}
    567 	return l, nil
    568 }
    569 
    570 // ListenPacket announces on the local network address laddr.
    571 // The network net must be a packet-oriented network: "udp", "udp4",
    572 // "udp6", "ip", "ip4", "ip6" or "unixgram".
    573 // For TCP and UDP, the syntax of laddr is "host:port", like "127.0.0.1:8080".
    574 // If host is omitted, as in ":8080", ListenPacket listens on all available interfaces
    575 // instead of just the interface with the given host address.
    576 // See Dial for the syntax of laddr.
    577 //
    578 // Listening on a hostname is not recommended because this creates a socket
    579 // for at most one of its IP addresses.
    580 func ListenPacket(net, laddr string) (PacketConn, error) {
    581 	addrs, err := DefaultResolver.resolveAddrList(context.Background(), "listen", net, laddr, nil)
    582 	if err != nil {
    583 		return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: nil, Err: err}
    584 	}
    585 	var l PacketConn
    586 	switch la := addrs.first(isIPv4).(type) {
    587 	case *UDPAddr:
    588 		l, err = ListenUDP(net, la)
    589 	case *IPAddr:
    590 		l, err = ListenIP(net, la)
    591 	case *UnixAddr:
    592 		l, err = ListenUnixgram(net, la)
    593 	default:
    594 		return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: la, Err: &AddrError{Err: "unexpected address type", Addr: laddr}}
    595 	}
    596 	if err != nil {
    597 		return nil, err // l is non-nil interface containing nil pointer
    598 	}
    599 	return l, nil
    600 }
    601