1 /* 2 * Copyright (C) 2015 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 // Tests socket functionality using loopback connections. The UDP tests assume that no packets are 18 // lost, which should be the case for loopback communication, but is not guaranteed. 19 // 20 // Also tests our SocketMock class to make sure it works as expected and reports errors properly 21 // if the mock expectations aren't met during a test. 22 23 #include "socket.h" 24 #include "socket_mock.h" 25 26 #include <list> 27 28 #include <gtest/gtest-spi.h> 29 #include <gtest/gtest.h> 30 31 static constexpr int kShortTimeoutMs = 10; 32 static constexpr int kTestTimeoutMs = 3000; 33 34 // Creates connected sockets |server| and |client|. Returns true on success. 35 bool MakeConnectedSockets(Socket::Protocol protocol, std::unique_ptr<Socket>* server, 36 std::unique_ptr<Socket>* client, 37 const std::string& hostname = "localhost") { 38 *server = Socket::NewServer(protocol, 0); 39 if (*server == nullptr) { 40 ADD_FAILURE() << "Failed to create server."; 41 return false; 42 } 43 44 *client = Socket::NewClient(protocol, hostname, (*server)->GetLocalPort(), nullptr); 45 if (*client == nullptr) { 46 ADD_FAILURE() << "Failed to create client."; 47 return false; 48 } 49 50 // TCP passes the client off to a new socket. 51 if (protocol == Socket::Protocol::kTcp) { 52 *server = (*server)->Accept(); 53 if (*server == nullptr) { 54 ADD_FAILURE() << "Failed to accept client connection."; 55 return false; 56 } 57 } 58 59 return true; 60 } 61 62 // Sends a string over a Socket. Returns true if the full string (without terminating char) 63 // was sent. 64 static bool SendString(Socket* sock, const std::string& message) { 65 return sock->Send(message.c_str(), message.length()); 66 } 67 68 // Receives a string from a Socket. Returns true if the full string (without terminating char) 69 // was received. 70 static bool ReceiveString(Socket* sock, const std::string& message) { 71 std::string received(message.length(), '\0'); 72 ssize_t bytes = sock->ReceiveAll(&received[0], received.length(), kTestTimeoutMs); 73 return static_cast<size_t>(bytes) == received.length() && received == message; 74 } 75 76 // Tests sending packets client -> server, then server -> client. 77 TEST(SocketTest, TestSendAndReceive) { 78 std::unique_ptr<Socket> server, client; 79 80 for (Socket::Protocol protocol : {Socket::Protocol::kUdp, Socket::Protocol::kTcp}) { 81 ASSERT_TRUE(MakeConnectedSockets(protocol, &server, &client)); 82 83 EXPECT_TRUE(SendString(client.get(), "foo")); 84 EXPECT_TRUE(ReceiveString(server.get(), "foo")); 85 86 EXPECT_TRUE(SendString(server.get(), "bar baz")); 87 EXPECT_TRUE(ReceiveString(client.get(), "bar baz")); 88 } 89 } 90 91 TEST(SocketTest, TestReceiveTimeout) { 92 std::unique_ptr<Socket> server, client; 93 char buffer[16]; 94 95 for (Socket::Protocol protocol : {Socket::Protocol::kUdp, Socket::Protocol::kTcp}) { 96 ASSERT_TRUE(MakeConnectedSockets(protocol, &server, &client)); 97 98 EXPECT_EQ(-1, server->Receive(buffer, sizeof(buffer), kShortTimeoutMs)); 99 EXPECT_TRUE(server->ReceiveTimedOut()); 100 101 EXPECT_EQ(-1, client->Receive(buffer, sizeof(buffer), kShortTimeoutMs)); 102 EXPECT_TRUE(client->ReceiveTimedOut()); 103 } 104 105 // UDP will wait for timeout if the other side closes. 106 ASSERT_TRUE(MakeConnectedSockets(Socket::Protocol::kUdp, &server, &client)); 107 EXPECT_EQ(0, server->Close()); 108 EXPECT_EQ(-1, client->Receive(buffer, sizeof(buffer), kShortTimeoutMs)); 109 EXPECT_TRUE(client->ReceiveTimedOut()); 110 } 111 112 TEST(SocketTest, TestReceiveFailure) { 113 std::unique_ptr<Socket> server, client; 114 char buffer[16]; 115 116 for (Socket::Protocol protocol : {Socket::Protocol::kUdp, Socket::Protocol::kTcp}) { 117 ASSERT_TRUE(MakeConnectedSockets(protocol, &server, &client)); 118 119 EXPECT_EQ(0, server->Close()); 120 EXPECT_EQ(-1, server->Receive(buffer, sizeof(buffer), kTestTimeoutMs)); 121 EXPECT_FALSE(server->ReceiveTimedOut()); 122 123 EXPECT_EQ(0, client->Close()); 124 EXPECT_EQ(-1, client->Receive(buffer, sizeof(buffer), kTestTimeoutMs)); 125 EXPECT_FALSE(client->ReceiveTimedOut()); 126 } 127 128 // TCP knows right away when the other side closes and returns 0 to indicate EOF. 129 ASSERT_TRUE(MakeConnectedSockets(Socket::Protocol::kTcp, &server, &client)); 130 EXPECT_EQ(0, server->Close()); 131 EXPECT_EQ(0, client->Receive(buffer, sizeof(buffer), kTestTimeoutMs)); 132 EXPECT_FALSE(client->ReceiveTimedOut()); 133 } 134 135 // Tests sending and receiving large packets. 136 TEST(SocketTest, TestLargePackets) { 137 std::string message(1024, '\0'); 138 std::unique_ptr<Socket> server, client; 139 140 for (Socket::Protocol protocol : {Socket::Protocol::kUdp, Socket::Protocol::kTcp}) { 141 ASSERT_TRUE(MakeConnectedSockets(protocol, &server, &client)); 142 143 // Run through the test a few times. 144 for (int i = 0; i < 10; ++i) { 145 // Use a different message each iteration to prevent false positives. 146 for (size_t j = 0; j < message.length(); ++j) { 147 message[j] = static_cast<char>(i + j); 148 } 149 150 EXPECT_TRUE(SendString(client.get(), message)); 151 EXPECT_TRUE(ReceiveString(server.get(), message)); 152 } 153 } 154 } 155 156 // Tests UDP receive overflow when the UDP packet is larger than the receive buffer. 157 TEST(SocketTest, TestUdpReceiveOverflow) { 158 std::unique_ptr<Socket> server, client; 159 ASSERT_TRUE(MakeConnectedSockets(Socket::Protocol::kUdp, &server, &client)); 160 161 EXPECT_TRUE(SendString(client.get(), "1234567890")); 162 163 // This behaves differently on different systems, either truncating the packet or returning -1. 164 char buffer[5]; 165 ssize_t bytes = server->Receive(buffer, 5, kTestTimeoutMs); 166 if (bytes == 5) { 167 EXPECT_EQ(0, memcmp(buffer, "12345", 5)); 168 } else { 169 EXPECT_EQ(-1, bytes); 170 } 171 } 172 173 // Tests UDP multi-buffer send. 174 TEST(SocketTest, TestUdpSendBuffers) { 175 std::unique_ptr<Socket> sock = Socket::NewServer(Socket::Protocol::kUdp, 0); 176 std::vector<std::string> data{"foo", "bar", "12345"}; 177 std::vector<cutils_socket_buffer_t> buffers{{data[0].data(), data[0].length()}, 178 {data[1].data(), data[1].length()}, 179 {data[2].data(), data[2].length()}}; 180 ssize_t mock_return_value = 0; 181 182 // Mock out socket_send_buffers() to verify we're sending in the correct buffers and 183 // return |mock_return_value|. 184 sock->socket_send_buffers_function_ = [&buffers, &mock_return_value]( 185 cutils_socket_t /*cutils_sock*/, cutils_socket_buffer_t* sent_buffers, 186 size_t num_sent_buffers) -> ssize_t { 187 EXPECT_EQ(buffers.size(), num_sent_buffers); 188 for (size_t i = 0; i < num_sent_buffers; ++i) { 189 EXPECT_EQ(buffers[i].data, sent_buffers[i].data); 190 EXPECT_EQ(buffers[i].length, sent_buffers[i].length); 191 } 192 return mock_return_value; 193 }; 194 195 mock_return_value = strlen("foobar12345"); 196 EXPECT_TRUE(sock->Send(buffers)); 197 198 mock_return_value -= 1; 199 EXPECT_FALSE(sock->Send(buffers)); 200 201 mock_return_value = 0; 202 EXPECT_FALSE(sock->Send(buffers)); 203 204 mock_return_value = -1; 205 EXPECT_FALSE(sock->Send(buffers)); 206 } 207 208 // Tests TCP re-sending until socket_send_buffers() sends all data. This is a little complicated, 209 // but the general idea is that we intercept calls to socket_send_buffers() using a lambda mock 210 // function that simulates partial writes. 211 TEST(SocketTest, TestTcpSendBuffers) { 212 std::unique_ptr<Socket> sock = Socket::NewServer(Socket::Protocol::kTcp, 0); 213 std::vector<std::string> data{"foo", "bar", "12345"}; 214 std::vector<cutils_socket_buffer_t> buffers{{data[0].data(), data[0].length()}, 215 {data[1].data(), data[1].length()}, 216 {data[2].data(), data[2].length()}}; 217 218 // Test breaking up the buffered send at various points. 219 std::list<std::string> test_sends[] = { 220 // Successes. 221 {"foobar12345"}, 222 {"f", "oob", "ar12345"}, 223 {"fo", "obar12", "345"}, 224 {"foo", "bar12345"}, 225 {"foob", "ar123", "45"}, 226 {"f", "o", "o", "b", "a", "r", "1", "2", "3", "4", "5"}, 227 228 // Failures. 229 {}, 230 {"f"}, 231 {"foo", "bar"}, 232 {"fo", "obar12"}, 233 {"foobar1234"} 234 }; 235 236 for (auto& test : test_sends) { 237 ssize_t bytes_sent = 0; 238 bool expect_success = true; 239 240 // Create a mock function for custom socket_send_buffers() behavior. This function will 241 // check to make sure the input buffers start at the next unsent byte, then return the 242 // number of bytes indicated by the next entry in |test|. 243 sock->socket_send_buffers_function_ = [&bytes_sent, &data, &expect_success, &test]( 244 cutils_socket_t /*cutils_sock*/, cutils_socket_buffer_t* buffers, 245 size_t num_buffers) -> ssize_t { 246 EXPECT_TRUE(num_buffers > 0); 247 248 // Failure case - pretend we errored out before sending all the buffers. 249 if (test.empty()) { 250 expect_success = false; 251 return -1; 252 } 253 254 // Count the bytes we've sent to find where the next buffer should start and how many 255 // bytes should be left in it. 256 size_t byte_count = bytes_sent, data_index = 0; 257 while (data_index < data.size()) { 258 if (byte_count >= data[data_index].length()) { 259 byte_count -= data[data_index].length(); 260 ++data_index; 261 } else { 262 break; 263 } 264 } 265 void* expected_next_byte = &data[data_index][byte_count]; 266 size_t expected_next_size = data[data_index].length() - byte_count; 267 268 EXPECT_EQ(data.size() - data_index, num_buffers); 269 EXPECT_EQ(expected_next_byte, buffers[0].data); 270 EXPECT_EQ(expected_next_size, buffers[0].length); 271 272 std::string to_send = std::move(test.front()); 273 test.pop_front(); 274 bytes_sent += to_send.length(); 275 return to_send.length(); 276 }; 277 278 EXPECT_EQ(expect_success, sock->Send(buffers)); 279 EXPECT_TRUE(test.empty()); 280 } 281 } 282 283 TEST(SocketMockTest, TestSendSuccess) { 284 SocketMock mock; 285 286 mock.ExpectSend("foo"); 287 EXPECT_TRUE(SendString(&mock, "foo")); 288 289 mock.ExpectSend("abc"); 290 mock.ExpectSend("123"); 291 EXPECT_TRUE(SendString(&mock, "abc")); 292 EXPECT_TRUE(SendString(&mock, "123")); 293 } 294 295 TEST(SocketMockTest, TestSendFailure) { 296 SocketMock* mock = new SocketMock; 297 298 mock->ExpectSendFailure("foo"); 299 EXPECT_FALSE(SendString(mock, "foo")); 300 301 EXPECT_NONFATAL_FAILURE(SendString(mock, "foo"), "no message was expected"); 302 303 mock->ExpectSend("foo"); 304 EXPECT_NONFATAL_FAILURE(SendString(mock, "bar"), "expected foo, but got bar"); 305 EXPECT_TRUE(SendString(mock, "foo")); 306 307 mock->AddReceive("foo"); 308 EXPECT_NONFATAL_FAILURE(SendString(mock, "foo"), "called out-of-order"); 309 EXPECT_TRUE(ReceiveString(mock, "foo")); 310 311 mock->ExpectSend("foo"); 312 EXPECT_NONFATAL_FAILURE(delete mock, "1 event(s) were not handled"); 313 } 314 315 TEST(SocketMockTest, TestReceiveSuccess) { 316 SocketMock mock; 317 318 mock.AddReceive("foo"); 319 EXPECT_TRUE(ReceiveString(&mock, "foo")); 320 321 mock.AddReceive("abc"); 322 mock.AddReceive("123"); 323 EXPECT_TRUE(ReceiveString(&mock, "abc")); 324 EXPECT_TRUE(ReceiveString(&mock, "123")); 325 326 // Make sure ReceiveAll() can piece together multiple receives. 327 mock.AddReceive("foo"); 328 mock.AddReceive("bar"); 329 mock.AddReceive("123"); 330 EXPECT_TRUE(ReceiveString(&mock, "foobar123")); 331 } 332 333 TEST(SocketMockTest, TestReceiveFailure) { 334 SocketMock* mock = new SocketMock; 335 336 mock->AddReceiveFailure(); 337 EXPECT_FALSE(ReceiveString(mock, "foo")); 338 EXPECT_FALSE(mock->ReceiveTimedOut()); 339 340 mock->AddReceiveTimeout(); 341 EXPECT_FALSE(ReceiveString(mock, "foo")); 342 EXPECT_TRUE(mock->ReceiveTimedOut()); 343 344 mock->AddReceive("foo"); 345 mock->AddReceiveFailure(); 346 EXPECT_FALSE(ReceiveString(mock, "foobar")); 347 348 EXPECT_NONFATAL_FAILURE(ReceiveString(mock, "foo"), "no message was ready"); 349 350 mock->ExpectSend("foo"); 351 EXPECT_NONFATAL_FAILURE(ReceiveString(mock, "foo"), "called out-of-order"); 352 EXPECT_TRUE(SendString(mock, "foo")); 353 354 char c; 355 mock->AddReceive("foo"); 356 EXPECT_NONFATAL_FAILURE(mock->Receive(&c, 1, 0), "not enough bytes (1) for foo"); 357 EXPECT_TRUE(ReceiveString(mock, "foo")); 358 359 mock->AddReceive("foo"); 360 EXPECT_NONFATAL_FAILURE(delete mock, "1 event(s) were not handled"); 361 } 362 363 TEST(SocketMockTest, TestAcceptSuccess) { 364 SocketMock mock; 365 366 SocketMock* mock_handler = new SocketMock; 367 mock.AddAccept(std::unique_ptr<SocketMock>(mock_handler)); 368 EXPECT_EQ(mock_handler, mock.Accept().get()); 369 370 mock.AddAccept(nullptr); 371 EXPECT_EQ(nullptr, mock.Accept().get()); 372 } 373 374 TEST(SocketMockTest, TestAcceptFailure) { 375 SocketMock* mock = new SocketMock; 376 377 EXPECT_NONFATAL_FAILURE(mock->Accept(), "no socket was ready"); 378 379 mock->ExpectSend("foo"); 380 EXPECT_NONFATAL_FAILURE(mock->Accept(), "called out-of-order"); 381 EXPECT_TRUE(SendString(mock, "foo")); 382 383 mock->AddAccept(nullptr); 384 EXPECT_NONFATAL_FAILURE(delete mock, "1 event(s) were not handled"); 385 } 386