Home | History | Annotate | Download | only in net
      1 //
      2 // Copyright (C) 2012 The Android Open Source Project
      3 //
      4 // Licensed under the Apache License, Version 2.0 (the "License");
      5 // you may not use this file except in compliance with the License.
      6 // You may obtain a copy of the License at
      7 //
      8 //      http://www.apache.org/licenses/LICENSE-2.0
      9 //
     10 // Unless required by applicable law or agreed to in writing, software
     11 // distributed under the License is distributed on an "AS IS" BASIS,
     12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13 // See the License for the specific language governing permissions and
     14 // limitations under the License.
     15 //
     16 
     17 #include "shill/net/rtnl_handler.h"
     18 
     19 #include <string>
     20 
     21 #include <gtest/gtest.h>
     22 #include <net/if.h>
     23 #include <sys/socket.h>
     24 #include <linux/netlink.h>  // Needs typedefs from sys/socket.h.
     25 #include <linux/rtnetlink.h>
     26 #include <sys/ioctl.h>
     27 
     28 #include <base/bind.h>
     29 
     30 #include "shill/mock_log.h"
     31 #include "shill/net/mock_io_handler_factory.h"
     32 #include "shill/net/mock_sockets.h"
     33 #include "shill/net/rtnl_message.h"
     34 
     35 using base::Bind;
     36 using base::Callback;
     37 using base::Unretained;
     38 using std::string;
     39 using testing::_;
     40 using testing::A;
     41 using testing::DoAll;
     42 using testing::ElementsAre;
     43 using testing::HasSubstr;
     44 using testing::Return;
     45 using testing::ReturnArg;
     46 using testing::StrictMock;
     47 using testing::Test;
     48 
     49 namespace shill {
     50 
     51 namespace {
     52 
     53 const int kTestInterfaceIndex = 4;
     54 
     55 ACTION(SetInterfaceIndex) {
     56   if (arg2) {
     57     reinterpret_cast<struct ifreq*>(arg2)->ifr_ifindex = kTestInterfaceIndex;
     58   }
     59 }
     60 
     61 MATCHER_P(MessageType, message_type, "") {
     62   return std::get<0>(arg).type() == message_type;
     63 }
     64 
     65 }  // namespace
     66 
     67 class RTNLHandlerTest : public Test {
     68  public:
     69   RTNLHandlerTest()
     70       : sockets_(new StrictMock<MockSockets>()),
     71         callback_(Bind(&RTNLHandlerTest::HandlerCallback, Unretained(this))),
     72         dummy_message_(RTNLMessage::kTypeLink,
     73                        RTNLMessage::kModeGet,
     74                        0,
     75                        0,
     76                        0,
     77                        0,
     78                        IPAddress::kFamilyUnknown) {
     79   }
     80 
     81   virtual void SetUp() {
     82     RTNLHandler::GetInstance()->io_handler_factory_ = &io_handler_factory_;
     83     RTNLHandler::GetInstance()->sockets_.reset(sockets_);
     84   }
     85 
     86   virtual void TearDown() {
     87     RTNLHandler::GetInstance()->Stop();
     88   }
     89 
     90   uint32_t GetRequestSequence() {
     91     return RTNLHandler::GetInstance()->request_sequence_;
     92   }
     93 
     94   void SetRequestSequence(uint32_t sequence) {
     95     RTNLHandler::GetInstance()->request_sequence_ = sequence;
     96   }
     97 
     98   bool IsSequenceInErrorMaskWindow(uint32_t sequence) {
     99     return RTNLHandler::GetInstance()->IsSequenceInErrorMaskWindow(sequence);
    100   }
    101 
    102   void SetErrorMask(uint32_t sequence,
    103                     const RTNLHandler::ErrorMask& error_mask) {
    104     return RTNLHandler::GetInstance()->SetErrorMask(sequence, error_mask);
    105   }
    106 
    107   RTNLHandler::ErrorMask GetAndClearErrorMask(uint32_t sequence) {
    108     return RTNLHandler::GetInstance()->GetAndClearErrorMask(sequence);
    109   }
    110 
    111   int GetErrorWindowSize() {
    112     return  RTNLHandler::kErrorWindowSize;
    113   }
    114 
    115   MOCK_METHOD1(HandlerCallback, void(const RTNLMessage&));
    116 
    117  protected:
    118   static const int kTestSocket;
    119   static const int kTestDeviceIndex;
    120   static const char kTestDeviceName[];
    121 
    122   void AddLink();
    123   void AddNeighbor();
    124   void StartRTNLHandler();
    125   void StopRTNLHandler();
    126   void ReturnError(uint32_t sequence, int error_number);
    127 
    128   MockSockets* sockets_;
    129   StrictMock<MockIOHandlerFactory> io_handler_factory_;
    130   Callback<void(const RTNLMessage&)> callback_;
    131   RTNLMessage dummy_message_;
    132 };
    133 
    134 const int RTNLHandlerTest::kTestSocket = 123;
    135 const int RTNLHandlerTest::kTestDeviceIndex = 123456;
    136 const char RTNLHandlerTest::kTestDeviceName[] = "test-device";
    137 
    138 void RTNLHandlerTest::StartRTNLHandler() {
    139   EXPECT_CALL(*sockets_, Socket(PF_NETLINK, SOCK_DGRAM, NETLINK_ROUTE))
    140       .WillOnce(Return(kTestSocket));
    141   EXPECT_CALL(*sockets_, Bind(kTestSocket, _, sizeof(sockaddr_nl)))
    142       .WillOnce(Return(0));
    143   EXPECT_CALL(*sockets_, SetReceiveBuffer(kTestSocket, _)).WillOnce(Return(0));
    144   EXPECT_CALL(io_handler_factory_, CreateIOInputHandler(kTestSocket, _, _));
    145   RTNLHandler::GetInstance()->Start(0);
    146 }
    147 
    148 void RTNLHandlerTest::StopRTNLHandler() {
    149   EXPECT_CALL(*sockets_, Close(kTestSocket)).WillOnce(Return(0));
    150   RTNLHandler::GetInstance()->Stop();
    151 }
    152 
    153 void RTNLHandlerTest::AddLink() {
    154   RTNLMessage message(RTNLMessage::kTypeLink,
    155                       RTNLMessage::kModeAdd,
    156                       0,
    157                       0,
    158                       0,
    159                       kTestDeviceIndex,
    160                       IPAddress::kFamilyIPv4);
    161   message.SetAttribute(static_cast<uint16_t>(IFLA_IFNAME),
    162                        ByteString(string(kTestDeviceName), true));
    163   ByteString b(message.Encode());
    164   InputData data(b.GetData(), b.GetLength());
    165   RTNLHandler::GetInstance()->ParseRTNL(&data);
    166 }
    167 
    168 void RTNLHandlerTest::AddNeighbor() {
    169   RTNLMessage message(RTNLMessage::kTypeNeighbor,
    170                       RTNLMessage::kModeAdd,
    171                       0,
    172                       0,
    173                       0,
    174                       kTestDeviceIndex,
    175                       IPAddress::kFamilyIPv4);
    176   ByteString encoded(message.Encode());
    177   InputData data(encoded.GetData(), encoded.GetLength());
    178   RTNLHandler::GetInstance()->ParseRTNL(&data);
    179 }
    180 
    181 void RTNLHandlerTest::ReturnError(uint32_t sequence, int error_number) {
    182   struct {
    183     struct nlmsghdr hdr;
    184     struct nlmsgerr err;
    185   } errmsg;
    186 
    187   memset(&errmsg, 0, sizeof(errmsg));
    188   errmsg.hdr.nlmsg_type = NLMSG_ERROR;
    189   errmsg.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(errmsg.err));
    190   errmsg.hdr.nlmsg_seq = sequence;
    191   errmsg.err.error = -error_number;
    192 
    193   InputData data(reinterpret_cast<unsigned char*>(&errmsg), sizeof(errmsg));
    194   RTNLHandler::GetInstance()->ParseRTNL(&data);
    195 }
    196 
    197 TEST_F(RTNLHandlerTest, ListenersInvoked) {
    198   StartRTNLHandler();
    199 
    200   std::unique_ptr<RTNLListener> link_listener(
    201       new RTNLListener(RTNLHandler::kRequestLink, callback_));
    202   std::unique_ptr<RTNLListener> neighbor_listener(
    203       new RTNLListener(RTNLHandler::kRequestNeighbor, callback_));
    204 
    205   EXPECT_CALL(*this, HandlerCallback(A<const RTNLMessage&>()))
    206       .With(MessageType(RTNLMessage::kTypeLink));
    207   EXPECT_CALL(*this, HandlerCallback(A<const RTNLMessage&>()))
    208       .With(MessageType(RTNLMessage::kTypeNeighbor));
    209 
    210   AddLink();
    211   AddNeighbor();
    212 
    213   StopRTNLHandler();
    214 }
    215 
    216 TEST_F(RTNLHandlerTest, GetInterfaceName) {
    217   EXPECT_EQ(-1, RTNLHandler::GetInstance()->GetInterfaceIndex(""));
    218   {
    219     struct ifreq ifr;
    220     string name(sizeof(ifr.ifr_name), 'x');
    221     EXPECT_EQ(-1, RTNLHandler::GetInstance()->GetInterfaceIndex(name));
    222   }
    223 
    224   const int kTestSocket = 123;
    225   EXPECT_CALL(*sockets_, Socket(PF_INET, SOCK_DGRAM, 0))
    226       .Times(3)
    227       .WillOnce(Return(-1))
    228       .WillRepeatedly(Return(kTestSocket));
    229   EXPECT_CALL(*sockets_, Ioctl(kTestSocket, SIOCGIFINDEX, _))
    230       .WillOnce(Return(-1))
    231       .WillOnce(DoAll(SetInterfaceIndex(), Return(0)));
    232   EXPECT_CALL(*sockets_, Close(kTestSocket))
    233       .Times(2)
    234       .WillRepeatedly(Return(0));
    235   EXPECT_EQ(-1, RTNLHandler::GetInstance()->GetInterfaceIndex("eth0"));
    236   EXPECT_EQ(-1, RTNLHandler::GetInstance()->GetInterfaceIndex("wlan0"));
    237   EXPECT_EQ(kTestInterfaceIndex,
    238             RTNLHandler::GetInstance()->GetInterfaceIndex("usb0"));
    239 }
    240 
    241 TEST_F(RTNLHandlerTest, IsSequenceInErrorMaskWindow) {
    242   const uint32_t kRequestSequence = 1234;
    243   SetRequestSequence(kRequestSequence);
    244   EXPECT_FALSE(IsSequenceInErrorMaskWindow(kRequestSequence + 1));
    245   EXPECT_TRUE(IsSequenceInErrorMaskWindow(kRequestSequence));
    246   EXPECT_TRUE(IsSequenceInErrorMaskWindow(kRequestSequence - 1));
    247   EXPECT_TRUE(IsSequenceInErrorMaskWindow(kRequestSequence -
    248                                           GetErrorWindowSize() + 1));
    249   EXPECT_FALSE(IsSequenceInErrorMaskWindow(kRequestSequence -
    250                                            GetErrorWindowSize()));
    251   EXPECT_FALSE(IsSequenceInErrorMaskWindow(kRequestSequence -
    252                                            GetErrorWindowSize() - 1));
    253 }
    254 
    255 TEST_F(RTNLHandlerTest, SendMessageReturnsErrorAndAdvancesSequenceNumber) {
    256   StartRTNLHandler();
    257   const uint32_t kSequenceNumber = 123;
    258   SetRequestSequence(kSequenceNumber);
    259   EXPECT_CALL(*sockets_, Send(kTestSocket, _, _, 0)).WillOnce(Return(-1));
    260   EXPECT_FALSE(RTNLHandler::GetInstance()->SendMessage(&dummy_message_));
    261 
    262   // Sequence number should still increment even if there was a failure.
    263   EXPECT_EQ(kSequenceNumber + 1, GetRequestSequence());
    264   StopRTNLHandler();
    265 }
    266 
    267 TEST_F(RTNLHandlerTest, SendMessageWithEmptyMask) {
    268   StartRTNLHandler();
    269   const uint32_t kSequenceNumber = 123;
    270   SetRequestSequence(kSequenceNumber);
    271   SetErrorMask(kSequenceNumber, {1, 2, 3});
    272   EXPECT_CALL(*sockets_, Send(kTestSocket, _, _, 0)).WillOnce(ReturnArg<2>());
    273   EXPECT_TRUE(RTNLHandler::GetInstance()->SendMessageWithErrorMask(
    274       &dummy_message_, {}));
    275   EXPECT_EQ(kSequenceNumber + 1, GetRequestSequence());
    276   EXPECT_TRUE(GetAndClearErrorMask(kSequenceNumber).empty());
    277   StopRTNLHandler();
    278 }
    279 
    280 TEST_F(RTNLHandlerTest, SendMessageWithErrorMask) {
    281   StartRTNLHandler();
    282   const uint32_t kSequenceNumber = 123;
    283   SetRequestSequence(kSequenceNumber);
    284   EXPECT_CALL(*sockets_, Send(kTestSocket, _, _, 0)).WillOnce(ReturnArg<2>());
    285   EXPECT_TRUE(RTNLHandler::GetInstance()->SendMessageWithErrorMask(
    286       &dummy_message_, {1, 2, 3}));
    287   EXPECT_EQ(kSequenceNumber + 1, GetRequestSequence());
    288   EXPECT_TRUE(GetAndClearErrorMask(kSequenceNumber + 1).empty());
    289   EXPECT_THAT(GetAndClearErrorMask(kSequenceNumber), ElementsAre(1, 2, 3));
    290 
    291   // A second call to GetAndClearErrorMask() returns an empty vector.
    292   EXPECT_TRUE(GetAndClearErrorMask(kSequenceNumber).empty());
    293   StopRTNLHandler();
    294 }
    295 
    296 TEST_F(RTNLHandlerTest, SendMessageInferredErrorMasks) {
    297   struct {
    298     RTNLMessage::Type type;
    299     RTNLMessage::Mode mode;
    300     RTNLHandler::ErrorMask mask;
    301   } expectations[] = {
    302     { RTNLMessage::kTypeLink, RTNLMessage::kModeGet, {} },
    303     { RTNLMessage::kTypeLink, RTNLMessage::kModeAdd, {EEXIST} },
    304     { RTNLMessage::kTypeLink, RTNLMessage::kModeDelete, {ESRCH, ENODEV} },
    305     { RTNLMessage::kTypeAddress, RTNLMessage::kModeDelete,
    306          {ESRCH, ENODEV, EADDRNOTAVAIL} }
    307   };
    308   const uint32_t kSequenceNumber = 123;
    309   EXPECT_CALL(*sockets_, Send(_, _, _, 0)).WillRepeatedly(ReturnArg<2>());
    310   for (const auto& expectation : expectations) {
    311     SetRequestSequence(kSequenceNumber);
    312     RTNLMessage message(expectation.type,
    313                         expectation.mode,
    314                         0,
    315                         0,
    316                         0,
    317                         0,
    318                         IPAddress::kFamilyUnknown);
    319     EXPECT_TRUE(RTNLHandler::GetInstance()->SendMessage(&message));
    320     EXPECT_EQ(expectation.mask, GetAndClearErrorMask(kSequenceNumber));
    321   }
    322 }
    323 
    324 TEST_F(RTNLHandlerTest, MaskedError) {
    325   StartRTNLHandler();
    326   const uint32_t kSequenceNumber = 123;
    327   SetRequestSequence(kSequenceNumber);
    328   EXPECT_CALL(*sockets_, Send(kTestSocket, _, _, 0)).WillOnce(ReturnArg<2>());
    329   EXPECT_TRUE(RTNLHandler::GetInstance()->SendMessageWithErrorMask(
    330       &dummy_message_, {1, 2, 3}));
    331   ScopedMockLog log;
    332 
    333   // This error will be not be masked since this sequence number has no mask.
    334   EXPECT_CALL(log, Log(logging::LOG_ERROR, _, HasSubstr("error 1"))).Times(1);
    335   ReturnError(kSequenceNumber - 1, 1);
    336 
    337   // This error will be masked.
    338   EXPECT_CALL(log, Log(logging::LOG_ERROR, _, HasSubstr("error 2"))).Times(0);
    339   ReturnError(kSequenceNumber, 2);
    340 
    341   // This second error will be not be masked since the error mask was removed.
    342   EXPECT_CALL(log, Log(logging::LOG_ERROR, _, HasSubstr("error 3"))).Times(1);
    343   ReturnError(kSequenceNumber, 3);
    344 
    345   StopRTNLHandler();
    346 }
    347 
    348 }  // namespace shill
    349