Home | History | Annotate | Download | only in udp
      1 // Copyright (c) 2011 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/udp/udp_client_socket.h"
      6 #include "net/udp/udp_server_socket.h"
      7 
      8 #include "base/basictypes.h"
      9 #include "base/metrics/histogram.h"
     10 #include "net/base/io_buffer.h"
     11 #include "net/base/ip_endpoint.h"
     12 #include "net/base/net_errors.h"
     13 #include "net/base/net_test_suite.h"
     14 #include "net/base/net_util.h"
     15 #include "net/base/sys_addrinfo.h"
     16 #include "net/base/test_completion_callback.h"
     17 #include "testing/gtest/include/gtest/gtest.h"
     18 #include "testing/platform_test.h"
     19 
     20 namespace net {
     21 
     22 namespace {
     23 
     24 class UDPSocketTest : public PlatformTest {
     25  public:
     26   UDPSocketTest()
     27       : buffer_(new IOBufferWithSize(kMaxRead)) {
     28   }
     29 
     30   // Blocks until data is read from the socket.
     31   std::string RecvFromSocket(UDPServerSocket* socket) {
     32     TestCompletionCallback callback;
     33 
     34     int rv = socket->RecvFrom(buffer_, kMaxRead, &recv_from_address_,
     35                               &callback);
     36     if (rv == ERR_IO_PENDING)
     37       rv = callback.WaitForResult();
     38     if (rv < 0)
     39       return "";  // error!
     40     return std::string(buffer_->data(), rv);
     41   }
     42 
     43   // Loop until |msg| has been written to the socket or until an
     44   // error occurs.
     45   // If |address| is specified, then it is used for the destination
     46   // to send to. Otherwise, will send to the last socket this server
     47   // received from.
     48   int SendToSocket(UDPServerSocket* socket, std::string msg) {
     49     return SendToSocket(socket, msg, recv_from_address_);
     50   }
     51 
     52   int SendToSocket(UDPServerSocket* socket,
     53                    std::string msg,
     54                    const IPEndPoint& address) {
     55     TestCompletionCallback callback;
     56 
     57     int length = msg.length();
     58     scoped_refptr<StringIOBuffer> io_buffer(new StringIOBuffer(msg));
     59     scoped_refptr<DrainableIOBuffer> buffer(
     60         new DrainableIOBuffer(io_buffer, length));
     61 
     62     int bytes_sent = 0;
     63     while (buffer->BytesRemaining()) {
     64       int rv = socket->SendTo(buffer, buffer->BytesRemaining(),
     65                               address, &callback);
     66       if (rv == ERR_IO_PENDING)
     67         rv = callback.WaitForResult();
     68       if (rv <= 0)
     69         return bytes_sent > 0 ? bytes_sent : rv;
     70       bytes_sent += rv;
     71       buffer->DidConsume(rv);
     72     }
     73     return bytes_sent;
     74   }
     75 
     76   std::string ReadSocket(UDPClientSocket* socket) {
     77     TestCompletionCallback callback;
     78 
     79     int rv = socket->Read(buffer_, kMaxRead, &callback);
     80     if (rv == ERR_IO_PENDING)
     81       rv = callback.WaitForResult();
     82     if (rv < 0)
     83       return "";  // error!
     84     return std::string(buffer_->data(), rv);
     85   }
     86 
     87   // Loop until |msg| has been written to the socket or until an
     88   // error occurs.
     89   int WriteSocket(UDPClientSocket* socket, std::string msg) {
     90     TestCompletionCallback callback;
     91 
     92     int length = msg.length();
     93     scoped_refptr<StringIOBuffer> io_buffer(new StringIOBuffer(msg));
     94     scoped_refptr<DrainableIOBuffer> buffer(
     95         new DrainableIOBuffer(io_buffer, length));
     96 
     97     int bytes_sent = 0;
     98     while (buffer->BytesRemaining()) {
     99       int rv = socket->Write(buffer, buffer->BytesRemaining(), &callback);
    100       if (rv == ERR_IO_PENDING)
    101         rv = callback.WaitForResult();
    102       if (rv <= 0)
    103         return bytes_sent > 0 ? bytes_sent : rv;
    104       bytes_sent += rv;
    105       buffer->DidConsume(rv);
    106     }
    107     return bytes_sent;
    108   }
    109 
    110  protected:
    111   static const int kMaxRead = 1024;
    112   scoped_refptr<IOBufferWithSize> buffer_;
    113   IPEndPoint recv_from_address_;
    114 };
    115 
    116 // Creates and address from an ip/port and returns it in |address|.
    117 void CreateUDPAddress(std::string ip_str, int port, IPEndPoint* address) {
    118   IPAddressNumber ip_number;
    119   bool rv = ParseIPLiteralToNumber(ip_str, &ip_number);
    120   if (!rv)
    121     return;
    122   *address = IPEndPoint(ip_number, port);
    123 }
    124 
    125 TEST_F(UDPSocketTest, Connect) {
    126   const int kPort = 9999;
    127   std::string simple_message("hello world!");
    128 
    129   // Setup the server to listen.
    130   IPEndPoint bind_address;
    131   CreateUDPAddress("0.0.0.0", kPort, &bind_address);
    132   UDPServerSocket server(NULL, NetLog::Source());
    133   int rv = server.Listen(bind_address);
    134   EXPECT_EQ(OK, rv);
    135 
    136   // Setup the client.
    137   IPEndPoint server_address;
    138   CreateUDPAddress("127.0.0.1", kPort, &server_address);
    139   UDPClientSocket client(NULL, NetLog::Source());
    140   rv = client.Connect(server_address);
    141   EXPECT_EQ(OK, rv);
    142 
    143   // Client sends to the server.
    144   rv = WriteSocket(&client, simple_message);
    145   EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
    146 
    147   // Server waits for message.
    148   std::string str = RecvFromSocket(&server);
    149   DCHECK(simple_message == str);
    150 
    151   // Server echoes reply.
    152   rv = SendToSocket(&server, simple_message);
    153   EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
    154 
    155   // Client waits for response.
    156   str = ReadSocket(&client);
    157   DCHECK(simple_message == str);
    158 }
    159 
    160 // In this test, we verify that connect() on a socket will have the effect
    161 // of filtering reads on this socket only to data read from the destination
    162 // we connected to.
    163 //
    164 // The purpose of this test is that some documentation indicates that connect
    165 // binds the client's sends to send to a particular server endpoint, but does
    166 // not bind the client's reads to only be from that endpoint, and that we need
    167 // to always use recvfrom() to disambiguate.
    168 TEST_F(UDPSocketTest, VerifyConnectBindsAddr) {
    169   const int kPort1 = 9999;
    170   const int kPort2 = 10000;
    171   std::string simple_message("hello world!");
    172   std::string foreign_message("BAD MESSAGE TO GET!!");
    173 
    174   // Setup the first server to listen.
    175   IPEndPoint bind_address;
    176   CreateUDPAddress("0.0.0.0", kPort1, &bind_address);
    177   UDPServerSocket server1(NULL, NetLog::Source());
    178   int rv = server1.Listen(bind_address);
    179   EXPECT_EQ(OK, rv);
    180 
    181   // Setup the second server to listen.
    182   CreateUDPAddress("0.0.0.0", kPort2, &bind_address);
    183   UDPServerSocket server2(NULL, NetLog::Source());
    184   rv = server2.Listen(bind_address);
    185   EXPECT_EQ(OK, rv);
    186 
    187   // Setup the client, connected to server 1.
    188   IPEndPoint server_address;
    189   CreateUDPAddress("127.0.0.1", kPort1, &server_address);
    190   UDPClientSocket client(NULL, NetLog::Source());
    191   rv = client.Connect(server_address);
    192   EXPECT_EQ(OK, rv);
    193 
    194   // Client sends to server1.
    195   rv = WriteSocket(&client, simple_message);
    196   EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
    197 
    198   // Server1 waits for message.
    199   std::string str = RecvFromSocket(&server1);
    200   DCHECK(simple_message == str);
    201 
    202   // Get the client's address.
    203   IPEndPoint client_address;
    204   rv = client.GetLocalAddress(&client_address);
    205   EXPECT_EQ(OK, rv);
    206 
    207   // Server2 sends reply.
    208   rv = SendToSocket(&server2, foreign_message,
    209                     client_address);
    210   EXPECT_EQ(foreign_message.length(), static_cast<size_t>(rv));
    211 
    212   // Server1 sends reply.
    213   rv = SendToSocket(&server1, simple_message,
    214                     client_address);
    215   EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv));
    216 
    217   // Client waits for response.
    218   str = ReadSocket(&client);
    219   DCHECK(simple_message == str);
    220 }
    221 
    222 TEST_F(UDPSocketTest, ClientGetLocalPeerAddresses) {
    223   struct TestData {
    224     std::string remote_address;
    225     std::string local_address;
    226     bool may_fail;
    227   } tests[] = {
    228     { "127.0.00.1", "127.0.0.1", false },
    229     { "192.168.1.1", "127.0.0.1", false },
    230     { "::1", "::1", true },
    231     { "2001:db8:0::42", "::1", true },
    232   };
    233   for (size_t i = 0; i < ARRAYSIZE_UNSAFE(tests); i++) {
    234     SCOPED_TRACE(std::string("Connecting from ") +  tests[i].local_address +
    235                  std::string(" to ") + tests[i].remote_address);
    236 
    237     net::IPAddressNumber ip_number;
    238     net::ParseIPLiteralToNumber(tests[i].remote_address, &ip_number);
    239     net::IPEndPoint remote_address(ip_number, 80);
    240     net::ParseIPLiteralToNumber(tests[i].local_address, &ip_number);
    241     net::IPEndPoint local_address(ip_number, 80);
    242 
    243     UDPClientSocket client(NULL, NetLog::Source());
    244     int rv = client.Connect(remote_address);
    245     if (tests[i].may_fail && rv == ERR_ADDRESS_UNREACHABLE) {
    246       // Connect() may return ERR_ADDRESS_UNREACHABLE for IPv6
    247       // addresses if IPv6 is not configured.
    248       continue;
    249     }
    250 
    251     EXPECT_LE(ERR_IO_PENDING, rv);
    252 
    253     IPEndPoint fetched_local_address;
    254     rv = client.GetLocalAddress(&fetched_local_address);
    255     EXPECT_EQ(OK, rv);
    256 
    257     // TODO(mbelshe): figure out how to verify the IP and port.
    258     //                The port is dynamically generated by the udp stack.
    259     //                The IP is the real IP of the client, not necessarily
    260     //                loopback.
    261     //EXPECT_EQ(local_address.address(), fetched_local_address.address());
    262 
    263     IPEndPoint fetched_remote_address;
    264     rv = client.GetPeerAddress(&fetched_remote_address);
    265     EXPECT_EQ(OK, rv);
    266 
    267     EXPECT_EQ(remote_address, fetched_remote_address);
    268   }
    269 }
    270 
    271 TEST_F(UDPSocketTest, ServerGetLocalAddress) {
    272   IPEndPoint bind_address;
    273   CreateUDPAddress("127.0.0.1", 0, &bind_address);
    274   UDPServerSocket server(NULL, NetLog::Source());
    275   int rv = server.Listen(bind_address);
    276   EXPECT_EQ(OK, rv);
    277 
    278   IPEndPoint local_address;
    279   rv = server.GetLocalAddress(&local_address);
    280   EXPECT_EQ(rv, 0);
    281 
    282   // Verify that port was allocated.
    283   EXPECT_GT(local_address.port(), 0);
    284   EXPECT_EQ(local_address.address(), bind_address.address());
    285 }
    286 
    287 TEST_F(UDPSocketTest, ServerGetPeerAddress) {
    288   IPEndPoint bind_address;
    289   CreateUDPAddress("127.0.0.1", 0, &bind_address);
    290   UDPServerSocket server(NULL, NetLog::Source());
    291   int rv = server.Listen(bind_address);
    292   EXPECT_EQ(OK, rv);
    293 
    294   IPEndPoint peer_address;
    295   rv = server.GetPeerAddress(&peer_address);
    296   EXPECT_EQ(rv, ERR_SOCKET_NOT_CONNECTED);
    297 }
    298 
    299 // Close the socket while read is pending.
    300 TEST_F(UDPSocketTest, CloseWithPendingRead) {
    301   IPEndPoint bind_address;
    302   CreateUDPAddress("127.0.0.1", 0, &bind_address);
    303   UDPServerSocket server(NULL, NetLog::Source());
    304   int rv = server.Listen(bind_address);
    305   EXPECT_EQ(OK, rv);
    306 
    307   TestCompletionCallback callback;
    308   IPEndPoint from;
    309   rv = server.RecvFrom(buffer_, kMaxRead, &from, &callback);
    310   EXPECT_EQ(rv, ERR_IO_PENDING);
    311 
    312   server.Close();
    313 
    314   EXPECT_FALSE(callback.have_result());
    315 }
    316 
    317 }  // namespace
    318 
    319 }  // namespace net
    320