Home | History | Annotate | Download | only in socket
      1 // Copyright 2014 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/socket/unix_domain_client_socket_posix.h"
      6 
      7 #include <unistd.h>
      8 
      9 #include "base/bind.h"
     10 #include "base/files/file_path.h"
     11 #include "base/files/scoped_temp_dir.h"
     12 #include "base/memory/scoped_ptr.h"
     13 #include "base/posix/eintr_wrapper.h"
     14 #include "net/base/io_buffer.h"
     15 #include "net/base/net_errors.h"
     16 #include "net/base/test_completion_callback.h"
     17 #include "net/socket/socket_libevent.h"
     18 #include "net/socket/unix_domain_server_socket_posix.h"
     19 #include "testing/gtest/include/gtest/gtest.h"
     20 
     21 namespace net {
     22 namespace {
     23 
     24 const char kSocketFilename[] = "socket_for_testing";
     25 
     26 bool UserCanConnectCallback(
     27     bool allow_user, const UnixDomainServerSocket::Credentials& credentials) {
     28   // Here peers are running in same process.
     29 #if defined(OS_LINUX) || defined(OS_ANDROID)
     30   EXPECT_EQ(getpid(), credentials.process_id);
     31 #endif
     32   EXPECT_EQ(getuid(), credentials.user_id);
     33   EXPECT_EQ(getgid(), credentials.group_id);
     34   return allow_user;
     35 }
     36 
     37 UnixDomainServerSocket::AuthCallback CreateAuthCallback(bool allow_user) {
     38   return base::Bind(&UserCanConnectCallback, allow_user);
     39 }
     40 
     41 // Connects socket synchronously.
     42 int ConnectSynchronously(StreamSocket* socket) {
     43   TestCompletionCallback connect_callback;
     44   int rv = socket->Connect(connect_callback.callback());
     45   if (rv == ERR_IO_PENDING)
     46     rv = connect_callback.WaitForResult();
     47   return rv;
     48 }
     49 
     50 // Reads data from |socket| until it fills |buf| at least up to |min_data_len|.
     51 // Returns length of data read, or a net error.
     52 int ReadSynchronously(StreamSocket* socket,
     53                       IOBuffer* buf,
     54                       int buf_len,
     55                       int min_data_len) {
     56   DCHECK_LE(min_data_len, buf_len);
     57   scoped_refptr<DrainableIOBuffer> read_buf(
     58       new DrainableIOBuffer(buf, buf_len));
     59   TestCompletionCallback read_callback;
     60   // Iterate reading several times (but not infinite) until it reads at least
     61   // |min_data_len| bytes into |buf|.
     62   for (int retry_count = 10;
     63        retry_count > 0 && (read_buf->BytesConsumed() < min_data_len ||
     64                            // Try at least once when min_data_len == 0.
     65                            min_data_len == 0);
     66        --retry_count) {
     67     int rv = socket->Read(
     68         read_buf.get(), read_buf->BytesRemaining(), read_callback.callback());
     69     EXPECT_GE(read_buf->BytesRemaining(), rv);
     70     if (rv == ERR_IO_PENDING) {
     71       // If |min_data_len| is 0, returns ERR_IO_PENDING to distinguish the case
     72       // when some data has been read.
     73       if (min_data_len == 0) {
     74         // No data has been read because of for-loop condition.
     75         DCHECK_EQ(0, read_buf->BytesConsumed());
     76         return ERR_IO_PENDING;
     77       }
     78       rv = read_callback.WaitForResult();
     79     }
     80     EXPECT_NE(ERR_IO_PENDING, rv);
     81     if (rv < 0)
     82       return rv;
     83     read_buf->DidConsume(rv);
     84   }
     85   EXPECT_LE(0, read_buf->BytesRemaining());
     86   return read_buf->BytesConsumed();
     87 }
     88 
     89 // Writes data to |socket| until it completes writing |buf| up to |buf_len|.
     90 // Returns length of data written, or a net error.
     91 int WriteSynchronously(StreamSocket* socket,
     92                        IOBuffer* buf,
     93                        int buf_len) {
     94   scoped_refptr<DrainableIOBuffer> write_buf(
     95       new DrainableIOBuffer(buf, buf_len));
     96   TestCompletionCallback write_callback;
     97   // Iterate writing several times (but not infinite) until it writes buf fully.
     98   for (int retry_count = 10;
     99        retry_count > 0 && write_buf->BytesRemaining() > 0;
    100        --retry_count) {
    101     int rv = socket->Write(write_buf.get(),
    102                            write_buf->BytesRemaining(),
    103                            write_callback.callback());
    104     EXPECT_GE(write_buf->BytesRemaining(), rv);
    105     if (rv == ERR_IO_PENDING)
    106       rv = write_callback.WaitForResult();
    107     EXPECT_NE(ERR_IO_PENDING, rv);
    108     if (rv < 0)
    109       return rv;
    110     write_buf->DidConsume(rv);
    111   }
    112   EXPECT_LE(0, write_buf->BytesRemaining());
    113   return write_buf->BytesConsumed();
    114 }
    115 
    116 class UnixDomainClientSocketTest : public testing::Test {
    117  protected:
    118   UnixDomainClientSocketTest() {
    119     EXPECT_TRUE(temp_dir_.CreateUniqueTempDir());
    120     socket_path_ = temp_dir_.path().Append(kSocketFilename).value();
    121   }
    122 
    123   base::ScopedTempDir temp_dir_;
    124   std::string socket_path_;
    125 };
    126 
    127 TEST_F(UnixDomainClientSocketTest, Connect) {
    128   const bool kUseAbstractNamespace = false;
    129 
    130   UnixDomainServerSocket server_socket(CreateAuthCallback(true),
    131                                        kUseAbstractNamespace);
    132   EXPECT_EQ(OK, server_socket.ListenWithAddressAndPort(socket_path_, 0, 1));
    133 
    134   scoped_ptr<StreamSocket> accepted_socket;
    135   TestCompletionCallback accept_callback;
    136   EXPECT_EQ(ERR_IO_PENDING,
    137             server_socket.Accept(&accepted_socket, accept_callback.callback()));
    138   EXPECT_FALSE(accepted_socket);
    139 
    140   UnixDomainClientSocket client_socket(socket_path_, kUseAbstractNamespace);
    141   EXPECT_FALSE(client_socket.IsConnected());
    142 
    143   EXPECT_EQ(OK, ConnectSynchronously(&client_socket));
    144   EXPECT_TRUE(client_socket.IsConnected());
    145   // Server has not yet been notified of the connection.
    146   EXPECT_FALSE(accepted_socket);
    147 
    148   EXPECT_EQ(OK, accept_callback.WaitForResult());
    149   EXPECT_TRUE(accepted_socket);
    150   EXPECT_TRUE(accepted_socket->IsConnected());
    151 }
    152 
    153 TEST_F(UnixDomainClientSocketTest, ConnectWithSocketDescriptor) {
    154   const bool kUseAbstractNamespace = false;
    155 
    156   UnixDomainServerSocket server_socket(CreateAuthCallback(true),
    157                                        kUseAbstractNamespace);
    158   EXPECT_EQ(OK, server_socket.ListenWithAddressAndPort(socket_path_, 0, 1));
    159 
    160   SocketDescriptor accepted_socket_fd = kInvalidSocket;
    161   TestCompletionCallback accept_callback;
    162   EXPECT_EQ(ERR_IO_PENDING,
    163             server_socket.AcceptSocketDescriptor(&accepted_socket_fd,
    164                                                  accept_callback.callback()));
    165   EXPECT_EQ(kInvalidSocket, accepted_socket_fd);
    166 
    167   UnixDomainClientSocket client_socket(socket_path_, kUseAbstractNamespace);
    168   EXPECT_FALSE(client_socket.IsConnected());
    169 
    170   EXPECT_EQ(OK, ConnectSynchronously(&client_socket));
    171   EXPECT_TRUE(client_socket.IsConnected());
    172   // Server has not yet been notified of the connection.
    173   EXPECT_EQ(kInvalidSocket, accepted_socket_fd);
    174 
    175   EXPECT_EQ(OK, accept_callback.WaitForResult());
    176   EXPECT_NE(kInvalidSocket, accepted_socket_fd);
    177 
    178   SocketDescriptor client_socket_fd = client_socket.ReleaseConnectedSocket();
    179   EXPECT_NE(kInvalidSocket, client_socket_fd);
    180 
    181   // Now, re-wrap client_socket_fd in a UnixDomainClientSocket and try a read
    182   // to be sure it hasn't gotten accidentally closed.
    183   SockaddrStorage addr;
    184   ASSERT_TRUE(UnixDomainClientSocket::FillAddress(socket_path_, false, &addr));
    185   scoped_ptr<SocketLibevent> adopter(new SocketLibevent);
    186   adopter->AdoptConnectedSocket(client_socket_fd, addr);
    187   UnixDomainClientSocket rewrapped_socket(adopter.Pass());
    188   EXPECT_TRUE(rewrapped_socket.IsConnected());
    189 
    190   // Try to read data.
    191   const int kReadDataSize = 10;
    192   scoped_refptr<IOBuffer> read_buffer(new IOBuffer(kReadDataSize));
    193   TestCompletionCallback read_callback;
    194   EXPECT_EQ(ERR_IO_PENDING,
    195             rewrapped_socket.Read(
    196                 read_buffer.get(), kReadDataSize, read_callback.callback()));
    197 
    198   EXPECT_EQ(0, IGNORE_EINTR(close(accepted_socket_fd)));
    199 }
    200 
    201 TEST_F(UnixDomainClientSocketTest, ConnectWithAbstractNamespace) {
    202   const bool kUseAbstractNamespace = true;
    203 
    204   UnixDomainClientSocket client_socket(socket_path_, kUseAbstractNamespace);
    205   EXPECT_FALSE(client_socket.IsConnected());
    206 
    207 #if defined(OS_ANDROID) || defined(OS_LINUX)
    208   UnixDomainServerSocket server_socket(CreateAuthCallback(true),
    209                                        kUseAbstractNamespace);
    210   EXPECT_EQ(OK, server_socket.ListenWithAddressAndPort(socket_path_, 0, 1));
    211 
    212   scoped_ptr<StreamSocket> accepted_socket;
    213   TestCompletionCallback accept_callback;
    214   EXPECT_EQ(ERR_IO_PENDING,
    215             server_socket.Accept(&accepted_socket, accept_callback.callback()));
    216   EXPECT_FALSE(accepted_socket);
    217 
    218   EXPECT_EQ(OK, ConnectSynchronously(&client_socket));
    219   EXPECT_TRUE(client_socket.IsConnected());
    220   // Server has not yet beend notified of the connection.
    221   EXPECT_FALSE(accepted_socket);
    222 
    223   EXPECT_EQ(OK, accept_callback.WaitForResult());
    224   EXPECT_TRUE(accepted_socket);
    225   EXPECT_TRUE(accepted_socket->IsConnected());
    226 #else
    227   EXPECT_EQ(ERR_ADDRESS_INVALID, ConnectSynchronously(&client_socket));
    228 #endif
    229 }
    230 
    231 TEST_F(UnixDomainClientSocketTest, ConnectToNonExistentSocket) {
    232   const bool kUseAbstractNamespace = false;
    233 
    234   UnixDomainClientSocket client_socket(socket_path_, kUseAbstractNamespace);
    235   EXPECT_FALSE(client_socket.IsConnected());
    236   EXPECT_EQ(ERR_FILE_NOT_FOUND, ConnectSynchronously(&client_socket));
    237 }
    238 
    239 TEST_F(UnixDomainClientSocketTest,
    240        ConnectToNonExistentSocketWithAbstractNamespace) {
    241   const bool kUseAbstractNamespace = true;
    242 
    243   UnixDomainClientSocket client_socket(socket_path_, kUseAbstractNamespace);
    244   EXPECT_FALSE(client_socket.IsConnected());
    245 
    246   TestCompletionCallback connect_callback;
    247 #if defined(OS_ANDROID) || defined(OS_LINUX)
    248   EXPECT_EQ(ERR_CONNECTION_REFUSED, ConnectSynchronously(&client_socket));
    249 #else
    250   EXPECT_EQ(ERR_ADDRESS_INVALID, ConnectSynchronously(&client_socket));
    251 #endif
    252 }
    253 
    254 TEST_F(UnixDomainClientSocketTest, DisconnectFromClient) {
    255   UnixDomainServerSocket server_socket(CreateAuthCallback(true), false);
    256   EXPECT_EQ(OK, server_socket.ListenWithAddressAndPort(socket_path_, 0, 1));
    257   scoped_ptr<StreamSocket> accepted_socket;
    258   TestCompletionCallback accept_callback;
    259   EXPECT_EQ(ERR_IO_PENDING,
    260             server_socket.Accept(&accepted_socket, accept_callback.callback()));
    261   UnixDomainClientSocket client_socket(socket_path_, false);
    262   EXPECT_EQ(OK, ConnectSynchronously(&client_socket));
    263 
    264   EXPECT_EQ(OK, accept_callback.WaitForResult());
    265   EXPECT_TRUE(accepted_socket->IsConnected());
    266   EXPECT_TRUE(client_socket.IsConnected());
    267 
    268   // Try to read data.
    269   const int kReadDataSize = 10;
    270   scoped_refptr<IOBuffer> read_buffer(new IOBuffer(kReadDataSize));
    271   TestCompletionCallback read_callback;
    272   EXPECT_EQ(ERR_IO_PENDING,
    273             accepted_socket->Read(
    274                 read_buffer.get(), kReadDataSize, read_callback.callback()));
    275 
    276   // Disconnect from client side.
    277   client_socket.Disconnect();
    278   EXPECT_FALSE(client_socket.IsConnected());
    279   EXPECT_FALSE(accepted_socket->IsConnected());
    280 
    281   // Connection closed by peer.
    282   EXPECT_EQ(0 /* EOF */, read_callback.WaitForResult());
    283   // Note that read callback won't be called when the connection is closed
    284   // locally before the peer closes it. SocketLibevent just clears callbacks.
    285 }
    286 
    287 TEST_F(UnixDomainClientSocketTest, DisconnectFromServer) {
    288   UnixDomainServerSocket server_socket(CreateAuthCallback(true), false);
    289   EXPECT_EQ(OK, server_socket.ListenWithAddressAndPort(socket_path_, 0, 1));
    290   scoped_ptr<StreamSocket> accepted_socket;
    291   TestCompletionCallback accept_callback;
    292   EXPECT_EQ(ERR_IO_PENDING,
    293             server_socket.Accept(&accepted_socket, accept_callback.callback()));
    294   UnixDomainClientSocket client_socket(socket_path_, false);
    295   EXPECT_EQ(OK, ConnectSynchronously(&client_socket));
    296 
    297   EXPECT_EQ(OK, accept_callback.WaitForResult());
    298   EXPECT_TRUE(accepted_socket->IsConnected());
    299   EXPECT_TRUE(client_socket.IsConnected());
    300 
    301   // Try to read data.
    302   const int kReadDataSize = 10;
    303   scoped_refptr<IOBuffer> read_buffer(new IOBuffer(kReadDataSize));
    304   TestCompletionCallback read_callback;
    305   EXPECT_EQ(ERR_IO_PENDING,
    306             client_socket.Read(
    307                 read_buffer.get(), kReadDataSize, read_callback.callback()));
    308 
    309   // Disconnect from server side.
    310   accepted_socket->Disconnect();
    311   EXPECT_FALSE(accepted_socket->IsConnected());
    312   EXPECT_FALSE(client_socket.IsConnected());
    313 
    314   // Connection closed by peer.
    315   EXPECT_EQ(0 /* EOF */, read_callback.WaitForResult());
    316   // Note that read callback won't be called when the connection is closed
    317   // locally before the peer closes it. SocketLibevent just clears callbacks.
    318 }
    319 
    320 TEST_F(UnixDomainClientSocketTest, ReadAfterWrite) {
    321   UnixDomainServerSocket server_socket(CreateAuthCallback(true), false);
    322   EXPECT_EQ(OK, server_socket.ListenWithAddressAndPort(socket_path_, 0, 1));
    323   scoped_ptr<StreamSocket> accepted_socket;
    324   TestCompletionCallback accept_callback;
    325   EXPECT_EQ(ERR_IO_PENDING,
    326             server_socket.Accept(&accepted_socket, accept_callback.callback()));
    327   UnixDomainClientSocket client_socket(socket_path_, false);
    328   EXPECT_EQ(OK, ConnectSynchronously(&client_socket));
    329 
    330   EXPECT_EQ(OK, accept_callback.WaitForResult());
    331   EXPECT_TRUE(accepted_socket->IsConnected());
    332   EXPECT_TRUE(client_socket.IsConnected());
    333 
    334   // Send data from client to server.
    335   const int kWriteDataSize = 10;
    336   scoped_refptr<IOBuffer> write_buffer(
    337       new StringIOBuffer(std::string(kWriteDataSize, 'd')));
    338   EXPECT_EQ(
    339       kWriteDataSize,
    340       WriteSynchronously(&client_socket, write_buffer.get(), kWriteDataSize));
    341 
    342   // The buffer is bigger than write data size.
    343   const int kReadBufferSize = kWriteDataSize * 2;
    344   scoped_refptr<IOBuffer> read_buffer(new IOBuffer(kReadBufferSize));
    345   EXPECT_EQ(kWriteDataSize,
    346             ReadSynchronously(accepted_socket.get(),
    347                               read_buffer.get(),
    348                               kReadBufferSize,
    349                               kWriteDataSize));
    350   EXPECT_EQ(std::string(write_buffer->data(), kWriteDataSize),
    351             std::string(read_buffer->data(), kWriteDataSize));
    352 
    353   // Send data from server and client.
    354   EXPECT_EQ(kWriteDataSize,
    355             WriteSynchronously(
    356                 accepted_socket.get(), write_buffer.get(), kWriteDataSize));
    357 
    358   // Read multiple times.
    359   const int kSmallReadBufferSize = kWriteDataSize / 3;
    360   EXPECT_EQ(kSmallReadBufferSize,
    361             ReadSynchronously(&client_socket,
    362                               read_buffer.get(),
    363                               kSmallReadBufferSize,
    364                               kSmallReadBufferSize));
    365   EXPECT_EQ(std::string(write_buffer->data(), kSmallReadBufferSize),
    366             std::string(read_buffer->data(), kSmallReadBufferSize));
    367 
    368   EXPECT_EQ(kWriteDataSize - kSmallReadBufferSize,
    369             ReadSynchronously(&client_socket,
    370                               read_buffer.get(),
    371                               kReadBufferSize,
    372                               kWriteDataSize - kSmallReadBufferSize));
    373   EXPECT_EQ(std::string(write_buffer->data() + kSmallReadBufferSize,
    374                         kWriteDataSize - kSmallReadBufferSize),
    375             std::string(read_buffer->data(),
    376                         kWriteDataSize - kSmallReadBufferSize));
    377 
    378   // No more data.
    379   EXPECT_EQ(
    380       ERR_IO_PENDING,
    381       ReadSynchronously(&client_socket, read_buffer.get(), kReadBufferSize, 0));
    382 
    383   // Disconnect from server side after read-write.
    384   accepted_socket->Disconnect();
    385   EXPECT_FALSE(accepted_socket->IsConnected());
    386   EXPECT_FALSE(client_socket.IsConnected());
    387 }
    388 
    389 TEST_F(UnixDomainClientSocketTest, ReadBeforeWrite) {
    390   UnixDomainServerSocket server_socket(CreateAuthCallback(true), false);
    391   EXPECT_EQ(OK, server_socket.ListenWithAddressAndPort(socket_path_, 0, 1));
    392   scoped_ptr<StreamSocket> accepted_socket;
    393   TestCompletionCallback accept_callback;
    394   EXPECT_EQ(ERR_IO_PENDING,
    395             server_socket.Accept(&accepted_socket, accept_callback.callback()));
    396   UnixDomainClientSocket client_socket(socket_path_, false);
    397   EXPECT_EQ(OK, ConnectSynchronously(&client_socket));
    398 
    399   EXPECT_EQ(OK, accept_callback.WaitForResult());
    400   EXPECT_TRUE(accepted_socket->IsConnected());
    401   EXPECT_TRUE(client_socket.IsConnected());
    402 
    403   // Wait for data from client.
    404   const int kWriteDataSize = 10;
    405   const int kReadBufferSize = kWriteDataSize * 2;
    406   const int kSmallReadBufferSize = kWriteDataSize / 3;
    407   // Read smaller than write data size first.
    408   scoped_refptr<IOBuffer> read_buffer(new IOBuffer(kReadBufferSize));
    409   TestCompletionCallback read_callback;
    410   EXPECT_EQ(
    411       ERR_IO_PENDING,
    412       accepted_socket->Read(
    413           read_buffer.get(), kSmallReadBufferSize, read_callback.callback()));
    414 
    415   scoped_refptr<IOBuffer> write_buffer(
    416       new StringIOBuffer(std::string(kWriteDataSize, 'd')));
    417   EXPECT_EQ(
    418       kWriteDataSize,
    419       WriteSynchronously(&client_socket, write_buffer.get(), kWriteDataSize));
    420 
    421   // First read completed.
    422   int rv = read_callback.WaitForResult();
    423   EXPECT_LT(0, rv);
    424   EXPECT_LE(rv, kSmallReadBufferSize);
    425 
    426   // Read remaining data.
    427   const int kExpectedRemainingDataSize = kWriteDataSize - rv;
    428   EXPECT_LE(0, kExpectedRemainingDataSize);
    429   EXPECT_EQ(kExpectedRemainingDataSize,
    430             ReadSynchronously(accepted_socket.get(),
    431                               read_buffer.get(),
    432                               kReadBufferSize,
    433                               kExpectedRemainingDataSize));
    434   // No more data.
    435   EXPECT_EQ(ERR_IO_PENDING,
    436             ReadSynchronously(
    437                 accepted_socket.get(), read_buffer.get(), kReadBufferSize, 0));
    438 
    439   // Disconnect from server side after read-write.
    440   accepted_socket->Disconnect();
    441   EXPECT_FALSE(accepted_socket->IsConnected());
    442   EXPECT_FALSE(client_socket.IsConnected());
    443 }
    444 
    445 }  // namespace
    446 }  // namespace net
    447