1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "net/socket/tcp_server_socket.h" 6 7 #include <string> 8 #include <vector> 9 10 #include "base/compiler_specific.h" 11 #include "base/memory/ref_counted.h" 12 #include "base/memory/scoped_ptr.h" 13 #include "net/base/address_list.h" 14 #include "net/base/io_buffer.h" 15 #include "net/base/ip_endpoint.h" 16 #include "net/base/net_errors.h" 17 #include "net/base/test_completion_callback.h" 18 #include "net/socket/tcp_client_socket.h" 19 #include "testing/gtest/include/gtest/gtest.h" 20 #include "testing/platform_test.h" 21 22 namespace net { 23 24 namespace { 25 const int kListenBacklog = 5; 26 27 class TCPServerSocketTest : public PlatformTest { 28 protected: 29 TCPServerSocketTest() 30 : socket_(NULL, NetLog::Source()) { 31 } 32 33 void SetUpIPv4() { 34 IPEndPoint address; 35 ParseAddress("127.0.0.1", 0, &address); 36 ASSERT_EQ(OK, socket_.Listen(address, kListenBacklog)); 37 ASSERT_EQ(OK, socket_.GetLocalAddress(&local_address_)); 38 } 39 40 void SetUpIPv6(bool* success) { 41 *success = false; 42 IPEndPoint address; 43 ParseAddress("::1", 0, &address); 44 if (socket_.Listen(address, kListenBacklog) != 0) { 45 LOG(ERROR) << "Failed to listen on ::1 - probably because IPv6 is " 46 "disabled. Skipping the test"; 47 return; 48 } 49 ASSERT_EQ(OK, socket_.GetLocalAddress(&local_address_)); 50 *success = true; 51 } 52 53 void ParseAddress(std::string ip_str, int port, IPEndPoint* address) { 54 IPAddressNumber ip_number; 55 bool rv = ParseIPLiteralToNumber(ip_str, &ip_number); 56 if (!rv) 57 return; 58 *address = IPEndPoint(ip_number, port); 59 } 60 61 static IPEndPoint GetPeerAddress(StreamSocket* socket) { 62 IPEndPoint address; 63 EXPECT_EQ(OK, socket->GetPeerAddress(&address)); 64 return address; 65 } 66 67 AddressList local_address_list() const { 68 return AddressList(local_address_); 69 } 70 71 TCPServerSocket socket_; 72 IPEndPoint local_address_; 73 }; 74 75 TEST_F(TCPServerSocketTest, Accept) { 76 ASSERT_NO_FATAL_FAILURE(SetUpIPv4()); 77 78 TestCompletionCallback connect_callback; 79 TCPClientSocket connecting_socket(local_address_list(), 80 NULL, NetLog::Source()); 81 connecting_socket.Connect(connect_callback.callback()); 82 83 TestCompletionCallback accept_callback; 84 scoped_ptr<StreamSocket> accepted_socket; 85 int result = socket_.Accept(&accepted_socket, accept_callback.callback()); 86 if (result == ERR_IO_PENDING) 87 result = accept_callback.WaitForResult(); 88 ASSERT_EQ(OK, result); 89 90 ASSERT_TRUE(accepted_socket.get() != NULL); 91 92 // Both sockets should be on the loopback network interface. 93 EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(), 94 local_address_.address()); 95 96 EXPECT_EQ(OK, connect_callback.WaitForResult()); 97 } 98 99 // Test Accept() callback. 100 TEST_F(TCPServerSocketTest, AcceptAsync) { 101 ASSERT_NO_FATAL_FAILURE(SetUpIPv4()); 102 103 TestCompletionCallback accept_callback; 104 scoped_ptr<StreamSocket> accepted_socket; 105 106 ASSERT_EQ(ERR_IO_PENDING, 107 socket_.Accept(&accepted_socket, accept_callback.callback())); 108 109 TestCompletionCallback connect_callback; 110 TCPClientSocket connecting_socket(local_address_list(), 111 NULL, NetLog::Source()); 112 connecting_socket.Connect(connect_callback.callback()); 113 114 EXPECT_EQ(OK, connect_callback.WaitForResult()); 115 EXPECT_EQ(OK, accept_callback.WaitForResult()); 116 117 EXPECT_TRUE(accepted_socket != NULL); 118 119 // Both sockets should be on the loopback network interface. 120 EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(), 121 local_address_.address()); 122 } 123 124 // Accept two connections simultaneously. 125 TEST_F(TCPServerSocketTest, Accept2Connections) { 126 ASSERT_NO_FATAL_FAILURE(SetUpIPv4()); 127 128 TestCompletionCallback accept_callback; 129 scoped_ptr<StreamSocket> accepted_socket; 130 131 ASSERT_EQ(ERR_IO_PENDING, 132 socket_.Accept(&accepted_socket, accept_callback.callback())); 133 134 TestCompletionCallback connect_callback; 135 TCPClientSocket connecting_socket(local_address_list(), 136 NULL, NetLog::Source()); 137 connecting_socket.Connect(connect_callback.callback()); 138 139 TestCompletionCallback connect_callback2; 140 TCPClientSocket connecting_socket2(local_address_list(), 141 NULL, NetLog::Source()); 142 connecting_socket2.Connect(connect_callback2.callback()); 143 144 EXPECT_EQ(OK, accept_callback.WaitForResult()); 145 146 TestCompletionCallback accept_callback2; 147 scoped_ptr<StreamSocket> accepted_socket2; 148 int result = socket_.Accept(&accepted_socket2, accept_callback2.callback()); 149 if (result == ERR_IO_PENDING) 150 result = accept_callback2.WaitForResult(); 151 ASSERT_EQ(OK, result); 152 153 EXPECT_EQ(OK, connect_callback.WaitForResult()); 154 155 EXPECT_TRUE(accepted_socket != NULL); 156 EXPECT_TRUE(accepted_socket2 != NULL); 157 EXPECT_NE(accepted_socket.get(), accepted_socket2.get()); 158 159 EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(), 160 local_address_.address()); 161 EXPECT_EQ(GetPeerAddress(accepted_socket2.get()).address(), 162 local_address_.address()); 163 } 164 165 TEST_F(TCPServerSocketTest, AcceptIPv6) { 166 bool initialized = false; 167 ASSERT_NO_FATAL_FAILURE(SetUpIPv6(&initialized)); 168 if (!initialized) 169 return; 170 171 TestCompletionCallback connect_callback; 172 TCPClientSocket connecting_socket(local_address_list(), 173 NULL, NetLog::Source()); 174 connecting_socket.Connect(connect_callback.callback()); 175 176 TestCompletionCallback accept_callback; 177 scoped_ptr<StreamSocket> accepted_socket; 178 int result = socket_.Accept(&accepted_socket, accept_callback.callback()); 179 if (result == ERR_IO_PENDING) 180 result = accept_callback.WaitForResult(); 181 ASSERT_EQ(OK, result); 182 183 ASSERT_TRUE(accepted_socket.get() != NULL); 184 185 // Both sockets should be on the loopback network interface. 186 EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(), 187 local_address_.address()); 188 189 EXPECT_EQ(OK, connect_callback.WaitForResult()); 190 } 191 192 TEST_F(TCPServerSocketTest, AcceptIO) { 193 ASSERT_NO_FATAL_FAILURE(SetUpIPv4()); 194 195 TestCompletionCallback connect_callback; 196 TCPClientSocket connecting_socket(local_address_list(), 197 NULL, NetLog::Source()); 198 connecting_socket.Connect(connect_callback.callback()); 199 200 TestCompletionCallback accept_callback; 201 scoped_ptr<StreamSocket> accepted_socket; 202 int result = socket_.Accept(&accepted_socket, accept_callback.callback()); 203 ASSERT_EQ(OK, accept_callback.GetResult(result)); 204 205 ASSERT_TRUE(accepted_socket.get() != NULL); 206 207 // Both sockets should be on the loopback network interface. 208 EXPECT_EQ(GetPeerAddress(accepted_socket.get()).address(), 209 local_address_.address()); 210 211 EXPECT_EQ(OK, connect_callback.WaitForResult()); 212 213 const std::string message("test message"); 214 std::vector<char> buffer(message.size()); 215 216 size_t bytes_written = 0; 217 while (bytes_written < message.size()) { 218 scoped_refptr<net::IOBufferWithSize> write_buffer( 219 new net::IOBufferWithSize(message.size() - bytes_written)); 220 memmove(write_buffer->data(), message.data(), message.size()); 221 222 TestCompletionCallback write_callback; 223 int write_result = accepted_socket->Write( 224 write_buffer.get(), write_buffer->size(), write_callback.callback()); 225 write_result = write_callback.GetResult(write_result); 226 ASSERT_TRUE(write_result >= 0); 227 ASSERT_TRUE(bytes_written + write_result <= message.size()); 228 bytes_written += write_result; 229 } 230 231 size_t bytes_read = 0; 232 while (bytes_read < message.size()) { 233 scoped_refptr<net::IOBufferWithSize> read_buffer( 234 new net::IOBufferWithSize(message.size() - bytes_read)); 235 TestCompletionCallback read_callback; 236 int read_result = connecting_socket.Read( 237 read_buffer.get(), read_buffer->size(), read_callback.callback()); 238 read_result = read_callback.GetResult(read_result); 239 ASSERT_TRUE(read_result >= 0); 240 ASSERT_TRUE(bytes_read + read_result <= message.size()); 241 memmove(&buffer[bytes_read], read_buffer->data(), read_result); 242 bytes_read += read_result; 243 } 244 245 std::string received_message(buffer.begin(), buffer.end()); 246 ASSERT_EQ(message, received_message); 247 } 248 249 } // namespace 250 251 } // namespace net 252