Home | History | Annotate | Download | only in base
      1 /*
      2  *  Copyright 2004 The WebRTC Project Authors. All rights reserved.
      3  *
      4  *  Use of this source code is governed by a BSD-style license
      5  *  that can be found in the LICENSE file in the root of the source
      6  *  tree. An additional intellectual property rights grant can be found
      7  *  in the file PATENTS.  All contributing project authors may
      8  *  be found in the AUTHORS file in the root of the source tree.
      9  */
     10 
     11 #include "webrtc/base/natsocketfactory.h"
     12 #include "webrtc/base/natserver.h"
     13 #include "webrtc/base/logging.h"
     14 
     15 namespace rtc {
     16 
     17 RouteCmp::RouteCmp(NAT* nat) : symmetric(nat->IsSymmetric()) {
     18 }
     19 
     20 size_t RouteCmp::operator()(const SocketAddressPair& r) const {
     21   size_t h = r.source().Hash();
     22   if (symmetric)
     23     h ^= r.destination().Hash();
     24   return h;
     25 }
     26 
     27 bool RouteCmp::operator()(
     28       const SocketAddressPair& r1, const SocketAddressPair& r2) const {
     29   if (r1.source() < r2.source())
     30     return true;
     31   if (r2.source() < r1.source())
     32     return false;
     33   if (symmetric && (r1.destination() < r2.destination()))
     34     return true;
     35   if (symmetric && (r2.destination() < r1.destination()))
     36     return false;
     37   return false;
     38 }
     39 
     40 AddrCmp::AddrCmp(NAT* nat)
     41     : use_ip(nat->FiltersIP()), use_port(nat->FiltersPort()) {
     42 }
     43 
     44 size_t AddrCmp::operator()(const SocketAddress& a) const {
     45   size_t h = 0;
     46   if (use_ip)
     47     h ^= HashIP(a.ipaddr());
     48   if (use_port)
     49     h ^= a.port() | (a.port() << 16);
     50   return h;
     51 }
     52 
     53 bool AddrCmp::operator()(
     54       const SocketAddress& a1, const SocketAddress& a2) const {
     55   if (use_ip && (a1.ipaddr() < a2.ipaddr()))
     56     return true;
     57   if (use_ip && (a2.ipaddr() < a1.ipaddr()))
     58     return false;
     59   if (use_port && (a1.port() < a2.port()))
     60     return true;
     61   if (use_port && (a2.port() < a1.port()))
     62     return false;
     63   return false;
     64 }
     65 
     66 NATServer::NATServer(
     67     NATType type, SocketFactory* internal, const SocketAddress& internal_addr,
     68     SocketFactory* external, const SocketAddress& external_ip)
     69     : external_(external), external_ip_(external_ip.ipaddr(), 0) {
     70   nat_ = NAT::Create(type);
     71 
     72   server_socket_ = AsyncUDPSocket::Create(internal, internal_addr);
     73   server_socket_->SignalReadPacket.connect(this, &NATServer::OnInternalPacket);
     74 
     75   int_map_ = new InternalMap(RouteCmp(nat_));
     76   ext_map_ = new ExternalMap();
     77 }
     78 
     79 NATServer::~NATServer() {
     80   for (InternalMap::iterator iter = int_map_->begin();
     81        iter != int_map_->end();
     82        iter++)
     83     delete iter->second;
     84 
     85   delete nat_;
     86   delete server_socket_;
     87   delete int_map_;
     88   delete ext_map_;
     89 }
     90 
     91 void NATServer::OnInternalPacket(
     92     AsyncPacketSocket* socket, const char* buf, size_t size,
     93     const SocketAddress& addr, const PacketTime& packet_time) {
     94 
     95   // Read the intended destination from the wire.
     96   SocketAddress dest_addr;
     97   size_t length = UnpackAddressFromNAT(buf, size, &dest_addr);
     98 
     99   // Find the translation for these addresses (allocating one if necessary).
    100   SocketAddressPair route(addr, dest_addr);
    101   InternalMap::iterator iter = int_map_->find(route);
    102   if (iter == int_map_->end()) {
    103     Translate(route);
    104     iter = int_map_->find(route);
    105   }
    106   ASSERT(iter != int_map_->end());
    107 
    108   // Allow the destination to send packets back to the source.
    109   iter->second->WhitelistInsert(dest_addr);
    110 
    111   // Send the packet to its intended destination.
    112   rtc::PacketOptions options;
    113   iter->second->socket->SendTo(buf + length, size - length, dest_addr, options);
    114 }
    115 
    116 void NATServer::OnExternalPacket(
    117     AsyncPacketSocket* socket, const char* buf, size_t size,
    118     const SocketAddress& remote_addr, const PacketTime& packet_time) {
    119 
    120   SocketAddress local_addr = socket->GetLocalAddress();
    121 
    122   // Find the translation for this addresses.
    123   ExternalMap::iterator iter = ext_map_->find(local_addr);
    124   ASSERT(iter != ext_map_->end());
    125 
    126   // Allow the NAT to reject this packet.
    127   if (ShouldFilterOut(iter->second, remote_addr)) {
    128     LOG(LS_INFO) << "Packet from " << remote_addr.ToSensitiveString()
    129                  << " was filtered out by the NAT.";
    130     return;
    131   }
    132 
    133   // Forward this packet to the internal address.
    134   // First prepend the address in a quasi-STUN format.
    135   scoped_ptr<char[]> real_buf(new char[size + kNATEncodedIPv6AddressSize]);
    136   size_t addrlength = PackAddressForNAT(real_buf.get(),
    137                                         size + kNATEncodedIPv6AddressSize,
    138                                         remote_addr);
    139   // Copy the data part after the address.
    140   rtc::PacketOptions options;
    141   memcpy(real_buf.get() + addrlength, buf, size);
    142   server_socket_->SendTo(real_buf.get(), size + addrlength,
    143                          iter->second->route.source(), options);
    144 }
    145 
    146 void NATServer::Translate(const SocketAddressPair& route) {
    147   AsyncUDPSocket* socket = AsyncUDPSocket::Create(external_, external_ip_);
    148 
    149   if (!socket) {
    150     LOG(LS_ERROR) << "Couldn't find a free port!";
    151     return;
    152   }
    153 
    154   TransEntry* entry = new TransEntry(route, socket, nat_);
    155   (*int_map_)[route] = entry;
    156   (*ext_map_)[socket->GetLocalAddress()] = entry;
    157   socket->SignalReadPacket.connect(this, &NATServer::OnExternalPacket);
    158 }
    159 
    160 bool NATServer::ShouldFilterOut(TransEntry* entry,
    161                                 const SocketAddress& ext_addr) {
    162   return entry->WhitelistContains(ext_addr);
    163 }
    164 
    165 NATServer::TransEntry::TransEntry(
    166     const SocketAddressPair& r, AsyncUDPSocket* s, NAT* nat)
    167     : route(r), socket(s) {
    168   whitelist = new AddressSet(AddrCmp(nat));
    169 }
    170 
    171 NATServer::TransEntry::~TransEntry() {
    172   delete whitelist;
    173   delete socket;
    174 }
    175 
    176 void NATServer::TransEntry::WhitelistInsert(const SocketAddress& addr) {
    177   CritScope cs(&crit_);
    178   whitelist->insert(addr);
    179 }
    180 
    181 bool NATServer::TransEntry::WhitelistContains(const SocketAddress& ext_addr) {
    182   CritScope cs(&crit_);
    183   return whitelist->find(ext_addr) == whitelist->end();
    184 }
    185 
    186 }  // namespace rtc
    187