Home | History | Annotate | Download | only in runner
      1 // Copyright 2012 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 package runner
      6 
      7 import (
      8 	"bytes"
      9 	"crypto/aes"
     10 	"crypto/cipher"
     11 	"crypto/hmac"
     12 	"crypto/sha256"
     13 	"crypto/subtle"
     14 	"errors"
     15 	"io"
     16 )
     17 
     18 // sessionState contains the information that is serialized into a session
     19 // ticket in order to later resume a connection.
     20 type sessionState struct {
     21 	vers                 uint16
     22 	cipherSuite          uint16
     23 	masterSecret         []byte
     24 	handshakeHash        []byte
     25 	certificates         [][]byte
     26 	extendedMasterSecret bool
     27 }
     28 
     29 func (s *sessionState) equal(i interface{}) bool {
     30 	s1, ok := i.(*sessionState)
     31 	if !ok {
     32 		return false
     33 	}
     34 
     35 	if s.vers != s1.vers ||
     36 		s.cipherSuite != s1.cipherSuite ||
     37 		!bytes.Equal(s.masterSecret, s1.masterSecret) ||
     38 		!bytes.Equal(s.handshakeHash, s1.handshakeHash) ||
     39 		s.extendedMasterSecret != s1.extendedMasterSecret {
     40 		return false
     41 	}
     42 
     43 	if len(s.certificates) != len(s1.certificates) {
     44 		return false
     45 	}
     46 
     47 	for i := range s.certificates {
     48 		if !bytes.Equal(s.certificates[i], s1.certificates[i]) {
     49 			return false
     50 		}
     51 	}
     52 
     53 	return true
     54 }
     55 
     56 func (s *sessionState) marshal() []byte {
     57 	length := 2 + 2 + 2 + len(s.masterSecret) + 2 + len(s.handshakeHash) + 2
     58 	for _, cert := range s.certificates {
     59 		length += 4 + len(cert)
     60 	}
     61 	length++
     62 
     63 	ret := make([]byte, length)
     64 	x := ret
     65 	x[0] = byte(s.vers >> 8)
     66 	x[1] = byte(s.vers)
     67 	x[2] = byte(s.cipherSuite >> 8)
     68 	x[3] = byte(s.cipherSuite)
     69 	x[4] = byte(len(s.masterSecret) >> 8)
     70 	x[5] = byte(len(s.masterSecret))
     71 	x = x[6:]
     72 	copy(x, s.masterSecret)
     73 	x = x[len(s.masterSecret):]
     74 
     75 	x[0] = byte(len(s.handshakeHash) >> 8)
     76 	x[1] = byte(len(s.handshakeHash))
     77 	x = x[2:]
     78 	copy(x, s.handshakeHash)
     79 	x = x[len(s.handshakeHash):]
     80 
     81 	x[0] = byte(len(s.certificates) >> 8)
     82 	x[1] = byte(len(s.certificates))
     83 	x = x[2:]
     84 
     85 	for _, cert := range s.certificates {
     86 		x[0] = byte(len(cert) >> 24)
     87 		x[1] = byte(len(cert) >> 16)
     88 		x[2] = byte(len(cert) >> 8)
     89 		x[3] = byte(len(cert))
     90 		copy(x[4:], cert)
     91 		x = x[4+len(cert):]
     92 	}
     93 
     94 	if s.extendedMasterSecret {
     95 		x[0] = 1
     96 	}
     97 	x = x[1:]
     98 
     99 	return ret
    100 }
    101 
    102 func (s *sessionState) unmarshal(data []byte) bool {
    103 	if len(data) < 8 {
    104 		return false
    105 	}
    106 
    107 	s.vers = uint16(data[0])<<8 | uint16(data[1])
    108 	s.cipherSuite = uint16(data[2])<<8 | uint16(data[3])
    109 	masterSecretLen := int(data[4])<<8 | int(data[5])
    110 	data = data[6:]
    111 	if len(data) < masterSecretLen {
    112 		return false
    113 	}
    114 
    115 	s.masterSecret = data[:masterSecretLen]
    116 	data = data[masterSecretLen:]
    117 
    118 	if len(data) < 2 {
    119 		return false
    120 	}
    121 
    122 	handshakeHashLen := int(data[0])<<8 | int(data[1])
    123 	data = data[2:]
    124 	if len(data) < handshakeHashLen {
    125 		return false
    126 	}
    127 
    128 	s.handshakeHash = data[:handshakeHashLen]
    129 	data = data[handshakeHashLen:]
    130 
    131 	if len(data) < 2 {
    132 		return false
    133 	}
    134 
    135 	numCerts := int(data[0])<<8 | int(data[1])
    136 	data = data[2:]
    137 
    138 	s.certificates = make([][]byte, numCerts)
    139 	for i := range s.certificates {
    140 		if len(data) < 4 {
    141 			return false
    142 		}
    143 		certLen := int(data[0])<<24 | int(data[1])<<16 | int(data[2])<<8 | int(data[3])
    144 		data = data[4:]
    145 		if certLen < 0 {
    146 			return false
    147 		}
    148 		if len(data) < certLen {
    149 			return false
    150 		}
    151 		s.certificates[i] = data[:certLen]
    152 		data = data[certLen:]
    153 	}
    154 
    155 	if len(data) < 1 {
    156 		return false
    157 	}
    158 
    159 	s.extendedMasterSecret = false
    160 	if data[0] == 1 {
    161 		s.extendedMasterSecret = true
    162 	}
    163 	data = data[1:]
    164 
    165 	if len(data) > 0 {
    166 		return false
    167 	}
    168 
    169 	return true
    170 }
    171 
    172 func (c *Conn) encryptTicket(state *sessionState) ([]byte, error) {
    173 	serialized := state.marshal()
    174 	encrypted := make([]byte, aes.BlockSize+len(serialized)+sha256.Size)
    175 	iv := encrypted[:aes.BlockSize]
    176 	macBytes := encrypted[len(encrypted)-sha256.Size:]
    177 
    178 	if _, err := io.ReadFull(c.config.rand(), iv); err != nil {
    179 		return nil, err
    180 	}
    181 	block, err := aes.NewCipher(c.config.SessionTicketKey[:16])
    182 	if err != nil {
    183 		return nil, errors.New("tls: failed to create cipher while encrypting ticket: " + err.Error())
    184 	}
    185 	cipher.NewCTR(block, iv).XORKeyStream(encrypted[aes.BlockSize:], serialized)
    186 
    187 	mac := hmac.New(sha256.New, c.config.SessionTicketKey[16:32])
    188 	mac.Write(encrypted[:len(encrypted)-sha256.Size])
    189 	mac.Sum(macBytes[:0])
    190 
    191 	return encrypted, nil
    192 }
    193 
    194 func (c *Conn) decryptTicket(encrypted []byte) (*sessionState, bool) {
    195 	if len(encrypted) < aes.BlockSize+sha256.Size {
    196 		return nil, false
    197 	}
    198 
    199 	iv := encrypted[:aes.BlockSize]
    200 	macBytes := encrypted[len(encrypted)-sha256.Size:]
    201 
    202 	mac := hmac.New(sha256.New, c.config.SessionTicketKey[16:32])
    203 	mac.Write(encrypted[:len(encrypted)-sha256.Size])
    204 	expected := mac.Sum(nil)
    205 
    206 	if subtle.ConstantTimeCompare(macBytes, expected) != 1 {
    207 		return nil, false
    208 	}
    209 
    210 	block, err := aes.NewCipher(c.config.SessionTicketKey[:16])
    211 	if err != nil {
    212 		return nil, false
    213 	}
    214 	ciphertext := encrypted[aes.BlockSize : len(encrypted)-sha256.Size]
    215 	plaintext := make([]byte, len(ciphertext))
    216 	cipher.NewCTR(block, iv).XORKeyStream(plaintext, ciphertext)
    217 
    218 	state := new(sessionState)
    219 	ok := state.unmarshal(plaintext)
    220 	return state, ok
    221 }
    222