Home | History | Annotate | Download | only in net_test
      1 #!/usr/bin/python
      2 #
      3 # Copyright 2014 The Android Open Source Project
      4 #
      5 # Licensed under the Apache License, Version 2.0 (the "License");
      6 # you may not use this file except in compliance with the License.
      7 # You may obtain a copy of the License at
      8 #
      9 # http://www.apache.org/licenses/LICENSE-2.0
     10 #
     11 # Unless required by applicable law or agreed to in writing, software
     12 # distributed under the License is distributed on an "AS IS" BASIS,
     13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     14 # See the License for the specific language governing permissions and
     15 # limitations under the License.
     16 
     17 import errno
     18 import os
     19 import random
     20 from socket import *  # pylint: disable=wildcard-import
     21 import struct
     22 import time           # pylint: disable=unused-import
     23 import unittest
     24 
     25 from scapy import all as scapy
     26 
     27 import iproute
     28 import multinetwork_base
     29 import net_test
     30 
     31 PING_IDENT = 0xff19
     32 PING_PAYLOAD = "foobarbaz"
     33 PING_SEQ = 3
     34 PING_TOS = 0x83
     35 
     36 IPV6_FLOWINFO = 11
     37 
     38 
     39 UDP_PAYLOAD = str(scapy.DNS(rd=1,
     40                             id=random.randint(0, 65535),
     41                             qd=scapy.DNSQR(qname="wWW.GoOGle.CoM",
     42                                            qtype="AAAA")))
     43 
     44 
     45 IPV4_MARK_REFLECT_SYSCTL = "/proc/sys/net/ipv4/fwmark_reflect"
     46 IPV6_MARK_REFLECT_SYSCTL = "/proc/sys/net/ipv6/fwmark_reflect"
     47 SYNCOOKIES_SYSCTL = "/proc/sys/net/ipv4/tcp_syncookies"
     48 TCP_MARK_ACCEPT_SYSCTL = "/proc/sys/net/ipv4/tcp_fwmark_accept"
     49 
     50 HAVE_MARK_REFLECT = os.path.isfile(IPV4_MARK_REFLECT_SYSCTL)
     51 HAVE_TCP_MARK_ACCEPT = os.path.isfile(TCP_MARK_ACCEPT_SYSCTL)
     52 
     53 # The IP[V6]UNICAST_IF socket option was added between 3.1 and 3.4.
     54 HAVE_UNICAST_IF = net_test.LINUX_VERSION >= (3, 4, 0)
     55 
     56 
     57 class ConfigurationError(AssertionError):
     58   pass
     59 
     60 
     61 class Packets(object):
     62 
     63   TCP_FIN = 1
     64   TCP_SYN = 2
     65   TCP_RST = 4
     66   TCP_PSH = 8
     67   TCP_ACK = 16
     68 
     69   TCP_SEQ = 1692871236
     70   TCP_WINDOW = 14400
     71 
     72   @staticmethod
     73   def RandomPort():
     74     return random.randint(1025, 65535)
     75 
     76   @staticmethod
     77   def _GetIpLayer(version):
     78     return {4: scapy.IP, 6: scapy.IPv6}[version]
     79 
     80   @staticmethod
     81   def _SetPacketTos(packet, tos):
     82     if isinstance(packet, scapy.IPv6):
     83       packet.tc = tos
     84     elif isinstance(packet, scapy.IP):
     85       packet.tos = tos
     86     else:
     87       raise ValueError("Can't find ToS Field")
     88 
     89   @classmethod
     90   def UDP(cls, version, srcaddr, dstaddr, sport=0):
     91     ip = cls._GetIpLayer(version)
     92     # Can't just use "if sport" because None has meaning (it means unspecified).
     93     if sport == 0:
     94       sport = cls.RandomPort()
     95     return ("UDPv%d packet" % version,
     96             ip(src=srcaddr, dst=dstaddr) /
     97             scapy.UDP(sport=sport, dport=53) / UDP_PAYLOAD)
     98 
     99   @classmethod
    100   def UDPWithOptions(cls, version, srcaddr, dstaddr, sport=0):
    101     if version == 4:
    102       packet = (scapy.IP(src=srcaddr, dst=dstaddr, ttl=39, tos=0x83) /
    103                 scapy.UDP(sport=sport, dport=53) /
    104                 UDP_PAYLOAD)
    105     else:
    106       packet = (scapy.IPv6(src=srcaddr, dst=dstaddr,
    107                            fl=0xbeef, hlim=39, tc=0x83) /
    108                 scapy.UDP(sport=sport, dport=53) /
    109                 UDP_PAYLOAD)
    110     return ("UDPv%d packet with options" % version, packet)
    111 
    112   @classmethod
    113   def SYN(cls, dport, version, srcaddr, dstaddr, sport=0, seq=TCP_SEQ):
    114     ip = cls._GetIpLayer(version)
    115     if sport == 0:
    116       sport = cls.RandomPort()
    117     return ("TCP SYN",
    118             ip(src=srcaddr, dst=dstaddr) /
    119             scapy.TCP(sport=sport, dport=dport,
    120                       seq=seq, ack=0,
    121                       flags=cls.TCP_SYN, window=cls.TCP_WINDOW))
    122 
    123   @classmethod
    124   def RST(cls, version, srcaddr, dstaddr, packet):
    125     ip = cls._GetIpLayer(version)
    126     original = packet.getlayer("TCP")
    127     return ("TCP RST",
    128             ip(src=srcaddr, dst=dstaddr) /
    129             scapy.TCP(sport=original.dport, dport=original.sport,
    130                       ack=original.seq + 1, seq=None,
    131                       flags=cls.TCP_RST | cls.TCP_ACK, window=cls.TCP_WINDOW))
    132 
    133   @classmethod
    134   def SYNACK(cls, version, srcaddr, dstaddr, packet):
    135     ip = cls._GetIpLayer(version)
    136     original = packet.getlayer("TCP")
    137     return ("TCP SYN+ACK",
    138             ip(src=srcaddr, dst=dstaddr) /
    139             scapy.TCP(sport=original.dport, dport=original.sport,
    140                       ack=original.seq + 1, seq=None,
    141                       flags=cls.TCP_SYN | cls.TCP_ACK, window=None))
    142 
    143   @classmethod
    144   def ACK(cls, version, srcaddr, dstaddr, packet, payload=""):
    145     ip = cls._GetIpLayer(version)
    146     original = packet.getlayer("TCP")
    147     was_syn_or_fin = (original.flags & (cls.TCP_SYN | cls.TCP_FIN)) != 0
    148     ack_delta = was_syn_or_fin + len(original.payload)
    149     desc = "TCP data" if payload else "TCP ACK"
    150     flags = cls.TCP_ACK | cls.TCP_PSH if payload else cls.TCP_ACK
    151     return (desc,
    152             ip(src=srcaddr, dst=dstaddr) /
    153             scapy.TCP(sport=original.dport, dport=original.sport,
    154                       ack=original.seq + ack_delta, seq=original.ack,
    155                       flags=flags, window=cls.TCP_WINDOW) /
    156             payload)
    157 
    158   @classmethod
    159   def FIN(cls, version, srcaddr, dstaddr, packet):
    160     ip = cls._GetIpLayer(version)
    161     original = packet.getlayer("TCP")
    162     was_fin = (original.flags & cls.TCP_FIN) != 0
    163     return ("TCP FIN",
    164             ip(src=srcaddr, dst=dstaddr) /
    165             scapy.TCP(sport=original.dport, dport=original.sport,
    166                       ack=original.seq + was_fin, seq=original.ack,
    167                       flags=cls.TCP_ACK | cls.TCP_FIN, window=cls.TCP_WINDOW))
    168 
    169   @classmethod
    170   def GRE(cls, version, srcaddr, dstaddr, proto, packet):
    171     if version == 4:
    172       ip = scapy.IP(src=srcaddr, dst=dstaddr, proto=net_test.IPPROTO_GRE)
    173     else:
    174       ip = scapy.IPv6(src=srcaddr, dst=dstaddr, nh=net_test.IPPROTO_GRE)
    175     packet = ip / scapy.GRE(proto=proto) / packet
    176     return ("GRE packet", packet)
    177 
    178   @classmethod
    179   def ICMPPortUnreachable(cls, version, srcaddr, dstaddr, packet):
    180     if version == 4:
    181       # Linux hardcodes the ToS on ICMP errors to 0xc0 or greater because of
    182       # RFC 1812 4.3.2.5 (!).
    183       return ("ICMPv4 port unreachable",
    184               scapy.IP(src=srcaddr, dst=dstaddr, proto=1, tos=0xc0) /
    185               scapy.ICMPerror(type=3, code=3) / packet)
    186     else:
    187       return ("ICMPv6 port unreachable",
    188               scapy.IPv6(src=srcaddr, dst=dstaddr) /
    189               scapy.ICMPv6DestUnreach(code=4) / packet)
    190 
    191   @classmethod
    192   def ICMPPacketTooBig(cls, version, srcaddr, dstaddr, packet):
    193     if version == 4:
    194       return ("ICMPv4 fragmentation needed",
    195               scapy.IP(src=srcaddr, dst=dstaddr, proto=1) /
    196               scapy.ICMPerror(type=3, code=4, unused=1280) / str(packet)[:64])
    197     else:
    198       udp = packet.getlayer("UDP")
    199       udp.payload = str(udp.payload)[:1280-40-8]
    200       return ("ICMPv6 Packet Too Big",
    201               scapy.IPv6(src=srcaddr, dst=dstaddr) /
    202               scapy.ICMPv6PacketTooBig() / str(packet)[:1232])
    203 
    204   @classmethod
    205   def ICMPEcho(cls, version, srcaddr, dstaddr):
    206     ip = cls._GetIpLayer(version)
    207     icmp = {4: scapy.ICMP, 6: scapy.ICMPv6EchoRequest}[version]
    208     packet = (ip(src=srcaddr, dst=dstaddr) /
    209               icmp(id=PING_IDENT, seq=PING_SEQ) / PING_PAYLOAD)
    210     cls._SetPacketTos(packet, PING_TOS)
    211     return ("ICMPv%d echo" % version, packet)
    212 
    213   @classmethod
    214   def ICMPReply(cls, version, srcaddr, dstaddr, packet):
    215     ip = cls._GetIpLayer(version)
    216     # Scapy doesn't provide an ICMP echo reply constructor.
    217     icmpv4_reply = lambda **kwargs: scapy.ICMP(type=0, **kwargs)
    218     icmp = {4: icmpv4_reply, 6: scapy.ICMPv6EchoReply}[version]
    219     packet = (ip(src=srcaddr, dst=dstaddr) /
    220               icmp(id=PING_IDENT, seq=PING_SEQ) / PING_PAYLOAD)
    221     # IPv6 only started copying the tclass to echo replies in 3.14.
    222     if version == 4 or net_test.LINUX_VERSION >= (3, 14):
    223       cls._SetPacketTos(packet, PING_TOS)
    224     return ("ICMPv%d echo reply" % version, packet)
    225 
    226   @classmethod
    227   def NS(cls, srcaddr, tgtaddr, srcmac):
    228     solicited = inet_pton(AF_INET6, tgtaddr)
    229     last3bytes = tuple([ord(b) for b in solicited[-3:]])
    230     solicited = "ff02::1:ff%02x:%02x%02x" % last3bytes
    231     packet = (scapy.IPv6(src=srcaddr, dst=solicited) /
    232               scapy.ICMPv6ND_NS(tgt=tgtaddr) /
    233               scapy.ICMPv6NDOptSrcLLAddr(lladdr=srcmac))
    234     return ("ICMPv6 NS", packet)
    235 
    236   @classmethod
    237   def NA(cls, srcaddr, dstaddr, srcmac):
    238     packet = (scapy.IPv6(src=srcaddr, dst=dstaddr) /
    239               scapy.ICMPv6ND_NA(tgt=srcaddr, R=0, S=1, O=1) /
    240               scapy.ICMPv6NDOptDstLLAddr(lladdr=srcmac))
    241     return ("ICMPv6 NA", packet)
    242 
    243 
    244 class InboundMarkingTest(multinetwork_base.MultiNetworkBaseTest):
    245 
    246   @classmethod
    247   def _SetInboundMarking(cls, netid, is_add):
    248     for version in [4, 6]:
    249       # Run iptables to set up incoming packet marking.
    250       iface = cls.GetInterfaceName(netid)
    251       add_del = "-A" if is_add else "-D"
    252       iptables = {4: "iptables", 6: "ip6tables"}[version]
    253       args = "%s %s INPUT -t mangle -i %s -j MARK --set-mark %d" % (
    254           iptables, add_del, iface, netid)
    255       iptables = "/sbin/" + iptables
    256       ret = os.spawnvp(os.P_WAIT, iptables, args.split(" "))
    257       if ret:
    258         raise ConfigurationError("Setup command failed: %s" % args)
    259 
    260   @classmethod
    261   def setUpClass(cls):
    262     super(InboundMarkingTest, cls).setUpClass()
    263     for netid in cls.tuns:
    264       cls._SetInboundMarking(netid, True)
    265 
    266   @classmethod
    267   def tearDownClass(cls):
    268     for netid in cls.tuns:
    269       cls._SetInboundMarking(netid, False)
    270     super(InboundMarkingTest, cls).tearDownClass()
    271 
    272   @classmethod
    273   def SetMarkReflectSysctls(cls, value):
    274     cls.SetSysctl(IPV4_MARK_REFLECT_SYSCTL, value)
    275     try:
    276       cls.SetSysctl(IPV6_MARK_REFLECT_SYSCTL, value)
    277     except IOError:
    278       # This does not exist if we use the version of the patch that uses a
    279       # common sysctl for IPv4 and IPv6.
    280       pass
    281 
    282 
    283 class OutgoingTest(multinetwork_base.MultiNetworkBaseTest):
    284 
    285   # How many times to run outgoing packet tests.
    286   ITERATIONS = 5
    287 
    288   def CheckPingPacket(self, version, netid, routing_mode, dstaddr, packet):
    289     s = self.BuildSocket(version, net_test.PingSocket, netid, routing_mode)
    290 
    291     myaddr = self.MyAddress(version, netid)
    292     s.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
    293     s.bind((myaddr, PING_IDENT))
    294     net_test.SetSocketTos(s, PING_TOS)
    295 
    296     desc, expected = Packets.ICMPEcho(version, myaddr, dstaddr)
    297     msg = "IPv%d ping: expected %s on %s" % (
    298         version, desc, self.GetInterfaceName(netid))
    299 
    300     s.sendto(packet + PING_PAYLOAD, (dstaddr, 19321))
    301 
    302     self.ExpectPacketOn(netid, msg, expected)
    303 
    304   def CheckTCPSYNPacket(self, version, netid, routing_mode, dstaddr):
    305     s = self.BuildSocket(version, net_test.TCPSocket, netid, routing_mode)
    306 
    307     if version == 6 and dstaddr.startswith("::ffff"):
    308       version = 4
    309     myaddr = self.MyAddress(version, netid)
    310     desc, expected = Packets.SYN(53, version, myaddr, dstaddr,
    311                                  sport=None, seq=None)
    312 
    313     # Non-blocking TCP connects always return EINPROGRESS.
    314     self.assertRaisesErrno(errno.EINPROGRESS, s.connect, (dstaddr, 53))
    315     msg = "IPv%s TCP connect: expected %s on %s" % (
    316         version, desc, self.GetInterfaceName(netid))
    317     self.ExpectPacketOn(netid, msg, expected)
    318     s.close()
    319 
    320   def CheckUDPPacket(self, version, netid, routing_mode, dstaddr):
    321     s = self.BuildSocket(version, net_test.UDPSocket, netid, routing_mode)
    322 
    323     if version == 6 and dstaddr.startswith("::ffff"):
    324       version = 4
    325     myaddr = self.MyAddress(version, netid)
    326     desc, expected = Packets.UDP(version, myaddr, dstaddr, sport=None)
    327     msg = "IPv%s UDP %%s: expected %s on %s" % (
    328         version, desc, self.GetInterfaceName(netid))
    329 
    330     s.sendto(UDP_PAYLOAD, (dstaddr, 53))
    331     self.ExpectPacketOn(netid, msg % "sendto", expected)
    332 
    333     # IP_UNICAST_IF doesn't seem to work on connected sockets, so no TCP.
    334     if routing_mode != "ucast_oif":
    335       s.connect((dstaddr, 53))
    336       s.send(UDP_PAYLOAD)
    337       self.ExpectPacketOn(netid, msg % "connect/send", expected)
    338       s.close()
    339 
    340   def CheckRawGrePacket(self, version, netid, routing_mode, dstaddr):
    341     s = self.BuildSocket(version, net_test.RawGRESocket, netid, routing_mode)
    342 
    343     inner_version = {4: 6, 6: 4}[version]
    344     inner_src = self.MyAddress(inner_version, netid)
    345     inner_dst = self.GetRemoteAddress(inner_version)
    346     inner = str(Packets.UDP(inner_version, inner_src, inner_dst, sport=None)[1])
    347 
    348     ethertype = {4: net_test.ETH_P_IP, 6: net_test.ETH_P_IPV6}[inner_version]
    349     # A GRE header can be as simple as two zero bytes and the ethertype.
    350     packet = struct.pack("!i", ethertype) + inner
    351     myaddr = self.MyAddress(version, netid)
    352 
    353     s.sendto(packet, (dstaddr, IPPROTO_GRE))
    354     desc, expected = Packets.GRE(version, myaddr, dstaddr, ethertype, inner)
    355     msg = "Raw IPv%d GRE with inner IPv%d UDP: expected %s on %s" % (
    356         version, inner_version, desc, self.GetInterfaceName(netid))
    357     self.ExpectPacketOn(netid, msg, expected)
    358 
    359   def CheckOutgoingPackets(self, routing_mode):
    360     v4addr = self.IPV4_ADDR
    361     v6addr = self.IPV6_ADDR
    362     v4mapped = "::ffff:" + v4addr
    363 
    364     for _ in xrange(self.ITERATIONS):
    365       for netid in self.tuns:
    366 
    367         self.CheckPingPacket(4, netid, routing_mode, v4addr, self.IPV4_PING)
    368         # Kernel bug.
    369         if routing_mode != "oif":
    370           self.CheckPingPacket(6, netid, routing_mode, v6addr, self.IPV6_PING)
    371 
    372         # IP_UNICAST_IF doesn't seem to work on connected sockets, so no TCP.
    373         if routing_mode != "ucast_oif":
    374           self.CheckTCPSYNPacket(4, netid, routing_mode, v4addr)
    375           self.CheckTCPSYNPacket(6, netid, routing_mode, v6addr)
    376           self.CheckTCPSYNPacket(6, netid, routing_mode, v4mapped)
    377 
    378         self.CheckUDPPacket(4, netid, routing_mode, v4addr)
    379         self.CheckUDPPacket(6, netid, routing_mode, v6addr)
    380         self.CheckUDPPacket(6, netid, routing_mode, v4mapped)
    381 
    382         # Creating raw sockets on non-root UIDs requires properly setting
    383         # capabilities, which is hard to do from Python.
    384         # IP_UNICAST_IF is not supported on raw sockets.
    385         if routing_mode not in ["uid", "ucast_oif"]:
    386           self.CheckRawGrePacket(4, netid, routing_mode, v4addr)
    387           self.CheckRawGrePacket(6, netid, routing_mode, v6addr)
    388 
    389   def testMarkRouting(self):
    390     """Checks that socket marking selects the right outgoing interface."""
    391     self.CheckOutgoingPackets("mark")
    392 
    393   @unittest.skipUnless(multinetwork_base.HAVE_UID_ROUTING, "no UID routes")
    394   def testUidRouting(self):
    395     """Checks that UID routing selects the right outgoing interface."""
    396     self.CheckOutgoingPackets("uid")
    397 
    398   def testOifRouting(self):
    399     """Checks that oif routing selects the right outgoing interface."""
    400     self.CheckOutgoingPackets("oif")
    401 
    402   @unittest.skipUnless(HAVE_UNICAST_IF, "no support for UNICAST_IF")
    403   def testUcastOifRouting(self):
    404     """Checks that ucast oif routing selects the right outgoing interface."""
    405     self.CheckOutgoingPackets("ucast_oif")
    406 
    407   def CheckRemarking(self, version, use_connect):
    408     # Remarking or resetting UNICAST_IF on connected sockets does not work.
    409     if use_connect:
    410       modes = ["oif"]
    411     else:
    412       modes = ["mark", "oif"]
    413       if HAVE_UNICAST_IF:
    414         modes += ["ucast_oif"]
    415 
    416     for mode in modes:
    417       s = net_test.UDPSocket(self.GetProtocolFamily(version))
    418 
    419       # Figure out what packets to expect.
    420       unspec = {4: "0.0.0.0", 6: "::"}[version]
    421       sport = Packets.RandomPort()
    422       s.bind((unspec, sport))
    423       dstaddr = {4: self.IPV4_ADDR, 6: self.IPV6_ADDR}[version]
    424       desc, expected = Packets.UDP(version, unspec, dstaddr, sport)
    425 
    426       # If we're testing connected sockets, connect the socket on the first
    427       # netid now.
    428       if use_connect:
    429         netid = self.tuns.keys()[0]
    430         self.SelectInterface(s, netid, mode)
    431         s.connect((dstaddr, 53))
    432         expected.src = self.MyAddress(version, netid)
    433 
    434       # For each netid, select that network without closing the socket, and
    435       # check that the packets sent on that socket go out on the right network.
    436       for netid in self.tuns:
    437         self.SelectInterface(s, netid, mode)
    438         if not use_connect:
    439           expected.src = self.MyAddress(version, netid)
    440         s.sendto(UDP_PAYLOAD, (dstaddr, 53))
    441         connected_str = "Connected" if use_connect else "Unconnected"
    442         msg = "%s UDPv%d socket remarked using %s: expecting %s on %s" % (
    443             connected_str, version, mode, desc, self.GetInterfaceName(netid))
    444         self.ExpectPacketOn(netid, msg, expected)
    445         self.SelectInterface(s, None, mode)
    446 
    447   def testIPv4Remarking(self):
    448     """Checks that updating the mark on an IPv4 socket changes routing."""
    449     self.CheckRemarking(4, False)
    450     self.CheckRemarking(4, True)
    451 
    452   def testIPv6Remarking(self):
    453     """Checks that updating the mark on an IPv6 socket changes routing."""
    454     self.CheckRemarking(6, False)
    455     self.CheckRemarking(6, True)
    456 
    457   def testIPv6StickyPktinfo(self):
    458     for _ in xrange(self.ITERATIONS):
    459       for netid in self.tuns:
    460         s = net_test.UDPSocket(AF_INET6)
    461 
    462         # Set a flowlabel.
    463         net_test.SetFlowLabel(s, net_test.IPV6_ADDR, 0xdead)
    464         s.setsockopt(net_test.SOL_IPV6, net_test.IPV6_FLOWINFO_SEND, 1)
    465 
    466         # Set some destination options.
    467         nonce = "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c"
    468         dstopts = "".join([
    469             "\x11\x02",              # Next header=UDP, 24 bytes of options.
    470             "\x01\x06", "\x00" * 6,  # PadN, 6 bytes of padding.
    471             "\x8b\x0c",              # ILNP nonce, 12 bytes.
    472             nonce
    473         ])
    474         s.setsockopt(net_test.SOL_IPV6, IPV6_DSTOPTS, dstopts)
    475         s.setsockopt(net_test.SOL_IPV6, IPV6_UNICAST_HOPS, 255)
    476 
    477         pktinfo = multinetwork_base.MakePktInfo(6, None, self.ifindices[netid])
    478 
    479         # Set the sticky pktinfo option.
    480         s.setsockopt(net_test.SOL_IPV6, IPV6_PKTINFO, pktinfo)
    481 
    482         # Specify the flowlabel in the destination address.
    483         s.sendto(UDP_PAYLOAD, (net_test.IPV6_ADDR, 53, 0xdead, 0))
    484 
    485         sport = s.getsockname()[1]
    486         srcaddr = self.MyAddress(6, netid)
    487         expected = (scapy.IPv6(src=srcaddr, dst=net_test.IPV6_ADDR,
    488                                fl=0xdead, hlim=255) /
    489                     scapy.IPv6ExtHdrDestOpt(
    490                         options=[scapy.PadN(optdata="\x00\x00\x00\x00\x00\x00"),
    491                                  scapy.HBHOptUnknown(otype=0x8b,
    492                                                      optdata=nonce)]) /
    493                     scapy.UDP(sport=sport, dport=53) /
    494                     UDP_PAYLOAD)
    495         msg = "IPv6 UDP using sticky pktinfo: expected UDP packet on %s" % (
    496             self.GetInterfaceName(netid))
    497         self.ExpectPacketOn(netid, msg, expected)
    498 
    499   def CheckPktinfoRouting(self, version):
    500     for _ in xrange(self.ITERATIONS):
    501       for netid in self.tuns:
    502         family = self.GetProtocolFamily(version)
    503         s = net_test.UDPSocket(family)
    504 
    505         if version == 6:
    506           # Create a flowlabel so we can use it.
    507           net_test.SetFlowLabel(s, net_test.IPV6_ADDR, 0xbeef)
    508 
    509           # Specify some arbitrary options.
    510           cmsgs = [
    511               (net_test.SOL_IPV6, IPV6_HOPLIMIT, 39),
    512               (net_test.SOL_IPV6, IPV6_TCLASS, 0x83),
    513               (net_test.SOL_IPV6, IPV6_FLOWINFO, int(htonl(0xbeef))),
    514           ]
    515         else:
    516           # Support for setting IPv4 TOS and TTL via cmsg only appeared in 3.13.
    517           cmsgs = []
    518           s.setsockopt(net_test.SOL_IP, IP_TTL, 39)
    519           s.setsockopt(net_test.SOL_IP, IP_TOS, 0x83)
    520 
    521         dstaddr = self.GetRemoteAddress(version)
    522         self.SendOnNetid(version, s, dstaddr, 53, netid, UDP_PAYLOAD, cmsgs)
    523 
    524         sport = s.getsockname()[1]
    525         srcaddr = self.MyAddress(version, netid)
    526 
    527         desc, expected = Packets.UDPWithOptions(version, srcaddr, dstaddr,
    528                                                 sport=sport)
    529 
    530         msg = "IPv%d UDP using pktinfo routing: expected %s on %s" % (
    531             version, desc, self.GetInterfaceName(netid))
    532         self.ExpectPacketOn(netid, msg, expected)
    533 
    534   def testIPv4PktinfoRouting(self):
    535     self.CheckPktinfoRouting(4)
    536 
    537   def testIPv6PktinfoRouting(self):
    538     self.CheckPktinfoRouting(6)
    539 
    540 
    541 class MarkTest(InboundMarkingTest):
    542 
    543   def CheckReflection(self, version, gen_packet, gen_reply):
    544     """Checks that replies go out on the same interface as the original.
    545 
    546     For each combination:
    547      - Calls gen_packet to generate a packet to that IP address.
    548      - Writes the packet generated by gen_packet on the given tun
    549        interface, causing the kernel to receive it.
    550      - Checks that the kernel's reply matches the packet generated by
    551        gen_reply.
    552 
    553     Args:
    554       version: An integer, 4 or 6.
    555       gen_packet: A function taking an IP version (an integer), a source
    556         address and a destination address (strings), and returning a scapy
    557         packet.
    558       gen_reply: A function taking the same arguments as gen_packet,
    559         plus a scapy packet, and returning a scapy packet.
    560     """
    561     for netid, iif, ip_if, myaddr, remoteaddr in self.Combinations(version):
    562       # Generate a test packet.
    563       desc, packet = gen_packet(version, remoteaddr, myaddr)
    564 
    565       # Test with mark reflection enabled and disabled.
    566       for reflect in [0, 1]:
    567         self.SetMarkReflectSysctls(reflect)
    568         # HACK: IPv6 ping replies always do a routing lookup with the
    569         # interface the ping came in on. So even if mark reflection is not
    570         # working, IPv6 ping replies will be properly reflected. Don't
    571         # fail when that happens.
    572         if reflect or desc == "ICMPv6 echo":
    573           reply_desc, reply = gen_reply(version, myaddr, remoteaddr, packet)
    574         else:
    575           reply_desc, reply = None, None
    576 
    577         msg = self._FormatMessage(iif, ip_if, "reflect=%d" % reflect,
    578                                   desc, reply_desc)
    579         self._ReceiveAndExpectResponse(netid, packet, reply, msg)
    580 
    581   def SYNToClosedPort(self, *args):
    582     return Packets.SYN(999, *args)
    583 
    584   @unittest.skipUnless(HAVE_MARK_REFLECT, "no mark reflection")
    585   def testIPv4ICMPErrorsReflectMark(self):
    586     self.CheckReflection(4, Packets.UDP, Packets.ICMPPortUnreachable)
    587 
    588   @unittest.skipUnless(HAVE_MARK_REFLECT, "no mark reflection")
    589   def testIPv6ICMPErrorsReflectMark(self):
    590     self.CheckReflection(6, Packets.UDP, Packets.ICMPPortUnreachable)
    591 
    592   @unittest.skipUnless(HAVE_MARK_REFLECT, "no mark reflection")
    593   def testIPv4PingRepliesReflectMarkAndTos(self):
    594     self.CheckReflection(4, Packets.ICMPEcho, Packets.ICMPReply)
    595 
    596   @unittest.skipUnless(HAVE_MARK_REFLECT, "no mark reflection")
    597   def testIPv6PingRepliesReflectMarkAndTos(self):
    598     self.CheckReflection(6, Packets.ICMPEcho, Packets.ICMPReply)
    599 
    600   @unittest.skipUnless(HAVE_MARK_REFLECT, "no mark reflection")
    601   def testIPv4RSTsReflectMark(self):
    602     self.CheckReflection(4, self.SYNToClosedPort, Packets.RST)
    603 
    604   @unittest.skipUnless(HAVE_MARK_REFLECT, "no mark reflection")
    605   def testIPv6RSTsReflectMark(self):
    606     self.CheckReflection(6, self.SYNToClosedPort, Packets.RST)
    607 
    608 
    609 class TCPAcceptTest(InboundMarkingTest):
    610 
    611   MODE_BINDTODEVICE = "SO_BINDTODEVICE"
    612   MODE_INCOMING_MARK = "incoming mark"
    613   MODE_EXPLICIT_MARK = "explicit mark"
    614   MODE_UID = "uid"
    615 
    616   @classmethod
    617   def setUpClass(cls):
    618     super(TCPAcceptTest, cls).setUpClass()
    619 
    620     # Open a port so we can observe SYN+ACKs. Since it's a dual-stack socket it
    621     # will accept both IPv4 and IPv6 connections. We do this here instead of in
    622     # each test so we can use the same socket every time. That way, if a kernel
    623     # bug causes incoming packets to mark the listening socket instead of the
    624     # accepted socket, the test will fail as soon as the next address/interface
    625     # combination is tried.
    626     cls.listenport = 1234
    627     cls.listensocket = net_test.IPv6TCPSocket()
    628     cls.listensocket.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
    629     cls.listensocket.bind(("::", cls.listenport))
    630     cls.listensocket.listen(100)
    631 
    632   def BounceSocket(self, s):
    633     """Attempts to invalidate a socket's destination cache entry."""
    634     if s.family == AF_INET:
    635       tos = s.getsockopt(SOL_IP, IP_TOS)
    636       s.setsockopt(net_test.SOL_IP, IP_TOS, 53)
    637       s.setsockopt(net_test.SOL_IP, IP_TOS, tos)
    638     else:
    639       # UDP, 8 bytes dstopts; PAD1, 4 bytes padding; 4 bytes zeros.
    640       pad8 = "".join(["\x11\x00", "\x01\x04", "\x00" * 4])
    641       s.setsockopt(net_test.SOL_IPV6, IPV6_DSTOPTS, pad8)
    642       s.setsockopt(net_test.SOL_IPV6, IPV6_DSTOPTS, "")
    643 
    644   def _SetTCPMarkAcceptSysctl(self, value):
    645     self.SetSysctl(TCP_MARK_ACCEPT_SYSCTL, value)
    646 
    647   def CheckTCPConnection(self, mode, listensocket, netid, version,
    648                          myaddr, remoteaddr, packet, reply, msg):
    649     establishing_ack = Packets.ACK(version, remoteaddr, myaddr, reply)[1]
    650 
    651     # Attempt to confuse the kernel.
    652     self.BounceSocket(listensocket)
    653 
    654     self.ReceivePacketOn(netid, establishing_ack)
    655 
    656     # If we're using UID routing, the accept() call has to be run as a UID that
    657     # is routed to the specified netid, because the UID of the socket returned
    658     # by accept() is the effective UID of the process that calls it. It doesn't
    659     # need to be the same UID; any UID that selects the same interface will do.
    660     with net_test.RunAsUid(self.UidForNetid(netid)):
    661       s, _ = listensocket.accept()
    662 
    663     try:
    664       # Check that data sent on the connection goes out on the right interface.
    665       desc, data = Packets.ACK(version, myaddr, remoteaddr, establishing_ack,
    666                                payload=UDP_PAYLOAD)
    667       s.send(UDP_PAYLOAD)
    668       self.ExpectPacketOn(netid, msg + ": expecting %s" % desc, data)
    669       self.BounceSocket(s)
    670 
    671       # Keep up our end of the conversation.
    672       ack = Packets.ACK(version, remoteaddr, myaddr, data)[1]
    673       self.BounceSocket(listensocket)
    674       self.ReceivePacketOn(netid, ack)
    675 
    676       mark = self.GetSocketMark(s)
    677     finally:
    678       self.BounceSocket(s)
    679       s.close()
    680 
    681     if mode == self.MODE_INCOMING_MARK:
    682       self.assertEquals(netid, mark,
    683                         msg + ": Accepted socket: Expected mark %d, got %d" % (
    684                             netid, mark))
    685     elif mode != self.MODE_EXPLICIT_MARK:
    686       self.assertEquals(0, self.GetSocketMark(listensocket))
    687 
    688     # Check the FIN was sent on the right interface, and ack it. We don't expect
    689     # this to fail because by the time the connection is established things are
    690     # likely working, but a) extra tests are always good and b) extra packets
    691     # like the FIN (and retransmitted FINs) could cause later tests that expect
    692     # no packets to fail.
    693     desc, fin = Packets.FIN(version, myaddr, remoteaddr, ack)
    694     self.ExpectPacketOn(netid, msg + ": expecting %s after close" % desc, fin)
    695 
    696     desc, finack = Packets.FIN(version, remoteaddr, myaddr, fin)
    697     self.ReceivePacketOn(netid, finack)
    698 
    699     # Since we called close() earlier, the userspace socket object is gone, so
    700     # the socket has no UID. If we're doing UID routing, the ack might be routed
    701     # incorrectly. Not much we can do here.
    702     desc, finackack = Packets.ACK(version, myaddr, remoteaddr, finack)
    703     if mode != self.MODE_UID:
    704       self.ExpectPacketOn(netid, msg + ": expecting final ack", finackack)
    705     else:
    706       self.ClearTunQueues()
    707 
    708   def CheckTCP(self, version, modes):
    709     """Checks that incoming TCP connections work.
    710 
    711     Args:
    712       version: An integer, 4 or 6.
    713       modes: A list of modes to excercise.
    714     """
    715     for syncookies in [0, 2]:
    716       for mode in modes:
    717         for netid, iif, ip_if, myaddr, remoteaddr in self.Combinations(version):
    718           if mode == self.MODE_UID:
    719             listensocket = self.BuildSocket(6, net_test.TCPSocket, netid, mode)
    720             listensocket.listen(100)
    721           else:
    722             listensocket = self.listensocket
    723 
    724           listenport = listensocket.getsockname()[1]
    725 
    726           if HAVE_TCP_MARK_ACCEPT:
    727             accept_sysctl = 1 if mode == self.MODE_INCOMING_MARK else 0
    728             self._SetTCPMarkAcceptSysctl(accept_sysctl)
    729 
    730           bound_dev = iif if mode == self.MODE_BINDTODEVICE else None
    731           self.BindToDevice(listensocket, bound_dev)
    732 
    733           mark = netid if mode == self.MODE_EXPLICIT_MARK else 0
    734           self.SetSocketMark(listensocket, mark)
    735 
    736           # Generate the packet here instead of in the outer loop, so
    737           # subsequent TCP connections use different source ports and
    738           # retransmissions from old connections don't confuse subsequent
    739           # tests.
    740           desc, packet = Packets.SYN(listenport, version, remoteaddr, myaddr)
    741 
    742           if mode:
    743             reply_desc, reply = Packets.SYNACK(version, myaddr, remoteaddr,
    744                                                packet)
    745           else:
    746             reply_desc, reply = None, None
    747 
    748           extra = "mode=%s, syncookies=%d" % (mode, syncookies)
    749           msg = self._FormatMessage(iif, ip_if, extra, desc, reply_desc)
    750           reply = self._ReceiveAndExpectResponse(netid, packet, reply, msg)
    751           if reply:
    752             self.CheckTCPConnection(mode, listensocket, netid, version, myaddr,
    753                                     remoteaddr, packet, reply, msg)
    754 
    755   def testBasicTCP(self):
    756     self.CheckTCP(4, [None, self.MODE_BINDTODEVICE, self.MODE_EXPLICIT_MARK])
    757     self.CheckTCP(6, [None, self.MODE_BINDTODEVICE, self.MODE_EXPLICIT_MARK])
    758 
    759   @unittest.skipUnless(HAVE_TCP_MARK_ACCEPT, "fwmark writeback not supported")
    760   def testIPv4MarkAccept(self):
    761     self.CheckTCP(4, [self.MODE_INCOMING_MARK])
    762 
    763   @unittest.skipUnless(HAVE_TCP_MARK_ACCEPT, "fwmark writeback not supported")
    764   def testIPv6MarkAccept(self):
    765     self.CheckTCP(6, [self.MODE_INCOMING_MARK])
    766 
    767   @unittest.skipUnless(multinetwork_base.HAVE_UID_ROUTING, "no UID routes")
    768   def testIPv4UidAccept(self):
    769     self.CheckTCP(4, [self.MODE_UID])
    770 
    771   @unittest.skipUnless(multinetwork_base.HAVE_UID_ROUTING, "no UID routes")
    772   def testIPv6UidAccept(self):
    773     self.CheckTCP(6, [self.MODE_UID])
    774 
    775   def testIPv6ExplicitMark(self):
    776     self.CheckTCP(6, [self.MODE_EXPLICIT_MARK])
    777 
    778 
    779 class RATest(multinetwork_base.MultiNetworkBaseTest):
    780 
    781   def testDoesNotHaveObsoleteSysctl(self):
    782     self.assertFalse(os.path.isfile(
    783         "/proc/sys/net/ipv6/route/autoconf_table_offset"))
    784 
    785   @unittest.skipUnless(multinetwork_base.HAVE_AUTOCONF_TABLE,
    786                        "no support for per-table autoconf")
    787   def testPurgeDefaultRouters(self):
    788 
    789     def CheckIPv6Connectivity(expect_connectivity):
    790       for netid in self.NETIDS:
    791         s = net_test.UDPSocket(AF_INET6)
    792         self.SetSocketMark(s, netid)
    793         if expect_connectivity:
    794           self.assertTrue(s.sendto(UDP_PAYLOAD, (net_test.IPV6_ADDR, 1234)))
    795         else:
    796           self.assertRaisesErrno(errno.ENETUNREACH, s.sendto, UDP_PAYLOAD,
    797                                  (net_test.IPV6_ADDR, 1234))
    798 
    799     try:
    800       CheckIPv6Connectivity(True)
    801       self.SetSysctl("/proc/sys/net/ipv6/conf/all/forwarding", 1)
    802       CheckIPv6Connectivity(False)
    803     finally:
    804       self.SetSysctl("/proc/sys/net/ipv6/conf/all/forwarding", 0)
    805       for netid in self.NETIDS:
    806         self.SendRA(netid)
    807       CheckIPv6Connectivity(True)
    808 
    809   def testOnlinkCommunication(self):
    810     """Checks that on-link communication goes direct and not through routers."""
    811     for netid in self.tuns:
    812       # Send a UDP packet to a random on-link destination.
    813       s = net_test.UDPSocket(AF_INET6)
    814       iface = self.GetInterfaceName(netid)
    815       self.BindToDevice(s, iface)
    816       # dstaddr can never be our address because GetRandomDestination only fills
    817       # in the lower 32 bits, but our address has 0xff in the byte before that
    818       # (since it's constructed from the EUI-64 and so has ff:fe in the middle).
    819       dstaddr = self.GetRandomDestination(self.IPv6Prefix(netid))
    820       s.sendto(UDP_PAYLOAD, (dstaddr, 53))
    821 
    822       # Expect an NS for that destination on the interface.
    823       myaddr = self.MyAddress(6, netid)
    824       mymac = self.MyMacAddress(netid)
    825       desc, expected = Packets.NS(myaddr, dstaddr, mymac)
    826       msg = "Sending UDP packet to on-link destination: expecting %s" % desc
    827       time.sleep(0.0001)  # Required to make the test work on kernel 3.1(!)
    828       self.ExpectPacketOn(netid, msg, expected)
    829 
    830       # Send an NA.
    831       tgtmac = "02:00:00:00:%02x:99" % netid
    832       _, reply = Packets.NA(dstaddr, myaddr, tgtmac)
    833       # Don't use ReceivePacketOn, since that uses the router's MAC address as
    834       # the source. Instead, construct our own Ethernet header with source
    835       # MAC of tgtmac.
    836       reply = scapy.Ether(src=tgtmac, dst=mymac) / reply
    837       self.ReceiveEtherPacketOn(netid, reply)
    838 
    839       # Expect the kernel to send the original UDP packet now that the ND cache
    840       # entry has been populated.
    841       sport = s.getsockname()[1]
    842       desc, expected = Packets.UDP(6, myaddr, dstaddr, sport=sport)
    843       msg = "After NA response, expecting %s" % desc
    844       self.ExpectPacketOn(netid, msg, expected)
    845 
    846   # This test documents a known issue: routing tables are never deleted.
    847   @unittest.skipUnless(multinetwork_base.HAVE_AUTOCONF_TABLE,
    848                        "no support for per-table autoconf")
    849   def testLeftoverRoutes(self):
    850     def GetNumRoutes():
    851       return len(open("/proc/net/ipv6_route").readlines())
    852 
    853     num_routes = GetNumRoutes()
    854     for i in xrange(10, 20):
    855       try:
    856         self.tuns[i] = self.CreateTunInterface(i)
    857         self.SendRA(i)
    858         self.tuns[i].close()
    859       finally:
    860         del self.tuns[i]
    861     self.assertLess(num_routes, GetNumRoutes())
    862 
    863 
    864 class PMTUTest(InboundMarkingTest):
    865 
    866   PAYLOAD_SIZE = 1400
    867 
    868   # Socket options to change PMTU behaviour.
    869   IP_MTU_DISCOVER = 10
    870   IP_PMTUDISC_DO = 1
    871   IPV6_DONTFRAG = 62
    872 
    873   # Socket options to get the MTU.
    874   IP_MTU = 14
    875   IPV6_PATHMTU = 61
    876 
    877   def GetSocketMTU(self, version, s):
    878     if version == 6:
    879       ip6_mtuinfo = s.getsockopt(net_test.SOL_IPV6, self.IPV6_PATHMTU, 32)
    880       unused_sockaddr, mtu = struct.unpack("=28sI", ip6_mtuinfo)
    881       return mtu
    882     else:
    883       return s.getsockopt(net_test.SOL_IP, self.IP_MTU)
    884 
    885   def DisableFragmentationAndReportErrors(self, version, s):
    886     if version == 4:
    887       s.setsockopt(net_test.SOL_IP, self.IP_MTU_DISCOVER, self.IP_PMTUDISC_DO)
    888       s.setsockopt(net_test.SOL_IP, net_test.IP_RECVERR, 1)
    889     else:
    890       s.setsockopt(net_test.SOL_IPV6, self.IPV6_DONTFRAG, 1)
    891       s.setsockopt(net_test.SOL_IPV6, net_test.IPV6_RECVERR, 1)
    892 
    893   def CheckPMTU(self, version, use_connect, modes):
    894 
    895     def SendBigPacket(version, s, dstaddr, netid, payload):
    896       if use_connect:
    897         s.send(payload)
    898       else:
    899         self.SendOnNetid(version, s, dstaddr, 1234, netid, payload, [])
    900 
    901     for netid in self.tuns:
    902       for mode in modes:
    903         s = self.BuildSocket(version, net_test.UDPSocket, netid, mode)
    904         self.DisableFragmentationAndReportErrors(version, s)
    905 
    906         srcaddr = self.MyAddress(version, netid)
    907         dst_prefix, intermediate = {
    908             4: ("172.19.", "172.16.9.12"),
    909             6: ("2001:db8::", "2001:db8::1")
    910         }[version]
    911         dstaddr = self.GetRandomDestination(dst_prefix)
    912 
    913         if use_connect:
    914           s.connect((dstaddr, 1234))
    915 
    916         payload = self.PAYLOAD_SIZE * "a"
    917 
    918         # Send a packet and receive a packet too big.
    919         SendBigPacket(version, s, dstaddr, netid, payload)
    920         packets = self.ReadAllPacketsOn(netid)
    921         self.assertEquals(1, len(packets))
    922         _, toobig = Packets.ICMPPacketTooBig(version, intermediate, srcaddr,
    923                                              packets[0])
    924         self.ReceivePacketOn(netid, toobig)
    925 
    926         # Check that another send on the same socket returns EMSGSIZE.
    927         self.assertRaisesErrno(
    928             errno.EMSGSIZE,
    929             SendBigPacket, version, s, dstaddr, netid, payload)
    930 
    931         # If this is a connected socket, make sure the socket MTU was set.
    932         # Note that in IPv4 this only started working in Linux 3.6!
    933         if use_connect and (version == 6 or net_test.LINUX_VERSION >= (3, 6)):
    934           self.assertEquals(1280, self.GetSocketMTU(version, s))
    935 
    936         s.close()
    937 
    938         # Check that other sockets pick up the PMTU we have been told about by
    939         # connecting another socket to the same destination and getting its MTU.
    940         # This new socket can use any method to select its outgoing interface;
    941         # here we use a mark for simplicity.
    942         s2 = self.BuildSocket(version, net_test.UDPSocket, netid, "mark")
    943         s2.connect((dstaddr, 1234))
    944         self.assertEquals(1280, self.GetSocketMTU(version, s2))
    945 
    946         # Also check the MTU reported by ip route get, this time using the oif.
    947         routes = self.iproute.GetRoutes(dstaddr, self.ifindices[netid], 0, None)
    948         self.assertTrue(routes)
    949         route = routes[0]
    950         rtmsg, attributes = route
    951         self.assertEquals(iproute.RTN_UNICAST, rtmsg.type)
    952         metrics = attributes["RTA_METRICS"]
    953         self.assertEquals(metrics["RTAX_MTU"], 1280)
    954 
    955   def testIPv4BasicPMTU(self):
    956     self.CheckPMTU(4, True, ["mark", "oif"])
    957     self.CheckPMTU(4, False, ["mark", "oif"])
    958 
    959   def testIPv6BasicPMTU(self):
    960     self.CheckPMTU(6, True, ["mark", "oif"])
    961     self.CheckPMTU(6, False, ["mark", "oif"])
    962 
    963   @unittest.skipUnless(multinetwork_base.HAVE_UID_ROUTING, "no UID routes")
    964   def testIPv4UIDPMTU(self):
    965     self.CheckPMTU(4, True, ["uid"])
    966     self.CheckPMTU(4, False, ["uid"])
    967 
    968   @unittest.skipUnless(multinetwork_base.HAVE_UID_ROUTING, "no UID routes")
    969   def testIPv6UIDPMTU(self):
    970     self.CheckPMTU(6, True, ["uid"])
    971     self.CheckPMTU(6, False, ["uid"])
    972 
    973   # Making Path MTU Discovery work on unmarked  sockets requires that mark
    974   # reflection be enabled. Otherwise the kernel has no way to know what routing
    975   # table the original packet used, and thus it won't be able to clone the
    976   # correct route.
    977 
    978   @unittest.skipUnless(HAVE_MARK_REFLECT, "no mark reflection")
    979   def testIPv4UnmarkedSocketPMTU(self):
    980     self.SetMarkReflectSysctls(1)
    981     try:
    982       self.CheckPMTU(4, False, [None])
    983     finally:
    984       self.SetMarkReflectSysctls(0)
    985 
    986   @unittest.skipUnless(HAVE_MARK_REFLECT, "no mark reflection")
    987   def testIPv6UnmarkedSocketPMTU(self):
    988     self.SetMarkReflectSysctls(1)
    989     try:
    990       self.CheckPMTU(6, False, [None])
    991     finally:
    992       self.SetMarkReflectSysctls(0)
    993 
    994 
    995 @unittest.skipUnless(multinetwork_base.HAVE_UID_ROUTING, "no UID routes")
    996 class UidRoutingTest(multinetwork_base.MultiNetworkBaseTest):
    997 
    998   def GetRulesAtPriority(self, version, priority):
    999     rules = self.iproute.DumpRules(version)
   1000     out = [(rule, attributes) for rule, attributes in rules
   1001            if attributes.get("FRA_PRIORITY", 0) == priority]
   1002     return out
   1003 
   1004   def CheckInitialTablesHaveNoUIDs(self, version):
   1005     rules = []
   1006     for priority in [0, 32766, 32767]:
   1007       rules.extend(self.GetRulesAtPriority(version, priority))
   1008     for _, attributes in rules:
   1009       self.assertNotIn("FRA_UID_START", attributes)
   1010       self.assertNotIn("FRA_UID_END", attributes)
   1011 
   1012   def testIPv4InitialTablesHaveNoUIDs(self):
   1013     self.CheckInitialTablesHaveNoUIDs(4)
   1014 
   1015   def testIPv6InitialTablesHaveNoUIDs(self):
   1016     self.CheckInitialTablesHaveNoUIDs(6)
   1017 
   1018   def CheckGetAndSetRules(self, version):
   1019     def Random():
   1020       return random.randint(1000000, 2000000)
   1021 
   1022     start, end = tuple(sorted([Random(), Random()]))
   1023     table = Random()
   1024     priority = Random()
   1025 
   1026     try:
   1027       self.iproute.UidRangeRule(version, True, start, end, table,
   1028                                 priority=priority)
   1029 
   1030       rules = self.GetRulesAtPriority(version, priority)
   1031       self.assertTrue(rules)
   1032       _, attributes = rules[-1]
   1033       self.assertEquals(priority, attributes["FRA_PRIORITY"])
   1034       self.assertEquals(start, attributes["FRA_UID_START"])
   1035       self.assertEquals(end, attributes["FRA_UID_END"])
   1036       self.assertEquals(table, attributes["FRA_TABLE"])
   1037     finally:
   1038       self.iproute.UidRangeRule(version, False, start, end, table,
   1039                                 priority=priority)
   1040 
   1041   def testIPv4GetAndSetRules(self):
   1042     self.CheckGetAndSetRules(4)
   1043 
   1044   def testIPv6GetAndSetRules(self):
   1045     self.CheckGetAndSetRules(6)
   1046 
   1047   def ExpectNoRoute(self, addr, oif, mark, uid):
   1048     # The lack of a route may be either an error, or an unreachable route.
   1049     try:
   1050       routes = self.iproute.GetRoutes(addr, oif, mark, uid)
   1051       rtmsg, _ = routes[0]
   1052       self.assertEquals(iproute.RTN_UNREACHABLE, rtmsg.type)
   1053     except IOError, e:
   1054       if int(e.errno) != -int(errno.ENETUNREACH):
   1055         raise e
   1056 
   1057   def ExpectRoute(self, addr, oif, mark, uid):
   1058     routes = self.iproute.GetRoutes(addr, oif, mark, uid)
   1059     rtmsg, _ = routes[0]
   1060     self.assertEquals(iproute.RTN_UNICAST, rtmsg.type)
   1061 
   1062   def CheckGetRoute(self, version, addr):
   1063     self.ExpectNoRoute(addr, 0, 0, 0)
   1064     for netid in self.NETIDS:
   1065       uid = self.UidForNetid(netid)
   1066       self.ExpectRoute(addr, 0, 0, uid)
   1067     self.ExpectNoRoute(addr, 0, 0, 0)
   1068 
   1069   def testIPv4RouteGet(self):
   1070     self.CheckGetRoute(4, net_test.IPV4_ADDR)
   1071 
   1072   def testIPv6RouteGet(self):
   1073     self.CheckGetRoute(6, net_test.IPV6_ADDR)
   1074 
   1075 
   1076 class RulesTest(net_test.NetworkTest):
   1077 
   1078   RULE_PRIORITY = 99999
   1079 
   1080   def setUp(self):
   1081     self.iproute = iproute.IPRoute()
   1082     for version in [4, 6]:
   1083       self.iproute.DeleteRulesAtPriority(version, self.RULE_PRIORITY)
   1084 
   1085   def tearDown(self):
   1086     for version in [4, 6]:
   1087       self.iproute.DeleteRulesAtPriority(version, self.RULE_PRIORITY)
   1088 
   1089   def testRuleDeletionMatchesTable(self):
   1090     for version in [4, 6]:
   1091       # Add rules with mark 300 pointing at tables 301 and 302.
   1092       # This checks for a kernel bug where deletion request for tables > 256
   1093       # ignored the table.
   1094       self.iproute.FwmarkRule(version, True, 300, 301,
   1095                               priority=self.RULE_PRIORITY)
   1096       self.iproute.FwmarkRule(version, True, 300, 302,
   1097                               priority=self.RULE_PRIORITY)
   1098       # Delete rule with mark 300 pointing at table 302.
   1099       self.iproute.FwmarkRule(version, False, 300, 302,
   1100                               priority=self.RULE_PRIORITY)
   1101       # Check that the rule pointing at table 301 is still around.
   1102       attributes = [a for _, a in self.iproute.DumpRules(version)
   1103                     if a.get("FRA_PRIORITY", 0) == self.RULE_PRIORITY]
   1104       self.assertEquals(1, len(attributes))
   1105       self.assertEquals(301, attributes[0]["FRA_TABLE"])
   1106 
   1107 
   1108 if __name__ == "__main__":
   1109   unittest.main()
   1110