Home | History | Annotate | Download | only in forwarder2
      1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "tools/android/forwarder2/socket.h"
      6 
      7 #include <arpa/inet.h>
      8 #include <fcntl.h>
      9 #include <netdb.h>
     10 #include <netinet/in.h>
     11 #include <stdio.h>
     12 #include <string.h>
     13 #include <sys/socket.h>
     14 #include <sys/types.h>
     15 #include <unistd.h>
     16 
     17 #include "base/logging.h"
     18 #include "base/posix/eintr_wrapper.h"
     19 #include "base/safe_strerror_posix.h"
     20 #include "tools/android/common/net.h"
     21 #include "tools/android/forwarder2/common.h"
     22 
     23 namespace {
     24 const int kNoTimeout = -1;
     25 const int kConnectTimeOut = 10;  // Seconds.
     26 
     27 bool FamilyIsTCP(int family) {
     28   return family == AF_INET || family == AF_INET6;
     29 }
     30 }  // namespace
     31 
     32 namespace forwarder2 {
     33 
     34 bool Socket::BindUnix(const std::string& path) {
     35   errno = 0;
     36   if (!InitUnixSocket(path) || !BindAndListen()) {
     37     Close();
     38     return false;
     39   }
     40   return true;
     41 }
     42 
     43 bool Socket::BindTcp(const std::string& host, int port) {
     44   errno = 0;
     45   if (!InitTcpSocket(host, port) || !BindAndListen()) {
     46     Close();
     47     return false;
     48   }
     49   return true;
     50 }
     51 
     52 bool Socket::ConnectUnix(const std::string& path) {
     53   errno = 0;
     54   if (!InitUnixSocket(path) || !Connect()) {
     55     Close();
     56     return false;
     57   }
     58   return true;
     59 }
     60 
     61 bool Socket::ConnectTcp(const std::string& host, int port) {
     62   errno = 0;
     63   if (!InitTcpSocket(host, port) || !Connect()) {
     64     Close();
     65     return false;
     66   }
     67   return true;
     68 }
     69 
     70 Socket::Socket()
     71     : socket_(-1),
     72       port_(0),
     73       socket_error_(false),
     74       family_(AF_INET),
     75       addr_ptr_(reinterpret_cast<sockaddr*>(&addr_.addr4)),
     76       addr_len_(sizeof(sockaddr)) {
     77   memset(&addr_, 0, sizeof(addr_));
     78 }
     79 
     80 Socket::~Socket() {
     81   Close();
     82 }
     83 
     84 void Socket::Shutdown() {
     85   if (!IsClosed()) {
     86     PRESERVE_ERRNO_HANDLE_EINTR(shutdown(socket_, SHUT_RDWR));
     87   }
     88 }
     89 
     90 void Socket::Close() {
     91   if (!IsClosed()) {
     92     CloseFD(socket_);
     93     socket_ = -1;
     94   }
     95 }
     96 
     97 bool Socket::InitSocketInternal() {
     98   socket_ = socket(family_, SOCK_STREAM, 0);
     99   if (socket_ < 0)
    100     return false;
    101   tools::DisableNagle(socket_);
    102   int reuse_addr = 1;
    103   setsockopt(socket_, SOL_SOCKET, SO_REUSEADDR,
    104              &reuse_addr, sizeof(reuse_addr));
    105   return true;
    106 }
    107 
    108 bool Socket::InitUnixSocket(const std::string& path) {
    109   static const size_t kPathMax = sizeof(addr_.addr_un.sun_path);
    110   // For abstract sockets we need one extra byte for the leading zero.
    111   if (path.size() + 2 /* '\0' */ > kPathMax) {
    112     LOG(ERROR) << "The provided path is too big to create a unix "
    113                << "domain socket: " << path;
    114     return false;
    115   }
    116   family_ = PF_UNIX;
    117   addr_.addr_un.sun_family = family_;
    118   // Copied from net/socket/unix_domain_socket_posix.cc
    119   // Convert the path given into abstract socket name. It must start with
    120   // the '\0' character, so we are adding it. |addr_len| must specify the
    121   // length of the structure exactly, as potentially the socket name may
    122   // have '\0' characters embedded (although we don't support this).
    123   // Note that addr_.addr_un.sun_path is already zero initialized.
    124   memcpy(addr_.addr_un.sun_path + 1, path.c_str(), path.size());
    125   addr_len_ = path.size() + offsetof(struct sockaddr_un, sun_path) + 1;
    126   addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr_un);
    127   return InitSocketInternal();
    128 }
    129 
    130 bool Socket::InitTcpSocket(const std::string& host, int port) {
    131   port_ = port;
    132   if (host.empty()) {
    133     // Use localhost: INADDR_LOOPBACK
    134     family_ = AF_INET;
    135     addr_.addr4.sin_family = family_;
    136     addr_.addr4.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
    137   } else if (!Resolve(host)) {
    138     return false;
    139   }
    140   CHECK(FamilyIsTCP(family_)) << "Invalid socket family.";
    141   if (family_ == AF_INET) {
    142     addr_.addr4.sin_port = htons(port_);
    143     addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr4);
    144     addr_len_ = sizeof(addr_.addr4);
    145   } else if (family_ == AF_INET6) {
    146     addr_.addr6.sin6_port = htons(port_);
    147     addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr6);
    148     addr_len_ = sizeof(addr_.addr6);
    149   }
    150   return InitSocketInternal();
    151 }
    152 
    153 bool Socket::BindAndListen() {
    154   errno = 0;
    155   if (HANDLE_EINTR(bind(socket_, addr_ptr_, addr_len_)) < 0 ||
    156       HANDLE_EINTR(listen(socket_, SOMAXCONN)) < 0) {
    157     SetSocketError();
    158     return false;
    159   }
    160   if (port_ == 0 && FamilyIsTCP(family_)) {
    161     SockAddr addr;
    162     memset(&addr, 0, sizeof(addr));
    163     socklen_t addrlen = 0;
    164     sockaddr* addr_ptr = NULL;
    165     uint16* port_ptr = NULL;
    166     if (family_ == AF_INET) {
    167       addr_ptr = reinterpret_cast<sockaddr*>(&addr.addr4);
    168       port_ptr = &addr.addr4.sin_port;
    169       addrlen = sizeof(addr.addr4);
    170     } else if (family_ == AF_INET6) {
    171       addr_ptr = reinterpret_cast<sockaddr*>(&addr.addr6);
    172       port_ptr = &addr.addr6.sin6_port;
    173       addrlen = sizeof(addr.addr6);
    174     }
    175     errno = 0;
    176     if (getsockname(socket_, addr_ptr, &addrlen) != 0) {
    177       LOG(ERROR) << "getsockname error: " << safe_strerror(errno);;
    178       SetSocketError();
    179       return false;
    180     }
    181     port_ = ntohs(*port_ptr);
    182   }
    183   return true;
    184 }
    185 
    186 bool Socket::Accept(Socket* new_socket) {
    187   DCHECK(new_socket != NULL);
    188   if (!WaitForEvent(READ, kNoTimeout)) {
    189     SetSocketError();
    190     return false;
    191   }
    192   errno = 0;
    193   int new_socket_fd = HANDLE_EINTR(accept(socket_, NULL, NULL));
    194   if (new_socket_fd < 0) {
    195     SetSocketError();
    196     return false;
    197   }
    198 
    199   tools::DisableNagle(new_socket_fd);
    200   new_socket->socket_ = new_socket_fd;
    201   return true;
    202 }
    203 
    204 bool Socket::Connect() {
    205   // Set non-block because we use select for connect.
    206   const int kFlags = fcntl(socket_, F_GETFL);
    207   DCHECK(!(kFlags & O_NONBLOCK));
    208   fcntl(socket_, F_SETFL, kFlags | O_NONBLOCK);
    209   errno = 0;
    210   if (HANDLE_EINTR(connect(socket_, addr_ptr_, addr_len_)) < 0 &&
    211       errno != EINPROGRESS) {
    212     SetSocketError();
    213     PRESERVE_ERRNO_HANDLE_EINTR(fcntl(socket_, F_SETFL, kFlags));
    214     return false;
    215   }
    216   // Wait for connection to complete, or receive a notification.
    217   if (!WaitForEvent(WRITE, kConnectTimeOut)) {
    218     SetSocketError();
    219     PRESERVE_ERRNO_HANDLE_EINTR(fcntl(socket_, F_SETFL, kFlags));
    220     return false;
    221   }
    222   int socket_errno;
    223   socklen_t opt_len = sizeof(socket_errno);
    224   if (getsockopt(socket_, SOL_SOCKET, SO_ERROR, &socket_errno, &opt_len) < 0) {
    225     LOG(ERROR) << "getsockopt(): " << safe_strerror(errno);
    226     SetSocketError();
    227     PRESERVE_ERRNO_HANDLE_EINTR(fcntl(socket_, F_SETFL, kFlags));
    228     return false;
    229   }
    230   if (socket_errno != 0) {
    231     LOG(ERROR) << "Could not connect to host: " << safe_strerror(socket_errno);
    232     SetSocketError();
    233     PRESERVE_ERRNO_HANDLE_EINTR(fcntl(socket_, F_SETFL, kFlags));
    234     return false;
    235   }
    236   fcntl(socket_, F_SETFL, kFlags);
    237   return true;
    238 }
    239 
    240 bool Socket::Resolve(const std::string& host) {
    241   struct addrinfo hints;
    242   struct addrinfo* res;
    243   memset(&hints, 0, sizeof(hints));
    244   hints.ai_family = AF_UNSPEC;
    245   hints.ai_socktype = SOCK_STREAM;
    246   hints.ai_flags |= AI_CANONNAME;
    247 
    248   int errcode = getaddrinfo(host.c_str(), NULL, &hints, &res);
    249   if (errcode != 0) {
    250     SetSocketError();
    251     freeaddrinfo(res);
    252     return false;
    253   }
    254   family_ = res->ai_family;
    255   switch (res->ai_family) {
    256     case AF_INET:
    257       memcpy(&addr_.addr4,
    258              reinterpret_cast<sockaddr_in*>(res->ai_addr),
    259              sizeof(sockaddr_in));
    260       break;
    261     case AF_INET6:
    262       memcpy(&addr_.addr6,
    263              reinterpret_cast<sockaddr_in6*>(res->ai_addr),
    264              sizeof(sockaddr_in6));
    265       break;
    266   }
    267   freeaddrinfo(res);
    268   return true;
    269 }
    270 
    271 int Socket::GetPort() {
    272   if (!FamilyIsTCP(family_)) {
    273     LOG(ERROR) << "Can't call GetPort() on an unix domain socket.";
    274     return 0;
    275   }
    276   return port_;
    277 }
    278 
    279 bool Socket::IsFdInSet(const fd_set& fds) const {
    280   if (IsClosed())
    281     return false;
    282   return FD_ISSET(socket_, &fds);
    283 }
    284 
    285 bool Socket::AddFdToSet(fd_set* fds) const {
    286   if (IsClosed())
    287     return false;
    288   FD_SET(socket_, fds);
    289   return true;
    290 }
    291 
    292 int Socket::ReadNumBytes(void* buffer, size_t num_bytes) {
    293   int bytes_read = 0;
    294   int ret = 1;
    295   while (bytes_read < num_bytes && ret > 0) {
    296     ret = Read(static_cast<char*>(buffer) + bytes_read, num_bytes - bytes_read);
    297     if (ret >= 0)
    298       bytes_read += ret;
    299   }
    300   return bytes_read;
    301 }
    302 
    303 void Socket::SetSocketError() {
    304   socket_error_ = true;
    305   // We never use non-blocking socket.
    306   DCHECK(errno != EAGAIN && errno != EWOULDBLOCK);
    307   Close();
    308 }
    309 
    310 int Socket::Read(void* buffer, size_t buffer_size) {
    311   if (!WaitForEvent(READ, kNoTimeout)) {
    312     SetSocketError();
    313     return 0;
    314   }
    315   int ret = HANDLE_EINTR(read(socket_, buffer, buffer_size));
    316   if (ret < 0)
    317     SetSocketError();
    318   return ret;
    319 }
    320 
    321 int Socket::Write(const void* buffer, size_t count) {
    322   int ret = HANDLE_EINTR(send(socket_, buffer, count, MSG_NOSIGNAL));
    323   if (ret < 0)
    324     SetSocketError();
    325   return ret;
    326 }
    327 
    328 int Socket::WriteString(const std::string& buffer) {
    329   return WriteNumBytes(buffer.c_str(), buffer.size());
    330 }
    331 
    332 void Socket::AddEventFd(int event_fd) {
    333   Event event;
    334   event.fd = event_fd;
    335   event.was_fired = false;
    336   events_.push_back(event);
    337 }
    338 
    339 bool Socket::DidReceiveEventOnFd(int fd) const {
    340   for (size_t i = 0; i < events_.size(); ++i)
    341     if (events_[i].fd == fd)
    342       return events_[i].was_fired;
    343   return false;
    344 }
    345 
    346 bool Socket::DidReceiveEvent() const {
    347   for (size_t i = 0; i < events_.size(); ++i)
    348     if (events_[i].was_fired)
    349       return true;
    350   return false;
    351 }
    352 
    353 int Socket::WriteNumBytes(const void* buffer, size_t num_bytes) {
    354   int bytes_written = 0;
    355   int ret = 1;
    356   while (bytes_written < num_bytes && ret > 0) {
    357     ret = Write(static_cast<const char*>(buffer) + bytes_written,
    358                 num_bytes - bytes_written);
    359     if (ret >= 0)
    360       bytes_written += ret;
    361   }
    362   return bytes_written;
    363 }
    364 
    365 bool Socket::WaitForEvent(EventType type, int timeout_secs) {
    366   if (events_.empty() || socket_ == -1)
    367     return true;
    368   fd_set read_fds;
    369   fd_set write_fds;
    370   FD_ZERO(&read_fds);
    371   FD_ZERO(&write_fds);
    372   if (type == READ)
    373     FD_SET(socket_, &read_fds);
    374   else
    375     FD_SET(socket_, &write_fds);
    376   for (size_t i = 0; i < events_.size(); ++i)
    377     FD_SET(events_[i].fd, &read_fds);
    378   timeval tv = {};
    379   timeval* tv_ptr = NULL;
    380   if (timeout_secs > 0) {
    381     tv.tv_sec = timeout_secs;
    382     tv.tv_usec = 0;
    383     tv_ptr = &tv;
    384   }
    385   int max_fd = socket_;
    386   for (size_t i = 0; i < events_.size(); ++i)
    387     if (events_[i].fd > max_fd)
    388       max_fd = events_[i].fd;
    389   if (HANDLE_EINTR(
    390           select(max_fd + 1, &read_fds, &write_fds, NULL, tv_ptr)) <= 0) {
    391     return false;
    392   }
    393   bool event_was_fired = false;
    394   for (size_t i = 0; i < events_.size(); ++i) {
    395     if (FD_ISSET(events_[i].fd, &read_fds)) {
    396       events_[i].was_fired = true;
    397       event_was_fired = true;
    398     }
    399   }
    400   return !event_was_fired;
    401 }
    402 
    403 // static
    404 int Socket::GetHighestFileDescriptor(const Socket& s1, const Socket& s2) {
    405   return std::max(s1.socket_, s2.socket_);
    406 }
    407 
    408 // static
    409 pid_t Socket::GetUnixDomainSocketProcessOwner(const std::string& path) {
    410   Socket socket;
    411   if (!socket.ConnectUnix(path))
    412     return -1;
    413   ucred ucred;
    414   socklen_t len = sizeof(ucred);
    415   if (getsockopt(socket.socket_, SOL_SOCKET, SO_PEERCRED, &ucred, &len) == -1) {
    416     CHECK_NE(ENOPROTOOPT, errno);
    417     return -1;
    418   }
    419   return ucred.pid;
    420 }
    421 
    422 }  // namespace forwarder2
    423