Home | History | Annotate | Download | only in net
      1 //
      2 // Copyright (C) 2012 The Android Open Source Project
      3 //
      4 // Licensed under the Apache License, Version 2.0 (the "License");
      5 // you may not use this file except in compliance with the License.
      6 // You may obtain a copy of the License at
      7 //
      8 //      http://www.apache.org/licenses/LICENSE-2.0
      9 //
     10 // Unless required by applicable law or agreed to in writing, software
     11 // distributed under the License is distributed on an "AS IS" BASIS,
     12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13 // See the License for the specific language governing permissions and
     14 // limitations under the License.
     15 //
     16 
     17 #include "shill/net/netlink_socket.h"
     18 
     19 #include <string>
     20 
     21 #include <linux/if_packet.h>
     22 #include <linux/netlink.h>
     23 #include <sys/socket.h>
     24 
     25 #include <base/logging.h>
     26 
     27 #include "shill/net/netlink_message.h"
     28 #include "shill/net/sockets.h"
     29 
     30 // This is from a version of linux/socket.h that we don't have.
     31 #define SOL_NETLINK 270
     32 
     33 namespace shill {
     34 
     35 // Keep this large enough to avoid overflows on IPv6 SNM routing update spikes
     36 const int NetlinkSocket::kReceiveBufferSize = 512 * 1024;
     37 
     38 NetlinkSocket::NetlinkSocket() : sequence_number_(0), file_descriptor_(-1) {}
     39 
     40 NetlinkSocket::~NetlinkSocket() {
     41   if (sockets_ && (file_descriptor_ >= 0)) {
     42     sockets_->Close(file_descriptor_);
     43   }
     44 }
     45 
     46 bool NetlinkSocket::Init() {
     47   // Allows for a test to set |sockets_| before calling |Init|.
     48   if (sockets_) {
     49     LOG(INFO) << "|sockets_| already has a value -- this must be a test.";
     50   } else {
     51     sockets_.reset(new Sockets);
     52   }
     53 
     54   // The following is stolen directly from RTNLHandler.
     55   // TODO(wdg): refactor this and RTNLHandler together to use common code.
     56   // crbug.com/221940
     57 
     58   file_descriptor_ = sockets_->Socket(PF_NETLINK, SOCK_DGRAM, NETLINK_GENERIC);
     59   if (file_descriptor_ < 0) {
     60     LOG(ERROR) << "Failed to open netlink socket";
     61     return false;
     62   }
     63 
     64   if (sockets_->SetReceiveBuffer(file_descriptor_, kReceiveBufferSize)) {
     65     LOG(ERROR) << "Failed to increase receive buffer size";
     66   }
     67 
     68   struct sockaddr_nl addr;
     69   memset(&addr, 0, sizeof(addr));
     70   addr.nl_family = AF_NETLINK;
     71 
     72   if (sockets_->Bind(file_descriptor_,
     73                     reinterpret_cast<struct sockaddr*>(&addr),
     74                     sizeof(addr)) < 0) {
     75     sockets_->Close(file_descriptor_);
     76     file_descriptor_ = -1;
     77     LOG(ERROR) << "Netlink socket bind failed";
     78     return false;
     79   }
     80   VLOG(2) << "Netlink socket started";
     81 
     82   return true;
     83 }
     84 
     85 bool NetlinkSocket::RecvMessage(ByteString* message) {
     86   if (!message) {
     87     LOG(ERROR) << "Null |message|";
     88     return false;
     89   }
     90 
     91   // Determine the amount of data currently waiting.
     92   const size_t kDummyReadByteCount = 1;
     93   ByteString dummy_read(kDummyReadByteCount);
     94   ssize_t result;
     95   result = sockets_->RecvFrom(
     96       file_descriptor_,
     97       dummy_read.GetData(),
     98       dummy_read.GetLength(),
     99       MSG_TRUNC | MSG_PEEK,
    100       nullptr,
    101       nullptr);
    102   if (result < 0) {
    103     PLOG(ERROR) << "Socket recvfrom failed.";
    104     return false;
    105   }
    106 
    107   // Read the data that was waiting when we did our previous read.
    108   message->Resize(result);
    109   result = sockets_->RecvFrom(
    110       file_descriptor_,
    111       message->GetData(),
    112       message->GetLength(),
    113       0,
    114       nullptr,
    115       nullptr);
    116   if (result < 0) {
    117     PLOG(ERROR) << "Second socket recvfrom failed.";
    118     return false;
    119   }
    120   return true;
    121 }
    122 
    123 bool NetlinkSocket::SendMessage(const ByteString& out_msg) {
    124   ssize_t result = sockets_->Send(file_descriptor(), out_msg.GetConstData(),
    125                                   out_msg.GetLength(), 0);
    126   if (!result) {
    127     PLOG(ERROR) << "Send failed.";
    128     return false;
    129   }
    130   if (result != static_cast<ssize_t>(out_msg.GetLength())) {
    131     LOG(ERROR) << "Only sent " << result << " bytes out of "
    132                << out_msg.GetLength() << ".";
    133     return false;
    134   }
    135 
    136   return true;
    137 }
    138 
    139 bool NetlinkSocket::SubscribeToEvents(uint32_t group_id) {
    140   int err = setsockopt(file_descriptor_, SOL_NETLINK, NETLINK_ADD_MEMBERSHIP,
    141                        &group_id, sizeof(group_id));
    142   if (err < 0) {
    143     PLOG(ERROR) << "setsockopt didn't work.";
    144     return false;
    145   }
    146   return true;
    147 }
    148 
    149 uint32_t NetlinkSocket::GetSequenceNumber() {
    150   if (++sequence_number_ == NetlinkMessage::kBroadcastSequenceNumber)
    151     ++sequence_number_;
    152   return sequence_number_;
    153 }
    154 
    155 }  // namespace shill.
    156