Home | History | Annotate | Download | only in forwarder2
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "tools/android/forwarder2/socket.h"
      6 
      7 #include <arpa/inet.h>
      8 #include <fcntl.h>
      9 #include <netdb.h>
     10 #include <netinet/in.h>
     11 #include <stdio.h>
     12 #include <string.h>
     13 #include <sys/socket.h>
     14 #include <sys/types.h>
     15 #include <unistd.h>
     16 
     17 #include "base/logging.h"
     18 #include "base/posix/eintr_wrapper.h"
     19 #include "base/safe_strerror_posix.h"
     20 #include "tools/android/common/net.h"
     21 #include "tools/android/forwarder2/common.h"
     22 
     23 namespace {
     24 const int kNoTimeout = -1;
     25 const int kConnectTimeOut = 10;  // Seconds.
     26 
     27 bool FamilyIsTCP(int family) {
     28   return family == AF_INET || family == AF_INET6;
     29 }
     30 }  // namespace
     31 
     32 namespace forwarder2 {
     33 
     34 bool Socket::BindUnix(const std::string& path) {
     35   errno = 0;
     36   if (!InitUnixSocket(path) || !BindAndListen()) {
     37     Close();
     38     return false;
     39   }
     40   return true;
     41 }
     42 
     43 bool Socket::BindTcp(const std::string& host, int port) {
     44   errno = 0;
     45   if (!InitTcpSocket(host, port) || !BindAndListen()) {
     46     Close();
     47     return false;
     48   }
     49   return true;
     50 }
     51 
     52 bool Socket::ConnectUnix(const std::string& path) {
     53   errno = 0;
     54   if (!InitUnixSocket(path) || !Connect()) {
     55     Close();
     56     return false;
     57   }
     58   return true;
     59 }
     60 
     61 bool Socket::ConnectTcp(const std::string& host, int port) {
     62   errno = 0;
     63   if (!InitTcpSocket(host, port) || !Connect()) {
     64     Close();
     65     return false;
     66   }
     67   return true;
     68 }
     69 
     70 Socket::Socket()
     71     : socket_(-1),
     72       port_(0),
     73       socket_error_(false),
     74       family_(AF_INET),
     75       addr_ptr_(reinterpret_cast<sockaddr*>(&addr_.addr4)),
     76       addr_len_(sizeof(sockaddr)) {
     77   memset(&addr_, 0, sizeof(addr_));
     78 }
     79 
     80 Socket::~Socket() {
     81   Close();
     82 }
     83 
     84 void Socket::Shutdown() {
     85   if (!IsClosed()) {
     86     PRESERVE_ERRNO_HANDLE_EINTR(shutdown(socket_, SHUT_RDWR));
     87   }
     88 }
     89 
     90 void Socket::Close() {
     91   if (!IsClosed()) {
     92     CloseFD(socket_);
     93     socket_ = -1;
     94   }
     95 }
     96 
     97 bool Socket::InitSocketInternal() {
     98   socket_ = socket(family_, SOCK_STREAM, 0);
     99   if (socket_ < 0) {
    100     PLOG(ERROR) << "socket";
    101     return false;
    102   }
    103   tools::DisableNagle(socket_);
    104   int reuse_addr = 1;
    105   setsockopt(socket_, SOL_SOCKET, SO_REUSEADDR, &reuse_addr,
    106              sizeof(reuse_addr));
    107   if (!SetNonBlocking())
    108     return false;
    109   return true;
    110 }
    111 
    112 bool Socket::SetNonBlocking() {
    113   const int flags = fcntl(socket_, F_GETFL);
    114   if (flags < 0) {
    115     PLOG(ERROR) << "fcntl";
    116     return false;
    117   }
    118   if (flags & O_NONBLOCK)
    119     return true;
    120   if (fcntl(socket_, F_SETFL, flags | O_NONBLOCK) < 0) {
    121     PLOG(ERROR) << "fcntl";
    122     return false;
    123   }
    124   return true;
    125 }
    126 
    127 bool Socket::InitUnixSocket(const std::string& path) {
    128   static const size_t kPathMax = sizeof(addr_.addr_un.sun_path);
    129   // For abstract sockets we need one extra byte for the leading zero.
    130   if (path.size() + 2 /* '\0' */ > kPathMax) {
    131     LOG(ERROR) << "The provided path is too big to create a unix "
    132                << "domain socket: " << path;
    133     return false;
    134   }
    135   family_ = PF_UNIX;
    136   addr_.addr_un.sun_family = family_;
    137   // Copied from net/socket/unix_domain_socket_posix.cc
    138   // Convert the path given into abstract socket name. It must start with
    139   // the '\0' character, so we are adding it. |addr_len| must specify the
    140   // length of the structure exactly, as potentially the socket name may
    141   // have '\0' characters embedded (although we don't support this).
    142   // Note that addr_.addr_un.sun_path is already zero initialized.
    143   memcpy(addr_.addr_un.sun_path + 1, path.c_str(), path.size());
    144   addr_len_ = path.size() + offsetof(struct sockaddr_un, sun_path) + 1;
    145   addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr_un);
    146   return InitSocketInternal();
    147 }
    148 
    149 bool Socket::InitTcpSocket(const std::string& host, int port) {
    150   port_ = port;
    151   if (host.empty()) {
    152     // Use localhost: INADDR_LOOPBACK
    153     family_ = AF_INET;
    154     addr_.addr4.sin_family = family_;
    155     addr_.addr4.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
    156   } else if (!Resolve(host)) {
    157     return false;
    158   }
    159   CHECK(FamilyIsTCP(family_)) << "Invalid socket family.";
    160   if (family_ == AF_INET) {
    161     addr_.addr4.sin_port = htons(port_);
    162     addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr4);
    163     addr_len_ = sizeof(addr_.addr4);
    164   } else if (family_ == AF_INET6) {
    165     addr_.addr6.sin6_port = htons(port_);
    166     addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr6);
    167     addr_len_ = sizeof(addr_.addr6);
    168   }
    169   return InitSocketInternal();
    170 }
    171 
    172 bool Socket::BindAndListen() {
    173   errno = 0;
    174   if (HANDLE_EINTR(bind(socket_, addr_ptr_, addr_len_)) < 0 ||
    175       HANDLE_EINTR(listen(socket_, SOMAXCONN)) < 0) {
    176     PLOG(ERROR) << "bind/listen";
    177     SetSocketError();
    178     return false;
    179   }
    180   if (port_ == 0 && FamilyIsTCP(family_)) {
    181     SockAddr addr;
    182     memset(&addr, 0, sizeof(addr));
    183     socklen_t addrlen = 0;
    184     sockaddr* addr_ptr = NULL;
    185     uint16* port_ptr = NULL;
    186     if (family_ == AF_INET) {
    187       addr_ptr = reinterpret_cast<sockaddr*>(&addr.addr4);
    188       port_ptr = &addr.addr4.sin_port;
    189       addrlen = sizeof(addr.addr4);
    190     } else if (family_ == AF_INET6) {
    191       addr_ptr = reinterpret_cast<sockaddr*>(&addr.addr6);
    192       port_ptr = &addr.addr6.sin6_port;
    193       addrlen = sizeof(addr.addr6);
    194     }
    195     errno = 0;
    196     if (getsockname(socket_, addr_ptr, &addrlen) != 0) {
    197       PLOG(ERROR) << "getsockname";
    198       SetSocketError();
    199       return false;
    200     }
    201     port_ = ntohs(*port_ptr);
    202   }
    203   return true;
    204 }
    205 
    206 bool Socket::Accept(Socket* new_socket) {
    207   DCHECK(new_socket != NULL);
    208   if (!WaitForEvent(READ, kNoTimeout)) {
    209     SetSocketError();
    210     return false;
    211   }
    212   errno = 0;
    213   int new_socket_fd = HANDLE_EINTR(accept(socket_, NULL, NULL));
    214   if (new_socket_fd < 0) {
    215     SetSocketError();
    216     return false;
    217   }
    218   tools::DisableNagle(new_socket_fd);
    219   new_socket->socket_ = new_socket_fd;
    220   if (!new_socket->SetNonBlocking())
    221     return false;
    222   return true;
    223 }
    224 
    225 bool Socket::Connect() {
    226   DCHECK(fcntl(socket_, F_GETFL) & O_NONBLOCK);
    227   errno = 0;
    228   if (HANDLE_EINTR(connect(socket_, addr_ptr_, addr_len_)) < 0 &&
    229       errno != EINPROGRESS) {
    230     SetSocketError();
    231     return false;
    232   }
    233   // Wait for connection to complete, or receive a notification.
    234   if (!WaitForEvent(WRITE, kConnectTimeOut)) {
    235     SetSocketError();
    236     return false;
    237   }
    238   int socket_errno;
    239   socklen_t opt_len = sizeof(socket_errno);
    240   if (getsockopt(socket_, SOL_SOCKET, SO_ERROR, &socket_errno, &opt_len) < 0) {
    241     PLOG(ERROR) << "getsockopt()";
    242     SetSocketError();
    243     return false;
    244   }
    245   if (socket_errno != 0) {
    246     LOG(ERROR) << "Could not connect to host: " << safe_strerror(socket_errno);
    247     SetSocketError();
    248     return false;
    249   }
    250   return true;
    251 }
    252 
    253 bool Socket::Resolve(const std::string& host) {
    254   struct addrinfo hints;
    255   struct addrinfo* res;
    256   memset(&hints, 0, sizeof(hints));
    257   hints.ai_family = AF_UNSPEC;
    258   hints.ai_socktype = SOCK_STREAM;
    259   hints.ai_flags |= AI_CANONNAME;
    260 
    261   int errcode = getaddrinfo(host.c_str(), NULL, &hints, &res);
    262   if (errcode != 0) {
    263     errno = 0;
    264     SetSocketError();
    265     freeaddrinfo(res);
    266     return false;
    267   }
    268   family_ = res->ai_family;
    269   switch (res->ai_family) {
    270     case AF_INET:
    271       memcpy(&addr_.addr4,
    272              reinterpret_cast<sockaddr_in*>(res->ai_addr),
    273              sizeof(sockaddr_in));
    274       break;
    275     case AF_INET6:
    276       memcpy(&addr_.addr6,
    277              reinterpret_cast<sockaddr_in6*>(res->ai_addr),
    278              sizeof(sockaddr_in6));
    279       break;
    280   }
    281   freeaddrinfo(res);
    282   return true;
    283 }
    284 
    285 int Socket::GetPort() {
    286   if (!FamilyIsTCP(family_)) {
    287     LOG(ERROR) << "Can't call GetPort() on an unix domain socket.";
    288     return 0;
    289   }
    290   return port_;
    291 }
    292 
    293 int Socket::ReadNumBytes(void* buffer, size_t num_bytes) {
    294   int bytes_read = 0;
    295   int ret = 1;
    296   while (bytes_read < num_bytes && ret > 0) {
    297     ret = Read(static_cast<char*>(buffer) + bytes_read, num_bytes - bytes_read);
    298     if (ret >= 0)
    299       bytes_read += ret;
    300   }
    301   return bytes_read;
    302 }
    303 
    304 void Socket::SetSocketError() {
    305   socket_error_ = true;
    306   DCHECK_NE(EAGAIN, errno);
    307   DCHECK_NE(EWOULDBLOCK, errno);
    308   Close();
    309 }
    310 
    311 int Socket::Read(void* buffer, size_t buffer_size) {
    312   if (!WaitForEvent(READ, kNoTimeout)) {
    313     SetSocketError();
    314     return 0;
    315   }
    316   int ret = HANDLE_EINTR(read(socket_, buffer, buffer_size));
    317   if (ret < 0) {
    318     PLOG(ERROR) << "read";
    319     SetSocketError();
    320   }
    321   return ret;
    322 }
    323 
    324 int Socket::NonBlockingRead(void* buffer, size_t buffer_size) {
    325   DCHECK(fcntl(socket_, F_GETFL) & O_NONBLOCK);
    326   int ret = HANDLE_EINTR(read(socket_, buffer, buffer_size));
    327   if (ret < 0) {
    328     PLOG(ERROR) << "read";
    329     SetSocketError();
    330   }
    331   return ret;
    332 }
    333 
    334 int Socket::Write(const void* buffer, size_t count) {
    335   if (!WaitForEvent(WRITE, kNoTimeout)) {
    336     SetSocketError();
    337     return 0;
    338   }
    339   int ret = HANDLE_EINTR(send(socket_, buffer, count, MSG_NOSIGNAL));
    340   if (ret < 0) {
    341     PLOG(ERROR) << "send";
    342     SetSocketError();
    343   }
    344   return ret;
    345 }
    346 
    347 int Socket::NonBlockingWrite(const void* buffer, size_t count) {
    348   DCHECK(fcntl(socket_, F_GETFL) & O_NONBLOCK);
    349   int ret = HANDLE_EINTR(send(socket_, buffer, count, MSG_NOSIGNAL));
    350   if (ret < 0) {
    351     PLOG(ERROR) << "send";
    352     SetSocketError();
    353   }
    354   return ret;
    355 }
    356 
    357 int Socket::WriteString(const std::string& buffer) {
    358   return WriteNumBytes(buffer.c_str(), buffer.size());
    359 }
    360 
    361 void Socket::AddEventFd(int event_fd) {
    362   Event event;
    363   event.fd = event_fd;
    364   event.was_fired = false;
    365   events_.push_back(event);
    366 }
    367 
    368 bool Socket::DidReceiveEventOnFd(int fd) const {
    369   for (size_t i = 0; i < events_.size(); ++i)
    370     if (events_[i].fd == fd)
    371       return events_[i].was_fired;
    372   return false;
    373 }
    374 
    375 bool Socket::DidReceiveEvent() const {
    376   for (size_t i = 0; i < events_.size(); ++i)
    377     if (events_[i].was_fired)
    378       return true;
    379   return false;
    380 }
    381 
    382 int Socket::WriteNumBytes(const void* buffer, size_t num_bytes) {
    383   int bytes_written = 0;
    384   int ret = 1;
    385   while (bytes_written < num_bytes && ret > 0) {
    386     ret = Write(static_cast<const char*>(buffer) + bytes_written,
    387                 num_bytes - bytes_written);
    388     if (ret >= 0)
    389       bytes_written += ret;
    390   }
    391   return bytes_written;
    392 }
    393 
    394 bool Socket::WaitForEvent(EventType type, int timeout_secs) {
    395   if (socket_ == -1)
    396     return true;
    397   DCHECK(fcntl(socket_, F_GETFL) & O_NONBLOCK);
    398   fd_set read_fds;
    399   fd_set write_fds;
    400   FD_ZERO(&read_fds);
    401   FD_ZERO(&write_fds);
    402   if (type == READ)
    403     FD_SET(socket_, &read_fds);
    404   else
    405     FD_SET(socket_, &write_fds);
    406   for (size_t i = 0; i < events_.size(); ++i)
    407     FD_SET(events_[i].fd, &read_fds);
    408   timeval tv = {};
    409   timeval* tv_ptr = NULL;
    410   if (timeout_secs > 0) {
    411     tv.tv_sec = timeout_secs;
    412     tv.tv_usec = 0;
    413     tv_ptr = &tv;
    414   }
    415   int max_fd = socket_;
    416   for (size_t i = 0; i < events_.size(); ++i)
    417     if (events_[i].fd > max_fd)
    418       max_fd = events_[i].fd;
    419   if (HANDLE_EINTR(
    420           select(max_fd + 1, &read_fds, &write_fds, NULL, tv_ptr)) <= 0) {
    421     PLOG(ERROR) << "select";
    422     return false;
    423   }
    424   bool event_was_fired = false;
    425   for (size_t i = 0; i < events_.size(); ++i) {
    426     if (FD_ISSET(events_[i].fd, &read_fds)) {
    427       events_[i].was_fired = true;
    428       event_was_fired = true;
    429     }
    430   }
    431   return !event_was_fired;
    432 }
    433 
    434 // static
    435 pid_t Socket::GetUnixDomainSocketProcessOwner(const std::string& path) {
    436   Socket socket;
    437   if (!socket.ConnectUnix(path))
    438     return -1;
    439   ucred ucred;
    440   socklen_t len = sizeof(ucred);
    441   if (getsockopt(socket.socket_, SOL_SOCKET, SO_PEERCRED, &ucred, &len) == -1) {
    442     CHECK_NE(ENOPROTOOPT, errno);
    443     return -1;
    444   }
    445   return ucred.pid;
    446 }
    447 
    448 }  // namespace forwarder2
    449