Home | History | Annotate | Download | only in ipc
      1 // Copyright 2013 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include <sys/socket.h>
      6 
      7 #include "base/bind.h"
      8 #include "base/files/file_path.h"
      9 #include "base/path_service.h"
     10 #include "base/posix/eintr_wrapper.h"
     11 #include "base/synchronization/waitable_event.h"
     12 #include "base/threading/thread.h"
     13 #include "base/threading/thread_restrictions.h"
     14 #include "ipc/unix_domain_socket_util.h"
     15 #include "testing/gtest/include/gtest/gtest.h"
     16 
     17 namespace {
     18 
     19 class SocketAcceptor : public base::MessageLoopForIO::Watcher {
     20  public:
     21   SocketAcceptor(int fd, base::MessageLoopProxy* target_thread)
     22       : server_fd_(-1),
     23         target_thread_(target_thread),
     24         started_watching_event_(false, false),
     25         accepted_event_(false, false) {
     26     target_thread->PostTask(FROM_HERE,
     27         base::Bind(&SocketAcceptor::StartWatching, base::Unretained(this), fd));
     28   }
     29 
     30   virtual ~SocketAcceptor() {
     31     Close();
     32   }
     33 
     34   int server_fd() const { return server_fd_; }
     35 
     36   void WaitUntilReady() {
     37     started_watching_event_.Wait();
     38   }
     39 
     40   void WaitForAccept() {
     41     accepted_event_.Wait();
     42   }
     43 
     44   void Close() {
     45     if (watcher_.get()) {
     46       target_thread_->PostTask(FROM_HERE,
     47           base::Bind(&SocketAcceptor::StopWatching, base::Unretained(this),
     48               watcher_.release()));
     49     }
     50   }
     51 
     52  private:
     53   void StartWatching(int fd) {
     54     watcher_.reset(new base::MessageLoopForIO::FileDescriptorWatcher);
     55     base::MessageLoopForIO::current()->WatchFileDescriptor(
     56         fd, true, base::MessageLoopForIO::WATCH_READ, watcher_.get(), this);
     57     started_watching_event_.Signal();
     58   }
     59   void StopWatching(base::MessageLoopForIO::FileDescriptorWatcher* watcher) {
     60     watcher->StopWatchingFileDescriptor();
     61     delete watcher;
     62   }
     63   virtual void OnFileCanReadWithoutBlocking(int fd) OVERRIDE {
     64     ASSERT_EQ(-1, server_fd_);
     65     IPC::ServerAcceptConnection(fd, &server_fd_);
     66     watcher_->StopWatchingFileDescriptor();
     67     accepted_event_.Signal();
     68   }
     69   virtual void OnFileCanWriteWithoutBlocking(int fd) OVERRIDE {}
     70 
     71   int server_fd_;
     72   base::MessageLoopProxy* target_thread_;
     73   scoped_ptr<base::MessageLoopForIO::FileDescriptorWatcher> watcher_;
     74   base::WaitableEvent started_watching_event_;
     75   base::WaitableEvent accepted_event_;
     76 
     77   DISALLOW_COPY_AND_ASSIGN(SocketAcceptor);
     78 };
     79 
     80 const base::FilePath GetChannelDir() {
     81 #if defined(OS_ANDROID)
     82   base::FilePath tmp_dir;
     83   PathService::Get(base::DIR_CACHE, &tmp_dir);
     84   return tmp_dir;
     85 #else
     86   return base::FilePath("/var/tmp");
     87 #endif
     88 }
     89 
     90 class TestUnixSocketConnection {
     91  public:
     92   TestUnixSocketConnection()
     93       : worker_("WorkerThread"),
     94         server_listen_fd_(-1),
     95         server_fd_(-1),
     96         client_fd_(-1) {
     97     socket_name_ = GetChannelDir().Append("TestSocket");
     98     base::Thread::Options options;
     99     options.message_loop_type = base::MessageLoop::TYPE_IO;
    100     worker_.StartWithOptions(options);
    101   }
    102 
    103   bool CreateServerSocket() {
    104     IPC::CreateServerUnixDomainSocket(socket_name_, &server_listen_fd_);
    105     if (server_listen_fd_ < 0)
    106       return false;
    107     struct stat socket_stat;
    108     stat(socket_name_.value().c_str(), &socket_stat);
    109     EXPECT_TRUE(S_ISSOCK(socket_stat.st_mode));
    110     acceptor_.reset(new SocketAcceptor(server_listen_fd_,
    111                                        worker_.message_loop_proxy().get()));
    112     acceptor_->WaitUntilReady();
    113     return true;
    114   }
    115 
    116   bool CreateClientSocket() {
    117     DCHECK(server_listen_fd_ >= 0);
    118     IPC::CreateClientUnixDomainSocket(socket_name_, &client_fd_);
    119     if (client_fd_ < 0)
    120       return false;
    121     acceptor_->WaitForAccept();
    122     server_fd_ = acceptor_->server_fd();
    123     return server_fd_ >= 0;
    124   }
    125 
    126   virtual ~TestUnixSocketConnection() {
    127     if (client_fd_ >= 0)
    128       close(client_fd_);
    129     if (server_fd_ >= 0)
    130       close(server_fd_);
    131     if (server_listen_fd_ >= 0) {
    132       close(server_listen_fd_);
    133       unlink(socket_name_.value().c_str());
    134     }
    135   }
    136 
    137   int client_fd() const { return client_fd_; }
    138   int server_fd() const { return server_fd_; }
    139 
    140  private:
    141   base::Thread worker_;
    142   base::FilePath socket_name_;
    143   int server_listen_fd_;
    144   int server_fd_;
    145   int client_fd_;
    146   scoped_ptr<SocketAcceptor> acceptor_;
    147 };
    148 
    149 // Ensure that IPC::CreateServerUnixDomainSocket creates a socket that
    150 // IPC::CreateClientUnixDomainSocket can successfully connect to.
    151 TEST(UnixDomainSocketUtil, Connect) {
    152   TestUnixSocketConnection connection;
    153   ASSERT_TRUE(connection.CreateServerSocket());
    154   ASSERT_TRUE(connection.CreateClientSocket());
    155 }
    156 
    157 // Ensure that messages can be sent across the resulting socket.
    158 TEST(UnixDomainSocketUtil, SendReceive) {
    159   TestUnixSocketConnection connection;
    160   ASSERT_TRUE(connection.CreateServerSocket());
    161   ASSERT_TRUE(connection.CreateClientSocket());
    162 
    163   const char buffer[] = "Hello, server!";
    164   size_t buf_len = sizeof(buffer);
    165   size_t sent_bytes =
    166       HANDLE_EINTR(send(connection.client_fd(), buffer, buf_len, 0));
    167   ASSERT_EQ(buf_len, sent_bytes);
    168   char recv_buf[sizeof(buffer)];
    169   size_t received_bytes =
    170       HANDLE_EINTR(recv(connection.server_fd(), recv_buf, buf_len, 0));
    171   ASSERT_EQ(buf_len, received_bytes);
    172   ASSERT_EQ(0, memcmp(recv_buf, buffer, buf_len));
    173 }
    174 
    175 }  // namespace
    176