Home | History | Annotate | Download | only in py
      1 # Copyright (c) 2013 The Chromium OS Authors. All rights reserved.
      2 # Use of this source code is governed by a BSD-style license that can be
      3 # found in the LICENSE file.
      4 
      5 import dpkt
      6 import os
      7 import select
      8 import struct
      9 import sys
     10 import threading
     11 import time
     12 import traceback
     13 
     14 
     15 class SimulatorError(Exception):
     16     "A Simulator generic error."
     17 
     18 
     19 class NullContext(object):
     20     """A context manager without any functionality."""
     21     def __enter__(self):
     22         return self
     23 
     24 
     25     def __exit__(self, exc_type, exc_val, exc_tb):
     26         return False # raises the exception if passed.
     27 
     28 
     29 class Simulator(object):
     30     """A TUN/TAP network interface simulator class.
     31 
     32     This class allows several implementations of different fake hosts to
     33     coexists on the same TUN/TAP interface. It will dispatch the same packet
     34     to each one of the registered hosts, providing some basic filtering
     35     to simplify these implementations.
     36     """
     37 
     38     def __init__(self, iface):
     39         """Initialize the instance.
     40 
     41         @param tuntap.TunTap iface: the interface over which this interface
     42         runs. Should not be shared with other modules.
     43         """
     44         self._iface = iface
     45         self._rules = []
     46         # _events holds a lists of events that need to be fired for each
     47         # timestamp stored on the key. The event list is a list of callback
     48         # functions that will be called if the simulation reaches that
     49         # timestamp. This is used to fire time-based events.
     50         self._events = {}
     51         self._write_queue = []
     52         # A pipe used to wake up the run() method from a diffent thread calling
     53         # stop(). See the stop() method for details.
     54         self._pipe_rd, self._pipe_wr = os.pipe()
     55         self._running = False
     56         # Lock object used for _events if multithreading is required.
     57         self._lock = NullContext()
     58 
     59 
     60     def __del__(self):
     61         os.close(self._pipe_rd)
     62         os.close(self._pipe_wr)
     63 
     64 
     65     def add_match(self, rule, callback):
     66         """Add a new match rule to the outbound traffic.
     67 
     68         This function adds a new rule that will be matched against each packet
     69         that the host sends through the interface and will call a callback if
     70         it matches. The rule can be specified in the following ways:
     71           * A python function that takes a packet as a single argument and
     72             returns True when the packet matches.
     73           * A dictionary of key=value pairs that all of them need to be matched.
     74             A pair matches when the packet has the provided chain of attributes
     75             and its value is equal to the provided value. For example, this will
     76             match any DNS traffic sent to the host 192.168.0.1:
     77             {"ip.dst": socket.inet_aton("192.168.0.1"),
     78              "ip.upd.dport": 53}
     79 
     80         @param rule: The rule description.
     81         @param callback: A callback function that receives the dpkt packet as
     82         the only argument.
     83         """
     84         if not callable(callback):
     85             raise SimulatorError("|callback| must be a callable object.")
     86 
     87         if callable(rule):
     88             self._rules.append((rule, callback))
     89         if isinstance(rule, dict):
     90             rule = dict(rule) # Makes a copy of the dict, but not the contents.
     91             self._rules.append((lambda p: self._dict_rule(rule, p), callback))
     92         else:
     93             raise SimulatorError("Unknown rule format: %r" % rule)
     94 
     95 
     96     def add_timeout(self, timeout, callback):
     97         """Add a new callback function to be called after a timeout.
     98 
     99         This method schedules the given |callback| to be called after |timeout|
    100         seconds. The callback will be called at most once while the simulator
    101         is running (see the run() method). To have a repetitive event call again
    102         add_timeout() from the callback.
    103 
    104         @param timeout: The rule description.
    105         @param callback: A callback function that doesn't receive any argument.
    106         """
    107         if not callable(callback):
    108             raise SimulatorError("|callback| must be a callable object.")
    109         timestamp = time.time() + timeout
    110         with self._lock:
    111             if timestamp not in self._events:
    112                 self._events[timestamp] = [callback]
    113             else:
    114                 self._events[timestamp].append(callback)
    115 
    116 
    117     def remove_timeout(self, callback):
    118         """Removes the every scheduled timeout call to the passed callback.
    119 
    120         When a callable object is passed to add_timeout() it is scheduled to be
    121         called once the timeout is reached. This method removes all the
    122         scheduled calls to that object.
    123 
    124         @param callback: The callable object passed to add_timeout().
    125         @return: Wether the callback was found and removed at least once.
    126         """
    127         removed = False
    128         for _ts, ev_list in self._events.iteritems():
    129             try:
    130                 while True:
    131                     ev_list.remove(callback)
    132                     removed = True
    133             except ValueError:
    134                 pass
    135         return removed
    136 
    137 
    138     def _dict_rule(self, rules, pkt):
    139         """Returns wether a given packet matches a set of rules.
    140 
    141         The maching rules passed in |rules| need to be a dict() as described
    142         on the add_match() method. The packet |pkt| is any dpkt packet.
    143         """
    144         for key, value in rules.iteritems():
    145             p = pkt
    146             for member in key.split('.'):
    147                 if not hasattr(p, member):
    148                     return False
    149                 p = getattr(p, member)
    150             if p != value:
    151                 return False
    152         return True
    153 
    154 
    155     def write(self, pkt):
    156         """Writes a packet to the network interface.
    157 
    158         @param pkt: The dpkt.Packet to be received on the network interface.
    159         """
    160         # Converts the dpkt packet to: flags, proto, buffer.
    161         self._write_queue.append(struct.pack("!HH", 0, pkt.type) + str(pkt))
    162 
    163 
    164     def run(self, timeout=None, until=None):
    165         """Runs the Simulator.
    166 
    167         This method blocks the caller thread until the timeout is reached (if
    168         a timeout is passed), until stop() is called or until the function
    169         passed in until returns a True value (if a function is passed);
    170         whichever occurs first. stop() can be called from any other thread or
    171         from a callback called from this thread.
    172 
    173         @param timeout: The timeout in seconds. Can be a float value, or None
    174         for no timeout.
    175         @param until: A callable object called during the loop returning True
    176         when the loop should stop.
    177         """
    178         if not self._iface.is_up():
    179             raise SimulatorError("Interface is down.")
    180 
    181         stop_callback = None
    182         if timeout != None:
    183             # We use a newly created callable object to avoid remove another
    184             # scheduled call to self.stop.
    185             stop_callback = lambda: self.stop()
    186             self.add_timeout(timeout, stop_callback)
    187 
    188         self._running = True
    189         iface_fd = self._iface.fileno()
    190         # Check the until function.
    191         while not (until and until()):
    192             # The main purpose of this loop is to wait (block) until the next
    193             # event is required to be fired. There are four kinds of events:
    194             #  * a packet is received.
    195             #  * a packet waiting to be sent can now be sent.
    196             #  * a time-based event needs to be fired.
    197             #  * the simulator was stopped from a different thread.
    198             # To achieve this we use select.select() to wait simultaneously on
    199             # all those event sources.
    200 
    201             # Fires all the time-based events that need to be fired and computes
    202             # the timeout for the next event if there's one.
    203             timeout = None
    204             cur_time = time.time()
    205             with self._lock:
    206                 if self._events:
    207                     # Check events that should be fired.
    208                     while self._events and min(self._events) <= cur_time:
    209                         key = min(self._events)
    210                         lst = self._events[key]
    211                         del self._events[key]
    212                         for callback in lst:
    213                             callback()
    214                         cur_time = time.time()
    215                 # Check if there is an event to attend. Here we know that
    216                 # min(self._events) > cur_time because the previous while
    217                 # finished.
    218                 if self._events:
    219                     timeout = min(self._events) - cur_time # in seconds
    220 
    221             # Pool the until() function at least once a second.
    222             if timeout is None or timeout > 1.0:
    223                 timeout = 1.0
    224 
    225             # Compute the list of file descriptors that select.select() needs to
    226             # monitor to attend the required events. select() will return when
    227             # any of the following occurs:
    228             #  * rlist: is possible to read from the interface or another
    229             #           thread want's to wake up the simulator loop.
    230             #  * wlist: is possible to write to network, if there's a packet
    231             #           pending.
    232             #  * xlist: an error on the network fd occured. Likely the TAP
    233             #           interface was closed.
    234             #  * timeout: The previously computed timeout was reached.
    235             rlist = iface_fd, self._pipe_rd
    236             wlist = tuple()
    237             if self._write_queue:
    238                 wlist = iface_fd,
    239             xlist = iface_fd,
    240 
    241             rlist, wlist, xlist = select.select(rlist, wlist, xlist, timeout)
    242 
    243             if self._pipe_rd in rlist:
    244                 msg = os.read(self._pipe_rd, 1)
    245                 # stop() breaks the loop sending a '*'.
    246                 if '*' in msg:
    247                     break
    248                 # Other messages are ignored.
    249 
    250             if xlist:
    251                 break
    252 
    253             if iface_fd in wlist:
    254                 self._iface.write(self._write_queue.pop(0))
    255                 # Attempt to send all the scheduled packets before reading more
    256                 continue
    257 
    258             # Process the given packet:
    259             if iface_fd in rlist:
    260                 raw = self._iface.read()
    261                 flag, proto = struct.unpack("!HH", raw[:4])
    262                 pkt = dpkt.ethernet.Ethernet(raw[4:])
    263                 for rule, callback in self._rules:
    264                     if rule(pkt):
    265                         # Parse again the packet to allow callbacks to modify
    266                         # it.
    267                         callback(dpkt.ethernet.Ethernet(raw[4:]))
    268 
    269         if stop_callback:
    270             self.remove_timeout(stop_callback)
    271         self._running = False
    272 
    273 
    274     def stop(self):
    275         """Stops the run() method if it is running."""
    276         os.write(self._pipe_wr, '*')
    277 
    278 
    279 class SimulatorThread(threading.Thread, Simulator):
    280     """A threaded version of the Simulator.
    281 
    282     This class exposses a similar interface as the Simulator class with the
    283     difference that it runs on its own thread. This exposes an extra method
    284     start() that should be called instead of Simulator.run(). start() will make
    285     the process run continuosly until stop() is called, after which the
    286     simulator can't be restarted.
    287 
    288     The methods used to add new matches can be called from any thread *before*
    289     the method start() is caller. After that point, only the callbacks, running
    290     from this thread, are allowed to create new matches and timeouts.
    291 
    292     Example:
    293         simu = SimulatorThread(tap_interface)
    294         simu.add_match({"ip.tcp.dport": 80}, some_callback)
    295         simu.start()
    296         time.sleep(100)
    297         simu.stop()
    298         simu.join() # Optional
    299     """
    300 
    301     def __init__(self, iface, timeout=None):
    302         threading.Thread.__init__(self)
    303         Simulator.__init__(self, iface)
    304         self._timeout = timeout
    305         # We allow the same thread to acquire the lock more than once. This is
    306         # useful if a callback want's to add itself.
    307         self._lock = threading.RLock()
    308         self.error = None
    309 
    310 
    311     def run_on_simulator(self, callback):
    312         """Runs the given callback on the SimulatorThread thread.
    313 
    314         Before calling start() on the SimulatorThread, all the calls seting up
    315         the simulator are allowed, but once the thread is running, concurrency
    316         problems should be considered. This method runs the provided callback
    317         on the simulator.
    318 
    319         @param callback: A callback function without arguments.
    320         """
    321         self.add_timeout(0, callback)
    322         # Wake up the main loop with an ignored message.
    323         os.write(self._pipe_wr, ' ')
    324 
    325 
    326     def wait_for_condition(self, condition, timeout=None):
    327         """Blocks until the condition is met or timeout is exceeded.
    328 
    329         This method should be called from a different thread while the simulator
    330         thread is running as it blocks the calling thread's execution until a
    331         condition is met. The condition function is evaluated in a callback
    332         running on the simulator thread and thus can safely access objects owned
    333         by the simulator.
    334 
    335         @param condition: A function called on the simulator thread that returns
    336         a value indicating if the condition is met.
    337         @param timeout: The timeout in seconds. None for no timeout.
    338         @return: The value returned by condition the last time it was called.
    339         This means that in the event of a timeout, this function will return a
    340         value that evaluates to False since the condition wasn't met the last
    341         time it was checked.
    342         """
    343         # Lock and Condition used to wait until the passed condition is met.
    344         lock_cond = threading.Lock()
    345         cond_var = threading.Condition(lock_cond)
    346         # We use a mutable object like the [] to pass the reference by value
    347         # to the simulator's callback and let it modify the contents.
    348         ret = [None]
    349 
    350         # Create the actual callback that will be running on the simulator
    351         # thread and pass a reference to it to keep including it
    352         callback = lambda: self._condition_poller(
    353                 callback, ret, cond_var, condition)
    354 
    355         # Let the simulator keep calling our function, it will keep calling
    356         # itself until the condition is met (or we remove it).
    357         self.run_on_simulator(callback)
    358 
    359         # Condition variable waiting loop.
    360         cur_time = time.time()
    361         start_time = cur_time
    362         with cond_var:
    363             while not ret[0]:
    364                 if timeout is None:
    365                     cond_var.wait()
    366                 else:
    367                     cur_timeout = timeout - (cur_time - start_time)
    368                     if cur_timeout < 0:
    369                         break
    370                     cond_var.wait(cur_timeout)
    371                     cur_time = time.time()
    372         self.remove_timeout(callback)
    373 
    374         return ret[0]
    375 
    376 
    377     def _condition_poller(self, callback, ref_value, cond_var, func):
    378         """Callback function used to poll for a condition.
    379 
    380         This method keeps scheduling itself in the simulator until the passed
    381         condition evaluates to a True value. This effectivelly implements a
    382         polling mechanism. See wait_for_condition() for details.
    383         """
    384         with cond_var:
    385             ref_value[0] = func()
    386             if ref_value[0]:
    387                 cond_var.notify()
    388             else:
    389                 self.add_timeout(1., callback)
    390 
    391 
    392     def run(self):
    393         """Runs the simulation on the thread, called by start().
    394 
    395         This method wraps the Simulator.run() to pass the timeout value passed
    396         during construction.
    397         """
    398         try:
    399             Simulator.run(self, self._timeout)
    400         except Exception, e:
    401             self.error = e
    402             exc_type, exc_value, exc_traceback = sys.exc_info()
    403             self.traceback = ''.join(traceback.format_exception(
    404                     exc_type, exc_value, exc_traceback))
    405