Home | History | Annotate | Download | only in net_test
      1 #!/usr/bin/python
      2 #
      3 # Copyright 2015 The Android Open Source Project
      4 #
      5 # Licensed under the Apache License, Version 2.0 (the "License");
      6 # you may not use this file except in compliance with the License.
      7 # You may obtain a copy of the License at
      8 #
      9 # http://www.apache.org/licenses/LICENSE-2.0
     10 #
     11 # Unless required by applicable law or agreed to in writing, software
     12 # distributed under the License is distributed on an "AS IS" BASIS,
     13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     14 # See the License for the specific language governing permissions and
     15 # limitations under the License.
     16 
     17 """Partial Python implementation of sock_diag functionality."""
     18 
     19 # pylint: disable=g-bad-todo
     20 
     21 import errno
     22 from socket import *  # pylint: disable=wildcard-import
     23 import struct
     24 
     25 import cstruct
     26 import net_test
     27 import netlink
     28 
     29 ### Base netlink constants. See include/uapi/linux/netlink.h.
     30 NETLINK_SOCK_DIAG = 4
     31 
     32 ### sock_diag constants. See include/uapi/linux/sock_diag.h.
     33 # Message types.
     34 SOCK_DIAG_BY_FAMILY = 20
     35 SOCK_DESTROY = 21
     36 
     37 ### inet_diag_constants. See include/uapi/linux/inet_diag.h
     38 # Message types.
     39 TCPDIAG_GETSOCK = 18
     40 
     41 # Request attributes.
     42 INET_DIAG_REQ_BYTECODE = 1
     43 
     44 # Extensions.
     45 INET_DIAG_NONE = 0
     46 INET_DIAG_MEMINFO = 1
     47 INET_DIAG_INFO = 2
     48 INET_DIAG_VEGASINFO = 3
     49 INET_DIAG_CONG = 4
     50 INET_DIAG_TOS = 5
     51 INET_DIAG_TCLASS = 6
     52 INET_DIAG_SKMEMINFO = 7
     53 INET_DIAG_SHUTDOWN = 8
     54 INET_DIAG_DCTCPINFO = 9
     55 
     56 # Bytecode operations.
     57 INET_DIAG_BC_NOP = 0
     58 INET_DIAG_BC_JMP = 1
     59 INET_DIAG_BC_S_GE = 2
     60 INET_DIAG_BC_S_LE = 3
     61 INET_DIAG_BC_D_GE = 4
     62 INET_DIAG_BC_D_LE = 5
     63 INET_DIAG_BC_AUTO = 6
     64 INET_DIAG_BC_S_COND = 7
     65 INET_DIAG_BC_D_COND = 8
     66 
     67 # Data structure formats.
     68 # These aren't constants, they're classes. So, pylint: disable=invalid-name
     69 InetDiagSockId = cstruct.Struct(
     70     "InetDiagSockId", "!HH16s16sI8s", "sport dport src dst iface cookie")
     71 InetDiagReqV2 = cstruct.Struct(
     72     "InetDiagReqV2", "=BBBxIS", "family protocol ext states id",
     73     [InetDiagSockId])
     74 InetDiagMsg = cstruct.Struct(
     75     "InetDiagMsg", "=BBBBSLLLLL",
     76     "family state timer retrans id expires rqueue wqueue uid inode",
     77     [InetDiagSockId])
     78 InetDiagMeminfo = cstruct.Struct(
     79     "InetDiagMeminfo", "=IIII", "rmem wmem fmem tmem")
     80 InetDiagBcOp = cstruct.Struct("InetDiagBcOp", "BBH", "code yes no")
     81 InetDiagHostcond = cstruct.Struct("InetDiagHostcond", "=BBxxi",
     82                                   "family prefix_len port")
     83 
     84 SkMeminfo = cstruct.Struct(
     85     "SkMeminfo", "=IIIIIIII",
     86     "rmem_alloc rcvbuf wmem_alloc sndbuf fwd_alloc wmem_queued optmem backlog")
     87 TcpInfo = cstruct.Struct(
     88     "TcpInfo", "=BBBBBBBxIIIIIIIIIIIIIIIIIIIIIIII",
     89     "state ca_state retransmits probes backoff options wscale "
     90     "rto ato snd_mss rcv_mss "
     91     "unacked sacked lost retrans fackets "
     92     "last_data_sent last_ack_sent last_data_recv last_ack_recv "
     93     "pmtu rcv_ssthresh rtt rttvar snd_ssthresh snd_cwnd advmss reordering "
     94     "rcv_rtt rcv_space "
     95     "total_retrans")  # As of linux 3.13, at least.
     96 
     97 TCP_TIME_WAIT = 6
     98 ALL_NON_TIME_WAIT = 0xffffffff & ~(1 << TCP_TIME_WAIT)
     99 
    100 
    101 class SockDiag(netlink.NetlinkSocket):
    102 
    103   FAMILY = NETLINK_SOCK_DIAG
    104   NL_DEBUG = []
    105 
    106   def _Decode(self, command, msg, nla_type, nla_data):
    107     """Decodes netlink attributes to Python types."""
    108     if msg.family == AF_INET or msg.family == AF_INET6:
    109       name = self._GetConstantName(__name__, nla_type, "INET_DIAG")
    110     else:
    111       # Don't know what this is. Leave it as an integer.
    112       name = nla_type
    113 
    114     if name in ["INET_DIAG_SHUTDOWN", "INET_DIAG_TOS", "INET_DIAG_TCLASS"]:
    115       data = ord(nla_data)
    116     elif name == "INET_DIAG_CONG":
    117       data = nla_data.strip("\x00")
    118     elif name == "INET_DIAG_MEMINFO":
    119       data = InetDiagMeminfo(nla_data)
    120     elif name == "INET_DIAG_INFO":
    121       # TODO: Catch the exception and try something else if it's not TCP.
    122       data = TcpInfo(nla_data)
    123     elif name == "INET_DIAG_SKMEMINFO":
    124       data = SkMeminfo(nla_data)
    125     else:
    126       data = nla_data
    127 
    128     return name, data
    129 
    130   def MaybeDebugCommand(self, command, data):
    131     name = self._GetConstantName(__name__, command, "SOCK_")
    132     if "ALL" not in self.NL_DEBUG and "SOCK" not in self.NL_DEBUG:
    133       return
    134     parsed = self._ParseNLMsg(data, InetDiagReqV2)
    135     print "%s %s" % (name, str(parsed))
    136 
    137   @staticmethod
    138   def _EmptyInetDiagSockId():
    139     return InetDiagSockId(("\x00" * len(InetDiagSockId)))
    140 
    141   def PackBytecode(self, instructions):
    142     """Compiles instructions to inet_diag bytecode.
    143 
    144     The input is a list of (INET_DIAG_BC_xxx, yes, no, arg) tuples, where yes
    145     and no are relative jump offsets measured in instructions. The yes branch
    146     is taken if the instruction matches.
    147 
    148     To accept, jump 1 past the last instruction. To reject, jump 2 past the
    149     last instruction.
    150 
    151     The target of a no jump is only valid if it is reachable by following
    152     only yes jumps from the first instruction - see inet_diag_bc_audit and
    153     valid_cc. This means that if cond1 and cond2 are two mutually exclusive
    154     filter terms, it is not possible to implement cond1 OR cond2 using:
    155 
    156       ...
    157       cond1 2 1 arg
    158       cond2 1 2 arg
    159       accept
    160       reject
    161 
    162     but only using:
    163 
    164       ...
    165       cond1 1 2 arg
    166       jmp   1 2
    167       cond2 1 2 arg
    168       accept
    169       reject
    170 
    171     The jmp instruction ignores yes and always jumps to no, but yes must be 1
    172     or the bytecode won't validate. It doesn't have to be jmp - any instruction
    173     that is guaranteed not to match on real data will do.
    174 
    175     Args:
    176       instructions: list of instruction tuples
    177 
    178     Returns:
    179       A string, the raw bytecode.
    180     """
    181     args = []
    182     positions = [0]
    183 
    184     for op, yes, no, arg in instructions:
    185 
    186       if yes <= 0 or no <= 0:
    187         raise ValueError("Jumps must be > 0")
    188 
    189       if op in [INET_DIAG_BC_NOP, INET_DIAG_BC_JMP, INET_DIAG_BC_AUTO]:
    190         arg = ""
    191       elif op in [INET_DIAG_BC_S_GE, INET_DIAG_BC_S_LE,
    192                   INET_DIAG_BC_D_GE, INET_DIAG_BC_D_LE]:
    193         arg = "\x00\x00" + struct.pack("=H", arg)
    194       elif op in [INET_DIAG_BC_S_COND, INET_DIAG_BC_D_COND]:
    195         addr, prefixlen, port = arg
    196         family = AF_INET6 if ":" in addr else AF_INET
    197         addr = inet_pton(family, addr)
    198         arg = InetDiagHostcond((family, prefixlen, port)).Pack() + addr
    199       else:
    200         raise ValueError("Unsupported opcode %d" % op)
    201 
    202       args.append(arg)
    203       length = len(InetDiagBcOp) + len(arg)
    204       positions.append(positions[-1] + length)
    205 
    206     # Reject label.
    207     positions.append(positions[-1] + 4)  # Why 4? Because the kernel uses 4.
    208     assert len(args) == len(instructions) == len(positions) - 2
    209 
    210     # print positions
    211 
    212     packed = ""
    213     for i, (op, yes, no, arg) in enumerate(instructions):
    214       yes = positions[i + yes] - positions[i]
    215       no = positions[i + no] - positions[i]
    216       instruction = InetDiagBcOp((op, yes, no)).Pack() + args[i]
    217       #print "%3d: %d %3d %3d %s %s" % (positions[i], op, yes, no,
    218       #                                 arg, instruction.encode("hex"))
    219       packed += instruction
    220     #print
    221 
    222     return packed
    223 
    224   def Dump(self, diag_req, bytecode=""):
    225     out = self._Dump(SOCK_DIAG_BY_FAMILY, diag_req, InetDiagMsg, bytecode)
    226     return out
    227 
    228   def DumpAllInetSockets(self, protocol, bytecode, sock_id=None, ext=0,
    229                          states=ALL_NON_TIME_WAIT):
    230     """Dumps IPv4 or IPv6 sockets matching the specified parameters."""
    231     # DumpSockets(AF_UNSPEC) does not result in dumping all inet sockets, it
    232     # results in ENOENT.
    233     if sock_id is None:
    234       sock_id = self._EmptyInetDiagSockId()
    235 
    236     if bytecode:
    237       bytecode = self._NlAttr(INET_DIAG_REQ_BYTECODE, bytecode)
    238 
    239     sockets = []
    240     for family in [AF_INET, AF_INET6]:
    241       diag_req = InetDiagReqV2((family, protocol, ext, states, sock_id))
    242       sockets += self.Dump(diag_req, bytecode)
    243 
    244     return sockets
    245 
    246   @staticmethod
    247   def GetRawAddress(family, addr):
    248     """Fetches the source address from an InetDiagMsg."""
    249     addrlen = {AF_INET:4, AF_INET6: 16}[family]
    250     return inet_ntop(family, addr[:addrlen])
    251 
    252   @staticmethod
    253   def GetSourceAddress(diag_msg):
    254     """Fetches the source address from an InetDiagMsg."""
    255     return SockDiag.GetRawAddress(diag_msg.family, diag_msg.id.src)
    256 
    257   @staticmethod
    258   def GetDestinationAddress(diag_msg):
    259     """Fetches the source address from an InetDiagMsg."""
    260     return SockDiag.GetRawAddress(diag_msg.family, diag_msg.id.dst)
    261 
    262   @staticmethod
    263   def RawAddress(addr):
    264     """Converts an IP address string to binary format."""
    265     family = AF_INET6 if ":" in addr else AF_INET
    266     return inet_pton(family, addr)
    267 
    268   @staticmethod
    269   def PaddedAddress(addr):
    270     """Converts an IP address string to binary format for InetDiagSockId."""
    271     padded = SockDiag.RawAddress(addr)
    272     if len(padded) < 16:
    273       padded += "\x00" * (16 - len(padded))
    274     return padded
    275 
    276   @staticmethod
    277   def DiagReqFromSocket(s):
    278     """Creates an InetDiagReqV2 that matches the specified socket."""
    279     family = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_DOMAIN)
    280     protocol = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_PROTOCOL)
    281     if net_test.LINUX_VERSION >= (3, 8):
    282       iface = s.getsockopt(SOL_SOCKET, net_test.SO_BINDTODEVICE,
    283                            net_test.IFNAMSIZ)
    284       iface = GetInterfaceIndex(iface) if iface else 0
    285     else:
    286       iface = 0
    287     src, sport = s.getsockname()[:2]
    288     try:
    289       dst, dport = s.getpeername()[:2]
    290     except error, e:
    291       if e.errno == errno.ENOTCONN:
    292         dport = 0
    293         dst = "::" if family == AF_INET6 else "0.0.0.0"
    294       else:
    295         raise e
    296     src = SockDiag.PaddedAddress(src)
    297     dst = SockDiag.PaddedAddress(dst)
    298     sock_id = InetDiagSockId((sport, dport, src, dst, iface, "\x00" * 8))
    299     return InetDiagReqV2((family, protocol, 0, 0xffffffff, sock_id))
    300 
    301   def FindSockDiagFromReq(self, req):
    302     for diag_msg, attrs in self.Dump(req):
    303       return diag_msg
    304     raise ValueError("Dump of %s returned no sockets" % req)
    305 
    306   def FindSockDiagFromFd(self, s):
    307     """Gets an InetDiagMsg from the kernel for the specified socket."""
    308     req = self.DiagReqFromSocket(s)
    309     return self.FindSockDiagFromReq(req)
    310 
    311   def GetSockDiag(self, req):
    312     """Gets an InetDiagMsg from the kernel for the specified request."""
    313     self._SendNlRequest(SOCK_DIAG_BY_FAMILY, req.Pack(), netlink.NLM_F_REQUEST)
    314     return self._GetMsg(InetDiagMsg)[0]
    315 
    316   @staticmethod
    317   def DiagReqFromDiagMsg(d, protocol):
    318     """Constructs a diag_req from a diag_msg the kernel has given us."""
    319     return InetDiagReqV2((d.family, protocol, 0, 1 << d.state, d.id))
    320 
    321   def CloseSocket(self, req):
    322     self._SendNlRequest(SOCK_DESTROY, req.Pack(),
    323                         netlink.NLM_F_REQUEST | netlink.NLM_F_ACK)
    324 
    325   def CloseSocketFromFd(self, s):
    326     diag_msg = self.FindSockDiagFromFd(s)
    327     protocol = s.getsockopt(SOL_SOCKET, net_test.SO_PROTOCOL)
    328     req = self.DiagReqFromDiagMsg(diag_msg, protocol)
    329     return self.CloseSocket(req)
    330 
    331 
    332 if __name__ == "__main__":
    333   n = SockDiag()
    334   n.DEBUG = True
    335   bytecode = ""
    336   sock_id = n._EmptyInetDiagSockId()
    337   sock_id.dport = 443
    338   ext = 1 << (INET_DIAG_TOS - 1) | 1 << (INET_DIAG_TCLASS - 1)
    339   states = 0xffffffff
    340   diag_msgs = n.DumpAllInetSockets(IPPROTO_TCP, "",
    341                                    sock_id=sock_id, ext=ext, states=states)
    342   print diag_msgs
    343