Home | History | Annotate | Download | only in iptables
      1 #!/usr/bin/python
      2 #
      3 # (C) 2012-2013 by Pablo Neira Ayuso <pablo (at] netfilter.org>
      4 #
      5 # This program is free software; you can redistribute it and/or modify
      6 # it under the terms of the GNU General Public License as published by
      7 # the Free Software Foundation; either version 2 of the License, or
      8 # (at your option) any later version.
      9 #
     10 # This software has been sponsored by Sophos Astaro <http://www.sophos.com>
     11 #
     12 
     13 import sys
     14 import os
     15 import subprocess
     16 import argparse
     17 
     18 IPTABLES = "iptables"
     19 IP6TABLES = "ip6tables"
     20 #IPTABLES = "xtables -4"
     21 #IP6TABLES = "xtables -6"
     22 
     23 IPTABLES_SAVE = "iptables-save"
     24 IP6TABLES_SAVE = "ip6tables-save"
     25 #IPTABLES_SAVE = ['xtables-save','-4']
     26 #IP6TABLES_SAVE = ['xtables-save','-6']
     27 
     28 EXTENSIONS_PATH = "extensions"
     29 LOGFILE="/tmp/iptables-test.log"
     30 log_file = None
     31 
     32 
     33 class Colors:
     34     HEADER = '\033[95m'
     35     BLUE = '\033[94m'
     36     GREEN = '\033[92m'
     37     YELLOW = '\033[93m'
     38     RED = '\033[91m'
     39     ENDC = '\033[0m'
     40 
     41 
     42 def print_error(reason, filename=None, lineno=None):
     43     '''
     44     Prints an error with nice colors, indicating file and line number.
     45     '''
     46     print (filename + ": " + Colors.RED + "ERROR" +
     47         Colors.ENDC + ": line %d (%s)" % (lineno, reason))
     48 
     49 
     50 def delete_rule(iptables, rule, filename, lineno):
     51     '''
     52     Removes an iptables rule
     53     '''
     54     cmd = iptables + " -D " + rule
     55     ret = execute_cmd(cmd, filename, lineno)
     56     if ret == 1:
     57         reason = "cannot delete: " + iptables + " -I " + rule
     58         print_error(reason, filename, lineno)
     59         return -1
     60 
     61     return 0
     62 
     63 
     64 def run_test(iptables, rule, rule_save, res, filename, lineno):
     65     '''
     66     Executes an unit test. Returns the output of delete_rule().
     67 
     68     Parameters:
     69     :param  iptables: string with the iptables command to execute
     70     :param rule: string with iptables arguments for the rule to test
     71     :param rule_save: string to find the rule in the output of iptables -save
     72     :param res: expected result of the rule. Valid values: "OK", "FAIL"
     73     :param filename: name of the file tested (used for print_error purposes)
     74     :param lineno: line number being tested (used for print_error purposes)
     75     '''
     76     ret = 0
     77 
     78     cmd = iptables + " -A " + rule
     79     ret = execute_cmd(cmd, filename, lineno)
     80 
     81     #
     82     # report failed test
     83     #
     84     if ret:
     85         if res == "OK":
     86             reason = "cannot load: " + cmd
     87             print_error(reason, filename, lineno)
     88             return -1
     89         else:
     90             # do not report this error
     91             return 0
     92     else:
     93         if res == "FAIL":
     94             reason = "should fail: " + cmd
     95             print_error(reason, filename, lineno)
     96             delete_rule(iptables, rule, filename, lineno)
     97             return -1
     98 
     99     matching = 0
    100     splitted = iptables.split(" ")
    101     if len(splitted) == 2:
    102         if splitted[1] == '-4':
    103             command = IPTABLES_SAVE
    104         elif splitted[1] == '-6':
    105             command = IP6TABLES_SAVE
    106     elif len(splitted) == 1:
    107         if splitted[0] == IPTABLES:
    108             command = IPTABLES_SAVE
    109         elif splitted[0] == IP6TABLES:
    110             command = IP6TABLES_SAVE
    111     args = splitted[1:]
    112     proc = subprocess.Popen(command, stdin=subprocess.PIPE,
    113                             stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    114     out, err = proc.communicate()
    115 
    116     #
    117     # check for segfaults
    118     #
    119     if proc.returncode == -11:
    120         reason = "iptables-save segfaults: " + cmd
    121         print_error(reason, filename, lineno)
    122         delete_rule(iptables, rule, filename, lineno)
    123         return -1
    124 
    125     # find the rule
    126     matching = out.find(rule_save)
    127     if matching < 0:
    128         reason = "cannot find: " + iptables + " -I " + rule
    129         print_error(reason, filename, lineno)
    130         delete_rule(iptables, rule, filename, lineno)
    131         return -1
    132 
    133     return delete_rule(iptables, rule, filename, lineno)
    134 
    135 
    136 def execute_cmd(cmd, filename, lineno):
    137     '''
    138     Executes a command, checking for segfaults and returning the command exit
    139     code.
    140 
    141     :param cmd: string with the command to be executed
    142     :param filename: name of the file tested (used for print_error purposes)
    143     :param lineno: line number being tested (used for print_error purposes)
    144     '''
    145     global log_file
    146     print >> log_file, "command: %s" % cmd
    147     ret = subprocess.call(cmd, shell=True, universal_newlines=True,
    148         stderr=subprocess.STDOUT, stdout=log_file)
    149     log_file.flush()
    150 
    151     # generic check for segfaults
    152     if ret  == -11:
    153         reason = "command segfaults: " + cmd
    154         print_error(reason, filename, lineno)
    155     return ret
    156 
    157 
    158 def run_test_file(filename):
    159     '''
    160     Runs a test file
    161 
    162     :param filename: name of the file with the test rules
    163     '''
    164     #
    165     # if this is not a test file, skip.
    166     #
    167     if not filename.endswith(".t"):
    168         return 0, 0
    169 
    170     if "libipt_" in filename:
    171         iptables = IPTABLES
    172     elif "libip6t_" in filename:
    173         iptables = IP6TABLES
    174     elif "libxt_"  in filename:
    175         iptables = IPTABLES
    176     else:
    177         # default to iptables if not known prefix
    178         iptables = IPTABLES
    179 
    180     f = open(filename)
    181 
    182     tests = 0
    183     passed = 0
    184     table = ""
    185     total_test_passed = True
    186 
    187     for lineno, line in enumerate(f):
    188         if line[0] == "#":
    189             continue
    190 
    191         if line[0] == ":":
    192             chain_array = line.rstrip()[1:].split(",")
    193             continue
    194 
    195         # external non-iptables invocation, executed as is.
    196         if line[0] == "@":
    197             external_cmd = line.rstrip()[1:]
    198             execute_cmd(external_cmd, filename, lineno)
    199             continue
    200 
    201         if line[0] == "*":
    202             table = line.rstrip()[1:]
    203             continue
    204 
    205         if len(chain_array) == 0:
    206             print "broken test, missing chain, leaving"
    207             sys.exit()
    208 
    209         test_passed = True
    210         tests += 1
    211 
    212         for chain in chain_array:
    213             item = line.split(";")
    214             if table == "":
    215                 rule = chain + " " + item[0]
    216             else:
    217                 rule = chain + " -t " + table + " " + item[0]
    218 
    219             if item[1] == "=":
    220                 rule_save = chain + " " + item[0]
    221             else:
    222                 rule_save = chain + " " + item[1]
    223 
    224             res = item[2].rstrip()
    225 
    226             ret = run_test(iptables, rule, rule_save,
    227                            res, filename, lineno + 1)
    228             if ret < 0:
    229                 test_passed = False
    230                 total_test_passed = False
    231                 break
    232 
    233         if test_passed:
    234             passed += 1
    235 
    236     if total_test_passed:
    237         print filename + ": " + Colors.GREEN + "OK" + Colors.ENDC
    238 
    239     f.close()
    240     return tests, passed
    241 
    242 
    243 def show_missing():
    244     '''
    245     Show the list of missing test files
    246     '''
    247     file_list = os.listdir(EXTENSIONS_PATH)
    248     testfiles = [i for i in file_list if i.endswith('.t')]
    249     libfiles = [i for i in file_list
    250                 if i.startswith('lib') and i.endswith('.c')]
    251 
    252     def test_name(x):
    253         return x[0:-2] + '.t'
    254     missing = [test_name(i) for i in libfiles
    255                if not test_name(i) in testfiles]
    256 
    257     print '\n'.join(missing)
    258 
    259 
    260 #
    261 # main
    262 #
    263 def main():
    264     parser = argparse.ArgumentParser(description='Run iptables tests')
    265     parser.add_argument('filename', nargs='?',
    266                         metavar='path/to/file.t',
    267                         help='Run only this test')
    268     parser.add_argument('-m', '--missing', action='store_true',
    269                         help='Check for missing tests')
    270     args = parser.parse_args()
    271 
    272     #
    273     # show list of missing test files
    274     #
    275     if args.missing:
    276         show_missing()
    277         return
    278 
    279     if os.getuid() != 0:
    280         print "You need to be root to run this, sorry"
    281         return
    282 
    283     test_files = 0
    284     tests = 0
    285     passed = 0
    286 
    287     # setup global var log file
    288     global log_file
    289     try:
    290         log_file = open(LOGFILE, 'w')
    291     except IOError:
    292         print "Couldn't open log file %s" % LOGFILE
    293         return
    294 
    295     file_list = [os.path.join(EXTENSIONS_PATH, i)
    296                  for i in os.listdir(EXTENSIONS_PATH)]
    297     if args.filename:
    298         file_list = [args.filename]
    299     for filename in file_list:
    300         file_tests, file_passed = run_test_file(filename)
    301         if file_tests:
    302             tests += file_tests
    303             passed += file_passed
    304             test_files += 1
    305 
    306     print ("%d test files, %d unit tests, %d passed" %
    307            (test_files, tests, passed))
    308 
    309 
    310 if __name__ == '__main__':
    311     main()
    312