Home | History | Annotate | Download | only in tls
      1 // Copyright 2013 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 package tls
      6 
      7 import (
      8 	"bufio"
      9 	"encoding/hex"
     10 	"errors"
     11 	"flag"
     12 	"fmt"
     13 	"io"
     14 	"io/ioutil"
     15 	"net"
     16 	"os/exec"
     17 	"strconv"
     18 	"strings"
     19 	"sync"
     20 	"testing"
     21 )
     22 
     23 // TLS reference tests run a connection against a reference implementation
     24 // (OpenSSL) of TLS and record the bytes of the resulting connection. The Go
     25 // code, during a test, is configured with deterministic randomness and so the
     26 // reference test can be reproduced exactly in the future.
     27 //
     28 // In order to save everyone who wishes to run the tests from needing the
     29 // reference implementation installed, the reference connections are saved in
     30 // files in the testdata directory. Thus running the tests involves nothing
     31 // external, but creating and updating them requires the reference
     32 // implementation.
     33 //
     34 // Tests can be updated by running them with the -update flag. This will cause
     35 // the test files to be regenerated. Generally one should combine the -update
     36 // flag with -test.run to updated a specific test. Since the reference
     37 // implementation will always generate fresh random numbers, large parts of
     38 // the reference connection will always change.
     39 
     40 var (
     41 	update = flag.Bool("update", false, "update golden files on disk")
     42 
     43 	opensslVersionTestOnce sync.Once
     44 	opensslVersionTestErr  error
     45 )
     46 
     47 func checkOpenSSLVersion(t *testing.T) {
     48 	opensslVersionTestOnce.Do(testOpenSSLVersion)
     49 	if opensslVersionTestErr != nil {
     50 		t.Fatal(opensslVersionTestErr)
     51 	}
     52 }
     53 
     54 func testOpenSSLVersion() {
     55 	// This test ensures that the version of OpenSSL looks reasonable
     56 	// before updating the test data.
     57 
     58 	if !*update {
     59 		return
     60 	}
     61 
     62 	openssl := exec.Command("openssl", "version")
     63 	output, err := openssl.CombinedOutput()
     64 	if err != nil {
     65 		opensslVersionTestErr = err
     66 		return
     67 	}
     68 
     69 	version := string(output)
     70 	if strings.HasPrefix(version, "OpenSSL 1.1.0") {
     71 		return
     72 	}
     73 
     74 	println("***********************************************")
     75 	println("")
     76 	println("You need to build OpenSSL 1.1.0 from source in order")
     77 	println("to update the test data.")
     78 	println("")
     79 	println("Configure it with:")
     80 	println("./Configure enable-weak-ssl-ciphers enable-ssl3 enable-ssl3-method -static linux-x86_64")
     81 	println("and then add the apps/ directory at the front of your PATH.")
     82 	println("***********************************************")
     83 
     84 	opensslVersionTestErr = errors.New("version of OpenSSL does not appear to be suitable for updating test data")
     85 }
     86 
     87 // recordingConn is a net.Conn that records the traffic that passes through it.
     88 // WriteTo can be used to produce output that can be later be loaded with
     89 // ParseTestData.
     90 type recordingConn struct {
     91 	net.Conn
     92 	sync.Mutex
     93 	flows   [][]byte
     94 	reading bool
     95 }
     96 
     97 func (r *recordingConn) Read(b []byte) (n int, err error) {
     98 	if n, err = r.Conn.Read(b); n == 0 {
     99 		return
    100 	}
    101 	b = b[:n]
    102 
    103 	r.Lock()
    104 	defer r.Unlock()
    105 
    106 	if l := len(r.flows); l == 0 || !r.reading {
    107 		buf := make([]byte, len(b))
    108 		copy(buf, b)
    109 		r.flows = append(r.flows, buf)
    110 	} else {
    111 		r.flows[l-1] = append(r.flows[l-1], b[:n]...)
    112 	}
    113 	r.reading = true
    114 	return
    115 }
    116 
    117 func (r *recordingConn) Write(b []byte) (n int, err error) {
    118 	if n, err = r.Conn.Write(b); n == 0 {
    119 		return
    120 	}
    121 	b = b[:n]
    122 
    123 	r.Lock()
    124 	defer r.Unlock()
    125 
    126 	if l := len(r.flows); l == 0 || r.reading {
    127 		buf := make([]byte, len(b))
    128 		copy(buf, b)
    129 		r.flows = append(r.flows, buf)
    130 	} else {
    131 		r.flows[l-1] = append(r.flows[l-1], b[:n]...)
    132 	}
    133 	r.reading = false
    134 	return
    135 }
    136 
    137 // WriteTo writes Go source code to w that contains the recorded traffic.
    138 func (r *recordingConn) WriteTo(w io.Writer) (int64, error) {
    139 	// TLS always starts with a client to server flow.
    140 	clientToServer := true
    141 	var written int64
    142 	for i, flow := range r.flows {
    143 		source, dest := "client", "server"
    144 		if !clientToServer {
    145 			source, dest = dest, source
    146 		}
    147 		n, err := fmt.Fprintf(w, ">>> Flow %d (%s to %s)\n", i+1, source, dest)
    148 		written += int64(n)
    149 		if err != nil {
    150 			return written, err
    151 		}
    152 		dumper := hex.Dumper(w)
    153 		n, err = dumper.Write(flow)
    154 		written += int64(n)
    155 		if err != nil {
    156 			return written, err
    157 		}
    158 		err = dumper.Close()
    159 		if err != nil {
    160 			return written, err
    161 		}
    162 		clientToServer = !clientToServer
    163 	}
    164 	return written, nil
    165 }
    166 
    167 func parseTestData(r io.Reader) (flows [][]byte, err error) {
    168 	var currentFlow []byte
    169 
    170 	scanner := bufio.NewScanner(r)
    171 	for scanner.Scan() {
    172 		line := scanner.Text()
    173 		// If the line starts with ">>> " then it marks the beginning
    174 		// of a new flow.
    175 		if strings.HasPrefix(line, ">>> ") {
    176 			if len(currentFlow) > 0 || len(flows) > 0 {
    177 				flows = append(flows, currentFlow)
    178 				currentFlow = nil
    179 			}
    180 			continue
    181 		}
    182 
    183 		// Otherwise the line is a line of hex dump that looks like:
    184 		// 00000170  fc f5 06 bf (...)  |.....X{&?......!|
    185 		// (Some bytes have been omitted from the middle section.)
    186 
    187 		if i := strings.IndexByte(line, ' '); i >= 0 {
    188 			line = line[i:]
    189 		} else {
    190 			return nil, errors.New("invalid test data")
    191 		}
    192 
    193 		if i := strings.IndexByte(line, '|'); i >= 0 {
    194 			line = line[:i]
    195 		} else {
    196 			return nil, errors.New("invalid test data")
    197 		}
    198 
    199 		hexBytes := strings.Fields(line)
    200 		for _, hexByte := range hexBytes {
    201 			val, err := strconv.ParseUint(hexByte, 16, 8)
    202 			if err != nil {
    203 				return nil, errors.New("invalid hex byte in test data: " + err.Error())
    204 			}
    205 			currentFlow = append(currentFlow, byte(val))
    206 		}
    207 	}
    208 
    209 	if len(currentFlow) > 0 {
    210 		flows = append(flows, currentFlow)
    211 	}
    212 
    213 	return flows, nil
    214 }
    215 
    216 // tempFile creates a temp file containing contents and returns its path.
    217 func tempFile(contents string) string {
    218 	file, err := ioutil.TempFile("", "go-tls-test")
    219 	if err != nil {
    220 		panic("failed to create temp file: " + err.Error())
    221 	}
    222 	path := file.Name()
    223 	file.WriteString(contents)
    224 	file.Close()
    225 	return path
    226 }
    227