1 /* 2 * Copyright 2017, The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "socket.h" 18 19 #include "message.h" 20 #include "utils.h" 21 22 #include <errno.h> 23 #include <linux/if_packet.h> 24 #include <netinet/ip.h> 25 #include <netinet/udp.h> 26 #include <string.h> 27 #include <sys/socket.h> 28 #include <sys/types.h> 29 #include <sys/uio.h> 30 #include <unistd.h> 31 32 // Combine the checksum of |buffer| with |size| bytes with |checksum|. This is 33 // used for checksum calculations for IP and UDP. 34 static uint32_t addChecksum(const uint8_t* buffer, 35 size_t size, 36 uint32_t checksum) { 37 const uint16_t* data = reinterpret_cast<const uint16_t*>(buffer); 38 while (size > 1) { 39 checksum += *data++; 40 size -= 2; 41 } 42 if (size > 0) { 43 // Odd size, add the last byte 44 checksum += *reinterpret_cast<const uint8_t*>(data); 45 } 46 // msw is the most significant word, the upper 16 bits of the checksum 47 for (uint32_t msw = checksum >> 16; msw != 0; msw = checksum >> 16) { 48 checksum = (checksum & 0xFFFF) + msw; 49 } 50 return checksum; 51 } 52 53 // Convenienct template function for checksum calculation 54 template<typename T> 55 static uint32_t addChecksum(const T& data, uint32_t checksum) { 56 return addChecksum(reinterpret_cast<const uint8_t*>(&data), sizeof(T), checksum); 57 } 58 59 // Finalize the IP or UDP |checksum| by inverting and truncating it. 60 static uint32_t finishChecksum(uint32_t checksum) { 61 return ~checksum & 0xFFFF; 62 } 63 64 Socket::Socket() : mSocketFd(-1) { 65 } 66 67 Socket::~Socket() { 68 if (mSocketFd != -1) { 69 ::close(mSocketFd); 70 mSocketFd = -1; 71 } 72 } 73 74 75 Result Socket::open(int domain, int type, int protocol) { 76 if (mSocketFd != -1) { 77 return Result::error("Socket already open"); 78 } 79 mSocketFd = ::socket(domain, type, protocol); 80 if (mSocketFd == -1) { 81 return Result::error("Failed to open socket: %s", strerror(errno)); 82 } 83 return Result::success(); 84 } 85 86 Result Socket::bind(const void* sockaddr, size_t sockaddrLength) { 87 if (mSocketFd == -1) { 88 return Result::error("Socket not open"); 89 } 90 91 int status = ::bind(mSocketFd, 92 reinterpret_cast<const struct sockaddr*>(sockaddr), 93 sockaddrLength); 94 if (status != 0) { 95 return Result::error("Unable to bind raw socket: %s", strerror(errno)); 96 } 97 98 return Result::success(); 99 } 100 101 Result Socket::bindIp(in_addr_t address, uint16_t port) { 102 struct sockaddr_in sockaddr; 103 memset(&sockaddr, 0, sizeof(sockaddr)); 104 sockaddr.sin_family = AF_INET; 105 sockaddr.sin_port = htons(port); 106 sockaddr.sin_addr.s_addr = address; 107 108 return bind(&sockaddr, sizeof(sockaddr)); 109 } 110 111 Result Socket::bindRaw(unsigned int interfaceIndex) { 112 struct sockaddr_ll sockaddr; 113 memset(&sockaddr, 0, sizeof(sockaddr)); 114 sockaddr.sll_family = AF_PACKET; 115 sockaddr.sll_protocol = htons(ETH_P_IP); 116 sockaddr.sll_ifindex = interfaceIndex; 117 118 return bind(&sockaddr, sizeof(sockaddr)); 119 } 120 121 Result Socket::sendOnInterface(unsigned int interfaceIndex, 122 in_addr_t destinationAddress, 123 uint16_t destinationPort, 124 const Message& message) { 125 if (mSocketFd == -1) { 126 return Result::error("Socket not open"); 127 } 128 129 char controlData[CMSG_SPACE(sizeof(struct in_pktinfo))] = { 0 }; 130 struct sockaddr_in addr; 131 memset(&addr, 0, sizeof(addr)); 132 addr.sin_family = AF_INET; 133 addr.sin_port = htons(destinationPort); 134 addr.sin_addr.s_addr = destinationAddress; 135 136 struct msghdr header; 137 memset(&header, 0, sizeof(header)); 138 struct iovec iov; 139 // The struct member is non-const since it's used for receiving but it's 140 // safe to cast away const for sending. 141 iov.iov_base = const_cast<uint8_t*>(message.data()); 142 iov.iov_len = message.size(); 143 header.msg_name = &addr; 144 header.msg_namelen = sizeof(addr); 145 header.msg_iov = &iov; 146 header.msg_iovlen = 1; 147 header.msg_control = &controlData; 148 header.msg_controllen = sizeof(controlData); 149 150 struct cmsghdr* controlHeader = CMSG_FIRSTHDR(&header); 151 controlHeader->cmsg_level = IPPROTO_IP; 152 controlHeader->cmsg_type = IP_PKTINFO; 153 controlHeader->cmsg_len = CMSG_LEN(sizeof(struct in_pktinfo)); 154 auto packetInfo = 155 reinterpret_cast<struct in_pktinfo*>(CMSG_DATA(controlHeader)); 156 memset(packetInfo, 0, sizeof(*packetInfo)); 157 packetInfo->ipi_ifindex = interfaceIndex; 158 159 ssize_t status = ::sendmsg(mSocketFd, &header, 0); 160 if (status <= 0) { 161 return Result::error("Failed to send packet: %s", strerror(errno)); 162 } 163 return Result::success(); 164 } 165 166 Result Socket::sendRawUdp(in_addr_t source, 167 uint16_t sourcePort, 168 in_addr_t destination, 169 uint16_t destinationPort, 170 unsigned int interfaceIndex, 171 const Message& message) { 172 struct iphdr ip; 173 struct udphdr udp; 174 175 ip.version = IPVERSION; 176 ip.ihl = sizeof(ip) >> 2; 177 ip.tos = 0; 178 ip.tot_len = htons(sizeof(ip) + sizeof(udp) + message.size()); 179 ip.id = 0; 180 ip.frag_off = 0; 181 ip.ttl = IPDEFTTL; 182 ip.protocol = IPPROTO_UDP; 183 ip.check = 0; 184 ip.saddr = source; 185 ip.daddr = destination; 186 ip.check = finishChecksum(addChecksum(ip, 0)); 187 188 udp.source = htons(sourcePort); 189 udp.dest = htons(destinationPort); 190 udp.len = htons(sizeof(udp) + message.size()); 191 udp.check = 0; 192 193 uint32_t udpChecksum = 0; 194 udpChecksum = addChecksum(ip.saddr, udpChecksum); 195 udpChecksum = addChecksum(ip.daddr, udpChecksum); 196 udpChecksum = addChecksum(htons(IPPROTO_UDP), udpChecksum); 197 udpChecksum = addChecksum(udp.len, udpChecksum); 198 udpChecksum = addChecksum(udp, udpChecksum); 199 udpChecksum = addChecksum(message.data(), message.size(), udpChecksum); 200 udp.check = finishChecksum(udpChecksum); 201 202 struct iovec iov[3]; 203 204 iov[0].iov_base = static_cast<void*>(&ip); 205 iov[0].iov_len = sizeof(ip); 206 iov[1].iov_base = static_cast<void*>(&udp); 207 iov[1].iov_len = sizeof(udp); 208 // sendmsg requires these to be non-const but for sending won't modify them 209 iov[2].iov_base = static_cast<void*>(const_cast<uint8_t*>(message.data())); 210 iov[2].iov_len = message.size(); 211 212 struct sockaddr_ll dest; 213 memset(&dest, 0, sizeof(dest)); 214 dest.sll_family = AF_PACKET; 215 dest.sll_protocol = htons(ETH_P_IP); 216 dest.sll_ifindex = interfaceIndex; 217 dest.sll_halen = ETH_ALEN; 218 memset(dest.sll_addr, 0xFF, ETH_ALEN); 219 220 struct msghdr header; 221 memset(&header, 0, sizeof(header)); 222 header.msg_name = &dest; 223 header.msg_namelen = sizeof(dest); 224 header.msg_iov = iov; 225 header.msg_iovlen = sizeof(iov) / sizeof(iov[0]); 226 227 ssize_t res = ::sendmsg(mSocketFd, &header, 0); 228 if (res == -1) { 229 return Result::error("Failed to send message: %s", strerror(errno)); 230 } 231 return Result::success(); 232 } 233 234 Result Socket::receiveFromInterface(Message* message, 235 unsigned int* interfaceIndex) { 236 char controlData[CMSG_SPACE(sizeof(struct in_pktinfo))]; 237 struct msghdr header; 238 memset(&header, 0, sizeof(header)); 239 struct iovec iov; 240 iov.iov_base = message->data(); 241 iov.iov_len = message->capacity(); 242 header.msg_iov = &iov; 243 header.msg_iovlen = 1; 244 header.msg_control = &controlData; 245 header.msg_controllen = sizeof(controlData); 246 247 ssize_t bytesRead = ::recvmsg(mSocketFd, &header, 0); 248 if (bytesRead < 0) { 249 return Result::error("Error receiving on socket: %s", strerror(errno)); 250 } 251 message->setSize(static_cast<size_t>(bytesRead)); 252 if (header.msg_controllen >= sizeof(struct cmsghdr)) { 253 for (struct cmsghdr* ctrl = CMSG_FIRSTHDR(&header); 254 ctrl; 255 ctrl = CMSG_NXTHDR(&header, ctrl)) { 256 if (ctrl->cmsg_level == SOL_IP && 257 ctrl->cmsg_type == IP_PKTINFO) { 258 auto packetInfo = 259 reinterpret_cast<struct in_pktinfo*>(CMSG_DATA(ctrl)); 260 *interfaceIndex = packetInfo->ipi_ifindex; 261 } 262 } 263 } 264 return Result::success(); 265 } 266 267 Result Socket::receiveRawUdp(uint16_t expectedPort, 268 Message* message, 269 bool* isValid) { 270 struct iphdr ip; 271 struct udphdr udp; 272 273 struct iovec iov[3]; 274 iov[0].iov_base = &ip; 275 iov[0].iov_len = sizeof(ip); 276 iov[1].iov_base = &udp; 277 iov[1].iov_len = sizeof(udp); 278 iov[2].iov_base = message->data(); 279 iov[2].iov_len = message->capacity(); 280 281 ssize_t bytesRead = ::readv(mSocketFd, iov, 3); 282 if (bytesRead < 0) { 283 return Result::error("Unable to read from socket: %s", strerror(errno)); 284 } 285 if (static_cast<size_t>(bytesRead) < sizeof(ip) + sizeof(udp)) { 286 // Not enough bytes to even cover IP and UDP headers 287 *isValid = false; 288 return Result::success(); 289 } 290 *isValid = ip.version == IPVERSION && 291 ip.ihl == (sizeof(ip) >> 2) && 292 ip.protocol == IPPROTO_UDP && 293 udp.dest == htons(expectedPort); 294 295 message->setSize(bytesRead - sizeof(ip) - sizeof(udp)); 296 return Result::success(); 297 } 298 299 Result Socket::enableOption(int level, int optionName) { 300 if (mSocketFd == -1) { 301 return Result::error("Socket not open"); 302 } 303 304 int enabled = 1; 305 int status = ::setsockopt(mSocketFd, 306 level, 307 optionName, 308 &enabled, 309 sizeof(enabled)); 310 if (status == -1) { 311 return Result::error("Failed to set socket option: %s", 312 strerror(errno)); 313 } 314 return Result::success(); 315 } 316