Home | History | Annotate | Download | only in http
      1 // Copyright 2016 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 // White-box tests for transport.go (in package http instead of http_test).
      6 
      7 package http
      8 
      9 import (
     10 	"errors"
     11 	"net"
     12 	"testing"
     13 )
     14 
     15 // Issue 15446: incorrect wrapping of errors when server closes an idle connection.
     16 func TestTransportPersistConnReadLoopEOF(t *testing.T) {
     17 	ln := newLocalListener(t)
     18 	defer ln.Close()
     19 
     20 	connc := make(chan net.Conn, 1)
     21 	go func() {
     22 		defer close(connc)
     23 		c, err := ln.Accept()
     24 		if err != nil {
     25 			t.Error(err)
     26 			return
     27 		}
     28 		connc <- c
     29 	}()
     30 
     31 	tr := new(Transport)
     32 	req, _ := NewRequest("GET", "http://"+ln.Addr().String(), nil)
     33 	treq := &transportRequest{Request: req}
     34 	cm := connectMethod{targetScheme: "http", targetAddr: ln.Addr().String()}
     35 	pc, err := tr.getConn(treq, cm)
     36 	if err != nil {
     37 		t.Fatal(err)
     38 	}
     39 	defer pc.close(errors.New("test over"))
     40 
     41 	conn := <-connc
     42 	if conn == nil {
     43 		// Already called t.Error in the accept goroutine.
     44 		return
     45 	}
     46 	conn.Close() // simulate the server hanging up on the client
     47 
     48 	_, err = pc.roundTrip(treq)
     49 	if !isTransportReadFromServerError(err) && err != errServerClosedIdle {
     50 		t.Fatalf("roundTrip = %#v, %v; want errServerClosedConn or errServerClosedIdle", err, err)
     51 	}
     52 
     53 	<-pc.closech
     54 	err = pc.closed
     55 	if !isTransportReadFromServerError(err) && err != errServerClosedIdle {
     56 		t.Fatalf("pc.closed = %#v, %v; want errServerClosedConn or errServerClosedIdle", err, err)
     57 	}
     58 }
     59 
     60 func isTransportReadFromServerError(err error) bool {
     61 	_, ok := err.(transportReadFromServerError)
     62 	return ok
     63 }
     64 
     65 func newLocalListener(t *testing.T) net.Listener {
     66 	ln, err := net.Listen("tcp", "127.0.0.1:0")
     67 	if err != nil {
     68 		ln, err = net.Listen("tcp6", "[::1]:0")
     69 	}
     70 	if err != nil {
     71 		t.Fatal(err)
     72 	}
     73 	return ln
     74 }
     75 
     76 func dummyRequest(method string) *Request {
     77 	req, err := NewRequest(method, "http://fake.tld/", nil)
     78 	if err != nil {
     79 		panic(err)
     80 	}
     81 	return req
     82 }
     83 
     84 func TestTransportShouldRetryRequest(t *testing.T) {
     85 	tests := []struct {
     86 		pc  *persistConn
     87 		req *Request
     88 
     89 		err  error
     90 		want bool
     91 	}{
     92 		0: {
     93 			pc:   &persistConn{reused: false},
     94 			req:  dummyRequest("POST"),
     95 			err:  nothingWrittenError{},
     96 			want: false,
     97 		},
     98 		1: {
     99 			pc:   &persistConn{reused: true},
    100 			req:  dummyRequest("POST"),
    101 			err:  nothingWrittenError{},
    102 			want: true,
    103 		},
    104 		2: {
    105 			pc:   &persistConn{reused: true},
    106 			req:  dummyRequest("POST"),
    107 			err:  http2ErrNoCachedConn,
    108 			want: true,
    109 		},
    110 		3: {
    111 			pc:   &persistConn{reused: true},
    112 			req:  dummyRequest("POST"),
    113 			err:  errMissingHost,
    114 			want: false,
    115 		},
    116 		4: {
    117 			pc:   &persistConn{reused: true},
    118 			req:  dummyRequest("POST"),
    119 			err:  transportReadFromServerError{},
    120 			want: false,
    121 		},
    122 		5: {
    123 			pc:   &persistConn{reused: true},
    124 			req:  dummyRequest("GET"),
    125 			err:  transportReadFromServerError{},
    126 			want: true,
    127 		},
    128 		6: {
    129 			pc:   &persistConn{reused: true},
    130 			req:  dummyRequest("GET"),
    131 			err:  errServerClosedIdle,
    132 			want: true,
    133 		},
    134 	}
    135 	for i, tt := range tests {
    136 		got := tt.pc.shouldRetryRequest(tt.req, tt.err)
    137 		if got != tt.want {
    138 			t.Errorf("%d. shouldRetryRequest = %v; want %v", i, got, tt.want)
    139 		}
    140 	}
    141 }
    142