Home | History | Annotate | Download | only in io
      1 // Copyright 2010 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 package io_test
      6 
      7 import (
      8 	"bytes"
      9 	"crypto/sha1"
     10 	"errors"
     11 	"fmt"
     12 	. "io"
     13 	"io/ioutil"
     14 	"runtime"
     15 	"strings"
     16 	"testing"
     17 	"time"
     18 )
     19 
     20 func TestMultiReader(t *testing.T) {
     21 	var mr Reader
     22 	var buf []byte
     23 	nread := 0
     24 	withFooBar := func(tests func()) {
     25 		r1 := strings.NewReader("foo ")
     26 		r2 := strings.NewReader("")
     27 		r3 := strings.NewReader("bar")
     28 		mr = MultiReader(r1, r2, r3)
     29 		buf = make([]byte, 20)
     30 		tests()
     31 	}
     32 	expectRead := func(size int, expected string, eerr error) {
     33 		nread++
     34 		n, gerr := mr.Read(buf[0:size])
     35 		if n != len(expected) {
     36 			t.Errorf("#%d, expected %d bytes; got %d",
     37 				nread, len(expected), n)
     38 		}
     39 		got := string(buf[0:n])
     40 		if got != expected {
     41 			t.Errorf("#%d, expected %q; got %q",
     42 				nread, expected, got)
     43 		}
     44 		if gerr != eerr {
     45 			t.Errorf("#%d, expected error %v; got %v",
     46 				nread, eerr, gerr)
     47 		}
     48 		buf = buf[n:]
     49 	}
     50 	withFooBar(func() {
     51 		expectRead(2, "fo", nil)
     52 		expectRead(5, "o ", nil)
     53 		expectRead(5, "bar", nil)
     54 		expectRead(5, "", EOF)
     55 	})
     56 	withFooBar(func() {
     57 		expectRead(4, "foo ", nil)
     58 		expectRead(1, "b", nil)
     59 		expectRead(3, "ar", nil)
     60 		expectRead(1, "", EOF)
     61 	})
     62 	withFooBar(func() {
     63 		expectRead(5, "foo ", nil)
     64 	})
     65 }
     66 
     67 func TestMultiWriter(t *testing.T) {
     68 	sink := new(bytes.Buffer)
     69 	// Hide bytes.Buffer's WriteString method:
     70 	testMultiWriter(t, struct {
     71 		Writer
     72 		fmt.Stringer
     73 	}{sink, sink})
     74 }
     75 
     76 func TestMultiWriter_String(t *testing.T) {
     77 	testMultiWriter(t, new(bytes.Buffer))
     78 }
     79 
     80 // test that a multiWriter.WriteString calls results in at most 1 allocation,
     81 // even if multiple targets don't support WriteString.
     82 func TestMultiWriter_WriteStringSingleAlloc(t *testing.T) {
     83 	var sink1, sink2 bytes.Buffer
     84 	type simpleWriter struct { // hide bytes.Buffer's WriteString
     85 		Writer
     86 	}
     87 	mw := MultiWriter(simpleWriter{&sink1}, simpleWriter{&sink2})
     88 	allocs := int(testing.AllocsPerRun(1000, func() {
     89 		WriteString(mw, "foo")
     90 	}))
     91 	if allocs != 1 {
     92 		t.Errorf("num allocations = %d; want 1", allocs)
     93 	}
     94 }
     95 
     96 type writeStringChecker struct{ called bool }
     97 
     98 func (c *writeStringChecker) WriteString(s string) (n int, err error) {
     99 	c.called = true
    100 	return len(s), nil
    101 }
    102 
    103 func (c *writeStringChecker) Write(p []byte) (n int, err error) {
    104 	return len(p), nil
    105 }
    106 
    107 func TestMultiWriter_StringCheckCall(t *testing.T) {
    108 	var c writeStringChecker
    109 	mw := MultiWriter(&c)
    110 	WriteString(mw, "foo")
    111 	if !c.called {
    112 		t.Error("did not see WriteString call to writeStringChecker")
    113 	}
    114 }
    115 
    116 func testMultiWriter(t *testing.T, sink interface {
    117 	Writer
    118 	fmt.Stringer
    119 }) {
    120 	sha1 := sha1.New()
    121 	mw := MultiWriter(sha1, sink)
    122 
    123 	sourceString := "My input text."
    124 	source := strings.NewReader(sourceString)
    125 	written, err := Copy(mw, source)
    126 
    127 	if written != int64(len(sourceString)) {
    128 		t.Errorf("short write of %d, not %d", written, len(sourceString))
    129 	}
    130 
    131 	if err != nil {
    132 		t.Errorf("unexpected error: %v", err)
    133 	}
    134 
    135 	sha1hex := fmt.Sprintf("%x", sha1.Sum(nil))
    136 	if sha1hex != "01cb303fa8c30a64123067c5aa6284ba7ec2d31b" {
    137 		t.Error("incorrect sha1 value")
    138 	}
    139 
    140 	if sink.String() != sourceString {
    141 		t.Errorf("expected %q; got %q", sourceString, sink.String())
    142 	}
    143 }
    144 
    145 // writerFunc is an io.Writer implemented by the underlying func.
    146 type writerFunc func(p []byte) (int, error)
    147 
    148 func (f writerFunc) Write(p []byte) (int, error) {
    149 	return f(p)
    150 }
    151 
    152 // Test that MultiWriter properly flattens chained multiWriters,
    153 func TestMultiWriterSingleChainFlatten(t *testing.T) {
    154 	pc := make([]uintptr, 1000) // 1000 should fit the full stack
    155 	n := runtime.Callers(0, pc)
    156 	var myDepth = callDepth(pc[:n])
    157 	var writeDepth int // will contain the depth from which writerFunc.Writer was called
    158 	var w Writer = MultiWriter(writerFunc(func(p []byte) (int, error) {
    159 		n := runtime.Callers(1, pc)
    160 		writeDepth += callDepth(pc[:n])
    161 		return 0, nil
    162 	}))
    163 
    164 	mw := w
    165 	// chain a bunch of multiWriters
    166 	for i := 0; i < 100; i++ {
    167 		mw = MultiWriter(w)
    168 	}
    169 
    170 	mw = MultiWriter(w, mw, w, mw)
    171 	mw.Write(nil) // don't care about errors, just want to check the call-depth for Write
    172 
    173 	if writeDepth != 4*(myDepth+2) { // 2 should be multiWriter.Write and writerFunc.Write
    174 		t.Errorf("multiWriter did not flatten chained multiWriters: expected writeDepth %d, got %d",
    175 			4*(myDepth+2), writeDepth)
    176 	}
    177 }
    178 
    179 func TestMultiWriterError(t *testing.T) {
    180 	f1 := writerFunc(func(p []byte) (int, error) {
    181 		return len(p) / 2, ErrShortWrite
    182 	})
    183 	f2 := writerFunc(func(p []byte) (int, error) {
    184 		t.Errorf("MultiWriter called f2.Write")
    185 		return len(p), nil
    186 	})
    187 	w := MultiWriter(f1, f2)
    188 	n, err := w.Write(make([]byte, 100))
    189 	if n != 50 || err != ErrShortWrite {
    190 		t.Errorf("Write = %d, %v, want 50, ErrShortWrite", n, err)
    191 	}
    192 }
    193 
    194 // Test that MultiReader copies the input slice and is insulated from future modification.
    195 func TestMultiReaderCopy(t *testing.T) {
    196 	slice := []Reader{strings.NewReader("hello world")}
    197 	r := MultiReader(slice...)
    198 	slice[0] = nil
    199 	data, err := ioutil.ReadAll(r)
    200 	if err != nil || string(data) != "hello world" {
    201 		t.Errorf("ReadAll() = %q, %v, want %q, nil", data, err, "hello world")
    202 	}
    203 }
    204 
    205 // Test that MultiWriter copies the input slice and is insulated from future modification.
    206 func TestMultiWriterCopy(t *testing.T) {
    207 	var buf bytes.Buffer
    208 	slice := []Writer{&buf}
    209 	w := MultiWriter(slice...)
    210 	slice[0] = nil
    211 	n, err := w.Write([]byte("hello world"))
    212 	if err != nil || n != 11 {
    213 		t.Errorf("Write(`hello world`) = %d, %v, want 11, nil", n, err)
    214 	}
    215 	if buf.String() != "hello world" {
    216 		t.Errorf("buf.String() = %q, want %q", buf.String(), "hello world")
    217 	}
    218 }
    219 
    220 // readerFunc is an io.Reader implemented by the underlying func.
    221 type readerFunc func(p []byte) (int, error)
    222 
    223 func (f readerFunc) Read(p []byte) (int, error) {
    224 	return f(p)
    225 }
    226 
    227 // callDepth returns the logical call depth for the given PCs.
    228 func callDepth(callers []uintptr) (depth int) {
    229 	frames := runtime.CallersFrames(callers)
    230 	more := true
    231 	for more {
    232 		_, more = frames.Next()
    233 		depth++
    234 	}
    235 	return
    236 }
    237 
    238 // Test that MultiReader properly flattens chained multiReaders when Read is called
    239 func TestMultiReaderFlatten(t *testing.T) {
    240 	pc := make([]uintptr, 1000) // 1000 should fit the full stack
    241 	n := runtime.Callers(0, pc)
    242 	var myDepth = callDepth(pc[:n])
    243 	var readDepth int // will contain the depth from which fakeReader.Read was called
    244 	var r Reader = MultiReader(readerFunc(func(p []byte) (int, error) {
    245 		n := runtime.Callers(1, pc)
    246 		readDepth = callDepth(pc[:n])
    247 		return 0, errors.New("irrelevant")
    248 	}))
    249 
    250 	// chain a bunch of multiReaders
    251 	for i := 0; i < 100; i++ {
    252 		r = MultiReader(r)
    253 	}
    254 
    255 	r.Read(nil) // don't care about errors, just want to check the call-depth for Read
    256 
    257 	if readDepth != myDepth+2 { // 2 should be multiReader.Read and fakeReader.Read
    258 		t.Errorf("multiReader did not flatten chained multiReaders: expected readDepth %d, got %d",
    259 			myDepth+2, readDepth)
    260 	}
    261 }
    262 
    263 // byteAndEOFReader is a Reader which reads one byte (the underlying
    264 // byte) and io.EOF at once in its Read call.
    265 type byteAndEOFReader byte
    266 
    267 func (b byteAndEOFReader) Read(p []byte) (n int, err error) {
    268 	if len(p) == 0 {
    269 		// Read(0 bytes) is useless. We expect no such useless
    270 		// calls in this test.
    271 		panic("unexpected call")
    272 	}
    273 	p[0] = byte(b)
    274 	return 1, EOF
    275 }
    276 
    277 // This used to yield bytes forever; issue 16795.
    278 func TestMultiReaderSingleByteWithEOF(t *testing.T) {
    279 	got, err := ioutil.ReadAll(LimitReader(MultiReader(byteAndEOFReader('a'), byteAndEOFReader('b')), 10))
    280 	if err != nil {
    281 		t.Fatal(err)
    282 	}
    283 	const want = "ab"
    284 	if string(got) != want {
    285 		t.Errorf("got %q; want %q", got, want)
    286 	}
    287 }
    288 
    289 // Test that a reader returning (n, EOF) at the end of an MultiReader
    290 // chain continues to return EOF on its final read, rather than
    291 // yielding a (0, EOF).
    292 func TestMultiReaderFinalEOF(t *testing.T) {
    293 	r := MultiReader(bytes.NewReader(nil), byteAndEOFReader('a'))
    294 	buf := make([]byte, 2)
    295 	n, err := r.Read(buf)
    296 	if n != 1 || err != EOF {
    297 		t.Errorf("got %v, %v; want 1, EOF", n, err)
    298 	}
    299 }
    300 
    301 func TestMultiReaderFreesExhaustedReaders(t *testing.T) {
    302 	var mr Reader
    303 	closed := make(chan struct{})
    304 	// The closure ensures that we don't have a live reference to buf1
    305 	// on our stack after MultiReader is inlined (Issue 18819).  This
    306 	// is a work around for a limitation in liveness analysis.
    307 	func() {
    308 		buf1 := bytes.NewReader([]byte("foo"))
    309 		buf2 := bytes.NewReader([]byte("bar"))
    310 		mr = MultiReader(buf1, buf2)
    311 		runtime.SetFinalizer(buf1, func(*bytes.Reader) {
    312 			close(closed)
    313 		})
    314 	}()
    315 
    316 	buf := make([]byte, 4)
    317 	if n, err := ReadFull(mr, buf); err != nil || string(buf) != "foob" {
    318 		t.Fatalf(`ReadFull = %d (%q), %v; want 3, "foo", nil`, n, buf[:n], err)
    319 	}
    320 
    321 	runtime.GC()
    322 	select {
    323 	case <-closed:
    324 	case <-time.After(5 * time.Second):
    325 		t.Fatal("timeout waiting for collection of buf1")
    326 	}
    327 
    328 	if n, err := ReadFull(mr, buf[:2]); err != nil || string(buf[:2]) != "ar" {
    329 		t.Fatalf(`ReadFull = %d (%q), %v; want 2, "ar", nil`, n, buf[:n], err)
    330 	}
    331 }
    332 
    333 func TestInterleavedMultiReader(t *testing.T) {
    334 	r1 := strings.NewReader("123")
    335 	r2 := strings.NewReader("45678")
    336 
    337 	mr1 := MultiReader(r1, r2)
    338 	mr2 := MultiReader(mr1)
    339 
    340 	buf := make([]byte, 4)
    341 
    342 	// Have mr2 use mr1's []Readers.
    343 	// Consume r1 (and clear it for GC to handle) and consume part of r2.
    344 	n, err := ReadFull(mr2, buf)
    345 	if got := string(buf[:n]); got != "1234" || err != nil {
    346 		t.Errorf(`ReadFull(mr2) = (%q, %v), want ("1234", nil)`, got, err)
    347 	}
    348 
    349 	// Consume the rest of r2 via mr1.
    350 	// This should not panic even though mr2 cleared r1.
    351 	n, err = ReadFull(mr1, buf)
    352 	if got := string(buf[:n]); got != "5678" || err != nil {
    353 		t.Errorf(`ReadFull(mr1) = (%q, %v), want ("5678", nil)`, got, err)
    354 	}
    355 }
    356