1 // Copyright 2012 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package main 6 7 import ( 8 "bytes" 9 "crypto/aes" 10 "crypto/cipher" 11 "crypto/hmac" 12 "crypto/sha256" 13 "crypto/subtle" 14 "errors" 15 "io" 16 ) 17 18 // sessionState contains the information that is serialized into a session 19 // ticket in order to later resume a connection. 20 type sessionState struct { 21 vers uint16 22 cipherSuite uint16 23 masterSecret []byte 24 handshakeHash []byte 25 certificates [][]byte 26 extendedMasterSecret bool 27 } 28 29 func (s *sessionState) equal(i interface{}) bool { 30 s1, ok := i.(*sessionState) 31 if !ok { 32 return false 33 } 34 35 if s.vers != s1.vers || 36 s.cipherSuite != s1.cipherSuite || 37 !bytes.Equal(s.masterSecret, s1.masterSecret) || 38 !bytes.Equal(s.handshakeHash, s1.handshakeHash) || 39 s.extendedMasterSecret != s1.extendedMasterSecret { 40 return false 41 } 42 43 if len(s.certificates) != len(s1.certificates) { 44 return false 45 } 46 47 for i := range s.certificates { 48 if !bytes.Equal(s.certificates[i], s1.certificates[i]) { 49 return false 50 } 51 } 52 53 return true 54 } 55 56 func (s *sessionState) marshal() []byte { 57 length := 2 + 2 + 2 + len(s.masterSecret) + 2 + len(s.handshakeHash) + 2 58 for _, cert := range s.certificates { 59 length += 4 + len(cert) 60 } 61 length++ 62 63 ret := make([]byte, length) 64 x := ret 65 x[0] = byte(s.vers >> 8) 66 x[1] = byte(s.vers) 67 x[2] = byte(s.cipherSuite >> 8) 68 x[3] = byte(s.cipherSuite) 69 x[4] = byte(len(s.masterSecret) >> 8) 70 x[5] = byte(len(s.masterSecret)) 71 x = x[6:] 72 copy(x, s.masterSecret) 73 x = x[len(s.masterSecret):] 74 75 x[0] = byte(len(s.handshakeHash) >> 8) 76 x[1] = byte(len(s.handshakeHash)) 77 x = x[2:] 78 copy(x, s.handshakeHash) 79 x = x[len(s.handshakeHash):] 80 81 x[0] = byte(len(s.certificates) >> 8) 82 x[1] = byte(len(s.certificates)) 83 x = x[2:] 84 85 for _, cert := range s.certificates { 86 x[0] = byte(len(cert) >> 24) 87 x[1] = byte(len(cert) >> 16) 88 x[2] = byte(len(cert) >> 8) 89 x[3] = byte(len(cert)) 90 copy(x[4:], cert) 91 x = x[4+len(cert):] 92 } 93 94 if s.extendedMasterSecret { 95 x[0] = 1 96 } 97 x = x[1:] 98 99 return ret 100 } 101 102 func (s *sessionState) unmarshal(data []byte) bool { 103 if len(data) < 8 { 104 return false 105 } 106 107 s.vers = uint16(data[0])<<8 | uint16(data[1]) 108 s.cipherSuite = uint16(data[2])<<8 | uint16(data[3]) 109 masterSecretLen := int(data[4])<<8 | int(data[5]) 110 data = data[6:] 111 if len(data) < masterSecretLen { 112 return false 113 } 114 115 s.masterSecret = data[:masterSecretLen] 116 data = data[masterSecretLen:] 117 118 if len(data) < 2 { 119 return false 120 } 121 122 handshakeHashLen := int(data[0])<<8 | int(data[1]) 123 data = data[2:] 124 if len(data) < handshakeHashLen { 125 return false 126 } 127 128 s.handshakeHash = data[:handshakeHashLen] 129 data = data[handshakeHashLen:] 130 131 if len(data) < 2 { 132 return false 133 } 134 135 numCerts := int(data[0])<<8 | int(data[1]) 136 data = data[2:] 137 138 s.certificates = make([][]byte, numCerts) 139 for i := range s.certificates { 140 if len(data) < 4 { 141 return false 142 } 143 certLen := int(data[0])<<24 | int(data[1])<<16 | int(data[2])<<8 | int(data[3]) 144 data = data[4:] 145 if certLen < 0 { 146 return false 147 } 148 if len(data) < certLen { 149 return false 150 } 151 s.certificates[i] = data[:certLen] 152 data = data[certLen:] 153 } 154 155 if len(data) < 1 { 156 return false 157 } 158 159 s.extendedMasterSecret = false 160 if data[0] == 1 { 161 s.extendedMasterSecret = true 162 } 163 data = data[1:] 164 165 if len(data) > 0 { 166 return false 167 } 168 169 return true 170 } 171 172 func (c *Conn) encryptTicket(state *sessionState) ([]byte, error) { 173 serialized := state.marshal() 174 encrypted := make([]byte, aes.BlockSize+len(serialized)+sha256.Size) 175 iv := encrypted[:aes.BlockSize] 176 macBytes := encrypted[len(encrypted)-sha256.Size:] 177 178 if _, err := io.ReadFull(c.config.rand(), iv); err != nil { 179 return nil, err 180 } 181 block, err := aes.NewCipher(c.config.SessionTicketKey[:16]) 182 if err != nil { 183 return nil, errors.New("tls: failed to create cipher while encrypting ticket: " + err.Error()) 184 } 185 cipher.NewCTR(block, iv).XORKeyStream(encrypted[aes.BlockSize:], serialized) 186 187 mac := hmac.New(sha256.New, c.config.SessionTicketKey[16:32]) 188 mac.Write(encrypted[:len(encrypted)-sha256.Size]) 189 mac.Sum(macBytes[:0]) 190 191 return encrypted, nil 192 } 193 194 func (c *Conn) decryptTicket(encrypted []byte) (*sessionState, bool) { 195 if len(encrypted) < aes.BlockSize+sha256.Size { 196 return nil, false 197 } 198 199 iv := encrypted[:aes.BlockSize] 200 macBytes := encrypted[len(encrypted)-sha256.Size:] 201 202 mac := hmac.New(sha256.New, c.config.SessionTicketKey[16:32]) 203 mac.Write(encrypted[:len(encrypted)-sha256.Size]) 204 expected := mac.Sum(nil) 205 206 if subtle.ConstantTimeCompare(macBytes, expected) != 1 { 207 return nil, false 208 } 209 210 block, err := aes.NewCipher(c.config.SessionTicketKey[:16]) 211 if err != nil { 212 return nil, false 213 } 214 ciphertext := encrypted[aes.BlockSize : len(encrypted)-sha256.Size] 215 plaintext := make([]byte, len(ciphertext)) 216 cipher.NewCTR(block, iv).XORKeyStream(plaintext, ciphertext) 217 218 state := new(sessionState) 219 ok := state.unmarshal(plaintext) 220 return state, ok 221 } 222