Home | History | Annotate | Download | only in runner
      1 // Copyright 2009 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 // Package tls partially implements TLS 1.2, as specified in RFC 5246.
      6 package runner
      7 
      8 import (
      9 	"bytes"
     10 	"crypto"
     11 	"crypto/ecdsa"
     12 	"crypto/rsa"
     13 	"crypto/x509"
     14 	"encoding/pem"
     15 	"errors"
     16 	"io/ioutil"
     17 	"net"
     18 	"strings"
     19 	"time"
     20 
     21 	"./ed25519"
     22 )
     23 
     24 // Server returns a new TLS server side connection
     25 // using conn as the underlying transport.
     26 // The configuration config must be non-nil and must have
     27 // at least one certificate.
     28 func Server(conn net.Conn, config *Config) *Conn {
     29 	c := &Conn{conn: conn, config: config}
     30 	c.init()
     31 	return c
     32 }
     33 
     34 // Client returns a new TLS client side connection
     35 // using conn as the underlying transport.
     36 // The config cannot be nil: users must set either ServerHostname or
     37 // InsecureSkipVerify in the config.
     38 func Client(conn net.Conn, config *Config) *Conn {
     39 	c := &Conn{conn: conn, config: config, isClient: true}
     40 	c.init()
     41 	return c
     42 }
     43 
     44 // A listener implements a network listener (net.Listener) for TLS connections.
     45 type listener struct {
     46 	net.Listener
     47 	config *Config
     48 }
     49 
     50 // Accept waits for and returns the next incoming TLS connection.
     51 // The returned connection c is a *tls.Conn.
     52 func (l *listener) Accept() (c net.Conn, err error) {
     53 	c, err = l.Listener.Accept()
     54 	if err != nil {
     55 		return
     56 	}
     57 	c = Server(c, l.config)
     58 	return
     59 }
     60 
     61 // NewListener creates a Listener which accepts connections from an inner
     62 // Listener and wraps each connection with Server.
     63 // The configuration config must be non-nil and must have
     64 // at least one certificate.
     65 func NewListener(inner net.Listener, config *Config) net.Listener {
     66 	l := new(listener)
     67 	l.Listener = inner
     68 	l.config = config
     69 	return l
     70 }
     71 
     72 // Listen creates a TLS listener accepting connections on the
     73 // given network address using net.Listen.
     74 // The configuration config must be non-nil and must have
     75 // at least one certificate.
     76 func Listen(network, laddr string, config *Config) (net.Listener, error) {
     77 	if config == nil || len(config.Certificates) == 0 {
     78 		return nil, errors.New("tls.Listen: no certificates in configuration")
     79 	}
     80 	l, err := net.Listen(network, laddr)
     81 	if err != nil {
     82 		return nil, err
     83 	}
     84 	return NewListener(l, config), nil
     85 }
     86 
     87 type timeoutError struct{}
     88 
     89 func (timeoutError) Error() string   { return "tls: DialWithDialer timed out" }
     90 func (timeoutError) Timeout() bool   { return true }
     91 func (timeoutError) Temporary() bool { return true }
     92 
     93 // DialWithDialer connects to the given network address using dialer.Dial and
     94 // then initiates a TLS handshake, returning the resulting TLS connection. Any
     95 // timeout or deadline given in the dialer apply to connection and TLS
     96 // handshake as a whole.
     97 //
     98 // DialWithDialer interprets a nil configuration as equivalent to the zero
     99 // configuration; see the documentation of Config for the defaults.
    100 func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
    101 	// We want the Timeout and Deadline values from dialer to cover the
    102 	// whole process: TCP connection and TLS handshake. This means that we
    103 	// also need to start our own timers now.
    104 	timeout := dialer.Timeout
    105 
    106 	if !dialer.Deadline.IsZero() {
    107 		deadlineTimeout := dialer.Deadline.Sub(time.Now())
    108 		if timeout == 0 || deadlineTimeout < timeout {
    109 			timeout = deadlineTimeout
    110 		}
    111 	}
    112 
    113 	var errChannel chan error
    114 
    115 	if timeout != 0 {
    116 		errChannel = make(chan error, 2)
    117 		time.AfterFunc(timeout, func() {
    118 			errChannel <- timeoutError{}
    119 		})
    120 	}
    121 
    122 	rawConn, err := dialer.Dial(network, addr)
    123 	if err != nil {
    124 		return nil, err
    125 	}
    126 
    127 	colonPos := strings.LastIndex(addr, ":")
    128 	if colonPos == -1 {
    129 		colonPos = len(addr)
    130 	}
    131 	hostname := addr[:colonPos]
    132 
    133 	if config == nil {
    134 		config = defaultConfig()
    135 	}
    136 	// If no ServerName is set, infer the ServerName
    137 	// from the hostname we're connecting to.
    138 	if config.ServerName == "" {
    139 		// Make a copy to avoid polluting argument or default.
    140 		c := *config
    141 		c.ServerName = hostname
    142 		config = &c
    143 	}
    144 
    145 	conn := Client(rawConn, config)
    146 
    147 	if timeout == 0 {
    148 		err = conn.Handshake()
    149 	} else {
    150 		go func() {
    151 			errChannel <- conn.Handshake()
    152 		}()
    153 
    154 		err = <-errChannel
    155 	}
    156 
    157 	if err != nil {
    158 		rawConn.Close()
    159 		return nil, err
    160 	}
    161 
    162 	return conn, nil
    163 }
    164 
    165 // Dial connects to the given network address using net.Dial
    166 // and then initiates a TLS handshake, returning the resulting
    167 // TLS connection.
    168 // Dial interprets a nil configuration as equivalent to
    169 // the zero configuration; see the documentation of Config
    170 // for the defaults.
    171 func Dial(network, addr string, config *Config) (*Conn, error) {
    172 	return DialWithDialer(new(net.Dialer), network, addr, config)
    173 }
    174 
    175 // LoadX509KeyPair reads and parses a public/private key pair from a pair of
    176 // files. The files must contain PEM encoded data.
    177 func LoadX509KeyPair(certFile, keyFile string) (cert Certificate, err error) {
    178 	certPEMBlock, err := ioutil.ReadFile(certFile)
    179 	if err != nil {
    180 		return
    181 	}
    182 	keyPEMBlock, err := ioutil.ReadFile(keyFile)
    183 	if err != nil {
    184 		return
    185 	}
    186 	return X509KeyPair(certPEMBlock, keyPEMBlock)
    187 }
    188 
    189 // X509KeyPair parses a public/private key pair from a pair of
    190 // PEM encoded data.
    191 func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (cert Certificate, err error) {
    192 	var certDERBlock *pem.Block
    193 	for {
    194 		certDERBlock, certPEMBlock = pem.Decode(certPEMBlock)
    195 		if certDERBlock == nil {
    196 			break
    197 		}
    198 		if certDERBlock.Type == "CERTIFICATE" {
    199 			cert.Certificate = append(cert.Certificate, certDERBlock.Bytes)
    200 		}
    201 	}
    202 
    203 	if len(cert.Certificate) == 0 {
    204 		err = errors.New("crypto/tls: failed to parse certificate PEM data")
    205 		return
    206 	}
    207 
    208 	var keyDERBlock *pem.Block
    209 	for {
    210 		keyDERBlock, keyPEMBlock = pem.Decode(keyPEMBlock)
    211 		if keyDERBlock == nil {
    212 			err = errors.New("crypto/tls: failed to parse key PEM data")
    213 			return
    214 		}
    215 		if keyDERBlock.Type == "PRIVATE KEY" || strings.HasSuffix(keyDERBlock.Type, " PRIVATE KEY") {
    216 			break
    217 		}
    218 	}
    219 
    220 	cert.PrivateKey, err = parsePrivateKey(keyDERBlock.Bytes)
    221 	if err != nil {
    222 		return
    223 	}
    224 
    225 	// We don't need to parse the public key for TLS, but we so do anyway
    226 	// to check that it looks sane and matches the private key.
    227 	x509Cert, err := x509.ParseCertificate(cert.Certificate[0])
    228 	if err != nil {
    229 		return
    230 	}
    231 
    232 	switch pub := getCertificatePublicKey(x509Cert).(type) {
    233 	case *rsa.PublicKey:
    234 		priv, ok := cert.PrivateKey.(*rsa.PrivateKey)
    235 		if !ok {
    236 			err = errors.New("crypto/tls: private key type does not match public key type")
    237 			return
    238 		}
    239 		if pub.N.Cmp(priv.N) != 0 {
    240 			err = errors.New("crypto/tls: private key does not match public key")
    241 			return
    242 		}
    243 	case *ecdsa.PublicKey:
    244 		priv, ok := cert.PrivateKey.(*ecdsa.PrivateKey)
    245 		if !ok {
    246 			err = errors.New("crypto/tls: private key type does not match public key type")
    247 			return
    248 
    249 		}
    250 		if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 {
    251 			err = errors.New("crypto/tls: private key does not match public key")
    252 			return
    253 		}
    254 	case ed25519.PublicKey:
    255 		priv, ok := cert.PrivateKey.(ed25519.PrivateKey)
    256 		if !ok {
    257 			err = errors.New("crypto/tls: private key type does not match public key type")
    258 			return
    259 		}
    260 		if !bytes.Equal(priv[32:], pub) {
    261 			err = errors.New("crypto/tls: private key does not match public key")
    262 			return
    263 		}
    264 	default:
    265 		err = errors.New("crypto/tls: unknown public key algorithm")
    266 		return
    267 	}
    268 
    269 	return
    270 }
    271 
    272 var ed25519SPKIPrefix = []byte{0x30, 0x2a, 0x30, 0x05, 0x06, 0x03, 0x2b, 0x65, 0x70, 0x03, 0x21, 0x00}
    273 
    274 func isEd25519Certificate(cert *x509.Certificate) bool {
    275 	return bytes.HasPrefix(cert.RawSubjectPublicKeyInfo, ed25519SPKIPrefix) && len(cert.RawSubjectPublicKeyInfo) == len(ed25519SPKIPrefix)+32
    276 }
    277 
    278 func getCertificatePublicKey(cert *x509.Certificate) crypto.PublicKey {
    279 	if cert.PublicKey != nil {
    280 		return cert.PublicKey
    281 	}
    282 
    283 	if isEd25519Certificate(cert) {
    284 		return ed25519.PublicKey(cert.RawSubjectPublicKeyInfo[len(ed25519SPKIPrefix):])
    285 	}
    286 
    287 	return nil
    288 }
    289 
    290 var ed25519PKCS8Prefix = []byte{0x30, 0x2e, 0x02, 0x01, 0x00, 0x30, 0x05, 0x06, 0x03, 0x2b, 0x65, 0x70,
    291 	0x04, 0x22, 0x04, 0x20}
    292 
    293 // Attempt to parse the given private key DER block. OpenSSL 0.9.8 generates
    294 // PKCS#1 private keys by default, while OpenSSL 1.0.0 generates PKCS#8 keys.
    295 // OpenSSL ecparam generates SEC1 EC private keys for ECDSA. We try all three.
    296 func parsePrivateKey(der []byte) (crypto.PrivateKey, error) {
    297 	if key, err := x509.ParsePKCS1PrivateKey(der); err == nil {
    298 		return key, nil
    299 	}
    300 	if key, err := x509.ParsePKCS8PrivateKey(der); err == nil {
    301 		switch key := key.(type) {
    302 		case *rsa.PrivateKey, *ecdsa.PrivateKey:
    303 			return key, nil
    304 		default:
    305 			return nil, errors.New("crypto/tls: found unknown private key type in PKCS#8 wrapping")
    306 		}
    307 	}
    308 	if key, err := x509.ParseECPrivateKey(der); err == nil {
    309 		return key, nil
    310 	}
    311 
    312 	if bytes.HasPrefix(der, ed25519PKCS8Prefix) && len(der) == len(ed25519PKCS8Prefix)+32 {
    313 		seed := der[len(ed25519PKCS8Prefix):]
    314 		_, key := ed25519.NewKeyPairFromSeed(seed)
    315 		return key, nil
    316 	}
    317 
    318 	return nil, errors.New("crypto/tls: failed to parse private key")
    319 }
    320