Home | History | Annotate | Download | only in scapy
      1 ## This file is part of Scapy
      2 ## See http://www.secdev.org/projects/scapy for more informations
      3 ## Copyright (C) Philippe Biondi <phil (at] secdev.org>
      4 ## This program is published under a GPLv2 license
      5 
      6 """
      7 Generators and packet meta classes.
      8 """
      9 
     10 ###############
     11 ## Generators ##
     12 ################
     13 
     14 from __future__ import absolute_import
     15 import re,random,socket
     16 import types
     17 from scapy.modules.six.moves import range
     18 
     19 class Gen(object):
     20     __slots__ = []
     21     def __iter__(self):
     22         return iter([])
     23     
     24 class SetGen(Gen):
     25     def __init__(self, values, _iterpacket=1):
     26         self._iterpacket=_iterpacket
     27         if isinstance(values, (list, BasePacketList)):
     28             self.values = list(values)
     29         elif (isinstance(values, tuple) and (2 <= len(values) <= 3) and \
     30              all(hasattr(i, "__int__") for i in values)):
     31             # We use values[1] + 1 as stop value for (x)range to maintain
     32             # the behavior of using tuples as field `values`
     33             self.values = [range(*((int(values[0]), int(values[1]) + 1)
     34                                     + tuple(int(v) for v in values[2:])))]
     35         else:
     36             self.values = [values]
     37     def transf(self, element):
     38         return element
     39     def __iter__(self):
     40         for i in self.values:
     41             if (isinstance(i, Gen) and
     42                 (self._iterpacket or not isinstance(i,BasePacket))) or (
     43                     isinstance(i, (range, types.GeneratorType))):
     44                 for j in i:
     45                     yield j
     46             else:
     47                 yield i
     48     def __repr__(self):
     49         return "<SetGen %r>" % self.values
     50 
     51 class Net(Gen):
     52     """Generate a list of IPs from a network address or a name"""
     53     name = "ip"
     54     ip_regex = re.compile(r"^(\*|[0-2]?[0-9]?[0-9](-[0-2]?[0-9]?[0-9])?)\.(\*|[0-2]?[0-9]?[0-9](-[0-2]?[0-9]?[0-9])?)\.(\*|[0-2]?[0-9]?[0-9](-[0-2]?[0-9]?[0-9])?)\.(\*|[0-2]?[0-9]?[0-9](-[0-2]?[0-9]?[0-9])?)(/[0-3]?[0-9])?$")
     55 
     56     @staticmethod
     57     def _parse_digit(a,netmask):
     58         netmask = min(8,max(netmask,0))
     59         if a == "*":
     60             a = (0,256)
     61         elif a.find("-") >= 0:
     62             x, y = [int(d) for d in a.split('-')]
     63             if x > y:
     64                 y = x
     65             a = (x &  (0xff<<netmask) , max(y, (x | (0xff>>(8-netmask))))+1)
     66         else:
     67             a = (int(a) & (0xff<<netmask),(int(a) | (0xff>>(8-netmask)))+1)
     68         return a
     69 
     70     @classmethod
     71     def _parse_net(cls, net):
     72         tmp=net.split('/')+["32"]
     73         if not cls.ip_regex.match(net):
     74             tmp[0]=socket.gethostbyname(tmp[0])
     75         netmask = int(tmp[1])
     76         ret_list = [cls._parse_digit(x, y-netmask) for (x, y) in zip(tmp[0].split('.'), [8, 16, 24, 32])]
     77         return ret_list, netmask
     78 
     79     def __init__(self, net):
     80         self.repr=net
     81         self.parsed,self.netmask = self._parse_net(net)
     82 
     83     def __str__(self):
     84         try:
     85             return next(self.__iter__())
     86         except StopIteration:
     87             return None
     88                                                                                                
     89     def __iter__(self):
     90         for d in range(*self.parsed[3]):
     91             for c in range(*self.parsed[2]):
     92                 for b in range(*self.parsed[1]):
     93                     for a in range(*self.parsed[0]):
     94                         yield "%i.%i.%i.%i" % (a,b,c,d)
     95     def choice(self):
     96         ip = []
     97         for v in self.parsed:
     98             ip.append(str(random.randint(v[0],v[1]-1)))
     99         return ".".join(ip) 
    100                           
    101     def __repr__(self):
    102         return "Net(%r)" % self.repr
    103     def __eq__(self, other):
    104         if hasattr(other, "parsed"):
    105             p2 = other.parsed
    106         else:
    107             p2,nm2 = self._parse_net(other)
    108         return self.parsed == p2
    109     def __contains__(self, other):
    110         if hasattr(other, "parsed"):
    111             p2 = other.parsed
    112         else:
    113             p2,nm2 = self._parse_net(other)
    114         for (a1,b1),(a2,b2) in zip(self.parsed,p2):
    115             if a1 > a2 or b1 < b2:
    116                 return False
    117         return True
    118     def __rcontains__(self, other):        
    119         return self in self.__class__(other)
    120         
    121 
    122 class OID(Gen):
    123     name = "OID"
    124     def __init__(self, oid):
    125         self.oid = oid        
    126         self.cmpt = []
    127         fmt = []        
    128         for i in oid.split("."):
    129             if "-" in i:
    130                 fmt.append("%i")
    131                 self.cmpt.append(tuple(map(int, i.split("-"))))
    132             else:
    133                 fmt.append(i)
    134         self.fmt = ".".join(fmt)
    135     def __repr__(self):
    136         return "OID(%r)" % self.oid
    137     def __iter__(self):        
    138         ii = [k[0] for k in self.cmpt]
    139         while True:
    140             yield self.fmt % tuple(ii)
    141             i = 0
    142             while True:
    143                 if i >= len(ii):
    144                     raise StopIteration
    145                 if ii[i] < self.cmpt[i][1]:
    146                     ii[i]+=1
    147                     break
    148                 else:
    149                     ii[i] = self.cmpt[i][0]
    150                 i += 1
    151 
    152 
    153  
    154 ######################################
    155 ## Packet abstract and base classes ##
    156 ######################################
    157 
    158 class Packet_metaclass(type):
    159     def __new__(cls, name, bases, dct):
    160         if "fields_desc" in dct: # perform resolution of references to other packets
    161             current_fld = dct["fields_desc"]
    162             resolved_fld = []
    163             for f in current_fld:
    164                 if isinstance(f, Packet_metaclass): # reference to another fields_desc
    165                     for f2 in f.fields_desc:
    166                         resolved_fld.append(f2)
    167                 else:
    168                     resolved_fld.append(f)
    169         else: # look for a fields_desc in parent classes
    170             resolved_fld = None
    171             for b in bases:
    172                 if hasattr(b,"fields_desc"):
    173                     resolved_fld = b.fields_desc
    174                     break
    175 
    176         if resolved_fld: # perform default value replacements
    177             final_fld = []
    178             for f in resolved_fld:
    179                 if f.name in dct:
    180                     f = f.copy()
    181                     f.default = dct[f.name]
    182                     del(dct[f.name])
    183                 final_fld.append(f)
    184 
    185             dct["fields_desc"] = final_fld
    186 
    187         if "__slots__" not in dct:
    188             dct["__slots__"] = []
    189         for attr in ["name", "overload_fields"]:
    190             try:
    191                 dct["_%s" % attr] = dct.pop(attr)
    192             except KeyError:
    193                 pass
    194         newcls = super(Packet_metaclass, cls).__new__(cls, name, bases, dct)
    195         newcls.__all_slots__ = set(
    196             attr
    197             for cls in newcls.__mro__ if hasattr(cls, "__slots__")
    198             for attr in cls.__slots__
    199         )
    200 
    201         if hasattr(newcls, "aliastypes"):
    202             newcls.aliastypes = [newcls] + newcls.aliastypes
    203         else:
    204             newcls.aliastypes = [newcls]
    205 
    206         if hasattr(newcls,"register_variant"):
    207             newcls.register_variant()
    208         for f in newcls.fields_desc:
    209             if hasattr(f, "register_owner"):
    210                 f.register_owner(newcls)
    211         from scapy import config
    212         config.conf.layers.register(newcls)
    213         return newcls
    214 
    215     def __getattr__(self, attr):
    216         for k in self.fields_desc:
    217             if k.name == attr:
    218                 return k
    219         raise AttributeError(attr)
    220 
    221     def __call__(cls, *args, **kargs):
    222         if "dispatch_hook" in cls.__dict__:
    223             try:
    224                 cls = cls.dispatch_hook(*args, **kargs)
    225             except:
    226                 from scapy import config
    227                 if config.conf.debug_dissector:
    228                     raise
    229                 cls = config.conf.raw_layer
    230         i = cls.__new__(cls, cls.__name__, cls.__bases__, cls.__dict__)
    231         i.__init__(*args, **kargs)
    232         return i
    233 
    234 class Field_metaclass(type):
    235     def __new__(cls, name, bases, dct):
    236         if "__slots__" not in dct:
    237             dct["__slots__"] = []
    238         newcls = super(Field_metaclass, cls).__new__(cls, name, bases, dct)
    239         return newcls
    240 
    241 class NewDefaultValues(Packet_metaclass):
    242     """NewDefaultValues is deprecated (not needed anymore)
    243     
    244     remove this:
    245         __metaclass__ = NewDefaultValues
    246     and it should still work.
    247     """    
    248     def __new__(cls, name, bases, dct):
    249         from scapy.error import log_loading
    250         import traceback
    251         try:
    252             for tb in traceback.extract_stack()+[("??",-1,None,"")]:
    253                 f,l,_,line = tb
    254                 if line.startswith("class"):
    255                     break
    256         except:
    257             f,l="??",-1
    258             raise
    259         log_loading.warning("Deprecated (no more needed) use of NewDefaultValues  (%s l. %i).", f, l)
    260         
    261         return super(NewDefaultValues, cls).__new__(cls, name, bases, dct)
    262 
    263 class BasePacket(Gen):
    264     __slots__ = []
    265 
    266 
    267 #############################
    268 ## Packet list base class  ##
    269 #############################
    270 
    271 class BasePacketList(object):
    272     __slots__ = []
    273