Home | History | Annotate | Download | only in net_test
      1 #!/usr/bin/python
      2 #
      3 # Copyright 2015 The Android Open Source Project
      4 #
      5 # Licensed under the Apache License, Version 2.0 (the "License");
      6 # you may not use this file except in compliance with the License.
      7 # You may obtain a copy of the License at
      8 #
      9 # http://www.apache.org/licenses/LICENSE-2.0
     10 #
     11 # Unless required by applicable law or agreed to in writing, software
     12 # distributed under the License is distributed on an "AS IS" BASIS,
     13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     14 # See the License for the specific language governing permissions and
     15 # limitations under the License.
     16 
     17 import contextlib
     18 import errno
     19 import fcntl
     20 import resource
     21 import os
     22 from socket import *  # pylint: disable=wildcard-import
     23 import struct
     24 import threading
     25 import time
     26 import unittest
     27 
     28 import csocket
     29 import cstruct
     30 import net_test
     31 
     32 IPV4_LOOPBACK_ADDR = "127.0.0.1"
     33 IPV6_LOOPBACK_ADDR = "::1"
     34 LOOPBACK_DEV = "lo"
     35 LOOPBACK_IFINDEX = 1
     36 
     37 SIOCKILLADDR = 0x8939
     38 
     39 DEFAULT_TCP_PORT = 8001
     40 DEFAULT_BUFFER_SIZE = 20
     41 DEFAULT_TEST_MESSAGE = "TCP NUKE ADDR TEST"
     42 DEFAULT_TEST_RUNS = 100
     43 HASH_TEST_RUNS = 4000
     44 HASH_TEST_NOFILE = 16384
     45 
     46 
     47 Ifreq = cstruct.Struct("Ifreq", "=16s16s", "name data")
     48 In6Ifreq = cstruct.Struct("In6Ifreq", "=16sIi", "addr prefixlen ifindex")
     49 
     50 @contextlib.contextmanager
     51 def RunInBackground(thread):
     52   """Starts a thread and waits until it joins.
     53 
     54   Args:
     55     thread: A not yet started threading.Thread object.
     56   """
     57   try:
     58     thread.start()
     59     yield thread
     60   finally:
     61     thread.join()
     62 
     63 
     64 def TcpAcceptAndReceive(listening_sock, buffer_size=DEFAULT_BUFFER_SIZE):
     65   """Accepts a single connection and blocks receiving data from it.
     66 
     67   Args:
     68     listening_socket: A socket in LISTEN state.
     69     buffer_size: Size of buffer where to read a message.
     70   """
     71   connection, _ = listening_sock.accept()
     72   with contextlib.closing(connection):
     73     _ = connection.recv(buffer_size)
     74 
     75 
     76 def ExchangeMessage(addr_family, ip_addr):
     77   """Creates a listening socket, accepts a connection and sends data to it.
     78 
     79   Args:
     80     addr_family: The address family (e.g. AF_INET6).
     81     ip_addr: The IP address (IPv4 or IPv6 depending on the addr_family).
     82     tcp_port: The TCP port to listen on.
     83   """
     84   # Bind to a random port and connect to it.
     85   test_addr = (ip_addr, 0)
     86   with contextlib.closing(
     87       socket(addr_family, SOCK_STREAM)) as listening_socket:
     88     listening_socket.bind(test_addr)
     89     test_addr = listening_socket.getsockname()
     90     listening_socket.listen(1)
     91     with RunInBackground(threading.Thread(target=TcpAcceptAndReceive,
     92                                           args=(listening_socket,))):
     93       with contextlib.closing(
     94           socket(addr_family, SOCK_STREAM)) as client_socket:
     95         client_socket.connect(test_addr)
     96         client_socket.send(DEFAULT_TEST_MESSAGE)
     97 
     98 
     99 def KillAddrIoctl(addr):
    100   """Calls the SIOCKILLADDR ioctl on the provided IP address.
    101 
    102   Args:
    103     addr The IP address to pass to the ioctl.
    104 
    105   Raises:
    106     ValueError: If addr is of an unsupported address family.
    107   """
    108   family, _, _, _, _ = getaddrinfo(addr, None, AF_UNSPEC, SOCK_DGRAM, 0,
    109                                    AI_NUMERICHOST)[0]
    110   if family == AF_INET6:
    111     addr = inet_pton(AF_INET6, addr)
    112     ifreq = In6Ifreq((addr, 128, LOOPBACK_IFINDEX)).Pack()
    113   elif family == AF_INET:
    114     addr = inet_pton(AF_INET, addr)
    115     sockaddr = csocket.SockaddrIn((AF_INET, 0, addr)).Pack()
    116     ifreq = Ifreq((LOOPBACK_DEV, sockaddr)).Pack()
    117   else:
    118     raise ValueError('Address family %r not supported.' % family)
    119   datagram_socket = socket(family, SOCK_DGRAM)
    120   fcntl.ioctl(datagram_socket.fileno(), SIOCKILLADDR, ifreq)
    121   datagram_socket.close()
    122 
    123 
    124 class ExceptionalReadThread(threading.Thread):
    125 
    126   def __init__(self, sock):
    127     self.sock = sock
    128     self.exception = None
    129     super(ExceptionalReadThread, self).__init__()
    130     self.daemon = True
    131 
    132   def run(self):
    133     try:
    134       read = self.sock.recv(4096)
    135     except Exception, e:
    136       self.exception = e
    137 
    138 # For convenience.
    139 def CreateIPv4SocketPair():
    140   return net_test.CreateSocketPair(AF_INET, SOCK_STREAM, IPV4_LOOPBACK_ADDR)
    141 
    142 def CreateIPv6SocketPair():
    143   return net_test.CreateSocketPair(AF_INET6, SOCK_STREAM, IPV6_LOOPBACK_ADDR)
    144 
    145 
    146 class TcpNukeAddrTest(net_test.NetworkTest):
    147 
    148   def testTimewaitSockets(self):
    149     """Tests that SIOCKILLADDR works as expected.
    150 
    151     Relevant kernel commits:
    152       https://www.codeaurora.org/cgit/quic/la/kernel/msm-3.18/commit/net/ipv4/tcp.c?h=aosp/android-3.10&id=1dcd3a1fa2fe78251cc91700eb1d384ab02e2dd6
    153     """
    154     for i in xrange(DEFAULT_TEST_RUNS):
    155       ExchangeMessage(AF_INET6, IPV6_LOOPBACK_ADDR)
    156       KillAddrIoctl(IPV6_LOOPBACK_ADDR)
    157       ExchangeMessage(AF_INET, IPV4_LOOPBACK_ADDR)
    158       KillAddrIoctl(IPV4_LOOPBACK_ADDR)
    159       # Test passes if kernel does not crash.
    160 
    161   def testClosesIPv6Sockets(self):
    162     """Tests that SIOCKILLADDR closes IPv6 sockets and unblocks threads."""
    163 
    164     threadpairs = []
    165 
    166     for i in xrange(DEFAULT_TEST_RUNS):
    167       clientsock, acceptedsock = CreateIPv6SocketPair()
    168       clientthread = ExceptionalReadThread(clientsock)
    169       clientthread.start()
    170       serverthread = ExceptionalReadThread(acceptedsock)
    171       serverthread.start()
    172       threadpairs.append((clientthread, serverthread))
    173 
    174     KillAddrIoctl(IPV6_LOOPBACK_ADDR)
    175 
    176     def CheckThreadException(thread):
    177       thread.join(100)
    178       self.assertFalse(thread.is_alive())
    179       self.assertIsNotNone(thread.exception)
    180       self.assertTrue(isinstance(thread.exception, IOError))
    181       self.assertEquals(errno.ETIMEDOUT, thread.exception.errno)
    182       self.assertRaisesErrno(errno.ENOTCONN, thread.sock.getpeername)
    183       self.assertRaisesErrno(errno.EISCONN, thread.sock.connect,
    184                              (IPV6_LOOPBACK_ADDR, 53))
    185       self.assertRaisesErrno(errno.EPIPE, thread.sock.send, "foo")
    186 
    187     for clientthread, serverthread in threadpairs:
    188       CheckThreadException(clientthread)
    189       CheckThreadException(serverthread)
    190 
    191   def assertSocketsClosed(self, socketpair):
    192     for sock in socketpair:
    193       self.assertRaisesErrno(errno.ENOTCONN, sock.getpeername)
    194 
    195   def assertSocketsNotClosed(self, socketpair):
    196     for sock in socketpair:
    197       self.assertTrue(sock.getpeername())
    198 
    199   def testAddresses(self):
    200     socketpair = CreateIPv4SocketPair()
    201     KillAddrIoctl("::")
    202     self.assertSocketsNotClosed(socketpair)
    203     KillAddrIoctl("::1")
    204     self.assertSocketsNotClosed(socketpair)
    205     KillAddrIoctl("127.0.0.3")
    206     self.assertSocketsNotClosed(socketpair)
    207     KillAddrIoctl("0.0.0.0")
    208     self.assertSocketsNotClosed(socketpair)
    209     KillAddrIoctl("127.0.0.1")
    210     self.assertSocketsClosed(socketpair)
    211 
    212     socketpair = CreateIPv6SocketPair()
    213     KillAddrIoctl("0.0.0.0")
    214     self.assertSocketsNotClosed(socketpair)
    215     KillAddrIoctl("127.0.0.1")
    216     self.assertSocketsNotClosed(socketpair)
    217     KillAddrIoctl("::2")
    218     self.assertSocketsNotClosed(socketpair)
    219     KillAddrIoctl("::")
    220     self.assertSocketsNotClosed(socketpair)
    221     KillAddrIoctl("::1")
    222     self.assertSocketsClosed(socketpair)
    223 
    224 
    225 class TcpNukeAddrHashTest(net_test.NetworkTest):
    226 
    227   def setUp(self):
    228     self.nofile = resource.getrlimit(resource.RLIMIT_NOFILE)
    229     resource.setrlimit(resource.RLIMIT_NOFILE, (HASH_TEST_NOFILE,
    230                                                 HASH_TEST_NOFILE))
    231 
    232   def tearDown(self):
    233     resource.setrlimit(resource.RLIMIT_NOFILE, self.nofile)
    234 
    235   def testClosesAllSockets(self):
    236     socketpairs = []
    237     for i in xrange(HASH_TEST_RUNS):
    238       socketpairs.append(CreateIPv4SocketPair())
    239       socketpairs.append(CreateIPv6SocketPair())
    240 
    241     KillAddrIoctl(IPV4_LOOPBACK_ADDR)
    242     KillAddrIoctl(IPV6_LOOPBACK_ADDR)
    243 
    244     for socketpair in socketpairs:
    245       for sock in socketpair:
    246         self.assertRaisesErrno(errno.ENOTCONN, sock.getpeername)
    247 
    248 
    249 if __name__ == "__main__":
    250   unittest.main()
    251