Home | History | Annotate | Download | only in src
      1 #include <alloca.h>
      2 #include <errno.h>
      3 #include <sys/socket.h>
      4 #include <sys/types.h>
      5 #include <pthread.h>
      6 #include <string.h>
      7 #include <arpa/inet.h>
      8 
      9 #define LOG_TAG "SocketClient"
     10 #include <cutils/log.h>
     11 
     12 #include <sysutils/SocketClient.h>
     13 
     14 SocketClient::SocketClient(int socket, bool owned) {
     15     init(socket, owned, false);
     16 }
     17 
     18 SocketClient::SocketClient(int socket, bool owned, bool useCmdNum) {
     19     init(socket, owned, useCmdNum);
     20 }
     21 
     22 void SocketClient::init(int socket, bool owned, bool useCmdNum) {
     23     mSocket = socket;
     24     mSocketOwned = owned;
     25     mUseCmdNum = useCmdNum;
     26     pthread_mutex_init(&mWriteMutex, NULL);
     27     pthread_mutex_init(&mRefCountMutex, NULL);
     28     mPid = -1;
     29     mUid = -1;
     30     mGid = -1;
     31     mRefCount = 1;
     32     mCmdNum = 0;
     33 
     34     struct ucred creds;
     35     socklen_t szCreds = sizeof(creds);
     36     memset(&creds, 0, szCreds);
     37 
     38     int err = getsockopt(socket, SOL_SOCKET, SO_PEERCRED, &creds, &szCreds);
     39     if (err == 0) {
     40         mPid = creds.pid;
     41         mUid = creds.uid;
     42         mGid = creds.gid;
     43     }
     44 }
     45 
     46 SocketClient::~SocketClient()
     47 {
     48     if (mSocketOwned) {
     49         close(mSocket);
     50     }
     51 }
     52 
     53 int SocketClient::sendMsg(int code, const char *msg, bool addErrno) {
     54     return sendMsg(code, msg, addErrno, mUseCmdNum);
     55 }
     56 
     57 int SocketClient::sendMsg(int code, const char *msg, bool addErrno, bool useCmdNum) {
     58     char *buf;
     59     int ret = 0;
     60 
     61     if (addErrno) {
     62         if (useCmdNum) {
     63             ret = asprintf(&buf, "%d %d %s (%s)", code, getCmdNum(), msg, strerror(errno));
     64         } else {
     65             ret = asprintf(&buf, "%d %s (%s)", code, msg, strerror(errno));
     66         }
     67     } else {
     68         if (useCmdNum) {
     69             ret = asprintf(&buf, "%d %d %s", code, getCmdNum(), msg);
     70         } else {
     71             ret = asprintf(&buf, "%d %s", code, msg);
     72         }
     73     }
     74     /* Send the zero-terminated message */
     75     if (ret != -1) {
     76         ret = sendMsg(buf);
     77         free(buf);
     78     }
     79     return ret;
     80 }
     81 
     82 /** send 3-digit code, null, binary-length, binary data */
     83 int SocketClient::sendBinaryMsg(int code, const void *data, int len) {
     84 
     85     /* 4 bytes for the code & null + 4 bytes for the len */
     86     char buf[8];
     87     /* Write the code */
     88     snprintf(buf, 4, "%.3d", code);
     89     /* Write the len */
     90     uint32_t tmp = htonl(len);
     91     memcpy(buf + 4, &tmp, sizeof(uint32_t));
     92 
     93     pthread_mutex_lock(&mWriteMutex);
     94     int result = sendDataLocked(buf, sizeof(buf));
     95     if (result == 0 && len > 0) {
     96         result = sendDataLocked(data, len);
     97     }
     98     pthread_mutex_unlock(&mWriteMutex);
     99 
    100     return result;
    101 }
    102 
    103 // Sends the code (c-string null-terminated).
    104 int SocketClient::sendCode(int code) {
    105     char buf[4];
    106     snprintf(buf, sizeof(buf), "%.3d", code);
    107     return sendData(buf, sizeof(buf));
    108 }
    109 
    110 char *SocketClient::quoteArg(const char *arg) {
    111     int len = strlen(arg);
    112     char *result = (char *)malloc(len * 2 + 3);
    113     char *current = result;
    114     const char *end = arg + len;
    115 
    116     *(current++) = '"';
    117     while (arg < end) {
    118         switch (*arg) {
    119         case '\\':
    120         case '"':
    121             *(current++) = '\\'; // fallthrough
    122         default:
    123             *(current++) = *(arg++);
    124         }
    125     }
    126     *(current++) = '"';
    127     *(current++) = '\0';
    128     result = (char *)realloc(result, current-result);
    129     return result;
    130 }
    131 
    132 
    133 int SocketClient::sendMsg(const char *msg) {
    134     // Send the message including null character
    135     if (sendData(msg, strlen(msg) + 1) != 0) {
    136         SLOGW("Unable to send msg '%s'", msg);
    137         return -1;
    138     }
    139     return 0;
    140 }
    141 
    142 int SocketClient::sendData(const void *data, int len) {
    143 
    144     pthread_mutex_lock(&mWriteMutex);
    145     int rc = sendDataLocked(data, len);
    146     pthread_mutex_unlock(&mWriteMutex);
    147 
    148     return rc;
    149 }
    150 
    151 int SocketClient::sendDataLocked(const void *data, int len) {
    152     int rc = 0;
    153     const char *p = (const char*) data;
    154     int brtw = len;
    155 
    156     if (mSocket < 0) {
    157         errno = EHOSTUNREACH;
    158         return -1;
    159     }
    160 
    161     if (len == 0) {
    162         return 0;
    163     }
    164 
    165     while (brtw > 0) {
    166         rc = send(mSocket, p, brtw, MSG_NOSIGNAL);
    167         if (rc > 0) {
    168             p += rc;
    169             brtw -= rc;
    170             continue;
    171         }
    172 
    173         if (rc < 0 && errno == EINTR)
    174             continue;
    175 
    176         if (rc == 0) {
    177             SLOGW("0 length write :(");
    178             errno = EIO;
    179         } else {
    180             SLOGW("write error (%s)", strerror(errno));
    181         }
    182         return -1;
    183     }
    184     return 0;
    185 }
    186 
    187 void SocketClient::incRef() {
    188     pthread_mutex_lock(&mRefCountMutex);
    189     mRefCount++;
    190     pthread_mutex_unlock(&mRefCountMutex);
    191 }
    192 
    193 bool SocketClient::decRef() {
    194     bool deleteSelf = false;
    195     pthread_mutex_lock(&mRefCountMutex);
    196     mRefCount--;
    197     if (mRefCount == 0) {
    198         deleteSelf = true;
    199     } else if (mRefCount < 0) {
    200         SLOGE("SocketClient refcount went negative!");
    201     }
    202     pthread_mutex_unlock(&mRefCountMutex);
    203     if (deleteSelf) {
    204         delete this;
    205     }
    206     return deleteSelf;
    207 }
    208