1 #!/usr/bin/python 2 # 3 # Copyright 2017 The Android Open Source Project 4 # 5 # Licensed under the Apache License, Version 2.0 (the "License"); 6 # you may not use this file except in compliance with the License. 7 # You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 17 """Unit tests for xt_qtaguid.""" 18 19 import errno 20 from socket import * # pylint: disable=wildcard-import 21 import unittest 22 import os 23 24 import net_test 25 import packets 26 import tcp_test 27 28 CTRL_PROCPATH = "/proc/net/xt_qtaguid/ctrl" 29 OTHER_UID_GID = 12345 30 31 class QtaguidTest(tcp_test.TcpBaseTest): 32 33 def RunIptablesCommand(self, args): 34 self.assertFalse(net_test.RunIptablesCommand(4, args)) 35 self.assertFalse(net_test.RunIptablesCommand(6, args)) 36 37 def setUp(self): 38 self.RunIptablesCommand("-N qtaguid_test_OUTPUT") 39 self.RunIptablesCommand("-A OUTPUT -j qtaguid_test_OUTPUT") 40 41 def tearDown(self): 42 self.RunIptablesCommand("-D OUTPUT -j qtaguid_test_OUTPUT") 43 self.RunIptablesCommand("-F qtaguid_test_OUTPUT") 44 self.RunIptablesCommand("-X qtaguid_test_OUTPUT") 45 46 def WriteToCtrl(self, command): 47 ctrl_file = open(CTRL_PROCPATH, 'w') 48 ctrl_file.write(command) 49 ctrl_file.close() 50 51 def CheckTag(self, tag, uid): 52 for line in open(CTRL_PROCPATH, 'r').readlines(): 53 if "tag=0x%x (uid=%d)" % ((tag|uid), uid) in line: 54 return True 55 return False 56 57 def SetIptablesRule(self, version, is_add, is_gid, my_id, inverted): 58 add_del = "-A" if is_add else "-D" 59 uid_gid = "--gid-owner" if is_gid else "--uid-owner" 60 if inverted: 61 args = "%s qtaguid_test_OUTPUT -m owner ! %s %d -j DROP" % (add_del, uid_gid, my_id) 62 else: 63 args = "%s qtaguid_test_OUTPUT -m owner %s %d -j DROP" % (add_del, uid_gid, my_id) 64 self.assertFalse(net_test.RunIptablesCommand(version, args)) 65 66 def AddIptablesRule(self, version, is_gid, myId): 67 self.SetIptablesRule(version, True, is_gid, myId, False) 68 69 def AddIptablesInvertedRule(self, version, is_gid, myId): 70 self.SetIptablesRule(version, True, is_gid, myId, True) 71 72 def DelIptablesRule(self, version, is_gid, myId): 73 self.SetIptablesRule(version, False, is_gid, myId, False) 74 75 def DelIptablesInvertedRule(self, version, is_gid, myId): 76 self.SetIptablesRule(version, False, is_gid, myId, True) 77 78 def CheckSocketOutput(self, version, is_gid): 79 myId = os.getgid() if is_gid else os.getuid() 80 self.AddIptablesRule(version, is_gid, myId) 81 family = {4: AF_INET, 6: AF_INET6}[version] 82 s = socket(family, SOCK_DGRAM, 0) 83 addr = {4: "127.0.0.1", 6: "::1"}[version] 84 s.bind((addr, 0)) 85 addr = s.getsockname() 86 self.assertRaisesErrno(errno.EPERM, s.sendto, "foo", addr) 87 self.DelIptablesRule(version, is_gid, myId) 88 s.sendto("foo", addr) 89 data, sockaddr = s.recvfrom(4096) 90 self.assertEqual("foo", data) 91 self.assertEqual(sockaddr, addr) 92 93 def CheckSocketOutputInverted(self, version, is_gid): 94 # Load a inverted iptable rule on current uid/gid 0, traffic from other 95 # uid/gid should be blocked and traffic from current uid/gid should pass. 96 myId = os.getgid() if is_gid else os.getuid() 97 self.AddIptablesInvertedRule(version, is_gid, myId) 98 family = {4: AF_INET, 6: AF_INET6}[version] 99 s = socket(family, SOCK_DGRAM, 0) 100 addr1 = {4: "127.0.0.1", 6: "::1"}[version] 101 s.bind((addr1, 0)) 102 addr1 = s.getsockname() 103 s.sendto("foo", addr1) 104 data, sockaddr = s.recvfrom(4096) 105 self.assertEqual("foo", data) 106 self.assertEqual(sockaddr, addr1) 107 with net_test.RunAsUidGid(0 if is_gid else 12345, 108 12345 if is_gid else 0): 109 s2 = socket(family, SOCK_DGRAM, 0) 110 addr2 = {4: "127.0.0.1", 6: "::1"}[version] 111 s2.bind((addr2, 0)) 112 addr2 = s2.getsockname() 113 self.assertRaisesErrno(errno.EPERM, s2.sendto, "foo", addr2) 114 self.DelIptablesInvertedRule(version, is_gid, myId) 115 s.sendto("foo", addr1) 116 data, sockaddr = s.recvfrom(4096) 117 self.assertEqual("foo", data) 118 self.assertEqual(sockaddr, addr1) 119 120 def SendRSTOnClosedSocket(self, version, netid, expect_rst): 121 self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, netid) 122 self.accepted.setsockopt(net_test.SOL_TCP, net_test.TCP_LINGER2, -1) 123 net_test.EnableFinWait(self.accepted) 124 self.accepted.shutdown(SHUT_WR) 125 desc, fin = self.FinPacket() 126 self.ExpectPacketOn(netid, "Closing FIN_WAIT1 socket", fin) 127 finversion = 4 if version == 5 else version 128 desc, finack = packets.ACK(finversion, self.remoteaddr, self.myaddr, fin) 129 self.ReceivePacketOn(netid, finack) 130 try: 131 self.ExpectPacketOn(netid, "Closing FIN_WAIT1 socket", fin) 132 except AssertionError: 133 pass 134 self.accepted.close() 135 desc, rst = packets.RST(version, self.myaddr, self.remoteaddr, self.last_packet) 136 if expect_rst: 137 msg = "closing socket with linger2, expecting %s: " % desc 138 self.ExpectPacketOn(netid, msg, rst) 139 else: 140 msg = "closing socket with linger2, expecting no packets" 141 self.ExpectNoPacketsOn(netid, msg) 142 143 def CheckUidGidCombination(self, version, invert_gid, invert_uid): 144 my_uid = os.getuid() 145 my_gid = os.getgid() 146 if invert_gid: 147 self.AddIptablesInvertedRule(version, True, my_gid) 148 else: 149 self.AddIptablesRule(version, True, OTHER_UID_GID) 150 if invert_uid: 151 self.AddIptablesInvertedRule(version, False, my_uid) 152 else: 153 self.AddIptablesRule(version, False, OTHER_UID_GID) 154 for netid in self.NETIDS: 155 self.SendRSTOnClosedSocket(version, netid, not invert_gid) 156 if invert_gid: 157 self.DelIptablesInvertedRule(version, True, my_gid) 158 else: 159 self.DelIptablesRule(version, True, OTHER_UID_GID) 160 if invert_uid: 161 self.AddIptablesInvertedRule(version, False, my_uid) 162 else: 163 self.DelIptablesRule(version, False, OTHER_UID_GID) 164 165 def testCloseWithoutUntag(self): 166 self.dev_file = open("/dev/xt_qtaguid", "r"); 167 sk = socket(AF_INET, SOCK_DGRAM, 0) 168 uid = os.getuid() 169 tag = 0xff00ff00 << 32 170 command = "t %d %d %d" % (sk.fileno(), tag, uid) 171 self.WriteToCtrl(command) 172 self.assertTrue(self.CheckTag(tag, uid)) 173 sk.close(); 174 self.assertFalse(self.CheckTag(tag, uid)) 175 self.dev_file.close(); 176 177 def testTagWithoutDeviceOpen(self): 178 sk = socket(AF_INET, SOCK_DGRAM, 0) 179 uid = os.getuid() 180 tag = 0xff00ff00 << 32 181 command = "t %d %d %d" % (sk.fileno(), tag, uid) 182 self.WriteToCtrl(command) 183 self.assertTrue(self.CheckTag(tag, uid)) 184 self.dev_file = open("/dev/xt_qtaguid", "r") 185 sk.close() 186 self.assertFalse(self.CheckTag(tag, uid)) 187 self.dev_file.close(); 188 189 def testUidGidMatch(self): 190 self.CheckSocketOutput(4, False) 191 self.CheckSocketOutput(6, False) 192 self.CheckSocketOutput(4, True) 193 self.CheckSocketOutput(6, True) 194 self.CheckSocketOutputInverted(4, True) 195 self.CheckSocketOutputInverted(6, True) 196 self.CheckSocketOutputInverted(4, False) 197 self.CheckSocketOutputInverted(6, False) 198 199 def testCheckNotMatchGid(self): 200 self.assertIn("match_no_sk_gid", open(CTRL_PROCPATH, 'r').read()) 201 202 def testRstPacketNotDropped(self): 203 my_uid = os.getuid() 204 self.AddIptablesInvertedRule(4, False, my_uid) 205 for netid in self.NETIDS: 206 self.SendRSTOnClosedSocket(4, netid, True) 207 self.DelIptablesInvertedRule(4, False, my_uid) 208 self.AddIptablesInvertedRule(6, False, my_uid) 209 for netid in self.NETIDS: 210 self.SendRSTOnClosedSocket(6, netid, True) 211 self.DelIptablesInvertedRule(6, False, my_uid) 212 213 def testUidGidCombineMatch(self): 214 self.CheckUidGidCombination(4, invert_gid=True, invert_uid=True) 215 self.CheckUidGidCombination(4, invert_gid=True, invert_uid=False) 216 self.CheckUidGidCombination(4, invert_gid=False, invert_uid=True) 217 self.CheckUidGidCombination(4, invert_gid=False, invert_uid=False) 218 self.CheckUidGidCombination(6, invert_gid=True, invert_uid=True) 219 self.CheckUidGidCombination(6, invert_gid=True, invert_uid=False) 220 self.CheckUidGidCombination(6, invert_gid=False, invert_uid=True) 221 self.CheckUidGidCombination(6, invert_gid=False, invert_uid=False) 222 223 224 if __name__ == "__main__": 225 unittest.main() 226