1 // 2 // Copyright (C) 2013 The Android Open Source Project 3 // 4 // Licensed under the Apache License, Version 2.0 (the "License"); 5 // you may not use this file except in compliance with the License. 6 // You may obtain a copy of the License at 7 // 8 // http://www.apache.org/licenses/LICENSE-2.0 9 // 10 // Unless required by applicable law or agreed to in writing, software 11 // distributed under the License is distributed on an "AS IS" BASIS, 12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 // See the License for the specific language governing permissions and 14 // limitations under the License. 15 // 16 17 #ifndef SHILL_NET_NETLINK_MESSAGE_H_ 18 #define SHILL_NET_NETLINK_MESSAGE_H_ 19 20 #include <linux/netlink.h> 21 22 #include <map> 23 #include <string> 24 25 #include <base/bind.h> 26 27 #include <gtest/gtest_prod.h> // for FRIEND_TEST. 28 29 #include "shill/net/byte_string.h" 30 #include "shill/net/shill_export.h" 31 32 struct nlmsghdr; 33 34 namespace shill { 35 36 // Netlink messages are sent over netlink sockets to talk between user-space 37 // programs (like shill) and kernel modules (like the cfg80211 module). Each 38 // kernel module that talks netlink potentially adds its own family header to 39 // the nlmsghdr (the top-level netlink message header) and, potentially, uses a 40 // different payload format. The NetlinkMessage class represents that which 41 // is common between the different types of netlink message. 42 // 43 // The common portions of Netlink Messages start with a |nlmsghdr|. Those 44 // messages look something like the following: 45 // 46 // |<--------------NetlinkPacket::GetLength()------------->| 47 // | |<--NetlinkPacket::GetPayload().GetLength() --->| 48 // | | | 49 // -----+-----+-+---------------------------------------------+-+---- 50 // ... | | | netlink payload | | 51 // | | +------------+-+------------------------------+ | 52 // | nl | | | | | | nl 53 // | msg |p| (optional) |p| |p| msg ... 54 // | hdr |a| family |a| family payload |a| hdr 55 // | |d| header |d| |d| 56 // | | | | | | | 57 // -----+-----+-+------------+-+------------------------------+-+---- 58 // ^ 59 // | 60 // +-- nlmsg payload (NetlinkPacket::GetPayload()) 61 // 62 // All NetlinkMessages sent to the kernel need a valid message type (which is 63 // found in the nlmsghdr structure) and all NetlinkMessages received from the 64 // kernel have a valid message type. Some message types (NLMSG_NOOP, 65 // NLMSG_ERROR, and GENL_ID_CTRL, for example) are allocated statically; for 66 // those, the |message_type_| is assigned directly. 67 // 68 // Other message types ("nl80211", for example), are assigned by the kernel 69 // dynamically. To get the message type, pass a closure to assign the 70 // message_type along with the sting to NetlinkManager::GetFamily: 71 // 72 // nl80211_type = netlink_manager->GetFamily(Nl80211Message::kMessageType); 73 // 74 // Do all of this before you start to create NetlinkMessages so that 75 // NetlinkMessage can be instantiated with a valid |message_type_|. 76 77 class NetlinkPacket; 78 79 class SHILL_EXPORT NetlinkMessage { 80 public: 81 // Describes the context of the netlink message for parsing purposes. 82 struct MessageContext { 83 MessageContext() : nl80211_cmd(0), is_broadcast(false) {} 84 85 size_t nl80211_cmd; 86 bool is_broadcast; 87 }; 88 89 static const uint32_t kBroadcastSequenceNumber; 90 static const uint16_t kIllegalMessageType; 91 92 explicit NetlinkMessage(uint16_t message_type) : 93 flags_(0), message_type_(message_type), 94 sequence_number_(kBroadcastSequenceNumber) {} 95 virtual ~NetlinkMessage() {} 96 97 // Returns a string of bytes representing the message (with it headers) and 98 // any necessary padding. These bytes are appropriately formatted to be 99 // written to a netlink socket. 100 virtual ByteString Encode(uint32_t sequence_number) = 0; 101 102 // Initializes the |NetlinkMessage| from a complete and legal message 103 // (potentially received from the kernel via a netlink socket). 104 virtual bool InitFromPacket(NetlinkPacket* packet, MessageContext context); 105 106 uint16_t message_type() const { return message_type_; } 107 void AddFlag(uint16_t new_flag) { flags_ |= new_flag; } 108 void AddAckFlag() { flags_ |= NLM_F_ACK; } 109 uint16_t flags() const { return flags_; } 110 uint32_t sequence_number() const { return sequence_number_; } 111 // Logs the message. Allows a different log level (presumably more 112 // stringent) for the body of the message than the header. 113 virtual void Print(int header_log_level, int detail_log_level) const = 0; 114 115 // Logs the message's raw bytes (with minimal interpretation). 116 static void PrintBytes(int log_level, const unsigned char* buf, 117 size_t num_bytes); 118 119 // Logs a netlink message (with minimal interpretation). 120 static void PrintPacket(int log_level, const NetlinkPacket& packet); 121 122 protected: 123 friend class NetlinkManagerTest; 124 FRIEND_TEST(NetlinkManagerTest, NL80211_CMD_NOTIFY_CQM); 125 126 // Returns a string of bytes representing an |nlmsghdr|, filled-in, and its 127 // padding. 128 virtual ByteString EncodeHeader(uint32_t sequence_number); 129 // Reads the |nlmsghdr|. Subclasses may read additional data from the 130 // payload. 131 virtual bool InitAndStripHeader(NetlinkPacket* packet); 132 133 uint16_t flags_; 134 uint16_t message_type_; 135 uint32_t sequence_number_; 136 137 private: 138 static void PrintHeader(int log_level, const nlmsghdr* header); 139 static void PrintPayload(int log_level, const unsigned char* buf, 140 size_t num_bytes); 141 142 DISALLOW_COPY_AND_ASSIGN(NetlinkMessage); 143 }; 144 145 146 // The Error and Ack messages are received from the kernel and are combined, 147 // here, because they look so much alike (the only difference is that the 148 // error code is 0 for the Ack messages). Error messages are received from 149 // the kernel in response to a sent message when there's a problem (such as 150 // a malformed message or a busy kernel module). Ack messages are received 151 // from the kernel when a sent message has the NLM_F_ACK flag set, indicating 152 // that an Ack is requested. 153 class SHILL_EXPORT ErrorAckMessage : public NetlinkMessage { 154 public: 155 static const uint16_t kMessageType; 156 157 ErrorAckMessage() : NetlinkMessage(kMessageType), error_(0) {} 158 explicit ErrorAckMessage(uint32_t err) 159 : NetlinkMessage(kMessageType), error_(err) {} 160 static uint16_t GetMessageType() { return kMessageType; } 161 bool InitFromPacket(NetlinkPacket* packet, MessageContext context) override; 162 ByteString Encode(uint32_t sequence_number) override; 163 void Print(int header_log_level, int detail_log_level) const override; 164 std::string ToString() const; 165 uint32_t error() const { return -error_; } 166 167 private: 168 uint32_t error_; 169 170 DISALLOW_COPY_AND_ASSIGN(ErrorAckMessage); 171 }; 172 173 174 class SHILL_EXPORT NoopMessage : public NetlinkMessage { 175 public: 176 static const uint16_t kMessageType; 177 178 NoopMessage() : NetlinkMessage(kMessageType) {} 179 static uint16_t GetMessageType() { return kMessageType; } 180 virtual ByteString Encode(uint32_t sequence_number); 181 virtual void Print(int header_log_level, int detail_log_level) const; 182 std::string ToString() const { return "<NOOP>"; } 183 184 private: 185 DISALLOW_COPY_AND_ASSIGN(NoopMessage); 186 }; 187 188 189 class SHILL_EXPORT DoneMessage : public NetlinkMessage { 190 public: 191 static const uint16_t kMessageType; 192 193 DoneMessage() : NetlinkMessage(kMessageType) {} 194 static uint16_t GetMessageType() { return kMessageType; } 195 virtual ByteString Encode(uint32_t sequence_number); 196 virtual void Print(int header_log_level, int detail_log_level) const; 197 std::string ToString() const { return "<DONE with multipart message>"; } 198 199 private: 200 DISALLOW_COPY_AND_ASSIGN(DoneMessage); 201 }; 202 203 204 class SHILL_EXPORT OverrunMessage : public NetlinkMessage { 205 public: 206 static const uint16_t kMessageType; 207 208 OverrunMessage() : NetlinkMessage(kMessageType) {} 209 static uint16_t GetMessageType() { return kMessageType; } 210 virtual ByteString Encode(uint32_t sequence_number); 211 virtual void Print(int header_log_level, int detail_log_level) const; 212 std::string ToString() const { return "<OVERRUN - data lost>"; } 213 214 private: 215 DISALLOW_COPY_AND_ASSIGN(OverrunMessage); 216 }; 217 218 219 class SHILL_EXPORT UnknownMessage : public NetlinkMessage { 220 public: 221 UnknownMessage(uint16_t message_type, ByteString message_body) : 222 NetlinkMessage(message_type), message_body_(message_body) {} 223 virtual ByteString Encode(uint32_t sequence_number); 224 virtual void Print(int header_log_level, int detail_log_level) const; 225 226 private: 227 ByteString message_body_; 228 229 DISALLOW_COPY_AND_ASSIGN(UnknownMessage); 230 }; 231 232 233 // 234 // Factory class. 235 // 236 237 class SHILL_EXPORT NetlinkMessageFactory { 238 public: 239 typedef base::Callback<NetlinkMessage*(const NetlinkPacket& packet)> 240 FactoryMethod; 241 242 NetlinkMessageFactory() {} 243 244 // Adds a message factory for a specific message_type. Intended to be used 245 // at initialization. 246 bool AddFactoryMethod(uint16_t message_type, FactoryMethod factory); 247 248 // Ownership of the message is passed to the caller and, as such, he should 249 // delete it. 250 NetlinkMessage* CreateMessage(NetlinkPacket* packet, 251 NetlinkMessage::MessageContext context) const; 252 253 private: 254 std::map<uint16_t, FactoryMethod> factories_; 255 256 DISALLOW_COPY_AND_ASSIGN(NetlinkMessageFactory); 257 }; 258 259 } // namespace shill 260 261 #endif // SHILL_NET_NETLINK_MESSAGE_H_ 262