Home | History | Annotate | Download | only in common
      1 /*
      2  * Copyright 2017, The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *     http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "socket.h"
     18 
     19 #include "message.h"
     20 #include "utils.h"
     21 
     22 #include <errno.h>
     23 #include <linux/if_packet.h>
     24 #include <netinet/ip.h>
     25 #include <netinet/udp.h>
     26 #include <string.h>
     27 #include <sys/socket.h>
     28 #include <sys/types.h>
     29 #include <sys/uio.h>
     30 #include <unistd.h>
     31 
     32 // Combine the checksum of |buffer| with |size| bytes with |checksum|. This is
     33 // used for checksum calculations for IP and UDP.
     34 static uint32_t addChecksum(const uint8_t* buffer,
     35                             size_t size,
     36                             uint32_t checksum) {
     37     const uint16_t* data = reinterpret_cast<const uint16_t*>(buffer);
     38     while (size > 1) {
     39         checksum += *data++;
     40         size -= 2;
     41     }
     42     if (size > 0) {
     43         // Odd size, add the last byte
     44         checksum += *reinterpret_cast<const uint8_t*>(data);
     45     }
     46     // msw is the most significant word, the upper 16 bits of the checksum
     47     for (uint32_t msw = checksum >> 16; msw != 0; msw = checksum >> 16) {
     48         checksum = (checksum & 0xFFFF) + msw;
     49     }
     50     return checksum;
     51 }
     52 
     53 // Convenienct template function for checksum calculation
     54 template<typename T>
     55 static uint32_t addChecksum(const T& data, uint32_t checksum) {
     56     return addChecksum(reinterpret_cast<const uint8_t*>(&data), sizeof(T), checksum);
     57 }
     58 
     59 // Finalize the IP or UDP |checksum| by inverting and truncating it.
     60 static uint32_t finishChecksum(uint32_t checksum) {
     61     return ~checksum & 0xFFFF;
     62 }
     63 
     64 Socket::Socket() : mSocketFd(-1) {
     65 }
     66 
     67 Socket::~Socket() {
     68     if (mSocketFd != -1) {
     69         ::close(mSocketFd);
     70         mSocketFd = -1;
     71     }
     72 }
     73 
     74 
     75 Result Socket::open(int domain, int type, int protocol) {
     76     if (mSocketFd != -1) {
     77         return Result::error("Socket already open");
     78     }
     79     mSocketFd = ::socket(domain, type, protocol);
     80     if (mSocketFd == -1) {
     81         return Result::error("Failed to open socket: %s", strerror(errno));
     82     }
     83     return Result::success();
     84 }
     85 
     86 Result Socket::bind(const void* sockaddr, size_t sockaddrLength) {
     87     if (mSocketFd == -1) {
     88         return Result::error("Socket not open");
     89     }
     90 
     91     int status = ::bind(mSocketFd,
     92                         reinterpret_cast<const struct sockaddr*>(sockaddr),
     93                         sockaddrLength);
     94     if (status != 0) {
     95         return Result::error("Unable to bind raw socket: %s", strerror(errno));
     96     }
     97 
     98     return Result::success();
     99 }
    100 
    101 Result Socket::bindIp(in_addr_t address, uint16_t port) {
    102     struct sockaddr_in sockaddr;
    103     memset(&sockaddr, 0, sizeof(sockaddr));
    104     sockaddr.sin_family = AF_INET;
    105     sockaddr.sin_port = htons(port);
    106     sockaddr.sin_addr.s_addr = address;
    107 
    108     return bind(&sockaddr, sizeof(sockaddr));
    109 }
    110 
    111 Result Socket::bindRaw(unsigned int interfaceIndex) {
    112     struct sockaddr_ll sockaddr;
    113     memset(&sockaddr, 0, sizeof(sockaddr));
    114     sockaddr.sll_family = AF_PACKET;
    115     sockaddr.sll_protocol = htons(ETH_P_IP);
    116     sockaddr.sll_ifindex = interfaceIndex;
    117 
    118     return bind(&sockaddr, sizeof(sockaddr));
    119 }
    120 
    121 Result Socket::sendOnInterface(unsigned int interfaceIndex,
    122                                in_addr_t destinationAddress,
    123                                uint16_t destinationPort,
    124                                const Message& message) {
    125     if (mSocketFd == -1) {
    126         return Result::error("Socket not open");
    127     }
    128 
    129     char controlData[CMSG_SPACE(sizeof(struct in_pktinfo))] = { 0 };
    130     struct sockaddr_in addr;
    131     memset(&addr, 0, sizeof(addr));
    132     addr.sin_family = AF_INET;
    133     addr.sin_port = htons(destinationPort);
    134     addr.sin_addr.s_addr = destinationAddress;
    135 
    136     struct msghdr header;
    137     memset(&header, 0, sizeof(header));
    138     struct iovec iov;
    139     // The struct member is non-const since it's used for receiving but it's
    140     // safe to cast away const for sending.
    141     iov.iov_base = const_cast<uint8_t*>(message.data());
    142     iov.iov_len = message.size();
    143     header.msg_name = &addr;
    144     header.msg_namelen = sizeof(addr);
    145     header.msg_iov = &iov;
    146     header.msg_iovlen = 1;
    147     header.msg_control = &controlData;
    148     header.msg_controllen = sizeof(controlData);
    149 
    150     struct cmsghdr* controlHeader = CMSG_FIRSTHDR(&header);
    151     controlHeader->cmsg_level = IPPROTO_IP;
    152     controlHeader->cmsg_type = IP_PKTINFO;
    153     controlHeader->cmsg_len = CMSG_LEN(sizeof(struct in_pktinfo));
    154     auto packetInfo =
    155         reinterpret_cast<struct in_pktinfo*>(CMSG_DATA(controlHeader));
    156     memset(packetInfo, 0, sizeof(*packetInfo));
    157     packetInfo->ipi_ifindex = interfaceIndex;
    158 
    159     ssize_t status = ::sendmsg(mSocketFd, &header, 0);
    160     if (status <= 0) {
    161         return Result::error("Failed to send packet: %s", strerror(errno));
    162     }
    163     return Result::success();
    164 }
    165 
    166 Result Socket::sendRawUdp(in_addr_t source,
    167                           uint16_t sourcePort,
    168                           in_addr_t destination,
    169                           uint16_t destinationPort,
    170                           unsigned int interfaceIndex,
    171                           const Message& message) {
    172     struct iphdr ip;
    173     struct udphdr udp;
    174 
    175     ip.version = IPVERSION;
    176     ip.ihl = sizeof(ip) >> 2;
    177     ip.tos = 0;
    178     ip.tot_len = htons(sizeof(ip) + sizeof(udp) + message.size());
    179     ip.id = 0;
    180     ip.frag_off = 0;
    181     ip.ttl = IPDEFTTL;
    182     ip.protocol = IPPROTO_UDP;
    183     ip.check = 0;
    184     ip.saddr = source;
    185     ip.daddr = destination;
    186     ip.check = finishChecksum(addChecksum(ip, 0));
    187 
    188     udp.source = htons(sourcePort);
    189     udp.dest = htons(destinationPort);
    190     udp.len = htons(sizeof(udp) + message.size());
    191     udp.check = 0;
    192 
    193     uint32_t udpChecksum = 0;
    194     udpChecksum = addChecksum(ip.saddr, udpChecksum);
    195     udpChecksum = addChecksum(ip.daddr, udpChecksum);
    196     udpChecksum = addChecksum(htons(IPPROTO_UDP), udpChecksum);
    197     udpChecksum = addChecksum(udp.len, udpChecksum);
    198     udpChecksum = addChecksum(udp, udpChecksum);
    199     udpChecksum = addChecksum(message.data(), message.size(), udpChecksum);
    200     udp.check = finishChecksum(udpChecksum);
    201 
    202     struct iovec iov[3];
    203 
    204     iov[0].iov_base = static_cast<void*>(&ip);
    205     iov[0].iov_len = sizeof(ip);
    206     iov[1].iov_base = static_cast<void*>(&udp);
    207     iov[1].iov_len = sizeof(udp);
    208     // sendmsg requires these to be non-const but for sending won't modify them
    209     iov[2].iov_base = static_cast<void*>(const_cast<uint8_t*>(message.data()));
    210     iov[2].iov_len = message.size();
    211 
    212     struct sockaddr_ll dest;
    213     memset(&dest, 0, sizeof(dest));
    214     dest.sll_family = AF_PACKET;
    215     dest.sll_protocol = htons(ETH_P_IP);
    216     dest.sll_ifindex = interfaceIndex;
    217     dest.sll_halen = ETH_ALEN;
    218     memset(dest.sll_addr, 0xFF, ETH_ALEN);
    219 
    220     struct msghdr header;
    221     memset(&header, 0, sizeof(header));
    222     header.msg_name = &dest;
    223     header.msg_namelen = sizeof(dest);
    224     header.msg_iov = iov;
    225     header.msg_iovlen = sizeof(iov) / sizeof(iov[0]);
    226 
    227     ssize_t res = ::sendmsg(mSocketFd, &header, 0);
    228     if (res == -1) {
    229         return Result::error("Failed to send message: %s", strerror(errno));
    230     }
    231     return Result::success();
    232 }
    233 
    234 Result Socket::receiveFromInterface(Message* message,
    235                                     unsigned int* interfaceIndex) {
    236     char controlData[CMSG_SPACE(sizeof(struct in_pktinfo))];
    237     struct msghdr header;
    238     memset(&header, 0, sizeof(header));
    239     struct iovec iov;
    240     iov.iov_base = message->data();
    241     iov.iov_len = message->capacity();
    242     header.msg_iov = &iov;
    243     header.msg_iovlen = 1;
    244     header.msg_control = &controlData;
    245     header.msg_controllen = sizeof(controlData);
    246 
    247     ssize_t bytesRead = ::recvmsg(mSocketFd, &header, 0);
    248     if (bytesRead < 0) {
    249         return Result::error("Error receiving on socket: %s", strerror(errno));
    250     }
    251     message->setSize(static_cast<size_t>(bytesRead));
    252     if (header.msg_controllen >= sizeof(struct cmsghdr)) {
    253         for (struct cmsghdr* ctrl = CMSG_FIRSTHDR(&header);
    254              ctrl;
    255              ctrl = CMSG_NXTHDR(&header, ctrl)) {
    256             if (ctrl->cmsg_level == SOL_IP &&
    257                 ctrl->cmsg_type == IP_PKTINFO) {
    258                 auto packetInfo =
    259                     reinterpret_cast<struct in_pktinfo*>(CMSG_DATA(ctrl));
    260                 *interfaceIndex = packetInfo->ipi_ifindex;
    261             }
    262         }
    263     }
    264     return Result::success();
    265 }
    266 
    267 Result Socket::receiveRawUdp(uint16_t expectedPort,
    268                              Message* message,
    269                              bool* isValid) {
    270     struct iphdr ip;
    271     struct udphdr udp;
    272 
    273     struct iovec iov[3];
    274     iov[0].iov_base = &ip;
    275     iov[0].iov_len = sizeof(ip);
    276     iov[1].iov_base = &udp;
    277     iov[1].iov_len = sizeof(udp);
    278     iov[2].iov_base = message->data();
    279     iov[2].iov_len = message->capacity();
    280 
    281     ssize_t bytesRead = ::readv(mSocketFd, iov, 3);
    282     if (bytesRead < 0) {
    283         return Result::error("Unable to read from socket: %s", strerror(errno));
    284     }
    285     if (static_cast<size_t>(bytesRead) < sizeof(ip) + sizeof(udp)) {
    286         // Not enough bytes to even cover IP and UDP headers
    287         *isValid = false;
    288         return Result::success();
    289     }
    290     *isValid = ip.version == IPVERSION &&
    291                ip.ihl == (sizeof(ip) >> 2) &&
    292                ip.protocol == IPPROTO_UDP &&
    293                udp.dest == htons(expectedPort);
    294 
    295     message->setSize(bytesRead - sizeof(ip) - sizeof(udp));
    296     return Result::success();
    297 }
    298 
    299 Result Socket::enableOption(int level, int optionName) {
    300     if (mSocketFd == -1) {
    301         return Result::error("Socket not open");
    302     }
    303 
    304     int enabled = 1;
    305     int status = ::setsockopt(mSocketFd,
    306                               level,
    307                               optionName,
    308                               &enabled,
    309                               sizeof(enabled));
    310     if (status == -1) {
    311         return Result::error("Failed to set socket option: %s",
    312                              strerror(errno));
    313     }
    314     return Result::success();
    315 }
    316