Home | History | Annotate | Download | only in test
      1 #!/usr/bin/python
      2 #
      3 # Copyright 2017 The Android Open Source Project
      4 #
      5 # Licensed under the Apache License, Version 2.0 (the "License");
      6 # you may not use this file except in compliance with the License.
      7 # You may obtain a copy of the License at
      8 #
      9 # http://www.apache.org/licenses/LICENSE-2.0
     10 #
     11 # Unless required by applicable law or agreed to in writing, software
     12 # distributed under the License is distributed on an "AS IS" BASIS,
     13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     14 # See the License for the specific language governing permissions and
     15 # limitations under the License.
     16 
     17 """Unit tests for xt_qtaguid."""
     18 
     19 import errno
     20 from socket import *  # pylint: disable=wildcard-import
     21 import unittest
     22 import os
     23 import net_test
     24 
     25 CTRL_PROCPATH = "/proc/net/xt_qtaguid/ctrl"
     26 
     27 class QtaguidTest(net_test.NetworkTest):
     28 
     29   def RunIptablesCommand(self, args):
     30     self.assertFalse(net_test.RunIptablesCommand(4, args))
     31     self.assertFalse(net_test.RunIptablesCommand(6, args))
     32 
     33   def setUp(self):
     34     self.RunIptablesCommand("-N qtaguid_test_OUTPUT")
     35     self.RunIptablesCommand("-A OUTPUT -j qtaguid_test_OUTPUT")
     36 
     37   def tearDown(self):
     38     self.RunIptablesCommand("-D OUTPUT -j qtaguid_test_OUTPUT")
     39     self.RunIptablesCommand("-F qtaguid_test_OUTPUT")
     40     self.RunIptablesCommand("-X qtaguid_test_OUTPUT")
     41 
     42   def WriteToCtrl(self, command):
     43     ctrl_file = open(CTRL_PROCPATH, 'w')
     44     ctrl_file.write(command)
     45     ctrl_file.close()
     46 
     47   def CheckTag(self, tag, uid):
     48     for line in open(CTRL_PROCPATH, 'r').readlines():
     49       if "tag=0x%x (uid=%d)" % ((tag|uid), uid) in line:
     50         return True
     51     return False
     52 
     53   def SetIptablesRule(self, version, is_add, is_gid, my_id, inverted):
     54     add_del = "-A" if is_add else "-D"
     55     uid_gid = "--gid-owner" if is_gid else "--uid-owner"
     56     if inverted:
     57       args = "%s qtaguid_test_OUTPUT -m owner ! %s %d -j DROP" % (add_del, uid_gid, my_id)
     58     else:
     59       args = "%s qtaguid_test_OUTPUT -m owner %s %d -j DROP" % (add_del, uid_gid, my_id)
     60     self.assertFalse(net_test.RunIptablesCommand(version, args))
     61 
     62   def AddIptablesRule(self, version, is_gid, myId):
     63     self.SetIptablesRule(version, True, is_gid, myId, False)
     64 
     65   def AddIptablesInvertedRule(self, version, is_gid, myId):
     66     self.SetIptablesRule(version, True, is_gid, myId, True)
     67 
     68   def DelIptablesRule(self, version, is_gid, myId):
     69     self.SetIptablesRule(version, False, is_gid, myId, False)
     70 
     71   def DelIptablesInvertedRule(self, version, is_gid, myId):
     72     self.SetIptablesRule(version, False, is_gid, myId, True)
     73 
     74   def CheckSocketOutput(self, version, is_gid):
     75     myId = os.getgid() if is_gid else os.getuid()
     76     self.AddIptablesRule(version, is_gid, myId)
     77     family = {4: AF_INET, 6: AF_INET6}[version]
     78     s = socket(family, SOCK_DGRAM, 0)
     79     addr = {4: "127.0.0.1", 6: "::1"}[version]
     80     s.bind((addr, 0))
     81     addr = s.getsockname()
     82     self.assertRaisesErrno(errno.EPERM, s.sendto, "foo", addr)
     83     self.DelIptablesRule(version, is_gid, myId)
     84     s.sendto("foo", addr)
     85     data, sockaddr = s.recvfrom(4096)
     86     self.assertEqual("foo", data)
     87     self.assertEqual(sockaddr, addr)
     88 
     89   def CheckSocketOutputInverted(self, version, is_gid):
     90     # Load a inverted iptable rule on current uid/gid 0, traffic from other
     91     # uid/gid should be blocked and traffic from current uid/gid should pass.
     92     myId = os.getgid() if is_gid else os.getuid()
     93     self.AddIptablesInvertedRule(version, is_gid, myId)
     94     family = {4: AF_INET, 6: AF_INET6}[version]
     95     s = socket(family, SOCK_DGRAM, 0)
     96     addr1 = {4: "127.0.0.1", 6: "::1"}[version]
     97     s.bind((addr1, 0))
     98     addr1 = s.getsockname()
     99     s.sendto("foo", addr1)
    100     data, sockaddr = s.recvfrom(4096)
    101     self.assertEqual("foo", data)
    102     self.assertEqual(sockaddr, addr1)
    103     with net_test.RunAsUidGid(0 if is_gid else 12345,
    104                               12345 if is_gid else 0):
    105       s2 = socket(family, SOCK_DGRAM, 0)
    106       addr2 = {4: "127.0.0.1", 6: "::1"}[version]
    107       s2.bind((addr2, 0))
    108       addr2 = s2.getsockname()
    109       self.assertRaisesErrno(errno.EPERM, s2.sendto, "foo", addr2)
    110     self.DelIptablesInvertedRule(version, is_gid, myId)
    111     s.sendto("foo", addr1)
    112     data, sockaddr = s.recvfrom(4096)
    113     self.assertEqual("foo", data)
    114     self.assertEqual(sockaddr, addr1)
    115 
    116   def testCloseWithoutUntag(self):
    117     self.dev_file = open("/dev/xt_qtaguid", "r");
    118     sk = socket(AF_INET, SOCK_DGRAM, 0)
    119     uid = os.getuid()
    120     tag = 0xff00ff00 << 32
    121     command =  "t %d %d %d" % (sk.fileno(), tag, uid)
    122     self.WriteToCtrl(command)
    123     self.assertTrue(self.CheckTag(tag, uid))
    124     sk.close();
    125     self.assertFalse(self.CheckTag(tag, uid))
    126     self.dev_file.close();
    127 
    128   def testTagWithoutDeviceOpen(self):
    129     sk = socket(AF_INET, SOCK_DGRAM, 0)
    130     uid = os.getuid()
    131     tag = 0xff00ff00 << 32
    132     command = "t %d %d %d" % (sk.fileno(), tag, uid)
    133     self.WriteToCtrl(command)
    134     self.assertTrue(self.CheckTag(tag, uid))
    135     self.dev_file = open("/dev/xt_qtaguid", "r")
    136     sk.close()
    137     self.assertFalse(self.CheckTag(tag, uid))
    138     self.dev_file.close();
    139 
    140   def testUidGidMatch(self):
    141     self.CheckSocketOutput(4, False)
    142     self.CheckSocketOutput(6, False)
    143     self.CheckSocketOutput(4, True)
    144     self.CheckSocketOutput(6, True)
    145     self.CheckSocketOutputInverted(4, True)
    146     self.CheckSocketOutputInverted(6, True)
    147     self.CheckSocketOutputInverted(4, False)
    148     self.CheckSocketOutputInverted(6, False)
    149 
    150   @unittest.skip("does not pass on current kernels")
    151   def testCheckNotMatchGid(self):
    152     self.assertIn("match_no_sk_gid", open(CTRL_PROCPATH, 'r').read())
    153 
    154 
    155 if __name__ == "__main__":
    156   unittest.main()
    157