1 #!/usr/bin/python 2 # 3 # Copyright 2017 The Android Open Source Project 4 # 5 # Licensed under the Apache License, Version 2.0 (the "License"); 6 # you may not use this file except in compliance with the License. 7 # You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 17 """Unit tests for xt_qtaguid.""" 18 19 import errno 20 from socket import * # pylint: disable=wildcard-import 21 import unittest 22 import os 23 import net_test 24 25 CTRL_PROCPATH = "/proc/net/xt_qtaguid/ctrl" 26 27 class QtaguidTest(net_test.NetworkTest): 28 29 def RunIptablesCommand(self, args): 30 self.assertFalse(net_test.RunIptablesCommand(4, args)) 31 self.assertFalse(net_test.RunIptablesCommand(6, args)) 32 33 def setUp(self): 34 self.RunIptablesCommand("-N qtaguid_test_OUTPUT") 35 self.RunIptablesCommand("-A OUTPUT -j qtaguid_test_OUTPUT") 36 37 def tearDown(self): 38 self.RunIptablesCommand("-D OUTPUT -j qtaguid_test_OUTPUT") 39 self.RunIptablesCommand("-F qtaguid_test_OUTPUT") 40 self.RunIptablesCommand("-X qtaguid_test_OUTPUT") 41 42 def WriteToCtrl(self, command): 43 ctrl_file = open(CTRL_PROCPATH, 'w') 44 ctrl_file.write(command) 45 ctrl_file.close() 46 47 def CheckTag(self, tag, uid): 48 for line in open(CTRL_PROCPATH, 'r').readlines(): 49 if "tag=0x%x (uid=%d)" % ((tag|uid), uid) in line: 50 return True 51 return False 52 53 def SetIptablesRule(self, version, is_add, is_gid, my_id, inverted): 54 add_del = "-A" if is_add else "-D" 55 uid_gid = "--gid-owner" if is_gid else "--uid-owner" 56 if inverted: 57 args = "%s qtaguid_test_OUTPUT -m owner ! %s %d -j DROP" % (add_del, uid_gid, my_id) 58 else: 59 args = "%s qtaguid_test_OUTPUT -m owner %s %d -j DROP" % (add_del, uid_gid, my_id) 60 self.assertFalse(net_test.RunIptablesCommand(version, args)) 61 62 def AddIptablesRule(self, version, is_gid, myId): 63 self.SetIptablesRule(version, True, is_gid, myId, False) 64 65 def AddIptablesInvertedRule(self, version, is_gid, myId): 66 self.SetIptablesRule(version, True, is_gid, myId, True) 67 68 def DelIptablesRule(self, version, is_gid, myId): 69 self.SetIptablesRule(version, False, is_gid, myId, False) 70 71 def DelIptablesInvertedRule(self, version, is_gid, myId): 72 self.SetIptablesRule(version, False, is_gid, myId, True) 73 74 def CheckSocketOutput(self, version, is_gid): 75 myId = os.getgid() if is_gid else os.getuid() 76 self.AddIptablesRule(version, is_gid, myId) 77 family = {4: AF_INET, 6: AF_INET6}[version] 78 s = socket(family, SOCK_DGRAM, 0) 79 addr = {4: "127.0.0.1", 6: "::1"}[version] 80 s.bind((addr, 0)) 81 addr = s.getsockname() 82 self.assertRaisesErrno(errno.EPERM, s.sendto, "foo", addr) 83 self.DelIptablesRule(version, is_gid, myId) 84 s.sendto("foo", addr) 85 data, sockaddr = s.recvfrom(4096) 86 self.assertEqual("foo", data) 87 self.assertEqual(sockaddr, addr) 88 89 def CheckSocketOutputInverted(self, version, is_gid): 90 # Load a inverted iptable rule on current uid/gid 0, traffic from other 91 # uid/gid should be blocked and traffic from current uid/gid should pass. 92 myId = os.getgid() if is_gid else os.getuid() 93 self.AddIptablesInvertedRule(version, is_gid, myId) 94 family = {4: AF_INET, 6: AF_INET6}[version] 95 s = socket(family, SOCK_DGRAM, 0) 96 addr1 = {4: "127.0.0.1", 6: "::1"}[version] 97 s.bind((addr1, 0)) 98 addr1 = s.getsockname() 99 s.sendto("foo", addr1) 100 data, sockaddr = s.recvfrom(4096) 101 self.assertEqual("foo", data) 102 self.assertEqual(sockaddr, addr1) 103 with net_test.RunAsUidGid(0 if is_gid else 12345, 104 12345 if is_gid else 0): 105 s2 = socket(family, SOCK_DGRAM, 0) 106 addr2 = {4: "127.0.0.1", 6: "::1"}[version] 107 s2.bind((addr2, 0)) 108 addr2 = s2.getsockname() 109 self.assertRaisesErrno(errno.EPERM, s2.sendto, "foo", addr2) 110 self.DelIptablesInvertedRule(version, is_gid, myId) 111 s.sendto("foo", addr1) 112 data, sockaddr = s.recvfrom(4096) 113 self.assertEqual("foo", data) 114 self.assertEqual(sockaddr, addr1) 115 116 def testCloseWithoutUntag(self): 117 self.dev_file = open("/dev/xt_qtaguid", "r"); 118 sk = socket(AF_INET, SOCK_DGRAM, 0) 119 uid = os.getuid() 120 tag = 0xff00ff00 << 32 121 command = "t %d %d %d" % (sk.fileno(), tag, uid) 122 self.WriteToCtrl(command) 123 self.assertTrue(self.CheckTag(tag, uid)) 124 sk.close(); 125 self.assertFalse(self.CheckTag(tag, uid)) 126 self.dev_file.close(); 127 128 def testTagWithoutDeviceOpen(self): 129 sk = socket(AF_INET, SOCK_DGRAM, 0) 130 uid = os.getuid() 131 tag = 0xff00ff00 << 32 132 command = "t %d %d %d" % (sk.fileno(), tag, uid) 133 self.WriteToCtrl(command) 134 self.assertTrue(self.CheckTag(tag, uid)) 135 self.dev_file = open("/dev/xt_qtaguid", "r") 136 sk.close() 137 self.assertFalse(self.CheckTag(tag, uid)) 138 self.dev_file.close(); 139 140 def testUidGidMatch(self): 141 self.CheckSocketOutput(4, False) 142 self.CheckSocketOutput(6, False) 143 self.CheckSocketOutput(4, True) 144 self.CheckSocketOutput(6, True) 145 self.CheckSocketOutputInverted(4, True) 146 self.CheckSocketOutputInverted(6, True) 147 self.CheckSocketOutputInverted(4, False) 148 self.CheckSocketOutputInverted(6, False) 149 150 @unittest.skip("does not pass on current kernels") 151 def testCheckNotMatchGid(self): 152 self.assertIn("match_no_sk_gid", open(CTRL_PROCPATH, 'r').read()) 153 154 155 if __name__ == "__main__": 156 unittest.main() 157