1 // Copyright 2013 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include <sys/socket.h> 6 7 #include "base/bind.h" 8 #include "base/files/file_path.h" 9 #include "base/path_service.h" 10 #include "base/posix/eintr_wrapper.h" 11 #include "base/synchronization/waitable_event.h" 12 #include "base/threading/thread.h" 13 #include "base/threading/thread_restrictions.h" 14 #include "ipc/unix_domain_socket_util.h" 15 #include "testing/gtest/include/gtest/gtest.h" 16 17 namespace { 18 19 class SocketAcceptor : public base::MessageLoopForIO::Watcher { 20 public: 21 SocketAcceptor(int fd, base::MessageLoopProxy* target_thread) 22 : server_fd_(-1), 23 target_thread_(target_thread), 24 started_watching_event_(false, false), 25 accepted_event_(false, false) { 26 target_thread->PostTask(FROM_HERE, 27 base::Bind(&SocketAcceptor::StartWatching, base::Unretained(this), fd)); 28 } 29 30 virtual ~SocketAcceptor() { 31 Close(); 32 } 33 34 int server_fd() const { return server_fd_; } 35 36 void WaitUntilReady() { 37 started_watching_event_.Wait(); 38 } 39 40 void WaitForAccept() { 41 accepted_event_.Wait(); 42 } 43 44 void Close() { 45 if (watcher_.get()) { 46 target_thread_->PostTask(FROM_HERE, 47 base::Bind(&SocketAcceptor::StopWatching, base::Unretained(this), 48 watcher_.release())); 49 } 50 } 51 52 private: 53 void StartWatching(int fd) { 54 watcher_.reset(new base::MessageLoopForIO::FileDescriptorWatcher); 55 base::MessageLoopForIO::current()->WatchFileDescriptor( 56 fd, true, base::MessageLoopForIO::WATCH_READ, watcher_.get(), this); 57 started_watching_event_.Signal(); 58 } 59 void StopWatching(base::MessageLoopForIO::FileDescriptorWatcher* watcher) { 60 watcher->StopWatchingFileDescriptor(); 61 delete watcher; 62 } 63 virtual void OnFileCanReadWithoutBlocking(int fd) OVERRIDE { 64 ASSERT_EQ(-1, server_fd_); 65 IPC::ServerAcceptConnection(fd, &server_fd_); 66 watcher_->StopWatchingFileDescriptor(); 67 accepted_event_.Signal(); 68 } 69 virtual void OnFileCanWriteWithoutBlocking(int fd) OVERRIDE {} 70 71 int server_fd_; 72 base::MessageLoopProxy* target_thread_; 73 scoped_ptr<base::MessageLoopForIO::FileDescriptorWatcher> watcher_; 74 base::WaitableEvent started_watching_event_; 75 base::WaitableEvent accepted_event_; 76 77 DISALLOW_COPY_AND_ASSIGN(SocketAcceptor); 78 }; 79 80 const base::FilePath GetChannelDir() { 81 #if defined(OS_ANDROID) 82 base::FilePath tmp_dir; 83 PathService::Get(base::DIR_CACHE, &tmp_dir); 84 return tmp_dir; 85 #else 86 return base::FilePath("/var/tmp"); 87 #endif 88 } 89 90 class TestUnixSocketConnection { 91 public: 92 TestUnixSocketConnection() 93 : worker_("WorkerThread"), 94 server_listen_fd_(-1), 95 server_fd_(-1), 96 client_fd_(-1) { 97 socket_name_ = GetChannelDir().Append("TestSocket"); 98 base::Thread::Options options; 99 options.message_loop_type = base::MessageLoop::TYPE_IO; 100 worker_.StartWithOptions(options); 101 } 102 103 bool CreateServerSocket() { 104 IPC::CreateServerUnixDomainSocket(socket_name_, &server_listen_fd_); 105 if (server_listen_fd_ < 0) 106 return false; 107 struct stat socket_stat; 108 stat(socket_name_.value().c_str(), &socket_stat); 109 EXPECT_TRUE(S_ISSOCK(socket_stat.st_mode)); 110 acceptor_.reset(new SocketAcceptor(server_listen_fd_, 111 worker_.message_loop_proxy().get())); 112 acceptor_->WaitUntilReady(); 113 return true; 114 } 115 116 bool CreateClientSocket() { 117 DCHECK(server_listen_fd_ >= 0); 118 IPC::CreateClientUnixDomainSocket(socket_name_, &client_fd_); 119 if (client_fd_ < 0) 120 return false; 121 acceptor_->WaitForAccept(); 122 server_fd_ = acceptor_->server_fd(); 123 return server_fd_ >= 0; 124 } 125 126 virtual ~TestUnixSocketConnection() { 127 if (client_fd_ >= 0) 128 close(client_fd_); 129 if (server_fd_ >= 0) 130 close(server_fd_); 131 if (server_listen_fd_ >= 0) { 132 close(server_listen_fd_); 133 unlink(socket_name_.value().c_str()); 134 } 135 } 136 137 int client_fd() const { return client_fd_; } 138 int server_fd() const { return server_fd_; } 139 140 private: 141 base::Thread worker_; 142 base::FilePath socket_name_; 143 int server_listen_fd_; 144 int server_fd_; 145 int client_fd_; 146 scoped_ptr<SocketAcceptor> acceptor_; 147 }; 148 149 // Ensure that IPC::CreateServerUnixDomainSocket creates a socket that 150 // IPC::CreateClientUnixDomainSocket can successfully connect to. 151 TEST(UnixDomainSocketUtil, Connect) { 152 TestUnixSocketConnection connection; 153 ASSERT_TRUE(connection.CreateServerSocket()); 154 ASSERT_TRUE(connection.CreateClientSocket()); 155 } 156 157 // Ensure that messages can be sent across the resulting socket. 158 TEST(UnixDomainSocketUtil, SendReceive) { 159 TestUnixSocketConnection connection; 160 ASSERT_TRUE(connection.CreateServerSocket()); 161 ASSERT_TRUE(connection.CreateClientSocket()); 162 163 const char buffer[] = "Hello, server!"; 164 size_t buf_len = sizeof(buffer); 165 size_t sent_bytes = 166 HANDLE_EINTR(send(connection.client_fd(), buffer, buf_len, 0)); 167 ASSERT_EQ(buf_len, sent_bytes); 168 char recv_buf[sizeof(buffer)]; 169 size_t received_bytes = 170 HANDLE_EINTR(recv(connection.server_fd(), recv_buf, buf_len, 0)); 171 ASSERT_EQ(buf_len, received_bytes); 172 ASSERT_EQ(0, memcmp(recv_buf, buffer, buf_len)); 173 } 174 175 } // namespace 176