Home | History | Annotate | Download | only in net
      1 // Copyright 2009 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 package net
      6 
      7 import (
      8 	"io"
      9 	"os"
     10 	"runtime"
     11 	"testing"
     12 )
     13 
     14 func TestCloseRead(t *testing.T) {
     15 	switch runtime.GOOS {
     16 	case "nacl", "plan9":
     17 		t.Skipf("not supported on %s", runtime.GOOS)
     18 	}
     19 
     20 	for _, network := range []string{"tcp", "unix", "unixpacket"} {
     21 		if !testableNetwork(network) {
     22 			t.Logf("skipping %s test", network)
     23 			continue
     24 		}
     25 
     26 		ln, err := newLocalListener(network)
     27 		if err != nil {
     28 			t.Fatal(err)
     29 		}
     30 		switch network {
     31 		case "unix", "unixpacket":
     32 			defer os.Remove(ln.Addr().String())
     33 		}
     34 		defer ln.Close()
     35 
     36 		c, err := Dial(ln.Addr().Network(), ln.Addr().String())
     37 		if err != nil {
     38 			t.Fatal(err)
     39 		}
     40 		switch network {
     41 		case "unix", "unixpacket":
     42 			defer os.Remove(c.LocalAddr().String())
     43 		}
     44 		defer c.Close()
     45 
     46 		switch c := c.(type) {
     47 		case *TCPConn:
     48 			err = c.CloseRead()
     49 		case *UnixConn:
     50 			err = c.CloseRead()
     51 		}
     52 		if err != nil {
     53 			if perr := parseCloseError(err); perr != nil {
     54 				t.Error(perr)
     55 			}
     56 			t.Fatal(err)
     57 		}
     58 		var b [1]byte
     59 		n, err := c.Read(b[:])
     60 		if n != 0 || err == nil {
     61 			t.Fatalf("got (%d, %v); want (0, error)", n, err)
     62 		}
     63 	}
     64 }
     65 
     66 func TestCloseWrite(t *testing.T) {
     67 	switch runtime.GOOS {
     68 	case "nacl", "plan9":
     69 		t.Skipf("not supported on %s", runtime.GOOS)
     70 	}
     71 
     72 	handler := func(ls *localServer, ln Listener) {
     73 		c, err := ln.Accept()
     74 		if err != nil {
     75 			t.Error(err)
     76 			return
     77 		}
     78 		defer c.Close()
     79 
     80 		var b [1]byte
     81 		n, err := c.Read(b[:])
     82 		if n != 0 || err != io.EOF {
     83 			t.Errorf("got (%d, %v); want (0, io.EOF)", n, err)
     84 			return
     85 		}
     86 		switch c := c.(type) {
     87 		case *TCPConn:
     88 			err = c.CloseWrite()
     89 		case *UnixConn:
     90 			err = c.CloseWrite()
     91 		}
     92 		if err != nil {
     93 			if perr := parseCloseError(err); perr != nil {
     94 				t.Error(perr)
     95 			}
     96 			t.Error(err)
     97 			return
     98 		}
     99 		n, err = c.Write(b[:])
    100 		if err == nil {
    101 			t.Errorf("got (%d, %v); want (any, error)", n, err)
    102 			return
    103 		}
    104 	}
    105 
    106 	for _, network := range []string{"tcp", "unix", "unixpacket"} {
    107 		if !testableNetwork(network) {
    108 			t.Logf("skipping %s test", network)
    109 			continue
    110 		}
    111 
    112 		ls, err := newLocalServer(network)
    113 		if err != nil {
    114 			t.Fatal(err)
    115 		}
    116 		defer ls.teardown()
    117 		if err := ls.buildup(handler); err != nil {
    118 			t.Fatal(err)
    119 		}
    120 
    121 		c, err := Dial(ls.Listener.Addr().Network(), ls.Listener.Addr().String())
    122 		if err != nil {
    123 			t.Fatal(err)
    124 		}
    125 		switch network {
    126 		case "unix", "unixpacket":
    127 			defer os.Remove(c.LocalAddr().String())
    128 		}
    129 		defer c.Close()
    130 
    131 		switch c := c.(type) {
    132 		case *TCPConn:
    133 			err = c.CloseWrite()
    134 		case *UnixConn:
    135 			err = c.CloseWrite()
    136 		}
    137 		if err != nil {
    138 			if perr := parseCloseError(err); perr != nil {
    139 				t.Error(perr)
    140 			}
    141 			t.Fatal(err)
    142 		}
    143 		var b [1]byte
    144 		n, err := c.Read(b[:])
    145 		if n != 0 || err != io.EOF {
    146 			t.Fatalf("got (%d, %v); want (0, io.EOF)", n, err)
    147 		}
    148 		n, err = c.Write(b[:])
    149 		if err == nil {
    150 			t.Fatalf("got (%d, %v); want (any, error)", n, err)
    151 		}
    152 	}
    153 }
    154 
    155 func TestConnClose(t *testing.T) {
    156 	for _, network := range []string{"tcp", "unix", "unixpacket"} {
    157 		if !testableNetwork(network) {
    158 			t.Logf("skipping %s test", network)
    159 			continue
    160 		}
    161 
    162 		ln, err := newLocalListener(network)
    163 		if err != nil {
    164 			t.Fatal(err)
    165 		}
    166 		switch network {
    167 		case "unix", "unixpacket":
    168 			defer os.Remove(ln.Addr().String())
    169 		}
    170 		defer ln.Close()
    171 
    172 		c, err := Dial(ln.Addr().Network(), ln.Addr().String())
    173 		if err != nil {
    174 			t.Fatal(err)
    175 		}
    176 		switch network {
    177 		case "unix", "unixpacket":
    178 			defer os.Remove(c.LocalAddr().String())
    179 		}
    180 		defer c.Close()
    181 
    182 		if err := c.Close(); err != nil {
    183 			if perr := parseCloseError(err); perr != nil {
    184 				t.Error(perr)
    185 			}
    186 			t.Fatal(err)
    187 		}
    188 		var b [1]byte
    189 		n, err := c.Read(b[:])
    190 		if n != 0 || err == nil {
    191 			t.Fatalf("got (%d, %v); want (0, error)", n, err)
    192 		}
    193 	}
    194 }
    195 
    196 func TestListenerClose(t *testing.T) {
    197 	for _, network := range []string{"tcp", "unix", "unixpacket"} {
    198 		if !testableNetwork(network) {
    199 			t.Logf("skipping %s test", network)
    200 			continue
    201 		}
    202 
    203 		ln, err := newLocalListener(network)
    204 		if err != nil {
    205 			t.Fatal(err)
    206 		}
    207 		switch network {
    208 		case "unix", "unixpacket":
    209 			defer os.Remove(ln.Addr().String())
    210 		}
    211 		defer ln.Close()
    212 
    213 		if err := ln.Close(); err != nil {
    214 			if perr := parseCloseError(err); perr != nil {
    215 				t.Error(perr)
    216 			}
    217 			t.Fatal(err)
    218 		}
    219 		c, err := ln.Accept()
    220 		if err == nil {
    221 			c.Close()
    222 			t.Fatal("should fail")
    223 		}
    224 	}
    225 }
    226 
    227 func TestPacketConnClose(t *testing.T) {
    228 	for _, network := range []string{"udp", "unixgram"} {
    229 		if !testableNetwork(network) {
    230 			t.Logf("skipping %s test", network)
    231 			continue
    232 		}
    233 
    234 		c, err := newLocalPacketListener(network)
    235 		if err != nil {
    236 			t.Fatal(err)
    237 		}
    238 		switch network {
    239 		case "unixgram":
    240 			defer os.Remove(c.LocalAddr().String())
    241 		}
    242 		defer c.Close()
    243 
    244 		if err := c.Close(); err != nil {
    245 			if perr := parseCloseError(err); perr != nil {
    246 				t.Error(perr)
    247 			}
    248 			t.Fatal(err)
    249 		}
    250 		var b [1]byte
    251 		n, _, err := c.ReadFrom(b[:])
    252 		if n != 0 || err == nil {
    253 			t.Fatalf("got (%d, %v); want (0, error)", n, err)
    254 		}
    255 	}
    256 }
    257