1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "tools/android/forwarder2/socket.h" 6 7 #include <arpa/inet.h> 8 #include <fcntl.h> 9 #include <netdb.h> 10 #include <netinet/in.h> 11 #include <stdio.h> 12 #include <string.h> 13 #include <sys/socket.h> 14 #include <sys/types.h> 15 #include <unistd.h> 16 17 #include "base/logging.h" 18 #include "base/posix/eintr_wrapper.h" 19 #include "base/safe_strerror_posix.h" 20 #include "tools/android/common/net.h" 21 #include "tools/android/forwarder2/common.h" 22 23 namespace { 24 const int kNoTimeout = -1; 25 const int kConnectTimeOut = 10; // Seconds. 26 27 bool FamilyIsTCP(int family) { 28 return family == AF_INET || family == AF_INET6; 29 } 30 } // namespace 31 32 namespace forwarder2 { 33 34 bool Socket::BindUnix(const std::string& path) { 35 errno = 0; 36 if (!InitUnixSocket(path) || !BindAndListen()) { 37 Close(); 38 return false; 39 } 40 return true; 41 } 42 43 bool Socket::BindTcp(const std::string& host, int port) { 44 errno = 0; 45 if (!InitTcpSocket(host, port) || !BindAndListen()) { 46 Close(); 47 return false; 48 } 49 return true; 50 } 51 52 bool Socket::ConnectUnix(const std::string& path) { 53 errno = 0; 54 if (!InitUnixSocket(path) || !Connect()) { 55 Close(); 56 return false; 57 } 58 return true; 59 } 60 61 bool Socket::ConnectTcp(const std::string& host, int port) { 62 errno = 0; 63 if (!InitTcpSocket(host, port) || !Connect()) { 64 Close(); 65 return false; 66 } 67 return true; 68 } 69 70 Socket::Socket() 71 : socket_(-1), 72 port_(0), 73 socket_error_(false), 74 family_(AF_INET), 75 addr_ptr_(reinterpret_cast<sockaddr*>(&addr_.addr4)), 76 addr_len_(sizeof(sockaddr)) { 77 memset(&addr_, 0, sizeof(addr_)); 78 } 79 80 Socket::~Socket() { 81 Close(); 82 } 83 84 void Socket::Shutdown() { 85 if (!IsClosed()) { 86 PRESERVE_ERRNO_HANDLE_EINTR(shutdown(socket_, SHUT_RDWR)); 87 } 88 } 89 90 void Socket::Close() { 91 if (!IsClosed()) { 92 CloseFD(socket_); 93 socket_ = -1; 94 } 95 } 96 97 bool Socket::InitSocketInternal() { 98 socket_ = socket(family_, SOCK_STREAM, 0); 99 if (socket_ < 0) 100 return false; 101 tools::DisableNagle(socket_); 102 int reuse_addr = 1; 103 setsockopt(socket_, SOL_SOCKET, SO_REUSEADDR, 104 &reuse_addr, sizeof(reuse_addr)); 105 return true; 106 } 107 108 bool Socket::InitUnixSocket(const std::string& path) { 109 static const size_t kPathMax = sizeof(addr_.addr_un.sun_path); 110 // For abstract sockets we need one extra byte for the leading zero. 111 if (path.size() + 2 /* '\0' */ > kPathMax) { 112 LOG(ERROR) << "The provided path is too big to create a unix " 113 << "domain socket: " << path; 114 return false; 115 } 116 family_ = PF_UNIX; 117 addr_.addr_un.sun_family = family_; 118 // Copied from net/socket/unix_domain_socket_posix.cc 119 // Convert the path given into abstract socket name. It must start with 120 // the '\0' character, so we are adding it. |addr_len| must specify the 121 // length of the structure exactly, as potentially the socket name may 122 // have '\0' characters embedded (although we don't support this). 123 // Note that addr_.addr_un.sun_path is already zero initialized. 124 memcpy(addr_.addr_un.sun_path + 1, path.c_str(), path.size()); 125 addr_len_ = path.size() + offsetof(struct sockaddr_un, sun_path) + 1; 126 addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr_un); 127 return InitSocketInternal(); 128 } 129 130 bool Socket::InitTcpSocket(const std::string& host, int port) { 131 port_ = port; 132 if (host.empty()) { 133 // Use localhost: INADDR_LOOPBACK 134 family_ = AF_INET; 135 addr_.addr4.sin_family = family_; 136 addr_.addr4.sin_addr.s_addr = htonl(INADDR_LOOPBACK); 137 } else if (!Resolve(host)) { 138 return false; 139 } 140 CHECK(FamilyIsTCP(family_)) << "Invalid socket family."; 141 if (family_ == AF_INET) { 142 addr_.addr4.sin_port = htons(port_); 143 addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr4); 144 addr_len_ = sizeof(addr_.addr4); 145 } else if (family_ == AF_INET6) { 146 addr_.addr6.sin6_port = htons(port_); 147 addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr6); 148 addr_len_ = sizeof(addr_.addr6); 149 } 150 return InitSocketInternal(); 151 } 152 153 bool Socket::BindAndListen() { 154 errno = 0; 155 if (HANDLE_EINTR(bind(socket_, addr_ptr_, addr_len_)) < 0 || 156 HANDLE_EINTR(listen(socket_, SOMAXCONN)) < 0) { 157 SetSocketError(); 158 return false; 159 } 160 if (port_ == 0 && FamilyIsTCP(family_)) { 161 SockAddr addr; 162 memset(&addr, 0, sizeof(addr)); 163 socklen_t addrlen = 0; 164 sockaddr* addr_ptr = NULL; 165 uint16* port_ptr = NULL; 166 if (family_ == AF_INET) { 167 addr_ptr = reinterpret_cast<sockaddr*>(&addr.addr4); 168 port_ptr = &addr.addr4.sin_port; 169 addrlen = sizeof(addr.addr4); 170 } else if (family_ == AF_INET6) { 171 addr_ptr = reinterpret_cast<sockaddr*>(&addr.addr6); 172 port_ptr = &addr.addr6.sin6_port; 173 addrlen = sizeof(addr.addr6); 174 } 175 errno = 0; 176 if (getsockname(socket_, addr_ptr, &addrlen) != 0) { 177 LOG(ERROR) << "getsockname error: " << safe_strerror(errno);; 178 SetSocketError(); 179 return false; 180 } 181 port_ = ntohs(*port_ptr); 182 } 183 return true; 184 } 185 186 bool Socket::Accept(Socket* new_socket) { 187 DCHECK(new_socket != NULL); 188 if (!WaitForEvent(READ, kNoTimeout)) { 189 SetSocketError(); 190 return false; 191 } 192 errno = 0; 193 int new_socket_fd = HANDLE_EINTR(accept(socket_, NULL, NULL)); 194 if (new_socket_fd < 0) { 195 SetSocketError(); 196 return false; 197 } 198 199 tools::DisableNagle(new_socket_fd); 200 new_socket->socket_ = new_socket_fd; 201 return true; 202 } 203 204 bool Socket::Connect() { 205 // Set non-block because we use select for connect. 206 const int kFlags = fcntl(socket_, F_GETFL); 207 DCHECK(!(kFlags & O_NONBLOCK)); 208 fcntl(socket_, F_SETFL, kFlags | O_NONBLOCK); 209 errno = 0; 210 if (HANDLE_EINTR(connect(socket_, addr_ptr_, addr_len_)) < 0 && 211 errno != EINPROGRESS) { 212 SetSocketError(); 213 PRESERVE_ERRNO_HANDLE_EINTR(fcntl(socket_, F_SETFL, kFlags)); 214 return false; 215 } 216 // Wait for connection to complete, or receive a notification. 217 if (!WaitForEvent(WRITE, kConnectTimeOut)) { 218 SetSocketError(); 219 PRESERVE_ERRNO_HANDLE_EINTR(fcntl(socket_, F_SETFL, kFlags)); 220 return false; 221 } 222 int socket_errno; 223 socklen_t opt_len = sizeof(socket_errno); 224 if (getsockopt(socket_, SOL_SOCKET, SO_ERROR, &socket_errno, &opt_len) < 0) { 225 LOG(ERROR) << "getsockopt(): " << safe_strerror(errno); 226 SetSocketError(); 227 PRESERVE_ERRNO_HANDLE_EINTR(fcntl(socket_, F_SETFL, kFlags)); 228 return false; 229 } 230 if (socket_errno != 0) { 231 LOG(ERROR) << "Could not connect to host: " << safe_strerror(socket_errno); 232 SetSocketError(); 233 PRESERVE_ERRNO_HANDLE_EINTR(fcntl(socket_, F_SETFL, kFlags)); 234 return false; 235 } 236 fcntl(socket_, F_SETFL, kFlags); 237 return true; 238 } 239 240 bool Socket::Resolve(const std::string& host) { 241 struct addrinfo hints; 242 struct addrinfo* res; 243 memset(&hints, 0, sizeof(hints)); 244 hints.ai_family = AF_UNSPEC; 245 hints.ai_socktype = SOCK_STREAM; 246 hints.ai_flags |= AI_CANONNAME; 247 248 int errcode = getaddrinfo(host.c_str(), NULL, &hints, &res); 249 if (errcode != 0) { 250 SetSocketError(); 251 freeaddrinfo(res); 252 return false; 253 } 254 family_ = res->ai_family; 255 switch (res->ai_family) { 256 case AF_INET: 257 memcpy(&addr_.addr4, 258 reinterpret_cast<sockaddr_in*>(res->ai_addr), 259 sizeof(sockaddr_in)); 260 break; 261 case AF_INET6: 262 memcpy(&addr_.addr6, 263 reinterpret_cast<sockaddr_in6*>(res->ai_addr), 264 sizeof(sockaddr_in6)); 265 break; 266 } 267 freeaddrinfo(res); 268 return true; 269 } 270 271 int Socket::GetPort() { 272 if (!FamilyIsTCP(family_)) { 273 LOG(ERROR) << "Can't call GetPort() on an unix domain socket."; 274 return 0; 275 } 276 return port_; 277 } 278 279 bool Socket::IsFdInSet(const fd_set& fds) const { 280 if (IsClosed()) 281 return false; 282 return FD_ISSET(socket_, &fds); 283 } 284 285 bool Socket::AddFdToSet(fd_set* fds) const { 286 if (IsClosed()) 287 return false; 288 FD_SET(socket_, fds); 289 return true; 290 } 291 292 int Socket::ReadNumBytes(void* buffer, size_t num_bytes) { 293 int bytes_read = 0; 294 int ret = 1; 295 while (bytes_read < num_bytes && ret > 0) { 296 ret = Read(static_cast<char*>(buffer) + bytes_read, num_bytes - bytes_read); 297 if (ret >= 0) 298 bytes_read += ret; 299 } 300 return bytes_read; 301 } 302 303 void Socket::SetSocketError() { 304 socket_error_ = true; 305 // We never use non-blocking socket. 306 DCHECK(errno != EAGAIN && errno != EWOULDBLOCK); 307 Close(); 308 } 309 310 int Socket::Read(void* buffer, size_t buffer_size) { 311 if (!WaitForEvent(READ, kNoTimeout)) { 312 SetSocketError(); 313 return 0; 314 } 315 int ret = HANDLE_EINTR(read(socket_, buffer, buffer_size)); 316 if (ret < 0) 317 SetSocketError(); 318 return ret; 319 } 320 321 int Socket::Write(const void* buffer, size_t count) { 322 int ret = HANDLE_EINTR(send(socket_, buffer, count, MSG_NOSIGNAL)); 323 if (ret < 0) 324 SetSocketError(); 325 return ret; 326 } 327 328 int Socket::WriteString(const std::string& buffer) { 329 return WriteNumBytes(buffer.c_str(), buffer.size()); 330 } 331 332 void Socket::AddEventFd(int event_fd) { 333 Event event; 334 event.fd = event_fd; 335 event.was_fired = false; 336 events_.push_back(event); 337 } 338 339 bool Socket::DidReceiveEventOnFd(int fd) const { 340 for (size_t i = 0; i < events_.size(); ++i) 341 if (events_[i].fd == fd) 342 return events_[i].was_fired; 343 return false; 344 } 345 346 bool Socket::DidReceiveEvent() const { 347 for (size_t i = 0; i < events_.size(); ++i) 348 if (events_[i].was_fired) 349 return true; 350 return false; 351 } 352 353 int Socket::WriteNumBytes(const void* buffer, size_t num_bytes) { 354 int bytes_written = 0; 355 int ret = 1; 356 while (bytes_written < num_bytes && ret > 0) { 357 ret = Write(static_cast<const char*>(buffer) + bytes_written, 358 num_bytes - bytes_written); 359 if (ret >= 0) 360 bytes_written += ret; 361 } 362 return bytes_written; 363 } 364 365 bool Socket::WaitForEvent(EventType type, int timeout_secs) { 366 if (events_.empty() || socket_ == -1) 367 return true; 368 fd_set read_fds; 369 fd_set write_fds; 370 FD_ZERO(&read_fds); 371 FD_ZERO(&write_fds); 372 if (type == READ) 373 FD_SET(socket_, &read_fds); 374 else 375 FD_SET(socket_, &write_fds); 376 for (size_t i = 0; i < events_.size(); ++i) 377 FD_SET(events_[i].fd, &read_fds); 378 timeval tv = {}; 379 timeval* tv_ptr = NULL; 380 if (timeout_secs > 0) { 381 tv.tv_sec = timeout_secs; 382 tv.tv_usec = 0; 383 tv_ptr = &tv; 384 } 385 int max_fd = socket_; 386 for (size_t i = 0; i < events_.size(); ++i) 387 if (events_[i].fd > max_fd) 388 max_fd = events_[i].fd; 389 if (HANDLE_EINTR( 390 select(max_fd + 1, &read_fds, &write_fds, NULL, tv_ptr)) <= 0) { 391 return false; 392 } 393 bool event_was_fired = false; 394 for (size_t i = 0; i < events_.size(); ++i) { 395 if (FD_ISSET(events_[i].fd, &read_fds)) { 396 events_[i].was_fired = true; 397 event_was_fired = true; 398 } 399 } 400 return !event_was_fired; 401 } 402 403 // static 404 int Socket::GetHighestFileDescriptor(const Socket& s1, const Socket& s2) { 405 return std::max(s1.socket_, s2.socket_); 406 } 407 408 // static 409 pid_t Socket::GetUnixDomainSocketProcessOwner(const std::string& path) { 410 Socket socket; 411 if (!socket.ConnectUnix(path)) 412 return -1; 413 ucred ucred; 414 socklen_t len = sizeof(ucred); 415 if (getsockopt(socket.socket_, SOL_SOCKET, SO_PEERCRED, &ucred, &len) == -1) { 416 CHECK_NE(ENOPROTOOPT, errno); 417 return -1; 418 } 419 return ucred.pid; 420 } 421 422 } // namespace forwarder2 423