Home | History | Annotate | Download | only in shill
      1 //
      2 // Copyright (C) 2013 The Android Open Source Project
      3 //
      4 // Licensed under the Apache License, Version 2.0 (the "License");
      5 // you may not use this file except in compliance with the License.
      6 // You may obtain a copy of the License at
      7 //
      8 //      http://www.apache.org/licenses/LICENSE-2.0
      9 //
     10 // Unless required by applicable law or agreed to in writing, software
     11 // distributed under the License is distributed on an "AS IS" BASIS,
     12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13 // See the License for the specific language governing permissions and
     14 // limitations under the License.
     15 //
     16 
     17 #include "shill/icmp.h"
     18 
     19 #include <netinet/in.h>
     20 #include <netinet/ip_icmp.h>
     21 
     22 #include <gtest/gtest.h>
     23 
     24 #include "shill/mock_log.h"
     25 #include "shill/net/ip_address.h"
     26 #include "shill/net/mock_sockets.h"
     27 
     28 using testing::_;
     29 using testing::HasSubstr;
     30 using testing::InSequence;
     31 using testing::Return;
     32 using testing::StrictMock;
     33 using testing::Test;
     34 
     35 namespace shill {
     36 
     37 namespace {
     38 
     39 // These binary blobs representing ICMP headers and their respective checksums
     40 // were taken directly from Wireshark ICMP packet captures and are given in big
     41 // endian. The checksum field is zeroed in |kIcmpEchoRequestEvenLen| and
     42 // |kIcmpEchoRequestOddLen| so the checksum can be calculated on the header in
     43 // IcmpTest.ComputeIcmpChecksum.
     44 const uint8_t kIcmpEchoRequestEvenLen[] = {0x08, 0x00, 0x00, 0x00,
     45                                            0x71, 0x50, 0x00, 0x00};
     46 const uint8_t kIcmpEchoRequestEvenLenChecksum[] = {0x86, 0xaf};
     47 const uint8_t kIcmpEchoRequestOddLen[] = {0x08, 0x00, 0x00, 0x00, 0xac, 0x51,
     48                                           0x00, 0x00, 0x00, 0x00, 0x01};
     49 const uint8_t kIcmpEchoRequestOddLenChecksum[] = {0x4a, 0xae};
     50 
     51 }  // namespace
     52 
     53 class IcmpTest : public Test {
     54  public:
     55   IcmpTest() {}
     56   virtual ~IcmpTest() {}
     57 
     58   virtual void SetUp() {
     59     sockets_ = new StrictMock<MockSockets>();
     60     // Passes ownership.
     61     icmp_.sockets_.reset(sockets_);
     62   }
     63 
     64   virtual void TearDown() {
     65     if (icmp_.IsStarted()) {
     66       EXPECT_CALL(*sockets_, Close(kSocketFD));
     67       icmp_.Stop();
     68     }
     69     EXPECT_FALSE(icmp_.IsStarted());
     70   }
     71 
     72  protected:
     73   static const int kSocketFD;
     74   static const char kIPAddress[];
     75 
     76   int GetSocket() { return icmp_.socket_; }
     77   bool StartIcmp() { return StartIcmpWithFD(kSocketFD); }
     78   bool StartIcmpWithFD(int fd) {
     79     EXPECT_CALL(*sockets_, Socket(AF_INET, SOCK_RAW, IPPROTO_ICMP))
     80         .WillOnce(Return(fd));
     81     EXPECT_CALL(*sockets_, SetNonBlocking(fd)).WillOnce(Return(0));
     82     bool start_status = icmp_.Start();
     83     EXPECT_TRUE(start_status);
     84     EXPECT_EQ(fd, icmp_.socket_);
     85     EXPECT_TRUE(icmp_.IsStarted());
     86     return start_status;
     87   }
     88   uint16_t ComputeIcmpChecksum(const struct icmphdr &hdr, size_t len) {
     89     return Icmp::ComputeIcmpChecksum(hdr, len);
     90   }
     91 
     92   // Owned by Icmp, and tracked here only for mocks.
     93   MockSockets* sockets_;
     94 
     95   Icmp icmp_;
     96 };
     97 
     98 
     99 const int IcmpTest::kSocketFD = 456;
    100 const char IcmpTest::kIPAddress[] = "10.0.1.1";
    101 
    102 
    103 TEST_F(IcmpTest, Constructor) {
    104   EXPECT_EQ(-1, GetSocket());
    105   EXPECT_FALSE(icmp_.IsStarted());
    106 }
    107 
    108 TEST_F(IcmpTest, SocketOpenFail) {
    109   ScopedMockLog log;
    110   EXPECT_CALL(log,
    111       Log(logging::LOG_ERROR, _,
    112           HasSubstr("Could not create ICMP socket"))).Times(1);
    113 
    114   EXPECT_CALL(*sockets_, Socket(AF_INET, SOCK_RAW, IPPROTO_ICMP))
    115       .WillOnce(Return(-1));
    116   EXPECT_FALSE(icmp_.Start());
    117   EXPECT_FALSE(icmp_.IsStarted());
    118 }
    119 
    120 TEST_F(IcmpTest, SocketNonBlockingFail) {
    121   ScopedMockLog log;
    122   EXPECT_CALL(log,
    123       Log(logging::LOG_ERROR, _,
    124           HasSubstr("Could not set socket to be non-blocking"))).Times(1);
    125 
    126   EXPECT_CALL(*sockets_, Socket(_, _, _)).WillOnce(Return(kSocketFD));
    127   EXPECT_CALL(*sockets_, SetNonBlocking(kSocketFD)).WillOnce(Return(-1));
    128   EXPECT_CALL(*sockets_, Close(kSocketFD));
    129   EXPECT_FALSE(icmp_.Start());
    130   EXPECT_FALSE(icmp_.IsStarted());
    131 }
    132 
    133 TEST_F(IcmpTest, StartMultipleTimes) {
    134   const int kFirstSocketFD = kSocketFD + 1;
    135   StartIcmpWithFD(kFirstSocketFD);
    136   EXPECT_CALL(*sockets_, Close(kFirstSocketFD));
    137   StartIcmp();
    138 }
    139 
    140 MATCHER_P(IsIcmpHeader, header, "") {
    141   return memcmp(arg, &header, sizeof(header)) == 0;
    142 }
    143 
    144 
    145 MATCHER_P(IsSocketAddress, address, "") {
    146   const struct sockaddr_in* sock_addr =
    147       reinterpret_cast<const struct sockaddr_in*>(arg);
    148   return sock_addr->sin_family == address.family() &&
    149       memcmp(&sock_addr->sin_addr.s_addr, address.GetConstData(),
    150              address.GetLength()) == 0;
    151 }
    152 
    153 TEST_F(IcmpTest, TransmitEchoRequest) {
    154   StartIcmp();
    155   // Address isn't valid.
    156   EXPECT_FALSE(
    157       icmp_.TransmitEchoRequest(IPAddress(IPAddress::kFamilyIPv4), 1, 1));
    158 
    159   // IPv6 adresses aren't implemented.
    160   IPAddress ipv6_destination(IPAddress::kFamilyIPv6);
    161   EXPECT_TRUE(ipv6_destination.SetAddressFromString(
    162       "fe80::1aa9:5ff:abcd:1234"));
    163   EXPECT_FALSE(icmp_.TransmitEchoRequest(ipv6_destination, 1, 1));
    164 
    165   IPAddress ipv4_destination(IPAddress::kFamilyIPv4);
    166   EXPECT_TRUE(ipv4_destination.SetAddressFromString(kIPAddress));
    167 
    168   struct icmphdr icmp_header;
    169   memset(&icmp_header, 0, sizeof(icmp_header));
    170   icmp_header.type = ICMP_ECHO;
    171   icmp_header.code = Icmp::kIcmpEchoCode;
    172   icmp_header.un.echo.id = 1;
    173   icmp_header.un.echo.sequence = 1;
    174   icmp_header.checksum = ComputeIcmpChecksum(icmp_header, sizeof(icmp_header));
    175 
    176   EXPECT_CALL(*sockets_, SendTo(kSocketFD,
    177                                 IsIcmpHeader(icmp_header),
    178                                 sizeof(icmp_header),
    179                                 0,
    180                                 IsSocketAddress(ipv4_destination),
    181                                 sizeof(sockaddr_in)))
    182       .WillOnce(Return(-1))
    183       .WillOnce(Return(0))
    184       .WillOnce(Return(sizeof(icmp_header) - 1))
    185       .WillOnce(Return(sizeof(icmp_header)));
    186   {
    187     InSequence seq;
    188     ScopedMockLog log;
    189     EXPECT_CALL(log,
    190         Log(logging::LOG_ERROR, _,
    191             HasSubstr("Socket sendto failed"))).Times(1);
    192     EXPECT_CALL(log,
    193         Log(logging::LOG_ERROR, _,
    194             HasSubstr("less than the expected result"))).Times(2);
    195 
    196     EXPECT_FALSE(icmp_.TransmitEchoRequest(ipv4_destination, 1, 1));
    197     EXPECT_FALSE(icmp_.TransmitEchoRequest(ipv4_destination, 1, 1));
    198     EXPECT_FALSE(icmp_.TransmitEchoRequest(ipv4_destination, 1, 1));
    199     EXPECT_TRUE(icmp_.TransmitEchoRequest(ipv4_destination, 1, 1));
    200   }
    201 }
    202 
    203 TEST_F(IcmpTest, ComputeIcmpChecksum) {
    204   EXPECT_EQ(*reinterpret_cast<const uint16_t*>(kIcmpEchoRequestEvenLenChecksum),
    205             ComputeIcmpChecksum(*reinterpret_cast<const struct icmphdr*>(
    206                                     kIcmpEchoRequestEvenLen),
    207                                 sizeof(kIcmpEchoRequestEvenLen)));
    208   EXPECT_EQ(*reinterpret_cast<const uint16_t*>(kIcmpEchoRequestOddLenChecksum),
    209             ComputeIcmpChecksum(*reinterpret_cast<const struct icmphdr*>(
    210                                     kIcmpEchoRequestOddLen),
    211                                 sizeof(kIcmpEchoRequestOddLen)));
    212 }
    213 
    214 }  // namespace shill
    215