Home | History | Annotate | Download | only in dhcp_client
      1 //
      2 // Copyright (C) 2015 The Android Open Source Project
      3 //
      4 // Licensed under the Apache License, Version 2.0 (the "License");
      5 // you may not use this file except in compliance with the License.
      6 // You may obtain a copy of the License at
      7 //
      8 //      http://www.apache.org/licenses/LICENSE-2.0
      9 //
     10 // Unless required by applicable law or agreed to in writing, software
     11 // distributed under the License is distributed on an "AS IS" BASIS,
     12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13 // See the License for the specific language governing permissions and
     14 // limitations under the License.
     15 //
     16 
     17 #include "dhcp_client/dhcpv4.h"
     18 
     19 #include <linux/filter.h>
     20 #include <linux/if_packet.h>
     21 #include <net/ethernet.h>
     22 #include <net/if.h>
     23 #include <net/if_arp.h>
     24 #include <netinet/ip.h>
     25 #include <netinet/udp.h>
     26 
     27 #include <random>
     28 
     29 #include <base/bind.h>
     30 #include <base/logging.h>
     31 
     32 #include "dhcp_client/dhcp_message.h"
     33 
     34 using base::Bind;
     35 using base::Unretained;
     36 using shill::ByteString;
     37 using shill::IOHandlerFactoryContainer;
     38 
     39 namespace dhcp_client {
     40 
     41 namespace {
     42 // UDP port numbers for DHCP.
     43 const uint16_t kDHCPServerPort = 67;
     44 const uint16_t kDHCPClientPort = 68;
     45 
     46 const int kInvalidSocketDescriptor = -1;
     47 
     48 // RFC 791: the minimum value for a correct header is 20 octets.
     49 // The maximum value is 60 octets.
     50 const size_t kIPHeaderMinLength = 20;
     51 const size_t kIPHeaderMaxLength = 60;
     52 
     53 // Socket filter for dhcp packet.
     54 const sock_filter dhcp_bpf_filter[] = {
     55   BPF_STMT(BPF_LD + BPF_B + BPF_ABS, 23 - ETH_HLEN),
     56   BPF_JUMP(BPF_JMP + BPF_JEQ + BPF_K, IPPROTO_UDP, 0, 6),
     57   BPF_STMT(BPF_LD + BPF_H + BPF_ABS, 20 - ETH_HLEN),
     58   BPF_JUMP(BPF_JMP + BPF_JSET + BPF_K, 0x1fff, 4, 0),
     59   BPF_STMT(BPF_LDX + BPF_B + BPF_MSH, 14 - ETH_HLEN),
     60   BPF_STMT(BPF_LD + BPF_H + BPF_IND, 16 - ETH_HLEN),
     61   BPF_JUMP(BPF_JMP + BPF_JEQ + BPF_K, kDHCPClientPort, 0, 1),
     62   BPF_STMT(BPF_RET + BPF_K, 0x0fffffff),
     63   BPF_STMT(BPF_RET + BPF_K, 0),
     64 };
     65 const int dhcp_bpf_filter_len =
     66     sizeof(dhcp_bpf_filter) / sizeof(dhcp_bpf_filter[0]);
     67 }  // namespace
     68 
     69 DHCPV4::DHCPV4(const std::string& interface_name,
     70                const ByteString& hardware_address,
     71                unsigned int interface_index,
     72                const std::string& network_id,
     73                bool request_hostname,
     74                bool arp_gateway,
     75                bool unicast_arp,
     76                EventDispatcherInterface* event_dispatcher)
     77     : interface_name_(interface_name),
     78       hardware_address_(hardware_address),
     79       interface_index_(interface_index),
     80       network_id_(network_id),
     81       request_hostname_(request_hostname),
     82       arp_gateway_(arp_gateway),
     83       unicast_arp_(unicast_arp),
     84       event_dispatcher_(event_dispatcher),
     85       io_handler_factory_(
     86           IOHandlerFactoryContainer::GetInstance()->GetIOHandlerFactory()),
     87       state_(State::INIT),
     88       from_(INADDR_ANY),
     89       to_(INADDR_BROADCAST),
     90       socket_(kInvalidSocketDescriptor),
     91       sockets_(new shill::Sockets()),
     92       random_engine_(time(nullptr)) {
     93 }
     94 
     95 DHCPV4::~DHCPV4() {
     96   Stop();
     97 }
     98 
     99 void DHCPV4::ParseRawPacket(shill::InputData* data) {
    100   if (data->len < sizeof(iphdr)) {
    101     LOG(ERROR) << "Invalid packet length from buffer";
    102     return;
    103   }
    104   // The socket filter has finished part the header validation.
    105   // This function will perform the remaining part.
    106   int header_len = ValidatePacketHeader(data->buf, data->len);
    107   if (header_len == -1) {
    108     return;
    109   }
    110   unsigned char* buffer = data->buf + header_len;
    111   DHCPMessage msg;
    112   if (!DHCPMessage::InitFromBuffer(buffer, data->len - header_len, &msg)) {
    113     LOG(ERROR) << "Failed to initialize DHCP message from buffer";
    114     return;
    115   }
    116   // In INIT state the client ignores all messages from server.
    117   if (state_ == State::INIT) {
    118     return;
    119   }
    120   // Check transaction id with the existing one.
    121   if (msg.transaction_id() != transaction_id_) {
    122     LOG(ERROR) << "Transaction id(xid) doesn't match";
    123     return;
    124   }
    125   uint8_t message_type = msg.message_type();
    126   switch (message_type) {
    127     case kDHCPMessageTypeOffer:
    128       HandleOffer(msg);
    129       break;
    130     case kDHCPMessageTypeAck:
    131       HandleAck(msg);
    132       break;
    133     case kDHCPMessageTypeNak:
    134       HandleNak(msg);
    135       break;
    136     default:
    137       LOG(ERROR) << "Invalid message type: "
    138                  << static_cast<int>(message_type);
    139   }
    140 }
    141 
    142 void DHCPV4::OnReadError(const std::string& error_msg) {
    143   LOG(INFO) << __func__;
    144 }
    145 
    146 bool DHCPV4::Start() {
    147   if (!CreateRawSocket()) {
    148     return false;
    149   }
    150 
    151   input_handler_.reset(io_handler_factory_->CreateIOInputHandler(
    152       socket_,
    153       Bind(&DHCPV4::ParseRawPacket, Unretained(this)),
    154       Bind(&DHCPV4::OnReadError, Unretained(this))));
    155   return true;
    156 }
    157 
    158 void DHCPV4::Stop() {
    159   input_handler_.reset();
    160   if (socket_ != kInvalidSocketDescriptor) {
    161     sockets_->Close(socket_);
    162   }
    163 }
    164 
    165 bool DHCPV4::CreateRawSocket() {
    166   int fd = sockets_->Socket(PF_PACKET,
    167                             SOCK_DGRAM | SOCK_CLOEXEC | SOCK_NONBLOCK,
    168                             htons(ETHERTYPE_IP));
    169   if (fd == kInvalidSocketDescriptor) {
    170     PLOG(ERROR) << "Failed to create socket";
    171     return false;
    172   }
    173   shill::ScopedSocketCloser socket_closer(sockets_.get(), fd);
    174 
    175   // Apply the socket filter.
    176   sock_fprog pf;
    177   memset(&pf, 0, sizeof(pf));
    178   pf.filter = const_cast<sock_filter*>(dhcp_bpf_filter);
    179   pf.len = dhcp_bpf_filter_len;
    180 
    181   if (sockets_->AttachFilter(fd, &pf) != 0) {
    182     PLOG(ERROR) << "Failed to attach filter";
    183     return false;
    184   }
    185 
    186   if (sockets_->ReuseAddress(fd) == -1) {
    187     PLOG(ERROR) << "Failed to reuse socket address";
    188     return false;
    189   }
    190 
    191   if (sockets_->BindToDevice(fd, interface_name_) < 0) {
    192     PLOG(ERROR) << "Failed to bind socket to device";
    193     return false;
    194   }
    195 
    196   struct sockaddr_ll local;
    197   memset(&local, 0, sizeof(local));
    198   local.sll_family = PF_PACKET;
    199   local.sll_protocol = htons(ETHERTYPE_IP);
    200   local.sll_ifindex = static_cast<int>(interface_index_);
    201 
    202   if (sockets_->Bind(fd,
    203                      reinterpret_cast<struct sockaddr*>(&local),
    204                      sizeof(local)) < 0) {
    205     PLOG(ERROR) << "Failed to bind to address";
    206     return false;
    207   }
    208 
    209   socket_ = socket_closer.Release();
    210   return true;
    211 }
    212 
    213 void DHCPV4::HandleOffer(const DHCPMessage& msg) {
    214   return;
    215 }
    216 
    217 void DHCPV4::HandleAck(const DHCPMessage& msg) {
    218   return;
    219 }
    220 
    221 void DHCPV4::HandleNak(const DHCPMessage& msg) {
    222   return;
    223 }
    224 
    225 bool DHCPV4::MakeRawPacket(const DHCPMessage& message, ByteString* output) {
    226   ByteString payload;
    227   if (!message.Serialize(&payload)) {
    228     LOG(ERROR) << "Failed to serialzie dhcp message";
    229     return false;
    230   }
    231   const size_t header_len = sizeof(struct iphdr) + sizeof(struct udphdr);
    232   const size_t payload_len = payload.GetLength();
    233 
    234   char buffer[header_len + payload_len];
    235   memset(buffer, 0, header_len + payload_len);
    236   struct iphdr* ip = reinterpret_cast<struct iphdr*>(buffer);
    237   struct udphdr* udp = reinterpret_cast<struct udphdr*>(buffer + sizeof(*ip));
    238 
    239   if (!payload.CopyData(payload_len, buffer + header_len)) {
    240     LOG(ERROR) << "Failed to copy data from payload";
    241     return false;
    242   }
    243   udp->uh_sport = htons(kDHCPClientPort);
    244   udp->uh_dport = htons(kDHCPServerPort);
    245   udp->uh_ulen =
    246       htons(static_cast<uint16_t>(sizeof(*udp) + payload.GetLength()));
    247 
    248   // Fill pseudo header (for UDP checksum computing):
    249   // Protocol.
    250   ip->protocol = IPPROTO_UDP;
    251   // Source IP address.
    252   ip->saddr = htonl(from_);
    253   // Destination IP address.
    254   ip->daddr = htonl(to_);
    255   // Total length, use udp packet length for pseudo header.
    256   ip->tot_len = udp->uh_ulen;
    257   // Calculate udp checksum based on:
    258   // IPV4 pseudo header, UDP header, and payload.
    259   udp->uh_sum = htons(DHCPMessage::ComputeChecksum(
    260       reinterpret_cast<const uint8_t*>(buffer),
    261       header_len + payload_len));
    262 
    263   // IP version.
    264   ip->version = IPVERSION;
    265   // IP header length.
    266   ip->ihl = sizeof(*ip) >> 2;
    267   // Fragment offset field.
    268   // The DHCP packet is always smaller than MTU,
    269   // so fragmentation is not needed.
    270   ip->frag_off = 0;
    271   // Identification.
    272   ip->id = static_cast<uint16_t>(
    273       std::uniform_int_distribution<unsigned int>()(
    274           random_engine_) % UINT16_MAX + 1);
    275   // Time to live.
    276   ip->ttl = IPDEFTTL;
    277   // Total length.
    278   ip->tot_len = htons(static_cast<uint16_t>(header_len+ payload.GetLength()));
    279   // Calculate IP Checksum only based on IP header.
    280   ip->check = htons(DHCPMessage::ComputeChecksum(
    281       reinterpret_cast<const uint8_t*>(ip),
    282       sizeof(*ip)));
    283 
    284   *output = ByteString(buffer, header_len + payload_len);
    285   return true;
    286 }
    287 
    288 bool DHCPV4::SendRawPacket(const ByteString& packet) {
    289   struct sockaddr_ll remote;
    290   memset(&remote, 0, sizeof(remote));
    291   remote.sll_family = AF_PACKET;
    292   remote.sll_protocol = htons(ETHERTYPE_IP);
    293   remote.sll_ifindex = interface_index_;
    294   remote.sll_hatype = htons(ARPHRD_ETHER);
    295   // Use broadcast hardware address.
    296   remote.sll_halen = IFHWADDRLEN;
    297   memset(remote.sll_addr, 0xff, IFHWADDRLEN);
    298 
    299   size_t result = sockets_->SendTo(socket_,
    300                                    packet.GetConstData(),
    301                                    packet.GetLength(),
    302                                    0,
    303                                    reinterpret_cast<struct sockaddr *>(&remote),
    304                                    sizeof(remote));
    305 
    306   if (result != packet.GetLength()) {
    307     PLOG(ERROR) << "Socket sento failed";
    308     return false;
    309   }
    310   return true;
    311 }
    312 
    313 int DHCPV4::ValidatePacketHeader(const unsigned char* buffer, size_t len) {
    314   const struct iphdr* ip =
    315       reinterpret_cast<const struct iphdr*>(buffer);
    316   const size_t ip_header_len = static_cast<size_t>(ip->ihl) << 2;
    317   if (ip_header_len < kIPHeaderMinLength ||
    318       ip_header_len > kIPHeaderMaxLength) {
    319     LOG(ERROR) << "Invalid Internet Header Length: "
    320                << ip_header_len << " bytes";
    321     return -1;
    322   }
    323   if (ip->tot_len != len) {
    324     LOG(ERROR) << "Invalid IP total length";
    325     return -1;
    326   }
    327   // TODO(nywang): Validate other ip header fields.
    328 
    329   const struct udphdr* udp =
    330       reinterpret_cast<const struct udphdr*>(buffer + ip_header_len);
    331   if (udp->uh_sport != htons(kDHCPServerPort) ||
    332       udp->uh_dport != htons(kDHCPClientPort)) {
    333     LOG(ERROR) << "Invlaid UDP ports";
    334     return -1;
    335   }
    336   if (udp->uh_ulen != len - ip_header_len) {
    337     LOG(ERROR) << "Invalid UDP total length";
    338     return -1;
    339   }
    340   // TODO(nywang): Validate UDP checksum.
    341 
    342   return ip_header_len + sizeof(*udp);
    343 }
    344 
    345 }  // namespace dhcp_client
    346 
    347