Home | History | Annotate | Download | only in net
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 #include "common/libs/net/netlink_client.h"
     17 
     18 #include <errno.h>
     19 #include <linux/rtnetlink.h>
     20 #include <linux/sockios.h>
     21 #include <net/if.h>
     22 #include <sys/socket.h>
     23 
     24 #include "common/libs/fs/shared_fd.h"
     25 #include "common/libs/glog/logging.h"
     26 
     27 namespace cvd {
     28 namespace {
     29 // NetlinkClient implementation.
     30 // Talks to libnetlink to apply network changes.
     31 class NetlinkClientImpl : public NetlinkClient {
     32  public:
     33   NetlinkClientImpl() = default;
     34   virtual ~NetlinkClientImpl() = default;
     35 
     36   virtual bool Send(const NetlinkRequest& message);
     37 
     38   // Initialize NetlinkClient instance.
     39   // Open netlink channel and initialize interface list.
     40   // Parameter |type| specifies which netlink target to address, eg.
     41   // NETLINK_ROUTE.
     42   // Returns true, if initialization was successful.
     43   bool OpenNetlink(int type);
     44 
     45  private:
     46   bool CheckResponse(uint32_t seq_no);
     47 
     48   SharedFD netlink_fd_;
     49   sockaddr_nl address_;
     50 };
     51 
     52 bool NetlinkClientImpl::CheckResponse(uint32_t seq_no) {
     53   uint32_t len;
     54   char buf[4096];
     55   struct iovec iov = { buf, sizeof(buf) };
     56   struct sockaddr_nl sa;
     57   struct msghdr msg = { &sa, sizeof(sa), &iov, 1, NULL, 0, 0 };
     58   struct nlmsghdr *nh;
     59 
     60   int result = netlink_fd_->RecvMsg(&msg, 0);
     61   if (result  < 0) {
     62     LOG(ERROR) << "Netlink error: " << strerror(errno);
     63     return false;
     64   }
     65 
     66   len = static_cast<uint32_t>(result);
     67   LOG(INFO) << "Received netlink response (" << len << " bytes)";
     68 
     69   for (nh = reinterpret_cast<nlmsghdr*>(buf);
     70        NLMSG_OK(nh, len);
     71        nh = NLMSG_NEXT(nh, len)) {
     72     if (nh->nlmsg_seq != seq_no) {
     73       // This really shouldn't happen. If we see this, it means somebody is
     74       // issuing netlink requests using the same socket as us, and ignoring
     75       // responses.
     76       LOG(WARNING) << "Sequence number mismatch: "
     77                    << nh->nlmsg_seq << " != " << seq_no;
     78       continue;
     79     }
     80 
     81     // This is the end of multi-part message.
     82     // It indicates there's nothing more netlink wants to tell us.
     83     // It also means we failed to find the response to our request.
     84     if (nh->nlmsg_type == NLMSG_DONE)
     85       break;
     86 
     87     // This is the 'nlmsgerr' package carrying response to our request.
     88     // It carries an 'error' value (errno) along with the netlink header info
     89     // that caused this error.
     90     if (nh->nlmsg_type == NLMSG_ERROR) {
     91       nlmsgerr* err = reinterpret_cast<nlmsgerr*>(nh + 1);
     92       if (err->error < 0) {
     93         LOG(ERROR) << "Failed to complete netlink request: "
     94                    << "Netlink error: " << err->error
     95                    << ", errno: " << strerror(errno);
     96         return false;
     97       }
     98       return true;
     99     }
    100   }
    101 
    102   LOG(ERROR) << "No response from netlink.";
    103   return false;
    104 }
    105 
    106 bool NetlinkClientImpl::Send(const NetlinkRequest& message) {
    107   struct sockaddr_nl netlink_addr;
    108   struct iovec netlink_iov = {
    109     message.RequestData(),
    110     message.RequestLength()
    111   };
    112   struct msghdr msg;
    113   memset(&msg, 0, sizeof(msg));
    114   memset(&netlink_addr, 0, sizeof(netlink_addr));
    115 
    116   msg.msg_name = &address_;
    117   msg.msg_namelen = sizeof(address_);
    118   msg.msg_iov = &netlink_iov;
    119   msg.msg_iovlen = sizeof(netlink_iov) / sizeof(iovec);
    120 
    121   if (netlink_fd_->SendMsg(&msg, 0) < 0) {
    122     LOG(ERROR) << "Failed to send netlink message: "
    123                << strerror(errno);
    124 
    125     return false;
    126   }
    127 
    128   return CheckResponse(message.SeqNo());
    129 }
    130 
    131 bool NetlinkClientImpl::OpenNetlink(int type) {
    132   netlink_fd_ = SharedFD::Socket(AF_NETLINK, SOCK_RAW, type);
    133   if (!netlink_fd_->IsOpen()) return false;
    134 
    135   address_.nl_family = AF_NETLINK;
    136   address_.nl_groups = 0;
    137 
    138   netlink_fd_->Bind(reinterpret_cast<sockaddr*>(&address_), sizeof(address_));
    139 
    140   return true;
    141 }
    142 
    143 class NetlinkClientFactoryImpl : public NetlinkClientFactory {
    144  public:
    145   NetlinkClientFactoryImpl() = default;
    146   ~NetlinkClientFactoryImpl() override = default;
    147 
    148   std::unique_ptr<NetlinkClient> New(int type) override {
    149     auto client_raw = new NetlinkClientImpl();
    150     // Use RVO when possible.
    151     std::unique_ptr<NetlinkClient> client(client_raw);
    152 
    153     if (!client_raw->OpenNetlink(type)) {
    154       // Note: deletes client_raw.
    155       client.reset();
    156     }
    157     return client;
    158   }
    159 };
    160 
    161 }  // namespace
    162 
    163 NetlinkClientFactory* NetlinkClientFactory::Default() {
    164   static NetlinkClientFactory &factory = *new NetlinkClientFactoryImpl();
    165   return &factory;
    166 }
    167 
    168 }  // namespace cvd
    169