Home | History | Annotate | Download | only in net
      1 /*
      2  * Copyright (C) 2016 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "wificond/net/nl80211_packet.h"
     18 
     19 #include <android-base/logging.h>
     20 
     21 using std::make_unique;
     22 using std::unique_ptr;
     23 using std::vector;
     24 
     25 namespace android {
     26 namespace wificond {
     27 
     28 NL80211Packet::NL80211Packet(const vector<uint8_t>& data)
     29     : data_(data) {
     30   data_ = data;
     31 }
     32 
     33 NL80211Packet::NL80211Packet(const NL80211Packet& packet) {
     34   data_ = packet.data_;
     35   LOG(WARNING) << "Copy constructor is only used for unit tests";
     36 }
     37 
     38 NL80211Packet::NL80211Packet(uint16_t type,
     39                              uint8_t command,
     40                              uint32_t sequence,
     41                              uint32_t pid) {
     42   // Initialize the netlink header and generic netlink header.
     43   // NLMSG_HDRLEN and GENL_HDRLEN already include the padding size.
     44   data_.resize(NLMSG_HDRLEN + GENL_HDRLEN, 0);
     45   // Initialize length field.
     46   nlmsghdr* nl_header = reinterpret_cast<nlmsghdr*>(data_.data());
     47   nl_header->nlmsg_len = data_.size();
     48   // Add NLM_F_REQUEST flag.
     49   nl_header->nlmsg_flags = nl_header->nlmsg_flags | NLM_F_REQUEST;
     50   nl_header->nlmsg_type = type;
     51   nl_header->nlmsg_seq = sequence;
     52   nl_header->nlmsg_pid = pid;
     53 
     54   genlmsghdr* genl_header =
     55       reinterpret_cast<genlmsghdr*>(data_.data() + NLMSG_HDRLEN);
     56   genl_header->version = 1;
     57   genl_header->cmd = command;
     58   // genl_header->reserved is aready 0.
     59 }
     60 
     61 bool NL80211Packet::IsValid() const {
     62   // Verify the size of packet.
     63   if (data_.size() < NLMSG_HDRLEN) {
     64     LOG(ERROR) << "Cannot retrieve netlink header.";
     65     return false;
     66   }
     67 
     68   const nlmsghdr* nl_header = reinterpret_cast<const nlmsghdr*>(data_.data());
     69 
     70   // If type < NLMSG_MIN_TYPE, this should be a reserved control message,
     71   // which doesn't carry a generic netlink header.
     72   if (GetMessageType() >= NLMSG_MIN_TYPE) {
     73     if (data_.size() < NLMSG_HDRLEN + GENL_HDRLEN ||
     74         nl_header->nlmsg_len < NLMSG_HDRLEN + GENL_HDRLEN) {
     75       LOG(ERROR) << "Cannot retrieve generic netlink header.";
     76       return false;
     77     }
     78   }
     79   // If it is an ERROR message, it should be long enough to carry an extra error
     80   // code field.
     81   // Kernel uses int for this field.
     82   if (GetMessageType() == NLMSG_ERROR) {
     83     if (data_.size() < NLMSG_HDRLEN + sizeof(int) ||
     84         nl_header->nlmsg_len < NLMSG_HDRLEN + sizeof(int)) {
     85      LOG(ERROR) << "Broken error message.";
     86      return false;
     87     }
     88   }
     89 
     90   // Verify the netlink header.
     91   if (data_.size() < nl_header->nlmsg_len ||
     92       nl_header->nlmsg_len < sizeof(nlmsghdr)) {
     93     LOG(ERROR) << "Discarding incomplete / invalid message.";
     94     return false;
     95   }
     96   return true;
     97 }
     98 
     99 bool NL80211Packet::IsDump() const {
    100   return GetFlags() & NLM_F_DUMP;
    101 }
    102 
    103 bool NL80211Packet::IsMulti() const {
    104   return GetFlags() & NLM_F_MULTI;
    105 }
    106 
    107 uint8_t NL80211Packet::GetCommand() const {
    108   const genlmsghdr* genl_header = reinterpret_cast<const genlmsghdr*>(
    109       data_.data() + NLMSG_HDRLEN);
    110   return genl_header->cmd;
    111 }
    112 
    113 uint16_t NL80211Packet::GetFlags() const {
    114   const nlmsghdr* nl_header = reinterpret_cast<const nlmsghdr*>(data_.data());
    115   return nl_header->nlmsg_flags;
    116 }
    117 
    118 uint16_t NL80211Packet::GetMessageType() const {
    119   const nlmsghdr* nl_header = reinterpret_cast<const nlmsghdr*>(data_.data());
    120   return nl_header->nlmsg_type;
    121 }
    122 
    123 uint32_t NL80211Packet::GetMessageSequence() const {
    124   const nlmsghdr* nl_header = reinterpret_cast<const nlmsghdr*>(data_.data());
    125   return nl_header->nlmsg_seq;
    126 }
    127 
    128 uint32_t NL80211Packet::GetPortId() const {
    129   const nlmsghdr* nl_header = reinterpret_cast<const nlmsghdr*>(data_.data());
    130   return nl_header->nlmsg_pid;
    131 }
    132 
    133 int NL80211Packet::GetErrorCode() const {
    134   return -*reinterpret_cast<const int*>(data_.data() + NLMSG_HDRLEN);
    135 }
    136 
    137 const vector<uint8_t>& NL80211Packet::GetConstData() const {
    138   return data_;
    139 }
    140 
    141 void NL80211Packet::SetCommand(uint8_t command) {
    142   genlmsghdr* genl_header = reinterpret_cast<genlmsghdr*>(
    143       data_.data() + NLMSG_HDRLEN);
    144   genl_header->cmd = command;
    145 }
    146 
    147 void NL80211Packet::AddFlag(uint16_t flag) {
    148   nlmsghdr* nl_header = reinterpret_cast<nlmsghdr*>(data_.data());
    149   nl_header->nlmsg_flags |= flag;
    150 }
    151 
    152 void NL80211Packet::SetFlags(uint16_t flags) {
    153   nlmsghdr* nl_header = reinterpret_cast<nlmsghdr*>(data_.data());
    154   nl_header->nlmsg_flags = flags;
    155 }
    156 
    157 void NL80211Packet::SetMessageType(uint16_t message_type) {
    158   nlmsghdr* nl_header = reinterpret_cast<nlmsghdr*>(data_.data());
    159   nl_header->nlmsg_type = message_type;
    160 }
    161 
    162 void NL80211Packet::SetMessageSequence(uint32_t message_sequence) {
    163   nlmsghdr* nl_header = reinterpret_cast<nlmsghdr*>(data_.data());
    164   nl_header->nlmsg_seq = message_sequence;
    165 }
    166 
    167 void NL80211Packet::SetPortId(uint32_t pid) {
    168   nlmsghdr* nl_header = reinterpret_cast<nlmsghdr*>(data_.data());
    169   nl_header->nlmsg_pid = pid;
    170 }
    171 
    172 void NL80211Packet::AddAttribute(const BaseNL80211Attr& attribute) {
    173   const vector<uint8_t>& append_data = attribute.GetConstData();
    174   // Append the data of |attribute| to |this|.
    175   data_.insert(data_.end(), append_data.begin(), append_data.end());
    176   nlmsghdr* nl_header = reinterpret_cast<nlmsghdr*>(data_.data());
    177   // We don't need to worry about padding for a nl80211 packet.
    178   // Because as long as all sub attributes have padding, the payload is aligned.
    179   nl_header->nlmsg_len += append_data.size();
    180 }
    181 
    182 void NL80211Packet::AddFlagAttribute(int attribute_id) {
    183   // We only need to append a header for flag attribute.
    184   // Make space for the new attribute.
    185   data_.resize(data_.size() + NLA_HDRLEN, 0);
    186   nlattr* flag_header =
    187       reinterpret_cast<nlattr*>(data_.data() + data_.size() - NLA_HDRLEN);
    188   flag_header->nla_type = attribute_id;
    189   flag_header->nla_len = NLA_HDRLEN;
    190   nlmsghdr* nl_header = reinterpret_cast<nlmsghdr*>(data_.data());
    191   nl_header->nlmsg_len += NLA_HDRLEN;
    192 }
    193 
    194 bool NL80211Packet::HasAttribute(int id) const {
    195   return BaseNL80211Attr::GetAttributeImpl(
    196       data_.data() + NLMSG_HDRLEN + GENL_HDRLEN,
    197       data_.size() - NLMSG_HDRLEN - GENL_HDRLEN,
    198       id, nullptr, nullptr);
    199 }
    200 
    201 bool NL80211Packet::GetAttribute(int id,
    202     NL80211NestedAttr* attribute) const {
    203   uint8_t* start = nullptr;
    204   uint8_t* end = nullptr;
    205   if (!BaseNL80211Attr::GetAttributeImpl(
    206           data_.data() + NLMSG_HDRLEN + GENL_HDRLEN,
    207           data_.size() - NLMSG_HDRLEN - GENL_HDRLEN,
    208           id, &start, &end) ||
    209       start == nullptr ||
    210       end == nullptr) {
    211     return false;
    212   }
    213   *attribute = NL80211NestedAttr(vector<uint8_t>(start, end));
    214   if (!attribute->IsValid()) {
    215     return false;
    216   }
    217   return true;
    218 }
    219 
    220 bool NL80211Packet::GetAllAttributes(
    221     vector<BaseNL80211Attr>* attributes) const {
    222   const uint8_t* ptr = data_.data() + NLMSG_HDRLEN + GENL_HDRLEN;
    223   const uint8_t* end_ptr = data_.data() + data_.size();
    224   while (ptr + NLA_HDRLEN <= end_ptr) {
    225     auto header = reinterpret_cast<const nlattr*>(ptr);
    226     if (ptr + NLA_ALIGN(header->nla_len) > end_ptr ||
    227       header->nla_len == 0) {
    228       LOG(ERROR) << "broken nl80211 atrribute.";
    229       return false;
    230     }
    231     attributes->emplace_back(
    232         header->nla_type,
    233         vector<uint8_t>(ptr + NLA_HDRLEN, ptr + header->nla_len));
    234     ptr += NLA_ALIGN(header->nla_len);
    235   }
    236   return true;
    237 }
    238 
    239 void NL80211Packet::DebugLog() const {
    240   const uint8_t* ptr = data_.data() + NLMSG_HDRLEN + GENL_HDRLEN;
    241   const uint8_t* end_ptr = data_.data() + data_.size();
    242   while (ptr + NLA_HDRLEN <= end_ptr) {
    243     const nlattr* header = reinterpret_cast<const nlattr*>(ptr);
    244     if (ptr + NLA_ALIGN(header->nla_len) > end_ptr) {
    245       LOG(ERROR) << "broken nl80211 atrribute.";
    246       return;
    247     }
    248     LOG(INFO) << "Have attribute with nla_type=" << header->nla_type
    249               << " and nla_len=" << header->nla_len;
    250     if (header->nla_len == 0) {
    251       LOG(ERROR) << "0 is a bad nla_len";
    252       return;
    253     }
    254     ptr += NLA_ALIGN(header->nla_len);
    255   }
    256 }
    257 
    258 }  // namespace wificond
    259 }  // namespace android
    260