Home | History | Annotate | Download | only in test
      1 #!/usr/bin/python
      2 #
      3 # Copyright 2014 The Android Open Source Project
      4 #
      5 # Licensed under the Apache License, Version 2.0 (the "License");
      6 # you may not use this file except in compliance with the License.
      7 # You may obtain a copy of the License at
      8 #
      9 # http://www.apache.org/licenses/LICENSE-2.0
     10 #
     11 # Unless required by applicable law or agreed to in writing, software
     12 # distributed under the License is distributed on an "AS IS" BASIS,
     13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     14 # See the License for the specific language governing permissions and
     15 # limitations under the License.
     16 
     17 import ctypes
     18 import errno
     19 import os
     20 import random
     21 from socket import *  # pylint: disable=wildcard-import
     22 import struct
     23 import time           # pylint: disable=unused-import
     24 import unittest
     25 
     26 from scapy import all as scapy
     27 
     28 import csocket
     29 import iproute
     30 import multinetwork_base
     31 import net_test
     32 import packets
     33 
     34 # For brevity.
     35 UDP_PAYLOAD = net_test.UDP_PAYLOAD
     36 
     37 IPV6_FLOWINFO = 11
     38 
     39 SYNCOOKIES_SYSCTL = "/proc/sys/net/ipv4/tcp_syncookies"
     40 TCP_MARK_ACCEPT_SYSCTL = "/proc/sys/net/ipv4/tcp_fwmark_accept"
     41 
     42 # The IP[V6]UNICAST_IF socket option was added between 3.1 and 3.4.
     43 HAVE_UNICAST_IF = net_test.LINUX_VERSION >= (3, 4, 0)
     44 
     45 # RTPROT_RA is working properly with 4.14
     46 HAVE_RTPROT_RA = net_test.LINUX_VERSION >= (4, 14, 0)
     47 
     48 class ConfigurationError(AssertionError):
     49   pass
     50 
     51 
     52 class OutgoingTest(multinetwork_base.MultiNetworkBaseTest):
     53 
     54   # How many times to run outgoing packet tests.
     55   ITERATIONS = 5
     56 
     57   def CheckPingPacket(self, version, netid, routing_mode, packet):
     58     s = self.BuildSocket(version, net_test.PingSocket, netid, routing_mode)
     59 
     60     myaddr = self.MyAddress(version, netid)
     61     mysockaddr = self.MySocketAddress(version, netid)
     62     s.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
     63     s.bind((mysockaddr, packets.PING_IDENT))
     64     net_test.SetSocketTos(s, packets.PING_TOS)
     65 
     66     dstaddr = self.GetRemoteAddress(version)
     67     dstsockaddr = self.GetRemoteSocketAddress(version)
     68     desc, expected = packets.ICMPEcho(version, myaddr, dstaddr)
     69     msg = "IPv%d ping: expected %s on %s" % (
     70         version, desc, self.GetInterfaceName(netid))
     71 
     72     s.sendto(packet + packets.PING_PAYLOAD, (dstsockaddr, 19321))
     73 
     74     self.ExpectPacketOn(netid, msg, expected)
     75 
     76   def CheckTCPSYNPacket(self, version, netid, routing_mode):
     77     s = self.BuildSocket(version, net_test.TCPSocket, netid, routing_mode)
     78 
     79     myaddr = self.MyAddress(version, netid)
     80     dstaddr = self.GetRemoteAddress(version)
     81     dstsockaddr = self.GetRemoteSocketAddress(version)
     82     desc, expected = packets.SYN(53, version, myaddr, dstaddr,
     83                                  sport=None, seq=None)
     84 
     85 
     86     # Non-blocking TCP connects always return EINPROGRESS.
     87     self.assertRaisesErrno(errno.EINPROGRESS, s.connect, (dstsockaddr, 53))
     88     msg = "IPv%s TCP connect: expected %s on %s" % (
     89         version, desc, self.GetInterfaceName(netid))
     90     self.ExpectPacketOn(netid, msg, expected)
     91     s.close()
     92 
     93   def CheckUDPPacket(self, version, netid, routing_mode):
     94     s = self.BuildSocket(version, net_test.UDPSocket, netid, routing_mode)
     95 
     96     myaddr = self.MyAddress(version, netid)
     97     dstaddr = self.GetRemoteAddress(version)
     98     dstsockaddr = self.GetRemoteSocketAddress(version)
     99 
    100     desc, expected = packets.UDP(version, myaddr, dstaddr, sport=None)
    101     msg = "IPv%s UDP %%s: expected %s on %s" % (
    102         version, desc, self.GetInterfaceName(netid))
    103 
    104     s.sendto(UDP_PAYLOAD, (dstsockaddr, 53))
    105     self.ExpectPacketOn(netid, msg % "sendto", expected)
    106 
    107     # IP_UNICAST_IF doesn't seem to work on connected sockets, so no TCP.
    108     if routing_mode != "ucast_oif":
    109       s.connect((dstsockaddr, 53))
    110       s.send(UDP_PAYLOAD)
    111       self.ExpectPacketOn(netid, msg % "connect/send", expected)
    112       s.close()
    113 
    114   def CheckRawGrePacket(self, version, netid, routing_mode):
    115     s = self.BuildSocket(version, net_test.RawGRESocket, netid, routing_mode)
    116 
    117     inner_version = {4: 6, 6: 4}[version]
    118     inner_src = self.MyAddress(inner_version, netid)
    119     inner_dst = self.GetRemoteAddress(inner_version)
    120     inner = str(packets.UDP(inner_version, inner_src, inner_dst, sport=None)[1])
    121 
    122     ethertype = {4: net_test.ETH_P_IP, 6: net_test.ETH_P_IPV6}[inner_version]
    123     # A GRE header can be as simple as two zero bytes and the ethertype.
    124     packet = struct.pack("!i", ethertype) + inner
    125     myaddr = self.MyAddress(version, netid)
    126     dstaddr = self.GetRemoteAddress(version)
    127 
    128     s.sendto(packet, (dstaddr, IPPROTO_GRE))
    129     desc, expected = packets.GRE(version, myaddr, dstaddr, ethertype, inner)
    130     msg = "Raw IPv%d GRE with inner IPv%d UDP: expected %s on %s" % (
    131         version, inner_version, desc, self.GetInterfaceName(netid))
    132     self.ExpectPacketOn(netid, msg, expected)
    133 
    134   def CheckOutgoingPackets(self, routing_mode):
    135     for _ in xrange(self.ITERATIONS):
    136       for netid in self.tuns:
    137 
    138         self.CheckPingPacket(4, netid, routing_mode, self.IPV4_PING)
    139         # Kernel bug.
    140         if routing_mode != "oif":
    141           self.CheckPingPacket(6, netid, routing_mode, self.IPV6_PING)
    142 
    143         # IP_UNICAST_IF doesn't seem to work on connected sockets, so no TCP.
    144         if routing_mode != "ucast_oif":
    145           self.CheckTCPSYNPacket(4, netid, routing_mode)
    146           self.CheckTCPSYNPacket(6, netid, routing_mode)
    147           self.CheckTCPSYNPacket(5, netid, routing_mode)
    148 
    149         self.CheckUDPPacket(4, netid, routing_mode)
    150         self.CheckUDPPacket(6, netid, routing_mode)
    151         self.CheckUDPPacket(5, netid, routing_mode)
    152 
    153         # Creating raw sockets on non-root UIDs requires properly setting
    154         # capabilities, which is hard to do from Python.
    155         # IP_UNICAST_IF is not supported on raw sockets.
    156         if routing_mode not in ["uid", "ucast_oif"]:
    157           self.CheckRawGrePacket(4, netid, routing_mode)
    158           self.CheckRawGrePacket(6, netid, routing_mode)
    159 
    160   def testMarkRouting(self):
    161     """Checks that socket marking selects the right outgoing interface."""
    162     self.CheckOutgoingPackets("mark")
    163 
    164   def testUidRouting(self):
    165     """Checks that UID routing selects the right outgoing interface."""
    166     self.CheckOutgoingPackets("uid")
    167 
    168   def testOifRouting(self):
    169     """Checks that oif routing selects the right outgoing interface."""
    170     self.CheckOutgoingPackets("oif")
    171 
    172   @unittest.skipUnless(HAVE_UNICAST_IF, "no support for UNICAST_IF")
    173   def testUcastOifRouting(self):
    174     """Checks that ucast oif routing selects the right outgoing interface."""
    175     self.CheckOutgoingPackets("ucast_oif")
    176 
    177   def CheckRemarking(self, version, use_connect):
    178     modes = ["mark", "oif", "uid"]
    179     # Setting UNICAST_IF on connected sockets does not work.
    180     if not use_connect and HAVE_UNICAST_IF:
    181       modes += ["ucast_oif"]
    182 
    183     for mode in modes:
    184       s = net_test.UDPSocket(self.GetProtocolFamily(version))
    185 
    186       # Figure out what packets to expect.
    187       sport = net_test.BindRandomPort(version, s)
    188       dstaddr = {4: self.IPV4_ADDR, 6: self.IPV6_ADDR}[version]
    189       unspec = {4: "0.0.0.0", 6: "::"}[version]  # Placeholder.
    190       desc, expected = packets.UDP(version, unspec, dstaddr, sport)
    191 
    192       # If we're testing connected sockets, connect the socket on the first
    193       # netid now.
    194       if use_connect:
    195         netid = self.tuns.keys()[0]
    196         self.SelectInterface(s, netid, mode)
    197         s.connect((dstaddr, 53))
    198         expected.src = self.MyAddress(version, netid)
    199 
    200       # For each netid, select that network without closing the socket, and
    201       # check that the packets sent on that socket go out on the right network.
    202       #
    203       # For connected sockets, routing is cached in the socket's destination
    204       # cache entry. In this case, we check that just re-selecting the netid
    205       # (except via SO_BINDTODEVICE) does not change routing, but that
    206       # subsequently invalidating the destination cache entry does. Arguably
    207       # this is a bug in the kernel because re-selecting the netid should cause
    208       # routing to change. But it is a convenient way to check that
    209       # InvalidateDstCache actually works.
    210       prevnetid = None
    211       for netid in self.tuns:
    212         self.SelectInterface(s, netid, mode)
    213         if not use_connect:
    214           expected.src = self.MyAddress(version, netid)
    215 
    216         def ExpectSendUsesNetid(netid):
    217           connected_str = "Connected" if use_connect else "Unconnected"
    218           msg = "%s UDPv%d socket remarked using %s: expecting %s on %s" % (
    219               connected_str, version, mode, desc, self.GetInterfaceName(netid))
    220           if use_connect:
    221             s.send(UDP_PAYLOAD)
    222           else:
    223             s.sendto(UDP_PAYLOAD, (dstaddr, 53))
    224           self.ExpectPacketOn(netid, msg, expected)
    225 
    226         if use_connect and mode in ["mark", "uid", "ucast_oif"]:
    227           # If we have a destination cache entry, packets are not rerouted...
    228           if prevnetid:
    229             ExpectSendUsesNetid(prevnetid)
    230             # ... until we invalidate it.
    231             self.InvalidateDstCache(version, prevnetid)
    232           ExpectSendUsesNetid(netid)
    233         else:
    234           ExpectSendUsesNetid(netid)
    235 
    236         self.SelectInterface(s, None, mode)
    237         prevnetid = netid
    238 
    239   def testIPv4Remarking(self):
    240     """Checks that updating the mark on an IPv4 socket changes routing."""
    241     self.CheckRemarking(4, False)
    242     self.CheckRemarking(4, True)
    243 
    244   def testIPv6Remarking(self):
    245     """Checks that updating the mark on an IPv6 socket changes routing."""
    246     self.CheckRemarking(6, False)
    247     self.CheckRemarking(6, True)
    248 
    249   def testIPv6StickyPktinfo(self):
    250     for _ in xrange(self.ITERATIONS):
    251       for netid in self.tuns:
    252         s = net_test.UDPSocket(AF_INET6)
    253 
    254         # Set a flowlabel.
    255         net_test.SetFlowLabel(s, net_test.IPV6_ADDR, 0xdead)
    256         s.setsockopt(net_test.SOL_IPV6, net_test.IPV6_FLOWINFO_SEND, 1)
    257 
    258         # Set some destination options.
    259         nonce = "\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c"
    260         dstopts = "".join([
    261             "\x11\x02",              # Next header=UDP, 24 bytes of options.
    262             "\x01\x06", "\x00" * 6,  # PadN, 6 bytes of padding.
    263             "\x8b\x0c",              # ILNP nonce, 12 bytes.
    264             nonce
    265         ])
    266         s.setsockopt(net_test.SOL_IPV6, IPV6_DSTOPTS, dstopts)
    267         s.setsockopt(net_test.SOL_IPV6, IPV6_UNICAST_HOPS, 255)
    268 
    269         pktinfo = multinetwork_base.MakePktInfo(6, None, self.ifindices[netid])
    270 
    271         # Set the sticky pktinfo option.
    272         s.setsockopt(net_test.SOL_IPV6, IPV6_PKTINFO, pktinfo)
    273 
    274         # Specify the flowlabel in the destination address.
    275         s.sendto(UDP_PAYLOAD, (net_test.IPV6_ADDR, 53, 0xdead, 0))
    276 
    277         sport = s.getsockname()[1]
    278         srcaddr = self.MyAddress(6, netid)
    279         expected = (scapy.IPv6(src=srcaddr, dst=net_test.IPV6_ADDR,
    280                                fl=0xdead, hlim=255) /
    281                     scapy.IPv6ExtHdrDestOpt(
    282                         options=[scapy.PadN(optdata="\x00\x00\x00\x00\x00\x00"),
    283                                  scapy.HBHOptUnknown(otype=0x8b,
    284                                                      optdata=nonce)]) /
    285                     scapy.UDP(sport=sport, dport=53) /
    286                     UDP_PAYLOAD)
    287         msg = "IPv6 UDP using sticky pktinfo: expected UDP packet on %s" % (
    288             self.GetInterfaceName(netid))
    289         self.ExpectPacketOn(netid, msg, expected)
    290 
    291   def CheckPktinfoRouting(self, version):
    292     for _ in xrange(self.ITERATIONS):
    293       for netid in self.tuns:
    294         family = self.GetProtocolFamily(version)
    295         s = net_test.UDPSocket(family)
    296 
    297         if version == 6:
    298           # Create a flowlabel so we can use it.
    299           net_test.SetFlowLabel(s, net_test.IPV6_ADDR, 0xbeef)
    300 
    301           # Specify some arbitrary options.
    302           # We declare the flowlabel as ctypes.c_uint32 because on a 32-bit
    303           # Python interpreter an integer greater than 0x7fffffff (such as our
    304           # chosen flowlabel after being passed through htonl) is converted to
    305           # long, and _MakeMsgControl doesn't know what to do with longs.
    306           cmsgs = [
    307               (net_test.SOL_IPV6, IPV6_HOPLIMIT, 39),
    308               (net_test.SOL_IPV6, IPV6_TCLASS, 0x83),
    309               (net_test.SOL_IPV6, IPV6_FLOWINFO, ctypes.c_uint(htonl(0xbeef))),
    310           ]
    311         else:
    312           # Support for setting IPv4 TOS and TTL via cmsg only appeared in 3.13.
    313           cmsgs = []
    314           s.setsockopt(net_test.SOL_IP, IP_TTL, 39)
    315           s.setsockopt(net_test.SOL_IP, IP_TOS, 0x83)
    316 
    317         dstaddr = self.GetRemoteAddress(version)
    318         self.SendOnNetid(version, s, dstaddr, 53, netid, UDP_PAYLOAD, cmsgs)
    319 
    320         sport = s.getsockname()[1]
    321         srcaddr = self.MyAddress(version, netid)
    322 
    323         desc, expected = packets.UDPWithOptions(version, srcaddr, dstaddr,
    324                                                 sport=sport)
    325 
    326         msg = "IPv%d UDP using pktinfo routing: expected %s on %s" % (
    327             version, desc, self.GetInterfaceName(netid))
    328         self.ExpectPacketOn(netid, msg, expected)
    329 
    330   def testIPv4PktinfoRouting(self):
    331     self.CheckPktinfoRouting(4)
    332 
    333   def testIPv6PktinfoRouting(self):
    334     self.CheckPktinfoRouting(6)
    335 
    336 
    337 class MarkTest(multinetwork_base.InboundMarkingTest):
    338 
    339   def CheckReflection(self, version, gen_packet, gen_reply):
    340     """Checks that replies go out on the same interface as the original.
    341 
    342     For each combination:
    343      - Calls gen_packet to generate a packet to that IP address.
    344      - Writes the packet generated by gen_packet on the given tun
    345        interface, causing the kernel to receive it.
    346      - Checks that the kernel's reply matches the packet generated by
    347        gen_reply.
    348 
    349     Args:
    350       version: An integer, 4 or 6.
    351       gen_packet: A function taking an IP version (an integer), a source
    352         address and a destination address (strings), and returning a scapy
    353         packet.
    354       gen_reply: A function taking the same arguments as gen_packet,
    355         plus a scapy packet, and returning a scapy packet.
    356     """
    357     for netid, iif, ip_if, myaddr, remoteaddr in self.Combinations(version):
    358       # Generate a test packet.
    359       desc, packet = gen_packet(version, remoteaddr, myaddr)
    360 
    361       # Test with mark reflection enabled and disabled.
    362       for reflect in [0, 1]:
    363         self.SetMarkReflectSysctls(reflect)
    364         # HACK: IPv6 ping replies always do a routing lookup with the
    365         # interface the ping came in on. So even if mark reflection is not
    366         # working, IPv6 ping replies will be properly reflected. Don't
    367         # fail when that happens.
    368         if reflect or desc == "ICMPv6 echo":
    369           reply_desc, reply = gen_reply(version, myaddr, remoteaddr, packet)
    370         else:
    371           reply_desc, reply = None, None
    372 
    373         msg = self._FormatMessage(iif, ip_if, "reflect=%d" % reflect,
    374                                   desc, reply_desc)
    375         self._ReceiveAndExpectResponse(netid, packet, reply, msg)
    376 
    377   def SYNToClosedPort(self, *args):
    378     return packets.SYN(999, *args)
    379 
    380   def testIPv4ICMPErrorsReflectMark(self):
    381     self.CheckReflection(4, packets.UDP, packets.ICMPPortUnreachable)
    382 
    383   def testIPv6ICMPErrorsReflectMark(self):
    384     self.CheckReflection(6, packets.UDP, packets.ICMPPortUnreachable)
    385 
    386   def testIPv4PingRepliesReflectMarkAndTos(self):
    387     self.CheckReflection(4, packets.ICMPEcho, packets.ICMPReply)
    388 
    389   def testIPv6PingRepliesReflectMarkAndTos(self):
    390     self.CheckReflection(6, packets.ICMPEcho, packets.ICMPReply)
    391 
    392   def testIPv4RSTsReflectMark(self):
    393     self.CheckReflection(4, self.SYNToClosedPort, packets.RST)
    394 
    395   def testIPv6RSTsReflectMark(self):
    396     self.CheckReflection(6, self.SYNToClosedPort, packets.RST)
    397 
    398 
    399 class TCPAcceptTest(multinetwork_base.InboundMarkingTest):
    400 
    401   MODE_BINDTODEVICE = "SO_BINDTODEVICE"
    402   MODE_INCOMING_MARK = "incoming mark"
    403   MODE_EXPLICIT_MARK = "explicit mark"
    404   MODE_UID = "uid"
    405 
    406   @classmethod
    407   def setUpClass(cls):
    408     super(TCPAcceptTest, cls).setUpClass()
    409 
    410     # Open a port so we can observe SYN+ACKs. Since it's a dual-stack socket it
    411     # will accept both IPv4 and IPv6 connections. We do this here instead of in
    412     # each test so we can use the same socket every time. That way, if a kernel
    413     # bug causes incoming packets to mark the listening socket instead of the
    414     # accepted socket, the test will fail as soon as the next address/interface
    415     # combination is tried.
    416     cls.listensocket = net_test.IPv6TCPSocket()
    417     cls.listenport = net_test.BindRandomPort(6, cls.listensocket)
    418 
    419   def _SetTCPMarkAcceptSysctl(self, value):
    420     self.SetSysctl(TCP_MARK_ACCEPT_SYSCTL, value)
    421 
    422   def CheckTCPConnection(self, mode, listensocket, netid, version,
    423                          myaddr, remoteaddr, packet, reply, msg):
    424     establishing_ack = packets.ACK(version, remoteaddr, myaddr, reply)[1]
    425 
    426     # Attempt to confuse the kernel.
    427     self.InvalidateDstCache(version, netid)
    428 
    429     self.ReceivePacketOn(netid, establishing_ack)
    430 
    431     # If we're using UID routing, the accept() call has to be run as a UID that
    432     # is routed to the specified netid, because the UID of the socket returned
    433     # by accept() is the effective UID of the process that calls it. It doesn't
    434     # need to be the same UID; any UID that selects the same interface will do.
    435     with net_test.RunAsUid(self.UidForNetid(netid)):
    436       s, _ = listensocket.accept()
    437 
    438     try:
    439       # Check that data sent on the connection goes out on the right interface.
    440       desc, data = packets.ACK(version, myaddr, remoteaddr, establishing_ack,
    441                                payload=UDP_PAYLOAD)
    442       s.send(UDP_PAYLOAD)
    443       self.ExpectPacketOn(netid, msg + ": expecting %s" % desc, data)
    444       self.InvalidateDstCache(version, netid)
    445 
    446       # Keep up our end of the conversation.
    447       ack = packets.ACK(version, remoteaddr, myaddr, data)[1]
    448       self.InvalidateDstCache(version, netid)
    449       self.ReceivePacketOn(netid, ack)
    450 
    451       mark = self.GetSocketMark(s)
    452     finally:
    453       self.InvalidateDstCache(version, netid)
    454       s.close()
    455       self.InvalidateDstCache(version, netid)
    456 
    457     if mode == self.MODE_INCOMING_MARK:
    458       self.assertEquals(netid, mark & self.NETID_FWMASK,
    459                         msg + ": Accepted socket: Expected mark %d, got %d" % (
    460                             netid, mark))
    461     elif mode != self.MODE_EXPLICIT_MARK:
    462       self.assertEquals(0, self.GetSocketMark(listensocket))
    463 
    464     # Check the FIN was sent on the right interface, and ack it. We don't expect
    465     # this to fail because by the time the connection is established things are
    466     # likely working, but a) extra tests are always good and b) extra packets
    467     # like the FIN (and retransmitted FINs) could cause later tests that expect
    468     # no packets to fail.
    469     desc, fin = packets.FIN(version, myaddr, remoteaddr, ack)
    470     self.ExpectPacketOn(netid, msg + ": expecting %s after close" % desc, fin)
    471 
    472     desc, finack = packets.FIN(version, remoteaddr, myaddr, fin)
    473     self.ReceivePacketOn(netid, finack)
    474 
    475     # Since we called close() earlier, the userspace socket object is gone, so
    476     # the socket has no UID. If we're doing UID routing, the ack might be routed
    477     # incorrectly. Not much we can do here.
    478     desc, finackack = packets.ACK(version, myaddr, remoteaddr, finack)
    479     self.ExpectPacketOn(netid, msg + ": expecting final ack", finackack)
    480 
    481   def CheckTCP(self, version, modes):
    482     """Checks that incoming TCP connections work.
    483 
    484     Args:
    485       version: An integer, 4 or 6.
    486       modes: A list of modes to excercise.
    487     """
    488     for syncookies in [0, 2]:
    489       for mode in modes:
    490         for netid, iif, ip_if, myaddr, remoteaddr in self.Combinations(version):
    491           listensocket = self.listensocket
    492           listenport = listensocket.getsockname()[1]
    493 
    494           accept_sysctl = 1 if mode == self.MODE_INCOMING_MARK else 0
    495           self._SetTCPMarkAcceptSysctl(accept_sysctl)
    496           self.SetMarkReflectSysctls(accept_sysctl)
    497 
    498           bound_dev = iif if mode == self.MODE_BINDTODEVICE else None
    499           self.BindToDevice(listensocket, bound_dev)
    500 
    501           mark = netid if mode == self.MODE_EXPLICIT_MARK else 0
    502           self.SetSocketMark(listensocket, mark)
    503 
    504           uid = self.UidForNetid(netid) if mode == self.MODE_UID else 0
    505           os.fchown(listensocket.fileno(), uid, -1)
    506 
    507           # Generate the packet here instead of in the outer loop, so
    508           # subsequent TCP connections use different source ports and
    509           # retransmissions from old connections don't confuse subsequent
    510           # tests.
    511           desc, packet = packets.SYN(listenport, version, remoteaddr, myaddr)
    512 
    513           if mode:
    514             reply_desc, reply = packets.SYNACK(version, myaddr, remoteaddr,
    515                                                packet)
    516           else:
    517             reply_desc, reply = None, None
    518 
    519           extra = "mode=%s, syncookies=%d" % (mode, syncookies)
    520           msg = self._FormatMessage(iif, ip_if, extra, desc, reply_desc)
    521           reply = self._ReceiveAndExpectResponse(netid, packet, reply, msg)
    522           if reply:
    523             self.CheckTCPConnection(mode, listensocket, netid, version, myaddr,
    524                                     remoteaddr, packet, reply, msg)
    525 
    526   def testBasicTCP(self):
    527     self.CheckTCP(4, [None, self.MODE_BINDTODEVICE, self.MODE_EXPLICIT_MARK])
    528     self.CheckTCP(6, [None, self.MODE_BINDTODEVICE, self.MODE_EXPLICIT_MARK])
    529 
    530   def testIPv4MarkAccept(self):
    531     self.CheckTCP(4, [self.MODE_INCOMING_MARK])
    532 
    533   def testIPv6MarkAccept(self):
    534     self.CheckTCP(6, [self.MODE_INCOMING_MARK])
    535 
    536   def testIPv4UidAccept(self):
    537     self.CheckTCP(4, [self.MODE_UID])
    538 
    539   def testIPv6UidAccept(self):
    540     self.CheckTCP(6, [self.MODE_UID])
    541 
    542   def testIPv6ExplicitMark(self):
    543     self.CheckTCP(6, [self.MODE_EXPLICIT_MARK])
    544 
    545 @unittest.skipUnless(multinetwork_base.HAVE_AUTOCONF_TABLE,
    546                      "need support for per-table autoconf")
    547 class RIOTest(multinetwork_base.MultiNetworkBaseTest):
    548   """Test for IPv6 RFC 4191 route information option
    549 
    550   Relevant kernel commits:
    551     upstream:
    552       f104a567e673 ipv6: use rt6_get_dflt_router to get default router in rt6_route_rcv
    553       bbea124bc99d net: ipv6: Add sysctl for minimum prefix len acceptable in RIOs.
    554 
    555     android-4.9:
    556       d860b2e8a7f1 FROMLIST: net: ipv6: Add sysctl for minimum prefix len acceptable in RIOs
    557 
    558     android-4.4:
    559       e953f89b8563 net: ipv6: Add sysctl for minimum prefix len acceptable in RIOs.
    560 
    561     android-4.1:
    562       84f2f47716cd net: ipv6: Add sysctl for minimum prefix len acceptable in RIOs.
    563 
    564     android-3.18:
    565       65f8936934fa net: ipv6: Add sysctl for minimum prefix len acceptable in RIOs.
    566 
    567     android-3.10:
    568       161e88ebebc7 net: ipv6: Add sysctl for minimum prefix len acceptable in RIOs.
    569 
    570   """
    571 
    572   def setUp(self):
    573     super(RIOTest, self).setUp()
    574     self.NETID = random.choice(self.NETIDS)
    575     self.IFACE = self.GetInterfaceName(self.NETID)
    576     # return min/max plen to default values before each test case
    577     self.SetAcceptRaRtInfoMinPlen(0)
    578     self.SetAcceptRaRtInfoMaxPlen(0)
    579 
    580   def GetRoutingTable(self):
    581     return self._TableForNetid(self.NETID)
    582 
    583   def SetAcceptRaRtInfoMinPlen(self, plen):
    584     self.SetSysctl(
    585         "/proc/sys/net/ipv6/conf/%s/accept_ra_rt_info_min_plen"
    586         % self.IFACE, plen)
    587 
    588   def GetAcceptRaRtInfoMinPlen(self):
    589     return int(self.GetSysctl(
    590         "/proc/sys/net/ipv6/conf/%s/accept_ra_rt_info_min_plen" % self.IFACE))
    591 
    592   def SetAcceptRaRtInfoMaxPlen(self, plen):
    593     self.SetSysctl(
    594         "/proc/sys/net/ipv6/conf/%s/accept_ra_rt_info_max_plen"
    595         % self.IFACE, plen)
    596 
    597   def GetAcceptRaRtInfoMaxPlen(self):
    598     return int(self.GetSysctl(
    599         "/proc/sys/net/ipv6/conf/%s/accept_ra_rt_info_max_plen" % self.IFACE))
    600 
    601   def SendRIO(self, rtlifetime, plen, prefix, prf):
    602     options = scapy.ICMPv6NDOptRouteInfo(rtlifetime=rtlifetime, plen=plen,
    603                                          prefix=prefix, prf=prf)
    604     self.SendRA(self.NETID, options=(options,))
    605 
    606   def FindRoutesWithDestination(self, destination):
    607     canonical = net_test.CanonicalizeIPv6Address(destination)
    608     return [r for _, r in self.iproute.DumpRoutes(6, self.GetRoutingTable())
    609             if ('RTA_DST' in r and r['RTA_DST'] == canonical)]
    610 
    611   def FindRoutesWithGateway(self):
    612     return [r for _, r in self.iproute.DumpRoutes(6, self.GetRoutingTable())
    613             if 'RTA_GATEWAY' in r]
    614 
    615   def CountRoutes(self):
    616     return len(self.iproute.DumpRoutes(6, self.GetRoutingTable()))
    617 
    618   def GetRouteExpiration(self, route):
    619     return float(route['RTA_CACHEINFO'].expires) / 100.0
    620 
    621   def AssertExpirationInRange(self, routes, lifetime, epsilon):
    622     self.assertTrue(routes)
    623     found = False
    624     # Assert that at least one route in routes has the expected lifetime
    625     for route in routes:
    626       expiration = self.GetRouteExpiration(route)
    627       if expiration < lifetime - epsilon:
    628         continue
    629       if expiration > lifetime + epsilon:
    630         continue
    631       found = True
    632     self.assertTrue(found)
    633 
    634   def DelRA6(self, prefix, plen):
    635     version = 6
    636     netid = self.NETID
    637     table = self._TableForNetid(netid)
    638     router = self._RouterAddress(netid, version)
    639     ifindex = self.ifindices[netid]
    640     # We actually want to specify RTPROT_RA, however an upstream
    641     # kernel bug causes RAs to be installed with RTPROT_BOOT.
    642     if HAVE_RTPROT_RA:
    643        rtprot = iproute.RTPROT_RA
    644     else:
    645        rtprot = iproute.RTPROT_BOOT
    646     self.iproute._Route(version, rtprot, iproute.RTM_DELROUTE,
    647                         table, prefix, plen, router, ifindex, None, None)
    648 
    649   def testSetAcceptRaRtInfoMinPlen(self):
    650     for plen in xrange(-1, 130):
    651       self.SetAcceptRaRtInfoMinPlen(plen)
    652       self.assertEquals(plen, self.GetAcceptRaRtInfoMinPlen())
    653 
    654   def testSetAcceptRaRtInfoMaxPlen(self):
    655     for plen in xrange(-1, 130):
    656       self.SetAcceptRaRtInfoMaxPlen(plen)
    657       self.assertEquals(plen, self.GetAcceptRaRtInfoMaxPlen())
    658 
    659   def testZeroRtLifetime(self):
    660     PREFIX = "2001:db8:8901:2300::"
    661     RTLIFETIME = 73500
    662     PLEN = 56
    663     PRF = 0
    664     self.SetAcceptRaRtInfoMaxPlen(PLEN)
    665     self.SendRIO(RTLIFETIME, PLEN, PREFIX, PRF)
    666     # Give the kernel time to notice our RA
    667     time.sleep(0.01)
    668     self.assertTrue(self.FindRoutesWithDestination(PREFIX))
    669     # RIO with rtlifetime = 0 should remove from routing table
    670     self.SendRIO(0, PLEN, PREFIX, PRF)
    671     # Give the kernel time to notice our RA
    672     time.sleep(0.01)
    673     self.assertFalse(self.FindRoutesWithDestination(PREFIX))
    674 
    675   def testMinPrefixLenRejection(self):
    676     PREFIX = "2001:db8:8902:2345::"
    677     RTLIFETIME = 70372
    678     PRF = 0
    679     # sweep from high to low to avoid spurious failures from late arrivals.
    680     for plen in xrange(130, 1, -1):
    681       self.SetAcceptRaRtInfoMinPlen(plen)
    682       # RIO with plen < min_plen should be ignored
    683       self.SendRIO(RTLIFETIME, plen - 1, PREFIX, PRF)
    684     # Give the kernel time to notice our RAs
    685     time.sleep(0.1)
    686     # Expect no routes
    687     routes = self.FindRoutesWithDestination(PREFIX)
    688     self.assertFalse(routes)
    689 
    690   def testMaxPrefixLenRejection(self):
    691     PREFIX = "2001:db8:8903:2345::"
    692     RTLIFETIME = 73078
    693     PRF = 0
    694     # sweep from low to high to avoid spurious failures from late arrivals.
    695     for plen in xrange(-1, 128, 1):
    696       self.SetAcceptRaRtInfoMaxPlen(plen)
    697       # RIO with plen > max_plen should be ignored
    698       self.SendRIO(RTLIFETIME, plen + 1, PREFIX, PRF)
    699     # Give the kernel time to notice our RAs
    700     time.sleep(0.1)
    701     # Expect no routes
    702     routes = self.FindRoutesWithDestination(PREFIX)
    703     self.assertFalse(routes)
    704 
    705   def testSimpleAccept(self):
    706     PREFIX = "2001:db8:8904:2345::"
    707     RTLIFETIME = 9993
    708     PRF = 0
    709     PLEN = 56
    710     self.SetAcceptRaRtInfoMinPlen(48)
    711     self.SetAcceptRaRtInfoMaxPlen(64)
    712     self.SendRIO(RTLIFETIME, PLEN, PREFIX, PRF)
    713     # Give the kernel time to notice our RA
    714     time.sleep(0.01)
    715     routes = self.FindRoutesWithGateway()
    716     self.AssertExpirationInRange(routes, RTLIFETIME, 1)
    717     self.DelRA6(PREFIX, PLEN)
    718 
    719   def testEqualMinMaxAccept(self):
    720     PREFIX = "2001:db8:8905:2345::"
    721     RTLIFETIME = 6326
    722     PLEN = 21
    723     PRF = 0
    724     self.SetAcceptRaRtInfoMinPlen(PLEN)
    725     self.SetAcceptRaRtInfoMaxPlen(PLEN)
    726     self.SendRIO(RTLIFETIME, PLEN, PREFIX, PRF)
    727     # Give the kernel time to notice our RA
    728     time.sleep(0.01)
    729     routes = self.FindRoutesWithGateway()
    730     self.AssertExpirationInRange(routes, RTLIFETIME, 1)
    731     self.DelRA6(PREFIX, PLEN)
    732 
    733   def testZeroLengthPrefix(self):
    734     PREFIX = "2001:db8:8906:2345::"
    735     RTLIFETIME = self.RA_VALIDITY * 2
    736     PLEN = 0
    737     PRF = 0
    738     # Max plen = 0 still allows default RIOs!
    739     self.SetAcceptRaRtInfoMaxPlen(PLEN)
    740     self.SendRA(self.NETID)
    741     # Give the kernel time to notice our RA
    742     time.sleep(0.01)
    743     default = self.FindRoutesWithGateway()
    744     self.AssertExpirationInRange(default, self.RA_VALIDITY, 1)
    745     # RIO with prefix length = 0, should overwrite default route lifetime
    746     # note that the RIO lifetime overwrites the RA lifetime.
    747     self.SendRIO(RTLIFETIME, PLEN, PREFIX, PRF)
    748     # Give the kernel time to notice our RA
    749     time.sleep(0.01)
    750     default = self.FindRoutesWithGateway()
    751     self.AssertExpirationInRange(default, RTLIFETIME, 1)
    752     self.DelRA6(PREFIX, PLEN)
    753 
    754   def testManyRIOs(self):
    755     RTLIFETIME = 68012
    756     PLEN = 56
    757     PRF = 0
    758     COUNT = 1000
    759     baseline = self.CountRoutes()
    760     self.SetAcceptRaRtInfoMaxPlen(56)
    761     # Send many RIOs compared to the expected number on a healthy system.
    762     for i in xrange(0, COUNT):
    763       prefix = "2001:db8:%x:1100::" % i
    764       self.SendRIO(RTLIFETIME, PLEN, prefix, PRF)
    765     time.sleep(0.1)
    766     self.assertEquals(COUNT + baseline, self.CountRoutes())
    767     for i in xrange(0, COUNT):
    768       prefix = "2001:db8:%x:1100::" % i
    769       self.DelRA6(prefix, PLEN)
    770     # Expect that we can return to baseline config without lingering routes.
    771     self.assertEquals(baseline, self.CountRoutes())
    772 
    773 class RATest(multinetwork_base.MultiNetworkBaseTest):
    774 
    775   def testDoesNotHaveObsoleteSysctl(self):
    776     self.assertFalse(os.path.isfile(
    777         "/proc/sys/net/ipv6/route/autoconf_table_offset"))
    778 
    779   @unittest.skipUnless(multinetwork_base.HAVE_AUTOCONF_TABLE,
    780                        "no support for per-table autoconf")
    781   def testPurgeDefaultRouters(self):
    782 
    783     def CheckIPv6Connectivity(expect_connectivity):
    784       for netid in self.NETIDS:
    785         s = net_test.UDPSocket(AF_INET6)
    786         self.SetSocketMark(s, netid)
    787         if expect_connectivity:
    788           self.assertTrue(s.sendto(UDP_PAYLOAD, (net_test.IPV6_ADDR, 1234)))
    789         else:
    790           self.assertRaisesErrno(errno.ENETUNREACH, s.sendto, UDP_PAYLOAD,
    791                                  (net_test.IPV6_ADDR, 1234))
    792 
    793     try:
    794       CheckIPv6Connectivity(True)
    795       self.SetIPv6SysctlOnAllIfaces("accept_ra", 1)
    796       self.SetSysctl("/proc/sys/net/ipv6/conf/all/forwarding", 1)
    797       CheckIPv6Connectivity(False)
    798     finally:
    799       self.SetSysctl("/proc/sys/net/ipv6/conf/all/forwarding", 0)
    800       for netid in self.NETIDS:
    801         self.SendRA(netid)
    802       CheckIPv6Connectivity(True)
    803 
    804   def testOnlinkCommunication(self):
    805     """Checks that on-link communication goes direct and not through routers."""
    806     for netid in self.tuns:
    807       # Send a UDP packet to a random on-link destination.
    808       s = net_test.UDPSocket(AF_INET6)
    809       iface = self.GetInterfaceName(netid)
    810       self.BindToDevice(s, iface)
    811       # dstaddr can never be our address because GetRandomDestination only fills
    812       # in the lower 32 bits, but our address has 0xff in the byte before that
    813       # (since it's constructed from the EUI-64 and so has ff:fe in the middle).
    814       dstaddr = self.GetRandomDestination(self.OnlinkPrefix(6, netid))
    815       s.sendto(UDP_PAYLOAD, (dstaddr, 53))
    816 
    817       # Expect an NS for that destination on the interface.
    818       myaddr = self.MyAddress(6, netid)
    819       mymac = self.MyMacAddress(netid)
    820       desc, expected = packets.NS(myaddr, dstaddr, mymac)
    821       msg = "Sending UDP packet to on-link destination: expecting %s" % desc
    822       time.sleep(0.0001)  # Required to make the test work on kernel 3.1(!)
    823       self.ExpectPacketOn(netid, msg, expected)
    824 
    825       # Send an NA.
    826       tgtmac = "02:00:00:00:%02x:99" % netid
    827       _, reply = packets.NA(dstaddr, myaddr, tgtmac)
    828       # Don't use ReceivePacketOn, since that uses the router's MAC address as
    829       # the source. Instead, construct our own Ethernet header with source
    830       # MAC of tgtmac.
    831       reply = scapy.Ether(src=tgtmac, dst=mymac) / reply
    832       self.ReceiveEtherPacketOn(netid, reply)
    833 
    834       # Expect the kernel to send the original UDP packet now that the ND cache
    835       # entry has been populated.
    836       sport = s.getsockname()[1]
    837       desc, expected = packets.UDP(6, myaddr, dstaddr, sport=sport)
    838       msg = "After NA response, expecting %s" % desc
    839       self.ExpectPacketOn(netid, msg, expected)
    840 
    841   # This test documents a known issue: routing tables are never deleted.
    842   @unittest.skipUnless(multinetwork_base.HAVE_AUTOCONF_TABLE,
    843                        "no support for per-table autoconf")
    844   def testLeftoverRoutes(self):
    845     def GetNumRoutes():
    846       return len(open("/proc/net/ipv6_route").readlines())
    847 
    848     num_routes = GetNumRoutes()
    849     for i in xrange(10, 20):
    850       try:
    851         self.tuns[i] = self.CreateTunInterface(i)
    852         self.SendRA(i)
    853         self.tuns[i].close()
    854       finally:
    855         del self.tuns[i]
    856     self.assertLess(num_routes, GetNumRoutes())
    857 
    858 
    859 class PMTUTest(multinetwork_base.InboundMarkingTest):
    860 
    861   PAYLOAD_SIZE = 1400
    862   dstaddrs = set()
    863 
    864   def GetSocketMTU(self, version, s):
    865     if version == 6:
    866       ip6_mtuinfo = s.getsockopt(net_test.SOL_IPV6, csocket.IPV6_PATHMTU, 32)
    867       unused_sockaddr, mtu = struct.unpack("=28sI", ip6_mtuinfo)
    868       return mtu
    869     else:
    870       return s.getsockopt(net_test.SOL_IP, csocket.IP_MTU)
    871 
    872   def DisableFragmentationAndReportErrors(self, version, s):
    873     if version == 4:
    874       s.setsockopt(net_test.SOL_IP, csocket.IP_MTU_DISCOVER,
    875                    csocket.IP_PMTUDISC_DO)
    876       s.setsockopt(net_test.SOL_IP, net_test.IP_RECVERR, 1)
    877     else:
    878       s.setsockopt(net_test.SOL_IPV6, csocket.IPV6_DONTFRAG, 1)
    879       s.setsockopt(net_test.SOL_IPV6, net_test.IPV6_RECVERR, 1)
    880 
    881   def CheckPMTU(self, version, use_connect, modes):
    882 
    883     def SendBigPacket(version, s, dstaddr, netid, payload):
    884       if use_connect:
    885         s.send(payload)
    886       else:
    887         self.SendOnNetid(version, s, dstaddr, 1234, netid, payload, [])
    888 
    889     for netid in self.tuns:
    890       for mode in modes:
    891         s = self.BuildSocket(version, net_test.UDPSocket, netid, mode)
    892         self.DisableFragmentationAndReportErrors(version, s)
    893 
    894         srcaddr = self.MyAddress(version, netid)
    895         dst_prefix, intermediate = {
    896             4: ("172.19.", "172.16.9.12"),
    897             6: ("2001:db8::", "2001:db8::1")
    898         }[version]
    899 
    900         # Run this test often enough (e.g., in presubmits), and eventually
    901         # we'll be unlucky enough to pick the same address twice, in which
    902         # case the test will fail because the kernel will already have seen
    903         # the lower MTU. Don't do this.
    904         dstaddr = self.GetRandomDestination(dst_prefix)
    905         while dstaddr in self.dstaddrs:
    906           dstaddr = self.GetRandomDestination(dst_prefix)
    907         self.dstaddrs.add(dstaddr)
    908 
    909         if use_connect:
    910           s.connect((dstaddr, 1234))
    911 
    912         payload = self.PAYLOAD_SIZE * "a"
    913 
    914         # Send a packet and receive a packet too big.
    915         SendBigPacket(version, s, dstaddr, netid, payload)
    916         received = self.ReadAllPacketsOn(netid)
    917         self.assertEquals(1, len(received),
    918                           "unexpected packets: %s" % received[1:])
    919         _, toobig = packets.ICMPPacketTooBig(version, intermediate, srcaddr,
    920                                              received[0])
    921         self.ReceivePacketOn(netid, toobig)
    922 
    923         # Check that another send on the same socket returns EMSGSIZE.
    924         self.assertRaisesErrno(
    925             errno.EMSGSIZE,
    926             SendBigPacket, version, s, dstaddr, netid, payload)
    927 
    928         # If this is a connected socket, make sure the socket MTU was set.
    929         # Note that in IPv4 this only started working in Linux 3.6!
    930         if use_connect and (version == 6 or net_test.LINUX_VERSION >= (3, 6)):
    931           self.assertEquals(packets.PTB_MTU, self.GetSocketMTU(version, s))
    932 
    933         s.close()
    934 
    935         # Check that other sockets pick up the PMTU we have been told about by
    936         # connecting another socket to the same destination and getting its MTU.
    937         # This new socket can use any method to select its outgoing interface;
    938         # here we use a mark for simplicity.
    939         s2 = self.BuildSocket(version, net_test.UDPSocket, netid, "mark")
    940         s2.connect((dstaddr, 1234))
    941         self.assertEquals(packets.PTB_MTU, self.GetSocketMTU(version, s2))
    942 
    943         # Also check the MTU reported by ip route get, this time using the oif.
    944         routes = self.iproute.GetRoutes(dstaddr, self.ifindices[netid], 0, None)
    945         self.assertTrue(routes)
    946         route = routes[0]
    947         rtmsg, attributes = route
    948         self.assertEquals(iproute.RTN_UNICAST, rtmsg.type)
    949         metrics = attributes["RTA_METRICS"]
    950         self.assertEquals(packets.PTB_MTU, metrics["RTAX_MTU"])
    951 
    952   def testIPv4BasicPMTU(self):
    953     """Tests IPv4 path MTU discovery.
    954 
    955     Relevant kernel commits:
    956       upstream net-next:
    957         6a66271 ipv4, fib: pass LOOPBACK_IFINDEX instead of 0 to flowi4_iif
    958 
    959       android-3.10:
    960         4bc64dd ipv4, fib: pass LOOPBACK_IFINDEX instead of 0 to flowi4_iif
    961     """
    962 
    963     self.CheckPMTU(4, True, ["mark", "oif"])
    964     self.CheckPMTU(4, False, ["mark", "oif"])
    965 
    966   def testIPv6BasicPMTU(self):
    967     self.CheckPMTU(6, True, ["mark", "oif"])
    968     self.CheckPMTU(6, False, ["mark", "oif"])
    969 
    970   def testIPv4UIDPMTU(self):
    971     self.CheckPMTU(4, True, ["uid"])
    972     self.CheckPMTU(4, False, ["uid"])
    973 
    974   def testIPv6UIDPMTU(self):
    975     self.CheckPMTU(6, True, ["uid"])
    976     self.CheckPMTU(6, False, ["uid"])
    977 
    978   # Making Path MTU Discovery work on unmarked  sockets requires that mark
    979   # reflection be enabled. Otherwise the kernel has no way to know what routing
    980   # table the original packet used, and thus it won't be able to clone the
    981   # correct route.
    982 
    983   def testIPv4UnmarkedSocketPMTU(self):
    984     self.SetMarkReflectSysctls(1)
    985     try:
    986       self.CheckPMTU(4, False, [None])
    987     finally:
    988       self.SetMarkReflectSysctls(0)
    989 
    990   def testIPv6UnmarkedSocketPMTU(self):
    991     self.SetMarkReflectSysctls(1)
    992     try:
    993       self.CheckPMTU(6, False, [None])
    994     finally:
    995       self.SetMarkReflectSysctls(0)
    996 
    997 
    998 class UidRoutingTest(multinetwork_base.MultiNetworkBaseTest):
    999   """Tests that per-UID routing works properly.
   1000 
   1001   Relevant kernel commits:
   1002     upstream net-next:
   1003       7d99569460 net: ipv4: Don't crash if passing a null sk to ip_do_redirect.
   1004       d109e61bfe net: ipv4: Don't crash if passing a null sk to ip_rt_update_pmtu.
   1005       35b80733b3 net: core: add missing check for uid_range in rule_exists.
   1006       e2d118a1cb net: inet: Support UID-based routing in IP protocols.
   1007       622ec2c9d5 net: core: add UID to flows, rules, and routes
   1008       86741ec254 net: core: Add a UID field to struct sock.
   1009 
   1010     android-3.18:
   1011       b004e79504 net: ipv4: Don't crash if passing a null sk to ip_rt_update_pmtu.
   1012       04c0eace81 net: inet: Support UID-based routing in IP protocols.
   1013       18c36d7b71 net: core: add UID to flows, rules, and routes
   1014       80e3440721 net: core: Add a UID field to struct sock.
   1015       fa8cc2c30c Revert "net: core: Support UID-based routing."
   1016       b585141890 Revert "Handle 'sk' being NULL in UID-based routing."
   1017       5115ab7514 Revert "net: core: fix UID-based routing build"
   1018       f9f4281f79 Revert "ANDROID: net: fib: remove duplicate assignment"
   1019 
   1020     android-4.4:
   1021       341965cf10 net: ipv4: Don't crash if passing a null sk to ip_rt_update_pmtu.
   1022       344afd627c net: inet: Support UID-based routing in IP protocols.
   1023       03441d56d8 net: core: add UID to flows, rules, and routes
   1024       eb964bdba7 net: core: Add a UID field to struct sock.
   1025       9789b697c6 Revert "net: core: Support UID-based routing."
   1026   """
   1027 
   1028   def GetRulesAtPriority(self, version, priority):
   1029     rules = self.iproute.DumpRules(version)
   1030     out = [(rule, attributes) for rule, attributes in rules
   1031            if attributes.get("FRA_PRIORITY", 0) == priority]
   1032     return out
   1033 
   1034   def CheckInitialTablesHaveNoUIDs(self, version):
   1035     rules = []
   1036     for priority in [0, 32766, 32767]:
   1037       rules.extend(self.GetRulesAtPriority(version, priority))
   1038     for _, attributes in rules:
   1039       self.assertNotIn("FRA_UID_RANGE", attributes)
   1040 
   1041   def testIPv4InitialTablesHaveNoUIDs(self):
   1042     self.CheckInitialTablesHaveNoUIDs(4)
   1043 
   1044   def testIPv6InitialTablesHaveNoUIDs(self):
   1045     self.CheckInitialTablesHaveNoUIDs(6)
   1046 
   1047   @staticmethod
   1048   def _Random():
   1049     return random.randint(1000000, 2000000)
   1050 
   1051   def CheckGetAndSetRules(self, version):
   1052     start, end = tuple(sorted([self._Random(), self._Random()]))
   1053     table = self._Random()
   1054     priority = self._Random()
   1055 
   1056     # Can't create a UID range to UID -1 because -1 is INVALID_UID...
   1057     self.assertRaisesErrno(
   1058         errno.EINVAL,
   1059         self.iproute.UidRangeRule, version, True, 100, 0xffffffff, table,
   1060         priority)
   1061 
   1062     # ... but -2 is valid.
   1063     self.iproute.UidRangeRule(version, True, 100, 0xfffffffe, table, priority)
   1064     self.iproute.UidRangeRule(version, False, 100, 0xfffffffe, table, priority)
   1065 
   1066     try:
   1067       # Create a UID range rule.
   1068       self.iproute.UidRangeRule(version, True, start, end, table, priority)
   1069 
   1070       # Check that deleting the wrong UID range doesn't work.
   1071       self.assertRaisesErrno(
   1072           errno.ENOENT,
   1073           self.iproute.UidRangeRule, version, False, start, end + 1, table,
   1074           priority)
   1075       self.assertRaisesErrno(errno.ENOENT,
   1076         self.iproute.UidRangeRule, version, False, start + 1, end, table,
   1077         priority)
   1078 
   1079       # Check that the UID range appears in dumps.
   1080       rules = self.GetRulesAtPriority(version, priority)
   1081       self.assertTrue(rules)
   1082       _, attributes = rules[-1]
   1083       self.assertEquals(priority, attributes["FRA_PRIORITY"])
   1084       uidrange = attributes["FRA_UID_RANGE"]
   1085       self.assertEquals(start, uidrange.start)
   1086       self.assertEquals(end, uidrange.end)
   1087       self.assertEquals(table, attributes["FRA_TABLE"])
   1088     finally:
   1089       self.iproute.UidRangeRule(version, False, start, end, table, priority)
   1090       self.assertRaisesErrno(
   1091           errno.ENOENT,
   1092           self.iproute.UidRangeRule, version, False, start, end, table,
   1093           priority)
   1094 
   1095     fwmask = 0xfefefefe
   1096     try:
   1097       # Create a rule without a UID range.
   1098       self.iproute.FwmarkRule(version, True, 300, fwmask, 301, priority + 1)
   1099 
   1100       # Check it doesn't have a UID range.
   1101       rules = self.GetRulesAtPriority(version, priority + 1)
   1102       self.assertTrue(rules)
   1103       for _, attributes in rules:
   1104         self.assertIn("FRA_TABLE", attributes)
   1105         self.assertNotIn("FRA_UID_RANGE", attributes)
   1106     finally:
   1107       self.iproute.FwmarkRule(version, False, 300, fwmask, 301, priority + 1)
   1108 
   1109     # Test that EEXIST worksfor UID range rules too. This behaviour was only
   1110     # added in 4.8.
   1111     if net_test.LINUX_VERSION >= (4, 8, 0):
   1112       ranges = [(100, 101), (100, 102), (99, 101), (1234, 5678)]
   1113       dup = ranges[0]
   1114       try:
   1115         # Check that otherwise identical rules with different UID ranges can be
   1116         # created without EEXIST.
   1117         for start, end in ranges:
   1118           self.iproute.UidRangeRule(version, True, start, end, table, priority)
   1119         # ... but EEXIST is returned if the UID range is identical.
   1120         self.assertRaisesErrno(
   1121           errno.EEXIST,
   1122           self.iproute.UidRangeRule, version, True, dup[0], dup[1], table,
   1123           priority)
   1124       finally:
   1125         # Clean up.
   1126         for start, end in ranges + [dup]:
   1127           try:
   1128             self.iproute.UidRangeRule(version, False, start, end, table,
   1129                                       priority)
   1130           except IOError:
   1131             pass
   1132 
   1133   def testIPv4GetAndSetRules(self):
   1134     self.CheckGetAndSetRules(4)
   1135 
   1136   def testIPv6GetAndSetRules(self):
   1137     self.CheckGetAndSetRules(6)
   1138 
   1139   @unittest.skipUnless(net_test.LINUX_VERSION >= (4, 9, 0), "not backported")
   1140   def testDeleteErrno(self):
   1141     for version in [4, 6]:
   1142       table = self._Random()
   1143       priority = self._Random()
   1144       self.assertRaisesErrno(
   1145           errno.EINVAL,
   1146           self.iproute.UidRangeRule, version, False, 100, 0xffffffff, table,
   1147           priority)
   1148 
   1149   def ExpectNoRoute(self, addr, oif, mark, uid):
   1150     # The lack of a route may be either an error, or an unreachable route.
   1151     try:
   1152       routes = self.iproute.GetRoutes(addr, oif, mark, uid)
   1153       rtmsg, _ = routes[0]
   1154       self.assertEquals(iproute.RTN_UNREACHABLE, rtmsg.type)
   1155     except IOError, e:
   1156       if int(e.errno) != int(errno.ENETUNREACH):
   1157         raise e
   1158 
   1159   def ExpectRoute(self, addr, oif, mark, uid):
   1160     routes = self.iproute.GetRoutes(addr, oif, mark, uid)
   1161     rtmsg, _ = routes[0]
   1162     self.assertEquals(iproute.RTN_UNICAST, rtmsg.type)
   1163 
   1164   def CheckGetRoute(self, version, addr):
   1165     self.ExpectNoRoute(addr, 0, 0, 0)
   1166     for netid in self.NETIDS:
   1167       uid = self.UidForNetid(netid)
   1168       self.ExpectRoute(addr, 0, 0, uid)
   1169     self.ExpectNoRoute(addr, 0, 0, 0)
   1170 
   1171   def testIPv4RouteGet(self):
   1172     self.CheckGetRoute(4, net_test.IPV4_ADDR)
   1173 
   1174   def testIPv6RouteGet(self):
   1175     self.CheckGetRoute(6, net_test.IPV6_ADDR)
   1176 
   1177   def testChangeFdAttributes(self):
   1178     netid = random.choice(self.NETIDS)
   1179     uid = self._Random()
   1180     table = self._TableForNetid(netid)
   1181     remoteaddr = self.GetRemoteAddress(6)
   1182     s = socket(AF_INET6, SOCK_DGRAM, 0)
   1183 
   1184     def CheckSendFails():
   1185       self.assertRaisesErrno(errno.ENETUNREACH,
   1186                              s.sendto, "foo", (remoteaddr, 53))
   1187     def CheckSendSucceeds():
   1188       self.assertEquals(len("foo"), s.sendto("foo", (remoteaddr, 53)))
   1189 
   1190     CheckSendFails()
   1191     self.iproute.UidRangeRule(6, True, uid, uid, table, self.PRIORITY_UID)
   1192     try:
   1193       CheckSendFails()
   1194       os.fchown(s.fileno(), uid, -1)
   1195       CheckSendSucceeds()
   1196       os.fchown(s.fileno(), -1, -1)
   1197       CheckSendSucceeds()
   1198       os.fchown(s.fileno(), -1, 12345)
   1199       CheckSendSucceeds()
   1200       os.fchmod(s.fileno(), 0777)
   1201       CheckSendSucceeds()
   1202       os.fchown(s.fileno(), 0, -1)
   1203       CheckSendFails()
   1204     finally:
   1205       self.iproute.UidRangeRule(6, False, uid, uid, table, self.PRIORITY_UID)
   1206 
   1207 
   1208 class RulesTest(net_test.NetworkTest):
   1209 
   1210   RULE_PRIORITY = 99999
   1211   FWMASK = 0xffffffff
   1212 
   1213   def setUp(self):
   1214     self.iproute = iproute.IPRoute()
   1215     for version in [4, 6]:
   1216       self.iproute.DeleteRulesAtPriority(version, self.RULE_PRIORITY)
   1217 
   1218   def tearDown(self):
   1219     for version in [4, 6]:
   1220       self.iproute.DeleteRulesAtPriority(version, self.RULE_PRIORITY)
   1221 
   1222   def testRuleDeletionMatchesTable(self):
   1223     for version in [4, 6]:
   1224       # Add rules with mark 300 pointing at tables 301 and 302.
   1225       # This checks for a kernel bug where deletion request for tables > 256
   1226       # ignored the table.
   1227       self.iproute.FwmarkRule(version, True, 300, self.FWMASK, 301,
   1228                               priority=self.RULE_PRIORITY)
   1229       self.iproute.FwmarkRule(version, True, 300, self.FWMASK, 302,
   1230                               priority=self.RULE_PRIORITY)
   1231       # Delete rule with mark 300 pointing at table 302.
   1232       self.iproute.FwmarkRule(version, False, 300, self.FWMASK, 302,
   1233                               priority=self.RULE_PRIORITY)
   1234       # Check that the rule pointing at table 301 is still around.
   1235       attributes = [a for _, a in self.iproute.DumpRules(version)
   1236                     if a.get("FRA_PRIORITY", 0) == self.RULE_PRIORITY]
   1237       self.assertEquals(1, len(attributes))
   1238       self.assertEquals(301, attributes[0]["FRA_TABLE"])
   1239 
   1240 
   1241 if __name__ == "__main__":
   1242   unittest.main()
   1243