1 // Copyright 2009 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package asn1 6 7 import ( 8 "bytes" 9 "errors" 10 "fmt" 11 "io" 12 "math/big" 13 "reflect" 14 "time" 15 "unicode/utf8" 16 ) 17 18 // A forkableWriter is an in-memory buffer that can be 19 // 'forked' to create new forkableWriters that bracket the 20 // original. After 21 // pre, post := w.fork() 22 // the overall sequence of bytes represented is logically w+pre+post. 23 type forkableWriter struct { 24 *bytes.Buffer 25 pre, post *forkableWriter 26 } 27 28 func newForkableWriter() *forkableWriter { 29 return &forkableWriter{new(bytes.Buffer), nil, nil} 30 } 31 32 func (f *forkableWriter) fork() (pre, post *forkableWriter) { 33 if f.pre != nil || f.post != nil { 34 panic("have already forked") 35 } 36 f.pre = newForkableWriter() 37 f.post = newForkableWriter() 38 return f.pre, f.post 39 } 40 41 func (f *forkableWriter) Len() (l int) { 42 l += f.Buffer.Len() 43 if f.pre != nil { 44 l += f.pre.Len() 45 } 46 if f.post != nil { 47 l += f.post.Len() 48 } 49 return 50 } 51 52 func (f *forkableWriter) writeTo(out io.Writer) (n int, err error) { 53 n, err = out.Write(f.Bytes()) 54 if err != nil { 55 return 56 } 57 58 var nn int 59 60 if f.pre != nil { 61 nn, err = f.pre.writeTo(out) 62 n += nn 63 if err != nil { 64 return 65 } 66 } 67 68 if f.post != nil { 69 nn, err = f.post.writeTo(out) 70 n += nn 71 } 72 return 73 } 74 75 func marshalBase128Int(out *forkableWriter, n int64) (err error) { 76 if n == 0 { 77 err = out.WriteByte(0) 78 return 79 } 80 81 l := 0 82 for i := n; i > 0; i >>= 7 { 83 l++ 84 } 85 86 for i := l - 1; i >= 0; i-- { 87 o := byte(n >> uint(i*7)) 88 o &= 0x7f 89 if i != 0 { 90 o |= 0x80 91 } 92 err = out.WriteByte(o) 93 if err != nil { 94 return 95 } 96 } 97 98 return nil 99 } 100 101 func marshalInt64(out *forkableWriter, i int64) (err error) { 102 n := int64Length(i) 103 104 for ; n > 0; n-- { 105 err = out.WriteByte(byte(i >> uint((n-1)*8))) 106 if err != nil { 107 return 108 } 109 } 110 111 return nil 112 } 113 114 func int64Length(i int64) (numBytes int) { 115 numBytes = 1 116 117 for i > 127 { 118 numBytes++ 119 i >>= 8 120 } 121 122 for i < -128 { 123 numBytes++ 124 i >>= 8 125 } 126 127 return 128 } 129 130 func marshalBigInt(out *forkableWriter, n *big.Int) (err error) { 131 if n.Sign() < 0 { 132 // A negative number has to be converted to two's-complement 133 // form. So we'll subtract 1 and invert. If the 134 // most-significant-bit isn't set then we'll need to pad the 135 // beginning with 0xff in order to keep the number negative. 136 nMinus1 := new(big.Int).Neg(n) 137 nMinus1.Sub(nMinus1, bigOne) 138 bytes := nMinus1.Bytes() 139 for i := range bytes { 140 bytes[i] ^= 0xff 141 } 142 if len(bytes) == 0 || bytes[0]&0x80 == 0 { 143 err = out.WriteByte(0xff) 144 if err != nil { 145 return 146 } 147 } 148 _, err = out.Write(bytes) 149 } else if n.Sign() == 0 { 150 // Zero is written as a single 0 zero rather than no bytes. 151 err = out.WriteByte(0x00) 152 } else { 153 bytes := n.Bytes() 154 if len(bytes) > 0 && bytes[0]&0x80 != 0 { 155 // We'll have to pad this with 0x00 in order to stop it 156 // looking like a negative number. 157 err = out.WriteByte(0) 158 if err != nil { 159 return 160 } 161 } 162 _, err = out.Write(bytes) 163 } 164 return 165 } 166 167 func marshalLength(out *forkableWriter, i int) (err error) { 168 n := lengthLength(i) 169 170 for ; n > 0; n-- { 171 err = out.WriteByte(byte(i >> uint((n-1)*8))) 172 if err != nil { 173 return 174 } 175 } 176 177 return nil 178 } 179 180 func lengthLength(i int) (numBytes int) { 181 numBytes = 1 182 for i > 255 { 183 numBytes++ 184 i >>= 8 185 } 186 return 187 } 188 189 func marshalTagAndLength(out *forkableWriter, t tagAndLength) (err error) { 190 b := uint8(t.class) << 6 191 if t.isCompound { 192 b |= 0x20 193 } 194 if t.tag >= 31 { 195 b |= 0x1f 196 err = out.WriteByte(b) 197 if err != nil { 198 return 199 } 200 err = marshalBase128Int(out, int64(t.tag)) 201 if err != nil { 202 return 203 } 204 } else { 205 b |= uint8(t.tag) 206 err = out.WriteByte(b) 207 if err != nil { 208 return 209 } 210 } 211 212 if t.length >= 128 { 213 l := lengthLength(t.length) 214 err = out.WriteByte(0x80 | byte(l)) 215 if err != nil { 216 return 217 } 218 err = marshalLength(out, t.length) 219 if err != nil { 220 return 221 } 222 } else { 223 err = out.WriteByte(byte(t.length)) 224 if err != nil { 225 return 226 } 227 } 228 229 return nil 230 } 231 232 func marshalBitString(out *forkableWriter, b BitString) (err error) { 233 paddingBits := byte((8 - b.BitLength%8) % 8) 234 err = out.WriteByte(paddingBits) 235 if err != nil { 236 return 237 } 238 _, err = out.Write(b.Bytes) 239 return 240 } 241 242 func marshalObjectIdentifier(out *forkableWriter, oid []int) (err error) { 243 if len(oid) < 2 || oid[0] > 2 || (oid[0] < 2 && oid[1] >= 40) { 244 return StructuralError{"invalid object identifier"} 245 } 246 247 err = marshalBase128Int(out, int64(oid[0]*40+oid[1])) 248 if err != nil { 249 return 250 } 251 for i := 2; i < len(oid); i++ { 252 err = marshalBase128Int(out, int64(oid[i])) 253 if err != nil { 254 return 255 } 256 } 257 258 return 259 } 260 261 func marshalPrintableString(out *forkableWriter, s string) (err error) { 262 b := []byte(s) 263 for _, c := range b { 264 if !isPrintable(c) { 265 return StructuralError{"PrintableString contains invalid character"} 266 } 267 } 268 269 _, err = out.Write(b) 270 return 271 } 272 273 func marshalIA5String(out *forkableWriter, s string) (err error) { 274 b := []byte(s) 275 for _, c := range b { 276 if c > 127 { 277 return StructuralError{"IA5String contains invalid character"} 278 } 279 } 280 281 _, err = out.Write(b) 282 return 283 } 284 285 func marshalUTF8String(out *forkableWriter, s string) (err error) { 286 _, err = out.Write([]byte(s)) 287 return 288 } 289 290 func marshalTwoDigits(out *forkableWriter, v int) (err error) { 291 err = out.WriteByte(byte('0' + (v/10)%10)) 292 if err != nil { 293 return 294 } 295 return out.WriteByte(byte('0' + v%10)) 296 } 297 298 func marshalFourDigits(out *forkableWriter, v int) (err error) { 299 var bytes [4]byte 300 for i := range bytes { 301 bytes[3-i] = '0' + byte(v%10) 302 v /= 10 303 } 304 _, err = out.Write(bytes[:]) 305 return 306 } 307 308 func outsideUTCRange(t time.Time) bool { 309 year := t.Year() 310 return year < 1950 || year >= 2050 311 } 312 313 func marshalUTCTime(out *forkableWriter, t time.Time) (err error) { 314 year := t.Year() 315 316 switch { 317 case 1950 <= year && year < 2000: 318 err = marshalTwoDigits(out, int(year-1900)) 319 case 2000 <= year && year < 2050: 320 err = marshalTwoDigits(out, int(year-2000)) 321 default: 322 return StructuralError{"cannot represent time as UTCTime"} 323 } 324 if err != nil { 325 return 326 } 327 328 return marshalTimeCommon(out, t) 329 } 330 331 func marshalGeneralizedTime(out *forkableWriter, t time.Time) (err error) { 332 year := t.Year() 333 if year < 0 || year > 9999 { 334 return StructuralError{"cannot represent time as GeneralizedTime"} 335 } 336 if err = marshalFourDigits(out, year); err != nil { 337 return 338 } 339 340 return marshalTimeCommon(out, t) 341 } 342 343 func marshalTimeCommon(out *forkableWriter, t time.Time) (err error) { 344 _, month, day := t.Date() 345 346 err = marshalTwoDigits(out, int(month)) 347 if err != nil { 348 return 349 } 350 351 err = marshalTwoDigits(out, day) 352 if err != nil { 353 return 354 } 355 356 hour, min, sec := t.Clock() 357 358 err = marshalTwoDigits(out, hour) 359 if err != nil { 360 return 361 } 362 363 err = marshalTwoDigits(out, min) 364 if err != nil { 365 return 366 } 367 368 err = marshalTwoDigits(out, sec) 369 if err != nil { 370 return 371 } 372 373 _, offset := t.Zone() 374 375 switch { 376 case offset/60 == 0: 377 err = out.WriteByte('Z') 378 return 379 case offset > 0: 380 err = out.WriteByte('+') 381 case offset < 0: 382 err = out.WriteByte('-') 383 } 384 385 if err != nil { 386 return 387 } 388 389 offsetMinutes := offset / 60 390 if offsetMinutes < 0 { 391 offsetMinutes = -offsetMinutes 392 } 393 394 err = marshalTwoDigits(out, offsetMinutes/60) 395 if err != nil { 396 return 397 } 398 399 err = marshalTwoDigits(out, offsetMinutes%60) 400 return 401 } 402 403 func stripTagAndLength(in []byte) []byte { 404 _, offset, err := parseTagAndLength(in, 0) 405 if err != nil { 406 return in 407 } 408 return in[offset:] 409 } 410 411 func marshalBody(out *forkableWriter, value reflect.Value, params fieldParameters) (err error) { 412 switch value.Type() { 413 case flagType: 414 return nil 415 case timeType: 416 t := value.Interface().(time.Time) 417 if params.timeType == tagGeneralizedTime || outsideUTCRange(t) { 418 return marshalGeneralizedTime(out, t) 419 } else { 420 return marshalUTCTime(out, t) 421 } 422 case bitStringType: 423 return marshalBitString(out, value.Interface().(BitString)) 424 case objectIdentifierType: 425 return marshalObjectIdentifier(out, value.Interface().(ObjectIdentifier)) 426 case bigIntType: 427 return marshalBigInt(out, value.Interface().(*big.Int)) 428 } 429 430 switch v := value; v.Kind() { 431 case reflect.Bool: 432 if v.Bool() { 433 return out.WriteByte(255) 434 } else { 435 return out.WriteByte(0) 436 } 437 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 438 return marshalInt64(out, int64(v.Int())) 439 case reflect.Struct: 440 t := v.Type() 441 442 startingField := 0 443 444 // If the first element of the structure is a non-empty 445 // RawContents, then we don't bother serializing the rest. 446 if t.NumField() > 0 && t.Field(0).Type == rawContentsType { 447 s := v.Field(0) 448 if s.Len() > 0 { 449 bytes := make([]byte, s.Len()) 450 for i := 0; i < s.Len(); i++ { 451 bytes[i] = uint8(s.Index(i).Uint()) 452 } 453 /* The RawContents will contain the tag and 454 * length fields but we'll also be writing 455 * those ourselves, so we strip them out of 456 * bytes */ 457 _, err = out.Write(stripTagAndLength(bytes)) 458 return 459 } else { 460 startingField = 1 461 } 462 } 463 464 for i := startingField; i < t.NumField(); i++ { 465 var pre *forkableWriter 466 pre, out = out.fork() 467 err = marshalField(pre, v.Field(i), parseFieldParameters(t.Field(i).Tag.Get("asn1"))) 468 if err != nil { 469 return 470 } 471 } 472 return 473 case reflect.Slice: 474 sliceType := v.Type() 475 if sliceType.Elem().Kind() == reflect.Uint8 { 476 bytes := make([]byte, v.Len()) 477 for i := 0; i < v.Len(); i++ { 478 bytes[i] = uint8(v.Index(i).Uint()) 479 } 480 _, err = out.Write(bytes) 481 return 482 } 483 484 var fp fieldParameters 485 for i := 0; i < v.Len(); i++ { 486 var pre *forkableWriter 487 pre, out = out.fork() 488 err = marshalField(pre, v.Index(i), fp) 489 if err != nil { 490 return 491 } 492 } 493 return 494 case reflect.String: 495 switch params.stringType { 496 case tagIA5String: 497 return marshalIA5String(out, v.String()) 498 case tagPrintableString: 499 return marshalPrintableString(out, v.String()) 500 default: 501 return marshalUTF8String(out, v.String()) 502 } 503 } 504 505 return StructuralError{"unknown Go type"} 506 } 507 508 func marshalField(out *forkableWriter, v reflect.Value, params fieldParameters) (err error) { 509 // If the field is an interface{} then recurse into it. 510 if v.Kind() == reflect.Interface && v.Type().NumMethod() == 0 { 511 return marshalField(out, v.Elem(), params) 512 } 513 514 if v.Kind() == reflect.Slice && v.Len() == 0 && params.omitEmpty { 515 return 516 } 517 518 if params.optional && params.defaultValue != nil && canHaveDefaultValue(v.Kind()) { 519 defaultValue := reflect.New(v.Type()).Elem() 520 defaultValue.SetInt(*params.defaultValue) 521 522 if reflect.DeepEqual(v.Interface(), defaultValue.Interface()) { 523 return 524 } 525 } 526 527 // If no default value is given then the zero value for the type is 528 // assumed to be the default value. This isn't obviously the correct 529 // behaviour, but it's what Go has traditionally done. 530 if params.optional && params.defaultValue == nil { 531 if reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface()) { 532 return 533 } 534 } 535 536 if v.Type() == rawValueType { 537 rv := v.Interface().(RawValue) 538 if len(rv.FullBytes) != 0 { 539 _, err = out.Write(rv.FullBytes) 540 } else { 541 err = marshalTagAndLength(out, tagAndLength{rv.Class, rv.Tag, len(rv.Bytes), rv.IsCompound}) 542 if err != nil { 543 return 544 } 545 _, err = out.Write(rv.Bytes) 546 } 547 return 548 } 549 550 tag, isCompound, ok := getUniversalType(v.Type()) 551 if !ok { 552 err = StructuralError{fmt.Sprintf("unknown Go type: %v", v.Type())} 553 return 554 } 555 class := classUniversal 556 557 if params.timeType != 0 && tag != tagUTCTime { 558 return StructuralError{"explicit time type given to non-time member"} 559 } 560 561 if params.stringType != 0 && tag != tagPrintableString { 562 return StructuralError{"explicit string type given to non-string member"} 563 } 564 565 switch tag { 566 case tagPrintableString: 567 if params.stringType == 0 { 568 // This is a string without an explicit string type. We'll use 569 // a PrintableString if the character set in the string is 570 // sufficiently limited, otherwise we'll use a UTF8String. 571 for _, r := range v.String() { 572 if r >= utf8.RuneSelf || !isPrintable(byte(r)) { 573 if !utf8.ValidString(v.String()) { 574 return errors.New("asn1: string not valid UTF-8") 575 } 576 tag = tagUTF8String 577 break 578 } 579 } 580 } else { 581 tag = params.stringType 582 } 583 case tagUTCTime: 584 if params.timeType == tagGeneralizedTime || outsideUTCRange(v.Interface().(time.Time)) { 585 tag = tagGeneralizedTime 586 } 587 } 588 589 if params.set { 590 if tag != tagSequence { 591 return StructuralError{"non sequence tagged as set"} 592 } 593 tag = tagSet 594 } 595 596 tags, body := out.fork() 597 598 err = marshalBody(body, v, params) 599 if err != nil { 600 return 601 } 602 603 bodyLen := body.Len() 604 605 var explicitTag *forkableWriter 606 if params.explicit { 607 explicitTag, tags = tags.fork() 608 } 609 610 if !params.explicit && params.tag != nil { 611 // implicit tag. 612 tag = *params.tag 613 class = classContextSpecific 614 } 615 616 err = marshalTagAndLength(tags, tagAndLength{class, tag, bodyLen, isCompound}) 617 if err != nil { 618 return 619 } 620 621 if params.explicit { 622 err = marshalTagAndLength(explicitTag, tagAndLength{ 623 class: classContextSpecific, 624 tag: *params.tag, 625 length: bodyLen + tags.Len(), 626 isCompound: true, 627 }) 628 } 629 630 return nil 631 } 632 633 // Marshal returns the ASN.1 encoding of val. 634 // 635 // In addition to the struct tags recognised by Unmarshal, the following can be 636 // used: 637 // 638 // ia5: causes strings to be marshaled as ASN.1, IA5 strings 639 // omitempty: causes empty slices to be skipped 640 // printable: causes strings to be marshaled as ASN.1, PrintableString strings. 641 // utf8: causes strings to be marshaled as ASN.1, UTF8 strings 642 func Marshal(val interface{}) ([]byte, error) { 643 var out bytes.Buffer 644 v := reflect.ValueOf(val) 645 f := newForkableWriter() 646 err := marshalField(f, v, fieldParameters{}) 647 if err != nil { 648 return nil, err 649 } 650 _, err = f.writeTo(&out) 651 return out.Bytes(), nil 652 } 653