1 // 2 // Copyright (C) 2015 The Android Open Source Project 3 // 4 // Licensed under the Apache License, Version 2.0 (the "License"); 5 // you may not use this file except in compliance with the License. 6 // You may obtain a copy of the License at 7 // 8 // http://www.apache.org/licenses/LICENSE-2.0 9 // 10 // Unless required by applicable law or agreed to in writing, software 11 // distributed under the License is distributed on an "AS IS" BASIS, 12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 // See the License for the specific language governing permissions and 14 // limitations under the License. 15 // 16 17 #include "shill/net/netlink_packet.h" 18 19 #include <algorithm> 20 21 #include <base/logging.h> 22 23 #include "shill/net/byte_string.h" 24 25 namespace shill { 26 27 NetlinkPacket::NetlinkPacket(const unsigned char* buf, size_t len) 28 : consumed_bytes_(0) { 29 if (!buf || len < sizeof(header_)) { 30 LOG(ERROR) << "Cannot retrieve header."; 31 return; 32 } 33 34 memcpy(&header_, buf, sizeof(header_)); 35 if (len < header_.nlmsg_len || header_.nlmsg_len < sizeof(header_)) { 36 LOG(ERROR) << "Discarding incomplete / invalid message."; 37 return; 38 } 39 40 payload_.reset( 41 new ByteString(buf + sizeof(header_), len - sizeof(header_))); 42 } 43 44 NetlinkPacket::~NetlinkPacket() { 45 } 46 47 bool NetlinkPacket::IsValid() const { 48 return payload_ != nullptr; 49 } 50 51 size_t NetlinkPacket::GetLength() const { 52 return GetNlMsgHeader().nlmsg_len; 53 } 54 55 uint16_t NetlinkPacket::GetMessageType() const { 56 return GetNlMsgHeader().nlmsg_type; 57 } 58 59 uint32_t NetlinkPacket::GetMessageSequence() const { 60 return GetNlMsgHeader().nlmsg_seq; 61 } 62 63 size_t NetlinkPacket::GetRemainingLength() const { 64 return GetPayload().GetLength() - consumed_bytes_; 65 } 66 67 const ByteString& NetlinkPacket::GetPayload() const { 68 CHECK(IsValid()); 69 return *payload_.get(); 70 } 71 72 bool NetlinkPacket::ConsumeAttributes( 73 const AttributeList::NewFromIdMethod& factory, 74 const AttributeListRefPtr& attributes) { 75 bool result = attributes->Decode(GetPayload(), consumed_bytes_, factory); 76 consumed_bytes_ = GetPayload().GetLength(); 77 return result; 78 } 79 80 bool NetlinkPacket::ConsumeData(size_t len, void* data) { 81 if (GetRemainingLength() < len) { 82 LOG(ERROR) << "Not enough bytes remaining."; 83 return false; 84 } 85 86 memcpy(data, payload_->GetData() + consumed_bytes_, len); 87 consumed_bytes_ = std::min(payload_->GetLength(), 88 consumed_bytes_ + NLMSG_ALIGN(len)); 89 return true; 90 } 91 92 93 const nlmsghdr& NetlinkPacket::GetNlMsgHeader() const { 94 CHECK(IsValid()); 95 return header_; 96 } 97 98 bool NetlinkPacket::GetGenlMsgHdr(genlmsghdr* header) const { 99 if (GetPayload().GetLength() < sizeof(*header)) { 100 return false; 101 } 102 memcpy(header, payload_->GetConstData(), sizeof(*header)); 103 return true; 104 } 105 106 MutableNetlinkPacket::MutableNetlinkPacket(const unsigned char* buf, size_t len) 107 : NetlinkPacket(buf, len) { 108 } 109 110 MutableNetlinkPacket::~MutableNetlinkPacket() { 111 } 112 113 void MutableNetlinkPacket::ResetConsumedBytes() { 114 set_consumed_bytes(0); 115 } 116 117 nlmsghdr* MutableNetlinkPacket::GetMutableHeader() { 118 CHECK(IsValid()); 119 return mutable_header(); 120 } 121 122 ByteString* MutableNetlinkPacket::GetMutablePayload() { 123 CHECK(IsValid()); 124 return mutable_payload(); 125 } 126 127 void MutableNetlinkPacket::SetMessageType(uint16_t type) { 128 mutable_header()->nlmsg_type = type; 129 } 130 131 void MutableNetlinkPacket::SetMessageSequence(uint32_t sequence) { 132 mutable_header()->nlmsg_seq = sequence; 133 } 134 135 } // namespace shill. 136