Home | History | Annotate | Download | only in forwarder2
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "tools/android/forwarder2/socket.h"
      6 
      7 #include <arpa/inet.h>
      8 #include <fcntl.h>
      9 #include <netdb.h>
     10 #include <netinet/in.h>
     11 #include <stdio.h>
     12 #include <string.h>
     13 #include <sys/socket.h>
     14 #include <sys/types.h>
     15 #include <unistd.h>
     16 
     17 #include "base/logging.h"
     18 #include "base/posix/eintr_wrapper.h"
     19 #include "base/safe_strerror_posix.h"
     20 #include "tools/android/common/net.h"
     21 #include "tools/android/forwarder2/common.h"
     22 
     23 namespace {
     24 const int kNoTimeout = -1;
     25 const int kConnectTimeOut = 10;  // Seconds.
     26 
     27 bool FamilyIsTCP(int family) {
     28   return family == AF_INET || family == AF_INET6;
     29 }
     30 }  // namespace
     31 
     32 namespace forwarder2 {
     33 
     34 bool Socket::BindUnix(const std::string& path) {
     35   errno = 0;
     36   if (!InitUnixSocket(path) || !BindAndListen()) {
     37     Close();
     38     return false;
     39   }
     40   return true;
     41 }
     42 
     43 bool Socket::BindTcp(const std::string& host, int port) {
     44   errno = 0;
     45   if (!InitTcpSocket(host, port) || !BindAndListen()) {
     46     Close();
     47     return false;
     48   }
     49   return true;
     50 }
     51 
     52 bool Socket::ConnectUnix(const std::string& path) {
     53   errno = 0;
     54   if (!InitUnixSocket(path) || !Connect()) {
     55     Close();
     56     return false;
     57   }
     58   return true;
     59 }
     60 
     61 bool Socket::ConnectTcp(const std::string& host, int port) {
     62   errno = 0;
     63   if (!InitTcpSocket(host, port) || !Connect()) {
     64     Close();
     65     return false;
     66   }
     67   return true;
     68 }
     69 
     70 Socket::Socket()
     71     : socket_(-1),
     72       port_(0),
     73       socket_error_(false),
     74       family_(AF_INET),
     75       addr_ptr_(reinterpret_cast<sockaddr*>(&addr_.addr4)),
     76       addr_len_(sizeof(sockaddr)) {
     77   memset(&addr_, 0, sizeof(addr_));
     78 }
     79 
     80 Socket::~Socket() {
     81   Close();
     82 }
     83 
     84 void Socket::Shutdown() {
     85   if (!IsClosed()) {
     86     PRESERVE_ERRNO_HANDLE_EINTR(shutdown(socket_, SHUT_RDWR));
     87   }
     88 }
     89 
     90 void Socket::Close() {
     91   if (!IsClosed()) {
     92     CloseFD(socket_);
     93     socket_ = -1;
     94   }
     95 }
     96 
     97 bool Socket::InitSocketInternal() {
     98   socket_ = socket(family_, SOCK_STREAM, 0);
     99   if (socket_ < 0)
    100     return false;
    101   tools::DisableNagle(socket_);
    102   int reuse_addr = 1;
    103   setsockopt(socket_, SOL_SOCKET, SO_REUSEADDR, &reuse_addr,
    104              sizeof(reuse_addr));
    105   if (!SetNonBlocking())
    106     return false;
    107   return true;
    108 }
    109 
    110 bool Socket::SetNonBlocking() {
    111   const int flags = fcntl(socket_, F_GETFL);
    112   if (flags < 0) {
    113     PLOG(ERROR) << "fcntl";
    114     return false;
    115   }
    116   if (flags & O_NONBLOCK)
    117     return true;
    118   if (fcntl(socket_, F_SETFL, flags | O_NONBLOCK) < 0) {
    119     PLOG(ERROR) << "fcntl";
    120     return false;
    121   }
    122   return true;
    123 }
    124 
    125 bool Socket::InitUnixSocket(const std::string& path) {
    126   static const size_t kPathMax = sizeof(addr_.addr_un.sun_path);
    127   // For abstract sockets we need one extra byte for the leading zero.
    128   if (path.size() + 2 /* '\0' */ > kPathMax) {
    129     LOG(ERROR) << "The provided path is too big to create a unix "
    130                << "domain socket: " << path;
    131     return false;
    132   }
    133   family_ = PF_UNIX;
    134   addr_.addr_un.sun_family = family_;
    135   // Copied from net/socket/unix_domain_socket_posix.cc
    136   // Convert the path given into abstract socket name. It must start with
    137   // the '\0' character, so we are adding it. |addr_len| must specify the
    138   // length of the structure exactly, as potentially the socket name may
    139   // have '\0' characters embedded (although we don't support this).
    140   // Note that addr_.addr_un.sun_path is already zero initialized.
    141   memcpy(addr_.addr_un.sun_path + 1, path.c_str(), path.size());
    142   addr_len_ = path.size() + offsetof(struct sockaddr_un, sun_path) + 1;
    143   addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr_un);
    144   return InitSocketInternal();
    145 }
    146 
    147 bool Socket::InitTcpSocket(const std::string& host, int port) {
    148   port_ = port;
    149   if (host.empty()) {
    150     // Use localhost: INADDR_LOOPBACK
    151     family_ = AF_INET;
    152     addr_.addr4.sin_family = family_;
    153     addr_.addr4.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
    154   } else if (!Resolve(host)) {
    155     return false;
    156   }
    157   CHECK(FamilyIsTCP(family_)) << "Invalid socket family.";
    158   if (family_ == AF_INET) {
    159     addr_.addr4.sin_port = htons(port_);
    160     addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr4);
    161     addr_len_ = sizeof(addr_.addr4);
    162   } else if (family_ == AF_INET6) {
    163     addr_.addr6.sin6_port = htons(port_);
    164     addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr6);
    165     addr_len_ = sizeof(addr_.addr6);
    166   }
    167   return InitSocketInternal();
    168 }
    169 
    170 bool Socket::BindAndListen() {
    171   errno = 0;
    172   if (HANDLE_EINTR(bind(socket_, addr_ptr_, addr_len_)) < 0 ||
    173       HANDLE_EINTR(listen(socket_, SOMAXCONN)) < 0) {
    174     SetSocketError();
    175     return false;
    176   }
    177   if (port_ == 0 && FamilyIsTCP(family_)) {
    178     SockAddr addr;
    179     memset(&addr, 0, sizeof(addr));
    180     socklen_t addrlen = 0;
    181     sockaddr* addr_ptr = NULL;
    182     uint16* port_ptr = NULL;
    183     if (family_ == AF_INET) {
    184       addr_ptr = reinterpret_cast<sockaddr*>(&addr.addr4);
    185       port_ptr = &addr.addr4.sin_port;
    186       addrlen = sizeof(addr.addr4);
    187     } else if (family_ == AF_INET6) {
    188       addr_ptr = reinterpret_cast<sockaddr*>(&addr.addr6);
    189       port_ptr = &addr.addr6.sin6_port;
    190       addrlen = sizeof(addr.addr6);
    191     }
    192     errno = 0;
    193     if (getsockname(socket_, addr_ptr, &addrlen) != 0) {
    194       PLOG(ERROR) << "getsockname";
    195       SetSocketError();
    196       return false;
    197     }
    198     port_ = ntohs(*port_ptr);
    199   }
    200   return true;
    201 }
    202 
    203 bool Socket::Accept(Socket* new_socket) {
    204   DCHECK(new_socket != NULL);
    205   if (!WaitForEvent(READ, kNoTimeout)) {
    206     SetSocketError();
    207     return false;
    208   }
    209   errno = 0;
    210   int new_socket_fd = HANDLE_EINTR(accept(socket_, NULL, NULL));
    211   if (new_socket_fd < 0) {
    212     SetSocketError();
    213     return false;
    214   }
    215   tools::DisableNagle(new_socket_fd);
    216   new_socket->socket_ = new_socket_fd;
    217   if (!new_socket->SetNonBlocking())
    218     return false;
    219   return true;
    220 }
    221 
    222 bool Socket::Connect() {
    223   DCHECK(fcntl(socket_, F_GETFL) & O_NONBLOCK);
    224   errno = 0;
    225   if (HANDLE_EINTR(connect(socket_, addr_ptr_, addr_len_)) < 0 &&
    226       errno != EINPROGRESS) {
    227     SetSocketError();
    228     return false;
    229   }
    230   // Wait for connection to complete, or receive a notification.
    231   if (!WaitForEvent(WRITE, kConnectTimeOut)) {
    232     SetSocketError();
    233     return false;
    234   }
    235   int socket_errno;
    236   socklen_t opt_len = sizeof(socket_errno);
    237   if (getsockopt(socket_, SOL_SOCKET, SO_ERROR, &socket_errno, &opt_len) < 0) {
    238     PLOG(ERROR) << "getsockopt()";
    239     SetSocketError();
    240     return false;
    241   }
    242   if (socket_errno != 0) {
    243     LOG(ERROR) << "Could not connect to host: " << safe_strerror(socket_errno);
    244     SetSocketError();
    245     return false;
    246   }
    247   return true;
    248 }
    249 
    250 bool Socket::Resolve(const std::string& host) {
    251   struct addrinfo hints;
    252   struct addrinfo* res;
    253   memset(&hints, 0, sizeof(hints));
    254   hints.ai_family = AF_UNSPEC;
    255   hints.ai_socktype = SOCK_STREAM;
    256   hints.ai_flags |= AI_CANONNAME;
    257 
    258   int errcode = getaddrinfo(host.c_str(), NULL, &hints, &res);
    259   if (errcode != 0) {
    260     errno = 0;
    261     SetSocketError();
    262     freeaddrinfo(res);
    263     return false;
    264   }
    265   family_ = res->ai_family;
    266   switch (res->ai_family) {
    267     case AF_INET:
    268       memcpy(&addr_.addr4,
    269              reinterpret_cast<sockaddr_in*>(res->ai_addr),
    270              sizeof(sockaddr_in));
    271       break;
    272     case AF_INET6:
    273       memcpy(&addr_.addr6,
    274              reinterpret_cast<sockaddr_in6*>(res->ai_addr),
    275              sizeof(sockaddr_in6));
    276       break;
    277   }
    278   freeaddrinfo(res);
    279   return true;
    280 }
    281 
    282 int Socket::GetPort() {
    283   if (!FamilyIsTCP(family_)) {
    284     LOG(ERROR) << "Can't call GetPort() on an unix domain socket.";
    285     return 0;
    286   }
    287   return port_;
    288 }
    289 
    290 bool Socket::IsFdInSet(const fd_set& fds) const {
    291   if (IsClosed())
    292     return false;
    293   return FD_ISSET(socket_, &fds);
    294 }
    295 
    296 bool Socket::AddFdToSet(fd_set* fds) const {
    297   if (IsClosed())
    298     return false;
    299   FD_SET(socket_, fds);
    300   return true;
    301 }
    302 
    303 int Socket::ReadNumBytes(void* buffer, size_t num_bytes) {
    304   int bytes_read = 0;
    305   int ret = 1;
    306   while (bytes_read < num_bytes && ret > 0) {
    307     ret = Read(static_cast<char*>(buffer) + bytes_read, num_bytes - bytes_read);
    308     if (ret >= 0)
    309       bytes_read += ret;
    310   }
    311   return bytes_read;
    312 }
    313 
    314 void Socket::SetSocketError() {
    315   socket_error_ = true;
    316   DCHECK_NE(EAGAIN, errno);
    317   DCHECK_NE(EWOULDBLOCK, errno);
    318   Close();
    319 }
    320 
    321 int Socket::Read(void* buffer, size_t buffer_size) {
    322   if (!WaitForEvent(READ, kNoTimeout)) {
    323     SetSocketError();
    324     return 0;
    325   }
    326   int ret = HANDLE_EINTR(read(socket_, buffer, buffer_size));
    327   if (ret < 0) {
    328     PLOG(ERROR) << "read";
    329     SetSocketError();
    330   }
    331   return ret;
    332 }
    333 
    334 int Socket::NonBlockingRead(void* buffer, size_t buffer_size) {
    335   DCHECK(fcntl(socket_, F_GETFL) & O_NONBLOCK);
    336   int ret = HANDLE_EINTR(read(socket_, buffer, buffer_size));
    337   if (ret < 0) {
    338     PLOG(ERROR) << "read";
    339     SetSocketError();
    340   }
    341   return ret;
    342 }
    343 
    344 int Socket::Write(const void* buffer, size_t count) {
    345   if (!WaitForEvent(WRITE, kNoTimeout)) {
    346     SetSocketError();
    347     return 0;
    348   }
    349   int ret = HANDLE_EINTR(send(socket_, buffer, count, MSG_NOSIGNAL));
    350   if (ret < 0) {
    351     PLOG(ERROR) << "send";
    352     SetSocketError();
    353   }
    354   return ret;
    355 }
    356 
    357 int Socket::NonBlockingWrite(const void* buffer, size_t count) {
    358   DCHECK(fcntl(socket_, F_GETFL) & O_NONBLOCK);
    359   int ret = HANDLE_EINTR(send(socket_, buffer, count, MSG_NOSIGNAL));
    360   if (ret < 0) {
    361     PLOG(ERROR) << "send";
    362     SetSocketError();
    363   }
    364   return ret;
    365 }
    366 
    367 int Socket::WriteString(const std::string& buffer) {
    368   return WriteNumBytes(buffer.c_str(), buffer.size());
    369 }
    370 
    371 void Socket::AddEventFd(int event_fd) {
    372   Event event;
    373   event.fd = event_fd;
    374   event.was_fired = false;
    375   events_.push_back(event);
    376 }
    377 
    378 bool Socket::DidReceiveEventOnFd(int fd) const {
    379   for (size_t i = 0; i < events_.size(); ++i)
    380     if (events_[i].fd == fd)
    381       return events_[i].was_fired;
    382   return false;
    383 }
    384 
    385 bool Socket::DidReceiveEvent() const {
    386   for (size_t i = 0; i < events_.size(); ++i)
    387     if (events_[i].was_fired)
    388       return true;
    389   return false;
    390 }
    391 
    392 int Socket::WriteNumBytes(const void* buffer, size_t num_bytes) {
    393   int bytes_written = 0;
    394   int ret = 1;
    395   while (bytes_written < num_bytes && ret > 0) {
    396     ret = Write(static_cast<const char*>(buffer) + bytes_written,
    397                 num_bytes - bytes_written);
    398     if (ret >= 0)
    399       bytes_written += ret;
    400   }
    401   return bytes_written;
    402 }
    403 
    404 bool Socket::WaitForEvent(EventType type, int timeout_secs) {
    405   if (socket_ == -1)
    406     return true;
    407   DCHECK(fcntl(socket_, F_GETFL) & O_NONBLOCK);
    408   fd_set read_fds;
    409   fd_set write_fds;
    410   FD_ZERO(&read_fds);
    411   FD_ZERO(&write_fds);
    412   if (type == READ)
    413     FD_SET(socket_, &read_fds);
    414   else
    415     FD_SET(socket_, &write_fds);
    416   for (size_t i = 0; i < events_.size(); ++i)
    417     FD_SET(events_[i].fd, &read_fds);
    418   timeval tv = {};
    419   timeval* tv_ptr = NULL;
    420   if (timeout_secs > 0) {
    421     tv.tv_sec = timeout_secs;
    422     tv.tv_usec = 0;
    423     tv_ptr = &tv;
    424   }
    425   int max_fd = socket_;
    426   for (size_t i = 0; i < events_.size(); ++i)
    427     if (events_[i].fd > max_fd)
    428       max_fd = events_[i].fd;
    429   if (HANDLE_EINTR(
    430           select(max_fd + 1, &read_fds, &write_fds, NULL, tv_ptr)) <= 0) {
    431     PLOG(ERROR) << "select";
    432     return false;
    433   }
    434   bool event_was_fired = false;
    435   for (size_t i = 0; i < events_.size(); ++i) {
    436     if (FD_ISSET(events_[i].fd, &read_fds)) {
    437       events_[i].was_fired = true;
    438       event_was_fired = true;
    439     }
    440   }
    441   return !event_was_fired;
    442 }
    443 
    444 // static
    445 int Socket::GetHighestFileDescriptor(const Socket& s1, const Socket& s2) {
    446   return std::max(s1.socket_, s2.socket_);
    447 }
    448 
    449 // static
    450 pid_t Socket::GetUnixDomainSocketProcessOwner(const std::string& path) {
    451   Socket socket;
    452   if (!socket.ConnectUnix(path))
    453     return -1;
    454   ucred ucred;
    455   socklen_t len = sizeof(ucred);
    456   if (getsockopt(socket.socket_, SOL_SOCKET, SO_PEERCRED, &ucred, &len) == -1) {
    457     CHECK_NE(ENOPROTOOPT, errno);
    458     return -1;
    459   }
    460   return ucred.pid;
    461 }
    462 
    463 }  // namespace forwarder2
    464