Home | History | Annotate | Download | only in profile
      1 // Copyright 2014 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 // This file is a simple protocol buffer encoder and decoder.
      6 //
      7 // A protocol message must implement the message interface:
      8 //   decoder() []decoder
      9 //   encode(*buffer)
     10 //
     11 // The decode method returns a slice indexed by field number that gives the
     12 // function to decode that field.
     13 // The encode method encodes its receiver into the given buffer.
     14 //
     15 // The two methods are simple enough to be implemented by hand rather than
     16 // by using a protocol compiler.
     17 //
     18 // See profile.go for examples of messages implementing this interface.
     19 //
     20 // There is no support for groups, message sets, or "has" bits.
     21 
     22 package profile
     23 
     24 import "errors"
     25 
     26 type buffer struct {
     27 	field int
     28 	typ   int
     29 	u64   uint64
     30 	data  []byte
     31 	tmp   [16]byte
     32 }
     33 
     34 type decoder func(*buffer, message) error
     35 
     36 type message interface {
     37 	decoder() []decoder
     38 	encode(*buffer)
     39 }
     40 
     41 func marshal(m message) []byte {
     42 	var b buffer
     43 	m.encode(&b)
     44 	return b.data
     45 }
     46 
     47 func encodeVarint(b *buffer, x uint64) {
     48 	for x >= 128 {
     49 		b.data = append(b.data, byte(x)|0x80)
     50 		x >>= 7
     51 	}
     52 	b.data = append(b.data, byte(x))
     53 }
     54 
     55 func encodeLength(b *buffer, tag int, len int) {
     56 	encodeVarint(b, uint64(tag)<<3|2)
     57 	encodeVarint(b, uint64(len))
     58 }
     59 
     60 func encodeUint64(b *buffer, tag int, x uint64) {
     61 	// append varint to b.data
     62 	encodeVarint(b, uint64(tag)<<3|0)
     63 	encodeVarint(b, x)
     64 }
     65 
     66 func encodeUint64s(b *buffer, tag int, x []uint64) {
     67 	for _, u := range x {
     68 		encodeUint64(b, tag, u)
     69 	}
     70 }
     71 
     72 func encodeUint64Opt(b *buffer, tag int, x uint64) {
     73 	if x == 0 {
     74 		return
     75 	}
     76 	encodeUint64(b, tag, x)
     77 }
     78 
     79 func encodeInt64(b *buffer, tag int, x int64) {
     80 	u := uint64(x)
     81 	encodeUint64(b, tag, u)
     82 }
     83 
     84 func encodeInt64Opt(b *buffer, tag int, x int64) {
     85 	if x == 0 {
     86 		return
     87 	}
     88 	encodeInt64(b, tag, x)
     89 }
     90 
     91 func encodeString(b *buffer, tag int, x string) {
     92 	encodeLength(b, tag, len(x))
     93 	b.data = append(b.data, x...)
     94 }
     95 
     96 func encodeStrings(b *buffer, tag int, x []string) {
     97 	for _, s := range x {
     98 		encodeString(b, tag, s)
     99 	}
    100 }
    101 
    102 func encodeStringOpt(b *buffer, tag int, x string) {
    103 	if x == "" {
    104 		return
    105 	}
    106 	encodeString(b, tag, x)
    107 }
    108 
    109 func encodeBool(b *buffer, tag int, x bool) {
    110 	if x {
    111 		encodeUint64(b, tag, 1)
    112 	} else {
    113 		encodeUint64(b, tag, 0)
    114 	}
    115 }
    116 
    117 func encodeBoolOpt(b *buffer, tag int, x bool) {
    118 	if x == false {
    119 		return
    120 	}
    121 	encodeBool(b, tag, x)
    122 }
    123 
    124 func encodeMessage(b *buffer, tag int, m message) {
    125 	n1 := len(b.data)
    126 	m.encode(b)
    127 	n2 := len(b.data)
    128 	encodeLength(b, tag, n2-n1)
    129 	n3 := len(b.data)
    130 	copy(b.tmp[:], b.data[n2:n3])
    131 	copy(b.data[n1+(n3-n2):], b.data[n1:n2])
    132 	copy(b.data[n1:], b.tmp[:n3-n2])
    133 }
    134 
    135 func unmarshal(data []byte, m message) (err error) {
    136 	b := buffer{data: data, typ: 2}
    137 	return decodeMessage(&b, m)
    138 }
    139 
    140 func le64(p []byte) uint64 {
    141 	return uint64(p[0]) | uint64(p[1])<<8 | uint64(p[2])<<16 | uint64(p[3])<<24 | uint64(p[4])<<32 | uint64(p[5])<<40 | uint64(p[6])<<48 | uint64(p[7])<<56
    142 }
    143 
    144 func le32(p []byte) uint32 {
    145 	return uint32(p[0]) | uint32(p[1])<<8 | uint32(p[2])<<16 | uint32(p[3])<<24
    146 }
    147 
    148 func decodeVarint(data []byte) (uint64, []byte, error) {
    149 	var i int
    150 	var u uint64
    151 	for i = 0; ; i++ {
    152 		if i >= 10 || i >= len(data) {
    153 			return 0, nil, errors.New("bad varint")
    154 		}
    155 		u |= uint64(data[i]&0x7F) << uint(7*i)
    156 		if data[i]&0x80 == 0 {
    157 			return u, data[i+1:], nil
    158 		}
    159 	}
    160 }
    161 
    162 func decodeField(b *buffer, data []byte) ([]byte, error) {
    163 	x, data, err := decodeVarint(data)
    164 	if err != nil {
    165 		return nil, err
    166 	}
    167 	b.field = int(x >> 3)
    168 	b.typ = int(x & 7)
    169 	b.data = nil
    170 	b.u64 = 0
    171 	switch b.typ {
    172 	case 0:
    173 		b.u64, data, err = decodeVarint(data)
    174 		if err != nil {
    175 			return nil, err
    176 		}
    177 	case 1:
    178 		if len(data) < 8 {
    179 			return nil, errors.New("not enough data")
    180 		}
    181 		b.u64 = le64(data[:8])
    182 		data = data[8:]
    183 	case 2:
    184 		var n uint64
    185 		n, data, err = decodeVarint(data)
    186 		if err != nil {
    187 			return nil, err
    188 		}
    189 		if n > uint64(len(data)) {
    190 			return nil, errors.New("too much data")
    191 		}
    192 		b.data = data[:n]
    193 		data = data[n:]
    194 	case 5:
    195 		if len(data) < 4 {
    196 			return nil, errors.New("not enough data")
    197 		}
    198 		b.u64 = uint64(le32(data[:4]))
    199 		data = data[4:]
    200 	default:
    201 		return nil, errors.New("unknown type: " + string(b.typ))
    202 	}
    203 
    204 	return data, nil
    205 }
    206 
    207 func checkType(b *buffer, typ int) error {
    208 	if b.typ != typ {
    209 		return errors.New("type mismatch")
    210 	}
    211 	return nil
    212 }
    213 
    214 func decodeMessage(b *buffer, m message) error {
    215 	if err := checkType(b, 2); err != nil {
    216 		return err
    217 	}
    218 	dec := m.decoder()
    219 	data := b.data
    220 	for len(data) > 0 {
    221 		// pull varint field# + type
    222 		var err error
    223 		data, err = decodeField(b, data)
    224 		if err != nil {
    225 			return err
    226 		}
    227 		if b.field >= len(dec) || dec[b.field] == nil {
    228 			continue
    229 		}
    230 		if err := dec[b.field](b, m); err != nil {
    231 			return err
    232 		}
    233 	}
    234 	return nil
    235 }
    236 
    237 func decodeInt64(b *buffer, x *int64) error {
    238 	if err := checkType(b, 0); err != nil {
    239 		return err
    240 	}
    241 	*x = int64(b.u64)
    242 	return nil
    243 }
    244 
    245 func decodeInt64s(b *buffer, x *[]int64) error {
    246 	var i int64
    247 	if err := decodeInt64(b, &i); err != nil {
    248 		return err
    249 	}
    250 	*x = append(*x, i)
    251 	return nil
    252 }
    253 
    254 func decodeUint64(b *buffer, x *uint64) error {
    255 	if err := checkType(b, 0); err != nil {
    256 		return err
    257 	}
    258 	*x = b.u64
    259 	return nil
    260 }
    261 
    262 func decodeUint64s(b *buffer, x *[]uint64) error {
    263 	var u uint64
    264 	if err := decodeUint64(b, &u); err != nil {
    265 		return err
    266 	}
    267 	*x = append(*x, u)
    268 	return nil
    269 }
    270 
    271 func decodeString(b *buffer, x *string) error {
    272 	if err := checkType(b, 2); err != nil {
    273 		return err
    274 	}
    275 	*x = string(b.data)
    276 	return nil
    277 }
    278 
    279 func decodeStrings(b *buffer, x *[]string) error {
    280 	var s string
    281 	if err := decodeString(b, &s); err != nil {
    282 		return err
    283 	}
    284 	*x = append(*x, s)
    285 	return nil
    286 }
    287 
    288 func decodeBool(b *buffer, x *bool) error {
    289 	if err := checkType(b, 0); err != nil {
    290 		return err
    291 	}
    292 	if int64(b.u64) == 0 {
    293 		*x = false
    294 	} else {
    295 		*x = true
    296 	}
    297 	return nil
    298 }
    299