Home | History | Annotate | Download | only in socket
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "net/socket/unix_domain_socket_posix.h"
      6 
      7 #include <cstring>
      8 #include <string>
      9 
     10 #include <errno.h>
     11 #include <sys/socket.h>
     12 #include <sys/stat.h>
     13 #include <sys/types.h>
     14 #include <sys/un.h>
     15 #include <unistd.h>
     16 
     17 #include "base/bind.h"
     18 #include "base/callback.h"
     19 #include "base/posix/eintr_wrapper.h"
     20 #include "base/threading/platform_thread.h"
     21 #include "build/build_config.h"
     22 #include "net/base/net_errors.h"
     23 #include "net/base/net_util.h"
     24 #include "net/socket/socket_descriptor.h"
     25 
     26 namespace net {
     27 
     28 namespace {
     29 
     30 bool NoAuthenticationCallback(uid_t, gid_t) {
     31   return true;
     32 }
     33 
     34 bool GetPeerIds(int socket, uid_t* user_id, gid_t* group_id) {
     35 #if defined(OS_LINUX) || defined(OS_ANDROID)
     36   struct ucred user_cred;
     37   socklen_t len = sizeof(user_cred);
     38   if (getsockopt(socket, SOL_SOCKET, SO_PEERCRED, &user_cred, &len) == -1)
     39     return false;
     40   *user_id = user_cred.uid;
     41   *group_id = user_cred.gid;
     42 #else
     43   if (getpeereid(socket, user_id, group_id) == -1)
     44     return false;
     45 #endif
     46   return true;
     47 }
     48 
     49 }  // namespace
     50 
     51 // static
     52 UnixDomainSocket::AuthCallback UnixDomainSocket::NoAuthentication() {
     53   return base::Bind(NoAuthenticationCallback);
     54 }
     55 
     56 // static
     57 scoped_ptr<UnixDomainSocket> UnixDomainSocket::CreateAndListenInternal(
     58     const std::string& path,
     59     const std::string& fallback_path,
     60     StreamListenSocket::Delegate* del,
     61     const AuthCallback& auth_callback,
     62     bool use_abstract_namespace) {
     63   SocketDescriptor s = CreateAndBind(path, use_abstract_namespace);
     64   if (s == kInvalidSocket && !fallback_path.empty())
     65     s = CreateAndBind(fallback_path, use_abstract_namespace);
     66   if (s == kInvalidSocket)
     67     return scoped_ptr<UnixDomainSocket>();
     68   scoped_ptr<UnixDomainSocket> sock(
     69       new UnixDomainSocket(s, del, auth_callback));
     70   sock->Listen();
     71   return sock.Pass();
     72 }
     73 
     74 // static
     75 scoped_ptr<UnixDomainSocket> UnixDomainSocket::CreateAndListen(
     76     const std::string& path,
     77     StreamListenSocket::Delegate* del,
     78     const AuthCallback& auth_callback) {
     79   return CreateAndListenInternal(path, "", del, auth_callback, false);
     80 }
     81 
     82 #if defined(SOCKET_ABSTRACT_NAMESPACE_SUPPORTED)
     83 // static
     84 scoped_ptr<UnixDomainSocket>
     85 UnixDomainSocket::CreateAndListenWithAbstractNamespace(
     86     const std::string& path,
     87     const std::string& fallback_path,
     88     StreamListenSocket::Delegate* del,
     89     const AuthCallback& auth_callback) {
     90   return
     91       CreateAndListenInternal(path, fallback_path, del, auth_callback, true);
     92 }
     93 #endif
     94 
     95 UnixDomainSocket::UnixDomainSocket(
     96     SocketDescriptor s,
     97     StreamListenSocket::Delegate* del,
     98     const AuthCallback& auth_callback)
     99     : StreamListenSocket(s, del),
    100       auth_callback_(auth_callback) {}
    101 
    102 UnixDomainSocket::~UnixDomainSocket() {}
    103 
    104 // static
    105 SocketDescriptor UnixDomainSocket::CreateAndBind(const std::string& path,
    106                                                  bool use_abstract_namespace) {
    107   sockaddr_un addr;
    108   static const size_t kPathMax = sizeof(addr.sun_path);
    109   if (use_abstract_namespace + path.size() + 1 /* '\0' */ > kPathMax)
    110     return kInvalidSocket;
    111   const SocketDescriptor s = CreatePlatformSocket(PF_UNIX, SOCK_STREAM, 0);
    112   if (s == kInvalidSocket)
    113     return kInvalidSocket;
    114   memset(&addr, 0, sizeof(addr));
    115   addr.sun_family = AF_UNIX;
    116   socklen_t addr_len;
    117   if (use_abstract_namespace) {
    118     // Convert the path given into abstract socket name. It must start with
    119     // the '\0' character, so we are adding it. |addr_len| must specify the
    120     // length of the structure exactly, as potentially the socket name may
    121     // have '\0' characters embedded (although we don't support this).
    122     // Note that addr.sun_path is already zero initialized.
    123     memcpy(addr.sun_path + 1, path.c_str(), path.size());
    124     addr_len = path.size() + offsetof(struct sockaddr_un, sun_path) + 1;
    125   } else {
    126     memcpy(addr.sun_path, path.c_str(), path.size());
    127     addr_len = sizeof(sockaddr_un);
    128   }
    129   if (bind(s, reinterpret_cast<sockaddr*>(&addr), addr_len)) {
    130     LOG(ERROR) << "Could not bind unix domain socket to " << path;
    131     if (use_abstract_namespace)
    132       LOG(ERROR) << " (with abstract namespace enabled)";
    133     if (IGNORE_EINTR(close(s)) < 0)
    134       LOG(ERROR) << "close() error";
    135     return kInvalidSocket;
    136   }
    137   return s;
    138 }
    139 
    140 void UnixDomainSocket::Accept() {
    141   SocketDescriptor conn = StreamListenSocket::AcceptSocket();
    142   if (conn == kInvalidSocket)
    143     return;
    144   uid_t user_id;
    145   gid_t group_id;
    146   if (!GetPeerIds(conn, &user_id, &group_id) ||
    147       !auth_callback_.Run(user_id, group_id)) {
    148     if (IGNORE_EINTR(close(conn)) < 0)
    149       LOG(ERROR) << "close() error";
    150     return;
    151   }
    152   scoped_ptr<UnixDomainSocket> sock(
    153       new UnixDomainSocket(conn, socket_delegate_, auth_callback_));
    154   // It's up to the delegate to AddRef if it wants to keep it around.
    155   sock->WatchSocket(WAITING_READ);
    156   socket_delegate_->DidAccept(this, sock.PassAs<StreamListenSocket>());
    157 }
    158 
    159 UnixDomainSocketFactory::UnixDomainSocketFactory(
    160     const std::string& path,
    161     const UnixDomainSocket::AuthCallback& auth_callback)
    162     : path_(path),
    163       auth_callback_(auth_callback) {}
    164 
    165 UnixDomainSocketFactory::~UnixDomainSocketFactory() {}
    166 
    167 scoped_ptr<StreamListenSocket> UnixDomainSocketFactory::CreateAndListen(
    168     StreamListenSocket::Delegate* delegate) const {
    169   return UnixDomainSocket::CreateAndListen(
    170       path_, delegate, auth_callback_).PassAs<StreamListenSocket>();
    171 }
    172 
    173 #if defined(SOCKET_ABSTRACT_NAMESPACE_SUPPORTED)
    174 
    175 UnixDomainSocketWithAbstractNamespaceFactory::
    176 UnixDomainSocketWithAbstractNamespaceFactory(
    177     const std::string& path,
    178     const std::string& fallback_path,
    179     const UnixDomainSocket::AuthCallback& auth_callback)
    180     : UnixDomainSocketFactory(path, auth_callback),
    181       fallback_path_(fallback_path) {}
    182 
    183 UnixDomainSocketWithAbstractNamespaceFactory::
    184 ~UnixDomainSocketWithAbstractNamespaceFactory() {}
    185 
    186 scoped_ptr<StreamListenSocket>
    187 UnixDomainSocketWithAbstractNamespaceFactory::CreateAndListen(
    188     StreamListenSocket::Delegate* delegate) const {
    189   return UnixDomainSocket::CreateAndListenWithAbstractNamespace(
    190              path_, fallback_path_, delegate, auth_callback_)
    191          .PassAs<StreamListenSocket>();
    192 }
    193 
    194 #endif
    195 
    196 }  // namespace net
    197