Home | History | Annotate | Download | only in syscall
      1 // Copyright 2013 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 // +build ignore
      6 
      7 /*
      8 mksyscall_windows generates windows system call bodies
      9 
     10 It parses all files specified on command line containing function
     11 prototypes (like syscall_windows.go) and prints system call bodies
     12 to standard output.
     13 
     14 The prototypes are marked by lines beginning with "//sys" and read
     15 like func declarations if //sys is replaced by func, but:
     16 
     17 * The parameter lists must give a name for each argument. This
     18   includes return parameters.
     19 
     20 * The parameter lists must give a type for each argument:
     21   the (x, y, z int) shorthand is not allowed.
     22 
     23 * If the return parameter is an error number, it must be named err.
     24 
     25 * If go func name needs to be different from it's winapi dll name,
     26   the winapi name could be specified at the end, after "=" sign, like
     27   //sys LoadLibrary(libname string) (handle uint32, err error) = LoadLibraryA
     28 
     29 * Each function that returns err needs to supply a condition, that
     30   return value of winapi will be tested against to detect failure.
     31   This would set err to windows "last-error", otherwise it will be nil.
     32   The value can be provided at end of //sys declaration, like
     33   //sys LoadLibrary(libname string) (handle uint32, err error) [failretval==-1] = LoadLibraryA
     34   and is [failretval==0] by default.
     35 
     36 Usage:
     37 	mksyscall_windows [flags] [path ...]
     38 
     39 The flags are:
     40 	-output
     41 		Specify output file name (outputs to console if blank).
     42 	-trace
     43 		Generate print statement after every syscall.
     44 */
     45 package main
     46 
     47 import (
     48 	"bufio"
     49 	"bytes"
     50 	"errors"
     51 	"flag"
     52 	"fmt"
     53 	"go/format"
     54 	"go/parser"
     55 	"go/token"
     56 	"io"
     57 	"io/ioutil"
     58 	"log"
     59 	"os"
     60 	"path/filepath"
     61 	"runtime"
     62 	"sort"
     63 	"strconv"
     64 	"strings"
     65 	"text/template"
     66 )
     67 
     68 var (
     69 	filename       = flag.String("output", "", "output file name (standard output if omitted)")
     70 	printTraceFlag = flag.Bool("trace", false, "generate print statement after every syscall")
     71 	systemDLL      = flag.Bool("systemdll", true, "whether all DLLs should be loaded from the Windows system directory")
     72 )
     73 
     74 func trim(s string) string {
     75 	return strings.Trim(s, " \t")
     76 }
     77 
     78 var packageName string
     79 
     80 func packagename() string {
     81 	return packageName
     82 }
     83 
     84 func syscalldot() string {
     85 	if packageName == "syscall" {
     86 		return ""
     87 	}
     88 	return "syscall."
     89 }
     90 
     91 // Param is function parameter
     92 type Param struct {
     93 	Name      string
     94 	Type      string
     95 	fn        *Fn
     96 	tmpVarIdx int
     97 }
     98 
     99 // tmpVar returns temp variable name that will be used to represent p during syscall.
    100 func (p *Param) tmpVar() string {
    101 	if p.tmpVarIdx < 0 {
    102 		p.tmpVarIdx = p.fn.curTmpVarIdx
    103 		p.fn.curTmpVarIdx++
    104 	}
    105 	return fmt.Sprintf("_p%d", p.tmpVarIdx)
    106 }
    107 
    108 // BoolTmpVarCode returns source code for bool temp variable.
    109 func (p *Param) BoolTmpVarCode() string {
    110 	const code = `var %s uint32
    111 	if %s {
    112 		%s = 1
    113 	} else {
    114 		%s = 0
    115 	}`
    116 	tmp := p.tmpVar()
    117 	return fmt.Sprintf(code, tmp, p.Name, tmp, tmp)
    118 }
    119 
    120 // SliceTmpVarCode returns source code for slice temp variable.
    121 func (p *Param) SliceTmpVarCode() string {
    122 	const code = `var %s *%s
    123 	if len(%s) > 0 {
    124 		%s = &%s[0]
    125 	}`
    126 	tmp := p.tmpVar()
    127 	return fmt.Sprintf(code, tmp, p.Type[2:], p.Name, tmp, p.Name)
    128 }
    129 
    130 // StringTmpVarCode returns source code for string temp variable.
    131 func (p *Param) StringTmpVarCode() string {
    132 	errvar := p.fn.Rets.ErrorVarName()
    133 	if errvar == "" {
    134 		errvar = "_"
    135 	}
    136 	tmp := p.tmpVar()
    137 	const code = `var %s %s
    138 	%s, %s = %s(%s)`
    139 	s := fmt.Sprintf(code, tmp, p.fn.StrconvType(), tmp, errvar, p.fn.StrconvFunc(), p.Name)
    140 	if errvar == "-" {
    141 		return s
    142 	}
    143 	const morecode = `
    144 	if %s != nil {
    145 		return
    146 	}`
    147 	return s + fmt.Sprintf(morecode, errvar)
    148 }
    149 
    150 // TmpVarCode returns source code for temp variable.
    151 func (p *Param) TmpVarCode() string {
    152 	switch {
    153 	case p.Type == "bool":
    154 		return p.BoolTmpVarCode()
    155 	case strings.HasPrefix(p.Type, "[]"):
    156 		return p.SliceTmpVarCode()
    157 	default:
    158 		return ""
    159 	}
    160 }
    161 
    162 // TmpVarHelperCode returns source code for helper's temp variable.
    163 func (p *Param) TmpVarHelperCode() string {
    164 	if p.Type != "string" {
    165 		return ""
    166 	}
    167 	return p.StringTmpVarCode()
    168 }
    169 
    170 // SyscallArgList returns source code fragments representing p parameter
    171 // in syscall. Slices are translated into 2 syscall parameters: pointer to
    172 // the first element and length.
    173 func (p *Param) SyscallArgList() []string {
    174 	t := p.HelperType()
    175 	var s string
    176 	switch {
    177 	case t[0] == '*':
    178 		s = fmt.Sprintf("unsafe.Pointer(%s)", p.Name)
    179 	case t == "bool":
    180 		s = p.tmpVar()
    181 	case strings.HasPrefix(t, "[]"):
    182 		return []string{
    183 			fmt.Sprintf("uintptr(unsafe.Pointer(%s))", p.tmpVar()),
    184 			fmt.Sprintf("uintptr(len(%s))", p.Name),
    185 		}
    186 	default:
    187 		s = p.Name
    188 	}
    189 	return []string{fmt.Sprintf("uintptr(%s)", s)}
    190 }
    191 
    192 // IsError determines if p parameter is used to return error.
    193 func (p *Param) IsError() bool {
    194 	return p.Name == "err" && p.Type == "error"
    195 }
    196 
    197 // HelperType returns type of parameter p used in helper function.
    198 func (p *Param) HelperType() string {
    199 	if p.Type == "string" {
    200 		return p.fn.StrconvType()
    201 	}
    202 	return p.Type
    203 }
    204 
    205 // join concatenates parameters ps into a string with sep separator.
    206 // Each parameter is converted into string by applying fn to it
    207 // before conversion.
    208 func join(ps []*Param, fn func(*Param) string, sep string) string {
    209 	if len(ps) == 0 {
    210 		return ""
    211 	}
    212 	a := make([]string, 0)
    213 	for _, p := range ps {
    214 		a = append(a, fn(p))
    215 	}
    216 	return strings.Join(a, sep)
    217 }
    218 
    219 // Rets describes function return parameters.
    220 type Rets struct {
    221 	Name         string
    222 	Type         string
    223 	ReturnsError bool
    224 	FailCond     string
    225 }
    226 
    227 // ErrorVarName returns error variable name for r.
    228 func (r *Rets) ErrorVarName() string {
    229 	if r.ReturnsError {
    230 		return "err"
    231 	}
    232 	if r.Type == "error" {
    233 		return r.Name
    234 	}
    235 	return ""
    236 }
    237 
    238 // ToParams converts r into slice of *Param.
    239 func (r *Rets) ToParams() []*Param {
    240 	ps := make([]*Param, 0)
    241 	if len(r.Name) > 0 {
    242 		ps = append(ps, &Param{Name: r.Name, Type: r.Type})
    243 	}
    244 	if r.ReturnsError {
    245 		ps = append(ps, &Param{Name: "err", Type: "error"})
    246 	}
    247 	return ps
    248 }
    249 
    250 // List returns source code of syscall return parameters.
    251 func (r *Rets) List() string {
    252 	s := join(r.ToParams(), func(p *Param) string { return p.Name + " " + p.Type }, ", ")
    253 	if len(s) > 0 {
    254 		s = "(" + s + ")"
    255 	}
    256 	return s
    257 }
    258 
    259 // PrintList returns source code of trace printing part correspondent
    260 // to syscall return values.
    261 func (r *Rets) PrintList() string {
    262 	return join(r.ToParams(), func(p *Param) string { return fmt.Sprintf(`"%s=", %s, `, p.Name, p.Name) }, `", ", `)
    263 }
    264 
    265 // SetReturnValuesCode returns source code that accepts syscall return values.
    266 func (r *Rets) SetReturnValuesCode() string {
    267 	if r.Name == "" && !r.ReturnsError {
    268 		return ""
    269 	}
    270 	retvar := "r0"
    271 	if r.Name == "" {
    272 		retvar = "r1"
    273 	}
    274 	errvar := "_"
    275 	if r.ReturnsError {
    276 		errvar = "e1"
    277 	}
    278 	return fmt.Sprintf("%s, _, %s := ", retvar, errvar)
    279 }
    280 
    281 func (r *Rets) useLongHandleErrorCode(retvar string) string {
    282 	const code = `if %s {
    283 		if e1 != 0 {
    284 			err = errnoErr(e1)
    285 		} else {
    286 			err = %sEINVAL
    287 		}
    288 	}`
    289 	cond := retvar + " == 0"
    290 	if r.FailCond != "" {
    291 		cond = strings.Replace(r.FailCond, "failretval", retvar, 1)
    292 	}
    293 	return fmt.Sprintf(code, cond, syscalldot())
    294 }
    295 
    296 // SetErrorCode returns source code that sets return parameters.
    297 func (r *Rets) SetErrorCode() string {
    298 	const code = `if r0 != 0 {
    299 		%s = %sErrno(r0)
    300 	}`
    301 	if r.Name == "" && !r.ReturnsError {
    302 		return ""
    303 	}
    304 	if r.Name == "" {
    305 		return r.useLongHandleErrorCode("r1")
    306 	}
    307 	if r.Type == "error" {
    308 		return fmt.Sprintf(code, r.Name, syscalldot())
    309 	}
    310 	s := ""
    311 	switch {
    312 	case r.Type[0] == '*':
    313 		s = fmt.Sprintf("%s = (%s)(unsafe.Pointer(r0))", r.Name, r.Type)
    314 	case r.Type == "bool":
    315 		s = fmt.Sprintf("%s = r0 != 0", r.Name)
    316 	default:
    317 		s = fmt.Sprintf("%s = %s(r0)", r.Name, r.Type)
    318 	}
    319 	if !r.ReturnsError {
    320 		return s
    321 	}
    322 	return s + "\n\t" + r.useLongHandleErrorCode(r.Name)
    323 }
    324 
    325 // Fn describes syscall function.
    326 type Fn struct {
    327 	Name        string
    328 	Params      []*Param
    329 	Rets        *Rets
    330 	PrintTrace  bool
    331 	dllname     string
    332 	dllfuncname string
    333 	src         string
    334 	// TODO: get rid of this field and just use parameter index instead
    335 	curTmpVarIdx int // insure tmp variables have uniq names
    336 }
    337 
    338 // extractParams parses s to extract function parameters.
    339 func extractParams(s string, f *Fn) ([]*Param, error) {
    340 	s = trim(s)
    341 	if s == "" {
    342 		return nil, nil
    343 	}
    344 	a := strings.Split(s, ",")
    345 	ps := make([]*Param, len(a))
    346 	for i := range ps {
    347 		s2 := trim(a[i])
    348 		b := strings.Split(s2, " ")
    349 		if len(b) != 2 {
    350 			b = strings.Split(s2, "\t")
    351 			if len(b) != 2 {
    352 				return nil, errors.New("Could not extract function parameter from \"" + s2 + "\"")
    353 			}
    354 		}
    355 		ps[i] = &Param{
    356 			Name:      trim(b[0]),
    357 			Type:      trim(b[1]),
    358 			fn:        f,
    359 			tmpVarIdx: -1,
    360 		}
    361 	}
    362 	return ps, nil
    363 }
    364 
    365 // extractSection extracts text out of string s starting after start
    366 // and ending just before end. found return value will indicate success,
    367 // and prefix, body and suffix will contain correspondent parts of string s.
    368 func extractSection(s string, start, end rune) (prefix, body, suffix string, found bool) {
    369 	s = trim(s)
    370 	if strings.HasPrefix(s, string(start)) {
    371 		// no prefix
    372 		body = s[1:]
    373 	} else {
    374 		a := strings.SplitN(s, string(start), 2)
    375 		if len(a) != 2 {
    376 			return "", "", s, false
    377 		}
    378 		prefix = a[0]
    379 		body = a[1]
    380 	}
    381 	a := strings.SplitN(body, string(end), 2)
    382 	if len(a) != 2 {
    383 		return "", "", "", false
    384 	}
    385 	return prefix, a[0], a[1], true
    386 }
    387 
    388 // newFn parses string s and return created function Fn.
    389 func newFn(s string) (*Fn, error) {
    390 	s = trim(s)
    391 	f := &Fn{
    392 		Rets:       &Rets{},
    393 		src:        s,
    394 		PrintTrace: *printTraceFlag,
    395 	}
    396 	// function name and args
    397 	prefix, body, s, found := extractSection(s, '(', ')')
    398 	if !found || prefix == "" {
    399 		return nil, errors.New("Could not extract function name and parameters from \"" + f.src + "\"")
    400 	}
    401 	f.Name = prefix
    402 	var err error
    403 	f.Params, err = extractParams(body, f)
    404 	if err != nil {
    405 		return nil, err
    406 	}
    407 	// return values
    408 	_, body, s, found = extractSection(s, '(', ')')
    409 	if found {
    410 		r, err := extractParams(body, f)
    411 		if err != nil {
    412 			return nil, err
    413 		}
    414 		switch len(r) {
    415 		case 0:
    416 		case 1:
    417 			if r[0].IsError() {
    418 				f.Rets.ReturnsError = true
    419 			} else {
    420 				f.Rets.Name = r[0].Name
    421 				f.Rets.Type = r[0].Type
    422 			}
    423 		case 2:
    424 			if !r[1].IsError() {
    425 				return nil, errors.New("Only last windows error is allowed as second return value in \"" + f.src + "\"")
    426 			}
    427 			f.Rets.ReturnsError = true
    428 			f.Rets.Name = r[0].Name
    429 			f.Rets.Type = r[0].Type
    430 		default:
    431 			return nil, errors.New("Too many return values in \"" + f.src + "\"")
    432 		}
    433 	}
    434 	// fail condition
    435 	_, body, s, found = extractSection(s, '[', ']')
    436 	if found {
    437 		f.Rets.FailCond = body
    438 	}
    439 	// dll and dll function names
    440 	s = trim(s)
    441 	if s == "" {
    442 		return f, nil
    443 	}
    444 	if !strings.HasPrefix(s, "=") {
    445 		return nil, errors.New("Could not extract dll name from \"" + f.src + "\"")
    446 	}
    447 	s = trim(s[1:])
    448 	a := strings.Split(s, ".")
    449 	switch len(a) {
    450 	case 1:
    451 		f.dllfuncname = a[0]
    452 	case 2:
    453 		f.dllname = a[0]
    454 		f.dllfuncname = a[1]
    455 	default:
    456 		return nil, errors.New("Could not extract dll name from \"" + f.src + "\"")
    457 	}
    458 	return f, nil
    459 }
    460 
    461 // DLLName returns DLL name for function f.
    462 func (f *Fn) DLLName() string {
    463 	if f.dllname == "" {
    464 		return "kernel32"
    465 	}
    466 	return f.dllname
    467 }
    468 
    469 // DLLName returns DLL function name for function f.
    470 func (f *Fn) DLLFuncName() string {
    471 	if f.dllfuncname == "" {
    472 		return f.Name
    473 	}
    474 	return f.dllfuncname
    475 }
    476 
    477 // ParamList returns source code for function f parameters.
    478 func (f *Fn) ParamList() string {
    479 	return join(f.Params, func(p *Param) string { return p.Name + " " + p.Type }, ", ")
    480 }
    481 
    482 // HelperParamList returns source code for helper function f parameters.
    483 func (f *Fn) HelperParamList() string {
    484 	return join(f.Params, func(p *Param) string { return p.Name + " " + p.HelperType() }, ", ")
    485 }
    486 
    487 // ParamPrintList returns source code of trace printing part correspondent
    488 // to syscall input parameters.
    489 func (f *Fn) ParamPrintList() string {
    490 	return join(f.Params, func(p *Param) string { return fmt.Sprintf(`"%s=", %s, `, p.Name, p.Name) }, `", ", `)
    491 }
    492 
    493 // ParamCount return number of syscall parameters for function f.
    494 func (f *Fn) ParamCount() int {
    495 	n := 0
    496 	for _, p := range f.Params {
    497 		n += len(p.SyscallArgList())
    498 	}
    499 	return n
    500 }
    501 
    502 // SyscallParamCount determines which version of Syscall/Syscall6/Syscall9/...
    503 // to use. It returns parameter count for correspondent SyscallX function.
    504 func (f *Fn) SyscallParamCount() int {
    505 	n := f.ParamCount()
    506 	switch {
    507 	case n <= 3:
    508 		return 3
    509 	case n <= 6:
    510 		return 6
    511 	case n <= 9:
    512 		return 9
    513 	case n <= 12:
    514 		return 12
    515 	case n <= 15:
    516 		return 15
    517 	default:
    518 		panic("too many arguments to system call")
    519 	}
    520 }
    521 
    522 // Syscall determines which SyscallX function to use for function f.
    523 func (f *Fn) Syscall() string {
    524 	c := f.SyscallParamCount()
    525 	if c == 3 {
    526 		return syscalldot() + "Syscall"
    527 	}
    528 	return syscalldot() + "Syscall" + strconv.Itoa(c)
    529 }
    530 
    531 // SyscallParamList returns source code for SyscallX parameters for function f.
    532 func (f *Fn) SyscallParamList() string {
    533 	a := make([]string, 0)
    534 	for _, p := range f.Params {
    535 		a = append(a, p.SyscallArgList()...)
    536 	}
    537 	for len(a) < f.SyscallParamCount() {
    538 		a = append(a, "0")
    539 	}
    540 	return strings.Join(a, ", ")
    541 }
    542 
    543 // HelperCallParamList returns source code of call into function f helper.
    544 func (f *Fn) HelperCallParamList() string {
    545 	a := make([]string, 0, len(f.Params))
    546 	for _, p := range f.Params {
    547 		s := p.Name
    548 		if p.Type == "string" {
    549 			s = p.tmpVar()
    550 		}
    551 		a = append(a, s)
    552 	}
    553 	return strings.Join(a, ", ")
    554 }
    555 
    556 // IsUTF16 is true, if f is W (utf16) function. It is false
    557 // for all A (ascii) functions.
    558 func (f *Fn) IsUTF16() bool {
    559 	s := f.DLLFuncName()
    560 	return s[len(s)-1] == 'W'
    561 }
    562 
    563 // StrconvFunc returns name of Go string to OS string function for f.
    564 func (f *Fn) StrconvFunc() string {
    565 	if f.IsUTF16() {
    566 		return syscalldot() + "UTF16PtrFromString"
    567 	}
    568 	return syscalldot() + "BytePtrFromString"
    569 }
    570 
    571 // StrconvType returns Go type name used for OS string for f.
    572 func (f *Fn) StrconvType() string {
    573 	if f.IsUTF16() {
    574 		return "*uint16"
    575 	}
    576 	return "*byte"
    577 }
    578 
    579 // HasStringParam is true, if f has at least one string parameter.
    580 // Otherwise it is false.
    581 func (f *Fn) HasStringParam() bool {
    582 	for _, p := range f.Params {
    583 		if p.Type == "string" {
    584 			return true
    585 		}
    586 	}
    587 	return false
    588 }
    589 
    590 // HelperName returns name of function f helper.
    591 func (f *Fn) HelperName() string {
    592 	if !f.HasStringParam() {
    593 		return f.Name
    594 	}
    595 	return "_" + f.Name
    596 }
    597 
    598 // Source files and functions.
    599 type Source struct {
    600 	Funcs           []*Fn
    601 	Files           []string
    602 	StdLibImports   []string
    603 	ExternalImports []string
    604 }
    605 
    606 func (src *Source) Import(pkg string) {
    607 	src.StdLibImports = append(src.StdLibImports, pkg)
    608 	sort.Strings(src.StdLibImports)
    609 }
    610 
    611 func (src *Source) ExternalImport(pkg string) {
    612 	src.ExternalImports = append(src.ExternalImports, pkg)
    613 	sort.Strings(src.ExternalImports)
    614 }
    615 
    616 // ParseFiles parses files listed in fs and extracts all syscall
    617 // functions listed in sys comments. It returns source files
    618 // and functions collection *Source if successful.
    619 func ParseFiles(fs []string) (*Source, error) {
    620 	src := &Source{
    621 		Funcs: make([]*Fn, 0),
    622 		Files: make([]string, 0),
    623 		StdLibImports: []string{
    624 			"unsafe",
    625 		},
    626 		ExternalImports: make([]string, 0),
    627 	}
    628 	for _, file := range fs {
    629 		if err := src.ParseFile(file); err != nil {
    630 			return nil, err
    631 		}
    632 	}
    633 	return src, nil
    634 }
    635 
    636 // DLLs return dll names for a source set src.
    637 func (src *Source) DLLs() []string {
    638 	uniq := make(map[string]bool)
    639 	r := make([]string, 0)
    640 	for _, f := range src.Funcs {
    641 		name := f.DLLName()
    642 		if _, found := uniq[name]; !found {
    643 			uniq[name] = true
    644 			r = append(r, name)
    645 		}
    646 	}
    647 	return r
    648 }
    649 
    650 // ParseFile adds additional file path to a source set src.
    651 func (src *Source) ParseFile(path string) error {
    652 	file, err := os.Open(path)
    653 	if err != nil {
    654 		return err
    655 	}
    656 	defer file.Close()
    657 
    658 	s := bufio.NewScanner(file)
    659 	for s.Scan() {
    660 		t := trim(s.Text())
    661 		if len(t) < 7 {
    662 			continue
    663 		}
    664 		if !strings.HasPrefix(t, "//sys") {
    665 			continue
    666 		}
    667 		t = t[5:]
    668 		if !(t[0] == ' ' || t[0] == '\t') {
    669 			continue
    670 		}
    671 		f, err := newFn(t[1:])
    672 		if err != nil {
    673 			return err
    674 		}
    675 		src.Funcs = append(src.Funcs, f)
    676 	}
    677 	if err := s.Err(); err != nil {
    678 		return err
    679 	}
    680 	src.Files = append(src.Files, path)
    681 
    682 	// get package name
    683 	fset := token.NewFileSet()
    684 	_, err = file.Seek(0, 0)
    685 	if err != nil {
    686 		return err
    687 	}
    688 	pkg, err := parser.ParseFile(fset, "", file, parser.PackageClauseOnly)
    689 	if err != nil {
    690 		return err
    691 	}
    692 	packageName = pkg.Name.Name
    693 
    694 	return nil
    695 }
    696 
    697 // IsStdRepo returns true if src is part of standard library.
    698 func (src *Source) IsStdRepo() (bool, error) {
    699 	if len(src.Files) == 0 {
    700 		return false, errors.New("no input files provided")
    701 	}
    702 	abspath, err := filepath.Abs(src.Files[0])
    703 	if err != nil {
    704 		return false, err
    705 	}
    706 	goroot := runtime.GOROOT()
    707 	if runtime.GOOS == "windows" {
    708 		abspath = strings.ToLower(abspath)
    709 		goroot = strings.ToLower(goroot)
    710 	}
    711 	sep := string(os.PathSeparator)
    712 	if !strings.HasSuffix(goroot, sep) {
    713 		goroot += sep
    714 	}
    715 	return strings.HasPrefix(abspath, goroot), nil
    716 }
    717 
    718 // Generate output source file from a source set src.
    719 func (src *Source) Generate(w io.Writer) error {
    720 	const (
    721 		pkgStd         = iota // any package in std library
    722 		pkgXSysWindows        // x/sys/windows package
    723 		pkgOther
    724 	)
    725 	isStdRepo, err := src.IsStdRepo()
    726 	if err != nil {
    727 		return err
    728 	}
    729 	var pkgtype int
    730 	switch {
    731 	case isStdRepo:
    732 		pkgtype = pkgStd
    733 	case packageName == "windows":
    734 		// TODO: this needs better logic than just using package name
    735 		pkgtype = pkgXSysWindows
    736 	default:
    737 		pkgtype = pkgOther
    738 	}
    739 	if *systemDLL {
    740 		switch pkgtype {
    741 		case pkgStd:
    742 			src.Import("internal/syscall/windows/sysdll")
    743 		case pkgXSysWindows:
    744 		default:
    745 			src.ExternalImport("golang.org/x/sys/windows")
    746 		}
    747 	}
    748 	if packageName != "syscall" {
    749 		src.Import("syscall")
    750 	}
    751 	funcMap := template.FuncMap{
    752 		"packagename": packagename,
    753 		"syscalldot":  syscalldot,
    754 		"newlazydll": func(dll string) string {
    755 			arg := "\"" + dll + ".dll\""
    756 			if !*systemDLL {
    757 				return syscalldot() + "NewLazyDLL(" + arg + ")"
    758 			}
    759 			switch pkgtype {
    760 			case pkgStd:
    761 				return syscalldot() + "NewLazyDLL(sysdll.Add(" + arg + "))"
    762 			case pkgXSysWindows:
    763 				return "NewLazySystemDLL(" + arg + ")"
    764 			default:
    765 				return "windows.NewLazySystemDLL(" + arg + ")"
    766 			}
    767 		},
    768 	}
    769 	t := template.Must(template.New("main").Funcs(funcMap).Parse(srcTemplate))
    770 	err = t.Execute(w, src)
    771 	if err != nil {
    772 		return errors.New("Failed to execute template: " + err.Error())
    773 	}
    774 	return nil
    775 }
    776 
    777 func usage() {
    778 	fmt.Fprintf(os.Stderr, "usage: mksyscall_windows [flags] [path ...]\n")
    779 	flag.PrintDefaults()
    780 	os.Exit(1)
    781 }
    782 
    783 func main() {
    784 	flag.Usage = usage
    785 	flag.Parse()
    786 	if len(flag.Args()) <= 0 {
    787 		fmt.Fprintf(os.Stderr, "no files to parse provided\n")
    788 		usage()
    789 	}
    790 
    791 	src, err := ParseFiles(flag.Args())
    792 	if err != nil {
    793 		log.Fatal(err)
    794 	}
    795 
    796 	var buf bytes.Buffer
    797 	if err := src.Generate(&buf); err != nil {
    798 		log.Fatal(err)
    799 	}
    800 
    801 	data, err := format.Source(buf.Bytes())
    802 	if err != nil {
    803 		log.Fatal(err)
    804 	}
    805 	if *filename == "" {
    806 		_, err = os.Stdout.Write(data)
    807 	} else {
    808 		err = ioutil.WriteFile(*filename, data, 0644)
    809 	}
    810 	if err != nil {
    811 		log.Fatal(err)
    812 	}
    813 }
    814 
    815 // TODO: use println instead to print in the following template
    816 const srcTemplate = `
    817 
    818 {{define "main"}}// MACHINE GENERATED BY 'go generate' COMMAND; DO NOT EDIT
    819 
    820 package {{packagename}}
    821 
    822 import (
    823 {{range .StdLibImports}}"{{.}}"
    824 {{end}}
    825 
    826 {{range .ExternalImports}}"{{.}}"
    827 {{end}}
    828 )
    829 
    830 var _ unsafe.Pointer
    831 
    832 // Do the interface allocations only once for common
    833 // Errno values.
    834 const (
    835 	errnoERROR_IO_PENDING = 997
    836 )
    837 
    838 var (
    839 	errERROR_IO_PENDING error = {{syscalldot}}Errno(errnoERROR_IO_PENDING)
    840 )
    841 
    842 // errnoErr returns common boxed Errno values, to prevent
    843 // allocations at runtime.
    844 func errnoErr(e {{syscalldot}}Errno) error {
    845 	switch e {
    846 	case 0:
    847 		return nil
    848 	case errnoERROR_IO_PENDING:
    849 		return errERROR_IO_PENDING
    850 	}
    851 	// TODO: add more here, after collecting data on the common
    852 	// error values see on Windows. (perhaps when running
    853 	// all.bat?)
    854 	return e
    855 }
    856 
    857 var (
    858 {{template "dlls" .}}
    859 {{template "funcnames" .}})
    860 {{range .Funcs}}{{if .HasStringParam}}{{template "helperbody" .}}{{end}}{{template "funcbody" .}}{{end}}
    861 {{end}}
    862 
    863 {{/* help functions */}}
    864 
    865 {{define "dlls"}}{{range .DLLs}}	mod{{.}} = {{newlazydll .}}
    866 {{end}}{{end}}
    867 
    868 {{define "funcnames"}}{{range .Funcs}}	proc{{.DLLFuncName}} = mod{{.DLLName}}.NewProc("{{.DLLFuncName}}")
    869 {{end}}{{end}}
    870 
    871 {{define "helperbody"}}
    872 func {{.Name}}({{.ParamList}}) {{template "results" .}}{
    873 {{template "helpertmpvars" .}}	return {{.HelperName}}({{.HelperCallParamList}})
    874 }
    875 {{end}}
    876 
    877 {{define "funcbody"}}
    878 func {{.HelperName}}({{.HelperParamList}}) {{template "results" .}}{
    879 {{template "tmpvars" .}}	{{template "syscall" .}}
    880 {{template "seterror" .}}{{template "printtrace" .}}	return
    881 }
    882 {{end}}
    883 
    884 {{define "helpertmpvars"}}{{range .Params}}{{if .TmpVarHelperCode}}	{{.TmpVarHelperCode}}
    885 {{end}}{{end}}{{end}}
    886 
    887 {{define "tmpvars"}}{{range .Params}}{{if .TmpVarCode}}	{{.TmpVarCode}}
    888 {{end}}{{end}}{{end}}
    889 
    890 {{define "results"}}{{if .Rets.List}}{{.Rets.List}} {{end}}{{end}}
    891 
    892 {{define "syscall"}}{{.Rets.SetReturnValuesCode}}{{.Syscall}}(proc{{.DLLFuncName}}.Addr(), {{.ParamCount}}, {{.SyscallParamList}}){{end}}
    893 
    894 {{define "seterror"}}{{if .Rets.SetErrorCode}}	{{.Rets.SetErrorCode}}
    895 {{end}}{{end}}
    896 
    897 {{define "printtrace"}}{{if .PrintTrace}}	print("SYSCALL: {{.Name}}(", {{.ParamPrintList}}") (", {{.Rets.PrintList}}")\n")
    898 {{end}}{{end}}
    899 
    900 `
    901