Home | History | Annotate | Download | only in runner
      1 // Copyright (c) 2016, Google Inc.
      2 //
      3 // Permission to use, copy, modify, and/or distribute this software for any
      4 // purpose with or without fee is hereby granted, provided that the above
      5 // copyright notice and this permission notice appear in all copies.
      6 //
      7 // THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
      8 // WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
      9 // MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
     10 // SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
     11 // WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
     12 // OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
     13 // CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
     14 
     15 package runner
     16 
     17 import (
     18 	"bytes"
     19 	"crypto/aes"
     20 	"crypto/cipher"
     21 	"crypto/hmac"
     22 	"crypto/sha256"
     23 	"encoding/asn1"
     24 	"errors"
     25 )
     26 
     27 // TestShimTicketKey is the testing key assumed for the shim.
     28 var TestShimTicketKey = make([]byte, 48)
     29 
     30 func DecryptShimTicket(in []byte) ([]byte, error) {
     31 	name := TestShimTicketKey[:16]
     32 	macKey := TestShimTicketKey[16:32]
     33 	encKey := TestShimTicketKey[32:48]
     34 
     35 	h := hmac.New(sha256.New, macKey)
     36 
     37 	block, err := aes.NewCipher(encKey)
     38 	if err != nil {
     39 		panic(err)
     40 	}
     41 
     42 	if len(in) < len(name)+block.BlockSize()+1+h.Size() {
     43 		return nil, errors.New("tls: shim ticket too short")
     44 	}
     45 
     46 	// Check the key name.
     47 	if !bytes.Equal(name, in[:len(name)]) {
     48 		return nil, errors.New("tls: shim ticket name mismatch")
     49 	}
     50 
     51 	// Check the MAC at the end of the ticket.
     52 	mac := in[len(in)-h.Size():]
     53 	in = in[:len(in)-h.Size()]
     54 	h.Write(in)
     55 	if !hmac.Equal(mac, h.Sum(nil)) {
     56 		return nil, errors.New("tls: shim ticket MAC mismatch")
     57 	}
     58 
     59 	// The MAC covers the key name, but the encryption does not.
     60 	in = in[len(name):]
     61 
     62 	// Decrypt in-place.
     63 	iv := in[:block.BlockSize()]
     64 	in = in[block.BlockSize():]
     65 	if l := len(in); l == 0 || l%block.BlockSize() != 0 {
     66 		return nil, errors.New("tls: ticket ciphertext not a multiple of the block size")
     67 	}
     68 	out := make([]byte, len(in))
     69 	cbc := cipher.NewCBCDecrypter(block, iv)
     70 	cbc.CryptBlocks(out, in)
     71 
     72 	// Remove the padding.
     73 	pad := int(out[len(out)-1])
     74 	if pad == 0 || pad > block.BlockSize() || pad > len(in) {
     75 		return nil, errors.New("tls: bad shim ticket CBC pad")
     76 	}
     77 
     78 	for i := 0; i < pad; i++ {
     79 		if out[len(out)-1-i] != byte(pad) {
     80 			return nil, errors.New("tls: bad shim ticket CBC pad")
     81 		}
     82 	}
     83 
     84 	return out[:len(out)-pad], nil
     85 }
     86 
     87 func EncryptShimTicket(in []byte) []byte {
     88 	name := TestShimTicketKey[:16]
     89 	macKey := TestShimTicketKey[16:32]
     90 	encKey := TestShimTicketKey[32:48]
     91 
     92 	h := hmac.New(sha256.New, macKey)
     93 
     94 	block, err := aes.NewCipher(encKey)
     95 	if err != nil {
     96 		panic(err)
     97 	}
     98 
     99 	// Use the zero IV for rewritten tickets.
    100 	iv := make([]byte, block.BlockSize())
    101 	cbc := cipher.NewCBCEncrypter(block, iv)
    102 	pad := block.BlockSize() - (len(in) % block.BlockSize())
    103 
    104 	out := make([]byte, 0, len(name)+len(iv)+len(in)+pad+h.Size())
    105 	out = append(out, name...)
    106 	out = append(out, iv...)
    107 	out = append(out, in...)
    108 	for i := 0; i < pad; i++ {
    109 		out = append(out, byte(pad))
    110 	}
    111 
    112 	ciphertext := out[len(name)+len(iv):]
    113 	cbc.CryptBlocks(ciphertext, ciphertext)
    114 
    115 	h.Write(out)
    116 	return h.Sum(out)
    117 }
    118 
    119 const asn1Constructed = 0x20
    120 
    121 func parseDERElement(in []byte) (tag byte, body, rest []byte, ok bool) {
    122 	rest = in
    123 	if len(rest) < 1 {
    124 		return
    125 	}
    126 
    127 	tag = rest[0]
    128 	rest = rest[1:]
    129 
    130 	if tag&0x1f == 0x1f {
    131 		// Long-form tags not supported.
    132 		return
    133 	}
    134 
    135 	if len(rest) < 1 {
    136 		return
    137 	}
    138 
    139 	length := int(rest[0])
    140 	rest = rest[1:]
    141 	if length > 0x7f {
    142 		lengthLength := length & 0x7f
    143 		length = 0
    144 		if lengthLength == 0 {
    145 			// No indefinite-length encoding.
    146 			return
    147 		}
    148 
    149 		// Decode long-form lengths.
    150 		for lengthLength > 0 {
    151 			if len(rest) < 1 || (length<<8)>>8 != length {
    152 				return
    153 			}
    154 			if length == 0 && rest[0] == 0 {
    155 				// Length not minimally-encoded.
    156 				return
    157 			}
    158 			length <<= 8
    159 			length |= int(rest[0])
    160 			rest = rest[1:]
    161 			lengthLength--
    162 		}
    163 
    164 		if length < 0x80 {
    165 			// Length not minimally-encoded.
    166 			return
    167 		}
    168 	}
    169 
    170 	if len(rest) < length {
    171 		return
    172 	}
    173 
    174 	body = rest[:length]
    175 	rest = rest[length:]
    176 	ok = true
    177 	return
    178 }
    179 
    180 func SetShimTicketVersion(in []byte, vers uint16) ([]byte, error) {
    181 	plaintext, err := DecryptShimTicket(in)
    182 	if err != nil {
    183 		return nil, err
    184 	}
    185 
    186 	tag, session, _, ok := parseDERElement(plaintext)
    187 	if !ok || tag != asn1.TagSequence|asn1Constructed {
    188 		return nil, errors.New("tls: could not decode shim session")
    189 	}
    190 
    191 	// Skip the session version.
    192 	tag, _, session, ok = parseDERElement(session)
    193 	if !ok || tag != asn1.TagInteger {
    194 		return nil, errors.New("tls: could not decode shim session")
    195 	}
    196 
    197 	// Next field is the protocol version.
    198 	tag, version, _, ok := parseDERElement(session)
    199 	if !ok || tag != asn1.TagInteger {
    200 		return nil, errors.New("tls: could not decode shim session")
    201 	}
    202 
    203 	// This code assumes both old and new versions are encoded in two
    204 	// bytes. This isn't quite right as INTEGERs are minimally-encoded, but
    205 	// we do not need to support other caess for now.
    206 	if len(version) != 2 || vers < 0x80 || vers >= 0x8000 {
    207 		return nil, errors.New("tls: unsupported version in shim session")
    208 	}
    209 
    210 	version[0] = byte(vers >> 8)
    211 	version[1] = byte(vers)
    212 
    213 	return EncryptShimTicket(plaintext), nil
    214 }
    215 
    216 func SetShimTicketCipherSuite(in []byte, id uint16) ([]byte, error) {
    217 	plaintext, err := DecryptShimTicket(in)
    218 	if err != nil {
    219 		return nil, err
    220 	}
    221 
    222 	tag, session, _, ok := parseDERElement(plaintext)
    223 	if !ok || tag != asn1.TagSequence|asn1Constructed {
    224 		return nil, errors.New("tls: could not decode shim session")
    225 	}
    226 
    227 	// Skip the session version.
    228 	tag, _, session, ok = parseDERElement(session)
    229 	if !ok || tag != asn1.TagInteger {
    230 		return nil, errors.New("tls: could not decode shim session")
    231 	}
    232 
    233 	// Skip the protocol version.
    234 	tag, _, session, ok = parseDERElement(session)
    235 	if !ok || tag != asn1.TagInteger {
    236 		return nil, errors.New("tls: could not decode shim session")
    237 	}
    238 
    239 	// Next field is the cipher suite.
    240 	tag, cipherSuite, _, ok := parseDERElement(session)
    241 	if !ok || tag != asn1.TagOctetString || len(cipherSuite) != 2 {
    242 		return nil, errors.New("tls: could not decode shim session")
    243 	}
    244 
    245 	cipherSuite[0] = byte(id >> 8)
    246 	cipherSuite[1] = byte(id)
    247 
    248 	return EncryptShimTicket(plaintext), nil
    249 }
    250