1 // Copyright (c) 2016, Google Inc. 2 // 3 // Permission to use, copy, modify, and/or distribute this software for any 4 // purpose with or without fee is hereby granted, provided that the above 5 // copyright notice and this permission notice appear in all copies. 6 // 7 // THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 8 // WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 9 // MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY 10 // SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 11 // WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION 12 // OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN 13 // CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 14 15 package runner 16 17 import ( 18 "bytes" 19 "crypto/aes" 20 "crypto/cipher" 21 "crypto/hmac" 22 "crypto/sha256" 23 "encoding/asn1" 24 "errors" 25 ) 26 27 // TestShimTicketKey is the testing key assumed for the shim. 28 var TestShimTicketKey = make([]byte, 48) 29 30 func DecryptShimTicket(in []byte) ([]byte, error) { 31 name := TestShimTicketKey[:16] 32 macKey := TestShimTicketKey[16:32] 33 encKey := TestShimTicketKey[32:48] 34 35 h := hmac.New(sha256.New, macKey) 36 37 block, err := aes.NewCipher(encKey) 38 if err != nil { 39 panic(err) 40 } 41 42 if len(in) < len(name)+block.BlockSize()+1+h.Size() { 43 return nil, errors.New("tls: shim ticket too short") 44 } 45 46 // Check the key name. 47 if !bytes.Equal(name, in[:len(name)]) { 48 return nil, errors.New("tls: shim ticket name mismatch") 49 } 50 51 // Check the MAC at the end of the ticket. 52 mac := in[len(in)-h.Size():] 53 in = in[:len(in)-h.Size()] 54 h.Write(in) 55 if !hmac.Equal(mac, h.Sum(nil)) { 56 return nil, errors.New("tls: shim ticket MAC mismatch") 57 } 58 59 // The MAC covers the key name, but the encryption does not. 60 in = in[len(name):] 61 62 // Decrypt in-place. 63 iv := in[:block.BlockSize()] 64 in = in[block.BlockSize():] 65 if l := len(in); l == 0 || l%block.BlockSize() != 0 { 66 return nil, errors.New("tls: ticket ciphertext not a multiple of the block size") 67 } 68 out := make([]byte, len(in)) 69 cbc := cipher.NewCBCDecrypter(block, iv) 70 cbc.CryptBlocks(out, in) 71 72 // Remove the padding. 73 pad := int(out[len(out)-1]) 74 if pad == 0 || pad > block.BlockSize() || pad > len(in) { 75 return nil, errors.New("tls: bad shim ticket CBC pad") 76 } 77 78 for i := 0; i < pad; i++ { 79 if out[len(out)-1-i] != byte(pad) { 80 return nil, errors.New("tls: bad shim ticket CBC pad") 81 } 82 } 83 84 return out[:len(out)-pad], nil 85 } 86 87 func EncryptShimTicket(in []byte) []byte { 88 name := TestShimTicketKey[:16] 89 macKey := TestShimTicketKey[16:32] 90 encKey := TestShimTicketKey[32:48] 91 92 h := hmac.New(sha256.New, macKey) 93 94 block, err := aes.NewCipher(encKey) 95 if err != nil { 96 panic(err) 97 } 98 99 // Use the zero IV for rewritten tickets. 100 iv := make([]byte, block.BlockSize()) 101 cbc := cipher.NewCBCEncrypter(block, iv) 102 pad := block.BlockSize() - (len(in) % block.BlockSize()) 103 104 out := make([]byte, 0, len(name)+len(iv)+len(in)+pad+h.Size()) 105 out = append(out, name...) 106 out = append(out, iv...) 107 out = append(out, in...) 108 for i := 0; i < pad; i++ { 109 out = append(out, byte(pad)) 110 } 111 112 ciphertext := out[len(name)+len(iv):] 113 cbc.CryptBlocks(ciphertext, ciphertext) 114 115 h.Write(out) 116 return h.Sum(out) 117 } 118 119 const asn1Constructed = 0x20 120 121 func parseDERElement(in []byte) (tag byte, body, rest []byte, ok bool) { 122 rest = in 123 if len(rest) < 1 { 124 return 125 } 126 127 tag = rest[0] 128 rest = rest[1:] 129 130 if tag&0x1f == 0x1f { 131 // Long-form tags not supported. 132 return 133 } 134 135 if len(rest) < 1 { 136 return 137 } 138 139 length := int(rest[0]) 140 rest = rest[1:] 141 if length > 0x7f { 142 lengthLength := length & 0x7f 143 length = 0 144 if lengthLength == 0 { 145 // No indefinite-length encoding. 146 return 147 } 148 149 // Decode long-form lengths. 150 for lengthLength > 0 { 151 if len(rest) < 1 || (length<<8)>>8 != length { 152 return 153 } 154 if length == 0 && rest[0] == 0 { 155 // Length not minimally-encoded. 156 return 157 } 158 length <<= 8 159 length |= int(rest[0]) 160 rest = rest[1:] 161 lengthLength-- 162 } 163 164 if length < 0x80 { 165 // Length not minimally-encoded. 166 return 167 } 168 } 169 170 if len(rest) < length { 171 return 172 } 173 174 body = rest[:length] 175 rest = rest[length:] 176 ok = true 177 return 178 } 179 180 func SetShimTicketVersion(in []byte, vers uint16) ([]byte, error) { 181 plaintext, err := DecryptShimTicket(in) 182 if err != nil { 183 return nil, err 184 } 185 186 tag, session, _, ok := parseDERElement(plaintext) 187 if !ok || tag != asn1.TagSequence|asn1Constructed { 188 return nil, errors.New("tls: could not decode shim session") 189 } 190 191 // Skip the session version. 192 tag, _, session, ok = parseDERElement(session) 193 if !ok || tag != asn1.TagInteger { 194 return nil, errors.New("tls: could not decode shim session") 195 } 196 197 // Next field is the protocol version. 198 tag, version, _, ok := parseDERElement(session) 199 if !ok || tag != asn1.TagInteger { 200 return nil, errors.New("tls: could not decode shim session") 201 } 202 203 // This code assumes both old and new versions are encoded in two 204 // bytes. This isn't quite right as INTEGERs are minimally-encoded, but 205 // we do not need to support other caess for now. 206 if len(version) != 2 || vers < 0x80 || vers >= 0x8000 { 207 return nil, errors.New("tls: unsupported version in shim session") 208 } 209 210 version[0] = byte(vers >> 8) 211 version[1] = byte(vers) 212 213 return EncryptShimTicket(plaintext), nil 214 } 215 216 func SetShimTicketCipherSuite(in []byte, id uint16) ([]byte, error) { 217 plaintext, err := DecryptShimTicket(in) 218 if err != nil { 219 return nil, err 220 } 221 222 tag, session, _, ok := parseDERElement(plaintext) 223 if !ok || tag != asn1.TagSequence|asn1Constructed { 224 return nil, errors.New("tls: could not decode shim session") 225 } 226 227 // Skip the session version. 228 tag, _, session, ok = parseDERElement(session) 229 if !ok || tag != asn1.TagInteger { 230 return nil, errors.New("tls: could not decode shim session") 231 } 232 233 // Skip the protocol version. 234 tag, _, session, ok = parseDERElement(session) 235 if !ok || tag != asn1.TagInteger { 236 return nil, errors.New("tls: could not decode shim session") 237 } 238 239 // Next field is the cipher suite. 240 tag, cipherSuite, _, ok := parseDERElement(session) 241 if !ok || tag != asn1.TagOctetString || len(cipherSuite) != 2 { 242 return nil, errors.New("tls: could not decode shim session") 243 } 244 245 cipherSuite[0] = byte(id >> 8) 246 cipherSuite[1] = byte(id) 247 248 return EncryptShimTicket(plaintext), nil 249 } 250