Home | History | Annotate | Download | only in runner
      1 // Copyright 2009 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 // Package tls partially implements TLS 1.2, as specified in RFC 5246.
      6 package runner
      7 
      8 import (
      9 	"crypto"
     10 	"crypto/ecdsa"
     11 	"crypto/rsa"
     12 	"crypto/x509"
     13 	"encoding/pem"
     14 	"errors"
     15 	"io/ioutil"
     16 	"net"
     17 	"strings"
     18 	"time"
     19 )
     20 
     21 // Server returns a new TLS server side connection
     22 // using conn as the underlying transport.
     23 // The configuration config must be non-nil and must have
     24 // at least one certificate.
     25 func Server(conn net.Conn, config *Config) *Conn {
     26 	c := &Conn{conn: conn, config: config}
     27 	c.init()
     28 	return c
     29 }
     30 
     31 // Client returns a new TLS client side connection
     32 // using conn as the underlying transport.
     33 // The config cannot be nil: users must set either ServerHostname or
     34 // InsecureSkipVerify in the config.
     35 func Client(conn net.Conn, config *Config) *Conn {
     36 	c := &Conn{conn: conn, config: config, isClient: true}
     37 	c.init()
     38 	return c
     39 }
     40 
     41 // A listener implements a network listener (net.Listener) for TLS connections.
     42 type listener struct {
     43 	net.Listener
     44 	config *Config
     45 }
     46 
     47 // Accept waits for and returns the next incoming TLS connection.
     48 // The returned connection c is a *tls.Conn.
     49 func (l *listener) Accept() (c net.Conn, err error) {
     50 	c, err = l.Listener.Accept()
     51 	if err != nil {
     52 		return
     53 	}
     54 	c = Server(c, l.config)
     55 	return
     56 }
     57 
     58 // NewListener creates a Listener which accepts connections from an inner
     59 // Listener and wraps each connection with Server.
     60 // The configuration config must be non-nil and must have
     61 // at least one certificate.
     62 func NewListener(inner net.Listener, config *Config) net.Listener {
     63 	l := new(listener)
     64 	l.Listener = inner
     65 	l.config = config
     66 	return l
     67 }
     68 
     69 // Listen creates a TLS listener accepting connections on the
     70 // given network address using net.Listen.
     71 // The configuration config must be non-nil and must have
     72 // at least one certificate.
     73 func Listen(network, laddr string, config *Config) (net.Listener, error) {
     74 	if config == nil || len(config.Certificates) == 0 {
     75 		return nil, errors.New("tls.Listen: no certificates in configuration")
     76 	}
     77 	l, err := net.Listen(network, laddr)
     78 	if err != nil {
     79 		return nil, err
     80 	}
     81 	return NewListener(l, config), nil
     82 }
     83 
     84 type timeoutError struct{}
     85 
     86 func (timeoutError) Error() string   { return "tls: DialWithDialer timed out" }
     87 func (timeoutError) Timeout() bool   { return true }
     88 func (timeoutError) Temporary() bool { return true }
     89 
     90 // DialWithDialer connects to the given network address using dialer.Dial and
     91 // then initiates a TLS handshake, returning the resulting TLS connection. Any
     92 // timeout or deadline given in the dialer apply to connection and TLS
     93 // handshake as a whole.
     94 //
     95 // DialWithDialer interprets a nil configuration as equivalent to the zero
     96 // configuration; see the documentation of Config for the defaults.
     97 func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
     98 	// We want the Timeout and Deadline values from dialer to cover the
     99 	// whole process: TCP connection and TLS handshake. This means that we
    100 	// also need to start our own timers now.
    101 	timeout := dialer.Timeout
    102 
    103 	if !dialer.Deadline.IsZero() {
    104 		deadlineTimeout := dialer.Deadline.Sub(time.Now())
    105 		if timeout == 0 || deadlineTimeout < timeout {
    106 			timeout = deadlineTimeout
    107 		}
    108 	}
    109 
    110 	var errChannel chan error
    111 
    112 	if timeout != 0 {
    113 		errChannel = make(chan error, 2)
    114 		time.AfterFunc(timeout, func() {
    115 			errChannel <- timeoutError{}
    116 		})
    117 	}
    118 
    119 	rawConn, err := dialer.Dial(network, addr)
    120 	if err != nil {
    121 		return nil, err
    122 	}
    123 
    124 	colonPos := strings.LastIndex(addr, ":")
    125 	if colonPos == -1 {
    126 		colonPos = len(addr)
    127 	}
    128 	hostname := addr[:colonPos]
    129 
    130 	if config == nil {
    131 		config = defaultConfig()
    132 	}
    133 	// If no ServerName is set, infer the ServerName
    134 	// from the hostname we're connecting to.
    135 	if config.ServerName == "" {
    136 		// Make a copy to avoid polluting argument or default.
    137 		c := *config
    138 		c.ServerName = hostname
    139 		config = &c
    140 	}
    141 
    142 	conn := Client(rawConn, config)
    143 
    144 	if timeout == 0 {
    145 		err = conn.Handshake()
    146 	} else {
    147 		go func() {
    148 			errChannel <- conn.Handshake()
    149 		}()
    150 
    151 		err = <-errChannel
    152 	}
    153 
    154 	if err != nil {
    155 		rawConn.Close()
    156 		return nil, err
    157 	}
    158 
    159 	return conn, nil
    160 }
    161 
    162 // Dial connects to the given network address using net.Dial
    163 // and then initiates a TLS handshake, returning the resulting
    164 // TLS connection.
    165 // Dial interprets a nil configuration as equivalent to
    166 // the zero configuration; see the documentation of Config
    167 // for the defaults.
    168 func Dial(network, addr string, config *Config) (*Conn, error) {
    169 	return DialWithDialer(new(net.Dialer), network, addr, config)
    170 }
    171 
    172 // LoadX509KeyPair reads and parses a public/private key pair from a pair of
    173 // files. The files must contain PEM encoded data.
    174 func LoadX509KeyPair(certFile, keyFile string) (cert Certificate, err error) {
    175 	certPEMBlock, err := ioutil.ReadFile(certFile)
    176 	if err != nil {
    177 		return
    178 	}
    179 	keyPEMBlock, err := ioutil.ReadFile(keyFile)
    180 	if err != nil {
    181 		return
    182 	}
    183 	return X509KeyPair(certPEMBlock, keyPEMBlock)
    184 }
    185 
    186 // X509KeyPair parses a public/private key pair from a pair of
    187 // PEM encoded data.
    188 func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (cert Certificate, err error) {
    189 	var certDERBlock *pem.Block
    190 	for {
    191 		certDERBlock, certPEMBlock = pem.Decode(certPEMBlock)
    192 		if certDERBlock == nil {
    193 			break
    194 		}
    195 		if certDERBlock.Type == "CERTIFICATE" {
    196 			cert.Certificate = append(cert.Certificate, certDERBlock.Bytes)
    197 		}
    198 	}
    199 
    200 	if len(cert.Certificate) == 0 {
    201 		err = errors.New("crypto/tls: failed to parse certificate PEM data")
    202 		return
    203 	}
    204 
    205 	var keyDERBlock *pem.Block
    206 	for {
    207 		keyDERBlock, keyPEMBlock = pem.Decode(keyPEMBlock)
    208 		if keyDERBlock == nil {
    209 			err = errors.New("crypto/tls: failed to parse key PEM data")
    210 			return
    211 		}
    212 		if keyDERBlock.Type == "PRIVATE KEY" || strings.HasSuffix(keyDERBlock.Type, " PRIVATE KEY") {
    213 			break
    214 		}
    215 	}
    216 
    217 	cert.PrivateKey, err = parsePrivateKey(keyDERBlock.Bytes)
    218 	if err != nil {
    219 		return
    220 	}
    221 
    222 	// We don't need to parse the public key for TLS, but we so do anyway
    223 	// to check that it looks sane and matches the private key.
    224 	x509Cert, err := x509.ParseCertificate(cert.Certificate[0])
    225 	if err != nil {
    226 		return
    227 	}
    228 
    229 	switch pub := x509Cert.PublicKey.(type) {
    230 	case *rsa.PublicKey:
    231 		priv, ok := cert.PrivateKey.(*rsa.PrivateKey)
    232 		if !ok {
    233 			err = errors.New("crypto/tls: private key type does not match public key type")
    234 			return
    235 		}
    236 		if pub.N.Cmp(priv.N) != 0 {
    237 			err = errors.New("crypto/tls: private key does not match public key")
    238 			return
    239 		}
    240 	case *ecdsa.PublicKey:
    241 		priv, ok := cert.PrivateKey.(*ecdsa.PrivateKey)
    242 		if !ok {
    243 			err = errors.New("crypto/tls: private key type does not match public key type")
    244 			return
    245 
    246 		}
    247 		if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 {
    248 			err = errors.New("crypto/tls: private key does not match public key")
    249 			return
    250 		}
    251 	default:
    252 		err = errors.New("crypto/tls: unknown public key algorithm")
    253 		return
    254 	}
    255 
    256 	return
    257 }
    258 
    259 // Attempt to parse the given private key DER block. OpenSSL 0.9.8 generates
    260 // PKCS#1 private keys by default, while OpenSSL 1.0.0 generates PKCS#8 keys.
    261 // OpenSSL ecparam generates SEC1 EC private keys for ECDSA. We try all three.
    262 func parsePrivateKey(der []byte) (crypto.PrivateKey, error) {
    263 	if key, err := x509.ParsePKCS1PrivateKey(der); err == nil {
    264 		return key, nil
    265 	}
    266 	if key, err := x509.ParsePKCS8PrivateKey(der); err == nil {
    267 		switch key := key.(type) {
    268 		case *rsa.PrivateKey, *ecdsa.PrivateKey:
    269 			return key, nil
    270 		default:
    271 			return nil, errors.New("crypto/tls: found unknown private key type in PKCS#8 wrapping")
    272 		}
    273 	}
    274 	if key, err := x509.ParseECPrivateKey(der); err == nil {
    275 		return key, nil
    276 	}
    277 
    278 	return nil, errors.New("crypto/tls: failed to parse private key")
    279 }
    280