Home | History | Annotate | Download | only in test
      1 #!/usr/bin/python
      2 #
      3 # Copyright 2014 The Android Open Source Project
      4 #
      5 # Licensed under the Apache License, Version 2.0 (the "License");
      6 # you may not use this file except in compliance with the License.
      7 # You may obtain a copy of the License at
      8 #
      9 # http://www.apache.org/licenses/LICENSE-2.0
     10 #
     11 # Unless required by applicable law or agreed to in writing, software
     12 # distributed under the License is distributed on an "AS IS" BASIS,
     13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     14 # See the License for the specific language governing permissions and
     15 # limitations under the License.
     16 
     17 """Partial Python implementation of iproute functionality."""
     18 
     19 # pylint: disable=g-bad-todo
     20 
     21 import os
     22 import socket
     23 import struct
     24 import sys
     25 
     26 import cstruct
     27 import util
     28 
     29 ### Base netlink constants. See include/uapi/linux/netlink.h.
     30 NETLINK_ROUTE = 0
     31 NETLINK_SOCK_DIAG = 4
     32 NETLINK_XFRM = 6
     33 NETLINK_GENERIC = 16
     34 
     35 # Request constants.
     36 NLM_F_REQUEST = 1
     37 NLM_F_ACK = 4
     38 NLM_F_REPLACE = 0x100
     39 NLM_F_EXCL = 0x200
     40 NLM_F_CREATE = 0x400
     41 NLM_F_DUMP = 0x300
     42 
     43 # Message types.
     44 NLMSG_ERROR = 2
     45 NLMSG_DONE = 3
     46 
     47 # Data structure formats.
     48 # These aren't constants, they're classes. So, pylint: disable=invalid-name
     49 NLMsgHdr = cstruct.Struct("NLMsgHdr", "=LHHLL", "length type flags seq pid")
     50 NLMsgErr = cstruct.Struct("NLMsgErr", "=i", "error")
     51 NLAttr = cstruct.Struct("NLAttr", "=HH", "nla_len nla_type")
     52 
     53 # Alignment / padding.
     54 NLA_ALIGNTO = 4
     55 
     56 # List of attributes that can appear more than once in a given netlink message.
     57 # These can appear more than once but don't seem to contain any data.
     58 DUP_ATTRS_OK = ["INET_DIAG_NONE", "IFLA_PAD"]
     59 
     60 class NetlinkSocket(object):
     61   """A basic netlink socket object."""
     62 
     63   BUFSIZE = 65536
     64   DEBUG = False
     65   # List of netlink messages to print, e.g., [], ["NEIGH", "ROUTE"], or ["ALL"]
     66   NL_DEBUG = []
     67 
     68   def _Debug(self, s):
     69     if self.DEBUG:
     70       print s
     71 
     72   def _NlAttr(self, nla_type, data):
     73     datalen = len(data)
     74     # Pad the data if it's not a multiple of NLA_ALIGNTO bytes long.
     75     padding = "\x00" * util.GetPadLength(NLA_ALIGNTO, datalen)
     76     nla_len = datalen + len(NLAttr)
     77     return NLAttr((nla_len, nla_type)).Pack() + data + padding
     78 
     79   def _NlAttrIPAddress(self, nla_type, family, address):
     80     return self._NlAttr(nla_type, socket.inet_pton(family, address))
     81 
     82   def _NlAttrStr(self, nla_type, value):
     83     value = value + "\x00"
     84     return self._NlAttr(nla_type, value.encode("UTF-8"))
     85 
     86   def _NlAttrU32(self, nla_type, value):
     87     return self._NlAttr(nla_type, struct.pack("=I", value))
     88 
     89   def _GetConstantName(self, module, value, prefix):
     90     thismodule = sys.modules[module]
     91     for name in dir(thismodule):
     92       if name.startswith("INET_DIAG_BC"):
     93         continue
     94       if (name.startswith(prefix) and
     95           not name.startswith(prefix + "F_") and
     96           name.isupper() and getattr(thismodule, name) == value):
     97           return name
     98     return value
     99 
    100   def _Decode(self, command, msg, nla_type, nla_data):
    101     """No-op, nonspecific version of decode."""
    102     return nla_type, nla_data
    103 
    104   def _ReadNlAttr(self, data):
    105     # Read the nlattr header.
    106     nla, data = cstruct.Read(data, NLAttr)
    107 
    108     # Read the data.
    109     datalen = nla.nla_len - len(nla)
    110     padded_len = util.GetPadLength(NLA_ALIGNTO, datalen) + datalen
    111     nla_data, data = data[:datalen], data[padded_len:]
    112 
    113     return nla, nla_data, data
    114 
    115   def _ParseAttributes(self, command, msg, data, nested=0):
    116     """Parses and decodes netlink attributes.
    117 
    118     Takes a block of NLAttr data structures, decodes them using Decode, and
    119     returns the result in a dict keyed by attribute number.
    120 
    121     Args:
    122       command: An integer, the rtnetlink command being carried out.
    123       msg: A Struct, the type of the data after the netlink header.
    124       data: A byte string containing a sequence of NLAttr data structures.
    125       nested: An integer, how deep we're currently nested.
    126 
    127     Returns:
    128       A dictionary mapping attribute types (integers) to decoded values.
    129 
    130     Raises:
    131       ValueError: There was a duplicate attribute type.
    132     """
    133     attributes = {}
    134     while data:
    135       nla, nla_data, data = self._ReadNlAttr(data)
    136 
    137       # If it's an attribute we know about, try to decode it.
    138       nla_name, nla_data = self._Decode(command, msg, nla.nla_type, nla_data)
    139 
    140       if nla_name in attributes and nla_name not in DUP_ATTRS_OK:
    141         raise ValueError("Duplicate attribute %s" % nla_name)
    142 
    143       attributes[nla_name] = nla_data
    144       if not nested:
    145         self._Debug("      %s" % (str((nla_name, nla_data))))
    146 
    147     return attributes
    148 
    149   def _OpenNetlinkSocket(self, family, groups=None):
    150     sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, family)
    151     if groups:
    152       sock.bind((0,  groups))
    153     sock.connect((0, 0))  # The kernel.
    154     return sock
    155 
    156   def __init__(self, family):
    157     # Global sequence number.
    158     self.seq = 0
    159     self.sock = self._OpenNetlinkSocket(family)
    160     self.pid = self.sock.getsockname()[1]
    161 
    162   def MaybeDebugCommand(self, command, flags, data):
    163     # Default no-op implementation to be overridden by subclasses.
    164     pass
    165 
    166   def _Send(self, msg):
    167     # self._Debug(msg.encode("hex"))
    168     self.seq += 1
    169     self.sock.send(msg)
    170 
    171   def _Recv(self):
    172     data = self.sock.recv(self.BUFSIZE)
    173     # self._Debug(data.encode("hex"))
    174     return data
    175 
    176   def _ExpectDone(self):
    177     response = self._Recv()
    178     hdr = NLMsgHdr(response)
    179     if hdr.type != NLMSG_DONE:
    180       raise ValueError("Expected DONE, got type %d" % hdr.type)
    181 
    182   def _ParseAck(self, response):
    183     # Find the error code.
    184     hdr, data = cstruct.Read(response, NLMsgHdr)
    185     if hdr.type == NLMSG_ERROR:
    186       error = NLMsgErr(data).error
    187       if error:
    188         raise IOError(-error, os.strerror(-error))
    189     else:
    190       raise ValueError("Expected ACK, got type %d" % hdr.type)
    191 
    192   def _ExpectAck(self):
    193     response = self._Recv()
    194     self._ParseAck(response)
    195 
    196   def _SendNlRequest(self, command, data, flags):
    197     """Sends a netlink request and expects an ack."""
    198     length = len(NLMsgHdr) + len(data)
    199     nlmsg = NLMsgHdr((length, command, flags, self.seq, self.pid)).Pack()
    200 
    201     self.MaybeDebugCommand(command, flags, nlmsg + data)
    202 
    203     # Send the message.
    204     self._Send(nlmsg + data)
    205 
    206     if flags & NLM_F_ACK:
    207       self._ExpectAck()
    208 
    209   def _ParseNLMsg(self, data, msgtype):
    210     """Parses a Netlink message into a header and a dictionary of attributes."""
    211     nlmsghdr, data = cstruct.Read(data, NLMsgHdr)
    212     self._Debug("  %s" % nlmsghdr)
    213 
    214     if nlmsghdr.type == NLMSG_ERROR or nlmsghdr.type == NLMSG_DONE:
    215       print "done"
    216       return (None, None), data
    217 
    218     nlmsg, data = cstruct.Read(data, msgtype)
    219     self._Debug("    %s" % nlmsg)
    220 
    221     # Parse the attributes in the nlmsg.
    222     attrlen = nlmsghdr.length - len(nlmsghdr) - len(nlmsg)
    223     attributes = self._ParseAttributes(nlmsghdr.type, nlmsg, data[:attrlen])
    224     data = data[attrlen:]
    225     return (nlmsg, attributes), data
    226 
    227   def _GetMsg(self, msgtype):
    228     data = self._Recv()
    229     if NLMsgHdr(data).type == NLMSG_ERROR:
    230       self._ParseAck(data)
    231     return self._ParseNLMsg(data, msgtype)[0]
    232 
    233   def _GetMsgList(self, msgtype, data, expect_done):
    234     out = []
    235     while data:
    236       msg, data = self._ParseNLMsg(data, msgtype)
    237       if msg is None:
    238         break
    239       out.append(msg)
    240     if expect_done:
    241       self._ExpectDone()
    242     return out
    243 
    244   def _Dump(self, command, msg, msgtype, attrs):
    245     """Sends a dump request and returns a list of decoded messages.
    246 
    247     Args:
    248       command: An integer, the command to run (e.g., RTM_NEWADDR).
    249       msg: A struct, the request (e.g., a RTMsg). May be None.
    250       msgtype: A cstruct.Struct, the data type to parse the dump results as.
    251       attrs: A string, the raw bytes of any request attributes to include.
    252 
    253     Returns:
    254       A list of (msg, attrs) tuples where msg is of type msgtype and attrs is
    255       a dict of attributes.
    256     """
    257     # Create a netlink dump request containing the msg.
    258     flags = NLM_F_DUMP | NLM_F_REQUEST
    259     msg = "" if msg is None else msg.Pack()
    260     length = len(NLMsgHdr) + len(msg) + len(attrs)
    261     nlmsghdr = NLMsgHdr((length, command, flags, self.seq, self.pid))
    262 
    263     # Send the request.
    264     request = nlmsghdr.Pack() + msg + attrs
    265     self.MaybeDebugCommand(command, flags, request)
    266     self._Send(request)
    267 
    268     # Keep reading netlink messages until we get a NLMSG_DONE.
    269     out = []
    270     while True:
    271       data = self._Recv()
    272       response_type = NLMsgHdr(data).type
    273       if response_type == NLMSG_DONE:
    274         break
    275       elif response_type == NLMSG_ERROR:
    276         # Likely means that the kernel didn't like our dump request.
    277         # Parse the error and throw an exception.
    278         self._ParseAck(data)
    279       out.extend(self._GetMsgList(msgtype, data, False))
    280 
    281     return out
    282