1 #!/usr/bin/python 2 # 3 # Copyright 2015 The Android Open Source Project 4 # 5 # Licensed under the Apache License, Version 2.0 (the "License"); 6 # you may not use this file except in compliance with the License. 7 # You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 17 import contextlib 18 import errno 19 import fcntl 20 import resource 21 import os 22 from socket import * # pylint: disable=wildcard-import 23 import struct 24 import threading 25 import time 26 import unittest 27 28 import csocket 29 import cstruct 30 import net_test 31 32 IPV4_LOOPBACK_ADDR = "127.0.0.1" 33 IPV6_LOOPBACK_ADDR = "::1" 34 LOOPBACK_DEV = "lo" 35 LOOPBACK_IFINDEX = 1 36 37 SIOCKILLADDR = 0x8939 38 39 DEFAULT_TCP_PORT = 8001 40 DEFAULT_BUFFER_SIZE = 20 41 DEFAULT_TEST_MESSAGE = "TCP NUKE ADDR TEST" 42 DEFAULT_TEST_RUNS = 100 43 HASH_TEST_RUNS = 4000 44 HASH_TEST_NOFILE = 16384 45 46 47 Ifreq = cstruct.Struct("Ifreq", "=16s16s", "name data") 48 In6Ifreq = cstruct.Struct("In6Ifreq", "=16sIi", "addr prefixlen ifindex") 49 50 @contextlib.contextmanager 51 def RunInBackground(thread): 52 """Starts a thread and waits until it joins. 53 54 Args: 55 thread: A not yet started threading.Thread object. 56 """ 57 try: 58 thread.start() 59 yield thread 60 finally: 61 thread.join() 62 63 64 def TcpAcceptAndReceive(listening_sock, buffer_size=DEFAULT_BUFFER_SIZE): 65 """Accepts a single connection and blocks receiving data from it. 66 67 Args: 68 listening_socket: A socket in LISTEN state. 69 buffer_size: Size of buffer where to read a message. 70 """ 71 connection, _ = listening_sock.accept() 72 with contextlib.closing(connection): 73 _ = connection.recv(buffer_size) 74 75 76 def ExchangeMessage(addr_family, ip_addr): 77 """Creates a listening socket, accepts a connection and sends data to it. 78 79 Args: 80 addr_family: The address family (e.g. AF_INET6). 81 ip_addr: The IP address (IPv4 or IPv6 depending on the addr_family). 82 tcp_port: The TCP port to listen on. 83 """ 84 # Bind to a random port and connect to it. 85 test_addr = (ip_addr, 0) 86 with contextlib.closing( 87 socket(addr_family, SOCK_STREAM)) as listening_socket: 88 listening_socket.bind(test_addr) 89 test_addr = listening_socket.getsockname() 90 listening_socket.listen(1) 91 with RunInBackground(threading.Thread(target=TcpAcceptAndReceive, 92 args=(listening_socket,))): 93 with contextlib.closing( 94 socket(addr_family, SOCK_STREAM)) as client_socket: 95 client_socket.connect(test_addr) 96 client_socket.send(DEFAULT_TEST_MESSAGE) 97 98 99 def KillAddrIoctl(addr): 100 """Calls the SIOCKILLADDR ioctl on the provided IP address. 101 102 Args: 103 addr The IP address to pass to the ioctl. 104 105 Raises: 106 ValueError: If addr is of an unsupported address family. 107 """ 108 family, _, _, _, _ = getaddrinfo(addr, None, AF_UNSPEC, SOCK_DGRAM, 0, 109 AI_NUMERICHOST)[0] 110 if family == AF_INET6: 111 addr = inet_pton(AF_INET6, addr) 112 ifreq = In6Ifreq((addr, 128, LOOPBACK_IFINDEX)).Pack() 113 elif family == AF_INET: 114 addr = inet_pton(AF_INET, addr) 115 sockaddr = csocket.SockaddrIn((AF_INET, 0, addr)).Pack() 116 ifreq = Ifreq((LOOPBACK_DEV, sockaddr)).Pack() 117 else: 118 raise ValueError('Address family %r not supported.' % family) 119 datagram_socket = socket(family, SOCK_DGRAM) 120 fcntl.ioctl(datagram_socket.fileno(), SIOCKILLADDR, ifreq) 121 datagram_socket.close() 122 123 124 class ExceptionalReadThread(threading.Thread): 125 126 def __init__(self, sock): 127 self.sock = sock 128 self.exception = None 129 super(ExceptionalReadThread, self).__init__() 130 self.daemon = True 131 132 def run(self): 133 try: 134 read = self.sock.recv(4096) 135 except Exception, e: 136 self.exception = e 137 138 # For convenience. 139 def CreateIPv4SocketPair(): 140 return net_test.CreateSocketPair(AF_INET, SOCK_STREAM, IPV4_LOOPBACK_ADDR) 141 142 def CreateIPv6SocketPair(): 143 return net_test.CreateSocketPair(AF_INET6, SOCK_STREAM, IPV6_LOOPBACK_ADDR) 144 145 146 class TcpNukeAddrTest(net_test.NetworkTest): 147 148 def testTimewaitSockets(self): 149 """Tests that SIOCKILLADDR works as expected. 150 151 Relevant kernel commits: 152 https://www.codeaurora.org/cgit/quic/la/kernel/msm-3.18/commit/net/ipv4/tcp.c?h=aosp/android-3.10&id=1dcd3a1fa2fe78251cc91700eb1d384ab02e2dd6 153 """ 154 for i in xrange(DEFAULT_TEST_RUNS): 155 ExchangeMessage(AF_INET6, IPV6_LOOPBACK_ADDR) 156 KillAddrIoctl(IPV6_LOOPBACK_ADDR) 157 ExchangeMessage(AF_INET, IPV4_LOOPBACK_ADDR) 158 KillAddrIoctl(IPV4_LOOPBACK_ADDR) 159 # Test passes if kernel does not crash. 160 161 def testClosesIPv6Sockets(self): 162 """Tests that SIOCKILLADDR closes IPv6 sockets and unblocks threads.""" 163 164 threadpairs = [] 165 166 for i in xrange(DEFAULT_TEST_RUNS): 167 clientsock, acceptedsock = CreateIPv6SocketPair() 168 clientthread = ExceptionalReadThread(clientsock) 169 clientthread.start() 170 serverthread = ExceptionalReadThread(acceptedsock) 171 serverthread.start() 172 threadpairs.append((clientthread, serverthread)) 173 174 KillAddrIoctl(IPV6_LOOPBACK_ADDR) 175 176 def CheckThreadException(thread): 177 thread.join(100) 178 self.assertFalse(thread.is_alive()) 179 self.assertIsNotNone(thread.exception) 180 self.assertTrue(isinstance(thread.exception, IOError)) 181 self.assertEquals(errno.ETIMEDOUT, thread.exception.errno) 182 self.assertRaisesErrno(errno.ENOTCONN, thread.sock.getpeername) 183 self.assertRaisesErrno(errno.EISCONN, thread.sock.connect, 184 (IPV6_LOOPBACK_ADDR, 53)) 185 self.assertRaisesErrno(errno.EPIPE, thread.sock.send, "foo") 186 187 for clientthread, serverthread in threadpairs: 188 CheckThreadException(clientthread) 189 CheckThreadException(serverthread) 190 191 def assertSocketsClosed(self, socketpair): 192 for sock in socketpair: 193 self.assertRaisesErrno(errno.ENOTCONN, sock.getpeername) 194 195 def assertSocketsNotClosed(self, socketpair): 196 for sock in socketpair: 197 self.assertTrue(sock.getpeername()) 198 199 def testAddresses(self): 200 socketpair = CreateIPv4SocketPair() 201 KillAddrIoctl("::") 202 self.assertSocketsNotClosed(socketpair) 203 KillAddrIoctl("::1") 204 self.assertSocketsNotClosed(socketpair) 205 KillAddrIoctl("127.0.0.3") 206 self.assertSocketsNotClosed(socketpair) 207 KillAddrIoctl("0.0.0.0") 208 self.assertSocketsNotClosed(socketpair) 209 KillAddrIoctl("127.0.0.1") 210 self.assertSocketsClosed(socketpair) 211 212 socketpair = CreateIPv6SocketPair() 213 KillAddrIoctl("0.0.0.0") 214 self.assertSocketsNotClosed(socketpair) 215 KillAddrIoctl("127.0.0.1") 216 self.assertSocketsNotClosed(socketpair) 217 KillAddrIoctl("::2") 218 self.assertSocketsNotClosed(socketpair) 219 KillAddrIoctl("::") 220 self.assertSocketsNotClosed(socketpair) 221 KillAddrIoctl("::1") 222 self.assertSocketsClosed(socketpair) 223 224 225 class TcpNukeAddrHashTest(net_test.NetworkTest): 226 227 def setUp(self): 228 self.nofile = resource.getrlimit(resource.RLIMIT_NOFILE) 229 resource.setrlimit(resource.RLIMIT_NOFILE, (HASH_TEST_NOFILE, 230 HASH_TEST_NOFILE)) 231 232 def tearDown(self): 233 resource.setrlimit(resource.RLIMIT_NOFILE, self.nofile) 234 235 def testClosesAllSockets(self): 236 socketpairs = [] 237 for i in xrange(HASH_TEST_RUNS): 238 socketpairs.append(CreateIPv4SocketPair()) 239 socketpairs.append(CreateIPv6SocketPair()) 240 241 KillAddrIoctl(IPV4_LOOPBACK_ADDR) 242 KillAddrIoctl(IPV6_LOOPBACK_ADDR) 243 244 for socketpair in socketpairs: 245 for sock in socketpair: 246 self.assertRaisesErrno(errno.ENOTCONN, sock.getpeername) 247 248 249 if __name__ == "__main__": 250 unittest.main() 251