Home | History | Annotate | Download | only in test
      1 #!/usr/bin/python
      2 #
      3 # Copyright 2015 The Android Open Source Project
      4 #
      5 # Licensed under the Apache License, Version 2.0 (the "License");
      6 # you may not use this file except in compliance with the License.
      7 # You may obtain a copy of the License at
      8 #
      9 # http://www.apache.org/licenses/LICENSE-2.0
     10 #
     11 # Unless required by applicable law or agreed to in writing, software
     12 # distributed under the License is distributed on an "AS IS" BASIS,
     13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     14 # See the License for the specific language governing permissions and
     15 # limitations under the License.
     16 
     17 """Partial Python implementation of sock_diag functionality."""
     18 
     19 # pylint: disable=g-bad-todo
     20 
     21 import errno
     22 import os
     23 from socket import *  # pylint: disable=wildcard-import
     24 import struct
     25 
     26 import csocket
     27 import cstruct
     28 import net_test
     29 import netlink
     30 
     31 ### Base netlink constants. See include/uapi/linux/netlink.h.
     32 NETLINK_SOCK_DIAG = 4
     33 
     34 ### sock_diag constants. See include/uapi/linux/sock_diag.h.
     35 # Message types.
     36 SOCK_DIAG_BY_FAMILY = 20
     37 SOCK_DESTROY = 21
     38 
     39 ### inet_diag_constants. See include/uapi/linux/inet_diag.h
     40 # Message types.
     41 TCPDIAG_GETSOCK = 18
     42 
     43 # Request attributes.
     44 INET_DIAG_REQ_BYTECODE = 1
     45 
     46 # Extensions.
     47 INET_DIAG_NONE = 0
     48 INET_DIAG_MEMINFO = 1
     49 INET_DIAG_INFO = 2
     50 INET_DIAG_VEGASINFO = 3
     51 INET_DIAG_CONG = 4
     52 INET_DIAG_TOS = 5
     53 INET_DIAG_TCLASS = 6
     54 INET_DIAG_SKMEMINFO = 7
     55 INET_DIAG_SHUTDOWN = 8
     56 INET_DIAG_DCTCPINFO = 9
     57 INET_DIAG_DCTCPINFO = 9
     58 INET_DIAG_PROTOCOL = 10
     59 INET_DIAG_SKV6ONLY = 11
     60 INET_DIAG_LOCALS = 12
     61 INET_DIAG_PEERS = 13
     62 INET_DIAG_PAD = 14
     63 INET_DIAG_MARK = 15
     64 
     65 # Bytecode operations.
     66 INET_DIAG_BC_NOP = 0
     67 INET_DIAG_BC_JMP = 1
     68 INET_DIAG_BC_S_GE = 2
     69 INET_DIAG_BC_S_LE = 3
     70 INET_DIAG_BC_D_GE = 4
     71 INET_DIAG_BC_D_LE = 5
     72 INET_DIAG_BC_AUTO = 6
     73 INET_DIAG_BC_S_COND = 7
     74 INET_DIAG_BC_D_COND = 8
     75 INET_DIAG_BC_DEV_COND = 9
     76 INET_DIAG_BC_MARK_COND = 10
     77 
     78 # Data structure formats.
     79 # These aren't constants, they're classes. So, pylint: disable=invalid-name
     80 InetDiagSockId = cstruct.Struct(
     81     "InetDiagSockId", "!HH16s16sI8s", "sport dport src dst iface cookie")
     82 InetDiagReqV2 = cstruct.Struct(
     83     "InetDiagReqV2", "=BBBxIS", "family protocol ext states id",
     84     [InetDiagSockId])
     85 InetDiagMsg = cstruct.Struct(
     86     "InetDiagMsg", "=BBBBSLLLLL",
     87     "family state timer retrans id expires rqueue wqueue uid inode",
     88     [InetDiagSockId])
     89 InetDiagMeminfo = cstruct.Struct(
     90     "InetDiagMeminfo", "=IIII", "rmem wmem fmem tmem")
     91 InetDiagBcOp = cstruct.Struct("InetDiagBcOp", "BBH", "code yes no")
     92 InetDiagHostcond = cstruct.Struct("InetDiagHostcond", "=BBxxi",
     93                                   "family prefix_len port")
     94 InetDiagMarkcond = cstruct.Struct("InetDiagMarkcond", "=II", "mark mask")
     95 
     96 SkMeminfo = cstruct.Struct(
     97     "SkMeminfo", "=IIIIIIII",
     98     "rmem_alloc rcvbuf wmem_alloc sndbuf fwd_alloc wmem_queued optmem backlog")
     99 TcpInfo = cstruct.Struct(
    100     "TcpInfo", "=BBBBBBBxIIIIIIIIIIIIIIIIIIIIIIII",
    101     "state ca_state retransmits probes backoff options wscale "
    102     "rto ato snd_mss rcv_mss "
    103     "unacked sacked lost retrans fackets "
    104     "last_data_sent last_ack_sent last_data_recv last_ack_recv "
    105     "pmtu rcv_ssthresh rtt rttvar snd_ssthresh snd_cwnd advmss reordering "
    106     "rcv_rtt rcv_space "
    107     "total_retrans")  # As of linux 3.13, at least.
    108 
    109 TCP_TIME_WAIT = 6
    110 ALL_NON_TIME_WAIT = 0xffffffff & ~(1 << TCP_TIME_WAIT)
    111 
    112 
    113 class SockDiag(netlink.NetlinkSocket):
    114 
    115   FAMILY = NETLINK_SOCK_DIAG
    116   NL_DEBUG = []
    117 
    118   def _Decode(self, command, msg, nla_type, nla_data):
    119     """Decodes netlink attributes to Python types."""
    120     if msg.family == AF_INET or msg.family == AF_INET6:
    121       if isinstance(msg, InetDiagReqV2):
    122         prefix = "INET_DIAG_REQ"
    123       else:
    124         prefix = "INET_DIAG"
    125       name = self._GetConstantName(__name__, nla_type, prefix)
    126     else:
    127       # Don't know what this is. Leave it as an integer.
    128       name = nla_type
    129 
    130     if name in ["INET_DIAG_SHUTDOWN", "INET_DIAG_TOS", "INET_DIAG_TCLASS",
    131                 "INET_DIAG_SKV6ONLY"]:
    132       data = ord(nla_data)
    133     elif name == "INET_DIAG_CONG":
    134       data = nla_data.strip("\x00")
    135     elif name == "INET_DIAG_MEMINFO":
    136       data = InetDiagMeminfo(nla_data)
    137     elif name == "INET_DIAG_INFO":
    138       # TODO: Catch the exception and try something else if it's not TCP.
    139       data = TcpInfo(nla_data)
    140     elif name == "INET_DIAG_SKMEMINFO":
    141       data = SkMeminfo(nla_data)
    142     elif name == "INET_DIAG_MARK":
    143       data = struct.unpack("=I", nla_data)[0]
    144     elif name == "INET_DIAG_REQ_BYTECODE":
    145       data = self.DecodeBytecode(nla_data)
    146     elif name in ["INET_DIAG_LOCALS", "INET_DIAG_PEERS"]:
    147       data = []
    148       while len(nla_data):
    149         # The SCTP diag code always appears to copy sizeof(sockaddr_storage)
    150         # bytes, but does so from a union sctp_addr which is at most as long
    151         # as a sockaddr_in6.
    152         addr, nla_data = cstruct.Read(nla_data, csocket.SockaddrStorage)
    153         if addr.family == AF_INET:
    154           addr = csocket.SockaddrIn(addr.Pack())
    155         elif addr.family == AF_INET6:
    156           addr = csocket.SockaddrIn6(addr.Pack())
    157         data.append(addr)
    158     else:
    159       data = nla_data
    160 
    161     return name, data
    162 
    163   def MaybeDebugCommand(self, command, unused_flags, data):
    164     name = self._GetConstantName(__name__, command, "SOCK_")
    165     if "ALL" not in self.NL_DEBUG and "SOCK" not in self.NL_DEBUG:
    166       return
    167     parsed = self._ParseNLMsg(data, InetDiagReqV2)
    168     print "%s %s" % (name, str(parsed))
    169 
    170   @staticmethod
    171   def _EmptyInetDiagSockId():
    172     return InetDiagSockId(("\x00" * len(InetDiagSockId)))
    173 
    174   @staticmethod
    175   def PackBytecode(instructions):
    176     """Compiles instructions to inet_diag bytecode.
    177 
    178     The input is a list of (INET_DIAG_BC_xxx, yes, no, arg) tuples, where yes
    179     and no are relative jump offsets measured in instructions. The yes branch
    180     is taken if the instruction matches.
    181 
    182     To accept, jump 1 past the last instruction. To reject, jump 2 past the
    183     last instruction.
    184 
    185     The target of a no jump is only valid if it is reachable by following
    186     only yes jumps from the first instruction - see inet_diag_bc_audit and
    187     valid_cc. This means that if cond1 and cond2 are two mutually exclusive
    188     filter terms, it is not possible to implement cond1 OR cond2 using:
    189 
    190       ...
    191       cond1 2 1 arg
    192       cond2 1 2 arg
    193       accept
    194       reject
    195 
    196     but only using:
    197 
    198       ...
    199       cond1 1 2 arg
    200       jmp   1 2
    201       cond2 1 2 arg
    202       accept
    203       reject
    204 
    205     The jmp instruction ignores yes and always jumps to no, but yes must be 1
    206     or the bytecode won't validate. It doesn't have to be jmp - any instruction
    207     that is guaranteed not to match on real data will do.
    208 
    209     Args:
    210       instructions: list of instruction tuples
    211 
    212     Returns:
    213       A string, the raw bytecode.
    214     """
    215     args = []
    216     positions = [0]
    217 
    218     for op, yes, no, arg in instructions:
    219 
    220       if yes <= 0 or no <= 0:
    221         raise ValueError("Jumps must be > 0")
    222 
    223       if op in [INET_DIAG_BC_NOP, INET_DIAG_BC_JMP, INET_DIAG_BC_AUTO]:
    224         arg = ""
    225       elif op in [INET_DIAG_BC_S_GE, INET_DIAG_BC_S_LE,
    226                   INET_DIAG_BC_D_GE, INET_DIAG_BC_D_LE]:
    227         arg = "\x00\x00" + struct.pack("=H", arg)
    228       elif op in [INET_DIAG_BC_S_COND, INET_DIAG_BC_D_COND]:
    229         addr, prefixlen, port = arg
    230         family = AF_INET6 if ":" in addr else AF_INET
    231         addr = inet_pton(family, addr)
    232         arg = InetDiagHostcond((family, prefixlen, port)).Pack() + addr
    233       elif op == INET_DIAG_BC_MARK_COND:
    234         if isinstance(arg, tuple):
    235           mark, mask = arg
    236         else:
    237           mark, mask = arg, 0xffffffff
    238         arg = InetDiagMarkcond((mark, mask)).Pack()
    239       else:
    240         raise ValueError("Unsupported opcode %d" % op)
    241 
    242       args.append(arg)
    243       length = len(InetDiagBcOp) + len(arg)
    244       positions.append(positions[-1] + length)
    245 
    246     # Reject label.
    247     positions.append(positions[-1] + 4)  # Why 4? Because the kernel uses 4.
    248     assert len(args) == len(instructions) == len(positions) - 2
    249 
    250     # print positions
    251 
    252     packed = ""
    253     for i, (op, yes, no, arg) in enumerate(instructions):
    254       yes = positions[i + yes] - positions[i]
    255       no = positions[i + no] - positions[i]
    256       instruction = InetDiagBcOp((op, yes, no)).Pack() + args[i]
    257       #print "%3d: %d %3d %3d %s %s" % (positions[i], op, yes, no,
    258       #                                 arg, instruction.encode("hex"))
    259       packed += instruction
    260     #print
    261 
    262     return packed
    263 
    264   @staticmethod
    265   def DecodeBytecode(bytecode):
    266     instructions = []
    267     try:
    268       while bytecode:
    269         op, rest = cstruct.Read(bytecode, InetDiagBcOp)
    270 
    271         if op.code in [INET_DIAG_BC_NOP, INET_DIAG_BC_JMP, INET_DIAG_BC_AUTO]:
    272           arg = None
    273         elif op.code in [INET_DIAG_BC_S_GE, INET_DIAG_BC_S_LE,
    274                          INET_DIAG_BC_D_GE, INET_DIAG_BC_D_LE]:
    275           op, rest = cstruct.Read(rest, InetDiagBcOp)
    276           arg = op.no
    277         elif op.code in [INET_DIAG_BC_S_COND, INET_DIAG_BC_D_COND]:
    278           cond, rest = cstruct.Read(rest, InetDiagHostcond)
    279           if cond.family == 0:
    280             arg = (None, cond.prefix_len, cond.port)
    281           else:
    282             addrlen = 4 if cond.family == AF_INET else 16
    283             addr, rest = rest[:addrlen], rest[addrlen:]
    284             addr = inet_ntop(cond.family, addr)
    285             arg = (addr, cond.prefix_len, cond.port)
    286         elif op.code == INET_DIAG_BC_DEV_COND:
    287           attrlen = struct.calcsize("=I")
    288           attr, rest = rest[:attrlen], rest[attrlen:]
    289           arg = struct.unpack("=I", attr)
    290         elif op.code == INET_DIAG_BC_MARK_COND:
    291           arg, rest = cstruct.Read(rest, InetDiagMarkcond)
    292         else:
    293           raise ValueError("Unknown opcode %d" % op.code)
    294         instructions.append((op, arg))
    295         bytecode = rest
    296 
    297       return instructions
    298     except (TypeError, ValueError):
    299       return "???"
    300 
    301   def Dump(self, diag_req, bytecode):
    302     if bytecode:
    303       bytecode = self._NlAttr(INET_DIAG_REQ_BYTECODE, bytecode)
    304 
    305     out = self._Dump(SOCK_DIAG_BY_FAMILY, diag_req, InetDiagMsg, bytecode)
    306     return out
    307 
    308   def DumpAllInetSockets(self, protocol, bytecode, sock_id=None, ext=0,
    309                          states=ALL_NON_TIME_WAIT):
    310     """Dumps IPv4 or IPv6 sockets matching the specified parameters."""
    311     # DumpSockets(AF_UNSPEC) does not result in dumping all inet sockets, it
    312     # results in ENOENT.
    313     if sock_id is None:
    314       sock_id = self._EmptyInetDiagSockId()
    315 
    316     sockets = []
    317     for family in [AF_INET, AF_INET6]:
    318       diag_req = InetDiagReqV2((family, protocol, ext, states, sock_id))
    319       sockets += self.Dump(diag_req, bytecode)
    320 
    321     return sockets
    322 
    323   @staticmethod
    324   def GetRawAddress(family, addr):
    325     """Fetches the source address from an InetDiagMsg."""
    326     addrlen = {AF_INET:4, AF_INET6: 16}[family]
    327     return inet_ntop(family, addr[:addrlen])
    328 
    329   @staticmethod
    330   def GetSourceAddress(diag_msg):
    331     """Fetches the source address from an InetDiagMsg."""
    332     return SockDiag.GetRawAddress(diag_msg.family, diag_msg.id.src)
    333 
    334   @staticmethod
    335   def GetDestinationAddress(diag_msg):
    336     """Fetches the source address from an InetDiagMsg."""
    337     return SockDiag.GetRawAddress(diag_msg.family, diag_msg.id.dst)
    338 
    339   @staticmethod
    340   def RawAddress(addr):
    341     """Converts an IP address string to binary format."""
    342     family = AF_INET6 if ":" in addr else AF_INET
    343     return inet_pton(family, addr)
    344 
    345   @staticmethod
    346   def PaddedAddress(addr):
    347     """Converts an IP address string to binary format for InetDiagSockId."""
    348     padded = SockDiag.RawAddress(addr)
    349     if len(padded) < 16:
    350       padded += "\x00" * (16 - len(padded))
    351     return padded
    352 
    353   @staticmethod
    354   def DiagReqFromSocket(s):
    355     """Creates an InetDiagReqV2 that matches the specified socket."""
    356     family = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_DOMAIN)
    357     protocol = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_PROTOCOL)
    358     if net_test.LINUX_VERSION >= (3, 8):
    359       iface = s.getsockopt(SOL_SOCKET, net_test.SO_BINDTODEVICE,
    360                            net_test.IFNAMSIZ)
    361       iface = GetInterfaceIndex(iface) if iface else 0
    362     else:
    363       iface = 0
    364     src, sport = s.getsockname()[:2]
    365     try:
    366       dst, dport = s.getpeername()[:2]
    367     except error, e:
    368       if e.errno == errno.ENOTCONN:
    369         dport = 0
    370         dst = "::" if family == AF_INET6 else "0.0.0.0"
    371       else:
    372         raise e
    373     src = SockDiag.PaddedAddress(src)
    374     dst = SockDiag.PaddedAddress(dst)
    375     sock_id = InetDiagSockId((sport, dport, src, dst, iface, "\x00" * 8))
    376     return InetDiagReqV2((family, protocol, 0, 0xffffffff, sock_id))
    377 
    378   def FindSockInfoFromFd(self, s):
    379     """Gets a diag_msg and attrs from the kernel for the specified socket."""
    380     req = self.DiagReqFromSocket(s)
    381     # The kernel doesn't use idiag_src and idiag_dst when dumping sockets, it
    382     # only uses them when targeting a specific socket with a cookie. Check the
    383     # the inode number to ensure we don't mistakenly match another socket on
    384     # the same port but with a different IP address.
    385     inode = os.fstat(s.fileno()).st_ino
    386     results = self.Dump(req, "")
    387     if len(results) == 0:
    388       raise ValueError("Dump of %s returned no sockets" % req)
    389     for diag_msg, attrs in results:
    390       if diag_msg.inode == inode:
    391         return diag_msg, attrs
    392     raise ValueError("Dump of %s did not contain inode %d" % (req, inode))
    393 
    394   def FindSockDiagFromFd(self, s):
    395     """Gets an InetDiagMsg from the kernel for the specified socket."""
    396     return self.FindSockInfoFromFd(s)[0]
    397 
    398   def GetSockInfo(self, req):
    399     """Gets a diag_msg and attrs from the kernel for the specified request."""
    400     self._SendNlRequest(SOCK_DIAG_BY_FAMILY, req.Pack(), netlink.NLM_F_REQUEST)
    401     return self._GetMsg(InetDiagMsg)
    402 
    403   @staticmethod
    404   def DiagReqFromDiagMsg(d, protocol):
    405     """Constructs a diag_req from a diag_msg the kernel has given us."""
    406     return InetDiagReqV2((d.family, protocol, 0, 1 << d.state, d.id))
    407 
    408   def CloseSocket(self, req):
    409     self._SendNlRequest(SOCK_DESTROY, req.Pack(),
    410                         netlink.NLM_F_REQUEST | netlink.NLM_F_ACK)
    411 
    412   def CloseSocketFromFd(self, s):
    413     diag_msg, attrs = self.FindSockInfoFromFd(s)
    414     protocol = s.getsockopt(SOL_SOCKET, net_test.SO_PROTOCOL)
    415     req = self.DiagReqFromDiagMsg(diag_msg, protocol)
    416     return self.CloseSocket(req)
    417 
    418 
    419 if __name__ == "__main__":
    420   n = SockDiag()
    421   n.DEBUG = True
    422   bytecode = ""
    423   sock_id = n._EmptyInetDiagSockId()
    424   sock_id.dport = 443
    425   ext = 1 << (INET_DIAG_TOS - 1) | 1 << (INET_DIAG_TCLASS - 1)
    426   states = 0xffffffff
    427   diag_msgs = n.DumpAllInetSockets(IPPROTO_TCP, "",
    428                                    sock_id=sock_id, ext=ext, states=states)
    429   print diag_msgs
    430