Home | History | Annotate | Download | only in base
      1 /*
      2  *  Copyright 2004 The WebRTC Project Authors. All rights reserved.
      3  *
      4  *  Use of this source code is governed by a BSD-style license
      5  *  that can be found in the LICENSE file in the root of the source
      6  *  tree. An additional intellectual property rights grant can be found
      7  *  in the file PATENTS.  All contributing project authors may
      8  *  be found in the AUTHORS file in the root of the source tree.
      9  */
     10 
     11 #include <algorithm>
     12 #include <string>
     13 
     14 #include "webrtc/base/gunit.h"
     15 #include "webrtc/base/logging.h"
     16 #include "webrtc/base/natserver.h"
     17 #include "webrtc/base/natsocketfactory.h"
     18 #include "webrtc/base/nethelpers.h"
     19 #include "webrtc/base/network.h"
     20 #include "webrtc/base/physicalsocketserver.h"
     21 #include "webrtc/base/testclient.h"
     22 #include "webrtc/base/asynctcpsocket.h"
     23 #include "webrtc/base/virtualsocketserver.h"
     24 
     25 using namespace rtc;
     26 
     27 bool CheckReceive(
     28     TestClient* client, bool should_receive, const char* buf, size_t size) {
     29   return (should_receive) ?
     30       client->CheckNextPacket(buf, size, 0) :
     31       client->CheckNoPacket();
     32 }
     33 
     34 TestClient* CreateTestClient(
     35       SocketFactory* factory, const SocketAddress& local_addr) {
     36   AsyncUDPSocket* socket = AsyncUDPSocket::Create(factory, local_addr);
     37   return new TestClient(socket);
     38 }
     39 
     40 TestClient* CreateTCPTestClient(AsyncSocket* socket) {
     41   AsyncTCPSocket* packet_socket = new AsyncTCPSocket(socket, false);
     42   return new TestClient(packet_socket);
     43 }
     44 
     45 // Tests that when sending from internal_addr to external_addrs through the
     46 // NAT type specified by nat_type, all external addrs receive the sent packet
     47 // and, if exp_same is true, all use the same mapped-address on the NAT.
     48 void TestSend(
     49       SocketServer* internal, const SocketAddress& internal_addr,
     50       SocketServer* external, const SocketAddress external_addrs[4],
     51       NATType nat_type, bool exp_same) {
     52   Thread th_int(internal);
     53   Thread th_ext(external);
     54 
     55   SocketAddress server_addr = internal_addr;
     56   server_addr.SetPort(0);  // Auto-select a port
     57   NATServer* nat = new NATServer(nat_type, internal, server_addr, server_addr,
     58                                  external, external_addrs[0]);
     59   NATSocketFactory* natsf = new NATSocketFactory(internal,
     60                                                  nat->internal_udp_address(),
     61                                                  nat->internal_tcp_address());
     62 
     63   TestClient* in = CreateTestClient(natsf, internal_addr);
     64   TestClient* out[4];
     65   for (int i = 0; i < 4; i++)
     66     out[i] = CreateTestClient(external, external_addrs[i]);
     67 
     68   th_int.Start();
     69   th_ext.Start();
     70 
     71   const char* buf = "filter_test";
     72   size_t len = strlen(buf);
     73 
     74   in->SendTo(buf, len, out[0]->address());
     75   SocketAddress trans_addr;
     76   EXPECT_TRUE(out[0]->CheckNextPacket(buf, len, &trans_addr));
     77 
     78   for (int i = 1; i < 4; i++) {
     79     in->SendTo(buf, len, out[i]->address());
     80     SocketAddress trans_addr2;
     81     EXPECT_TRUE(out[i]->CheckNextPacket(buf, len, &trans_addr2));
     82     bool are_same = (trans_addr == trans_addr2);
     83     ASSERT_EQ(are_same, exp_same) << "same translated address";
     84     ASSERT_NE(AF_UNSPEC, trans_addr.family());
     85     ASSERT_NE(AF_UNSPEC, trans_addr2.family());
     86   }
     87 
     88   th_int.Stop();
     89   th_ext.Stop();
     90 
     91   delete nat;
     92   delete natsf;
     93   delete in;
     94   for (int i = 0; i < 4; i++)
     95     delete out[i];
     96 }
     97 
     98 // Tests that when sending from external_addrs to internal_addr, the packet
     99 // is delivered according to the specified filter_ip and filter_port rules.
    100 void TestRecv(
    101       SocketServer* internal, const SocketAddress& internal_addr,
    102       SocketServer* external, const SocketAddress external_addrs[4],
    103       NATType nat_type, bool filter_ip, bool filter_port) {
    104   Thread th_int(internal);
    105   Thread th_ext(external);
    106 
    107   SocketAddress server_addr = internal_addr;
    108   server_addr.SetPort(0);  // Auto-select a port
    109   NATServer* nat = new NATServer(nat_type, internal, server_addr, server_addr,
    110                                  external, external_addrs[0]);
    111   NATSocketFactory* natsf = new NATSocketFactory(internal,
    112                                                  nat->internal_udp_address(),
    113                                                  nat->internal_tcp_address());
    114 
    115   TestClient* in = CreateTestClient(natsf, internal_addr);
    116   TestClient* out[4];
    117   for (int i = 0; i < 4; i++)
    118     out[i] = CreateTestClient(external, external_addrs[i]);
    119 
    120   th_int.Start();
    121   th_ext.Start();
    122 
    123   const char* buf = "filter_test";
    124   size_t len = strlen(buf);
    125 
    126   in->SendTo(buf, len, out[0]->address());
    127   SocketAddress trans_addr;
    128   EXPECT_TRUE(out[0]->CheckNextPacket(buf, len, &trans_addr));
    129 
    130   out[1]->SendTo(buf, len, trans_addr);
    131   EXPECT_TRUE(CheckReceive(in, !filter_ip, buf, len));
    132 
    133   out[2]->SendTo(buf, len, trans_addr);
    134   EXPECT_TRUE(CheckReceive(in, !filter_port, buf, len));
    135 
    136   out[3]->SendTo(buf, len, trans_addr);
    137   EXPECT_TRUE(CheckReceive(in, !filter_ip && !filter_port, buf, len));
    138 
    139   th_int.Stop();
    140   th_ext.Stop();
    141 
    142   delete nat;
    143   delete natsf;
    144   delete in;
    145   for (int i = 0; i < 4; i++)
    146     delete out[i];
    147 }
    148 
    149 // Tests that NATServer allocates bindings properly.
    150 void TestBindings(
    151     SocketServer* internal, const SocketAddress& internal_addr,
    152     SocketServer* external, const SocketAddress external_addrs[4]) {
    153   TestSend(internal, internal_addr, external, external_addrs,
    154            NAT_OPEN_CONE, true);
    155   TestSend(internal, internal_addr, external, external_addrs,
    156            NAT_ADDR_RESTRICTED, true);
    157   TestSend(internal, internal_addr, external, external_addrs,
    158            NAT_PORT_RESTRICTED, true);
    159   TestSend(internal, internal_addr, external, external_addrs,
    160            NAT_SYMMETRIC, false);
    161 }
    162 
    163 // Tests that NATServer filters packets properly.
    164 void TestFilters(
    165     SocketServer* internal, const SocketAddress& internal_addr,
    166     SocketServer* external, const SocketAddress external_addrs[4]) {
    167   TestRecv(internal, internal_addr, external, external_addrs,
    168            NAT_OPEN_CONE, false, false);
    169   TestRecv(internal, internal_addr, external, external_addrs,
    170            NAT_ADDR_RESTRICTED, true, false);
    171   TestRecv(internal, internal_addr, external, external_addrs,
    172            NAT_PORT_RESTRICTED, true, true);
    173   TestRecv(internal, internal_addr, external, external_addrs,
    174            NAT_SYMMETRIC, true, true);
    175 }
    176 
    177 bool TestConnectivity(const SocketAddress& src, const IPAddress& dst) {
    178   // The physical NAT tests require connectivity to the selected ip from the
    179   // internal address used for the NAT. Things like firewalls can break that, so
    180   // check to see if it's worth even trying with this ip.
    181   scoped_ptr<PhysicalSocketServer> pss(new PhysicalSocketServer());
    182   scoped_ptr<AsyncSocket> client(pss->CreateAsyncSocket(src.family(),
    183                                                         SOCK_DGRAM));
    184   scoped_ptr<AsyncSocket> server(pss->CreateAsyncSocket(src.family(),
    185                                                         SOCK_DGRAM));
    186   if (client->Bind(SocketAddress(src.ipaddr(), 0)) != 0 ||
    187       server->Bind(SocketAddress(dst, 0)) != 0) {
    188     return false;
    189   }
    190   const char* buf = "hello other socket";
    191   size_t len = strlen(buf);
    192   int sent = client->SendTo(buf, len, server->GetLocalAddress());
    193   SocketAddress addr;
    194   const size_t kRecvBufSize = 64;
    195   char recvbuf[kRecvBufSize];
    196   Thread::Current()->SleepMs(100);
    197   int received = server->RecvFrom(recvbuf, kRecvBufSize, &addr);
    198   return received == sent && ::memcmp(buf, recvbuf, len) == 0;
    199 }
    200 
    201 void TestPhysicalInternal(const SocketAddress& int_addr) {
    202   BasicNetworkManager network_manager;
    203   network_manager.set_ipv6_enabled(true);
    204   network_manager.StartUpdating();
    205   // Process pending messages so the network list is updated.
    206   Thread::Current()->ProcessMessages(0);
    207 
    208   std::vector<Network*> networks;
    209   network_manager.GetNetworks(&networks);
    210   networks.erase(std::remove_if(networks.begin(), networks.end(),
    211                                 [](rtc::Network* network) {
    212                                   return rtc::kDefaultNetworkIgnoreMask &
    213                                          network->type();
    214                                 }),
    215                  networks.end());
    216   if (networks.empty()) {
    217     LOG(LS_WARNING) << "Not enough network adapters for test.";
    218     return;
    219   }
    220 
    221   SocketAddress ext_addr1(int_addr);
    222   SocketAddress ext_addr2;
    223   // Find an available IP with matching family. The test breaks if int_addr
    224   // can't talk to ip, so check for connectivity as well.
    225   for (std::vector<Network*>::iterator it = networks.begin();
    226       it != networks.end(); ++it) {
    227     const IPAddress& ip = (*it)->GetBestIP();
    228     if (ip.family() == int_addr.family() && TestConnectivity(int_addr, ip)) {
    229       ext_addr2.SetIP(ip);
    230       break;
    231     }
    232   }
    233   if (ext_addr2.IsNil()) {
    234     LOG(LS_WARNING) << "No available IP of same family as " << int_addr;
    235     return;
    236   }
    237 
    238   LOG(LS_INFO) << "selected ip " << ext_addr2.ipaddr();
    239 
    240   SocketAddress ext_addrs[4] = {
    241       SocketAddress(ext_addr1),
    242       SocketAddress(ext_addr2),
    243       SocketAddress(ext_addr1),
    244       SocketAddress(ext_addr2)
    245   };
    246 
    247   scoped_ptr<PhysicalSocketServer> int_pss(new PhysicalSocketServer());
    248   scoped_ptr<PhysicalSocketServer> ext_pss(new PhysicalSocketServer());
    249 
    250   TestBindings(int_pss.get(), int_addr, ext_pss.get(), ext_addrs);
    251   TestFilters(int_pss.get(), int_addr, ext_pss.get(), ext_addrs);
    252 }
    253 
    254 TEST(NatTest, TestPhysicalIPv4) {
    255   TestPhysicalInternal(SocketAddress("127.0.0.1", 0));
    256 }
    257 
    258 TEST(NatTest, TestPhysicalIPv6) {
    259   if (HasIPv6Enabled()) {
    260     TestPhysicalInternal(SocketAddress("::1", 0));
    261   } else {
    262     LOG(LS_WARNING) << "No IPv6, skipping";
    263   }
    264 }
    265 
    266 namespace {
    267 
    268 class TestVirtualSocketServer : public VirtualSocketServer {
    269  public:
    270   explicit TestVirtualSocketServer(SocketServer* ss)
    271       : VirtualSocketServer(ss),
    272         ss_(ss) {}
    273   // Expose this publicly
    274   IPAddress GetNextIP(int af) { return VirtualSocketServer::GetNextIP(af); }
    275 
    276  private:
    277   scoped_ptr<SocketServer> ss_;
    278 };
    279 
    280 }  // namespace
    281 
    282 void TestVirtualInternal(int family) {
    283   scoped_ptr<TestVirtualSocketServer> int_vss(new TestVirtualSocketServer(
    284       new PhysicalSocketServer()));
    285   scoped_ptr<TestVirtualSocketServer> ext_vss(new TestVirtualSocketServer(
    286       new PhysicalSocketServer()));
    287 
    288   SocketAddress int_addr;
    289   SocketAddress ext_addrs[4];
    290   int_addr.SetIP(int_vss->GetNextIP(family));
    291   ext_addrs[0].SetIP(ext_vss->GetNextIP(int_addr.family()));
    292   ext_addrs[1].SetIP(ext_vss->GetNextIP(int_addr.family()));
    293   ext_addrs[2].SetIP(ext_addrs[0].ipaddr());
    294   ext_addrs[3].SetIP(ext_addrs[1].ipaddr());
    295 
    296   TestBindings(int_vss.get(), int_addr, ext_vss.get(), ext_addrs);
    297   TestFilters(int_vss.get(), int_addr, ext_vss.get(), ext_addrs);
    298 }
    299 
    300 TEST(NatTest, TestVirtualIPv4) {
    301   TestVirtualInternal(AF_INET);
    302 }
    303 
    304 TEST(NatTest, TestVirtualIPv6) {
    305   if (HasIPv6Enabled()) {
    306     TestVirtualInternal(AF_INET6);
    307   } else {
    308     LOG(LS_WARNING) << "No IPv6, skipping";
    309   }
    310 }
    311 
    312 class NatTcpTest : public testing::Test, public sigslot::has_slots<> {
    313  public:
    314   NatTcpTest()
    315       : int_addr_("192.168.0.1", 0),
    316         ext_addr_("10.0.0.1", 0),
    317         connected_(false),
    318         int_pss_(new PhysicalSocketServer()),
    319         ext_pss_(new PhysicalSocketServer()),
    320         int_vss_(new TestVirtualSocketServer(int_pss_)),
    321         ext_vss_(new TestVirtualSocketServer(ext_pss_)),
    322         int_thread_(new Thread(int_vss_.get())),
    323         ext_thread_(new Thread(ext_vss_.get())),
    324         nat_(new NATServer(NAT_OPEN_CONE, int_vss_.get(), int_addr_, int_addr_,
    325                            ext_vss_.get(), ext_addr_)),
    326         natsf_(new NATSocketFactory(int_vss_.get(),
    327                                     nat_->internal_udp_address(),
    328                                     nat_->internal_tcp_address())) {
    329     int_thread_->Start();
    330     ext_thread_->Start();
    331   }
    332 
    333   void OnConnectEvent(AsyncSocket* socket) {
    334     connected_ = true;
    335   }
    336 
    337   void OnAcceptEvent(AsyncSocket* socket) {
    338     accepted_.reset(server_->Accept(NULL));
    339   }
    340 
    341   void OnCloseEvent(AsyncSocket* socket, int error) {
    342   }
    343 
    344   void ConnectEvents() {
    345     server_->SignalReadEvent.connect(this, &NatTcpTest::OnAcceptEvent);
    346     client_->SignalConnectEvent.connect(this, &NatTcpTest::OnConnectEvent);
    347   }
    348 
    349   SocketAddress int_addr_;
    350   SocketAddress ext_addr_;
    351   bool connected_;
    352   PhysicalSocketServer* int_pss_;
    353   PhysicalSocketServer* ext_pss_;
    354   rtc::scoped_ptr<TestVirtualSocketServer> int_vss_;
    355   rtc::scoped_ptr<TestVirtualSocketServer> ext_vss_;
    356   rtc::scoped_ptr<Thread> int_thread_;
    357   rtc::scoped_ptr<Thread> ext_thread_;
    358   rtc::scoped_ptr<NATServer> nat_;
    359   rtc::scoped_ptr<NATSocketFactory> natsf_;
    360   rtc::scoped_ptr<AsyncSocket> client_;
    361   rtc::scoped_ptr<AsyncSocket> server_;
    362   rtc::scoped_ptr<AsyncSocket> accepted_;
    363 };
    364 
    365 TEST_F(NatTcpTest, DISABLED_TestConnectOut) {
    366   server_.reset(ext_vss_->CreateAsyncSocket(SOCK_STREAM));
    367   server_->Bind(ext_addr_);
    368   server_->Listen(5);
    369 
    370   client_.reset(natsf_->CreateAsyncSocket(SOCK_STREAM));
    371   EXPECT_GE(0, client_->Bind(int_addr_));
    372   EXPECT_GE(0, client_->Connect(server_->GetLocalAddress()));
    373 
    374   ConnectEvents();
    375 
    376   EXPECT_TRUE_WAIT(connected_, 1000);
    377   EXPECT_EQ(client_->GetRemoteAddress(), server_->GetLocalAddress());
    378   EXPECT_EQ(accepted_->GetRemoteAddress().ipaddr(), ext_addr_.ipaddr());
    379 
    380   rtc::scoped_ptr<rtc::TestClient> in(CreateTCPTestClient(client_.release()));
    381   rtc::scoped_ptr<rtc::TestClient> out(
    382       CreateTCPTestClient(accepted_.release()));
    383 
    384   const char* buf = "test_packet";
    385   size_t len = strlen(buf);
    386 
    387   in->Send(buf, len);
    388   SocketAddress trans_addr;
    389   EXPECT_TRUE(out->CheckNextPacket(buf, len, &trans_addr));
    390 
    391   out->Send(buf, len);
    392   EXPECT_TRUE(in->CheckNextPacket(buf, len, &trans_addr));
    393 }
    394 // #endif
    395