1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "tools/android/forwarder2/socket.h" 6 7 #include <arpa/inet.h> 8 #include <fcntl.h> 9 #include <netdb.h> 10 #include <netinet/in.h> 11 #include <stdio.h> 12 #include <string.h> 13 #include <sys/socket.h> 14 #include <sys/types.h> 15 #include <unistd.h> 16 17 #include "base/logging.h" 18 #include "base/posix/eintr_wrapper.h" 19 #include "base/safe_strerror_posix.h" 20 #include "tools/android/common/net.h" 21 #include "tools/android/forwarder2/common.h" 22 23 namespace { 24 const int kNoTimeout = -1; 25 const int kConnectTimeOut = 10; // Seconds. 26 27 bool FamilyIsTCP(int family) { 28 return family == AF_INET || family == AF_INET6; 29 } 30 } // namespace 31 32 namespace forwarder2 { 33 34 bool Socket::BindUnix(const std::string& path) { 35 errno = 0; 36 if (!InitUnixSocket(path) || !BindAndListen()) { 37 Close(); 38 return false; 39 } 40 return true; 41 } 42 43 bool Socket::BindTcp(const std::string& host, int port) { 44 errno = 0; 45 if (!InitTcpSocket(host, port) || !BindAndListen()) { 46 Close(); 47 return false; 48 } 49 return true; 50 } 51 52 bool Socket::ConnectUnix(const std::string& path) { 53 errno = 0; 54 if (!InitUnixSocket(path) || !Connect()) { 55 Close(); 56 return false; 57 } 58 return true; 59 } 60 61 bool Socket::ConnectTcp(const std::string& host, int port) { 62 errno = 0; 63 if (!InitTcpSocket(host, port) || !Connect()) { 64 Close(); 65 return false; 66 } 67 return true; 68 } 69 70 Socket::Socket() 71 : socket_(-1), 72 port_(0), 73 socket_error_(false), 74 family_(AF_INET), 75 addr_ptr_(reinterpret_cast<sockaddr*>(&addr_.addr4)), 76 addr_len_(sizeof(sockaddr)) { 77 memset(&addr_, 0, sizeof(addr_)); 78 } 79 80 Socket::~Socket() { 81 Close(); 82 } 83 84 void Socket::Shutdown() { 85 if (!IsClosed()) { 86 PRESERVE_ERRNO_HANDLE_EINTR(shutdown(socket_, SHUT_RDWR)); 87 } 88 } 89 90 void Socket::Close() { 91 if (!IsClosed()) { 92 CloseFD(socket_); 93 socket_ = -1; 94 } 95 } 96 97 bool Socket::InitSocketInternal() { 98 socket_ = socket(family_, SOCK_STREAM, 0); 99 if (socket_ < 0) { 100 PLOG(ERROR) << "socket"; 101 return false; 102 } 103 tools::DisableNagle(socket_); 104 int reuse_addr = 1; 105 setsockopt(socket_, SOL_SOCKET, SO_REUSEADDR, &reuse_addr, 106 sizeof(reuse_addr)); 107 if (!SetNonBlocking()) 108 return false; 109 return true; 110 } 111 112 bool Socket::SetNonBlocking() { 113 const int flags = fcntl(socket_, F_GETFL); 114 if (flags < 0) { 115 PLOG(ERROR) << "fcntl"; 116 return false; 117 } 118 if (flags & O_NONBLOCK) 119 return true; 120 if (fcntl(socket_, F_SETFL, flags | O_NONBLOCK) < 0) { 121 PLOG(ERROR) << "fcntl"; 122 return false; 123 } 124 return true; 125 } 126 127 bool Socket::InitUnixSocket(const std::string& path) { 128 static const size_t kPathMax = sizeof(addr_.addr_un.sun_path); 129 // For abstract sockets we need one extra byte for the leading zero. 130 if (path.size() + 2 /* '\0' */ > kPathMax) { 131 LOG(ERROR) << "The provided path is too big to create a unix " 132 << "domain socket: " << path; 133 return false; 134 } 135 family_ = PF_UNIX; 136 addr_.addr_un.sun_family = family_; 137 // Copied from net/socket/unix_domain_socket_posix.cc 138 // Convert the path given into abstract socket name. It must start with 139 // the '\0' character, so we are adding it. |addr_len| must specify the 140 // length of the structure exactly, as potentially the socket name may 141 // have '\0' characters embedded (although we don't support this). 142 // Note that addr_.addr_un.sun_path is already zero initialized. 143 memcpy(addr_.addr_un.sun_path + 1, path.c_str(), path.size()); 144 addr_len_ = path.size() + offsetof(struct sockaddr_un, sun_path) + 1; 145 addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr_un); 146 return InitSocketInternal(); 147 } 148 149 bool Socket::InitTcpSocket(const std::string& host, int port) { 150 port_ = port; 151 if (host.empty()) { 152 // Use localhost: INADDR_LOOPBACK 153 family_ = AF_INET; 154 addr_.addr4.sin_family = family_; 155 addr_.addr4.sin_addr.s_addr = htonl(INADDR_LOOPBACK); 156 } else if (!Resolve(host)) { 157 return false; 158 } 159 CHECK(FamilyIsTCP(family_)) << "Invalid socket family."; 160 if (family_ == AF_INET) { 161 addr_.addr4.sin_port = htons(port_); 162 addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr4); 163 addr_len_ = sizeof(addr_.addr4); 164 } else if (family_ == AF_INET6) { 165 addr_.addr6.sin6_port = htons(port_); 166 addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr6); 167 addr_len_ = sizeof(addr_.addr6); 168 } 169 return InitSocketInternal(); 170 } 171 172 bool Socket::BindAndListen() { 173 errno = 0; 174 if (HANDLE_EINTR(bind(socket_, addr_ptr_, addr_len_)) < 0 || 175 HANDLE_EINTR(listen(socket_, SOMAXCONN)) < 0) { 176 PLOG(ERROR) << "bind/listen"; 177 SetSocketError(); 178 return false; 179 } 180 if (port_ == 0 && FamilyIsTCP(family_)) { 181 SockAddr addr; 182 memset(&addr, 0, sizeof(addr)); 183 socklen_t addrlen = 0; 184 sockaddr* addr_ptr = NULL; 185 uint16* port_ptr = NULL; 186 if (family_ == AF_INET) { 187 addr_ptr = reinterpret_cast<sockaddr*>(&addr.addr4); 188 port_ptr = &addr.addr4.sin_port; 189 addrlen = sizeof(addr.addr4); 190 } else if (family_ == AF_INET6) { 191 addr_ptr = reinterpret_cast<sockaddr*>(&addr.addr6); 192 port_ptr = &addr.addr6.sin6_port; 193 addrlen = sizeof(addr.addr6); 194 } 195 errno = 0; 196 if (getsockname(socket_, addr_ptr, &addrlen) != 0) { 197 PLOG(ERROR) << "getsockname"; 198 SetSocketError(); 199 return false; 200 } 201 port_ = ntohs(*port_ptr); 202 } 203 return true; 204 } 205 206 bool Socket::Accept(Socket* new_socket) { 207 DCHECK(new_socket != NULL); 208 if (!WaitForEvent(READ, kNoTimeout)) { 209 SetSocketError(); 210 return false; 211 } 212 errno = 0; 213 int new_socket_fd = HANDLE_EINTR(accept(socket_, NULL, NULL)); 214 if (new_socket_fd < 0) { 215 SetSocketError(); 216 return false; 217 } 218 tools::DisableNagle(new_socket_fd); 219 new_socket->socket_ = new_socket_fd; 220 if (!new_socket->SetNonBlocking()) 221 return false; 222 return true; 223 } 224 225 bool Socket::Connect() { 226 DCHECK(fcntl(socket_, F_GETFL) & O_NONBLOCK); 227 errno = 0; 228 if (HANDLE_EINTR(connect(socket_, addr_ptr_, addr_len_)) < 0 && 229 errno != EINPROGRESS) { 230 SetSocketError(); 231 return false; 232 } 233 // Wait for connection to complete, or receive a notification. 234 if (!WaitForEvent(WRITE, kConnectTimeOut)) { 235 SetSocketError(); 236 return false; 237 } 238 int socket_errno; 239 socklen_t opt_len = sizeof(socket_errno); 240 if (getsockopt(socket_, SOL_SOCKET, SO_ERROR, &socket_errno, &opt_len) < 0) { 241 PLOG(ERROR) << "getsockopt()"; 242 SetSocketError(); 243 return false; 244 } 245 if (socket_errno != 0) { 246 LOG(ERROR) << "Could not connect to host: " << safe_strerror(socket_errno); 247 SetSocketError(); 248 return false; 249 } 250 return true; 251 } 252 253 bool Socket::Resolve(const std::string& host) { 254 struct addrinfo hints; 255 struct addrinfo* res; 256 memset(&hints, 0, sizeof(hints)); 257 hints.ai_family = AF_UNSPEC; 258 hints.ai_socktype = SOCK_STREAM; 259 hints.ai_flags |= AI_CANONNAME; 260 261 int errcode = getaddrinfo(host.c_str(), NULL, &hints, &res); 262 if (errcode != 0) { 263 errno = 0; 264 SetSocketError(); 265 freeaddrinfo(res); 266 return false; 267 } 268 family_ = res->ai_family; 269 switch (res->ai_family) { 270 case AF_INET: 271 memcpy(&addr_.addr4, 272 reinterpret_cast<sockaddr_in*>(res->ai_addr), 273 sizeof(sockaddr_in)); 274 break; 275 case AF_INET6: 276 memcpy(&addr_.addr6, 277 reinterpret_cast<sockaddr_in6*>(res->ai_addr), 278 sizeof(sockaddr_in6)); 279 break; 280 } 281 freeaddrinfo(res); 282 return true; 283 } 284 285 int Socket::GetPort() { 286 if (!FamilyIsTCP(family_)) { 287 LOG(ERROR) << "Can't call GetPort() on an unix domain socket."; 288 return 0; 289 } 290 return port_; 291 } 292 293 int Socket::ReadNumBytes(void* buffer, size_t num_bytes) { 294 int bytes_read = 0; 295 int ret = 1; 296 while (bytes_read < num_bytes && ret > 0) { 297 ret = Read(static_cast<char*>(buffer) + bytes_read, num_bytes - bytes_read); 298 if (ret >= 0) 299 bytes_read += ret; 300 } 301 return bytes_read; 302 } 303 304 void Socket::SetSocketError() { 305 socket_error_ = true; 306 DCHECK_NE(EAGAIN, errno); 307 DCHECK_NE(EWOULDBLOCK, errno); 308 Close(); 309 } 310 311 int Socket::Read(void* buffer, size_t buffer_size) { 312 if (!WaitForEvent(READ, kNoTimeout)) { 313 SetSocketError(); 314 return 0; 315 } 316 int ret = HANDLE_EINTR(read(socket_, buffer, buffer_size)); 317 if (ret < 0) { 318 PLOG(ERROR) << "read"; 319 SetSocketError(); 320 } 321 return ret; 322 } 323 324 int Socket::NonBlockingRead(void* buffer, size_t buffer_size) { 325 DCHECK(fcntl(socket_, F_GETFL) & O_NONBLOCK); 326 int ret = HANDLE_EINTR(read(socket_, buffer, buffer_size)); 327 if (ret < 0) { 328 PLOG(ERROR) << "read"; 329 SetSocketError(); 330 } 331 return ret; 332 } 333 334 int Socket::Write(const void* buffer, size_t count) { 335 if (!WaitForEvent(WRITE, kNoTimeout)) { 336 SetSocketError(); 337 return 0; 338 } 339 int ret = HANDLE_EINTR(send(socket_, buffer, count, MSG_NOSIGNAL)); 340 if (ret < 0) { 341 PLOG(ERROR) << "send"; 342 SetSocketError(); 343 } 344 return ret; 345 } 346 347 int Socket::NonBlockingWrite(const void* buffer, size_t count) { 348 DCHECK(fcntl(socket_, F_GETFL) & O_NONBLOCK); 349 int ret = HANDLE_EINTR(send(socket_, buffer, count, MSG_NOSIGNAL)); 350 if (ret < 0) { 351 PLOG(ERROR) << "send"; 352 SetSocketError(); 353 } 354 return ret; 355 } 356 357 int Socket::WriteString(const std::string& buffer) { 358 return WriteNumBytes(buffer.c_str(), buffer.size()); 359 } 360 361 void Socket::AddEventFd(int event_fd) { 362 Event event; 363 event.fd = event_fd; 364 event.was_fired = false; 365 events_.push_back(event); 366 } 367 368 bool Socket::DidReceiveEventOnFd(int fd) const { 369 for (size_t i = 0; i < events_.size(); ++i) 370 if (events_[i].fd == fd) 371 return events_[i].was_fired; 372 return false; 373 } 374 375 bool Socket::DidReceiveEvent() const { 376 for (size_t i = 0; i < events_.size(); ++i) 377 if (events_[i].was_fired) 378 return true; 379 return false; 380 } 381 382 int Socket::WriteNumBytes(const void* buffer, size_t num_bytes) { 383 int bytes_written = 0; 384 int ret = 1; 385 while (bytes_written < num_bytes && ret > 0) { 386 ret = Write(static_cast<const char*>(buffer) + bytes_written, 387 num_bytes - bytes_written); 388 if (ret >= 0) 389 bytes_written += ret; 390 } 391 return bytes_written; 392 } 393 394 bool Socket::WaitForEvent(EventType type, int timeout_secs) { 395 if (socket_ == -1) 396 return true; 397 DCHECK(fcntl(socket_, F_GETFL) & O_NONBLOCK); 398 fd_set read_fds; 399 fd_set write_fds; 400 FD_ZERO(&read_fds); 401 FD_ZERO(&write_fds); 402 if (type == READ) 403 FD_SET(socket_, &read_fds); 404 else 405 FD_SET(socket_, &write_fds); 406 for (size_t i = 0; i < events_.size(); ++i) 407 FD_SET(events_[i].fd, &read_fds); 408 timeval tv = {}; 409 timeval* tv_ptr = NULL; 410 if (timeout_secs > 0) { 411 tv.tv_sec = timeout_secs; 412 tv.tv_usec = 0; 413 tv_ptr = &tv; 414 } 415 int max_fd = socket_; 416 for (size_t i = 0; i < events_.size(); ++i) 417 if (events_[i].fd > max_fd) 418 max_fd = events_[i].fd; 419 if (HANDLE_EINTR( 420 select(max_fd + 1, &read_fds, &write_fds, NULL, tv_ptr)) <= 0) { 421 PLOG(ERROR) << "select"; 422 return false; 423 } 424 bool event_was_fired = false; 425 for (size_t i = 0; i < events_.size(); ++i) { 426 if (FD_ISSET(events_[i].fd, &read_fds)) { 427 events_[i].was_fired = true; 428 event_was_fired = true; 429 } 430 } 431 return !event_was_fired; 432 } 433 434 // static 435 pid_t Socket::GetUnixDomainSocketProcessOwner(const std::string& path) { 436 Socket socket; 437 if (!socket.ConnectUnix(path)) 438 return -1; 439 ucred ucred; 440 socklen_t len = sizeof(ucred); 441 if (getsockopt(socket.socket_, SOL_SOCKET, SO_PEERCRED, &ucred, &len) == -1) { 442 CHECK_NE(ENOPROTOOPT, errno); 443 return -1; 444 } 445 return ucred.pid; 446 } 447 448 } // namespace forwarder2 449