Home | History | Annotate | Download | only in doc
      1 // Copyright 2011 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 // This file implements export filtering of an AST.
      6 
      7 package doc
      8 
      9 import (
     10 	"go/ast"
     11 	"go/token"
     12 )
     13 
     14 // filterIdentList removes unexported names from list in place
     15 // and returns the resulting list.
     16 //
     17 func filterIdentList(list []*ast.Ident) []*ast.Ident {
     18 	j := 0
     19 	for _, x := range list {
     20 		if ast.IsExported(x.Name) {
     21 			list[j] = x
     22 			j++
     23 		}
     24 	}
     25 	return list[0:j]
     26 }
     27 
     28 // hasExportedName reports whether list contains any exported names.
     29 //
     30 func hasExportedName(list []*ast.Ident) bool {
     31 	for _, x := range list {
     32 		if x.IsExported() {
     33 			return true
     34 		}
     35 	}
     36 	return false
     37 }
     38 
     39 // removeErrorField removes anonymous fields named "error" from an interface.
     40 // This is called when "error" has been determined to be a local name,
     41 // not the predeclared type.
     42 //
     43 func removeErrorField(ityp *ast.InterfaceType) {
     44 	list := ityp.Methods.List // we know that ityp.Methods != nil
     45 	j := 0
     46 	for _, field := range list {
     47 		keepField := true
     48 		if n := len(field.Names); n == 0 {
     49 			// anonymous field
     50 			if fname, _ := baseTypeName(field.Type); fname == "error" {
     51 				keepField = false
     52 			}
     53 		}
     54 		if keepField {
     55 			list[j] = field
     56 			j++
     57 		}
     58 	}
     59 	if j < len(list) {
     60 		ityp.Incomplete = true
     61 	}
     62 	ityp.Methods.List = list[0:j]
     63 }
     64 
     65 // filterFieldList removes unexported fields (field names) from the field list
     66 // in place and reports whether fields were removed. Anonymous fields are
     67 // recorded with the parent type. filterType is called with the types of
     68 // all remaining fields.
     69 //
     70 func (r *reader) filterFieldList(parent *namedType, fields *ast.FieldList, ityp *ast.InterfaceType) (removedFields bool) {
     71 	if fields == nil {
     72 		return
     73 	}
     74 	list := fields.List
     75 	j := 0
     76 	for _, field := range list {
     77 		keepField := false
     78 		if n := len(field.Names); n == 0 {
     79 			// anonymous field
     80 			fname := r.recordAnonymousField(parent, field.Type)
     81 			if ast.IsExported(fname) {
     82 				keepField = true
     83 			} else if ityp != nil && fname == "error" {
     84 				// possibly the predeclared error interface; keep
     85 				// it for now but remember this interface so that
     86 				// it can be fixed if error is also defined locally
     87 				keepField = true
     88 				r.remember(ityp)
     89 			}
     90 		} else {
     91 			field.Names = filterIdentList(field.Names)
     92 			if len(field.Names) < n {
     93 				removedFields = true
     94 			}
     95 			if len(field.Names) > 0 {
     96 				keepField = true
     97 			}
     98 		}
     99 		if keepField {
    100 			r.filterType(nil, field.Type)
    101 			list[j] = field
    102 			j++
    103 		}
    104 	}
    105 	if j < len(list) {
    106 		removedFields = true
    107 	}
    108 	fields.List = list[0:j]
    109 	return
    110 }
    111 
    112 // filterParamList applies filterType to each parameter type in fields.
    113 //
    114 func (r *reader) filterParamList(fields *ast.FieldList) {
    115 	if fields != nil {
    116 		for _, f := range fields.List {
    117 			r.filterType(nil, f.Type)
    118 		}
    119 	}
    120 }
    121 
    122 // filterType strips any unexported struct fields or method types from typ
    123 // in place. If fields (or methods) have been removed, the corresponding
    124 // struct or interface type has the Incomplete field set to true.
    125 //
    126 func (r *reader) filterType(parent *namedType, typ ast.Expr) {
    127 	switch t := typ.(type) {
    128 	case *ast.Ident:
    129 		// nothing to do
    130 	case *ast.ParenExpr:
    131 		r.filterType(nil, t.X)
    132 	case *ast.ArrayType:
    133 		r.filterType(nil, t.Elt)
    134 	case *ast.StructType:
    135 		if r.filterFieldList(parent, t.Fields, nil) {
    136 			t.Incomplete = true
    137 		}
    138 	case *ast.FuncType:
    139 		r.filterParamList(t.Params)
    140 		r.filterParamList(t.Results)
    141 	case *ast.InterfaceType:
    142 		if r.filterFieldList(parent, t.Methods, t) {
    143 			t.Incomplete = true
    144 		}
    145 	case *ast.MapType:
    146 		r.filterType(nil, t.Key)
    147 		r.filterType(nil, t.Value)
    148 	case *ast.ChanType:
    149 		r.filterType(nil, t.Value)
    150 	}
    151 }
    152 
    153 func (r *reader) filterSpec(spec ast.Spec, tok token.Token) bool {
    154 	switch s := spec.(type) {
    155 	case *ast.ImportSpec:
    156 		// always keep imports so we can collect them
    157 		return true
    158 	case *ast.ValueSpec:
    159 		s.Names = filterIdentList(s.Names)
    160 		if len(s.Names) > 0 {
    161 			r.filterType(nil, s.Type)
    162 			return true
    163 		}
    164 	case *ast.TypeSpec:
    165 		if name := s.Name.Name; ast.IsExported(name) {
    166 			r.filterType(r.lookupType(s.Name.Name), s.Type)
    167 			return true
    168 		} else if name == "error" {
    169 			// special case: remember that error is declared locally
    170 			r.errorDecl = true
    171 		}
    172 	}
    173 	return false
    174 }
    175 
    176 // copyConstType returns a copy of typ with position pos.
    177 // typ must be a valid constant type.
    178 // In practice, only (possibly qualified) identifiers are possible.
    179 //
    180 func copyConstType(typ ast.Expr, pos token.Pos) ast.Expr {
    181 	switch typ := typ.(type) {
    182 	case *ast.Ident:
    183 		return &ast.Ident{Name: typ.Name, NamePos: pos}
    184 	case *ast.SelectorExpr:
    185 		if id, ok := typ.X.(*ast.Ident); ok {
    186 			// presumably a qualified identifier
    187 			return &ast.SelectorExpr{
    188 				Sel: ast.NewIdent(typ.Sel.Name),
    189 				X:   &ast.Ident{Name: id.Name, NamePos: pos},
    190 			}
    191 		}
    192 	}
    193 	return nil // shouldn't happen, but be conservative and don't panic
    194 }
    195 
    196 func (r *reader) filterSpecList(list []ast.Spec, tok token.Token) []ast.Spec {
    197 	if tok == token.CONST {
    198 		// Propagate any type information that would get lost otherwise
    199 		// when unexported constants are filtered.
    200 		var prevType ast.Expr
    201 		for _, spec := range list {
    202 			spec := spec.(*ast.ValueSpec)
    203 			if spec.Type == nil && prevType != nil {
    204 				// provide current spec with an explicit type
    205 				spec.Type = copyConstType(prevType, spec.Pos())
    206 			}
    207 			if hasExportedName(spec.Names) {
    208 				// exported names are preserved so there's no need to propagate the type
    209 				prevType = nil
    210 			} else {
    211 				prevType = spec.Type
    212 			}
    213 		}
    214 	}
    215 
    216 	j := 0
    217 	for _, s := range list {
    218 		if r.filterSpec(s, tok) {
    219 			list[j] = s
    220 			j++
    221 		}
    222 	}
    223 	return list[0:j]
    224 }
    225 
    226 func (r *reader) filterDecl(decl ast.Decl) bool {
    227 	switch d := decl.(type) {
    228 	case *ast.GenDecl:
    229 		d.Specs = r.filterSpecList(d.Specs, d.Tok)
    230 		return len(d.Specs) > 0
    231 	case *ast.FuncDecl:
    232 		// ok to filter these methods early because any
    233 		// conflicting method will be filtered here, too -
    234 		// thus, removing these methods early will not lead
    235 		// to the false removal of possible conflicts
    236 		return ast.IsExported(d.Name.Name)
    237 	}
    238 	return false
    239 }
    240 
    241 // fileExports removes unexported declarations from src in place.
    242 //
    243 func (r *reader) fileExports(src *ast.File) {
    244 	j := 0
    245 	for _, d := range src.Decls {
    246 		if r.filterDecl(d) {
    247 			src.Decls[j] = d
    248 			j++
    249 		}
    250 	}
    251 	src.Decls = src.Decls[0:j]
    252 }
    253