Home | History | Annotate | Download | only in scapy
      1 ## This file is part of Scapy
      2 ## See http://www.secdev.org/projects/scapy for more informations
      3 ## Copyright (C) Philippe Biondi <phil (at] secdev.org>
      4 ## This program is published under a GPLv2 license
      5 
      6 """
      7 PacketList: holds several packets and allows to do operations on them.
      8 """
      9 
     10 
     11 from __future__ import absolute_import
     12 from __future__ import print_function
     13 import os,subprocess
     14 from collections import defaultdict
     15 
     16 from scapy.config import conf
     17 from scapy.base_classes import BasePacket,BasePacketList
     18 from scapy.utils import do_graph,hexdump,make_table,make_lined_table,make_tex_table,get_temp_file
     19 
     20 from scapy.consts import plt, MATPLOTLIB_INLINED, MATPLOTLIB_DEFAULT_PLOT_KARGS
     21 from functools import reduce
     22 import scapy.modules.six as six
     23 from scapy.modules.six.moves import filter, range, zip
     24 
     25 
     26 #############
     27 ## Results ##
     28 #############
     29 
     30 class PacketList(BasePacketList):
     31     __slots__ = ["stats", "res", "listname"]
     32     def __init__(self, res=None, name="PacketList", stats=None):
     33         """create a packet list from a list of packets
     34            res: the list of packets
     35            stats: a list of classes that will appear in the stats (defaults to [TCP,UDP,ICMP])"""
     36         if stats is None:
     37             stats = conf.stats_classic_protocols
     38         self.stats = stats
     39         if res is None:
     40             res = []
     41         elif isinstance(res, PacketList):
     42             res = res.res
     43         self.res = res
     44         self.listname = name
     45     def __len__(self):
     46         return len(self.res)
     47     def _elt2pkt(self, elt):
     48         return elt
     49     def _elt2sum(self, elt):
     50         return elt.summary()
     51     def _elt2show(self, elt):
     52         return self._elt2sum(elt)
     53     def __repr__(self):
     54         stats = {x: 0 for x in self.stats}
     55         other = 0
     56         for r in self.res:
     57             f = 0
     58             for p in stats:
     59                 if self._elt2pkt(r).haslayer(p):
     60                     stats[p] += 1
     61                     f = 1
     62                     break
     63             if not f:
     64                 other += 1
     65         s = ""
     66         ct = conf.color_theme
     67         for p in self.stats:
     68             s += " %s%s%s" % (ct.packetlist_proto(p._name),
     69                               ct.punct(":"),
     70                               ct.packetlist_value(stats[p]))
     71         s += " %s%s%s" % (ct.packetlist_proto("Other"),
     72                           ct.punct(":"),
     73                           ct.packetlist_value(other))
     74         return "%s%s%s%s%s" % (ct.punct("<"),
     75                                ct.packetlist_name(self.listname),
     76                                ct.punct(":"),
     77                                s,
     78                                ct.punct(">"))
     79     def __getattr__(self, attr):
     80         return getattr(self.res, attr)
     81     def __getitem__(self, item):
     82         if isinstance(item,type) and issubclass(item,BasePacket):
     83             return self.__class__([x for x in self.res if item in self._elt2pkt(x)],
     84                                   name="%s from %s"%(item.__name__,self.listname))
     85         if isinstance(item, slice):
     86             return self.__class__(self.res.__getitem__(item),
     87                                   name = "mod %s" % self.listname)
     88         return self.res.__getitem__(item)
     89     def __getslice__(self, *args, **kargs):
     90         return self.__class__(self.res.__getslice__(*args, **kargs),
     91                               name="mod %s"%self.listname)
     92     def __add__(self, other):
     93         return self.__class__(self.res+other.res,
     94                               name="%s+%s"%(self.listname,other.listname))
     95     def summary(self, prn=None, lfilter=None):
     96         """prints a summary of each packet
     97 prn:     function to apply to each packet instead of lambda x:x.summary()
     98 lfilter: truth function to apply to each packet to decide whether it will be displayed"""
     99         for r in self.res:
    100             if lfilter is not None:
    101                 if not lfilter(r):
    102                     continue
    103             if prn is None:
    104                 print(self._elt2sum(r))
    105             else:
    106                 print(prn(r))
    107     def nsummary(self, prn=None, lfilter=None):
    108         """prints a summary of each packet with the packet's number
    109 prn:     function to apply to each packet instead of lambda x:x.summary()
    110 lfilter: truth function to apply to each packet to decide whether it will be displayed"""
    111         for i, res in enumerate(self.res):
    112             if lfilter is not None:
    113                 if not lfilter(res):
    114                     continue
    115             print(conf.color_theme.id(i,fmt="%04i"), end=' ')
    116             if prn is None:
    117                 print(self._elt2sum(res))
    118             else:
    119                 print(prn(res))
    120     def display(self): # Deprecated. Use show()
    121         """deprecated. is show()"""
    122         self.show()
    123     def show(self, *args, **kargs):
    124         """Best way to display the packet list. Defaults to nsummary() method"""
    125         return self.nsummary(*args, **kargs)
    126     
    127     def filter(self, func):
    128         """Returns a packet list filtered by a truth function"""
    129         return self.__class__([x for x in self.res if func(x)],
    130                               name="filtered %s"%self.listname)
    131     def make_table(self, *args, **kargs):
    132         """Prints a table using a function that returns for each packet its head column value, head row value and displayed value
    133         ex: p.make_table(lambda x:(x[IP].dst, x[TCP].dport, x[TCP].sprintf("%flags%")) """
    134         return make_table(self.res, *args, **kargs)
    135     def make_lined_table(self, *args, **kargs):
    136         """Same as make_table, but print a table with lines"""
    137         return make_lined_table(self.res, *args, **kargs)
    138     def make_tex_table(self, *args, **kargs):
    139         """Same as make_table, but print a table with LaTeX syntax"""
    140         return make_tex_table(self.res, *args, **kargs)
    141 
    142     def plot(self, f, lfilter=None, plot_xy=False, **kargs):
    143         """Applies a function to each packet to get a value that will be plotted
    144         with matplotlib. A list of matplotlib.lines.Line2D is returned.
    145 
    146         lfilter: a truth function that decides whether a packet must be plotted
    147         """
    148 
    149         # Get the list of packets
    150         if lfilter is None:
    151             l = [f(e) for e in self.res]
    152         else:
    153             l = [f(e) for e in self.res if lfilter(e)]
    154 
    155         # Mimic the default gnuplot output
    156         if kargs == {}:
    157             kargs = MATPLOTLIB_DEFAULT_PLOT_KARGS
    158         if plot_xy:
    159             lines = plt.plot(*zip(*l), **kargs)
    160         else:
    161             lines = plt.plot(l, **kargs)
    162 
    163         # Call show() if matplotlib is not inlined
    164         if not MATPLOTLIB_INLINED:
    165             plt.show()
    166 
    167         return lines
    168 
    169     def diffplot(self, f, delay=1, lfilter=None, **kargs):
    170         """diffplot(f, delay=1, lfilter=None)
    171         Applies a function to couples (l[i],l[i+delay])
    172 
    173         A list of matplotlib.lines.Line2D is returned.
    174         """
    175 
    176         # Get the list of packets
    177         if lfilter is None:
    178             l = [f(self.res[i], self.res[i+1])
    179                     for i in range(len(self.res) - delay)]
    180         else:
    181             l = [f(self.res[i], self.res[i+1])
    182                     for i in range(len(self.res) - delay)
    183                         if lfilter(self.res[i])]
    184 
    185         # Mimic the default gnuplot output
    186         if kargs == {}:
    187             kargs = MATPLOTLIB_DEFAULT_PLOT_KARGS
    188         lines = plt.plot(l, **kargs)
    189 
    190         # Call show() if matplotlib is not inlined
    191         if not MATPLOTLIB_INLINED:
    192             plt.show()
    193 
    194         return lines
    195 
    196     def multiplot(self, f, lfilter=None, plot_xy=False, **kargs):
    197         """Uses a function that returns a label and a value for this label, then
    198         plots all the values label by label.
    199 
    200         A list of matplotlib.lines.Line2D is returned.
    201         """
    202 
    203         # Get the list of packets
    204         if lfilter is None:
    205             l = (f(e) for e in self.res)
    206         else:
    207             l = (f(e) for e in self.res if lfilter(e))
    208 
    209         # Apply the function f to the packets
    210         d = {}
    211         for k, v in l:
    212             d.setdefault(k, []).append(v)
    213 
    214         # Mimic the default gnuplot output
    215         if not kargs:
    216             kargs = MATPLOTLIB_DEFAULT_PLOT_KARGS
    217 
    218         if plot_xy:
    219             lines = [plt.plot(*zip(*pl), **dict(kargs, label=k))
    220                      for k, pl in six.iteritems(d)]
    221         else:
    222             lines = [plt.plot(pl, **dict(kargs, label=k))
    223                      for k, pl in six.iteritems(d)]
    224         plt.legend(loc="center right", bbox_to_anchor=(1.5, 0.5))
    225 
    226         # Call show() if matplotlib is not inlined
    227         if not MATPLOTLIB_INLINED:
    228             plt.show()
    229 
    230         return lines
    231 
    232     def rawhexdump(self):
    233         """Prints an hexadecimal dump of each packet in the list"""
    234         for p in self:
    235             hexdump(self._elt2pkt(p))
    236 
    237     def hexraw(self, lfilter=None):
    238         """Same as nsummary(), except that if a packet has a Raw layer, it will be hexdumped
    239         lfilter: a truth function that decides whether a packet must be displayed"""
    240         for i, res in enumerate(self.res):
    241             p = self._elt2pkt(res)
    242             if lfilter is not None and not lfilter(p):
    243                 continue
    244             print("%s %s %s" % (conf.color_theme.id(i,fmt="%04i"),
    245                                 p.sprintf("%.time%"),
    246                                 self._elt2sum(res)))
    247             if p.haslayer(conf.raw_layer):
    248                 hexdump(p.getlayer(conf.raw_layer).load)
    249 
    250     def hexdump(self, lfilter=None):
    251         """Same as nsummary(), except that packets are also hexdumped
    252         lfilter: a truth function that decides whether a packet must be displayed"""
    253         for i, res in enumerate(self.res):
    254             p = self._elt2pkt(res)
    255             if lfilter is not None and not lfilter(p):
    256                 continue
    257             print("%s %s %s" % (conf.color_theme.id(i,fmt="%04i"),
    258                                 p.sprintf("%.time%"),
    259                                 self._elt2sum(res)))
    260             hexdump(p)
    261 
    262     def padding(self, lfilter=None):
    263         """Same as hexraw(), for Padding layer"""
    264         for i, res in enumerate(self.res):
    265             p = self._elt2pkt(res)
    266             if p.haslayer(conf.padding_layer):
    267                 if lfilter is None or lfilter(p):
    268                     print("%s %s %s" % (conf.color_theme.id(i,fmt="%04i"),
    269                                         p.sprintf("%.time%"),
    270                                         self._elt2sum(res)))
    271                     hexdump(p.getlayer(conf.padding_layer).load)
    272 
    273     def nzpadding(self, lfilter=None):
    274         """Same as padding() but only non null padding"""
    275         for i, res in enumerate(self.res):
    276             p = self._elt2pkt(res)
    277             if p.haslayer(conf.padding_layer):
    278                 pad = p.getlayer(conf.padding_layer).load
    279                 if pad == pad[0]*len(pad):
    280                     continue
    281                 if lfilter is None or lfilter(p):
    282                     print("%s %s %s" % (conf.color_theme.id(i,fmt="%04i"),
    283                                         p.sprintf("%.time%"),
    284                                         self._elt2sum(res)))
    285                     hexdump(p.getlayer(conf.padding_layer).load)
    286         
    287 
    288     def conversations(self, getsrcdst=None,**kargs):
    289         """Graphes a conversations between sources and destinations and display it
    290         (using graphviz and imagemagick)
    291         getsrcdst: a function that takes an element of the list and
    292                    returns the source, the destination and optionally
    293                    a label. By default, returns the IP source and
    294                    destination from IP and ARP layers
    295         type: output type (svg, ps, gif, jpg, etc.), passed to dot's "-T" option
    296         target: filename or redirect. Defaults pipe to Imagemagick's display program
    297         prog: which graphviz program to use"""
    298         if getsrcdst is None:
    299             def getsrcdst(pkt):
    300                 if 'IP' in pkt:
    301                     return (pkt['IP'].src, pkt['IP'].dst)
    302                 if 'ARP' in pkt:
    303                     return (pkt['ARP'].psrc, pkt['ARP'].pdst)
    304                 raise TypeError()
    305         conv = {}
    306         for p in self.res:
    307             p = self._elt2pkt(p)
    308             try:
    309                 c = getsrcdst(p)
    310             except:
    311                 # No warning here: it's OK that getsrcdst() raises an
    312                 # exception, since it might be, for example, a
    313                 # function that expects a specific layer in each
    314                 # packet. The try/except approach is faster and
    315                 # considered more Pythonic than adding tests.
    316                 continue
    317             if len(c) == 3:
    318                 conv.setdefault(c[:2], set()).add(c[2])
    319             else:
    320                 conv[c] = conv.get(c, 0) + 1
    321         gr = 'digraph "conv" {\n'
    322         for (s, d), l in six.iteritems(conv):
    323             gr += '\t "%s" -> "%s" [label="%s"]\n' % (
    324                 s, d, ', '.join(str(x) for x in l) if isinstance(l, set) else l
    325             )
    326         gr += "}\n"        
    327         return do_graph(gr, **kargs)
    328 
    329     def afterglow(self, src=None, event=None, dst=None, **kargs):
    330         """Experimental clone attempt of http://sourceforge.net/projects/afterglow
    331         each datum is reduced as src -> event -> dst and the data are graphed.
    332         by default we have IP.src -> IP.dport -> IP.dst"""
    333         if src is None:
    334             src = lambda x: x['IP'].src
    335         if event is None:
    336             event = lambda x: x['IP'].dport
    337         if dst is None:
    338             dst = lambda x: x['IP'].dst
    339         sl = {}
    340         el = {}
    341         dl = {}
    342         for i in self.res:
    343             try:
    344                 s,e,d = src(i),event(i),dst(i)
    345                 if s in sl:
    346                     n,l = sl[s]
    347                     n += 1
    348                     if e not in l:
    349                         l.append(e)
    350                     sl[s] = (n,l)
    351                 else:
    352                     sl[s] = (1,[e])
    353                 if e in el:
    354                     n,l = el[e]
    355                     n+=1
    356                     if d not in l:
    357                         l.append(d)
    358                     el[e] = (n,l)
    359                 else:
    360                     el[e] = (1,[d])
    361                 dl[d] = dl.get(d,0)+1
    362             except:
    363                 continue
    364 
    365         import math
    366         def normalize(n):
    367             return 2+math.log(n)/4.0
    368 
    369         def minmax(x):
    370             m, M = reduce(lambda a, b: (min(a[0], b[0]), max(a[1], b[1])),
    371                           ((a, a) for a in x))
    372             if m == M:
    373                 m = 0
    374             if M == 0:
    375                 M = 1
    376             return m, M
    377 
    378         mins, maxs = minmax(x for x, _ in six.itervalues(sl))
    379         mine, maxe = minmax(x for x, _ in six.itervalues(el))
    380         mind, maxd = minmax(six.itervalues(dl))
    381     
    382         gr = 'digraph "afterglow" {\n\tedge [len=2.5];\n'
    383 
    384         gr += "# src nodes\n"
    385         for s in sl:
    386             n,l = sl[s]; n = 1+float(n-mins)/(maxs-mins)
    387             gr += '"src.%s" [label = "%s", shape=box, fillcolor="#FF0000", style=filled, fixedsize=1, height=%.2f,width=%.2f];\n' % (repr(s),repr(s),n,n)
    388         gr += "# event nodes\n"
    389         for e in el:
    390             n,l = el[e]; n = n = 1+float(n-mine)/(maxe-mine)
    391             gr += '"evt.%s" [label = "%s", shape=circle, fillcolor="#00FFFF", style=filled, fixedsize=1, height=%.2f, width=%.2f];\n' % (repr(e),repr(e),n,n)
    392         for d in dl:
    393             n = dl[d]; n = n = 1+float(n-mind)/(maxd-mind)
    394             gr += '"dst.%s" [label = "%s", shape=triangle, fillcolor="#0000ff", style=filled, fixedsize=1, height=%.2f, width=%.2f];\n' % (repr(d),repr(d),n,n)
    395 
    396         gr += "###\n"
    397         for s in sl:
    398             n,l = sl[s]
    399             for e in l:
    400                 gr += ' "src.%s" -> "evt.%s";\n' % (repr(s),repr(e)) 
    401         for e in el:
    402             n,l = el[e]
    403             for d in l:
    404                 gr += ' "evt.%s" -> "dst.%s";\n' % (repr(e),repr(d)) 
    405             
    406         gr += "}"
    407         return do_graph(gr, **kargs)
    408 
    409 
    410     def _dump_document(self, **kargs):
    411         import pyx
    412         d = pyx.document.document()
    413         l = len(self.res)
    414         for i, res in enumerate(self.res):
    415             c = self._elt2pkt(res).canvas_dump(**kargs)
    416             cbb = c.bbox()
    417             c.text(cbb.left(),cbb.top()+1,r"\font\cmssfont=cmss12\cmssfont{Frame %i/%i}" % (i,l),[pyx.text.size.LARGE])
    418             if conf.verb >= 2:
    419                 os.write(1, b".")
    420             d.append(pyx.document.page(c, paperformat=pyx.document.paperformat.A4,
    421                                        margin=1*pyx.unit.t_cm,
    422                                        fittosize=1))
    423         return d
    424                      
    425                  
    426 
    427     def psdump(self, filename = None, **kargs):
    428         """Creates a multi-page postcript file with a psdump of every packet
    429         filename: name of the file to write to. If empty, a temporary file is used and
    430                   conf.prog.psreader is called"""
    431         d = self._dump_document(**kargs)
    432         if filename is None:
    433             filename = get_temp_file(autoext=".ps")
    434             d.writePSfile(filename)
    435             with ContextManagerSubprocess("psdump()"):
    436                 subprocess.Popen([conf.prog.psreader, filename+".ps"])
    437         else:
    438             d.writePSfile(filename)
    439         print()
    440         
    441     def pdfdump(self, filename = None, **kargs):
    442         """Creates a PDF file with a psdump of every packet
    443         filename: name of the file to write to. If empty, a temporary file is used and
    444                   conf.prog.pdfreader is called"""
    445         d = self._dump_document(**kargs)
    446         if filename is None:
    447             filename = get_temp_file(autoext=".pdf")
    448             d.writePDFfile(filename)
    449             with ContextManagerSubprocess("psdump()"):
    450                 subprocess.Popen([conf.prog.pdfreader, filename+".pdf"])
    451         else:
    452             d.writePDFfile(filename)
    453         print()
    454 
    455     def sr(self,multi=0):
    456         """sr([multi=1]) -> (SndRcvList, PacketList)
    457         Matches packets in the list and return ( (matched couples), (unmatched packets) )"""
    458         remain = self.res[:]
    459         sr = []
    460         i = 0
    461         while i < len(remain):
    462             s = remain[i]
    463             j = i
    464             while j < len(remain)-1:
    465                 j += 1
    466                 r = remain[j]
    467                 if r.answers(s):
    468                     sr.append((s,r))
    469                     if multi:
    470                         remain[i]._answered=1
    471                         remain[j]._answered=2
    472                         continue
    473                     del(remain[j])
    474                     del(remain[i])
    475                     i -= 1
    476                     break
    477             i += 1
    478         if multi:
    479             remain = [x for x in remain if not hasattr(x, "_answered")]
    480         return SndRcvList(sr),PacketList(remain)
    481 
    482     def sessions(self, session_extractor=None):
    483         if session_extractor is None:
    484             def session_extractor(p):
    485                 sess = "Other"
    486                 if 'Ether' in p:
    487                     if 'IP' in p:
    488                         if 'TCP' in p:
    489                             sess = p.sprintf("TCP %IP.src%:%r,TCP.sport% > %IP.dst%:%r,TCP.dport%")
    490                         elif 'UDP' in p:
    491                             sess = p.sprintf("UDP %IP.src%:%r,UDP.sport% > %IP.dst%:%r,UDP.dport%")
    492                         elif 'ICMP' in p:
    493                             sess = p.sprintf("ICMP %IP.src% > %IP.dst% type=%r,ICMP.type% code=%r,ICMP.code% id=%ICMP.id%")
    494                         else:
    495                             sess = p.sprintf("IP %IP.src% > %IP.dst% proto=%IP.proto%")
    496                     elif 'ARP' in p:
    497                         sess = p.sprintf("ARP %ARP.psrc% > %ARP.pdst%")
    498                     else:
    499                         sess = p.sprintf("Ethernet type=%04xr,Ether.type%")
    500                 return sess
    501         sessions = defaultdict(self.__class__)
    502         for p in self.res:
    503             sess = session_extractor(self._elt2pkt(p))
    504             sessions[sess].append(p)
    505         return dict(sessions)
    506     
    507     def replace(self, *args, **kargs):
    508         """
    509         lst.replace(<field>,[<oldvalue>,]<newvalue>)
    510         lst.replace( (fld,[ov],nv),(fld,[ov,]nv),...)
    511           if ov is None, all values are replaced
    512         ex:
    513           lst.replace( IP.src, "192.168.1.1", "10.0.0.1" )
    514           lst.replace( IP.ttl, 64 )
    515           lst.replace( (IP.ttl, 64), (TCP.sport, 666, 777), )
    516         """
    517         delete_checksums = kargs.get("delete_checksums",False)
    518         x=PacketList(name="Replaced %s" % self.listname)
    519         if not isinstance(args[0], tuple):
    520             args = (args,)
    521         for p in self.res:
    522             p = self._elt2pkt(p)
    523             copied = False
    524             for scheme in args:
    525                 fld = scheme[0]
    526                 old = scheme[1] # not used if len(scheme) == 2
    527                 new = scheme[-1]
    528                 for o in fld.owners:
    529                     if o in p:
    530                         if len(scheme) == 2 or p[o].getfieldval(fld.name) == old:
    531                             if not copied:
    532                                 p = p.copy()
    533                                 if delete_checksums:
    534                                     p.delete_checksums()
    535                                 copied = True
    536                             setattr(p[o], fld.name, new)
    537             x.append(p)
    538         return x
    539 
    540 
    541 class SndRcvList(PacketList):
    542     __slots__ = []
    543     def __init__(self, res=None, name="Results", stats=None):
    544         PacketList.__init__(self, res, name, stats)
    545     def _elt2pkt(self, elt):
    546         return elt[1]
    547     def _elt2sum(self, elt):
    548         return "%s ==> %s" % (elt[0].summary(),elt[1].summary()) 
    549