Home | History | Annotate | Download | only in tools
      1 #!/usr/bin/python
      2 #
      3 # Copyright (C) 2016 The Android Open Source Project
      4 #
      5 # Licensed under the Apache License, Version 2.0 (the "License");
      6 # you may not use this file except in compliance with the License.
      7 # You may obtain a copy of the License at
      8 #
      9 #      http://www.apache.org/licenses/LICENSE-2.0
     10 #
     11 # Unless required by applicable law or agreed to in writing, software
     12 # distributed under the License is distributed on an "AS IS" BASIS,
     13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     14 # See the License for the specific language governing permissions and
     15 # limitations under the License.
     16 #
     17 # This script will take any number of trace files generated by strace(1)
     18 # and output a system call filtering policy suitable for use with Minijail.
     19 
     20 from collections import namedtuple
     21 import sys
     22 
     23 NOTICE = """# Copyright (C) 2016 The Android Open Source Project
     24 #
     25 # Licensed under the Apache License, Version 2.0 (the "License");
     26 # you may not use this file except in compliance with the License.
     27 # You may obtain a copy of the License at
     28 #
     29 #      http://www.apache.org/licenses/LICENSE-2.0
     30 #
     31 # Unless required by applicable law or agreed to in writing, software
     32 # distributed under the License is distributed on an "AS IS" BASIS,
     33 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     34 # See the License for the specific language governing permissions and
     35 # limitations under the License.
     36 """
     37 
     38 ALLOW = "%s: 1"
     39 
     40 SOCKETCALLS = ["accept", "bind", "connect", "getpeername", "getsockname",
     41                "getsockopt", "listen", "recv", "recvfrom", "recvmsg", "send",
     42                "sendmsg", "sendto", "setsockopt", "shutdown", "socket",
     43                "socketpair"]
     44 
     45 # /* Protocol families.  */
     46 # #define PF_UNSPEC     0       /* Unspecified.  */
     47 # #define PF_LOCAL      1       /* Local to host (pipes and file-domain).  */
     48 # #define PF_UNIX       PF_LOCAL /* POSIX name for PF_LOCAL.  */
     49 # #define PF_FILE       PF_LOCAL /* Another non-standard name for PF_LOCAL.  */
     50 # #define PF_INET       2       /* IP protocol family.  */
     51 # #define PF_AX25       3       /* Amateur Radio AX.25.  */
     52 # #define PF_IPX        4       /* Novell Internet Protocol.  */
     53 # #define PF_APPLETALK  5       /* Appletalk DDP.  */
     54 # #define PF_NETROM     6       /* Amateur radio NetROM.  */
     55 # #define PF_BRIDGE     7       /* Multiprotocol bridge.  */
     56 # #define PF_ATMPVC     8       /* ATM PVCs.  */
     57 # #define PF_X25        9       /* Reserved for X.25 project.  */
     58 # #define PF_INET6     10      /* IP version 6.  */
     59 # #define PF_ROSE      11      /* Amateur Radio X.25 PLP.  */
     60 # #define PF_DECnet    12      /* Reserved for DECnet project.  */
     61 # #define PF_NETBEUI   13      /* Reserved for 802.2LLC project.  */
     62 # #define PF_SECURITY  14      /* Security callback pseudo AF.  */
     63 # #define PF_KEY       15      /* PF_KEY key management API.  */
     64 # #define PF_NETLINK   16
     65 
     66 ArgInspectionEntry = namedtuple("ArgInspectionEntry", "arg_index value_set")
     67 
     68 
     69 def usage(argv):
     70     print "%s <trace file> [trace files...]" % argv[0]
     71 
     72 
     73 def main(traces):
     74     syscalls = {}
     75 
     76     uses_socketcall = False
     77 
     78     basic_set = ["restart_syscall", "exit", "exit_group",
     79                  "rt_sigreturn"]
     80     frequent_set = []
     81 
     82     syscall_sets = {}
     83     syscall_set_list = [["sigreturn", "rt_sigreturn"],
     84                         ["sigaction", "rt_sigaction"],
     85                         ["sigprocmask", "rt_sigprocmask"],
     86                         ["open", "openat"],
     87                         ["mmap", "mremap"],
     88                         ["mmap2", "mremap"]]
     89 
     90     arg_inspection = {
     91         "socket": ArgInspectionEntry(0, set([])),   # int domain
     92         "ioctl": ArgInspectionEntry(1, set([])),    # int request
     93         "prctl": ArgInspectionEntry(0, set([]))     # int option
     94     }
     95 
     96     for syscall_list in syscall_set_list:
     97         for syscall in syscall_list:
     98             other_syscalls = syscall_list[:]
     99             other_syscalls.remove(syscall)
    100             syscall_sets[syscall] = other_syscalls
    101 
    102     for trace_filename in traces:
    103         if "i386" in trace_filename or ("x86" in trace_filename and
    104                                         "64" not in trace_filename):
    105             uses_socketcall = True
    106 
    107         trace_file = open(trace_filename)
    108         for line in trace_file:
    109             if "---" in line or '(' not in line:
    110                 continue
    111 
    112             syscall, args = line.strip().split('(', 1)
    113             if uses_socketcall and syscall in SOCKETCALLS:
    114                 syscall = "socketcall"
    115 
    116             if syscall in syscalls:
    117                 syscalls[syscall] += 1
    118             else:
    119                 syscalls[syscall] = 1
    120 
    121             args = [arg.strip() for arg in args.split(')', 1)[0].split(',')]
    122 
    123             if syscall in arg_inspection:
    124                 arg_value = args[arg_inspection[syscall].arg_index]
    125                 arg_inspection[syscall].value_set.add(arg_value)
    126 
    127     sorted_syscalls = list(zip(*sorted(syscalls.iteritems(),
    128                                        key=lambda pair: pair[1],
    129                                        reverse=True))[0])
    130 
    131     print NOTICE
    132 
    133     # Add frequent syscalls first.
    134     for frequent_syscall in frequent_set:
    135         sorted_syscalls.remove(frequent_syscall)
    136 
    137     all_syscalls = frequent_set + sorted_syscalls
    138 
    139     # Add the basic set once the frequency drops below 2.
    140     below_ten_index = -1
    141     for sorted_syscall in sorted_syscalls:
    142         if syscalls[sorted_syscall] < 2:
    143             below_ten_index = all_syscalls.index(sorted_syscall)
    144             break
    145 
    146     first_half = all_syscalls[:below_ten_index]
    147     for basic_syscall in basic_set:
    148         if basic_syscall not in all_syscalls:
    149             first_half.append(basic_syscall)
    150 
    151     all_syscalls = first_half + all_syscalls[below_ten_index:]
    152 
    153     for syscall in all_syscalls:
    154         if syscall in arg_inspection:
    155             arg_index = arg_inspection[syscall].arg_index
    156             arg_values = arg_inspection[syscall].value_set
    157             arg_filter = " || ".join(["arg%d == %s" % (arg_index, arg_value)
    158                                       for arg_value in arg_values])
    159             print syscall + ": " + arg_filter
    160         else:
    161             print ALLOW % syscall
    162 
    163 
    164 if __name__ == "__main__":
    165     if len(sys.argv) < 2:
    166         usage(sys.argv)
    167         sys.exit(1)
    168 
    169     main(sys.argv[1:])
    170