Home | History | Annotate | Download | only in socket
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/socket/tcp_server_socket.h"
      6 
      7 #include <string>
      8 #include <vector>
      9 
     10 #include "base/compiler_specific.h"
     11 #include "base/memory/ref_counted.h"
     12 #include "base/memory/scoped_ptr.h"
     13 #include "net/base/address_list.h"
     14 #include "net/base/io_buffer.h"
     15 #include "net/base/ip_endpoint.h"
     16 #include "net/base/net_errors.h"
     17 #include "net/base/test_completion_callback.h"
     18 #include "net/socket/tcp_client_socket.h"
     19 #include "testing/gtest/include/gtest/gtest.h"
     20 #include "testing/platform_test.h"
     21 
     22 namespace net {
     23 
     24 namespace {
     25 const int kListenBacklog = 5;
     26 
     27 class TCPServerSocketTest : public PlatformTest {
     28  protected:
     29   TCPServerSocketTest()
     30       : socket_(NULL, NetLog::Source()) {
     31   }
     32 
     33   void SetUpIPv4() {
     34     IPEndPoint address;
     35     ParseAddress("127.0.0.1", 0, &address);
     36     ASSERT_EQ(OK, socket_.Listen(address, kListenBacklog));
     37     ASSERT_EQ(OK, socket_.GetLocalAddress(&local_address_));
     38   }
     39 
     40   void SetUpIPv6(bool* success) {
     41     *success = false;
     42     IPEndPoint address;
     43     ParseAddress("::1", 0, &address);
     44     if (socket_.Listen(address, kListenBacklog) != 0) {
     45       LOG(ERROR) << "Failed to listen on ::1 - probably because IPv6 is "
     46           "disabled. Skipping the test";
     47       return;
     48     }
     49     ASSERT_EQ(OK, socket_.GetLocalAddress(&local_address_));
     50     *success = true;
     51   }
     52 
     53   void ParseAddress(std::string ip_str, int port, IPEndPoint* address) {
     54     IPAddressNumber ip_number;
     55     bool rv = ParseIPLiteralToNumber(ip_str, &ip_number);
     56     if (!rv)
     57       return;
     58     *address = IPEndPoint(ip_number, port);
     59   }
     60 
     61   static IPEndPoint GetPeerAddress(StreamSocket* socket) {
     62     IPEndPoint address;
     63     EXPECT_EQ(OK, socket->GetPeerAddress(&address));
     64     return address;
     65   }
     66 
     67   AddressList local_address_list() const {
     68     return AddressList(local_address_);
     69   }
     70 
     71   TCPServerSocket socket_;
     72   IPEndPoint local_address_;
     73 };
     74 
     75 TEST_F(TCPServerSocketTest, Accept) {
     76   ASSERT_NO_FATAL_FAILURE(SetUpIPv4());
     77 
     78   TestCompletionCallback connect_callback;
     79   TCPClientSocket connecting_socket(local_address_list(),
     80                                     NULL, NetLog::Source());
     81   connecting_socket.Connect(connect_callback.callback());
     82 
     83   TestCompletionCallback accept_callback;
     84   scoped_ptr<StreamSocket> accepted_socket;
     85   int result = socket_.Accept(&accepted_socket, accept_callback.callback());
     86   if (result == ERR_IO_PENDING)
     87     result = accept_callback.WaitForResult();
     88   ASSERT_EQ(OK, result);
     89 
     90   ASSERT_TRUE(accepted_socket.get() != NULL);
     91 
     92   // Both sockets should be on the loopback network interface.
     93   EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(),
     94             local_address_.address());
     95 
     96   EXPECT_EQ(OK, connect_callback.WaitForResult());
     97 }
     98 
     99 // Test Accept() callback.
    100 TEST_F(TCPServerSocketTest, AcceptAsync) {
    101   ASSERT_NO_FATAL_FAILURE(SetUpIPv4());
    102 
    103   TestCompletionCallback accept_callback;
    104   scoped_ptr<StreamSocket> accepted_socket;
    105 
    106   ASSERT_EQ(ERR_IO_PENDING,
    107             socket_.Accept(&accepted_socket, accept_callback.callback()));
    108 
    109   TestCompletionCallback connect_callback;
    110   TCPClientSocket connecting_socket(local_address_list(),
    111                                     NULL, NetLog::Source());
    112   connecting_socket.Connect(connect_callback.callback());
    113 
    114   EXPECT_EQ(OK, connect_callback.WaitForResult());
    115   EXPECT_EQ(OK, accept_callback.WaitForResult());
    116 
    117   EXPECT_TRUE(accepted_socket != NULL);
    118 
    119   // Both sockets should be on the loopback network interface.
    120   EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(),
    121             local_address_.address());
    122 }
    123 
    124 // Accept two connections simultaneously.
    125 TEST_F(TCPServerSocketTest, Accept2Connections) {
    126   ASSERT_NO_FATAL_FAILURE(SetUpIPv4());
    127 
    128   TestCompletionCallback accept_callback;
    129   scoped_ptr<StreamSocket> accepted_socket;
    130 
    131   ASSERT_EQ(ERR_IO_PENDING,
    132             socket_.Accept(&accepted_socket, accept_callback.callback()));
    133 
    134   TestCompletionCallback connect_callback;
    135   TCPClientSocket connecting_socket(local_address_list(),
    136                                     NULL, NetLog::Source());
    137   connecting_socket.Connect(connect_callback.callback());
    138 
    139   TestCompletionCallback connect_callback2;
    140   TCPClientSocket connecting_socket2(local_address_list(),
    141                                      NULL, NetLog::Source());
    142   connecting_socket2.Connect(connect_callback2.callback());
    143 
    144   EXPECT_EQ(OK, accept_callback.WaitForResult());
    145 
    146   TestCompletionCallback accept_callback2;
    147   scoped_ptr<StreamSocket> accepted_socket2;
    148   int result = socket_.Accept(&accepted_socket2, accept_callback2.callback());
    149   if (result == ERR_IO_PENDING)
    150     result = accept_callback2.WaitForResult();
    151   ASSERT_EQ(OK, result);
    152 
    153   EXPECT_EQ(OK, connect_callback.WaitForResult());
    154 
    155   EXPECT_TRUE(accepted_socket != NULL);
    156   EXPECT_TRUE(accepted_socket2 != NULL);
    157   EXPECT_NE(accepted_socket.get(), accepted_socket2.get());
    158 
    159   EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(),
    160             local_address_.address());
    161   EXPECT_EQ(GetPeerAddress(accepted_socket2.get()).address(),
    162             local_address_.address());
    163 }
    164 
    165 TEST_F(TCPServerSocketTest, AcceptIPv6) {
    166   bool initialized = false;
    167   ASSERT_NO_FATAL_FAILURE(SetUpIPv6(&initialized));
    168   if (!initialized)
    169     return;
    170 
    171   TestCompletionCallback connect_callback;
    172   TCPClientSocket connecting_socket(local_address_list(),
    173                                     NULL, NetLog::Source());
    174   connecting_socket.Connect(connect_callback.callback());
    175 
    176   TestCompletionCallback accept_callback;
    177   scoped_ptr<StreamSocket> accepted_socket;
    178   int result = socket_.Accept(&accepted_socket, accept_callback.callback());
    179   if (result == ERR_IO_PENDING)
    180     result = accept_callback.WaitForResult();
    181   ASSERT_EQ(OK, result);
    182 
    183   ASSERT_TRUE(accepted_socket.get() != NULL);
    184 
    185   // Both sockets should be on the loopback network interface.
    186   EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(),
    187             local_address_.address());
    188 
    189   EXPECT_EQ(OK, connect_callback.WaitForResult());
    190 }
    191 
    192 TEST_F(TCPServerSocketTest, AcceptIO) {
    193   ASSERT_NO_FATAL_FAILURE(SetUpIPv4());
    194 
    195   TestCompletionCallback connect_callback;
    196   TCPClientSocket connecting_socket(local_address_list(),
    197                                     NULL, NetLog::Source());
    198   connecting_socket.Connect(connect_callback.callback());
    199 
    200   TestCompletionCallback accept_callback;
    201   scoped_ptr<StreamSocket> accepted_socket;
    202   int result = socket_.Accept(&accepted_socket, accept_callback.callback());
    203   ASSERT_EQ(OK, accept_callback.GetResult(result));
    204 
    205   ASSERT_TRUE(accepted_socket.get() != NULL);
    206 
    207   // Both sockets should be on the loopback network interface.
    208   EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(),
    209             local_address_.address());
    210 
    211   EXPECT_EQ(OK, connect_callback.WaitForResult());
    212 
    213   const std::string message("test message");
    214   std::vector<char> buffer(message.size());
    215 
    216   size_t bytes_written = 0;
    217   while (bytes_written < message.size()) {
    218     scoped_refptr<net::IOBufferWithSize> write_buffer(
    219         new net::IOBufferWithSize(message.size() - bytes_written));
    220     memmove(write_buffer->data(), message.data(), message.size());
    221 
    222     TestCompletionCallback write_callback;
    223     int write_result = accepted_socket->Write(
    224         write_buffer.get(), write_buffer->size(), write_callback.callback());
    225     write_result = write_callback.GetResult(write_result);
    226     ASSERT_TRUE(write_result >= 0);
    227     ASSERT_TRUE(bytes_written + write_result <= message.size());
    228     bytes_written += write_result;
    229   }
    230 
    231   size_t bytes_read = 0;
    232   while (bytes_read < message.size()) {
    233     scoped_refptr<net::IOBufferWithSize> read_buffer(
    234         new net::IOBufferWithSize(message.size() - bytes_read));
    235     TestCompletionCallback read_callback;
    236     int read_result = connecting_socket.Read(
    237         read_buffer.get(), read_buffer->size(), read_callback.callback());
    238     read_result = read_callback.GetResult(read_result);
    239     ASSERT_TRUE(read_result >= 0);
    240     ASSERT_TRUE(bytes_read + read_result <= message.size());
    241     memmove(&buffer[bytes_read], read_buffer->data(), read_result);
    242     bytes_read += read_result;
    243   }
    244 
    245   std::string received_message(buffer.begin(), buffer.end());
    246   ASSERT_EQ(message, received_message);
    247 }
    248 
    249 }  // namespace
    250 
    251 }  // namespace net
    252