Home | History | Annotate | Download | only in fastboot
      1 /*
      2  * Copyright (C) 2015 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 // Tests socket functionality using loopback connections. The UDP tests assume that no packets are
     18 // lost, which should be the case for loopback communication, but is not guaranteed.
     19 //
     20 // Also tests our SocketMock class to make sure it works as expected and reports errors properly
     21 // if the mock expectations aren't met during a test.
     22 
     23 #include "socket.h"
     24 #include "socket_mock.h"
     25 
     26 #include <list>
     27 
     28 #include <gtest/gtest-spi.h>
     29 #include <gtest/gtest.h>
     30 
     31 static constexpr int kShortTimeoutMs = 10;
     32 static constexpr int kTestTimeoutMs = 3000;
     33 
     34 // Creates connected sockets |server| and |client|. Returns true on success.
     35 bool MakeConnectedSockets(Socket::Protocol protocol, std::unique_ptr<Socket>* server,
     36                           std::unique_ptr<Socket>* client,
     37                           const std::string hostname = "localhost") {
     38     *server = Socket::NewServer(protocol, 0);
     39     if (*server == nullptr) {
     40         ADD_FAILURE() << "Failed to create server.";
     41         return false;
     42     }
     43 
     44     *client = Socket::NewClient(protocol, hostname, (*server)->GetLocalPort(), nullptr);
     45     if (*client == nullptr) {
     46         ADD_FAILURE() << "Failed to create client.";
     47         return false;
     48     }
     49 
     50     // TCP passes the client off to a new socket.
     51     if (protocol == Socket::Protocol::kTcp) {
     52         *server = (*server)->Accept();
     53         if (*server == nullptr) {
     54             ADD_FAILURE() << "Failed to accept client connection.";
     55             return false;
     56         }
     57     }
     58 
     59     return true;
     60 }
     61 
     62 // Sends a string over a Socket. Returns true if the full string (without terminating char)
     63 // was sent.
     64 static bool SendString(Socket* sock, const std::string& message) {
     65     return sock->Send(message.c_str(), message.length());
     66 }
     67 
     68 // Receives a string from a Socket. Returns true if the full string (without terminating char)
     69 // was received.
     70 static bool ReceiveString(Socket* sock, const std::string& message) {
     71     std::string received(message.length(), '\0');
     72     ssize_t bytes = sock->ReceiveAll(&received[0], received.length(), kTestTimeoutMs);
     73     return static_cast<size_t>(bytes) == received.length() && received == message;
     74 }
     75 
     76 // Tests sending packets client -> server, then server -> client.
     77 TEST(SocketTest, TestSendAndReceive) {
     78     std::unique_ptr<Socket> server, client;
     79 
     80     for (Socket::Protocol protocol : {Socket::Protocol::kUdp, Socket::Protocol::kTcp}) {
     81         ASSERT_TRUE(MakeConnectedSockets(protocol, &server, &client));
     82 
     83         EXPECT_TRUE(SendString(client.get(), "foo"));
     84         EXPECT_TRUE(ReceiveString(server.get(), "foo"));
     85 
     86         EXPECT_TRUE(SendString(server.get(), "bar baz"));
     87         EXPECT_TRUE(ReceiveString(client.get(), "bar baz"));
     88     }
     89 }
     90 
     91 TEST(SocketTest, TestReceiveTimeout) {
     92     std::unique_ptr<Socket> server, client;
     93     char buffer[16];
     94 
     95     for (Socket::Protocol protocol : {Socket::Protocol::kUdp, Socket::Protocol::kTcp}) {
     96         ASSERT_TRUE(MakeConnectedSockets(protocol, &server, &client));
     97 
     98         EXPECT_EQ(-1, server->Receive(buffer, sizeof(buffer), kShortTimeoutMs));
     99         EXPECT_TRUE(server->ReceiveTimedOut());
    100 
    101         EXPECT_EQ(-1, client->Receive(buffer, sizeof(buffer), kShortTimeoutMs));
    102         EXPECT_TRUE(client->ReceiveTimedOut());
    103     }
    104 
    105     // UDP will wait for timeout if the other side closes.
    106     ASSERT_TRUE(MakeConnectedSockets(Socket::Protocol::kUdp, &server, &client));
    107     EXPECT_EQ(0, server->Close());
    108     EXPECT_EQ(-1, client->Receive(buffer, sizeof(buffer), kShortTimeoutMs));
    109     EXPECT_TRUE(client->ReceiveTimedOut());
    110 }
    111 
    112 TEST(SocketTest, TestReceiveFailure) {
    113     std::unique_ptr<Socket> server, client;
    114     char buffer[16];
    115 
    116     for (Socket::Protocol protocol : {Socket::Protocol::kUdp, Socket::Protocol::kTcp}) {
    117         ASSERT_TRUE(MakeConnectedSockets(protocol, &server, &client));
    118 
    119         EXPECT_EQ(0, server->Close());
    120         EXPECT_EQ(-1, server->Receive(buffer, sizeof(buffer), kTestTimeoutMs));
    121         EXPECT_FALSE(server->ReceiveTimedOut());
    122 
    123         EXPECT_EQ(0, client->Close());
    124         EXPECT_EQ(-1, client->Receive(buffer, sizeof(buffer), kTestTimeoutMs));
    125         EXPECT_FALSE(client->ReceiveTimedOut());
    126     }
    127 
    128     // TCP knows right away when the other side closes and returns 0 to indicate EOF.
    129     ASSERT_TRUE(MakeConnectedSockets(Socket::Protocol::kTcp, &server, &client));
    130     EXPECT_EQ(0, server->Close());
    131     EXPECT_EQ(0, client->Receive(buffer, sizeof(buffer), kTestTimeoutMs));
    132     EXPECT_FALSE(client->ReceiveTimedOut());
    133 }
    134 
    135 // Tests sending and receiving large packets.
    136 TEST(SocketTest, TestLargePackets) {
    137     std::string message(1024, '\0');
    138     std::unique_ptr<Socket> server, client;
    139 
    140     for (Socket::Protocol protocol : {Socket::Protocol::kUdp, Socket::Protocol::kTcp}) {
    141         ASSERT_TRUE(MakeConnectedSockets(protocol, &server, &client));
    142 
    143         // Run through the test a few times.
    144         for (int i = 0; i < 10; ++i) {
    145             // Use a different message each iteration to prevent false positives.
    146             for (size_t j = 0; j < message.length(); ++j) {
    147                 message[j] = static_cast<char>(i + j);
    148             }
    149 
    150             EXPECT_TRUE(SendString(client.get(), message));
    151             EXPECT_TRUE(ReceiveString(server.get(), message));
    152         }
    153     }
    154 }
    155 
    156 // Tests UDP receive overflow when the UDP packet is larger than the receive buffer.
    157 TEST(SocketTest, TestUdpReceiveOverflow) {
    158     std::unique_ptr<Socket> server, client;
    159     ASSERT_TRUE(MakeConnectedSockets(Socket::Protocol::kUdp, &server, &client));
    160 
    161     EXPECT_TRUE(SendString(client.get(), "1234567890"));
    162 
    163     // This behaves differently on different systems, either truncating the packet or returning -1.
    164     char buffer[5];
    165     ssize_t bytes = server->Receive(buffer, 5, kTestTimeoutMs);
    166     if (bytes == 5) {
    167         EXPECT_EQ(0, memcmp(buffer, "12345", 5));
    168     } else {
    169         EXPECT_EQ(-1, bytes);
    170     }
    171 }
    172 
    173 // Tests UDP multi-buffer send.
    174 TEST(SocketTest, TestUdpSendBuffers) {
    175     std::unique_ptr<Socket> sock = Socket::NewServer(Socket::Protocol::kUdp, 0);
    176     std::vector<std::string> data{"foo", "bar", "12345"};
    177     std::vector<cutils_socket_buffer_t> buffers{{data[0].data(), data[0].length()},
    178                                                 {data[1].data(), data[1].length()},
    179                                                 {data[2].data(), data[2].length()}};
    180     ssize_t mock_return_value = 0;
    181 
    182     // Mock out socket_send_buffers() to verify we're sending in the correct buffers and
    183     // return |mock_return_value|.
    184     sock->socket_send_buffers_function_ = [&buffers, &mock_return_value](
    185             cutils_socket_t /*cutils_sock*/, cutils_socket_buffer_t* sent_buffers,
    186             size_t num_sent_buffers) -> ssize_t {
    187         EXPECT_EQ(buffers.size(), num_sent_buffers);
    188         for (size_t i = 0; i < num_sent_buffers; ++i) {
    189             EXPECT_EQ(buffers[i].data, sent_buffers[i].data);
    190             EXPECT_EQ(buffers[i].length, sent_buffers[i].length);
    191         }
    192         return mock_return_value;
    193     };
    194 
    195     mock_return_value = strlen("foobar12345");
    196     EXPECT_TRUE(sock->Send(buffers));
    197 
    198     mock_return_value -= 1;
    199     EXPECT_FALSE(sock->Send(buffers));
    200 
    201     mock_return_value = 0;
    202     EXPECT_FALSE(sock->Send(buffers));
    203 
    204     mock_return_value = -1;
    205     EXPECT_FALSE(sock->Send(buffers));
    206 }
    207 
    208 // Tests TCP re-sending until socket_send_buffers() sends all data. This is a little complicated,
    209 // but the general idea is that we intercept calls to socket_send_buffers() using a lambda mock
    210 // function that simulates partial writes.
    211 TEST(SocketTest, TestTcpSendBuffers) {
    212     std::unique_ptr<Socket> sock = Socket::NewServer(Socket::Protocol::kTcp, 0);
    213     std::vector<std::string> data{"foo", "bar", "12345"};
    214     std::vector<cutils_socket_buffer_t> buffers{{data[0].data(), data[0].length()},
    215                                                 {data[1].data(), data[1].length()},
    216                                                 {data[2].data(), data[2].length()}};
    217 
    218     // Test breaking up the buffered send at various points.
    219     std::list<std::string> test_sends[] = {
    220             // Successes.
    221             {"foobar12345"},
    222             {"f", "oob", "ar12345"},
    223             {"fo", "obar12", "345"},
    224             {"foo", "bar12345"},
    225             {"foob", "ar123", "45"},
    226             {"f", "o", "o", "b", "a", "r", "1", "2", "3", "4", "5"},
    227 
    228             // Failures.
    229             {},
    230             {"f"},
    231             {"foo", "bar"},
    232             {"fo", "obar12"},
    233             {"foobar1234"}
    234     };
    235 
    236     for (auto& test : test_sends) {
    237         ssize_t bytes_sent = 0;
    238         bool expect_success = true;
    239 
    240         // Create a mock function for custom socket_send_buffers() behavior. This function will
    241         // check to make sure the input buffers start at the next unsent byte, then return the
    242         // number of bytes indicated by the next entry in |test|.
    243         sock->socket_send_buffers_function_ = [&bytes_sent, &data, &expect_success, &test](
    244                 cutils_socket_t /*cutils_sock*/, cutils_socket_buffer_t* buffers,
    245                 size_t num_buffers) -> ssize_t {
    246             EXPECT_TRUE(num_buffers > 0);
    247 
    248             // Failure case - pretend we errored out before sending all the buffers.
    249             if (test.empty()) {
    250                 expect_success = false;
    251                 return -1;
    252             }
    253 
    254             // Count the bytes we've sent to find where the next buffer should start and how many
    255             // bytes should be left in it.
    256             size_t byte_count = bytes_sent, data_index = 0;
    257             while (data_index < data.size()) {
    258                 if (byte_count >= data[data_index].length()) {
    259                     byte_count -= data[data_index].length();
    260                     ++data_index;
    261                 } else {
    262                     break;
    263                 }
    264             }
    265             void* expected_next_byte = &data[data_index][byte_count];
    266             size_t expected_next_size = data[data_index].length() - byte_count;
    267 
    268             EXPECT_EQ(data.size() - data_index, num_buffers);
    269             EXPECT_EQ(expected_next_byte, buffers[0].data);
    270             EXPECT_EQ(expected_next_size, buffers[0].length);
    271 
    272             std::string to_send = std::move(test.front());
    273             test.pop_front();
    274             bytes_sent += to_send.length();
    275             return to_send.length();
    276         };
    277 
    278         EXPECT_EQ(expect_success, sock->Send(buffers));
    279         EXPECT_TRUE(test.empty());
    280     }
    281 }
    282 
    283 TEST(SocketMockTest, TestSendSuccess) {
    284     SocketMock mock;
    285 
    286     mock.ExpectSend("foo");
    287     EXPECT_TRUE(SendString(&mock, "foo"));
    288 
    289     mock.ExpectSend("abc");
    290     mock.ExpectSend("123");
    291     EXPECT_TRUE(SendString(&mock, "abc"));
    292     EXPECT_TRUE(SendString(&mock, "123"));
    293 }
    294 
    295 TEST(SocketMockTest, TestSendFailure) {
    296     SocketMock* mock = new SocketMock;
    297 
    298     mock->ExpectSendFailure("foo");
    299     EXPECT_FALSE(SendString(mock, "foo"));
    300 
    301     EXPECT_NONFATAL_FAILURE(SendString(mock, "foo"), "no message was expected");
    302 
    303     mock->ExpectSend("foo");
    304     EXPECT_NONFATAL_FAILURE(SendString(mock, "bar"), "expected foo, but got bar");
    305     EXPECT_TRUE(SendString(mock, "foo"));
    306 
    307     mock->AddReceive("foo");
    308     EXPECT_NONFATAL_FAILURE(SendString(mock, "foo"), "called out-of-order");
    309     EXPECT_TRUE(ReceiveString(mock, "foo"));
    310 
    311     mock->ExpectSend("foo");
    312     EXPECT_NONFATAL_FAILURE(delete mock, "1 event(s) were not handled");
    313 }
    314 
    315 TEST(SocketMockTest, TestReceiveSuccess) {
    316     SocketMock mock;
    317 
    318     mock.AddReceive("foo");
    319     EXPECT_TRUE(ReceiveString(&mock, "foo"));
    320 
    321     mock.AddReceive("abc");
    322     mock.AddReceive("123");
    323     EXPECT_TRUE(ReceiveString(&mock, "abc"));
    324     EXPECT_TRUE(ReceiveString(&mock, "123"));
    325 
    326     // Make sure ReceiveAll() can piece together multiple receives.
    327     mock.AddReceive("foo");
    328     mock.AddReceive("bar");
    329     mock.AddReceive("123");
    330     EXPECT_TRUE(ReceiveString(&mock, "foobar123"));
    331 }
    332 
    333 TEST(SocketMockTest, TestReceiveFailure) {
    334     SocketMock* mock = new SocketMock;
    335 
    336     mock->AddReceiveFailure();
    337     EXPECT_FALSE(ReceiveString(mock, "foo"));
    338     EXPECT_FALSE(mock->ReceiveTimedOut());
    339 
    340     mock->AddReceiveTimeout();
    341     EXPECT_FALSE(ReceiveString(mock, "foo"));
    342     EXPECT_TRUE(mock->ReceiveTimedOut());
    343 
    344     mock->AddReceive("foo");
    345     mock->AddReceiveFailure();
    346     EXPECT_FALSE(ReceiveString(mock, "foobar"));
    347 
    348     EXPECT_NONFATAL_FAILURE(ReceiveString(mock, "foo"), "no message was ready");
    349 
    350     mock->ExpectSend("foo");
    351     EXPECT_NONFATAL_FAILURE(ReceiveString(mock, "foo"), "called out-of-order");
    352     EXPECT_TRUE(SendString(mock, "foo"));
    353 
    354     char c;
    355     mock->AddReceive("foo");
    356     EXPECT_NONFATAL_FAILURE(mock->Receive(&c, 1, 0), "not enough bytes (1) for foo");
    357     EXPECT_TRUE(ReceiveString(mock, "foo"));
    358 
    359     mock->AddReceive("foo");
    360     EXPECT_NONFATAL_FAILURE(delete mock, "1 event(s) were not handled");
    361 }
    362 
    363 TEST(SocketMockTest, TestAcceptSuccess) {
    364     SocketMock mock;
    365 
    366     SocketMock* mock_handler = new SocketMock;
    367     mock.AddAccept(std::unique_ptr<SocketMock>(mock_handler));
    368     EXPECT_EQ(mock_handler, mock.Accept().get());
    369 
    370     mock.AddAccept(nullptr);
    371     EXPECT_EQ(nullptr, mock.Accept().get());
    372 }
    373 
    374 TEST(SocketMockTest, TestAcceptFailure) {
    375     SocketMock* mock = new SocketMock;
    376 
    377     EXPECT_NONFATAL_FAILURE(mock->Accept(), "no socket was ready");
    378 
    379     mock->ExpectSend("foo");
    380     EXPECT_NONFATAL_FAILURE(mock->Accept(), "called out-of-order");
    381     EXPECT_TRUE(SendString(mock, "foo"));
    382 
    383     mock->AddAccept(nullptr);
    384     EXPECT_NONFATAL_FAILURE(delete mock, "1 event(s) were not handled");
    385 }
    386