Home | History | Annotate | Download | only in cgo
      1 // Copyright 2011 The Go Authors.  All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 package main
      6 
      7 import (
      8 	"bytes"
      9 	"fmt"
     10 	"go/ast"
     11 	"go/printer"
     12 	"go/token"
     13 	"os"
     14 	"strings"
     15 )
     16 
     17 // godefs returns the output for -godefs mode.
     18 func (p *Package) godefs(f *File, srcfile string) string {
     19 	var buf bytes.Buffer
     20 
     21 	fmt.Fprintf(&buf, "// Created by cgo -godefs - DO NOT EDIT\n")
     22 	fmt.Fprintf(&buf, "// %s\n", strings.Join(os.Args, " "))
     23 	fmt.Fprintf(&buf, "\n")
     24 
     25 	override := make(map[string]string)
     26 
     27 	// Allow source file to specify override mappings.
     28 	// For example, the socket data structures refer
     29 	// to in_addr and in_addr6 structs but we want to be
     30 	// able to treat them as byte arrays, so the godefs
     31 	// inputs in package syscall say
     32 	//
     33 	//	// +godefs map struct_in_addr [4]byte
     34 	//	// +godefs map struct_in_addr6 [16]byte
     35 	//
     36 	for _, g := range f.Comments {
     37 		for _, c := range g.List {
     38 			i := strings.Index(c.Text, "+godefs map")
     39 			if i < 0 {
     40 				continue
     41 			}
     42 			s := strings.TrimSpace(c.Text[i+len("+godefs map"):])
     43 			i = strings.Index(s, " ")
     44 			if i < 0 {
     45 				fmt.Fprintf(os.Stderr, "invalid +godefs map comment: %s\n", c.Text)
     46 				continue
     47 			}
     48 			override["_Ctype_"+strings.TrimSpace(s[:i])] = strings.TrimSpace(s[i:])
     49 		}
     50 	}
     51 	for _, n := range f.Name {
     52 		if s := override[n.Go]; s != "" {
     53 			override[n.Mangle] = s
     54 		}
     55 	}
     56 
     57 	// Otherwise, if the source file says type T C.whatever,
     58 	// use "T" as the mangling of C.whatever,
     59 	// except in the definition (handled at end of function).
     60 	refName := make(map[*ast.Expr]*Name)
     61 	for _, r := range f.Ref {
     62 		refName[r.Expr] = r.Name
     63 	}
     64 	for _, d := range f.AST.Decls {
     65 		d, ok := d.(*ast.GenDecl)
     66 		if !ok || d.Tok != token.TYPE {
     67 			continue
     68 		}
     69 		for _, s := range d.Specs {
     70 			s := s.(*ast.TypeSpec)
     71 			n := refName[&s.Type]
     72 			if n != nil && n.Mangle != "" {
     73 				override[n.Mangle] = s.Name.Name
     74 			}
     75 		}
     76 	}
     77 
     78 	// Extend overrides using typedefs:
     79 	// If we know that C.xxx should format as T
     80 	// and xxx is a typedef for yyy, make C.yyy format as T.
     81 	for typ, def := range typedef {
     82 		if new := override[typ]; new != "" {
     83 			if id, ok := def.Go.(*ast.Ident); ok {
     84 				override[id.Name] = new
     85 			}
     86 		}
     87 	}
     88 
     89 	// Apply overrides.
     90 	for old, new := range override {
     91 		if id := goIdent[old]; id != nil {
     92 			id.Name = new
     93 		}
     94 	}
     95 
     96 	// Any names still using the _C syntax are not going to compile,
     97 	// although in general we don't know whether they all made it
     98 	// into the file, so we can't warn here.
     99 	//
    100 	// The most common case is union types, which begin with
    101 	// _Ctype_union and for which typedef[name] is a Go byte
    102 	// array of the appropriate size (such as [4]byte).
    103 	// Substitute those union types with byte arrays.
    104 	for name, id := range goIdent {
    105 		if id.Name == name && strings.Contains(name, "_Ctype_union") {
    106 			if def := typedef[name]; def != nil {
    107 				id.Name = gofmt(def)
    108 			}
    109 		}
    110 	}
    111 
    112 	conf.Fprint(&buf, fset, f.AST)
    113 
    114 	return buf.String()
    115 }
    116 
    117 var gofmtBuf bytes.Buffer
    118 
    119 // gofmt returns the gofmt-formatted string for an AST node.
    120 func gofmt(n interface{}) string {
    121 	gofmtBuf.Reset()
    122 	err := printer.Fprint(&gofmtBuf, fset, n)
    123 	if err != nil {
    124 		return "<" + err.Error() + ">"
    125 	}
    126 	return gofmtBuf.String()
    127 }
    128