Home | History | Annotate | Download | only in tools
      1 #!/usr/bin/env python
      2 #
      3 # This script generates a BPF program with structure inspired by trace.py. The
      4 # generated program operates on PID-indexed stacks. Generally speaking,
      5 # bookkeeping is done at every intermediate function kprobe/kretprobe to enforce
      6 # the goal of "fail iff this call chain and these predicates".
      7 #
      8 # Top level functions(the ones at the end of the call chain) are responsible for
      9 # creating the pid_struct and deleting it from the map in kprobe and kretprobe
     10 # respectively.
     11 #
     12 # Intermediate functions(between should_fail_whatever and the top level
     13 # functions) are responsible for updating the stack to indicate "I have been
     14 # called and one of my predicate(s) passed" in their entry probes. In their exit
     15 # probes, they do the opposite, popping their stack to maintain correctness.
     16 # This implementation aims to ensure correctness in edge cases like recursive
     17 # calls, so there's some additional information stored in pid_struct for that.
     18 #
     19 # At the bottom level function(should_fail_whatever), we do a simple check to
     20 # ensure all necessary calls/predicates have passed before error injection.
     21 #
     22 # Note: presently there are a few hacks to get around various rewriter/verifier
     23 # issues.
     24 #
     25 # Note: this tool requires:
     26 # - CONFIG_BPF_KPROBE_OVERRIDE
     27 #
     28 # USAGE: inject [-h] [-I header] [-P probability] [-v] mode spec
     29 #
     30 # Copyright (c) 2018 Facebook, Inc.
     31 # Licensed under the Apache License, Version 2.0 (the "License")
     32 #
     33 # 16-Mar-2018   Howard McLauchlan   Created this.
     34 
     35 import argparse
     36 import re
     37 from bcc import BPF
     38 
     39 
     40 class Probe:
     41     errno_mapping = {
     42         "kmalloc": "-ENOMEM",
     43         "bio": "-EIO",
     44     }
     45 
     46     @classmethod
     47     def configure(cls, mode, probability):
     48         cls.mode = mode
     49         cls.probability = probability
     50 
     51     def __init__(self, func, preds, length, entry):
     52         # length of call chain
     53         self.length = length
     54         self.func = func
     55         self.preds = preds
     56         self.is_entry = entry
     57 
     58     def _bail(self, err):
     59         raise ValueError("error in probe '%s': %s" %
     60                 (self.spec, err))
     61 
     62     def _get_err(self):
     63         return Probe.errno_mapping[Probe.mode]
     64 
     65     def _get_if_top(self):
     66         # ordering guarantees that if this function is top, the last tup is top
     67         chk = self.preds[0][1] == 0
     68         if not chk:
     69             return ""
     70 
     71         if Probe.probability == 1:
     72             early_pred = "false"
     73         else:
     74             early_pred = "bpf_get_prandom_u32() > %s" % str(int((1<<32)*Probe.probability))
     75         # init the map
     76         # dont do an early exit here so the singular case works automatically
     77         # have an early exit for probability option
     78         enter = """
     79         /*
     80          * Early exit for probability case
     81          */
     82         if (%s)
     83                return 0;
     84         /*
     85          * Top level function init map
     86          */
     87         struct pid_struct p_struct = {0, 0};
     88         m.insert(&pid, &p_struct);
     89         """ % early_pred
     90 
     91         # kill the entry
     92         exit = """
     93         /*
     94          * Top level function clean up map
     95          */
     96         m.delete(&pid);
     97         """
     98 
     99         return enter if self.is_entry else exit
    100 
    101     def _get_heading(self):
    102 
    103         # we need to insert identifier and ctx into self.func
    104         # gonna make a lot of formatting assumptions to make this work
    105         left = self.func.find("(")
    106         right = self.func.rfind(")")
    107 
    108         # self.event and self.func_name need to be accessible
    109         self.event = self.func[0:left]
    110         self.func_name = self.event + ("_entry" if self.is_entry else "_exit")
    111         func_sig = "struct pt_regs *ctx"
    112 
    113         # assume theres something in there, no guarantee its well formed
    114         if right > left + 1 and self.is_entry:
    115             func_sig += ", " + self.func[left + 1:right]
    116 
    117         return "int %s(%s)" % (self.func_name, func_sig)
    118 
    119     def _get_entry_logic(self):
    120         # there is at least one tup(pred, place) for this function
    121         text = """
    122 
    123         if (p->conds_met >= %s)
    124                 return 0;
    125         if (p->conds_met == %s && %s) {
    126                 p->stack[%s] = p->curr_call;
    127                 p->conds_met++;
    128         }"""
    129         text = text % (self.length, self.preds[0][1], self.preds[0][0],
    130                 self.preds[0][1])
    131 
    132         # for each additional pred
    133         for tup in self.preds[1:]:
    134             text += """
    135         else if (p->conds_met == %s && %s) {
    136                 p->stack[%s] = p->curr_call;
    137                 p->conds_met++;
    138         }
    139             """ % (tup[1], tup[0], tup[1])
    140         return text
    141 
    142     def _generate_entry(self):
    143         prog = self._get_heading() + """
    144 {
    145         u32 pid = bpf_get_current_pid_tgid();
    146         %s
    147 
    148         struct pid_struct *p = m.lookup(&pid);
    149 
    150         if (!p)
    151                 return 0;
    152 
    153         /*
    154          * preparation for predicate, if necessary
    155          */
    156          %s
    157         /*
    158          * Generate entry logic
    159          */
    160         %s
    161 
    162         p->curr_call++;
    163 
    164         return 0;
    165 }"""
    166 
    167         prog = prog % (self._get_if_top(), self.prep, self._get_entry_logic())
    168         return prog
    169 
    170     # only need to check top of stack
    171     def _get_exit_logic(self):
    172         text = """
    173         if (p->conds_met < 1 || p->conds_met >= %s)
    174                 return 0;
    175 
    176         if (p->stack[p->conds_met - 1] == p->curr_call)
    177                 p->conds_met--;
    178         """
    179         return text % str(self.length + 1)
    180 
    181     def _generate_exit(self):
    182         prog = self._get_heading() + """
    183 {
    184         u32 pid = bpf_get_current_pid_tgid();
    185 
    186         struct pid_struct *p = m.lookup(&pid);
    187 
    188         if (!p)
    189                 return 0;
    190 
    191         p->curr_call--;
    192 
    193         /*
    194          * Generate exit logic
    195          */
    196         %s
    197         %s
    198         return 0;
    199 }"""
    200 
    201         prog = prog % (self._get_exit_logic(), self._get_if_top())
    202 
    203         return prog
    204 
    205     # Special case for should_fail_whatever
    206     def _generate_bottom(self):
    207         pred = self.preds[0][0]
    208         text = self._get_heading() + """
    209 {
    210         /*
    211          * preparation for predicate, if necessary
    212          */
    213          %s
    214         /*
    215          * If this is the only call in the chain and predicate passes
    216          */
    217         if (%s == 1 && %s) {
    218                 bpf_override_return(ctx, %s);
    219                 return 0;
    220         }
    221         u32 pid = bpf_get_current_pid_tgid();
    222 
    223         struct pid_struct *p = m.lookup(&pid);
    224 
    225         if (!p)
    226                 return 0;
    227 
    228         /*
    229          * If all conds have been met and predicate passes
    230          */
    231         if (p->conds_met == %s && %s)
    232                 bpf_override_return(ctx, %s);
    233         return 0;
    234 }"""
    235         return text % (self.prep, self.length, pred, self._get_err(),
    236                     self.length - 1, pred, self._get_err())
    237 
    238     # presently parses and replaces STRCMP
    239     # STRCMP exists because string comparison is inconvenient and somewhat buggy
    240     # https://github.com/iovisor/bcc/issues/1617
    241     def _prepare_pred(self):
    242         self.prep = ""
    243         for i in range(len(self.preds)):
    244             new_pred = ""
    245             pred = self.preds[i][0]
    246             place = self.preds[i][1]
    247             start, ind = 0, 0
    248             while start < len(pred):
    249                 ind = pred.find("STRCMP(", start)
    250                 if ind == -1:
    251                     break
    252                 new_pred += pred[start:ind]
    253                 # 7 is len("STRCMP(")
    254                 start = pred.find(")", start + 7) + 1
    255 
    256                 # then ind ... start is STRCMP(...)
    257                 ptr, literal = pred[ind + 7:start - 1].split(",")
    258                 literal = literal.strip()
    259 
    260                 # x->y->z, some string literal
    261                 # we make unique id with place_ind
    262                 uuid = "%s_%s" % (place, ind)
    263                 unique_bool = "is_true_%s" % uuid
    264                 self.prep += """
    265         char *str_%s = %s;
    266         bool %s = true;\n""" % (uuid, ptr.strip(), unique_bool)
    267 
    268                 check = "\t%s &= *(str_%s++) == '%%s';\n" % (unique_bool, uuid)
    269 
    270                 for ch in literal:
    271                     self.prep += check % ch
    272                 self.prep += check % r'\0'
    273                 new_pred += unique_bool
    274 
    275             new_pred += pred[start:]
    276             self.preds[i] = (new_pred, place)
    277 
    278     def generate_program(self):
    279         # generate code to work around various rewriter issues
    280         self._prepare_pred()
    281 
    282         # special case for bottom
    283         if self.preds[-1][1] == self.length - 1:
    284             return self._generate_bottom()
    285 
    286         return self._generate_entry() if self.is_entry else self._generate_exit()
    287 
    288     def attach(self, bpf):
    289         if self.is_entry:
    290             bpf.attach_kprobe(event=self.event,
    291                     fn_name=self.func_name)
    292         else:
    293             bpf.attach_kretprobe(event=self.event,
    294                     fn_name=self.func_name)
    295 
    296 
    297 class Tool:
    298 
    299     examples ="""
    300 EXAMPLES:
    301 # ./inject.py kmalloc -v 'SyS_mount()'
    302     Fails all calls to syscall mount
    303 # ./inject.py kmalloc -v '(true) => SyS_mount()(true)'
    304     Explicit rewriting of above
    305 # ./inject.py kmalloc -v 'mount_subtree() => btrfs_mount()'
    306     Fails btrfs mounts only
    307 # ./inject.py kmalloc -v 'd_alloc_parallel(struct dentry *parent, const struct \\
    308     qstr *name)(STRCMP(name->name, 'bananas'))'
    309     Fails dentry allocations of files named 'bananas'
    310 # ./inject.py kmalloc -v -P 0.01 'SyS_mount()'
    311     Fails calls to syscall mount with 1% probability
    312     """
    313     # add cases as necessary
    314     error_injection_mapping = {
    315         "kmalloc": "should_failslab(struct kmem_cache *s, gfp_t gfpflags)",
    316         "bio": "should_fail_bio(struct bio *bio)",
    317     }
    318 
    319     def __init__(self):
    320         parser = argparse.ArgumentParser(description="Fail specified kernel" +
    321                 " functionality when call chain and predicates are met",
    322                 formatter_class=argparse.RawDescriptionHelpFormatter,
    323                 epilog=Tool.examples)
    324         parser.add_argument(dest="mode", choices=['kmalloc','bio'],
    325                 help="indicate which base kernel function to fail")
    326         parser.add_argument(metavar="spec", dest="spec",
    327                 help="specify call chain")
    328         parser.add_argument("-I", "--include", action="append",
    329                 metavar="header",
    330                 help="additional header files to include in the BPF program")
    331         parser.add_argument("-P", "--probability", default=1,
    332                 metavar="probability", type=float,
    333                 help="probability that this call chain will fail")
    334         parser.add_argument("-v", "--verbose", action="store_true",
    335                 help="print BPF program")
    336         self.args = parser.parse_args()
    337 
    338         self.program = ""
    339         self.spec = self.args.spec
    340         self.map = {}
    341         self.probes = []
    342         self.key = Tool.error_injection_mapping[self.args.mode]
    343 
    344     # create_probes and associated stuff
    345     def _create_probes(self):
    346         self._parse_spec()
    347         Probe.configure(self.args.mode, self.args.probability)
    348         # self, func, preds, total, entry
    349 
    350         # create all the pair probes
    351         for fx, preds in self.map.items():
    352 
    353             # do the enter
    354             self.probes.append(Probe(fx, preds, self.length, True))
    355 
    356             if self.key == fx:
    357                 continue
    358 
    359             # do the exit
    360             self.probes.append(Probe(fx, preds, self.length, False))
    361 
    362     def _parse_frames(self):
    363         # sentinel
    364         data = self.spec + '\0'
    365         start, count = 0, 0
    366 
    367         frames = []
    368         cur_frame = []
    369         i = 0
    370         last_frame_added = 0
    371 
    372         while i < len(data):
    373             # improper input
    374             if count < 0:
    375                 raise Exception("Check your parentheses")
    376             c = data[i]
    377             count += c == '('
    378             count -= c == ')'
    379             if not count:
    380                 if c == '\0' or (c == '=' and data[i + 1] == '>'):
    381                     # This block is closing a chunk. This means cur_frame must
    382                     # have something in it.
    383                     if not cur_frame:
    384                         raise Exception("Cannot parse spec, missing parens")
    385                     if len(cur_frame) == 2:
    386                         frame = tuple(cur_frame)
    387                     elif cur_frame[0][0] == '(':
    388                         frame = self.key, cur_frame[0]
    389                     else:
    390                         frame = cur_frame[0], '(true)'
    391                     frames.append(frame)
    392                     del cur_frame[:]
    393                     i += 1
    394                     start = i + 1
    395                 elif c == ')':
    396                     cur_frame.append(data[start:i + 1].strip())
    397                     start = i + 1
    398                     last_frame_added = start
    399             i += 1
    400 
    401         # We only permit spaces after the last frame
    402         if self.spec[last_frame_added:].strip():
    403             raise Exception("Invalid characters found after last frame");
    404         # improper input
    405         if count:
    406             raise Exception("Check your parentheses")
    407         return frames
    408 
    409     def _parse_spec(self):
    410         frames = self._parse_frames()
    411         frames.reverse()
    412 
    413         absolute_order = 0
    414         for f in frames:
    415             # default case
    416             func, pred = f[0], f[1]
    417 
    418             if not self._validate_predicate(pred):
    419                 raise Exception("Invalid predicate")
    420             if not self._validate_identifier(func):
    421                 raise Exception("Invalid function identifier")
    422             tup = (pred, absolute_order)
    423 
    424             if func not in self.map:
    425                 self.map[func] = [tup]
    426             else:
    427                 self.map[func].append(tup)
    428 
    429             absolute_order += 1
    430 
    431         if self.key not in self.map:
    432             self.map[self.key] = [('(true)', absolute_order)]
    433             absolute_order += 1
    434 
    435         self.length = absolute_order
    436 
    437     def _validate_identifier(self, func):
    438         # We've already established paren balancing. We will only look for
    439         # identifier validity here.
    440         paren_index = func.find("(")
    441         potential_id = func[:paren_index]
    442         pattern = '[_a-zA-z][_a-zA-Z0-9]*$'
    443         if re.match(pattern, potential_id):
    444             return True
    445         return False
    446 
    447     def _validate_predicate(self, pred):
    448 
    449         if len(pred) > 0 and pred[0] == "(":
    450             open = 1
    451             for i in range(1, len(pred)):
    452                 if pred[i] == "(":
    453                     open += 1
    454                 elif pred[i] == ")":
    455                     open -= 1
    456             if open != 0:
    457                 # not well formed, break
    458                 return False
    459 
    460         return True
    461 
    462     def _def_pid_struct(self):
    463         text = """
    464 struct pid_struct {
    465     u64 curr_call; /* book keeping to handle recursion */
    466     u64 conds_met; /* stack pointer */
    467     u64 stack[%s];
    468 };
    469 """ % self.length
    470         return text
    471 
    472     def _attach_probes(self):
    473         self.bpf = BPF(text=self.program)
    474         for p in self.probes:
    475             p.attach(self.bpf)
    476 
    477     def _generate_program(self):
    478         # leave out auto includes for now
    479         self.program += '#include <linux/mm.h>\n'
    480         for include in (self.args.include or []):
    481             self.program += "#include <%s>\n" % include
    482 
    483         self.program += self._def_pid_struct()
    484         self.program += "BPF_HASH(m, u32, struct pid_struct);\n"
    485         for p in self.probes:
    486             self.program += p.generate_program() + "\n"
    487 
    488         if self.args.verbose:
    489             print(self.program)
    490 
    491     def _main_loop(self):
    492         while True:
    493             self.bpf.perf_buffer_poll()
    494 
    495     def run(self):
    496         self._create_probes()
    497         self._generate_program()
    498         self._attach_probes()
    499         self._main_loop()
    500 
    501 
    502 if __name__ == "__main__":
    503     Tool().run()
    504