Home | History | Annotate | Download | only in proto
      1 // Go support for Protocol Buffers - Google's data interchange format
      2 //
      3 // Copyright 2010 The Go Authors.  All rights reserved.
      4 // https://github.com/golang/protobuf
      5 //
      6 // Redistribution and use in source and binary forms, with or without
      7 // modification, are permitted provided that the following conditions are
      8 // met:
      9 //
     10 //     * Redistributions of source code must retain the above copyright
     11 // notice, this list of conditions and the following disclaimer.
     12 //     * Redistributions in binary form must reproduce the above
     13 // copyright notice, this list of conditions and the following disclaimer
     14 // in the documentation and/or other materials provided with the
     15 // distribution.
     16 //     * Neither the name of Google Inc. nor the names of its
     17 // contributors may be used to endorse or promote products derived from
     18 // this software without specific prior written permission.
     19 //
     20 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
     21 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
     22 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
     23 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
     24 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
     25 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
     26 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
     27 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
     28 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
     29 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
     30 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
     31 
     32 package proto
     33 
     34 /*
     35  * Types and routines for supporting protocol buffer extensions.
     36  */
     37 
     38 import (
     39 	"errors"
     40 	"fmt"
     41 	"io"
     42 	"reflect"
     43 	"strconv"
     44 	"sync"
     45 )
     46 
     47 // ErrMissingExtension is the error returned by GetExtension if the named extension is not in the message.
     48 var ErrMissingExtension = errors.New("proto: missing extension")
     49 
     50 // ExtensionRange represents a range of message extensions for a protocol buffer.
     51 // Used in code generated by the protocol compiler.
     52 type ExtensionRange struct {
     53 	Start, End int32 // both inclusive
     54 }
     55 
     56 // extendableProto is an interface implemented by any protocol buffer generated by the current
     57 // proto compiler that may be extended.
     58 type extendableProto interface {
     59 	Message
     60 	ExtensionRangeArray() []ExtensionRange
     61 	extensionsWrite() map[int32]Extension
     62 	extensionsRead() (map[int32]Extension, sync.Locker)
     63 }
     64 
     65 // extendableProtoV1 is an interface implemented by a protocol buffer generated by the previous
     66 // version of the proto compiler that may be extended.
     67 type extendableProtoV1 interface {
     68 	Message
     69 	ExtensionRangeArray() []ExtensionRange
     70 	ExtensionMap() map[int32]Extension
     71 }
     72 
     73 // extensionAdapter is a wrapper around extendableProtoV1 that implements extendableProto.
     74 type extensionAdapter struct {
     75 	extendableProtoV1
     76 }
     77 
     78 func (e extensionAdapter) extensionsWrite() map[int32]Extension {
     79 	return e.ExtensionMap()
     80 }
     81 
     82 func (e extensionAdapter) extensionsRead() (map[int32]Extension, sync.Locker) {
     83 	return e.ExtensionMap(), notLocker{}
     84 }
     85 
     86 // notLocker is a sync.Locker whose Lock and Unlock methods are nops.
     87 type notLocker struct{}
     88 
     89 func (n notLocker) Lock()   {}
     90 func (n notLocker) Unlock() {}
     91 
     92 // extendable returns the extendableProto interface for the given generated proto message.
     93 // If the proto message has the old extension format, it returns a wrapper that implements
     94 // the extendableProto interface.
     95 func extendable(p interface{}) (extendableProto, error) {
     96 	switch p := p.(type) {
     97 	case extendableProto:
     98 		if isNilPtr(p) {
     99 			return nil, fmt.Errorf("proto: nil %T is not extendable", p)
    100 		}
    101 		return p, nil
    102 	case extendableProtoV1:
    103 		if isNilPtr(p) {
    104 			return nil, fmt.Errorf("proto: nil %T is not extendable", p)
    105 		}
    106 		return extensionAdapter{p}, nil
    107 	}
    108 	// Don't allocate a specific error containing %T:
    109 	// this is the hot path for Clone and MarshalText.
    110 	return nil, errNotExtendable
    111 }
    112 
    113 var errNotExtendable = errors.New("proto: not an extendable proto.Message")
    114 
    115 func isNilPtr(x interface{}) bool {
    116 	v := reflect.ValueOf(x)
    117 	return v.Kind() == reflect.Ptr && v.IsNil()
    118 }
    119 
    120 // XXX_InternalExtensions is an internal representation of proto extensions.
    121 //
    122 // Each generated message struct type embeds an anonymous XXX_InternalExtensions field,
    123 // thus gaining the unexported 'extensions' method, which can be called only from the proto package.
    124 //
    125 // The methods of XXX_InternalExtensions are not concurrency safe in general,
    126 // but calls to logically read-only methods such as has and get may be executed concurrently.
    127 type XXX_InternalExtensions struct {
    128 	// The struct must be indirect so that if a user inadvertently copies a
    129 	// generated message and its embedded XXX_InternalExtensions, they
    130 	// avoid the mayhem of a copied mutex.
    131 	//
    132 	// The mutex serializes all logically read-only operations to p.extensionMap.
    133 	// It is up to the client to ensure that write operations to p.extensionMap are
    134 	// mutually exclusive with other accesses.
    135 	p *struct {
    136 		mu           sync.Mutex
    137 		extensionMap map[int32]Extension
    138 	}
    139 }
    140 
    141 // extensionsWrite returns the extension map, creating it on first use.
    142 func (e *XXX_InternalExtensions) extensionsWrite() map[int32]Extension {
    143 	if e.p == nil {
    144 		e.p = new(struct {
    145 			mu           sync.Mutex
    146 			extensionMap map[int32]Extension
    147 		})
    148 		e.p.extensionMap = make(map[int32]Extension)
    149 	}
    150 	return e.p.extensionMap
    151 }
    152 
    153 // extensionsRead returns the extensions map for read-only use.  It may be nil.
    154 // The caller must hold the returned mutex's lock when accessing Elements within the map.
    155 func (e *XXX_InternalExtensions) extensionsRead() (map[int32]Extension, sync.Locker) {
    156 	if e.p == nil {
    157 		return nil, nil
    158 	}
    159 	return e.p.extensionMap, &e.p.mu
    160 }
    161 
    162 // ExtensionDesc represents an extension specification.
    163 // Used in generated code from the protocol compiler.
    164 type ExtensionDesc struct {
    165 	ExtendedType  Message     // nil pointer to the type that is being extended
    166 	ExtensionType interface{} // nil pointer to the extension type
    167 	Field         int32       // field number
    168 	Name          string      // fully-qualified name of extension, for text formatting
    169 	Tag           string      // protobuf tag style
    170 	Filename      string      // name of the file in which the extension is defined
    171 }
    172 
    173 func (ed *ExtensionDesc) repeated() bool {
    174 	t := reflect.TypeOf(ed.ExtensionType)
    175 	return t.Kind() == reflect.Slice && t.Elem().Kind() != reflect.Uint8
    176 }
    177 
    178 // Extension represents an extension in a message.
    179 type Extension struct {
    180 	// When an extension is stored in a message using SetExtension
    181 	// only desc and value are set. When the message is marshaled
    182 	// enc will be set to the encoded form of the message.
    183 	//
    184 	// When a message is unmarshaled and contains extensions, each
    185 	// extension will have only enc set. When such an extension is
    186 	// accessed using GetExtension (or GetExtensions) desc and value
    187 	// will be set.
    188 	desc  *ExtensionDesc
    189 	value interface{}
    190 	enc   []byte
    191 }
    192 
    193 // SetRawExtension is for testing only.
    194 func SetRawExtension(base Message, id int32, b []byte) {
    195 	epb, err := extendable(base)
    196 	if err != nil {
    197 		return
    198 	}
    199 	extmap := epb.extensionsWrite()
    200 	extmap[id] = Extension{enc: b}
    201 }
    202 
    203 // isExtensionField returns true iff the given field number is in an extension range.
    204 func isExtensionField(pb extendableProto, field int32) bool {
    205 	for _, er := range pb.ExtensionRangeArray() {
    206 		if er.Start <= field && field <= er.End {
    207 			return true
    208 		}
    209 	}
    210 	return false
    211 }
    212 
    213 // checkExtensionTypes checks that the given extension is valid for pb.
    214 func checkExtensionTypes(pb extendableProto, extension *ExtensionDesc) error {
    215 	var pbi interface{} = pb
    216 	// Check the extended type.
    217 	if ea, ok := pbi.(extensionAdapter); ok {
    218 		pbi = ea.extendableProtoV1
    219 	}
    220 	if a, b := reflect.TypeOf(pbi), reflect.TypeOf(extension.ExtendedType); a != b {
    221 		return fmt.Errorf("proto: bad extended type; %v does not extend %v", b, a)
    222 	}
    223 	// Check the range.
    224 	if !isExtensionField(pb, extension.Field) {
    225 		return errors.New("proto: bad extension number; not in declared ranges")
    226 	}
    227 	return nil
    228 }
    229 
    230 // extPropKey is sufficient to uniquely identify an extension.
    231 type extPropKey struct {
    232 	base  reflect.Type
    233 	field int32
    234 }
    235 
    236 var extProp = struct {
    237 	sync.RWMutex
    238 	m map[extPropKey]*Properties
    239 }{
    240 	m: make(map[extPropKey]*Properties),
    241 }
    242 
    243 func extensionProperties(ed *ExtensionDesc) *Properties {
    244 	key := extPropKey{base: reflect.TypeOf(ed.ExtendedType), field: ed.Field}
    245 
    246 	extProp.RLock()
    247 	if prop, ok := extProp.m[key]; ok {
    248 		extProp.RUnlock()
    249 		return prop
    250 	}
    251 	extProp.RUnlock()
    252 
    253 	extProp.Lock()
    254 	defer extProp.Unlock()
    255 	// Check again.
    256 	if prop, ok := extProp.m[key]; ok {
    257 		return prop
    258 	}
    259 
    260 	prop := new(Properties)
    261 	prop.Init(reflect.TypeOf(ed.ExtensionType), "unknown_name", ed.Tag, nil)
    262 	extProp.m[key] = prop
    263 	return prop
    264 }
    265 
    266 // HasExtension returns whether the given extension is present in pb.
    267 func HasExtension(pb Message, extension *ExtensionDesc) bool {
    268 	// TODO: Check types, field numbers, etc.?
    269 	epb, err := extendable(pb)
    270 	if err != nil {
    271 		return false
    272 	}
    273 	extmap, mu := epb.extensionsRead()
    274 	if extmap == nil {
    275 		return false
    276 	}
    277 	mu.Lock()
    278 	_, ok := extmap[extension.Field]
    279 	mu.Unlock()
    280 	return ok
    281 }
    282 
    283 // ClearExtension removes the given extension from pb.
    284 func ClearExtension(pb Message, extension *ExtensionDesc) {
    285 	epb, err := extendable(pb)
    286 	if err != nil {
    287 		return
    288 	}
    289 	// TODO: Check types, field numbers, etc.?
    290 	extmap := epb.extensionsWrite()
    291 	delete(extmap, extension.Field)
    292 }
    293 
    294 // GetExtension retrieves a proto2 extended field from pb.
    295 //
    296 // If the descriptor is type complete (i.e., ExtensionDesc.ExtensionType is non-nil),
    297 // then GetExtension parses the encoded field and returns a Go value of the specified type.
    298 // If the field is not present, then the default value is returned (if one is specified),
    299 // otherwise ErrMissingExtension is reported.
    300 //
    301 // If the descriptor is not type complete (i.e., ExtensionDesc.ExtensionType is nil),
    302 // then GetExtension returns the raw encoded bytes of the field extension.
    303 func GetExtension(pb Message, extension *ExtensionDesc) (interface{}, error) {
    304 	epb, err := extendable(pb)
    305 	if err != nil {
    306 		return nil, err
    307 	}
    308 
    309 	if extension.ExtendedType != nil {
    310 		// can only check type if this is a complete descriptor
    311 		if err := checkExtensionTypes(epb, extension); err != nil {
    312 			return nil, err
    313 		}
    314 	}
    315 
    316 	emap, mu := epb.extensionsRead()
    317 	if emap == nil {
    318 		return defaultExtensionValue(extension)
    319 	}
    320 	mu.Lock()
    321 	defer mu.Unlock()
    322 	e, ok := emap[extension.Field]
    323 	if !ok {
    324 		// defaultExtensionValue returns the default value or
    325 		// ErrMissingExtension if there is no default.
    326 		return defaultExtensionValue(extension)
    327 	}
    328 
    329 	if e.value != nil {
    330 		// Already decoded. Check the descriptor, though.
    331 		if e.desc != extension {
    332 			// This shouldn't happen. If it does, it means that
    333 			// GetExtension was called twice with two different
    334 			// descriptors with the same field number.
    335 			return nil, errors.New("proto: descriptor conflict")
    336 		}
    337 		return e.value, nil
    338 	}
    339 
    340 	if extension.ExtensionType == nil {
    341 		// incomplete descriptor
    342 		return e.enc, nil
    343 	}
    344 
    345 	v, err := decodeExtension(e.enc, extension)
    346 	if err != nil {
    347 		return nil, err
    348 	}
    349 
    350 	// Remember the decoded version and drop the encoded version.
    351 	// That way it is safe to mutate what we return.
    352 	e.value = v
    353 	e.desc = extension
    354 	e.enc = nil
    355 	emap[extension.Field] = e
    356 	return e.value, nil
    357 }
    358 
    359 // defaultExtensionValue returns the default value for extension.
    360 // If no default for an extension is defined ErrMissingExtension is returned.
    361 func defaultExtensionValue(extension *ExtensionDesc) (interface{}, error) {
    362 	if extension.ExtensionType == nil {
    363 		// incomplete descriptor, so no default
    364 		return nil, ErrMissingExtension
    365 	}
    366 
    367 	t := reflect.TypeOf(extension.ExtensionType)
    368 	props := extensionProperties(extension)
    369 
    370 	sf, _, err := fieldDefault(t, props)
    371 	if err != nil {
    372 		return nil, err
    373 	}
    374 
    375 	if sf == nil || sf.value == nil {
    376 		// There is no default value.
    377 		return nil, ErrMissingExtension
    378 	}
    379 
    380 	if t.Kind() != reflect.Ptr {
    381 		// We do not need to return a Ptr, we can directly return sf.value.
    382 		return sf.value, nil
    383 	}
    384 
    385 	// We need to return an interface{} that is a pointer to sf.value.
    386 	value := reflect.New(t).Elem()
    387 	value.Set(reflect.New(value.Type().Elem()))
    388 	if sf.kind == reflect.Int32 {
    389 		// We may have an int32 or an enum, but the underlying data is int32.
    390 		// Since we can't set an int32 into a non int32 reflect.value directly
    391 		// set it as a int32.
    392 		value.Elem().SetInt(int64(sf.value.(int32)))
    393 	} else {
    394 		value.Elem().Set(reflect.ValueOf(sf.value))
    395 	}
    396 	return value.Interface(), nil
    397 }
    398 
    399 // decodeExtension decodes an extension encoded in b.
    400 func decodeExtension(b []byte, extension *ExtensionDesc) (interface{}, error) {
    401 	t := reflect.TypeOf(extension.ExtensionType)
    402 	unmarshal := typeUnmarshaler(t, extension.Tag)
    403 
    404 	// t is a pointer to a struct, pointer to basic type or a slice.
    405 	// Allocate space to store the pointer/slice.
    406 	value := reflect.New(t).Elem()
    407 
    408 	var err error
    409 	for {
    410 		x, n := decodeVarint(b)
    411 		if n == 0 {
    412 			return nil, io.ErrUnexpectedEOF
    413 		}
    414 		b = b[n:]
    415 		wire := int(x) & 7
    416 
    417 		b, err = unmarshal(b, valToPointer(value.Addr()), wire)
    418 		if err != nil {
    419 			return nil, err
    420 		}
    421 
    422 		if len(b) == 0 {
    423 			break
    424 		}
    425 	}
    426 	return value.Interface(), nil
    427 }
    428 
    429 // GetExtensions returns a slice of the extensions present in pb that are also listed in es.
    430 // The returned slice has the same length as es; missing extensions will appear as nil elements.
    431 func GetExtensions(pb Message, es []*ExtensionDesc) (extensions []interface{}, err error) {
    432 	epb, err := extendable(pb)
    433 	if err != nil {
    434 		return nil, err
    435 	}
    436 	extensions = make([]interface{}, len(es))
    437 	for i, e := range es {
    438 		extensions[i], err = GetExtension(epb, e)
    439 		if err == ErrMissingExtension {
    440 			err = nil
    441 		}
    442 		if err != nil {
    443 			return
    444 		}
    445 	}
    446 	return
    447 }
    448 
    449 // ExtensionDescs returns a new slice containing pb's extension descriptors, in undefined order.
    450 // For non-registered extensions, ExtensionDescs returns an incomplete descriptor containing
    451 // just the Field field, which defines the extension's field number.
    452 func ExtensionDescs(pb Message) ([]*ExtensionDesc, error) {
    453 	epb, err := extendable(pb)
    454 	if err != nil {
    455 		return nil, err
    456 	}
    457 	registeredExtensions := RegisteredExtensions(pb)
    458 
    459 	emap, mu := epb.extensionsRead()
    460 	if emap == nil {
    461 		return nil, nil
    462 	}
    463 	mu.Lock()
    464 	defer mu.Unlock()
    465 	extensions := make([]*ExtensionDesc, 0, len(emap))
    466 	for extid, e := range emap {
    467 		desc := e.desc
    468 		if desc == nil {
    469 			desc = registeredExtensions[extid]
    470 			if desc == nil {
    471 				desc = &ExtensionDesc{Field: extid}
    472 			}
    473 		}
    474 
    475 		extensions = append(extensions, desc)
    476 	}
    477 	return extensions, nil
    478 }
    479 
    480 // SetExtension sets the specified extension of pb to the specified value.
    481 func SetExtension(pb Message, extension *ExtensionDesc, value interface{}) error {
    482 	epb, err := extendable(pb)
    483 	if err != nil {
    484 		return err
    485 	}
    486 	if err := checkExtensionTypes(epb, extension); err != nil {
    487 		return err
    488 	}
    489 	typ := reflect.TypeOf(extension.ExtensionType)
    490 	if typ != reflect.TypeOf(value) {
    491 		return errors.New("proto: bad extension value type")
    492 	}
    493 	// nil extension values need to be caught early, because the
    494 	// encoder can't distinguish an ErrNil due to a nil extension
    495 	// from an ErrNil due to a missing field. Extensions are
    496 	// always optional, so the encoder would just swallow the error
    497 	// and drop all the extensions from the encoded message.
    498 	if reflect.ValueOf(value).IsNil() {
    499 		return fmt.Errorf("proto: SetExtension called with nil value of type %T", value)
    500 	}
    501 
    502 	extmap := epb.extensionsWrite()
    503 	extmap[extension.Field] = Extension{desc: extension, value: value}
    504 	return nil
    505 }
    506 
    507 // ClearAllExtensions clears all extensions from pb.
    508 func ClearAllExtensions(pb Message) {
    509 	epb, err := extendable(pb)
    510 	if err != nil {
    511 		return
    512 	}
    513 	m := epb.extensionsWrite()
    514 	for k := range m {
    515 		delete(m, k)
    516 	}
    517 }
    518 
    519 // A global registry of extensions.
    520 // The generated code will register the generated descriptors by calling RegisterExtension.
    521 
    522 var extensionMaps = make(map[reflect.Type]map[int32]*ExtensionDesc)
    523 
    524 // RegisterExtension is called from the generated code.
    525 func RegisterExtension(desc *ExtensionDesc) {
    526 	st := reflect.TypeOf(desc.ExtendedType).Elem()
    527 	m := extensionMaps[st]
    528 	if m == nil {
    529 		m = make(map[int32]*ExtensionDesc)
    530 		extensionMaps[st] = m
    531 	}
    532 	if _, ok := m[desc.Field]; ok {
    533 		panic("proto: duplicate extension registered: " + st.String() + " " + strconv.Itoa(int(desc.Field)))
    534 	}
    535 	m[desc.Field] = desc
    536 }
    537 
    538 // RegisteredExtensions returns a map of the registered extensions of a
    539 // protocol buffer struct, indexed by the extension number.
    540 // The argument pb should be a nil pointer to the struct type.
    541 func RegisteredExtensions(pb Message) map[int32]*ExtensionDesc {
    542 	return extensionMaps[reflect.TypeOf(pb).Elem()]
    543 }
    544