Home | History | Annotate | Download | only in server
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include <linux/netfilter/nfnetlink_log.h>
     18 
     19 #include <arpa/inet.h>
     20 #include <sys/socket.h>
     21 #include <netinet/in.h>
     22 #include <netinet/ip.h>
     23 #include <netinet/tcp.h>
     24 
     25 #include <gmock/gmock.h>
     26 #include <gtest/gtest.h>
     27 
     28 #include "NetlinkManager.h"
     29 #include "WakeupController.h"
     30 
     31 using ::testing::StrictMock;
     32 using ::testing::Test;
     33 using ::testing::DoAll;
     34 using ::testing::SaveArg;
     35 using ::testing::Return;
     36 using ::testing::_;
     37 
     38 namespace android {
     39 namespace net {
     40 
     41 const uint32_t kDefaultPacketCopyRange = WakeupController::kDefaultPacketCopyRange;
     42 
     43 using netdutils::status::ok;
     44 
     45 class MockNetdEventListener {
     46   public:
     47     MOCK_METHOD10(onWakeupEvent, void(
     48             const std::string& prefix, int uid, int ether, int ipNextHeader,
     49             std::vector<uint8_t> dstHw, const std::string& srcIp, const std::string& dstIp,
     50             int srcPort, int dstPort, uint64_t timestampNs));
     51 };
     52 
     53 class MockIptablesRestore : public IptablesRestoreInterface {
     54   public:
     55     ~MockIptablesRestore() override = default;
     56     MOCK_METHOD3(execute, int(const IptablesTarget target, const std::string& commands,
     57                               std::string* output));
     58 };
     59 
     60 class MockNFLogListener : public NFLogListenerInterface {
     61   public:
     62     ~MockNFLogListener() override = default;
     63     MOCK_METHOD2(subscribe, netdutils::Status(uint16_t nfLogGroup, const DispatchFn& fn));
     64     MOCK_METHOD3(subscribe,
     65             netdutils::Status(uint16_t nfLogGroup, uint32_t copyRange, const DispatchFn& fn));
     66     MOCK_METHOD1(unsubscribe, netdutils::Status(uint16_t nfLogGroup));
     67 };
     68 
     69 class WakeupControllerTest : public Test {
     70   protected:
     71     WakeupControllerTest() {
     72         EXPECT_CALL(mListener,
     73             subscribe(NetlinkManager::NFLOG_WAKEUP_GROUP, kDefaultPacketCopyRange, _))
     74             .WillOnce(DoAll(SaveArg<2>(&mMessageHandler), Return(ok)));
     75         EXPECT_CALL(mListener,
     76             unsubscribe(NetlinkManager::NFLOG_WAKEUP_GROUP)).WillOnce(Return(ok));
     77         mController.init(&mListener);
     78     }
     79 
     80     StrictMock<MockNetdEventListener> mEventListener;
     81     StrictMock<MockIptablesRestore> mIptables;
     82     StrictMock<MockNFLogListener> mListener;
     83     WakeupController mController{
     84         [this](const WakeupController::ReportArgs& args) {
     85             mEventListener.onWakeupEvent(args.prefix, args.uid, args.ethertype, args.ipNextHeader,
     86                                          args.dstHw, args.srcIp, args.dstIp, args.srcPort,
     87                                          args.dstPort, args.timestampNs);
     88         },
     89         &mIptables};
     90     NFLogListenerInterface::DispatchFn mMessageHandler;
     91 };
     92 
     93 TEST_F(WakeupControllerTest, msgHandlerWithPartialAttributes) {
     94     const char kPrefix[] = "test:prefix";
     95     const uid_t kUid = 8734;
     96     const gid_t kGid = 2222;
     97     const uint64_t kNsPerS = 1000000000ULL;
     98     const uint64_t kTsNs = 9999 + (34 * kNsPerS);
     99 
    100     struct Msg {
    101         nlmsghdr nlmsg;
    102         nfgenmsg nfmsg;
    103         nlattr uidAttr;
    104         uid_t uid;
    105         nlattr gidAttr;
    106         gid_t gid;
    107         nlattr tsAttr;
    108         timespec ts;
    109         nlattr prefixAttr;
    110         char prefix[sizeof(kPrefix)];
    111     } msg = {};
    112 
    113     msg.uidAttr.nla_type = NFULA_UID;
    114     msg.uidAttr.nla_len = sizeof(msg.uidAttr) + sizeof(msg.uid);
    115     msg.uid = htonl(kUid);
    116 
    117     msg.gidAttr.nla_type = NFULA_GID;
    118     msg.gidAttr.nla_len = sizeof(msg.gidAttr) + sizeof(msg.gid);
    119     msg.gid = htonl(kGid);
    120 
    121     msg.tsAttr.nla_type = NFULA_TIMESTAMP;
    122     msg.tsAttr.nla_len = sizeof(msg.tsAttr) + sizeof(msg.ts);
    123     msg.ts.tv_sec = htonl(kTsNs / kNsPerS);
    124     msg.ts.tv_nsec = htonl(kTsNs % kNsPerS);
    125 
    126     msg.prefixAttr.nla_type = NFULA_PREFIX;
    127     msg.prefixAttr.nla_len = sizeof(msg.prefixAttr) + sizeof(msg.prefix);
    128     memcpy(msg.prefix, kPrefix, sizeof(kPrefix));
    129 
    130     auto payload = drop(netdutils::makeSlice(msg), offsetof(Msg, uidAttr));
    131     EXPECT_CALL(mEventListener,
    132             onWakeupEvent(kPrefix, kUid, -1, -1, std::vector<uint8_t>(), "", "", -1, -1, kTsNs));
    133     mMessageHandler(msg.nlmsg, msg.nfmsg, payload);
    134 }
    135 
    136 TEST_F(WakeupControllerTest, msgHandler) {
    137     const char kPrefix[] = "test:prefix";
    138     const uid_t kUid = 8734;
    139     const gid_t kGid = 2222;
    140     const std::vector<uint8_t> kMacAddr = {11, 22, 33, 44, 55, 66};
    141     const char* kSrcIpAddr = "192.168.2.1";
    142     const char* kDstIpAddr = "192.168.2.23";
    143     const uint16_t kEthertype = 0x800;
    144     const uint8_t kIpNextHeader = 6;
    145     const uint16_t kSrcPort = 1238;
    146     const uint16_t kDstPort = 4567;
    147     const uint64_t kNsPerS = 1000000000ULL;
    148     const uint64_t kTsNs = 9999 + (34 * kNsPerS);
    149 
    150     struct Msg {
    151         nlmsghdr nlmsg;
    152         nfgenmsg nfmsg;
    153         nlattr uidAttr;
    154         uid_t uid;
    155         nlattr gidAttr;
    156         gid_t gid;
    157         nlattr tsAttr;
    158         timespec ts;
    159         nlattr prefixAttr;
    160         char prefix[sizeof(kPrefix)];
    161         nlattr packetHeaderAttr;
    162         struct nfulnl_msg_packet_hdr packetHeader;
    163         nlattr hardwareAddrAttr;
    164         struct nfulnl_msg_packet_hw hardwareAddr;
    165         nlattr packetPayloadAttr;
    166         struct iphdr ipHeader;
    167         struct tcphdr tcpHeader;
    168     } msg = {};
    169 
    170     msg.prefixAttr.nla_type = NFULA_PREFIX;
    171     msg.prefixAttr.nla_len = sizeof(msg.prefixAttr) + sizeof(msg.prefix);
    172     memcpy(msg.prefix, kPrefix, sizeof(kPrefix));
    173 
    174     msg.uidAttr.nla_type = NFULA_UID;
    175     msg.uidAttr.nla_len = sizeof(msg.uidAttr) + sizeof(msg.uid);
    176     msg.uid = htonl(kUid);
    177 
    178     msg.gidAttr.nla_type = NFULA_GID;
    179     msg.gidAttr.nla_len = sizeof(msg.gidAttr) + sizeof(msg.gid);
    180     msg.gid = htonl(kGid);
    181 
    182     msg.tsAttr.nla_type = NFULA_TIMESTAMP;
    183     msg.tsAttr.nla_len = sizeof(msg.tsAttr) + sizeof(msg.ts);
    184     msg.ts.tv_sec = htonl(kTsNs / kNsPerS);
    185     msg.ts.tv_nsec = htonl(kTsNs % kNsPerS);
    186 
    187     msg.packetHeaderAttr.nla_type = NFULA_PACKET_HDR;
    188     msg.packetHeaderAttr.nla_len = sizeof(msg.packetHeaderAttr) + sizeof(msg.packetHeader);
    189     msg.packetHeader.hw_protocol = htons(kEthertype);
    190 
    191     msg.hardwareAddrAttr.nla_type = NFULA_HWADDR;
    192     msg.hardwareAddrAttr.nla_len = sizeof(msg.hardwareAddrAttr) + sizeof(msg.hardwareAddr);
    193     msg.hardwareAddr.hw_addrlen = htons(kMacAddr.size());
    194     std::copy(kMacAddr.begin(), kMacAddr.end(), msg.hardwareAddr.hw_addr);
    195 
    196     msg.packetPayloadAttr.nla_type = NFULA_PAYLOAD;
    197     msg.packetPayloadAttr.nla_len =
    198             sizeof(msg.packetPayloadAttr) + sizeof(msg.ipHeader) + sizeof(msg.tcpHeader);
    199     msg.ipHeader.protocol = IPPROTO_TCP;
    200     msg.ipHeader.ihl = sizeof(msg.ipHeader) / 4; // ipv4 IHL counts 32 bit words.
    201     inet_pton(AF_INET, kSrcIpAddr, &msg.ipHeader.saddr);
    202     inet_pton(AF_INET, kDstIpAddr, &msg.ipHeader.daddr);
    203     msg.tcpHeader.th_sport = htons(kSrcPort);
    204     msg.tcpHeader.th_dport = htons(kDstPort);
    205 
    206     auto payload = drop(netdutils::makeSlice(msg), offsetof(Msg, uidAttr));
    207     EXPECT_CALL(mEventListener, onWakeupEvent(kPrefix, kUid, kEthertype, kIpNextHeader, kMacAddr,
    208                                               kSrcIpAddr, kDstIpAddr, kSrcPort, kDstPort, kTsNs));
    209     mMessageHandler(msg.nlmsg, msg.nfmsg, payload);
    210 }
    211 
    212 TEST_F(WakeupControllerTest, badAttr) {
    213     const char kPrefix[] = "test:prefix";
    214     const uid_t kUid = 8734;
    215     const gid_t kGid = 2222;
    216     const uint64_t kNsPerS = 1000000000ULL;
    217     const uint64_t kTsNs = 9999 + (34 * kNsPerS);
    218 
    219     struct Msg {
    220         nlmsghdr nlmsg;
    221         nfgenmsg nfmsg;
    222         nlattr uidAttr;
    223         uid_t uid;
    224         nlattr invalid0;
    225         nlattr invalid1;
    226         nlattr gidAttr;
    227         gid_t gid;
    228         nlattr tsAttr;
    229         timespec ts;
    230         nlattr prefixAttr;
    231         char prefix[sizeof(kPrefix)];
    232     } msg = {};
    233 
    234     msg.uidAttr.nla_type = 999;
    235     msg.uidAttr.nla_len = sizeof(msg.uidAttr) + sizeof(msg.uid);
    236     msg.uid = htonl(kUid);
    237 
    238     msg.invalid0.nla_type = 0;
    239     msg.invalid0.nla_len = 0;
    240     msg.invalid1.nla_type = 0;
    241     msg.invalid1.nla_len = 1;
    242 
    243     msg.gidAttr.nla_type = NFULA_GID;
    244     msg.gidAttr.nla_len = sizeof(msg.gidAttr) + sizeof(msg.gid);
    245     msg.gid = htonl(kGid);
    246 
    247     msg.tsAttr.nla_type = NFULA_TIMESTAMP;
    248     msg.tsAttr.nla_len = sizeof(msg.tsAttr) - 2;
    249     msg.ts.tv_sec = htonl(kTsNs / kNsPerS);
    250     msg.ts.tv_nsec = htonl(kTsNs % kNsPerS);
    251 
    252     msg.prefixAttr.nla_type = NFULA_UID;
    253     msg.prefixAttr.nla_len = sizeof(msg.prefixAttr) + sizeof(msg.prefix);
    254     memcpy(msg.prefix, kPrefix, sizeof(kPrefix));
    255 
    256     auto payload = drop(netdutils::makeSlice(msg), offsetof(Msg, uidAttr));
    257     EXPECT_CALL(mEventListener,
    258             onWakeupEvent("", 1952805748, -1, -1, std::vector<uint8_t>(), "", "", -1, -1, 0));
    259     mMessageHandler(msg.nlmsg, msg.nfmsg, payload);
    260 }
    261 
    262 TEST_F(WakeupControllerTest, unterminatedString) {
    263     char ones[20] = {};
    264     memset(ones, 1, sizeof(ones));
    265 
    266     struct Msg {
    267         nlmsghdr nlmsg;
    268         nfgenmsg nfmsg;
    269         nlattr prefixAttr;
    270         char prefix[sizeof(ones)];
    271     } msg = {};
    272 
    273     msg.prefixAttr.nla_type = NFULA_PREFIX;
    274     msg.prefixAttr.nla_len = sizeof(msg.prefixAttr) + sizeof(msg.prefix);
    275     memcpy(msg.prefix, ones, sizeof(ones));
    276 
    277     const auto expected = std::string(ones, sizeof(ones) - 1);
    278     auto payload = drop(netdutils::makeSlice(msg), offsetof(Msg, prefixAttr));
    279     EXPECT_CALL(mEventListener,
    280             onWakeupEvent(expected, -1, -1, -1, std::vector<uint8_t>(), "", "", -1, -1, 0));
    281     mMessageHandler(msg.nlmsg, msg.nfmsg, payload);
    282 }
    283 
    284 TEST_F(WakeupControllerTest, addInterface) {
    285     const char kPrefix[] = "test:prefix";
    286     const char kIfName[] = "wlan8";
    287     const uint32_t kMark = 0x12345678;
    288     const uint32_t kMask = 0x0F0F0F0F;
    289     const char kExpected[] =
    290         "*mangle\n-A wakeupctrl_mangle_INPUT -i test:prefix"
    291         " -j NFLOG --nflog-prefix wlan8 --nflog-group 3 --nflog-threshold 8"
    292         " -m mark --mark 0x12345678/0x0f0f0f0f -m limit --limit 10/s\nCOMMIT\n";
    293     EXPECT_CALL(mIptables, execute(V4V6, kExpected, _)).WillOnce(Return(0));
    294     mController.addInterface(kPrefix, kIfName, kMark, kMask);
    295 }
    296 
    297 TEST_F(WakeupControllerTest, delInterface) {
    298     const char kPrefix[] = "test:prefix";
    299     const char kIfName[] = "wlan8";
    300     const uint32_t kMark = 0x12345678;
    301     const uint32_t kMask = 0xF0F0F0F0;
    302     const char kExpected[] =
    303         "*mangle\n-D wakeupctrl_mangle_INPUT -i test:prefix"
    304         " -j NFLOG --nflog-prefix wlan8 --nflog-group 3 --nflog-threshold 8"
    305         " -m mark --mark 0x12345678/0xf0f0f0f0 -m limit --limit 10/s\nCOMMIT\n";
    306     EXPECT_CALL(mIptables, execute(V4V6, kExpected, _)).WillOnce(Return(0));
    307     mController.delInterface(kPrefix, kIfName, kMark, kMask);
    308 }
    309 
    310 }  // namespace net
    311 }  // namespace android
    312