1 // Copyright (c) 2011 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "net/udp/udp_client_socket.h" 6 #include "net/udp/udp_server_socket.h" 7 8 #include "base/basictypes.h" 9 #include "base/metrics/histogram.h" 10 #include "net/base/io_buffer.h" 11 #include "net/base/ip_endpoint.h" 12 #include "net/base/net_errors.h" 13 #include "net/base/net_test_suite.h" 14 #include "net/base/net_util.h" 15 #include "net/base/sys_addrinfo.h" 16 #include "net/base/test_completion_callback.h" 17 #include "testing/gtest/include/gtest/gtest.h" 18 #include "testing/platform_test.h" 19 20 namespace net { 21 22 namespace { 23 24 class UDPSocketTest : public PlatformTest { 25 public: 26 UDPSocketTest() 27 : buffer_(new IOBufferWithSize(kMaxRead)) { 28 } 29 30 // Blocks until data is read from the socket. 31 std::string RecvFromSocket(UDPServerSocket* socket) { 32 TestCompletionCallback callback; 33 34 int rv = socket->RecvFrom(buffer_, kMaxRead, &recv_from_address_, 35 &callback); 36 if (rv == ERR_IO_PENDING) 37 rv = callback.WaitForResult(); 38 if (rv < 0) 39 return ""; // error! 40 return std::string(buffer_->data(), rv); 41 } 42 43 // Loop until |msg| has been written to the socket or until an 44 // error occurs. 45 // If |address| is specified, then it is used for the destination 46 // to send to. Otherwise, will send to the last socket this server 47 // received from. 48 int SendToSocket(UDPServerSocket* socket, std::string msg) { 49 return SendToSocket(socket, msg, recv_from_address_); 50 } 51 52 int SendToSocket(UDPServerSocket* socket, 53 std::string msg, 54 const IPEndPoint& address) { 55 TestCompletionCallback callback; 56 57 int length = msg.length(); 58 scoped_refptr<StringIOBuffer> io_buffer(new StringIOBuffer(msg)); 59 scoped_refptr<DrainableIOBuffer> buffer( 60 new DrainableIOBuffer(io_buffer, length)); 61 62 int bytes_sent = 0; 63 while (buffer->BytesRemaining()) { 64 int rv = socket->SendTo(buffer, buffer->BytesRemaining(), 65 address, &callback); 66 if (rv == ERR_IO_PENDING) 67 rv = callback.WaitForResult(); 68 if (rv <= 0) 69 return bytes_sent > 0 ? bytes_sent : rv; 70 bytes_sent += rv; 71 buffer->DidConsume(rv); 72 } 73 return bytes_sent; 74 } 75 76 std::string ReadSocket(UDPClientSocket* socket) { 77 TestCompletionCallback callback; 78 79 int rv = socket->Read(buffer_, kMaxRead, &callback); 80 if (rv == ERR_IO_PENDING) 81 rv = callback.WaitForResult(); 82 if (rv < 0) 83 return ""; // error! 84 return std::string(buffer_->data(), rv); 85 } 86 87 // Loop until |msg| has been written to the socket or until an 88 // error occurs. 89 int WriteSocket(UDPClientSocket* socket, std::string msg) { 90 TestCompletionCallback callback; 91 92 int length = msg.length(); 93 scoped_refptr<StringIOBuffer> io_buffer(new StringIOBuffer(msg)); 94 scoped_refptr<DrainableIOBuffer> buffer( 95 new DrainableIOBuffer(io_buffer, length)); 96 97 int bytes_sent = 0; 98 while (buffer->BytesRemaining()) { 99 int rv = socket->Write(buffer, buffer->BytesRemaining(), &callback); 100 if (rv == ERR_IO_PENDING) 101 rv = callback.WaitForResult(); 102 if (rv <= 0) 103 return bytes_sent > 0 ? bytes_sent : rv; 104 bytes_sent += rv; 105 buffer->DidConsume(rv); 106 } 107 return bytes_sent; 108 } 109 110 protected: 111 static const int kMaxRead = 1024; 112 scoped_refptr<IOBufferWithSize> buffer_; 113 IPEndPoint recv_from_address_; 114 }; 115 116 // Creates and address from an ip/port and returns it in |address|. 117 void CreateUDPAddress(std::string ip_str, int port, IPEndPoint* address) { 118 IPAddressNumber ip_number; 119 bool rv = ParseIPLiteralToNumber(ip_str, &ip_number); 120 if (!rv) 121 return; 122 *address = IPEndPoint(ip_number, port); 123 } 124 125 TEST_F(UDPSocketTest, Connect) { 126 const int kPort = 9999; 127 std::string simple_message("hello world!"); 128 129 // Setup the server to listen. 130 IPEndPoint bind_address; 131 CreateUDPAddress("0.0.0.0", kPort, &bind_address); 132 UDPServerSocket server(NULL, NetLog::Source()); 133 int rv = server.Listen(bind_address); 134 EXPECT_EQ(OK, rv); 135 136 // Setup the client. 137 IPEndPoint server_address; 138 CreateUDPAddress("127.0.0.1", kPort, &server_address); 139 UDPClientSocket client(NULL, NetLog::Source()); 140 rv = client.Connect(server_address); 141 EXPECT_EQ(OK, rv); 142 143 // Client sends to the server. 144 rv = WriteSocket(&client, simple_message); 145 EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv)); 146 147 // Server waits for message. 148 std::string str = RecvFromSocket(&server); 149 DCHECK(simple_message == str); 150 151 // Server echoes reply. 152 rv = SendToSocket(&server, simple_message); 153 EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv)); 154 155 // Client waits for response. 156 str = ReadSocket(&client); 157 DCHECK(simple_message == str); 158 } 159 160 // In this test, we verify that connect() on a socket will have the effect 161 // of filtering reads on this socket only to data read from the destination 162 // we connected to. 163 // 164 // The purpose of this test is that some documentation indicates that connect 165 // binds the client's sends to send to a particular server endpoint, but does 166 // not bind the client's reads to only be from that endpoint, and that we need 167 // to always use recvfrom() to disambiguate. 168 TEST_F(UDPSocketTest, VerifyConnectBindsAddr) { 169 const int kPort1 = 9999; 170 const int kPort2 = 10000; 171 std::string simple_message("hello world!"); 172 std::string foreign_message("BAD MESSAGE TO GET!!"); 173 174 // Setup the first server to listen. 175 IPEndPoint bind_address; 176 CreateUDPAddress("0.0.0.0", kPort1, &bind_address); 177 UDPServerSocket server1(NULL, NetLog::Source()); 178 int rv = server1.Listen(bind_address); 179 EXPECT_EQ(OK, rv); 180 181 // Setup the second server to listen. 182 CreateUDPAddress("0.0.0.0", kPort2, &bind_address); 183 UDPServerSocket server2(NULL, NetLog::Source()); 184 rv = server2.Listen(bind_address); 185 EXPECT_EQ(OK, rv); 186 187 // Setup the client, connected to server 1. 188 IPEndPoint server_address; 189 CreateUDPAddress("127.0.0.1", kPort1, &server_address); 190 UDPClientSocket client(NULL, NetLog::Source()); 191 rv = client.Connect(server_address); 192 EXPECT_EQ(OK, rv); 193 194 // Client sends to server1. 195 rv = WriteSocket(&client, simple_message); 196 EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv)); 197 198 // Server1 waits for message. 199 std::string str = RecvFromSocket(&server1); 200 DCHECK(simple_message == str); 201 202 // Get the client's address. 203 IPEndPoint client_address; 204 rv = client.GetLocalAddress(&client_address); 205 EXPECT_EQ(OK, rv); 206 207 // Server2 sends reply. 208 rv = SendToSocket(&server2, foreign_message, 209 client_address); 210 EXPECT_EQ(foreign_message.length(), static_cast<size_t>(rv)); 211 212 // Server1 sends reply. 213 rv = SendToSocket(&server1, simple_message, 214 client_address); 215 EXPECT_EQ(simple_message.length(), static_cast<size_t>(rv)); 216 217 // Client waits for response. 218 str = ReadSocket(&client); 219 DCHECK(simple_message == str); 220 } 221 222 TEST_F(UDPSocketTest, ClientGetLocalPeerAddresses) { 223 struct TestData { 224 std::string remote_address; 225 std::string local_address; 226 bool may_fail; 227 } tests[] = { 228 { "127.0.00.1", "127.0.0.1", false }, 229 { "192.168.1.1", "127.0.0.1", false }, 230 { "::1", "::1", true }, 231 { "2001:db8:0::42", "::1", true }, 232 }; 233 for (size_t i = 0; i < ARRAYSIZE_UNSAFE(tests); i++) { 234 SCOPED_TRACE(std::string("Connecting from ") + tests[i].local_address + 235 std::string(" to ") + tests[i].remote_address); 236 237 net::IPAddressNumber ip_number; 238 net::ParseIPLiteralToNumber(tests[i].remote_address, &ip_number); 239 net::IPEndPoint remote_address(ip_number, 80); 240 net::ParseIPLiteralToNumber(tests[i].local_address, &ip_number); 241 net::IPEndPoint local_address(ip_number, 80); 242 243 UDPClientSocket client(NULL, NetLog::Source()); 244 int rv = client.Connect(remote_address); 245 if (tests[i].may_fail && rv == ERR_ADDRESS_UNREACHABLE) { 246 // Connect() may return ERR_ADDRESS_UNREACHABLE for IPv6 247 // addresses if IPv6 is not configured. 248 continue; 249 } 250 251 EXPECT_LE(ERR_IO_PENDING, rv); 252 253 IPEndPoint fetched_local_address; 254 rv = client.GetLocalAddress(&fetched_local_address); 255 EXPECT_EQ(OK, rv); 256 257 // TODO(mbelshe): figure out how to verify the IP and port. 258 // The port is dynamically generated by the udp stack. 259 // The IP is the real IP of the client, not necessarily 260 // loopback. 261 //EXPECT_EQ(local_address.address(), fetched_local_address.address()); 262 263 IPEndPoint fetched_remote_address; 264 rv = client.GetPeerAddress(&fetched_remote_address); 265 EXPECT_EQ(OK, rv); 266 267 EXPECT_EQ(remote_address, fetched_remote_address); 268 } 269 } 270 271 TEST_F(UDPSocketTest, ServerGetLocalAddress) { 272 IPEndPoint bind_address; 273 CreateUDPAddress("127.0.0.1", 0, &bind_address); 274 UDPServerSocket server(NULL, NetLog::Source()); 275 int rv = server.Listen(bind_address); 276 EXPECT_EQ(OK, rv); 277 278 IPEndPoint local_address; 279 rv = server.GetLocalAddress(&local_address); 280 EXPECT_EQ(rv, 0); 281 282 // Verify that port was allocated. 283 EXPECT_GT(local_address.port(), 0); 284 EXPECT_EQ(local_address.address(), bind_address.address()); 285 } 286 287 TEST_F(UDPSocketTest, ServerGetPeerAddress) { 288 IPEndPoint bind_address; 289 CreateUDPAddress("127.0.0.1", 0, &bind_address); 290 UDPServerSocket server(NULL, NetLog::Source()); 291 int rv = server.Listen(bind_address); 292 EXPECT_EQ(OK, rv); 293 294 IPEndPoint peer_address; 295 rv = server.GetPeerAddress(&peer_address); 296 EXPECT_EQ(rv, ERR_SOCKET_NOT_CONNECTED); 297 } 298 299 // Close the socket while read is pending. 300 TEST_F(UDPSocketTest, CloseWithPendingRead) { 301 IPEndPoint bind_address; 302 CreateUDPAddress("127.0.0.1", 0, &bind_address); 303 UDPServerSocket server(NULL, NetLog::Source()); 304 int rv = server.Listen(bind_address); 305 EXPECT_EQ(OK, rv); 306 307 TestCompletionCallback callback; 308 IPEndPoint from; 309 rv = server.RecvFrom(buffer_, kMaxRead, &from, &callback); 310 EXPECT_EQ(rv, ERR_IO_PENDING); 311 312 server.Close(); 313 314 EXPECT_FALSE(callback.have_result()); 315 } 316 317 } // namespace 318 319 } // namespace net 320