Home | History | Annotate | Download | only in io
      1 // Copyright 2010 The Go Authors.  All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 package io_test
      6 
      7 import (
      8 	"bytes"
      9 	"crypto/sha1"
     10 	"fmt"
     11 	. "io"
     12 	"io/ioutil"
     13 	"strings"
     14 	"testing"
     15 )
     16 
     17 func TestMultiReader(t *testing.T) {
     18 	var mr Reader
     19 	var buf []byte
     20 	nread := 0
     21 	withFooBar := func(tests func()) {
     22 		r1 := strings.NewReader("foo ")
     23 		r2 := strings.NewReader("")
     24 		r3 := strings.NewReader("bar")
     25 		mr = MultiReader(r1, r2, r3)
     26 		buf = make([]byte, 20)
     27 		tests()
     28 	}
     29 	expectRead := func(size int, expected string, eerr error) {
     30 		nread++
     31 		n, gerr := mr.Read(buf[0:size])
     32 		if n != len(expected) {
     33 			t.Errorf("#%d, expected %d bytes; got %d",
     34 				nread, len(expected), n)
     35 		}
     36 		got := string(buf[0:n])
     37 		if got != expected {
     38 			t.Errorf("#%d, expected %q; got %q",
     39 				nread, expected, got)
     40 		}
     41 		if gerr != eerr {
     42 			t.Errorf("#%d, expected error %v; got %v",
     43 				nread, eerr, gerr)
     44 		}
     45 		buf = buf[n:]
     46 	}
     47 	withFooBar(func() {
     48 		expectRead(2, "fo", nil)
     49 		expectRead(5, "o ", nil)
     50 		expectRead(5, "bar", nil)
     51 		expectRead(5, "", EOF)
     52 	})
     53 	withFooBar(func() {
     54 		expectRead(4, "foo ", nil)
     55 		expectRead(1, "b", nil)
     56 		expectRead(3, "ar", nil)
     57 		expectRead(1, "", EOF)
     58 	})
     59 	withFooBar(func() {
     60 		expectRead(5, "foo ", nil)
     61 	})
     62 }
     63 
     64 func TestMultiWriter(t *testing.T) {
     65 	sha1 := sha1.New()
     66 	sink := new(bytes.Buffer)
     67 	mw := MultiWriter(sha1, sink)
     68 
     69 	sourceString := "My input text."
     70 	source := strings.NewReader(sourceString)
     71 	written, err := Copy(mw, source)
     72 
     73 	if written != int64(len(sourceString)) {
     74 		t.Errorf("short write of %d, not %d", written, len(sourceString))
     75 	}
     76 
     77 	if err != nil {
     78 		t.Errorf("unexpected error: %v", err)
     79 	}
     80 
     81 	sha1hex := fmt.Sprintf("%x", sha1.Sum(nil))
     82 	if sha1hex != "01cb303fa8c30a64123067c5aa6284ba7ec2d31b" {
     83 		t.Error("incorrect sha1 value")
     84 	}
     85 
     86 	if sink.String() != sourceString {
     87 		t.Errorf("expected %q; got %q", sourceString, sink.String())
     88 	}
     89 }
     90 
     91 // Test that MultiReader copies the input slice and is insulated from future modification.
     92 func TestMultiReaderCopy(t *testing.T) {
     93 	slice := []Reader{strings.NewReader("hello world")}
     94 	r := MultiReader(slice...)
     95 	slice[0] = nil
     96 	data, err := ioutil.ReadAll(r)
     97 	if err != nil || string(data) != "hello world" {
     98 		t.Errorf("ReadAll() = %q, %v, want %q, nil", data, err, "hello world")
     99 	}
    100 }
    101 
    102 // Test that MultiWriter copies the input slice and is insulated from future modification.
    103 func TestMultiWriterCopy(t *testing.T) {
    104 	var buf bytes.Buffer
    105 	slice := []Writer{&buf}
    106 	w := MultiWriter(slice...)
    107 	slice[0] = nil
    108 	n, err := w.Write([]byte("hello world"))
    109 	if err != nil || n != 11 {
    110 		t.Errorf("Write(`hello world`) = %d, %v, want 11, nil", n, err)
    111 	}
    112 	if buf.String() != "hello world" {
    113 		t.Errorf("buf.String() = %q, want %q", buf.String(), "hello world")
    114 	}
    115 }
    116