Home | History | Annotate | Download | only in net
      1 //
      2 // Copyright (C) 2015 The Android Open Source Project
      3 //
      4 // Licensed under the Apache License, Version 2.0 (the "License");
      5 // you may not use this file except in compliance with the License.
      6 // You may obtain a copy of the License at
      7 //
      8 //      http://www.apache.org/licenses/LICENSE-2.0
      9 //
     10 // Unless required by applicable law or agreed to in writing, software
     11 // distributed under the License is distributed on an "AS IS" BASIS,
     12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13 // See the License for the specific language governing permissions and
     14 // limitations under the License.
     15 //
     16 
     17 #include "shill/net/netlink_packet.h"
     18 
     19 #include <algorithm>
     20 
     21 #include <base/logging.h>
     22 
     23 #include "shill/net/byte_string.h"
     24 
     25 namespace shill {
     26 
     27 NetlinkPacket::NetlinkPacket(const unsigned char* buf, size_t len)
     28     : consumed_bytes_(0) {
     29   if (!buf || len < sizeof(header_)) {
     30     LOG(ERROR) << "Cannot retrieve header.";
     31     return;
     32   }
     33 
     34   memcpy(&header_, buf, sizeof(header_));
     35   if (len < header_.nlmsg_len || header_.nlmsg_len < sizeof(header_)) {
     36     LOG(ERROR) << "Discarding incomplete / invalid message.";
     37     return;
     38   }
     39 
     40   payload_.reset(
     41       new ByteString(buf + sizeof(header_), len - sizeof(header_)));
     42 }
     43 
     44 NetlinkPacket::~NetlinkPacket() {
     45 }
     46 
     47 bool NetlinkPacket::IsValid() const {
     48   return payload_ != nullptr;
     49 }
     50 
     51 size_t NetlinkPacket::GetLength() const {
     52   return GetNlMsgHeader().nlmsg_len;
     53 }
     54 
     55 uint16_t NetlinkPacket::GetMessageType() const {
     56   return GetNlMsgHeader().nlmsg_type;
     57 }
     58 
     59 uint32_t NetlinkPacket::GetMessageSequence() const {
     60   return GetNlMsgHeader().nlmsg_seq;
     61 }
     62 
     63 size_t NetlinkPacket::GetRemainingLength() const {
     64   return GetPayload().GetLength() - consumed_bytes_;
     65 }
     66 
     67 const ByteString& NetlinkPacket::GetPayload() const {
     68   CHECK(IsValid());
     69   return *payload_.get();
     70 }
     71 
     72 bool NetlinkPacket::ConsumeAttributes(
     73     const AttributeList::NewFromIdMethod& factory,
     74     const AttributeListRefPtr& attributes) {
     75   bool result = attributes->Decode(GetPayload(), consumed_bytes_, factory);
     76   consumed_bytes_ = GetPayload().GetLength();
     77   return result;
     78 }
     79 
     80 bool NetlinkPacket::ConsumeData(size_t len, void* data) {
     81   if (GetRemainingLength() < len) {
     82     LOG(ERROR) << "Not enough bytes remaining.";
     83     return false;
     84   }
     85 
     86   memcpy(data, payload_->GetData() + consumed_bytes_, len);
     87   consumed_bytes_ = std::min(payload_->GetLength(),
     88                              consumed_bytes_ + NLMSG_ALIGN(len));
     89   return true;
     90 }
     91 
     92 
     93 const nlmsghdr& NetlinkPacket::GetNlMsgHeader() const {
     94   CHECK(IsValid());
     95   return header_;
     96 }
     97 
     98 bool NetlinkPacket::GetGenlMsgHdr(genlmsghdr* header) const {
     99   if (GetPayload().GetLength() < sizeof(*header)) {
    100     return false;
    101   }
    102   memcpy(header, payload_->GetConstData(), sizeof(*header));
    103   return true;
    104 }
    105 
    106 MutableNetlinkPacket::MutableNetlinkPacket(const unsigned char* buf, size_t len)
    107     : NetlinkPacket(buf, len) {
    108 }
    109 
    110 MutableNetlinkPacket::~MutableNetlinkPacket() {
    111 }
    112 
    113 void MutableNetlinkPacket::ResetConsumedBytes() {
    114   set_consumed_bytes(0);
    115 }
    116 
    117 nlmsghdr* MutableNetlinkPacket::GetMutableHeader() {
    118   CHECK(IsValid());
    119   return mutable_header();
    120 }
    121 
    122 ByteString* MutableNetlinkPacket::GetMutablePayload() {
    123   CHECK(IsValid());
    124   return mutable_payload();
    125 }
    126 
    127 void MutableNetlinkPacket::SetMessageType(uint16_t type) {
    128   mutable_header()->nlmsg_type = type;
    129 }
    130 
    131 void MutableNetlinkPacket::SetMessageSequence(uint32_t sequence) {
    132   mutable_header()->nlmsg_seq = sequence;
    133 }
    134 
    135 }  // namespace shill.
    136