Home | History | Annotate | Download | only in socket
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/socket/unix_domain_socket_posix.h"
      6 
      7 #include <cstring>
      8 #include <string>
      9 
     10 #include <errno.h>
     11 #include <sys/socket.h>
     12 #include <sys/stat.h>
     13 #include <sys/types.h>
     14 #include <sys/un.h>
     15 #include <unistd.h>
     16 
     17 #include "base/bind.h"
     18 #include "base/callback.h"
     19 #include "base/posix/eintr_wrapper.h"
     20 #include "base/threading/platform_thread.h"
     21 #include "build/build_config.h"
     22 #include "net/base/net_errors.h"
     23 #include "net/base/net_util.h"
     24 
     25 namespace net {
     26 
     27 namespace {
     28 
     29 bool NoAuthenticationCallback(uid_t, gid_t) {
     30   return true;
     31 }
     32 
     33 bool GetPeerIds(int socket, uid_t* user_id, gid_t* group_id) {
     34 #if defined(OS_LINUX) || defined(OS_ANDROID)
     35   struct ucred user_cred;
     36   socklen_t len = sizeof(user_cred);
     37   if (getsockopt(socket, SOL_SOCKET, SO_PEERCRED, &user_cred, &len) == -1)
     38     return false;
     39   *user_id = user_cred.uid;
     40   *group_id = user_cred.gid;
     41 #else
     42   if (getpeereid(socket, user_id, group_id) == -1)
     43     return false;
     44 #endif
     45   return true;
     46 }
     47 
     48 }  // namespace
     49 
     50 // static
     51 UnixDomainSocket::AuthCallback NoAuthentication() {
     52   return base::Bind(NoAuthenticationCallback);
     53 }
     54 
     55 // static
     56 UnixDomainSocket* UnixDomainSocket::CreateAndListenInternal(
     57     const std::string& path,
     58     const std::string& fallback_path,
     59     StreamListenSocket::Delegate* del,
     60     const AuthCallback& auth_callback,
     61     bool use_abstract_namespace) {
     62   SocketDescriptor s = CreateAndBind(path, use_abstract_namespace);
     63   if (s == kInvalidSocket && !fallback_path.empty())
     64     s = CreateAndBind(fallback_path, use_abstract_namespace);
     65   if (s == kInvalidSocket)
     66     return NULL;
     67   UnixDomainSocket* sock = new UnixDomainSocket(s, del, auth_callback);
     68   sock->Listen();
     69   return sock;
     70 }
     71 
     72 // static
     73 scoped_refptr<UnixDomainSocket> UnixDomainSocket::CreateAndListen(
     74     const std::string& path,
     75     StreamListenSocket::Delegate* del,
     76     const AuthCallback& auth_callback) {
     77   return CreateAndListenInternal(path, "", del, auth_callback, false);
     78 }
     79 
     80 #if defined(SOCKET_ABSTRACT_NAMESPACE_SUPPORTED)
     81 // static
     82 scoped_refptr<UnixDomainSocket>
     83 UnixDomainSocket::CreateAndListenWithAbstractNamespace(
     84     const std::string& path,
     85     const std::string& fallback_path,
     86     StreamListenSocket::Delegate* del,
     87     const AuthCallback& auth_callback) {
     88   return make_scoped_refptr(
     89       CreateAndListenInternal(path, fallback_path, del, auth_callback, true));
     90 }
     91 #endif
     92 
     93 UnixDomainSocket::UnixDomainSocket(
     94     SocketDescriptor s,
     95     StreamListenSocket::Delegate* del,
     96     const AuthCallback& auth_callback)
     97     : StreamListenSocket(s, del),
     98       auth_callback_(auth_callback) {}
     99 
    100 UnixDomainSocket::~UnixDomainSocket() {}
    101 
    102 // static
    103 SocketDescriptor UnixDomainSocket::CreateAndBind(const std::string& path,
    104                                                  bool use_abstract_namespace) {
    105   sockaddr_un addr;
    106   static const size_t kPathMax = sizeof(addr.sun_path);
    107   if (use_abstract_namespace + path.size() + 1 /* '\0' */ > kPathMax)
    108     return kInvalidSocket;
    109   const SocketDescriptor s = socket(PF_UNIX, SOCK_STREAM, 0);
    110   if (s == kInvalidSocket)
    111     return kInvalidSocket;
    112   memset(&addr, 0, sizeof(addr));
    113   addr.sun_family = AF_UNIX;
    114   socklen_t addr_len;
    115   if (use_abstract_namespace) {
    116     // Convert the path given into abstract socket name. It must start with
    117     // the '\0' character, so we are adding it. |addr_len| must specify the
    118     // length of the structure exactly, as potentially the socket name may
    119     // have '\0' characters embedded (although we don't support this).
    120     // Note that addr.sun_path is already zero initialized.
    121     memcpy(addr.sun_path + 1, path.c_str(), path.size());
    122     addr_len = path.size() + offsetof(struct sockaddr_un, sun_path) + 1;
    123   } else {
    124     memcpy(addr.sun_path, path.c_str(), path.size());
    125     addr_len = sizeof(sockaddr_un);
    126   }
    127   if (bind(s, reinterpret_cast<sockaddr*>(&addr), addr_len)) {
    128     LOG(ERROR) << "Could not bind unix domain socket to " << path;
    129     if (use_abstract_namespace)
    130       LOG(ERROR) << " (with abstract namespace enabled)";
    131     if (HANDLE_EINTR(close(s)) < 0)
    132       LOG(ERROR) << "close() error";
    133     return kInvalidSocket;
    134   }
    135   return s;
    136 }
    137 
    138 void UnixDomainSocket::Accept() {
    139   SocketDescriptor conn = StreamListenSocket::AcceptSocket();
    140   if (conn == kInvalidSocket)
    141     return;
    142   uid_t user_id;
    143   gid_t group_id;
    144   if (!GetPeerIds(conn, &user_id, &group_id) ||
    145       !auth_callback_.Run(user_id, group_id)) {
    146     if (HANDLE_EINTR(close(conn)) < 0)
    147       LOG(ERROR) << "close() error";
    148     return;
    149   }
    150   scoped_refptr<UnixDomainSocket> sock(
    151       new UnixDomainSocket(conn, socket_delegate_, auth_callback_));
    152   // It's up to the delegate to AddRef if it wants to keep it around.
    153   sock->WatchSocket(WAITING_READ);
    154   socket_delegate_->DidAccept(this, sock.get());
    155 }
    156 
    157 UnixDomainSocketFactory::UnixDomainSocketFactory(
    158     const std::string& path,
    159     const UnixDomainSocket::AuthCallback& auth_callback)
    160     : path_(path),
    161       auth_callback_(auth_callback) {}
    162 
    163 UnixDomainSocketFactory::~UnixDomainSocketFactory() {}
    164 
    165 scoped_refptr<StreamListenSocket> UnixDomainSocketFactory::CreateAndListen(
    166     StreamListenSocket::Delegate* delegate) const {
    167   return UnixDomainSocket::CreateAndListen(
    168       path_, delegate, auth_callback_);
    169 }
    170 
    171 #if defined(SOCKET_ABSTRACT_NAMESPACE_SUPPORTED)
    172 
    173 UnixDomainSocketWithAbstractNamespaceFactory::
    174 UnixDomainSocketWithAbstractNamespaceFactory(
    175     const std::string& path,
    176     const std::string& fallback_path,
    177     const UnixDomainSocket::AuthCallback& auth_callback)
    178     : UnixDomainSocketFactory(path, auth_callback),
    179       fallback_path_(fallback_path) {}
    180 
    181 UnixDomainSocketWithAbstractNamespaceFactory::
    182 ~UnixDomainSocketWithAbstractNamespaceFactory() {}
    183 
    184 scoped_refptr<StreamListenSocket>
    185 UnixDomainSocketWithAbstractNamespaceFactory::CreateAndListen(
    186     StreamListenSocket::Delegate* delegate) const {
    187   return UnixDomainSocket::CreateAndListenWithAbstractNamespace(
    188       path_, fallback_path_, delegate, auth_callback_);
    189 }
    190 
    191 #endif
    192 
    193 }  // namespace net
    194