Home | History | Annotate | Download | only in shill
      1 //
      2 // Copyright (C) 2012 The Android Open Source Project
      3 //
      4 // Licensed under the Apache License, Version 2.0 (the "License");
      5 // you may not use this file except in compliance with the License.
      6 // You may obtain a copy of the License at
      7 //
      8 //      http://www.apache.org/licenses/LICENSE-2.0
      9 //
     10 // Unless required by applicable law or agreed to in writing, software
     11 // distributed under the License is distributed on an "AS IS" BASIS,
     12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13 // See the License for the specific language governing permissions and
     14 // limitations under the License.
     15 //
     16 
     17 #include "shill/routing_table.h"
     18 
     19 #include <arpa/inet.h>
     20 #include <fcntl.h>
     21 #include <linux/netlink.h>
     22 #include <linux/rtnetlink.h>
     23 #include <netinet/ether.h>
     24 #include <net/if.h>  // NOLINT - must be included after netinet/ether.h
     25 #include <net/if_arp.h>
     26 #include <string.h>
     27 #include <sys/socket.h>
     28 #include <time.h>
     29 #include <unistd.h>
     30 
     31 #include <memory>
     32 #include <string>
     33 
     34 #include <base/bind.h>
     35 #include <base/files/file_path.h>
     36 #include <base/files/file_util.h>
     37 #include <base/stl_util.h>
     38 #include <base/strings/stringprintf.h>
     39 
     40 #include "shill/ipconfig.h"
     41 #include "shill/logging.h"
     42 #include "shill/net/byte_string.h"
     43 #include "shill/net/rtnl_handler.h"
     44 #include "shill/net/rtnl_listener.h"
     45 #include "shill/net/rtnl_message.h"
     46 #include "shill/routing_table_entry.h"
     47 
     48 using base::Bind;
     49 using base::FilePath;
     50 using base::Unretained;
     51 using std::deque;
     52 using std::string;
     53 using std::vector;
     54 
     55 namespace shill {
     56 
     57 namespace Logging {
     58 static auto kModuleLogScope = ScopeLogger::kRoute;
     59 static string ObjectID(RoutingTable* r) { return "(routing_table)"; }
     60 }
     61 
     62 namespace {
     63 base::LazyInstance<RoutingTable> g_routing_table = LAZY_INSTANCE_INITIALIZER;
     64 }  // namespace
     65 
     66 // static
     67 const char RoutingTable::kRouteFlushPath4[] = "/proc/sys/net/ipv4/route/flush";
     68 // static
     69 const char RoutingTable::kRouteFlushPath6[] = "/proc/sys/net/ipv6/route/flush";
     70 
     71 RoutingTable::RoutingTable()
     72     : route_callback_(Bind(&RoutingTable::RouteMsgHandler, Unretained(this))),
     73       rtnl_handler_(RTNLHandler::GetInstance()) {
     74   SLOG(this, 2) << __func__;
     75 }
     76 
     77 RoutingTable::~RoutingTable() {}
     78 
     79 RoutingTable* RoutingTable::GetInstance() {
     80   return g_routing_table.Pointer();
     81 }
     82 
     83 void RoutingTable::Start() {
     84   SLOG(this, 2) << __func__;
     85 
     86   route_listener_.reset(
     87       new RTNLListener(RTNLHandler::kRequestRoute, route_callback_));
     88   rtnl_handler_->RequestDump(RTNLHandler::kRequestRoute);
     89 }
     90 
     91 void RoutingTable::Stop() {
     92   SLOG(this, 2) << __func__;
     93 
     94   route_listener_.reset();
     95 }
     96 
     97 bool RoutingTable::AddRoute(int interface_index,
     98                             const RoutingTableEntry& entry) {
     99   SLOG(this, 2) << __func__ << ": "
    100                 << "destination " << entry.dst.ToString()
    101                 << " index " << interface_index
    102                 << " gateway " << entry.gateway.ToString()
    103                 << " metric " << entry.metric;
    104 
    105   CHECK(!entry.from_rtnl);
    106   if (!ApplyRoute(interface_index,
    107                   entry,
    108                   RTNLMessage::kModeAdd,
    109                   NLM_F_CREATE | NLM_F_EXCL)) {
    110     return false;
    111   }
    112   tables_[interface_index].push_back(entry);
    113   return true;
    114 }
    115 
    116 bool RoutingTable::GetDefaultRoute(int interface_index,
    117                                    IPAddress::Family family,
    118                                    RoutingTableEntry* entry) {
    119   RoutingTableEntry* found_entry;
    120   bool ret = GetDefaultRouteInternal(interface_index, family, &found_entry);
    121   if (ret) {
    122     *entry = *found_entry;
    123   }
    124   return ret;
    125 }
    126 
    127 bool RoutingTable::GetDefaultRouteInternal(int interface_index,
    128                                            IPAddress::Family family,
    129                                            RoutingTableEntry** entry) {
    130   SLOG(this, 2) << __func__ << " index " << interface_index
    131                 << " family " << IPAddress::GetAddressFamilyName(family);
    132 
    133   Tables::iterator table = tables_.find(interface_index);
    134   if (table == tables_.end()) {
    135     SLOG(this, 2) << __func__ << " no table";
    136     return false;
    137   }
    138 
    139   for (auto& nent : table->second) {
    140     if (nent.dst.IsDefault() && nent.dst.family() == family) {
    141       *entry = &nent;
    142       SLOG(this, 2) << __func__ << ": found"
    143                     << " gateway " << nent.gateway.ToString()
    144                     << " metric " << nent.metric;
    145       return true;
    146     }
    147   }
    148 
    149   SLOG(this, 2) << __func__ << " no route";
    150   return false;
    151 }
    152 
    153 bool RoutingTable::SetDefaultRoute(int interface_index,
    154                                    const IPAddress& gateway_address,
    155                                    uint32_t metric,
    156                                    uint8_t table_id) {
    157   SLOG(this, 2) << __func__ << " index " << interface_index
    158                 << " metric " << metric;
    159 
    160   RoutingTableEntry* old_entry;
    161 
    162   if (GetDefaultRouteInternal(interface_index,
    163                               gateway_address.family(),
    164                               &old_entry)) {
    165     if (old_entry->gateway.Equals(gateway_address)) {
    166       if (old_entry->metric != metric) {
    167         ReplaceMetric(interface_index, old_entry, metric);
    168       }
    169       return true;
    170     } else {
    171       // TODO(quiche): Update internal state as well?
    172       ApplyRoute(interface_index,
    173                  *old_entry,
    174                  RTNLMessage::kModeDelete,
    175                  0);
    176     }
    177   }
    178 
    179   IPAddress default_address(gateway_address.family());
    180   default_address.SetAddressToDefault();
    181 
    182   return AddRoute(interface_index,
    183                   RoutingTableEntry(default_address,
    184                                     default_address,
    185                                     gateway_address,
    186                                     metric,
    187                                     RT_SCOPE_UNIVERSE,
    188                                     false,
    189                                     table_id,
    190                                     RoutingTableEntry::kDefaultTag));
    191 }
    192 
    193 bool RoutingTable::ConfigureRoutes(int interface_index,
    194                                    const IPConfigRefPtr& ipconfig,
    195                                    uint32_t metric,
    196                                    uint8_t table_id) {
    197   bool ret = true;
    198 
    199   IPAddress::Family address_family = ipconfig->properties().address_family;
    200   const vector<IPConfig::Route>& routes = ipconfig->properties().routes;
    201 
    202   for (const auto& route : routes) {
    203     SLOG(this, 3) << "Installing route:"
    204                   << " Destination: " << route.host
    205                   << " Netmask: " << route.netmask
    206                   << " Gateway: " << route.gateway;
    207     IPAddress destination_address(address_family);
    208     IPAddress source_address(address_family);  // Left as default.
    209     IPAddress gateway_address(address_family);
    210     if (!destination_address.SetAddressFromString(route.host)) {
    211       LOG(ERROR) << "Failed to parse host "
    212                  << route.host;
    213       ret = false;
    214       continue;
    215     }
    216     if (!gateway_address.SetAddressFromString(route.gateway)) {
    217       LOG(ERROR) << "Failed to parse gateway "
    218                  << route.gateway;
    219       ret = false;
    220       continue;
    221     }
    222     destination_address.set_prefix(
    223         IPAddress::GetPrefixLengthFromMask(address_family, route.netmask));
    224     if (!AddRoute(interface_index,
    225                   RoutingTableEntry(destination_address,
    226                                     source_address,
    227                                     gateway_address,
    228                                     metric,
    229                                     RT_SCOPE_UNIVERSE,
    230                                     false,
    231                                     table_id,
    232                                     RoutingTableEntry::kDefaultTag))) {
    233       ret = false;
    234     }
    235   }
    236   return ret;
    237 }
    238 
    239 void RoutingTable::FlushRoutes(int interface_index) {
    240   SLOG(this, 2) << __func__;
    241 
    242   auto table = tables_.find(interface_index);
    243   if (table == tables_.end()) {
    244     return;
    245   }
    246 
    247   for (const auto& nent : table->second) {
    248     ApplyRoute(interface_index, nent, RTNLMessage::kModeDelete, 0);
    249   }
    250   table->second.clear();
    251 }
    252 
    253 void RoutingTable::FlushRoutesWithTag(int tag) {
    254   SLOG(this, 2) << __func__;
    255 
    256   for (auto& table : tables_) {
    257     for (auto nent = table.second.begin(); nent != table.second.end();) {
    258       if (nent->tag == tag) {
    259         ApplyRoute(table.first, *nent, RTNLMessage::kModeDelete, 0);
    260         nent = table.second.erase(nent);
    261       } else {
    262         ++nent;
    263       }
    264     }
    265   }
    266 }
    267 
    268 void RoutingTable::ResetTable(int interface_index) {
    269   tables_.erase(interface_index);
    270 }
    271 
    272 void RoutingTable::SetDefaultMetric(int interface_index, uint32_t metric) {
    273   SLOG(this, 2) << __func__ << " index " << interface_index
    274                 << " metric " << metric;
    275 
    276   RoutingTableEntry* entry;
    277   if (GetDefaultRouteInternal(
    278           interface_index, IPAddress::kFamilyIPv4, &entry) &&
    279       entry->metric != metric) {
    280     ReplaceMetric(interface_index, entry, metric);
    281   }
    282 
    283   if (GetDefaultRouteInternal(
    284           interface_index, IPAddress::kFamilyIPv6, &entry) &&
    285       entry->metric != metric) {
    286     ReplaceMetric(interface_index, entry, metric);
    287   }
    288 }
    289 
    290 // static
    291 bool RoutingTable::ParseRoutingTableMessage(const RTNLMessage& message,
    292                                             int* interface_index,
    293                                             RoutingTableEntry* entry) {
    294   if (message.type() != RTNLMessage::kTypeRoute ||
    295       message.family() == IPAddress::kFamilyUnknown ||
    296       !message.HasAttribute(RTA_OIF)) {
    297     return false;
    298   }
    299 
    300   const RTNLMessage::RouteStatus& route_status = message.route_status();
    301 
    302   if (route_status.type != RTN_UNICAST) {
    303     return false;
    304   }
    305 
    306   uint32_t interface_index_u32 = 0;
    307   if (!message.GetAttribute(RTA_OIF).ConvertToCPUUInt32(&interface_index_u32)) {
    308     return false;
    309   }
    310   *interface_index = interface_index_u32;
    311 
    312   uint32_t metric = 0;
    313   if (message.HasAttribute(RTA_PRIORITY)) {
    314     message.GetAttribute(RTA_PRIORITY).ConvertToCPUUInt32(&metric);
    315   }
    316 
    317   IPAddress default_addr(message.family());
    318   default_addr.SetAddressToDefault();
    319 
    320   ByteString dst_bytes(default_addr.address());
    321   if (message.HasAttribute(RTA_DST)) {
    322     dst_bytes = message.GetAttribute(RTA_DST);
    323   }
    324   ByteString src_bytes(default_addr.address());
    325   if (message.HasAttribute(RTA_SRC)) {
    326     src_bytes = message.GetAttribute(RTA_SRC);
    327   }
    328   ByteString gateway_bytes(default_addr.address());
    329   if (message.HasAttribute(RTA_GATEWAY)) {
    330     gateway_bytes = message.GetAttribute(RTA_GATEWAY);
    331   }
    332 
    333   entry->dst = IPAddress(message.family(), dst_bytes, route_status.dst_prefix);
    334   entry->src = IPAddress(message.family(), src_bytes, route_status.src_prefix);
    335   entry->gateway = IPAddress(message.family(), gateway_bytes);
    336   entry->metric = metric;
    337   entry->scope = route_status.scope;
    338   entry->from_rtnl = true;
    339   entry->table = route_status.table;
    340 
    341   return true;
    342 }
    343 
    344 void RoutingTable::RouteMsgHandler(const RTNLMessage& message) {
    345   int interface_index;
    346   RoutingTableEntry entry;
    347 
    348   if (!ParseRoutingTableMessage(message, &interface_index, &entry)) {
    349     return;
    350   }
    351 
    352   if (!route_queries_.empty() &&
    353       message.route_status().protocol == RTPROT_UNSPEC) {
    354     SLOG(this, 3) << __func__ << ": Message seq: " << message.seq()
    355                   << " mode " << message.mode()
    356                   << ", next query seq: " << route_queries_.front().sequence;
    357 
    358     // Purge queries that have expired (sequence number of this message is
    359     // greater than that of the head of the route query sequence).  Do the
    360     // math in a way that's roll-over independent.
    361     const auto kuint32max = std::numeric_limits<uint32_t>::max();
    362     while (route_queries_.front().sequence - message.seq() > kuint32max / 2) {
    363       LOG(ERROR) << __func__ << ": Purging un-replied route request sequence "
    364                  << route_queries_.front().sequence
    365                  << " (< " << message.seq() << ")";
    366       route_queries_.pop_front();
    367       if (route_queries_.empty())
    368         return;
    369     }
    370 
    371     const Query& query = route_queries_.front();
    372     if (query.sequence == message.seq()) {
    373       RoutingTableEntry add_entry(entry);
    374       add_entry.from_rtnl = false;
    375       add_entry.tag = query.tag;
    376       add_entry.table = query.table_id;
    377       bool added = true;
    378       if (add_entry.gateway.IsDefault()) {
    379         SLOG(this, 2) << __func__ << ": Ignoring route result with no gateway "
    380                       << "since we don't need to plumb these.";
    381       } else {
    382         SLOG(this, 2) << __func__ << ": Adding host route to "
    383                       << add_entry.dst.ToString();
    384         added = AddRoute(interface_index, add_entry);
    385       }
    386       if (added && !query.callback.is_null()) {
    387         SLOG(this, 2) << "Running query callback.";
    388         query.callback.Run(interface_index, add_entry);
    389       }
    390       route_queries_.pop_front();
    391     }
    392     return;
    393   } else if (message.route_status().protocol != RTPROT_BOOT) {
    394     // Responses to route queries come back with a protocol of
    395     // RTPROT_UNSPEC.  Otherwise, normal route updates that we are
    396     // interested in come with a protocol of RTPROT_BOOT.
    397     return;
    398   }
    399 
    400   TableEntryVector& table = tables_[interface_index];
    401   for (auto nent = table.begin(); nent != table.end(); ++nent)  {
    402     if (nent->dst.Equals(entry.dst) &&
    403         nent->src.Equals(entry.src) &&
    404         nent->gateway.Equals(entry.gateway) &&
    405         nent->scope == entry.scope) {
    406       if (message.mode() == RTNLMessage::kModeDelete &&
    407           nent->metric == entry.metric) {
    408         table.erase(nent);
    409       } else if (message.mode() == RTNLMessage::kModeAdd) {
    410         nent->from_rtnl = true;
    411         nent->metric = entry.metric;
    412       }
    413       return;
    414     }
    415   }
    416 
    417   if (message.mode() == RTNLMessage::kModeAdd) {
    418     SLOG(this, 2) << __func__ << " adding"
    419                   << " destination " << entry.dst.ToString()
    420                   << " index " << interface_index
    421                   << " gateway " << entry.gateway.ToString()
    422                   << " metric " << entry.metric;
    423     table.push_back(entry);
    424   }
    425 }
    426 
    427 bool RoutingTable::ApplyRoute(uint32_t interface_index,
    428                               const RoutingTableEntry& entry,
    429                               RTNLMessage::Mode mode,
    430                               unsigned int flags) {
    431   SLOG(this, 2) << base::StringPrintf(
    432       "%s: dst %s/%d src %s/%d index %d mode %d flags 0x%x",
    433       __func__, entry.dst.ToString().c_str(), entry.dst.prefix(),
    434       entry.src.ToString().c_str(), entry.src.prefix(),
    435       interface_index, mode, flags);
    436 
    437   RTNLMessage message(
    438       RTNLMessage::kTypeRoute,
    439       mode,
    440       NLM_F_REQUEST | flags,
    441       0,
    442       0,
    443       0,
    444       entry.dst.family());
    445 
    446   message.set_route_status(RTNLMessage::RouteStatus(
    447       entry.dst.prefix(),
    448       entry.src.prefix(),
    449       entry.table,
    450       RTPROT_BOOT,
    451       entry.scope,
    452       RTN_UNICAST,
    453       0));
    454 
    455   message.SetAttribute(RTA_DST, entry.dst.address());
    456   if (!entry.src.IsDefault()) {
    457     message.SetAttribute(RTA_SRC, entry.src.address());
    458   }
    459   if (!entry.gateway.IsDefault()) {
    460     message.SetAttribute(RTA_GATEWAY, entry.gateway.address());
    461   }
    462   message.SetAttribute(RTA_PRIORITY,
    463                        ByteString::CreateFromCPUUInt32(entry.metric));
    464   message.SetAttribute(RTA_OIF,
    465                        ByteString::CreateFromCPUUInt32(interface_index));
    466 
    467   return rtnl_handler_->SendMessage(&message);
    468 }
    469 
    470 // Somewhat surprisingly, the kernel allows you to create multiple routes
    471 // to the same destination through the same interface with different metrics.
    472 // Therefore, to change the metric on a route, we can't just use the
    473 // NLM_F_REPLACE flag by itself.  We have to explicitly remove the old route.
    474 // We do so after creating the route at a new metric so there is no traffic
    475 // disruption to existing network streams.
    476 void RoutingTable::ReplaceMetric(uint32_t interface_index,
    477                                  RoutingTableEntry* entry,
    478                                  uint32_t metric) {
    479   SLOG(this, 2) << __func__ << " index " << interface_index
    480                 << " metric " << metric;
    481   RoutingTableEntry new_entry = *entry;
    482   new_entry.metric = metric;
    483   // First create the route at the new metric.
    484   ApplyRoute(interface_index, new_entry, RTNLMessage::kModeAdd,
    485              NLM_F_CREATE | NLM_F_REPLACE);
    486   // Then delete the route at the old metric.
    487   ApplyRoute(interface_index, *entry, RTNLMessage::kModeDelete, 0);
    488   // Now, update our routing table (via |*entry|) from |new_entry|.
    489   *entry = new_entry;
    490 }
    491 
    492 bool RoutingTable::FlushCache() {
    493   static const char* kPaths[2] = { kRouteFlushPath4, kRouteFlushPath6 };
    494   bool ret = true;
    495 
    496   SLOG(this, 2) << __func__;
    497 
    498   for (size_t i = 0; i < arraysize(kPaths); ++i) {
    499     if (base::WriteFile(FilePath(kPaths[i]), "-1", 2) != 2) {
    500       LOG(ERROR) << base::StringPrintf("Cannot write to route flush file %s",
    501                                        kPaths[i]);
    502       ret = false;
    503     }
    504   }
    505 
    506   return ret;
    507 }
    508 
    509 bool RoutingTable::RequestRouteToHost(const IPAddress& address,
    510                                       int interface_index,
    511                                       int tag,
    512                                       const Query::Callback& callback,
    513                                       uint8_t table_id) {
    514   // Make sure we don't get a cached response that is no longer valid.
    515   FlushCache();
    516 
    517   RTNLMessage message(
    518       RTNLMessage::kTypeRoute,
    519       RTNLMessage::kModeQuery,
    520       NLM_F_REQUEST,
    521       0,
    522       0,
    523       interface_index,
    524       address.family());
    525 
    526   RTNLMessage::RouteStatus status;
    527   status.dst_prefix = address.prefix();
    528   message.set_route_status(status);
    529   message.SetAttribute(RTA_DST, address.address());
    530 
    531   if (interface_index != -1) {
    532     message.SetAttribute(RTA_OIF,
    533                          ByteString::CreateFromCPUUInt32(interface_index));
    534   }
    535 
    536   if (!rtnl_handler_->SendMessage(&message)) {
    537     return false;
    538   }
    539 
    540   // Save the sequence number of the request so we can create a route for
    541   // this host when we get a reply.
    542   route_queries_.push_back(Query(message.seq(), tag, callback, table_id));
    543 
    544   return true;
    545 }
    546 
    547 bool RoutingTable::CreateBlackholeRoute(int interface_index,
    548                                         IPAddress::Family family,
    549                                         uint32_t metric,
    550                                         uint8_t table_id) {
    551   SLOG(this, 2) << base::StringPrintf(
    552       "%s: index %d family %s metric %d",
    553       __func__, interface_index,
    554       IPAddress::GetAddressFamilyName(family).c_str(), metric);
    555 
    556   RTNLMessage message(
    557       RTNLMessage::kTypeRoute,
    558       RTNLMessage::kModeAdd,
    559       NLM_F_REQUEST | NLM_F_CREATE | NLM_F_EXCL,
    560       0,
    561       0,
    562       0,
    563       family);
    564 
    565   message.set_route_status(RTNLMessage::RouteStatus(
    566       0,
    567       0,
    568       table_id,
    569       RTPROT_BOOT,
    570       RT_SCOPE_UNIVERSE,
    571       RTN_BLACKHOLE,
    572       0));
    573 
    574   message.SetAttribute(RTA_PRIORITY,
    575                        ByteString::CreateFromCPUUInt32(metric));
    576   message.SetAttribute(RTA_OIF,
    577                        ByteString::CreateFromCPUUInt32(interface_index));
    578 
    579   return rtnl_handler_->SendMessage(&message);
    580 }
    581 
    582 bool RoutingTable::CreateLinkRoute(int interface_index,
    583                                    const IPAddress& local_address,
    584                                    const IPAddress& remote_address,
    585                                    uint8_t table_id) {
    586   if (!local_address.CanReachAddress(remote_address)) {
    587     LOG(ERROR) << __func__ << " failed: "
    588                << remote_address.ToString() << " is not reachable from "
    589                << local_address.ToString();
    590     return false;
    591   }
    592 
    593   IPAddress default_address(local_address.family());
    594   default_address.SetAddressToDefault();
    595   IPAddress destination_address(remote_address);
    596   destination_address.set_prefix(
    597       IPAddress::GetMaxPrefixLength(remote_address.family()));
    598   SLOG(this, 2) << "Creating link route to " << destination_address.ToString()
    599                 << " from " << local_address.ToString()
    600                 << " on interface index " << interface_index;
    601   return AddRoute(interface_index,
    602                   RoutingTableEntry(destination_address,
    603                                     local_address,
    604                                     default_address,
    605                                     0,
    606                                     RT_SCOPE_LINK,
    607                                     false,
    608                                     table_id,
    609                                     RoutingTableEntry::kDefaultTag));
    610 }
    611 
    612 }  // namespace shill
    613