Home | History | Annotate | Download | only in net
      1 //
      2 // Copyright (C) 2013 The Android Open Source Project
      3 //
      4 // Licensed under the Apache License, Version 2.0 (the "License");
      5 // you may not use this file except in compliance with the License.
      6 // You may obtain a copy of the License at
      7 //
      8 //      http://www.apache.org/licenses/LICENSE-2.0
      9 //
     10 // Unless required by applicable law or agreed to in writing, software
     11 // distributed under the License is distributed on an "AS IS" BASIS,
     12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13 // See the License for the specific language governing permissions and
     14 // limitations under the License.
     15 //
     16 
     17 #ifndef SHILL_NET_NETLINK_MESSAGE_H_
     18 #define SHILL_NET_NETLINK_MESSAGE_H_
     19 
     20 #include <linux/netlink.h>
     21 
     22 #include <map>
     23 #include <string>
     24 
     25 #include <base/bind.h>
     26 
     27 #include <gtest/gtest_prod.h>  // for FRIEND_TEST.
     28 
     29 #include "shill/net/byte_string.h"
     30 #include "shill/net/shill_export.h"
     31 
     32 struct nlmsghdr;
     33 
     34 namespace shill {
     35 
     36 // Netlink messages are sent over netlink sockets to talk between user-space
     37 // programs (like shill) and kernel modules (like the cfg80211 module).  Each
     38 // kernel module that talks netlink potentially adds its own family header to
     39 // the nlmsghdr (the top-level netlink message header) and, potentially, uses a
     40 // different payload format.  The NetlinkMessage class represents that which
     41 // is common between the different types of netlink message.
     42 //
     43 // The common portions of Netlink Messages start with a |nlmsghdr|.  Those
     44 // messages look something like the following:
     45 //
     46 //         |<--------------NetlinkPacket::GetLength()------------->|
     47 //         |       |<--NetlinkPacket::GetPayload().GetLength() --->|
     48 //         |       |                                               |
     49 //    -----+-----+-+---------------------------------------------+-+----
     50 //     ... |     | |                 netlink payload             | |
     51 //         |     | +------------+-+------------------------------+ |
     52 //         | nl  | |            | |                              | | nl
     53 //         | msg |p| (optional) |p|                              |p| msg ...
     54 //         | hdr |a| family     |a|        family payload        |a| hdr
     55 //         |     |d| header     |d|                              |d|
     56 //         |     | |            | |                              | |
     57 //    -----+-----+-+------------+-+------------------------------+-+----
     58 //                  ^
     59 //                  |
     60 //                  +-- nlmsg payload (NetlinkPacket::GetPayload())
     61 //
     62 // All NetlinkMessages sent to the kernel need a valid message type (which is
     63 // found in the nlmsghdr structure) and all NetlinkMessages received from the
     64 // kernel have a valid message type.  Some message types (NLMSG_NOOP,
     65 // NLMSG_ERROR, and GENL_ID_CTRL, for example) are allocated statically; for
     66 // those, the |message_type_| is assigned directly.
     67 //
     68 // Other message types ("nl80211", for example), are assigned by the kernel
     69 // dynamically.  To get the message type, pass a closure to assign the
     70 // message_type along with the sting to NetlinkManager::GetFamily:
     71 //
     72 //  nl80211_type = netlink_manager->GetFamily(Nl80211Message::kMessageType);
     73 //
     74 // Do all of this before you start to create NetlinkMessages so that
     75 // NetlinkMessage can be instantiated with a valid |message_type_|.
     76 
     77 class NetlinkPacket;
     78 
     79 class SHILL_EXPORT NetlinkMessage {
     80  public:
     81   // Describes the context of the netlink message for parsing purposes.
     82   struct MessageContext {
     83     MessageContext() : nl80211_cmd(0), is_broadcast(false) {}
     84 
     85     size_t nl80211_cmd;
     86     bool is_broadcast;
     87   };
     88 
     89   static const uint32_t kBroadcastSequenceNumber;
     90   static const uint16_t kIllegalMessageType;
     91 
     92   explicit NetlinkMessage(uint16_t message_type) :
     93       flags_(0), message_type_(message_type),
     94       sequence_number_(kBroadcastSequenceNumber) {}
     95   virtual ~NetlinkMessage() {}
     96 
     97   // Returns a string of bytes representing the message (with it headers) and
     98   // any necessary padding.  These bytes are appropriately formatted to be
     99   // written to a netlink socket.
    100   virtual ByteString Encode(uint32_t sequence_number) = 0;
    101 
    102   // Initializes the |NetlinkMessage| from a complete and legal message
    103   // (potentially received from the kernel via a netlink socket).
    104   virtual bool InitFromPacket(NetlinkPacket* packet, MessageContext context);
    105 
    106   uint16_t message_type() const { return message_type_; }
    107   void AddFlag(uint16_t new_flag) { flags_ |= new_flag; }
    108   void AddAckFlag() { flags_ |= NLM_F_ACK; }
    109   uint16_t flags() const { return flags_; }
    110   uint32_t sequence_number() const { return sequence_number_; }
    111   // Logs the message.  Allows a different log level (presumably more
    112   // stringent) for the body of the message than the header.
    113   virtual void Print(int header_log_level, int detail_log_level) const = 0;
    114 
    115   // Logs the message's raw bytes (with minimal interpretation).
    116   static void PrintBytes(int log_level, const unsigned char* buf,
    117                          size_t num_bytes);
    118 
    119   // Logs a netlink message (with minimal interpretation).
    120   static void PrintPacket(int log_level, const NetlinkPacket& packet);
    121 
    122  protected:
    123   friend class NetlinkManagerTest;
    124   FRIEND_TEST(NetlinkManagerTest, NL80211_CMD_NOTIFY_CQM);
    125 
    126   // Returns a string of bytes representing an |nlmsghdr|, filled-in, and its
    127   // padding.
    128   virtual ByteString EncodeHeader(uint32_t sequence_number);
    129   // Reads the |nlmsghdr|.  Subclasses may read additional data from the
    130   // payload.
    131   virtual bool InitAndStripHeader(NetlinkPacket* packet);
    132 
    133   uint16_t flags_;
    134   uint16_t message_type_;
    135   uint32_t sequence_number_;
    136 
    137  private:
    138   static void PrintHeader(int log_level, const nlmsghdr* header);
    139   static void PrintPayload(int log_level, const unsigned char* buf,
    140                            size_t num_bytes);
    141 
    142   DISALLOW_COPY_AND_ASSIGN(NetlinkMessage);
    143 };
    144 
    145 
    146 // The Error and Ack messages are received from the kernel and are combined,
    147 // here, because they look so much alike (the only difference is that the
    148 // error code is 0 for the Ack messages).  Error messages are received from
    149 // the kernel in response to a sent message when there's a problem (such as
    150 // a malformed message or a busy kernel module).  Ack messages are received
    151 // from the kernel when a sent message has the NLM_F_ACK flag set, indicating
    152 // that an Ack is requested.
    153 class SHILL_EXPORT ErrorAckMessage : public NetlinkMessage {
    154  public:
    155   static const uint16_t kMessageType;
    156 
    157   ErrorAckMessage() : NetlinkMessage(kMessageType), error_(0) {}
    158   explicit ErrorAckMessage(uint32_t err)
    159       : NetlinkMessage(kMessageType), error_(err) {}
    160   static uint16_t GetMessageType() { return kMessageType; }
    161   bool InitFromPacket(NetlinkPacket* packet, MessageContext context) override;
    162   ByteString Encode(uint32_t sequence_number) override;
    163   void Print(int header_log_level, int detail_log_level) const override;
    164   std::string ToString() const;
    165   uint32_t error() const { return -error_; }
    166 
    167  private:
    168   uint32_t error_;
    169 
    170   DISALLOW_COPY_AND_ASSIGN(ErrorAckMessage);
    171 };
    172 
    173 
    174 class SHILL_EXPORT NoopMessage : public NetlinkMessage {
    175  public:
    176   static const uint16_t kMessageType;
    177 
    178   NoopMessage() : NetlinkMessage(kMessageType) {}
    179   static uint16_t GetMessageType() { return kMessageType; }
    180   virtual ByteString Encode(uint32_t sequence_number);
    181   virtual void Print(int header_log_level, int detail_log_level) const;
    182   std::string ToString() const { return "<NOOP>"; }
    183 
    184  private:
    185   DISALLOW_COPY_AND_ASSIGN(NoopMessage);
    186 };
    187 
    188 
    189 class SHILL_EXPORT DoneMessage : public NetlinkMessage {
    190  public:
    191   static const uint16_t kMessageType;
    192 
    193   DoneMessage() : NetlinkMessage(kMessageType) {}
    194   static uint16_t GetMessageType() { return kMessageType; }
    195   virtual ByteString Encode(uint32_t sequence_number);
    196   virtual void Print(int header_log_level, int detail_log_level) const;
    197   std::string ToString() const { return "<DONE with multipart message>"; }
    198 
    199  private:
    200   DISALLOW_COPY_AND_ASSIGN(DoneMessage);
    201 };
    202 
    203 
    204 class SHILL_EXPORT OverrunMessage : public NetlinkMessage {
    205  public:
    206   static const uint16_t kMessageType;
    207 
    208   OverrunMessage() : NetlinkMessage(kMessageType) {}
    209   static uint16_t GetMessageType() { return kMessageType; }
    210   virtual ByteString Encode(uint32_t sequence_number);
    211   virtual void Print(int header_log_level, int detail_log_level) const;
    212   std::string ToString() const { return "<OVERRUN - data lost>"; }
    213 
    214  private:
    215   DISALLOW_COPY_AND_ASSIGN(OverrunMessage);
    216 };
    217 
    218 
    219 class SHILL_EXPORT UnknownMessage : public NetlinkMessage {
    220  public:
    221   UnknownMessage(uint16_t message_type, ByteString message_body) :
    222       NetlinkMessage(message_type), message_body_(message_body) {}
    223   virtual ByteString Encode(uint32_t sequence_number);
    224   virtual void Print(int header_log_level, int detail_log_level) const;
    225 
    226  private:
    227   ByteString message_body_;
    228 
    229   DISALLOW_COPY_AND_ASSIGN(UnknownMessage);
    230 };
    231 
    232 
    233 //
    234 // Factory class.
    235 //
    236 
    237 class SHILL_EXPORT NetlinkMessageFactory {
    238  public:
    239   typedef base::Callback<NetlinkMessage*(const NetlinkPacket& packet)>
    240       FactoryMethod;
    241 
    242   NetlinkMessageFactory() {}
    243 
    244   // Adds a message factory for a specific message_type.  Intended to be used
    245   // at initialization.
    246   bool AddFactoryMethod(uint16_t message_type, FactoryMethod factory);
    247 
    248   // Ownership of the message is passed to the caller and, as such, he should
    249   // delete it.
    250   NetlinkMessage* CreateMessage(NetlinkPacket* packet,
    251                                 NetlinkMessage::MessageContext context) const;
    252 
    253  private:
    254   std::map<uint16_t, FactoryMethod> factories_;
    255 
    256   DISALLOW_COPY_AND_ASSIGN(NetlinkMessageFactory);
    257 };
    258 
    259 }  // namespace shill
    260 
    261 #endif  // SHILL_NET_NETLINK_MESSAGE_H_
    262