Home | History | Annotate | Download | only in httptest
      1 // Copyright 2011 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 // Implementation of Server
      6 
      7 package httptest
      8 
      9 import (
     10 	"bytes"
     11 	"crypto/tls"
     12 	"flag"
     13 	"fmt"
     14 	"log"
     15 	"net"
     16 	"net/http"
     17 	"net/http/internal"
     18 	"os"
     19 	"sync"
     20 	"time"
     21 )
     22 
     23 // A Server is an HTTP server listening on a system-chosen port on the
     24 // local loopback interface, for use in end-to-end HTTP tests.
     25 type Server struct {
     26 	URL      string // base URL of form http://ipaddr:port with no trailing slash
     27 	Listener net.Listener
     28 
     29 	// TLS is the optional TLS configuration, populated with a new config
     30 	// after TLS is started. If set on an unstarted server before StartTLS
     31 	// is called, existing fields are copied into the new config.
     32 	TLS *tls.Config
     33 
     34 	// Config may be changed after calling NewUnstartedServer and
     35 	// before Start or StartTLS.
     36 	Config *http.Server
     37 
     38 	// wg counts the number of outstanding HTTP requests on this server.
     39 	// Close blocks until all requests are finished.
     40 	wg sync.WaitGroup
     41 
     42 	mu     sync.Mutex // guards closed and conns
     43 	closed bool
     44 	conns  map[net.Conn]http.ConnState // except terminal states
     45 }
     46 
     47 func newLocalListener() net.Listener {
     48 	if *serve != "" {
     49 		l, err := net.Listen("tcp", *serve)
     50 		if err != nil {
     51 			panic(fmt.Sprintf("httptest: failed to listen on %v: %v", *serve, err))
     52 		}
     53 		return l
     54 	}
     55 	l, err := net.Listen("tcp", "127.0.0.1:0")
     56 	if err != nil {
     57 		if l, err = net.Listen("tcp6", "[::1]:0"); err != nil {
     58 			panic(fmt.Sprintf("httptest: failed to listen on a port: %v", err))
     59 		}
     60 	}
     61 	return l
     62 }
     63 
     64 // When debugging a particular http server-based test,
     65 // this flag lets you run
     66 //	go test -run=BrokenTest -httptest.serve=127.0.0.1:8000
     67 // to start the broken server so you can interact with it manually.
     68 var serve = flag.String("httptest.serve", "", "if non-empty, httptest.NewServer serves on this address and blocks")
     69 
     70 // NewServer starts and returns a new Server.
     71 // The caller should call Close when finished, to shut it down.
     72 func NewServer(handler http.Handler) *Server {
     73 	ts := NewUnstartedServer(handler)
     74 	ts.Start()
     75 	return ts
     76 }
     77 
     78 // NewUnstartedServer returns a new Server but doesn't start it.
     79 //
     80 // After changing its configuration, the caller should call Start or
     81 // StartTLS.
     82 //
     83 // The caller should call Close when finished, to shut it down.
     84 func NewUnstartedServer(handler http.Handler) *Server {
     85 	return &Server{
     86 		Listener: newLocalListener(),
     87 		Config:   &http.Server{Handler: handler},
     88 	}
     89 }
     90 
     91 // Start starts a server from NewUnstartedServer.
     92 func (s *Server) Start() {
     93 	if s.URL != "" {
     94 		panic("Server already started")
     95 	}
     96 	s.URL = "http://" + s.Listener.Addr().String()
     97 	s.wrap()
     98 	s.goServe()
     99 	if *serve != "" {
    100 		fmt.Fprintln(os.Stderr, "httptest: serving on", s.URL)
    101 		select {}
    102 	}
    103 }
    104 
    105 // StartTLS starts TLS on a server from NewUnstartedServer.
    106 func (s *Server) StartTLS() {
    107 	if s.URL != "" {
    108 		panic("Server already started")
    109 	}
    110 	cert, err := tls.X509KeyPair(internal.LocalhostCert, internal.LocalhostKey)
    111 	if err != nil {
    112 		panic(fmt.Sprintf("httptest: NewTLSServer: %v", err))
    113 	}
    114 
    115 	existingConfig := s.TLS
    116 	if existingConfig != nil {
    117 		s.TLS = existingConfig.Clone()
    118 	} else {
    119 		s.TLS = new(tls.Config)
    120 	}
    121 	if s.TLS.NextProtos == nil {
    122 		s.TLS.NextProtos = []string{"http/1.1"}
    123 	}
    124 	if len(s.TLS.Certificates) == 0 {
    125 		s.TLS.Certificates = []tls.Certificate{cert}
    126 	}
    127 	s.Listener = tls.NewListener(s.Listener, s.TLS)
    128 	s.URL = "https://" + s.Listener.Addr().String()
    129 	s.wrap()
    130 	s.goServe()
    131 }
    132 
    133 // NewTLSServer starts and returns a new Server using TLS.
    134 // The caller should call Close when finished, to shut it down.
    135 func NewTLSServer(handler http.Handler) *Server {
    136 	ts := NewUnstartedServer(handler)
    137 	ts.StartTLS()
    138 	return ts
    139 }
    140 
    141 type closeIdleTransport interface {
    142 	CloseIdleConnections()
    143 }
    144 
    145 // Close shuts down the server and blocks until all outstanding
    146 // requests on this server have completed.
    147 func (s *Server) Close() {
    148 	s.mu.Lock()
    149 	if !s.closed {
    150 		s.closed = true
    151 		s.Listener.Close()
    152 		s.Config.SetKeepAlivesEnabled(false)
    153 		for c, st := range s.conns {
    154 			// Force-close any idle connections (those between
    155 			// requests) and new connections (those which connected
    156 			// but never sent a request). StateNew connections are
    157 			// super rare and have only been seen (in
    158 			// previously-flaky tests) in the case of
    159 			// socket-late-binding races from the http Client
    160 			// dialing this server and then getting an idle
    161 			// connection before the dial completed. There is thus
    162 			// a connected connection in StateNew with no
    163 			// associated Request. We only close StateIdle and
    164 			// StateNew because they're not doing anything. It's
    165 			// possible StateNew is about to do something in a few
    166 			// milliseconds, but a previous CL to check again in a
    167 			// few milliseconds wasn't liked (early versions of
    168 			// https://golang.org/cl/15151) so now we just
    169 			// forcefully close StateNew. The docs for Server.Close say
    170 			// we wait for "outstanding requests", so we don't close things
    171 			// in StateActive.
    172 			if st == http.StateIdle || st == http.StateNew {
    173 				s.closeConn(c)
    174 			}
    175 		}
    176 		// If this server doesn't shut down in 5 seconds, tell the user why.
    177 		t := time.AfterFunc(5*time.Second, s.logCloseHangDebugInfo)
    178 		defer t.Stop()
    179 	}
    180 	s.mu.Unlock()
    181 
    182 	// Not part of httptest.Server's correctness, but assume most
    183 	// users of httptest.Server will be using the standard
    184 	// transport, so help them out and close any idle connections for them.
    185 	if t, ok := http.DefaultTransport.(closeIdleTransport); ok {
    186 		t.CloseIdleConnections()
    187 	}
    188 
    189 	s.wg.Wait()
    190 }
    191 
    192 func (s *Server) logCloseHangDebugInfo() {
    193 	s.mu.Lock()
    194 	defer s.mu.Unlock()
    195 	var buf bytes.Buffer
    196 	buf.WriteString("httptest.Server blocked in Close after 5 seconds, waiting for connections:\n")
    197 	for c, st := range s.conns {
    198 		fmt.Fprintf(&buf, "  %T %p %v in state %v\n", c, c, c.RemoteAddr(), st)
    199 	}
    200 	log.Print(buf.String())
    201 }
    202 
    203 // CloseClientConnections closes any open HTTP connections to the test Server.
    204 func (s *Server) CloseClientConnections() {
    205 	s.mu.Lock()
    206 	nconn := len(s.conns)
    207 	ch := make(chan struct{}, nconn)
    208 	for c := range s.conns {
    209 		s.closeConnChan(c, ch)
    210 	}
    211 	s.mu.Unlock()
    212 
    213 	// Wait for outstanding closes to finish.
    214 	//
    215 	// Out of paranoia for making a late change in Go 1.6, we
    216 	// bound how long this can wait, since golang.org/issue/14291
    217 	// isn't fully understood yet. At least this should only be used
    218 	// in tests.
    219 	timer := time.NewTimer(5 * time.Second)
    220 	defer timer.Stop()
    221 	for i := 0; i < nconn; i++ {
    222 		select {
    223 		case <-ch:
    224 		case <-timer.C:
    225 			// Too slow. Give up.
    226 			return
    227 		}
    228 	}
    229 }
    230 
    231 func (s *Server) goServe() {
    232 	s.wg.Add(1)
    233 	go func() {
    234 		defer s.wg.Done()
    235 		s.Config.Serve(s.Listener)
    236 	}()
    237 }
    238 
    239 // wrap installs the connection state-tracking hook to know which
    240 // connections are idle.
    241 func (s *Server) wrap() {
    242 	oldHook := s.Config.ConnState
    243 	s.Config.ConnState = func(c net.Conn, cs http.ConnState) {
    244 		s.mu.Lock()
    245 		defer s.mu.Unlock()
    246 		switch cs {
    247 		case http.StateNew:
    248 			s.wg.Add(1)
    249 			if _, exists := s.conns[c]; exists {
    250 				panic("invalid state transition")
    251 			}
    252 			if s.conns == nil {
    253 				s.conns = make(map[net.Conn]http.ConnState)
    254 			}
    255 			s.conns[c] = cs
    256 			if s.closed {
    257 				// Probably just a socket-late-binding dial from
    258 				// the default transport that lost the race (and
    259 				// thus this connection is now idle and will
    260 				// never be used).
    261 				s.closeConn(c)
    262 			}
    263 		case http.StateActive:
    264 			if oldState, ok := s.conns[c]; ok {
    265 				if oldState != http.StateNew && oldState != http.StateIdle {
    266 					panic("invalid state transition")
    267 				}
    268 				s.conns[c] = cs
    269 			}
    270 		case http.StateIdle:
    271 			if oldState, ok := s.conns[c]; ok {
    272 				if oldState != http.StateActive {
    273 					panic("invalid state transition")
    274 				}
    275 				s.conns[c] = cs
    276 			}
    277 			if s.closed {
    278 				s.closeConn(c)
    279 			}
    280 		case http.StateHijacked, http.StateClosed:
    281 			s.forgetConn(c)
    282 		}
    283 		if oldHook != nil {
    284 			oldHook(c, cs)
    285 		}
    286 	}
    287 }
    288 
    289 // closeConn closes c.
    290 // s.mu must be held.
    291 func (s *Server) closeConn(c net.Conn) { s.closeConnChan(c, nil) }
    292 
    293 // closeConnChan is like closeConn, but takes an optional channel to receive a value
    294 // when the goroutine closing c is done.
    295 func (s *Server) closeConnChan(c net.Conn, done chan<- struct{}) {
    296 	c.Close()
    297 	if done != nil {
    298 		done <- struct{}{}
    299 	}
    300 }
    301 
    302 // forgetConn removes c from the set of tracked conns and decrements it from the
    303 // waitgroup, unless it was previously removed.
    304 // s.mu must be held.
    305 func (s *Server) forgetConn(c net.Conn) {
    306 	if _, ok := s.conns[c]; ok {
    307 		delete(s.conns, c)
    308 		s.wg.Done()
    309 	}
    310 }
    311