Home | History | Annotate | Download | only in libnetutils
      1 /*
      2  * Copyright 2008, The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *     http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include <errno.h>
     18 #include <stdlib.h>
     19 #include <string.h>
     20 #include <sys/socket.h>
     21 #include <sys/uio.h>
     22 #include <linux/if_ether.h>
     23 #include <linux/if_packet.h>
     24 #include <netinet/in.h>
     25 #include <netinet/ip.h>
     26 #include <netinet/udp.h>
     27 #include <unistd.h>
     28 
     29 #ifdef ANDROID
     30 #define LOG_TAG "DHCP"
     31 #include <log/log.h>
     32 #else
     33 #include <stdio.h>
     34 #define ALOGD printf
     35 #define ALOGW printf
     36 #endif
     37 
     38 #include "dhcpmsg.h"
     39 
     40 int fatal();
     41 
     42 int open_raw_socket(const char *ifname __attribute__((unused)), uint8_t *hwaddr, int if_index)
     43 {
     44     int s;
     45     struct sockaddr_ll bindaddr;
     46 
     47     if((s = socket(PF_PACKET, SOCK_DGRAM, htons(ETH_P_IP))) < 0) {
     48         return fatal("socket(PF_PACKET)");
     49     }
     50 
     51     memset(&bindaddr, 0, sizeof(bindaddr));
     52     bindaddr.sll_family = AF_PACKET;
     53     bindaddr.sll_protocol = htons(ETH_P_IP);
     54     bindaddr.sll_halen = ETH_ALEN;
     55     memcpy(bindaddr.sll_addr, hwaddr, ETH_ALEN);
     56     bindaddr.sll_ifindex = if_index;
     57 
     58     if (bind(s, (struct sockaddr *)&bindaddr, sizeof(bindaddr)) < 0) {
     59         return fatal("Cannot bind raw socket to interface");
     60     }
     61 
     62     return s;
     63 }
     64 
     65 static uint32_t checksum(void *buffer, unsigned int count, uint32_t startsum)
     66 {
     67     uint16_t *up = (uint16_t *)buffer;
     68     uint32_t sum = startsum;
     69     uint32_t upper16;
     70 
     71     while (count > 1) {
     72         sum += *up++;
     73         count -= 2;
     74     }
     75     if (count > 0) {
     76         sum += (uint16_t) *(uint8_t *)up;
     77     }
     78     while ((upper16 = (sum >> 16)) != 0) {
     79         sum = (sum & 0xffff) + upper16;
     80     }
     81     return sum;
     82 }
     83 
     84 static uint32_t finish_sum(uint32_t sum)
     85 {
     86     return ~sum & 0xffff;
     87 }
     88 
     89 int send_packet(int s, int if_index, struct dhcp_msg *msg, int size,
     90                 uint32_t saddr, uint32_t daddr, uint32_t sport, uint32_t dport)
     91 {
     92     struct iphdr ip;
     93     struct udphdr udp;
     94     struct iovec iov[3];
     95     uint32_t udpsum;
     96     uint16_t temp;
     97     struct msghdr msghdr;
     98     struct sockaddr_ll destaddr;
     99 
    100     ip.version = IPVERSION;
    101     ip.ihl = sizeof(ip) >> 2;
    102     ip.tos = 0;
    103     ip.tot_len = htons(sizeof(ip) + sizeof(udp) + size);
    104     ip.id = 0;
    105     ip.frag_off = 0;
    106     ip.ttl = IPDEFTTL;
    107     ip.protocol = IPPROTO_UDP;
    108     ip.check = 0;
    109     ip.saddr = saddr;
    110     ip.daddr = daddr;
    111     ip.check = finish_sum(checksum(&ip, sizeof(ip), 0));
    112 
    113     udp.source = htons(sport);
    114     udp.dest = htons(dport);
    115     udp.len = htons(sizeof(udp) + size);
    116     udp.check = 0;
    117 
    118     /* Calculate checksum for pseudo header */
    119     udpsum = checksum(&ip.saddr, sizeof(ip.saddr), 0);
    120     udpsum = checksum(&ip.daddr, sizeof(ip.daddr), udpsum);
    121     temp = htons(IPPROTO_UDP);
    122     udpsum = checksum(&temp, sizeof(temp), udpsum);
    123     temp = udp.len;
    124     udpsum = checksum(&temp, sizeof(temp), udpsum);
    125 
    126     /* Add in the checksum for the udp header */
    127     udpsum = checksum(&udp, sizeof(udp), udpsum);
    128 
    129     /* Add in the checksum for the data */
    130     udpsum = checksum(msg, size, udpsum);
    131     udp.check = finish_sum(udpsum);
    132 
    133     iov[0].iov_base = (char *)&ip;
    134     iov[0].iov_len = sizeof(ip);
    135     iov[1].iov_base = (char *)&udp;
    136     iov[1].iov_len = sizeof(udp);
    137     iov[2].iov_base = (char *)msg;
    138     iov[2].iov_len = size;
    139     memset(&destaddr, 0, sizeof(destaddr));
    140     destaddr.sll_family = AF_PACKET;
    141     destaddr.sll_protocol = htons(ETH_P_IP);
    142     destaddr.sll_ifindex = if_index;
    143     destaddr.sll_halen = ETH_ALEN;
    144     memcpy(destaddr.sll_addr, "\xff\xff\xff\xff\xff\xff", ETH_ALEN);
    145 
    146     msghdr.msg_name = &destaddr;
    147     msghdr.msg_namelen = sizeof(destaddr);
    148     msghdr.msg_iov = iov;
    149     msghdr.msg_iovlen = sizeof(iov) / sizeof(struct iovec);
    150     msghdr.msg_flags = 0;
    151     msghdr.msg_control = 0;
    152     msghdr.msg_controllen = 0;
    153     return sendmsg(s, &msghdr, 0);
    154 }
    155 
    156 int receive_packet(int s, struct dhcp_msg *msg)
    157 {
    158     int nread;
    159     int is_valid;
    160     struct dhcp_packet {
    161         struct iphdr ip;
    162         struct udphdr udp;
    163         struct dhcp_msg dhcp;
    164     } packet;
    165     int dhcp_size;
    166     uint32_t sum;
    167     uint16_t temp;
    168     uint32_t saddr, daddr;
    169 
    170     nread = read(s, &packet, sizeof(packet));
    171     if (nread < 0) {
    172         return -1;
    173     }
    174     /*
    175      * The raw packet interface gives us all packets received by the
    176      * network interface. We need to filter out all packets that are
    177      * not meant for us.
    178      */
    179     is_valid = 0;
    180     if (nread < (int)(sizeof(struct iphdr) + sizeof(struct udphdr))) {
    181 #if VERBOSE
    182         ALOGD("Packet is too small (%d) to be a UDP datagram", nread);
    183 #endif
    184     } else if (packet.ip.version != IPVERSION || packet.ip.ihl != (sizeof(packet.ip) >> 2)) {
    185 #if VERBOSE
    186         ALOGD("Not a valid IP packet");
    187 #endif
    188     } else if (nread < ntohs(packet.ip.tot_len)) {
    189 #if VERBOSE
    190         ALOGD("Packet was truncated (read %d, needed %d)", nread, ntohs(packet.ip.tot_len));
    191 #endif
    192     } else if (packet.ip.protocol != IPPROTO_UDP) {
    193 #if VERBOSE
    194         ALOGD("IP protocol (%d) is not UDP", packet.ip.protocol);
    195 #endif
    196     } else if (packet.udp.dest != htons(PORT_BOOTP_CLIENT)) {
    197 #if VERBOSE
    198         ALOGD("UDP dest port (%d) is not DHCP client", ntohs(packet.udp.dest));
    199 #endif
    200     } else {
    201         is_valid = 1;
    202     }
    203 
    204     if (!is_valid) {
    205         return -1;
    206     }
    207 
    208     /* Seems like it's probably a valid DHCP packet */
    209     /* validate IP header checksum */
    210     sum = finish_sum(checksum(&packet.ip, sizeof(packet.ip), 0));
    211     if (sum != 0) {
    212         ALOGW("IP header checksum failure (0x%x)", packet.ip.check);
    213         return -1;
    214     }
    215     /*
    216      * Validate the UDP checksum.
    217      * Since we don't need the IP header anymore, we "borrow" it
    218      * to construct the pseudo header used in the checksum calculation.
    219      */
    220     dhcp_size = ntohs(packet.udp.len) - sizeof(packet.udp);
    221     /*
    222      * check validity of dhcp_size.
    223      * 1) cannot be negative or zero.
    224      * 2) src buffer contains enough bytes to copy
    225      * 3) cannot exceed destination buffer
    226      */
    227     if ((dhcp_size <= 0) ||
    228         ((int)(nread - sizeof(struct iphdr) - sizeof(struct udphdr)) < dhcp_size) ||
    229         ((int)sizeof(struct dhcp_msg) < dhcp_size)) {
    230 #if VERBOSE
    231         ALOGD("Malformed Packet");
    232 #endif
    233         return -1;
    234     }
    235     saddr = packet.ip.saddr;
    236     daddr = packet.ip.daddr;
    237     nread = ntohs(packet.ip.tot_len);
    238     memset(&packet.ip, 0, sizeof(packet.ip));
    239     packet.ip.saddr = saddr;
    240     packet.ip.daddr = daddr;
    241     packet.ip.protocol = IPPROTO_UDP;
    242     packet.ip.tot_len = packet.udp.len;
    243     temp = packet.udp.check;
    244     packet.udp.check = 0;
    245     sum = finish_sum(checksum(&packet, nread, 0));
    246     packet.udp.check = temp;
    247     if (!sum)
    248         sum = finish_sum(sum);
    249     if (temp != sum) {
    250         ALOGW("UDP header checksum failure (0x%x should be 0x%x)", sum, temp);
    251         return -1;
    252     }
    253     memcpy(msg, &packet.dhcp, dhcp_size);
    254     return dhcp_size;
    255 }
    256