Home | History | Annotate | Download | only in src
      1 #include <alloca.h>
      2 #include <errno.h>
      3 #include <pthread.h>
      4 #include <signal.h>
      5 #include <string.h>
      6 #include <arpa/inet.h>
      7 #include <sys/socket.h>
      8 #include <sys/types.h>
      9 
     10 #define LOG_TAG "SocketClient"
     11 #include <cutils/log.h>
     12 
     13 #include <sysutils/SocketClient.h>
     14 
     15 SocketClient::SocketClient(int socket, bool owned) {
     16     init(socket, owned, false);
     17 }
     18 
     19 SocketClient::SocketClient(int socket, bool owned, bool useCmdNum) {
     20     init(socket, owned, useCmdNum);
     21 }
     22 
     23 void SocketClient::init(int socket, bool owned, bool useCmdNum) {
     24     mSocket = socket;
     25     mSocketOwned = owned;
     26     mUseCmdNum = useCmdNum;
     27     pthread_mutex_init(&mWriteMutex, NULL);
     28     pthread_mutex_init(&mRefCountMutex, NULL);
     29     mPid = -1;
     30     mUid = -1;
     31     mGid = -1;
     32     mRefCount = 1;
     33     mCmdNum = 0;
     34 
     35     struct ucred creds;
     36     socklen_t szCreds = sizeof(creds);
     37     memset(&creds, 0, szCreds);
     38 
     39     int err = getsockopt(socket, SOL_SOCKET, SO_PEERCRED, &creds, &szCreds);
     40     if (err == 0) {
     41         mPid = creds.pid;
     42         mUid = creds.uid;
     43         mGid = creds.gid;
     44     }
     45 }
     46 
     47 SocketClient::~SocketClient() {
     48     if (mSocketOwned) {
     49         close(mSocket);
     50     }
     51 }
     52 
     53 int SocketClient::sendMsg(int code, const char *msg, bool addErrno) {
     54     return sendMsg(code, msg, addErrno, mUseCmdNum);
     55 }
     56 
     57 int SocketClient::sendMsg(int code, const char *msg, bool addErrno, bool useCmdNum) {
     58     char *buf;
     59     int ret = 0;
     60 
     61     if (addErrno) {
     62         if (useCmdNum) {
     63             ret = asprintf(&buf, "%d %d %s (%s)", code, getCmdNum(), msg, strerror(errno));
     64         } else {
     65             ret = asprintf(&buf, "%d %s (%s)", code, msg, strerror(errno));
     66         }
     67     } else {
     68         if (useCmdNum) {
     69             ret = asprintf(&buf, "%d %d %s", code, getCmdNum(), msg);
     70         } else {
     71             ret = asprintf(&buf, "%d %s", code, msg);
     72         }
     73     }
     74     // Send the zero-terminated message
     75     if (ret != -1) {
     76         ret = sendMsg(buf);
     77         free(buf);
     78     }
     79     return ret;
     80 }
     81 
     82 // send 3-digit code, null, binary-length, binary data
     83 int SocketClient::sendBinaryMsg(int code, const void *data, int len) {
     84 
     85     // 4 bytes for the code & null + 4 bytes for the len
     86     char buf[8];
     87     // Write the code
     88     snprintf(buf, 4, "%.3d", code);
     89     // Write the len
     90     uint32_t tmp = htonl(len);
     91     memcpy(buf + 4, &tmp, sizeof(uint32_t));
     92 
     93     struct iovec vec[2];
     94     vec[0].iov_base = (void *) buf;
     95     vec[0].iov_len = sizeof(buf);
     96     vec[1].iov_base = (void *) data;
     97     vec[1].iov_len = len;
     98 
     99     pthread_mutex_lock(&mWriteMutex);
    100     int result = sendDataLockedv(vec, (len > 0) ? 2 : 1);
    101     pthread_mutex_unlock(&mWriteMutex);
    102 
    103     return result;
    104 }
    105 
    106 // Sends the code (c-string null-terminated).
    107 int SocketClient::sendCode(int code) {
    108     char buf[4];
    109     snprintf(buf, sizeof(buf), "%.3d", code);
    110     return sendData(buf, sizeof(buf));
    111 }
    112 
    113 char *SocketClient::quoteArg(const char *arg) {
    114     int len = strlen(arg);
    115     char *result = (char *)malloc(len * 2 + 3);
    116     char *current = result;
    117     const char *end = arg + len;
    118     char *oldresult;
    119 
    120     if(result == NULL) {
    121         SLOGW("malloc error (%s)", strerror(errno));
    122         return NULL;
    123     }
    124 
    125     *(current++) = '"';
    126     while (arg < end) {
    127         switch (*arg) {
    128         case '\\':
    129         case '"':
    130             *(current++) = '\\'; // fallthrough
    131         default:
    132             *(current++) = *(arg++);
    133         }
    134     }
    135     *(current++) = '"';
    136     *(current++) = '\0';
    137     oldresult = result; // save pointer in case realloc fails
    138     result = (char *)realloc(result, current-result);
    139     return result ? result : oldresult;
    140 }
    141 
    142 
    143 int SocketClient::sendMsg(const char *msg) {
    144     // Send the message including null character
    145     if (sendData(msg, strlen(msg) + 1) != 0) {
    146         SLOGW("Unable to send msg '%s'", msg);
    147         return -1;
    148     }
    149     return 0;
    150 }
    151 
    152 int SocketClient::sendData(const void *data, int len) {
    153     struct iovec vec[1];
    154     vec[0].iov_base = (void *) data;
    155     vec[0].iov_len = len;
    156 
    157     pthread_mutex_lock(&mWriteMutex);
    158     int rc = sendDataLockedv(vec, 1);
    159     pthread_mutex_unlock(&mWriteMutex);
    160 
    161     return rc;
    162 }
    163 
    164 int SocketClient::sendDatav(struct iovec *iov, int iovcnt) {
    165     pthread_mutex_lock(&mWriteMutex);
    166     int rc = sendDataLockedv(iov, iovcnt);
    167     pthread_mutex_unlock(&mWriteMutex);
    168 
    169     return rc;
    170 }
    171 
    172 int SocketClient::sendDataLockedv(struct iovec *iov, int iovcnt) {
    173 
    174     if (mSocket < 0) {
    175         errno = EHOSTUNREACH;
    176         return -1;
    177     }
    178 
    179     if (iovcnt <= 0) {
    180         return 0;
    181     }
    182 
    183     int ret = 0;
    184     int e = 0; // SLOGW and sigaction are not inert regarding errno
    185     int current = 0;
    186 
    187     struct sigaction new_action, old_action;
    188     memset(&new_action, 0, sizeof(new_action));
    189     new_action.sa_handler = SIG_IGN;
    190     sigaction(SIGPIPE, &new_action, &old_action);
    191 
    192     for (;;) {
    193         ssize_t rc = TEMP_FAILURE_RETRY(
    194             writev(mSocket, iov + current, iovcnt - current));
    195 
    196         if (rc > 0) {
    197             size_t written = rc;
    198             while ((current < iovcnt) && (written >= iov[current].iov_len)) {
    199                 written -= iov[current].iov_len;
    200                 current++;
    201             }
    202             if (current == iovcnt) {
    203                 break;
    204             }
    205             iov[current].iov_base = (char *)iov[current].iov_base + written;
    206             iov[current].iov_len -= written;
    207             continue;
    208         }
    209 
    210         if (rc == 0) {
    211             e = EIO;
    212             SLOGW("0 length write :(");
    213         } else {
    214             e = errno;
    215             SLOGW("write error (%s)", strerror(e));
    216         }
    217         ret = -1;
    218         break;
    219     }
    220 
    221     sigaction(SIGPIPE, &old_action, &new_action);
    222 
    223     errno = e;
    224     return ret;
    225 }
    226 
    227 void SocketClient::incRef() {
    228     pthread_mutex_lock(&mRefCountMutex);
    229     mRefCount++;
    230     pthread_mutex_unlock(&mRefCountMutex);
    231 }
    232 
    233 bool SocketClient::decRef() {
    234     bool deleteSelf = false;
    235     pthread_mutex_lock(&mRefCountMutex);
    236     mRefCount--;
    237     if (mRefCount == 0) {
    238         deleteSelf = true;
    239     } else if (mRefCount < 0) {
    240         SLOGE("SocketClient refcount went negative!");
    241     }
    242     pthread_mutex_unlock(&mRefCountMutex);
    243     if (deleteSelf) {
    244         delete this;
    245     }
    246     return deleteSelf;
    247 }
    248