Home | History | Annotate | Download | only in runtime
      1 // Copyright 2010 The Go Authors.  All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 package runtime_test
      6 
      7 import (
      8 	"fmt"
      9 	"io/ioutil"
     10 	"os"
     11 	"os/exec"
     12 	"path/filepath"
     13 	"runtime"
     14 	"strings"
     15 	"syscall"
     16 	"testing"
     17 	"unsafe"
     18 )
     19 
     20 type DLL struct {
     21 	*syscall.DLL
     22 	t *testing.T
     23 }
     24 
     25 func GetDLL(t *testing.T, name string) *DLL {
     26 	d, e := syscall.LoadDLL(name)
     27 	if e != nil {
     28 		t.Fatal(e)
     29 	}
     30 	return &DLL{DLL: d, t: t}
     31 }
     32 
     33 func (d *DLL) Proc(name string) *syscall.Proc {
     34 	p, e := d.FindProc(name)
     35 	if e != nil {
     36 		d.t.Fatal(e)
     37 	}
     38 	return p
     39 }
     40 
     41 func TestStdCall(t *testing.T) {
     42 	type Rect struct {
     43 		left, top, right, bottom int32
     44 	}
     45 	res := Rect{}
     46 	expected := Rect{1, 1, 40, 60}
     47 	a, _, _ := GetDLL(t, "user32.dll").Proc("UnionRect").Call(
     48 		uintptr(unsafe.Pointer(&res)),
     49 		uintptr(unsafe.Pointer(&Rect{10, 1, 14, 60})),
     50 		uintptr(unsafe.Pointer(&Rect{1, 2, 40, 50})))
     51 	if a != 1 || res.left != expected.left ||
     52 		res.top != expected.top ||
     53 		res.right != expected.right ||
     54 		res.bottom != expected.bottom {
     55 		t.Error("stdcall USER32.UnionRect returns", a, "res=", res)
     56 	}
     57 }
     58 
     59 func Test64BitReturnStdCall(t *testing.T) {
     60 
     61 	const (
     62 		VER_BUILDNUMBER      = 0x0000004
     63 		VER_MAJORVERSION     = 0x0000002
     64 		VER_MINORVERSION     = 0x0000001
     65 		VER_PLATFORMID       = 0x0000008
     66 		VER_PRODUCT_TYPE     = 0x0000080
     67 		VER_SERVICEPACKMAJOR = 0x0000020
     68 		VER_SERVICEPACKMINOR = 0x0000010
     69 		VER_SUITENAME        = 0x0000040
     70 
     71 		VER_EQUAL         = 1
     72 		VER_GREATER       = 2
     73 		VER_GREATER_EQUAL = 3
     74 		VER_LESS          = 4
     75 		VER_LESS_EQUAL    = 5
     76 
     77 		ERROR_OLD_WIN_VERSION syscall.Errno = 1150
     78 	)
     79 
     80 	type OSVersionInfoEx struct {
     81 		OSVersionInfoSize uint32
     82 		MajorVersion      uint32
     83 		MinorVersion      uint32
     84 		BuildNumber       uint32
     85 		PlatformId        uint32
     86 		CSDVersion        [128]uint16
     87 		ServicePackMajor  uint16
     88 		ServicePackMinor  uint16
     89 		SuiteMask         uint16
     90 		ProductType       byte
     91 		Reserve           byte
     92 	}
     93 
     94 	d := GetDLL(t, "kernel32.dll")
     95 
     96 	var m1, m2 uintptr
     97 	VerSetConditionMask := d.Proc("VerSetConditionMask")
     98 	m1, m2, _ = VerSetConditionMask.Call(m1, m2, VER_MAJORVERSION, VER_GREATER_EQUAL)
     99 	m1, m2, _ = VerSetConditionMask.Call(m1, m2, VER_MINORVERSION, VER_GREATER_EQUAL)
    100 	m1, m2, _ = VerSetConditionMask.Call(m1, m2, VER_SERVICEPACKMAJOR, VER_GREATER_EQUAL)
    101 	m1, m2, _ = VerSetConditionMask.Call(m1, m2, VER_SERVICEPACKMINOR, VER_GREATER_EQUAL)
    102 
    103 	vi := OSVersionInfoEx{
    104 		MajorVersion:     5,
    105 		MinorVersion:     1,
    106 		ServicePackMajor: 2,
    107 		ServicePackMinor: 0,
    108 	}
    109 	vi.OSVersionInfoSize = uint32(unsafe.Sizeof(vi))
    110 	r, _, e2 := d.Proc("VerifyVersionInfoW").Call(
    111 		uintptr(unsafe.Pointer(&vi)),
    112 		VER_MAJORVERSION|VER_MINORVERSION|VER_SERVICEPACKMAJOR|VER_SERVICEPACKMINOR,
    113 		m1, m2)
    114 	if r == 0 && e2 != ERROR_OLD_WIN_VERSION {
    115 		t.Errorf("VerifyVersionInfo failed: %s", e2)
    116 	}
    117 }
    118 
    119 func TestCDecl(t *testing.T) {
    120 	var buf [50]byte
    121 	fmtp, _ := syscall.BytePtrFromString("%d %d %d")
    122 	a, _, _ := GetDLL(t, "user32.dll").Proc("wsprintfA").Call(
    123 		uintptr(unsafe.Pointer(&buf[0])),
    124 		uintptr(unsafe.Pointer(fmtp)),
    125 		1000, 2000, 3000)
    126 	if string(buf[:a]) != "1000 2000 3000" {
    127 		t.Error("cdecl USER32.wsprintfA returns", a, "buf=", buf[:a])
    128 	}
    129 }
    130 
    131 func TestEnumWindows(t *testing.T) {
    132 	d := GetDLL(t, "user32.dll")
    133 	isWindows := d.Proc("IsWindow")
    134 	counter := 0
    135 	cb := syscall.NewCallback(func(hwnd syscall.Handle, lparam uintptr) uintptr {
    136 		if lparam != 888 {
    137 			t.Error("lparam was not passed to callback")
    138 		}
    139 		b, _, _ := isWindows.Call(uintptr(hwnd))
    140 		if b == 0 {
    141 			t.Error("USER32.IsWindow returns FALSE")
    142 		}
    143 		counter++
    144 		return 1 // continue enumeration
    145 	})
    146 	a, _, _ := d.Proc("EnumWindows").Call(cb, 888)
    147 	if a == 0 {
    148 		t.Error("USER32.EnumWindows returns FALSE")
    149 	}
    150 	if counter == 0 {
    151 		t.Error("Callback has been never called or your have no windows")
    152 	}
    153 }
    154 
    155 func callback(hwnd syscall.Handle, lparam uintptr) uintptr {
    156 	(*(*func())(unsafe.Pointer(&lparam)))()
    157 	return 0 // stop enumeration
    158 }
    159 
    160 // nestedCall calls into Windows, back into Go, and finally to f.
    161 func nestedCall(t *testing.T, f func()) {
    162 	c := syscall.NewCallback(callback)
    163 	d := GetDLL(t, "user32.dll")
    164 	defer d.Release()
    165 	d.Proc("EnumWindows").Call(c, uintptr(*(*unsafe.Pointer)(unsafe.Pointer(&f))))
    166 }
    167 
    168 func TestCallback(t *testing.T) {
    169 	var x = false
    170 	nestedCall(t, func() { x = true })
    171 	if !x {
    172 		t.Fatal("nestedCall did not call func")
    173 	}
    174 }
    175 
    176 func TestCallbackGC(t *testing.T) {
    177 	nestedCall(t, runtime.GC)
    178 }
    179 
    180 func TestCallbackPanicLocked(t *testing.T) {
    181 	runtime.LockOSThread()
    182 	defer runtime.UnlockOSThread()
    183 
    184 	if !runtime.LockedOSThread() {
    185 		t.Fatal("runtime.LockOSThread didn't")
    186 	}
    187 	defer func() {
    188 		s := recover()
    189 		if s == nil {
    190 			t.Fatal("did not panic")
    191 		}
    192 		if s.(string) != "callback panic" {
    193 			t.Fatal("wrong panic:", s)
    194 		}
    195 		if !runtime.LockedOSThread() {
    196 			t.Fatal("lost lock on OS thread after panic")
    197 		}
    198 	}()
    199 	nestedCall(t, func() { panic("callback panic") })
    200 	panic("nestedCall returned")
    201 }
    202 
    203 func TestCallbackPanic(t *testing.T) {
    204 	// Make sure panic during callback unwinds properly.
    205 	if runtime.LockedOSThread() {
    206 		t.Fatal("locked OS thread on entry to TestCallbackPanic")
    207 	}
    208 	defer func() {
    209 		s := recover()
    210 		if s == nil {
    211 			t.Fatal("did not panic")
    212 		}
    213 		if s.(string) != "callback panic" {
    214 			t.Fatal("wrong panic:", s)
    215 		}
    216 		if runtime.LockedOSThread() {
    217 			t.Fatal("locked OS thread on exit from TestCallbackPanic")
    218 		}
    219 	}()
    220 	nestedCall(t, func() { panic("callback panic") })
    221 	panic("nestedCall returned")
    222 }
    223 
    224 func TestCallbackPanicLoop(t *testing.T) {
    225 	// Make sure we don't blow out m->g0 stack.
    226 	for i := 0; i < 100000; i++ {
    227 		TestCallbackPanic(t)
    228 	}
    229 }
    230 
    231 func TestBlockingCallback(t *testing.T) {
    232 	c := make(chan int)
    233 	go func() {
    234 		for i := 0; i < 10; i++ {
    235 			c <- <-c
    236 		}
    237 	}()
    238 	nestedCall(t, func() {
    239 		for i := 0; i < 10; i++ {
    240 			c <- i
    241 			if j := <-c; j != i {
    242 				t.Errorf("out of sync %d != %d", j, i)
    243 			}
    244 		}
    245 	})
    246 }
    247 
    248 func TestCallbackInAnotherThread(t *testing.T) {
    249 	// TODO: test a function which calls back in another thread: QueueUserAPC() or CreateThread()
    250 }
    251 
    252 type cbDLLFunc int // int determines number of callback parameters
    253 
    254 func (f cbDLLFunc) stdcallName() string {
    255 	return fmt.Sprintf("stdcall%d", f)
    256 }
    257 
    258 func (f cbDLLFunc) cdeclName() string {
    259 	return fmt.Sprintf("cdecl%d", f)
    260 }
    261 
    262 func (f cbDLLFunc) buildOne(stdcall bool) string {
    263 	var funcname, attr string
    264 	if stdcall {
    265 		funcname = f.stdcallName()
    266 		attr = "__stdcall"
    267 	} else {
    268 		funcname = f.cdeclName()
    269 		attr = "__cdecl"
    270 	}
    271 	typename := "t" + funcname
    272 	p := make([]string, f)
    273 	for i := range p {
    274 		p[i] = "uintptr_t"
    275 	}
    276 	params := strings.Join(p, ",")
    277 	for i := range p {
    278 		p[i] = fmt.Sprintf("%d", i+1)
    279 	}
    280 	args := strings.Join(p, ",")
    281 	return fmt.Sprintf(`
    282 typedef void %s (*%s)(%s);
    283 void %s(%s f, uintptr_t n) {
    284 	uintptr_t i;
    285 	for(i=0;i<n;i++){
    286 		f(%s);
    287 	}
    288 }
    289 	`, attr, typename, params, funcname, typename, args)
    290 }
    291 
    292 func (f cbDLLFunc) build() string {
    293 	return "#include <stdint.h>\n\n" + f.buildOne(false) + f.buildOne(true)
    294 }
    295 
    296 var cbFuncs = [...]interface{}{
    297 	2: func(i1, i2 uintptr) uintptr {
    298 		if i1+i2 != 3 {
    299 			panic("bad input")
    300 		}
    301 		return 0
    302 	},
    303 	3: func(i1, i2, i3 uintptr) uintptr {
    304 		if i1+i2+i3 != 6 {
    305 			panic("bad input")
    306 		}
    307 		return 0
    308 	},
    309 	4: func(i1, i2, i3, i4 uintptr) uintptr {
    310 		if i1+i2+i3+i4 != 10 {
    311 			panic("bad input")
    312 		}
    313 		return 0
    314 	},
    315 	5: func(i1, i2, i3, i4, i5 uintptr) uintptr {
    316 		if i1+i2+i3+i4+i5 != 15 {
    317 			panic("bad input")
    318 		}
    319 		return 0
    320 	},
    321 	6: func(i1, i2, i3, i4, i5, i6 uintptr) uintptr {
    322 		if i1+i2+i3+i4+i5+i6 != 21 {
    323 			panic("bad input")
    324 		}
    325 		return 0
    326 	},
    327 	7: func(i1, i2, i3, i4, i5, i6, i7 uintptr) uintptr {
    328 		if i1+i2+i3+i4+i5+i6+i7 != 28 {
    329 			panic("bad input")
    330 		}
    331 		return 0
    332 	},
    333 	8: func(i1, i2, i3, i4, i5, i6, i7, i8 uintptr) uintptr {
    334 		if i1+i2+i3+i4+i5+i6+i7+i8 != 36 {
    335 			panic("bad input")
    336 		}
    337 		return 0
    338 	},
    339 	9: func(i1, i2, i3, i4, i5, i6, i7, i8, i9 uintptr) uintptr {
    340 		if i1+i2+i3+i4+i5+i6+i7+i8+i9 != 45 {
    341 			panic("bad input")
    342 		}
    343 		return 0
    344 	},
    345 }
    346 
    347 type cbDLL struct {
    348 	name      string
    349 	buildArgs func(out, src string) []string
    350 }
    351 
    352 func (d *cbDLL) buildSrc(t *testing.T, path string) {
    353 	f, err := os.Create(path)
    354 	if err != nil {
    355 		t.Fatalf("failed to create source file: %v", err)
    356 	}
    357 	defer f.Close()
    358 
    359 	for i := 2; i < 10; i++ {
    360 		fmt.Fprint(f, cbDLLFunc(i).build())
    361 	}
    362 }
    363 
    364 func (d *cbDLL) build(t *testing.T, dir string) string {
    365 	srcname := d.name + ".c"
    366 	d.buildSrc(t, filepath.Join(dir, srcname))
    367 	outname := d.name + ".dll"
    368 	args := d.buildArgs(outname, srcname)
    369 	cmd := exec.Command(args[0], args[1:]...)
    370 	cmd.Dir = dir
    371 	out, err := cmd.CombinedOutput()
    372 	if err != nil {
    373 		t.Fatalf("failed to build dll: %v - %v", err, string(out))
    374 	}
    375 	return filepath.Join(dir, outname)
    376 }
    377 
    378 var cbDLLs = []cbDLL{
    379 	{
    380 		"test",
    381 		func(out, src string) []string {
    382 			return []string{"gcc", "-shared", "-s", "-Werror", "-o", out, src}
    383 		},
    384 	},
    385 	{
    386 		"testO2",
    387 		func(out, src string) []string {
    388 			return []string{"gcc", "-shared", "-s", "-Werror", "-o", out, "-O2", src}
    389 		},
    390 	},
    391 }
    392 
    393 type cbTest struct {
    394 	n     int     // number of callback parameters
    395 	param uintptr // dll function parameter
    396 }
    397 
    398 func (test *cbTest) run(t *testing.T, dllpath string) {
    399 	dll := syscall.MustLoadDLL(dllpath)
    400 	defer dll.Release()
    401 	cb := cbFuncs[test.n]
    402 	stdcall := syscall.NewCallback(cb)
    403 	f := cbDLLFunc(test.n)
    404 	test.runOne(t, dll, f.stdcallName(), stdcall)
    405 	cdecl := syscall.NewCallbackCDecl(cb)
    406 	test.runOne(t, dll, f.cdeclName(), cdecl)
    407 }
    408 
    409 func (test *cbTest) runOne(t *testing.T, dll *syscall.DLL, proc string, cb uintptr) {
    410 	defer func() {
    411 		if r := recover(); r != nil {
    412 			t.Errorf("dll call %v(..., %d) failed: %v", proc, test.param, r)
    413 		}
    414 	}()
    415 	dll.MustFindProc(proc).Call(cb, test.param)
    416 }
    417 
    418 var cbTests = []cbTest{
    419 	{2, 1},
    420 	{2, 10000},
    421 	{3, 3},
    422 	{4, 5},
    423 	{4, 6},
    424 	{5, 2},
    425 	{6, 7},
    426 	{6, 8},
    427 	{7, 6},
    428 	{8, 1},
    429 	{9, 8},
    430 	{9, 10000},
    431 	{3, 4},
    432 	{5, 3},
    433 	{7, 7},
    434 	{8, 2},
    435 	{9, 9},
    436 }
    437 
    438 func TestStdcallAndCDeclCallbacks(t *testing.T) {
    439 	if _, err := exec.LookPath("gcc"); err != nil {
    440 		t.Skip("skipping test: gcc is missing")
    441 	}
    442 	tmp, err := ioutil.TempDir("", "TestCDeclCallback")
    443 	if err != nil {
    444 		t.Fatal("TempDir failed: ", err)
    445 	}
    446 	defer os.RemoveAll(tmp)
    447 
    448 	for _, dll := range cbDLLs {
    449 		dllPath := dll.build(t, tmp)
    450 		for _, test := range cbTests {
    451 			test.run(t, dllPath)
    452 		}
    453 	}
    454 }
    455 
    456 func TestRegisterClass(t *testing.T) {
    457 	kernel32 := GetDLL(t, "kernel32.dll")
    458 	user32 := GetDLL(t, "user32.dll")
    459 	mh, _, _ := kernel32.Proc("GetModuleHandleW").Call(0)
    460 	cb := syscall.NewCallback(func(hwnd syscall.Handle, msg uint32, wparam, lparam uintptr) (rc uintptr) {
    461 		t.Fatal("callback should never get called")
    462 		return 0
    463 	})
    464 	type Wndclassex struct {
    465 		Size       uint32
    466 		Style      uint32
    467 		WndProc    uintptr
    468 		ClsExtra   int32
    469 		WndExtra   int32
    470 		Instance   syscall.Handle
    471 		Icon       syscall.Handle
    472 		Cursor     syscall.Handle
    473 		Background syscall.Handle
    474 		MenuName   *uint16
    475 		ClassName  *uint16
    476 		IconSm     syscall.Handle
    477 	}
    478 	name := syscall.StringToUTF16Ptr("test_window")
    479 	wc := Wndclassex{
    480 		WndProc:   cb,
    481 		Instance:  syscall.Handle(mh),
    482 		ClassName: name,
    483 	}
    484 	wc.Size = uint32(unsafe.Sizeof(wc))
    485 	a, _, err := user32.Proc("RegisterClassExW").Call(uintptr(unsafe.Pointer(&wc)))
    486 	if a == 0 {
    487 		t.Fatalf("RegisterClassEx failed: %v", err)
    488 	}
    489 	r, _, err := user32.Proc("UnregisterClassW").Call(uintptr(unsafe.Pointer(name)), 0)
    490 	if r == 0 {
    491 		t.Fatalf("UnregisterClass failed: %v", err)
    492 	}
    493 }
    494 
    495 func TestOutputDebugString(t *testing.T) {
    496 	d := GetDLL(t, "kernel32.dll")
    497 	p := syscall.StringToUTF16Ptr("testing OutputDebugString")
    498 	d.Proc("OutputDebugStringW").Call(uintptr(unsafe.Pointer(p)))
    499 }
    500 
    501 func TestRaiseException(t *testing.T) {
    502 	o := executeTest(t, raiseExceptionSource, nil)
    503 	if strings.Contains(o, "RaiseException should not return") {
    504 		t.Fatalf("RaiseException did not crash program: %v", o)
    505 	}
    506 	if !strings.Contains(o, "Exception 0xbad") {
    507 		t.Fatalf("No stack trace: %v", o)
    508 	}
    509 }
    510 
    511 const raiseExceptionSource = `
    512 package main
    513 import "syscall"
    514 func main() {
    515 	const EXCEPTION_NONCONTINUABLE = 1
    516 	mod := syscall.MustLoadDLL("kernel32.dll")
    517 	proc := mod.MustFindProc("RaiseException")
    518 	proc.Call(0xbad, EXCEPTION_NONCONTINUABLE, 0, 0)
    519 	println("RaiseException should not return")
    520 }
    521 `
    522 
    523 func TestZeroDivisionException(t *testing.T) {
    524 	o := executeTest(t, zeroDivisionExceptionSource, nil)
    525 	if !strings.Contains(o, "panic: runtime error: integer divide by zero") {
    526 		t.Fatalf("No stack trace: %v", o)
    527 	}
    528 }
    529 
    530 const zeroDivisionExceptionSource = `
    531 package main
    532 func main() {
    533 	x := 1
    534 	y := 0
    535 	z := x / y
    536 	println(z)
    537 }
    538 `
    539 
    540 func TestWERDialogue(t *testing.T) {
    541 	if os.Getenv("TESTING_WER_DIALOGUE") == "1" {
    542 		defer os.Exit(0)
    543 
    544 		*runtime.TestingWER = true
    545 		const EXCEPTION_NONCONTINUABLE = 1
    546 		mod := syscall.MustLoadDLL("kernel32.dll")
    547 		proc := mod.MustFindProc("RaiseException")
    548 		proc.Call(0xbad, EXCEPTION_NONCONTINUABLE, 0, 0)
    549 		println("RaiseException should not return")
    550 		return
    551 	}
    552 	cmd := exec.Command(os.Args[0], "-test.run=TestWERDialogue")
    553 	cmd.Env = []string{"TESTING_WER_DIALOGUE=1"}
    554 	// Child process should not open WER dialogue, but return immediately instead.
    555 	cmd.CombinedOutput()
    556 }
    557 
    558 var used byte
    559 
    560 func use(buf []byte) {
    561 	for _, c := range buf {
    562 		used += c
    563 	}
    564 }
    565 
    566 func forceStackCopy() (r int) {
    567 	var f func(int) int
    568 	f = func(i int) int {
    569 		var buf [256]byte
    570 		use(buf[:])
    571 		if i == 0 {
    572 			return 0
    573 		}
    574 		return i + f(i-1)
    575 	}
    576 	r = f(128)
    577 	return
    578 }
    579 
    580 func TestReturnAfterStackGrowInCallback(t *testing.T) {
    581 	if _, err := exec.LookPath("gcc"); err != nil {
    582 		t.Skip("skipping test: gcc is missing")
    583 	}
    584 
    585 	const src = `
    586 #include <stdint.h>
    587 #include <windows.h>
    588 
    589 typedef uintptr_t __stdcall (*callback)(uintptr_t);
    590 
    591 uintptr_t cfunc(callback f, uintptr_t n) {
    592    uintptr_t r;
    593    r = f(n);
    594    SetLastError(333);
    595    return r;
    596 }
    597 `
    598 	tmpdir, err := ioutil.TempDir("", "TestReturnAfterStackGrowInCallback")
    599 	if err != nil {
    600 		t.Fatal("TempDir failed: ", err)
    601 	}
    602 	defer os.RemoveAll(tmpdir)
    603 
    604 	srcname := "mydll.c"
    605 	err = ioutil.WriteFile(filepath.Join(tmpdir, srcname), []byte(src), 0)
    606 	if err != nil {
    607 		t.Fatal(err)
    608 	}
    609 	outname := "mydll.dll"
    610 	cmd := exec.Command("gcc", "-shared", "-s", "-Werror", "-o", outname, srcname)
    611 	cmd.Dir = tmpdir
    612 	out, err := cmd.CombinedOutput()
    613 	if err != nil {
    614 		t.Fatalf("failed to build dll: %v - %v", err, string(out))
    615 	}
    616 	dllpath := filepath.Join(tmpdir, outname)
    617 
    618 	dll := syscall.MustLoadDLL(dllpath)
    619 	defer dll.Release()
    620 
    621 	proc := dll.MustFindProc("cfunc")
    622 
    623 	cb := syscall.NewCallback(func(n uintptr) uintptr {
    624 		forceStackCopy()
    625 		return n
    626 	})
    627 
    628 	// Use a new goroutine so that we get a small stack.
    629 	type result struct {
    630 		r   uintptr
    631 		err syscall.Errno
    632 	}
    633 	c := make(chan result)
    634 	go func() {
    635 		r, _, err := proc.Call(cb, 100)
    636 		c <- result{r, err.(syscall.Errno)}
    637 	}()
    638 	want := result{r: 100, err: 333}
    639 	if got := <-c; got != want {
    640 		t.Errorf("got %d want %d", got, want)
    641 	}
    642 }
    643