1 // Go support for Protocol Buffers - Google's data interchange format 2 // 3 // Copyright 2010 The Go Authors. All rights reserved. 4 // https://github.com/golang/protobuf 5 // 6 // Redistribution and use in source and binary forms, with or without 7 // modification, are permitted provided that the following conditions are 8 // met: 9 // 10 // * Redistributions of source code must retain the above copyright 11 // notice, this list of conditions and the following disclaimer. 12 // * Redistributions in binary form must reproduce the above 13 // copyright notice, this list of conditions and the following disclaimer 14 // in the documentation and/or other materials provided with the 15 // distribution. 16 // * Neither the name of Google Inc. nor the names of its 17 // contributors may be used to endorse or promote products derived from 18 // this software without specific prior written permission. 19 // 20 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 21 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 22 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 23 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 24 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 25 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 26 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 27 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 28 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 32 package proto 33 34 /* 35 * Types and routines for supporting protocol buffer extensions. 36 */ 37 38 import ( 39 "errors" 40 "fmt" 41 "io" 42 "reflect" 43 "strconv" 44 "sync" 45 ) 46 47 // ErrMissingExtension is the error returned by GetExtension if the named extension is not in the message. 48 var ErrMissingExtension = errors.New("proto: missing extension") 49 50 // ExtensionRange represents a range of message extensions for a protocol buffer. 51 // Used in code generated by the protocol compiler. 52 type ExtensionRange struct { 53 Start, End int32 // both inclusive 54 } 55 56 // extendableProto is an interface implemented by any protocol buffer generated by the current 57 // proto compiler that may be extended. 58 type extendableProto interface { 59 Message 60 ExtensionRangeArray() []ExtensionRange 61 extensionsWrite() map[int32]Extension 62 extensionsRead() (map[int32]Extension, sync.Locker) 63 } 64 65 // extendableProtoV1 is an interface implemented by a protocol buffer generated by the previous 66 // version of the proto compiler that may be extended. 67 type extendableProtoV1 interface { 68 Message 69 ExtensionRangeArray() []ExtensionRange 70 ExtensionMap() map[int32]Extension 71 } 72 73 // extensionAdapter is a wrapper around extendableProtoV1 that implements extendableProto. 74 type extensionAdapter struct { 75 extendableProtoV1 76 } 77 78 func (e extensionAdapter) extensionsWrite() map[int32]Extension { 79 return e.ExtensionMap() 80 } 81 82 func (e extensionAdapter) extensionsRead() (map[int32]Extension, sync.Locker) { 83 return e.ExtensionMap(), notLocker{} 84 } 85 86 // notLocker is a sync.Locker whose Lock and Unlock methods are nops. 87 type notLocker struct{} 88 89 func (n notLocker) Lock() {} 90 func (n notLocker) Unlock() {} 91 92 // extendable returns the extendableProto interface for the given generated proto message. 93 // If the proto message has the old extension format, it returns a wrapper that implements 94 // the extendableProto interface. 95 func extendable(p interface{}) (extendableProto, error) { 96 switch p := p.(type) { 97 case extendableProto: 98 if isNilPtr(p) { 99 return nil, fmt.Errorf("proto: nil %T is not extendable", p) 100 } 101 return p, nil 102 case extendableProtoV1: 103 if isNilPtr(p) { 104 return nil, fmt.Errorf("proto: nil %T is not extendable", p) 105 } 106 return extensionAdapter{p}, nil 107 } 108 // Don't allocate a specific error containing %T: 109 // this is the hot path for Clone and MarshalText. 110 return nil, errNotExtendable 111 } 112 113 var errNotExtendable = errors.New("proto: not an extendable proto.Message") 114 115 func isNilPtr(x interface{}) bool { 116 v := reflect.ValueOf(x) 117 return v.Kind() == reflect.Ptr && v.IsNil() 118 } 119 120 // XXX_InternalExtensions is an internal representation of proto extensions. 121 // 122 // Each generated message struct type embeds an anonymous XXX_InternalExtensions field, 123 // thus gaining the unexported 'extensions' method, which can be called only from the proto package. 124 // 125 // The methods of XXX_InternalExtensions are not concurrency safe in general, 126 // but calls to logically read-only methods such as has and get may be executed concurrently. 127 type XXX_InternalExtensions struct { 128 // The struct must be indirect so that if a user inadvertently copies a 129 // generated message and its embedded XXX_InternalExtensions, they 130 // avoid the mayhem of a copied mutex. 131 // 132 // The mutex serializes all logically read-only operations to p.extensionMap. 133 // It is up to the client to ensure that write operations to p.extensionMap are 134 // mutually exclusive with other accesses. 135 p *struct { 136 mu sync.Mutex 137 extensionMap map[int32]Extension 138 } 139 } 140 141 // extensionsWrite returns the extension map, creating it on first use. 142 func (e *XXX_InternalExtensions) extensionsWrite() map[int32]Extension { 143 if e.p == nil { 144 e.p = new(struct { 145 mu sync.Mutex 146 extensionMap map[int32]Extension 147 }) 148 e.p.extensionMap = make(map[int32]Extension) 149 } 150 return e.p.extensionMap 151 } 152 153 // extensionsRead returns the extensions map for read-only use. It may be nil. 154 // The caller must hold the returned mutex's lock when accessing Elements within the map. 155 func (e *XXX_InternalExtensions) extensionsRead() (map[int32]Extension, sync.Locker) { 156 if e.p == nil { 157 return nil, nil 158 } 159 return e.p.extensionMap, &e.p.mu 160 } 161 162 // ExtensionDesc represents an extension specification. 163 // Used in generated code from the protocol compiler. 164 type ExtensionDesc struct { 165 ExtendedType Message // nil pointer to the type that is being extended 166 ExtensionType interface{} // nil pointer to the extension type 167 Field int32 // field number 168 Name string // fully-qualified name of extension, for text formatting 169 Tag string // protobuf tag style 170 Filename string // name of the file in which the extension is defined 171 } 172 173 func (ed *ExtensionDesc) repeated() bool { 174 t := reflect.TypeOf(ed.ExtensionType) 175 return t.Kind() == reflect.Slice && t.Elem().Kind() != reflect.Uint8 176 } 177 178 // Extension represents an extension in a message. 179 type Extension struct { 180 // When an extension is stored in a message using SetExtension 181 // only desc and value are set. When the message is marshaled 182 // enc will be set to the encoded form of the message. 183 // 184 // When a message is unmarshaled and contains extensions, each 185 // extension will have only enc set. When such an extension is 186 // accessed using GetExtension (or GetExtensions) desc and value 187 // will be set. 188 desc *ExtensionDesc 189 value interface{} 190 enc []byte 191 } 192 193 // SetRawExtension is for testing only. 194 func SetRawExtension(base Message, id int32, b []byte) { 195 epb, err := extendable(base) 196 if err != nil { 197 return 198 } 199 extmap := epb.extensionsWrite() 200 extmap[id] = Extension{enc: b} 201 } 202 203 // isExtensionField returns true iff the given field number is in an extension range. 204 func isExtensionField(pb extendableProto, field int32) bool { 205 for _, er := range pb.ExtensionRangeArray() { 206 if er.Start <= field && field <= er.End { 207 return true 208 } 209 } 210 return false 211 } 212 213 // checkExtensionTypes checks that the given extension is valid for pb. 214 func checkExtensionTypes(pb extendableProto, extension *ExtensionDesc) error { 215 var pbi interface{} = pb 216 // Check the extended type. 217 if ea, ok := pbi.(extensionAdapter); ok { 218 pbi = ea.extendableProtoV1 219 } 220 if a, b := reflect.TypeOf(pbi), reflect.TypeOf(extension.ExtendedType); a != b { 221 return fmt.Errorf("proto: bad extended type; %v does not extend %v", b, a) 222 } 223 // Check the range. 224 if !isExtensionField(pb, extension.Field) { 225 return errors.New("proto: bad extension number; not in declared ranges") 226 } 227 return nil 228 } 229 230 // extPropKey is sufficient to uniquely identify an extension. 231 type extPropKey struct { 232 base reflect.Type 233 field int32 234 } 235 236 var extProp = struct { 237 sync.RWMutex 238 m map[extPropKey]*Properties 239 }{ 240 m: make(map[extPropKey]*Properties), 241 } 242 243 func extensionProperties(ed *ExtensionDesc) *Properties { 244 key := extPropKey{base: reflect.TypeOf(ed.ExtendedType), field: ed.Field} 245 246 extProp.RLock() 247 if prop, ok := extProp.m[key]; ok { 248 extProp.RUnlock() 249 return prop 250 } 251 extProp.RUnlock() 252 253 extProp.Lock() 254 defer extProp.Unlock() 255 // Check again. 256 if prop, ok := extProp.m[key]; ok { 257 return prop 258 } 259 260 prop := new(Properties) 261 prop.Init(reflect.TypeOf(ed.ExtensionType), "unknown_name", ed.Tag, nil) 262 extProp.m[key] = prop 263 return prop 264 } 265 266 // HasExtension returns whether the given extension is present in pb. 267 func HasExtension(pb Message, extension *ExtensionDesc) bool { 268 // TODO: Check types, field numbers, etc.? 269 epb, err := extendable(pb) 270 if err != nil { 271 return false 272 } 273 extmap, mu := epb.extensionsRead() 274 if extmap == nil { 275 return false 276 } 277 mu.Lock() 278 _, ok := extmap[extension.Field] 279 mu.Unlock() 280 return ok 281 } 282 283 // ClearExtension removes the given extension from pb. 284 func ClearExtension(pb Message, extension *ExtensionDesc) { 285 epb, err := extendable(pb) 286 if err != nil { 287 return 288 } 289 // TODO: Check types, field numbers, etc.? 290 extmap := epb.extensionsWrite() 291 delete(extmap, extension.Field) 292 } 293 294 // GetExtension retrieves a proto2 extended field from pb. 295 // 296 // If the descriptor is type complete (i.e., ExtensionDesc.ExtensionType is non-nil), 297 // then GetExtension parses the encoded field and returns a Go value of the specified type. 298 // If the field is not present, then the default value is returned (if one is specified), 299 // otherwise ErrMissingExtension is reported. 300 // 301 // If the descriptor is not type complete (i.e., ExtensionDesc.ExtensionType is nil), 302 // then GetExtension returns the raw encoded bytes of the field extension. 303 func GetExtension(pb Message, extension *ExtensionDesc) (interface{}, error) { 304 epb, err := extendable(pb) 305 if err != nil { 306 return nil, err 307 } 308 309 if extension.ExtendedType != nil { 310 // can only check type if this is a complete descriptor 311 if err := checkExtensionTypes(epb, extension); err != nil { 312 return nil, err 313 } 314 } 315 316 emap, mu := epb.extensionsRead() 317 if emap == nil { 318 return defaultExtensionValue(extension) 319 } 320 mu.Lock() 321 defer mu.Unlock() 322 e, ok := emap[extension.Field] 323 if !ok { 324 // defaultExtensionValue returns the default value or 325 // ErrMissingExtension if there is no default. 326 return defaultExtensionValue(extension) 327 } 328 329 if e.value != nil { 330 // Already decoded. Check the descriptor, though. 331 if e.desc != extension { 332 // This shouldn't happen. If it does, it means that 333 // GetExtension was called twice with two different 334 // descriptors with the same field number. 335 return nil, errors.New("proto: descriptor conflict") 336 } 337 return e.value, nil 338 } 339 340 if extension.ExtensionType == nil { 341 // incomplete descriptor 342 return e.enc, nil 343 } 344 345 v, err := decodeExtension(e.enc, extension) 346 if err != nil { 347 return nil, err 348 } 349 350 // Remember the decoded version and drop the encoded version. 351 // That way it is safe to mutate what we return. 352 e.value = v 353 e.desc = extension 354 e.enc = nil 355 emap[extension.Field] = e 356 return e.value, nil 357 } 358 359 // defaultExtensionValue returns the default value for extension. 360 // If no default for an extension is defined ErrMissingExtension is returned. 361 func defaultExtensionValue(extension *ExtensionDesc) (interface{}, error) { 362 if extension.ExtensionType == nil { 363 // incomplete descriptor, so no default 364 return nil, ErrMissingExtension 365 } 366 367 t := reflect.TypeOf(extension.ExtensionType) 368 props := extensionProperties(extension) 369 370 sf, _, err := fieldDefault(t, props) 371 if err != nil { 372 return nil, err 373 } 374 375 if sf == nil || sf.value == nil { 376 // There is no default value. 377 return nil, ErrMissingExtension 378 } 379 380 if t.Kind() != reflect.Ptr { 381 // We do not need to return a Ptr, we can directly return sf.value. 382 return sf.value, nil 383 } 384 385 // We need to return an interface{} that is a pointer to sf.value. 386 value := reflect.New(t).Elem() 387 value.Set(reflect.New(value.Type().Elem())) 388 if sf.kind == reflect.Int32 { 389 // We may have an int32 or an enum, but the underlying data is int32. 390 // Since we can't set an int32 into a non int32 reflect.value directly 391 // set it as a int32. 392 value.Elem().SetInt(int64(sf.value.(int32))) 393 } else { 394 value.Elem().Set(reflect.ValueOf(sf.value)) 395 } 396 return value.Interface(), nil 397 } 398 399 // decodeExtension decodes an extension encoded in b. 400 func decodeExtension(b []byte, extension *ExtensionDesc) (interface{}, error) { 401 t := reflect.TypeOf(extension.ExtensionType) 402 unmarshal := typeUnmarshaler(t, extension.Tag) 403 404 // t is a pointer to a struct, pointer to basic type or a slice. 405 // Allocate space to store the pointer/slice. 406 value := reflect.New(t).Elem() 407 408 var err error 409 for { 410 x, n := decodeVarint(b) 411 if n == 0 { 412 return nil, io.ErrUnexpectedEOF 413 } 414 b = b[n:] 415 wire := int(x) & 7 416 417 b, err = unmarshal(b, valToPointer(value.Addr()), wire) 418 if err != nil { 419 return nil, err 420 } 421 422 if len(b) == 0 { 423 break 424 } 425 } 426 return value.Interface(), nil 427 } 428 429 // GetExtensions returns a slice of the extensions present in pb that are also listed in es. 430 // The returned slice has the same length as es; missing extensions will appear as nil elements. 431 func GetExtensions(pb Message, es []*ExtensionDesc) (extensions []interface{}, err error) { 432 epb, err := extendable(pb) 433 if err != nil { 434 return nil, err 435 } 436 extensions = make([]interface{}, len(es)) 437 for i, e := range es { 438 extensions[i], err = GetExtension(epb, e) 439 if err == ErrMissingExtension { 440 err = nil 441 } 442 if err != nil { 443 return 444 } 445 } 446 return 447 } 448 449 // ExtensionDescs returns a new slice containing pb's extension descriptors, in undefined order. 450 // For non-registered extensions, ExtensionDescs returns an incomplete descriptor containing 451 // just the Field field, which defines the extension's field number. 452 func ExtensionDescs(pb Message) ([]*ExtensionDesc, error) { 453 epb, err := extendable(pb) 454 if err != nil { 455 return nil, err 456 } 457 registeredExtensions := RegisteredExtensions(pb) 458 459 emap, mu := epb.extensionsRead() 460 if emap == nil { 461 return nil, nil 462 } 463 mu.Lock() 464 defer mu.Unlock() 465 extensions := make([]*ExtensionDesc, 0, len(emap)) 466 for extid, e := range emap { 467 desc := e.desc 468 if desc == nil { 469 desc = registeredExtensions[extid] 470 if desc == nil { 471 desc = &ExtensionDesc{Field: extid} 472 } 473 } 474 475 extensions = append(extensions, desc) 476 } 477 return extensions, nil 478 } 479 480 // SetExtension sets the specified extension of pb to the specified value. 481 func SetExtension(pb Message, extension *ExtensionDesc, value interface{}) error { 482 epb, err := extendable(pb) 483 if err != nil { 484 return err 485 } 486 if err := checkExtensionTypes(epb, extension); err != nil { 487 return err 488 } 489 typ := reflect.TypeOf(extension.ExtensionType) 490 if typ != reflect.TypeOf(value) { 491 return errors.New("proto: bad extension value type") 492 } 493 // nil extension values need to be caught early, because the 494 // encoder can't distinguish an ErrNil due to a nil extension 495 // from an ErrNil due to a missing field. Extensions are 496 // always optional, so the encoder would just swallow the error 497 // and drop all the extensions from the encoded message. 498 if reflect.ValueOf(value).IsNil() { 499 return fmt.Errorf("proto: SetExtension called with nil value of type %T", value) 500 } 501 502 extmap := epb.extensionsWrite() 503 extmap[extension.Field] = Extension{desc: extension, value: value} 504 return nil 505 } 506 507 // ClearAllExtensions clears all extensions from pb. 508 func ClearAllExtensions(pb Message) { 509 epb, err := extendable(pb) 510 if err != nil { 511 return 512 } 513 m := epb.extensionsWrite() 514 for k := range m { 515 delete(m, k) 516 } 517 } 518 519 // A global registry of extensions. 520 // The generated code will register the generated descriptors by calling RegisterExtension. 521 522 var extensionMaps = make(map[reflect.Type]map[int32]*ExtensionDesc) 523 524 // RegisterExtension is called from the generated code. 525 func RegisterExtension(desc *ExtensionDesc) { 526 st := reflect.TypeOf(desc.ExtendedType).Elem() 527 m := extensionMaps[st] 528 if m == nil { 529 m = make(map[int32]*ExtensionDesc) 530 extensionMaps[st] = m 531 } 532 if _, ok := m[desc.Field]; ok { 533 panic("proto: duplicate extension registered: " + st.String() + " " + strconv.Itoa(int(desc.Field))) 534 } 535 m[desc.Field] = desc 536 } 537 538 // RegisteredExtensions returns a map of the registered extensions of a 539 // protocol buffer struct, indexed by the extension number. 540 // The argument pb should be a nil pointer to the struct type. 541 func RegisteredExtensions(pb Message) map[int32]*ExtensionDesc { 542 return extensionMaps[reflect.TypeOf(pb).Elem()] 543 } 544