1 // Copyright 2013 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "net/socket/tcp_socket.h" 6 7 #include <string.h> 8 9 #include <string> 10 #include <vector> 11 12 #include "base/memory/ref_counted.h" 13 #include "base/memory/scoped_ptr.h" 14 #include "net/base/address_list.h" 15 #include "net/base/io_buffer.h" 16 #include "net/base/ip_endpoint.h" 17 #include "net/base/net_errors.h" 18 #include "net/base/test_completion_callback.h" 19 #include "net/socket/tcp_client_socket.h" 20 #include "testing/gtest/include/gtest/gtest.h" 21 #include "testing/platform_test.h" 22 23 namespace net { 24 25 namespace { 26 const int kListenBacklog = 5; 27 28 class TCPSocketTest : public PlatformTest { 29 protected: 30 TCPSocketTest() : socket_(NULL, NetLog::Source()) { 31 } 32 33 void SetUpListenIPv4() { 34 IPEndPoint address; 35 ParseAddress("127.0.0.1", 0, &address); 36 37 ASSERT_EQ(OK, socket_.Open(ADDRESS_FAMILY_IPV4)); 38 ASSERT_EQ(OK, socket_.Bind(address)); 39 ASSERT_EQ(OK, socket_.Listen(kListenBacklog)); 40 ASSERT_EQ(OK, socket_.GetLocalAddress(&local_address_)); 41 } 42 43 void SetUpListenIPv6(bool* success) { 44 *success = false; 45 IPEndPoint address; 46 ParseAddress("::1", 0, &address); 47 48 if (socket_.Open(ADDRESS_FAMILY_IPV6) != OK || 49 socket_.Bind(address) != OK || 50 socket_.Listen(kListenBacklog) != OK) { 51 LOG(ERROR) << "Failed to listen on ::1 - probably because IPv6 is " 52 "disabled. Skipping the test"; 53 return; 54 } 55 ASSERT_EQ(OK, socket_.GetLocalAddress(&local_address_)); 56 *success = true; 57 } 58 59 void ParseAddress(const std::string& ip_str, int port, IPEndPoint* address) { 60 IPAddressNumber ip_number; 61 bool rv = ParseIPLiteralToNumber(ip_str, &ip_number); 62 if (!rv) 63 return; 64 *address = IPEndPoint(ip_number, port); 65 } 66 67 void TestAcceptAsync() { 68 TestCompletionCallback accept_callback; 69 scoped_ptr<TCPSocket> accepted_socket; 70 IPEndPoint accepted_address; 71 ASSERT_EQ(ERR_IO_PENDING, 72 socket_.Accept(&accepted_socket, &accepted_address, 73 accept_callback.callback())); 74 75 TestCompletionCallback connect_callback; 76 TCPClientSocket connecting_socket(local_address_list(), 77 NULL, NetLog::Source()); 78 connecting_socket.Connect(connect_callback.callback()); 79 80 EXPECT_EQ(OK, connect_callback.WaitForResult()); 81 EXPECT_EQ(OK, accept_callback.WaitForResult()); 82 83 EXPECT_TRUE(accepted_socket.get()); 84 85 // Both sockets should be on the loopback network interface. 86 EXPECT_EQ(accepted_address.address(), local_address_.address()); 87 } 88 89 AddressList local_address_list() const { 90 return AddressList(local_address_); 91 } 92 93 TCPSocket socket_; 94 IPEndPoint local_address_; 95 }; 96 97 // Test listening and accepting with a socket bound to an IPv4 address. 98 TEST_F(TCPSocketTest, Accept) { 99 ASSERT_NO_FATAL_FAILURE(SetUpListenIPv4()); 100 101 TestCompletionCallback connect_callback; 102 // TODO(yzshen): Switch to use TCPSocket when it supports client socket 103 // operations. 104 TCPClientSocket connecting_socket(local_address_list(), 105 NULL, NetLog::Source()); 106 connecting_socket.Connect(connect_callback.callback()); 107 108 TestCompletionCallback accept_callback; 109 scoped_ptr<TCPSocket> accepted_socket; 110 IPEndPoint accepted_address; 111 int result = socket_.Accept(&accepted_socket, &accepted_address, 112 accept_callback.callback()); 113 if (result == ERR_IO_PENDING) 114 result = accept_callback.WaitForResult(); 115 ASSERT_EQ(OK, result); 116 117 EXPECT_TRUE(accepted_socket.get()); 118 119 // Both sockets should be on the loopback network interface. 120 EXPECT_EQ(accepted_address.address(), local_address_.address()); 121 122 EXPECT_EQ(OK, connect_callback.WaitForResult()); 123 } 124 125 // Test Accept() callback. 126 TEST_F(TCPSocketTest, AcceptAsync) { 127 ASSERT_NO_FATAL_FAILURE(SetUpListenIPv4()); 128 TestAcceptAsync(); 129 } 130 131 #if defined(OS_WIN) 132 // Test Accept() for AdoptListenSocket. 133 TEST_F(TCPSocketTest, AcceptForAdoptedListenSocket) { 134 // Create a socket to be used with AdoptListenSocket. 135 SOCKET existing_socket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); 136 ASSERT_EQ(OK, socket_.AdoptListenSocket(existing_socket)); 137 138 IPEndPoint address; 139 ParseAddress("127.0.0.1", 0, &address); 140 SockaddrStorage storage; 141 ASSERT_TRUE(address.ToSockAddr(storage.addr, &storage.addr_len)); 142 ASSERT_EQ(0, bind(existing_socket, storage.addr, storage.addr_len)); 143 144 ASSERT_EQ(OK, socket_.Listen(kListenBacklog)); 145 ASSERT_EQ(OK, socket_.GetLocalAddress(&local_address_)); 146 147 TestAcceptAsync(); 148 } 149 #endif 150 151 // Accept two connections simultaneously. 152 TEST_F(TCPSocketTest, Accept2Connections) { 153 ASSERT_NO_FATAL_FAILURE(SetUpListenIPv4()); 154 155 TestCompletionCallback accept_callback; 156 scoped_ptr<TCPSocket> accepted_socket; 157 IPEndPoint accepted_address; 158 159 ASSERT_EQ(ERR_IO_PENDING, 160 socket_.Accept(&accepted_socket, &accepted_address, 161 accept_callback.callback())); 162 163 TestCompletionCallback connect_callback; 164 TCPClientSocket connecting_socket(local_address_list(), 165 NULL, NetLog::Source()); 166 connecting_socket.Connect(connect_callback.callback()); 167 168 TestCompletionCallback connect_callback2; 169 TCPClientSocket connecting_socket2(local_address_list(), 170 NULL, NetLog::Source()); 171 connecting_socket2.Connect(connect_callback2.callback()); 172 173 EXPECT_EQ(OK, accept_callback.WaitForResult()); 174 175 TestCompletionCallback accept_callback2; 176 scoped_ptr<TCPSocket> accepted_socket2; 177 IPEndPoint accepted_address2; 178 179 int result = socket_.Accept(&accepted_socket2, &accepted_address2, 180 accept_callback2.callback()); 181 if (result == ERR_IO_PENDING) 182 result = accept_callback2.WaitForResult(); 183 ASSERT_EQ(OK, result); 184 185 EXPECT_EQ(OK, connect_callback.WaitForResult()); 186 EXPECT_EQ(OK, connect_callback2.WaitForResult()); 187 188 EXPECT_TRUE(accepted_socket.get()); 189 EXPECT_TRUE(accepted_socket2.get()); 190 EXPECT_NE(accepted_socket.get(), accepted_socket2.get()); 191 192 EXPECT_EQ(accepted_address.address(), local_address_.address()); 193 EXPECT_EQ(accepted_address2.address(), local_address_.address()); 194 } 195 196 // Test listening and accepting with a socket bound to an IPv6 address. 197 TEST_F(TCPSocketTest, AcceptIPv6) { 198 bool initialized = false; 199 ASSERT_NO_FATAL_FAILURE(SetUpListenIPv6(&initialized)); 200 if (!initialized) 201 return; 202 203 TestCompletionCallback connect_callback; 204 TCPClientSocket connecting_socket(local_address_list(), 205 NULL, NetLog::Source()); 206 connecting_socket.Connect(connect_callback.callback()); 207 208 TestCompletionCallback accept_callback; 209 scoped_ptr<TCPSocket> accepted_socket; 210 IPEndPoint accepted_address; 211 int result = socket_.Accept(&accepted_socket, &accepted_address, 212 accept_callback.callback()); 213 if (result == ERR_IO_PENDING) 214 result = accept_callback.WaitForResult(); 215 ASSERT_EQ(OK, result); 216 217 EXPECT_TRUE(accepted_socket.get()); 218 219 // Both sockets should be on the loopback network interface. 220 EXPECT_EQ(accepted_address.address(), local_address_.address()); 221 222 EXPECT_EQ(OK, connect_callback.WaitForResult()); 223 } 224 225 TEST_F(TCPSocketTest, ReadWrite) { 226 ASSERT_NO_FATAL_FAILURE(SetUpListenIPv4()); 227 228 TestCompletionCallback connect_callback; 229 TCPSocket connecting_socket(NULL, NetLog::Source()); 230 int result = connecting_socket.Open(ADDRESS_FAMILY_IPV4); 231 ASSERT_EQ(OK, result); 232 connecting_socket.Connect(local_address_, connect_callback.callback()); 233 234 TestCompletionCallback accept_callback; 235 scoped_ptr<TCPSocket> accepted_socket; 236 IPEndPoint accepted_address; 237 result = socket_.Accept(&accepted_socket, &accepted_address, 238 accept_callback.callback()); 239 ASSERT_EQ(OK, accept_callback.GetResult(result)); 240 241 ASSERT_TRUE(accepted_socket.get()); 242 243 // Both sockets should be on the loopback network interface. 244 EXPECT_EQ(accepted_address.address(), local_address_.address()); 245 246 EXPECT_EQ(OK, connect_callback.WaitForResult()); 247 248 const std::string message("test message"); 249 std::vector<char> buffer(message.size()); 250 251 size_t bytes_written = 0; 252 while (bytes_written < message.size()) { 253 scoped_refptr<IOBufferWithSize> write_buffer( 254 new IOBufferWithSize(message.size() - bytes_written)); 255 memmove(write_buffer->data(), message.data() + bytes_written, 256 message.size() - bytes_written); 257 258 TestCompletionCallback write_callback; 259 int write_result = accepted_socket->Write( 260 write_buffer.get(), write_buffer->size(), write_callback.callback()); 261 write_result = write_callback.GetResult(write_result); 262 ASSERT_TRUE(write_result >= 0); 263 bytes_written += write_result; 264 ASSERT_TRUE(bytes_written <= message.size()); 265 } 266 267 size_t bytes_read = 0; 268 while (bytes_read < message.size()) { 269 scoped_refptr<IOBufferWithSize> read_buffer( 270 new IOBufferWithSize(message.size() - bytes_read)); 271 TestCompletionCallback read_callback; 272 int read_result = connecting_socket.Read( 273 read_buffer.get(), read_buffer->size(), read_callback.callback()); 274 read_result = read_callback.GetResult(read_result); 275 ASSERT_TRUE(read_result >= 0); 276 ASSERT_TRUE(bytes_read + read_result <= message.size()); 277 memmove(&buffer[bytes_read], read_buffer->data(), read_result); 278 bytes_read += read_result; 279 } 280 281 std::string received_message(buffer.begin(), buffer.end()); 282 ASSERT_EQ(message, received_message); 283 } 284 285 } // namespace 286 } // namespace net 287