Home | History | Annotate | Download | only in net
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 #include "common/libs/net/netlink_client.h"
     17 
     18 #include <linux/rtnetlink.h>
     19 
     20 #include <gmock/gmock.h>
     21 #include <gtest/gtest.h>
     22 #include <glog/logging.h>
     23 
     24 #include <iostream>
     25 #include <memory>
     26 
     27 using ::testing::ElementsAreArray;
     28 using ::testing::MatchResultListener;
     29 using ::testing::Return;
     30 
     31 namespace cvd {
     32 namespace {
     33 extern "C" void klog_write(int /* level */, const char* /* format */, ...) {}
     34 
     35 // Dump hex buffer to test log.
     36 void Dump(MatchResultListener* result_listener, const char* title,
     37           const uint8_t* data, size_t length) {
     38   for (size_t item = 0; item < length;) {
     39     *result_listener << title;
     40     do {
     41       result_listener->stream()->width(2);
     42       result_listener->stream()->fill('0');
     43       *result_listener << std::hex << +data[item] << " ";
     44       ++item;
     45     } while (item & 0xf);
     46     *result_listener << "\n";
     47   }
     48 }
     49 
     50 // Compare two memory areas byte by byte, print information about first
     51 // difference. Dumps both bufferst to user log.
     52 bool Compare(MatchResultListener* result_listener,
     53              const uint8_t* exp, const uint8_t* act, size_t length) {
     54   for (size_t index = 0; index < length; ++index) {
     55     if (exp[index] != act[index]) {
     56       *result_listener << "\nUnexpected data at offset " << index << "\n";
     57       Dump(result_listener, "Data Expected: ", exp, length);
     58       Dump(result_listener, "  Data Actual: ", act, length);
     59       return false;
     60     }
     61   }
     62 
     63   return true;
     64 }
     65 
     66 // Matcher validating Netlink Request data.
     67 MATCHER_P2(RequestDataIs, data, length, "Matches expected request data") {
     68   size_t offset = sizeof(nlmsghdr);
     69   if (offset + length != arg.RequestLength()) {
     70     *result_listener << "Unexpected request length: "
     71                      << arg.RequestLength() - offset << " vs " << length;
     72     return false;
     73   }
     74 
     75   // Note: Request begins with header (nlmsghdr). Header is not covered by this
     76   // call.
     77   const uint8_t* exp_data = static_cast<const uint8_t*>(
     78       static_cast<const void*>(data));
     79   const uint8_t* act_data = static_cast<const uint8_t*>(arg.RequestData());
     80   return Compare(
     81       result_listener, exp_data, &act_data[offset], length);
     82 }
     83 
     84 MATCHER_P4(RequestHeaderIs, length, type, flags, seq,
     85            "Matches request header") {
     86   nlmsghdr* header = static_cast<nlmsghdr*>(arg.RequestData());
     87   if (arg.RequestLength() < sizeof(header)) {
     88     *result_listener << "Malformed header: too short.";
     89     return false;
     90   }
     91 
     92   if (header->nlmsg_len != length) {
     93     *result_listener << "Invalid message length: "
     94                      << header->nlmsg_len << " vs " << length;
     95     return false;
     96   }
     97 
     98   if (header->nlmsg_type != type) {
     99     *result_listener << "Invalid header type: "
    100                      << header->nlmsg_type << " vs " << type;
    101     return false;
    102   }
    103 
    104   if (header->nlmsg_flags != flags) {
    105     *result_listener << "Invalid header flags: "
    106                      << header->nlmsg_flags << " vs " << flags;
    107     return false;
    108   }
    109 
    110   if (header->nlmsg_seq != seq) {
    111     *result_listener << "Invalid header sequence number: "
    112                      << header->nlmsg_seq << " vs " << seq;
    113     return false;
    114   }
    115 
    116   return true;
    117 }
    118 }  // namespace
    119 
    120 class NetlinkClientTest : public ::testing::Test {
    121   void SetUp() {
    122     google::InstallFailureSignalHandler();
    123   }
    124  protected:
    125   std::unique_ptr<NetlinkClient> nl_client_;
    126 };
    127 
    128 TEST_F(NetlinkClientTest, BasicStringNode) {
    129   constexpr uint16_t kDummyTag = 0xfce2;
    130   constexpr char kLongString[] = "long string";
    131 
    132   struct {
    133     // 11 bytes of text + padding 0 + 4 bytes of header.
    134     const uint16_t attr_length = 0x10;
    135     const uint16_t attr_type = kDummyTag;
    136     char text[sizeof(kLongString)];  // sizeof includes padding 0.
    137   } expected;
    138 
    139   memcpy(&expected.text, kLongString, sizeof(kLongString));
    140 
    141   cvd::NetlinkRequest request(RTM_SETLINK, 0);
    142   request.AddString(kDummyTag, kLongString);
    143   EXPECT_THAT(request, RequestDataIs(&expected, sizeof(expected)));
    144 }
    145 
    146 TEST_F(NetlinkClientTest, BasicIntNode) {
    147   // Basic { Dummy: Value } test.
    148   constexpr uint16_t kDummyTag = 0xfce2;
    149   constexpr int32_t kValue = 0x1badd00d;
    150 
    151   struct {
    152     const uint16_t attr_length = 0x8;  // 4 bytes of value + 4 bytes of header.
    153     const uint16_t attr_type = kDummyTag;
    154     const uint32_t attr_value = kValue;
    155   } expected;
    156 
    157   cvd::NetlinkRequest request(RTM_SETLINK, 0);
    158   request.AddInt32(kDummyTag, kValue);
    159   EXPECT_THAT(request, RequestDataIs(&expected, sizeof(expected)));
    160 }
    161 
    162 TEST_F(NetlinkClientTest, SingleList) {
    163   // List: { Dummy: Value}
    164   constexpr uint16_t kDummyTag = 0xfce2;
    165   constexpr uint16_t kListTag = 0xcafe;
    166   constexpr int32_t kValue = 0x1badd00d;
    167 
    168   struct {
    169     const uint16_t list_length = 0xc;
    170     const uint16_t list_type = kListTag;
    171     const uint16_t attr_length = 0x8;  // 4 bytes of value + 4 bytes of header.
    172     const uint16_t attr_type = kDummyTag;
    173     const uint32_t attr_value = kValue;
    174   } expected;
    175 
    176   cvd::NetlinkRequest request(RTM_SETLINK, 0);
    177   request.PushList(kListTag);
    178   request.AddInt32(kDummyTag, kValue);
    179   request.PopList();
    180 
    181   EXPECT_THAT(request, RequestDataIs(&expected, sizeof(expected)));
    182 }
    183 
    184 TEST_F(NetlinkClientTest, NestedList) {
    185   // List1: { List2: { Dummy: Value}}
    186   constexpr uint16_t kDummyTag = 0xfce2;
    187   constexpr uint16_t kList1Tag = 0xcafe;
    188   constexpr uint16_t kList2Tag = 0xfeed;
    189   constexpr int32_t kValue = 0x1badd00d;
    190 
    191   struct {
    192     const uint16_t list1_length = 0x10;
    193     const uint16_t list1_type = kList1Tag;
    194     const uint16_t list2_length = 0xc;
    195     const uint16_t list2_type = kList2Tag;
    196     const uint16_t attr_length = 0x8;
    197     const uint16_t attr_type = kDummyTag;
    198     const uint32_t attr_value = kValue;
    199   } expected;
    200 
    201   cvd::NetlinkRequest request(RTM_SETLINK, 0);
    202   request.PushList(kList1Tag);
    203   request.PushList(kList2Tag);
    204   request.AddInt32(kDummyTag, kValue);
    205   request.PopList();
    206   request.PopList();
    207 
    208   EXPECT_THAT(request, RequestDataIs(&expected, sizeof(expected)));
    209 }
    210 
    211 TEST_F(NetlinkClientTest, ListSequence) {
    212   // List1: { Dummy1: Value1}, List2: { Dummy2: Value2 }
    213   constexpr uint16_t kDummy1Tag = 0xfce2;
    214   constexpr uint16_t kDummy2Tag = 0xfd38;
    215   constexpr uint16_t kList1Tag = 0xcafe;
    216   constexpr uint16_t kList2Tag = 0xfeed;
    217   constexpr int32_t kValue1 = 0x1badd00d;
    218   constexpr int32_t kValue2 = 0xfee1;
    219 
    220   struct {
    221     const uint16_t list1_length = 0xc;
    222     const uint16_t list1_type = kList1Tag;
    223     const uint16_t attr1_length = 0x8;
    224     const uint16_t attr1_type = kDummy1Tag;
    225     const uint32_t attr1_value = kValue1;
    226     const uint16_t list2_length = 0xc;
    227     const uint16_t list2_type = kList2Tag;
    228     const uint16_t attr2_length = 0x8;
    229     const uint16_t attr2_type = kDummy2Tag;
    230     const uint32_t attr2_value = kValue2;
    231   } expected;
    232 
    233   cvd::NetlinkRequest request(RTM_SETLINK, 0);
    234   request.PushList(kList1Tag);
    235   request.AddInt32(kDummy1Tag, kValue1);
    236   request.PopList();
    237   request.PushList(kList2Tag);
    238   request.AddInt32(kDummy2Tag, kValue2);
    239   request.PopList();
    240 
    241   EXPECT_THAT(request, RequestDataIs(&expected, sizeof(expected)));
    242 }
    243 
    244 TEST_F(NetlinkClientTest, ComplexList) {
    245   // List1: { List2: { Dummy1: Value1 }, Dummy2: Value2 }
    246   constexpr uint16_t kDummy1Tag = 0xfce2;
    247   constexpr uint16_t kDummy2Tag = 0xfd38;
    248   constexpr uint16_t kList1Tag = 0xcafe;
    249   constexpr uint16_t kList2Tag = 0xfeed;
    250   constexpr int32_t kValue1 = 0x1badd00d;
    251   constexpr int32_t kValue2 = 0xfee1;
    252 
    253   struct {
    254     const uint16_t list1_length = 0x18;
    255     const uint16_t list1_type = kList1Tag;
    256     const uint16_t list2_length = 0xc;  // Note, this only covers until kValue1.
    257     const uint16_t list2_type = kList2Tag;
    258     const uint16_t attr1_length = 0x8;
    259     const uint16_t attr1_type = kDummy1Tag;
    260     const uint32_t attr1_value = kValue1;
    261     const uint16_t attr2_length = 0x8;
    262     const uint16_t attr2_type = kDummy2Tag;
    263     const uint32_t attr2_value = kValue2;
    264   } expected;
    265 
    266   cvd::NetlinkRequest request(RTM_SETLINK, 0);
    267   request.PushList(kList1Tag);
    268   request.PushList(kList2Tag);
    269   request.AddInt32(kDummy1Tag, kValue1);
    270   request.PopList();
    271   request.AddInt32(kDummy2Tag, kValue2);
    272   request.PopList();
    273 
    274   EXPECT_THAT(request, RequestDataIs(&expected, sizeof(expected)));
    275 }
    276 
    277 TEST_F(NetlinkClientTest, SimpleNetlinkCreateHeader) {
    278   cvd::NetlinkRequest request(RTM_NEWLINK, NLM_F_CREATE | NLM_F_EXCL);
    279   constexpr char kValue[] = "random string";
    280   request.AddString(0, kValue);  // Have something to work with.
    281 
    282   constexpr size_t kMsgLength =
    283       sizeof(nlmsghdr) + sizeof(nlattr) + RTA_ALIGN(sizeof(kValue));
    284   int base_seq = request.SeqNo();
    285 
    286   EXPECT_THAT(request, RequestHeaderIs(
    287       kMsgLength,
    288       RTM_NEWLINK,
    289       NLM_F_ACK | NLM_F_CREATE | NLM_F_EXCL | NLM_F_REQUEST,
    290       base_seq));
    291 
    292   cvd::NetlinkRequest request2(RTM_NEWLINK, NLM_F_CREATE | NLM_F_EXCL);
    293   request2.AddString(0, kValue);  // Have something to work with.
    294   EXPECT_THAT(request2, RequestHeaderIs(
    295       kMsgLength,
    296       RTM_NEWLINK,
    297       NLM_F_ACK | NLM_F_CREATE | NLM_F_EXCL | NLM_F_REQUEST,
    298       base_seq + 1));
    299 }
    300 
    301 TEST_F(NetlinkClientTest, SimpleNetlinkUpdateHeader) {
    302   cvd::NetlinkRequest request(RTM_SETLINK, 0);
    303   constexpr char kValue[] = "random string";
    304   request.AddString(0, kValue);  // Have something to work with.
    305 
    306   constexpr size_t kMsgLength =
    307       sizeof(nlmsghdr) + sizeof(nlattr) + RTA_ALIGN(sizeof(kValue));
    308   int base_seq = request.SeqNo();
    309 
    310   EXPECT_THAT(request, RequestHeaderIs(
    311       kMsgLength, RTM_SETLINK, NLM_F_REQUEST | NLM_F_ACK, base_seq));
    312 
    313   cvd::NetlinkRequest request2(RTM_SETLINK, 0);
    314   request2.AddString(0, kValue);  // Have something to work with.
    315   EXPECT_THAT(request2, RequestHeaderIs(
    316       kMsgLength, RTM_SETLINK, NLM_F_REQUEST | NLM_F_ACK, base_seq + 1));
    317 }
    318 
    319 }  // namespace cvd
    320