Home | History | Annotate | Download | only in httputil
      1 // Copyright 2011 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 // Reverse proxy tests.
      6 
      7 package httputil
      8 
      9 import (
     10 	"bufio"
     11 	"io/ioutil"
     12 	"log"
     13 	"net/http"
     14 	"net/http/httptest"
     15 	"net/url"
     16 	"reflect"
     17 	"runtime"
     18 	"strings"
     19 	"testing"
     20 	"time"
     21 )
     22 
     23 const fakeHopHeader = "X-Fake-Hop-Header-For-Test"
     24 
     25 func init() {
     26 	hopHeaders = append(hopHeaders, fakeHopHeader)
     27 }
     28 
     29 func TestReverseProxy(t *testing.T) {
     30 	const backendResponse = "I am the backend"
     31 	const backendStatus = 404
     32 	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
     33 		if len(r.TransferEncoding) > 0 {
     34 			t.Errorf("backend got unexpected TransferEncoding: %v", r.TransferEncoding)
     35 		}
     36 		if r.Header.Get("X-Forwarded-For") == "" {
     37 			t.Errorf("didn't get X-Forwarded-For header")
     38 		}
     39 		if c := r.Header.Get("Connection"); c != "" {
     40 			t.Errorf("handler got Connection header value %q", c)
     41 		}
     42 		if c := r.Header.Get("Upgrade"); c != "" {
     43 			t.Errorf("handler got Upgrade header value %q", c)
     44 		}
     45 		if g, e := r.Host, "some-name"; g != e {
     46 			t.Errorf("backend got Host header %q, want %q", g, e)
     47 		}
     48 		w.Header().Set("Trailer", "X-Trailer")
     49 		w.Header().Set("X-Foo", "bar")
     50 		w.Header().Set("Upgrade", "foo")
     51 		w.Header().Set(fakeHopHeader, "foo")
     52 		w.Header().Add("X-Multi-Value", "foo")
     53 		w.Header().Add("X-Multi-Value", "bar")
     54 		http.SetCookie(w, &http.Cookie{Name: "flavor", Value: "chocolateChip"})
     55 		w.WriteHeader(backendStatus)
     56 		w.Write([]byte(backendResponse))
     57 		w.Header().Set("X-Trailer", "trailer_value")
     58 	}))
     59 	defer backend.Close()
     60 	backendURL, err := url.Parse(backend.URL)
     61 	if err != nil {
     62 		t.Fatal(err)
     63 	}
     64 	proxyHandler := NewSingleHostReverseProxy(backendURL)
     65 	frontend := httptest.NewServer(proxyHandler)
     66 	defer frontend.Close()
     67 
     68 	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
     69 	getReq.Host = "some-name"
     70 	getReq.Header.Set("Connection", "close")
     71 	getReq.Header.Set("Upgrade", "foo")
     72 	getReq.Close = true
     73 	res, err := http.DefaultClient.Do(getReq)
     74 	if err != nil {
     75 		t.Fatalf("Get: %v", err)
     76 	}
     77 	if g, e := res.StatusCode, backendStatus; g != e {
     78 		t.Errorf("got res.StatusCode %d; expected %d", g, e)
     79 	}
     80 	if g, e := res.Header.Get("X-Foo"), "bar"; g != e {
     81 		t.Errorf("got X-Foo %q; expected %q", g, e)
     82 	}
     83 	if c := res.Header.Get(fakeHopHeader); c != "" {
     84 		t.Errorf("got %s header value %q", fakeHopHeader, c)
     85 	}
     86 	if g, e := len(res.Header["X-Multi-Value"]), 2; g != e {
     87 		t.Errorf("got %d X-Multi-Value header values; expected %d", g, e)
     88 	}
     89 	if g, e := len(res.Header["Set-Cookie"]), 1; g != e {
     90 		t.Fatalf("got %d SetCookies, want %d", g, e)
     91 	}
     92 	if g, e := res.Trailer, (http.Header{"X-Trailer": nil}); !reflect.DeepEqual(g, e) {
     93 		t.Errorf("before reading body, Trailer = %#v; want %#v", g, e)
     94 	}
     95 	if cookie := res.Cookies()[0]; cookie.Name != "flavor" {
     96 		t.Errorf("unexpected cookie %q", cookie.Name)
     97 	}
     98 	bodyBytes, _ := ioutil.ReadAll(res.Body)
     99 	if g, e := string(bodyBytes), backendResponse; g != e {
    100 		t.Errorf("got body %q; expected %q", g, e)
    101 	}
    102 	if g, e := res.Trailer.Get("X-Trailer"), "trailer_value"; g != e {
    103 		t.Errorf("Trailer(X-Trailer) = %q ; want %q", g, e)
    104 	}
    105 
    106 }
    107 
    108 func TestXForwardedFor(t *testing.T) {
    109 	const prevForwardedFor = "client ip"
    110 	const backendResponse = "I am the backend"
    111 	const backendStatus = 404
    112 	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    113 		if r.Header.Get("X-Forwarded-For") == "" {
    114 			t.Errorf("didn't get X-Forwarded-For header")
    115 		}
    116 		if !strings.Contains(r.Header.Get("X-Forwarded-For"), prevForwardedFor) {
    117 			t.Errorf("X-Forwarded-For didn't contain prior data")
    118 		}
    119 		w.WriteHeader(backendStatus)
    120 		w.Write([]byte(backendResponse))
    121 	}))
    122 	defer backend.Close()
    123 	backendURL, err := url.Parse(backend.URL)
    124 	if err != nil {
    125 		t.Fatal(err)
    126 	}
    127 	proxyHandler := NewSingleHostReverseProxy(backendURL)
    128 	frontend := httptest.NewServer(proxyHandler)
    129 	defer frontend.Close()
    130 
    131 	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
    132 	getReq.Host = "some-name"
    133 	getReq.Header.Set("Connection", "close")
    134 	getReq.Header.Set("X-Forwarded-For", prevForwardedFor)
    135 	getReq.Close = true
    136 	res, err := http.DefaultClient.Do(getReq)
    137 	if err != nil {
    138 		t.Fatalf("Get: %v", err)
    139 	}
    140 	if g, e := res.StatusCode, backendStatus; g != e {
    141 		t.Errorf("got res.StatusCode %d; expected %d", g, e)
    142 	}
    143 	bodyBytes, _ := ioutil.ReadAll(res.Body)
    144 	if g, e := string(bodyBytes), backendResponse; g != e {
    145 		t.Errorf("got body %q; expected %q", g, e)
    146 	}
    147 }
    148 
    149 var proxyQueryTests = []struct {
    150 	baseSuffix string // suffix to add to backend URL
    151 	reqSuffix  string // suffix to add to frontend's request URL
    152 	want       string // what backend should see for final request URL (without ?)
    153 }{
    154 	{"", "", ""},
    155 	{"?sta=tic", "?us=er", "sta=tic&us=er"},
    156 	{"", "?us=er", "us=er"},
    157 	{"?sta=tic", "", "sta=tic"},
    158 }
    159 
    160 func TestReverseProxyQuery(t *testing.T) {
    161 	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    162 		w.Header().Set("X-Got-Query", r.URL.RawQuery)
    163 		w.Write([]byte("hi"))
    164 	}))
    165 	defer backend.Close()
    166 
    167 	for i, tt := range proxyQueryTests {
    168 		backendURL, err := url.Parse(backend.URL + tt.baseSuffix)
    169 		if err != nil {
    170 			t.Fatal(err)
    171 		}
    172 		frontend := httptest.NewServer(NewSingleHostReverseProxy(backendURL))
    173 		req, _ := http.NewRequest("GET", frontend.URL+tt.reqSuffix, nil)
    174 		req.Close = true
    175 		res, err := http.DefaultClient.Do(req)
    176 		if err != nil {
    177 			t.Fatalf("%d. Get: %v", i, err)
    178 		}
    179 		if g, e := res.Header.Get("X-Got-Query"), tt.want; g != e {
    180 			t.Errorf("%d. got query %q; expected %q", i, g, e)
    181 		}
    182 		res.Body.Close()
    183 		frontend.Close()
    184 	}
    185 }
    186 
    187 func TestReverseProxyFlushInterval(t *testing.T) {
    188 	const expected = "hi"
    189 	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    190 		w.Write([]byte(expected))
    191 	}))
    192 	defer backend.Close()
    193 
    194 	backendURL, err := url.Parse(backend.URL)
    195 	if err != nil {
    196 		t.Fatal(err)
    197 	}
    198 
    199 	proxyHandler := NewSingleHostReverseProxy(backendURL)
    200 	proxyHandler.FlushInterval = time.Microsecond
    201 
    202 	done := make(chan bool)
    203 	onExitFlushLoop = func() { done <- true }
    204 	defer func() { onExitFlushLoop = nil }()
    205 
    206 	frontend := httptest.NewServer(proxyHandler)
    207 	defer frontend.Close()
    208 
    209 	req, _ := http.NewRequest("GET", frontend.URL, nil)
    210 	req.Close = true
    211 	res, err := http.DefaultClient.Do(req)
    212 	if err != nil {
    213 		t.Fatalf("Get: %v", err)
    214 	}
    215 	defer res.Body.Close()
    216 	if bodyBytes, _ := ioutil.ReadAll(res.Body); string(bodyBytes) != expected {
    217 		t.Errorf("got body %q; expected %q", bodyBytes, expected)
    218 	}
    219 
    220 	select {
    221 	case <-done:
    222 		// OK
    223 	case <-time.After(5 * time.Second):
    224 		t.Error("maxLatencyWriter flushLoop() never exited")
    225 	}
    226 }
    227 
    228 func TestReverseProxyCancellation(t *testing.T) {
    229 	if runtime.GOOS == "plan9" {
    230 		t.Skip("skipping test; see https://golang.org/issue/9554")
    231 	}
    232 	const backendResponse = "I am the backend"
    233 
    234 	reqInFlight := make(chan struct{})
    235 	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    236 		close(reqInFlight)
    237 
    238 		select {
    239 		case <-time.After(10 * time.Second):
    240 			// Note: this should only happen in broken implementations, and the
    241 			// closenotify case should be instantaneous.
    242 			t.Log("Failed to close backend connection")
    243 			t.Fail()
    244 		case <-w.(http.CloseNotifier).CloseNotify():
    245 		}
    246 
    247 		w.WriteHeader(http.StatusOK)
    248 		w.Write([]byte(backendResponse))
    249 	}))
    250 
    251 	defer backend.Close()
    252 
    253 	backend.Config.ErrorLog = log.New(ioutil.Discard, "", 0)
    254 
    255 	backendURL, err := url.Parse(backend.URL)
    256 	if err != nil {
    257 		t.Fatal(err)
    258 	}
    259 
    260 	proxyHandler := NewSingleHostReverseProxy(backendURL)
    261 
    262 	// Discards errors of the form:
    263 	// http: proxy error: read tcp 127.0.0.1:44643: use of closed network connection
    264 	proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0)
    265 
    266 	frontend := httptest.NewServer(proxyHandler)
    267 	defer frontend.Close()
    268 
    269 	getReq, _ := http.NewRequest("GET", frontend.URL, nil)
    270 	go func() {
    271 		<-reqInFlight
    272 		http.DefaultTransport.(*http.Transport).CancelRequest(getReq)
    273 	}()
    274 	res, err := http.DefaultClient.Do(getReq)
    275 	if res != nil {
    276 		t.Fatal("Non-nil response")
    277 	}
    278 	if err == nil {
    279 		// This should be an error like:
    280 		// Get http://127.0.0.1:58079: read tcp 127.0.0.1:58079:
    281 		//    use of closed network connection
    282 		t.Fatal("DefaultClient.Do() returned nil error")
    283 	}
    284 }
    285 
    286 func req(t *testing.T, v string) *http.Request {
    287 	req, err := http.ReadRequest(bufio.NewReader(strings.NewReader(v)))
    288 	if err != nil {
    289 		t.Fatal(err)
    290 	}
    291 	return req
    292 }
    293 
    294 // Issue 12344
    295 func TestNilBody(t *testing.T) {
    296 	backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    297 		w.Write([]byte("hi"))
    298 	}))
    299 	defer backend.Close()
    300 
    301 	frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
    302 		backURL, _ := url.Parse(backend.URL)
    303 		rp := NewSingleHostReverseProxy(backURL)
    304 		r := req(t, "GET / HTTP/1.0\r\n\r\n")
    305 		r.Body = nil // this accidentally worked in Go 1.4 and below, so keep it working
    306 		rp.ServeHTTP(w, r)
    307 	}))
    308 	defer frontend.Close()
    309 
    310 	res, err := http.Get(frontend.URL)
    311 	if err != nil {
    312 		t.Fatal(err)
    313 	}
    314 	defer res.Body.Close()
    315 	slurp, err := ioutil.ReadAll(res.Body)
    316 	if err != nil {
    317 		t.Fatal(err)
    318 	}
    319 	if string(slurp) != "hi" {
    320 		t.Errorf("Got %q; want %q", slurp, "hi")
    321 	}
    322 }
    323