Home | History | Annotate | Download | only in base
      1 /*
      2  *  Copyright 2004 The WebRTC Project Authors. All rights reserved.
      3  *
      4  *  Use of this source code is governed by a BSD-style license
      5  *  that can be found in the LICENSE file in the root of the source
      6  *  tree. An additional intellectual property rights grant can be found
      7  *  in the file PATENTS.  All contributing project authors may
      8  *  be found in the AUTHORS file in the root of the source tree.
      9  */
     10 
     11 #include "webrtc/base/natsocketfactory.h"
     12 #include "webrtc/base/natserver.h"
     13 #include "webrtc/base/logging.h"
     14 #include "webrtc/base/socketadapters.h"
     15 
     16 namespace rtc {
     17 
     18 RouteCmp::RouteCmp(NAT* nat) : symmetric(nat->IsSymmetric()) {
     19 }
     20 
     21 size_t RouteCmp::operator()(const SocketAddressPair& r) const {
     22   size_t h = r.source().Hash();
     23   if (symmetric)
     24     h ^= r.destination().Hash();
     25   return h;
     26 }
     27 
     28 bool RouteCmp::operator()(
     29       const SocketAddressPair& r1, const SocketAddressPair& r2) const {
     30   if (r1.source() < r2.source())
     31     return true;
     32   if (r2.source() < r1.source())
     33     return false;
     34   if (symmetric && (r1.destination() < r2.destination()))
     35     return true;
     36   if (symmetric && (r2.destination() < r1.destination()))
     37     return false;
     38   return false;
     39 }
     40 
     41 AddrCmp::AddrCmp(NAT* nat)
     42     : use_ip(nat->FiltersIP()), use_port(nat->FiltersPort()) {
     43 }
     44 
     45 size_t AddrCmp::operator()(const SocketAddress& a) const {
     46   size_t h = 0;
     47   if (use_ip)
     48     h ^= HashIP(a.ipaddr());
     49   if (use_port)
     50     h ^= a.port() | (a.port() << 16);
     51   return h;
     52 }
     53 
     54 bool AddrCmp::operator()(
     55       const SocketAddress& a1, const SocketAddress& a2) const {
     56   if (use_ip && (a1.ipaddr() < a2.ipaddr()))
     57     return true;
     58   if (use_ip && (a2.ipaddr() < a1.ipaddr()))
     59     return false;
     60   if (use_port && (a1.port() < a2.port()))
     61     return true;
     62   if (use_port && (a2.port() < a1.port()))
     63     return false;
     64   return false;
     65 }
     66 
     67 // Proxy socket that will capture the external destination address intended for
     68 // a TCP connection to the NAT server.
     69 class NATProxyServerSocket : public AsyncProxyServerSocket {
     70  public:
     71   NATProxyServerSocket(AsyncSocket* socket)
     72       : AsyncProxyServerSocket(socket, kNATEncodedIPv6AddressSize) {
     73     BufferInput(true);
     74   }
     75 
     76   void SendConnectResult(int err, const SocketAddress& addr) override {
     77     char code = err ? 1 : 0;
     78     BufferedReadAdapter::DirectSend(&code, sizeof(char));
     79   }
     80 
     81  protected:
     82   void ProcessInput(char* data, size_t* len) override {
     83     if (*len < 2) {
     84       return;
     85     }
     86 
     87     int family = data[1];
     88     ASSERT(family == AF_INET || family == AF_INET6);
     89     if ((family == AF_INET && *len < kNATEncodedIPv4AddressSize) ||
     90         (family == AF_INET6 && *len < kNATEncodedIPv6AddressSize)) {
     91       return;
     92     }
     93 
     94     SocketAddress dest_addr;
     95     size_t address_length = UnpackAddressFromNAT(data, *len, &dest_addr);
     96 
     97     *len -= address_length;
     98     if (*len > 0) {
     99       memmove(data, data + address_length, *len);
    100     }
    101 
    102     bool remainder = (*len > 0);
    103     BufferInput(false);
    104     SignalConnectRequest(this, dest_addr);
    105     if (remainder) {
    106       SignalReadEvent(this);
    107     }
    108   }
    109 
    110 };
    111 
    112 class NATProxyServer : public ProxyServer {
    113  public:
    114   NATProxyServer(SocketFactory* int_factory, const SocketAddress& int_addr,
    115                  SocketFactory* ext_factory, const SocketAddress& ext_ip)
    116       : ProxyServer(int_factory, int_addr, ext_factory, ext_ip) {
    117   }
    118 
    119  protected:
    120   AsyncProxyServerSocket* WrapSocket(AsyncSocket* socket) override {
    121     return new NATProxyServerSocket(socket);
    122   }
    123 };
    124 
    125 NATServer::NATServer(
    126     NATType type, SocketFactory* internal,
    127     const SocketAddress& internal_udp_addr,
    128     const SocketAddress& internal_tcp_addr,
    129     SocketFactory* external, const SocketAddress& external_ip)
    130     : external_(external), external_ip_(external_ip.ipaddr(), 0) {
    131   nat_ = NAT::Create(type);
    132 
    133   udp_server_socket_ = AsyncUDPSocket::Create(internal, internal_udp_addr);
    134   udp_server_socket_->SignalReadPacket.connect(this,
    135                                                &NATServer::OnInternalUDPPacket);
    136   tcp_proxy_server_ = new NATProxyServer(internal, internal_tcp_addr, external,
    137                                          external_ip);
    138 
    139   int_map_ = new InternalMap(RouteCmp(nat_));
    140   ext_map_ = new ExternalMap();
    141 }
    142 
    143 NATServer::~NATServer() {
    144   for (InternalMap::iterator iter = int_map_->begin();
    145        iter != int_map_->end();
    146        iter++)
    147     delete iter->second;
    148 
    149   delete nat_;
    150   delete udp_server_socket_;
    151   delete tcp_proxy_server_;
    152   delete int_map_;
    153   delete ext_map_;
    154 }
    155 
    156 void NATServer::OnInternalUDPPacket(
    157     AsyncPacketSocket* socket, const char* buf, size_t size,
    158     const SocketAddress& addr, const PacketTime& packet_time) {
    159   // Read the intended destination from the wire.
    160   SocketAddress dest_addr;
    161   size_t length = UnpackAddressFromNAT(buf, size, &dest_addr);
    162 
    163   // Find the translation for these addresses (allocating one if necessary).
    164   SocketAddressPair route(addr, dest_addr);
    165   InternalMap::iterator iter = int_map_->find(route);
    166   if (iter == int_map_->end()) {
    167     Translate(route);
    168     iter = int_map_->find(route);
    169   }
    170   ASSERT(iter != int_map_->end());
    171 
    172   // Allow the destination to send packets back to the source.
    173   iter->second->WhitelistInsert(dest_addr);
    174 
    175   // Send the packet to its intended destination.
    176   rtc::PacketOptions options;
    177   iter->second->socket->SendTo(buf + length, size - length, dest_addr, options);
    178 }
    179 
    180 void NATServer::OnExternalUDPPacket(
    181     AsyncPacketSocket* socket, const char* buf, size_t size,
    182     const SocketAddress& remote_addr, const PacketTime& packet_time) {
    183   SocketAddress local_addr = socket->GetLocalAddress();
    184 
    185   // Find the translation for this addresses.
    186   ExternalMap::iterator iter = ext_map_->find(local_addr);
    187   ASSERT(iter != ext_map_->end());
    188 
    189   // Allow the NAT to reject this packet.
    190   if (ShouldFilterOut(iter->second, remote_addr)) {
    191     LOG(LS_INFO) << "Packet from " << remote_addr.ToSensitiveString()
    192                  << " was filtered out by the NAT.";
    193     return;
    194   }
    195 
    196   // Forward this packet to the internal address.
    197   // First prepend the address in a quasi-STUN format.
    198   scoped_ptr<char[]> real_buf(new char[size + kNATEncodedIPv6AddressSize]);
    199   size_t addrlength = PackAddressForNAT(real_buf.get(),
    200                                         size + kNATEncodedIPv6AddressSize,
    201                                         remote_addr);
    202   // Copy the data part after the address.
    203   rtc::PacketOptions options;
    204   memcpy(real_buf.get() + addrlength, buf, size);
    205   udp_server_socket_->SendTo(real_buf.get(), size + addrlength,
    206                              iter->second->route.source(), options);
    207 }
    208 
    209 void NATServer::Translate(const SocketAddressPair& route) {
    210   AsyncUDPSocket* socket = AsyncUDPSocket::Create(external_, external_ip_);
    211 
    212   if (!socket) {
    213     LOG(LS_ERROR) << "Couldn't find a free port!";
    214     return;
    215   }
    216 
    217   TransEntry* entry = new TransEntry(route, socket, nat_);
    218   (*int_map_)[route] = entry;
    219   (*ext_map_)[socket->GetLocalAddress()] = entry;
    220   socket->SignalReadPacket.connect(this, &NATServer::OnExternalUDPPacket);
    221 }
    222 
    223 bool NATServer::ShouldFilterOut(TransEntry* entry,
    224                                 const SocketAddress& ext_addr) {
    225   return entry->WhitelistContains(ext_addr);
    226 }
    227 
    228 NATServer::TransEntry::TransEntry(
    229     const SocketAddressPair& r, AsyncUDPSocket* s, NAT* nat)
    230     : route(r), socket(s) {
    231   whitelist = new AddressSet(AddrCmp(nat));
    232 }
    233 
    234 NATServer::TransEntry::~TransEntry() {
    235   delete whitelist;
    236   delete socket;
    237 }
    238 
    239 void NATServer::TransEntry::WhitelistInsert(const SocketAddress& addr) {
    240   CritScope cs(&crit_);
    241   whitelist->insert(addr);
    242 }
    243 
    244 bool NATServer::TransEntry::WhitelistContains(const SocketAddress& ext_addr) {
    245   CritScope cs(&crit_);
    246   return whitelist->find(ext_addr) == whitelist->end();
    247 }
    248 
    249 }  // namespace rtc
    250