Home | History | Annotate | Download | only in base
      1 /*
      2  *  Copyright 2004 The WebRTC Project Authors. All rights reserved.
      3  *
      4  *  Use of this source code is governed by a BSD-style license
      5  *  that can be found in the LICENSE file in the root of the source
      6  *  tree. An additional intellectual property rights grant can be found
      7  *  in the file PATENTS.  All contributing project authors may
      8  *  be found in the AUTHORS file in the root of the source tree.
      9  */
     10 
     11 #include "webrtc/base/natsocketfactory.h"
     12 
     13 #include "webrtc/base/arraysize.h"
     14 #include "webrtc/base/logging.h"
     15 #include "webrtc/base/natserver.h"
     16 #include "webrtc/base/virtualsocketserver.h"
     17 
     18 namespace rtc {
     19 
     20 // Packs the given socketaddress into the buffer in buf, in the quasi-STUN
     21 // format that the natserver uses.
     22 // Returns 0 if an invalid address is passed.
     23 size_t PackAddressForNAT(char* buf, size_t buf_size,
     24                          const SocketAddress& remote_addr) {
     25   const IPAddress& ip = remote_addr.ipaddr();
     26   int family = ip.family();
     27   buf[0] = 0;
     28   buf[1] = family;
     29   // Writes the port.
     30   *(reinterpret_cast<uint16_t*>(&buf[2])) = HostToNetwork16(remote_addr.port());
     31   if (family == AF_INET) {
     32     ASSERT(buf_size >= kNATEncodedIPv4AddressSize);
     33     in_addr v4addr = ip.ipv4_address();
     34     memcpy(&buf[4], &v4addr, kNATEncodedIPv4AddressSize - 4);
     35     return kNATEncodedIPv4AddressSize;
     36   } else if (family == AF_INET6) {
     37     ASSERT(buf_size >= kNATEncodedIPv6AddressSize);
     38     in6_addr v6addr = ip.ipv6_address();
     39     memcpy(&buf[4], &v6addr, kNATEncodedIPv6AddressSize - 4);
     40     return kNATEncodedIPv6AddressSize;
     41   }
     42   return 0U;
     43 }
     44 
     45 // Decodes the remote address from a packet that has been encoded with the nat's
     46 // quasi-STUN format. Returns the length of the address (i.e., the offset into
     47 // data where the original packet starts).
     48 size_t UnpackAddressFromNAT(const char* buf, size_t buf_size,
     49                             SocketAddress* remote_addr) {
     50   ASSERT(buf_size >= 8);
     51   ASSERT(buf[0] == 0);
     52   int family = buf[1];
     53   uint16_t port =
     54       NetworkToHost16(*(reinterpret_cast<const uint16_t*>(&buf[2])));
     55   if (family == AF_INET) {
     56     const in_addr* v4addr = reinterpret_cast<const in_addr*>(&buf[4]);
     57     *remote_addr = SocketAddress(IPAddress(*v4addr), port);
     58     return kNATEncodedIPv4AddressSize;
     59   } else if (family == AF_INET6) {
     60     ASSERT(buf_size >= 20);
     61     const in6_addr* v6addr = reinterpret_cast<const in6_addr*>(&buf[4]);
     62     *remote_addr = SocketAddress(IPAddress(*v6addr), port);
     63     return kNATEncodedIPv6AddressSize;
     64   }
     65   return 0U;
     66 }
     67 
     68 
     69 // NATSocket
     70 class NATSocket : public AsyncSocket, public sigslot::has_slots<> {
     71  public:
     72   explicit NATSocket(NATInternalSocketFactory* sf, int family, int type)
     73       : sf_(sf), family_(family), type_(type), connected_(false),
     74         socket_(NULL), buf_(NULL), size_(0) {
     75   }
     76 
     77   ~NATSocket() override {
     78     delete socket_;
     79     delete[] buf_;
     80   }
     81 
     82   SocketAddress GetLocalAddress() const override {
     83     return (socket_) ? socket_->GetLocalAddress() : SocketAddress();
     84   }
     85 
     86   SocketAddress GetRemoteAddress() const override {
     87     return remote_addr_;  // will be NIL if not connected
     88   }
     89 
     90   int Bind(const SocketAddress& addr) override {
     91     if (socket_) {  // already bound, bubble up error
     92       return -1;
     93     }
     94 
     95     int result;
     96     socket_ = sf_->CreateInternalSocket(family_, type_, addr, &server_addr_);
     97     result = (socket_) ? socket_->Bind(addr) : -1;
     98     if (result >= 0) {
     99       socket_->SignalConnectEvent.connect(this, &NATSocket::OnConnectEvent);
    100       socket_->SignalReadEvent.connect(this, &NATSocket::OnReadEvent);
    101       socket_->SignalWriteEvent.connect(this, &NATSocket::OnWriteEvent);
    102       socket_->SignalCloseEvent.connect(this, &NATSocket::OnCloseEvent);
    103     } else {
    104       server_addr_.Clear();
    105       delete socket_;
    106       socket_ = NULL;
    107     }
    108 
    109     return result;
    110   }
    111 
    112   int Connect(const SocketAddress& addr) override {
    113     if (!socket_) {  // socket must be bound, for now
    114       return -1;
    115     }
    116 
    117     int result = 0;
    118     if (type_ == SOCK_STREAM) {
    119       result = socket_->Connect(server_addr_.IsNil() ? addr : server_addr_);
    120     } else {
    121       connected_ = true;
    122     }
    123 
    124     if (result >= 0) {
    125       remote_addr_ = addr;
    126     }
    127 
    128     return result;
    129   }
    130 
    131   int Send(const void* data, size_t size) override {
    132     ASSERT(connected_);
    133     return SendTo(data, size, remote_addr_);
    134   }
    135 
    136   int SendTo(const void* data,
    137              size_t size,
    138              const SocketAddress& addr) override {
    139     ASSERT(!connected_ || addr == remote_addr_);
    140     if (server_addr_.IsNil() || type_ == SOCK_STREAM) {
    141       return socket_->SendTo(data, size, addr);
    142     }
    143     // This array will be too large for IPv4 packets, but only by 12 bytes.
    144     scoped_ptr<char[]> buf(new char[size + kNATEncodedIPv6AddressSize]);
    145     size_t addrlength = PackAddressForNAT(buf.get(),
    146                                           size + kNATEncodedIPv6AddressSize,
    147                                           addr);
    148     size_t encoded_size = size + addrlength;
    149     memcpy(buf.get() + addrlength, data, size);
    150     int result = socket_->SendTo(buf.get(), encoded_size, server_addr_);
    151     if (result >= 0) {
    152       ASSERT(result == static_cast<int>(encoded_size));
    153       result = result - static_cast<int>(addrlength);
    154     }
    155     return result;
    156   }
    157 
    158   int Recv(void* data, size_t size) override {
    159     SocketAddress addr;
    160     return RecvFrom(data, size, &addr);
    161   }
    162 
    163   int RecvFrom(void* data, size_t size, SocketAddress* out_addr) override {
    164     if (server_addr_.IsNil() || type_ == SOCK_STREAM) {
    165       return socket_->RecvFrom(data, size, out_addr);
    166     }
    167     // Make sure we have enough room to read the requested amount plus the
    168     // largest possible header address.
    169     SocketAddress remote_addr;
    170     Grow(size + kNATEncodedIPv6AddressSize);
    171 
    172     // Read the packet from the socket.
    173     int result = socket_->RecvFrom(buf_, size_, &remote_addr);
    174     if (result >= 0) {
    175       ASSERT(remote_addr == server_addr_);
    176 
    177       // TODO: we need better framing so we know how many bytes we can
    178       // return before we need to read the next address. For UDP, this will be
    179       // fine as long as the reader always reads everything in the packet.
    180       ASSERT((size_t)result < size_);
    181 
    182       // Decode the wire packet into the actual results.
    183       SocketAddress real_remote_addr;
    184       size_t addrlength = UnpackAddressFromNAT(buf_, result, &real_remote_addr);
    185       memcpy(data, buf_ + addrlength, result - addrlength);
    186 
    187       // Make sure this packet should be delivered before returning it.
    188       if (!connected_ || (real_remote_addr == remote_addr_)) {
    189         if (out_addr)
    190           *out_addr = real_remote_addr;
    191         result = result - static_cast<int>(addrlength);
    192       } else {
    193         LOG(LS_ERROR) << "Dropping packet from unknown remote address: "
    194                       << real_remote_addr.ToString();
    195         result = 0;  // Tell the caller we didn't read anything
    196       }
    197     }
    198 
    199     return result;
    200   }
    201 
    202   int Close() override {
    203     int result = 0;
    204     if (socket_) {
    205       result = socket_->Close();
    206       if (result >= 0) {
    207         connected_ = false;
    208         remote_addr_ = SocketAddress();
    209         delete socket_;
    210         socket_ = NULL;
    211       }
    212     }
    213     return result;
    214   }
    215 
    216   int Listen(int backlog) override { return socket_->Listen(backlog); }
    217   AsyncSocket* Accept(SocketAddress* paddr) override {
    218     return socket_->Accept(paddr);
    219   }
    220   int GetError() const override { return socket_->GetError(); }
    221   void SetError(int error) override { socket_->SetError(error); }
    222   ConnState GetState() const override {
    223     return connected_ ? CS_CONNECTED : CS_CLOSED;
    224   }
    225   int EstimateMTU(uint16_t* mtu) override { return socket_->EstimateMTU(mtu); }
    226   int GetOption(Option opt, int* value) override {
    227     return socket_->GetOption(opt, value);
    228   }
    229   int SetOption(Option opt, int value) override {
    230     return socket_->SetOption(opt, value);
    231   }
    232 
    233   void OnConnectEvent(AsyncSocket* socket) {
    234     // If we're NATed, we need to send a message with the real addr to use.
    235     ASSERT(socket == socket_);
    236     if (server_addr_.IsNil()) {
    237       connected_ = true;
    238       SignalConnectEvent(this);
    239     } else {
    240       SendConnectRequest();
    241     }
    242   }
    243   void OnReadEvent(AsyncSocket* socket) {
    244     // If we're NATed, we need to process the connect reply.
    245     ASSERT(socket == socket_);
    246     if (type_ == SOCK_STREAM && !server_addr_.IsNil() && !connected_) {
    247       HandleConnectReply();
    248     } else {
    249       SignalReadEvent(this);
    250     }
    251   }
    252   void OnWriteEvent(AsyncSocket* socket) {
    253     ASSERT(socket == socket_);
    254     SignalWriteEvent(this);
    255   }
    256   void OnCloseEvent(AsyncSocket* socket, int error) {
    257     ASSERT(socket == socket_);
    258     SignalCloseEvent(this, error);
    259   }
    260 
    261  private:
    262   // Makes sure the buffer is at least the given size.
    263   void Grow(size_t new_size) {
    264     if (size_ < new_size) {
    265       delete[] buf_;
    266       size_ = new_size;
    267       buf_ = new char[size_];
    268     }
    269   }
    270 
    271   // Sends the destination address to the server to tell it to connect.
    272   void SendConnectRequest() {
    273     char buf[kNATEncodedIPv6AddressSize];
    274     size_t length = PackAddressForNAT(buf, arraysize(buf), remote_addr_);
    275     socket_->Send(buf, length);
    276   }
    277 
    278   // Handles the byte sent back from the server and fires the appropriate event.
    279   void HandleConnectReply() {
    280     char code;
    281     socket_->Recv(&code, sizeof(code));
    282     if (code == 0) {
    283       connected_ = true;
    284       SignalConnectEvent(this);
    285     } else {
    286       Close();
    287       SignalCloseEvent(this, code);
    288     }
    289   }
    290 
    291   NATInternalSocketFactory* sf_;
    292   int family_;
    293   int type_;
    294   bool connected_;
    295   SocketAddress remote_addr_;
    296   SocketAddress server_addr_;  // address of the NAT server
    297   AsyncSocket* socket_;
    298   char* buf_;
    299   size_t size_;
    300 };
    301 
    302 // NATSocketFactory
    303 NATSocketFactory::NATSocketFactory(SocketFactory* factory,
    304                                    const SocketAddress& nat_udp_addr,
    305                                    const SocketAddress& nat_tcp_addr)
    306     : factory_(factory), nat_udp_addr_(nat_udp_addr),
    307       nat_tcp_addr_(nat_tcp_addr) {
    308 }
    309 
    310 Socket* NATSocketFactory::CreateSocket(int type) {
    311   return CreateSocket(AF_INET, type);
    312 }
    313 
    314 Socket* NATSocketFactory::CreateSocket(int family, int type) {
    315   return new NATSocket(this, family, type);
    316 }
    317 
    318 AsyncSocket* NATSocketFactory::CreateAsyncSocket(int type) {
    319   return CreateAsyncSocket(AF_INET, type);
    320 }
    321 
    322 AsyncSocket* NATSocketFactory::CreateAsyncSocket(int family, int type) {
    323   return new NATSocket(this, family, type);
    324 }
    325 
    326 AsyncSocket* NATSocketFactory::CreateInternalSocket(int family, int type,
    327     const SocketAddress& local_addr, SocketAddress* nat_addr) {
    328   if (type == SOCK_STREAM) {
    329     *nat_addr = nat_tcp_addr_;
    330   } else {
    331     *nat_addr = nat_udp_addr_;
    332   }
    333   return factory_->CreateAsyncSocket(family, type);
    334 }
    335 
    336 // NATSocketServer
    337 NATSocketServer::NATSocketServer(SocketServer* server)
    338     : server_(server), msg_queue_(NULL) {
    339 }
    340 
    341 NATSocketServer::Translator* NATSocketServer::GetTranslator(
    342     const SocketAddress& ext_ip) {
    343   return nats_.Get(ext_ip);
    344 }
    345 
    346 NATSocketServer::Translator* NATSocketServer::AddTranslator(
    347     const SocketAddress& ext_ip, const SocketAddress& int_ip, NATType type) {
    348   // Fail if a translator already exists with this extternal address.
    349   if (nats_.Get(ext_ip))
    350     return NULL;
    351 
    352   return nats_.Add(ext_ip, new Translator(this, type, int_ip, server_, ext_ip));
    353 }
    354 
    355 void NATSocketServer::RemoveTranslator(
    356     const SocketAddress& ext_ip) {
    357   nats_.Remove(ext_ip);
    358 }
    359 
    360 Socket* NATSocketServer::CreateSocket(int type) {
    361   return CreateSocket(AF_INET, type);
    362 }
    363 
    364 Socket* NATSocketServer::CreateSocket(int family, int type) {
    365   return new NATSocket(this, family, type);
    366 }
    367 
    368 AsyncSocket* NATSocketServer::CreateAsyncSocket(int type) {
    369   return CreateAsyncSocket(AF_INET, type);
    370 }
    371 
    372 AsyncSocket* NATSocketServer::CreateAsyncSocket(int family, int type) {
    373   return new NATSocket(this, family, type);
    374 }
    375 
    376 void NATSocketServer::SetMessageQueue(MessageQueue* queue) {
    377   msg_queue_ = queue;
    378   server_->SetMessageQueue(queue);
    379 }
    380 
    381 bool NATSocketServer::Wait(int cms, bool process_io) {
    382   return server_->Wait(cms, process_io);
    383 }
    384 
    385 void NATSocketServer::WakeUp() {
    386   server_->WakeUp();
    387 }
    388 
    389 AsyncSocket* NATSocketServer::CreateInternalSocket(int family, int type,
    390     const SocketAddress& local_addr, SocketAddress* nat_addr) {
    391   AsyncSocket* socket = NULL;
    392   Translator* nat = nats_.FindClient(local_addr);
    393   if (nat) {
    394     socket = nat->internal_factory()->CreateAsyncSocket(family, type);
    395     *nat_addr = (type == SOCK_STREAM) ?
    396         nat->internal_tcp_address() : nat->internal_udp_address();
    397   } else {
    398     socket = server_->CreateAsyncSocket(family, type);
    399   }
    400   return socket;
    401 }
    402 
    403 // NATSocketServer::Translator
    404 NATSocketServer::Translator::Translator(
    405     NATSocketServer* server, NATType type, const SocketAddress& int_ip,
    406     SocketFactory* ext_factory, const SocketAddress& ext_ip)
    407     : server_(server) {
    408   // Create a new private network, and a NATServer running on the private
    409   // network that bridges to the external network. Also tell the private
    410   // network to use the same message queue as us.
    411   VirtualSocketServer* internal_server = new VirtualSocketServer(server_);
    412   internal_server->SetMessageQueue(server_->queue());
    413   internal_factory_.reset(internal_server);
    414   nat_server_.reset(new NATServer(type, internal_server, int_ip, int_ip,
    415                                   ext_factory, ext_ip));
    416 }
    417 
    418 NATSocketServer::Translator::~Translator() = default;
    419 
    420 NATSocketServer::Translator* NATSocketServer::Translator::GetTranslator(
    421     const SocketAddress& ext_ip) {
    422   return nats_.Get(ext_ip);
    423 }
    424 
    425 NATSocketServer::Translator* NATSocketServer::Translator::AddTranslator(
    426     const SocketAddress& ext_ip, const SocketAddress& int_ip, NATType type) {
    427   // Fail if a translator already exists with this extternal address.
    428   if (nats_.Get(ext_ip))
    429     return NULL;
    430 
    431   AddClient(ext_ip);
    432   return nats_.Add(ext_ip,
    433                    new Translator(server_, type, int_ip, server_, ext_ip));
    434 }
    435 void NATSocketServer::Translator::RemoveTranslator(
    436     const SocketAddress& ext_ip) {
    437   nats_.Remove(ext_ip);
    438   RemoveClient(ext_ip);
    439 }
    440 
    441 bool NATSocketServer::Translator::AddClient(
    442     const SocketAddress& int_ip) {
    443   // Fail if a client already exists with this internal address.
    444   if (clients_.find(int_ip) != clients_.end())
    445     return false;
    446 
    447   clients_.insert(int_ip);
    448   return true;
    449 }
    450 
    451 void NATSocketServer::Translator::RemoveClient(
    452     const SocketAddress& int_ip) {
    453   std::set<SocketAddress>::iterator it = clients_.find(int_ip);
    454   if (it != clients_.end()) {
    455     clients_.erase(it);
    456   }
    457 }
    458 
    459 NATSocketServer::Translator* NATSocketServer::Translator::FindClient(
    460     const SocketAddress& int_ip) {
    461   // See if we have the requested IP, or any of our children do.
    462   return (clients_.find(int_ip) != clients_.end()) ?
    463       this : nats_.FindClient(int_ip);
    464 }
    465 
    466 // NATSocketServer::TranslatorMap
    467 NATSocketServer::TranslatorMap::~TranslatorMap() {
    468   for (TranslatorMap::iterator it = begin(); it != end(); ++it) {
    469     delete it->second;
    470   }
    471 }
    472 
    473 NATSocketServer::Translator* NATSocketServer::TranslatorMap::Get(
    474     const SocketAddress& ext_ip) {
    475   TranslatorMap::iterator it = find(ext_ip);
    476   return (it != end()) ? it->second : NULL;
    477 }
    478 
    479 NATSocketServer::Translator* NATSocketServer::TranslatorMap::Add(
    480     const SocketAddress& ext_ip, Translator* nat) {
    481   (*this)[ext_ip] = nat;
    482   return nat;
    483 }
    484 
    485 void NATSocketServer::TranslatorMap::Remove(
    486     const SocketAddress& ext_ip) {
    487   TranslatorMap::iterator it = find(ext_ip);
    488   if (it != end()) {
    489     delete it->second;
    490     erase(it);
    491   }
    492 }
    493 
    494 NATSocketServer::Translator* NATSocketServer::TranslatorMap::FindClient(
    495     const SocketAddress& int_ip) {
    496   Translator* nat = NULL;
    497   for (TranslatorMap::iterator it = begin(); it != end() && !nat; ++it) {
    498     nat = it->second->FindClient(int_ip);
    499   }
    500   return nat;
    501 }
    502 
    503 }  // namespace rtc
    504