1 // Copyright 2013 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // +build ignore 6 7 /* 8 mksyscall_windows generates windows system call bodies 9 10 It parses all files specified on command line containing function 11 prototypes (like syscall_windows.go) and prints system call bodies 12 to standard output. 13 14 The prototypes are marked by lines beginning with "//sys" and read 15 like func declarations if //sys is replaced by func, but: 16 17 * The parameter lists must give a name for each argument. This 18 includes return parameters. 19 20 * The parameter lists must give a type for each argument: 21 the (x, y, z int) shorthand is not allowed. 22 23 * If the return parameter is an error number, it must be named err. 24 25 * If go func name needs to be different from it's winapi dll name, 26 the winapi name could be specified at the end, after "=" sign, like 27 //sys LoadLibrary(libname string) (handle uint32, err error) = LoadLibraryA 28 29 * Each function that returns err needs to supply a condition, that 30 return value of winapi will be tested against to detect failure. 31 This would set err to windows "last-error", otherwise it will be nil. 32 The value can be provided at end of //sys declaration, like 33 //sys LoadLibrary(libname string) (handle uint32, err error) [failretval==-1] = LoadLibraryA 34 and is [failretval==0] by default. 35 36 Usage: 37 mksyscall_windows [flags] [path ...] 38 39 The flags are: 40 -output 41 Specify output file name (outputs to console if blank). 42 -trace 43 Generate print statement after every syscall. 44 */ 45 package main 46 47 import ( 48 "bufio" 49 "bytes" 50 "errors" 51 "flag" 52 "fmt" 53 "go/format" 54 "go/parser" 55 "go/token" 56 "io" 57 "io/ioutil" 58 "log" 59 "os" 60 "strconv" 61 "strings" 62 "text/template" 63 ) 64 65 var ( 66 filename = flag.String("output", "", "output file name (standard output if omitted)") 67 printTraceFlag = flag.Bool("trace", false, "generate print statement after every syscall") 68 ) 69 70 func trim(s string) string { 71 return strings.Trim(s, " \t") 72 } 73 74 var packageName string 75 76 func packagename() string { 77 return packageName 78 } 79 80 func syscalldot() string { 81 if packageName == "syscall" { 82 return "" 83 } 84 return "syscall." 85 } 86 87 // Param is function parameter 88 type Param struct { 89 Name string 90 Type string 91 fn *Fn 92 tmpVarIdx int 93 } 94 95 // tmpVar returns temp variable name that will be used to represent p during syscall. 96 func (p *Param) tmpVar() string { 97 if p.tmpVarIdx < 0 { 98 p.tmpVarIdx = p.fn.curTmpVarIdx 99 p.fn.curTmpVarIdx++ 100 } 101 return fmt.Sprintf("_p%d", p.tmpVarIdx) 102 } 103 104 // BoolTmpVarCode returns source code for bool temp variable. 105 func (p *Param) BoolTmpVarCode() string { 106 const code = `var %s uint32 107 if %s { 108 %s = 1 109 } else { 110 %s = 0 111 }` 112 tmp := p.tmpVar() 113 return fmt.Sprintf(code, tmp, p.Name, tmp, tmp) 114 } 115 116 // SliceTmpVarCode returns source code for slice temp variable. 117 func (p *Param) SliceTmpVarCode() string { 118 const code = `var %s *%s 119 if len(%s) > 0 { 120 %s = &%s[0] 121 }` 122 tmp := p.tmpVar() 123 return fmt.Sprintf(code, tmp, p.Type[2:], p.Name, tmp, p.Name) 124 } 125 126 // StringTmpVarCode returns source code for string temp variable. 127 func (p *Param) StringTmpVarCode() string { 128 errvar := p.fn.Rets.ErrorVarName() 129 if errvar == "" { 130 errvar = "_" 131 } 132 tmp := p.tmpVar() 133 const code = `var %s %s 134 %s, %s = %s(%s)` 135 s := fmt.Sprintf(code, tmp, p.fn.StrconvType(), tmp, errvar, p.fn.StrconvFunc(), p.Name) 136 if errvar == "-" { 137 return s 138 } 139 const morecode = ` 140 if %s != nil { 141 return 142 }` 143 return s + fmt.Sprintf(morecode, errvar) 144 } 145 146 // TmpVarCode returns source code for temp variable. 147 func (p *Param) TmpVarCode() string { 148 switch { 149 case p.Type == "bool": 150 return p.BoolTmpVarCode() 151 case strings.HasPrefix(p.Type, "[]"): 152 return p.SliceTmpVarCode() 153 default: 154 return "" 155 } 156 } 157 158 // TmpVarHelperCode returns source code for helper's temp variable. 159 func (p *Param) TmpVarHelperCode() string { 160 if p.Type != "string" { 161 return "" 162 } 163 return p.StringTmpVarCode() 164 } 165 166 // SyscallArgList returns source code fragments representing p parameter 167 // in syscall. Slices are translated into 2 syscall parameters: pointer to 168 // the first element and length. 169 func (p *Param) SyscallArgList() []string { 170 t := p.HelperType() 171 var s string 172 switch { 173 case t[0] == '*': 174 s = fmt.Sprintf("unsafe.Pointer(%s)", p.Name) 175 case t == "bool": 176 s = p.tmpVar() 177 case strings.HasPrefix(t, "[]"): 178 return []string{ 179 fmt.Sprintf("uintptr(unsafe.Pointer(%s))", p.tmpVar()), 180 fmt.Sprintf("uintptr(len(%s))", p.Name), 181 } 182 default: 183 s = p.Name 184 } 185 return []string{fmt.Sprintf("uintptr(%s)", s)} 186 } 187 188 // IsError determines if p parameter is used to return error. 189 func (p *Param) IsError() bool { 190 return p.Name == "err" && p.Type == "error" 191 } 192 193 // HelperType returns type of parameter p used in helper function. 194 func (p *Param) HelperType() string { 195 if p.Type == "string" { 196 return p.fn.StrconvType() 197 } 198 return p.Type 199 } 200 201 // join concatenates parameters ps into a string with sep separator. 202 // Each parameter is converted into string by applying fn to it 203 // before conversion. 204 func join(ps []*Param, fn func(*Param) string, sep string) string { 205 if len(ps) == 0 { 206 return "" 207 } 208 a := make([]string, 0) 209 for _, p := range ps { 210 a = append(a, fn(p)) 211 } 212 return strings.Join(a, sep) 213 } 214 215 // Rets describes function return parameters. 216 type Rets struct { 217 Name string 218 Type string 219 ReturnsError bool 220 FailCond string 221 } 222 223 // ErrorVarName returns error variable name for r. 224 func (r *Rets) ErrorVarName() string { 225 if r.ReturnsError { 226 return "err" 227 } 228 if r.Type == "error" { 229 return r.Name 230 } 231 return "" 232 } 233 234 // ToParams converts r into slice of *Param. 235 func (r *Rets) ToParams() []*Param { 236 ps := make([]*Param, 0) 237 if len(r.Name) > 0 { 238 ps = append(ps, &Param{Name: r.Name, Type: r.Type}) 239 } 240 if r.ReturnsError { 241 ps = append(ps, &Param{Name: "err", Type: "error"}) 242 } 243 return ps 244 } 245 246 // List returns source code of syscall return parameters. 247 func (r *Rets) List() string { 248 s := join(r.ToParams(), func(p *Param) string { return p.Name + " " + p.Type }, ", ") 249 if len(s) > 0 { 250 s = "(" + s + ")" 251 } 252 return s 253 } 254 255 // PrintList returns source code of trace printing part correspondent 256 // to syscall return values. 257 func (r *Rets) PrintList() string { 258 return join(r.ToParams(), func(p *Param) string { return fmt.Sprintf(`"%s=", %s, `, p.Name, p.Name) }, `", ", `) 259 } 260 261 // SetReturnValuesCode returns source code that accepts syscall return values. 262 func (r *Rets) SetReturnValuesCode() string { 263 if r.Name == "" && !r.ReturnsError { 264 return "" 265 } 266 retvar := "r0" 267 if r.Name == "" { 268 retvar = "r1" 269 } 270 errvar := "_" 271 if r.ReturnsError { 272 errvar = "e1" 273 } 274 return fmt.Sprintf("%s, _, %s := ", retvar, errvar) 275 } 276 277 func (r *Rets) useLongHandleErrorCode(retvar string) string { 278 const code = `if %s { 279 if e1 != 0 { 280 err = error(e1) 281 } else { 282 err = %sEINVAL 283 } 284 }` 285 cond := retvar + " == 0" 286 if r.FailCond != "" { 287 cond = strings.Replace(r.FailCond, "failretval", retvar, 1) 288 } 289 return fmt.Sprintf(code, cond, syscalldot()) 290 } 291 292 // SetErrorCode returns source code that sets return parameters. 293 func (r *Rets) SetErrorCode() string { 294 const code = `if r0 != 0 { 295 %s = %sErrno(r0) 296 }` 297 if r.Name == "" && !r.ReturnsError { 298 return "" 299 } 300 if r.Name == "" { 301 return r.useLongHandleErrorCode("r1") 302 } 303 if r.Type == "error" { 304 return fmt.Sprintf(code, r.Name, syscalldot()) 305 } 306 s := "" 307 switch { 308 case r.Type[0] == '*': 309 s = fmt.Sprintf("%s = (%s)(unsafe.Pointer(r0))", r.Name, r.Type) 310 case r.Type == "bool": 311 s = fmt.Sprintf("%s = r0 != 0", r.Name) 312 default: 313 s = fmt.Sprintf("%s = %s(r0)", r.Name, r.Type) 314 } 315 if !r.ReturnsError { 316 return s 317 } 318 return s + "\n\t" + r.useLongHandleErrorCode(r.Name) 319 } 320 321 // Fn describes syscall function. 322 type Fn struct { 323 Name string 324 Params []*Param 325 Rets *Rets 326 PrintTrace bool 327 dllname string 328 dllfuncname string 329 src string 330 // TODO: get rid of this field and just use parameter index instead 331 curTmpVarIdx int // insure tmp variables have uniq names 332 } 333 334 // extractParams parses s to extract function parameters. 335 func extractParams(s string, f *Fn) ([]*Param, error) { 336 s = trim(s) 337 if s == "" { 338 return nil, nil 339 } 340 a := strings.Split(s, ",") 341 ps := make([]*Param, len(a)) 342 for i := range ps { 343 s2 := trim(a[i]) 344 b := strings.Split(s2, " ") 345 if len(b) != 2 { 346 b = strings.Split(s2, "\t") 347 if len(b) != 2 { 348 return nil, errors.New("Could not extract function parameter from \"" + s2 + "\"") 349 } 350 } 351 ps[i] = &Param{ 352 Name: trim(b[0]), 353 Type: trim(b[1]), 354 fn: f, 355 tmpVarIdx: -1, 356 } 357 } 358 return ps, nil 359 } 360 361 // extractSection extracts text out of string s starting after start 362 // and ending just before end. found return value will indicate success, 363 // and prefix, body and suffix will contain correspondent parts of string s. 364 func extractSection(s string, start, end rune) (prefix, body, suffix string, found bool) { 365 s = trim(s) 366 if strings.HasPrefix(s, string(start)) { 367 // no prefix 368 body = s[1:] 369 } else { 370 a := strings.SplitN(s, string(start), 2) 371 if len(a) != 2 { 372 return "", "", s, false 373 } 374 prefix = a[0] 375 body = a[1] 376 } 377 a := strings.SplitN(body, string(end), 2) 378 if len(a) != 2 { 379 return "", "", "", false 380 } 381 return prefix, a[0], a[1], true 382 } 383 384 // newFn parses string s and return created function Fn. 385 func newFn(s string) (*Fn, error) { 386 s = trim(s) 387 f := &Fn{ 388 Rets: &Rets{}, 389 src: s, 390 PrintTrace: *printTraceFlag, 391 } 392 // function name and args 393 prefix, body, s, found := extractSection(s, '(', ')') 394 if !found || prefix == "" { 395 return nil, errors.New("Could not extract function name and parameters from \"" + f.src + "\"") 396 } 397 f.Name = prefix 398 var err error 399 f.Params, err = extractParams(body, f) 400 if err != nil { 401 return nil, err 402 } 403 // return values 404 _, body, s, found = extractSection(s, '(', ')') 405 if found { 406 r, err := extractParams(body, f) 407 if err != nil { 408 return nil, err 409 } 410 switch len(r) { 411 case 0: 412 case 1: 413 if r[0].IsError() { 414 f.Rets.ReturnsError = true 415 } else { 416 f.Rets.Name = r[0].Name 417 f.Rets.Type = r[0].Type 418 } 419 case 2: 420 if !r[1].IsError() { 421 return nil, errors.New("Only last windows error is allowed as second return value in \"" + f.src + "\"") 422 } 423 f.Rets.ReturnsError = true 424 f.Rets.Name = r[0].Name 425 f.Rets.Type = r[0].Type 426 default: 427 return nil, errors.New("Too many return values in \"" + f.src + "\"") 428 } 429 } 430 // fail condition 431 _, body, s, found = extractSection(s, '[', ']') 432 if found { 433 f.Rets.FailCond = body 434 } 435 // dll and dll function names 436 s = trim(s) 437 if s == "" { 438 return f, nil 439 } 440 if !strings.HasPrefix(s, "=") { 441 return nil, errors.New("Could not extract dll name from \"" + f.src + "\"") 442 } 443 s = trim(s[1:]) 444 a := strings.Split(s, ".") 445 switch len(a) { 446 case 1: 447 f.dllfuncname = a[0] 448 case 2: 449 f.dllname = a[0] 450 f.dllfuncname = a[1] 451 default: 452 return nil, errors.New("Could not extract dll name from \"" + f.src + "\"") 453 } 454 return f, nil 455 } 456 457 // DLLName returns DLL name for function f. 458 func (f *Fn) DLLName() string { 459 if f.dllname == "" { 460 return "kernel32" 461 } 462 return f.dllname 463 } 464 465 // DLLName returns DLL function name for function f. 466 func (f *Fn) DLLFuncName() string { 467 if f.dllfuncname == "" { 468 return f.Name 469 } 470 return f.dllfuncname 471 } 472 473 // ParamList returns source code for function f parameters. 474 func (f *Fn) ParamList() string { 475 return join(f.Params, func(p *Param) string { return p.Name + " " + p.Type }, ", ") 476 } 477 478 // HelperParamList returns source code for helper function f parameters. 479 func (f *Fn) HelperParamList() string { 480 return join(f.Params, func(p *Param) string { return p.Name + " " + p.HelperType() }, ", ") 481 } 482 483 // ParamPrintList returns source code of trace printing part correspondent 484 // to syscall input parameters. 485 func (f *Fn) ParamPrintList() string { 486 return join(f.Params, func(p *Param) string { return fmt.Sprintf(`"%s=", %s, `, p.Name, p.Name) }, `", ", `) 487 } 488 489 // ParamCount return number of syscall parameters for function f. 490 func (f *Fn) ParamCount() int { 491 n := 0 492 for _, p := range f.Params { 493 n += len(p.SyscallArgList()) 494 } 495 return n 496 } 497 498 // SyscallParamCount determines which version of Syscall/Syscall6/Syscall9/... 499 // to use. It returns parameter count for correspondent SyscallX function. 500 func (f *Fn) SyscallParamCount() int { 501 n := f.ParamCount() 502 switch { 503 case n <= 3: 504 return 3 505 case n <= 6: 506 return 6 507 case n <= 9: 508 return 9 509 case n <= 12: 510 return 12 511 case n <= 15: 512 return 15 513 default: 514 panic("too many arguments to system call") 515 } 516 } 517 518 // Syscall determines which SyscallX function to use for function f. 519 func (f *Fn) Syscall() string { 520 c := f.SyscallParamCount() 521 if c == 3 { 522 return syscalldot() + "Syscall" 523 } 524 return syscalldot() + "Syscall" + strconv.Itoa(c) 525 } 526 527 // SyscallParamList returns source code for SyscallX parameters for function f. 528 func (f *Fn) SyscallParamList() string { 529 a := make([]string, 0) 530 for _, p := range f.Params { 531 a = append(a, p.SyscallArgList()...) 532 } 533 for len(a) < f.SyscallParamCount() { 534 a = append(a, "0") 535 } 536 return strings.Join(a, ", ") 537 } 538 539 // HelperCallParamList returns source code of call into function f helper. 540 func (f *Fn) HelperCallParamList() string { 541 a := make([]string, 0, len(f.Params)) 542 for _, p := range f.Params { 543 s := p.Name 544 if p.Type == "string" { 545 s = p.tmpVar() 546 } 547 a = append(a, s) 548 } 549 return strings.Join(a, ", ") 550 } 551 552 // IsUTF16 is true, if f is W (utf16) function. It is false 553 // for all A (ascii) functions. 554 func (f *Fn) IsUTF16() bool { 555 s := f.DLLFuncName() 556 return s[len(s)-1] == 'W' 557 } 558 559 // StrconvFunc returns name of Go string to OS string function for f. 560 func (f *Fn) StrconvFunc() string { 561 if f.IsUTF16() { 562 return syscalldot() + "UTF16PtrFromString" 563 } 564 return syscalldot() + "BytePtrFromString" 565 } 566 567 // StrconvType returns Go type name used for OS string for f. 568 func (f *Fn) StrconvType() string { 569 if f.IsUTF16() { 570 return "*uint16" 571 } 572 return "*byte" 573 } 574 575 // HasStringParam is true, if f has at least one string parameter. 576 // Otherwise it is false. 577 func (f *Fn) HasStringParam() bool { 578 for _, p := range f.Params { 579 if p.Type == "string" { 580 return true 581 } 582 } 583 return false 584 } 585 586 // HelperName returns name of function f helper. 587 func (f *Fn) HelperName() string { 588 if !f.HasStringParam() { 589 return f.Name 590 } 591 return "_" + f.Name 592 } 593 594 // Source files and functions. 595 type Source struct { 596 Funcs []*Fn 597 Files []string 598 } 599 600 // ParseFiles parses files listed in fs and extracts all syscall 601 // functions listed in sys comments. It returns source files 602 // and functions collection *Source if successful. 603 func ParseFiles(fs []string) (*Source, error) { 604 src := &Source{ 605 Funcs: make([]*Fn, 0), 606 Files: make([]string, 0), 607 } 608 for _, file := range fs { 609 if err := src.ParseFile(file); err != nil { 610 return nil, err 611 } 612 } 613 return src, nil 614 } 615 616 // DLLs return dll names for a source set src. 617 func (src *Source) DLLs() []string { 618 uniq := make(map[string]bool) 619 r := make([]string, 0) 620 for _, f := range src.Funcs { 621 name := f.DLLName() 622 if _, found := uniq[name]; !found { 623 uniq[name] = true 624 r = append(r, name) 625 } 626 } 627 return r 628 } 629 630 // ParseFile adds additional file path to a source set src. 631 func (src *Source) ParseFile(path string) error { 632 file, err := os.Open(path) 633 if err != nil { 634 return err 635 } 636 defer file.Close() 637 638 s := bufio.NewScanner(file) 639 for s.Scan() { 640 t := trim(s.Text()) 641 if len(t) < 7 { 642 continue 643 } 644 if !strings.HasPrefix(t, "//sys") { 645 continue 646 } 647 t = t[5:] 648 if !(t[0] == ' ' || t[0] == '\t') { 649 continue 650 } 651 f, err := newFn(t[1:]) 652 if err != nil { 653 return err 654 } 655 src.Funcs = append(src.Funcs, f) 656 } 657 if err := s.Err(); err != nil { 658 return err 659 } 660 src.Files = append(src.Files, path) 661 662 // get package name 663 fset := token.NewFileSet() 664 _, err = file.Seek(0, 0) 665 if err != nil { 666 return err 667 } 668 pkg, err := parser.ParseFile(fset, "", file, parser.PackageClauseOnly) 669 if err != nil { 670 return err 671 } 672 packageName = pkg.Name.Name 673 674 return nil 675 } 676 677 // Generate output source file from a source set src. 678 func (src *Source) Generate(w io.Writer) error { 679 funcMap := template.FuncMap{ 680 "packagename": packagename, 681 "syscalldot": syscalldot, 682 } 683 t := template.Must(template.New("main").Funcs(funcMap).Parse(srcTemplate)) 684 err := t.Execute(w, src) 685 if err != nil { 686 return errors.New("Failed to execute template: " + err.Error()) 687 } 688 return nil 689 } 690 691 func usage() { 692 fmt.Fprintf(os.Stderr, "usage: mksyscall_windows [flags] [path ...]\n") 693 flag.PrintDefaults() 694 os.Exit(1) 695 } 696 697 func main() { 698 flag.Usage = usage 699 flag.Parse() 700 if len(flag.Args()) <= 0 { 701 fmt.Fprintf(os.Stderr, "no files to parse provided\n") 702 usage() 703 } 704 705 src, err := ParseFiles(flag.Args()) 706 if err != nil { 707 log.Fatal(err) 708 } 709 710 var buf bytes.Buffer 711 if err := src.Generate(&buf); err != nil { 712 log.Fatal(err) 713 } 714 715 data, err := format.Source(buf.Bytes()) 716 if err != nil { 717 log.Fatal(err) 718 } 719 if *filename == "" { 720 _, err = os.Stdout.Write(data) 721 } else { 722 err = ioutil.WriteFile(*filename, data, 0644) 723 } 724 if err != nil { 725 log.Fatal(err) 726 } 727 } 728 729 // TODO: use println instead to print in the following template 730 const srcTemplate = ` 731 732 {{define "main"}}// MACHINE GENERATED BY 'go generate' COMMAND; DO NOT EDIT 733 734 package {{packagename}} 735 736 import "unsafe"{{if syscalldot}} 737 import "syscall"{{end}} 738 739 var _ unsafe.Pointer 740 741 var ( 742 {{template "dlls" .}} 743 {{template "funcnames" .}}) 744 {{range .Funcs}}{{if .HasStringParam}}{{template "helperbody" .}}{{end}}{{template "funcbody" .}}{{end}} 745 {{end}} 746 747 {{/* help functions */}} 748 749 {{define "dlls"}}{{range .DLLs}} mod{{.}} = {{syscalldot}}NewLazyDLL("{{.}}.dll") 750 {{end}}{{end}} 751 752 {{define "funcnames"}}{{range .Funcs}} proc{{.DLLFuncName}} = mod{{.DLLName}}.NewProc("{{.DLLFuncName}}") 753 {{end}}{{end}} 754 755 {{define "helperbody"}} 756 func {{.Name}}({{.ParamList}}) {{template "results" .}}{ 757 {{template "helpertmpvars" .}} return {{.HelperName}}({{.HelperCallParamList}}) 758 } 759 {{end}} 760 761 {{define "funcbody"}} 762 func {{.HelperName}}({{.HelperParamList}}) {{template "results" .}}{ 763 {{template "tmpvars" .}} {{template "syscall" .}} 764 {{template "seterror" .}}{{template "printtrace" .}} return 765 } 766 {{end}} 767 768 {{define "helpertmpvars"}}{{range .Params}}{{if .TmpVarHelperCode}} {{.TmpVarHelperCode}} 769 {{end}}{{end}}{{end}} 770 771 {{define "tmpvars"}}{{range .Params}}{{if .TmpVarCode}} {{.TmpVarCode}} 772 {{end}}{{end}}{{end}} 773 774 {{define "results"}}{{if .Rets.List}}{{.Rets.List}} {{end}}{{end}} 775 776 {{define "syscall"}}{{.Rets.SetReturnValuesCode}}{{.Syscall}}(proc{{.DLLFuncName}}.Addr(), {{.ParamCount}}, {{.SyscallParamList}}){{end}} 777 778 {{define "seterror"}}{{if .Rets.SetErrorCode}} {{.Rets.SetErrorCode}} 779 {{end}}{{end}} 780 781 {{define "printtrace"}}{{if .PrintTrace}} print("SYSCALL: {{.Name}}(", {{.ParamPrintList}}") (", {{.Rets.PrintList}}")\n") 782 {{end}}{{end}} 783 784 ` 785