1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "tools/android/forwarder2/socket.h" 6 7 #include <arpa/inet.h> 8 #include <fcntl.h> 9 #include <netdb.h> 10 #include <netinet/in.h> 11 #include <stdio.h> 12 #include <string.h> 13 #include <sys/socket.h> 14 #include <sys/types.h> 15 #include <unistd.h> 16 17 #include "base/logging.h" 18 #include "base/posix/eintr_wrapper.h" 19 #include "base/safe_strerror_posix.h" 20 #include "tools/android/common/net.h" 21 #include "tools/android/forwarder2/common.h" 22 23 namespace { 24 const int kNoTimeout = -1; 25 const int kConnectTimeOut = 10; // Seconds. 26 27 bool FamilyIsTCP(int family) { 28 return family == AF_INET || family == AF_INET6; 29 } 30 } // namespace 31 32 namespace forwarder2 { 33 34 bool Socket::BindUnix(const std::string& path) { 35 errno = 0; 36 if (!InitUnixSocket(path) || !BindAndListen()) { 37 Close(); 38 return false; 39 } 40 return true; 41 } 42 43 bool Socket::BindTcp(const std::string& host, int port) { 44 errno = 0; 45 if (!InitTcpSocket(host, port) || !BindAndListen()) { 46 Close(); 47 return false; 48 } 49 return true; 50 } 51 52 bool Socket::ConnectUnix(const std::string& path) { 53 errno = 0; 54 if (!InitUnixSocket(path) || !Connect()) { 55 Close(); 56 return false; 57 } 58 return true; 59 } 60 61 bool Socket::ConnectTcp(const std::string& host, int port) { 62 errno = 0; 63 if (!InitTcpSocket(host, port) || !Connect()) { 64 Close(); 65 return false; 66 } 67 return true; 68 } 69 70 Socket::Socket() 71 : socket_(-1), 72 port_(0), 73 socket_error_(false), 74 family_(AF_INET), 75 addr_ptr_(reinterpret_cast<sockaddr*>(&addr_.addr4)), 76 addr_len_(sizeof(sockaddr)) { 77 memset(&addr_, 0, sizeof(addr_)); 78 } 79 80 Socket::~Socket() { 81 Close(); 82 } 83 84 void Socket::Shutdown() { 85 if (!IsClosed()) { 86 PRESERVE_ERRNO_HANDLE_EINTR(shutdown(socket_, SHUT_RDWR)); 87 } 88 } 89 90 void Socket::Close() { 91 if (!IsClosed()) { 92 CloseFD(socket_); 93 socket_ = -1; 94 } 95 } 96 97 bool Socket::InitSocketInternal() { 98 socket_ = socket(family_, SOCK_STREAM, 0); 99 if (socket_ < 0) 100 return false; 101 tools::DisableNagle(socket_); 102 int reuse_addr = 1; 103 setsockopt(socket_, SOL_SOCKET, SO_REUSEADDR, &reuse_addr, 104 sizeof(reuse_addr)); 105 if (!SetNonBlocking()) 106 return false; 107 return true; 108 } 109 110 bool Socket::SetNonBlocking() { 111 const int flags = fcntl(socket_, F_GETFL); 112 if (flags < 0) { 113 PLOG(ERROR) << "fcntl"; 114 return false; 115 } 116 if (flags & O_NONBLOCK) 117 return true; 118 if (fcntl(socket_, F_SETFL, flags | O_NONBLOCK) < 0) { 119 PLOG(ERROR) << "fcntl"; 120 return false; 121 } 122 return true; 123 } 124 125 bool Socket::InitUnixSocket(const std::string& path) { 126 static const size_t kPathMax = sizeof(addr_.addr_un.sun_path); 127 // For abstract sockets we need one extra byte for the leading zero. 128 if (path.size() + 2 /* '\0' */ > kPathMax) { 129 LOG(ERROR) << "The provided path is too big to create a unix " 130 << "domain socket: " << path; 131 return false; 132 } 133 family_ = PF_UNIX; 134 addr_.addr_un.sun_family = family_; 135 // Copied from net/socket/unix_domain_socket_posix.cc 136 // Convert the path given into abstract socket name. It must start with 137 // the '\0' character, so we are adding it. |addr_len| must specify the 138 // length of the structure exactly, as potentially the socket name may 139 // have '\0' characters embedded (although we don't support this). 140 // Note that addr_.addr_un.sun_path is already zero initialized. 141 memcpy(addr_.addr_un.sun_path + 1, path.c_str(), path.size()); 142 addr_len_ = path.size() + offsetof(struct sockaddr_un, sun_path) + 1; 143 addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr_un); 144 return InitSocketInternal(); 145 } 146 147 bool Socket::InitTcpSocket(const std::string& host, int port) { 148 port_ = port; 149 if (host.empty()) { 150 // Use localhost: INADDR_LOOPBACK 151 family_ = AF_INET; 152 addr_.addr4.sin_family = family_; 153 addr_.addr4.sin_addr.s_addr = htonl(INADDR_LOOPBACK); 154 } else if (!Resolve(host)) { 155 return false; 156 } 157 CHECK(FamilyIsTCP(family_)) << "Invalid socket family."; 158 if (family_ == AF_INET) { 159 addr_.addr4.sin_port = htons(port_); 160 addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr4); 161 addr_len_ = sizeof(addr_.addr4); 162 } else if (family_ == AF_INET6) { 163 addr_.addr6.sin6_port = htons(port_); 164 addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr6); 165 addr_len_ = sizeof(addr_.addr6); 166 } 167 return InitSocketInternal(); 168 } 169 170 bool Socket::BindAndListen() { 171 errno = 0; 172 if (HANDLE_EINTR(bind(socket_, addr_ptr_, addr_len_)) < 0 || 173 HANDLE_EINTR(listen(socket_, SOMAXCONN)) < 0) { 174 SetSocketError(); 175 return false; 176 } 177 if (port_ == 0 && FamilyIsTCP(family_)) { 178 SockAddr addr; 179 memset(&addr, 0, sizeof(addr)); 180 socklen_t addrlen = 0; 181 sockaddr* addr_ptr = NULL; 182 uint16* port_ptr = NULL; 183 if (family_ == AF_INET) { 184 addr_ptr = reinterpret_cast<sockaddr*>(&addr.addr4); 185 port_ptr = &addr.addr4.sin_port; 186 addrlen = sizeof(addr.addr4); 187 } else if (family_ == AF_INET6) { 188 addr_ptr = reinterpret_cast<sockaddr*>(&addr.addr6); 189 port_ptr = &addr.addr6.sin6_port; 190 addrlen = sizeof(addr.addr6); 191 } 192 errno = 0; 193 if (getsockname(socket_, addr_ptr, &addrlen) != 0) { 194 PLOG(ERROR) << "getsockname"; 195 SetSocketError(); 196 return false; 197 } 198 port_ = ntohs(*port_ptr); 199 } 200 return true; 201 } 202 203 bool Socket::Accept(Socket* new_socket) { 204 DCHECK(new_socket != NULL); 205 if (!WaitForEvent(READ, kNoTimeout)) { 206 SetSocketError(); 207 return false; 208 } 209 errno = 0; 210 int new_socket_fd = HANDLE_EINTR(accept(socket_, NULL, NULL)); 211 if (new_socket_fd < 0) { 212 SetSocketError(); 213 return false; 214 } 215 tools::DisableNagle(new_socket_fd); 216 new_socket->socket_ = new_socket_fd; 217 if (!new_socket->SetNonBlocking()) 218 return false; 219 return true; 220 } 221 222 bool Socket::Connect() { 223 DCHECK(fcntl(socket_, F_GETFL) & O_NONBLOCK); 224 errno = 0; 225 if (HANDLE_EINTR(connect(socket_, addr_ptr_, addr_len_)) < 0 && 226 errno != EINPROGRESS) { 227 SetSocketError(); 228 return false; 229 } 230 // Wait for connection to complete, or receive a notification. 231 if (!WaitForEvent(WRITE, kConnectTimeOut)) { 232 SetSocketError(); 233 return false; 234 } 235 int socket_errno; 236 socklen_t opt_len = sizeof(socket_errno); 237 if (getsockopt(socket_, SOL_SOCKET, SO_ERROR, &socket_errno, &opt_len) < 0) { 238 PLOG(ERROR) << "getsockopt()"; 239 SetSocketError(); 240 return false; 241 } 242 if (socket_errno != 0) { 243 LOG(ERROR) << "Could not connect to host: " << safe_strerror(socket_errno); 244 SetSocketError(); 245 return false; 246 } 247 return true; 248 } 249 250 bool Socket::Resolve(const std::string& host) { 251 struct addrinfo hints; 252 struct addrinfo* res; 253 memset(&hints, 0, sizeof(hints)); 254 hints.ai_family = AF_UNSPEC; 255 hints.ai_socktype = SOCK_STREAM; 256 hints.ai_flags |= AI_CANONNAME; 257 258 int errcode = getaddrinfo(host.c_str(), NULL, &hints, &res); 259 if (errcode != 0) { 260 errno = 0; 261 SetSocketError(); 262 freeaddrinfo(res); 263 return false; 264 } 265 family_ = res->ai_family; 266 switch (res->ai_family) { 267 case AF_INET: 268 memcpy(&addr_.addr4, 269 reinterpret_cast<sockaddr_in*>(res->ai_addr), 270 sizeof(sockaddr_in)); 271 break; 272 case AF_INET6: 273 memcpy(&addr_.addr6, 274 reinterpret_cast<sockaddr_in6*>(res->ai_addr), 275 sizeof(sockaddr_in6)); 276 break; 277 } 278 freeaddrinfo(res); 279 return true; 280 } 281 282 int Socket::GetPort() { 283 if (!FamilyIsTCP(family_)) { 284 LOG(ERROR) << "Can't call GetPort() on an unix domain socket."; 285 return 0; 286 } 287 return port_; 288 } 289 290 bool Socket::IsFdInSet(const fd_set& fds) const { 291 if (IsClosed()) 292 return false; 293 return FD_ISSET(socket_, &fds); 294 } 295 296 bool Socket::AddFdToSet(fd_set* fds) const { 297 if (IsClosed()) 298 return false; 299 FD_SET(socket_, fds); 300 return true; 301 } 302 303 int Socket::ReadNumBytes(void* buffer, size_t num_bytes) { 304 int bytes_read = 0; 305 int ret = 1; 306 while (bytes_read < num_bytes && ret > 0) { 307 ret = Read(static_cast<char*>(buffer) + bytes_read, num_bytes - bytes_read); 308 if (ret >= 0) 309 bytes_read += ret; 310 } 311 return bytes_read; 312 } 313 314 void Socket::SetSocketError() { 315 socket_error_ = true; 316 DCHECK_NE(EAGAIN, errno); 317 DCHECK_NE(EWOULDBLOCK, errno); 318 Close(); 319 } 320 321 int Socket::Read(void* buffer, size_t buffer_size) { 322 if (!WaitForEvent(READ, kNoTimeout)) { 323 SetSocketError(); 324 return 0; 325 } 326 int ret = HANDLE_EINTR(read(socket_, buffer, buffer_size)); 327 if (ret < 0) { 328 PLOG(ERROR) << "read"; 329 SetSocketError(); 330 } 331 return ret; 332 } 333 334 int Socket::NonBlockingRead(void* buffer, size_t buffer_size) { 335 DCHECK(fcntl(socket_, F_GETFL) & O_NONBLOCK); 336 int ret = HANDLE_EINTR(read(socket_, buffer, buffer_size)); 337 if (ret < 0) { 338 PLOG(ERROR) << "read"; 339 SetSocketError(); 340 } 341 return ret; 342 } 343 344 int Socket::Write(const void* buffer, size_t count) { 345 if (!WaitForEvent(WRITE, kNoTimeout)) { 346 SetSocketError(); 347 return 0; 348 } 349 int ret = HANDLE_EINTR(send(socket_, buffer, count, MSG_NOSIGNAL)); 350 if (ret < 0) { 351 PLOG(ERROR) << "send"; 352 SetSocketError(); 353 } 354 return ret; 355 } 356 357 int Socket::NonBlockingWrite(const void* buffer, size_t count) { 358 DCHECK(fcntl(socket_, F_GETFL) & O_NONBLOCK); 359 int ret = HANDLE_EINTR(send(socket_, buffer, count, MSG_NOSIGNAL)); 360 if (ret < 0) { 361 PLOG(ERROR) << "send"; 362 SetSocketError(); 363 } 364 return ret; 365 } 366 367 int Socket::WriteString(const std::string& buffer) { 368 return WriteNumBytes(buffer.c_str(), buffer.size()); 369 } 370 371 void Socket::AddEventFd(int event_fd) { 372 Event event; 373 event.fd = event_fd; 374 event.was_fired = false; 375 events_.push_back(event); 376 } 377 378 bool Socket::DidReceiveEventOnFd(int fd) const { 379 for (size_t i = 0; i < events_.size(); ++i) 380 if (events_[i].fd == fd) 381 return events_[i].was_fired; 382 return false; 383 } 384 385 bool Socket::DidReceiveEvent() const { 386 for (size_t i = 0; i < events_.size(); ++i) 387 if (events_[i].was_fired) 388 return true; 389 return false; 390 } 391 392 int Socket::WriteNumBytes(const void* buffer, size_t num_bytes) { 393 int bytes_written = 0; 394 int ret = 1; 395 while (bytes_written < num_bytes && ret > 0) { 396 ret = Write(static_cast<const char*>(buffer) + bytes_written, 397 num_bytes - bytes_written); 398 if (ret >= 0) 399 bytes_written += ret; 400 } 401 return bytes_written; 402 } 403 404 bool Socket::WaitForEvent(EventType type, int timeout_secs) { 405 if (socket_ == -1) 406 return true; 407 DCHECK(fcntl(socket_, F_GETFL) & O_NONBLOCK); 408 fd_set read_fds; 409 fd_set write_fds; 410 FD_ZERO(&read_fds); 411 FD_ZERO(&write_fds); 412 if (type == READ) 413 FD_SET(socket_, &read_fds); 414 else 415 FD_SET(socket_, &write_fds); 416 for (size_t i = 0; i < events_.size(); ++i) 417 FD_SET(events_[i].fd, &read_fds); 418 timeval tv = {}; 419 timeval* tv_ptr = NULL; 420 if (timeout_secs > 0) { 421 tv.tv_sec = timeout_secs; 422 tv.tv_usec = 0; 423 tv_ptr = &tv; 424 } 425 int max_fd = socket_; 426 for (size_t i = 0; i < events_.size(); ++i) 427 if (events_[i].fd > max_fd) 428 max_fd = events_[i].fd; 429 if (HANDLE_EINTR( 430 select(max_fd + 1, &read_fds, &write_fds, NULL, tv_ptr)) <= 0) { 431 PLOG(ERROR) << "select"; 432 return false; 433 } 434 bool event_was_fired = false; 435 for (size_t i = 0; i < events_.size(); ++i) { 436 if (FD_ISSET(events_[i].fd, &read_fds)) { 437 events_[i].was_fired = true; 438 event_was_fired = true; 439 } 440 } 441 return !event_was_fired; 442 } 443 444 // static 445 int Socket::GetHighestFileDescriptor(const Socket& s1, const Socket& s2) { 446 return std::max(s1.socket_, s2.socket_); 447 } 448 449 // static 450 pid_t Socket::GetUnixDomainSocketProcessOwner(const std::string& path) { 451 Socket socket; 452 if (!socket.ConnectUnix(path)) 453 return -1; 454 ucred ucred; 455 socklen_t len = sizeof(ucred); 456 if (getsockopt(socket.socket_, SOL_SOCKET, SO_PEERCRED, &ucred, &len) == -1) { 457 CHECK_NE(ENOPROTOOPT, errno); 458 return -1; 459 } 460 return ucred.pid; 461 } 462 463 } // namespace forwarder2 464