Home | History | Annotate | Download | only in scapy
      1 ## This file is part of Scapy
      2 ## See http://www.secdev.org/projects/scapy for more informations
      3 ## Copyright (C) Philippe Biondi <phil (at] secdev.org>
      4 ## Copyright (C) Gabriel Potter <gabriel (at] potter.fr>
      5 ## This program is published under a GPLv2 license
      6 
      7 """
      8 Automata with states, transitions and actions.
      9 """
     10 
     11 from __future__ import absolute_import
     12 import types,itertools,time,os,sys,socket,traceback
     13 from select import select
     14 from collections import deque
     15 import threading
     16 from scapy.config import conf
     17 from scapy.utils import do_graph
     18 from scapy.error import log_interactive
     19 from scapy.plist import PacketList
     20 from scapy.data import MTU
     21 from scapy.supersocket import SuperSocket
     22 from scapy.consts import WINDOWS
     23 from scapy.compat import *
     24 import scapy.modules.six as six
     25 
     26 try:
     27     import thread
     28 except ImportError:
     29     THREAD_EXCEPTION = RuntimeError
     30 else:
     31     THREAD_EXCEPTION = thread.error
     32 
     33 if WINDOWS:
     34     from scapy.error import Scapy_Exception
     35     recv_error = Scapy_Exception
     36 else:
     37     recv_error = ()
     38 
     39 """ In Windows, select.select is not available for custom objects. Here's the implementation of scapy to re-create this functionnality
     40 # Passive way: using no-ressources locks
     41                +---------+             +---------------+      +-------------------------+
     42                |  Start  +------------->Select_objects +----->+Linux: call select.select|
     43                +---------+             |(select.select)|      +-------------------------+
     44                                        +-------+-------+
     45                                                |
     46                                           +----v----+               +--------+
     47                                           | Windows |               |Time Out+----------------------------------+
     48                                           +----+----+               +----+---+                                  |
     49                                                |                         ^                                      |
     50       Event                                    |                         |                                      |
     51         +                                      |                         |                                      |
     52         |                              +-------v-------+                 |                                      |
     53         |                       +------+Selectable Sel.+-----+-----------------+-----------+                    |
     54         |                       |      +-------+-------+     |           |     |           v              +-----v-----+
     55 +-------v----------+            |              |             |           |     |        Passive lock<-----+release_all<------+
     56 |Data added to list|       +----v-----+  +-----v-----+  +----v-----+     v     v            +             +-----------+      |
     57 +--------+---------+       |Selectable|  |Selectable |  |Selectable|   ............         |                                |
     58          |                 +----+-----+  +-----------+  +----------+                        |                                |
     59          |                      v                                                           |                                |
     60          v                 +----+------+   +------------------+               +-------------v-------------------+            |
     61    +-----+------+          |wait_return+-->+  check_recv:     |               |                                 |            |
     62    |call_release|          +----+------+   |If data is in list|               |  END state: selectable returned |        +---+--------+
     63    +-----+--------              v          +-------+----------+               |                                 |        | exit door  |
     64          |                    else                 |                          +---------------------------------+        +---+--------+
     65          |                      +                  |                                                                         |
     66          |                 +----v-------+          |                                                                         |
     67          +--------->free -->Passive lock|          |                                                                         |
     68                            +----+-------+          |                                                                         |
     69                                 |                  |                                                                         |
     70                                 |                  v                                                                         |
     71                                 +------------------Selectable-Selector-is-advertised-that-the-selectable-is-readable---------+
     72 """
     73 
     74 class SelectableObject:
     75     """DEV: to implement one of those, you need to add 2 things to your object:
     76     - add "check_recv" function
     77     - call "self.call_release" once you are ready to be read
     78 
     79     You can set the __selectable_force_select__ to True in the class, if you want to
     80     force the handler to use fileno(). This may only be useable on sockets created using
     81     the builtin socket API."""
     82     __selectable_force_select__ = False
     83     def check_recv(self):
     84         """DEV: will be called only once (at beginning) to check if the object is ready."""
     85         raise OSError("This method must be overwriten.")
     86 
     87     def _wait_non_ressources(self, callback):
     88         """This get started as a thread, and waits for the data lock to be freed then advertise itself to the SelectableSelector using the callback"""
     89         self.trigger = threading.Lock()
     90         self.was_ended = False
     91         self.trigger.acquire()
     92         self.trigger.acquire()
     93         if not self.was_ended:
     94             callback(self)
     95 
     96     def wait_return(self, callback):
     97         """Entry point of SelectableObject: register the callback"""
     98         if self.check_recv():
     99             return callback(self)
    100         _t = threading.Thread(target=self._wait_non_ressources, args=(callback,))
    101         _t.setDaemon(True)
    102         _t.start()
    103         
    104     def call_release(self, arborted=False):
    105         """DEV: Must be call when the object becomes ready to read.
    106            Relesases the lock of _wait_non_ressources"""
    107         self.was_ended = arborted
    108         try:
    109             self.trigger.release()
    110         except (THREAD_EXCEPTION, AttributeError):
    111             pass
    112 
    113 class SelectableSelector(object):
    114     """
    115     Select SelectableObject objects.
    116     
    117     inputs: objects to process
    118     remain: timeout. If 0, return [].
    119     customTypes: types of the objects that have the check_recv function.
    120     """
    121     def _release_all(self):
    122         """Releases all locks to kill all threads"""
    123         for i in self.inputs:
    124             i.call_release(True)
    125         self.available_lock.release()
    126 
    127     def _timeout_thread(self, remain):
    128         """Timeout before releasing every thing, if nothing was returned"""
    129         time.sleep(remain)
    130         if not self._ended:
    131             self._ended = True
    132             self._release_all()
    133 
    134     def _exit_door(self, _input):
    135         """This function is passed to each SelectableObject as a callback
    136         The SelectableObjects have to call it once there are ready"""
    137         self.results.append(_input)
    138         if self._ended:
    139             return
    140         self._ended = True
    141         self._release_all()
    142     
    143     def __init__(self, inputs, remain):
    144         self.results = []
    145         self.inputs = list(inputs)
    146         self.remain = remain
    147         self.available_lock = threading.Lock()
    148         self.available_lock.acquire()
    149         self._ended = False
    150 
    151     def process(self):
    152         """Entry point of SelectableSelector"""
    153         if WINDOWS:
    154             select_inputs = []
    155             for i in self.inputs:
    156                 if not isinstance(i, SelectableObject):
    157                     warning("Unknown ignored object type: %s", type(i))
    158                 elif i.__selectable_force_select__:
    159                     # Then use select.select
    160                     select_inputs.append(i)
    161                 elif not self.remain and i.check_recv():
    162                     self.results.append(i)
    163                 else:
    164                     i.wait_return(self._exit_door)
    165             if select_inputs:
    166                 # Use default select function
    167                 self.results.extend(select(select_inputs, [], [], self.remain)[0])
    168             if not self.remain:
    169                 return self.results
    170 
    171             threading.Thread(target=self._timeout_thread, args=(self.remain,)).start()
    172             if not self._ended:
    173                 self.available_lock.acquire()
    174             return self.results
    175         else:
    176             r,_,_ = select(self.inputs,[],[],self.remain)
    177             return r
    178 
    179 def select_objects(inputs, remain):
    180     """
    181     Select SelectableObject objects. Same than:
    182         select.select([inputs], [], [], remain)
    183     But also works on Windows, only on SelectableObject.
    184     
    185     inputs: objects to process
    186     remain: timeout. If 0, return [].
    187     """
    188     handler = SelectableSelector(inputs, remain)
    189     return handler.process()
    190 
    191 class ObjectPipe(SelectableObject):
    192     def __init__(self):
    193         self.rd,self.wr = os.pipe()
    194         self.queue = deque()
    195     def fileno(self):
    196         return self.rd
    197     def check_recv(self):
    198         return len(self.queue) > 0
    199     def send(self, obj):
    200         self.queue.append(obj)
    201         os.write(self.wr,b"X")
    202         self.call_release()
    203     def write(self, obj):
    204         self.send(obj)
    205     def recv(self, n=0):
    206         os.read(self.rd, 1)
    207         return self.queue.popleft()
    208     def read(self, n=0):
    209         return self.recv(n)
    210 
    211 class Message:
    212     def __init__(self, **args):
    213         self.__dict__.update(args)
    214     def __repr__(self):
    215         return "<Message %s>" % " ".join("%s=%r"%(k,v)
    216                                          for (k,v) in six.iteritems(self.__dict__)
    217                                          if not k.startswith("_"))
    218 
    219 class _instance_state:
    220     def __init__(self, instance):
    221         self.__self__ = instance.__self__
    222         self.__func__ = instance.__func__
    223         self.__self__.__class__ = instance.__self__.__class__
    224     def __getattr__(self, attr):
    225         return getattr(self.__func__, attr)
    226     def __call__(self, *args, **kargs):
    227         return self.__func__(self.__self__, *args, **kargs)
    228     def breaks(self):
    229         return self.__self__.add_breakpoints(self.__func__)
    230     def intercepts(self):
    231         return self.__self__.add_interception_points(self.__func__)
    232     def unbreaks(self):
    233         return self.__self__.remove_breakpoints(self.__func__)
    234     def unintercepts(self):
    235         return self.__self__.remove_interception_points(self.__func__)
    236         
    237 
    238 ##############
    239 ## Automata ##
    240 ##############
    241 
    242 class ATMT:
    243     STATE = "State"
    244     ACTION = "Action"
    245     CONDITION = "Condition"
    246     RECV = "Receive condition"
    247     TIMEOUT = "Timeout condition"
    248     IOEVENT = "I/O event"
    249 
    250     class NewStateRequested(Exception):
    251         def __init__(self, state_func, automaton, *args, **kargs):
    252             self.func = state_func
    253             self.state = state_func.atmt_state
    254             self.initial = state_func.atmt_initial
    255             self.error = state_func.atmt_error
    256             self.final = state_func.atmt_final
    257             Exception.__init__(self, "Request state [%s]" % self.state)
    258             self.automaton = automaton
    259             self.args = args
    260             self.kargs = kargs
    261             self.action_parameters() # init action parameters
    262         def action_parameters(self, *args, **kargs):
    263             self.action_args = args
    264             self.action_kargs = kargs
    265             return self
    266         def run(self):
    267             return self.func(self.automaton, *self.args, **self.kargs)
    268         def __repr__(self):
    269             return "NewStateRequested(%s)" % self.state
    270 
    271     @staticmethod
    272     def state(initial=0,final=0,error=0):
    273         def deco(f,initial=initial, final=final):
    274             f.atmt_type = ATMT.STATE
    275             f.atmt_state = f.__name__
    276             f.atmt_initial = initial
    277             f.atmt_final = final
    278             f.atmt_error = error
    279             def state_wrapper(self, *args, **kargs):
    280                 return ATMT.NewStateRequested(f, self, *args, **kargs)
    281 
    282             state_wrapper.__name__ = "%s_wrapper" % f.__name__
    283             state_wrapper.atmt_type = ATMT.STATE
    284             state_wrapper.atmt_state = f.__name__
    285             state_wrapper.atmt_initial = initial
    286             state_wrapper.atmt_final = final
    287             state_wrapper.atmt_error = error
    288             state_wrapper.atmt_origfunc = f
    289             return state_wrapper
    290         return deco
    291     @staticmethod
    292     def action(cond, prio=0):
    293         def deco(f,cond=cond):
    294             if not hasattr(f,"atmt_type"):
    295                 f.atmt_cond = {}
    296             f.atmt_type = ATMT.ACTION
    297             f.atmt_cond[cond.atmt_condname] = prio
    298             return f
    299         return deco
    300     @staticmethod
    301     def condition(state, prio=0):
    302         def deco(f, state=state):
    303             f.atmt_type = ATMT.CONDITION
    304             f.atmt_state = state.atmt_state
    305             f.atmt_condname = f.__name__
    306             f.atmt_prio = prio
    307             return f
    308         return deco
    309     @staticmethod
    310     def receive_condition(state, prio=0):
    311         def deco(f, state=state):
    312             f.atmt_type = ATMT.RECV
    313             f.atmt_state = state.atmt_state
    314             f.atmt_condname = f.__name__
    315             f.atmt_prio = prio
    316             return f
    317         return deco
    318     @staticmethod
    319     def ioevent(state, name, prio=0, as_supersocket=None):
    320         def deco(f, state=state):
    321             f.atmt_type = ATMT.IOEVENT
    322             f.atmt_state = state.atmt_state
    323             f.atmt_condname = f.__name__
    324             f.atmt_ioname = name
    325             f.atmt_prio = prio
    326             f.atmt_as_supersocket = as_supersocket
    327             return f
    328         return deco
    329     @staticmethod
    330     def timeout(state, timeout):
    331         def deco(f, state=state, timeout=timeout):
    332             f.atmt_type = ATMT.TIMEOUT
    333             f.atmt_state = state.atmt_state
    334             f.atmt_timeout = timeout
    335             f.atmt_condname = f.__name__
    336             return f
    337         return deco
    338 
    339 class _ATMT_Command:
    340     RUN = "RUN"
    341     NEXT = "NEXT"
    342     FREEZE = "FREEZE"
    343     STOP = "STOP"
    344     END = "END"
    345     EXCEPTION = "EXCEPTION"
    346     SINGLESTEP = "SINGLESTEP"
    347     BREAKPOINT = "BREAKPOINT"
    348     INTERCEPT = "INTERCEPT"
    349     ACCEPT = "ACCEPT"
    350     REPLACE = "REPLACE"
    351     REJECT = "REJECT"
    352 
    353 class _ATMT_supersocket(SuperSocket):
    354     def __init__(self, name, ioevent, automaton, proto, args, kargs):
    355         self.name = name
    356         self.ioevent = ioevent
    357         self.proto = proto
    358         self.spa,self.spb = socket.socketpair(socket.AF_UNIX, socket.SOCK_DGRAM)
    359         kargs["external_fd"] = {ioevent:self.spb}
    360         self.atmt = automaton(*args, **kargs)
    361         self.atmt.runbg()
    362     def fileno(self):
    363         return self.spa.fileno()
    364     def send(self, s):
    365         if not isinstance(s, bytes):
    366             s = bytes(s)
    367         return self.spa.send(s)
    368     def recv(self, n=MTU):
    369         try:
    370             r = self.spa.recv(n)
    371         except recv_error:
    372             if not WINDOWS:
    373                 raise
    374             return None
    375         if self.proto is not None:
    376             r = self.proto(r)
    377         return r
    378     def close(self):
    379         pass
    380 
    381 class _ATMT_to_supersocket:
    382     def __init__(self, name, ioevent, automaton):
    383         self.name = name
    384         self.ioevent = ioevent
    385         self.automaton = automaton
    386     def __call__(self, proto, *args, **kargs):
    387         return _ATMT_supersocket(self.name, self.ioevent, self.automaton, proto, args, kargs)
    388 
    389 class Automaton_metaclass(type):
    390     def __new__(cls, name, bases, dct):
    391         cls = super(Automaton_metaclass, cls).__new__(cls, name, bases, dct)
    392         cls.states={}
    393         cls.state = None
    394         cls.recv_conditions={}
    395         cls.conditions={}
    396         cls.ioevents={}
    397         cls.timeout={}
    398         cls.actions={}
    399         cls.initial_states=[]
    400         cls.ionames = []
    401         cls.iosupersockets = []
    402 
    403         members = {}
    404         classes = [cls]
    405         while classes:
    406             c = classes.pop(0) # order is important to avoid breaking method overloading
    407             classes += list(c.__bases__)
    408             for k,v in six.iteritems(c.__dict__):
    409                 if k not in members:
    410                     members[k] = v
    411 
    412         decorated = [v for v in six.itervalues(members)
    413                      if isinstance(v, types.FunctionType) and hasattr(v, "atmt_type")]
    414         
    415         for m in decorated:
    416             if m.atmt_type == ATMT.STATE:
    417                 s = m.atmt_state
    418                 cls.states[s] = m
    419                 cls.recv_conditions[s]=[]
    420                 cls.ioevents[s]=[]
    421                 cls.conditions[s]=[]
    422                 cls.timeout[s]=[]
    423                 if m.atmt_initial:
    424                     cls.initial_states.append(m)
    425             elif m.atmt_type in [ATMT.CONDITION, ATMT.RECV, ATMT.TIMEOUT, ATMT.IOEVENT]:
    426                 cls.actions[m.atmt_condname] = []
    427     
    428         for m in decorated:
    429             if m.atmt_type == ATMT.CONDITION:
    430                 cls.conditions[m.atmt_state].append(m)
    431             elif m.atmt_type == ATMT.RECV:
    432                 cls.recv_conditions[m.atmt_state].append(m)
    433             elif m.atmt_type == ATMT.IOEVENT:
    434                 cls.ioevents[m.atmt_state].append(m)
    435                 cls.ionames.append(m.atmt_ioname)
    436                 if m.atmt_as_supersocket is not None:
    437                     cls.iosupersockets.append(m)
    438             elif m.atmt_type == ATMT.TIMEOUT:
    439                 cls.timeout[m.atmt_state].append((m.atmt_timeout, m))
    440             elif m.atmt_type == ATMT.ACTION:
    441                 for c in m.atmt_cond:
    442                     cls.actions[c].append(m)
    443             
    444 
    445         for v in six.itervalues(cls.timeout):
    446             v.sort(key=cmp_to_key(lambda t1_f1,t2_f2: cmp(t1_f1[0],t2_f2[0])))
    447             v.append((None, None))
    448         for v in itertools.chain(six.itervalues(cls.conditions),
    449                                  six.itervalues(cls.recv_conditions),
    450                                  six.itervalues(cls.ioevents)):
    451             v.sort(key=cmp_to_key(lambda c1,c2: cmp(c1.atmt_prio,c2.atmt_prio)))
    452         for condname,actlst in six.iteritems(cls.actions):
    453             actlst.sort(key=cmp_to_key(lambda c1,c2: cmp(c1.atmt_cond[condname], c2.atmt_cond[condname])))
    454 
    455         for ioev in cls.iosupersockets:
    456             setattr(cls, ioev.atmt_as_supersocket, _ATMT_to_supersocket(ioev.atmt_as_supersocket, ioev.atmt_ioname, cls))
    457 
    458         return cls
    459 
    460     def graph(self, **kargs):
    461         s = 'digraph "%s" {\n'  % self.__class__.__name__
    462         
    463         se = "" # Keep initial nodes at the begining for better rendering
    464         for st in six.itervalues(self.states):
    465             if st.atmt_initial:
    466                 se = ('\t"%s" [ style=filled, fillcolor=blue, shape=box, root=true];\n' % st.atmt_state)+se
    467             elif st.atmt_final:
    468                 se += '\t"%s" [ style=filled, fillcolor=green, shape=octagon ];\n' % st.atmt_state
    469             elif st.atmt_error:
    470                 se += '\t"%s" [ style=filled, fillcolor=red, shape=octagon ];\n' % st.atmt_state
    471         s += se
    472 
    473         for st in six.itervalues(self.states):
    474             for n in st.atmt_origfunc.__code__.co_names+st.atmt_origfunc.__code__.co_consts:
    475                 if n in self.states:
    476                     s += '\t"%s" -> "%s" [ color=green ];\n' % (st.atmt_state,n)
    477             
    478 
    479         for c,k,v in ([("purple",k,v) for k,v in self.conditions.items()]+
    480                       [("red",k,v) for k,v in self.recv_conditions.items()]+
    481                       [("orange",k,v) for k,v in self.ioevents.items()]):
    482             for f in v:
    483                 for n in f.__code__.co_names+f.__code__.co_consts:
    484                     if n in self.states:
    485                         l = f.atmt_condname
    486                         for x in self.actions[f.atmt_condname]:
    487                             l += "\\l>[%s]" % x.__name__
    488                         s += '\t"%s" -> "%s" [label="%s", color=%s];\n' % (k,n,l,c)
    489         for k,v in six.iteritems(self.timeout):
    490             for t,f in v:
    491                 if f is None:
    492                     continue
    493                 for n in f.__code__.co_names+f.__code__.co_consts:
    494                     if n in self.states:
    495                         l = "%s/%.1fs" % (f.atmt_condname,t)                        
    496                         for x in self.actions[f.atmt_condname]:
    497                             l += "\\l>[%s]" % x.__name__
    498                         s += '\t"%s" -> "%s" [label="%s",color=blue];\n' % (k,n,l)
    499         s += "}\n"
    500         return do_graph(s, **kargs)
    501 
    502 class Automaton(six.with_metaclass(Automaton_metaclass)):
    503     def parse_args(self, debug=0, store=1, **kargs):
    504         self.debug_level=debug
    505         self.socket_kargs = kargs
    506         self.store_packets = store        
    507 
    508     def master_filter(self, pkt):
    509         return True
    510 
    511     def my_send(self, pkt):
    512         self.send_sock.send(pkt)
    513 
    514 
    515     ## Utility classes and exceptions
    516     class _IO_fdwrapper(SelectableObject):
    517         def __init__(self,rd,wr):
    518             if WINDOWS:
    519                 # rd will be used for reading and sending
    520                 if isinstance(rd, ObjectPipe):
    521                     self.rd = rd
    522                 else:
    523                     raise OSError("On windows, only instances of ObjectPipe are externally available")
    524             else:
    525                 if rd is not None and not isinstance(rd, int):
    526                     rd = rd.fileno()
    527                 if wr is not None and not isinstance(wr, int):
    528                     wr = wr.fileno()
    529                 self.rd = rd
    530                 self.wr = wr
    531         def fileno(self):
    532             return self.rd
    533         def check_recv(self):
    534             return self.rd.check_recv()
    535         def read(self, n=65535):
    536             if WINDOWS:
    537                 return self.rd.recv(n)
    538             return os.read(self.rd, n)
    539         def write(self, msg):
    540             if WINDOWS:
    541                 self.rd.send(msg)
    542                 return self.call_release()
    543             return os.write(self.wr,msg)
    544         def recv(self, n=65535):
    545             return self.read(n)        
    546         def send(self, msg):
    547             return self.write(msg)
    548 
    549     class _IO_mixer(SelectableObject):
    550         def __init__(self,rd,wr):
    551             self.rd = rd
    552             self.wr = wr
    553         def fileno(self):
    554             if isinstance(self.rd, int):
    555                 return self.rd
    556             return self.rd.fileno()
    557         def check_recv(self):
    558             return self.rd.check_recv()
    559         def recv(self, n=None):
    560             return self.rd.recv(n)
    561         def read(self, n=None):
    562             return self.recv(n)
    563         def send(self, msg):
    564             self.wr.send(msg)
    565             return self.call_release()
    566         def write(self, msg):
    567             return self.send(msg)
    568 
    569 
    570     class AutomatonException(Exception):
    571         def __init__(self, msg, state=None, result=None):
    572             Exception.__init__(self, msg)
    573             self.state = state
    574             self.result = result
    575 
    576     class AutomatonError(AutomatonException):
    577         pass
    578     class ErrorState(AutomatonException):
    579         pass
    580     class Stuck(AutomatonException):
    581         pass
    582     class AutomatonStopped(AutomatonException):
    583         pass
    584     
    585     class Breakpoint(AutomatonStopped):
    586         pass
    587     class Singlestep(AutomatonStopped):
    588         pass
    589     class InterceptionPoint(AutomatonStopped):
    590         def __init__(self, msg, state=None, result=None, packet=None):
    591             Automaton.AutomatonStopped.__init__(self, msg, state=state, result=result)
    592             self.packet = packet
    593 
    594     class CommandMessage(AutomatonException):
    595         pass
    596 
    597 
    598     ## Services
    599     def debug(self, lvl, msg):
    600         if self.debug_level >= lvl:
    601             log_interactive.debug(msg)            
    602 
    603     def send(self, pkt):
    604         if self.state.state in self.interception_points:
    605             self.debug(3,"INTERCEPT: packet intercepted: %s" % pkt.summary())
    606             self.intercepted_packet = pkt
    607             cmd = Message(type = _ATMT_Command.INTERCEPT, state=self.state, pkt=pkt)
    608             self.cmdout.send(cmd)
    609             cmd = self.cmdin.recv()
    610             self.intercepted_packet = None
    611             if cmd.type == _ATMT_Command.REJECT:
    612                 self.debug(3,"INTERCEPT: packet rejected")
    613                 return
    614             elif cmd.type == _ATMT_Command.REPLACE:
    615                 pkt = cmd.pkt
    616                 self.debug(3,"INTERCEPT: packet replaced by: %s" % pkt.summary())
    617             elif cmd.type == _ATMT_Command.ACCEPT:
    618                 self.debug(3,"INTERCEPT: packet accepted")
    619             else:
    620                 raise self.AutomatonError("INTERCEPT: unkown verdict: %r" % cmd.type)
    621         self.my_send(pkt)
    622         self.debug(3,"SENT : %s" % pkt.summary())
    623         
    624         if self.store_packets:
    625             self.packets.append(pkt.copy())
    626 
    627 
    628     ## Internals
    629     def __init__(self, *args, **kargs):
    630         external_fd = kargs.pop("external_fd",{})
    631         self.send_sock_class = kargs.pop("ll", conf.L3socket)
    632         self.recv_sock_class = kargs.pop("recvsock", conf.L2listen)
    633         self.started = threading.Lock()
    634         self.threadid = None
    635         self.breakpointed = None
    636         self.breakpoints = set()
    637         self.interception_points = set()
    638         self.intercepted_packet = None
    639         self.debug_level=0
    640         self.init_args=args
    641         self.init_kargs=kargs
    642         self.io = type.__new__(type, "IOnamespace",(),{})
    643         self.oi = type.__new__(type, "IOnamespace",(),{})
    644         self.cmdin = ObjectPipe()
    645         self.cmdout = ObjectPipe()
    646         self.ioin = {}
    647         self.ioout = {}
    648         for n in self.ionames:
    649             extfd = external_fd.get(n)
    650             if not isinstance(extfd, tuple):
    651                 extfd = (extfd,extfd)
    652             elif WINDOWS:
    653                 raise OSError("Tuples are not allowed as external_fd on windows")
    654             ioin,ioout = extfd                
    655             if ioin is None:
    656                 ioin = ObjectPipe()
    657             elif not isinstance(ioin, SelectableObject):
    658                 ioin = self._IO_fdwrapper(ioin,None)
    659             if ioout is None:
    660                 ioout = ioin if WINDOWS else ObjectPipe()
    661             elif not isinstance(ioout, SelectableObject):
    662                 ioout = self._IO_fdwrapper(None,ioout)
    663 
    664             self.ioin[n] = ioin
    665             self.ioout[n] = ioout 
    666             ioin.ioname = n
    667             ioout.ioname = n
    668             setattr(self.io, n, self._IO_mixer(ioout,ioin))
    669             setattr(self.oi, n, self._IO_mixer(ioin,ioout))
    670 
    671         for stname in self.states:
    672             setattr(self, stname, 
    673                     _instance_state(getattr(self, stname)))
    674 
    675         self.start()
    676 
    677     def __iter__(self):
    678         return self        
    679 
    680     def __del__(self):
    681         self.stop()
    682 
    683     def _run_condition(self, cond, *args, **kargs):
    684         try:
    685             self.debug(5, "Trying %s [%s]" % (cond.atmt_type, cond.atmt_condname))
    686             cond(self,*args, **kargs)
    687         except ATMT.NewStateRequested as state_req:
    688             self.debug(2, "%s [%s] taken to state [%s]" % (cond.atmt_type, cond.atmt_condname, state_req.state))
    689             if cond.atmt_type == ATMT.RECV:
    690                 if self.store_packets:
    691                     self.packets.append(args[0])
    692             for action in self.actions[cond.atmt_condname]:
    693                 self.debug(2, "   + Running action [%s]" % action.__name__)
    694                 action(self, *state_req.action_args, **state_req.action_kargs)
    695             raise
    696         except Exception as e:
    697             self.debug(2, "%s [%s] raised exception [%s]" % (cond.atmt_type, cond.atmt_condname, e))
    698             raise
    699         else:
    700             self.debug(2, "%s [%s] not taken" % (cond.atmt_type, cond.atmt_condname))
    701 
    702     def _do_start(self, *args, **kargs):
    703         ready = threading.Event()
    704         _t = threading.Thread(target=self._do_control, args=(ready,) + (args), kwargs=kargs)
    705         _t.setDaemon(True)
    706         _t.start()
    707         ready.wait()
    708 
    709     def _do_control(self, ready, *args, **kargs):
    710         with self.started:
    711             self.threadid = threading.currentThread().ident
    712 
    713             # Update default parameters
    714             a = args+self.init_args[len(args):]
    715             k = self.init_kargs.copy()
    716             k.update(kargs)
    717             self.parse_args(*a,**k)
    718     
    719             # Start the automaton
    720             self.state=self.initial_states[0](self)
    721             self.send_sock = self.send_sock_class(**self.socket_kargs)
    722             self.listen_sock = self.recv_sock_class(**self.socket_kargs)
    723             self.packets = PacketList(name="session[%s]"%self.__class__.__name__)
    724 
    725             singlestep = True
    726             iterator = self._do_iter()
    727             self.debug(3, "Starting control thread [tid=%i]" % self.threadid)
    728             # Sync threads
    729             ready.set()
    730             try:
    731                 while True:
    732                     c = self.cmdin.recv()
    733                     self.debug(5, "Received command %s" % c.type)
    734                     if c.type == _ATMT_Command.RUN:
    735                         singlestep = False
    736                     elif c.type == _ATMT_Command.NEXT:
    737                         singlestep = True
    738                     elif c.type == _ATMT_Command.FREEZE:
    739                         continue
    740                     elif c.type == _ATMT_Command.STOP:
    741                         break
    742                     while True:
    743                         state = next(iterator)
    744                         if isinstance(state, self.CommandMessage):
    745                             break
    746                         elif isinstance(state, self.Breakpoint):
    747                             c = Message(type=_ATMT_Command.BREAKPOINT,state=state)
    748                             self.cmdout.send(c)
    749                             break
    750                         if singlestep:
    751                             c = Message(type=_ATMT_Command.SINGLESTEP,state=state)
    752                             self.cmdout.send(c)
    753                             break
    754             except StopIteration as e:
    755                 c = Message(type=_ATMT_Command.END, result=e.args[0])
    756                 self.cmdout.send(c)
    757             except Exception as e:
    758                 exc_info = sys.exc_info()
    759                 self.debug(3, "Transfering exception from tid=%i:\n%s"% (self.threadid, traceback.format_exception(*exc_info)))
    760                 m = Message(type=_ATMT_Command.EXCEPTION, exception=e, exc_info=exc_info)
    761                 self.cmdout.send(m)        
    762             self.debug(3, "Stopping control thread (tid=%i)"%self.threadid)
    763             self.threadid = None
    764     
    765     def _do_iter(self):
    766         while True:
    767             try:
    768                 self.debug(1, "## state=[%s]" % self.state.state)
    769     
    770                 # Entering a new state. First, call new state function
    771                 if self.state.state in self.breakpoints and self.state.state != self.breakpointed: 
    772                     self.breakpointed = self.state.state
    773                     yield self.Breakpoint("breakpoint triggered on state %s" % self.state.state,
    774                                           state = self.state.state)
    775                 self.breakpointed = None
    776                 state_output = self.state.run()
    777                 if self.state.error:
    778                     raise self.ErrorState("Reached %s: [%r]" % (self.state.state, state_output), 
    779                                           result=state_output, state=self.state.state)
    780                 if self.state.final:
    781                     raise StopIteration(state_output)
    782     
    783                 if state_output is None:
    784                     state_output = ()
    785                 elif not isinstance(state_output, list):
    786                     state_output = state_output,
    787                 
    788                 # Then check immediate conditions
    789                 for cond in self.conditions[self.state.state]:
    790                     self._run_condition(cond, *state_output)
    791     
    792                 # If still there and no conditions left, we are stuck!
    793                 if ( len(self.recv_conditions[self.state.state]) == 0 and
    794                      len(self.ioevents[self.state.state]) == 0 and
    795                      len(self.timeout[self.state.state]) == 1 ):
    796                     raise self.Stuck("stuck in [%s]" % self.state.state,
    797                                      state=self.state.state, result=state_output)
    798     
    799                 # Finally listen and pay attention to timeouts
    800                 expirations = iter(self.timeout[self.state.state])
    801                 next_timeout,timeout_func = next(expirations)
    802                 t0 = time.time()
    803                 
    804                 fds = [self.cmdin]
    805                 if len(self.recv_conditions[self.state.state]) > 0:
    806                     fds.append(self.listen_sock)
    807                 for ioev in self.ioevents[self.state.state]:
    808                     fds.append(self.ioin[ioev.atmt_ioname])
    809                 while True:
    810                     t = time.time()-t0
    811                     if next_timeout is not None:
    812                         if next_timeout <= t:
    813                             self._run_condition(timeout_func, *state_output)
    814                             next_timeout,timeout_func = next(expirations)
    815                     if next_timeout is None:
    816                         remain = None
    817                     else:
    818                         remain = next_timeout-t
    819     
    820                     self.debug(5, "Select on %r" % fds)
    821                     r = select_objects(fds, remain)
    822                     self.debug(5, "Selected %r" % r)
    823                     for fd in r:
    824                         self.debug(5, "Looking at %r" % fd)
    825                         if fd == self.cmdin:
    826                             yield self.CommandMessage("Received command message")
    827                         elif fd == self.listen_sock:
    828                             try:
    829                                 pkt = self.listen_sock.recv(MTU)
    830                             except recv_error:
    831                                 pass
    832                             else:
    833                                 if pkt is not None:
    834                                     if self.master_filter(pkt):
    835                                         self.debug(3, "RECVD: %s" % pkt.summary())
    836                                         for rcvcond in self.recv_conditions[self.state.state]:
    837                                             self._run_condition(rcvcond, pkt, *state_output)
    838                                     else:
    839                                         self.debug(4, "FILTR: %s" % pkt.summary())
    840                         else:
    841                             self.debug(3, "IOEVENT on %s" % fd.ioname)
    842                             for ioevt in self.ioevents[self.state.state]:
    843                                 if ioevt.atmt_ioname == fd.ioname:
    844                                     self._run_condition(ioevt, fd, *state_output)
    845     
    846             except ATMT.NewStateRequested as state_req:
    847                 self.debug(2, "switching from [%s] to [%s]" % (self.state.state,state_req.state))
    848                 self.state = state_req
    849                 yield state_req
    850 
    851     ## Public API
    852     def add_interception_points(self, *ipts):
    853         for ipt in ipts:
    854             if hasattr(ipt,"atmt_state"):
    855                 ipt = ipt.atmt_state
    856             self.interception_points.add(ipt)
    857         
    858     def remove_interception_points(self, *ipts):
    859         for ipt in ipts:
    860             if hasattr(ipt,"atmt_state"):
    861                 ipt = ipt.atmt_state
    862             self.interception_points.discard(ipt)
    863 
    864     def add_breakpoints(self, *bps):
    865         for bp in bps:
    866             if hasattr(bp,"atmt_state"):
    867                 bp = bp.atmt_state
    868             self.breakpoints.add(bp)
    869 
    870     def remove_breakpoints(self, *bps):
    871         for bp in bps:
    872             if hasattr(bp,"atmt_state"):
    873                 bp = bp.atmt_state
    874             self.breakpoints.discard(bp)
    875 
    876     def start(self, *args, **kargs):
    877         if not self.started.locked():
    878             self._do_start(*args, **kargs)
    879         
    880     def run(self, resume=None, wait=True):
    881         if resume is None:
    882             resume = Message(type = _ATMT_Command.RUN)
    883         self.cmdin.send(resume)
    884         if wait:
    885             try:
    886                 c = self.cmdout.recv()
    887             except KeyboardInterrupt:
    888                 self.cmdin.send(Message(type = _ATMT_Command.FREEZE))
    889                 return
    890             if c.type == _ATMT_Command.END:
    891                 return c.result
    892             elif c.type == _ATMT_Command.INTERCEPT:
    893                 raise self.InterceptionPoint("packet intercepted", state=c.state.state, packet=c.pkt)
    894             elif c.type == _ATMT_Command.SINGLESTEP:
    895                 raise self.Singlestep("singlestep state=[%s]"%c.state.state, state=c.state.state)
    896             elif c.type == _ATMT_Command.BREAKPOINT:
    897                 raise self.Breakpoint("breakpoint triggered on state [%s]"%c.state.state, state=c.state.state)
    898             elif c.type == _ATMT_Command.EXCEPTION:
    899                 six.reraise(c.exc_info[0], c.exc_info[1], c.exc_info[2])
    900 
    901     def runbg(self, resume=None, wait=False):
    902         self.run(resume, wait)
    903 
    904     def next(self):
    905         return self.run(resume = Message(type=_ATMT_Command.NEXT))
    906     __next__ = next
    907 
    908     def stop(self):
    909         self.cmdin.send(Message(type=_ATMT_Command.STOP))
    910         with self.started:
    911             # Flush command pipes
    912             while True:
    913                 r = select_objects([self.cmdin, self.cmdout], 0)
    914                 if not r:
    915                     break
    916                 for fd in r:
    917                     fd.recv()
    918                 
    919     def restart(self, *args, **kargs):
    920         self.stop()
    921         self.start(*args, **kargs)
    922 
    923     def accept_packet(self, pkt=None, wait=False):
    924         rsm = Message()
    925         if pkt is None:
    926             rsm.type = _ATMT_Command.ACCEPT
    927         else:
    928             rsm.type = _ATMT_Command.REPLACE
    929             rsm.pkt = pkt
    930         return self.run(resume=rsm, wait=wait)
    931 
    932     def reject_packet(self, wait=False):
    933         rsm = Message(type = _ATMT_Command.REJECT)
    934         return self.run(resume=rsm, wait=wait)
    935 
    936     
    937 
    938