Home | History | Annotate | Download | only in syscall
      1 // Copyright 2012 The Go Authors.  All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 // +build linux
      6 
      7 package syscall_test
      8 
      9 import (
     10 	"bytes"
     11 	"net"
     12 	"os"
     13 	"syscall"
     14 	"testing"
     15 )
     16 
     17 // TestSCMCredentials tests the sending and receiving of credentials
     18 // (PID, UID, GID) in an ancillary message between two UNIX
     19 // sockets. The SO_PASSCRED socket option is enabled on the sending
     20 // socket for this to work.
     21 func TestSCMCredentials(t *testing.T) {
     22 	fds, err := syscall.Socketpair(syscall.AF_LOCAL, syscall.SOCK_STREAM, 0)
     23 	if err != nil {
     24 		t.Fatalf("Socketpair: %v", err)
     25 	}
     26 	defer syscall.Close(fds[0])
     27 	defer syscall.Close(fds[1])
     28 
     29 	err = syscall.SetsockoptInt(fds[0], syscall.SOL_SOCKET, syscall.SO_PASSCRED, 1)
     30 	if err != nil {
     31 		t.Fatalf("SetsockoptInt: %v", err)
     32 	}
     33 
     34 	srvFile := os.NewFile(uintptr(fds[0]), "server")
     35 	defer srvFile.Close()
     36 	srv, err := net.FileConn(srvFile)
     37 	if err != nil {
     38 		t.Errorf("FileConn: %v", err)
     39 		return
     40 	}
     41 	defer srv.Close()
     42 
     43 	cliFile := os.NewFile(uintptr(fds[1]), "client")
     44 	defer cliFile.Close()
     45 	cli, err := net.FileConn(cliFile)
     46 	if err != nil {
     47 		t.Errorf("FileConn: %v", err)
     48 		return
     49 	}
     50 	defer cli.Close()
     51 
     52 	var ucred syscall.Ucred
     53 	if os.Getuid() != 0 {
     54 		ucred.Pid = int32(os.Getpid())
     55 		ucred.Uid = 0
     56 		ucred.Gid = 0
     57 		oob := syscall.UnixCredentials(&ucred)
     58 		_, _, err := cli.(*net.UnixConn).WriteMsgUnix(nil, oob, nil)
     59 		if op, ok := err.(*net.OpError); ok {
     60 			err = op.Err
     61 		}
     62 		if sys, ok := err.(*os.SyscallError); ok {
     63 			err = sys.Err
     64 		}
     65 		if err != syscall.EPERM {
     66 			t.Fatalf("WriteMsgUnix failed with %v, want EPERM", err)
     67 		}
     68 	}
     69 
     70 	ucred.Pid = int32(os.Getpid())
     71 	ucred.Uid = uint32(os.Getuid())
     72 	ucred.Gid = uint32(os.Getgid())
     73 	oob := syscall.UnixCredentials(&ucred)
     74 
     75 	// this is going to send a dummy byte
     76 	n, oobn, err := cli.(*net.UnixConn).WriteMsgUnix(nil, oob, nil)
     77 	if err != nil {
     78 		t.Fatalf("WriteMsgUnix: %v", err)
     79 	}
     80 	if n != 0 {
     81 		t.Fatalf("WriteMsgUnix n = %d, want 0", n)
     82 	}
     83 	if oobn != len(oob) {
     84 		t.Fatalf("WriteMsgUnix oobn = %d, want %d", oobn, len(oob))
     85 	}
     86 
     87 	oob2 := make([]byte, 10*len(oob))
     88 	n, oobn2, flags, _, err := srv.(*net.UnixConn).ReadMsgUnix(nil, oob2)
     89 	if err != nil {
     90 		t.Fatalf("ReadMsgUnix: %v", err)
     91 	}
     92 	if flags != 0 {
     93 		t.Fatalf("ReadMsgUnix flags = 0x%x, want 0", flags)
     94 	}
     95 	if n != 1 {
     96 		t.Fatalf("ReadMsgUnix n = %d, want 1 (dummy byte)", n)
     97 	}
     98 	if oobn2 != oobn {
     99 		// without SO_PASSCRED set on the socket, ReadMsgUnix will
    100 		// return zero oob bytes
    101 		t.Fatalf("ReadMsgUnix oobn = %d, want %d", oobn2, oobn)
    102 	}
    103 	oob2 = oob2[:oobn2]
    104 	if !bytes.Equal(oob, oob2) {
    105 		t.Fatal("ReadMsgUnix oob bytes don't match")
    106 	}
    107 
    108 	scm, err := syscall.ParseSocketControlMessage(oob2)
    109 	if err != nil {
    110 		t.Fatalf("ParseSocketControlMessage: %v", err)
    111 	}
    112 	newUcred, err := syscall.ParseUnixCredentials(&scm[0])
    113 	if err != nil {
    114 		t.Fatalf("ParseUnixCredentials: %v", err)
    115 	}
    116 	if *newUcred != ucred {
    117 		t.Fatalf("ParseUnixCredentials = %+v, want %+v", newUcred, ucred)
    118 	}
    119 }
    120