1 /* 2 * Copyright 2008, The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include <stdlib.h> 18 #include <unistd.h> 19 #include <sys/uio.h> 20 #include <sys/socket.h> 21 #include <netinet/in.h> 22 #include <netinet/ip.h> 23 #include <netinet/udp.h> 24 #include <linux/if_packet.h> 25 #include <linux/if_ether.h> 26 #include <errno.h> 27 28 #ifdef ANDROID 29 #define LOG_TAG "DHCP" 30 #include <cutils/log.h> 31 #else 32 #include <stdio.h> 33 #include <string.h> 34 #define LOGD printf 35 #define LOGW printf 36 #endif 37 38 #include "dhcpmsg.h" 39 40 int fatal(); 41 42 int open_raw_socket(const char *ifname, uint8_t *hwaddr, int if_index) 43 { 44 int s, flag; 45 struct sockaddr_ll bindaddr; 46 47 if((s = socket(PF_PACKET, SOCK_DGRAM, htons(ETH_P_IP))) < 0) { 48 return fatal("socket(PF_PACKET)"); 49 } 50 51 memset(&bindaddr, 0, sizeof(bindaddr)); 52 bindaddr.sll_family = AF_PACKET; 53 bindaddr.sll_protocol = htons(ETH_P_IP); 54 bindaddr.sll_halen = ETH_ALEN; 55 memcpy(bindaddr.sll_addr, hwaddr, ETH_ALEN); 56 bindaddr.sll_ifindex = if_index; 57 58 if (bind(s, (struct sockaddr *)&bindaddr, sizeof(bindaddr)) < 0) { 59 return fatal("Cannot bind raw socket to interface"); 60 } 61 62 return s; 63 } 64 65 static uint32_t checksum(void *buffer, unsigned int count, uint32_t startsum) 66 { 67 uint16_t *up = (uint16_t *)buffer; 68 uint32_t sum = startsum; 69 uint32_t upper16; 70 71 while (count > 1) { 72 sum += *up++; 73 count -= 2; 74 } 75 if (count > 0) { 76 sum += (uint16_t) *(uint8_t *)up; 77 } 78 while ((upper16 = (sum >> 16)) != 0) { 79 sum = (sum & 0xffff) + upper16; 80 } 81 return sum; 82 } 83 84 static uint32_t finish_sum(uint32_t sum) 85 { 86 return ~sum & 0xffff; 87 } 88 89 int send_packet(int s, int if_index, struct dhcp_msg *msg, int size, 90 uint32_t saddr, uint32_t daddr, uint32_t sport, uint32_t dport) 91 { 92 struct iphdr ip; 93 struct udphdr udp; 94 struct iovec iov[3]; 95 uint32_t udpsum; 96 uint16_t temp; 97 struct msghdr msghdr; 98 struct sockaddr_ll destaddr; 99 100 ip.version = IPVERSION; 101 ip.ihl = sizeof(ip) >> 2; 102 ip.tos = 0; 103 ip.tot_len = htons(sizeof(ip) + sizeof(udp) + size); 104 ip.id = 0; 105 ip.frag_off = 0; 106 ip.ttl = IPDEFTTL; 107 ip.protocol = IPPROTO_UDP; 108 ip.check = 0; 109 ip.saddr = saddr; 110 ip.daddr = daddr; 111 ip.check = finish_sum(checksum(&ip, sizeof(ip), 0)); 112 113 udp.source = htons(sport); 114 udp.dest = htons(dport); 115 udp.len = htons(sizeof(udp) + size); 116 udp.check = 0; 117 118 /* Calculate checksum for pseudo header */ 119 udpsum = checksum(&ip.saddr, sizeof(ip.saddr), 0); 120 udpsum = checksum(&ip.daddr, sizeof(ip.daddr), udpsum); 121 temp = htons(IPPROTO_UDP); 122 udpsum = checksum(&temp, sizeof(temp), udpsum); 123 temp = udp.len; 124 udpsum = checksum(&temp, sizeof(temp), udpsum); 125 126 /* Add in the checksum for the udp header */ 127 udpsum = checksum(&udp, sizeof(udp), udpsum); 128 129 /* Add in the checksum for the data */ 130 udpsum = checksum(msg, size, udpsum); 131 udp.check = finish_sum(udpsum); 132 133 iov[0].iov_base = (char *)&ip; 134 iov[0].iov_len = sizeof(ip); 135 iov[1].iov_base = (char *)&udp; 136 iov[1].iov_len = sizeof(udp); 137 iov[2].iov_base = (char *)msg; 138 iov[2].iov_len = size; 139 memset(&destaddr, 0, sizeof(destaddr)); 140 destaddr.sll_family = AF_PACKET; 141 destaddr.sll_protocol = htons(ETH_P_IP); 142 destaddr.sll_ifindex = if_index; 143 destaddr.sll_halen = ETH_ALEN; 144 memcpy(destaddr.sll_addr, "\xff\xff\xff\xff\xff\xff", ETH_ALEN); 145 146 msghdr.msg_name = &destaddr; 147 msghdr.msg_namelen = sizeof(destaddr); 148 msghdr.msg_iov = iov; 149 msghdr.msg_iovlen = sizeof(iov) / sizeof(struct iovec); 150 msghdr.msg_flags = 0; 151 msghdr.msg_control = 0; 152 msghdr.msg_controllen = 0; 153 return sendmsg(s, &msghdr, 0); 154 } 155 156 int receive_packet(int s, struct dhcp_msg *msg) 157 { 158 int nread; 159 int is_valid; 160 struct dhcp_packet { 161 struct iphdr ip; 162 struct udphdr udp; 163 struct dhcp_msg dhcp; 164 } packet; 165 int dhcp_size; 166 uint32_t sum; 167 uint16_t temp; 168 uint32_t saddr, daddr; 169 170 nread = read(s, &packet, sizeof(packet)); 171 if (nread < 0) { 172 return -1; 173 } 174 /* 175 * The raw packet interface gives us all packets received by the 176 * network interface. We need to filter out all packets that are 177 * not meant for us. 178 */ 179 is_valid = 0; 180 if (nread < (int)(sizeof(struct iphdr) + sizeof(struct udphdr))) { 181 #if VERBOSE 182 LOGD("Packet is too small (%d) to be a UDP datagram", nread); 183 #endif 184 } else if (packet.ip.version != IPVERSION || packet.ip.ihl != (sizeof(packet.ip) >> 2)) { 185 #if VERBOSE 186 LOGD("Not a valid IP packet"); 187 #endif 188 } else if (nread < ntohs(packet.ip.tot_len)) { 189 #if VERBOSE 190 LOGD("Packet was truncated (read %d, needed %d)", nread, ntohs(packet.ip.tot_len)); 191 #endif 192 } else if (packet.ip.protocol != IPPROTO_UDP) { 193 #if VERBOSE 194 LOGD("IP protocol (%d) is not UDP", packet.ip.protocol); 195 #endif 196 } else if (packet.udp.dest != htons(PORT_BOOTP_CLIENT)) { 197 #if VERBOSE 198 LOGD("UDP dest port (%d) is not DHCP client", ntohs(packet.udp.dest)); 199 #endif 200 } else { 201 is_valid = 1; 202 } 203 204 if (!is_valid) { 205 return -1; 206 } 207 208 /* Seems like it's probably a valid DHCP packet */ 209 /* validate IP header checksum */ 210 sum = finish_sum(checksum(&packet.ip, sizeof(packet.ip), 0)); 211 if (sum != 0) { 212 LOGW("IP header checksum failure (0x%x)", packet.ip.check); 213 return -1; 214 } 215 /* 216 * Validate the UDP checksum. 217 * Since we don't need the IP header anymore, we "borrow" it 218 * to construct the pseudo header used in the checksum calculation. 219 */ 220 dhcp_size = ntohs(packet.udp.len) - sizeof(packet.udp); 221 saddr = packet.ip.saddr; 222 daddr = packet.ip.daddr; 223 nread = ntohs(packet.ip.tot_len); 224 memset(&packet.ip, 0, sizeof(packet.ip)); 225 packet.ip.saddr = saddr; 226 packet.ip.daddr = daddr; 227 packet.ip.protocol = IPPROTO_UDP; 228 packet.ip.tot_len = packet.udp.len; 229 temp = packet.udp.check; 230 packet.udp.check = 0; 231 sum = finish_sum(checksum(&packet, nread, 0)); 232 packet.udp.check = temp; 233 if (temp != sum) { 234 LOGW("UDP header checksum failure (0x%x should be 0x%x)", sum, temp); 235 return -1; 236 } 237 memcpy(msg, &packet.dhcp, dhcp_size); 238 return dhcp_size; 239 } 240