Home | History | Annotate | Download | only in test
      1 #!/usr/bin/python
      2 #
      3 # Copyright 2017 The Android Open Source Project
      4 #
      5 # Licensed under the Apache License, Version 2.0 (the "License");
      6 # you may not use this file except in compliance with the License.
      7 # You may obtain a copy of the License at
      8 #
      9 # http://www.apache.org/licenses/LICENSE-2.0
     10 #
     11 # Unless required by applicable law or agreed to in writing, software
     12 # distributed under the License is distributed on an "AS IS" BASIS,
     13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     14 # See the License for the specific language governing permissions and
     15 # limitations under the License.
     16 
     17 """Unit tests for xt_qtaguid."""
     18 
     19 import errno
     20 from socket import *  # pylint: disable=wildcard-import
     21 import unittest
     22 import os
     23 
     24 import net_test
     25 import packets
     26 import tcp_test
     27 
     28 CTRL_PROCPATH = "/proc/net/xt_qtaguid/ctrl"
     29 OTHER_UID_GID = 12345
     30 
     31 class QtaguidTest(tcp_test.TcpBaseTest):
     32 
     33   def RunIptablesCommand(self, args):
     34     self.assertFalse(net_test.RunIptablesCommand(4, args))
     35     self.assertFalse(net_test.RunIptablesCommand(6, args))
     36 
     37   def setUp(self):
     38     self.RunIptablesCommand("-N qtaguid_test_OUTPUT")
     39     self.RunIptablesCommand("-A OUTPUT -j qtaguid_test_OUTPUT")
     40 
     41   def tearDown(self):
     42     self.RunIptablesCommand("-D OUTPUT -j qtaguid_test_OUTPUT")
     43     self.RunIptablesCommand("-F qtaguid_test_OUTPUT")
     44     self.RunIptablesCommand("-X qtaguid_test_OUTPUT")
     45 
     46   def WriteToCtrl(self, command):
     47     ctrl_file = open(CTRL_PROCPATH, 'w')
     48     ctrl_file.write(command)
     49     ctrl_file.close()
     50 
     51   def CheckTag(self, tag, uid):
     52     for line in open(CTRL_PROCPATH, 'r').readlines():
     53       if "tag=0x%x (uid=%d)" % ((tag|uid), uid) in line:
     54         return True
     55     return False
     56 
     57   def SetIptablesRule(self, version, is_add, is_gid, my_id, inverted):
     58     add_del = "-A" if is_add else "-D"
     59     uid_gid = "--gid-owner" if is_gid else "--uid-owner"
     60     if inverted:
     61       args = "%s qtaguid_test_OUTPUT -m owner ! %s %d -j DROP" % (add_del, uid_gid, my_id)
     62     else:
     63       args = "%s qtaguid_test_OUTPUT -m owner %s %d -j DROP" % (add_del, uid_gid, my_id)
     64     self.assertFalse(net_test.RunIptablesCommand(version, args))
     65 
     66   def AddIptablesRule(self, version, is_gid, myId):
     67     self.SetIptablesRule(version, True, is_gid, myId, False)
     68 
     69   def AddIptablesInvertedRule(self, version, is_gid, myId):
     70     self.SetIptablesRule(version, True, is_gid, myId, True)
     71 
     72   def DelIptablesRule(self, version, is_gid, myId):
     73     self.SetIptablesRule(version, False, is_gid, myId, False)
     74 
     75   def DelIptablesInvertedRule(self, version, is_gid, myId):
     76     self.SetIptablesRule(version, False, is_gid, myId, True)
     77 
     78   def CheckSocketOutput(self, version, is_gid):
     79     myId = os.getgid() if is_gid else os.getuid()
     80     self.AddIptablesRule(version, is_gid, myId)
     81     family = {4: AF_INET, 6: AF_INET6}[version]
     82     s = socket(family, SOCK_DGRAM, 0)
     83     addr = {4: "127.0.0.1", 6: "::1"}[version]
     84     s.bind((addr, 0))
     85     addr = s.getsockname()
     86     self.assertRaisesErrno(errno.EPERM, s.sendto, "foo", addr)
     87     self.DelIptablesRule(version, is_gid, myId)
     88     s.sendto("foo", addr)
     89     data, sockaddr = s.recvfrom(4096)
     90     self.assertEqual("foo", data)
     91     self.assertEqual(sockaddr, addr)
     92 
     93   def CheckSocketOutputInverted(self, version, is_gid):
     94     # Load a inverted iptable rule on current uid/gid 0, traffic from other
     95     # uid/gid should be blocked and traffic from current uid/gid should pass.
     96     myId = os.getgid() if is_gid else os.getuid()
     97     self.AddIptablesInvertedRule(version, is_gid, myId)
     98     family = {4: AF_INET, 6: AF_INET6}[version]
     99     s = socket(family, SOCK_DGRAM, 0)
    100     addr1 = {4: "127.0.0.1", 6: "::1"}[version]
    101     s.bind((addr1, 0))
    102     addr1 = s.getsockname()
    103     s.sendto("foo", addr1)
    104     data, sockaddr = s.recvfrom(4096)
    105     self.assertEqual("foo", data)
    106     self.assertEqual(sockaddr, addr1)
    107     with net_test.RunAsUidGid(0 if is_gid else 12345,
    108                               12345 if is_gid else 0):
    109       s2 = socket(family, SOCK_DGRAM, 0)
    110       addr2 = {4: "127.0.0.1", 6: "::1"}[version]
    111       s2.bind((addr2, 0))
    112       addr2 = s2.getsockname()
    113       self.assertRaisesErrno(errno.EPERM, s2.sendto, "foo", addr2)
    114     self.DelIptablesInvertedRule(version, is_gid, myId)
    115     s.sendto("foo", addr1)
    116     data, sockaddr = s.recvfrom(4096)
    117     self.assertEqual("foo", data)
    118     self.assertEqual(sockaddr, addr1)
    119 
    120   def SendRSTOnClosedSocket(self, version, netid, expect_rst):
    121     self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, netid)
    122     self.accepted.setsockopt(net_test.SOL_TCP, net_test.TCP_LINGER2, -1)
    123     net_test.EnableFinWait(self.accepted)
    124     self.accepted.shutdown(SHUT_WR)
    125     desc, fin = self.FinPacket()
    126     self.ExpectPacketOn(netid, "Closing FIN_WAIT1 socket", fin)
    127     finversion = 4 if version == 5 else version
    128     desc, finack = packets.ACK(finversion, self.remoteaddr, self.myaddr, fin)
    129     self.ReceivePacketOn(netid, finack)
    130     try:
    131       self.ExpectPacketOn(netid, "Closing FIN_WAIT1 socket", fin)
    132     except AssertionError:
    133       pass
    134     self.accepted.close()
    135     desc, rst = packets.RST(version, self.myaddr, self.remoteaddr, self.last_packet)
    136     if expect_rst:
    137       msg = "closing socket with linger2, expecting %s: " % desc
    138       self.ExpectPacketOn(netid, msg, rst)
    139     else:
    140       msg = "closing socket with linger2, expecting no packets"
    141       self.ExpectNoPacketsOn(netid, msg)
    142 
    143   def CheckUidGidCombination(self, version, invert_gid, invert_uid):
    144     my_uid = os.getuid()
    145     my_gid = os.getgid()
    146     if invert_gid:
    147       self.AddIptablesInvertedRule(version, True, my_gid)
    148     else:
    149       self.AddIptablesRule(version, True, OTHER_UID_GID)
    150     if invert_uid:
    151       self.AddIptablesInvertedRule(version, False, my_uid)
    152     else:
    153       self.AddIptablesRule(version, False, OTHER_UID_GID)
    154     for netid in self.NETIDS:
    155       self.SendRSTOnClosedSocket(version, netid, not invert_gid)
    156     if invert_gid:
    157       self.DelIptablesInvertedRule(version, True, my_gid)
    158     else:
    159       self.DelIptablesRule(version, True, OTHER_UID_GID)
    160     if invert_uid:
    161       self.AddIptablesInvertedRule(version, False, my_uid)
    162     else:
    163       self.DelIptablesRule(version, False, OTHER_UID_GID)
    164 
    165   def testCloseWithoutUntag(self):
    166     self.dev_file = open("/dev/xt_qtaguid", "r");
    167     sk = socket(AF_INET, SOCK_DGRAM, 0)
    168     uid = os.getuid()
    169     tag = 0xff00ff00 << 32
    170     command =  "t %d %d %d" % (sk.fileno(), tag, uid)
    171     self.WriteToCtrl(command)
    172     self.assertTrue(self.CheckTag(tag, uid))
    173     sk.close();
    174     self.assertFalse(self.CheckTag(tag, uid))
    175     self.dev_file.close();
    176 
    177   def testTagWithoutDeviceOpen(self):
    178     sk = socket(AF_INET, SOCK_DGRAM, 0)
    179     uid = os.getuid()
    180     tag = 0xff00ff00 << 32
    181     command = "t %d %d %d" % (sk.fileno(), tag, uid)
    182     self.WriteToCtrl(command)
    183     self.assertTrue(self.CheckTag(tag, uid))
    184     self.dev_file = open("/dev/xt_qtaguid", "r")
    185     sk.close()
    186     self.assertFalse(self.CheckTag(tag, uid))
    187     self.dev_file.close();
    188 
    189   def testUidGidMatch(self):
    190     self.CheckSocketOutput(4, False)
    191     self.CheckSocketOutput(6, False)
    192     self.CheckSocketOutput(4, True)
    193     self.CheckSocketOutput(6, True)
    194     self.CheckSocketOutputInverted(4, True)
    195     self.CheckSocketOutputInverted(6, True)
    196     self.CheckSocketOutputInverted(4, False)
    197     self.CheckSocketOutputInverted(6, False)
    198 
    199   def testCheckNotMatchGid(self):
    200     self.assertIn("match_no_sk_gid", open(CTRL_PROCPATH, 'r').read())
    201 
    202   def testRstPacketNotDropped(self):
    203     my_uid = os.getuid()
    204     self.AddIptablesInvertedRule(4, False, my_uid)
    205     for netid in self.NETIDS:
    206       self.SendRSTOnClosedSocket(4, netid, True)
    207     self.DelIptablesInvertedRule(4, False, my_uid)
    208     self.AddIptablesInvertedRule(6, False, my_uid)
    209     for netid in self.NETIDS:
    210       self.SendRSTOnClosedSocket(6, netid, True)
    211     self.DelIptablesInvertedRule(6, False, my_uid)
    212 
    213   def testUidGidCombineMatch(self):
    214     self.CheckUidGidCombination(4, invert_gid=True, invert_uid=True)
    215     self.CheckUidGidCombination(4, invert_gid=True, invert_uid=False)
    216     self.CheckUidGidCombination(4, invert_gid=False, invert_uid=True)
    217     self.CheckUidGidCombination(4, invert_gid=False, invert_uid=False)
    218     self.CheckUidGidCombination(6, invert_gid=True, invert_uid=True)
    219     self.CheckUidGidCombination(6, invert_gid=True, invert_uid=False)
    220     self.CheckUidGidCombination(6, invert_gid=False, invert_uid=True)
    221     self.CheckUidGidCombination(6, invert_gid=False, invert_uid=False)
    222 
    223 
    224 if __name__ == "__main__":
    225   unittest.main()
    226