Home | History | Annotate | Download | only in test
      1 #!/usr/bin/python
      2 #
      3 # Copyright 2014 The Android Open Source Project
      4 #
      5 # Licensed under the Apache License, Version 2.0 (the "License");
      6 # you may not use this file except in compliance with the License.
      7 # You may obtain a copy of the License at
      8 #
      9 # http://www.apache.org/licenses/LICENSE-2.0
     10 #
     11 # Unless required by applicable law or agreed to in writing, software
     12 # distributed under the License is distributed on an "AS IS" BASIS,
     13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     14 # See the License for the specific language governing permissions and
     15 # limitations under the License.
     16 
     17 """Base module for multinetwork tests."""
     18 
     19 import errno
     20 import fcntl
     21 import os
     22 import posix
     23 import random
     24 import re
     25 from socket import *  # pylint: disable=wildcard-import
     26 import struct
     27 import time
     28 
     29 from scapy import all as scapy
     30 
     31 import csocket
     32 import iproute
     33 import net_test
     34 
     35 
     36 IFF_TUN = 1
     37 IFF_TAP = 2
     38 IFF_NO_PI = 0x1000
     39 TUNSETIFF = 0x400454ca
     40 
     41 SO_BINDTODEVICE = 25
     42 
     43 # Setsockopt values.
     44 IP_UNICAST_IF = 50
     45 IPV6_MULTICAST_IF = 17
     46 IPV6_UNICAST_IF = 76
     47 
     48 # Cmsg values.
     49 IP_TTL = 2
     50 IPV6_2292PKTOPTIONS = 6
     51 IPV6_FLOWINFO = 11
     52 IPV6_HOPLIMIT = 52  # Different from IPV6_UNICAST_HOPS, this is cmsg only.
     53 
     54 
     55 AUTOCONF_TABLE_SYSCTL = "/proc/sys/net/ipv6/conf/default/accept_ra_rt_table"
     56 IPV4_MARK_REFLECT_SYSCTL = "/proc/sys/net/ipv4/fwmark_reflect"
     57 IPV6_MARK_REFLECT_SYSCTL = "/proc/sys/net/ipv6/fwmark_reflect"
     58 
     59 HAVE_AUTOCONF_TABLE = os.path.isfile(AUTOCONF_TABLE_SYSCTL)
     60 
     61 
     62 class UnexpectedPacketError(AssertionError):
     63   pass
     64 
     65 
     66 def MakePktInfo(version, addr, ifindex):
     67   family = {4: AF_INET, 6: AF_INET6}[version]
     68   if not addr:
     69     addr = {4: "0.0.0.0", 6: "::"}[version]
     70   if addr:
     71     addr = inet_pton(family, addr)
     72   if version == 6:
     73     return csocket.In6Pktinfo((addr, ifindex)).Pack()
     74   else:
     75     return csocket.InPktinfo((ifindex, addr, "\x00" * 4)).Pack()
     76 
     77 
     78 class MultiNetworkBaseTest(net_test.NetworkTest):
     79   """Base class for all multinetwork tests.
     80 
     81   This class does not contain any test code, but contains code to set up and
     82   tear a multi-network environment using multiple tun interfaces. The
     83   environment is designed to be similar to a real Android device in terms of
     84   rules and routes, and supports IPv4 and IPv6.
     85 
     86   Tests wishing to use this environment should inherit from this class and
     87   ensure that any setupClass, tearDownClass, setUp, and tearDown methods they
     88   implement also call the superclass versions.
     89   """
     90 
     91   # Must be between 1 and 256, since we put them in MAC addresses and IIDs.
     92   NETIDS = [100, 150, 200, 250]
     93 
     94   # Stores sysctl values to write back when the test completes.
     95   saved_sysctls = {}
     96 
     97   # Wether to output setup commands.
     98   DEBUG = False
     99 
    100   # The size of our UID ranges.
    101   UID_RANGE_SIZE = 1000
    102 
    103   # Rule priorities.
    104   PRIORITY_UID = 100
    105   PRIORITY_OIF = 200
    106   PRIORITY_FWMARK = 300
    107   PRIORITY_IIF = 400
    108   PRIORITY_DEFAULT = 999
    109   PRIORITY_UNREACHABLE = 1000
    110 
    111   # Actual device routing is more complicated, involving more than one rule
    112   # per NetId, but here we make do with just one rule that selects the lower
    113   # 16 bits.
    114   NETID_FWMASK = 0xffff
    115 
    116   # For convenience.
    117   IPV4_ADDR = net_test.IPV4_ADDR
    118   IPV6_ADDR = net_test.IPV6_ADDR
    119   IPV4_ADDR2 = net_test.IPV4_ADDR2
    120   IPV6_ADDR2 = net_test.IPV6_ADDR2
    121   IPV4_PING = net_test.IPV4_PING
    122   IPV6_PING = net_test.IPV6_PING
    123 
    124   RA_VALIDITY = 300 # seconds
    125 
    126   @classmethod
    127   def UidRangeForNetid(cls, netid):
    128     return (
    129         cls.UID_RANGE_SIZE * netid,
    130         cls.UID_RANGE_SIZE * (netid + 1) - 1
    131     )
    132 
    133   @classmethod
    134   def UidForNetid(cls, netid):
    135     if not netid:
    136       return 0
    137     return random.randint(*cls.UidRangeForNetid(netid))
    138 
    139   @classmethod
    140   def _TableForNetid(cls, netid):
    141     if cls.AUTOCONF_TABLE_OFFSET and netid in cls.ifindices:
    142       return cls.ifindices[netid] + (-cls.AUTOCONF_TABLE_OFFSET)
    143     else:
    144       return netid
    145 
    146   @staticmethod
    147   def GetInterfaceName(netid):
    148     return "nettest%d" % netid
    149 
    150   @staticmethod
    151   def RouterMacAddress(netid):
    152     return "02:00:00:00:%02x:00" % netid
    153 
    154   @staticmethod
    155   def MyMacAddress(netid):
    156     return "02:00:00:00:%02x:01" % netid
    157 
    158   @staticmethod
    159   def _RouterAddress(netid, version):
    160     if version == 6:
    161       return "fe80::%02x00" % netid
    162     elif version == 4:
    163       return "10.0.%d.1" % netid
    164     else:
    165       raise ValueError("Don't support IPv%s" % version)
    166 
    167   @classmethod
    168   def _MyIPv4Address(cls, netid):
    169     return "10.0.%d.2" % netid
    170 
    171   @classmethod
    172   def _MyIPv6Address(cls, netid):
    173     return net_test.GetLinkAddress(cls.GetInterfaceName(netid), False)
    174 
    175   @classmethod
    176   def MyAddress(cls, version, netid):
    177     return {4: cls._MyIPv4Address(netid),
    178             5: cls._MyIPv4Address(netid),
    179             6: cls._MyIPv6Address(netid)}[version]
    180 
    181   @classmethod
    182   def MySocketAddress(cls, version, netid):
    183     addr = cls.MyAddress(version, netid)
    184     return "::ffff:" + addr if version == 5 else addr
    185 
    186   @classmethod
    187   def MyLinkLocalAddress(cls, netid):
    188     return net_test.GetLinkAddress(cls.GetInterfaceName(netid), True)
    189 
    190   @staticmethod
    191   def OnlinkPrefixLen(version):
    192     return {4: 24, 6: 64}[version]
    193 
    194   @staticmethod
    195   def OnlinkPrefix(version, netid):
    196     return {4: "10.0.%d.0" % netid,
    197             6: "2001:db8:%02x::" % netid}[version]
    198 
    199   @staticmethod
    200   def GetRandomDestination(prefix):
    201     if "." in prefix:
    202       return prefix + "%d.%d" % (random.randint(0, 255), random.randint(0, 255))
    203     else:
    204       return prefix + "%x:%x" % (random.randint(0, 65535),
    205                                  random.randint(0, 65535))
    206 
    207   def GetProtocolFamily(self, version):
    208     return {4: AF_INET, 6: AF_INET6}[version]
    209 
    210   @classmethod
    211   def CreateTunInterface(cls, netid):
    212     iface = cls.GetInterfaceName(netid)
    213     try:
    214       f = open("/dev/net/tun", "r+b")
    215     except IOError:
    216       f = open("/dev/tun", "r+b")
    217     ifr = struct.pack("16sH", iface, IFF_TAP | IFF_NO_PI)
    218     ifr += "\x00" * (40 - len(ifr))
    219     fcntl.ioctl(f, TUNSETIFF, ifr)
    220     # Give ourselves a predictable MAC address.
    221     net_test.SetInterfaceHWAddr(iface, cls.MyMacAddress(netid))
    222     # Disable DAD so we don't have to wait for it.
    223     cls.SetSysctl("/proc/sys/net/ipv6/conf/%s/accept_dad" % iface, 0)
    224     # Set accept_ra to 2, because that's what we use.
    225     cls.SetSysctl("/proc/sys/net/ipv6/conf/%s/accept_ra" % iface, 2)
    226     net_test.SetInterfaceUp(iface)
    227     net_test.SetNonBlocking(f)
    228     return f
    229 
    230   @classmethod
    231   def SendRA(cls, netid, retranstimer=None, reachabletime=0, options=()):
    232     validity = cls.RA_VALIDITY # seconds
    233     macaddr = cls.RouterMacAddress(netid)
    234     lladdr = cls._RouterAddress(netid, 6)
    235 
    236     if retranstimer is None:
    237       # If no retrans timer was specified, pick one that's as long as the
    238       # router lifetime. This ensures that no spurious ND retransmits
    239       # will interfere with test expectations.
    240       retranstimer = validity * 1000  # Lifetime is in s, retrans timer in ms.
    241 
    242     # We don't want any routes in the main table. If the kernel doesn't support
    243     # putting RA routes into per-interface tables, configure routing manually.
    244     routerlifetime = validity if HAVE_AUTOCONF_TABLE else 0
    245 
    246     ra = (scapy.Ether(src=macaddr, dst="33:33:00:00:00:01") /
    247           scapy.IPv6(src=lladdr, hlim=255) /
    248           scapy.ICMPv6ND_RA(reachabletime=reachabletime,
    249                             retranstimer=retranstimer,
    250                             routerlifetime=routerlifetime) /
    251           scapy.ICMPv6NDOptSrcLLAddr(lladdr=macaddr) /
    252           scapy.ICMPv6NDOptPrefixInfo(prefix=cls.OnlinkPrefix(6, netid),
    253                                       prefixlen=cls.OnlinkPrefixLen(6),
    254                                       L=1, A=1,
    255                                       validlifetime=validity,
    256                                       preferredlifetime=validity))
    257     for option in options:
    258       ra /= option
    259     posix.write(cls.tuns[netid].fileno(), str(ra))
    260 
    261   @classmethod
    262   def _RunSetupCommands(cls, netid, is_add):
    263     for version in [4, 6]:
    264       # Find out how to configure things.
    265       iface = cls.GetInterfaceName(netid)
    266       ifindex = cls.ifindices[netid]
    267       macaddr = cls.RouterMacAddress(netid)
    268       router = cls._RouterAddress(netid, version)
    269       table = cls._TableForNetid(netid)
    270 
    271       # Set up routing rules.
    272       start, end = cls.UidRangeForNetid(netid)
    273       cls.iproute.UidRangeRule(version, is_add, start, end, table,
    274                                cls.PRIORITY_UID)
    275       cls.iproute.OifRule(version, is_add, iface, table, cls.PRIORITY_OIF)
    276       cls.iproute.FwmarkRule(version, is_add, netid, cls.NETID_FWMASK, table,
    277                              cls.PRIORITY_FWMARK)
    278 
    279       # Configure routing and addressing.
    280       #
    281       # IPv6 uses autoconf for everything, except if per-device autoconf routing
    282       # tables are not supported, in which case the default route (only) is
    283       # configured manually. For IPv4 we have to manually configure addresses,
    284       # routes, and neighbour cache entries (since we don't reply to ARP or ND).
    285       #
    286       # Since deleting addresses also causes routes to be deleted, we need to
    287       # be careful with ordering or the delete commands will fail with ENOENT.
    288       #
    289       # A real Android system will have both IPv4 and IPv6 routes for
    290       # directly-connected subnets in the per-interface routing tables. Ensure
    291       # we create those as well.
    292       do_routing = (version == 4 or cls.AUTOCONF_TABLE_OFFSET is None)
    293       if is_add:
    294         if version == 4:
    295           cls.iproute.AddAddress(cls._MyIPv4Address(netid),
    296                                  cls.OnlinkPrefixLen(4), ifindex)
    297           cls.iproute.AddNeighbour(version, router, macaddr, ifindex)
    298         if do_routing:
    299           cls.iproute.AddRoute(version, table,
    300                                cls.OnlinkPrefix(version, netid),
    301                                cls.OnlinkPrefixLen(version), None, ifindex)
    302           cls.iproute.AddRoute(version, table, "default", 0, router, ifindex)
    303       else:
    304         if do_routing:
    305           cls.iproute.DelRoute(version, table, "default", 0, router, ifindex)
    306           cls.iproute.DelRoute(version, table,
    307                                cls.OnlinkPrefix(version, netid),
    308                                cls.OnlinkPrefixLen(version), None, ifindex)
    309         if version == 4:
    310           cls.iproute.DelNeighbour(version, router, macaddr, ifindex)
    311           cls.iproute.DelAddress(cls._MyIPv4Address(netid),
    312                                  cls.OnlinkPrefixLen(4), ifindex)
    313 
    314   @classmethod
    315   def SetMarkReflectSysctls(cls, value):
    316     """Makes kernel-generated replies use the mark of the original packet."""
    317     cls.SetSysctl(IPV4_MARK_REFLECT_SYSCTL, value)
    318     cls.SetSysctl(IPV6_MARK_REFLECT_SYSCTL, value)
    319 
    320   @classmethod
    321   def _SetInboundMarking(cls, netid, iface, is_add):
    322     for version in [4, 6]:
    323       # Run iptables to set up incoming packet marking.
    324       add_del = "-A" if is_add else "-D"
    325       iptables = {4: "iptables", 6: "ip6tables"}[version]
    326       args = "%s INPUT -t mangle -i %s -j MARK --set-mark %d" % (
    327           add_del, iface, netid)
    328       if net_test.RunIptablesCommand(version, args):
    329         raise ConfigurationError("Setup command failed: %s" % args)
    330 
    331   @classmethod
    332   def SetInboundMarks(cls, is_add):
    333     for netid in cls.tuns:
    334       cls._SetInboundMarking(netid, cls.GetInterfaceName(netid), is_add)
    335 
    336   @classmethod
    337   def SetDefaultNetwork(cls, netid):
    338     table = cls._TableForNetid(netid) if netid else None
    339     for version in [4, 6]:
    340       is_add = table is not None
    341       cls.iproute.DefaultRule(version, is_add, table, cls.PRIORITY_DEFAULT)
    342 
    343   @classmethod
    344   def ClearDefaultNetwork(cls):
    345     cls.SetDefaultNetwork(None)
    346 
    347   @classmethod
    348   def GetSysctl(cls, sysctl):
    349     return open(sysctl, "r").read()
    350 
    351   @classmethod
    352   def SetSysctl(cls, sysctl, value):
    353     # Only save each sysctl value the first time we set it. This is so we can
    354     # set it to arbitrary values multiple times and still write it back
    355     # correctly at the end.
    356     if sysctl not in cls.saved_sysctls:
    357       cls.saved_sysctls[sysctl] = cls.GetSysctl(sysctl)
    358     open(sysctl, "w").write(str(value) + "\n")
    359 
    360   @classmethod
    361   def SetIPv6SysctlOnAllIfaces(cls, sysctl, value):
    362     for netid in cls.tuns:
    363       iface = cls.GetInterfaceName(netid)
    364       name = "/proc/sys/net/ipv6/conf/%s/%s" % (iface, sysctl)
    365       cls.SetSysctl(name, value)
    366 
    367   @classmethod
    368   def _RestoreSysctls(cls):
    369     for sysctl, value in cls.saved_sysctls.iteritems():
    370       try:
    371         open(sysctl, "w").write(value)
    372       except IOError:
    373         pass
    374 
    375   @classmethod
    376   def _ICMPRatelimitFilename(cls, version):
    377     return "/proc/sys/net/" + {4: "ipv4/icmp_ratelimit",
    378                                6: "ipv6/icmp/ratelimit"}[version]
    379 
    380   @classmethod
    381   def _SetICMPRatelimit(cls, version, limit):
    382     cls.SetSysctl(cls._ICMPRatelimitFilename(version), limit)
    383 
    384   @classmethod
    385   def setUpClass(cls):
    386     # This is per-class setup instead of per-testcase setup because shelling out
    387     # to ip and iptables is slow, and because routing configuration doesn't
    388     # change during the test.
    389     cls.iproute = iproute.IPRoute()
    390     cls.tuns = {}
    391     cls.ifindices = {}
    392     if HAVE_AUTOCONF_TABLE:
    393       cls.SetSysctl(AUTOCONF_TABLE_SYSCTL, -1000)
    394       cls.AUTOCONF_TABLE_OFFSET = -1000
    395     else:
    396       cls.AUTOCONF_TABLE_OFFSET = None
    397 
    398     # Disable ICMP rate limits. These will be restored by _RestoreSysctls.
    399     for version in [4, 6]:
    400       cls._SetICMPRatelimit(version, 0)
    401 
    402     for version in [4, 6]:
    403       cls.iproute.UnreachableRule(version, True, cls.PRIORITY_UNREACHABLE)
    404 
    405     for netid in cls.NETIDS:
    406       cls.tuns[netid] = cls.CreateTunInterface(netid)
    407       iface = cls.GetInterfaceName(netid)
    408       cls.ifindices[netid] = net_test.GetInterfaceIndex(iface)
    409 
    410       cls.SendRA(netid)
    411       cls._RunSetupCommands(netid, True)
    412 
    413     # Don't print lots of "device foo entered promiscuous mode" warnings.
    414     cls.loglevel = cls.GetConsoleLogLevel()
    415     cls.SetConsoleLogLevel(net_test.KERN_INFO)
    416 
    417     # When running on device, don't send connections through FwmarkServer.
    418     os.environ["ANDROID_NO_USE_FWMARK_CLIENT"] = "1"
    419 
    420     # Uncomment to look around at interface and rule configuration while
    421     # running in the background. (Once the test finishes running, all the
    422     # interfaces and rules are gone.)
    423     # time.sleep(30)
    424 
    425   @classmethod
    426   def tearDownClass(cls):
    427     del os.environ["ANDROID_NO_USE_FWMARK_CLIENT"]
    428 
    429     for version in [4, 6]:
    430       try:
    431         cls.iproute.UnreachableRule(version, False, cls.PRIORITY_UNREACHABLE)
    432       except IOError:
    433         pass
    434 
    435     for netid in cls.tuns:
    436       cls._RunSetupCommands(netid, False)
    437       cls.tuns[netid].close()
    438 
    439     cls._RestoreSysctls()
    440     cls.SetConsoleLogLevel(cls.loglevel)
    441 
    442   def setUp(self):
    443     self.ClearTunQueues()
    444 
    445   def SetSocketMark(self, s, netid):
    446     if netid is None:
    447       netid = 0
    448     s.setsockopt(SOL_SOCKET, net_test.SO_MARK, netid)
    449 
    450   def GetSocketMark(self, s):
    451     return s.getsockopt(SOL_SOCKET, net_test.SO_MARK)
    452 
    453   def ClearSocketMark(self, s):
    454     self.SetSocketMark(s, 0)
    455 
    456   def BindToDevice(self, s, iface):
    457     if not iface:
    458       iface = ""
    459     s.setsockopt(SOL_SOCKET, SO_BINDTODEVICE, iface)
    460 
    461   def SetUnicastInterface(self, s, ifindex):
    462     # Otherwise, Python thinks it's a 1-byte option.
    463     ifindex = struct.pack("!I", ifindex)
    464 
    465     # Always set the IPv4 interface, because it will be used even on IPv6
    466     # sockets if the destination address is a mapped address.
    467     s.setsockopt(net_test.SOL_IP, IP_UNICAST_IF, ifindex)
    468     if s.family == AF_INET6:
    469       s.setsockopt(net_test.SOL_IPV6, IPV6_UNICAST_IF, ifindex)
    470 
    471   def GetRemoteAddress(self, version):
    472     return {4: self.IPV4_ADDR,
    473             5: self.IPV4_ADDR,
    474             6: self.IPV6_ADDR}[version]
    475 
    476   def GetRemoteSocketAddress(self, version):
    477     addr = self.GetRemoteAddress(version)
    478     return "::ffff:" + addr if version == 5 else addr
    479 
    480   def GetOtherRemoteSocketAddress(self, version):
    481     return {4: self.IPV4_ADDR2,
    482             5: "::ffff:" + self.IPV4_ADDR2,
    483             6: self.IPV6_ADDR2}[version]
    484 
    485   def SelectInterface(self, s, netid, mode):
    486     if mode == "uid":
    487       os.fchown(s.fileno(), self.UidForNetid(netid), -1)
    488     elif mode == "mark":
    489       self.SetSocketMark(s, netid)
    490     elif mode == "oif":
    491       iface = self.GetInterfaceName(netid) if netid else ""
    492       self.BindToDevice(s, iface)
    493     elif mode == "ucast_oif":
    494       self.SetUnicastInterface(s, self.ifindices.get(netid, 0))
    495     else:
    496       raise ValueError("Unknown interface selection mode %s" % mode)
    497 
    498   def BuildSocket(self, version, constructor, netid, routing_mode):
    499     if version == 5: version = 6
    500     s = constructor(self.GetProtocolFamily(version))
    501 
    502     if routing_mode not in [None, "uid"]:
    503       self.SelectInterface(s, netid, routing_mode)
    504     elif routing_mode == "uid":
    505       os.fchown(s.fileno(), self.UidForNetid(netid), -1)
    506 
    507     return s
    508 
    509   def RandomNetid(self, exclude=None):
    510     """Return a random netid from the list of netids
    511 
    512     Args:
    513       exclude: a netid or list of netids that should not be chosen
    514     """
    515     if exclude is None:
    516       exclude = []
    517     elif isinstance(exclude, int):
    518         exclude = [exclude]
    519     diff = [netid for netid in self.NETIDS if netid not in exclude]
    520     return random.choice(diff)
    521 
    522   def SendOnNetid(self, version, s, dstaddr, dstport, netid, payload, cmsgs):
    523     if netid is not None:
    524       pktinfo = MakePktInfo(version, None, self.ifindices[netid])
    525       cmsg_level, cmsg_name = {
    526           4: (net_test.SOL_IP, csocket.IP_PKTINFO),
    527           6: (net_test.SOL_IPV6, csocket.IPV6_PKTINFO)}[version]
    528       cmsgs.append((cmsg_level, cmsg_name, pktinfo))
    529     csocket.Sendmsg(s, (dstaddr, dstport), payload, cmsgs, csocket.MSG_CONFIRM)
    530 
    531   def ReceiveEtherPacketOn(self, netid, packet):
    532     posix.write(self.tuns[netid].fileno(), str(packet))
    533 
    534   def ReceivePacketOn(self, netid, ip_packet):
    535     routermac = self.RouterMacAddress(netid)
    536     mymac = self.MyMacAddress(netid)
    537     packet = scapy.Ether(src=routermac, dst=mymac) / ip_packet
    538     self.ReceiveEtherPacketOn(netid, packet)
    539 
    540   def ReadAllPacketsOn(self, netid, include_multicast=False):
    541     """Return all queued packets on a netid as a list.
    542 
    543     Args:
    544       netid: The netid from which to read packets
    545       include_multicast: A boolean, whether to remove multicast packets
    546         (default=False)
    547     """
    548     packets = []
    549     retries = 0
    550     max_retries = 1
    551     while True:
    552       try:
    553         packet = posix.read(self.tuns[netid].fileno(), 4096)
    554         if not packet:
    555           break
    556         ether = scapy.Ether(packet)
    557         # Multicast frames are frames where the first byte of the destination
    558         # MAC address has 1 in the least-significant bit.
    559         if include_multicast or not int(ether.dst.split(":")[0], 16) & 0x1:
    560           packets.append(ether.payload)
    561       except OSError, e:
    562         # EAGAIN means there are no more packets waiting.
    563         if re.match(e.message, os.strerror(errno.EAGAIN)):
    564           # If we didn't see any packets, try again for good luck.
    565           if not packets and retries < max_retries:
    566             time.sleep(0.01)
    567             retries += 1
    568             continue
    569           else:
    570             break
    571         # Anything else is unexpected.
    572         else:
    573           raise e
    574     return packets
    575 
    576   def InvalidateDstCache(self, version, netid):
    577     """Invalidates destination cache entries of sockets on the specified table.
    578 
    579     Creates and then deletes a low-priority throw route in the table for the
    580     given netid, which invalidates the destination cache entries of any sockets
    581     that refer to routes in that table.
    582 
    583     The fact that this method actually invalidates destination cache entries is
    584     tested by OutgoingTest#testIPv[46]Remarking, which checks that the kernel
    585     does not re-route sockets when they are remarked, but does re-route them if
    586     this method is called.
    587 
    588     Args:
    589       version: The IP version, 4 or 6.
    590       netid: The netid to invalidate dst caches on.
    591     """
    592     iface = self.GetInterfaceName(netid)
    593     ifindex = self.ifindices[netid]
    594     table = self._TableForNetid(netid)
    595     for action in [iproute.RTM_NEWROUTE, iproute.RTM_DELROUTE]:
    596       self.iproute._Route(version, iproute.RTPROT_STATIC, action, table,
    597                           "default", 0, nexthop=None, dev=None, mark=None,
    598                           uid=None, route_type=iproute.RTN_THROW,
    599                           priority=100000)
    600 
    601   def ClearTunQueues(self):
    602     # Keep reading packets on all netids until we get no packets on any of them.
    603     waiting = None
    604     while waiting != 0:
    605       waiting = sum(len(self.ReadAllPacketsOn(netid)) for netid in self.NETIDS)
    606 
    607   def assertPacketMatches(self, expected, actual):
    608     # The expected packet is just a rough sketch of the packet we expect to
    609     # receive. For example, it doesn't contain fields we can't predict, such as
    610     # initial TCP sequence numbers, or that depend on the host implementation
    611     # and settings, such as TCP options. To check whether the packet matches
    612     # what we expect, instead of just checking all the known fields one by one,
    613     # we blank out fields in the actual packet and then compare the whole
    614     # packets to each other as strings. Because we modify the actual packet,
    615     # make a copy here.
    616     actual = actual.copy()
    617 
    618     # Blank out IPv4 fields that we can't predict, like ID and the DF bit.
    619     actualip = actual.getlayer("IP")
    620     expectedip = expected.getlayer("IP")
    621     if actualip and expectedip:
    622       actualip.id = expectedip.id
    623       actualip.flags &= 5
    624       actualip.chksum = None  # Change the header, recalculate the checksum.
    625 
    626     # Blank out the flow label, since new kernels randomize it by default.
    627     actualipv6 = actual.getlayer("IPv6")
    628     expectedipv6 = expected.getlayer("IPv6")
    629     if actualipv6 and expectedipv6:
    630       actualipv6.fl = expectedipv6.fl
    631 
    632     # Blank out UDP fields that we can't predict (e.g., the source port for
    633     # kernel-originated packets).
    634     actualudp = actual.getlayer("UDP")
    635     expectedudp = expected.getlayer("UDP")
    636     if actualudp and expectedudp:
    637       if expectedudp.sport is None:
    638         actualudp.sport = None
    639         actualudp.chksum = None
    640       elif actualudp.chksum == 0xffff:
    641         # Scapy does not appear to change 0 to 0xffff as required by RFC 768.
    642         actualudp.chksum = 0
    643 
    644     # Since the TCP code below messes with options, recalculate the length.
    645     if actualip:
    646       actualip.len = None
    647     if actualipv6:
    648       actualipv6.plen = None
    649 
    650     # Blank out TCP fields that we can't predict.
    651     actualtcp = actual.getlayer("TCP")
    652     expectedtcp = expected.getlayer("TCP")
    653     if actualtcp and expectedtcp:
    654       actualtcp.dataofs = expectedtcp.dataofs
    655       actualtcp.options = expectedtcp.options
    656       actualtcp.window = expectedtcp.window
    657       if expectedtcp.sport is None:
    658         actualtcp.sport = None
    659       if expectedtcp.seq is None:
    660         actualtcp.seq = None
    661       if expectedtcp.ack is None:
    662         actualtcp.ack = None
    663       actualtcp.chksum = None
    664 
    665     # Serialize the packet so that expected packet fields that are only set when
    666     # a packet is serialized e.g., the checksum) are filled in.
    667     expected_real = expected.__class__(str(expected))
    668     actual_real = actual.__class__(str(actual))
    669     # repr() can be expensive. Call it only if the test is going to fail and we
    670     # want to see the error.
    671     if expected_real != actual_real:
    672       self.assertEquals(repr(expected_real), repr(actual_real))
    673 
    674   def PacketMatches(self, expected, actual):
    675     try:
    676       self.assertPacketMatches(expected, actual)
    677       return True
    678     except AssertionError:
    679       return False
    680 
    681   def ExpectNoPacketsOn(self, netid, msg):
    682     packets = self.ReadAllPacketsOn(netid)
    683     if packets:
    684       firstpacket = repr(packets[0])
    685     else:
    686       firstpacket = ""
    687     self.assertFalse(packets, msg + ": unexpected packet: " + firstpacket)
    688 
    689   def ExpectPacketOn(self, netid, msg, expected):
    690     # To avoid confusion due to lots of ICMPv6 ND going on all the time, drop
    691     # multicast packets unless the packet we expect to see is a multicast
    692     # packet. For now the only tests that use this are IPv6.
    693     ipv6 = expected.getlayer("IPv6")
    694     if ipv6 and ipv6.dst.startswith("ff"):
    695       include_multicast = True
    696     else:
    697       include_multicast = False
    698 
    699     packets = self.ReadAllPacketsOn(netid, include_multicast=include_multicast)
    700     self.assertTrue(packets, msg + ": received no packets")
    701 
    702     # If we receive a packet that matches what we expected, return it.
    703     for packet in packets:
    704       if self.PacketMatches(expected, packet):
    705         return packet
    706 
    707     # None of the packets matched. Call assertPacketMatches to output a diff
    708     # between the expected packet and the last packet we received. In theory,
    709     # we'd output a diff to the packet that's the best match for what we
    710     # expected, but this is good enough for now.
    711     try:
    712       self.assertPacketMatches(expected, packets[-1])
    713     except Exception, e:
    714       raise UnexpectedPacketError(
    715           "%s: diff with last packet:\n%s" % (msg, e.message))
    716 
    717   def Combinations(self, version):
    718     """Produces a list of combinations to test."""
    719     combinations = []
    720 
    721     # Check packets addressed to the IP addresses of all our interfaces...
    722     for dest_ip_netid in self.tuns:
    723       ip_if = self.GetInterfaceName(dest_ip_netid)
    724       myaddr = self.MyAddress(version, dest_ip_netid)
    725       prefix = {4: "172.22.", 6: "2001:db8:aaaa::"}[version]
    726       remoteaddr = self.GetRandomDestination(prefix)
    727 
    728       # ... coming in on all our interfaces.
    729       for netid in self.tuns:
    730         iif = self.GetInterfaceName(netid)
    731         combinations.append((netid, iif, ip_if, myaddr, remoteaddr))
    732 
    733     return combinations
    734 
    735   def _FormatMessage(self, iif, ip_if, extra, desc, reply_desc):
    736     msg = "Receiving %s on %s to %s IP, %s" % (desc, iif, ip_if, extra)
    737     if reply_desc:
    738       msg += ": Expecting %s on %s" % (reply_desc, iif)
    739     else:
    740       msg += ": Expecting no packets on %s" % iif
    741     return msg
    742 
    743   def _ReceiveAndExpectResponse(self, netid, packet, reply, msg):
    744     self.ReceivePacketOn(netid, packet)
    745     if reply:
    746       return self.ExpectPacketOn(netid, msg, reply)
    747     else:
    748       self.ExpectNoPacketsOn(netid, msg)
    749       return None
    750 
    751 
    752 class InboundMarkingTest(MultiNetworkBaseTest):
    753   """Class that automatically sets up inbound marking."""
    754 
    755   @classmethod
    756   def setUpClass(cls):
    757     super(InboundMarkingTest, cls).setUpClass()
    758     cls.SetInboundMarks(True)
    759 
    760   @classmethod
    761   def tearDownClass(cls):
    762     cls.SetInboundMarks(False)
    763     super(InboundMarkingTest, cls).tearDownClass()
    764