Home | History | Annotate | Download | only in src
      1 #include <alloca.h>
      2 #include <errno.h>
      3 #include <sys/socket.h>
      4 #include <sys/types.h>
      5 #include <pthread.h>
      6 #include <string.h>
      7 #include <arpa/inet.h>
      8 
      9 #define LOG_TAG "SocketClient"
     10 #include <cutils/log.h>
     11 
     12 #include <sysutils/SocketClient.h>
     13 
     14 SocketClient::SocketClient(int socket, bool owned) {
     15     init(socket, owned, false);
     16 }
     17 
     18 SocketClient::SocketClient(int socket, bool owned, bool useCmdNum) {
     19     init(socket, owned, useCmdNum);
     20 }
     21 
     22 void SocketClient::init(int socket, bool owned, bool useCmdNum) {
     23     mSocket = socket;
     24     mSocketOwned = owned;
     25     mUseCmdNum = useCmdNum;
     26     pthread_mutex_init(&mWriteMutex, NULL);
     27     pthread_mutex_init(&mRefCountMutex, NULL);
     28     mPid = -1;
     29     mUid = -1;
     30     mGid = -1;
     31     mRefCount = 1;
     32     mCmdNum = 0;
     33 
     34     struct ucred creds;
     35     socklen_t szCreds = sizeof(creds);
     36     memset(&creds, 0, szCreds);
     37 
     38     int err = getsockopt(socket, SOL_SOCKET, SO_PEERCRED, &creds, &szCreds);
     39     if (err == 0) {
     40         mPid = creds.pid;
     41         mUid = creds.uid;
     42         mGid = creds.gid;
     43     }
     44 }
     45 
     46 SocketClient::~SocketClient()
     47 {
     48     if (mSocketOwned) {
     49         close(mSocket);
     50     }
     51 }
     52 
     53 int SocketClient::sendMsg(int code, const char *msg, bool addErrno) {
     54     return sendMsg(code, msg, addErrno, mUseCmdNum);
     55 }
     56 
     57 int SocketClient::sendMsg(int code, const char *msg, bool addErrno, bool useCmdNum) {
     58     char *buf;
     59     int ret = 0;
     60 
     61     if (addErrno) {
     62         if (useCmdNum) {
     63             ret = asprintf(&buf, "%d %d %s (%s)", code, getCmdNum(), msg, strerror(errno));
     64         } else {
     65             ret = asprintf(&buf, "%d %s (%s)", code, msg, strerror(errno));
     66         }
     67     } else {
     68         if (useCmdNum) {
     69             ret = asprintf(&buf, "%d %d %s", code, getCmdNum(), msg);
     70         } else {
     71             ret = asprintf(&buf, "%d %s", code, msg);
     72         }
     73     }
     74     /* Send the zero-terminated message */
     75     if (ret != -1) {
     76         ret = sendMsg(buf);
     77         free(buf);
     78     }
     79     return ret;
     80 }
     81 
     82 /** send 3-digit code, null, binary-length, binary data */
     83 int SocketClient::sendBinaryMsg(int code, const void *data, int len) {
     84 
     85     /* 4 bytes for the code & null + 4 bytes for the len */
     86     char buf[8];
     87     /* Write the code */
     88     snprintf(buf, 4, "%.3d", code);
     89     /* Write the len */
     90     uint32_t tmp = htonl(len);
     91     memcpy(buf + 4, &tmp, sizeof(uint32_t));
     92 
     93     pthread_mutex_lock(&mWriteMutex);
     94     int result = sendDataLocked(buf, sizeof(buf));
     95     if (result == 0 && len > 0) {
     96         result = sendDataLocked(data, len);
     97     }
     98     pthread_mutex_unlock(&mWriteMutex);
     99 
    100     return result;
    101 }
    102 
    103 // Sends the code (c-string null-terminated).
    104 int SocketClient::sendCode(int code) {
    105     char buf[4];
    106     snprintf(buf, sizeof(buf), "%.3d", code);
    107     return sendData(buf, sizeof(buf));
    108 }
    109 
    110 char *SocketClient::quoteArg(const char *arg) {
    111     int len = strlen(arg);
    112     char *result = (char *)malloc(len * 2 + 3);
    113     char *current = result;
    114     const char *end = arg + len;
    115     char *oldresult;
    116 
    117     if(result == NULL) {
    118         SLOGW("malloc error (%s)", strerror(errno));
    119         return NULL;
    120     }
    121 
    122     *(current++) = '"';
    123     while (arg < end) {
    124         switch (*arg) {
    125         case '\\':
    126         case '"':
    127             *(current++) = '\\'; // fallthrough
    128         default:
    129             *(current++) = *(arg++);
    130         }
    131     }
    132     *(current++) = '"';
    133     *(current++) = '\0';
    134     oldresult = result; // save pointer in case realloc fails
    135     result = (char *)realloc(result, current-result);
    136     return result ? result : oldresult;
    137 }
    138 
    139 
    140 int SocketClient::sendMsg(const char *msg) {
    141     // Send the message including null character
    142     if (sendData(msg, strlen(msg) + 1) != 0) {
    143         SLOGW("Unable to send msg '%s'", msg);
    144         return -1;
    145     }
    146     return 0;
    147 }
    148 
    149 int SocketClient::sendData(const void *data, int len) {
    150 
    151     pthread_mutex_lock(&mWriteMutex);
    152     int rc = sendDataLocked(data, len);
    153     pthread_mutex_unlock(&mWriteMutex);
    154 
    155     return rc;
    156 }
    157 
    158 int SocketClient::sendDataLocked(const void *data, int len) {
    159     int rc = 0;
    160     const char *p = (const char*) data;
    161     int brtw = len;
    162 
    163     if (mSocket < 0) {
    164         errno = EHOSTUNREACH;
    165         return -1;
    166     }
    167 
    168     if (len == 0) {
    169         return 0;
    170     }
    171 
    172     while (brtw > 0) {
    173         rc = send(mSocket, p, brtw, MSG_NOSIGNAL);
    174         if (rc > 0) {
    175             p += rc;
    176             brtw -= rc;
    177             continue;
    178         }
    179 
    180         if (rc < 0 && errno == EINTR)
    181             continue;
    182 
    183         if (rc == 0) {
    184             SLOGW("0 length write :(");
    185             errno = EIO;
    186         } else {
    187             SLOGW("write error (%s)", strerror(errno));
    188         }
    189         return -1;
    190     }
    191     return 0;
    192 }
    193 
    194 void SocketClient::incRef() {
    195     pthread_mutex_lock(&mRefCountMutex);
    196     mRefCount++;
    197     pthread_mutex_unlock(&mRefCountMutex);
    198 }
    199 
    200 bool SocketClient::decRef() {
    201     bool deleteSelf = false;
    202     pthread_mutex_lock(&mRefCountMutex);
    203     mRefCount--;
    204     if (mRefCount == 0) {
    205         deleteSelf = true;
    206     } else if (mRefCount < 0) {
    207         SLOGE("SocketClient refcount went negative!");
    208     }
    209     pthread_mutex_unlock(&mRefCountMutex);
    210     if (deleteSelf) {
    211         delete this;
    212     }
    213     return deleteSelf;
    214 }
    215