Home | History | Annotate | Download | only in socket
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include <errno.h>
      6 #include <fcntl.h>
      7 #include <poll.h>
      8 #include <sys/socket.h>
      9 #include <sys/stat.h>
     10 #include <sys/time.h>
     11 #include <sys/types.h>
     12 #include <sys/un.h>
     13 #include <unistd.h>
     14 
     15 #include <cstring>
     16 #include <queue>
     17 #include <string>
     18 
     19 #include "base/bind.h"
     20 #include "base/callback.h"
     21 #include "base/compiler_specific.h"
     22 #include "base/file_util.h"
     23 #include "base/files/file_path.h"
     24 #include "base/memory/ref_counted.h"
     25 #include "base/memory/scoped_ptr.h"
     26 #include "base/message_loop/message_loop.h"
     27 #include "base/posix/eintr_wrapper.h"
     28 #include "base/synchronization/condition_variable.h"
     29 #include "base/synchronization/lock.h"
     30 #include "base/threading/platform_thread.h"
     31 #include "base/threading/thread.h"
     32 #include "net/socket/socket_descriptor.h"
     33 #include "net/socket/unix_domain_socket_posix.h"
     34 #include "testing/gtest/include/gtest/gtest.h"
     35 
     36 using std::queue;
     37 using std::string;
     38 
     39 namespace net {
     40 namespace {
     41 
     42 const char kSocketFilename[] = "unix_domain_socket_for_testing";
     43 const char kInvalidSocketPath[] = "/invalid/path";
     44 const char kMsg[] = "hello";
     45 
     46 enum EventType {
     47   EVENT_ACCEPT,
     48   EVENT_AUTH_DENIED,
     49   EVENT_AUTH_GRANTED,
     50   EVENT_CLOSE,
     51   EVENT_LISTEN,
     52   EVENT_READ,
     53 };
     54 
     55 string MakeSocketPath(const string& socket_file_name) {
     56   base::FilePath temp_dir;
     57   base::GetTempDir(&temp_dir);
     58   return temp_dir.Append(socket_file_name).value();
     59 }
     60 
     61 string MakeSocketPath() {
     62   return MakeSocketPath(kSocketFilename);
     63 }
     64 
     65 class EventManager : public base::RefCounted<EventManager> {
     66  public:
     67   EventManager() : condition_(&mutex_) {}
     68 
     69   bool HasPendingEvent() {
     70     base::AutoLock lock(mutex_);
     71     return !events_.empty();
     72   }
     73 
     74   void Notify(EventType event) {
     75     base::AutoLock lock(mutex_);
     76     events_.push(event);
     77     condition_.Broadcast();
     78   }
     79 
     80   EventType WaitForEvent() {
     81     base::AutoLock lock(mutex_);
     82     while (events_.empty())
     83       condition_.Wait();
     84     EventType event = events_.front();
     85     events_.pop();
     86     return event;
     87   }
     88 
     89  private:
     90   friend class base::RefCounted<EventManager>;
     91   virtual ~EventManager() {}
     92 
     93   queue<EventType> events_;
     94   base::Lock mutex_;
     95   base::ConditionVariable condition_;
     96 };
     97 
     98 class TestListenSocketDelegate : public StreamListenSocket::Delegate {
     99  public:
    100   explicit TestListenSocketDelegate(
    101       const scoped_refptr<EventManager>& event_manager)
    102       : event_manager_(event_manager) {}
    103 
    104   virtual void DidAccept(StreamListenSocket* server,
    105                          scoped_ptr<StreamListenSocket> connection) OVERRIDE {
    106     LOG(ERROR) << __PRETTY_FUNCTION__;
    107     connection_ = connection.Pass();
    108     Notify(EVENT_ACCEPT);
    109   }
    110 
    111   virtual void DidRead(StreamListenSocket* connection,
    112                        const char* data,
    113                        int len) OVERRIDE {
    114     {
    115       base::AutoLock lock(mutex_);
    116       DCHECK(len);
    117       data_.assign(data, len - 1);
    118     }
    119     Notify(EVENT_READ);
    120   }
    121 
    122   virtual void DidClose(StreamListenSocket* sock) OVERRIDE {
    123     Notify(EVENT_CLOSE);
    124   }
    125 
    126   void OnListenCompleted() {
    127     Notify(EVENT_LISTEN);
    128   }
    129 
    130   string ReceivedData() {
    131     base::AutoLock lock(mutex_);
    132     return data_;
    133   }
    134 
    135  private:
    136   void Notify(EventType event) {
    137     event_manager_->Notify(event);
    138   }
    139 
    140   const scoped_refptr<EventManager> event_manager_;
    141   scoped_ptr<StreamListenSocket> connection_;
    142   base::Lock mutex_;
    143   string data_;
    144 };
    145 
    146 bool UserCanConnectCallback(
    147     bool allow_user, const scoped_refptr<EventManager>& event_manager,
    148     uid_t, gid_t) {
    149   event_manager->Notify(
    150       allow_user ? EVENT_AUTH_GRANTED : EVENT_AUTH_DENIED);
    151   return allow_user;
    152 }
    153 
    154 class UnixDomainSocketTestHelper : public testing::Test {
    155  public:
    156   void CreateAndListen() {
    157     socket_ = UnixDomainSocket::CreateAndListen(
    158         file_path_.value(), socket_delegate_.get(), MakeAuthCallback());
    159     socket_delegate_->OnListenCompleted();
    160   }
    161 
    162  protected:
    163   UnixDomainSocketTestHelper(const string& path, bool allow_user)
    164       : file_path_(path),
    165         allow_user_(allow_user) {}
    166 
    167   virtual void SetUp() OVERRIDE {
    168     event_manager_ = new EventManager();
    169     socket_delegate_.reset(new TestListenSocketDelegate(event_manager_));
    170     DeleteSocketFile();
    171   }
    172 
    173   virtual void TearDown() OVERRIDE {
    174     DeleteSocketFile();
    175     socket_.reset();
    176     socket_delegate_.reset();
    177     event_manager_ = NULL;
    178   }
    179 
    180   UnixDomainSocket::AuthCallback MakeAuthCallback() {
    181     return base::Bind(&UserCanConnectCallback, allow_user_, event_manager_);
    182   }
    183 
    184   void DeleteSocketFile() {
    185     ASSERT_FALSE(file_path_.empty());
    186     base::DeleteFile(file_path_, false /* not recursive */);
    187   }
    188 
    189   SocketDescriptor CreateClientSocket() {
    190     const SocketDescriptor sock = CreatePlatformSocket(PF_UNIX, SOCK_STREAM, 0);
    191     if (sock < 0) {
    192       LOG(ERROR) << "socket() error";
    193       return kInvalidSocket;
    194     }
    195     sockaddr_un addr;
    196     memset(&addr, 0, sizeof(addr));
    197     addr.sun_family = AF_UNIX;
    198     socklen_t addr_len;
    199     strncpy(addr.sun_path, file_path_.value().c_str(), sizeof(addr.sun_path));
    200     addr_len = sizeof(sockaddr_un);
    201     if (connect(sock, reinterpret_cast<sockaddr*>(&addr), addr_len) != 0) {
    202       LOG(ERROR) << "connect() error";
    203       return kInvalidSocket;
    204     }
    205     return sock;
    206   }
    207 
    208   scoped_ptr<base::Thread> CreateAndRunServerThread() {
    209     base::Thread::Options options;
    210     options.message_loop_type = base::MessageLoop::TYPE_IO;
    211     scoped_ptr<base::Thread> thread(new base::Thread("socketio_test"));
    212     thread->StartWithOptions(options);
    213     thread->message_loop()->PostTask(
    214         FROM_HERE,
    215         base::Bind(&UnixDomainSocketTestHelper::CreateAndListen,
    216                    base::Unretained(this)));
    217     return thread.Pass();
    218   }
    219 
    220   const base::FilePath file_path_;
    221   const bool allow_user_;
    222   scoped_refptr<EventManager> event_manager_;
    223   scoped_ptr<TestListenSocketDelegate> socket_delegate_;
    224   scoped_ptr<UnixDomainSocket> socket_;
    225 };
    226 
    227 class UnixDomainSocketTest : public UnixDomainSocketTestHelper {
    228  protected:
    229   UnixDomainSocketTest()
    230       : UnixDomainSocketTestHelper(MakeSocketPath(), true /* allow user */) {}
    231 };
    232 
    233 class UnixDomainSocketTestWithInvalidPath : public UnixDomainSocketTestHelper {
    234  protected:
    235   UnixDomainSocketTestWithInvalidPath()
    236       : UnixDomainSocketTestHelper(kInvalidSocketPath, true) {}
    237 };
    238 
    239 class UnixDomainSocketTestWithForbiddenUser
    240     : public UnixDomainSocketTestHelper {
    241  protected:
    242   UnixDomainSocketTestWithForbiddenUser()
    243       : UnixDomainSocketTestHelper(MakeSocketPath(), false /* forbid user */) {}
    244 };
    245 
    246 TEST_F(UnixDomainSocketTest, CreateAndListen) {
    247   CreateAndListen();
    248   EXPECT_FALSE(socket_.get() == NULL);
    249 }
    250 
    251 TEST_F(UnixDomainSocketTestWithInvalidPath, CreateAndListenWithInvalidPath) {
    252   CreateAndListen();
    253   EXPECT_TRUE(socket_.get() == NULL);
    254 }
    255 
    256 #ifdef SOCKET_ABSTRACT_NAMESPACE_SUPPORTED
    257 // Test with an invalid path to make sure that the socket is not backed by a
    258 // file.
    259 TEST_F(UnixDomainSocketTestWithInvalidPath,
    260        CreateAndListenWithAbstractNamespace) {
    261   socket_ = UnixDomainSocket::CreateAndListenWithAbstractNamespace(
    262       file_path_.value(), "", socket_delegate_.get(), MakeAuthCallback());
    263   EXPECT_FALSE(socket_.get() == NULL);
    264 }
    265 
    266 TEST_F(UnixDomainSocketTest, TestFallbackName) {
    267   scoped_ptr<UnixDomainSocket> existing_socket =
    268       UnixDomainSocket::CreateAndListenWithAbstractNamespace(
    269           file_path_.value(), "", socket_delegate_.get(), MakeAuthCallback());
    270   EXPECT_FALSE(existing_socket.get() == NULL);
    271   // First, try to bind socket with the same name with no fallback name.
    272   socket_ =
    273       UnixDomainSocket::CreateAndListenWithAbstractNamespace(
    274           file_path_.value(), "", socket_delegate_.get(), MakeAuthCallback());
    275   EXPECT_TRUE(socket_.get() == NULL);
    276   // Now with a fallback name.
    277   const char kFallbackSocketName[] = "unix_domain_socket_for_testing_2";
    278   socket_ = UnixDomainSocket::CreateAndListenWithAbstractNamespace(
    279       file_path_.value(),
    280       MakeSocketPath(kFallbackSocketName),
    281       socket_delegate_.get(),
    282       MakeAuthCallback());
    283   EXPECT_FALSE(socket_.get() == NULL);
    284 }
    285 #endif
    286 
    287 TEST_F(UnixDomainSocketTest, TestWithClient) {
    288   const scoped_ptr<base::Thread> server_thread = CreateAndRunServerThread();
    289   EventType event = event_manager_->WaitForEvent();
    290   ASSERT_EQ(EVENT_LISTEN, event);
    291 
    292   // Create the client socket.
    293   const SocketDescriptor sock = CreateClientSocket();
    294   ASSERT_NE(kInvalidSocket, sock);
    295   event = event_manager_->WaitForEvent();
    296   ASSERT_EQ(EVENT_AUTH_GRANTED, event);
    297   event = event_manager_->WaitForEvent();
    298   ASSERT_EQ(EVENT_ACCEPT, event);
    299 
    300   // Send a message from the client to the server.
    301   ssize_t ret = HANDLE_EINTR(send(sock, kMsg, sizeof(kMsg), 0));
    302   ASSERT_NE(-1, ret);
    303   ASSERT_EQ(sizeof(kMsg), static_cast<size_t>(ret));
    304   event = event_manager_->WaitForEvent();
    305   ASSERT_EQ(EVENT_READ, event);
    306   ASSERT_EQ(kMsg, socket_delegate_->ReceivedData());
    307 
    308   // Close the client socket.
    309   ret = IGNORE_EINTR(close(sock));
    310   event = event_manager_->WaitForEvent();
    311   ASSERT_EQ(EVENT_CLOSE, event);
    312 }
    313 
    314 TEST_F(UnixDomainSocketTestWithForbiddenUser, TestWithForbiddenUser) {
    315   const scoped_ptr<base::Thread> server_thread = CreateAndRunServerThread();
    316   EventType event = event_manager_->WaitForEvent();
    317   ASSERT_EQ(EVENT_LISTEN, event);
    318   const SocketDescriptor sock = CreateClientSocket();
    319   ASSERT_NE(kInvalidSocket, sock);
    320 
    321   event = event_manager_->WaitForEvent();
    322   ASSERT_EQ(EVENT_AUTH_DENIED, event);
    323 
    324   // Wait until the file descriptor is closed by the server.
    325   struct pollfd poll_fd;
    326   poll_fd.fd = sock;
    327   poll_fd.events = POLLIN;
    328   poll(&poll_fd, 1, -1 /* rely on GTest for timeout handling */);
    329 
    330   // Send() must fail.
    331   ssize_t ret = HANDLE_EINTR(send(sock, kMsg, sizeof(kMsg), 0));
    332   ASSERT_EQ(-1, ret);
    333   ASSERT_EQ(EPIPE, errno);
    334   ASSERT_FALSE(event_manager_->HasPendingEvent());
    335 }
    336 
    337 }  // namespace
    338 }  // namespace net
    339