Home | History | Annotate | Download | only in net_test
      1 #!/usr/bin/python
      2 #
      3 # Copyright 2014 The Android Open Source Project
      4 #
      5 # Licensed under the Apache License, Version 2.0 (the "License");
      6 # you may not use this file except in compliance with the License.
      7 # You may obtain a copy of the License at
      8 #
      9 # http://www.apache.org/licenses/LICENSE-2.0
     10 #
     11 # Unless required by applicable law or agreed to in writing, software
     12 # distributed under the License is distributed on an "AS IS" BASIS,
     13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     14 # See the License for the specific language governing permissions and
     15 # limitations under the License.
     16 
     17 import errno
     18 import os
     19 import random
     20 from socket import *  # pylint: disable=wildcard-import
     21 import struct
     22 import time           # pylint: disable=unused-import
     23 import unittest
     24 
     25 from scapy import all as scapy
     26 
     27 import iproute
     28 import multinetwork_base
     29 import net_test
     30 import packets
     31 
     32 # For brevity.
     33 UDP_PAYLOAD = net_test.UDP_PAYLOAD
     34 
     35 IPV6_FLOWINFO = 11
     36 
     37 IPV4_MARK_REFLECT_SYSCTL = "/proc/sys/net/ipv4/fwmark_reflect"
     38 IPV6_MARK_REFLECT_SYSCTL = "/proc/sys/net/ipv6/fwmark_reflect"
     39 SYNCOOKIES_SYSCTL = "/proc/sys/net/ipv4/tcp_syncookies"
     40 TCP_MARK_ACCEPT_SYSCTL = "/proc/sys/net/ipv4/tcp_fwmark_accept"
     41 
     42 # The IP[V6]UNICAST_IF socket option was added between 3.1 and 3.4.
     43 HAVE_UNICAST_IF = net_test.LINUX_VERSION >= (3, 4, 0)
     44 
     45 
     46 class ConfigurationError(AssertionError):
     47   pass
     48 
     49 
     50 class InboundMarkingTest(multinetwork_base.MultiNetworkBaseTest):
     51 
     52   @classmethod
     53   def _SetInboundMarking(cls, netid, is_add):
     54     for version in [4, 6]:
     55       # Run iptables to set up incoming packet marking.
     56       iface = cls.GetInterfaceName(netid)
     57       add_del = "-A" if is_add else "-D"
     58       iptables = {4: "iptables", 6: "ip6tables"}[version]
     59       args = "%s %s INPUT -t mangle -i %s -j MARK --set-mark %d" % (
     60           iptables, add_del, iface, netid)
     61       iptables = "/sbin/" + iptables
     62       ret = os.spawnvp(os.P_WAIT, iptables, args.split(" "))
     63       if ret:
     64         raise ConfigurationError("Setup command failed: %s" % args)
     65 
     66   @classmethod
     67   def setUpClass(cls):
     68     super(InboundMarkingTest, cls).setUpClass()
     69     for netid in cls.tuns:
     70       cls._SetInboundMarking(netid, True)
     71 
     72   @classmethod
     73   def tearDownClass(cls):
     74     for netid in cls.tuns:
     75       cls._SetInboundMarking(netid, False)
     76     super(InboundMarkingTest, cls).tearDownClass()
     77 
     78   @classmethod
     79   def SetMarkReflectSysctls(cls, value):
     80     cls.SetSysctl(IPV4_MARK_REFLECT_SYSCTL, value)
     81     try:
     82       cls.SetSysctl(IPV6_MARK_REFLECT_SYSCTL, value)
     83     except IOError:
     84       # This does not exist if we use the version of the patch that uses a
     85       # common sysctl for IPv4 and IPv6.
     86       pass
     87 
     88 
     89 class OutgoingTest(multinetwork_base.MultiNetworkBaseTest):
     90 
     91   # How many times to run outgoing packet tests.
     92   ITERATIONS = 5
     93 
     94   def CheckPingPacket(self, version, netid, routing_mode, dstaddr, packet):
     95     s = self.BuildSocket(version, net_test.PingSocket, netid, routing_mode)
     96 
     97     myaddr = self.MyAddress(version, netid)
     98     s.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
     99     s.bind((myaddr, packets.PING_IDENT))
    100     net_test.SetSocketTos(s, packets.PING_TOS)
    101 
    102     desc, expected = packets.ICMPEcho(version, myaddr, dstaddr)
    103     msg = "IPv%d ping: expected %s on %s" % (
    104         version, desc, self.GetInterfaceName(netid))
    105 
    106     s.sendto(packet + packets.PING_PAYLOAD, (dstaddr, 19321))
    107 
    108     self.ExpectPacketOn(netid, msg, expected)
    109 
    110   def CheckTCPSYNPacket(self, version, netid, routing_mode, dstaddr):
    111     s = self.BuildSocket(version, net_test.TCPSocket, netid, routing_mode)
    112 
    113     if version == 6 and dstaddr.startswith("::ffff"):
    114       version = 4
    115     myaddr = self.MyAddress(version, netid)
    116     desc, expected = packets.SYN(53, version, myaddr, dstaddr,
    117                                  sport=None, seq=None)
    118 
    119     # Non-blocking TCP connects always return EINPROGRESS.
    120     self.assertRaisesErrno(errno.EINPROGRESS, s.connect, (dstaddr, 53))
    121     msg = "IPv%s TCP connect: expected %s on %s" % (
    122         version, desc, self.GetInterfaceName(netid))
    123     self.ExpectPacketOn(netid, msg, expected)
    124     s.close()
    125 
    126   def CheckUDPPacket(self, version, netid, routing_mode, dstaddr):
    127     s = self.BuildSocket(version, net_test.UDPSocket, netid, routing_mode)
    128 
    129     if version == 6 and dstaddr.startswith("::ffff"):
    130       version = 4
    131     myaddr = self.MyAddress(version, netid)
    132     desc, expected = packets.UDP(version, myaddr, dstaddr, sport=None)
    133     msg = "IPv%s UDP %%s: expected %s on %s" % (
    134         version, desc, self.GetInterfaceName(netid))
    135 
    136     s.sendto(UDP_PAYLOAD, (dstaddr, 53))
    137     self.ExpectPacketOn(netid, msg % "sendto", expected)
    138 
    139     # IP_UNICAST_IF doesn't seem to work on connected sockets, so no TCP.
    140     if routing_mode != "ucast_oif":
    141       s.connect((dstaddr, 53))
    142       s.send(UDP_PAYLOAD)
    143       self.ExpectPacketOn(netid, msg % "connect/send", expected)
    144       s.close()
    145 
    146   def CheckRawGrePacket(self, version, netid, routing_mode, dstaddr):
    147     s = self.BuildSocket(version, net_test.RawGRESocket, netid, routing_mode)
    148 
    149     inner_version = {4: 6, 6: 4}[version]
    150     inner_src = self.MyAddress(inner_version, netid)
    151     inner_dst = self.GetRemoteAddress(inner_version)
    152     inner = str(packets.UDP(inner_version, inner_src, inner_dst, sport=None)[1])
    153 
    154     ethertype = {4: net_test.ETH_P_IP, 6: net_test.ETH_P_IPV6}[inner_version]
    155     # A GRE header can be as simple as two zero bytes and the ethertype.
    156     packet = struct.pack("!i", ethertype) + inner
    157     myaddr = self.MyAddress(version, netid)
    158 
    159     s.sendto(packet, (dstaddr, IPPROTO_GRE))
    160     desc, expected = packets.GRE(version, myaddr, dstaddr, ethertype, inner)
    161     msg = "Raw IPv%d GRE with inner IPv%d UDP: expected %s on %s" % (
    162         version, inner_version, desc, self.GetInterfaceName(netid))
    163     self.ExpectPacketOn(netid, msg, expected)
    164 
    165   def CheckOutgoingPackets(self, routing_mode):
    166     v4addr = self.IPV4_ADDR
    167     v6addr = self.IPV6_ADDR
    168     v4mapped = "::ffff:" + v4addr
    169 
    170     for _ in xrange(self.ITERATIONS):
    171       for netid in self.tuns:
    172 
    173         self.CheckPingPacket(4, netid, routing_mode, v4addr, self.IPV4_PING)
    174         # Kernel bug.
    175         if routing_mode != "oif":
    176           self.CheckPingPacket(6, netid, routing_mode, v6addr, self.IPV6_PING)
    177 
    178         # IP_UNICAST_IF doesn't seem to work on connected sockets, so no TCP.
    179         if routing_mode != "ucast_oif":
    180           self.CheckTCPSYNPacket(4, netid, routing_mode, v4addr)
    181           self.CheckTCPSYNPacket(6, netid, routing_mode, v6addr)
    182           self.CheckTCPSYNPacket(6, netid, routing_mode, v4mapped)
    183 
    184         self.CheckUDPPacket(4, netid, routing_mode, v4addr)
    185         self.CheckUDPPacket(6, netid, routing_mode, v6addr)
    186         self.CheckUDPPacket(6, netid, routing_mode, v4mapped)
    187 
    188         # Creating raw sockets on non-root UIDs requires properly setting
    189         # capabilities, which is hard to do from Python.
    190         # IP_UNICAST_IF is not supported on raw sockets.
    191         if routing_mode not in ["uid", "ucast_oif"]:
    192           self.CheckRawGrePacket(4, netid, routing_mode, v4addr)
    193           self.CheckRawGrePacket(6, netid, routing_mode, v6addr)
    194 
    195   def testMarkRouting(self):
    196     """Checks that socket marking selects the right outgoing interface."""
    197     self.CheckOutgoingPackets("mark")
    198 
    199   @unittest.skipUnless(multinetwork_base.HAVE_UID_ROUTING, "no UID routes")
    200   def testUidRouting(self):
    201     """Checks that UID routing selects the right outgoing interface."""
    202     self.CheckOutgoingPackets("uid")
    203 
    204   def testOifRouting(self):
    205     """Checks that oif routing selects the right outgoing interface."""
    206     self.CheckOutgoingPackets("oif")
    207 
    208   @unittest.skipUnless(HAVE_UNICAST_IF, "no support for UNICAST_IF")
    209   def testUcastOifRouting(self):
    210     """Checks that ucast oif routing selects the right outgoing interface."""
    211     self.CheckOutgoingPackets("ucast_oif")
    212 
    213   def CheckRemarking(self, version, use_connect):
    214     # Remarking or resetting UNICAST_IF on connected sockets does not work.
    215     if use_connect:
    216       modes = ["oif"]
    217     else:
    218       modes = ["mark", "oif"]
    219       if HAVE_UNICAST_IF:
    220         modes += ["ucast_oif"]
    221 
    222     for mode in modes:
    223       s = net_test.UDPSocket(self.GetProtocolFamily(version))
    224 
    225       # Figure out what packets to expect.
    226       unspec = {4: "0.0.0.0", 6: "::"}[version]
    227       sport = packets.RandomPort()
    228       s.bind((unspec, sport))
    229       dstaddr = {4: self.IPV4_ADDR, 6: self.IPV6_ADDR}[version]
    230       desc, expected = packets.UDP(version, unspec, dstaddr, sport)
    231 
    232       # If we're testing connected sockets, connect the socket on the first
    233       # netid now.
    234       if use_connect:
    235         netid = self.tuns.keys()[0]
    236         self.SelectInterface(s, netid, mode)
    237         s.connect((dstaddr, 53))
    238         expected.src = self.MyAddress(version, netid)
    239 
    240       # For each netid, select that network without closing the socket, and
    241       # check that the packets sent on that socket go out on the right network.
    242       for netid in self.tuns:
    243         self.SelectInterface(s, netid, mode)
    244         if not use_connect:
    245           expected.src = self.MyAddress(version, netid)
    246         s.sendto(UDP_PAYLOAD, (dstaddr, 53))
    247         connected_str = "Connected" if use_connect else "Unconnected"
    248         msg = "%s UDPv%d socket remarked using %s: expecting %s on %s" % (
    249             connected_str, version, mode, desc, self.GetInterfaceName(netid))
    250         self.ExpectPacketOn(netid, msg, expected)
    251         self.SelectInterface(s, None, mode)
    252 
    253   def testIPv4Remarking(self):
    254     """Checks that updating the mark on an IPv4 socket changes routing."""
    255     self.CheckRemarking(4, False)
    256     self.CheckRemarking(4, True)
    257 
    258   def testIPv6Remarking(self):
    259     """Checks that updating the mark on an IPv6 socket changes routing."""
    260     self.CheckRemarking(6, False)
    261     self.CheckRemarking(6, True)
    262 
    263   def testIPv6StickyPktinfo(self):
    264     for _ in xrange(self.ITERATIONS):
    265       for netid in self.tuns:
    266         s = net_test.UDPSocket(AF_INET6)
    267 
    268         # Set a flowlabel.
    269         net_test.SetFlowLabel(s, net_test.IPV6_ADDR, 0xdead)
    270         s.setsockopt(net_test.SOL_IPV6, net_test.IPV6_FLOWINFO_SEND, 1)
    271 
    272         # Set some destination options.
    273         nonce = "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c"
    274         dstopts = "".join([
    275             "\x11\x02",              # Next header=UDP, 24 bytes of options.
    276             "\x01\x06", "\x00" * 6,  # PadN, 6 bytes of padding.
    277             "\x8b\x0c",              # ILNP nonce, 12 bytes.
    278             nonce
    279         ])
    280         s.setsockopt(net_test.SOL_IPV6, IPV6_DSTOPTS, dstopts)
    281         s.setsockopt(net_test.SOL_IPV6, IPV6_UNICAST_HOPS, 255)
    282 
    283         pktinfo = multinetwork_base.MakePktInfo(6, None, self.ifindices[netid])
    284 
    285         # Set the sticky pktinfo option.
    286         s.setsockopt(net_test.SOL_IPV6, IPV6_PKTINFO, pktinfo)
    287 
    288         # Specify the flowlabel in the destination address.
    289         s.sendto(UDP_PAYLOAD, (net_test.IPV6_ADDR, 53, 0xdead, 0))
    290 
    291         sport = s.getsockname()[1]
    292         srcaddr = self.MyAddress(6, netid)
    293         expected = (scapy.IPv6(src=srcaddr, dst=net_test.IPV6_ADDR,
    294                                fl=0xdead, hlim=255) /
    295                     scapy.IPv6ExtHdrDestOpt(
    296                         options=[scapy.PadN(optdata="\x00\x00\x00\x00\x00\x00"),
    297                                  scapy.HBHOptUnknown(otype=0x8b,
    298                                                      optdata=nonce)]) /
    299                     scapy.UDP(sport=sport, dport=53) /
    300                     UDP_PAYLOAD)
    301         msg = "IPv6 UDP using sticky pktinfo: expected UDP packet on %s" % (
    302             self.GetInterfaceName(netid))
    303         self.ExpectPacketOn(netid, msg, expected)
    304 
    305   def CheckPktinfoRouting(self, version):
    306     for _ in xrange(self.ITERATIONS):
    307       for netid in self.tuns:
    308         family = self.GetProtocolFamily(version)
    309         s = net_test.UDPSocket(family)
    310 
    311         if version == 6:
    312           # Create a flowlabel so we can use it.
    313           net_test.SetFlowLabel(s, net_test.IPV6_ADDR, 0xbeef)
    314 
    315           # Specify some arbitrary options.
    316           cmsgs = [
    317               (net_test.SOL_IPV6, IPV6_HOPLIMIT, 39),
    318               (net_test.SOL_IPV6, IPV6_TCLASS, 0x83),
    319               (net_test.SOL_IPV6, IPV6_FLOWINFO, int(htonl(0xbeef))),
    320           ]
    321         else:
    322           # Support for setting IPv4 TOS and TTL via cmsg only appeared in 3.13.
    323           cmsgs = []
    324           s.setsockopt(net_test.SOL_IP, IP_TTL, 39)
    325           s.setsockopt(net_test.SOL_IP, IP_TOS, 0x83)
    326 
    327         dstaddr = self.GetRemoteAddress(version)
    328         self.SendOnNetid(version, s, dstaddr, 53, netid, UDP_PAYLOAD, cmsgs)
    329 
    330         sport = s.getsockname()[1]
    331         srcaddr = self.MyAddress(version, netid)
    332 
    333         desc, expected = packets.UDPWithOptions(version, srcaddr, dstaddr,
    334                                                 sport=sport)
    335 
    336         msg = "IPv%d UDP using pktinfo routing: expected %s on %s" % (
    337             version, desc, self.GetInterfaceName(netid))
    338         self.ExpectPacketOn(netid, msg, expected)
    339 
    340   def testIPv4PktinfoRouting(self):
    341     self.CheckPktinfoRouting(4)
    342 
    343   def testIPv6PktinfoRouting(self):
    344     self.CheckPktinfoRouting(6)
    345 
    346 
    347 class MarkTest(InboundMarkingTest):
    348 
    349   def CheckReflection(self, version, gen_packet, gen_reply):
    350     """Checks that replies go out on the same interface as the original.
    351 
    352     For each combination:
    353      - Calls gen_packet to generate a packet to that IP address.
    354      - Writes the packet generated by gen_packet on the given tun
    355        interface, causing the kernel to receive it.
    356      - Checks that the kernel's reply matches the packet generated by
    357        gen_reply.
    358 
    359     Args:
    360       version: An integer, 4 or 6.
    361       gen_packet: A function taking an IP version (an integer), a source
    362         address and a destination address (strings), and returning a scapy
    363         packet.
    364       gen_reply: A function taking the same arguments as gen_packet,
    365         plus a scapy packet, and returning a scapy packet.
    366     """
    367     for netid, iif, ip_if, myaddr, remoteaddr in self.Combinations(version):
    368       # Generate a test packet.
    369       desc, packet = gen_packet(version, remoteaddr, myaddr)
    370 
    371       # Test with mark reflection enabled and disabled.
    372       for reflect in [0, 1]:
    373         self.SetMarkReflectSysctls(reflect)
    374         # HACK: IPv6 ping replies always do a routing lookup with the
    375         # interface the ping came in on. So even if mark reflection is not
    376         # working, IPv6 ping replies will be properly reflected. Don't
    377         # fail when that happens.
    378         if reflect or desc == "ICMPv6 echo":
    379           reply_desc, reply = gen_reply(version, myaddr, remoteaddr, packet)
    380         else:
    381           reply_desc, reply = None, None
    382 
    383         msg = self._FormatMessage(iif, ip_if, "reflect=%d" % reflect,
    384                                   desc, reply_desc)
    385         self._ReceiveAndExpectResponse(netid, packet, reply, msg)
    386 
    387   def SYNToClosedPort(self, *args):
    388     return packets.SYN(999, *args)
    389 
    390   def testIPv4ICMPErrorsReflectMark(self):
    391     self.CheckReflection(4, packets.UDP, packets.ICMPPortUnreachable)
    392 
    393   def testIPv6ICMPErrorsReflectMark(self):
    394     self.CheckReflection(6, packets.UDP, packets.ICMPPortUnreachable)
    395 
    396   def testIPv4PingRepliesReflectMarkAndTos(self):
    397     self.CheckReflection(4, packets.ICMPEcho, packets.ICMPReply)
    398 
    399   def testIPv6PingRepliesReflectMarkAndTos(self):
    400     self.CheckReflection(6, packets.ICMPEcho, packets.ICMPReply)
    401 
    402   def testIPv4RSTsReflectMark(self):
    403     self.CheckReflection(4, self.SYNToClosedPort, packets.RST)
    404 
    405   def testIPv6RSTsReflectMark(self):
    406     self.CheckReflection(6, self.SYNToClosedPort, packets.RST)
    407 
    408 
    409 class TCPAcceptTest(InboundMarkingTest):
    410 
    411   MODE_BINDTODEVICE = "SO_BINDTODEVICE"
    412   MODE_INCOMING_MARK = "incoming mark"
    413   MODE_EXPLICIT_MARK = "explicit mark"
    414   MODE_UID = "uid"
    415 
    416   @classmethod
    417   def setUpClass(cls):
    418     super(TCPAcceptTest, cls).setUpClass()
    419 
    420     # Open a port so we can observe SYN+ACKs. Since it's a dual-stack socket it
    421     # will accept both IPv4 and IPv6 connections. We do this here instead of in
    422     # each test so we can use the same socket every time. That way, if a kernel
    423     # bug causes incoming packets to mark the listening socket instead of the
    424     # accepted socket, the test will fail as soon as the next address/interface
    425     # combination is tried.
    426     cls.listenport = 1234
    427     cls.listensocket = net_test.IPv6TCPSocket()
    428     cls.listensocket.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
    429     cls.listensocket.bind(("::", cls.listenport))
    430     cls.listensocket.listen(100)
    431 
    432   def BounceSocket(self, s):
    433     """Attempts to invalidate a socket's destination cache entry."""
    434     if s.family == AF_INET:
    435       tos = s.getsockopt(SOL_IP, IP_TOS)
    436       s.setsockopt(net_test.SOL_IP, IP_TOS, 53)
    437       s.setsockopt(net_test.SOL_IP, IP_TOS, tos)
    438     else:
    439       # UDP, 8 bytes dstopts; PAD1, 4 bytes padding; 4 bytes zeros.
    440       pad8 = "".join(["\x11\x00", "\x01\x04", "\x00" * 4])
    441       s.setsockopt(net_test.SOL_IPV6, IPV6_DSTOPTS, pad8)
    442       s.setsockopt(net_test.SOL_IPV6, IPV6_DSTOPTS, "")
    443 
    444   def _SetTCPMarkAcceptSysctl(self, value):
    445     self.SetSysctl(TCP_MARK_ACCEPT_SYSCTL, value)
    446 
    447   def CheckTCPConnection(self, mode, listensocket, netid, version,
    448                          myaddr, remoteaddr, packet, reply, msg):
    449     establishing_ack = packets.ACK(version, remoteaddr, myaddr, reply)[1]
    450 
    451     # Attempt to confuse the kernel.
    452     self.BounceSocket(listensocket)
    453 
    454     self.ReceivePacketOn(netid, establishing_ack)
    455 
    456     # If we're using UID routing, the accept() call has to be run as a UID that
    457     # is routed to the specified netid, because the UID of the socket returned
    458     # by accept() is the effective UID of the process that calls it. It doesn't
    459     # need to be the same UID; any UID that selects the same interface will do.
    460     with net_test.RunAsUid(self.UidForNetid(netid)):
    461       s, _ = listensocket.accept()
    462 
    463     try:
    464       # Check that data sent on the connection goes out on the right interface.
    465       desc, data = packets.ACK(version, myaddr, remoteaddr, establishing_ack,
    466                                payload=UDP_PAYLOAD)
    467       s.send(UDP_PAYLOAD)
    468       self.ExpectPacketOn(netid, msg + ": expecting %s" % desc, data)
    469       self.BounceSocket(s)
    470 
    471       # Keep up our end of the conversation.
    472       ack = packets.ACK(version, remoteaddr, myaddr, data)[1]
    473       self.BounceSocket(listensocket)
    474       self.ReceivePacketOn(netid, ack)
    475 
    476       mark = self.GetSocketMark(s)
    477     finally:
    478       self.BounceSocket(s)
    479       s.close()
    480 
    481     if mode == self.MODE_INCOMING_MARK:
    482       self.assertEquals(netid, mark,
    483                         msg + ": Accepted socket: Expected mark %d, got %d" % (
    484                             netid, mark))
    485     elif mode != self.MODE_EXPLICIT_MARK:
    486       self.assertEquals(0, self.GetSocketMark(listensocket))
    487 
    488     # Check the FIN was sent on the right interface, and ack it. We don't expect
    489     # this to fail because by the time the connection is established things are
    490     # likely working, but a) extra tests are always good and b) extra packets
    491     # like the FIN (and retransmitted FINs) could cause later tests that expect
    492     # no packets to fail.
    493     desc, fin = packets.FIN(version, myaddr, remoteaddr, ack)
    494     self.ExpectPacketOn(netid, msg + ": expecting %s after close" % desc, fin)
    495 
    496     desc, finack = packets.FIN(version, remoteaddr, myaddr, fin)
    497     self.ReceivePacketOn(netid, finack)
    498 
    499     # Since we called close() earlier, the userspace socket object is gone, so
    500     # the socket has no UID. If we're doing UID routing, the ack might be routed
    501     # incorrectly. Not much we can do here.
    502     desc, finackack = packets.ACK(version, myaddr, remoteaddr, finack)
    503     if mode != self.MODE_UID:
    504       self.ExpectPacketOn(netid, msg + ": expecting final ack", finackack)
    505     else:
    506       self.ClearTunQueues()
    507 
    508   def CheckTCP(self, version, modes):
    509     """Checks that incoming TCP connections work.
    510 
    511     Args:
    512       version: An integer, 4 or 6.
    513       modes: A list of modes to excercise.
    514     """
    515     for syncookies in [0, 2]:
    516       for mode in modes:
    517         for netid, iif, ip_if, myaddr, remoteaddr in self.Combinations(version):
    518           if mode == self.MODE_UID:
    519             listensocket = self.BuildSocket(6, net_test.TCPSocket, netid, mode)
    520             listensocket.listen(100)
    521           else:
    522             listensocket = self.listensocket
    523 
    524           listenport = listensocket.getsockname()[1]
    525 
    526           accept_sysctl = 1 if mode == self.MODE_INCOMING_MARK else 0
    527           self._SetTCPMarkAcceptSysctl(accept_sysctl)
    528 
    529           bound_dev = iif if mode == self.MODE_BINDTODEVICE else None
    530           self.BindToDevice(listensocket, bound_dev)
    531 
    532           mark = netid if mode == self.MODE_EXPLICIT_MARK else 0
    533           self.SetSocketMark(listensocket, mark)
    534 
    535           # Generate the packet here instead of in the outer loop, so
    536           # subsequent TCP connections use different source ports and
    537           # retransmissions from old connections don't confuse subsequent
    538           # tests.
    539           desc, packet = packets.SYN(listenport, version, remoteaddr, myaddr)
    540 
    541           if mode:
    542             reply_desc, reply = packets.SYNACK(version, myaddr, remoteaddr,
    543                                                packet)
    544           else:
    545             reply_desc, reply = None, None
    546 
    547           extra = "mode=%s, syncookies=%d" % (mode, syncookies)
    548           msg = self._FormatMessage(iif, ip_if, extra, desc, reply_desc)
    549           reply = self._ReceiveAndExpectResponse(netid, packet, reply, msg)
    550           if reply:
    551             self.CheckTCPConnection(mode, listensocket, netid, version, myaddr,
    552                                     remoteaddr, packet, reply, msg)
    553 
    554   def testBasicTCP(self):
    555     self.CheckTCP(4, [None, self.MODE_BINDTODEVICE, self.MODE_EXPLICIT_MARK])
    556     self.CheckTCP(6, [None, self.MODE_BINDTODEVICE, self.MODE_EXPLICIT_MARK])
    557 
    558   def testIPv4MarkAccept(self):
    559     self.CheckTCP(4, [self.MODE_INCOMING_MARK])
    560 
    561   def testIPv6MarkAccept(self):
    562     self.CheckTCP(6, [self.MODE_INCOMING_MARK])
    563 
    564   @unittest.skipUnless(multinetwork_base.HAVE_UID_ROUTING, "no UID routes")
    565   def testIPv4UidAccept(self):
    566     self.CheckTCP(4, [self.MODE_UID])
    567 
    568   @unittest.skipUnless(multinetwork_base.HAVE_UID_ROUTING, "no UID routes")
    569   def testIPv6UidAccept(self):
    570     self.CheckTCP(6, [self.MODE_UID])
    571 
    572   def testIPv6ExplicitMark(self):
    573     self.CheckTCP(6, [self.MODE_EXPLICIT_MARK])
    574 
    575 
    576 class RATest(multinetwork_base.MultiNetworkBaseTest):
    577 
    578   def testDoesNotHaveObsoleteSysctl(self):
    579     self.assertFalse(os.path.isfile(
    580         "/proc/sys/net/ipv6/route/autoconf_table_offset"))
    581 
    582   @unittest.skipUnless(multinetwork_base.HAVE_AUTOCONF_TABLE,
    583                        "no support for per-table autoconf")
    584   def testPurgeDefaultRouters(self):
    585 
    586     def CheckIPv6Connectivity(expect_connectivity):
    587       for netid in self.NETIDS:
    588         s = net_test.UDPSocket(AF_INET6)
    589         self.SetSocketMark(s, netid)
    590         if expect_connectivity:
    591           self.assertTrue(s.sendto(UDP_PAYLOAD, (net_test.IPV6_ADDR, 1234)))
    592         else:
    593           self.assertRaisesErrno(errno.ENETUNREACH, s.sendto, UDP_PAYLOAD,
    594                                  (net_test.IPV6_ADDR, 1234))
    595 
    596     try:
    597       CheckIPv6Connectivity(True)
    598       self.SetIPv6SysctlOnAllIfaces("accept_ra", 1)
    599       self.SetSysctl("/proc/sys/net/ipv6/conf/all/forwarding", 1)
    600       CheckIPv6Connectivity(False)
    601     finally:
    602       self.SetSysctl("/proc/sys/net/ipv6/conf/all/forwarding", 0)
    603       for netid in self.NETIDS:
    604         self.SendRA(netid)
    605       CheckIPv6Connectivity(True)
    606 
    607   def testOnlinkCommunication(self):
    608     """Checks that on-link communication goes direct and not through routers."""
    609     for netid in self.tuns:
    610       # Send a UDP packet to a random on-link destination.
    611       s = net_test.UDPSocket(AF_INET6)
    612       iface = self.GetInterfaceName(netid)
    613       self.BindToDevice(s, iface)
    614       # dstaddr can never be our address because GetRandomDestination only fills
    615       # in the lower 32 bits, but our address has 0xff in the byte before that
    616       # (since it's constructed from the EUI-64 and so has ff:fe in the middle).
    617       dstaddr = self.GetRandomDestination(self.IPv6Prefix(netid))
    618       s.sendto(UDP_PAYLOAD, (dstaddr, 53))
    619 
    620       # Expect an NS for that destination on the interface.
    621       myaddr = self.MyAddress(6, netid)
    622       mymac = self.MyMacAddress(netid)
    623       desc, expected = packets.NS(myaddr, dstaddr, mymac)
    624       msg = "Sending UDP packet to on-link destination: expecting %s" % desc
    625       time.sleep(0.0001)  # Required to make the test work on kernel 3.1(!)
    626       self.ExpectPacketOn(netid, msg, expected)
    627 
    628       # Send an NA.
    629       tgtmac = "02:00:00:00:%02x:99" % netid
    630       _, reply = packets.NA(dstaddr, myaddr, tgtmac)
    631       # Don't use ReceivePacketOn, since that uses the router's MAC address as
    632       # the source. Instead, construct our own Ethernet header with source
    633       # MAC of tgtmac.
    634       reply = scapy.Ether(src=tgtmac, dst=mymac) / reply
    635       self.ReceiveEtherPacketOn(netid, reply)
    636 
    637       # Expect the kernel to send the original UDP packet now that the ND cache
    638       # entry has been populated.
    639       sport = s.getsockname()[1]
    640       desc, expected = packets.UDP(6, myaddr, dstaddr, sport=sport)
    641       msg = "After NA response, expecting %s" % desc
    642       self.ExpectPacketOn(netid, msg, expected)
    643 
    644   # This test documents a known issue: routing tables are never deleted.
    645   @unittest.skipUnless(multinetwork_base.HAVE_AUTOCONF_TABLE,
    646                        "no support for per-table autoconf")
    647   def testLeftoverRoutes(self):
    648     def GetNumRoutes():
    649       return len(open("/proc/net/ipv6_route").readlines())
    650 
    651     num_routes = GetNumRoutes()
    652     for i in xrange(10, 20):
    653       try:
    654         self.tuns[i] = self.CreateTunInterface(i)
    655         self.SendRA(i)
    656         self.tuns[i].close()
    657       finally:
    658         del self.tuns[i]
    659     self.assertLess(num_routes, GetNumRoutes())
    660 
    661 
    662 class PMTUTest(InboundMarkingTest):
    663 
    664   PAYLOAD_SIZE = 1400
    665 
    666   # Socket options to change PMTU behaviour.
    667   IP_MTU_DISCOVER = 10
    668   IP_PMTUDISC_DO = 1
    669   IPV6_DONTFRAG = 62
    670 
    671   # Socket options to get the MTU.
    672   IP_MTU = 14
    673   IPV6_PATHMTU = 61
    674 
    675   def GetSocketMTU(self, version, s):
    676     if version == 6:
    677       ip6_mtuinfo = s.getsockopt(net_test.SOL_IPV6, self.IPV6_PATHMTU, 32)
    678       unused_sockaddr, mtu = struct.unpack("=28sI", ip6_mtuinfo)
    679       return mtu
    680     else:
    681       return s.getsockopt(net_test.SOL_IP, self.IP_MTU)
    682 
    683   def DisableFragmentationAndReportErrors(self, version, s):
    684     if version == 4:
    685       s.setsockopt(net_test.SOL_IP, self.IP_MTU_DISCOVER, self.IP_PMTUDISC_DO)
    686       s.setsockopt(net_test.SOL_IP, net_test.IP_RECVERR, 1)
    687     else:
    688       s.setsockopt(net_test.SOL_IPV6, self.IPV6_DONTFRAG, 1)
    689       s.setsockopt(net_test.SOL_IPV6, net_test.IPV6_RECVERR, 1)
    690 
    691   def CheckPMTU(self, version, use_connect, modes):
    692 
    693     def SendBigPacket(version, s, dstaddr, netid, payload):
    694       if use_connect:
    695         s.send(payload)
    696       else:
    697         self.SendOnNetid(version, s, dstaddr, 1234, netid, payload, [])
    698 
    699     for netid in self.tuns:
    700       for mode in modes:
    701         s = self.BuildSocket(version, net_test.UDPSocket, netid, mode)
    702         self.DisableFragmentationAndReportErrors(version, s)
    703 
    704         srcaddr = self.MyAddress(version, netid)
    705         dst_prefix, intermediate = {
    706             4: ("172.19.", "172.16.9.12"),
    707             6: ("2001:db8::", "2001:db8::1")
    708         }[version]
    709         dstaddr = self.GetRandomDestination(dst_prefix)
    710 
    711         if use_connect:
    712           s.connect((dstaddr, 1234))
    713 
    714         payload = self.PAYLOAD_SIZE * "a"
    715 
    716         # Send a packet and receive a packet too big.
    717         SendBigPacket(version, s, dstaddr, netid, payload)
    718         received = self.ReadAllPacketsOn(netid)
    719         self.assertEquals(1, len(received))
    720         _, toobig = packets.ICMPPacketTooBig(version, intermediate, srcaddr,
    721                                              received[0])
    722         self.ReceivePacketOn(netid, toobig)
    723 
    724         # Check that another send on the same socket returns EMSGSIZE.
    725         self.assertRaisesErrno(
    726             errno.EMSGSIZE,
    727             SendBigPacket, version, s, dstaddr, netid, payload)
    728 
    729         # If this is a connected socket, make sure the socket MTU was set.
    730         # Note that in IPv4 this only started working in Linux 3.6!
    731         if use_connect and (version == 6 or net_test.LINUX_VERSION >= (3, 6)):
    732           self.assertEquals(1280, self.GetSocketMTU(version, s))
    733 
    734         s.close()
    735 
    736         # Check that other sockets pick up the PMTU we have been told about by
    737         # connecting another socket to the same destination and getting its MTU.
    738         # This new socket can use any method to select its outgoing interface;
    739         # here we use a mark for simplicity.
    740         s2 = self.BuildSocket(version, net_test.UDPSocket, netid, "mark")
    741         s2.connect((dstaddr, 1234))
    742         self.assertEquals(1280, self.GetSocketMTU(version, s2))
    743 
    744         # Also check the MTU reported by ip route get, this time using the oif.
    745         routes = self.iproute.GetRoutes(dstaddr, self.ifindices[netid], 0, None)
    746         self.assertTrue(routes)
    747         route = routes[0]
    748         rtmsg, attributes = route
    749         self.assertEquals(iproute.RTN_UNICAST, rtmsg.type)
    750         metrics = attributes["RTA_METRICS"]
    751         self.assertEquals(metrics["RTAX_MTU"], 1280)
    752 
    753   def testIPv4BasicPMTU(self):
    754     """Tests IPv4 path MTU discovery.
    755 
    756     Relevant kernel commits:
    757       upstream net-next:
    758         6a66271 ipv4, fib: pass LOOPBACK_IFINDEX instead of 0 to flowi4_iif
    759 
    760       android-3.10:
    761         4bc64dd ipv4, fib: pass LOOPBACK_IFINDEX instead of 0 to flowi4_iif
    762     """
    763 
    764     self.CheckPMTU(4, True, ["mark", "oif"])
    765     self.CheckPMTU(4, False, ["mark", "oif"])
    766 
    767   def testIPv6BasicPMTU(self):
    768     self.CheckPMTU(6, True, ["mark", "oif"])
    769     self.CheckPMTU(6, False, ["mark", "oif"])
    770 
    771   @unittest.skipUnless(multinetwork_base.HAVE_UID_ROUTING, "no UID routes")
    772   def testIPv4UIDPMTU(self):
    773     self.CheckPMTU(4, True, ["uid"])
    774     self.CheckPMTU(4, False, ["uid"])
    775 
    776   @unittest.skipUnless(multinetwork_base.HAVE_UID_ROUTING, "no UID routes")
    777   def testIPv6UIDPMTU(self):
    778     self.CheckPMTU(6, True, ["uid"])
    779     self.CheckPMTU(6, False, ["uid"])
    780 
    781   # Making Path MTU Discovery work on unmarked  sockets requires that mark
    782   # reflection be enabled. Otherwise the kernel has no way to know what routing
    783   # table the original packet used, and thus it won't be able to clone the
    784   # correct route.
    785 
    786   def testIPv4UnmarkedSocketPMTU(self):
    787     self.SetMarkReflectSysctls(1)
    788     try:
    789       self.CheckPMTU(4, False, [None])
    790     finally:
    791       self.SetMarkReflectSysctls(0)
    792 
    793   def testIPv6UnmarkedSocketPMTU(self):
    794     self.SetMarkReflectSysctls(1)
    795     try:
    796       self.CheckPMTU(6, False, [None])
    797     finally:
    798       self.SetMarkReflectSysctls(0)
    799 
    800 
    801 @unittest.skipUnless(multinetwork_base.HAVE_UID_ROUTING, "no UID routes")
    802 class UidRoutingTest(multinetwork_base.MultiNetworkBaseTest):
    803   """Tests that per-UID routing works properly.
    804 
    805   Relevant kernel commits:
    806     android-3.4:
    807       0b42874 net: core: Support UID-based routing.
    808       0836a0c Handle 'sk' being NULL in UID-based routing.
    809 
    810     android-3.10:
    811       99a6ea4 net: core: Support UID-based routing.
    812       455b09d Handle 'sk' being NULL in UID-based routing.
    813   """
    814 
    815   def GetRulesAtPriority(self, version, priority):
    816     rules = self.iproute.DumpRules(version)
    817     out = [(rule, attributes) for rule, attributes in rules
    818            if attributes.get("FRA_PRIORITY", 0) == priority]
    819     return out
    820 
    821   def CheckInitialTablesHaveNoUIDs(self, version):
    822     rules = []
    823     for priority in [0, 32766, 32767]:
    824       rules.extend(self.GetRulesAtPriority(version, priority))
    825     for _, attributes in rules:
    826       self.assertNotIn("FRA_UID_START", attributes)
    827       self.assertNotIn("FRA_UID_END", attributes)
    828 
    829   def testIPv4InitialTablesHaveNoUIDs(self):
    830     self.CheckInitialTablesHaveNoUIDs(4)
    831 
    832   def testIPv6InitialTablesHaveNoUIDs(self):
    833     self.CheckInitialTablesHaveNoUIDs(6)
    834 
    835   def CheckGetAndSetRules(self, version):
    836     def Random():
    837       return random.randint(1000000, 2000000)
    838 
    839     start, end = tuple(sorted([Random(), Random()]))
    840     table = Random()
    841     priority = Random()
    842 
    843     try:
    844       self.iproute.UidRangeRule(version, True, start, end, table,
    845                                 priority=priority)
    846 
    847       rules = self.GetRulesAtPriority(version, priority)
    848       self.assertTrue(rules)
    849       _, attributes = rules[-1]
    850       self.assertEquals(priority, attributes["FRA_PRIORITY"])
    851       self.assertEquals(start, attributes["FRA_UID_START"])
    852       self.assertEquals(end, attributes["FRA_UID_END"])
    853       self.assertEquals(table, attributes["FRA_TABLE"])
    854     finally:
    855       self.iproute.UidRangeRule(version, False, start, end, table,
    856                                 priority=priority)
    857 
    858   def testIPv4GetAndSetRules(self):
    859     self.CheckGetAndSetRules(4)
    860 
    861   def testIPv6GetAndSetRules(self):
    862     self.CheckGetAndSetRules(6)
    863 
    864   def ExpectNoRoute(self, addr, oif, mark, uid):
    865     # The lack of a route may be either an error, or an unreachable route.
    866     try:
    867       routes = self.iproute.GetRoutes(addr, oif, mark, uid)
    868       rtmsg, _ = routes[0]
    869       self.assertEquals(iproute.RTN_UNREACHABLE, rtmsg.type)
    870     except IOError, e:
    871       if int(e.errno) != -int(errno.ENETUNREACH):
    872         raise e
    873 
    874   def ExpectRoute(self, addr, oif, mark, uid):
    875     routes = self.iproute.GetRoutes(addr, oif, mark, uid)
    876     rtmsg, _ = routes[0]
    877     self.assertEquals(iproute.RTN_UNICAST, rtmsg.type)
    878 
    879   def CheckGetRoute(self, version, addr):
    880     self.ExpectNoRoute(addr, 0, 0, 0)
    881     for netid in self.NETIDS:
    882       uid = self.UidForNetid(netid)
    883       self.ExpectRoute(addr, 0, 0, uid)
    884     self.ExpectNoRoute(addr, 0, 0, 0)
    885 
    886   def testIPv4RouteGet(self):
    887     self.CheckGetRoute(4, net_test.IPV4_ADDR)
    888 
    889   def testIPv6RouteGet(self):
    890     self.CheckGetRoute(6, net_test.IPV6_ADDR)
    891 
    892 
    893 class RulesTest(net_test.NetworkTest):
    894 
    895   RULE_PRIORITY = 99999
    896 
    897   def setUp(self):
    898     self.iproute = iproute.IPRoute()
    899     for version in [4, 6]:
    900       self.iproute.DeleteRulesAtPriority(version, self.RULE_PRIORITY)
    901 
    902   def tearDown(self):
    903     for version in [4, 6]:
    904       self.iproute.DeleteRulesAtPriority(version, self.RULE_PRIORITY)
    905 
    906   def testRuleDeletionMatchesTable(self):
    907     for version in [4, 6]:
    908       # Add rules with mark 300 pointing at tables 301 and 302.
    909       # This checks for a kernel bug where deletion request for tables > 256
    910       # ignored the table.
    911       self.iproute.FwmarkRule(version, True, 300, 301,
    912                               priority=self.RULE_PRIORITY)
    913       self.iproute.FwmarkRule(version, True, 300, 302,
    914                               priority=self.RULE_PRIORITY)
    915       # Delete rule with mark 300 pointing at table 302.
    916       self.iproute.FwmarkRule(version, False, 300, 302,
    917                               priority=self.RULE_PRIORITY)
    918       # Check that the rule pointing at table 301 is still around.
    919       attributes = [a for _, a in self.iproute.DumpRules(version)
    920                     if a.get("FRA_PRIORITY", 0) == self.RULE_PRIORITY]
    921       self.assertEquals(1, len(attributes))
    922       self.assertEquals(301, attributes[0]["FRA_TABLE"])
    923 
    924 
    925 if __name__ == "__main__":
    926   unittest.main()
    927