Home | History | Annotate | Download | only in xml
      1 // Copyright 2011 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 package xml
      6 
      7 import (
      8 	"fmt"
      9 	"reflect"
     10 	"strings"
     11 	"sync"
     12 )
     13 
     14 // typeInfo holds details for the xml representation of a type.
     15 type typeInfo struct {
     16 	xmlname *fieldInfo
     17 	fields  []fieldInfo
     18 }
     19 
     20 // fieldInfo holds details for the xml representation of a single field.
     21 type fieldInfo struct {
     22 	idx     []int
     23 	name    string
     24 	xmlns   string
     25 	flags   fieldFlags
     26 	parents []string
     27 }
     28 
     29 type fieldFlags int
     30 
     31 const (
     32 	fElement fieldFlags = 1 << iota
     33 	fAttr
     34 	fCDATA
     35 	fCharData
     36 	fInnerXml
     37 	fComment
     38 	fAny
     39 
     40 	fOmitEmpty
     41 
     42 	fMode = fElement | fAttr | fCDATA | fCharData | fInnerXml | fComment | fAny
     43 )
     44 
     45 var tinfoMap = make(map[reflect.Type]*typeInfo)
     46 var tinfoLock sync.RWMutex
     47 
     48 var nameType = reflect.TypeOf(Name{})
     49 
     50 // getTypeInfo returns the typeInfo structure with details necessary
     51 // for marshaling and unmarshaling typ.
     52 func getTypeInfo(typ reflect.Type) (*typeInfo, error) {
     53 	tinfoLock.RLock()
     54 	tinfo, ok := tinfoMap[typ]
     55 	tinfoLock.RUnlock()
     56 	if ok {
     57 		return tinfo, nil
     58 	}
     59 	tinfo = &typeInfo{}
     60 	if typ.Kind() == reflect.Struct && typ != nameType {
     61 		n := typ.NumField()
     62 		for i := 0; i < n; i++ {
     63 			f := typ.Field(i)
     64 			if (f.PkgPath != "" && !f.Anonymous) || f.Tag.Get("xml") == "-" {
     65 				continue // Private field
     66 			}
     67 
     68 			// For embedded structs, embed its fields.
     69 			if f.Anonymous {
     70 				t := f.Type
     71 				if t.Kind() == reflect.Ptr {
     72 					t = t.Elem()
     73 				}
     74 				if t.Kind() == reflect.Struct {
     75 					inner, err := getTypeInfo(t)
     76 					if err != nil {
     77 						return nil, err
     78 					}
     79 					if tinfo.xmlname == nil {
     80 						tinfo.xmlname = inner.xmlname
     81 					}
     82 					for _, finfo := range inner.fields {
     83 						finfo.idx = append([]int{i}, finfo.idx...)
     84 						if err := addFieldInfo(typ, tinfo, &finfo); err != nil {
     85 							return nil, err
     86 						}
     87 					}
     88 					continue
     89 				}
     90 			}
     91 
     92 			finfo, err := structFieldInfo(typ, &f)
     93 			if err != nil {
     94 				return nil, err
     95 			}
     96 
     97 			if f.Name == "XMLName" {
     98 				tinfo.xmlname = finfo
     99 				continue
    100 			}
    101 
    102 			// Add the field if it doesn't conflict with other fields.
    103 			if err := addFieldInfo(typ, tinfo, finfo); err != nil {
    104 				return nil, err
    105 			}
    106 		}
    107 	}
    108 	tinfoLock.Lock()
    109 	tinfoMap[typ] = tinfo
    110 	tinfoLock.Unlock()
    111 	return tinfo, nil
    112 }
    113 
    114 // structFieldInfo builds and returns a fieldInfo for f.
    115 func structFieldInfo(typ reflect.Type, f *reflect.StructField) (*fieldInfo, error) {
    116 	finfo := &fieldInfo{idx: f.Index}
    117 
    118 	// Split the tag from the xml namespace if necessary.
    119 	tag := f.Tag.Get("xml")
    120 	if i := strings.Index(tag, " "); i >= 0 {
    121 		finfo.xmlns, tag = tag[:i], tag[i+1:]
    122 	}
    123 
    124 	// Parse flags.
    125 	tokens := strings.Split(tag, ",")
    126 	if len(tokens) == 1 {
    127 		finfo.flags = fElement
    128 	} else {
    129 		tag = tokens[0]
    130 		for _, flag := range tokens[1:] {
    131 			switch flag {
    132 			case "attr":
    133 				finfo.flags |= fAttr
    134 			case "cdata":
    135 				finfo.flags |= fCDATA
    136 			case "chardata":
    137 				finfo.flags |= fCharData
    138 			case "innerxml":
    139 				finfo.flags |= fInnerXml
    140 			case "comment":
    141 				finfo.flags |= fComment
    142 			case "any":
    143 				finfo.flags |= fAny
    144 			case "omitempty":
    145 				finfo.flags |= fOmitEmpty
    146 			}
    147 		}
    148 
    149 		// Validate the flags used.
    150 		valid := true
    151 		switch mode := finfo.flags & fMode; mode {
    152 		case 0:
    153 			finfo.flags |= fElement
    154 		case fAttr, fCDATA, fCharData, fInnerXml, fComment, fAny, fAny | fAttr:
    155 			if f.Name == "XMLName" || tag != "" && mode != fAttr {
    156 				valid = false
    157 			}
    158 		default:
    159 			// This will also catch multiple modes in a single field.
    160 			valid = false
    161 		}
    162 		if finfo.flags&fMode == fAny {
    163 			finfo.flags |= fElement
    164 		}
    165 		if finfo.flags&fOmitEmpty != 0 && finfo.flags&(fElement|fAttr) == 0 {
    166 			valid = false
    167 		}
    168 		if !valid {
    169 			return nil, fmt.Errorf("xml: invalid tag in field %s of type %s: %q",
    170 				f.Name, typ, f.Tag.Get("xml"))
    171 		}
    172 	}
    173 
    174 	// Use of xmlns without a name is not allowed.
    175 	if finfo.xmlns != "" && tag == "" {
    176 		return nil, fmt.Errorf("xml: namespace without name in field %s of type %s: %q",
    177 			f.Name, typ, f.Tag.Get("xml"))
    178 	}
    179 
    180 	if f.Name == "XMLName" {
    181 		// The XMLName field records the XML element name. Don't
    182 		// process it as usual because its name should default to
    183 		// empty rather than to the field name.
    184 		finfo.name = tag
    185 		return finfo, nil
    186 	}
    187 
    188 	if tag == "" {
    189 		// If the name part of the tag is completely empty, get
    190 		// default from XMLName of underlying struct if feasible,
    191 		// or field name otherwise.
    192 		if xmlname := lookupXMLName(f.Type); xmlname != nil {
    193 			finfo.xmlns, finfo.name = xmlname.xmlns, xmlname.name
    194 		} else {
    195 			finfo.name = f.Name
    196 		}
    197 		return finfo, nil
    198 	}
    199 
    200 	// Prepare field name and parents.
    201 	parents := strings.Split(tag, ">")
    202 	if parents[0] == "" {
    203 		parents[0] = f.Name
    204 	}
    205 	if parents[len(parents)-1] == "" {
    206 		return nil, fmt.Errorf("xml: trailing '>' in field %s of type %s", f.Name, typ)
    207 	}
    208 	finfo.name = parents[len(parents)-1]
    209 	if len(parents) > 1 {
    210 		if (finfo.flags & fElement) == 0 {
    211 			return nil, fmt.Errorf("xml: %s chain not valid with %s flag", tag, strings.Join(tokens[1:], ","))
    212 		}
    213 		finfo.parents = parents[:len(parents)-1]
    214 	}
    215 
    216 	// If the field type has an XMLName field, the names must match
    217 	// so that the behavior of both marshaling and unmarshaling
    218 	// is straightforward and unambiguous.
    219 	if finfo.flags&fElement != 0 {
    220 		ftyp := f.Type
    221 		xmlname := lookupXMLName(ftyp)
    222 		if xmlname != nil && xmlname.name != finfo.name {
    223 			return nil, fmt.Errorf("xml: name %q in tag of %s.%s conflicts with name %q in %s.XMLName",
    224 				finfo.name, typ, f.Name, xmlname.name, ftyp)
    225 		}
    226 	}
    227 	return finfo, nil
    228 }
    229 
    230 // lookupXMLName returns the fieldInfo for typ's XMLName field
    231 // in case it exists and has a valid xml field tag, otherwise
    232 // it returns nil.
    233 func lookupXMLName(typ reflect.Type) (xmlname *fieldInfo) {
    234 	for typ.Kind() == reflect.Ptr {
    235 		typ = typ.Elem()
    236 	}
    237 	if typ.Kind() != reflect.Struct {
    238 		return nil
    239 	}
    240 	for i, n := 0, typ.NumField(); i < n; i++ {
    241 		f := typ.Field(i)
    242 		if f.Name != "XMLName" {
    243 			continue
    244 		}
    245 		finfo, err := structFieldInfo(typ, &f)
    246 		if finfo.name != "" && err == nil {
    247 			return finfo
    248 		}
    249 		// Also consider errors as a non-existent field tag
    250 		// and let getTypeInfo itself report the error.
    251 		break
    252 	}
    253 	return nil
    254 }
    255 
    256 func min(a, b int) int {
    257 	if a <= b {
    258 		return a
    259 	}
    260 	return b
    261 }
    262 
    263 // addFieldInfo adds finfo to tinfo.fields if there are no
    264 // conflicts, or if conflicts arise from previous fields that were
    265 // obtained from deeper embedded structures than finfo. In the latter
    266 // case, the conflicting entries are dropped.
    267 // A conflict occurs when the path (parent + name) to a field is
    268 // itself a prefix of another path, or when two paths match exactly.
    269 // It is okay for field paths to share a common, shorter prefix.
    270 func addFieldInfo(typ reflect.Type, tinfo *typeInfo, newf *fieldInfo) error {
    271 	var conflicts []int
    272 Loop:
    273 	// First, figure all conflicts. Most working code will have none.
    274 	for i := range tinfo.fields {
    275 		oldf := &tinfo.fields[i]
    276 		if oldf.flags&fMode != newf.flags&fMode {
    277 			continue
    278 		}
    279 		if oldf.xmlns != "" && newf.xmlns != "" && oldf.xmlns != newf.xmlns {
    280 			continue
    281 		}
    282 		minl := min(len(newf.parents), len(oldf.parents))
    283 		for p := 0; p < minl; p++ {
    284 			if oldf.parents[p] != newf.parents[p] {
    285 				continue Loop
    286 			}
    287 		}
    288 		if len(oldf.parents) > len(newf.parents) {
    289 			if oldf.parents[len(newf.parents)] == newf.name {
    290 				conflicts = append(conflicts, i)
    291 			}
    292 		} else if len(oldf.parents) < len(newf.parents) {
    293 			if newf.parents[len(oldf.parents)] == oldf.name {
    294 				conflicts = append(conflicts, i)
    295 			}
    296 		} else {
    297 			if newf.name == oldf.name {
    298 				conflicts = append(conflicts, i)
    299 			}
    300 		}
    301 	}
    302 	// Without conflicts, add the new field and return.
    303 	if conflicts == nil {
    304 		tinfo.fields = append(tinfo.fields, *newf)
    305 		return nil
    306 	}
    307 
    308 	// If any conflict is shallower, ignore the new field.
    309 	// This matches the Go field resolution on embedding.
    310 	for _, i := range conflicts {
    311 		if len(tinfo.fields[i].idx) < len(newf.idx) {
    312 			return nil
    313 		}
    314 	}
    315 
    316 	// Otherwise, if any of them is at the same depth level, it's an error.
    317 	for _, i := range conflicts {
    318 		oldf := &tinfo.fields[i]
    319 		if len(oldf.idx) == len(newf.idx) {
    320 			f1 := typ.FieldByIndex(oldf.idx)
    321 			f2 := typ.FieldByIndex(newf.idx)
    322 			return &TagPathError{typ, f1.Name, f1.Tag.Get("xml"), f2.Name, f2.Tag.Get("xml")}
    323 		}
    324 	}
    325 
    326 	// Otherwise, the new field is shallower, and thus takes precedence,
    327 	// so drop the conflicting fields from tinfo and append the new one.
    328 	for c := len(conflicts) - 1; c >= 0; c-- {
    329 		i := conflicts[c]
    330 		copy(tinfo.fields[i:], tinfo.fields[i+1:])
    331 		tinfo.fields = tinfo.fields[:len(tinfo.fields)-1]
    332 	}
    333 	tinfo.fields = append(tinfo.fields, *newf)
    334 	return nil
    335 }
    336 
    337 // A TagPathError represents an error in the unmarshaling process
    338 // caused by the use of field tags with conflicting paths.
    339 type TagPathError struct {
    340 	Struct       reflect.Type
    341 	Field1, Tag1 string
    342 	Field2, Tag2 string
    343 }
    344 
    345 func (e *TagPathError) Error() string {
    346 	return fmt.Sprintf("%s field %q with tag %q conflicts with field %q with tag %q", e.Struct, e.Field1, e.Tag1, e.Field2, e.Tag2)
    347 }
    348 
    349 // value returns v's field value corresponding to finfo.
    350 // It's equivalent to v.FieldByIndex(finfo.idx), but initializes
    351 // and dereferences pointers as necessary.
    352 func (finfo *fieldInfo) value(v reflect.Value) reflect.Value {
    353 	for i, x := range finfo.idx {
    354 		if i > 0 {
    355 			t := v.Type()
    356 			if t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct {
    357 				if v.IsNil() {
    358 					v.Set(reflect.New(v.Type().Elem()))
    359 				}
    360 				v = v.Elem()
    361 			}
    362 		}
    363 		v = v.Field(x)
    364 	}
    365 	return v
    366 }
    367