Home | History | Annotate | Download | only in net
      1 // Copyright 2009 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 package net
      6 
      7 import (
      8 	"errors"
      9 	"fmt"
     10 	"io"
     11 	"net/internal/socktest"
     12 	"os"
     13 	"runtime"
     14 	"testing"
     15 	"time"
     16 )
     17 
     18 func TestCloseRead(t *testing.T) {
     19 	switch runtime.GOOS {
     20 	case "plan9":
     21 		t.Skipf("not supported on %s", runtime.GOOS)
     22 	}
     23 
     24 	for _, network := range []string{"tcp", "unix", "unixpacket"} {
     25 		if !testableNetwork(network) {
     26 			t.Logf("skipping %s test", network)
     27 			continue
     28 		}
     29 
     30 		ln, err := newLocalListener(network)
     31 		if err != nil {
     32 			t.Fatal(err)
     33 		}
     34 		switch network {
     35 		case "unix", "unixpacket":
     36 			defer os.Remove(ln.Addr().String())
     37 		}
     38 		defer ln.Close()
     39 
     40 		c, err := Dial(ln.Addr().Network(), ln.Addr().String())
     41 		if err != nil {
     42 			t.Fatal(err)
     43 		}
     44 		switch network {
     45 		case "unix", "unixpacket":
     46 			defer os.Remove(c.LocalAddr().String())
     47 		}
     48 		defer c.Close()
     49 
     50 		switch c := c.(type) {
     51 		case *TCPConn:
     52 			err = c.CloseRead()
     53 		case *UnixConn:
     54 			err = c.CloseRead()
     55 		}
     56 		if err != nil {
     57 			if perr := parseCloseError(err, true); perr != nil {
     58 				t.Error(perr)
     59 			}
     60 			t.Fatal(err)
     61 		}
     62 		var b [1]byte
     63 		n, err := c.Read(b[:])
     64 		if n != 0 || err == nil {
     65 			t.Fatalf("got (%d, %v); want (0, error)", n, err)
     66 		}
     67 	}
     68 }
     69 
     70 func TestCloseWrite(t *testing.T) {
     71 	switch runtime.GOOS {
     72 	case "nacl", "plan9":
     73 		t.Skipf("not supported on %s", runtime.GOOS)
     74 	}
     75 
     76 	handler := func(ls *localServer, ln Listener) {
     77 		c, err := ln.Accept()
     78 		if err != nil {
     79 			t.Error(err)
     80 			return
     81 		}
     82 		defer c.Close()
     83 
     84 		var b [1]byte
     85 		n, err := c.Read(b[:])
     86 		if n != 0 || err != io.EOF {
     87 			t.Errorf("got (%d, %v); want (0, io.EOF)", n, err)
     88 			return
     89 		}
     90 		switch c := c.(type) {
     91 		case *TCPConn:
     92 			err = c.CloseWrite()
     93 		case *UnixConn:
     94 			err = c.CloseWrite()
     95 		}
     96 		if err != nil {
     97 			if perr := parseCloseError(err, true); perr != nil {
     98 				t.Error(perr)
     99 			}
    100 			t.Error(err)
    101 			return
    102 		}
    103 		n, err = c.Write(b[:])
    104 		if err == nil {
    105 			t.Errorf("got (%d, %v); want (any, error)", n, err)
    106 			return
    107 		}
    108 	}
    109 
    110 	for _, network := range []string{"tcp", "unix", "unixpacket"} {
    111 		if !testableNetwork(network) {
    112 			t.Logf("skipping %s test", network)
    113 			continue
    114 		}
    115 
    116 		ls, err := newLocalServer(network)
    117 		if err != nil {
    118 			t.Fatal(err)
    119 		}
    120 		defer ls.teardown()
    121 		if err := ls.buildup(handler); err != nil {
    122 			t.Fatal(err)
    123 		}
    124 
    125 		c, err := Dial(ls.Listener.Addr().Network(), ls.Listener.Addr().String())
    126 		if err != nil {
    127 			t.Fatal(err)
    128 		}
    129 		switch network {
    130 		case "unix", "unixpacket":
    131 			defer os.Remove(c.LocalAddr().String())
    132 		}
    133 		defer c.Close()
    134 
    135 		switch c := c.(type) {
    136 		case *TCPConn:
    137 			err = c.CloseWrite()
    138 		case *UnixConn:
    139 			err = c.CloseWrite()
    140 		}
    141 		if err != nil {
    142 			if perr := parseCloseError(err, true); perr != nil {
    143 				t.Error(perr)
    144 			}
    145 			t.Fatal(err)
    146 		}
    147 		var b [1]byte
    148 		n, err := c.Read(b[:])
    149 		if n != 0 || err != io.EOF {
    150 			t.Fatalf("got (%d, %v); want (0, io.EOF)", n, err)
    151 		}
    152 		n, err = c.Write(b[:])
    153 		if err == nil {
    154 			t.Fatalf("got (%d, %v); want (any, error)", n, err)
    155 		}
    156 	}
    157 }
    158 
    159 func TestConnClose(t *testing.T) {
    160 	for _, network := range []string{"tcp", "unix", "unixpacket"} {
    161 		if !testableNetwork(network) {
    162 			t.Logf("skipping %s test", network)
    163 			continue
    164 		}
    165 
    166 		ln, err := newLocalListener(network)
    167 		if err != nil {
    168 			t.Fatal(err)
    169 		}
    170 		switch network {
    171 		case "unix", "unixpacket":
    172 			defer os.Remove(ln.Addr().String())
    173 		}
    174 		defer ln.Close()
    175 
    176 		c, err := Dial(ln.Addr().Network(), ln.Addr().String())
    177 		if err != nil {
    178 			t.Fatal(err)
    179 		}
    180 		switch network {
    181 		case "unix", "unixpacket":
    182 			defer os.Remove(c.LocalAddr().String())
    183 		}
    184 		defer c.Close()
    185 
    186 		if err := c.Close(); err != nil {
    187 			if perr := parseCloseError(err, false); perr != nil {
    188 				t.Error(perr)
    189 			}
    190 			t.Fatal(err)
    191 		}
    192 		var b [1]byte
    193 		n, err := c.Read(b[:])
    194 		if n != 0 || err == nil {
    195 			t.Fatalf("got (%d, %v); want (0, error)", n, err)
    196 		}
    197 	}
    198 }
    199 
    200 func TestListenerClose(t *testing.T) {
    201 	for _, network := range []string{"tcp", "unix", "unixpacket"} {
    202 		if !testableNetwork(network) {
    203 			t.Logf("skipping %s test", network)
    204 			continue
    205 		}
    206 
    207 		ln, err := newLocalListener(network)
    208 		if err != nil {
    209 			t.Fatal(err)
    210 		}
    211 		switch network {
    212 		case "unix", "unixpacket":
    213 			defer os.Remove(ln.Addr().String())
    214 		}
    215 
    216 		dst := ln.Addr().String()
    217 		if err := ln.Close(); err != nil {
    218 			if perr := parseCloseError(err, false); perr != nil {
    219 				t.Error(perr)
    220 			}
    221 			t.Fatal(err)
    222 		}
    223 		c, err := ln.Accept()
    224 		if err == nil {
    225 			c.Close()
    226 			t.Fatal("should fail")
    227 		}
    228 
    229 		if network == "tcp" {
    230 			// We will have two TCP FSMs inside the
    231 			// kernel here. There's no guarantee that a
    232 			// signal comes from the far end FSM will be
    233 			// delivered immediately to the near end FSM,
    234 			// especially on the platforms that allow
    235 			// multiple consumer threads to pull pending
    236 			// established connections at the same time by
    237 			// enabling SO_REUSEPORT option such as Linux,
    238 			// DragonFly BSD. So we need to give some time
    239 			// quantum to the kernel.
    240 			//
    241 			// Note that net.inet.tcp.reuseport_ext=1 by
    242 			// default on DragonFly BSD.
    243 			time.Sleep(time.Millisecond)
    244 
    245 			cc, err := Dial("tcp", dst)
    246 			if err == nil {
    247 				t.Error("Dial to closed TCP listener succeeded.")
    248 				cc.Close()
    249 			}
    250 		}
    251 	}
    252 }
    253 
    254 func TestPacketConnClose(t *testing.T) {
    255 	for _, network := range []string{"udp", "unixgram"} {
    256 		if !testableNetwork(network) {
    257 			t.Logf("skipping %s test", network)
    258 			continue
    259 		}
    260 
    261 		c, err := newLocalPacketListener(network)
    262 		if err != nil {
    263 			t.Fatal(err)
    264 		}
    265 		switch network {
    266 		case "unixgram":
    267 			defer os.Remove(c.LocalAddr().String())
    268 		}
    269 		defer c.Close()
    270 
    271 		if err := c.Close(); err != nil {
    272 			if perr := parseCloseError(err, false); perr != nil {
    273 				t.Error(perr)
    274 			}
    275 			t.Fatal(err)
    276 		}
    277 		var b [1]byte
    278 		n, _, err := c.ReadFrom(b[:])
    279 		if n != 0 || err == nil {
    280 			t.Fatalf("got (%d, %v); want (0, error)", n, err)
    281 		}
    282 	}
    283 }
    284 
    285 // nacl was previous failing to reuse an address.
    286 func TestListenCloseListen(t *testing.T) {
    287 	const maxTries = 10
    288 	for tries := 0; tries < maxTries; tries++ {
    289 		ln, err := newLocalListener("tcp")
    290 		if err != nil {
    291 			t.Fatal(err)
    292 		}
    293 		addr := ln.Addr().String()
    294 		if err := ln.Close(); err != nil {
    295 			if perr := parseCloseError(err, false); perr != nil {
    296 				t.Error(perr)
    297 			}
    298 			t.Fatal(err)
    299 		}
    300 		ln, err = Listen("tcp", addr)
    301 		if err == nil {
    302 			// Success. nacl couldn't do this before.
    303 			ln.Close()
    304 			return
    305 		}
    306 		t.Errorf("failed on try %d/%d: %v", tries+1, maxTries, err)
    307 	}
    308 	t.Fatalf("failed to listen/close/listen on same address after %d tries", maxTries)
    309 }
    310 
    311 // See golang.org/issue/6163, golang.org/issue/6987.
    312 func TestAcceptIgnoreAbortedConnRequest(t *testing.T) {
    313 	switch runtime.GOOS {
    314 	case "plan9":
    315 		t.Skipf("%s does not have full support of socktest", runtime.GOOS)
    316 	}
    317 
    318 	syserr := make(chan error)
    319 	go func() {
    320 		defer close(syserr)
    321 		for _, err := range abortedConnRequestErrors {
    322 			syserr <- err
    323 		}
    324 	}()
    325 	sw.Set(socktest.FilterAccept, func(so *socktest.Status) (socktest.AfterFilter, error) {
    326 		if err, ok := <-syserr; ok {
    327 			return nil, err
    328 		}
    329 		return nil, nil
    330 	})
    331 	defer sw.Set(socktest.FilterAccept, nil)
    332 
    333 	operr := make(chan error, 1)
    334 	handler := func(ls *localServer, ln Listener) {
    335 		defer close(operr)
    336 		c, err := ln.Accept()
    337 		if err != nil {
    338 			if perr := parseAcceptError(err); perr != nil {
    339 				operr <- perr
    340 			}
    341 			operr <- err
    342 			return
    343 		}
    344 		c.Close()
    345 	}
    346 	ls, err := newLocalServer("tcp")
    347 	if err != nil {
    348 		t.Fatal(err)
    349 	}
    350 	defer ls.teardown()
    351 	if err := ls.buildup(handler); err != nil {
    352 		t.Fatal(err)
    353 	}
    354 
    355 	c, err := Dial(ls.Listener.Addr().Network(), ls.Listener.Addr().String())
    356 	if err != nil {
    357 		t.Fatal(err)
    358 	}
    359 	c.Close()
    360 
    361 	for err := range operr {
    362 		t.Error(err)
    363 	}
    364 }
    365 
    366 func TestZeroByteRead(t *testing.T) {
    367 	for _, network := range []string{"tcp", "unix", "unixpacket"} {
    368 		if !testableNetwork(network) {
    369 			t.Logf("skipping %s test", network)
    370 			continue
    371 		}
    372 
    373 		ln, err := newLocalListener(network)
    374 		if err != nil {
    375 			t.Fatal(err)
    376 		}
    377 		connc := make(chan Conn, 1)
    378 		go func() {
    379 			defer ln.Close()
    380 			c, err := ln.Accept()
    381 			if err != nil {
    382 				t.Error(err)
    383 			}
    384 			connc <- c // might be nil
    385 		}()
    386 		c, err := Dial(network, ln.Addr().String())
    387 		if err != nil {
    388 			t.Fatal(err)
    389 		}
    390 		defer c.Close()
    391 		sc := <-connc
    392 		if sc == nil {
    393 			continue
    394 		}
    395 		defer sc.Close()
    396 
    397 		if runtime.GOOS == "windows" {
    398 			// A zero byte read on Windows caused a wait for readability first.
    399 			// Rather than change that behavior, satisfy it in this test.
    400 			// See Issue 15735.
    401 			go io.WriteString(sc, "a")
    402 		}
    403 
    404 		n, err := c.Read(nil)
    405 		if n != 0 || err != nil {
    406 			t.Errorf("%s: zero byte client read = %v, %v; want 0, nil", network, n, err)
    407 		}
    408 
    409 		if runtime.GOOS == "windows" {
    410 			// Same as comment above.
    411 			go io.WriteString(c, "a")
    412 		}
    413 		n, err = sc.Read(nil)
    414 		if n != 0 || err != nil {
    415 			t.Errorf("%s: zero byte server read = %v, %v; want 0, nil", network, n, err)
    416 		}
    417 	}
    418 }
    419 
    420 // withTCPConnPair sets up a TCP connection between two peers, then
    421 // runs peer1 and peer2 concurrently. withTCPConnPair returns when
    422 // both have completed.
    423 func withTCPConnPair(t *testing.T, peer1, peer2 func(c *TCPConn) error) {
    424 	ln, err := newLocalListener("tcp")
    425 	if err != nil {
    426 		t.Fatal(err)
    427 	}
    428 	defer ln.Close()
    429 	errc := make(chan error, 2)
    430 	go func() {
    431 		c1, err := ln.Accept()
    432 		if err != nil {
    433 			errc <- err
    434 			return
    435 		}
    436 		defer c1.Close()
    437 		errc <- peer1(c1.(*TCPConn))
    438 	}()
    439 	go func() {
    440 		c2, err := Dial("tcp", ln.Addr().String())
    441 		if err != nil {
    442 			errc <- err
    443 			return
    444 		}
    445 		defer c2.Close()
    446 		errc <- peer2(c2.(*TCPConn))
    447 	}()
    448 	for i := 0; i < 2; i++ {
    449 		if err := <-errc; err != nil {
    450 			t.Fatal(err)
    451 		}
    452 	}
    453 }
    454 
    455 // Tests that a blocked Read is interrupted by a concurrent SetReadDeadline
    456 // modifying that Conn's read deadline to the past.
    457 // See golang.org/cl/30164 which documented this. The net/http package
    458 // depends on this.
    459 func TestReadTimeoutUnblocksRead(t *testing.T) {
    460 	serverDone := make(chan struct{})
    461 	server := func(cs *TCPConn) error {
    462 		defer close(serverDone)
    463 		errc := make(chan error, 1)
    464 		go func() {
    465 			defer close(errc)
    466 			go func() {
    467 				// TODO: find a better way to wait
    468 				// until we're blocked in the cs.Read
    469 				// call below. Sleep is lame.
    470 				time.Sleep(100 * time.Millisecond)
    471 
    472 				// Interrupt the upcoming Read, unblocking it:
    473 				cs.SetReadDeadline(time.Unix(123, 0)) // time in the past
    474 			}()
    475 			var buf [1]byte
    476 			n, err := cs.Read(buf[:1])
    477 			if n != 0 || err == nil {
    478 				errc <- fmt.Errorf("Read = %v, %v; want 0, non-nil", n, err)
    479 			}
    480 		}()
    481 		select {
    482 		case err := <-errc:
    483 			return err
    484 		case <-time.After(5 * time.Second):
    485 			buf := make([]byte, 2<<20)
    486 			buf = buf[:runtime.Stack(buf, true)]
    487 			println("Stacks at timeout:\n", string(buf))
    488 			return errors.New("timeout waiting for Read to finish")
    489 		}
    490 
    491 	}
    492 	// Do nothing in the client. Never write. Just wait for the
    493 	// server's half to be done.
    494 	client := func(*TCPConn) error {
    495 		<-serverDone
    496 		return nil
    497 	}
    498 	withTCPConnPair(t, client, server)
    499 }
    500 
    501 // Issue 17695: verify that a blocked Read is woken up by a Close.
    502 func TestCloseUnblocksRead(t *testing.T) {
    503 	t.Parallel()
    504 	server := func(cs *TCPConn) error {
    505 		// Give the client time to get stuck in a Read:
    506 		time.Sleep(20 * time.Millisecond)
    507 		cs.Close()
    508 		return nil
    509 	}
    510 	client := func(ss *TCPConn) error {
    511 		n, err := ss.Read([]byte{0})
    512 		if n != 0 || err != io.EOF {
    513 			return fmt.Errorf("Read = %v, %v; want 0, EOF", n, err)
    514 		}
    515 		return nil
    516 	}
    517 	withTCPConnPair(t, client, server)
    518 }
    519