Home | History | Annotate | Download | only in socket
      1 // Copyright 2014 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/socket/unix_domain_listen_socket_posix.h"
      6 
      7 #include <errno.h>
      8 #include <fcntl.h>
      9 #include <poll.h>
     10 #include <sys/socket.h>
     11 #include <sys/stat.h>
     12 #include <sys/time.h>
     13 #include <sys/types.h>
     14 #include <sys/un.h>
     15 #include <unistd.h>
     16 
     17 #include <cstring>
     18 #include <queue>
     19 #include <string>
     20 
     21 #include "base/bind.h"
     22 #include "base/callback.h"
     23 #include "base/compiler_specific.h"
     24 #include "base/files/file_path.h"
     25 #include "base/files/file_util.h"
     26 #include "base/files/scoped_temp_dir.h"
     27 #include "base/memory/ref_counted.h"
     28 #include "base/memory/scoped_ptr.h"
     29 #include "base/message_loop/message_loop.h"
     30 #include "base/posix/eintr_wrapper.h"
     31 #include "base/synchronization/condition_variable.h"
     32 #include "base/synchronization/lock.h"
     33 #include "base/threading/platform_thread.h"
     34 #include "base/threading/thread.h"
     35 #include "net/socket/socket_descriptor.h"
     36 #include "testing/gtest/include/gtest/gtest.h"
     37 
     38 using std::queue;
     39 using std::string;
     40 
     41 namespace net {
     42 namespace deprecated {
     43 namespace {
     44 
     45 const char kSocketFilename[] = "socket_for_testing";
     46 const char kInvalidSocketPath[] = "/invalid/path";
     47 const char kMsg[] = "hello";
     48 
     49 enum EventType {
     50   EVENT_ACCEPT,
     51   EVENT_AUTH_DENIED,
     52   EVENT_AUTH_GRANTED,
     53   EVENT_CLOSE,
     54   EVENT_LISTEN,
     55   EVENT_READ,
     56 };
     57 
     58 class EventManager : public base::RefCounted<EventManager> {
     59  public:
     60   EventManager() : condition_(&mutex_) {}
     61 
     62   bool HasPendingEvent() {
     63     base::AutoLock lock(mutex_);
     64     return !events_.empty();
     65   }
     66 
     67   void Notify(EventType event) {
     68     base::AutoLock lock(mutex_);
     69     events_.push(event);
     70     condition_.Broadcast();
     71   }
     72 
     73   EventType WaitForEvent() {
     74     base::AutoLock lock(mutex_);
     75     while (events_.empty())
     76       condition_.Wait();
     77     EventType event = events_.front();
     78     events_.pop();
     79     return event;
     80   }
     81 
     82  private:
     83   friend class base::RefCounted<EventManager>;
     84   virtual ~EventManager() {}
     85 
     86   queue<EventType> events_;
     87   base::Lock mutex_;
     88   base::ConditionVariable condition_;
     89 };
     90 
     91 class TestListenSocketDelegate : public StreamListenSocket::Delegate {
     92  public:
     93   explicit TestListenSocketDelegate(
     94       const scoped_refptr<EventManager>& event_manager)
     95       : event_manager_(event_manager) {}
     96 
     97   virtual void DidAccept(StreamListenSocket* server,
     98                          scoped_ptr<StreamListenSocket> connection) OVERRIDE {
     99     LOG(ERROR) << __PRETTY_FUNCTION__;
    100     connection_ = connection.Pass();
    101     Notify(EVENT_ACCEPT);
    102   }
    103 
    104   virtual void DidRead(StreamListenSocket* connection,
    105                        const char* data,
    106                        int len) OVERRIDE {
    107     {
    108       base::AutoLock lock(mutex_);
    109       DCHECK(len);
    110       data_.assign(data, len - 1);
    111     }
    112     Notify(EVENT_READ);
    113   }
    114 
    115   virtual void DidClose(StreamListenSocket* sock) OVERRIDE {
    116     Notify(EVENT_CLOSE);
    117   }
    118 
    119   void OnListenCompleted() {
    120     Notify(EVENT_LISTEN);
    121   }
    122 
    123   string ReceivedData() {
    124     base::AutoLock lock(mutex_);
    125     return data_;
    126   }
    127 
    128  private:
    129   void Notify(EventType event) {
    130     event_manager_->Notify(event);
    131   }
    132 
    133   const scoped_refptr<EventManager> event_manager_;
    134   scoped_ptr<StreamListenSocket> connection_;
    135   base::Lock mutex_;
    136   string data_;
    137 };
    138 
    139 bool UserCanConnectCallback(
    140     bool allow_user, const scoped_refptr<EventManager>& event_manager,
    141     const UnixDomainServerSocket::Credentials&) {
    142   event_manager->Notify(
    143       allow_user ? EVENT_AUTH_GRANTED : EVENT_AUTH_DENIED);
    144   return allow_user;
    145 }
    146 
    147 class UnixDomainListenSocketTestHelper : public testing::Test {
    148  public:
    149   void CreateAndListen() {
    150     socket_ = UnixDomainListenSocket::CreateAndListen(
    151         file_path_.value(), socket_delegate_.get(), MakeAuthCallback());
    152     socket_delegate_->OnListenCompleted();
    153   }
    154 
    155  protected:
    156   UnixDomainListenSocketTestHelper(const string& path_str, bool allow_user)
    157       : allow_user_(allow_user) {
    158     file_path_ = base::FilePath(path_str);
    159     if (!file_path_.IsAbsolute()) {
    160       EXPECT_TRUE(temp_dir_.CreateUniqueTempDir());
    161       file_path_ = GetTempSocketPath(file_path_.value());
    162     }
    163     // Beware that if path_str is an absolute path, this class doesn't delete
    164     // the file. It must be an invalid path and cannot be created by unittests.
    165   }
    166 
    167   base::FilePath GetTempSocketPath(const std::string socket_name) {
    168     DCHECK(temp_dir_.IsValid());
    169     return temp_dir_.path().Append(socket_name);
    170   }
    171 
    172   virtual void SetUp() OVERRIDE {
    173     event_manager_ = new EventManager();
    174     socket_delegate_.reset(new TestListenSocketDelegate(event_manager_));
    175   }
    176 
    177   virtual void TearDown() OVERRIDE {
    178     socket_.reset();
    179     socket_delegate_.reset();
    180     event_manager_ = NULL;
    181   }
    182 
    183   UnixDomainListenSocket::AuthCallback MakeAuthCallback() {
    184     return base::Bind(&UserCanConnectCallback, allow_user_, event_manager_);
    185   }
    186 
    187   SocketDescriptor CreateClientSocket() {
    188     const SocketDescriptor sock = CreatePlatformSocket(PF_UNIX, SOCK_STREAM, 0);
    189     if (sock < 0) {
    190       LOG(ERROR) << "socket() error";
    191       return kInvalidSocket;
    192     }
    193     sockaddr_un addr;
    194     memset(&addr, 0, sizeof(addr));
    195     addr.sun_family = AF_UNIX;
    196     socklen_t addr_len;
    197     strncpy(addr.sun_path, file_path_.value().c_str(), sizeof(addr.sun_path));
    198     addr_len = sizeof(sockaddr_un);
    199     if (connect(sock, reinterpret_cast<sockaddr*>(&addr), addr_len) != 0) {
    200       LOG(ERROR) << "connect() error: " << strerror(errno)
    201                  << ": path=" << file_path_.value();
    202       return kInvalidSocket;
    203     }
    204     return sock;
    205   }
    206 
    207   scoped_ptr<base::Thread> CreateAndRunServerThread() {
    208     base::Thread::Options options;
    209     options.message_loop_type = base::MessageLoop::TYPE_IO;
    210     scoped_ptr<base::Thread> thread(new base::Thread("socketio_test"));
    211     thread->StartWithOptions(options);
    212     thread->message_loop()->PostTask(
    213         FROM_HERE,
    214         base::Bind(&UnixDomainListenSocketTestHelper::CreateAndListen,
    215                    base::Unretained(this)));
    216     return thread.Pass();
    217   }
    218 
    219   base::ScopedTempDir temp_dir_;
    220   base::FilePath file_path_;
    221   const bool allow_user_;
    222   scoped_refptr<EventManager> event_manager_;
    223   scoped_ptr<TestListenSocketDelegate> socket_delegate_;
    224   scoped_ptr<UnixDomainListenSocket> socket_;
    225 };
    226 
    227 class UnixDomainListenSocketTest : public UnixDomainListenSocketTestHelper {
    228  protected:
    229   UnixDomainListenSocketTest()
    230       : UnixDomainListenSocketTestHelper(kSocketFilename,
    231                                          true /* allow user */) {}
    232 };
    233 
    234 class UnixDomainListenSocketTestWithInvalidPath
    235     : public UnixDomainListenSocketTestHelper {
    236  protected:
    237   UnixDomainListenSocketTestWithInvalidPath()
    238       : UnixDomainListenSocketTestHelper(kInvalidSocketPath, true) {}
    239 };
    240 
    241 class UnixDomainListenSocketTestWithForbiddenUser
    242     : public UnixDomainListenSocketTestHelper {
    243  protected:
    244   UnixDomainListenSocketTestWithForbiddenUser()
    245       : UnixDomainListenSocketTestHelper(kSocketFilename,
    246                                          false /* forbid user */) {}
    247 };
    248 
    249 TEST_F(UnixDomainListenSocketTest, CreateAndListen) {
    250   CreateAndListen();
    251   EXPECT_FALSE(socket_.get() == NULL);
    252 }
    253 
    254 TEST_F(UnixDomainListenSocketTestWithInvalidPath,
    255        CreateAndListenWithInvalidPath) {
    256   CreateAndListen();
    257   EXPECT_TRUE(socket_.get() == NULL);
    258 }
    259 
    260 #ifdef SOCKET_ABSTRACT_NAMESPACE_SUPPORTED
    261 // Test with an invalid path to make sure that the socket is not backed by a
    262 // file.
    263 TEST_F(UnixDomainListenSocketTestWithInvalidPath,
    264        CreateAndListenWithAbstractNamespace) {
    265   socket_ = UnixDomainListenSocket::CreateAndListenWithAbstractNamespace(
    266       file_path_.value(), "", socket_delegate_.get(), MakeAuthCallback());
    267   EXPECT_FALSE(socket_.get() == NULL);
    268 }
    269 
    270 TEST_F(UnixDomainListenSocketTest, TestFallbackName) {
    271   scoped_ptr<UnixDomainListenSocket> existing_socket =
    272       UnixDomainListenSocket::CreateAndListenWithAbstractNamespace(
    273           file_path_.value(), "", socket_delegate_.get(), MakeAuthCallback());
    274   EXPECT_FALSE(existing_socket.get() == NULL);
    275   // First, try to bind socket with the same name with no fallback name.
    276   socket_ =
    277       UnixDomainListenSocket::CreateAndListenWithAbstractNamespace(
    278           file_path_.value(), "", socket_delegate_.get(), MakeAuthCallback());
    279   EXPECT_TRUE(socket_.get() == NULL);
    280   // Now with a fallback name.
    281   const char kFallbackSocketName[] = "socket_for_testing_2";
    282   socket_ = UnixDomainListenSocket::CreateAndListenWithAbstractNamespace(
    283       file_path_.value(),
    284       GetTempSocketPath(kFallbackSocketName).value(),
    285       socket_delegate_.get(),
    286       MakeAuthCallback());
    287   EXPECT_FALSE(socket_.get() == NULL);
    288 }
    289 #endif
    290 
    291 TEST_F(UnixDomainListenSocketTest, TestWithClient) {
    292   const scoped_ptr<base::Thread> server_thread = CreateAndRunServerThread();
    293   EventType event = event_manager_->WaitForEvent();
    294   ASSERT_EQ(EVENT_LISTEN, event);
    295 
    296   // Create the client socket.
    297   const SocketDescriptor sock = CreateClientSocket();
    298   ASSERT_NE(kInvalidSocket, sock);
    299   event = event_manager_->WaitForEvent();
    300   ASSERT_EQ(EVENT_AUTH_GRANTED, event);
    301   event = event_manager_->WaitForEvent();
    302   ASSERT_EQ(EVENT_ACCEPT, event);
    303 
    304   // Send a message from the client to the server.
    305   ssize_t ret = HANDLE_EINTR(send(sock, kMsg, sizeof(kMsg), 0));
    306   ASSERT_NE(-1, ret);
    307   ASSERT_EQ(sizeof(kMsg), static_cast<size_t>(ret));
    308   event = event_manager_->WaitForEvent();
    309   ASSERT_EQ(EVENT_READ, event);
    310   ASSERT_EQ(kMsg, socket_delegate_->ReceivedData());
    311 
    312   // Close the client socket.
    313   ret = IGNORE_EINTR(close(sock));
    314   event = event_manager_->WaitForEvent();
    315   ASSERT_EQ(EVENT_CLOSE, event);
    316 }
    317 
    318 TEST_F(UnixDomainListenSocketTestWithForbiddenUser, TestWithForbiddenUser) {
    319   const scoped_ptr<base::Thread> server_thread = CreateAndRunServerThread();
    320   EventType event = event_manager_->WaitForEvent();
    321   ASSERT_EQ(EVENT_LISTEN, event);
    322   const SocketDescriptor sock = CreateClientSocket();
    323   ASSERT_NE(kInvalidSocket, sock);
    324 
    325   event = event_manager_->WaitForEvent();
    326   ASSERT_EQ(EVENT_AUTH_DENIED, event);
    327 
    328   // Wait until the file descriptor is closed by the server.
    329   struct pollfd poll_fd;
    330   poll_fd.fd = sock;
    331   poll_fd.events = POLLIN;
    332   poll(&poll_fd, 1, -1 /* rely on GTest for timeout handling */);
    333 
    334   // Send() must fail.
    335   ssize_t ret = HANDLE_EINTR(send(sock, kMsg, sizeof(kMsg), 0));
    336   ASSERT_EQ(-1, ret);
    337   ASSERT_EQ(EPIPE, errno);
    338   ASSERT_FALSE(event_manager_->HasPendingEvent());
    339 }
    340 
    341 }  // namespace
    342 }  // namespace deprecated
    343 }  // namespace net
    344