Home | History | Annotate | Download | only in proto
      1 // Go support for Protocol Buffers - Google's data interchange format
      2 //
      3 // Copyright 2014 The Go Authors.  All rights reserved.
      4 // https://github.com/golang/protobuf
      5 //
      6 // Redistribution and use in source and binary forms, with or without
      7 // modification, are permitted provided that the following conditions are
      8 // met:
      9 //
     10 //     * Redistributions of source code must retain the above copyright
     11 // notice, this list of conditions and the following disclaimer.
     12 //     * Redistributions in binary form must reproduce the above
     13 // copyright notice, this list of conditions and the following disclaimer
     14 // in the documentation and/or other materials provided with the
     15 // distribution.
     16 //     * Neither the name of Google Inc. nor the names of its
     17 // contributors may be used to endorse or promote products derived from
     18 // this software without specific prior written permission.
     19 //
     20 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
     21 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
     22 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
     23 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
     24 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
     25 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
     26 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
     27 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
     28 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
     29 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
     30 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
     31 
     32 // AOSP change: ignore this file, since AOSP does not include
     33 // golang.org/x/sync/errgroup
     34 // +build ignore
     35 
     36 package proto_test
     37 
     38 import (
     39 	"bytes"
     40 	"fmt"
     41 	"io"
     42 	"reflect"
     43 	"sort"
     44 	"strings"
     45 	"testing"
     46 
     47 	"github.com/golang/protobuf/proto"
     48 	pb "github.com/golang/protobuf/proto/test_proto"
     49 	"golang.org/x/sync/errgroup"
     50 )
     51 
     52 func TestGetExtensionsWithMissingExtensions(t *testing.T) {
     53 	msg := &pb.MyMessage{}
     54 	ext1 := &pb.Ext{}
     55 	if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil {
     56 		t.Fatalf("Could not set ext1: %s", err)
     57 	}
     58 	exts, err := proto.GetExtensions(msg, []*proto.ExtensionDesc{
     59 		pb.E_Ext_More,
     60 		pb.E_Ext_Text,
     61 	})
     62 	if err != nil {
     63 		t.Fatalf("GetExtensions() failed: %s", err)
     64 	}
     65 	if exts[0] != ext1 {
     66 		t.Errorf("ext1 not in returned extensions: %T %v", exts[0], exts[0])
     67 	}
     68 	if exts[1] != nil {
     69 		t.Errorf("ext2 in returned extensions: %T %v", exts[1], exts[1])
     70 	}
     71 }
     72 
     73 func TestGetExtensionWithEmptyBuffer(t *testing.T) {
     74 	// Make sure that GetExtension returns an error if its
     75 	// undecoded buffer is empty.
     76 	msg := &pb.MyMessage{}
     77 	proto.SetRawExtension(msg, pb.E_Ext_More.Field, []byte{})
     78 	_, err := proto.GetExtension(msg, pb.E_Ext_More)
     79 	if want := io.ErrUnexpectedEOF; err != want {
     80 		t.Errorf("unexpected error in GetExtension from empty buffer: got %v, want %v", err, want)
     81 	}
     82 }
     83 
     84 func TestGetExtensionForIncompleteDesc(t *testing.T) {
     85 	msg := &pb.MyMessage{Count: proto.Int32(0)}
     86 	extdesc1 := &proto.ExtensionDesc{
     87 		ExtendedType:  (*pb.MyMessage)(nil),
     88 		ExtensionType: (*bool)(nil),
     89 		Field:         123456789,
     90 		Name:          "a.b",
     91 		Tag:           "varint,123456789,opt",
     92 	}
     93 	ext1 := proto.Bool(true)
     94 	if err := proto.SetExtension(msg, extdesc1, ext1); err != nil {
     95 		t.Fatalf("Could not set ext1: %s", err)
     96 	}
     97 	extdesc2 := &proto.ExtensionDesc{
     98 		ExtendedType:  (*pb.MyMessage)(nil),
     99 		ExtensionType: ([]byte)(nil),
    100 		Field:         123456790,
    101 		Name:          "a.c",
    102 		Tag:           "bytes,123456790,opt",
    103 	}
    104 	ext2 := []byte{0, 1, 2, 3, 4, 5, 6, 7}
    105 	if err := proto.SetExtension(msg, extdesc2, ext2); err != nil {
    106 		t.Fatalf("Could not set ext2: %s", err)
    107 	}
    108 	extdesc3 := &proto.ExtensionDesc{
    109 		ExtendedType:  (*pb.MyMessage)(nil),
    110 		ExtensionType: (*pb.Ext)(nil),
    111 		Field:         123456791,
    112 		Name:          "a.d",
    113 		Tag:           "bytes,123456791,opt",
    114 	}
    115 	ext3 := &pb.Ext{Data: proto.String("foo")}
    116 	if err := proto.SetExtension(msg, extdesc3, ext3); err != nil {
    117 		t.Fatalf("Could not set ext3: %s", err)
    118 	}
    119 
    120 	b, err := proto.Marshal(msg)
    121 	if err != nil {
    122 		t.Fatalf("Could not marshal msg: %v", err)
    123 	}
    124 	if err := proto.Unmarshal(b, msg); err != nil {
    125 		t.Fatalf("Could not unmarshal into msg: %v", err)
    126 	}
    127 
    128 	var expected proto.Buffer
    129 	if err := expected.EncodeVarint(uint64((extdesc1.Field << 3) | proto.WireVarint)); err != nil {
    130 		t.Fatalf("failed to compute expected prefix for ext1: %s", err)
    131 	}
    132 	if err := expected.EncodeVarint(1 /* bool true */); err != nil {
    133 		t.Fatalf("failed to compute expected value for ext1: %s", err)
    134 	}
    135 
    136 	if b, err := proto.GetExtension(msg, &proto.ExtensionDesc{Field: extdesc1.Field}); err != nil {
    137 		t.Fatalf("Failed to get raw value for ext1: %s", err)
    138 	} else if !reflect.DeepEqual(b, expected.Bytes()) {
    139 		t.Fatalf("Raw value for ext1: got %v, want %v", b, expected.Bytes())
    140 	}
    141 
    142 	expected = proto.Buffer{} // reset
    143 	if err := expected.EncodeVarint(uint64((extdesc2.Field << 3) | proto.WireBytes)); err != nil {
    144 		t.Fatalf("failed to compute expected prefix for ext2: %s", err)
    145 	}
    146 	if err := expected.EncodeRawBytes(ext2); err != nil {
    147 		t.Fatalf("failed to compute expected value for ext2: %s", err)
    148 	}
    149 
    150 	if b, err := proto.GetExtension(msg, &proto.ExtensionDesc{Field: extdesc2.Field}); err != nil {
    151 		t.Fatalf("Failed to get raw value for ext2: %s", err)
    152 	} else if !reflect.DeepEqual(b, expected.Bytes()) {
    153 		t.Fatalf("Raw value for ext2: got %v, want %v", b, expected.Bytes())
    154 	}
    155 
    156 	expected = proto.Buffer{} // reset
    157 	if err := expected.EncodeVarint(uint64((extdesc3.Field << 3) | proto.WireBytes)); err != nil {
    158 		t.Fatalf("failed to compute expected prefix for ext3: %s", err)
    159 	}
    160 	if b, err := proto.Marshal(ext3); err != nil {
    161 		t.Fatalf("failed to compute expected value for ext3: %s", err)
    162 	} else if err := expected.EncodeRawBytes(b); err != nil {
    163 		t.Fatalf("failed to compute expected value for ext3: %s", err)
    164 	}
    165 
    166 	if b, err := proto.GetExtension(msg, &proto.ExtensionDesc{Field: extdesc3.Field}); err != nil {
    167 		t.Fatalf("Failed to get raw value for ext3: %s", err)
    168 	} else if !reflect.DeepEqual(b, expected.Bytes()) {
    169 		t.Fatalf("Raw value for ext3: got %v, want %v", b, expected.Bytes())
    170 	}
    171 }
    172 
    173 func TestExtensionDescsWithUnregisteredExtensions(t *testing.T) {
    174 	msg := &pb.MyMessage{Count: proto.Int32(0)}
    175 	extdesc1 := pb.E_Ext_More
    176 	if descs, err := proto.ExtensionDescs(msg); len(descs) != 0 || err != nil {
    177 		t.Errorf("proto.ExtensionDescs: got %d descs, error %v; want 0, nil", len(descs), err)
    178 	}
    179 
    180 	ext1 := &pb.Ext{}
    181 	if err := proto.SetExtension(msg, extdesc1, ext1); err != nil {
    182 		t.Fatalf("Could not set ext1: %s", err)
    183 	}
    184 	extdesc2 := &proto.ExtensionDesc{
    185 		ExtendedType:  (*pb.MyMessage)(nil),
    186 		ExtensionType: (*bool)(nil),
    187 		Field:         123456789,
    188 		Name:          "a.b",
    189 		Tag:           "varint,123456789,opt",
    190 	}
    191 	ext2 := proto.Bool(false)
    192 	if err := proto.SetExtension(msg, extdesc2, ext2); err != nil {
    193 		t.Fatalf("Could not set ext2: %s", err)
    194 	}
    195 
    196 	b, err := proto.Marshal(msg)
    197 	if err != nil {
    198 		t.Fatalf("Could not marshal msg: %v", err)
    199 	}
    200 	if err := proto.Unmarshal(b, msg); err != nil {
    201 		t.Fatalf("Could not unmarshal into msg: %v", err)
    202 	}
    203 
    204 	descs, err := proto.ExtensionDescs(msg)
    205 	if err != nil {
    206 		t.Fatalf("proto.ExtensionDescs: got error %v", err)
    207 	}
    208 	sortExtDescs(descs)
    209 	wantDescs := []*proto.ExtensionDesc{extdesc1, {Field: extdesc2.Field}}
    210 	if !reflect.DeepEqual(descs, wantDescs) {
    211 		t.Errorf("proto.ExtensionDescs(msg) sorted extension ids: got %+v, want %+v", descs, wantDescs)
    212 	}
    213 }
    214 
    215 type ExtensionDescSlice []*proto.ExtensionDesc
    216 
    217 func (s ExtensionDescSlice) Len() int           { return len(s) }
    218 func (s ExtensionDescSlice) Less(i, j int) bool { return s[i].Field < s[j].Field }
    219 func (s ExtensionDescSlice) Swap(i, j int)      { s[i], s[j] = s[j], s[i] }
    220 
    221 func sortExtDescs(s []*proto.ExtensionDesc) {
    222 	sort.Sort(ExtensionDescSlice(s))
    223 }
    224 
    225 func TestGetExtensionStability(t *testing.T) {
    226 	check := func(m *pb.MyMessage) bool {
    227 		ext1, err := proto.GetExtension(m, pb.E_Ext_More)
    228 		if err != nil {
    229 			t.Fatalf("GetExtension() failed: %s", err)
    230 		}
    231 		ext2, err := proto.GetExtension(m, pb.E_Ext_More)
    232 		if err != nil {
    233 			t.Fatalf("GetExtension() failed: %s", err)
    234 		}
    235 		return ext1 == ext2
    236 	}
    237 	msg := &pb.MyMessage{Count: proto.Int32(4)}
    238 	ext0 := &pb.Ext{}
    239 	if err := proto.SetExtension(msg, pb.E_Ext_More, ext0); err != nil {
    240 		t.Fatalf("Could not set ext1: %s", ext0)
    241 	}
    242 	if !check(msg) {
    243 		t.Errorf("GetExtension() not stable before marshaling")
    244 	}
    245 	bb, err := proto.Marshal(msg)
    246 	if err != nil {
    247 		t.Fatalf("Marshal() failed: %s", err)
    248 	}
    249 	msg1 := &pb.MyMessage{}
    250 	err = proto.Unmarshal(bb, msg1)
    251 	if err != nil {
    252 		t.Fatalf("Unmarshal() failed: %s", err)
    253 	}
    254 	if !check(msg1) {
    255 		t.Errorf("GetExtension() not stable after unmarshaling")
    256 	}
    257 }
    258 
    259 func TestGetExtensionDefaults(t *testing.T) {
    260 	var setFloat64 float64 = 1
    261 	var setFloat32 float32 = 2
    262 	var setInt32 int32 = 3
    263 	var setInt64 int64 = 4
    264 	var setUint32 uint32 = 5
    265 	var setUint64 uint64 = 6
    266 	var setBool = true
    267 	var setBool2 = false
    268 	var setString = "Goodnight string"
    269 	var setBytes = []byte("Goodnight bytes")
    270 	var setEnum = pb.DefaultsMessage_TWO
    271 
    272 	type testcase struct {
    273 		ext  *proto.ExtensionDesc // Extension we are testing.
    274 		want interface{}          // Expected value of extension, or nil (meaning that GetExtension will fail).
    275 		def  interface{}          // Expected value of extension after ClearExtension().
    276 	}
    277 	tests := []testcase{
    278 		{pb.E_NoDefaultDouble, setFloat64, nil},
    279 		{pb.E_NoDefaultFloat, setFloat32, nil},
    280 		{pb.E_NoDefaultInt32, setInt32, nil},
    281 		{pb.E_NoDefaultInt64, setInt64, nil},
    282 		{pb.E_NoDefaultUint32, setUint32, nil},
    283 		{pb.E_NoDefaultUint64, setUint64, nil},
    284 		{pb.E_NoDefaultSint32, setInt32, nil},
    285 		{pb.E_NoDefaultSint64, setInt64, nil},
    286 		{pb.E_NoDefaultFixed32, setUint32, nil},
    287 		{pb.E_NoDefaultFixed64, setUint64, nil},
    288 		{pb.E_NoDefaultSfixed32, setInt32, nil},
    289 		{pb.E_NoDefaultSfixed64, setInt64, nil},
    290 		{pb.E_NoDefaultBool, setBool, nil},
    291 		{pb.E_NoDefaultBool, setBool2, nil},
    292 		{pb.E_NoDefaultString, setString, nil},
    293 		{pb.E_NoDefaultBytes, setBytes, nil},
    294 		{pb.E_NoDefaultEnum, setEnum, nil},
    295 		{pb.E_DefaultDouble, setFloat64, float64(3.1415)},
    296 		{pb.E_DefaultFloat, setFloat32, float32(3.14)},
    297 		{pb.E_DefaultInt32, setInt32, int32(42)},
    298 		{pb.E_DefaultInt64, setInt64, int64(43)},
    299 		{pb.E_DefaultUint32, setUint32, uint32(44)},
    300 		{pb.E_DefaultUint64, setUint64, uint64(45)},
    301 		{pb.E_DefaultSint32, setInt32, int32(46)},
    302 		{pb.E_DefaultSint64, setInt64, int64(47)},
    303 		{pb.E_DefaultFixed32, setUint32, uint32(48)},
    304 		{pb.E_DefaultFixed64, setUint64, uint64(49)},
    305 		{pb.E_DefaultSfixed32, setInt32, int32(50)},
    306 		{pb.E_DefaultSfixed64, setInt64, int64(51)},
    307 		{pb.E_DefaultBool, setBool, true},
    308 		{pb.E_DefaultBool, setBool2, true},
    309 		{pb.E_DefaultString, setString, "Hello, string,def=foo"},
    310 		{pb.E_DefaultBytes, setBytes, []byte("Hello, bytes")},
    311 		{pb.E_DefaultEnum, setEnum, pb.DefaultsMessage_ONE},
    312 	}
    313 
    314 	checkVal := func(test testcase, msg *pb.DefaultsMessage, valWant interface{}) error {
    315 		val, err := proto.GetExtension(msg, test.ext)
    316 		if err != nil {
    317 			if valWant != nil {
    318 				return fmt.Errorf("GetExtension(): %s", err)
    319 			}
    320 			if want := proto.ErrMissingExtension; err != want {
    321 				return fmt.Errorf("Unexpected error: got %v, want %v", err, want)
    322 			}
    323 			return nil
    324 		}
    325 
    326 		// All proto2 extension values are either a pointer to a value or a slice of values.
    327 		ty := reflect.TypeOf(val)
    328 		tyWant := reflect.TypeOf(test.ext.ExtensionType)
    329 		if got, want := ty, tyWant; got != want {
    330 			return fmt.Errorf("unexpected reflect.TypeOf(): got %v want %v", got, want)
    331 		}
    332 		tye := ty.Elem()
    333 		tyeWant := tyWant.Elem()
    334 		if got, want := tye, tyeWant; got != want {
    335 			return fmt.Errorf("unexpected reflect.TypeOf().Elem(): got %v want %v", got, want)
    336 		}
    337 
    338 		// Check the name of the type of the value.
    339 		// If it is an enum it will be type int32 with the name of the enum.
    340 		if got, want := tye.Name(), tye.Name(); got != want {
    341 			return fmt.Errorf("unexpected reflect.TypeOf().Elem().Name(): got %v want %v", got, want)
    342 		}
    343 
    344 		// Check that value is what we expect.
    345 		// If we have a pointer in val, get the value it points to.
    346 		valExp := val
    347 		if ty.Kind() == reflect.Ptr {
    348 			valExp = reflect.ValueOf(val).Elem().Interface()
    349 		}
    350 		if got, want := valExp, valWant; !reflect.DeepEqual(got, want) {
    351 			return fmt.Errorf("unexpected reflect.DeepEqual(): got %v want %v", got, want)
    352 		}
    353 
    354 		return nil
    355 	}
    356 
    357 	setTo := func(test testcase) interface{} {
    358 		setTo := reflect.ValueOf(test.want)
    359 		if typ := reflect.TypeOf(test.ext.ExtensionType); typ.Kind() == reflect.Ptr {
    360 			setTo = reflect.New(typ).Elem()
    361 			setTo.Set(reflect.New(setTo.Type().Elem()))
    362 			setTo.Elem().Set(reflect.ValueOf(test.want))
    363 		}
    364 		return setTo.Interface()
    365 	}
    366 
    367 	for _, test := range tests {
    368 		msg := &pb.DefaultsMessage{}
    369 		name := test.ext.Name
    370 
    371 		// Check the initial value.
    372 		if err := checkVal(test, msg, test.def); err != nil {
    373 			t.Errorf("%s: %v", name, err)
    374 		}
    375 
    376 		// Set the per-type value and check value.
    377 		name = fmt.Sprintf("%s (set to %T %v)", name, test.want, test.want)
    378 		if err := proto.SetExtension(msg, test.ext, setTo(test)); err != nil {
    379 			t.Errorf("%s: SetExtension(): %v", name, err)
    380 			continue
    381 		}
    382 		if err := checkVal(test, msg, test.want); err != nil {
    383 			t.Errorf("%s: %v", name, err)
    384 			continue
    385 		}
    386 
    387 		// Set and check the value.
    388 		name += " (cleared)"
    389 		proto.ClearExtension(msg, test.ext)
    390 		if err := checkVal(test, msg, test.def); err != nil {
    391 			t.Errorf("%s: %v", name, err)
    392 		}
    393 	}
    394 }
    395 
    396 func TestNilMessage(t *testing.T) {
    397 	name := "nil interface"
    398 	if got, err := proto.GetExtension(nil, pb.E_Ext_More); err == nil {
    399 		t.Errorf("%s: got %T %v, expected to fail", name, got, got)
    400 	} else if !strings.Contains(err.Error(), "extendable") {
    401 		t.Errorf("%s: got error %v, expected not-extendable error", name, err)
    402 	}
    403 
    404 	// Regression tests: all functions of the Extension API
    405 	// used to panic when passed (*M)(nil), where M is a concrete message
    406 	// type.  Now they handle this gracefully as a no-op or reported error.
    407 	var nilMsg *pb.MyMessage
    408 	desc := pb.E_Ext_More
    409 
    410 	isNotExtendable := func(err error) bool {
    411 		return strings.Contains(fmt.Sprint(err), "not extendable")
    412 	}
    413 
    414 	if proto.HasExtension(nilMsg, desc) {
    415 		t.Error("HasExtension(nil) = true")
    416 	}
    417 
    418 	if _, err := proto.GetExtensions(nilMsg, []*proto.ExtensionDesc{desc}); !isNotExtendable(err) {
    419 		t.Errorf("GetExtensions(nil) = %q (wrong error)", err)
    420 	}
    421 
    422 	if _, err := proto.ExtensionDescs(nilMsg); !isNotExtendable(err) {
    423 		t.Errorf("ExtensionDescs(nil) = %q (wrong error)", err)
    424 	}
    425 
    426 	if err := proto.SetExtension(nilMsg, desc, nil); !isNotExtendable(err) {
    427 		t.Errorf("SetExtension(nil) = %q (wrong error)", err)
    428 	}
    429 
    430 	proto.ClearExtension(nilMsg, desc) // no-op
    431 	proto.ClearAllExtensions(nilMsg)   // no-op
    432 }
    433 
    434 func TestExtensionsRoundTrip(t *testing.T) {
    435 	msg := &pb.MyMessage{}
    436 	ext1 := &pb.Ext{
    437 		Data: proto.String("hi"),
    438 	}
    439 	ext2 := &pb.Ext{
    440 		Data: proto.String("there"),
    441 	}
    442 	exists := proto.HasExtension(msg, pb.E_Ext_More)
    443 	if exists {
    444 		t.Error("Extension More present unexpectedly")
    445 	}
    446 	if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil {
    447 		t.Error(err)
    448 	}
    449 	if err := proto.SetExtension(msg, pb.E_Ext_More, ext2); err != nil {
    450 		t.Error(err)
    451 	}
    452 	e, err := proto.GetExtension(msg, pb.E_Ext_More)
    453 	if err != nil {
    454 		t.Error(err)
    455 	}
    456 	x, ok := e.(*pb.Ext)
    457 	if !ok {
    458 		t.Errorf("e has type %T, expected test_proto.Ext", e)
    459 	} else if *x.Data != "there" {
    460 		t.Errorf("SetExtension failed to overwrite, got %+v, not 'there'", x)
    461 	}
    462 	proto.ClearExtension(msg, pb.E_Ext_More)
    463 	if _, err = proto.GetExtension(msg, pb.E_Ext_More); err != proto.ErrMissingExtension {
    464 		t.Errorf("got %v, expected ErrMissingExtension", e)
    465 	}
    466 	if _, err := proto.GetExtension(msg, pb.E_X215); err == nil {
    467 		t.Error("expected bad extension error, got nil")
    468 	}
    469 	if err := proto.SetExtension(msg, pb.E_X215, 12); err == nil {
    470 		t.Error("expected extension err")
    471 	}
    472 	if err := proto.SetExtension(msg, pb.E_Ext_More, 12); err == nil {
    473 		t.Error("expected some sort of type mismatch error, got nil")
    474 	}
    475 }
    476 
    477 func TestNilExtension(t *testing.T) {
    478 	msg := &pb.MyMessage{
    479 		Count: proto.Int32(1),
    480 	}
    481 	if err := proto.SetExtension(msg, pb.E_Ext_Text, proto.String("hello")); err != nil {
    482 		t.Fatal(err)
    483 	}
    484 	if err := proto.SetExtension(msg, pb.E_Ext_More, (*pb.Ext)(nil)); err == nil {
    485 		t.Error("expected SetExtension to fail due to a nil extension")
    486 	} else if want := fmt.Sprintf("proto: SetExtension called with nil value of type %T", new(pb.Ext)); err.Error() != want {
    487 		t.Errorf("expected error %v, got %v", want, err)
    488 	}
    489 	// Note: if the behavior of Marshal is ever changed to ignore nil extensions, update
    490 	// this test to verify that E_Ext_Text is properly propagated through marshal->unmarshal.
    491 }
    492 
    493 func TestMarshalUnmarshalRepeatedExtension(t *testing.T) {
    494 	// Add a repeated extension to the result.
    495 	tests := []struct {
    496 		name string
    497 		ext  []*pb.ComplexExtension
    498 	}{
    499 		{
    500 			"two fields",
    501 			[]*pb.ComplexExtension{
    502 				{First: proto.Int32(7)},
    503 				{Second: proto.Int32(11)},
    504 			},
    505 		},
    506 		{
    507 			"repeated field",
    508 			[]*pb.ComplexExtension{
    509 				{Third: []int32{1000}},
    510 				{Third: []int32{2000}},
    511 			},
    512 		},
    513 		{
    514 			"two fields and repeated field",
    515 			[]*pb.ComplexExtension{
    516 				{Third: []int32{1000}},
    517 				{First: proto.Int32(9)},
    518 				{Second: proto.Int32(21)},
    519 				{Third: []int32{2000}},
    520 			},
    521 		},
    522 	}
    523 	for _, test := range tests {
    524 		// Marshal message with a repeated extension.
    525 		msg1 := new(pb.OtherMessage)
    526 		err := proto.SetExtension(msg1, pb.E_RComplex, test.ext)
    527 		if err != nil {
    528 			t.Fatalf("[%s] Error setting extension: %v", test.name, err)
    529 		}
    530 		b, err := proto.Marshal(msg1)
    531 		if err != nil {
    532 			t.Fatalf("[%s] Error marshaling message: %v", test.name, err)
    533 		}
    534 
    535 		// Unmarshal and read the merged proto.
    536 		msg2 := new(pb.OtherMessage)
    537 		err = proto.Unmarshal(b, msg2)
    538 		if err != nil {
    539 			t.Fatalf("[%s] Error unmarshaling message: %v", test.name, err)
    540 		}
    541 		e, err := proto.GetExtension(msg2, pb.E_RComplex)
    542 		if err != nil {
    543 			t.Fatalf("[%s] Error getting extension: %v", test.name, err)
    544 		}
    545 		ext := e.([]*pb.ComplexExtension)
    546 		if ext == nil {
    547 			t.Fatalf("[%s] Invalid extension", test.name)
    548 		}
    549 		if len(ext) != len(test.ext) {
    550 			t.Errorf("[%s] Wrong length of ComplexExtension: got: %v want: %v\n", test.name, len(ext), len(test.ext))
    551 		}
    552 		for i := range test.ext {
    553 			if !proto.Equal(ext[i], test.ext[i]) {
    554 				t.Errorf("[%s] Wrong value for ComplexExtension[%d]: got: %v want: %v\n", test.name, i, ext[i], test.ext[i])
    555 			}
    556 		}
    557 	}
    558 }
    559 
    560 func TestUnmarshalRepeatingNonRepeatedExtension(t *testing.T) {
    561 	// We may see multiple instances of the same extension in the wire
    562 	// format. For example, the proto compiler may encode custom options in
    563 	// this way. Here, we verify that we merge the extensions together.
    564 	tests := []struct {
    565 		name string
    566 		ext  []*pb.ComplexExtension
    567 	}{
    568 		{
    569 			"two fields",
    570 			[]*pb.ComplexExtension{
    571 				{First: proto.Int32(7)},
    572 				{Second: proto.Int32(11)},
    573 			},
    574 		},
    575 		{
    576 			"repeated field",
    577 			[]*pb.ComplexExtension{
    578 				{Third: []int32{1000}},
    579 				{Third: []int32{2000}},
    580 			},
    581 		},
    582 		{
    583 			"two fields and repeated field",
    584 			[]*pb.ComplexExtension{
    585 				{Third: []int32{1000}},
    586 				{First: proto.Int32(9)},
    587 				{Second: proto.Int32(21)},
    588 				{Third: []int32{2000}},
    589 			},
    590 		},
    591 	}
    592 	for _, test := range tests {
    593 		var buf bytes.Buffer
    594 		var want pb.ComplexExtension
    595 
    596 		// Generate a serialized representation of a repeated extension
    597 		// by catenating bytes together.
    598 		for i, e := range test.ext {
    599 			// Merge to create the wanted proto.
    600 			proto.Merge(&want, e)
    601 
    602 			// serialize the message
    603 			msg := new(pb.OtherMessage)
    604 			err := proto.SetExtension(msg, pb.E_Complex, e)
    605 			if err != nil {
    606 				t.Fatalf("[%s] Error setting extension %d: %v", test.name, i, err)
    607 			}
    608 			b, err := proto.Marshal(msg)
    609 			if err != nil {
    610 				t.Fatalf("[%s] Error marshaling message %d: %v", test.name, i, err)
    611 			}
    612 			buf.Write(b)
    613 		}
    614 
    615 		// Unmarshal and read the merged proto.
    616 		msg2 := new(pb.OtherMessage)
    617 		err := proto.Unmarshal(buf.Bytes(), msg2)
    618 		if err != nil {
    619 			t.Fatalf("[%s] Error unmarshaling message: %v", test.name, err)
    620 		}
    621 		e, err := proto.GetExtension(msg2, pb.E_Complex)
    622 		if err != nil {
    623 			t.Fatalf("[%s] Error getting extension: %v", test.name, err)
    624 		}
    625 		ext := e.(*pb.ComplexExtension)
    626 		if ext == nil {
    627 			t.Fatalf("[%s] Invalid extension", test.name)
    628 		}
    629 		if !proto.Equal(ext, &want) {
    630 			t.Errorf("[%s] Wrong value for ComplexExtension: got: %s want: %s\n", test.name, ext, &want)
    631 		}
    632 	}
    633 }
    634 
    635 func TestClearAllExtensions(t *testing.T) {
    636 	// unregistered extension
    637 	desc := &proto.ExtensionDesc{
    638 		ExtendedType:  (*pb.MyMessage)(nil),
    639 		ExtensionType: (*bool)(nil),
    640 		Field:         101010100,
    641 		Name:          "emptyextension",
    642 		Tag:           "varint,0,opt",
    643 	}
    644 	m := &pb.MyMessage{}
    645 	if proto.HasExtension(m, desc) {
    646 		t.Errorf("proto.HasExtension(%s): got true, want false", proto.MarshalTextString(m))
    647 	}
    648 	if err := proto.SetExtension(m, desc, proto.Bool(true)); err != nil {
    649 		t.Errorf("proto.SetExtension(m, desc, true): got error %q, want nil", err)
    650 	}
    651 	if !proto.HasExtension(m, desc) {
    652 		t.Errorf("proto.HasExtension(%s): got false, want true", proto.MarshalTextString(m))
    653 	}
    654 	proto.ClearAllExtensions(m)
    655 	if proto.HasExtension(m, desc) {
    656 		t.Errorf("proto.HasExtension(%s): got true, want false", proto.MarshalTextString(m))
    657 	}
    658 }
    659 
    660 func TestMarshalRace(t *testing.T) {
    661 	ext := &pb.Ext{}
    662 	m := &pb.MyMessage{Count: proto.Int32(4)}
    663 	if err := proto.SetExtension(m, pb.E_Ext_More, ext); err != nil {
    664 		t.Fatalf("proto.SetExtension(m, desc, true): got error %q, want nil", err)
    665 	}
    666 
    667 	b, err := proto.Marshal(m)
    668 	if err != nil {
    669 		t.Fatalf("Could not marshal message: %v", err)
    670 	}
    671 	if err := proto.Unmarshal(b, m); err != nil {
    672 		t.Fatalf("Could not unmarshal message: %v", err)
    673 	}
    674 	// after Unmarshal, the extension is in undecoded form.
    675 	// GetExtension will decode it lazily. Make sure this does
    676 	// not race against Marshal.
    677 
    678 	var g errgroup.Group
    679 	for n := 3; n > 0; n-- {
    680 		g.Go(func() error {
    681 			_, err := proto.Marshal(m)
    682 			return err
    683 		})
    684 		g.Go(func() error {
    685 			_, err := proto.GetExtension(m, pb.E_Ext_More)
    686 			return err
    687 		})
    688 	}
    689 	if err := g.Wait(); err != nil {
    690 		t.Fatal(err)
    691 	}
    692 }
    693