Home | History | Annotate | Download | only in test
      1 #!/usr/bin/python
      2 #
      3 # Copyright 2014 The Android Open Source Project
      4 #
      5 # Licensed under the Apache License, Version 2.0 (the "License");
      6 # you may not use this file except in compliance with the License.
      7 # You may obtain a copy of the License at
      8 #
      9 # http://www.apache.org/licenses/LICENSE-2.0
     10 #
     11 # Unless required by applicable law or agreed to in writing, software
     12 # distributed under the License is distributed on an "AS IS" BASIS,
     13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     14 # See the License for the specific language governing permissions and
     15 # limitations under the License.
     16 
     17 import fcntl
     18 import os
     19 import random
     20 import re
     21 from socket import *  # pylint: disable=wildcard-import
     22 import struct
     23 import unittest
     24 
     25 from scapy import all as scapy
     26 
     27 import csocket
     28 
     29 # TODO: Move these to csocket.py.
     30 SOL_IPV6 = 41
     31 IP_RECVERR = 11
     32 IPV6_RECVERR = 25
     33 IP_TRANSPARENT = 19
     34 IPV6_TRANSPARENT = 75
     35 IPV6_TCLASS = 67
     36 IPV6_FLOWLABEL_MGR = 32
     37 IPV6_FLOWINFO_SEND = 33
     38 
     39 SO_BINDTODEVICE = 25
     40 SO_MARK = 36
     41 SO_PROTOCOL = 38
     42 SO_DOMAIN = 39
     43 SO_COOKIE = 57
     44 
     45 ETH_P_IP = 0x0800
     46 ETH_P_IPV6 = 0x86dd
     47 
     48 IPPROTO_GRE = 47
     49 
     50 SIOCSIFHWADDR = 0x8924
     51 
     52 IPV6_FL_A_GET = 0
     53 IPV6_FL_A_PUT = 1
     54 IPV6_FL_A_RENEW = 1
     55 
     56 IPV6_FL_F_CREATE = 1
     57 IPV6_FL_F_EXCL = 2
     58 
     59 IPV6_FL_S_NONE = 0
     60 IPV6_FL_S_EXCL = 1
     61 IPV6_FL_S_ANY = 255
     62 
     63 IFNAMSIZ = 16
     64 
     65 IPV4_PING = "\x08\x00\x00\x00\x0a\xce\x00\x03"
     66 IPV6_PING = "\x80\x00\x00\x00\x0a\xce\x00\x03"
     67 
     68 IPV4_ADDR = "8.8.8.8"
     69 IPV6_ADDR = "2001:4860:4860::8888"
     70 
     71 IPV6_SEQ_DGRAM_HEADER = ("  sl  "
     72                          "local_address                         "
     73                          "remote_address                        "
     74                          "st tx_queue rx_queue tr tm->when retrnsmt"
     75                          "   uid  timeout inode ref pointer drops\n")
     76 
     77 # Arbitrary packet payload.
     78 UDP_PAYLOAD = str(scapy.DNS(rd=1,
     79                             id=random.randint(0, 65535),
     80                             qd=scapy.DNSQR(qname="wWW.GoOGle.CoM",
     81                                            qtype="AAAA")))
     82 
     83 # Unix group to use if we want to open sockets as non-root.
     84 AID_INET = 3003
     85 
     86 # Kernel log verbosity levels.
     87 KERN_INFO = 6
     88 
     89 LINUX_VERSION = csocket.LinuxVersion()
     90 
     91 
     92 def SetSocketTimeout(sock, ms):
     93   s = ms / 1000
     94   us = (ms % 1000) * 1000
     95   sock.setsockopt(SOL_SOCKET, SO_RCVTIMEO, struct.pack("LL", s, us))
     96 
     97 
     98 def SetSocketTos(s, tos):
     99   level = {AF_INET: SOL_IP, AF_INET6: SOL_IPV6}[s.family]
    100   option = {AF_INET: IP_TOS, AF_INET6: IPV6_TCLASS}[s.family]
    101   s.setsockopt(level, option, tos)
    102 
    103 
    104 def SetNonBlocking(fd):
    105   flags = fcntl.fcntl(fd, fcntl.F_GETFL, 0)
    106   fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
    107 
    108 
    109 # Convenience functions to create sockets.
    110 def Socket(family, sock_type, protocol):
    111   s = socket(family, sock_type, protocol)
    112   SetSocketTimeout(s, 5000)
    113   return s
    114 
    115 
    116 def PingSocket(family):
    117   proto = {AF_INET: IPPROTO_ICMP, AF_INET6: IPPROTO_ICMPV6}[family]
    118   return Socket(family, SOCK_DGRAM, proto)
    119 
    120 
    121 def IPv4PingSocket():
    122   return PingSocket(AF_INET)
    123 
    124 
    125 def IPv6PingSocket():
    126   return PingSocket(AF_INET6)
    127 
    128 
    129 def TCPSocket(family):
    130   s = Socket(family, SOCK_STREAM, IPPROTO_TCP)
    131   SetNonBlocking(s.fileno())
    132   return s
    133 
    134 
    135 def IPv4TCPSocket():
    136   return TCPSocket(AF_INET)
    137 
    138 
    139 def IPv6TCPSocket():
    140   return TCPSocket(AF_INET6)
    141 
    142 
    143 def UDPSocket(family):
    144   return Socket(family, SOCK_DGRAM, IPPROTO_UDP)
    145 
    146 
    147 def RawGRESocket(family):
    148   s = Socket(family, SOCK_RAW, IPPROTO_GRE)
    149   return s
    150 
    151 
    152 def BindRandomPort(version, sock):
    153   addr = {4: "0.0.0.0", 5: "::", 6: "::"}[version]
    154   sock.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
    155   sock.bind((addr, 0))
    156   if sock.getsockopt(SOL_SOCKET, SO_PROTOCOL) == IPPROTO_TCP:
    157     sock.listen(100)
    158   port = sock.getsockname()[1]
    159   return port
    160 
    161 
    162 def EnableFinWait(sock):
    163   # Disabling SO_LINGER causes sockets to go into FIN_WAIT on close().
    164   sock.setsockopt(SOL_SOCKET, SO_LINGER, struct.pack("ii", 0, 0))
    165 
    166 
    167 def DisableFinWait(sock):
    168   # Enabling SO_LINGER with a timeout of zero causes close() to send RST.
    169   sock.setsockopt(SOL_SOCKET, SO_LINGER, struct.pack("ii", 1, 0))
    170 
    171 
    172 def CreateSocketPair(family, socktype, addr):
    173   clientsock = socket(family, socktype, 0)
    174   listensock = socket(family, socktype, 0)
    175   listensock.bind((addr, 0))
    176   addr = listensock.getsockname()
    177   if socktype == SOCK_STREAM:
    178     listensock.listen(1)
    179   clientsock.connect(listensock.getsockname())
    180   if socktype == SOCK_STREAM:
    181     acceptedsock, _ = listensock.accept()
    182     DisableFinWait(clientsock)
    183     DisableFinWait(acceptedsock)
    184     listensock.close()
    185   else:
    186     listensock.connect(clientsock.getsockname())
    187     acceptedsock = listensock
    188   return clientsock, acceptedsock
    189 
    190 
    191 def GetInterfaceIndex(ifname):
    192   s = IPv4PingSocket()
    193   ifr = struct.pack("%dsi" % IFNAMSIZ, ifname, 0)
    194   ifr = fcntl.ioctl(s, scapy.SIOCGIFINDEX, ifr)
    195   return struct.unpack("%dsi" % IFNAMSIZ, ifr)[1]
    196 
    197 
    198 def SetInterfaceHWAddr(ifname, hwaddr):
    199   s = IPv4PingSocket()
    200   hwaddr = hwaddr.replace(":", "")
    201   hwaddr = hwaddr.decode("hex")
    202   if len(hwaddr) != 6:
    203     raise ValueError("Unknown hardware address length %d" % len(hwaddr))
    204   ifr = struct.pack("%dsH6s" % IFNAMSIZ, ifname, scapy.ARPHDR_ETHER, hwaddr)
    205   fcntl.ioctl(s, SIOCSIFHWADDR, ifr)
    206 
    207 
    208 def SetInterfaceState(ifname, up):
    209   s = IPv4PingSocket()
    210   ifr = struct.pack("%dsH" % IFNAMSIZ, ifname, 0)
    211   ifr = fcntl.ioctl(s, scapy.SIOCGIFFLAGS, ifr)
    212   _, flags = struct.unpack("%dsH" % IFNAMSIZ, ifr)
    213   if up:
    214     flags |= scapy.IFF_UP
    215   else:
    216     flags &= ~scapy.IFF_UP
    217   ifr = struct.pack("%dsH" % IFNAMSIZ, ifname, flags)
    218   ifr = fcntl.ioctl(s, scapy.SIOCSIFFLAGS, ifr)
    219 
    220 
    221 def SetInterfaceUp(ifname):
    222   return SetInterfaceState(ifname, True)
    223 
    224 
    225 def SetInterfaceDown(ifname):
    226   return SetInterfaceState(ifname, False)
    227 
    228 
    229 def CanonicalizeIPv6Address(addr):
    230   return inet_ntop(AF_INET6, inet_pton(AF_INET6, addr))
    231 
    232 
    233 def FormatProcAddress(unformatted):
    234   groups = []
    235   for i in xrange(0, len(unformatted), 4):
    236     groups.append(unformatted[i:i+4])
    237   formatted = ":".join(groups)
    238   # Compress the address.
    239   address = CanonicalizeIPv6Address(formatted)
    240   return address
    241 
    242 
    243 def FormatSockStatAddress(address):
    244   if ":" in address:
    245     family = AF_INET6
    246   else:
    247     family = AF_INET
    248   binary = inet_pton(family, address)
    249   out = ""
    250   for i in xrange(0, len(binary), 4):
    251     out += "%08X" % struct.unpack("=L", binary[i:i+4])
    252   return out
    253 
    254 
    255 def GetLinkAddress(ifname, linklocal):
    256   addresses = open("/proc/net/if_inet6").readlines()
    257   for address in addresses:
    258     address = [s for s in address.strip().split(" ") if s]
    259     if address[5] == ifname:
    260       if (linklocal and address[0].startswith("fe80")
    261           or not linklocal and not address[0].startswith("fe80")):
    262         # Convert the address from raw hex to something with colons in it.
    263         return FormatProcAddress(address[0])
    264   return None
    265 
    266 
    267 def GetDefaultRoute(version=6):
    268   if version == 6:
    269     routes = open("/proc/net/ipv6_route").readlines()
    270     for route in routes:
    271       route = [s for s in route.strip().split(" ") if s]
    272       if (route[0] == "00000000000000000000000000000000" and route[1] == "00"
    273           # Routes in non-default tables end up in /proc/net/ipv6_route!!!
    274           and route[9] != "lo" and not route[9].startswith("nettest")):
    275         return FormatProcAddress(route[4]), route[9]
    276     raise ValueError("No IPv6 default route found")
    277   elif version == 4:
    278     routes = open("/proc/net/route").readlines()
    279     for route in routes:
    280       route = [s for s in route.strip().split("\t") if s]
    281       if route[1] == "00000000" and route[7] == "00000000":
    282         gw, iface = route[2], route[0]
    283         gw = inet_ntop(AF_INET, gw.decode("hex")[::-1])
    284         return gw, iface
    285     raise ValueError("No IPv4 default route found")
    286   else:
    287     raise ValueError("Don't know about IPv%s" % version)
    288 
    289 
    290 def GetDefaultRouteInterface():
    291   unused_gw, iface = GetDefaultRoute()
    292   return iface
    293 
    294 
    295 def MakeFlowLabelOption(addr, label):
    296   # struct in6_flowlabel_req {
    297   #         struct in6_addr flr_dst;
    298   #         __be32  flr_label;
    299   #         __u8    flr_action;
    300   #         __u8    flr_share;
    301   #         __u16   flr_flags;
    302   #         __u16   flr_expires;
    303   #         __u16   flr_linger;
    304   #         __u32   __flr_pad;
    305   #         /* Options in format of IPV6_PKTOPTIONS */
    306   # };
    307   fmt = "16sIBBHHH4s"
    308   assert struct.calcsize(fmt) == 32
    309   addr = inet_pton(AF_INET6, addr)
    310   assert len(addr) == 16
    311   label = htonl(label & 0xfffff)
    312   action = IPV6_FL_A_GET
    313   share = IPV6_FL_S_ANY
    314   flags = IPV6_FL_F_CREATE
    315   pad = "\x00" * 4
    316   return struct.pack(fmt, addr, label, action, share, flags, 0, 0, pad)
    317 
    318 
    319 def SetFlowLabel(s, addr, label):
    320   opt = MakeFlowLabelOption(addr, label)
    321   s.setsockopt(SOL_IPV6, IPV6_FLOWLABEL_MGR, opt)
    322   # Caller also needs to do s.setsockopt(SOL_IPV6, IPV6_FLOWINFO_SEND, 1).
    323 
    324 
    325 def RunIptablesCommand(version, args):
    326   iptables = {4: "iptables", 6: "ip6tables"}[version]
    327   iptables_path = "/sbin/" + iptables
    328   if not os.access(iptables_path, os.X_OK):
    329     iptables_path = "/system/bin" + iptables
    330   return os.spawnvp(os.P_WAIT, iptables_path, [iptables_path] + args.split(" "))
    331 
    332 
    333 # Determine network configuration.
    334 try:
    335   GetDefaultRoute(version=4)
    336   HAVE_IPV4 = True
    337 except ValueError:
    338   HAVE_IPV4 = False
    339 
    340 try:
    341   GetDefaultRoute(version=6)
    342   HAVE_IPV6 = True
    343 except ValueError:
    344   HAVE_IPV6 = False
    345 
    346 class RunAsUidGid(object):
    347   """Context guard to run a code block as a given UID."""
    348 
    349   def __init__(self, uid, gid):
    350     self.uid = uid
    351     self.gid = gid
    352 
    353   def __enter__(self):
    354     if self.uid:
    355       self.saved_uid = os.geteuid()
    356       self.saved_groups = os.getgroups()
    357       os.setgroups(self.saved_groups + [AID_INET])
    358       os.seteuid(self.uid)
    359     if self.gid:
    360       self.saved_gid = os.getgid()
    361       os.setgid(self.gid)
    362 
    363   def __exit__(self, unused_type, unused_value, unused_traceback):
    364     if self.uid:
    365       os.seteuid(self.saved_uid)
    366       os.setgroups(self.saved_groups)
    367     if self.gid:
    368       os.setgid(self.saved_gid)
    369 
    370 class RunAsUid(RunAsUidGid):
    371   """Context guard to run a code block as a given GID and UID."""
    372 
    373   def __init__(self, uid):
    374     RunAsUidGid.__init__(self, uid, 0)
    375 
    376 
    377 class NetworkTest(unittest.TestCase):
    378 
    379   def assertRaisesErrno(self, err_num, f=None, *args):
    380     """Test that the system returns an errno error.
    381 
    382     This works similarly to unittest.TestCase.assertRaises. You can call it as
    383     an assertion, or use it as a context manager.
    384     e.g.
    385         self.assertRaisesErrno(errno.ENOENT, do_things, arg1, arg2)
    386     or
    387         with self.assertRaisesErrno(errno.ENOENT):
    388           do_things(arg1, arg2)
    389 
    390     Args:
    391       err_num: an errno constant
    392       f: (optional) A callable that should result in error
    393       *args: arguments passed to f
    394     """
    395     msg = os.strerror(err_num)
    396     if f is None:
    397       return self.assertRaisesRegexp(EnvironmentError, msg)
    398     else:
    399       self.assertRaisesRegexp(EnvironmentError, msg, f, *args)
    400 
    401   def ReadProcNetSocket(self, protocol):
    402     # Read file.
    403     filename = "/proc/net/%s" % protocol
    404     lines = open(filename).readlines()
    405 
    406     # Possibly check, and strip, header.
    407     if protocol in ["icmp6", "raw6", "udp6"]:
    408       self.assertEqual(IPV6_SEQ_DGRAM_HEADER, lines[0])
    409     lines = lines[1:]
    410 
    411     # Check contents.
    412     if protocol.endswith("6"):
    413       addrlen = 32
    414     else:
    415       addrlen = 8
    416 
    417     if protocol.startswith("tcp"):
    418       # Real sockets have 5 extra numbers, timewait sockets have none.
    419       end_regexp = "(| +[0-9]+ [0-9]+ [0-9]+ [0-9]+ -?[0-9]+|)$"
    420     elif re.match("icmp|udp|raw", protocol):
    421       # Drops.
    422       end_regexp = " +([0-9]+) *$"
    423     else:
    424       raise ValueError("Don't know how to parse %s" % filename)
    425 
    426     regexp = re.compile(r" *(\d+): "                    # bucket
    427                         "([0-9A-F]{%d}:[0-9A-F]{4}) "   # srcaddr, port
    428                         "([0-9A-F]{%d}:[0-9A-F]{4}) "   # dstaddr, port
    429                         "([0-9A-F][0-9A-F]) "           # state
    430                         "([0-9A-F]{8}:[0-9A-F]{8}) "    # mem
    431                         "([0-9A-F]{2}:[0-9A-F]{8}) "    # ?
    432                         "([0-9A-F]{8}) +"               # ?
    433                         "([0-9]+) +"                    # uid
    434                         "([0-9]+) +"                    # timeout
    435                         "([0-9]+) +"                    # inode
    436                         "([0-9]+) +"                    # refcnt
    437                         "([0-9a-f]+)"                   # sp
    438                         "%s"                            # icmp has spaces
    439                         % (addrlen, addrlen, end_regexp))
    440     # Return a list of lists with only source / dest addresses for now.
    441     # TODO: consider returning a dict or namedtuple instead.
    442     out = []
    443     for line in lines:
    444       (_, src, dst, state, mem,
    445        _, _, uid, _, _, refcnt, _, extra) = regexp.match(line).groups()
    446       out.append([src, dst, state, mem, uid, refcnt, extra])
    447     return out
    448 
    449   @staticmethod
    450   def GetConsoleLogLevel():
    451     return int(open("/proc/sys/kernel/printk").readline().split()[0])
    452 
    453   @staticmethod
    454   def SetConsoleLogLevel(level):
    455     return open("/proc/sys/kernel/printk", "w").write("%s\n" % level)
    456 
    457 
    458 if __name__ == "__main__":
    459   unittest.main()
    460