Home | History | Annotate | Download | only in runner
      1 // Copyright 2014 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 // DTLS implementation.
      6 //
      7 // NOTE: This is a not even a remotely production-quality DTLS
      8 // implementation. It is the bare minimum necessary to be able to
      9 // achieve coverage on BoringSSL's implementation. Of note is that
     10 // this implementation assumes the underlying net.PacketConn is not
     11 // only reliable but also ordered. BoringSSL will be expected to deal
     12 // with simulated loss, but there is no point in forcing the test
     13 // driver to.
     14 
     15 package runner
     16 
     17 import (
     18 	"bytes"
     19 	"errors"
     20 	"fmt"
     21 	"io"
     22 	"math/rand"
     23 	"net"
     24 )
     25 
     26 func wireToVersion(vers uint16, isDTLS bool) (uint16, bool) {
     27 	if isDTLS {
     28 		switch vers {
     29 		case VersionDTLS12:
     30 			return VersionTLS12, true
     31 		case VersionDTLS10:
     32 			return VersionTLS10, true
     33 		}
     34 	} else {
     35 		switch vers {
     36 		case VersionSSL30, VersionTLS10, VersionTLS11, VersionTLS12:
     37 			return vers, true
     38 		case tls13DraftVersion, tls13ExperimentVersion, tls13RecordTypeExperimentVersion:
     39 			return VersionTLS13, true
     40 		}
     41 	}
     42 
     43 	return 0, false
     44 }
     45 
     46 func (c *Conn) dtlsDoReadRecord(want recordType) (recordType, *block, error) {
     47 	recordHeaderLen := dtlsRecordHeaderLen
     48 
     49 	if c.rawInput == nil {
     50 		c.rawInput = c.in.newBlock()
     51 	}
     52 	b := c.rawInput
     53 
     54 	// Read a new packet only if the current one is empty.
     55 	var newPacket bool
     56 	if len(b.data) == 0 {
     57 		// Pick some absurdly large buffer size.
     58 		b.resize(maxCiphertext + recordHeaderLen)
     59 		n, err := c.conn.Read(c.rawInput.data)
     60 		if err != nil {
     61 			return 0, nil, err
     62 		}
     63 		if c.config.Bugs.MaxPacketLength != 0 && n > c.config.Bugs.MaxPacketLength {
     64 			return 0, nil, fmt.Errorf("dtls: exceeded maximum packet length")
     65 		}
     66 		c.rawInput.resize(n)
     67 		newPacket = true
     68 	}
     69 
     70 	// Read out one record.
     71 	//
     72 	// A real DTLS implementation should be tolerant of errors,
     73 	// but this is test code. We should not be tolerant of our
     74 	// peer sending garbage.
     75 	if len(b.data) < recordHeaderLen {
     76 		return 0, nil, errors.New("dtls: failed to read record header")
     77 	}
     78 	typ := recordType(b.data[0])
     79 	vers := uint16(b.data[1])<<8 | uint16(b.data[2])
     80 	// Alerts sent near version negotiation do not have a well-defined
     81 	// record-layer version prior to TLS 1.3. (In TLS 1.3, the record-layer
     82 	// version is irrelevant.)
     83 	if typ != recordTypeAlert {
     84 		if c.haveVers {
     85 			if vers != c.wireVersion {
     86 				c.sendAlert(alertProtocolVersion)
     87 				return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: received record with version %x when expecting version %x", vers, c.wireVersion))
     88 			}
     89 		} else {
     90 			// Pre-version-negotiation alerts may be sent with any version.
     91 			if expect := c.config.Bugs.ExpectInitialRecordVersion; expect != 0 && vers != expect {
     92 				c.sendAlert(alertProtocolVersion)
     93 				return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: received record with version %x when expecting version %x", vers, expect))
     94 			}
     95 		}
     96 	}
     97 	epoch := b.data[3:5]
     98 	seq := b.data[5:11]
     99 	// For test purposes, require the sequence number be monotonically
    100 	// increasing, so c.in includes the minimum next sequence number. Gaps
    101 	// may occur if packets failed to be sent out. A real implementation
    102 	// would maintain a replay window and such.
    103 	if !bytes.Equal(epoch, c.in.seq[:2]) {
    104 		c.sendAlert(alertIllegalParameter)
    105 		return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: bad epoch"))
    106 	}
    107 	if bytes.Compare(seq, c.in.seq[2:]) < 0 {
    108 		c.sendAlert(alertIllegalParameter)
    109 		return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: bad sequence number"))
    110 	}
    111 	copy(c.in.seq[2:], seq)
    112 	n := int(b.data[11])<<8 | int(b.data[12])
    113 	if n > maxCiphertext || len(b.data) < recordHeaderLen+n {
    114 		c.sendAlert(alertRecordOverflow)
    115 		return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: oversized record received with length %d", n))
    116 	}
    117 
    118 	// Process message.
    119 	b, c.rawInput = c.in.splitBlock(b, recordHeaderLen+n)
    120 	ok, off, _, alertValue := c.in.decrypt(b)
    121 	if !ok {
    122 		// A real DTLS implementation would silently ignore bad records,
    123 		// but we want to notice errors from the implementation under
    124 		// test.
    125 		return 0, nil, c.in.setErrorLocked(c.sendAlert(alertValue))
    126 	}
    127 	b.off = off
    128 
    129 	// TODO(nharper): Once DTLS 1.3 is defined, handle the extra
    130 	// parameter from decrypt.
    131 
    132 	// Require that ChangeCipherSpec always share a packet with either the
    133 	// previous or next handshake message.
    134 	if newPacket && typ == recordTypeChangeCipherSpec && c.rawInput == nil {
    135 		return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: ChangeCipherSpec not packed together with Finished"))
    136 	}
    137 
    138 	return typ, b, nil
    139 }
    140 
    141 func (c *Conn) makeFragment(header, data []byte, fragOffset, fragLen int) []byte {
    142 	fragment := make([]byte, 0, 12+fragLen)
    143 	fragment = append(fragment, header...)
    144 	fragment = append(fragment, byte(c.sendHandshakeSeq>>8), byte(c.sendHandshakeSeq))
    145 	fragment = append(fragment, byte(fragOffset>>16), byte(fragOffset>>8), byte(fragOffset))
    146 	fragment = append(fragment, byte(fragLen>>16), byte(fragLen>>8), byte(fragLen))
    147 	fragment = append(fragment, data[fragOffset:fragOffset+fragLen]...)
    148 	return fragment
    149 }
    150 
    151 func (c *Conn) dtlsWriteRecord(typ recordType, data []byte) (n int, err error) {
    152 	if typ != recordTypeHandshake {
    153 		// Only handshake messages are fragmented.
    154 		n, err = c.dtlsWriteRawRecord(typ, data)
    155 		if err != nil {
    156 			return
    157 		}
    158 
    159 		if typ == recordTypeChangeCipherSpec {
    160 			err = c.out.changeCipherSpec(c.config)
    161 			if err != nil {
    162 				return n, c.sendAlertLocked(alertLevelError, err.(alert))
    163 			}
    164 		}
    165 		return
    166 	}
    167 
    168 	if c.out.cipher == nil && c.config.Bugs.StrayChangeCipherSpec {
    169 		_, err = c.dtlsWriteRawRecord(recordTypeChangeCipherSpec, []byte{1})
    170 		if err != nil {
    171 			return
    172 		}
    173 	}
    174 
    175 	maxLen := c.config.Bugs.MaxHandshakeRecordLength
    176 	if maxLen <= 0 {
    177 		maxLen = 1024
    178 	}
    179 
    180 	// Handshake messages have to be modified to include fragment
    181 	// offset and length and with the header replicated. Save the
    182 	// TLS header here.
    183 	//
    184 	// TODO(davidben): This assumes that data contains exactly one
    185 	// handshake message. This is incompatible with
    186 	// FragmentAcrossChangeCipherSpec. (Which is unfortunate
    187 	// because OpenSSL's DTLS implementation will probably accept
    188 	// such fragmentation and could do with a fix + tests.)
    189 	header := data[:4]
    190 	data = data[4:]
    191 
    192 	isFinished := header[0] == typeFinished
    193 
    194 	if c.config.Bugs.SendEmptyFragments {
    195 		fragment := c.makeFragment(header, data, 0, 0)
    196 		c.pendingFragments = append(c.pendingFragments, fragment)
    197 	}
    198 
    199 	firstRun := true
    200 	fragOffset := 0
    201 	for firstRun || fragOffset < len(data) {
    202 		firstRun = false
    203 		fragLen := len(data) - fragOffset
    204 		if fragLen > maxLen {
    205 			fragLen = maxLen
    206 		}
    207 
    208 		fragment := c.makeFragment(header, data, fragOffset, fragLen)
    209 		if c.config.Bugs.FragmentMessageTypeMismatch && fragOffset > 0 {
    210 			fragment[0]++
    211 		}
    212 		if c.config.Bugs.FragmentMessageLengthMismatch && fragOffset > 0 {
    213 			fragment[3]++
    214 		}
    215 
    216 		// Buffer the fragment for later. They will be sent (and
    217 		// reordered) on flush.
    218 		c.pendingFragments = append(c.pendingFragments, fragment)
    219 		if c.config.Bugs.ReorderHandshakeFragments {
    220 			// Don't duplicate Finished to avoid the peer
    221 			// interpreting it as a retransmit request.
    222 			if !isFinished {
    223 				c.pendingFragments = append(c.pendingFragments, fragment)
    224 			}
    225 
    226 			if fragLen > (maxLen+1)/2 {
    227 				// Overlap each fragment by half.
    228 				fragLen = (maxLen + 1) / 2
    229 			}
    230 		}
    231 		fragOffset += fragLen
    232 		n += fragLen
    233 	}
    234 	if !isFinished && c.config.Bugs.MixCompleteMessageWithFragments {
    235 		fragment := c.makeFragment(header, data, 0, len(data))
    236 		c.pendingFragments = append(c.pendingFragments, fragment)
    237 	}
    238 
    239 	// Increment the handshake sequence number for the next
    240 	// handshake message.
    241 	c.sendHandshakeSeq++
    242 	return
    243 }
    244 
    245 func (c *Conn) dtlsFlushHandshake() error {
    246 	// This is a test-only DTLS implementation, so there is no need to
    247 	// retain |c.pendingFragments| for a future retransmit.
    248 	var fragments [][]byte
    249 	fragments, c.pendingFragments = c.pendingFragments, fragments
    250 
    251 	if c.config.Bugs.ReorderHandshakeFragments {
    252 		perm := rand.New(rand.NewSource(0)).Perm(len(fragments))
    253 		tmp := make([][]byte, len(fragments))
    254 		for i := range tmp {
    255 			tmp[i] = fragments[perm[i]]
    256 		}
    257 		fragments = tmp
    258 	} else if c.config.Bugs.ReverseHandshakeFragments {
    259 		tmp := make([][]byte, len(fragments))
    260 		for i := range tmp {
    261 			tmp[i] = fragments[len(fragments)-i-1]
    262 		}
    263 		fragments = tmp
    264 	}
    265 
    266 	maxRecordLen := c.config.Bugs.PackHandshakeFragments
    267 	maxPacketLen := c.config.Bugs.PackHandshakeRecords
    268 
    269 	// Pack handshake fragments into records.
    270 	var records [][]byte
    271 	for _, fragment := range fragments {
    272 		if n := c.config.Bugs.SplitFragments; n > 0 {
    273 			if len(fragment) > n {
    274 				records = append(records, fragment[:n])
    275 				records = append(records, fragment[n:])
    276 			} else {
    277 				records = append(records, fragment)
    278 			}
    279 		} else if i := len(records) - 1; len(records) > 0 && len(records[i])+len(fragment) <= maxRecordLen {
    280 			records[i] = append(records[i], fragment...)
    281 		} else {
    282 			// The fragment will be appended to, so copy it.
    283 			records = append(records, append([]byte{}, fragment...))
    284 		}
    285 	}
    286 
    287 	// Format them into packets.
    288 	var packets [][]byte
    289 	for _, record := range records {
    290 		b, err := c.dtlsSealRecord(recordTypeHandshake, record)
    291 		if err != nil {
    292 			return err
    293 		}
    294 
    295 		if i := len(packets) - 1; len(packets) > 0 && len(packets[i])+len(b.data) <= maxPacketLen {
    296 			packets[i] = append(packets[i], b.data...)
    297 		} else {
    298 			// The sealed record will be appended to and reused by
    299 			// |c.out|, so copy it.
    300 			packets = append(packets, append([]byte{}, b.data...))
    301 		}
    302 		c.out.freeBlock(b)
    303 	}
    304 
    305 	// Send all the packets.
    306 	for _, packet := range packets {
    307 		if _, err := c.conn.Write(packet); err != nil {
    308 			return err
    309 		}
    310 	}
    311 	return nil
    312 }
    313 
    314 // dtlsSealRecord seals a record into a block from |c.out|'s pool.
    315 func (c *Conn) dtlsSealRecord(typ recordType, data []byte) (b *block, err error) {
    316 	recordHeaderLen := dtlsRecordHeaderLen
    317 	maxLen := c.config.Bugs.MaxHandshakeRecordLength
    318 	if maxLen <= 0 {
    319 		maxLen = 1024
    320 	}
    321 
    322 	b = c.out.newBlock()
    323 
    324 	explicitIVLen := 0
    325 	explicitIVIsSeq := false
    326 
    327 	if cbc, ok := c.out.cipher.(cbcMode); ok {
    328 		// Block cipher modes have an explicit IV.
    329 		explicitIVLen = cbc.BlockSize()
    330 	} else if aead, ok := c.out.cipher.(*tlsAead); ok {
    331 		if aead.explicitNonce {
    332 			explicitIVLen = 8
    333 			// The AES-GCM construction in TLS has an explicit nonce so that
    334 			// the nonce can be random. However, the nonce is only 8 bytes
    335 			// which is too small for a secure, random nonce. Therefore we
    336 			// use the sequence number as the nonce.
    337 			explicitIVIsSeq = true
    338 		}
    339 	} else if _, ok := c.out.cipher.(nullCipher); !ok && c.out.cipher != nil {
    340 		panic("Unknown cipher")
    341 	}
    342 	b.resize(recordHeaderLen + explicitIVLen + len(data))
    343 	// TODO(nharper): DTLS 1.3 will likely need to set this to
    344 	// recordTypeApplicationData if c.out.cipher != nil.
    345 	b.data[0] = byte(typ)
    346 	vers := c.wireVersion
    347 	if vers == 0 {
    348 		// Some TLS servers fail if the record version is greater than
    349 		// TLS 1.0 for the initial ClientHello.
    350 		if c.isDTLS {
    351 			vers = VersionDTLS10
    352 		} else {
    353 			vers = VersionTLS10
    354 		}
    355 	}
    356 	b.data[1] = byte(vers >> 8)
    357 	b.data[2] = byte(vers)
    358 	// DTLS records include an explicit sequence number.
    359 	copy(b.data[3:11], c.out.outSeq[0:])
    360 	b.data[11] = byte(len(data) >> 8)
    361 	b.data[12] = byte(len(data))
    362 	if explicitIVLen > 0 {
    363 		explicitIV := b.data[recordHeaderLen : recordHeaderLen+explicitIVLen]
    364 		if explicitIVIsSeq {
    365 			copy(explicitIV, c.out.outSeq[:])
    366 		} else {
    367 			if _, err = io.ReadFull(c.config.rand(), explicitIV); err != nil {
    368 				return
    369 			}
    370 		}
    371 	}
    372 	copy(b.data[recordHeaderLen+explicitIVLen:], data)
    373 	c.out.encrypt(b, explicitIVLen, typ)
    374 	return
    375 }
    376 
    377 func (c *Conn) dtlsWriteRawRecord(typ recordType, data []byte) (n int, err error) {
    378 	b, err := c.dtlsSealRecord(typ, data)
    379 	if err != nil {
    380 		return
    381 	}
    382 
    383 	_, err = c.conn.Write(b.data)
    384 	if err != nil {
    385 		return
    386 	}
    387 	n = len(data)
    388 
    389 	c.out.freeBlock(b)
    390 	return
    391 }
    392 
    393 func (c *Conn) dtlsDoReadHandshake() ([]byte, error) {
    394 	// Assemble a full handshake message.  For test purposes, this
    395 	// implementation assumes fragments arrive in order. It may
    396 	// need to be cleverer if we ever test BoringSSL's retransmit
    397 	// behavior.
    398 	for len(c.handMsg) < 4+c.handMsgLen {
    399 		// Get a new handshake record if the previous has been
    400 		// exhausted.
    401 		if c.hand.Len() == 0 {
    402 			if err := c.in.err; err != nil {
    403 				return nil, err
    404 			}
    405 			if err := c.readRecord(recordTypeHandshake); err != nil {
    406 				return nil, err
    407 			}
    408 		}
    409 
    410 		// Read the next fragment. It must fit entirely within
    411 		// the record.
    412 		if c.hand.Len() < 12 {
    413 			return nil, errors.New("dtls: bad handshake record")
    414 		}
    415 		header := c.hand.Next(12)
    416 		fragN := int(header[1])<<16 | int(header[2])<<8 | int(header[3])
    417 		fragSeq := uint16(header[4])<<8 | uint16(header[5])
    418 		fragOff := int(header[6])<<16 | int(header[7])<<8 | int(header[8])
    419 		fragLen := int(header[9])<<16 | int(header[10])<<8 | int(header[11])
    420 
    421 		if c.hand.Len() < fragLen {
    422 			return nil, errors.New("dtls: fragment length too long")
    423 		}
    424 		fragment := c.hand.Next(fragLen)
    425 
    426 		// Check it's a fragment for the right message.
    427 		if fragSeq != c.recvHandshakeSeq {
    428 			return nil, errors.New("dtls: bad handshake sequence number")
    429 		}
    430 
    431 		// Check that the length is consistent.
    432 		if c.handMsg == nil {
    433 			c.handMsgLen = fragN
    434 			if c.handMsgLen > maxHandshake {
    435 				return nil, c.in.setErrorLocked(c.sendAlert(alertInternalError))
    436 			}
    437 			// Start with the TLS handshake header,
    438 			// without the DTLS bits.
    439 			c.handMsg = append([]byte{}, header[:4]...)
    440 		} else if fragN != c.handMsgLen {
    441 			return nil, errors.New("dtls: bad handshake length")
    442 		}
    443 
    444 		// Add the fragment to the pending message.
    445 		if 4+fragOff != len(c.handMsg) {
    446 			return nil, errors.New("dtls: bad fragment offset")
    447 		}
    448 		if fragOff+fragLen > c.handMsgLen {
    449 			return nil, errors.New("dtls: bad fragment length")
    450 		}
    451 		c.handMsg = append(c.handMsg, fragment...)
    452 	}
    453 	c.recvHandshakeSeq++
    454 	ret := c.handMsg
    455 	c.handMsg, c.handMsgLen = nil, 0
    456 	return ret, nil
    457 }
    458 
    459 // DTLSServer returns a new DTLS server side connection
    460 // using conn as the underlying transport.
    461 // The configuration config must be non-nil and must have
    462 // at least one certificate.
    463 func DTLSServer(conn net.Conn, config *Config) *Conn {
    464 	c := &Conn{config: config, isDTLS: true, conn: conn}
    465 	c.init()
    466 	return c
    467 }
    468 
    469 // DTLSClient returns a new DTLS client side connection
    470 // using conn as the underlying transport.
    471 // The config cannot be nil: users must set either ServerHostname or
    472 // InsecureSkipVerify in the config.
    473 func DTLSClient(conn net.Conn, config *Config) *Conn {
    474 	c := &Conn{config: config, isClient: true, isDTLS: true, conn: conn}
    475 	c.init()
    476 	return c
    477 }
    478