Home | History | Annotate | Download | only in runner
      1 // Copyright 2012 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 package runner
      6 
      7 import (
      8 	"crypto/aes"
      9 	"crypto/cipher"
     10 	"crypto/hmac"
     11 	"crypto/sha256"
     12 	"crypto/subtle"
     13 	"encoding/binary"
     14 	"errors"
     15 	"io"
     16 	"time"
     17 )
     18 
     19 // sessionState contains the information that is serialized into a session
     20 // ticket in order to later resume a connection.
     21 type sessionState struct {
     22 	vers                 uint16
     23 	cipherSuite          uint16
     24 	masterSecret         []byte
     25 	handshakeHash        []byte
     26 	certificates         [][]byte
     27 	extendedMasterSecret bool
     28 	earlyALPN            []byte
     29 	ticketCreationTime   time.Time
     30 	ticketExpiration     time.Time
     31 	ticketFlags          uint32
     32 	ticketAgeAdd         uint32
     33 }
     34 
     35 func (s *sessionState) marshal() []byte {
     36 	msg := newByteBuilder()
     37 	msg.addU16(s.vers)
     38 	msg.addU16(s.cipherSuite)
     39 	masterSecret := msg.addU16LengthPrefixed()
     40 	masterSecret.addBytes(s.masterSecret)
     41 	handshakeHash := msg.addU16LengthPrefixed()
     42 	handshakeHash.addBytes(s.handshakeHash)
     43 	msg.addU16(uint16(len(s.certificates)))
     44 	for _, cert := range s.certificates {
     45 		certMsg := msg.addU32LengthPrefixed()
     46 		certMsg.addBytes(cert)
     47 	}
     48 
     49 	if s.extendedMasterSecret {
     50 		msg.addU8(1)
     51 	} else {
     52 		msg.addU8(0)
     53 	}
     54 
     55 	if s.vers >= VersionTLS13 {
     56 		msg.addU64(uint64(s.ticketCreationTime.UnixNano()))
     57 		msg.addU64(uint64(s.ticketExpiration.UnixNano()))
     58 		msg.addU32(s.ticketFlags)
     59 		msg.addU32(s.ticketAgeAdd)
     60 	}
     61 
     62 	earlyALPN := msg.addU16LengthPrefixed()
     63 	earlyALPN.addBytes(s.earlyALPN)
     64 
     65 	return msg.finish()
     66 }
     67 
     68 func (s *sessionState) unmarshal(data []byte) bool {
     69 	if len(data) < 8 {
     70 		return false
     71 	}
     72 
     73 	s.vers = uint16(data[0])<<8 | uint16(data[1])
     74 	s.cipherSuite = uint16(data[2])<<8 | uint16(data[3])
     75 	masterSecretLen := int(data[4])<<8 | int(data[5])
     76 	data = data[6:]
     77 	if len(data) < masterSecretLen {
     78 		return false
     79 	}
     80 
     81 	s.masterSecret = data[:masterSecretLen]
     82 	data = data[masterSecretLen:]
     83 
     84 	if len(data) < 2 {
     85 		return false
     86 	}
     87 
     88 	handshakeHashLen := int(data[0])<<8 | int(data[1])
     89 	data = data[2:]
     90 	if len(data) < handshakeHashLen {
     91 		return false
     92 	}
     93 
     94 	s.handshakeHash = data[:handshakeHashLen]
     95 	data = data[handshakeHashLen:]
     96 
     97 	if len(data) < 2 {
     98 		return false
     99 	}
    100 
    101 	numCerts := int(data[0])<<8 | int(data[1])
    102 	data = data[2:]
    103 
    104 	s.certificates = make([][]byte, numCerts)
    105 	for i := range s.certificates {
    106 		if len(data) < 4 {
    107 			return false
    108 		}
    109 		certLen := int(data[0])<<24 | int(data[1])<<16 | int(data[2])<<8 | int(data[3])
    110 		data = data[4:]
    111 		if certLen < 0 {
    112 			return false
    113 		}
    114 		if len(data) < certLen {
    115 			return false
    116 		}
    117 		s.certificates[i] = data[:certLen]
    118 		data = data[certLen:]
    119 	}
    120 
    121 	if len(data) < 1 {
    122 		return false
    123 	}
    124 
    125 	s.extendedMasterSecret = false
    126 	if data[0] == 1 {
    127 		s.extendedMasterSecret = true
    128 	}
    129 	data = data[1:]
    130 
    131 	if s.vers >= VersionTLS13 {
    132 		if len(data) < 24 {
    133 			return false
    134 		}
    135 		s.ticketCreationTime = time.Unix(0, int64(binary.BigEndian.Uint64(data)))
    136 		data = data[8:]
    137 		s.ticketExpiration = time.Unix(0, int64(binary.BigEndian.Uint64(data)))
    138 		data = data[8:]
    139 		s.ticketFlags = binary.BigEndian.Uint32(data)
    140 		data = data[4:]
    141 		s.ticketAgeAdd = binary.BigEndian.Uint32(data)
    142 		data = data[4:]
    143 	}
    144 
    145 	earlyALPNLen := int(data[0])<<8 | int(data[1])
    146 	data = data[2:]
    147 	if len(data) < earlyALPNLen {
    148 		return false
    149 	}
    150 	s.earlyALPN = data[:earlyALPNLen]
    151 	data = data[earlyALPNLen:]
    152 
    153 	if len(data) > 0 {
    154 		return false
    155 	}
    156 
    157 	return true
    158 }
    159 
    160 func (c *Conn) encryptTicket(state *sessionState) ([]byte, error) {
    161 	serialized := state.marshal()
    162 	encrypted := make([]byte, aes.BlockSize+len(serialized)+sha256.Size)
    163 	iv := encrypted[:aes.BlockSize]
    164 	macBytes := encrypted[len(encrypted)-sha256.Size:]
    165 
    166 	if _, err := io.ReadFull(c.config.rand(), iv); err != nil {
    167 		return nil, err
    168 	}
    169 	block, err := aes.NewCipher(c.config.SessionTicketKey[:16])
    170 	if err != nil {
    171 		return nil, errors.New("tls: failed to create cipher while encrypting ticket: " + err.Error())
    172 	}
    173 	cipher.NewCTR(block, iv).XORKeyStream(encrypted[aes.BlockSize:], serialized)
    174 
    175 	mac := hmac.New(sha256.New, c.config.SessionTicketKey[16:32])
    176 	mac.Write(encrypted[:len(encrypted)-sha256.Size])
    177 	mac.Sum(macBytes[:0])
    178 
    179 	return encrypted, nil
    180 }
    181 
    182 func (c *Conn) decryptTicket(encrypted []byte) (*sessionState, bool) {
    183 	if len(encrypted) < aes.BlockSize+sha256.Size {
    184 		return nil, false
    185 	}
    186 
    187 	iv := encrypted[:aes.BlockSize]
    188 	macBytes := encrypted[len(encrypted)-sha256.Size:]
    189 
    190 	mac := hmac.New(sha256.New, c.config.SessionTicketKey[16:32])
    191 	mac.Write(encrypted[:len(encrypted)-sha256.Size])
    192 	expected := mac.Sum(nil)
    193 
    194 	if subtle.ConstantTimeCompare(macBytes, expected) != 1 {
    195 		return nil, false
    196 	}
    197 
    198 	block, err := aes.NewCipher(c.config.SessionTicketKey[:16])
    199 	if err != nil {
    200 		return nil, false
    201 	}
    202 	ciphertext := encrypted[aes.BlockSize : len(encrypted)-sha256.Size]
    203 	plaintext := make([]byte, len(ciphertext))
    204 	cipher.NewCTR(block, iv).XORKeyStream(plaintext, ciphertext)
    205 
    206 	state := new(sessionState)
    207 	ok := state.unmarshal(plaintext)
    208 	return state, ok
    209 }
    210