Home | History | Annotate | Download | only in httptest
      1 // Copyright 2011 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 // Implementation of Server
      6 
      7 package httptest
      8 
      9 import (
     10 	"crypto/tls"
     11 	"flag"
     12 	"fmt"
     13 	"net"
     14 	"net/http"
     15 	"os"
     16 	"sync"
     17 )
     18 
     19 // A Server is an HTTP server listening on a system-chosen port on the
     20 // local loopback interface, for use in end-to-end HTTP tests.
     21 type Server struct {
     22 	URL      string // base URL of form http://ipaddr:port with no trailing slash
     23 	Listener net.Listener
     24 
     25 	// TLS is the optional TLS configuration, populated with a new config
     26 	// after TLS is started. If set on an unstarted server before StartTLS
     27 	// is called, existing fields are copied into the new config.
     28 	TLS *tls.Config
     29 
     30 	// Config may be changed after calling NewUnstartedServer and
     31 	// before Start or StartTLS.
     32 	Config *http.Server
     33 
     34 	// wg counts the number of outstanding HTTP requests on this server.
     35 	// Close blocks until all requests are finished.
     36 	wg sync.WaitGroup
     37 }
     38 
     39 // historyListener keeps track of all connections that it's ever
     40 // accepted.
     41 type historyListener struct {
     42 	net.Listener
     43 	sync.Mutex // protects history
     44 	history    []net.Conn
     45 }
     46 
     47 func (hs *historyListener) Accept() (c net.Conn, err error) {
     48 	c, err = hs.Listener.Accept()
     49 	if err == nil {
     50 		hs.Lock()
     51 		hs.history = append(hs.history, c)
     52 		hs.Unlock()
     53 	}
     54 	return
     55 }
     56 
     57 func newLocalListener() net.Listener {
     58 	if *serve != "" {
     59 		l, err := net.Listen("tcp", *serve)
     60 		if err != nil {
     61 			panic(fmt.Sprintf("httptest: failed to listen on %v: %v", *serve, err))
     62 		}
     63 		return l
     64 	}
     65 	l, err := net.Listen("tcp", "127.0.0.1:0")
     66 	if err != nil {
     67 		if l, err = net.Listen("tcp6", "[::1]:0"); err != nil {
     68 			panic(fmt.Sprintf("httptest: failed to listen on a port: %v", err))
     69 		}
     70 	}
     71 	return l
     72 }
     73 
     74 // When debugging a particular http server-based test,
     75 // this flag lets you run
     76 //	go test -run=BrokenTest -httptest.serve=127.0.0.1:8000
     77 // to start the broken server so you can interact with it manually.
     78 var serve = flag.String("httptest.serve", "", "if non-empty, httptest.NewServer serves on this address and blocks")
     79 
     80 // NewServer starts and returns a new Server.
     81 // The caller should call Close when finished, to shut it down.
     82 func NewServer(handler http.Handler) *Server {
     83 	ts := NewUnstartedServer(handler)
     84 	ts.Start()
     85 	return ts
     86 }
     87 
     88 // NewUnstartedServer returns a new Server but doesn't start it.
     89 //
     90 // After changing its configuration, the caller should call Start or
     91 // StartTLS.
     92 //
     93 // The caller should call Close when finished, to shut it down.
     94 func NewUnstartedServer(handler http.Handler) *Server {
     95 	return &Server{
     96 		Listener: newLocalListener(),
     97 		Config:   &http.Server{Handler: handler},
     98 	}
     99 }
    100 
    101 // Start starts a server from NewUnstartedServer.
    102 func (s *Server) Start() {
    103 	if s.URL != "" {
    104 		panic("Server already started")
    105 	}
    106 	s.Listener = &historyListener{Listener: s.Listener}
    107 	s.URL = "http://" + s.Listener.Addr().String()
    108 	s.wrapHandler()
    109 	go s.Config.Serve(s.Listener)
    110 	if *serve != "" {
    111 		fmt.Fprintln(os.Stderr, "httptest: serving on", s.URL)
    112 		select {}
    113 	}
    114 }
    115 
    116 // StartTLS starts TLS on a server from NewUnstartedServer.
    117 func (s *Server) StartTLS() {
    118 	if s.URL != "" {
    119 		panic("Server already started")
    120 	}
    121 	cert, err := tls.X509KeyPair(localhostCert, localhostKey)
    122 	if err != nil {
    123 		panic(fmt.Sprintf("httptest: NewTLSServer: %v", err))
    124 	}
    125 
    126 	existingConfig := s.TLS
    127 	s.TLS = new(tls.Config)
    128 	if existingConfig != nil {
    129 		*s.TLS = *existingConfig
    130 	}
    131 	if s.TLS.NextProtos == nil {
    132 		s.TLS.NextProtos = []string{"http/1.1"}
    133 	}
    134 	if len(s.TLS.Certificates) == 0 {
    135 		s.TLS.Certificates = []tls.Certificate{cert}
    136 	}
    137 	tlsListener := tls.NewListener(s.Listener, s.TLS)
    138 
    139 	s.Listener = &historyListener{Listener: tlsListener}
    140 	s.URL = "https://" + s.Listener.Addr().String()
    141 	s.wrapHandler()
    142 	go s.Config.Serve(s.Listener)
    143 }
    144 
    145 func (s *Server) wrapHandler() {
    146 	h := s.Config.Handler
    147 	if h == nil {
    148 		h = http.DefaultServeMux
    149 	}
    150 	s.Config.Handler = &waitGroupHandler{
    151 		s: s,
    152 		h: h,
    153 	}
    154 }
    155 
    156 // NewTLSServer starts and returns a new Server using TLS.
    157 // The caller should call Close when finished, to shut it down.
    158 func NewTLSServer(handler http.Handler) *Server {
    159 	ts := NewUnstartedServer(handler)
    160 	ts.StartTLS()
    161 	return ts
    162 }
    163 
    164 // Close shuts down the server and blocks until all outstanding
    165 // requests on this server have completed.
    166 func (s *Server) Close() {
    167 	s.Listener.Close()
    168 	s.wg.Wait()
    169 	s.CloseClientConnections()
    170 	if t, ok := http.DefaultTransport.(*http.Transport); ok {
    171 		t.CloseIdleConnections()
    172 	}
    173 }
    174 
    175 // CloseClientConnections closes any currently open HTTP connections
    176 // to the test Server.
    177 func (s *Server) CloseClientConnections() {
    178 	hl, ok := s.Listener.(*historyListener)
    179 	if !ok {
    180 		return
    181 	}
    182 	hl.Lock()
    183 	for _, conn := range hl.history {
    184 		conn.Close()
    185 	}
    186 	hl.Unlock()
    187 }
    188 
    189 // waitGroupHandler wraps a handler, incrementing and decrementing a
    190 // sync.WaitGroup on each request, to enable Server.Close to block
    191 // until outstanding requests are finished.
    192 type waitGroupHandler struct {
    193 	s *Server
    194 	h http.Handler // non-nil
    195 }
    196 
    197 func (h *waitGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    198 	h.s.wg.Add(1)
    199 	defer h.s.wg.Done() // a defer, in case ServeHTTP below panics
    200 	h.h.ServeHTTP(w, r)
    201 }
    202 
    203 // localhostCert is a PEM-encoded TLS cert with SAN IPs
    204 // "127.0.0.1" and "[::1]", expiring at the last second of 2049 (the end
    205 // of ASN.1 time).
    206 // generated from src/crypto/tls:
    207 // go run generate_cert.go  --rsa-bits 1024 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h
    208 var localhostCert = []byte(`-----BEGIN CERTIFICATE-----
    209 MIICEzCCAXygAwIBAgIQMIMChMLGrR+QvmQvpwAU6zANBgkqhkiG9w0BAQsFADAS
    210 MRAwDgYDVQQKEwdBY21lIENvMCAXDTcwMDEwMTAwMDAwMFoYDzIwODQwMTI5MTYw
    211 MDAwWjASMRAwDgYDVQQKEwdBY21lIENvMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCB
    212 iQKBgQDuLnQAI3mDgey3VBzWnB2L39JUU4txjeVE6myuDqkM/uGlfjb9SjY1bIw4
    213 iA5sBBZzHi3z0h1YV8QPuxEbi4nW91IJm2gsvvZhIrCHS3l6afab4pZBl2+XsDul
    214 rKBxKKtD1rGxlG4LjncdabFn9gvLZad2bSysqz/qTAUStTvqJQIDAQABo2gwZjAO
    215 BgNVHQ8BAf8EBAMCAqQwEwYDVR0lBAwwCgYIKwYBBQUHAwEwDwYDVR0TAQH/BAUw
    216 AwEB/zAuBgNVHREEJzAlggtleGFtcGxlLmNvbYcEfwAAAYcQAAAAAAAAAAAAAAAA
    217 AAAAATANBgkqhkiG9w0BAQsFAAOBgQCEcetwO59EWk7WiJsG4x8SY+UIAA+flUI9
    218 tyC4lNhbcF2Idq9greZwbYCqTTTr2XiRNSMLCOjKyI7ukPoPjo16ocHj+P3vZGfs
    219 h1fIw3cSS2OolhloGw/XM6RWPWtPAlGykKLciQrBru5NAPvCMsb/I1DAceTiotQM
    220 fblo6RBxUQ==
    221 -----END CERTIFICATE-----`)
    222 
    223 // localhostKey is the private key for localhostCert.
    224 var localhostKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
    225 MIICXgIBAAKBgQDuLnQAI3mDgey3VBzWnB2L39JUU4txjeVE6myuDqkM/uGlfjb9
    226 SjY1bIw4iA5sBBZzHi3z0h1YV8QPuxEbi4nW91IJm2gsvvZhIrCHS3l6afab4pZB
    227 l2+XsDulrKBxKKtD1rGxlG4LjncdabFn9gvLZad2bSysqz/qTAUStTvqJQIDAQAB
    228 AoGAGRzwwir7XvBOAy5tM/uV6e+Zf6anZzus1s1Y1ClbjbE6HXbnWWF/wbZGOpet
    229 3Zm4vD6MXc7jpTLryzTQIvVdfQbRc6+MUVeLKwZatTXtdZrhu+Jk7hx0nTPy8Jcb
    230 uJqFk541aEw+mMogY/xEcfbWd6IOkp+4xqjlFLBEDytgbIECQQDvH/E6nk+hgN4H
    231 qzzVtxxr397vWrjrIgPbJpQvBsafG7b0dA4AFjwVbFLmQcj2PprIMmPcQrooz8vp
    232 jy4SHEg1AkEA/v13/5M47K9vCxmb8QeD/asydfsgS5TeuNi8DoUBEmiSJwma7FXY
    233 fFUtxuvL7XvjwjN5B30pNEbc6Iuyt7y4MQJBAIt21su4b3sjXNueLKH85Q+phy2U
    234 fQtuUE9txblTu14q3N7gHRZB4ZMhFYyDy8CKrN2cPg/Fvyt0Xlp/DoCzjA0CQQDU
    235 y2ptGsuSmgUtWj3NM9xuwYPm+Z/F84K6+ARYiZ6PYj013sovGKUFfYAqVXVlxtIX
    236 qyUBnu3X9ps8ZfjLZO7BAkEAlT4R5Yl6cGhaJQYZHOde3JEMhNRcVFMO8dJDaFeo
    237 f9Oeos0UUothgiDktdQHxdNEwLjQf7lJJBzV+5OtwswCWA==
    238 -----END RSA PRIVATE KEY-----`)
    239