1 ## This file is part of Scapy 2 ## See http://www.secdev.org/projects/scapy for more informations 3 ## Copyright (C) Philippe Biondi <phil (at] secdev.org> 4 ## Copyright (C) Gabriel Potter <gabriel (at] potter.fr> 5 ## This program is published under a GPLv2 license 6 7 """ 8 Automata with states, transitions and actions. 9 """ 10 11 from __future__ import absolute_import 12 import types,itertools,time,os,sys,socket,traceback 13 from select import select 14 from collections import deque 15 import threading 16 from scapy.config import conf 17 from scapy.utils import do_graph 18 from scapy.error import log_interactive 19 from scapy.plist import PacketList 20 from scapy.data import MTU 21 from scapy.supersocket import SuperSocket 22 from scapy.consts import WINDOWS 23 from scapy.compat import * 24 import scapy.modules.six as six 25 26 try: 27 import thread 28 except ImportError: 29 THREAD_EXCEPTION = RuntimeError 30 else: 31 THREAD_EXCEPTION = thread.error 32 33 if WINDOWS: 34 from scapy.error import Scapy_Exception 35 recv_error = Scapy_Exception 36 else: 37 recv_error = () 38 39 """ In Windows, select.select is not available for custom objects. Here's the implementation of scapy to re-create this functionnality 40 # Passive way: using no-ressources locks 41 +---------+ +---------------+ +-------------------------+ 42 | Start +------------->Select_objects +----->+Linux: call select.select| 43 +---------+ |(select.select)| +-------------------------+ 44 +-------+-------+ 45 | 46 +----v----+ +--------+ 47 | Windows | |Time Out+----------------------------------+ 48 +----+----+ +----+---+ | 49 | ^ | 50 Event | | | 51 + | | | 52 | +-------v-------+ | | 53 | +------+Selectable Sel.+-----+-----------------+-----------+ | 54 | | +-------+-------+ | | | v +-----v-----+ 55 +-------v----------+ | | | | | Passive lock<-----+release_all<------+ 56 |Data added to list| +----v-----+ +-----v-----+ +----v-----+ v v + +-----------+ | 57 +--------+---------+ |Selectable| |Selectable | |Selectable| ............ | | 58 | +----+-----+ +-----------+ +----------+ | | 59 | v | | 60 v +----+------+ +------------------+ +-------------v-------------------+ | 61 +-----+------+ |wait_return+-->+ check_recv: | | | | 62 |call_release| +----+------+ |If data is in list| | END state: selectable returned | +---+--------+ 63 +-----+-------- v +-------+----------+ | | | exit door | 64 | else | +---------------------------------+ +---+--------+ 65 | + | | 66 | +----v-------+ | | 67 +--------->free -->Passive lock| | | 68 +----+-------+ | | 69 | | | 70 | v | 71 +------------------Selectable-Selector-is-advertised-that-the-selectable-is-readable---------+ 72 """ 73 74 class SelectableObject: 75 """DEV: to implement one of those, you need to add 2 things to your object: 76 - add "check_recv" function 77 - call "self.call_release" once you are ready to be read 78 79 You can set the __selectable_force_select__ to True in the class, if you want to 80 force the handler to use fileno(). This may only be useable on sockets created using 81 the builtin socket API.""" 82 __selectable_force_select__ = False 83 def check_recv(self): 84 """DEV: will be called only once (at beginning) to check if the object is ready.""" 85 raise OSError("This method must be overwriten.") 86 87 def _wait_non_ressources(self, callback): 88 """This get started as a thread, and waits for the data lock to be freed then advertise itself to the SelectableSelector using the callback""" 89 self.trigger = threading.Lock() 90 self.was_ended = False 91 self.trigger.acquire() 92 self.trigger.acquire() 93 if not self.was_ended: 94 callback(self) 95 96 def wait_return(self, callback): 97 """Entry point of SelectableObject: register the callback""" 98 if self.check_recv(): 99 return callback(self) 100 _t = threading.Thread(target=self._wait_non_ressources, args=(callback,)) 101 _t.setDaemon(True) 102 _t.start() 103 104 def call_release(self, arborted=False): 105 """DEV: Must be call when the object becomes ready to read. 106 Relesases the lock of _wait_non_ressources""" 107 self.was_ended = arborted 108 try: 109 self.trigger.release() 110 except (THREAD_EXCEPTION, AttributeError): 111 pass 112 113 class SelectableSelector(object): 114 """ 115 Select SelectableObject objects. 116 117 inputs: objects to process 118 remain: timeout. If 0, return []. 119 customTypes: types of the objects that have the check_recv function. 120 """ 121 def _release_all(self): 122 """Releases all locks to kill all threads""" 123 for i in self.inputs: 124 i.call_release(True) 125 self.available_lock.release() 126 127 def _timeout_thread(self, remain): 128 """Timeout before releasing every thing, if nothing was returned""" 129 time.sleep(remain) 130 if not self._ended: 131 self._ended = True 132 self._release_all() 133 134 def _exit_door(self, _input): 135 """This function is passed to each SelectableObject as a callback 136 The SelectableObjects have to call it once there are ready""" 137 self.results.append(_input) 138 if self._ended: 139 return 140 self._ended = True 141 self._release_all() 142 143 def __init__(self, inputs, remain): 144 self.results = [] 145 self.inputs = list(inputs) 146 self.remain = remain 147 self.available_lock = threading.Lock() 148 self.available_lock.acquire() 149 self._ended = False 150 151 def process(self): 152 """Entry point of SelectableSelector""" 153 if WINDOWS: 154 select_inputs = [] 155 for i in self.inputs: 156 if not isinstance(i, SelectableObject): 157 warning("Unknown ignored object type: %s", type(i)) 158 elif i.__selectable_force_select__: 159 # Then use select.select 160 select_inputs.append(i) 161 elif not self.remain and i.check_recv(): 162 self.results.append(i) 163 else: 164 i.wait_return(self._exit_door) 165 if select_inputs: 166 # Use default select function 167 self.results.extend(select(select_inputs, [], [], self.remain)[0]) 168 if not self.remain: 169 return self.results 170 171 threading.Thread(target=self._timeout_thread, args=(self.remain,)).start() 172 if not self._ended: 173 self.available_lock.acquire() 174 return self.results 175 else: 176 r,_,_ = select(self.inputs,[],[],self.remain) 177 return r 178 179 def select_objects(inputs, remain): 180 """ 181 Select SelectableObject objects. Same than: 182 select.select([inputs], [], [], remain) 183 But also works on Windows, only on SelectableObject. 184 185 inputs: objects to process 186 remain: timeout. If 0, return []. 187 """ 188 handler = SelectableSelector(inputs, remain) 189 return handler.process() 190 191 class ObjectPipe(SelectableObject): 192 def __init__(self): 193 self.rd,self.wr = os.pipe() 194 self.queue = deque() 195 def fileno(self): 196 return self.rd 197 def check_recv(self): 198 return len(self.queue) > 0 199 def send(self, obj): 200 self.queue.append(obj) 201 os.write(self.wr,b"X") 202 self.call_release() 203 def write(self, obj): 204 self.send(obj) 205 def recv(self, n=0): 206 os.read(self.rd, 1) 207 return self.queue.popleft() 208 def read(self, n=0): 209 return self.recv(n) 210 211 class Message: 212 def __init__(self, **args): 213 self.__dict__.update(args) 214 def __repr__(self): 215 return "<Message %s>" % " ".join("%s=%r"%(k,v) 216 for (k,v) in six.iteritems(self.__dict__) 217 if not k.startswith("_")) 218 219 class _instance_state: 220 def __init__(self, instance): 221 self.__self__ = instance.__self__ 222 self.__func__ = instance.__func__ 223 self.__self__.__class__ = instance.__self__.__class__ 224 def __getattr__(self, attr): 225 return getattr(self.__func__, attr) 226 def __call__(self, *args, **kargs): 227 return self.__func__(self.__self__, *args, **kargs) 228 def breaks(self): 229 return self.__self__.add_breakpoints(self.__func__) 230 def intercepts(self): 231 return self.__self__.add_interception_points(self.__func__) 232 def unbreaks(self): 233 return self.__self__.remove_breakpoints(self.__func__) 234 def unintercepts(self): 235 return self.__self__.remove_interception_points(self.__func__) 236 237 238 ############## 239 ## Automata ## 240 ############## 241 242 class ATMT: 243 STATE = "State" 244 ACTION = "Action" 245 CONDITION = "Condition" 246 RECV = "Receive condition" 247 TIMEOUT = "Timeout condition" 248 IOEVENT = "I/O event" 249 250 class NewStateRequested(Exception): 251 def __init__(self, state_func, automaton, *args, **kargs): 252 self.func = state_func 253 self.state = state_func.atmt_state 254 self.initial = state_func.atmt_initial 255 self.error = state_func.atmt_error 256 self.final = state_func.atmt_final 257 Exception.__init__(self, "Request state [%s]" % self.state) 258 self.automaton = automaton 259 self.args = args 260 self.kargs = kargs 261 self.action_parameters() # init action parameters 262 def action_parameters(self, *args, **kargs): 263 self.action_args = args 264 self.action_kargs = kargs 265 return self 266 def run(self): 267 return self.func(self.automaton, *self.args, **self.kargs) 268 def __repr__(self): 269 return "NewStateRequested(%s)" % self.state 270 271 @staticmethod 272 def state(initial=0,final=0,error=0): 273 def deco(f,initial=initial, final=final): 274 f.atmt_type = ATMT.STATE 275 f.atmt_state = f.__name__ 276 f.atmt_initial = initial 277 f.atmt_final = final 278 f.atmt_error = error 279 def state_wrapper(self, *args, **kargs): 280 return ATMT.NewStateRequested(f, self, *args, **kargs) 281 282 state_wrapper.__name__ = "%s_wrapper" % f.__name__ 283 state_wrapper.atmt_type = ATMT.STATE 284 state_wrapper.atmt_state = f.__name__ 285 state_wrapper.atmt_initial = initial 286 state_wrapper.atmt_final = final 287 state_wrapper.atmt_error = error 288 state_wrapper.atmt_origfunc = f 289 return state_wrapper 290 return deco 291 @staticmethod 292 def action(cond, prio=0): 293 def deco(f,cond=cond): 294 if not hasattr(f,"atmt_type"): 295 f.atmt_cond = {} 296 f.atmt_type = ATMT.ACTION 297 f.atmt_cond[cond.atmt_condname] = prio 298 return f 299 return deco 300 @staticmethod 301 def condition(state, prio=0): 302 def deco(f, state=state): 303 f.atmt_type = ATMT.CONDITION 304 f.atmt_state = state.atmt_state 305 f.atmt_condname = f.__name__ 306 f.atmt_prio = prio 307 return f 308 return deco 309 @staticmethod 310 def receive_condition(state, prio=0): 311 def deco(f, state=state): 312 f.atmt_type = ATMT.RECV 313 f.atmt_state = state.atmt_state 314 f.atmt_condname = f.__name__ 315 f.atmt_prio = prio 316 return f 317 return deco 318 @staticmethod 319 def ioevent(state, name, prio=0, as_supersocket=None): 320 def deco(f, state=state): 321 f.atmt_type = ATMT.IOEVENT 322 f.atmt_state = state.atmt_state 323 f.atmt_condname = f.__name__ 324 f.atmt_ioname = name 325 f.atmt_prio = prio 326 f.atmt_as_supersocket = as_supersocket 327 return f 328 return deco 329 @staticmethod 330 def timeout(state, timeout): 331 def deco(f, state=state, timeout=timeout): 332 f.atmt_type = ATMT.TIMEOUT 333 f.atmt_state = state.atmt_state 334 f.atmt_timeout = timeout 335 f.atmt_condname = f.__name__ 336 return f 337 return deco 338 339 class _ATMT_Command: 340 RUN = "RUN" 341 NEXT = "NEXT" 342 FREEZE = "FREEZE" 343 STOP = "STOP" 344 END = "END" 345 EXCEPTION = "EXCEPTION" 346 SINGLESTEP = "SINGLESTEP" 347 BREAKPOINT = "BREAKPOINT" 348 INTERCEPT = "INTERCEPT" 349 ACCEPT = "ACCEPT" 350 REPLACE = "REPLACE" 351 REJECT = "REJECT" 352 353 class _ATMT_supersocket(SuperSocket): 354 def __init__(self, name, ioevent, automaton, proto, args, kargs): 355 self.name = name 356 self.ioevent = ioevent 357 self.proto = proto 358 self.spa,self.spb = socket.socketpair(socket.AF_UNIX, socket.SOCK_DGRAM) 359 kargs["external_fd"] = {ioevent:self.spb} 360 self.atmt = automaton(*args, **kargs) 361 self.atmt.runbg() 362 def fileno(self): 363 return self.spa.fileno() 364 def send(self, s): 365 if not isinstance(s, bytes): 366 s = bytes(s) 367 return self.spa.send(s) 368 def recv(self, n=MTU): 369 try: 370 r = self.spa.recv(n) 371 except recv_error: 372 if not WINDOWS: 373 raise 374 return None 375 if self.proto is not None: 376 r = self.proto(r) 377 return r 378 def close(self): 379 pass 380 381 class _ATMT_to_supersocket: 382 def __init__(self, name, ioevent, automaton): 383 self.name = name 384 self.ioevent = ioevent 385 self.automaton = automaton 386 def __call__(self, proto, *args, **kargs): 387 return _ATMT_supersocket(self.name, self.ioevent, self.automaton, proto, args, kargs) 388 389 class Automaton_metaclass(type): 390 def __new__(cls, name, bases, dct): 391 cls = super(Automaton_metaclass, cls).__new__(cls, name, bases, dct) 392 cls.states={} 393 cls.state = None 394 cls.recv_conditions={} 395 cls.conditions={} 396 cls.ioevents={} 397 cls.timeout={} 398 cls.actions={} 399 cls.initial_states=[] 400 cls.ionames = [] 401 cls.iosupersockets = [] 402 403 members = {} 404 classes = [cls] 405 while classes: 406 c = classes.pop(0) # order is important to avoid breaking method overloading 407 classes += list(c.__bases__) 408 for k,v in six.iteritems(c.__dict__): 409 if k not in members: 410 members[k] = v 411 412 decorated = [v for v in six.itervalues(members) 413 if isinstance(v, types.FunctionType) and hasattr(v, "atmt_type")] 414 415 for m in decorated: 416 if m.atmt_type == ATMT.STATE: 417 s = m.atmt_state 418 cls.states[s] = m 419 cls.recv_conditions[s]=[] 420 cls.ioevents[s]=[] 421 cls.conditions[s]=[] 422 cls.timeout[s]=[] 423 if m.atmt_initial: 424 cls.initial_states.append(m) 425 elif m.atmt_type in [ATMT.CONDITION, ATMT.RECV, ATMT.TIMEOUT, ATMT.IOEVENT]: 426 cls.actions[m.atmt_condname] = [] 427 428 for m in decorated: 429 if m.atmt_type == ATMT.CONDITION: 430 cls.conditions[m.atmt_state].append(m) 431 elif m.atmt_type == ATMT.RECV: 432 cls.recv_conditions[m.atmt_state].append(m) 433 elif m.atmt_type == ATMT.IOEVENT: 434 cls.ioevents[m.atmt_state].append(m) 435 cls.ionames.append(m.atmt_ioname) 436 if m.atmt_as_supersocket is not None: 437 cls.iosupersockets.append(m) 438 elif m.atmt_type == ATMT.TIMEOUT: 439 cls.timeout[m.atmt_state].append((m.atmt_timeout, m)) 440 elif m.atmt_type == ATMT.ACTION: 441 for c in m.atmt_cond: 442 cls.actions[c].append(m) 443 444 445 for v in six.itervalues(cls.timeout): 446 v.sort(key=cmp_to_key(lambda t1_f1,t2_f2: cmp(t1_f1[0],t2_f2[0]))) 447 v.append((None, None)) 448 for v in itertools.chain(six.itervalues(cls.conditions), 449 six.itervalues(cls.recv_conditions), 450 six.itervalues(cls.ioevents)): 451 v.sort(key=cmp_to_key(lambda c1,c2: cmp(c1.atmt_prio,c2.atmt_prio))) 452 for condname,actlst in six.iteritems(cls.actions): 453 actlst.sort(key=cmp_to_key(lambda c1,c2: cmp(c1.atmt_cond[condname], c2.atmt_cond[condname]))) 454 455 for ioev in cls.iosupersockets: 456 setattr(cls, ioev.atmt_as_supersocket, _ATMT_to_supersocket(ioev.atmt_as_supersocket, ioev.atmt_ioname, cls)) 457 458 return cls 459 460 def graph(self, **kargs): 461 s = 'digraph "%s" {\n' % self.__class__.__name__ 462 463 se = "" # Keep initial nodes at the begining for better rendering 464 for st in six.itervalues(self.states): 465 if st.atmt_initial: 466 se = ('\t"%s" [ style=filled, fillcolor=blue, shape=box, root=true];\n' % st.atmt_state)+se 467 elif st.atmt_final: 468 se += '\t"%s" [ style=filled, fillcolor=green, shape=octagon ];\n' % st.atmt_state 469 elif st.atmt_error: 470 se += '\t"%s" [ style=filled, fillcolor=red, shape=octagon ];\n' % st.atmt_state 471 s += se 472 473 for st in six.itervalues(self.states): 474 for n in st.atmt_origfunc.__code__.co_names+st.atmt_origfunc.__code__.co_consts: 475 if n in self.states: 476 s += '\t"%s" -> "%s" [ color=green ];\n' % (st.atmt_state,n) 477 478 479 for c,k,v in ([("purple",k,v) for k,v in self.conditions.items()]+ 480 [("red",k,v) for k,v in self.recv_conditions.items()]+ 481 [("orange",k,v) for k,v in self.ioevents.items()]): 482 for f in v: 483 for n in f.__code__.co_names+f.__code__.co_consts: 484 if n in self.states: 485 l = f.atmt_condname 486 for x in self.actions[f.atmt_condname]: 487 l += "\\l>[%s]" % x.__name__ 488 s += '\t"%s" -> "%s" [label="%s", color=%s];\n' % (k,n,l,c) 489 for k,v in six.iteritems(self.timeout): 490 for t,f in v: 491 if f is None: 492 continue 493 for n in f.__code__.co_names+f.__code__.co_consts: 494 if n in self.states: 495 l = "%s/%.1fs" % (f.atmt_condname,t) 496 for x in self.actions[f.atmt_condname]: 497 l += "\\l>[%s]" % x.__name__ 498 s += '\t"%s" -> "%s" [label="%s",color=blue];\n' % (k,n,l) 499 s += "}\n" 500 return do_graph(s, **kargs) 501 502 class Automaton(six.with_metaclass(Automaton_metaclass)): 503 def parse_args(self, debug=0, store=1, **kargs): 504 self.debug_level=debug 505 self.socket_kargs = kargs 506 self.store_packets = store 507 508 def master_filter(self, pkt): 509 return True 510 511 def my_send(self, pkt): 512 self.send_sock.send(pkt) 513 514 515 ## Utility classes and exceptions 516 class _IO_fdwrapper(SelectableObject): 517 def __init__(self,rd,wr): 518 if WINDOWS: 519 # rd will be used for reading and sending 520 if isinstance(rd, ObjectPipe): 521 self.rd = rd 522 else: 523 raise OSError("On windows, only instances of ObjectPipe are externally available") 524 else: 525 if rd is not None and not isinstance(rd, int): 526 rd = rd.fileno() 527 if wr is not None and not isinstance(wr, int): 528 wr = wr.fileno() 529 self.rd = rd 530 self.wr = wr 531 def fileno(self): 532 return self.rd 533 def check_recv(self): 534 return self.rd.check_recv() 535 def read(self, n=65535): 536 if WINDOWS: 537 return self.rd.recv(n) 538 return os.read(self.rd, n) 539 def write(self, msg): 540 if WINDOWS: 541 self.rd.send(msg) 542 return self.call_release() 543 return os.write(self.wr,msg) 544 def recv(self, n=65535): 545 return self.read(n) 546 def send(self, msg): 547 return self.write(msg) 548 549 class _IO_mixer(SelectableObject): 550 def __init__(self,rd,wr): 551 self.rd = rd 552 self.wr = wr 553 def fileno(self): 554 if isinstance(self.rd, int): 555 return self.rd 556 return self.rd.fileno() 557 def check_recv(self): 558 return self.rd.check_recv() 559 def recv(self, n=None): 560 return self.rd.recv(n) 561 def read(self, n=None): 562 return self.recv(n) 563 def send(self, msg): 564 self.wr.send(msg) 565 return self.call_release() 566 def write(self, msg): 567 return self.send(msg) 568 569 570 class AutomatonException(Exception): 571 def __init__(self, msg, state=None, result=None): 572 Exception.__init__(self, msg) 573 self.state = state 574 self.result = result 575 576 class AutomatonError(AutomatonException): 577 pass 578 class ErrorState(AutomatonException): 579 pass 580 class Stuck(AutomatonException): 581 pass 582 class AutomatonStopped(AutomatonException): 583 pass 584 585 class Breakpoint(AutomatonStopped): 586 pass 587 class Singlestep(AutomatonStopped): 588 pass 589 class InterceptionPoint(AutomatonStopped): 590 def __init__(self, msg, state=None, result=None, packet=None): 591 Automaton.AutomatonStopped.__init__(self, msg, state=state, result=result) 592 self.packet = packet 593 594 class CommandMessage(AutomatonException): 595 pass 596 597 598 ## Services 599 def debug(self, lvl, msg): 600 if self.debug_level >= lvl: 601 log_interactive.debug(msg) 602 603 def send(self, pkt): 604 if self.state.state in self.interception_points: 605 self.debug(3,"INTERCEPT: packet intercepted: %s" % pkt.summary()) 606 self.intercepted_packet = pkt 607 cmd = Message(type = _ATMT_Command.INTERCEPT, state=self.state, pkt=pkt) 608 self.cmdout.send(cmd) 609 cmd = self.cmdin.recv() 610 self.intercepted_packet = None 611 if cmd.type == _ATMT_Command.REJECT: 612 self.debug(3,"INTERCEPT: packet rejected") 613 return 614 elif cmd.type == _ATMT_Command.REPLACE: 615 pkt = cmd.pkt 616 self.debug(3,"INTERCEPT: packet replaced by: %s" % pkt.summary()) 617 elif cmd.type == _ATMT_Command.ACCEPT: 618 self.debug(3,"INTERCEPT: packet accepted") 619 else: 620 raise self.AutomatonError("INTERCEPT: unkown verdict: %r" % cmd.type) 621 self.my_send(pkt) 622 self.debug(3,"SENT : %s" % pkt.summary()) 623 624 if self.store_packets: 625 self.packets.append(pkt.copy()) 626 627 628 ## Internals 629 def __init__(self, *args, **kargs): 630 external_fd = kargs.pop("external_fd",{}) 631 self.send_sock_class = kargs.pop("ll", conf.L3socket) 632 self.recv_sock_class = kargs.pop("recvsock", conf.L2listen) 633 self.started = threading.Lock() 634 self.threadid = None 635 self.breakpointed = None 636 self.breakpoints = set() 637 self.interception_points = set() 638 self.intercepted_packet = None 639 self.debug_level=0 640 self.init_args=args 641 self.init_kargs=kargs 642 self.io = type.__new__(type, "IOnamespace",(),{}) 643 self.oi = type.__new__(type, "IOnamespace",(),{}) 644 self.cmdin = ObjectPipe() 645 self.cmdout = ObjectPipe() 646 self.ioin = {} 647 self.ioout = {} 648 for n in self.ionames: 649 extfd = external_fd.get(n) 650 if not isinstance(extfd, tuple): 651 extfd = (extfd,extfd) 652 elif WINDOWS: 653 raise OSError("Tuples are not allowed as external_fd on windows") 654 ioin,ioout = extfd 655 if ioin is None: 656 ioin = ObjectPipe() 657 elif not isinstance(ioin, SelectableObject): 658 ioin = self._IO_fdwrapper(ioin,None) 659 if ioout is None: 660 ioout = ioin if WINDOWS else ObjectPipe() 661 elif not isinstance(ioout, SelectableObject): 662 ioout = self._IO_fdwrapper(None,ioout) 663 664 self.ioin[n] = ioin 665 self.ioout[n] = ioout 666 ioin.ioname = n 667 ioout.ioname = n 668 setattr(self.io, n, self._IO_mixer(ioout,ioin)) 669 setattr(self.oi, n, self._IO_mixer(ioin,ioout)) 670 671 for stname in self.states: 672 setattr(self, stname, 673 _instance_state(getattr(self, stname))) 674 675 self.start() 676 677 def __iter__(self): 678 return self 679 680 def __del__(self): 681 self.stop() 682 683 def _run_condition(self, cond, *args, **kargs): 684 try: 685 self.debug(5, "Trying %s [%s]" % (cond.atmt_type, cond.atmt_condname)) 686 cond(self,*args, **kargs) 687 except ATMT.NewStateRequested as state_req: 688 self.debug(2, "%s [%s] taken to state [%s]" % (cond.atmt_type, cond.atmt_condname, state_req.state)) 689 if cond.atmt_type == ATMT.RECV: 690 if self.store_packets: 691 self.packets.append(args[0]) 692 for action in self.actions[cond.atmt_condname]: 693 self.debug(2, " + Running action [%s]" % action.__name__) 694 action(self, *state_req.action_args, **state_req.action_kargs) 695 raise 696 except Exception as e: 697 self.debug(2, "%s [%s] raised exception [%s]" % (cond.atmt_type, cond.atmt_condname, e)) 698 raise 699 else: 700 self.debug(2, "%s [%s] not taken" % (cond.atmt_type, cond.atmt_condname)) 701 702 def _do_start(self, *args, **kargs): 703 ready = threading.Event() 704 _t = threading.Thread(target=self._do_control, args=(ready,) + (args), kwargs=kargs) 705 _t.setDaemon(True) 706 _t.start() 707 ready.wait() 708 709 def _do_control(self, ready, *args, **kargs): 710 with self.started: 711 self.threadid = threading.currentThread().ident 712 713 # Update default parameters 714 a = args+self.init_args[len(args):] 715 k = self.init_kargs.copy() 716 k.update(kargs) 717 self.parse_args(*a,**k) 718 719 # Start the automaton 720 self.state=self.initial_states[0](self) 721 self.send_sock = self.send_sock_class(**self.socket_kargs) 722 self.listen_sock = self.recv_sock_class(**self.socket_kargs) 723 self.packets = PacketList(name="session[%s]"%self.__class__.__name__) 724 725 singlestep = True 726 iterator = self._do_iter() 727 self.debug(3, "Starting control thread [tid=%i]" % self.threadid) 728 # Sync threads 729 ready.set() 730 try: 731 while True: 732 c = self.cmdin.recv() 733 self.debug(5, "Received command %s" % c.type) 734 if c.type == _ATMT_Command.RUN: 735 singlestep = False 736 elif c.type == _ATMT_Command.NEXT: 737 singlestep = True 738 elif c.type == _ATMT_Command.FREEZE: 739 continue 740 elif c.type == _ATMT_Command.STOP: 741 break 742 while True: 743 state = next(iterator) 744 if isinstance(state, self.CommandMessage): 745 break 746 elif isinstance(state, self.Breakpoint): 747 c = Message(type=_ATMT_Command.BREAKPOINT,state=state) 748 self.cmdout.send(c) 749 break 750 if singlestep: 751 c = Message(type=_ATMT_Command.SINGLESTEP,state=state) 752 self.cmdout.send(c) 753 break 754 except StopIteration as e: 755 c = Message(type=_ATMT_Command.END, result=e.args[0]) 756 self.cmdout.send(c) 757 except Exception as e: 758 exc_info = sys.exc_info() 759 self.debug(3, "Transfering exception from tid=%i:\n%s"% (self.threadid, traceback.format_exception(*exc_info))) 760 m = Message(type=_ATMT_Command.EXCEPTION, exception=e, exc_info=exc_info) 761 self.cmdout.send(m) 762 self.debug(3, "Stopping control thread (tid=%i)"%self.threadid) 763 self.threadid = None 764 765 def _do_iter(self): 766 while True: 767 try: 768 self.debug(1, "## state=[%s]" % self.state.state) 769 770 # Entering a new state. First, call new state function 771 if self.state.state in self.breakpoints and self.state.state != self.breakpointed: 772 self.breakpointed = self.state.state 773 yield self.Breakpoint("breakpoint triggered on state %s" % self.state.state, 774 state = self.state.state) 775 self.breakpointed = None 776 state_output = self.state.run() 777 if self.state.error: 778 raise self.ErrorState("Reached %s: [%r]" % (self.state.state, state_output), 779 result=state_output, state=self.state.state) 780 if self.state.final: 781 raise StopIteration(state_output) 782 783 if state_output is None: 784 state_output = () 785 elif not isinstance(state_output, list): 786 state_output = state_output, 787 788 # Then check immediate conditions 789 for cond in self.conditions[self.state.state]: 790 self._run_condition(cond, *state_output) 791 792 # If still there and no conditions left, we are stuck! 793 if ( len(self.recv_conditions[self.state.state]) == 0 and 794 len(self.ioevents[self.state.state]) == 0 and 795 len(self.timeout[self.state.state]) == 1 ): 796 raise self.Stuck("stuck in [%s]" % self.state.state, 797 state=self.state.state, result=state_output) 798 799 # Finally listen and pay attention to timeouts 800 expirations = iter(self.timeout[self.state.state]) 801 next_timeout,timeout_func = next(expirations) 802 t0 = time.time() 803 804 fds = [self.cmdin] 805 if len(self.recv_conditions[self.state.state]) > 0: 806 fds.append(self.listen_sock) 807 for ioev in self.ioevents[self.state.state]: 808 fds.append(self.ioin[ioev.atmt_ioname]) 809 while True: 810 t = time.time()-t0 811 if next_timeout is not None: 812 if next_timeout <= t: 813 self._run_condition(timeout_func, *state_output) 814 next_timeout,timeout_func = next(expirations) 815 if next_timeout is None: 816 remain = None 817 else: 818 remain = next_timeout-t 819 820 self.debug(5, "Select on %r" % fds) 821 r = select_objects(fds, remain) 822 self.debug(5, "Selected %r" % r) 823 for fd in r: 824 self.debug(5, "Looking at %r" % fd) 825 if fd == self.cmdin: 826 yield self.CommandMessage("Received command message") 827 elif fd == self.listen_sock: 828 try: 829 pkt = self.listen_sock.recv(MTU) 830 except recv_error: 831 pass 832 else: 833 if pkt is not None: 834 if self.master_filter(pkt): 835 self.debug(3, "RECVD: %s" % pkt.summary()) 836 for rcvcond in self.recv_conditions[self.state.state]: 837 self._run_condition(rcvcond, pkt, *state_output) 838 else: 839 self.debug(4, "FILTR: %s" % pkt.summary()) 840 else: 841 self.debug(3, "IOEVENT on %s" % fd.ioname) 842 for ioevt in self.ioevents[self.state.state]: 843 if ioevt.atmt_ioname == fd.ioname: 844 self._run_condition(ioevt, fd, *state_output) 845 846 except ATMT.NewStateRequested as state_req: 847 self.debug(2, "switching from [%s] to [%s]" % (self.state.state,state_req.state)) 848 self.state = state_req 849 yield state_req 850 851 ## Public API 852 def add_interception_points(self, *ipts): 853 for ipt in ipts: 854 if hasattr(ipt,"atmt_state"): 855 ipt = ipt.atmt_state 856 self.interception_points.add(ipt) 857 858 def remove_interception_points(self, *ipts): 859 for ipt in ipts: 860 if hasattr(ipt,"atmt_state"): 861 ipt = ipt.atmt_state 862 self.interception_points.discard(ipt) 863 864 def add_breakpoints(self, *bps): 865 for bp in bps: 866 if hasattr(bp,"atmt_state"): 867 bp = bp.atmt_state 868 self.breakpoints.add(bp) 869 870 def remove_breakpoints(self, *bps): 871 for bp in bps: 872 if hasattr(bp,"atmt_state"): 873 bp = bp.atmt_state 874 self.breakpoints.discard(bp) 875 876 def start(self, *args, **kargs): 877 if not self.started.locked(): 878 self._do_start(*args, **kargs) 879 880 def run(self, resume=None, wait=True): 881 if resume is None: 882 resume = Message(type = _ATMT_Command.RUN) 883 self.cmdin.send(resume) 884 if wait: 885 try: 886 c = self.cmdout.recv() 887 except KeyboardInterrupt: 888 self.cmdin.send(Message(type = _ATMT_Command.FREEZE)) 889 return 890 if c.type == _ATMT_Command.END: 891 return c.result 892 elif c.type == _ATMT_Command.INTERCEPT: 893 raise self.InterceptionPoint("packet intercepted", state=c.state.state, packet=c.pkt) 894 elif c.type == _ATMT_Command.SINGLESTEP: 895 raise self.Singlestep("singlestep state=[%s]"%c.state.state, state=c.state.state) 896 elif c.type == _ATMT_Command.BREAKPOINT: 897 raise self.Breakpoint("breakpoint triggered on state [%s]"%c.state.state, state=c.state.state) 898 elif c.type == _ATMT_Command.EXCEPTION: 899 six.reraise(c.exc_info[0], c.exc_info[1], c.exc_info[2]) 900 901 def runbg(self, resume=None, wait=False): 902 self.run(resume, wait) 903 904 def next(self): 905 return self.run(resume = Message(type=_ATMT_Command.NEXT)) 906 __next__ = next 907 908 def stop(self): 909 self.cmdin.send(Message(type=_ATMT_Command.STOP)) 910 with self.started: 911 # Flush command pipes 912 while True: 913 r = select_objects([self.cmdin, self.cmdout], 0) 914 if not r: 915 break 916 for fd in r: 917 fd.recv() 918 919 def restart(self, *args, **kargs): 920 self.stop() 921 self.start(*args, **kargs) 922 923 def accept_packet(self, pkt=None, wait=False): 924 rsm = Message() 925 if pkt is None: 926 rsm.type = _ATMT_Command.ACCEPT 927 else: 928 rsm.type = _ATMT_Command.REPLACE 929 rsm.pkt = pkt 930 return self.run(resume=rsm, wait=wait) 931 932 def reject_packet(self, wait=False): 933 rsm = Message(type = _ATMT_Command.REJECT) 934 return self.run(resume=rsm, wait=wait) 935 936 937 938