Home | History | Annotate | Download | only in runner
      1 // Copyright 2014 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 // DTLS implementation.
      6 //
      7 // NOTE: This is a not even a remotely production-quality DTLS
      8 // implementation. It is the bare minimum necessary to be able to
      9 // achieve coverage on BoringSSL's implementation. Of note is that
     10 // this implementation assumes the underlying net.PacketConn is not
     11 // only reliable but also ordered. BoringSSL will be expected to deal
     12 // with simulated loss, but there is no point in forcing the test
     13 // driver to.
     14 
     15 package runner
     16 
     17 import (
     18 	"bytes"
     19 	"errors"
     20 	"fmt"
     21 	"io"
     22 	"math/rand"
     23 	"net"
     24 )
     25 
     26 func (c *Conn) dtlsDoReadRecord(want recordType) (recordType, *block, error) {
     27 	recordHeaderLen := dtlsRecordHeaderLen
     28 
     29 	if c.rawInput == nil {
     30 		c.rawInput = c.in.newBlock()
     31 	}
     32 	b := c.rawInput
     33 
     34 	// Read a new packet only if the current one is empty.
     35 	var newPacket bool
     36 	if len(b.data) == 0 {
     37 		// Pick some absurdly large buffer size.
     38 		b.resize(maxCiphertext + recordHeaderLen)
     39 		n, err := c.conn.Read(c.rawInput.data)
     40 		if err != nil {
     41 			return 0, nil, err
     42 		}
     43 		if c.config.Bugs.MaxPacketLength != 0 && n > c.config.Bugs.MaxPacketLength {
     44 			return 0, nil, fmt.Errorf("dtls: exceeded maximum packet length")
     45 		}
     46 		c.rawInput.resize(n)
     47 		newPacket = true
     48 	}
     49 
     50 	// Read out one record.
     51 	//
     52 	// A real DTLS implementation should be tolerant of errors,
     53 	// but this is test code. We should not be tolerant of our
     54 	// peer sending garbage.
     55 	if len(b.data) < recordHeaderLen {
     56 		return 0, nil, errors.New("dtls: failed to read record header")
     57 	}
     58 	typ := recordType(b.data[0])
     59 	vers := uint16(b.data[1])<<8 | uint16(b.data[2])
     60 	// Alerts sent near version negotiation do not have a well-defined
     61 	// record-layer version prior to TLS 1.3. (In TLS 1.3, the record-layer
     62 	// version is irrelevant.)
     63 	if typ != recordTypeAlert {
     64 		if c.haveVers {
     65 			if vers != c.wireVersion {
     66 				c.sendAlert(alertProtocolVersion)
     67 				return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: received record with version %x when expecting version %x", vers, c.wireVersion))
     68 			}
     69 		} else {
     70 			// Pre-version-negotiation alerts may be sent with any version.
     71 			if expect := c.config.Bugs.ExpectInitialRecordVersion; expect != 0 && vers != expect {
     72 				c.sendAlert(alertProtocolVersion)
     73 				return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: received record with version %x when expecting version %x", vers, expect))
     74 			}
     75 		}
     76 	}
     77 	epoch := b.data[3:5]
     78 	seq := b.data[5:11]
     79 	// For test purposes, require the sequence number be monotonically
     80 	// increasing, so c.in includes the minimum next sequence number. Gaps
     81 	// may occur if packets failed to be sent out. A real implementation
     82 	// would maintain a replay window and such.
     83 	if !bytes.Equal(epoch, c.in.seq[:2]) {
     84 		c.sendAlert(alertIllegalParameter)
     85 		return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: bad epoch"))
     86 	}
     87 	if bytes.Compare(seq, c.in.seq[2:]) < 0 {
     88 		c.sendAlert(alertIllegalParameter)
     89 		return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: bad sequence number"))
     90 	}
     91 	copy(c.in.seq[2:], seq)
     92 	n := int(b.data[11])<<8 | int(b.data[12])
     93 	if n > maxCiphertext || len(b.data) < recordHeaderLen+n {
     94 		c.sendAlert(alertRecordOverflow)
     95 		return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: oversized record received with length %d", n))
     96 	}
     97 
     98 	// Process message.
     99 	b, c.rawInput = c.in.splitBlock(b, recordHeaderLen+n)
    100 	ok, off, _, alertValue := c.in.decrypt(b)
    101 	if !ok {
    102 		// A real DTLS implementation would silently ignore bad records,
    103 		// but we want to notice errors from the implementation under
    104 		// test.
    105 		return 0, nil, c.in.setErrorLocked(c.sendAlert(alertValue))
    106 	}
    107 	b.off = off
    108 
    109 	// TODO(nharper): Once DTLS 1.3 is defined, handle the extra
    110 	// parameter from decrypt.
    111 
    112 	// Require that ChangeCipherSpec always share a packet with either the
    113 	// previous or next handshake message.
    114 	if newPacket && typ == recordTypeChangeCipherSpec && c.rawInput == nil {
    115 		return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: ChangeCipherSpec not packed together with Finished"))
    116 	}
    117 
    118 	return typ, b, nil
    119 }
    120 
    121 func (c *Conn) makeFragment(header, data []byte, fragOffset, fragLen int) []byte {
    122 	fragment := make([]byte, 0, 12+fragLen)
    123 	fragment = append(fragment, header...)
    124 	fragment = append(fragment, byte(c.sendHandshakeSeq>>8), byte(c.sendHandshakeSeq))
    125 	fragment = append(fragment, byte(fragOffset>>16), byte(fragOffset>>8), byte(fragOffset))
    126 	fragment = append(fragment, byte(fragLen>>16), byte(fragLen>>8), byte(fragLen))
    127 	fragment = append(fragment, data[fragOffset:fragOffset+fragLen]...)
    128 	return fragment
    129 }
    130 
    131 func (c *Conn) dtlsWriteRecord(typ recordType, data []byte) (n int, err error) {
    132 	// Only handshake messages are fragmented.
    133 	if typ != recordTypeHandshake {
    134 		reorder := typ == recordTypeChangeCipherSpec && c.config.Bugs.ReorderChangeCipherSpec
    135 
    136 		// Flush pending handshake messages before encrypting a new record.
    137 		if !reorder {
    138 			err = c.dtlsPackHandshake()
    139 			if err != nil {
    140 				return
    141 			}
    142 		}
    143 
    144 		if typ == recordTypeApplicationData && len(data) > 1 && c.config.Bugs.SplitAndPackAppData {
    145 			_, err = c.dtlsPackRecord(typ, data[:len(data)/2], false)
    146 			if err != nil {
    147 				return
    148 			}
    149 			_, err = c.dtlsPackRecord(typ, data[len(data)/2:], true)
    150 			if err != nil {
    151 				return
    152 			}
    153 			n = len(data)
    154 		} else {
    155 			n, err = c.dtlsPackRecord(typ, data, false)
    156 			if err != nil {
    157 				return
    158 			}
    159 		}
    160 
    161 		if reorder {
    162 			err = c.dtlsPackHandshake()
    163 			if err != nil {
    164 				return
    165 			}
    166 		}
    167 
    168 		if typ == recordTypeChangeCipherSpec {
    169 			err = c.out.changeCipherSpec(c.config)
    170 			if err != nil {
    171 				return n, c.sendAlertLocked(alertLevelError, err.(alert))
    172 			}
    173 		} else {
    174 			// ChangeCipherSpec is part of the handshake and not
    175 			// flushed until dtlsFlushPacket.
    176 			err = c.dtlsFlushPacket()
    177 			if err != nil {
    178 				return
    179 			}
    180 		}
    181 		return
    182 	}
    183 
    184 	if c.out.cipher == nil && c.config.Bugs.StrayChangeCipherSpec {
    185 		_, err = c.dtlsPackRecord(recordTypeChangeCipherSpec, []byte{1}, false)
    186 		if err != nil {
    187 			return
    188 		}
    189 	}
    190 
    191 	maxLen := c.config.Bugs.MaxHandshakeRecordLength
    192 	if maxLen <= 0 {
    193 		maxLen = 1024
    194 	}
    195 
    196 	// Handshake messages have to be modified to include fragment
    197 	// offset and length and with the header replicated. Save the
    198 	// TLS header here.
    199 	//
    200 	// TODO(davidben): This assumes that data contains exactly one
    201 	// handshake message. This is incompatible with
    202 	// FragmentAcrossChangeCipherSpec. (Which is unfortunate
    203 	// because OpenSSL's DTLS implementation will probably accept
    204 	// such fragmentation and could do with a fix + tests.)
    205 	header := data[:4]
    206 	data = data[4:]
    207 
    208 	isFinished := header[0] == typeFinished
    209 
    210 	if c.config.Bugs.SendEmptyFragments {
    211 		c.pendingFragments = append(c.pendingFragments, c.makeFragment(header, data, 0, 0))
    212 		c.pendingFragments = append(c.pendingFragments, c.makeFragment(header, data, len(data), 0))
    213 	}
    214 
    215 	firstRun := true
    216 	fragOffset := 0
    217 	for firstRun || fragOffset < len(data) {
    218 		firstRun = false
    219 		fragLen := len(data) - fragOffset
    220 		if fragLen > maxLen {
    221 			fragLen = maxLen
    222 		}
    223 
    224 		fragment := c.makeFragment(header, data, fragOffset, fragLen)
    225 		if c.config.Bugs.FragmentMessageTypeMismatch && fragOffset > 0 {
    226 			fragment[0]++
    227 		}
    228 		if c.config.Bugs.FragmentMessageLengthMismatch && fragOffset > 0 {
    229 			fragment[3]++
    230 		}
    231 
    232 		// Buffer the fragment for later. They will be sent (and
    233 		// reordered) on flush.
    234 		c.pendingFragments = append(c.pendingFragments, fragment)
    235 		if c.config.Bugs.ReorderHandshakeFragments {
    236 			// Don't duplicate Finished to avoid the peer
    237 			// interpreting it as a retransmit request.
    238 			if !isFinished {
    239 				c.pendingFragments = append(c.pendingFragments, fragment)
    240 			}
    241 
    242 			if fragLen > (maxLen+1)/2 {
    243 				// Overlap each fragment by half.
    244 				fragLen = (maxLen + 1) / 2
    245 			}
    246 		}
    247 		fragOffset += fragLen
    248 		n += fragLen
    249 	}
    250 	shouldSendTwice := c.config.Bugs.MixCompleteMessageWithFragments
    251 	if isFinished {
    252 		shouldSendTwice = c.config.Bugs.RetransmitFinished
    253 	}
    254 	if shouldSendTwice {
    255 		fragment := c.makeFragment(header, data, 0, len(data))
    256 		c.pendingFragments = append(c.pendingFragments, fragment)
    257 	}
    258 
    259 	// Increment the handshake sequence number for the next
    260 	// handshake message.
    261 	c.sendHandshakeSeq++
    262 	return
    263 }
    264 
    265 // dtlsPackHandshake packs the pending handshake flight into the pending
    266 // record. Callers should follow up with dtlsFlushPacket to write the packets.
    267 func (c *Conn) dtlsPackHandshake() error {
    268 	// This is a test-only DTLS implementation, so there is no need to
    269 	// retain |c.pendingFragments| for a future retransmit.
    270 	var fragments [][]byte
    271 	fragments, c.pendingFragments = c.pendingFragments, fragments
    272 
    273 	if c.config.Bugs.ReorderHandshakeFragments {
    274 		perm := rand.New(rand.NewSource(0)).Perm(len(fragments))
    275 		tmp := make([][]byte, len(fragments))
    276 		for i := range tmp {
    277 			tmp[i] = fragments[perm[i]]
    278 		}
    279 		fragments = tmp
    280 	} else if c.config.Bugs.ReverseHandshakeFragments {
    281 		tmp := make([][]byte, len(fragments))
    282 		for i := range tmp {
    283 			tmp[i] = fragments[len(fragments)-i-1]
    284 		}
    285 		fragments = tmp
    286 	}
    287 
    288 	maxRecordLen := c.config.Bugs.PackHandshakeFragments
    289 
    290 	// Pack handshake fragments into records.
    291 	var records [][]byte
    292 	for _, fragment := range fragments {
    293 		if n := c.config.Bugs.SplitFragments; n > 0 {
    294 			if len(fragment) > n {
    295 				records = append(records, fragment[:n])
    296 				records = append(records, fragment[n:])
    297 			} else {
    298 				records = append(records, fragment)
    299 			}
    300 		} else if i := len(records) - 1; len(records) > 0 && len(records[i])+len(fragment) <= maxRecordLen {
    301 			records[i] = append(records[i], fragment...)
    302 		} else {
    303 			// The fragment will be appended to, so copy it.
    304 			records = append(records, append([]byte{}, fragment...))
    305 		}
    306 	}
    307 
    308 	// Send the records.
    309 	for _, record := range records {
    310 		_, err := c.dtlsPackRecord(recordTypeHandshake, record, false)
    311 		if err != nil {
    312 			return err
    313 		}
    314 	}
    315 
    316 	return nil
    317 }
    318 
    319 func (c *Conn) dtlsFlushHandshake() error {
    320 	if err := c.dtlsPackHandshake(); err != nil {
    321 		return err
    322 	}
    323 	if err := c.dtlsFlushPacket(); err != nil {
    324 		return err
    325 	}
    326 
    327 	return nil
    328 }
    329 
    330 // dtlsPackRecord packs a single record to the pending packet, flushing it
    331 // if necessary. The caller should call dtlsFlushPacket to flush the current
    332 // pending packet afterwards.
    333 func (c *Conn) dtlsPackRecord(typ recordType, data []byte, mustPack bool) (n int, err error) {
    334 	recordHeaderLen := dtlsRecordHeaderLen
    335 	maxLen := c.config.Bugs.MaxHandshakeRecordLength
    336 	if maxLen <= 0 {
    337 		maxLen = 1024
    338 	}
    339 
    340 	b := c.out.newBlock()
    341 
    342 	explicitIVLen := 0
    343 	explicitIVIsSeq := false
    344 
    345 	if cbc, ok := c.out.cipher.(cbcMode); ok {
    346 		// Block cipher modes have an explicit IV.
    347 		explicitIVLen = cbc.BlockSize()
    348 	} else if aead, ok := c.out.cipher.(*tlsAead); ok {
    349 		if aead.explicitNonce {
    350 			explicitIVLen = 8
    351 			// The AES-GCM construction in TLS has an explicit nonce so that
    352 			// the nonce can be random. However, the nonce is only 8 bytes
    353 			// which is too small for a secure, random nonce. Therefore we
    354 			// use the sequence number as the nonce.
    355 			explicitIVIsSeq = true
    356 		}
    357 	} else if _, ok := c.out.cipher.(nullCipher); !ok && c.out.cipher != nil {
    358 		panic("Unknown cipher")
    359 	}
    360 	b.resize(recordHeaderLen + explicitIVLen + len(data))
    361 	// TODO(nharper): DTLS 1.3 will likely need to set this to
    362 	// recordTypeApplicationData if c.out.cipher != nil.
    363 	b.data[0] = byte(typ)
    364 	vers := c.wireVersion
    365 	if vers == 0 {
    366 		// Some TLS servers fail if the record version is greater than
    367 		// TLS 1.0 for the initial ClientHello.
    368 		if c.isDTLS {
    369 			vers = VersionDTLS10
    370 		} else {
    371 			vers = VersionTLS10
    372 		}
    373 	}
    374 	b.data[1] = byte(vers >> 8)
    375 	b.data[2] = byte(vers)
    376 	// DTLS records include an explicit sequence number.
    377 	copy(b.data[3:11], c.out.outSeq[0:])
    378 	b.data[11] = byte(len(data) >> 8)
    379 	b.data[12] = byte(len(data))
    380 	if explicitIVLen > 0 {
    381 		explicitIV := b.data[recordHeaderLen : recordHeaderLen+explicitIVLen]
    382 		if explicitIVIsSeq {
    383 			copy(explicitIV, c.out.outSeq[:])
    384 		} else {
    385 			if _, err = io.ReadFull(c.config.rand(), explicitIV); err != nil {
    386 				return
    387 			}
    388 		}
    389 	}
    390 	copy(b.data[recordHeaderLen+explicitIVLen:], data)
    391 	c.out.encrypt(b, explicitIVLen, typ)
    392 
    393 	// Flush the current pending packet if necessary.
    394 	if !mustPack && len(b.data)+len(c.pendingPacket) > c.config.Bugs.PackHandshakeRecords {
    395 		err = c.dtlsFlushPacket()
    396 		if err != nil {
    397 			c.out.freeBlock(b)
    398 			return
    399 		}
    400 	}
    401 
    402 	// Add the record to the pending packet.
    403 	c.pendingPacket = append(c.pendingPacket, b.data...)
    404 	c.out.freeBlock(b)
    405 	n = len(data)
    406 	return
    407 }
    408 
    409 func (c *Conn) dtlsFlushPacket() error {
    410 	if len(c.pendingPacket) == 0 {
    411 		return nil
    412 	}
    413 	_, err := c.conn.Write(c.pendingPacket)
    414 	c.pendingPacket = nil
    415 	return err
    416 }
    417 
    418 func (c *Conn) dtlsDoReadHandshake() ([]byte, error) {
    419 	// Assemble a full handshake message.  For test purposes, this
    420 	// implementation assumes fragments arrive in order. It may
    421 	// need to be cleverer if we ever test BoringSSL's retransmit
    422 	// behavior.
    423 	for len(c.handMsg) < 4+c.handMsgLen {
    424 		// Get a new handshake record if the previous has been
    425 		// exhausted.
    426 		if c.hand.Len() == 0 {
    427 			if err := c.in.err; err != nil {
    428 				return nil, err
    429 			}
    430 			if err := c.readRecord(recordTypeHandshake); err != nil {
    431 				return nil, err
    432 			}
    433 		}
    434 
    435 		// Read the next fragment. It must fit entirely within
    436 		// the record.
    437 		if c.hand.Len() < 12 {
    438 			return nil, errors.New("dtls: bad handshake record")
    439 		}
    440 		header := c.hand.Next(12)
    441 		fragN := int(header[1])<<16 | int(header[2])<<8 | int(header[3])
    442 		fragSeq := uint16(header[4])<<8 | uint16(header[5])
    443 		fragOff := int(header[6])<<16 | int(header[7])<<8 | int(header[8])
    444 		fragLen := int(header[9])<<16 | int(header[10])<<8 | int(header[11])
    445 
    446 		if c.hand.Len() < fragLen {
    447 			return nil, errors.New("dtls: fragment length too long")
    448 		}
    449 		fragment := c.hand.Next(fragLen)
    450 
    451 		// Check it's a fragment for the right message.
    452 		if fragSeq != c.recvHandshakeSeq {
    453 			return nil, errors.New("dtls: bad handshake sequence number")
    454 		}
    455 
    456 		// Check that the length is consistent.
    457 		if c.handMsg == nil {
    458 			c.handMsgLen = fragN
    459 			if c.handMsgLen > maxHandshake {
    460 				return nil, c.in.setErrorLocked(c.sendAlert(alertInternalError))
    461 			}
    462 			// Start with the TLS handshake header,
    463 			// without the DTLS bits.
    464 			c.handMsg = append([]byte{}, header[:4]...)
    465 		} else if fragN != c.handMsgLen {
    466 			return nil, errors.New("dtls: bad handshake length")
    467 		}
    468 
    469 		// Add the fragment to the pending message.
    470 		if 4+fragOff != len(c.handMsg) {
    471 			return nil, errors.New("dtls: bad fragment offset")
    472 		}
    473 		if fragOff+fragLen > c.handMsgLen {
    474 			return nil, errors.New("dtls: bad fragment length")
    475 		}
    476 		c.handMsg = append(c.handMsg, fragment...)
    477 	}
    478 	c.recvHandshakeSeq++
    479 	ret := c.handMsg
    480 	c.handMsg, c.handMsgLen = nil, 0
    481 	return ret, nil
    482 }
    483 
    484 // DTLSServer returns a new DTLS server side connection
    485 // using conn as the underlying transport.
    486 // The configuration config must be non-nil and must have
    487 // at least one certificate.
    488 func DTLSServer(conn net.Conn, config *Config) *Conn {
    489 	c := &Conn{config: config, isDTLS: true, conn: conn}
    490 	c.init()
    491 	return c
    492 }
    493 
    494 // DTLSClient returns a new DTLS client side connection
    495 // using conn as the underlying transport.
    496 // The config cannot be nil: users must set either ServerHostname or
    497 // InsecureSkipVerify in the config.
    498 func DTLSClient(conn net.Conn, config *Config) *Conn {
    499 	c := &Conn{config: config, isClient: true, isDTLS: true, conn: conn}
    500 	c.init()
    501 	return c
    502 }
    503