Home | History | Annotate | Download | only in net
      1 // Copyright 2016 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 package net
      6 
      7 import (
      8 	"bytes"
      9 	"fmt"
     10 	"io"
     11 	"io/ioutil"
     12 	"reflect"
     13 	"runtime"
     14 	"sync"
     15 	"testing"
     16 )
     17 
     18 func TestBuffers_read(t *testing.T) {
     19 	const story = "once upon a time in Gopherland ... "
     20 	buffers := Buffers{
     21 		[]byte("once "),
     22 		[]byte("upon "),
     23 		[]byte("a "),
     24 		[]byte("time "),
     25 		[]byte("in "),
     26 		[]byte("Gopherland ... "),
     27 	}
     28 	got, err := ioutil.ReadAll(&buffers)
     29 	if err != nil {
     30 		t.Fatal(err)
     31 	}
     32 	if string(got) != story {
     33 		t.Errorf("read %q; want %q", got, story)
     34 	}
     35 	if len(buffers) != 0 {
     36 		t.Errorf("len(buffers) = %d; want 0", len(buffers))
     37 	}
     38 }
     39 
     40 func TestBuffers_consume(t *testing.T) {
     41 	tests := []struct {
     42 		in      Buffers
     43 		consume int64
     44 		want    Buffers
     45 	}{
     46 		{
     47 			in:      Buffers{[]byte("foo"), []byte("bar")},
     48 			consume: 0,
     49 			want:    Buffers{[]byte("foo"), []byte("bar")},
     50 		},
     51 		{
     52 			in:      Buffers{[]byte("foo"), []byte("bar")},
     53 			consume: 2,
     54 			want:    Buffers{[]byte("o"), []byte("bar")},
     55 		},
     56 		{
     57 			in:      Buffers{[]byte("foo"), []byte("bar")},
     58 			consume: 3,
     59 			want:    Buffers{[]byte("bar")},
     60 		},
     61 		{
     62 			in:      Buffers{[]byte("foo"), []byte("bar")},
     63 			consume: 4,
     64 			want:    Buffers{[]byte("ar")},
     65 		},
     66 		{
     67 			in:      Buffers{nil, nil, nil, []byte("bar")},
     68 			consume: 1,
     69 			want:    Buffers{[]byte("ar")},
     70 		},
     71 		{
     72 			in:      Buffers{nil, nil, nil, []byte("foo")},
     73 			consume: 0,
     74 			want:    Buffers{[]byte("foo")},
     75 		},
     76 		{
     77 			in:      Buffers{nil, nil, nil},
     78 			consume: 0,
     79 			want:    Buffers{},
     80 		},
     81 	}
     82 	for i, tt := range tests {
     83 		in := tt.in
     84 		in.consume(tt.consume)
     85 		if !reflect.DeepEqual(in, tt.want) {
     86 			t.Errorf("%d. after consume(%d) = %+v, want %+v", i, tt.consume, in, tt.want)
     87 		}
     88 	}
     89 }
     90 
     91 func TestBuffers_WriteTo(t *testing.T) {
     92 	for _, name := range []string{"WriteTo", "Copy"} {
     93 		for _, size := range []int{0, 10, 1023, 1024, 1025} {
     94 			t.Run(fmt.Sprintf("%s/%d", name, size), func(t *testing.T) {
     95 				testBuffer_writeTo(t, size, name == "Copy")
     96 			})
     97 		}
     98 	}
     99 }
    100 
    101 func testBuffer_writeTo(t *testing.T, chunks int, useCopy bool) {
    102 	oldHook := testHookDidWritev
    103 	defer func() { testHookDidWritev = oldHook }()
    104 	var writeLog struct {
    105 		sync.Mutex
    106 		log []int
    107 	}
    108 	testHookDidWritev = func(size int) {
    109 		writeLog.Lock()
    110 		writeLog.log = append(writeLog.log, size)
    111 		writeLog.Unlock()
    112 	}
    113 	var want bytes.Buffer
    114 	for i := 0; i < chunks; i++ {
    115 		want.WriteByte(byte(i))
    116 	}
    117 
    118 	withTCPConnPair(t, func(c *TCPConn) error {
    119 		buffers := make(Buffers, chunks)
    120 		for i := range buffers {
    121 			buffers[i] = want.Bytes()[i : i+1]
    122 		}
    123 		var n int64
    124 		var err error
    125 		if useCopy {
    126 			n, err = io.Copy(c, &buffers)
    127 		} else {
    128 			n, err = buffers.WriteTo(c)
    129 		}
    130 		if err != nil {
    131 			return err
    132 		}
    133 		if len(buffers) != 0 {
    134 			return fmt.Errorf("len(buffers) = %d; want 0", len(buffers))
    135 		}
    136 		if n != int64(want.Len()) {
    137 			return fmt.Errorf("Buffers.WriteTo returned %d; want %d", n, want.Len())
    138 		}
    139 		return nil
    140 	}, func(c *TCPConn) error {
    141 		all, err := ioutil.ReadAll(c)
    142 		if !bytes.Equal(all, want.Bytes()) || err != nil {
    143 			return fmt.Errorf("client read %q, %v; want %q, nil", all, err, want.Bytes())
    144 		}
    145 
    146 		writeLog.Lock() // no need to unlock
    147 		var gotSum int
    148 		for _, v := range writeLog.log {
    149 			gotSum += v
    150 		}
    151 
    152 		var wantSum int
    153 		switch runtime.GOOS {
    154 		case "android", "darwin", "dragonfly", "freebsd", "linux", "netbsd", "openbsd":
    155 			var wantMinCalls int
    156 			wantSum = want.Len()
    157 			v := chunks
    158 			for v > 0 {
    159 				wantMinCalls++
    160 				v -= 1024
    161 			}
    162 			if len(writeLog.log) < wantMinCalls {
    163 				t.Errorf("write calls = %v < wanted min %v", len(writeLog.log), wantMinCalls)
    164 			}
    165 		case "windows":
    166 			var wantCalls int
    167 			wantSum = want.Len()
    168 			if wantSum > 0 {
    169 				wantCalls = 1 // windows will always do 1 syscall, unless sending empty buffer
    170 			}
    171 			if len(writeLog.log) != wantCalls {
    172 				t.Errorf("write calls = %v; want %v", len(writeLog.log), wantCalls)
    173 			}
    174 		}
    175 		if gotSum != wantSum {
    176 			t.Errorf("writev call sum  = %v; want %v", gotSum, wantSum)
    177 		}
    178 		return nil
    179 	})
    180 }
    181 
    182 func TestWritevError(t *testing.T) {
    183 	if runtime.GOOS == "windows" {
    184 		t.Skipf("skipping the test: windows does not have problem sending large chunks of data")
    185 	}
    186 
    187 	ln, err := newLocalListener("tcp")
    188 	if err != nil {
    189 		t.Fatal(err)
    190 	}
    191 	defer ln.Close()
    192 
    193 	ch := make(chan Conn, 1)
    194 	go func() {
    195 		defer close(ch)
    196 		c, err := ln.Accept()
    197 		if err != nil {
    198 			t.Error(err)
    199 			return
    200 		}
    201 		ch <- c
    202 	}()
    203 	c1, err := Dial("tcp", ln.Addr().String())
    204 	if err != nil {
    205 		t.Fatal(err)
    206 	}
    207 	defer c1.Close()
    208 	c2 := <-ch
    209 	if c2 == nil {
    210 		t.Fatal("no server side connection")
    211 	}
    212 	c2.Close()
    213 
    214 	// 1 GB of data should be enough to notice the connection is gone.
    215 	// Just a few bytes is not enough.
    216 	// Arrange to reuse the same 1 MB buffer so that we don't allocate much.
    217 	buf := make([]byte, 1<<20)
    218 	buffers := make(Buffers, 1<<10)
    219 	for i := range buffers {
    220 		buffers[i] = buf
    221 	}
    222 	if _, err := buffers.WriteTo(c1); err == nil {
    223 		t.Fatal("Buffers.WriteTo(closed conn) succeeded, want error")
    224 	}
    225 }
    226