1 // Go support for Protocol Buffers - Google's data interchange format 2 // 3 // Copyright 2010 The Go Authors. All rights reserved. 4 // https://github.com/golang/protobuf 5 // 6 // Redistribution and use in source and binary forms, with or without 7 // modification, are permitted provided that the following conditions are 8 // met: 9 // 10 // * Redistributions of source code must retain the above copyright 11 // notice, this list of conditions and the following disclaimer. 12 // * Redistributions in binary form must reproduce the above 13 // copyright notice, this list of conditions and the following disclaimer 14 // in the documentation and/or other materials provided with the 15 // distribution. 16 // * Neither the name of Google Inc. nor the names of its 17 // contributors may be used to endorse or promote products derived from 18 // this software without specific prior written permission. 19 // 20 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 21 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 22 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 23 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 24 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 25 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 26 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 27 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 28 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 32 package proto 33 34 /* 35 * Types and routines for supporting protocol buffer extensions. 36 */ 37 38 import ( 39 "errors" 40 "fmt" 41 "reflect" 42 "strconv" 43 "sync" 44 ) 45 46 // ErrMissingExtension is the error returned by GetExtension if the named extension is not in the message. 47 var ErrMissingExtension = errors.New("proto: missing extension") 48 49 // ExtensionRange represents a range of message extensions for a protocol buffer. 50 // Used in code generated by the protocol compiler. 51 type ExtensionRange struct { 52 Start, End int32 // both inclusive 53 } 54 55 // extendableProto is an interface implemented by any protocol buffer generated by the current 56 // proto compiler that may be extended. 57 type extendableProto interface { 58 Message 59 ExtensionRangeArray() []ExtensionRange 60 extensionsWrite() map[int32]Extension 61 extensionsRead() (map[int32]Extension, sync.Locker) 62 } 63 64 // extendableProtoV1 is an interface implemented by a protocol buffer generated by the previous 65 // version of the proto compiler that may be extended. 66 type extendableProtoV1 interface { 67 Message 68 ExtensionRangeArray() []ExtensionRange 69 ExtensionMap() map[int32]Extension 70 } 71 72 // extensionAdapter is a wrapper around extendableProtoV1 that implements extendableProto. 73 type extensionAdapter struct { 74 extendableProtoV1 75 } 76 77 func (e extensionAdapter) extensionsWrite() map[int32]Extension { 78 return e.ExtensionMap() 79 } 80 81 func (e extensionAdapter) extensionsRead() (map[int32]Extension, sync.Locker) { 82 return e.ExtensionMap(), notLocker{} 83 } 84 85 // notLocker is a sync.Locker whose Lock and Unlock methods are nops. 86 type notLocker struct{} 87 88 func (n notLocker) Lock() {} 89 func (n notLocker) Unlock() {} 90 91 // extendable returns the extendableProto interface for the given generated proto message. 92 // If the proto message has the old extension format, it returns a wrapper that implements 93 // the extendableProto interface. 94 func extendable(p interface{}) (extendableProto, bool) { 95 if ep, ok := p.(extendableProto); ok { 96 return ep, ok 97 } 98 if ep, ok := p.(extendableProtoV1); ok { 99 return extensionAdapter{ep}, ok 100 } 101 return nil, false 102 } 103 104 // XXX_InternalExtensions is an internal representation of proto extensions. 105 // 106 // Each generated message struct type embeds an anonymous XXX_InternalExtensions field, 107 // thus gaining the unexported 'extensions' method, which can be called only from the proto package. 108 // 109 // The methods of XXX_InternalExtensions are not concurrency safe in general, 110 // but calls to logically read-only methods such as has and get may be executed concurrently. 111 type XXX_InternalExtensions struct { 112 // The struct must be indirect so that if a user inadvertently copies a 113 // generated message and its embedded XXX_InternalExtensions, they 114 // avoid the mayhem of a copied mutex. 115 // 116 // The mutex serializes all logically read-only operations to p.extensionMap. 117 // It is up to the client to ensure that write operations to p.extensionMap are 118 // mutually exclusive with other accesses. 119 p *struct { 120 mu sync.Mutex 121 extensionMap map[int32]Extension 122 } 123 } 124 125 // extensionsWrite returns the extension map, creating it on first use. 126 func (e *XXX_InternalExtensions) extensionsWrite() map[int32]Extension { 127 if e.p == nil { 128 e.p = new(struct { 129 mu sync.Mutex 130 extensionMap map[int32]Extension 131 }) 132 e.p.extensionMap = make(map[int32]Extension) 133 } 134 return e.p.extensionMap 135 } 136 137 // extensionsRead returns the extensions map for read-only use. It may be nil. 138 // The caller must hold the returned mutex's lock when accessing Elements within the map. 139 func (e *XXX_InternalExtensions) extensionsRead() (map[int32]Extension, sync.Locker) { 140 if e.p == nil { 141 return nil, nil 142 } 143 return e.p.extensionMap, &e.p.mu 144 } 145 146 var extendableProtoType = reflect.TypeOf((*extendableProto)(nil)).Elem() 147 var extendableProtoV1Type = reflect.TypeOf((*extendableProtoV1)(nil)).Elem() 148 149 // ExtensionDesc represents an extension specification. 150 // Used in generated code from the protocol compiler. 151 type ExtensionDesc struct { 152 ExtendedType Message // nil pointer to the type that is being extended 153 ExtensionType interface{} // nil pointer to the extension type 154 Field int32 // field number 155 Name string // fully-qualified name of extension, for text formatting 156 Tag string // protobuf tag style 157 Filename string // name of the file in which the extension is defined 158 } 159 160 func (ed *ExtensionDesc) repeated() bool { 161 t := reflect.TypeOf(ed.ExtensionType) 162 return t.Kind() == reflect.Slice && t.Elem().Kind() != reflect.Uint8 163 } 164 165 // Extension represents an extension in a message. 166 type Extension struct { 167 // When an extension is stored in a message using SetExtension 168 // only desc and value are set. When the message is marshaled 169 // enc will be set to the encoded form of the message. 170 // 171 // When a message is unmarshaled and contains extensions, each 172 // extension will have only enc set. When such an extension is 173 // accessed using GetExtension (or GetExtensions) desc and value 174 // will be set. 175 desc *ExtensionDesc 176 value interface{} 177 enc []byte 178 } 179 180 // SetRawExtension is for testing only. 181 func SetRawExtension(base Message, id int32, b []byte) { 182 epb, ok := extendable(base) 183 if !ok { 184 return 185 } 186 extmap := epb.extensionsWrite() 187 extmap[id] = Extension{enc: b} 188 } 189 190 // isExtensionField returns true iff the given field number is in an extension range. 191 func isExtensionField(pb extendableProto, field int32) bool { 192 for _, er := range pb.ExtensionRangeArray() { 193 if er.Start <= field && field <= er.End { 194 return true 195 } 196 } 197 return false 198 } 199 200 // checkExtensionTypes checks that the given extension is valid for pb. 201 func checkExtensionTypes(pb extendableProto, extension *ExtensionDesc) error { 202 var pbi interface{} = pb 203 // Check the extended type. 204 if ea, ok := pbi.(extensionAdapter); ok { 205 pbi = ea.extendableProtoV1 206 } 207 if a, b := reflect.TypeOf(pbi), reflect.TypeOf(extension.ExtendedType); a != b { 208 return errors.New("proto: bad extended type; " + b.String() + " does not extend " + a.String()) 209 } 210 // Check the range. 211 if !isExtensionField(pb, extension.Field) { 212 return errors.New("proto: bad extension number; not in declared ranges") 213 } 214 return nil 215 } 216 217 // extPropKey is sufficient to uniquely identify an extension. 218 type extPropKey struct { 219 base reflect.Type 220 field int32 221 } 222 223 var extProp = struct { 224 sync.RWMutex 225 m map[extPropKey]*Properties 226 }{ 227 m: make(map[extPropKey]*Properties), 228 } 229 230 func extensionProperties(ed *ExtensionDesc) *Properties { 231 key := extPropKey{base: reflect.TypeOf(ed.ExtendedType), field: ed.Field} 232 233 extProp.RLock() 234 if prop, ok := extProp.m[key]; ok { 235 extProp.RUnlock() 236 return prop 237 } 238 extProp.RUnlock() 239 240 extProp.Lock() 241 defer extProp.Unlock() 242 // Check again. 243 if prop, ok := extProp.m[key]; ok { 244 return prop 245 } 246 247 prop := new(Properties) 248 prop.Init(reflect.TypeOf(ed.ExtensionType), "unknown_name", ed.Tag, nil) 249 extProp.m[key] = prop 250 return prop 251 } 252 253 // encode encodes any unmarshaled (unencoded) extensions in e. 254 func encodeExtensions(e *XXX_InternalExtensions) error { 255 m, mu := e.extensionsRead() 256 if m == nil { 257 return nil // fast path 258 } 259 mu.Lock() 260 defer mu.Unlock() 261 return encodeExtensionsMap(m) 262 } 263 264 // encode encodes any unmarshaled (unencoded) extensions in e. 265 func encodeExtensionsMap(m map[int32]Extension) error { 266 for k, e := range m { 267 if e.value == nil || e.desc == nil { 268 // Extension is only in its encoded form. 269 continue 270 } 271 272 // We don't skip extensions that have an encoded form set, 273 // because the extension value may have been mutated after 274 // the last time this function was called. 275 276 et := reflect.TypeOf(e.desc.ExtensionType) 277 props := extensionProperties(e.desc) 278 279 p := NewBuffer(nil) 280 // If e.value has type T, the encoder expects a *struct{ X T }. 281 // Pass a *T with a zero field and hope it all works out. 282 x := reflect.New(et) 283 x.Elem().Set(reflect.ValueOf(e.value)) 284 if err := props.enc(p, props, toStructPointer(x)); err != nil { 285 return err 286 } 287 e.enc = p.buf 288 m[k] = e 289 } 290 return nil 291 } 292 293 func extensionsSize(e *XXX_InternalExtensions) (n int) { 294 m, mu := e.extensionsRead() 295 if m == nil { 296 return 0 297 } 298 mu.Lock() 299 defer mu.Unlock() 300 return extensionsMapSize(m) 301 } 302 303 func extensionsMapSize(m map[int32]Extension) (n int) { 304 for _, e := range m { 305 if e.value == nil || e.desc == nil { 306 // Extension is only in its encoded form. 307 n += len(e.enc) 308 continue 309 } 310 311 // We don't skip extensions that have an encoded form set, 312 // because the extension value may have been mutated after 313 // the last time this function was called. 314 315 et := reflect.TypeOf(e.desc.ExtensionType) 316 props := extensionProperties(e.desc) 317 318 // If e.value has type T, the encoder expects a *struct{ X T }. 319 // Pass a *T with a zero field and hope it all works out. 320 x := reflect.New(et) 321 x.Elem().Set(reflect.ValueOf(e.value)) 322 n += props.size(props, toStructPointer(x)) 323 } 324 return 325 } 326 327 // HasExtension returns whether the given extension is present in pb. 328 func HasExtension(pb Message, extension *ExtensionDesc) bool { 329 // TODO: Check types, field numbers, etc.? 330 epb, ok := extendable(pb) 331 if !ok { 332 return false 333 } 334 extmap, mu := epb.extensionsRead() 335 if extmap == nil { 336 return false 337 } 338 mu.Lock() 339 _, ok = extmap[extension.Field] 340 mu.Unlock() 341 return ok 342 } 343 344 // ClearExtension removes the given extension from pb. 345 func ClearExtension(pb Message, extension *ExtensionDesc) { 346 epb, ok := extendable(pb) 347 if !ok { 348 return 349 } 350 // TODO: Check types, field numbers, etc.? 351 extmap := epb.extensionsWrite() 352 delete(extmap, extension.Field) 353 } 354 355 // GetExtension parses and returns the given extension of pb. 356 // If the extension is not present and has no default value it returns ErrMissingExtension. 357 func GetExtension(pb Message, extension *ExtensionDesc) (interface{}, error) { 358 epb, ok := extendable(pb) 359 if !ok { 360 return nil, errors.New("proto: not an extendable proto") 361 } 362 363 if err := checkExtensionTypes(epb, extension); err != nil { 364 return nil, err 365 } 366 367 emap, mu := epb.extensionsRead() 368 if emap == nil { 369 return defaultExtensionValue(extension) 370 } 371 mu.Lock() 372 defer mu.Unlock() 373 e, ok := emap[extension.Field] 374 if !ok { 375 // defaultExtensionValue returns the default value or 376 // ErrMissingExtension if there is no default. 377 return defaultExtensionValue(extension) 378 } 379 380 if e.value != nil { 381 // Already decoded. Check the descriptor, though. 382 if e.desc != extension { 383 // This shouldn't happen. If it does, it means that 384 // GetExtension was called twice with two different 385 // descriptors with the same field number. 386 return nil, errors.New("proto: descriptor conflict") 387 } 388 return e.value, nil 389 } 390 391 v, err := decodeExtension(e.enc, extension) 392 if err != nil { 393 return nil, err 394 } 395 396 // Remember the decoded version and drop the encoded version. 397 // That way it is safe to mutate what we return. 398 e.value = v 399 e.desc = extension 400 e.enc = nil 401 emap[extension.Field] = e 402 return e.value, nil 403 } 404 405 // defaultExtensionValue returns the default value for extension. 406 // If no default for an extension is defined ErrMissingExtension is returned. 407 func defaultExtensionValue(extension *ExtensionDesc) (interface{}, error) { 408 t := reflect.TypeOf(extension.ExtensionType) 409 props := extensionProperties(extension) 410 411 sf, _, err := fieldDefault(t, props) 412 if err != nil { 413 return nil, err 414 } 415 416 if sf == nil || sf.value == nil { 417 // There is no default value. 418 return nil, ErrMissingExtension 419 } 420 421 if t.Kind() != reflect.Ptr { 422 // We do not need to return a Ptr, we can directly return sf.value. 423 return sf.value, nil 424 } 425 426 // We need to return an interface{} that is a pointer to sf.value. 427 value := reflect.New(t).Elem() 428 value.Set(reflect.New(value.Type().Elem())) 429 if sf.kind == reflect.Int32 { 430 // We may have an int32 or an enum, but the underlying data is int32. 431 // Since we can't set an int32 into a non int32 reflect.value directly 432 // set it as a int32. 433 value.Elem().SetInt(int64(sf.value.(int32))) 434 } else { 435 value.Elem().Set(reflect.ValueOf(sf.value)) 436 } 437 return value.Interface(), nil 438 } 439 440 // decodeExtension decodes an extension encoded in b. 441 func decodeExtension(b []byte, extension *ExtensionDesc) (interface{}, error) { 442 o := NewBuffer(b) 443 444 t := reflect.TypeOf(extension.ExtensionType) 445 446 props := extensionProperties(extension) 447 448 // t is a pointer to a struct, pointer to basic type or a slice. 449 // Allocate a "field" to store the pointer/slice itself; the 450 // pointer/slice will be stored here. We pass 451 // the address of this field to props.dec. 452 // This passes a zero field and a *t and lets props.dec 453 // interpret it as a *struct{ x t }. 454 value := reflect.New(t).Elem() 455 456 for { 457 // Discard wire type and field number varint. It isn't needed. 458 if _, err := o.DecodeVarint(); err != nil { 459 return nil, err 460 } 461 462 if err := props.dec(o, props, toStructPointer(value.Addr())); err != nil { 463 return nil, err 464 } 465 466 if o.index >= len(o.buf) { 467 break 468 } 469 } 470 return value.Interface(), nil 471 } 472 473 // GetExtensions returns a slice of the extensions present in pb that are also listed in es. 474 // The returned slice has the same length as es; missing extensions will appear as nil elements. 475 func GetExtensions(pb Message, es []*ExtensionDesc) (extensions []interface{}, err error) { 476 epb, ok := extendable(pb) 477 if !ok { 478 return nil, errors.New("proto: not an extendable proto") 479 } 480 extensions = make([]interface{}, len(es)) 481 for i, e := range es { 482 extensions[i], err = GetExtension(epb, e) 483 if err == ErrMissingExtension { 484 err = nil 485 } 486 if err != nil { 487 return 488 } 489 } 490 return 491 } 492 493 // ExtensionDescs returns a new slice containing pb's extension descriptors, in undefined order. 494 // For non-registered extensions, ExtensionDescs returns an incomplete descriptor containing 495 // just the Field field, which defines the extension's field number. 496 func ExtensionDescs(pb Message) ([]*ExtensionDesc, error) { 497 epb, ok := extendable(pb) 498 if !ok { 499 return nil, fmt.Errorf("proto: %T is not an extendable proto.Message", pb) 500 } 501 registeredExtensions := RegisteredExtensions(pb) 502 503 emap, mu := epb.extensionsRead() 504 if emap == nil { 505 return nil, nil 506 } 507 mu.Lock() 508 defer mu.Unlock() 509 extensions := make([]*ExtensionDesc, 0, len(emap)) 510 for extid, e := range emap { 511 desc := e.desc 512 if desc == nil { 513 desc = registeredExtensions[extid] 514 if desc == nil { 515 desc = &ExtensionDesc{Field: extid} 516 } 517 } 518 519 extensions = append(extensions, desc) 520 } 521 return extensions, nil 522 } 523 524 // SetExtension sets the specified extension of pb to the specified value. 525 func SetExtension(pb Message, extension *ExtensionDesc, value interface{}) error { 526 epb, ok := extendable(pb) 527 if !ok { 528 return errors.New("proto: not an extendable proto") 529 } 530 if err := checkExtensionTypes(epb, extension); err != nil { 531 return err 532 } 533 typ := reflect.TypeOf(extension.ExtensionType) 534 if typ != reflect.TypeOf(value) { 535 return errors.New("proto: bad extension value type") 536 } 537 // nil extension values need to be caught early, because the 538 // encoder can't distinguish an ErrNil due to a nil extension 539 // from an ErrNil due to a missing field. Extensions are 540 // always optional, so the encoder would just swallow the error 541 // and drop all the extensions from the encoded message. 542 if reflect.ValueOf(value).IsNil() { 543 return fmt.Errorf("proto: SetExtension called with nil value of type %T", value) 544 } 545 546 extmap := epb.extensionsWrite() 547 extmap[extension.Field] = Extension{desc: extension, value: value} 548 return nil 549 } 550 551 // ClearAllExtensions clears all extensions from pb. 552 func ClearAllExtensions(pb Message) { 553 epb, ok := extendable(pb) 554 if !ok { 555 return 556 } 557 m := epb.extensionsWrite() 558 for k := range m { 559 delete(m, k) 560 } 561 } 562 563 // A global registry of extensions. 564 // The generated code will register the generated descriptors by calling RegisterExtension. 565 566 var extensionMaps = make(map[reflect.Type]map[int32]*ExtensionDesc) 567 568 // RegisterExtension is called from the generated code. 569 func RegisterExtension(desc *ExtensionDesc) { 570 st := reflect.TypeOf(desc.ExtendedType).Elem() 571 m := extensionMaps[st] 572 if m == nil { 573 m = make(map[int32]*ExtensionDesc) 574 extensionMaps[st] = m 575 } 576 if _, ok := m[desc.Field]; ok { 577 panic("proto: duplicate extension registered: " + st.String() + " " + strconv.Itoa(int(desc.Field))) 578 } 579 m[desc.Field] = desc 580 } 581 582 // RegisteredExtensions returns a map of the registered extensions of a 583 // protocol buffer struct, indexed by the extension number. 584 // The argument pb should be a nil pointer to the struct type. 585 func RegisteredExtensions(pb Message) map[int32]*ExtensionDesc { 586 return extensionMaps[reflect.TypeOf(pb).Elem()] 587 } 588