Home | History | Annotate | Download | only in http
      1 // Copyright 2013 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 package http_test
      6 
      7 import (
      8 	"bufio"
      9 	"bytes"
     10 	"crypto/tls"
     11 	"fmt"
     12 	"io"
     13 	"io/ioutil"
     14 	. "net/http"
     15 	"net/http/httptest"
     16 	"strings"
     17 	"testing"
     18 )
     19 
     20 func TestNextProtoUpgrade(t *testing.T) {
     21 	defer afterTest(t)
     22 	ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
     23 		fmt.Fprintf(w, "path=%s,proto=", r.URL.Path)
     24 		if r.TLS != nil {
     25 			w.Write([]byte(r.TLS.NegotiatedProtocol))
     26 		}
     27 		if r.RemoteAddr == "" {
     28 			t.Error("request with no RemoteAddr")
     29 		}
     30 		if r.Body == nil {
     31 			t.Errorf("request with nil Body")
     32 		}
     33 	}))
     34 	ts.TLS = &tls.Config{
     35 		NextProtos: []string{"unhandled-proto", "tls-0.9"},
     36 	}
     37 	ts.Config.TLSNextProto = map[string]func(*Server, *tls.Conn, Handler){
     38 		"tls-0.9": handleTLSProtocol09,
     39 	}
     40 	ts.StartTLS()
     41 	defer ts.Close()
     42 
     43 	// Normal request, without NPN.
     44 	{
     45 		tr := newTLSTransport(t, ts)
     46 		defer tr.CloseIdleConnections()
     47 		c := &Client{Transport: tr}
     48 
     49 		res, err := c.Get(ts.URL)
     50 		if err != nil {
     51 			t.Fatal(err)
     52 		}
     53 		body, err := ioutil.ReadAll(res.Body)
     54 		if err != nil {
     55 			t.Fatal(err)
     56 		}
     57 		if want := "path=/,proto="; string(body) != want {
     58 			t.Errorf("plain request = %q; want %q", body, want)
     59 		}
     60 	}
     61 
     62 	// Request to an advertised but unhandled NPN protocol.
     63 	// Server will hang up.
     64 	{
     65 		tr := newTLSTransport(t, ts)
     66 		tr.TLSClientConfig.NextProtos = []string{"unhandled-proto"}
     67 		defer tr.CloseIdleConnections()
     68 		c := &Client{Transport: tr}
     69 
     70 		res, err := c.Get(ts.URL)
     71 		if err == nil {
     72 			defer res.Body.Close()
     73 			var buf bytes.Buffer
     74 			res.Write(&buf)
     75 			t.Errorf("expected error on unhandled-proto request; got: %s", buf.Bytes())
     76 		}
     77 	}
     78 
     79 	// Request using the "tls-0.9" protocol, which we register here.
     80 	// It is HTTP/0.9 over TLS.
     81 	{
     82 		tlsConfig := newTLSTransport(t, ts).TLSClientConfig
     83 		tlsConfig.NextProtos = []string{"tls-0.9"}
     84 		conn, err := tls.Dial("tcp", ts.Listener.Addr().String(), tlsConfig)
     85 		if err != nil {
     86 			t.Fatal(err)
     87 		}
     88 		conn.Write([]byte("GET /foo\n"))
     89 		body, err := ioutil.ReadAll(conn)
     90 		if err != nil {
     91 			t.Fatal(err)
     92 		}
     93 		if want := "path=/foo,proto=tls-0.9"; string(body) != want {
     94 			t.Errorf("plain request = %q; want %q", body, want)
     95 		}
     96 	}
     97 }
     98 
     99 // handleTLSProtocol09 implements the HTTP/0.9 protocol over TLS, for the
    100 // TestNextProtoUpgrade test.
    101 func handleTLSProtocol09(srv *Server, conn *tls.Conn, h Handler) {
    102 	br := bufio.NewReader(conn)
    103 	line, err := br.ReadString('\n')
    104 	if err != nil {
    105 		return
    106 	}
    107 	line = strings.TrimSpace(line)
    108 	path := strings.TrimPrefix(line, "GET ")
    109 	if path == line {
    110 		return
    111 	}
    112 	req, _ := NewRequest("GET", path, nil)
    113 	req.Proto = "HTTP/0.9"
    114 	req.ProtoMajor = 0
    115 	req.ProtoMinor = 9
    116 	rw := &http09Writer{conn, make(Header)}
    117 	h.ServeHTTP(rw, req)
    118 }
    119 
    120 type http09Writer struct {
    121 	io.Writer
    122 	h Header
    123 }
    124 
    125 func (w http09Writer) Header() Header  { return w.h }
    126 func (w http09Writer) WriteHeader(int) {} // no headers
    127