Home | History | Annotate | Download | only in asn1
      1 // Copyright 2009 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 package asn1
      6 
      7 import (
      8 	"bytes"
      9 	"errors"
     10 	"fmt"
     11 	"io"
     12 	"math/big"
     13 	"reflect"
     14 	"time"
     15 	"unicode/utf8"
     16 )
     17 
     18 // A forkableWriter is an in-memory buffer that can be
     19 // 'forked' to create new forkableWriters that bracket the
     20 // original.  After
     21 //    pre, post := w.fork()
     22 // the overall sequence of bytes represented is logically w+pre+post.
     23 type forkableWriter struct {
     24 	*bytes.Buffer
     25 	pre, post *forkableWriter
     26 }
     27 
     28 func newForkableWriter() *forkableWriter {
     29 	return &forkableWriter{new(bytes.Buffer), nil, nil}
     30 }
     31 
     32 func (f *forkableWriter) fork() (pre, post *forkableWriter) {
     33 	if f.pre != nil || f.post != nil {
     34 		panic("have already forked")
     35 	}
     36 	f.pre = newForkableWriter()
     37 	f.post = newForkableWriter()
     38 	return f.pre, f.post
     39 }
     40 
     41 func (f *forkableWriter) Len() (l int) {
     42 	l += f.Buffer.Len()
     43 	if f.pre != nil {
     44 		l += f.pre.Len()
     45 	}
     46 	if f.post != nil {
     47 		l += f.post.Len()
     48 	}
     49 	return
     50 }
     51 
     52 func (f *forkableWriter) writeTo(out io.Writer) (n int, err error) {
     53 	n, err = out.Write(f.Bytes())
     54 	if err != nil {
     55 		return
     56 	}
     57 
     58 	var nn int
     59 
     60 	if f.pre != nil {
     61 		nn, err = f.pre.writeTo(out)
     62 		n += nn
     63 		if err != nil {
     64 			return
     65 		}
     66 	}
     67 
     68 	if f.post != nil {
     69 		nn, err = f.post.writeTo(out)
     70 		n += nn
     71 	}
     72 	return
     73 }
     74 
     75 func marshalBase128Int(out *forkableWriter, n int64) (err error) {
     76 	if n == 0 {
     77 		err = out.WriteByte(0)
     78 		return
     79 	}
     80 
     81 	l := 0
     82 	for i := n; i > 0; i >>= 7 {
     83 		l++
     84 	}
     85 
     86 	for i := l - 1; i >= 0; i-- {
     87 		o := byte(n >> uint(i*7))
     88 		o &= 0x7f
     89 		if i != 0 {
     90 			o |= 0x80
     91 		}
     92 		err = out.WriteByte(o)
     93 		if err != nil {
     94 			return
     95 		}
     96 	}
     97 
     98 	return nil
     99 }
    100 
    101 func marshalInt64(out *forkableWriter, i int64) (err error) {
    102 	n := int64Length(i)
    103 
    104 	for ; n > 0; n-- {
    105 		err = out.WriteByte(byte(i >> uint((n-1)*8)))
    106 		if err != nil {
    107 			return
    108 		}
    109 	}
    110 
    111 	return nil
    112 }
    113 
    114 func int64Length(i int64) (numBytes int) {
    115 	numBytes = 1
    116 
    117 	for i > 127 {
    118 		numBytes++
    119 		i >>= 8
    120 	}
    121 
    122 	for i < -128 {
    123 		numBytes++
    124 		i >>= 8
    125 	}
    126 
    127 	return
    128 }
    129 
    130 func marshalBigInt(out *forkableWriter, n *big.Int) (err error) {
    131 	if n.Sign() < 0 {
    132 		// A negative number has to be converted to two's-complement
    133 		// form. So we'll subtract 1 and invert. If the
    134 		// most-significant-bit isn't set then we'll need to pad the
    135 		// beginning with 0xff in order to keep the number negative.
    136 		nMinus1 := new(big.Int).Neg(n)
    137 		nMinus1.Sub(nMinus1, bigOne)
    138 		bytes := nMinus1.Bytes()
    139 		for i := range bytes {
    140 			bytes[i] ^= 0xff
    141 		}
    142 		if len(bytes) == 0 || bytes[0]&0x80 == 0 {
    143 			err = out.WriteByte(0xff)
    144 			if err != nil {
    145 				return
    146 			}
    147 		}
    148 		_, err = out.Write(bytes)
    149 	} else if n.Sign() == 0 {
    150 		// Zero is written as a single 0 zero rather than no bytes.
    151 		err = out.WriteByte(0x00)
    152 	} else {
    153 		bytes := n.Bytes()
    154 		if len(bytes) > 0 && bytes[0]&0x80 != 0 {
    155 			// We'll have to pad this with 0x00 in order to stop it
    156 			// looking like a negative number.
    157 			err = out.WriteByte(0)
    158 			if err != nil {
    159 				return
    160 			}
    161 		}
    162 		_, err = out.Write(bytes)
    163 	}
    164 	return
    165 }
    166 
    167 func marshalLength(out *forkableWriter, i int) (err error) {
    168 	n := lengthLength(i)
    169 
    170 	for ; n > 0; n-- {
    171 		err = out.WriteByte(byte(i >> uint((n-1)*8)))
    172 		if err != nil {
    173 			return
    174 		}
    175 	}
    176 
    177 	return nil
    178 }
    179 
    180 func lengthLength(i int) (numBytes int) {
    181 	numBytes = 1
    182 	for i > 255 {
    183 		numBytes++
    184 		i >>= 8
    185 	}
    186 	return
    187 }
    188 
    189 func marshalTagAndLength(out *forkableWriter, t tagAndLength) (err error) {
    190 	b := uint8(t.class) << 6
    191 	if t.isCompound {
    192 		b |= 0x20
    193 	}
    194 	if t.tag >= 31 {
    195 		b |= 0x1f
    196 		err = out.WriteByte(b)
    197 		if err != nil {
    198 			return
    199 		}
    200 		err = marshalBase128Int(out, int64(t.tag))
    201 		if err != nil {
    202 			return
    203 		}
    204 	} else {
    205 		b |= uint8(t.tag)
    206 		err = out.WriteByte(b)
    207 		if err != nil {
    208 			return
    209 		}
    210 	}
    211 
    212 	if t.length >= 128 {
    213 		l := lengthLength(t.length)
    214 		err = out.WriteByte(0x80 | byte(l))
    215 		if err != nil {
    216 			return
    217 		}
    218 		err = marshalLength(out, t.length)
    219 		if err != nil {
    220 			return
    221 		}
    222 	} else {
    223 		err = out.WriteByte(byte(t.length))
    224 		if err != nil {
    225 			return
    226 		}
    227 	}
    228 
    229 	return nil
    230 }
    231 
    232 func marshalBitString(out *forkableWriter, b BitString) (err error) {
    233 	paddingBits := byte((8 - b.BitLength%8) % 8)
    234 	err = out.WriteByte(paddingBits)
    235 	if err != nil {
    236 		return
    237 	}
    238 	_, err = out.Write(b.Bytes)
    239 	return
    240 }
    241 
    242 func marshalObjectIdentifier(out *forkableWriter, oid []int) (err error) {
    243 	if len(oid) < 2 || oid[0] > 2 || (oid[0] < 2 && oid[1] >= 40) {
    244 		return StructuralError{"invalid object identifier"}
    245 	}
    246 
    247 	err = marshalBase128Int(out, int64(oid[0]*40+oid[1]))
    248 	if err != nil {
    249 		return
    250 	}
    251 	for i := 2; i < len(oid); i++ {
    252 		err = marshalBase128Int(out, int64(oid[i]))
    253 		if err != nil {
    254 			return
    255 		}
    256 	}
    257 
    258 	return
    259 }
    260 
    261 func marshalPrintableString(out *forkableWriter, s string) (err error) {
    262 	b := []byte(s)
    263 	for _, c := range b {
    264 		if !isPrintable(c) {
    265 			return StructuralError{"PrintableString contains invalid character"}
    266 		}
    267 	}
    268 
    269 	_, err = out.Write(b)
    270 	return
    271 }
    272 
    273 func marshalIA5String(out *forkableWriter, s string) (err error) {
    274 	b := []byte(s)
    275 	for _, c := range b {
    276 		if c > 127 {
    277 			return StructuralError{"IA5String contains invalid character"}
    278 		}
    279 	}
    280 
    281 	_, err = out.Write(b)
    282 	return
    283 }
    284 
    285 func marshalUTF8String(out *forkableWriter, s string) (err error) {
    286 	_, err = out.Write([]byte(s))
    287 	return
    288 }
    289 
    290 func marshalTwoDigits(out *forkableWriter, v int) (err error) {
    291 	err = out.WriteByte(byte('0' + (v/10)%10))
    292 	if err != nil {
    293 		return
    294 	}
    295 	return out.WriteByte(byte('0' + v%10))
    296 }
    297 
    298 func marshalFourDigits(out *forkableWriter, v int) (err error) {
    299 	var bytes [4]byte
    300 	for i := range bytes {
    301 		bytes[3-i] = '0' + byte(v%10)
    302 		v /= 10
    303 	}
    304 	_, err = out.Write(bytes[:])
    305 	return
    306 }
    307 
    308 func outsideUTCRange(t time.Time) bool {
    309 	year := t.Year()
    310 	return year < 1950 || year >= 2050
    311 }
    312 
    313 func marshalUTCTime(out *forkableWriter, t time.Time) (err error) {
    314 	year := t.Year()
    315 
    316 	switch {
    317 	case 1950 <= year && year < 2000:
    318 		err = marshalTwoDigits(out, int(year-1900))
    319 	case 2000 <= year && year < 2050:
    320 		err = marshalTwoDigits(out, int(year-2000))
    321 	default:
    322 		return StructuralError{"cannot represent time as UTCTime"}
    323 	}
    324 	if err != nil {
    325 		return
    326 	}
    327 
    328 	return marshalTimeCommon(out, t)
    329 }
    330 
    331 func marshalGeneralizedTime(out *forkableWriter, t time.Time) (err error) {
    332 	year := t.Year()
    333 	if year < 0 || year > 9999 {
    334 		return StructuralError{"cannot represent time as GeneralizedTime"}
    335 	}
    336 	if err = marshalFourDigits(out, year); err != nil {
    337 		return
    338 	}
    339 
    340 	return marshalTimeCommon(out, t)
    341 }
    342 
    343 func marshalTimeCommon(out *forkableWriter, t time.Time) (err error) {
    344 	_, month, day := t.Date()
    345 
    346 	err = marshalTwoDigits(out, int(month))
    347 	if err != nil {
    348 		return
    349 	}
    350 
    351 	err = marshalTwoDigits(out, day)
    352 	if err != nil {
    353 		return
    354 	}
    355 
    356 	hour, min, sec := t.Clock()
    357 
    358 	err = marshalTwoDigits(out, hour)
    359 	if err != nil {
    360 		return
    361 	}
    362 
    363 	err = marshalTwoDigits(out, min)
    364 	if err != nil {
    365 		return
    366 	}
    367 
    368 	err = marshalTwoDigits(out, sec)
    369 	if err != nil {
    370 		return
    371 	}
    372 
    373 	_, offset := t.Zone()
    374 
    375 	switch {
    376 	case offset/60 == 0:
    377 		err = out.WriteByte('Z')
    378 		return
    379 	case offset > 0:
    380 		err = out.WriteByte('+')
    381 	case offset < 0:
    382 		err = out.WriteByte('-')
    383 	}
    384 
    385 	if err != nil {
    386 		return
    387 	}
    388 
    389 	offsetMinutes := offset / 60
    390 	if offsetMinutes < 0 {
    391 		offsetMinutes = -offsetMinutes
    392 	}
    393 
    394 	err = marshalTwoDigits(out, offsetMinutes/60)
    395 	if err != nil {
    396 		return
    397 	}
    398 
    399 	err = marshalTwoDigits(out, offsetMinutes%60)
    400 	return
    401 }
    402 
    403 func stripTagAndLength(in []byte) []byte {
    404 	_, offset, err := parseTagAndLength(in, 0)
    405 	if err != nil {
    406 		return in
    407 	}
    408 	return in[offset:]
    409 }
    410 
    411 func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameters) (err error) {
    412 	switch value.Type() {
    413 	case flagType:
    414 		return nil
    415 	case timeType:
    416 		t := value.Interface().(time.Time)
    417 		if params.timeType == tagGeneralizedTime || outsideUTCRange(t) {
    418 			return marshalGeneralizedTime(out, t)
    419 		} else {
    420 			return marshalUTCTime(out, t)
    421 		}
    422 	case bitStringType:
    423 		return marshalBitString(out, value.Interface().(BitString))
    424 	case objectIdentifierType:
    425 		return marshalObjectIdentifier(out, value.Interface().(ObjectIdentifier))
    426 	case bigIntType:
    427 		return marshalBigInt(out, value.Interface().(*big.Int))
    428 	}
    429 
    430 	switch v := value; v.Kind() {
    431 	case reflect.Bool:
    432 		if v.Bool() {
    433 			return out.WriteByte(255)
    434 		} else {
    435 			return out.WriteByte(0)
    436 		}
    437 	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
    438 		return marshalInt64(out, int64(v.Int()))
    439 	case reflect.Struct:
    440 		t := v.Type()
    441 
    442 		startingField := 0
    443 
    444 		// If the first element of the structure is a non-empty
    445 		// RawContents, then we don't bother serializing the rest.
    446 		if t.NumField() > 0 && t.Field(0).Type == rawContentsType {
    447 			s := v.Field(0)
    448 			if s.Len() > 0 {
    449 				bytes := make([]byte, s.Len())
    450 				for i := 0; i < s.Len(); i++ {
    451 					bytes[i] = uint8(s.Index(i).Uint())
    452 				}
    453 				/* The RawContents will contain the tag and
    454 				 * length fields but we'll also be writing
    455 				 * those ourselves, so we strip them out of
    456 				 * bytes */
    457 				_, err = out.Write(stripTagAndLength(bytes))
    458 				return
    459 			} else {
    460 				startingField = 1
    461 			}
    462 		}
    463 
    464 		for i := startingField; i < t.NumField(); i++ {
    465 			var pre *forkableWriter
    466 			pre, out = out.fork()
    467 			err = marshalField(pre, v.Field(i), parseFieldParameters(t.Field(i).Tag.Get("asn1")))
    468 			if err != nil {
    469 				return
    470 			}
    471 		}
    472 		return
    473 	case reflect.Slice:
    474 		sliceType := v.Type()
    475 		if sliceType.Elem().Kind() == reflect.Uint8 {
    476 			bytes := make([]byte, v.Len())
    477 			for i := 0; i < v.Len(); i++ {
    478 				bytes[i] = uint8(v.Index(i).Uint())
    479 			}
    480 			_, err = out.Write(bytes)
    481 			return
    482 		}
    483 
    484 		var fp fieldParameters
    485 		for i := 0; i < v.Len(); i++ {
    486 			var pre *forkableWriter
    487 			pre, out = out.fork()
    488 			err = marshalField(pre, v.Index(i), fp)
    489 			if err != nil {
    490 				return
    491 			}
    492 		}
    493 		return
    494 	case reflect.String:
    495 		switch params.stringType {
    496 		case tagIA5String:
    497 			return marshalIA5String(out, v.String())
    498 		case tagPrintableString:
    499 			return marshalPrintableString(out, v.String())
    500 		default:
    501 			return marshalUTF8String(out, v.String())
    502 		}
    503 	}
    504 
    505 	return StructuralError{"unknown Go type"}
    506 }
    507 
    508 func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters) (err error) {
    509 	// If the field is an interface{} then recurse into it.
    510 	if v.Kind() == reflect.Interface && v.Type().NumMethod() == 0 {
    511 		return marshalField(out, v.Elem(), params)
    512 	}
    513 
    514 	if v.Kind() == reflect.Slice && v.Len() == 0 && params.omitEmpty {
    515 		return
    516 	}
    517 
    518 	if params.optional && params.defaultValue != nil && canHaveDefaultValue(v.Kind()) {
    519 		defaultValue := reflect.New(v.Type()).Elem()
    520 		defaultValue.SetInt(*params.defaultValue)
    521 
    522 		if reflect.DeepEqual(v.Interface(), defaultValue.Interface()) {
    523 			return
    524 		}
    525 	}
    526 
    527 	// If no default value is given then the zero value for the type is
    528 	// assumed to be the default value. This isn't obviously the correct
    529 	// behaviour, but it's what Go has traditionally done.
    530 	if params.optional && params.defaultValue == nil {
    531 		if reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface()) {
    532 			return
    533 		}
    534 	}
    535 
    536 	if v.Type() == rawValueType {
    537 		rv := v.Interface().(RawValue)
    538 		if len(rv.FullBytes) != 0 {
    539 			_, err = out.Write(rv.FullBytes)
    540 		} else {
    541 			err = marshalTagAndLength(out, tagAndLength{rv.Class, rv.Tag, len(rv.Bytes), rv.IsCompound})
    542 			if err != nil {
    543 				return
    544 			}
    545 			_, err = out.Write(rv.Bytes)
    546 		}
    547 		return
    548 	}
    549 
    550 	tag, isCompound, ok := getUniversalType(v.Type())
    551 	if !ok {
    552 		err = StructuralError{fmt.Sprintf("unknown Go type: %v", v.Type())}
    553 		return
    554 	}
    555 	class := classUniversal
    556 
    557 	if params.timeType != 0 && tag != tagUTCTime {
    558 		return StructuralError{"explicit time type given to non-time member"}
    559 	}
    560 
    561 	if params.stringType != 0 && tag != tagPrintableString {
    562 		return StructuralError{"explicit string type given to non-string member"}
    563 	}
    564 
    565 	switch tag {
    566 	case tagPrintableString:
    567 		if params.stringType == 0 {
    568 			// This is a string without an explicit string type. We'll use
    569 			// a PrintableString if the character set in the string is
    570 			// sufficiently limited, otherwise we'll use a UTF8String.
    571 			for _, r := range v.String() {
    572 				if r >= utf8.RuneSelf || !isPrintable(byte(r)) {
    573 					if !utf8.ValidString(v.String()) {
    574 						return errors.New("asn1: string not valid UTF-8")
    575 					}
    576 					tag = tagUTF8String
    577 					break
    578 				}
    579 			}
    580 		} else {
    581 			tag = params.stringType
    582 		}
    583 	case tagUTCTime:
    584 		if params.timeType == tagGeneralizedTime || outsideUTCRange(v.Interface().(time.Time)) {
    585 			tag = tagGeneralizedTime
    586 		}
    587 	}
    588 
    589 	if params.set {
    590 		if tag != tagSequence {
    591 			return StructuralError{"non sequence tagged as set"}
    592 		}
    593 		tag = tagSet
    594 	}
    595 
    596 	tags, body := out.fork()
    597 
    598 	err = marshalBody(body, v, params)
    599 	if err != nil {
    600 		return
    601 	}
    602 
    603 	bodyLen := body.Len()
    604 
    605 	var explicitTag *forkableWriter
    606 	if params.explicit {
    607 		explicitTag, tags = tags.fork()
    608 	}
    609 
    610 	if !params.explicit && params.tag != nil {
    611 		// implicit tag.
    612 		tag = *params.tag
    613 		class = classContextSpecific
    614 	}
    615 
    616 	err = marshalTagAndLength(tags, tagAndLength{class, tag, bodyLen, isCompound})
    617 	if err != nil {
    618 		return
    619 	}
    620 
    621 	if params.explicit {
    622 		err = marshalTagAndLength(explicitTag, tagAndLength{
    623 			class:      classContextSpecific,
    624 			tag:        *params.tag,
    625 			length:     bodyLen + tags.Len(),
    626 			isCompound: true,
    627 		})
    628 	}
    629 
    630 	return nil
    631 }
    632 
    633 // Marshal returns the ASN.1 encoding of val.
    634 //
    635 // In addition to the struct tags recognised by Unmarshal, the following can be
    636 // used:
    637 //
    638 //	ia5:		causes strings to be marshaled as ASN.1, IA5 strings
    639 //	omitempty:	causes empty slices to be skipped
    640 //	printable:	causes strings to be marshaled as ASN.1, PrintableString strings.
    641 //	utf8:		causes strings to be marshaled as ASN.1, UTF8 strings
    642 func Marshal(val interface{}) ([]byte, error) {
    643 	var out bytes.Buffer
    644 	v := reflect.ValueOf(val)
    645 	f := newForkableWriter()
    646 	err := marshalField(f, v, fieldParameters{})
    647 	if err != nil {
    648 		return nil, err
    649 	}
    650 	_, err = f.writeTo(&out)
    651 	return out.Bytes(), nil
    652 }
    653