Home | History | Annotate | Download | only in net
      1 // Copyright 2010 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 package net
      6 
      7 import (
      8 	"context"
      9 	"internal/poll"
     10 	"os"
     11 	"runtime"
     12 	"syscall"
     13 	"unsafe"
     14 )
     15 
     16 // canUseConnectEx reports whether we can use the ConnectEx Windows API call
     17 // for the given network type.
     18 func canUseConnectEx(net string) bool {
     19 	switch net {
     20 	case "tcp", "tcp4", "tcp6":
     21 		return true
     22 	}
     23 	// ConnectEx windows API does not support connectionless sockets.
     24 	return false
     25 }
     26 
     27 // Network file descriptor.
     28 type netFD struct {
     29 	pfd poll.FD
     30 
     31 	// immutable until Close
     32 	family      int
     33 	sotype      int
     34 	isConnected bool
     35 	net         string
     36 	laddr       Addr
     37 	raddr       Addr
     38 }
     39 
     40 func newFD(sysfd syscall.Handle, family, sotype int, net string) (*netFD, error) {
     41 	ret := &netFD{
     42 		pfd: poll.FD{
     43 			Sysfd:         sysfd,
     44 			IsStream:      sotype == syscall.SOCK_STREAM,
     45 			ZeroReadIsEOF: sotype != syscall.SOCK_DGRAM && sotype != syscall.SOCK_RAW,
     46 		},
     47 		family: family,
     48 		sotype: sotype,
     49 		net:    net,
     50 	}
     51 	return ret, nil
     52 }
     53 
     54 func (fd *netFD) init() error {
     55 	errcall, err := fd.pfd.Init(fd.net, true)
     56 	if errcall != "" {
     57 		err = wrapSyscallError(errcall, err)
     58 	}
     59 	return err
     60 }
     61 
     62 func (fd *netFD) setAddr(laddr, raddr Addr) {
     63 	fd.laddr = laddr
     64 	fd.raddr = raddr
     65 	runtime.SetFinalizer(fd, (*netFD).Close)
     66 }
     67 
     68 // Always returns nil for connected peer address result.
     69 func (fd *netFD) connect(ctx context.Context, la, ra syscall.Sockaddr) (syscall.Sockaddr, error) {
     70 	// Do not need to call fd.writeLock here,
     71 	// because fd is not yet accessible to user,
     72 	// so no concurrent operations are possible.
     73 	if err := fd.init(); err != nil {
     74 		return nil, err
     75 	}
     76 	if deadline, ok := ctx.Deadline(); ok && !deadline.IsZero() {
     77 		fd.pfd.SetWriteDeadline(deadline)
     78 		defer fd.pfd.SetWriteDeadline(noDeadline)
     79 	}
     80 	if !canUseConnectEx(fd.net) {
     81 		err := connectFunc(fd.pfd.Sysfd, ra)
     82 		return nil, os.NewSyscallError("connect", err)
     83 	}
     84 	// ConnectEx windows API requires an unconnected, previously bound socket.
     85 	if la == nil {
     86 		switch ra.(type) {
     87 		case *syscall.SockaddrInet4:
     88 			la = &syscall.SockaddrInet4{}
     89 		case *syscall.SockaddrInet6:
     90 			la = &syscall.SockaddrInet6{}
     91 		default:
     92 			panic("unexpected type in connect")
     93 		}
     94 		if err := syscall.Bind(fd.pfd.Sysfd, la); err != nil {
     95 			return nil, os.NewSyscallError("bind", err)
     96 		}
     97 	}
     98 
     99 	// Wait for the goroutine converting context.Done into a write timeout
    100 	// to exist, otherwise our caller might cancel the context and
    101 	// cause fd.setWriteDeadline(aLongTimeAgo) to cancel a successful dial.
    102 	done := make(chan bool) // must be unbuffered
    103 	defer func() { done <- true }()
    104 	go func() {
    105 		select {
    106 		case <-ctx.Done():
    107 			// Force the runtime's poller to immediately give
    108 			// up waiting for writability.
    109 			fd.pfd.SetWriteDeadline(aLongTimeAgo)
    110 			<-done
    111 		case <-done:
    112 		}
    113 	}()
    114 
    115 	// Call ConnectEx API.
    116 	if err := fd.pfd.ConnectEx(ra); err != nil {
    117 		select {
    118 		case <-ctx.Done():
    119 			return nil, mapErr(ctx.Err())
    120 		default:
    121 			if _, ok := err.(syscall.Errno); ok {
    122 				err = os.NewSyscallError("connectex", err)
    123 			}
    124 			return nil, err
    125 		}
    126 	}
    127 	// Refresh socket properties.
    128 	return nil, os.NewSyscallError("setsockopt", syscall.Setsockopt(fd.pfd.Sysfd, syscall.SOL_SOCKET, syscall.SO_UPDATE_CONNECT_CONTEXT, (*byte)(unsafe.Pointer(&fd.pfd.Sysfd)), int32(unsafe.Sizeof(fd.pfd.Sysfd))))
    129 }
    130 
    131 func (fd *netFD) Close() error {
    132 	runtime.SetFinalizer(fd, nil)
    133 	return fd.pfd.Close()
    134 }
    135 
    136 func (fd *netFD) shutdown(how int) error {
    137 	err := fd.pfd.Shutdown(how)
    138 	runtime.KeepAlive(fd)
    139 	return err
    140 }
    141 
    142 func (fd *netFD) closeRead() error {
    143 	return fd.shutdown(syscall.SHUT_RD)
    144 }
    145 
    146 func (fd *netFD) closeWrite() error {
    147 	return fd.shutdown(syscall.SHUT_WR)
    148 }
    149 
    150 func (fd *netFD) Read(buf []byte) (int, error) {
    151 	n, err := fd.pfd.Read(buf)
    152 	runtime.KeepAlive(fd)
    153 	return n, wrapSyscallError("wsarecv", err)
    154 }
    155 
    156 func (fd *netFD) readFrom(buf []byte) (int, syscall.Sockaddr, error) {
    157 	n, sa, err := fd.pfd.ReadFrom(buf)
    158 	runtime.KeepAlive(fd)
    159 	return n, sa, wrapSyscallError("wsarecvfrom", err)
    160 }
    161 
    162 func (fd *netFD) Write(buf []byte) (int, error) {
    163 	n, err := fd.pfd.Write(buf)
    164 	runtime.KeepAlive(fd)
    165 	return n, wrapSyscallError("wsasend", err)
    166 }
    167 
    168 func (c *conn) writeBuffers(v *Buffers) (int64, error) {
    169 	if !c.ok() {
    170 		return 0, syscall.EINVAL
    171 	}
    172 	n, err := c.fd.writeBuffers(v)
    173 	if err != nil {
    174 		return n, &OpError{Op: "wsasend", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
    175 	}
    176 	return n, nil
    177 }
    178 
    179 func (fd *netFD) writeBuffers(buf *Buffers) (int64, error) {
    180 	n, err := fd.pfd.Writev((*[][]byte)(buf))
    181 	runtime.KeepAlive(fd)
    182 	return n, wrapSyscallError("wsasend", err)
    183 }
    184 
    185 func (fd *netFD) writeTo(buf []byte, sa syscall.Sockaddr) (int, error) {
    186 	n, err := fd.pfd.WriteTo(buf, sa)
    187 	runtime.KeepAlive(fd)
    188 	return n, wrapSyscallError("wsasendto", err)
    189 }
    190 
    191 func (fd *netFD) accept() (*netFD, error) {
    192 	s, rawsa, rsan, errcall, err := fd.pfd.Accept(func() (syscall.Handle, error) {
    193 		return sysSocket(fd.family, fd.sotype, 0)
    194 	})
    195 
    196 	if err != nil {
    197 		if errcall != "" {
    198 			err = wrapSyscallError(errcall, err)
    199 		}
    200 		return nil, err
    201 	}
    202 
    203 	// Associate our new socket with IOCP.
    204 	netfd, err := newFD(s, fd.family, fd.sotype, fd.net)
    205 	if err != nil {
    206 		poll.CloseFunc(s)
    207 		return nil, err
    208 	}
    209 	if err := netfd.init(); err != nil {
    210 		fd.Close()
    211 		return nil, err
    212 	}
    213 
    214 	// Get local and peer addr out of AcceptEx buffer.
    215 	var lrsa, rrsa *syscall.RawSockaddrAny
    216 	var llen, rlen int32
    217 	syscall.GetAcceptExSockaddrs((*byte)(unsafe.Pointer(&rawsa[0])),
    218 		0, rsan, rsan, &lrsa, &llen, &rrsa, &rlen)
    219 	lsa, _ := lrsa.Sockaddr()
    220 	rsa, _ := rrsa.Sockaddr()
    221 
    222 	netfd.setAddr(netfd.addrFunc()(lsa), netfd.addrFunc()(rsa))
    223 	return netfd, nil
    224 }
    225 
    226 func (fd *netFD) readMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.Sockaddr, err error) {
    227 	n, oobn, flags, sa, err = fd.pfd.ReadMsg(p, oob)
    228 	runtime.KeepAlive(fd)
    229 	return n, oobn, flags, sa, wrapSyscallError("wsarecvmsg", err)
    230 }
    231 
    232 func (fd *netFD) writeMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err error) {
    233 	n, oobn, err = fd.pfd.WriteMsg(p, oob, sa)
    234 	runtime.KeepAlive(fd)
    235 	return n, oobn, wrapSyscallError("wsasendmsg", err)
    236 }
    237 
    238 // Unimplemented functions.
    239 
    240 func (fd *netFD) dup() (*os.File, error) {
    241 	// TODO: Implement this
    242 	return nil, syscall.EWINDOWS
    243 }
    244