Home | History | Annotate | Download | only in common
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include "chre_host/socket_server.h"
     18 
     19 #include <poll.h>
     20 
     21 #include <cassert>
     22 #include <cinttypes>
     23 #include <csignal>
     24 #include <cstdlib>
     25 #include <map>
     26 #include <mutex>
     27 
     28 #include <cutils/sockets.h>
     29 
     30 #include "chre_host/log.h"
     31 
     32 namespace android {
     33 namespace chre {
     34 
     35 std::atomic<bool> SocketServer::sSignalReceived(false);
     36 
     37 namespace {
     38 
     39 void maskAllSignals() {
     40   sigset_t signalMask;
     41   sigfillset(&signalMask);
     42   if (sigprocmask(SIG_SETMASK, &signalMask, NULL) != 0) {
     43     LOG_ERROR("Couldn't mask all signals", errno);
     44   }
     45 }
     46 
     47 void maskAllSignalsExceptIntAndTerm() {
     48   sigset_t signalMask;
     49   sigfillset(&signalMask);
     50   sigdelset(&signalMask, SIGINT);
     51   sigdelset(&signalMask, SIGTERM);
     52   if (sigprocmask(SIG_SETMASK, &signalMask, NULL) != 0) {
     53     LOG_ERROR("Couldn't mask all signals except INT/TERM", errno);
     54   }
     55 }
     56 
     57 }  // anonymous namespace
     58 
     59 SocketServer::SocketServer() {
     60   // Initialize the socket fds field for all inactive client slots to -1, so
     61   // poll skips over it, and we don't attempt to send on it
     62   for (size_t i = 1; i <= kMaxActiveClients; i++) {
     63     mPollFds[i].fd = -1;
     64     mPollFds[i].events = POLLIN;
     65   }
     66 }
     67 
     68 void SocketServer::run(const char *socketName, bool allowSocketCreation,
     69                        ClientMessageCallback clientMessageCallback) {
     70   mClientMessageCallback = clientMessageCallback;
     71 
     72   mSockFd = android_get_control_socket(socketName);
     73   if (mSockFd == INVALID_SOCKET && allowSocketCreation) {
     74     LOGI("Didn't inherit socket, creating...");
     75     mSockFd = socket_local_server(socketName,
     76                                   ANDROID_SOCKET_NAMESPACE_RESERVED,
     77                                   SOCK_SEQPACKET);
     78   }
     79 
     80   if (mSockFd == INVALID_SOCKET) {
     81     LOGE("Couldn't get/create socket");
     82   } else {
     83     int ret = listen(mSockFd, kMaxPendingConnectionRequests);
     84     if (ret < 0) {
     85       LOG_ERROR("Couldn't listen on socket", errno);
     86     } else {
     87       serviceSocket();
     88     }
     89 
     90     {
     91       std::lock_guard<std::mutex> lock(mClientsMutex);
     92       for (const auto& pair : mClients) {
     93         int clientSocket = pair.first;
     94         if (close(clientSocket) != 0) {
     95           LOGI("Couldn't close client %" PRIu16 "'s socket: %s",
     96                pair.second.clientId, strerror(errno));
     97         }
     98       }
     99       mClients.clear();
    100     }
    101     close(mSockFd);
    102   }
    103 }
    104 
    105 void SocketServer::sendToAllClients(const void *data, size_t length) {
    106   std::lock_guard<std::mutex> lock(mClientsMutex);
    107 
    108   int deliveredCount = 0;
    109   for (const auto& pair : mClients) {
    110     int clientSocket = pair.first;
    111     uint16_t clientId = pair.second.clientId;
    112     if (sendToClientSocket(data, length, clientSocket, clientId)) {
    113       deliveredCount++;
    114     } else if (errno == EINTR) {
    115       // Exit early if we were interrupted - we should only get this for
    116       // SIGINT/SIGTERM, so we should exit quickly
    117       break;
    118     }
    119   }
    120 
    121   if (deliveredCount == 0) {
    122     LOGW("Got message but didn't deliver to any clients");
    123   }
    124 }
    125 
    126 bool SocketServer::sendToClientById(const void *data, size_t length,
    127                                     uint16_t clientId) {
    128   std::lock_guard<std::mutex> lock(mClientsMutex);
    129 
    130   bool sent = false;
    131   for (const auto& pair : mClients) {
    132     uint16_t thisClientId = pair.second.clientId;
    133     if (thisClientId == clientId) {
    134       int clientSocket = pair.first;
    135       sent = sendToClientSocket(data, length, clientSocket, thisClientId);
    136       break;
    137     }
    138   }
    139 
    140   return sent;
    141 }
    142 
    143 void SocketServer::acceptClientConnection() {
    144   int clientSocket = accept(mSockFd, NULL, NULL);
    145   if (clientSocket < 0) {
    146     LOG_ERROR("Couldn't accept client connection", errno);
    147   } else if (mClients.size() >= kMaxActiveClients) {
    148     LOGW("Rejecting client request - maximum number of clients reached");
    149     close(clientSocket);
    150   } else {
    151     ClientData clientData;
    152     clientData.clientId = mNextClientId++;
    153 
    154     // We currently don't handle wraparound - if we're getting this many
    155     // connects/disconnects, then something is wrong.
    156     // TODO: can handle this properly by iterating over the existing clients to
    157     // avoid a conflict.
    158     if (clientData.clientId == 0) {
    159       LOGE("Couldn't allocate client ID");
    160       std::exit(-1);
    161     }
    162 
    163     bool slotFound = false;
    164     for (size_t i = 1; i <= kMaxActiveClients; i++) {
    165       if (mPollFds[i].fd < 0) {
    166         mPollFds[i].fd = clientSocket;
    167         slotFound = true;
    168         break;
    169       }
    170     }
    171 
    172     if (!slotFound) {
    173       LOGE("Couldn't find slot for client!");
    174       assert(slotFound);
    175       close(clientSocket);
    176     } else {
    177       {
    178         std::lock_guard<std::mutex> lock(mClientsMutex);
    179         mClients[clientSocket] = clientData;
    180       }
    181       LOGI("Accepted new client connection (count %zu), assigned client ID %"
    182            PRIu16, mClients.size(), clientData.clientId);
    183     }
    184   }
    185 }
    186 
    187 void SocketServer::handleClientData(int clientSocket) {
    188   const ClientData& clientData = mClients[clientSocket];
    189   uint16_t clientId = clientData.clientId;
    190 
    191   ssize_t packetSize = recv(
    192       clientSocket, mRecvBuffer.data(), mRecvBuffer.size(), MSG_DONTWAIT);
    193   if (packetSize < 0) {
    194     LOGE("Couldn't get packet from client %" PRIu16 ": %s", clientId,
    195          strerror(errno));
    196   } else if (packetSize == 0) {
    197     LOGI("Client %" PRIu16 " disconnected", clientId);
    198     disconnectClient(clientSocket);
    199   } else {
    200     LOGV("Got %zd byte packet from client %" PRIu16, packetSize, clientId);
    201     mClientMessageCallback(clientId, mRecvBuffer.data(), packetSize);
    202   }
    203 }
    204 
    205 void SocketServer::disconnectClient(int clientSocket) {
    206   {
    207     std::lock_guard<std::mutex> lock(mClientsMutex);
    208     mClients.erase(clientSocket);
    209   }
    210   close(clientSocket);
    211 
    212   bool removed = false;
    213   for (size_t i = 1; i <= kMaxActiveClients; i++) {
    214     if (mPollFds[i].fd == clientSocket) {
    215       mPollFds[i].fd = -1;
    216       removed = true;
    217       break;
    218     }
    219   }
    220 
    221   if (!removed) {
    222     LOGE("Out of sync");
    223     assert(removed);
    224   }
    225 }
    226 
    227 bool SocketServer::sendToClientSocket(const void *data, size_t length,
    228                                       int clientSocket, uint16_t clientId) {
    229   errno = 0;
    230   ssize_t bytesSent = send(clientSocket, data, length, 0);
    231   if (bytesSent < 0) {
    232     LOGE("Error sending packet of size %zu to client %" PRIu16 ": %s",
    233          length, clientId, strerror(errno));
    234   } else if (bytesSent == 0) {
    235     LOGW("Client %" PRIu16 " disconnected before message could be delivered",
    236          clientId);
    237   } else {
    238     LOGV("Delivered message of size %zu bytes to client %" PRIu16, length,
    239          clientId);
    240   }
    241 
    242   return (bytesSent > 0);
    243 }
    244 
    245 void SocketServer::serviceSocket() {
    246   constexpr size_t kListenIndex = 0;
    247   static_assert(kListenIndex == 0, "Code assumes that the first index is "
    248                 "always the listen socket");
    249 
    250   mPollFds[kListenIndex].fd = mSockFd;
    251   mPollFds[kListenIndex].events = POLLIN;
    252 
    253   // Signal mask used with ppoll() so we gracefully handle SIGINT and SIGTERM,
    254   // and ignore other signals
    255   sigset_t signalMask;
    256   sigfillset(&signalMask);
    257   sigdelset(&signalMask, SIGINT);
    258   sigdelset(&signalMask, SIGTERM);
    259 
    260   // Masking signals here ensure that after this point, we won't handle INT/TERM
    261   // until after we call into ppoll()
    262   maskAllSignals();
    263   std::signal(SIGINT, signalHandler);
    264   std::signal(SIGTERM, signalHandler);
    265 
    266   LOGI("Ready to accept connections");
    267   while (!sSignalReceived) {
    268     int ret = ppoll(mPollFds, 1 + kMaxActiveClients, nullptr, &signalMask);
    269     maskAllSignalsExceptIntAndTerm();
    270     if (ret == -1) {
    271       LOGI("Exiting poll loop: %s", strerror(errno));
    272       break;
    273     }
    274 
    275     if (mPollFds[kListenIndex].revents & POLLIN) {
    276       acceptClientConnection();
    277     }
    278 
    279     for (size_t i = 1; i <= kMaxActiveClients; i++) {
    280       if (mPollFds[i].fd < 0) {
    281         continue;
    282       }
    283 
    284       if (mPollFds[i].revents & POLLIN) {
    285         handleClientData(mPollFds[i].fd);
    286       }
    287     }
    288 
    289     // Mask all signals to ensure that sSignalReceived can't become true between
    290     // checking it in the while condition and calling into ppoll()
    291     maskAllSignals();
    292   }
    293 }
    294 
    295 void SocketServer::signalHandler(int signal) {
    296   LOGD("Caught signal %d", signal);
    297   sSignalReceived = true;
    298 }
    299 
    300 }  // namespace chre
    301 }  // namespace android
    302