Home | History | Annotate | Download | only in syscall
      1 // Copyright 2012 The Go Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style
      3 // license that can be found in the LICENSE file.
      4 
      5 // +build linux
      6 
      7 package syscall_test
      8 
      9 import (
     10 	"bytes"
     11 	"net"
     12 	"os"
     13 	"syscall"
     14 	"testing"
     15 )
     16 
     17 // TestSCMCredentials tests the sending and receiving of credentials
     18 // (PID, UID, GID) in an ancillary message between two UNIX
     19 // sockets. The SO_PASSCRED socket option is enabled on the sending
     20 // socket for this to work.
     21 func TestSCMCredentials(t *testing.T) {
     22 	socketTypeTests := []struct {
     23 		socketType int
     24 		dataLen    int
     25 	}{
     26 		{
     27 			syscall.SOCK_STREAM,
     28 			1,
     29 		}, {
     30 			syscall.SOCK_DGRAM,
     31 			0,
     32 		},
     33 	}
     34 
     35 	for _, tt := range socketTypeTests {
     36 		fds, err := syscall.Socketpair(syscall.AF_LOCAL, tt.socketType, 0)
     37 		if err != nil {
     38 			t.Fatalf("Socketpair: %v", err)
     39 		}
     40 		defer syscall.Close(fds[0])
     41 		defer syscall.Close(fds[1])
     42 
     43 		err = syscall.SetsockoptInt(fds[0], syscall.SOL_SOCKET, syscall.SO_PASSCRED, 1)
     44 		if err != nil {
     45 			t.Fatalf("SetsockoptInt: %v", err)
     46 		}
     47 
     48 		srvFile := os.NewFile(uintptr(fds[0]), "server")
     49 		defer srvFile.Close()
     50 		srv, err := net.FileConn(srvFile)
     51 		if err != nil {
     52 			t.Errorf("FileConn: %v", err)
     53 			return
     54 		}
     55 		defer srv.Close()
     56 
     57 		cliFile := os.NewFile(uintptr(fds[1]), "client")
     58 		defer cliFile.Close()
     59 		cli, err := net.FileConn(cliFile)
     60 		if err != nil {
     61 			t.Errorf("FileConn: %v", err)
     62 			return
     63 		}
     64 		defer cli.Close()
     65 
     66 		var ucred syscall.Ucred
     67 		if os.Getuid() != 0 {
     68 			ucred.Pid = int32(os.Getpid())
     69 			ucred.Uid = 0
     70 			ucred.Gid = 0
     71 			oob := syscall.UnixCredentials(&ucred)
     72 			_, _, err := cli.(*net.UnixConn).WriteMsgUnix(nil, oob, nil)
     73 			if op, ok := err.(*net.OpError); ok {
     74 				err = op.Err
     75 			}
     76 			if sys, ok := err.(*os.SyscallError); ok {
     77 				err = sys.Err
     78 			}
     79 			if err != syscall.EPERM {
     80 				t.Fatalf("WriteMsgUnix failed with %v, want EPERM", err)
     81 			}
     82 		}
     83 
     84 		ucred.Pid = int32(os.Getpid())
     85 		ucred.Uid = uint32(os.Getuid())
     86 		ucred.Gid = uint32(os.Getgid())
     87 		oob := syscall.UnixCredentials(&ucred)
     88 
     89 		// On SOCK_STREAM, this is internally going to send a dummy byte
     90 		n, oobn, err := cli.(*net.UnixConn).WriteMsgUnix(nil, oob, nil)
     91 		if err != nil {
     92 			t.Fatalf("WriteMsgUnix: %v", err)
     93 		}
     94 		if n != 0 {
     95 			t.Fatalf("WriteMsgUnix n = %d, want 0", n)
     96 		}
     97 		if oobn != len(oob) {
     98 			t.Fatalf("WriteMsgUnix oobn = %d, want %d", oobn, len(oob))
     99 		}
    100 
    101 		oob2 := make([]byte, 10*len(oob))
    102 		n, oobn2, flags, _, err := srv.(*net.UnixConn).ReadMsgUnix(nil, oob2)
    103 		if err != nil {
    104 			t.Fatalf("ReadMsgUnix: %v", err)
    105 		}
    106 		if flags != 0 {
    107 			t.Fatalf("ReadMsgUnix flags = 0x%x, want 0", flags)
    108 		}
    109 		if n != tt.dataLen {
    110 			t.Fatalf("ReadMsgUnix n = %d, want %d", n, tt.dataLen)
    111 		}
    112 		if oobn2 != oobn {
    113 			// without SO_PASSCRED set on the socket, ReadMsgUnix will
    114 			// return zero oob bytes
    115 			t.Fatalf("ReadMsgUnix oobn = %d, want %d", oobn2, oobn)
    116 		}
    117 		oob2 = oob2[:oobn2]
    118 		if !bytes.Equal(oob, oob2) {
    119 			t.Fatal("ReadMsgUnix oob bytes don't match")
    120 		}
    121 
    122 		scm, err := syscall.ParseSocketControlMessage(oob2)
    123 		if err != nil {
    124 			t.Fatalf("ParseSocketControlMessage: %v", err)
    125 		}
    126 		newUcred, err := syscall.ParseUnixCredentials(&scm[0])
    127 		if err != nil {
    128 			t.Fatalf("ParseUnixCredentials: %v", err)
    129 		}
    130 		if *newUcred != ucred {
    131 			t.Fatalf("ParseUnixCredentials = %+v, want %+v", newUcred, ucred)
    132 		}
    133 	}
    134 }
    135