Home | History | Annotate | Download | only in runner
      1 // Copyright 2014 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 // DTLS implementation.
      6 //
      7 // NOTE: This is a not even a remotely production-quality DTLS
      8 // implementation. It is the bare minimum necessary to be able to
      9 // achieve coverage on BoringSSL's implementation. Of note is that
     10 // this implementation assumes the underlying net.PacketConn is not
     11 // only reliable but also ordered. BoringSSL will be expected to deal
     12 // with simulated loss, but there is no point in forcing the test
     13 // driver to.
     14 
     15 package runner
     16 
     17 import (
     18 	"bytes"
     19 	"errors"
     20 	"fmt"
     21 	"io"
     22 	"math/rand"
     23 	"net"
     24 )
     25 
     26 func versionToWire(vers uint16, isDTLS bool) uint16 {
     27 	if isDTLS {
     28 		return ^(vers - 0x0201)
     29 	}
     30 	return vers
     31 }
     32 
     33 func wireToVersion(vers uint16, isDTLS bool) uint16 {
     34 	if isDTLS {
     35 		return ^vers + 0x0201
     36 	}
     37 	return vers
     38 }
     39 
     40 func (c *Conn) dtlsDoReadRecord(want recordType) (recordType, *block, error) {
     41 	recordHeaderLen := dtlsRecordHeaderLen
     42 
     43 	if c.rawInput == nil {
     44 		c.rawInput = c.in.newBlock()
     45 	}
     46 	b := c.rawInput
     47 
     48 	// Read a new packet only if the current one is empty.
     49 	if len(b.data) == 0 {
     50 		// Pick some absurdly large buffer size.
     51 		b.resize(maxCiphertext + recordHeaderLen)
     52 		n, err := c.conn.Read(c.rawInput.data)
     53 		if err != nil {
     54 			return 0, nil, err
     55 		}
     56 		if c.config.Bugs.MaxPacketLength != 0 && n > c.config.Bugs.MaxPacketLength {
     57 			return 0, nil, fmt.Errorf("dtls: exceeded maximum packet length")
     58 		}
     59 		c.rawInput.resize(n)
     60 	}
     61 
     62 	// Read out one record.
     63 	//
     64 	// A real DTLS implementation should be tolerant of errors,
     65 	// but this is test code. We should not be tolerant of our
     66 	// peer sending garbage.
     67 	if len(b.data) < recordHeaderLen {
     68 		return 0, nil, errors.New("dtls: failed to read record header")
     69 	}
     70 	typ := recordType(b.data[0])
     71 	vers := wireToVersion(uint16(b.data[1])<<8|uint16(b.data[2]), c.isDTLS)
     72 	if c.haveVers {
     73 		if vers != c.vers {
     74 			c.sendAlert(alertProtocolVersion)
     75 			return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: received record with version %x when expecting version %x", vers, c.vers))
     76 		}
     77 	} else {
     78 		if expect := c.config.Bugs.ExpectInitialRecordVersion; expect != 0 && vers != expect {
     79 			c.sendAlert(alertProtocolVersion)
     80 			return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: received record with version %x when expecting version %x", vers, expect))
     81 		}
     82 	}
     83 	epoch := b.data[3:5]
     84 	seq := b.data[5:11]
     85 	// For test purposes, require the sequence number be monotonically
     86 	// increasing, so c.in includes the minimum next sequence number. Gaps
     87 	// may occur if packets failed to be sent out. A real implementation
     88 	// would maintain a replay window and such.
     89 	if !bytes.Equal(epoch, c.in.seq[:2]) {
     90 		c.sendAlert(alertIllegalParameter)
     91 		return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: bad epoch"))
     92 	}
     93 	if bytes.Compare(seq, c.in.seq[2:]) < 0 {
     94 		c.sendAlert(alertIllegalParameter)
     95 		return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: bad sequence number"))
     96 	}
     97 	copy(c.in.seq[2:], seq)
     98 	n := int(b.data[11])<<8 | int(b.data[12])
     99 	if n > maxCiphertext || len(b.data) < recordHeaderLen+n {
    100 		c.sendAlert(alertRecordOverflow)
    101 		return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: oversized record received with length %d", n))
    102 	}
    103 
    104 	// Process message.
    105 	b, c.rawInput = c.in.splitBlock(b, recordHeaderLen+n)
    106 	ok, off, err := c.in.decrypt(b)
    107 	if !ok {
    108 		c.in.setErrorLocked(c.sendAlert(err))
    109 	}
    110 	b.off = off
    111 	return typ, b, nil
    112 }
    113 
    114 func (c *Conn) makeFragment(header, data []byte, fragOffset, fragLen int) []byte {
    115 	fragment := make([]byte, 0, 12+fragLen)
    116 	fragment = append(fragment, header...)
    117 	fragment = append(fragment, byte(c.sendHandshakeSeq>>8), byte(c.sendHandshakeSeq))
    118 	fragment = append(fragment, byte(fragOffset>>16), byte(fragOffset>>8), byte(fragOffset))
    119 	fragment = append(fragment, byte(fragLen>>16), byte(fragLen>>8), byte(fragLen))
    120 	fragment = append(fragment, data[fragOffset:fragOffset+fragLen]...)
    121 	return fragment
    122 }
    123 
    124 func (c *Conn) dtlsWriteRecord(typ recordType, data []byte) (n int, err error) {
    125 	if typ != recordTypeHandshake {
    126 		// Only handshake messages are fragmented.
    127 		return c.dtlsWriteRawRecord(typ, data)
    128 	}
    129 
    130 	maxLen := c.config.Bugs.MaxHandshakeRecordLength
    131 	if maxLen <= 0 {
    132 		maxLen = 1024
    133 	}
    134 
    135 	// Handshake messages have to be modified to include fragment
    136 	// offset and length and with the header replicated. Save the
    137 	// TLS header here.
    138 	//
    139 	// TODO(davidben): This assumes that data contains exactly one
    140 	// handshake message. This is incompatible with
    141 	// FragmentAcrossChangeCipherSpec. (Which is unfortunate
    142 	// because OpenSSL's DTLS implementation will probably accept
    143 	// such fragmentation and could do with a fix + tests.)
    144 	header := data[:4]
    145 	data = data[4:]
    146 
    147 	isFinished := header[0] == typeFinished
    148 
    149 	if c.config.Bugs.SendEmptyFragments {
    150 		fragment := c.makeFragment(header, data, 0, 0)
    151 		c.pendingFragments = append(c.pendingFragments, fragment)
    152 	}
    153 
    154 	firstRun := true
    155 	fragOffset := 0
    156 	for firstRun || fragOffset < len(data) {
    157 		firstRun = false
    158 		fragLen := len(data) - fragOffset
    159 		if fragLen > maxLen {
    160 			fragLen = maxLen
    161 		}
    162 
    163 		fragment := c.makeFragment(header, data, fragOffset, fragLen)
    164 		if c.config.Bugs.FragmentMessageTypeMismatch && fragOffset > 0 {
    165 			fragment[0]++
    166 		}
    167 		if c.config.Bugs.FragmentMessageLengthMismatch && fragOffset > 0 {
    168 			fragment[3]++
    169 		}
    170 
    171 		// Buffer the fragment for later. They will be sent (and
    172 		// reordered) on flush.
    173 		c.pendingFragments = append(c.pendingFragments, fragment)
    174 		if c.config.Bugs.ReorderHandshakeFragments {
    175 			// Don't duplicate Finished to avoid the peer
    176 			// interpreting it as a retransmit request.
    177 			if !isFinished {
    178 				c.pendingFragments = append(c.pendingFragments, fragment)
    179 			}
    180 
    181 			if fragLen > (maxLen+1)/2 {
    182 				// Overlap each fragment by half.
    183 				fragLen = (maxLen + 1) / 2
    184 			}
    185 		}
    186 		fragOffset += fragLen
    187 		n += fragLen
    188 	}
    189 	if !isFinished && c.config.Bugs.MixCompleteMessageWithFragments {
    190 		fragment := c.makeFragment(header, data, 0, len(data))
    191 		c.pendingFragments = append(c.pendingFragments, fragment)
    192 	}
    193 
    194 	// Increment the handshake sequence number for the next
    195 	// handshake message.
    196 	c.sendHandshakeSeq++
    197 	return
    198 }
    199 
    200 func (c *Conn) dtlsFlushHandshake() error {
    201 	if !c.isDTLS {
    202 		return nil
    203 	}
    204 
    205 	// This is a test-only DTLS implementation, so there is no need to
    206 	// retain |c.pendingFragments| for a future retransmit.
    207 	var fragments [][]byte
    208 	fragments, c.pendingFragments = c.pendingFragments, fragments
    209 
    210 	if c.config.Bugs.ReorderHandshakeFragments {
    211 		perm := rand.New(rand.NewSource(0)).Perm(len(fragments))
    212 		tmp := make([][]byte, len(fragments))
    213 		for i := range tmp {
    214 			tmp[i] = fragments[perm[i]]
    215 		}
    216 		fragments = tmp
    217 	}
    218 
    219 	maxRecordLen := c.config.Bugs.PackHandshakeFragments
    220 	maxPacketLen := c.config.Bugs.PackHandshakeRecords
    221 
    222 	// Pack handshake fragments into records.
    223 	var records [][]byte
    224 	for _, fragment := range fragments {
    225 		if n := c.config.Bugs.SplitFragments; n > 0 {
    226 			if len(fragment) > n {
    227 				records = append(records, fragment[:n])
    228 				records = append(records, fragment[n:])
    229 			} else {
    230 				records = append(records, fragment)
    231 			}
    232 		} else if i := len(records) - 1; len(records) > 0 && len(records[i])+len(fragment) <= maxRecordLen {
    233 			records[i] = append(records[i], fragment...)
    234 		} else {
    235 			// The fragment will be appended to, so copy it.
    236 			records = append(records, append([]byte{}, fragment...))
    237 		}
    238 	}
    239 
    240 	// Format them into packets.
    241 	var packets [][]byte
    242 	for _, record := range records {
    243 		b, err := c.dtlsSealRecord(recordTypeHandshake, record)
    244 		if err != nil {
    245 			return err
    246 		}
    247 
    248 		if i := len(packets) - 1; len(packets) > 0 && len(packets[i])+len(b.data) <= maxPacketLen {
    249 			packets[i] = append(packets[i], b.data...)
    250 		} else {
    251 			// The sealed record will be appended to and reused by
    252 			// |c.out|, so copy it.
    253 			packets = append(packets, append([]byte{}, b.data...))
    254 		}
    255 		c.out.freeBlock(b)
    256 	}
    257 
    258 	// Send all the packets.
    259 	for _, packet := range packets {
    260 		if _, err := c.conn.Write(packet); err != nil {
    261 			return err
    262 		}
    263 	}
    264 	return nil
    265 }
    266 
    267 // dtlsSealRecord seals a record into a block from |c.out|'s pool.
    268 func (c *Conn) dtlsSealRecord(typ recordType, data []byte) (b *block, err error) {
    269 	recordHeaderLen := dtlsRecordHeaderLen
    270 	maxLen := c.config.Bugs.MaxHandshakeRecordLength
    271 	if maxLen <= 0 {
    272 		maxLen = 1024
    273 	}
    274 
    275 	b = c.out.newBlock()
    276 
    277 	explicitIVLen := 0
    278 	explicitIVIsSeq := false
    279 
    280 	if cbc, ok := c.out.cipher.(cbcMode); ok {
    281 		// Block cipher modes have an explicit IV.
    282 		explicitIVLen = cbc.BlockSize()
    283 	} else if aead, ok := c.out.cipher.(*tlsAead); ok {
    284 		if aead.explicitNonce {
    285 			explicitIVLen = 8
    286 			// The AES-GCM construction in TLS has an explicit nonce so that
    287 			// the nonce can be random. However, the nonce is only 8 bytes
    288 			// which is too small for a secure, random nonce. Therefore we
    289 			// use the sequence number as the nonce.
    290 			explicitIVIsSeq = true
    291 		}
    292 	} else if c.out.cipher != nil {
    293 		panic("Unknown cipher")
    294 	}
    295 	b.resize(recordHeaderLen + explicitIVLen + len(data))
    296 	b.data[0] = byte(typ)
    297 	vers := c.vers
    298 	if vers == 0 {
    299 		// Some TLS servers fail if the record version is greater than
    300 		// TLS 1.0 for the initial ClientHello.
    301 		vers = VersionTLS10
    302 	}
    303 	vers = versionToWire(vers, c.isDTLS)
    304 	b.data[1] = byte(vers >> 8)
    305 	b.data[2] = byte(vers)
    306 	// DTLS records include an explicit sequence number.
    307 	copy(b.data[3:11], c.out.outSeq[0:])
    308 	b.data[11] = byte(len(data) >> 8)
    309 	b.data[12] = byte(len(data))
    310 	if explicitIVLen > 0 {
    311 		explicitIV := b.data[recordHeaderLen : recordHeaderLen+explicitIVLen]
    312 		if explicitIVIsSeq {
    313 			copy(explicitIV, c.out.outSeq[:])
    314 		} else {
    315 			if _, err = io.ReadFull(c.config.rand(), explicitIV); err != nil {
    316 				return
    317 			}
    318 		}
    319 	}
    320 	copy(b.data[recordHeaderLen+explicitIVLen:], data)
    321 	c.out.encrypt(b, explicitIVLen)
    322 	return
    323 }
    324 
    325 func (c *Conn) dtlsWriteRawRecord(typ recordType, data []byte) (n int, err error) {
    326 	b, err := c.dtlsSealRecord(typ, data)
    327 	if err != nil {
    328 		return
    329 	}
    330 
    331 	_, err = c.conn.Write(b.data)
    332 	if err != nil {
    333 		return
    334 	}
    335 	n = len(data)
    336 
    337 	c.out.freeBlock(b)
    338 
    339 	if typ == recordTypeChangeCipherSpec {
    340 		err = c.out.changeCipherSpec(c.config)
    341 		if err != nil {
    342 			// Cannot call sendAlert directly,
    343 			// because we already hold c.out.Mutex.
    344 			c.tmp[0] = alertLevelError
    345 			c.tmp[1] = byte(err.(alert))
    346 			c.writeRecord(recordTypeAlert, c.tmp[0:2])
    347 			return n, c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
    348 		}
    349 	}
    350 	return
    351 }
    352 
    353 func (c *Conn) dtlsDoReadHandshake() ([]byte, error) {
    354 	// Assemble a full handshake message.  For test purposes, this
    355 	// implementation assumes fragments arrive in order. It may
    356 	// need to be cleverer if we ever test BoringSSL's retransmit
    357 	// behavior.
    358 	for len(c.handMsg) < 4+c.handMsgLen {
    359 		// Get a new handshake record if the previous has been
    360 		// exhausted.
    361 		if c.hand.Len() == 0 {
    362 			if err := c.in.err; err != nil {
    363 				return nil, err
    364 			}
    365 			if err := c.readRecord(recordTypeHandshake); err != nil {
    366 				return nil, err
    367 			}
    368 		}
    369 
    370 		// Read the next fragment. It must fit entirely within
    371 		// the record.
    372 		if c.hand.Len() < 12 {
    373 			return nil, errors.New("dtls: bad handshake record")
    374 		}
    375 		header := c.hand.Next(12)
    376 		fragN := int(header[1])<<16 | int(header[2])<<8 | int(header[3])
    377 		fragSeq := uint16(header[4])<<8 | uint16(header[5])
    378 		fragOff := int(header[6])<<16 | int(header[7])<<8 | int(header[8])
    379 		fragLen := int(header[9])<<16 | int(header[10])<<8 | int(header[11])
    380 
    381 		if c.hand.Len() < fragLen {
    382 			return nil, errors.New("dtls: fragment length too long")
    383 		}
    384 		fragment := c.hand.Next(fragLen)
    385 
    386 		// Check it's a fragment for the right message.
    387 		if fragSeq != c.recvHandshakeSeq {
    388 			return nil, errors.New("dtls: bad handshake sequence number")
    389 		}
    390 
    391 		// Check that the length is consistent.
    392 		if c.handMsg == nil {
    393 			c.handMsgLen = fragN
    394 			if c.handMsgLen > maxHandshake {
    395 				return nil, c.in.setErrorLocked(c.sendAlert(alertInternalError))
    396 			}
    397 			// Start with the TLS handshake header,
    398 			// without the DTLS bits.
    399 			c.handMsg = append([]byte{}, header[:4]...)
    400 		} else if fragN != c.handMsgLen {
    401 			return nil, errors.New("dtls: bad handshake length")
    402 		}
    403 
    404 		// Add the fragment to the pending message.
    405 		if 4+fragOff != len(c.handMsg) {
    406 			return nil, errors.New("dtls: bad fragment offset")
    407 		}
    408 		if fragOff+fragLen > c.handMsgLen {
    409 			return nil, errors.New("dtls: bad fragment length")
    410 		}
    411 		c.handMsg = append(c.handMsg, fragment...)
    412 	}
    413 	c.recvHandshakeSeq++
    414 	ret := c.handMsg
    415 	c.handMsg, c.handMsgLen = nil, 0
    416 	return ret, nil
    417 }
    418 
    419 // DTLSServer returns a new DTLS server side connection
    420 // using conn as the underlying transport.
    421 // The configuration config must be non-nil and must have
    422 // at least one certificate.
    423 func DTLSServer(conn net.Conn, config *Config) *Conn {
    424 	c := &Conn{config: config, isDTLS: true, conn: conn}
    425 	c.init()
    426 	return c
    427 }
    428 
    429 // DTLSClient returns a new DTLS client side connection
    430 // using conn as the underlying transport.
    431 // The config cannot be nil: users must set either ServerHostname or
    432 // InsecureSkipVerify in the config.
    433 func DTLSClient(conn net.Conn, config *Config) *Conn {
    434 	c := &Conn{config: config, isClient: true, isDTLS: true, conn: conn}
    435 	c.init()
    436 	return c
    437 }
    438