Home | History | Annotate | Download | only in httputil
      1 // Copyright 2011 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 // HTTP reverse proxy handler
      6 
      7 package httputil
      8 
      9 import (
     10 	"context"
     11 	"io"
     12 	"log"
     13 	"net"
     14 	"net/http"
     15 	"net/url"
     16 	"strings"
     17 	"sync"
     18 	"time"
     19 )
     20 
     21 // onExitFlushLoop is a callback set by tests to detect the state of the
     22 // flushLoop() goroutine.
     23 var onExitFlushLoop func()
     24 
     25 // ReverseProxy is an HTTP Handler that takes an incoming request and
     26 // sends it to another server, proxying the response back to the
     27 // client.
     28 type ReverseProxy struct {
     29 	// Director must be a function which modifies
     30 	// the request into a new request to be sent
     31 	// using Transport. Its response is then copied
     32 	// back to the original client unmodified.
     33 	// Director must not access the provided Request
     34 	// after returning.
     35 	Director func(*http.Request)
     36 
     37 	// The transport used to perform proxy requests.
     38 	// If nil, http.DefaultTransport is used.
     39 	Transport http.RoundTripper
     40 
     41 	// FlushInterval specifies the flush interval
     42 	// to flush to the client while copying the
     43 	// response body.
     44 	// If zero, no periodic flushing is done.
     45 	FlushInterval time.Duration
     46 
     47 	// ErrorLog specifies an optional logger for errors
     48 	// that occur when attempting to proxy the request.
     49 	// If nil, logging goes to os.Stderr via the log package's
     50 	// standard logger.
     51 	ErrorLog *log.Logger
     52 
     53 	// BufferPool optionally specifies a buffer pool to
     54 	// get byte slices for use by io.CopyBuffer when
     55 	// copying HTTP response bodies.
     56 	BufferPool BufferPool
     57 
     58 	// ModifyResponse is an optional function that
     59 	// modifies the Response from the backend.
     60 	// If it returns an error, the proxy returns a StatusBadGateway error.
     61 	ModifyResponse func(*http.Response) error
     62 }
     63 
     64 // A BufferPool is an interface for getting and returning temporary
     65 // byte slices for use by io.CopyBuffer.
     66 type BufferPool interface {
     67 	Get() []byte
     68 	Put([]byte)
     69 }
     70 
     71 func singleJoiningSlash(a, b string) string {
     72 	aslash := strings.HasSuffix(a, "/")
     73 	bslash := strings.HasPrefix(b, "/")
     74 	switch {
     75 	case aslash && bslash:
     76 		return a + b[1:]
     77 	case !aslash && !bslash:
     78 		return a + "/" + b
     79 	}
     80 	return a + b
     81 }
     82 
     83 // NewSingleHostReverseProxy returns a new ReverseProxy that routes
     84 // URLs to the scheme, host, and base path provided in target. If the
     85 // target's path is "/base" and the incoming request was for "/dir",
     86 // the target request will be for /base/dir.
     87 // NewSingleHostReverseProxy does not rewrite the Host header.
     88 // To rewrite Host headers, use ReverseProxy directly with a custom
     89 // Director policy.
     90 func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy {
     91 	targetQuery := target.RawQuery
     92 	director := func(req *http.Request) {
     93 		req.URL.Scheme = target.Scheme
     94 		req.URL.Host = target.Host
     95 		req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
     96 		if targetQuery == "" || req.URL.RawQuery == "" {
     97 			req.URL.RawQuery = targetQuery + req.URL.RawQuery
     98 		} else {
     99 			req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
    100 		}
    101 		if _, ok := req.Header["User-Agent"]; !ok {
    102 			// explicitly disable User-Agent so it's not set to default value
    103 			req.Header.Set("User-Agent", "")
    104 		}
    105 	}
    106 	return &ReverseProxy{Director: director}
    107 }
    108 
    109 func copyHeader(dst, src http.Header) {
    110 	for k, vv := range src {
    111 		for _, v := range vv {
    112 			dst.Add(k, v)
    113 		}
    114 	}
    115 }
    116 
    117 // Hop-by-hop headers. These are removed when sent to the backend.
    118 // http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
    119 var hopHeaders = []string{
    120 	"Connection",
    121 	"Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
    122 	"Keep-Alive",
    123 	"Proxy-Authenticate",
    124 	"Proxy-Authorization",
    125 	"Te",      // canonicalized version of "TE"
    126 	"Trailer", // not Trailers per URL above; http://www.rfc-editor.org/errata_search.php?eid=4522
    127 	"Transfer-Encoding",
    128 	"Upgrade",
    129 }
    130 
    131 func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
    132 	transport := p.Transport
    133 	if transport == nil {
    134 		transport = http.DefaultTransport
    135 	}
    136 
    137 	ctx := req.Context()
    138 	if cn, ok := rw.(http.CloseNotifier); ok {
    139 		var cancel context.CancelFunc
    140 		ctx, cancel = context.WithCancel(ctx)
    141 		defer cancel()
    142 		notifyChan := cn.CloseNotify()
    143 		go func() {
    144 			select {
    145 			case <-notifyChan:
    146 				cancel()
    147 			case <-ctx.Done():
    148 			}
    149 		}()
    150 	}
    151 
    152 	outreq := new(http.Request)
    153 	*outreq = *req // includes shallow copies of maps, but okay
    154 	if req.ContentLength == 0 {
    155 		outreq.Body = nil // Issue 16036: nil Body for http.Transport retries
    156 	}
    157 	outreq = outreq.WithContext(ctx)
    158 
    159 	p.Director(outreq)
    160 	outreq.Close = false
    161 
    162 	// We are modifying the same underlying map from req (shallow
    163 	// copied above) so we only copy it if necessary.
    164 	copiedHeaders := false
    165 
    166 	// Remove hop-by-hop headers listed in the "Connection" header.
    167 	// See RFC 2616, section 14.10.
    168 	if c := outreq.Header.Get("Connection"); c != "" {
    169 		for _, f := range strings.Split(c, ",") {
    170 			if f = strings.TrimSpace(f); f != "" {
    171 				if !copiedHeaders {
    172 					outreq.Header = make(http.Header)
    173 					copyHeader(outreq.Header, req.Header)
    174 					copiedHeaders = true
    175 				}
    176 				outreq.Header.Del(f)
    177 			}
    178 		}
    179 	}
    180 
    181 	// Remove hop-by-hop headers to the backend. Especially
    182 	// important is "Connection" because we want a persistent
    183 	// connection, regardless of what the client sent to us.
    184 	for _, h := range hopHeaders {
    185 		if outreq.Header.Get(h) != "" {
    186 			if !copiedHeaders {
    187 				outreq.Header = make(http.Header)
    188 				copyHeader(outreq.Header, req.Header)
    189 				copiedHeaders = true
    190 			}
    191 			outreq.Header.Del(h)
    192 		}
    193 	}
    194 
    195 	if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
    196 		// If we aren't the first proxy retain prior
    197 		// X-Forwarded-For information as a comma+space
    198 		// separated list and fold multiple headers into one.
    199 		if prior, ok := outreq.Header["X-Forwarded-For"]; ok {
    200 			clientIP = strings.Join(prior, ", ") + ", " + clientIP
    201 		}
    202 		outreq.Header.Set("X-Forwarded-For", clientIP)
    203 	}
    204 
    205 	res, err := transport.RoundTrip(outreq)
    206 	if err != nil {
    207 		p.logf("http: proxy error: %v", err)
    208 		rw.WriteHeader(http.StatusBadGateway)
    209 		return
    210 	}
    211 
    212 	// Remove hop-by-hop headers listed in the
    213 	// "Connection" header of the response.
    214 	if c := res.Header.Get("Connection"); c != "" {
    215 		for _, f := range strings.Split(c, ",") {
    216 			if f = strings.TrimSpace(f); f != "" {
    217 				res.Header.Del(f)
    218 			}
    219 		}
    220 	}
    221 
    222 	for _, h := range hopHeaders {
    223 		res.Header.Del(h)
    224 	}
    225 
    226 	if p.ModifyResponse != nil {
    227 		if err := p.ModifyResponse(res); err != nil {
    228 			p.logf("http: proxy error: %v", err)
    229 			rw.WriteHeader(http.StatusBadGateway)
    230 			return
    231 		}
    232 	}
    233 
    234 	copyHeader(rw.Header(), res.Header)
    235 
    236 	// The "Trailer" header isn't included in the Transport's response,
    237 	// at least for *http.Transport. Build it up from Trailer.
    238 	if len(res.Trailer) > 0 {
    239 		trailerKeys := make([]string, 0, len(res.Trailer))
    240 		for k := range res.Trailer {
    241 			trailerKeys = append(trailerKeys, k)
    242 		}
    243 		rw.Header().Add("Trailer", strings.Join(trailerKeys, ", "))
    244 	}
    245 
    246 	rw.WriteHeader(res.StatusCode)
    247 	if len(res.Trailer) > 0 {
    248 		// Force chunking if we saw a response trailer.
    249 		// This prevents net/http from calculating the length for short
    250 		// bodies and adding a Content-Length.
    251 		if fl, ok := rw.(http.Flusher); ok {
    252 			fl.Flush()
    253 		}
    254 	}
    255 	p.copyResponse(rw, res.Body)
    256 	res.Body.Close() // close now, instead of defer, to populate res.Trailer
    257 	copyHeader(rw.Header(), res.Trailer)
    258 }
    259 
    260 func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
    261 	if p.FlushInterval != 0 {
    262 		if wf, ok := dst.(writeFlusher); ok {
    263 			mlw := &maxLatencyWriter{
    264 				dst:     wf,
    265 				latency: p.FlushInterval,
    266 				done:    make(chan bool),
    267 			}
    268 			go mlw.flushLoop()
    269 			defer mlw.stop()
    270 			dst = mlw
    271 		}
    272 	}
    273 
    274 	var buf []byte
    275 	if p.BufferPool != nil {
    276 		buf = p.BufferPool.Get()
    277 	}
    278 	p.copyBuffer(dst, src, buf)
    279 	if p.BufferPool != nil {
    280 		p.BufferPool.Put(buf)
    281 	}
    282 }
    283 
    284 func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) {
    285 	if len(buf) == 0 {
    286 		buf = make([]byte, 32*1024)
    287 	}
    288 	var written int64
    289 	for {
    290 		nr, rerr := src.Read(buf)
    291 		if rerr != nil && rerr != io.EOF {
    292 			p.logf("httputil: ReverseProxy read error during body copy: %v", rerr)
    293 		}
    294 		if nr > 0 {
    295 			nw, werr := dst.Write(buf[:nr])
    296 			if nw > 0 {
    297 				written += int64(nw)
    298 			}
    299 			if werr != nil {
    300 				return written, werr
    301 			}
    302 			if nr != nw {
    303 				return written, io.ErrShortWrite
    304 			}
    305 		}
    306 		if rerr != nil {
    307 			return written, rerr
    308 		}
    309 	}
    310 }
    311 
    312 func (p *ReverseProxy) logf(format string, args ...interface{}) {
    313 	if p.ErrorLog != nil {
    314 		p.ErrorLog.Printf(format, args...)
    315 	} else {
    316 		log.Printf(format, args...)
    317 	}
    318 }
    319 
    320 type writeFlusher interface {
    321 	io.Writer
    322 	http.Flusher
    323 }
    324 
    325 type maxLatencyWriter struct {
    326 	dst     writeFlusher
    327 	latency time.Duration
    328 
    329 	mu   sync.Mutex // protects Write + Flush
    330 	done chan bool
    331 }
    332 
    333 func (m *maxLatencyWriter) Write(p []byte) (int, error) {
    334 	m.mu.Lock()
    335 	defer m.mu.Unlock()
    336 	return m.dst.Write(p)
    337 }
    338 
    339 func (m *maxLatencyWriter) flushLoop() {
    340 	t := time.NewTicker(m.latency)
    341 	defer t.Stop()
    342 	for {
    343 		select {
    344 		case <-m.done:
    345 			if onExitFlushLoop != nil {
    346 				onExitFlushLoop()
    347 			}
    348 			return
    349 		case <-t.C:
    350 			m.mu.Lock()
    351 			m.dst.Flush()
    352 			m.mu.Unlock()
    353 		}
    354 	}
    355 }
    356 
    357 func (m *maxLatencyWriter) stop() { m.done <- true }
    358