Home | History | Annotate | Download | only in socket
      1 // Copyright 2013 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/socket/tcp_socket.h"
      6 
      7 #include <string.h>
      8 
      9 #include <string>
     10 #include <vector>
     11 
     12 #include "base/memory/ref_counted.h"
     13 #include "base/memory/scoped_ptr.h"
     14 #include "net/base/address_list.h"
     15 #include "net/base/io_buffer.h"
     16 #include "net/base/ip_endpoint.h"
     17 #include "net/base/net_errors.h"
     18 #include "net/base/test_completion_callback.h"
     19 #include "net/socket/tcp_client_socket.h"
     20 #include "testing/gtest/include/gtest/gtest.h"
     21 #include "testing/platform_test.h"
     22 
     23 namespace net {
     24 
     25 namespace {
     26 const int kListenBacklog = 5;
     27 
     28 class TCPSocketTest : public PlatformTest {
     29  protected:
     30   TCPSocketTest() : socket_(NULL, NetLog::Source()) {
     31   }
     32 
     33   void SetUpListenIPv4() {
     34     IPEndPoint address;
     35     ParseAddress("127.0.0.1", 0, &address);
     36 
     37     ASSERT_EQ(OK, socket_.Open(ADDRESS_FAMILY_IPV4));
     38     ASSERT_EQ(OK, socket_.Bind(address));
     39     ASSERT_EQ(OK, socket_.Listen(kListenBacklog));
     40     ASSERT_EQ(OK, socket_.GetLocalAddress(&local_address_));
     41   }
     42 
     43   void SetUpListenIPv6(bool* success) {
     44     *success = false;
     45     IPEndPoint address;
     46     ParseAddress("::1", 0, &address);
     47 
     48     if (socket_.Open(ADDRESS_FAMILY_IPV6) != OK ||
     49         socket_.Bind(address) != OK ||
     50         socket_.Listen(kListenBacklog) != OK) {
     51       LOG(ERROR) << "Failed to listen on ::1 - probably because IPv6 is "
     52           "disabled. Skipping the test";
     53       return;
     54     }
     55     ASSERT_EQ(OK, socket_.GetLocalAddress(&local_address_));
     56     *success = true;
     57   }
     58 
     59   void ParseAddress(const std::string& ip_str, int port, IPEndPoint* address) {
     60     IPAddressNumber ip_number;
     61     bool rv = ParseIPLiteralToNumber(ip_str, &ip_number);
     62     if (!rv)
     63       return;
     64     *address = IPEndPoint(ip_number, port);
     65   }
     66 
     67   void TestAcceptAsync() {
     68     TestCompletionCallback accept_callback;
     69     scoped_ptr<TCPSocket> accepted_socket;
     70     IPEndPoint accepted_address;
     71     ASSERT_EQ(ERR_IO_PENDING,
     72               socket_.Accept(&accepted_socket, &accepted_address,
     73                              accept_callback.callback()));
     74 
     75     TestCompletionCallback connect_callback;
     76     TCPClientSocket connecting_socket(local_address_list(),
     77                                       NULL, NetLog::Source());
     78     connecting_socket.Connect(connect_callback.callback());
     79 
     80     EXPECT_EQ(OK, connect_callback.WaitForResult());
     81     EXPECT_EQ(OK, accept_callback.WaitForResult());
     82 
     83     EXPECT_TRUE(accepted_socket.get());
     84 
     85     // Both sockets should be on the loopback network interface.
     86     EXPECT_EQ(accepted_address.address(), local_address_.address());
     87   }
     88 
     89   AddressList local_address_list() const {
     90     return AddressList(local_address_);
     91   }
     92 
     93   TCPSocket socket_;
     94   IPEndPoint local_address_;
     95 };
     96 
     97 // Test listening and accepting with a socket bound to an IPv4 address.
     98 TEST_F(TCPSocketTest, Accept) {
     99   ASSERT_NO_FATAL_FAILURE(SetUpListenIPv4());
    100 
    101   TestCompletionCallback connect_callback;
    102   // TODO(yzshen): Switch to use TCPSocket when it supports client socket
    103   // operations.
    104   TCPClientSocket connecting_socket(local_address_list(),
    105                                     NULL, NetLog::Source());
    106   connecting_socket.Connect(connect_callback.callback());
    107 
    108   TestCompletionCallback accept_callback;
    109   scoped_ptr<TCPSocket> accepted_socket;
    110   IPEndPoint accepted_address;
    111   int result = socket_.Accept(&accepted_socket, &accepted_address,
    112                               accept_callback.callback());
    113   if (result == ERR_IO_PENDING)
    114     result = accept_callback.WaitForResult();
    115   ASSERT_EQ(OK, result);
    116 
    117   EXPECT_TRUE(accepted_socket.get());
    118 
    119   // Both sockets should be on the loopback network interface.
    120   EXPECT_EQ(accepted_address.address(), local_address_.address());
    121 
    122   EXPECT_EQ(OK, connect_callback.WaitForResult());
    123 }
    124 
    125 // Test Accept() callback.
    126 TEST_F(TCPSocketTest, AcceptAsync) {
    127   ASSERT_NO_FATAL_FAILURE(SetUpListenIPv4());
    128   TestAcceptAsync();
    129 }
    130 
    131 #if defined(OS_WIN)
    132 // Test Accept() for AdoptListenSocket.
    133 TEST_F(TCPSocketTest, AcceptForAdoptedListenSocket) {
    134   // Create a socket to be used with AdoptListenSocket.
    135   SOCKET existing_socket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
    136   ASSERT_EQ(OK, socket_.AdoptListenSocket(existing_socket));
    137 
    138   IPEndPoint address;
    139   ParseAddress("127.0.0.1", 0, &address);
    140   SockaddrStorage storage;
    141   ASSERT_TRUE(address.ToSockAddr(storage.addr, &storage.addr_len));
    142   ASSERT_EQ(0, bind(existing_socket, storage.addr, storage.addr_len));
    143 
    144   ASSERT_EQ(OK, socket_.Listen(kListenBacklog));
    145   ASSERT_EQ(OK, socket_.GetLocalAddress(&local_address_));
    146 
    147   TestAcceptAsync();
    148 }
    149 #endif
    150 
    151 // Accept two connections simultaneously.
    152 TEST_F(TCPSocketTest, Accept2Connections) {
    153   ASSERT_NO_FATAL_FAILURE(SetUpListenIPv4());
    154 
    155   TestCompletionCallback accept_callback;
    156   scoped_ptr<TCPSocket> accepted_socket;
    157   IPEndPoint accepted_address;
    158 
    159   ASSERT_EQ(ERR_IO_PENDING,
    160             socket_.Accept(&accepted_socket, &accepted_address,
    161                            accept_callback.callback()));
    162 
    163   TestCompletionCallback connect_callback;
    164   TCPClientSocket connecting_socket(local_address_list(),
    165                                     NULL, NetLog::Source());
    166   connecting_socket.Connect(connect_callback.callback());
    167 
    168   TestCompletionCallback connect_callback2;
    169   TCPClientSocket connecting_socket2(local_address_list(),
    170                                      NULL, NetLog::Source());
    171   connecting_socket2.Connect(connect_callback2.callback());
    172 
    173   EXPECT_EQ(OK, accept_callback.WaitForResult());
    174 
    175   TestCompletionCallback accept_callback2;
    176   scoped_ptr<TCPSocket> accepted_socket2;
    177   IPEndPoint accepted_address2;
    178 
    179   int result = socket_.Accept(&accepted_socket2, &accepted_address2,
    180                               accept_callback2.callback());
    181   if (result == ERR_IO_PENDING)
    182     result = accept_callback2.WaitForResult();
    183   ASSERT_EQ(OK, result);
    184 
    185   EXPECT_EQ(OK, connect_callback.WaitForResult());
    186   EXPECT_EQ(OK, connect_callback2.WaitForResult());
    187 
    188   EXPECT_TRUE(accepted_socket.get());
    189   EXPECT_TRUE(accepted_socket2.get());
    190   EXPECT_NE(accepted_socket.get(), accepted_socket2.get());
    191 
    192   EXPECT_EQ(accepted_address.address(), local_address_.address());
    193   EXPECT_EQ(accepted_address2.address(), local_address_.address());
    194 }
    195 
    196 // Test listening and accepting with a socket bound to an IPv6 address.
    197 TEST_F(TCPSocketTest, AcceptIPv6) {
    198   bool initialized = false;
    199   ASSERT_NO_FATAL_FAILURE(SetUpListenIPv6(&initialized));
    200   if (!initialized)
    201     return;
    202 
    203   TestCompletionCallback connect_callback;
    204   TCPClientSocket connecting_socket(local_address_list(),
    205                                     NULL, NetLog::Source());
    206   connecting_socket.Connect(connect_callback.callback());
    207 
    208   TestCompletionCallback accept_callback;
    209   scoped_ptr<TCPSocket> accepted_socket;
    210   IPEndPoint accepted_address;
    211   int result = socket_.Accept(&accepted_socket, &accepted_address,
    212                               accept_callback.callback());
    213   if (result == ERR_IO_PENDING)
    214     result = accept_callback.WaitForResult();
    215   ASSERT_EQ(OK, result);
    216 
    217   EXPECT_TRUE(accepted_socket.get());
    218 
    219   // Both sockets should be on the loopback network interface.
    220   EXPECT_EQ(accepted_address.address(), local_address_.address());
    221 
    222   EXPECT_EQ(OK, connect_callback.WaitForResult());
    223 }
    224 
    225 TEST_F(TCPSocketTest, ReadWrite) {
    226   ASSERT_NO_FATAL_FAILURE(SetUpListenIPv4());
    227 
    228   TestCompletionCallback connect_callback;
    229   TCPSocket connecting_socket(NULL, NetLog::Source());
    230   int result = connecting_socket.Open(ADDRESS_FAMILY_IPV4);
    231   ASSERT_EQ(OK, result);
    232   connecting_socket.Connect(local_address_, connect_callback.callback());
    233 
    234   TestCompletionCallback accept_callback;
    235   scoped_ptr<TCPSocket> accepted_socket;
    236   IPEndPoint accepted_address;
    237   result = socket_.Accept(&accepted_socket, &accepted_address,
    238                           accept_callback.callback());
    239   ASSERT_EQ(OK, accept_callback.GetResult(result));
    240 
    241   ASSERT_TRUE(accepted_socket.get());
    242 
    243   // Both sockets should be on the loopback network interface.
    244   EXPECT_EQ(accepted_address.address(), local_address_.address());
    245 
    246   EXPECT_EQ(OK, connect_callback.WaitForResult());
    247 
    248   const std::string message("test message");
    249   std::vector<char> buffer(message.size());
    250 
    251   size_t bytes_written = 0;
    252   while (bytes_written < message.size()) {
    253     scoped_refptr<IOBufferWithSize> write_buffer(
    254         new IOBufferWithSize(message.size() - bytes_written));
    255     memmove(write_buffer->data(), message.data() + bytes_written,
    256             message.size() - bytes_written);
    257 
    258     TestCompletionCallback write_callback;
    259     int write_result = accepted_socket->Write(
    260         write_buffer.get(), write_buffer->size(), write_callback.callback());
    261     write_result = write_callback.GetResult(write_result);
    262     ASSERT_TRUE(write_result >= 0);
    263     bytes_written += write_result;
    264     ASSERT_TRUE(bytes_written <= message.size());
    265   }
    266 
    267   size_t bytes_read = 0;
    268   while (bytes_read < message.size()) {
    269     scoped_refptr<IOBufferWithSize> read_buffer(
    270         new IOBufferWithSize(message.size() - bytes_read));
    271     TestCompletionCallback read_callback;
    272     int read_result = connecting_socket.Read(
    273         read_buffer.get(), read_buffer->size(), read_callback.callback());
    274     read_result = read_callback.GetResult(read_result);
    275     ASSERT_TRUE(read_result >= 0);
    276     ASSERT_TRUE(bytes_read + read_result <= message.size());
    277     memmove(&buffer[bytes_read], read_buffer->data(), read_result);
    278     bytes_read += read_result;
    279   }
    280 
    281   std::string received_message(buffer.begin(), buffer.end());
    282   ASSERT_EQ(message, received_message);
    283 }
    284 
    285 }  // namespace
    286 }  // namespace net
    287