1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include <errno.h> 6 #include <fcntl.h> 7 #include <poll.h> 8 #include <sys/socket.h> 9 #include <sys/stat.h> 10 #include <sys/time.h> 11 #include <sys/types.h> 12 #include <sys/un.h> 13 #include <unistd.h> 14 15 #include <cstring> 16 #include <queue> 17 #include <string> 18 19 #include "base/bind.h" 20 #include "base/callback.h" 21 #include "base/compiler_specific.h" 22 #include "base/file_util.h" 23 #include "base/files/file_path.h" 24 #include "base/memory/ref_counted.h" 25 #include "base/memory/scoped_ptr.h" 26 #include "base/message_loop/message_loop.h" 27 #include "base/posix/eintr_wrapper.h" 28 #include "base/synchronization/condition_variable.h" 29 #include "base/synchronization/lock.h" 30 #include "base/threading/platform_thread.h" 31 #include "base/threading/thread.h" 32 #include "net/socket/socket_descriptor.h" 33 #include "net/socket/unix_domain_socket_posix.h" 34 #include "testing/gtest/include/gtest/gtest.h" 35 36 using std::queue; 37 using std::string; 38 39 namespace net { 40 namespace { 41 42 const char kSocketFilename[] = "unix_domain_socket_for_testing"; 43 const char kInvalidSocketPath[] = "/invalid/path"; 44 const char kMsg[] = "hello"; 45 46 enum EventType { 47 EVENT_ACCEPT, 48 EVENT_AUTH_DENIED, 49 EVENT_AUTH_GRANTED, 50 EVENT_CLOSE, 51 EVENT_LISTEN, 52 EVENT_READ, 53 }; 54 55 string MakeSocketPath(const string& socket_file_name) { 56 base::FilePath temp_dir; 57 base::GetTempDir(&temp_dir); 58 return temp_dir.Append(socket_file_name).value(); 59 } 60 61 string MakeSocketPath() { 62 return MakeSocketPath(kSocketFilename); 63 } 64 65 class EventManager : public base::RefCounted<EventManager> { 66 public: 67 EventManager() : condition_(&mutex_) {} 68 69 bool HasPendingEvent() { 70 base::AutoLock lock(mutex_); 71 return !events_.empty(); 72 } 73 74 void Notify(EventType event) { 75 base::AutoLock lock(mutex_); 76 events_.push(event); 77 condition_.Broadcast(); 78 } 79 80 EventType WaitForEvent() { 81 base::AutoLock lock(mutex_); 82 while (events_.empty()) 83 condition_.Wait(); 84 EventType event = events_.front(); 85 events_.pop(); 86 return event; 87 } 88 89 private: 90 friend class base::RefCounted<EventManager>; 91 virtual ~EventManager() {} 92 93 queue<EventType> events_; 94 base::Lock mutex_; 95 base::ConditionVariable condition_; 96 }; 97 98 class TestListenSocketDelegate : public StreamListenSocket::Delegate { 99 public: 100 explicit TestListenSocketDelegate( 101 const scoped_refptr<EventManager>& event_manager) 102 : event_manager_(event_manager) {} 103 104 virtual void DidAccept(StreamListenSocket* server, 105 scoped_ptr<StreamListenSocket> connection) OVERRIDE { 106 LOG(ERROR) << __PRETTY_FUNCTION__; 107 connection_ = connection.Pass(); 108 Notify(EVENT_ACCEPT); 109 } 110 111 virtual void DidRead(StreamListenSocket* connection, 112 const char* data, 113 int len) OVERRIDE { 114 { 115 base::AutoLock lock(mutex_); 116 DCHECK(len); 117 data_.assign(data, len - 1); 118 } 119 Notify(EVENT_READ); 120 } 121 122 virtual void DidClose(StreamListenSocket* sock) OVERRIDE { 123 Notify(EVENT_CLOSE); 124 } 125 126 void OnListenCompleted() { 127 Notify(EVENT_LISTEN); 128 } 129 130 string ReceivedData() { 131 base::AutoLock lock(mutex_); 132 return data_; 133 } 134 135 private: 136 void Notify(EventType event) { 137 event_manager_->Notify(event); 138 } 139 140 const scoped_refptr<EventManager> event_manager_; 141 scoped_ptr<StreamListenSocket> connection_; 142 base::Lock mutex_; 143 string data_; 144 }; 145 146 bool UserCanConnectCallback( 147 bool allow_user, const scoped_refptr<EventManager>& event_manager, 148 uid_t, gid_t) { 149 event_manager->Notify( 150 allow_user ? EVENT_AUTH_GRANTED : EVENT_AUTH_DENIED); 151 return allow_user; 152 } 153 154 class UnixDomainSocketTestHelper : public testing::Test { 155 public: 156 void CreateAndListen() { 157 socket_ = UnixDomainSocket::CreateAndListen( 158 file_path_.value(), socket_delegate_.get(), MakeAuthCallback()); 159 socket_delegate_->OnListenCompleted(); 160 } 161 162 protected: 163 UnixDomainSocketTestHelper(const string& path, bool allow_user) 164 : file_path_(path), 165 allow_user_(allow_user) {} 166 167 virtual void SetUp() OVERRIDE { 168 event_manager_ = new EventManager(); 169 socket_delegate_.reset(new TestListenSocketDelegate(event_manager_)); 170 DeleteSocketFile(); 171 } 172 173 virtual void TearDown() OVERRIDE { 174 DeleteSocketFile(); 175 socket_.reset(); 176 socket_delegate_.reset(); 177 event_manager_ = NULL; 178 } 179 180 UnixDomainSocket::AuthCallback MakeAuthCallback() { 181 return base::Bind(&UserCanConnectCallback, allow_user_, event_manager_); 182 } 183 184 void DeleteSocketFile() { 185 ASSERT_FALSE(file_path_.empty()); 186 base::DeleteFile(file_path_, false /* not recursive */); 187 } 188 189 SocketDescriptor CreateClientSocket() { 190 const SocketDescriptor sock = CreatePlatformSocket(PF_UNIX, SOCK_STREAM, 0); 191 if (sock < 0) { 192 LOG(ERROR) << "socket() error"; 193 return kInvalidSocket; 194 } 195 sockaddr_un addr; 196 memset(&addr, 0, sizeof(addr)); 197 addr.sun_family = AF_UNIX; 198 socklen_t addr_len; 199 strncpy(addr.sun_path, file_path_.value().c_str(), sizeof(addr.sun_path)); 200 addr_len = sizeof(sockaddr_un); 201 if (connect(sock, reinterpret_cast<sockaddr*>(&addr), addr_len) != 0) { 202 LOG(ERROR) << "connect() error"; 203 return kInvalidSocket; 204 } 205 return sock; 206 } 207 208 scoped_ptr<base::Thread> CreateAndRunServerThread() { 209 base::Thread::Options options; 210 options.message_loop_type = base::MessageLoop::TYPE_IO; 211 scoped_ptr<base::Thread> thread(new base::Thread("socketio_test")); 212 thread->StartWithOptions(options); 213 thread->message_loop()->PostTask( 214 FROM_HERE, 215 base::Bind(&UnixDomainSocketTestHelper::CreateAndListen, 216 base::Unretained(this))); 217 return thread.Pass(); 218 } 219 220 const base::FilePath file_path_; 221 const bool allow_user_; 222 scoped_refptr<EventManager> event_manager_; 223 scoped_ptr<TestListenSocketDelegate> socket_delegate_; 224 scoped_ptr<UnixDomainSocket> socket_; 225 }; 226 227 class UnixDomainSocketTest : public UnixDomainSocketTestHelper { 228 protected: 229 UnixDomainSocketTest() 230 : UnixDomainSocketTestHelper(MakeSocketPath(), true /* allow user */) {} 231 }; 232 233 class UnixDomainSocketTestWithInvalidPath : public UnixDomainSocketTestHelper { 234 protected: 235 UnixDomainSocketTestWithInvalidPath() 236 : UnixDomainSocketTestHelper(kInvalidSocketPath, true) {} 237 }; 238 239 class UnixDomainSocketTestWithForbiddenUser 240 : public UnixDomainSocketTestHelper { 241 protected: 242 UnixDomainSocketTestWithForbiddenUser() 243 : UnixDomainSocketTestHelper(MakeSocketPath(), false /* forbid user */) {} 244 }; 245 246 TEST_F(UnixDomainSocketTest, CreateAndListen) { 247 CreateAndListen(); 248 EXPECT_FALSE(socket_.get() == NULL); 249 } 250 251 TEST_F(UnixDomainSocketTestWithInvalidPath, CreateAndListenWithInvalidPath) { 252 CreateAndListen(); 253 EXPECT_TRUE(socket_.get() == NULL); 254 } 255 256 #ifdef SOCKET_ABSTRACT_NAMESPACE_SUPPORTED 257 // Test with an invalid path to make sure that the socket is not backed by a 258 // file. 259 TEST_F(UnixDomainSocketTestWithInvalidPath, 260 CreateAndListenWithAbstractNamespace) { 261 socket_ = UnixDomainSocket::CreateAndListenWithAbstractNamespace( 262 file_path_.value(), "", socket_delegate_.get(), MakeAuthCallback()); 263 EXPECT_FALSE(socket_.get() == NULL); 264 } 265 266 TEST_F(UnixDomainSocketTest, TestFallbackName) { 267 scoped_ptr<UnixDomainSocket> existing_socket = 268 UnixDomainSocket::CreateAndListenWithAbstractNamespace( 269 file_path_.value(), "", socket_delegate_.get(), MakeAuthCallback()); 270 EXPECT_FALSE(existing_socket.get() == NULL); 271 // First, try to bind socket with the same name with no fallback name. 272 socket_ = 273 UnixDomainSocket::CreateAndListenWithAbstractNamespace( 274 file_path_.value(), "", socket_delegate_.get(), MakeAuthCallback()); 275 EXPECT_TRUE(socket_.get() == NULL); 276 // Now with a fallback name. 277 const char kFallbackSocketName[] = "unix_domain_socket_for_testing_2"; 278 socket_ = UnixDomainSocket::CreateAndListenWithAbstractNamespace( 279 file_path_.value(), 280 MakeSocketPath(kFallbackSocketName), 281 socket_delegate_.get(), 282 MakeAuthCallback()); 283 EXPECT_FALSE(socket_.get() == NULL); 284 } 285 #endif 286 287 TEST_F(UnixDomainSocketTest, TestWithClient) { 288 const scoped_ptr<base::Thread> server_thread = CreateAndRunServerThread(); 289 EventType event = event_manager_->WaitForEvent(); 290 ASSERT_EQ(EVENT_LISTEN, event); 291 292 // Create the client socket. 293 const SocketDescriptor sock = CreateClientSocket(); 294 ASSERT_NE(kInvalidSocket, sock); 295 event = event_manager_->WaitForEvent(); 296 ASSERT_EQ(EVENT_AUTH_GRANTED, event); 297 event = event_manager_->WaitForEvent(); 298 ASSERT_EQ(EVENT_ACCEPT, event); 299 300 // Send a message from the client to the server. 301 ssize_t ret = HANDLE_EINTR(send(sock, kMsg, sizeof(kMsg), 0)); 302 ASSERT_NE(-1, ret); 303 ASSERT_EQ(sizeof(kMsg), static_cast<size_t>(ret)); 304 event = event_manager_->WaitForEvent(); 305 ASSERT_EQ(EVENT_READ, event); 306 ASSERT_EQ(kMsg, socket_delegate_->ReceivedData()); 307 308 // Close the client socket. 309 ret = IGNORE_EINTR(close(sock)); 310 event = event_manager_->WaitForEvent(); 311 ASSERT_EQ(EVENT_CLOSE, event); 312 } 313 314 TEST_F(UnixDomainSocketTestWithForbiddenUser, TestWithForbiddenUser) { 315 const scoped_ptr<base::Thread> server_thread = CreateAndRunServerThread(); 316 EventType event = event_manager_->WaitForEvent(); 317 ASSERT_EQ(EVENT_LISTEN, event); 318 const SocketDescriptor sock = CreateClientSocket(); 319 ASSERT_NE(kInvalidSocket, sock); 320 321 event = event_manager_->WaitForEvent(); 322 ASSERT_EQ(EVENT_AUTH_DENIED, event); 323 324 // Wait until the file descriptor is closed by the server. 325 struct pollfd poll_fd; 326 poll_fd.fd = sock; 327 poll_fd.events = POLLIN; 328 poll(&poll_fd, 1, -1 /* rely on GTest for timeout handling */); 329 330 // Send() must fail. 331 ssize_t ret = HANDLE_EINTR(send(sock, kMsg, sizeof(kMsg), 0)); 332 ASSERT_EQ(-1, ret); 333 ASSERT_EQ(EPIPE, errno); 334 ASSERT_FALSE(event_manager_->HasPendingEvent()); 335 } 336 337 } // namespace 338 } // namespace net 339