1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "chre_host/socket_server.h" 18 19 #include <poll.h> 20 21 #include <cassert> 22 #include <cinttypes> 23 #include <csignal> 24 #include <cstdlib> 25 #include <map> 26 #include <mutex> 27 28 #include <cutils/sockets.h> 29 30 #include "chre_host/log.h" 31 32 namespace android { 33 namespace chre { 34 35 std::atomic<bool> SocketServer::sSignalReceived(false); 36 37 namespace { 38 39 void maskAllSignals() { 40 sigset_t signalMask; 41 sigfillset(&signalMask); 42 if (sigprocmask(SIG_SETMASK, &signalMask, NULL) != 0) { 43 LOG_ERROR("Couldn't mask all signals", errno); 44 } 45 } 46 47 void maskAllSignalsExceptIntAndTerm() { 48 sigset_t signalMask; 49 sigfillset(&signalMask); 50 sigdelset(&signalMask, SIGINT); 51 sigdelset(&signalMask, SIGTERM); 52 if (sigprocmask(SIG_SETMASK, &signalMask, NULL) != 0) { 53 LOG_ERROR("Couldn't mask all signals except INT/TERM", errno); 54 } 55 } 56 57 } // anonymous namespace 58 59 SocketServer::SocketServer() { 60 // Initialize the socket fds field for all inactive client slots to -1, so 61 // poll skips over it, and we don't attempt to send on it 62 for (size_t i = 1; i <= kMaxActiveClients; i++) { 63 mPollFds[i].fd = -1; 64 mPollFds[i].events = POLLIN; 65 } 66 } 67 68 void SocketServer::run(const char *socketName, bool allowSocketCreation, 69 ClientMessageCallback clientMessageCallback) { 70 mClientMessageCallback = clientMessageCallback; 71 72 mSockFd = android_get_control_socket(socketName); 73 if (mSockFd == INVALID_SOCKET && allowSocketCreation) { 74 LOGI("Didn't inherit socket, creating..."); 75 mSockFd = socket_local_server(socketName, 76 ANDROID_SOCKET_NAMESPACE_RESERVED, 77 SOCK_SEQPACKET); 78 } 79 80 if (mSockFd == INVALID_SOCKET) { 81 LOGE("Couldn't get/create socket"); 82 } else { 83 int ret = listen(mSockFd, kMaxPendingConnectionRequests); 84 if (ret < 0) { 85 LOG_ERROR("Couldn't listen on socket", errno); 86 } else { 87 serviceSocket(); 88 } 89 90 { 91 std::lock_guard<std::mutex> lock(mClientsMutex); 92 for (const auto& pair : mClients) { 93 int clientSocket = pair.first; 94 if (close(clientSocket) != 0) { 95 LOGI("Couldn't close client %" PRIu16 "'s socket: %s", 96 pair.second.clientId, strerror(errno)); 97 } 98 } 99 mClients.clear(); 100 } 101 close(mSockFd); 102 } 103 } 104 105 void SocketServer::sendToAllClients(const void *data, size_t length) { 106 std::lock_guard<std::mutex> lock(mClientsMutex); 107 108 int deliveredCount = 0; 109 for (const auto& pair : mClients) { 110 int clientSocket = pair.first; 111 uint16_t clientId = pair.second.clientId; 112 if (sendToClientSocket(data, length, clientSocket, clientId)) { 113 deliveredCount++; 114 } else if (errno == EINTR) { 115 // Exit early if we were interrupted - we should only get this for 116 // SIGINT/SIGTERM, so we should exit quickly 117 break; 118 } 119 } 120 121 if (deliveredCount == 0) { 122 LOGW("Got message but didn't deliver to any clients"); 123 } 124 } 125 126 bool SocketServer::sendToClientById(const void *data, size_t length, 127 uint16_t clientId) { 128 std::lock_guard<std::mutex> lock(mClientsMutex); 129 130 bool sent = false; 131 for (const auto& pair : mClients) { 132 uint16_t thisClientId = pair.second.clientId; 133 if (thisClientId == clientId) { 134 int clientSocket = pair.first; 135 sent = sendToClientSocket(data, length, clientSocket, thisClientId); 136 break; 137 } 138 } 139 140 return sent; 141 } 142 143 void SocketServer::acceptClientConnection() { 144 int clientSocket = accept(mSockFd, NULL, NULL); 145 if (clientSocket < 0) { 146 LOG_ERROR("Couldn't accept client connection", errno); 147 } else if (mClients.size() >= kMaxActiveClients) { 148 LOGW("Rejecting client request - maximum number of clients reached"); 149 close(clientSocket); 150 } else { 151 ClientData clientData; 152 clientData.clientId = mNextClientId++; 153 154 // We currently don't handle wraparound - if we're getting this many 155 // connects/disconnects, then something is wrong. 156 // TODO: can handle this properly by iterating over the existing clients to 157 // avoid a conflict. 158 if (clientData.clientId == 0) { 159 LOGE("Couldn't allocate client ID"); 160 std::exit(-1); 161 } 162 163 bool slotFound = false; 164 for (size_t i = 1; i <= kMaxActiveClients; i++) { 165 if (mPollFds[i].fd < 0) { 166 mPollFds[i].fd = clientSocket; 167 slotFound = true; 168 break; 169 } 170 } 171 172 if (!slotFound) { 173 LOGE("Couldn't find slot for client!"); 174 assert(slotFound); 175 close(clientSocket); 176 } else { 177 { 178 std::lock_guard<std::mutex> lock(mClientsMutex); 179 mClients[clientSocket] = clientData; 180 } 181 LOGI("Accepted new client connection (count %zu), assigned client ID %" 182 PRIu16, mClients.size(), clientData.clientId); 183 } 184 } 185 } 186 187 void SocketServer::handleClientData(int clientSocket) { 188 const ClientData& clientData = mClients[clientSocket]; 189 uint16_t clientId = clientData.clientId; 190 191 ssize_t packetSize = recv( 192 clientSocket, mRecvBuffer.data(), mRecvBuffer.size(), MSG_DONTWAIT); 193 if (packetSize < 0) { 194 LOGE("Couldn't get packet from client %" PRIu16 ": %s", clientId, 195 strerror(errno)); 196 } else if (packetSize == 0) { 197 LOGI("Client %" PRIu16 " disconnected", clientId); 198 disconnectClient(clientSocket); 199 } else { 200 LOGV("Got %zd byte packet from client %" PRIu16, packetSize, clientId); 201 mClientMessageCallback(clientId, mRecvBuffer.data(), packetSize); 202 } 203 } 204 205 void SocketServer::disconnectClient(int clientSocket) { 206 { 207 std::lock_guard<std::mutex> lock(mClientsMutex); 208 mClients.erase(clientSocket); 209 } 210 close(clientSocket); 211 212 bool removed = false; 213 for (size_t i = 1; i <= kMaxActiveClients; i++) { 214 if (mPollFds[i].fd == clientSocket) { 215 mPollFds[i].fd = -1; 216 removed = true; 217 break; 218 } 219 } 220 221 if (!removed) { 222 LOGE("Out of sync"); 223 assert(removed); 224 } 225 } 226 227 bool SocketServer::sendToClientSocket(const void *data, size_t length, 228 int clientSocket, uint16_t clientId) { 229 errno = 0; 230 ssize_t bytesSent = send(clientSocket, data, length, 0); 231 if (bytesSent < 0) { 232 LOGE("Error sending packet of size %zu to client %" PRIu16 ": %s", 233 length, clientId, strerror(errno)); 234 } else if (bytesSent == 0) { 235 LOGW("Client %" PRIu16 " disconnected before message could be delivered", 236 clientId); 237 } else { 238 LOGV("Delivered message of size %zu bytes to client %" PRIu16, length, 239 clientId); 240 } 241 242 return (bytesSent > 0); 243 } 244 245 void SocketServer::serviceSocket() { 246 constexpr size_t kListenIndex = 0; 247 static_assert(kListenIndex == 0, "Code assumes that the first index is " 248 "always the listen socket"); 249 250 mPollFds[kListenIndex].fd = mSockFd; 251 mPollFds[kListenIndex].events = POLLIN; 252 253 // Signal mask used with ppoll() so we gracefully handle SIGINT and SIGTERM, 254 // and ignore other signals 255 sigset_t signalMask; 256 sigfillset(&signalMask); 257 sigdelset(&signalMask, SIGINT); 258 sigdelset(&signalMask, SIGTERM); 259 260 // Masking signals here ensure that after this point, we won't handle INT/TERM 261 // until after we call into ppoll() 262 maskAllSignals(); 263 std::signal(SIGINT, signalHandler); 264 std::signal(SIGTERM, signalHandler); 265 266 LOGI("Ready to accept connections"); 267 while (!sSignalReceived) { 268 int ret = ppoll(mPollFds, 1 + kMaxActiveClients, nullptr, &signalMask); 269 maskAllSignalsExceptIntAndTerm(); 270 if (ret == -1) { 271 LOGI("Exiting poll loop: %s", strerror(errno)); 272 break; 273 } 274 275 if (mPollFds[kListenIndex].revents & POLLIN) { 276 acceptClientConnection(); 277 } 278 279 for (size_t i = 1; i <= kMaxActiveClients; i++) { 280 if (mPollFds[i].fd < 0) { 281 continue; 282 } 283 284 if (mPollFds[i].revents & POLLIN) { 285 handleClientData(mPollFds[i].fd); 286 } 287 } 288 289 // Mask all signals to ensure that sSignalReceived can't become true between 290 // checking it in the while condition and calling into ppoll() 291 maskAllSignals(); 292 } 293 } 294 295 void SocketServer::signalHandler(int signal) { 296 LOGD("Caught signal %d", signal); 297 sSignalReceived = true; 298 } 299 300 } // namespace chre 301 } // namespace android 302