Home | History | Annotate | Download | only in tls
      1 // Copyright 2012 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 package tls
      6 
      7 import (
      8 	"bytes"
      9 	"crypto/aes"
     10 	"crypto/cipher"
     11 	"crypto/hmac"
     12 	"crypto/sha256"
     13 	"crypto/subtle"
     14 	"errors"
     15 	"io"
     16 )
     17 
     18 // sessionState contains the information that is serialized into a session
     19 // ticket in order to later resume a connection.
     20 type sessionState struct {
     21 	vers         uint16
     22 	cipherSuite  uint16
     23 	masterSecret []byte
     24 	certificates [][]byte
     25 	// usedOldKey is true if the ticket from which this session came from
     26 	// was encrypted with an older key and thus should be refreshed.
     27 	usedOldKey bool
     28 }
     29 
     30 func (s *sessionState) equal(i interface{}) bool {
     31 	s1, ok := i.(*sessionState)
     32 	if !ok {
     33 		return false
     34 	}
     35 
     36 	if s.vers != s1.vers ||
     37 		s.cipherSuite != s1.cipherSuite ||
     38 		!bytes.Equal(s.masterSecret, s1.masterSecret) {
     39 		return false
     40 	}
     41 
     42 	if len(s.certificates) != len(s1.certificates) {
     43 		return false
     44 	}
     45 
     46 	for i := range s.certificates {
     47 		if !bytes.Equal(s.certificates[i], s1.certificates[i]) {
     48 			return false
     49 		}
     50 	}
     51 
     52 	return true
     53 }
     54 
     55 func (s *sessionState) marshal() []byte {
     56 	length := 2 + 2 + 2 + len(s.masterSecret) + 2
     57 	for _, cert := range s.certificates {
     58 		length += 4 + len(cert)
     59 	}
     60 
     61 	ret := make([]byte, length)
     62 	x := ret
     63 	x[0] = byte(s.vers >> 8)
     64 	x[1] = byte(s.vers)
     65 	x[2] = byte(s.cipherSuite >> 8)
     66 	x[3] = byte(s.cipherSuite)
     67 	x[4] = byte(len(s.masterSecret) >> 8)
     68 	x[5] = byte(len(s.masterSecret))
     69 	x = x[6:]
     70 	copy(x, s.masterSecret)
     71 	x = x[len(s.masterSecret):]
     72 
     73 	x[0] = byte(len(s.certificates) >> 8)
     74 	x[1] = byte(len(s.certificates))
     75 	x = x[2:]
     76 
     77 	for _, cert := range s.certificates {
     78 		x[0] = byte(len(cert) >> 24)
     79 		x[1] = byte(len(cert) >> 16)
     80 		x[2] = byte(len(cert) >> 8)
     81 		x[3] = byte(len(cert))
     82 		copy(x[4:], cert)
     83 		x = x[4+len(cert):]
     84 	}
     85 
     86 	return ret
     87 }
     88 
     89 func (s *sessionState) unmarshal(data []byte) bool {
     90 	if len(data) < 8 {
     91 		return false
     92 	}
     93 
     94 	s.vers = uint16(data[0])<<8 | uint16(data[1])
     95 	s.cipherSuite = uint16(data[2])<<8 | uint16(data[3])
     96 	masterSecretLen := int(data[4])<<8 | int(data[5])
     97 	data = data[6:]
     98 	if len(data) < masterSecretLen {
     99 		return false
    100 	}
    101 
    102 	s.masterSecret = data[:masterSecretLen]
    103 	data = data[masterSecretLen:]
    104 
    105 	if len(data) < 2 {
    106 		return false
    107 	}
    108 
    109 	numCerts := int(data[0])<<8 | int(data[1])
    110 	data = data[2:]
    111 
    112 	s.certificates = make([][]byte, numCerts)
    113 	for i := range s.certificates {
    114 		if len(data) < 4 {
    115 			return false
    116 		}
    117 		certLen := int(data[0])<<24 | int(data[1])<<16 | int(data[2])<<8 | int(data[3])
    118 		data = data[4:]
    119 		if certLen < 0 {
    120 			return false
    121 		}
    122 		if len(data) < certLen {
    123 			return false
    124 		}
    125 		s.certificates[i] = data[:certLen]
    126 		data = data[certLen:]
    127 	}
    128 
    129 	return len(data) == 0
    130 }
    131 
    132 func (c *Conn) encryptTicket(state *sessionState) ([]byte, error) {
    133 	serialized := state.marshal()
    134 	encrypted := make([]byte, ticketKeyNameLen+aes.BlockSize+len(serialized)+sha256.Size)
    135 	keyName := encrypted[:ticketKeyNameLen]
    136 	iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize]
    137 	macBytes := encrypted[len(encrypted)-sha256.Size:]
    138 
    139 	if _, err := io.ReadFull(c.config.rand(), iv); err != nil {
    140 		return nil, err
    141 	}
    142 	key := c.config.ticketKeys()[0]
    143 	copy(keyName, key.keyName[:])
    144 	block, err := aes.NewCipher(key.aesKey[:])
    145 	if err != nil {
    146 		return nil, errors.New("tls: failed to create cipher while encrypting ticket: " + err.Error())
    147 	}
    148 	cipher.NewCTR(block, iv).XORKeyStream(encrypted[ticketKeyNameLen+aes.BlockSize:], serialized)
    149 
    150 	mac := hmac.New(sha256.New, key.hmacKey[:])
    151 	mac.Write(encrypted[:len(encrypted)-sha256.Size])
    152 	mac.Sum(macBytes[:0])
    153 
    154 	return encrypted, nil
    155 }
    156 
    157 func (c *Conn) decryptTicket(encrypted []byte) (*sessionState, bool) {
    158 	if c.config.SessionTicketsDisabled ||
    159 		len(encrypted) < ticketKeyNameLen+aes.BlockSize+sha256.Size {
    160 		return nil, false
    161 	}
    162 
    163 	keyName := encrypted[:ticketKeyNameLen]
    164 	iv := encrypted[ticketKeyNameLen : ticketKeyNameLen+aes.BlockSize]
    165 	macBytes := encrypted[len(encrypted)-sha256.Size:]
    166 
    167 	keys := c.config.ticketKeys()
    168 	keyIndex := -1
    169 	for i, candidateKey := range keys {
    170 		if bytes.Equal(keyName, candidateKey.keyName[:]) {
    171 			keyIndex = i
    172 			break
    173 		}
    174 	}
    175 
    176 	if keyIndex == -1 {
    177 		return nil, false
    178 	}
    179 	key := &keys[keyIndex]
    180 
    181 	mac := hmac.New(sha256.New, key.hmacKey[:])
    182 	mac.Write(encrypted[:len(encrypted)-sha256.Size])
    183 	expected := mac.Sum(nil)
    184 
    185 	if subtle.ConstantTimeCompare(macBytes, expected) != 1 {
    186 		return nil, false
    187 	}
    188 
    189 	block, err := aes.NewCipher(key.aesKey[:])
    190 	if err != nil {
    191 		return nil, false
    192 	}
    193 	ciphertext := encrypted[ticketKeyNameLen+aes.BlockSize : len(encrypted)-sha256.Size]
    194 	plaintext := ciphertext
    195 	cipher.NewCTR(block, iv).XORKeyStream(plaintext, ciphertext)
    196 
    197 	state := &sessionState{usedOldKey: keyIndex > 0}
    198 	ok := state.unmarshal(plaintext)
    199 	return state, ok
    200 }
    201