Home | History | Annotate | Download | only in test
      1 #!/usr/bin/python
      2 #
      3 # Copyright 2015 The Android Open Source Project
      4 #
      5 # Licensed under the Apache License, Version 2.0 (the "License");
      6 # you may not use this file except in compliance with the License.
      7 # You may obtain a copy of the License at
      8 #
      9 # http://www.apache.org/licenses/LICENSE-2.0
     10 #
     11 # Unless required by applicable law or agreed to in writing, software
     12 # distributed under the License is distributed on an "AS IS" BASIS,
     13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     14 # See the License for the specific language governing permissions and
     15 # limitations under the License.
     16 
     17 # pylint: disable=g-bad-todo,g-bad-file-header,wildcard-import
     18 from errno import *  # pylint: disable=wildcard-import
     19 import os
     20 import random
     21 import select
     22 from socket import *  # pylint: disable=wildcard-import
     23 import struct
     24 import threading
     25 import time
     26 import unittest
     27 
     28 import multinetwork_base
     29 import net_test
     30 import packets
     31 import sock_diag
     32 import tcp_test
     33 
     34 
     35 NUM_SOCKETS = 30
     36 NO_BYTECODE = ""
     37 HAVE_SO_COOKIE_SUPPORT = net_test.LINUX_VERSION >= (4, 9, 0)
     38 
     39 IPPROTO_SCTP = 132
     40 
     41 def HaveUdpDiag():
     42   # There is no way to tell whether a dump succeeded: if the appropriate handler
     43   # wasn't found, __inet_diag_dump just returns an empty result instead of an
     44   # error. So, just check to see if a UDP dump returns no sockets when we know
     45   # it should return one.
     46   s = socket(AF_INET6, SOCK_DGRAM, 0)
     47   s.bind(("::", 0))
     48   s.connect((s.getsockname()))
     49   sd = sock_diag.SockDiag()
     50   have_udp_diag = len(sd.DumpAllInetSockets(IPPROTO_UDP, "")) > 0
     51   s.close()
     52   return have_udp_diag
     53 
     54 def HaveSctp():
     55   if net_test.LINUX_VERSION < (4, 7, 0):
     56     return False
     57   try:
     58     s = socket(AF_INET, SOCK_STREAM, IPPROTO_SCTP)
     59     s.close()
     60     return True
     61   except IOError:
     62     return False
     63 
     64 HAVE_UDP_DIAG = HaveUdpDiag()
     65 HAVE_SCTP = HaveSctp()
     66 
     67 
     68 class SockDiagBaseTest(multinetwork_base.MultiNetworkBaseTest):
     69   """Basic tests for SOCK_DIAG functionality.
     70 
     71     Relevant kernel commits:
     72       android-3.4:
     73         ab4a727 net: inet_diag: zero out uninitialized idiag_{src,dst} fields
     74         99ee451 net: diag: support v4mapped sockets in inet_diag_find_one_icsk()
     75 
     76       android-3.10:
     77         3eb409b net: inet_diag: zero out uninitialized idiag_{src,dst} fields
     78         f77e059 net: diag: support v4mapped sockets in inet_diag_find_one_icsk()
     79 
     80       android-3.18:
     81         e603010 net: diag: support v4mapped sockets in inet_diag_find_one_icsk()
     82 
     83       android-4.4:
     84         525ee59 net: diag: support v4mapped sockets in inet_diag_find_one_icsk()
     85   """
     86   @staticmethod
     87   def _CreateLotsOfSockets(socktype):
     88     # Dict mapping (addr, sport, dport) tuples to socketpairs.
     89     socketpairs = {}
     90     for _ in xrange(NUM_SOCKETS):
     91       family, addr = random.choice([
     92           (AF_INET, "127.0.0.1"),
     93           (AF_INET6, "::1"),
     94           (AF_INET6, "::ffff:127.0.0.1")])
     95       socketpair = net_test.CreateSocketPair(family, socktype, addr)
     96       sport, dport = (socketpair[0].getsockname()[1],
     97                       socketpair[1].getsockname()[1])
     98       socketpairs[(addr, sport, dport)] = socketpair
     99     return socketpairs
    100 
    101   def assertSocketClosed(self, sock):
    102     self.assertRaisesErrno(ENOTCONN, sock.getpeername)
    103 
    104   def assertSocketConnected(self, sock):
    105     sock.getpeername()  # No errors? Socket is alive and connected.
    106 
    107   def assertSocketsClosed(self, socketpair):
    108     for sock in socketpair:
    109       self.assertSocketClosed(sock)
    110 
    111   def assertMarkIs(self, mark, attrs):
    112     self.assertEqual(mark, attrs.get("INET_DIAG_MARK", None))
    113 
    114   def assertSockInfoMatchesSocket(self, s, info):
    115     diag_msg, attrs = info
    116     family = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_DOMAIN)
    117     self.assertEqual(diag_msg.family, family)
    118 
    119     src, sport = s.getsockname()[0:2]
    120     self.assertEqual(diag_msg.id.src, self.sock_diag.PaddedAddress(src))
    121     self.assertEqual(diag_msg.id.sport, sport)
    122 
    123     if self.sock_diag.GetDestinationAddress(diag_msg) not in ["0.0.0.0", "::"]:
    124       dst, dport = s.getpeername()[0:2]
    125       self.assertEqual(diag_msg.id.dst, self.sock_diag.PaddedAddress(dst))
    126       self.assertEqual(diag_msg.id.dport, dport)
    127     else:
    128       self.assertRaisesErrno(ENOTCONN, s.getpeername)
    129 
    130     mark = s.getsockopt(SOL_SOCKET, net_test.SO_MARK)
    131     self.assertMarkIs(mark, attrs)
    132 
    133   def PackAndCheckBytecode(self, instructions):
    134     bytecode = self.sock_diag.PackBytecode(instructions)
    135     decoded = self.sock_diag.DecodeBytecode(bytecode)
    136     self.assertEquals(len(instructions), len(decoded))
    137     self.assertFalse("???" in decoded)
    138     return bytecode
    139 
    140   def _EventDuringBlockingCall(self, sock, call, expected_errno, event):
    141     """Simulates an external event during a blocking call on sock.
    142 
    143     Args:
    144       sock: The socket to use.
    145       call: A function, the call to make. Takes one parameter, sock.
    146       expected_errno: The value that call is expected to fail with, or None if
    147         call is expected to succeed.
    148       event: A function, the event that will happen during the blocking call.
    149         Takes one parameter, sock.
    150     """
    151     thread = SocketExceptionThread(sock, call)
    152     thread.start()
    153     time.sleep(0.1)
    154     event(sock)
    155     thread.join(1)
    156     self.assertFalse(thread.is_alive())
    157     if expected_errno is not None:
    158       self.assertIsNotNone(thread.exception)
    159       self.assertTrue(isinstance(thread.exception, IOError),
    160                       "Expected IOError, got %s" % thread.exception)
    161       self.assertEqual(expected_errno, thread.exception.errno)
    162     else:
    163       self.assertIsNone(thread.exception)
    164     self.assertSocketClosed(sock)
    165 
    166   def CloseDuringBlockingCall(self, sock, call, expected_errno):
    167     self._EventDuringBlockingCall(
    168         sock, call, expected_errno,
    169         lambda sock: self.sock_diag.CloseSocketFromFd(sock))
    170 
    171   def setUp(self):
    172     super(SockDiagBaseTest, self).setUp()
    173     self.sock_diag = sock_diag.SockDiag()
    174     self.socketpairs = {}
    175 
    176   def tearDown(self):
    177     for socketpair in self.socketpairs.values():
    178       for s in socketpair:
    179         s.close()
    180     super(SockDiagBaseTest, self).tearDown()
    181 
    182 
    183 class SockDiagTest(SockDiagBaseTest):
    184 
    185   def testFindsMappedSockets(self):
    186     """Tests that inet_diag_find_one_icsk can find mapped sockets."""
    187     socketpair = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM,
    188                                            "::ffff:127.0.0.1")
    189     for sock in socketpair:
    190       diag_msg = self.sock_diag.FindSockDiagFromFd(sock)
    191       diag_req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP)
    192       self.sock_diag.GetSockInfo(diag_req)
    193       # No errors? Good.
    194 
    195   def testFindsAllMySockets(self):
    196     """Tests that basic socket dumping works."""
    197     self.socketpairs = self._CreateLotsOfSockets(SOCK_STREAM)
    198     sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE)
    199     self.assertGreaterEqual(len(sockets), NUM_SOCKETS)
    200 
    201     # Find the cookies for all of our sockets.
    202     cookies = {}
    203     for diag_msg, unused_attrs in sockets:
    204       addr = self.sock_diag.GetSourceAddress(diag_msg)
    205       sport = diag_msg.id.sport
    206       dport = diag_msg.id.dport
    207       if (addr, sport, dport) in self.socketpairs:
    208         cookies[(addr, sport, dport)] = diag_msg.id.cookie
    209       elif (addr, dport, sport) in self.socketpairs:
    210         cookies[(addr, sport, dport)] = diag_msg.id.cookie
    211 
    212     # Did we find all the cookies?
    213     self.assertEquals(2 * NUM_SOCKETS, len(cookies))
    214 
    215     socketpairs = self.socketpairs.values()
    216     random.shuffle(socketpairs)
    217     for socketpair in socketpairs:
    218       for sock in socketpair:
    219         # Check that we can find a diag_msg by scanning a dump.
    220         self.assertSockInfoMatchesSocket(
    221             sock,
    222             self.sock_diag.FindSockInfoFromFd(sock))
    223         cookie = self.sock_diag.FindSockDiagFromFd(sock).id.cookie
    224 
    225         # Check that we can find a diag_msg once we know the cookie.
    226         req = self.sock_diag.DiagReqFromSocket(sock)
    227         req.id.cookie = cookie
    228         info = self.sock_diag.GetSockInfo(req)
    229         self.assertSockInfoMatchesSocket(sock, info)
    230 
    231   def testBytecodeCompilation(self):
    232     # pylint: disable=bad-whitespace
    233     instructions = [
    234         (sock_diag.INET_DIAG_BC_S_GE,   1, 8, 0),                      # 0
    235         (sock_diag.INET_DIAG_BC_D_LE,   1, 7, 0xffff),                 # 8
    236         (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("::1", 128, -1)),       # 16
    237         (sock_diag.INET_DIAG_BC_JMP,    1, 3, None),                   # 44
    238         (sock_diag.INET_DIAG_BC_S_COND, 2, 4, ("127.0.0.1", 32, -1)),  # 48
    239         (sock_diag.INET_DIAG_BC_D_LE,   1, 3, 0x6665),  # not used     # 64
    240         (sock_diag.INET_DIAG_BC_NOP,    1, 1, None),                   # 72
    241                                                                        # 76 acc
    242                                                                        # 80 rej
    243     ]
    244     # pylint: enable=bad-whitespace
    245     bytecode = self.PackAndCheckBytecode(instructions)
    246     expected = (
    247         "0208500000000000"
    248         "050848000000ffff"
    249         "071c20000a800000ffffffff00000000000000000000000000000001"
    250         "01041c00"
    251         "0718200002200000ffffffff7f000001"
    252         "0508100000006566"
    253         "00040400"
    254     )
    255     states = 1 << tcp_test.TCP_ESTABLISHED
    256     self.assertMultiLineEqual(expected, bytecode.encode("hex"))
    257     self.assertEquals(76, len(bytecode))
    258     self.socketpairs = self._CreateLotsOfSockets(SOCK_STREAM)
    259     filteredsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode,
    260                                                         states=states)
    261     allsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE,
    262                                                    states=states)
    263     self.assertItemsEqual(allsockets, filteredsockets)
    264 
    265     # Pick a few sockets in hash table order, and check that the bytecode we
    266     # compiled selects them properly.
    267     for socketpair in self.socketpairs.values()[:20]:
    268       for s in socketpair:
    269         diag_msg = self.sock_diag.FindSockDiagFromFd(s)
    270         instructions = [
    271             (sock_diag.INET_DIAG_BC_S_GE, 1, 5, diag_msg.id.sport),
    272             (sock_diag.INET_DIAG_BC_S_LE, 1, 4, diag_msg.id.sport),
    273             (sock_diag.INET_DIAG_BC_D_GE, 1, 3, diag_msg.id.dport),
    274             (sock_diag.INET_DIAG_BC_D_LE, 1, 2, diag_msg.id.dport),
    275         ]
    276         bytecode = self.PackAndCheckBytecode(instructions)
    277         self.assertEquals(32, len(bytecode))
    278         sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode)
    279         self.assertEquals(1, len(sockets))
    280 
    281         # TODO: why doesn't comparing the cstructs work?
    282         self.assertEquals(diag_msg.Pack(), sockets[0][0].Pack())
    283 
    284   def testCrossFamilyBytecode(self):
    285     """Checks for a cross-family bug in inet_diag_hostcond matching.
    286 
    287     Relevant kernel commits:
    288       android-3.4:
    289         f67caec inet_diag: avoid unsafe and nonsensical prefix matches in inet_diag_bc_run()
    290     """
    291     # TODO: this is only here because the test fails if there are any open
    292     # sockets other than the ones it creates itself. Make the bytecode more
    293     # specific and remove it.
    294     states = 1 << tcp_test.TCP_ESTABLISHED
    295     self.assertFalse(self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, "",
    296                                                        states=states))
    297 
    298     unused_pair4 = net_test.CreateSocketPair(AF_INET, SOCK_STREAM, "127.0.0.1")
    299     unused_pair6 = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM, "::1")
    300 
    301     bytecode4 = self.PackAndCheckBytecode([
    302         (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("0.0.0.0", 0, -1))])
    303     bytecode6 = self.PackAndCheckBytecode([
    304         (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("::", 0, -1))])
    305 
    306     # IPv4/v6 filters must never match IPv6/IPv4 sockets...
    307     v4socks = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode4,
    308                                                   states=states)
    309     self.assertTrue(v4socks)
    310     self.assertTrue(all(d.family == AF_INET for d, _ in v4socks))
    311 
    312     v6socks = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode6,
    313                                                   states=states)
    314     self.assertTrue(v6socks)
    315     self.assertTrue(all(d.family == AF_INET6 for d, _ in v6socks))
    316 
    317     # Except for mapped addresses, which match both IPv4 and IPv6.
    318     pair5 = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM,
    319                                       "::ffff:127.0.0.1")
    320     diag_msgs = [self.sock_diag.FindSockDiagFromFd(s) for s in pair5]
    321     v4socks = [d for d, _ in self.sock_diag.DumpAllInetSockets(IPPROTO_TCP,
    322                                                                bytecode4,
    323                                                                states=states)]
    324     v6socks = [d for d, _ in self.sock_diag.DumpAllInetSockets(IPPROTO_TCP,
    325                                                                bytecode6,
    326                                                                states=states)]
    327     self.assertTrue(all(d in v4socks for d in diag_msgs))
    328     self.assertTrue(all(d in v6socks for d in diag_msgs))
    329 
    330   def testPortComparisonValidation(self):
    331     """Checks for a bug in validating port comparison bytecode.
    332 
    333     Relevant kernel commits:
    334       android-3.4:
    335         5e1f542 inet_diag: validate port comparison byte code to prevent unsafe reads
    336     """
    337     bytecode = sock_diag.InetDiagBcOp((sock_diag.INET_DIAG_BC_D_GE, 4, 8))
    338     self.assertEquals("???",
    339                       self.sock_diag.DecodeBytecode(bytecode))
    340     self.assertRaisesErrno(
    341         EINVAL,
    342         self.sock_diag.DumpAllInetSockets, IPPROTO_TCP, bytecode.Pack())
    343 
    344   def testNonSockDiagCommand(self):
    345     def DiagDump(code):
    346       sock_id = self.sock_diag._EmptyInetDiagSockId()
    347       req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, 0xffffffff,
    348                                      sock_id))
    349       self.sock_diag._Dump(code, req, sock_diag.InetDiagMsg, "")
    350 
    351     op = sock_diag.SOCK_DIAG_BY_FAMILY
    352     DiagDump(op)  # No errors? Good.
    353     self.assertRaisesErrno(EINVAL, DiagDump, op + 17)
    354 
    355   def CheckSocketCookie(self, inet, addr):
    356     """Tests that getsockopt SO_COOKIE can get cookie for all sockets."""
    357     socketpair = net_test.CreateSocketPair(inet, SOCK_STREAM, addr)
    358     for sock in socketpair:
    359       diag_msg = self.sock_diag.FindSockDiagFromFd(sock)
    360       cookie = sock.getsockopt(net_test.SOL_SOCKET, net_test.SO_COOKIE, 8)
    361       self.assertEqual(diag_msg.id.cookie, cookie)
    362 
    363   @unittest.skipUnless(HAVE_SO_COOKIE_SUPPORT, "SO_COOKIE not supported")
    364   def testGetsockoptcookie(self):
    365     self.CheckSocketCookie(AF_INET, "127.0.0.1")
    366     self.CheckSocketCookie(AF_INET6, "::1")
    367 
    368 
    369 class SockDestroyTest(SockDiagBaseTest):
    370   """Tests that SOCK_DESTROY works correctly.
    371 
    372   Relevant kernel commits:
    373     net-next:
    374       b613f56 net: diag: split inet_diag_dump_one_icsk into two
    375       64be0ae net: diag: Add the ability to destroy a socket.
    376       6eb5d2e net: diag: Support SOCK_DESTROY for inet sockets.
    377       c1e64e2 net: diag: Support destroying TCP sockets.
    378       2010b93 net: tcp: deal with listen sockets properly in tcp_abort.
    379 
    380     android-3.4:
    381       d48ec88 net: diag: split inet_diag_dump_one_icsk into two
    382       2438189 net: diag: Add the ability to destroy a socket.
    383       7a2ddbc net: diag: Support SOCK_DESTROY for inet sockets.
    384       44047b2 net: diag: Support destroying TCP sockets.
    385       200dae7 net: tcp: deal with listen sockets properly in tcp_abort.
    386 
    387     android-3.10:
    388       9eaff90 net: diag: split inet_diag_dump_one_icsk into two
    389       d60326c net: diag: Add the ability to destroy a socket.
    390       3d4ce85 net: diag: Support SOCK_DESTROY for inet sockets.
    391       529dfc6 net: diag: Support destroying TCP sockets.
    392       9c712fe net: tcp: deal with listen sockets properly in tcp_abort.
    393 
    394     android-3.18:
    395       100263d net: diag: split inet_diag_dump_one_icsk into two
    396       194c5f3 net: diag: Add the ability to destroy a socket.
    397       8387ea2 net: diag: Support SOCK_DESTROY for inet sockets.
    398       b80585a net: diag: Support destroying TCP sockets.
    399       476c6ce net: tcp: deal with listen sockets properly in tcp_abort.
    400 
    401     android-4.1:
    402       56eebf8 net: diag: split inet_diag_dump_one_icsk into two
    403       fb486c9 net: diag: Add the ability to destroy a socket.
    404       0c02b7e net: diag: Support SOCK_DESTROY for inet sockets.
    405       67c71d8 net: diag: Support destroying TCP sockets.
    406       a76e0ec net: tcp: deal with listen sockets properly in tcp_abort.
    407       e6e277b net: diag: support v4mapped sockets in inet_diag_find_one_icsk()
    408 
    409     android-4.4:
    410       76c83a9 net: diag: split inet_diag_dump_one_icsk into two
    411       f7cf791 net: diag: Add the ability to destroy a socket.
    412       1c42248 net: diag: Support SOCK_DESTROY for inet sockets.
    413       c9e8440d net: diag: Support destroying TCP sockets.
    414       3d9502c tcp: diag: add support for request sockets to tcp_abort()
    415       001cf75 net: tcp: deal with listen sockets properly in tcp_abort.
    416   """
    417 
    418   def testClosesSockets(self):
    419     self.socketpairs = self._CreateLotsOfSockets(SOCK_STREAM)
    420     for _, socketpair in self.socketpairs.iteritems():
    421       # Close one of the sockets.
    422       # This will send a RST that will close the other side as well.
    423       s = random.choice(socketpair)
    424       if random.randrange(0, 2) == 1:
    425         self.sock_diag.CloseSocketFromFd(s)
    426       else:
    427         diag_msg = self.sock_diag.FindSockDiagFromFd(s)
    428 
    429         # Get the cookie wrong and ensure that we get an error and the socket
    430         # is not closed.
    431         real_cookie = diag_msg.id.cookie
    432         diag_msg.id.cookie = os.urandom(len(real_cookie))
    433         req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP)
    434         self.assertRaisesErrno(ENOENT, self.sock_diag.CloseSocket, req)
    435         self.assertSocketConnected(s)
    436 
    437         # Now close it with the correct cookie.
    438         req.id.cookie = real_cookie
    439         self.sock_diag.CloseSocket(req)
    440 
    441       # Check that both sockets in the pair are closed.
    442       self.assertSocketsClosed(socketpair)
    443 
    444   # TODO:
    445   # Test that killing unix sockets returns EOPNOTSUPP.
    446 
    447 
    448 class SocketExceptionThread(threading.Thread):
    449 
    450   def __init__(self, sock, operation):
    451     self.exception = None
    452     super(SocketExceptionThread, self).__init__()
    453     self.daemon = True
    454     self.sock = sock
    455     self.operation = operation
    456 
    457   def run(self):
    458     try:
    459       self.operation(self.sock)
    460     except (IOError, AssertionError), e:
    461       self.exception = e
    462 
    463 
    464 class SockDiagTcpTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
    465 
    466   def testIpv4MappedSynRecvSocket(self):
    467     """Tests for the absence of a bug with AF_INET6 TCP SYN-RECV sockets.
    468 
    469     Relevant kernel commits:
    470          android-3.4:
    471            457a04b inet_diag: fix oops for IPv4 AF_INET6 TCP SYN-RECV state
    472     """
    473     netid = random.choice(self.tuns.keys())
    474     self.IncomingConnection(5, tcp_test.TCP_SYN_RECV, netid)
    475     sock_id = self.sock_diag._EmptyInetDiagSockId()
    476     sock_id.sport = self.port
    477     states = 1 << tcp_test.TCP_SYN_RECV
    478     req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, states, sock_id))
    479     children = self.sock_diag.Dump(req, NO_BYTECODE)
    480 
    481     self.assertTrue(children)
    482     for child, unused_args in children:
    483       self.assertEqual(tcp_test.TCP_SYN_RECV, child.state)
    484       self.assertEqual(self.sock_diag.PaddedAddress(self.remotesockaddr),
    485                        child.id.dst)
    486       self.assertEqual(self.sock_diag.PaddedAddress(self.mysockaddr),
    487                        child.id.src)
    488 
    489 
    490 class SockDestroyTcpTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
    491 
    492   def setUp(self):
    493     super(SockDestroyTcpTest, self).setUp()
    494     self.netid = random.choice(self.tuns.keys())
    495 
    496   def CheckRstOnClose(self, sock, req, expect_reset, msg, do_close=True):
    497     """Closes the socket and checks whether a RST is sent or not."""
    498     if sock is not None:
    499       self.assertIsNone(req, "Must specify sock or req, not both")
    500       self.sock_diag.CloseSocketFromFd(sock)
    501       self.assertRaisesErrno(EINVAL, sock.accept)
    502     else:
    503       self.assertIsNone(sock, "Must specify sock or req, not both")
    504       self.sock_diag.CloseSocket(req)
    505 
    506     if expect_reset:
    507       desc, rst = self.RstPacket()
    508       msg = "%s: expecting %s: " % (msg, desc)
    509       self.ExpectPacketOn(self.netid, msg, rst)
    510     else:
    511       msg = "%s: " % msg
    512       self.ExpectNoPacketsOn(self.netid, msg)
    513 
    514     if sock is not None and do_close:
    515       sock.close()
    516 
    517   def CheckTcpReset(self, state, statename):
    518     for version in [4, 5, 6]:
    519       msg = "Closing incoming IPv%d %s socket" % (version, statename)
    520       self.IncomingConnection(version, state, self.netid)
    521       self.CheckRstOnClose(self.s, None, False, msg)
    522       if state != tcp_test.TCP_LISTEN:
    523         msg = "Closing accepted IPv%d %s socket" % (version, statename)
    524         self.CheckRstOnClose(self.accepted, None, True, msg)
    525 
    526   def testTcpResets(self):
    527     """Checks that closing sockets in appropriate states sends a RST."""
    528     self.CheckTcpReset(tcp_test.TCP_LISTEN, "TCP_LISTEN")
    529     self.CheckTcpReset(tcp_test.TCP_ESTABLISHED, "TCP_ESTABLISHED")
    530     self.CheckTcpReset(tcp_test.TCP_CLOSE_WAIT, "TCP_CLOSE_WAIT")
    531 
    532   def testFinWait1Socket(self):
    533     for version in [4, 5, 6]:
    534       self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid)
    535 
    536       # Get the cookie so we can find this socket after we close it.
    537       diag_msg = self.sock_diag.FindSockDiagFromFd(self.accepted)
    538       diag_req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP)
    539 
    540       # Close the socket and check that it goes into FIN_WAIT1 and sends a FIN.
    541       net_test.EnableFinWait(self.accepted)
    542       self.accepted.close()
    543       diag_req.states = 1 << tcp_test.TCP_FIN_WAIT1
    544       diag_msg, attrs = self.sock_diag.GetSockInfo(diag_req)
    545       self.assertEquals(tcp_test.TCP_FIN_WAIT1, diag_msg.state)
    546       desc, fin = self.FinPacket()
    547       self.ExpectPacketOn(self.netid, "Closing FIN_WAIT1 socket", fin)
    548 
    549       # Destroy the socket and expect no RST.
    550       self.CheckRstOnClose(None, diag_req, False, "Closing FIN_WAIT1 socket")
    551       diag_msg, attrs = self.sock_diag.GetSockInfo(diag_req)
    552 
    553       # The socket is still there in FIN_WAIT1: SOCK_DESTROY did nothing
    554       # because userspace had already closed it.
    555       self.assertEquals(tcp_test.TCP_FIN_WAIT1, diag_msg.state)
    556 
    557       # ACK the FIN so we don't trip over retransmits in future tests.
    558       finversion = 4 if version == 5 else version
    559       desc, finack = packets.ACK(finversion, self.remoteaddr, self.myaddr, fin)
    560       diag_msg, attrs = self.sock_diag.GetSockInfo(diag_req)
    561       self.ReceivePacketOn(self.netid, finack)
    562 
    563       # See if we can find the resulting FIN_WAIT2 socket. This does not appear
    564       # to work on 3.10.
    565       if net_test.LINUX_VERSION >= (3, 18):
    566         diag_req.states = 1 << tcp_test.TCP_FIN_WAIT2
    567         infos = self.sock_diag.Dump(diag_req, "")
    568         self.assertTrue(any(diag_msg.state == tcp_test.TCP_FIN_WAIT2
    569                             for diag_msg, attrs in infos),
    570                         "Expected to find FIN_WAIT2 socket in %s" % infos)
    571 
    572   def FindChildSockets(self, s):
    573     """Finds the SYN_RECV child sockets of a given listening socket."""
    574     d = self.sock_diag.FindSockDiagFromFd(self.s)
    575     req = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
    576     req.states = 1 << tcp_test.TCP_SYN_RECV | 1 << tcp_test.TCP_ESTABLISHED
    577     req.id.cookie = "\x00" * 8
    578 
    579     bad_bytecode = self.PackAndCheckBytecode(
    580         [(sock_diag.INET_DIAG_BC_MARK_COND, 1, 2, (0xffff, 0xffff))])
    581     self.assertEqual([], self.sock_diag.Dump(req, bad_bytecode))
    582 
    583     bytecode = self.PackAndCheckBytecode(
    584         [(sock_diag.INET_DIAG_BC_MARK_COND, 1, 2, (self.netid, 0xffff))])
    585     children = self.sock_diag.Dump(req, bytecode)
    586     return [self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
    587             for d, _ in children]
    588 
    589   def CheckChildSocket(self, version, statename, parent_first):
    590     state = getattr(tcp_test, statename)
    591 
    592     self.IncomingConnection(version, state, self.netid)
    593 
    594     d = self.sock_diag.FindSockDiagFromFd(self.s)
    595     parent = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
    596     children = self.FindChildSockets(self.s)
    597     self.assertEquals(1, len(children))
    598 
    599     is_established = (state == tcp_test.TCP_NOT_YET_ACCEPTED)
    600     expected_state = tcp_test.TCP_ESTABLISHED if is_established else state
    601 
    602     # The new TCP listener code in 4.4 makes SYN_RECV sockets live in the
    603     # regular TCP hash tables, and inet_diag_find_one_icsk can find them.
    604     # Before 4.4, we can see those sockets in dumps, but we can't fetch
    605     # or close them.
    606     can_close_children = is_established or net_test.LINUX_VERSION >= (4, 4)
    607 
    608     for child in children:
    609       if can_close_children:
    610         diag_msg, attrs = self.sock_diag.GetSockInfo(child)
    611         self.assertEquals(diag_msg.state, expected_state)
    612         self.assertMarkIs(self.netid, attrs)
    613       else:
    614         self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, child)
    615 
    616     def CloseParent(expect_reset):
    617       msg = "Closing parent IPv%d %s socket %s child" % (
    618           version, statename, "before" if parent_first else "after")
    619       self.CheckRstOnClose(self.s, None, expect_reset, msg)
    620       self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, parent)
    621 
    622     def CheckChildrenClosed():
    623       for child in children:
    624         self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, child)
    625 
    626     def CloseChildren():
    627       for child in children:
    628         msg = "Closing child IPv%d %s socket %s parent" % (
    629             version, statename, "after" if parent_first else "before")
    630         self.sock_diag.GetSockInfo(child)
    631         self.CheckRstOnClose(None, child, is_established, msg)
    632         self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockInfo, child)
    633       CheckChildrenClosed()
    634 
    635     if parent_first:
    636       # Closing the parent will close child sockets, which will send a RST,
    637       # iff they are already established.
    638       CloseParent(is_established)
    639       if is_established:
    640         CheckChildrenClosed()
    641       elif can_close_children:
    642         CloseChildren()
    643         CheckChildrenClosed()
    644       self.s.close()
    645     else:
    646       if can_close_children:
    647         CloseChildren()
    648       CloseParent(False)
    649       self.s.close()
    650 
    651   def testChildSockets(self):
    652     for version in [4, 5, 6]:
    653       self.CheckChildSocket(version, "TCP_SYN_RECV", False)
    654       self.CheckChildSocket(version, "TCP_SYN_RECV", True)
    655       self.CheckChildSocket(version, "TCP_NOT_YET_ACCEPTED", False)
    656       self.CheckChildSocket(version, "TCP_NOT_YET_ACCEPTED", True)
    657 
    658   def testAcceptInterrupted(self):
    659     """Tests that accept() is interrupted by SOCK_DESTROY."""
    660     for version in [4, 5, 6]:
    661       self.IncomingConnection(version, tcp_test.TCP_LISTEN, self.netid)
    662       self.assertRaisesErrno(ENOTCONN, self.s.recv, 4096)
    663       self.CloseDuringBlockingCall(self.s, lambda sock: sock.accept(), EINVAL)
    664       self.assertRaisesErrno(ECONNABORTED, self.s.send, "foo")
    665       self.assertRaisesErrno(EINVAL, self.s.accept)
    666       # TODO: this should really return an error such as ENOTCONN...
    667       self.assertEquals("", self.s.recv(4096))
    668 
    669   def testReadInterrupted(self):
    670     """Tests that read() is interrupted by SOCK_DESTROY."""
    671     for version in [4, 5, 6]:
    672       self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid)
    673       self.CloseDuringBlockingCall(self.accepted, lambda sock: sock.recv(4096),
    674                                    ECONNABORTED)
    675       # Writing returns EPIPE, and reading returns EOF.
    676       self.assertRaisesErrno(EPIPE, self.accepted.send, "foo")
    677       self.assertEquals("", self.accepted.recv(4096))
    678       self.assertEquals("", self.accepted.recv(4096))
    679 
    680   def testConnectInterrupted(self):
    681     """Tests that connect() is interrupted by SOCK_DESTROY."""
    682     for version in [4, 5, 6]:
    683       family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version]
    684       s = net_test.Socket(family, SOCK_STREAM, IPPROTO_TCP)
    685       self.SelectInterface(s, self.netid, "mark")
    686 
    687       remotesockaddr = self.GetRemoteSocketAddress(version)
    688       remoteaddr = self.GetRemoteAddress(version)
    689       s.bind(("", 0))
    690       _, sport = s.getsockname()[:2]
    691       self.CloseDuringBlockingCall(
    692           s, lambda sock: sock.connect((remotesockaddr, 53)), ECONNABORTED)
    693       desc, syn = packets.SYN(53, version, self.MyAddress(version, self.netid),
    694                               remoteaddr, sport=sport, seq=None)
    695       self.ExpectPacketOn(self.netid, desc, syn)
    696       msg = "SOCK_DESTROY of socket in connect, expected no RST"
    697       self.ExpectNoPacketsOn(self.netid, msg)
    698 
    699 
    700 class PollOnCloseTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
    701   """Tests that the effect of SOCK_DESTROY on poll matches TCP RSTs.
    702 
    703   The behaviour of poll() in these cases is not what we might expect: if only
    704   POLLIN is specified, it will return POLLIN|POLLERR|POLLHUP, but if POLLOUT
    705   is (also) specified, it will only return POLLOUT.
    706   """
    707 
    708   POLLIN_OUT = select.POLLIN | select.POLLOUT
    709   POLLIN_ERR_HUP = select.POLLIN | select.POLLERR | select.POLLHUP
    710 
    711   def setUp(self):
    712     super(PollOnCloseTest, self).setUp()
    713     self.netid = random.choice(self.tuns.keys())
    714 
    715   POLL_FLAGS = [(select.POLLIN, "IN"), (select.POLLOUT, "OUT"),
    716                 (select.POLLERR, "ERR"), (select.POLLHUP, "HUP")]
    717 
    718   def PollResultToString(self, poll_events, ignoremask):
    719     out = []
    720     for fd, event in poll_events:
    721       flags = [name for (flag, name) in self.POLL_FLAGS
    722                if event & flag & ~ignoremask != 0]
    723       out.append((fd, "|".join(flags)))
    724     return out
    725 
    726   def BlockingPoll(self, sock, mask, expected, ignoremask):
    727     p = select.poll()
    728     p.register(sock, mask)
    729     expected_fds = [(sock.fileno(), expected)]
    730     # Don't block forever or we'll hang continuous test runs on failure.
    731     # A 5-second timeout should be long enough not to be flaky.
    732     actual_fds = p.poll(5000)
    733     self.assertEqual(self.PollResultToString(expected_fds, ignoremask),
    734                      self.PollResultToString(actual_fds, ignoremask))
    735 
    736   def RstDuringBlockingCall(self, sock, call, expected_errno):
    737     self._EventDuringBlockingCall(
    738         sock, call, expected_errno,
    739         lambda _: self.ReceiveRstPacketOn(self.netid))
    740 
    741   def assertSocketErrors(self, errno):
    742     # The first operation returns the expected errno.
    743     self.assertRaisesErrno(errno, self.accepted.recv, 4096)
    744 
    745     # Subsequent operations behave as normal.
    746     self.assertRaisesErrno(EPIPE, self.accepted.send, "foo")
    747     self.assertEquals("", self.accepted.recv(4096))
    748     self.assertEquals("", self.accepted.recv(4096))
    749 
    750   def CheckPollDestroy(self, mask, expected, ignoremask):
    751     """Interrupts a poll() with SOCK_DESTROY."""
    752     for version in [4, 5, 6]:
    753       self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid)
    754       self.CloseDuringBlockingCall(
    755           self.accepted,
    756           lambda sock: self.BlockingPoll(sock, mask, expected, ignoremask),
    757           None)
    758       self.assertSocketErrors(ECONNABORTED)
    759 
    760   def CheckPollRst(self, mask, expected, ignoremask):
    761     """Interrupts a poll() by receiving a TCP RST."""
    762     for version in [4, 5, 6]:
    763       self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid)
    764       self.RstDuringBlockingCall(
    765           self.accepted,
    766           lambda sock: self.BlockingPoll(sock, mask, expected, ignoremask),
    767           None)
    768       self.assertSocketErrors(ECONNRESET)
    769 
    770   def testReadPollRst(self):
    771     # Until 3d4762639d ("tcp: remove poll() flakes when receiving RST"), poll()
    772     # would sometimes return POLLERR and sometimes POLLIN|POLLERR|POLLHUP. This
    773     # is due to a race inside the kernel and thus is not visible on the VM, only
    774     # on physical hardware.
    775     if net_test.LINUX_VERSION < (4, 14, 0):
    776       ignoremask = select.POLLIN | select.POLLHUP
    777     else:
    778       ignoremask = 0
    779     self.CheckPollRst(select.POLLIN, self.POLLIN_ERR_HUP, ignoremask)
    780 
    781   def testWritePollRst(self):
    782     self.CheckPollRst(select.POLLOUT, select.POLLOUT, 0)
    783 
    784   def testReadWritePollRst(self):
    785     self.CheckPollRst(self.POLLIN_OUT, select.POLLOUT, 0)
    786 
    787   def testReadPollDestroy(self):
    788     # tcp_abort has the same race that tcp_reset has, but it's not fixed yet.
    789     ignoremask = select.POLLIN | select.POLLHUP
    790     self.CheckPollDestroy(select.POLLIN, self.POLLIN_ERR_HUP, ignoremask)
    791 
    792   def testWritePollDestroy(self):
    793     self.CheckPollDestroy(select.POLLOUT, select.POLLOUT, 0)
    794 
    795   def testReadWritePollDestroy(self):
    796     self.CheckPollDestroy(self.POLLIN_OUT, select.POLLOUT, 0)
    797 
    798 
    799 @unittest.skipUnless(HAVE_UDP_DIAG, "INET_UDP_DIAG not enabled")
    800 class SockDestroyUdpTest(SockDiagBaseTest):
    801 
    802   """Tests SOCK_DESTROY on UDP sockets.
    803 
    804     Relevant kernel commits:
    805       upstream net-next:
    806         5d77dca net: diag: support SOCK_DESTROY for UDP sockets
    807         f95bf34 net: diag: make udp_diag_destroy work for mapped addresses.
    808   """
    809 
    810   def testClosesUdpSockets(self):
    811     self.socketpairs = self._CreateLotsOfSockets(SOCK_DGRAM)
    812     for _, socketpair in self.socketpairs.iteritems():
    813       s1, s2 = socketpair
    814 
    815       self.assertSocketConnected(s1)
    816       self.sock_diag.CloseSocketFromFd(s1)
    817       self.assertSocketClosed(s1)
    818 
    819       self.assertSocketConnected(s2)
    820       self.sock_diag.CloseSocketFromFd(s2)
    821       self.assertSocketClosed(s2)
    822 
    823   def BindToRandomPort(self, s, addr):
    824     ATTEMPTS = 20
    825     for i in xrange(20):
    826       port = random.randrange(1024, 65535)
    827       try:
    828         s.bind((addr, port))
    829         return port
    830       except error, e:
    831         if e.errno != EADDRINUSE:
    832           raise e
    833     raise ValueError("Could not find a free port on %s after %d attempts" %
    834                      (addr, ATTEMPTS))
    835 
    836   def testSocketAddressesAfterClose(self):
    837     for version in 4, 5, 6:
    838       netid = random.choice(self.NETIDS)
    839       dst = self.GetRemoteSocketAddress(version)
    840       family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version]
    841       unspec = {4: "0.0.0.0", 5: "::", 6: "::"}[version]
    842 
    843       # Closing a socket that was not explicitly bound (i.e., bound via
    844       # connect(), not bind()) clears the source address and port.
    845       s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark")
    846       self.SelectInterface(s, netid, "mark")
    847       s.connect((dst, 53))
    848       self.sock_diag.CloseSocketFromFd(s)
    849       self.assertEqual((unspec, 0), s.getsockname()[:2])
    850 
    851       # Closing a socket bound to an IP address leaves the address as is.
    852       s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark")
    853       src = self.MySocketAddress(version, netid)
    854       s.bind((src, 0))
    855       s.connect((dst, 53))
    856       port = s.getsockname()[1]
    857       self.sock_diag.CloseSocketFromFd(s)
    858       self.assertEqual((src, 0), s.getsockname()[:2])
    859 
    860       # Closing a socket bound to a port leaves the port as is.
    861       s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark")
    862       port = self.BindToRandomPort(s, "")
    863       s.connect((dst, 53))
    864       self.sock_diag.CloseSocketFromFd(s)
    865       self.assertEqual((unspec, port), s.getsockname()[:2])
    866 
    867       # Closing a socket bound to IP address and port leaves both as is.
    868       s = self.BuildSocket(version, net_test.UDPSocket, netid, "mark")
    869       src = self.MySocketAddress(version, netid)
    870       port = self.BindToRandomPort(s, src)
    871       self.sock_diag.CloseSocketFromFd(s)
    872       self.assertEqual((src, port), s.getsockname()[:2])
    873 
    874   def testReadInterrupted(self):
    875     """Tests that read() is interrupted by SOCK_DESTROY."""
    876     for version in [4, 5, 6]:
    877       family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version]
    878       s = net_test.UDPSocket(family)
    879       self.SelectInterface(s, random.choice(self.NETIDS), "mark")
    880       addr = self.GetRemoteAddress(version)
    881 
    882       # Check that reads on connected sockets are interrupted.
    883       s.connect((addr, 53))
    884       self.assertEquals(3, s.send("foo"))
    885       self.CloseDuringBlockingCall(s, lambda sock: sock.recv(4096),
    886                                    ECONNABORTED)
    887 
    888       # A destroyed socket is no longer connected, but still usable.
    889       self.assertRaisesErrno(EDESTADDRREQ, s.send, "foo")
    890       self.assertEquals(3, s.sendto("foo", (addr, 53)))
    891 
    892       # Check that reads on unconnected sockets are also interrupted.
    893       self.CloseDuringBlockingCall(s, lambda sock: sock.recv(4096),
    894                                    ECONNABORTED)
    895 
    896 class SockDestroyPermissionTest(SockDiagBaseTest):
    897 
    898   def CheckPermissions(self, socktype):
    899     s = socket(AF_INET6, socktype, 0)
    900     self.SelectInterface(s, random.choice(self.NETIDS), "mark")
    901     if socktype == SOCK_STREAM:
    902       s.listen(1)
    903       expectedstate = tcp_test.TCP_LISTEN
    904     else:
    905       s.connect((self.GetRemoteAddress(6), 53))
    906       expectedstate = tcp_test.TCP_ESTABLISHED
    907 
    908     with net_test.RunAsUid(12345):
    909       self.assertRaisesErrno(
    910           EPERM, self.sock_diag.CloseSocketFromFd, s)
    911 
    912     self.sock_diag.CloseSocketFromFd(s)
    913     self.assertRaises(ValueError, self.sock_diag.CloseSocketFromFd, s)
    914 
    915 
    916   @unittest.skipUnless(HAVE_UDP_DIAG, "INET_UDP_DIAG not enabled")
    917   def testUdp(self):
    918     self.CheckPermissions(SOCK_DGRAM)
    919 
    920   def testTcp(self):
    921     self.CheckPermissions(SOCK_STREAM)
    922 
    923 
    924 class SockDiagMarkTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
    925 
    926   """Tests SOCK_DIAG bytecode filters that use marks.
    927 
    928     Relevant kernel commits:
    929       upstream net-next:
    930         627cc4a net: diag: slightly refactor the inet_diag_bc_audit error checks.
    931         a52e95a net: diag: allow socket bytecode filters to match socket marks
    932         d545cac net: inet: diag: expose the socket mark to privileged processes.
    933   """
    934 
    935   def FilterEstablishedSockets(self, mark, mask):
    936     instructions = [(sock_diag.INET_DIAG_BC_MARK_COND, 1, 2, (mark, mask))]
    937     bytecode = self.sock_diag.PackBytecode(instructions)
    938     return self.sock_diag.DumpAllInetSockets(
    939         IPPROTO_TCP, bytecode, states=(1 << tcp_test.TCP_ESTABLISHED))
    940 
    941   def assertSamePorts(self, ports, diag_msgs):
    942     expected = sorted(ports)
    943     actual = sorted([msg[0].id.sport for msg in diag_msgs])
    944     self.assertEquals(expected, actual)
    945 
    946   def SockInfoMatchesSocket(self, s, info):
    947     try:
    948       self.assertSockInfoMatchesSocket(s, info)
    949       return True
    950     except AssertionError:
    951       return False
    952 
    953   @staticmethod
    954   def SocketDescription(s):
    955     return "%s -> %s" % (str(s.getsockname()), str(s.getpeername()))
    956 
    957   def assertFoundSockets(self, infos, sockets):
    958     matches = {}
    959     for s in sockets:
    960       match = None
    961       for info in infos:
    962         if self.SockInfoMatchesSocket(s, info):
    963           if match:
    964             self.fail("Socket %s matched both %s and %s" %
    965                       (self.SocketDescription(s), match, info))
    966           matches[s] = info
    967       self.assertTrue(s in matches, "Did not find socket %s in dump" %
    968                       self.SocketDescription(s))
    969 
    970     for i in infos:
    971        if i not in matches.values():
    972          self.fail("Too many sockets in dump, first unexpected: %s" % str(i))
    973 
    974   def testMarkBytecode(self):
    975     family, addr = random.choice([
    976         (AF_INET, "127.0.0.1"),
    977         (AF_INET6, "::1"),
    978         (AF_INET6, "::ffff:127.0.0.1")])
    979     s1, s2 = net_test.CreateSocketPair(family, SOCK_STREAM, addr)
    980     s1.setsockopt(SOL_SOCKET, net_test.SO_MARK, 0xfff1234)
    981     s2.setsockopt(SOL_SOCKET, net_test.SO_MARK, 0xf0f1235)
    982 
    983     infos = self.FilterEstablishedSockets(0x1234, 0xffff)
    984     self.assertFoundSockets(infos, [s1])
    985 
    986     infos = self.FilterEstablishedSockets(0x1234, 0xfffe)
    987     self.assertFoundSockets(infos, [s1, s2])
    988 
    989     infos = self.FilterEstablishedSockets(0x1235, 0xffff)
    990     self.assertFoundSockets(infos, [s2])
    991 
    992     infos = self.FilterEstablishedSockets(0x0, 0x0)
    993     self.assertFoundSockets(infos, [s1, s2])
    994 
    995     infos = self.FilterEstablishedSockets(0xfff0000, 0xf0fed00)
    996     self.assertEquals(0, len(infos))
    997 
    998     with net_test.RunAsUid(12345):
    999         self.assertRaisesErrno(EPERM, self.FilterEstablishedSockets,
   1000                                0xfff0000, 0xf0fed00)
   1001 
   1002   @staticmethod
   1003   def SetRandomMark(s):
   1004     # Python doesn't like marks that don't fit into a signed int.
   1005     mark = random.randrange(0, 2**31 - 1)
   1006     s.setsockopt(SOL_SOCKET, net_test.SO_MARK, mark)
   1007     return mark
   1008 
   1009   def assertSocketMarkIs(self, s, mark):
   1010     diag_msg, attrs = self.sock_diag.FindSockInfoFromFd(s)
   1011     self.assertMarkIs(mark, attrs)
   1012     with net_test.RunAsUid(12345):
   1013       diag_msg, attrs = self.sock_diag.FindSockInfoFromFd(s)
   1014       self.assertMarkIs(None, attrs)
   1015 
   1016   def testMarkInAttributes(self):
   1017     testcases = [(AF_INET, "127.0.0.1"),
   1018                  (AF_INET6, "::1"),
   1019                  (AF_INET6, "::ffff:127.0.0.1")]
   1020     for family, addr in testcases:
   1021       # TCP listen sockets.
   1022       server = socket(family, SOCK_STREAM, 0)
   1023       server.bind((addr, 0))
   1024       port = server.getsockname()[1]
   1025       server.listen(1)  # Or the socket won't be in the hashtables.
   1026       server_mark = self.SetRandomMark(server)
   1027       self.assertSocketMarkIs(server, server_mark)
   1028 
   1029       # TCP client sockets.
   1030       client = socket(family, SOCK_STREAM, 0)
   1031       client_mark = self.SetRandomMark(client)
   1032       client.connect((addr, port))
   1033       self.assertSocketMarkIs(client, client_mark)
   1034 
   1035       # TCP server sockets.
   1036       accepted, _ = server.accept()
   1037       self.assertSocketMarkIs(accepted, server_mark)
   1038 
   1039       accepted_mark = self.SetRandomMark(accepted)
   1040       self.assertSocketMarkIs(accepted, accepted_mark)
   1041       self.assertSocketMarkIs(server, server_mark)
   1042 
   1043       server.close()
   1044       client.close()
   1045 
   1046       # Other TCP states are tested in SockDestroyTcpTest.
   1047 
   1048       # UDP sockets.
   1049       if HAVE_UDP_DIAG:
   1050         s = socket(family, SOCK_DGRAM, 0)
   1051         mark = self.SetRandomMark(s)
   1052         s.connect(("", 53))
   1053         self.assertSocketMarkIs(s, mark)
   1054         s.close()
   1055 
   1056       # Basic test for SCTP. sctp_diag was only added in 4.7.
   1057       if HAVE_SCTP:
   1058         s = socket(family, SOCK_STREAM, IPPROTO_SCTP)
   1059         s.bind((addr, 0))
   1060         s.listen(1)
   1061         mark = self.SetRandomMark(s)
   1062         self.assertSocketMarkIs(s, mark)
   1063         sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_SCTP, NO_BYTECODE)
   1064         self.assertEqual(1, len(sockets))
   1065         self.assertEqual(mark, sockets[0][1].get("INET_DIAG_MARK", None))
   1066         s.close()
   1067 
   1068 
   1069 if __name__ == "__main__":
   1070   unittest.main()
   1071