Home | History | Annotate | Download | only in jni
      1 /*
      2  * Copyright (C) 2006 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #define LOG_TAG "LocalSocketImpl"
     18 
     19 #include <nativehelper/JNIHelp.h>
     20 #include "jni.h"
     21 #include "utils/Log.h"
     22 #include "utils/misc.h"
     23 
     24 #include <stdio.h>
     25 #include <string.h>
     26 #include <sys/types.h>
     27 #include <sys/socket.h>
     28 #include <sys/un.h>
     29 #include <arpa/inet.h>
     30 #include <netinet/in.h>
     31 #include <stdlib.h>
     32 #include <errno.h>
     33 #include <unistd.h>
     34 #include <sys/ioctl.h>
     35 
     36 #include <cutils/sockets.h>
     37 #include <netinet/tcp.h>
     38 #include <nativehelper/ScopedUtfChars.h>
     39 
     40 namespace android {
     41 
     42 template <typename T>
     43 void UNUSED(T t) {}
     44 
     45 static jfieldID field_inboundFileDescriptors;
     46 static jfieldID field_outboundFileDescriptors;
     47 static jclass class_Credentials;
     48 static jclass class_FileDescriptor;
     49 static jmethodID method_CredentialsInit;
     50 
     51 /* private native void connectLocal(FileDescriptor fd,
     52  * String name, int namespace) throws IOException
     53  */
     54 static void
     55 socket_connect_local(JNIEnv *env, jobject object,
     56                         jobject fileDescriptor, jstring name, jint namespaceId)
     57 {
     58     int ret;
     59     int fd;
     60 
     61     fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
     62 
     63     if (env->ExceptionCheck()) {
     64         return;
     65     }
     66 
     67     ScopedUtfChars nameUtf8(env, name);
     68 
     69     ret = socket_local_client_connect(
     70                 fd,
     71                 nameUtf8.c_str(),
     72                 namespaceId,
     73                 SOCK_STREAM);
     74 
     75     if (ret < 0) {
     76         jniThrowIOException(env, errno);
     77         return;
     78     }
     79 }
     80 
     81 #define DEFAULT_BACKLOG 4
     82 
     83 /* private native void bindLocal(FileDescriptor fd, String name, namespace)
     84  * throws IOException;
     85  */
     86 
     87 static void
     88 socket_bind_local (JNIEnv *env, jobject object, jobject fileDescriptor,
     89                 jstring name, jint namespaceId)
     90 {
     91     int ret;
     92     int fd;
     93 
     94     if (name == NULL) {
     95         jniThrowNullPointerException(env, NULL);
     96         return;
     97     }
     98 
     99     fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
    100 
    101     if (env->ExceptionCheck()) {
    102         return;
    103     }
    104 
    105     ScopedUtfChars nameUtf8(env, name);
    106 
    107     ret = socket_local_server_bind(fd, nameUtf8.c_str(), namespaceId);
    108 
    109     if (ret < 0) {
    110         jniThrowIOException(env, errno);
    111         return;
    112     }
    113 }
    114 
    115 /**
    116  * Processes ancillary data, handling only
    117  * SCM_RIGHTS. Creates appropriate objects and sets appropriate
    118  * fields in the LocalSocketImpl object. Returns 0 on success
    119  * or -1 if an exception was thrown.
    120  */
    121 static int socket_process_cmsg(JNIEnv *env, jobject thisJ, struct msghdr * pMsg)
    122 {
    123     struct cmsghdr *cmsgptr;
    124 
    125     for (cmsgptr = CMSG_FIRSTHDR(pMsg);
    126             cmsgptr != NULL; cmsgptr = CMSG_NXTHDR(pMsg, cmsgptr)) {
    127 
    128         if (cmsgptr->cmsg_level != SOL_SOCKET) {
    129             continue;
    130         }
    131 
    132         if (cmsgptr->cmsg_type == SCM_RIGHTS) {
    133             int *pDescriptors = (int *)CMSG_DATA(cmsgptr);
    134             jobjectArray fdArray;
    135             int count
    136                 = ((cmsgptr->cmsg_len - CMSG_LEN(0)) / sizeof(int));
    137 
    138             if (count < 0) {
    139                 jniThrowException(env, "java/io/IOException",
    140                     "invalid cmsg length");
    141                 return -1;
    142             }
    143 
    144             fdArray = env->NewObjectArray(count, class_FileDescriptor, NULL);
    145 
    146             if (fdArray == NULL) {
    147                 return -1;
    148             }
    149 
    150             for (int i = 0; i < count; i++) {
    151                 jobject fdObject
    152                         = jniCreateFileDescriptor(env, pDescriptors[i]);
    153 
    154                 if (env->ExceptionCheck()) {
    155                     return -1;
    156                 }
    157 
    158                 env->SetObjectArrayElement(fdArray, i, fdObject);
    159 
    160                 if (env->ExceptionCheck()) {
    161                     return -1;
    162                 }
    163             }
    164 
    165             env->SetObjectField(thisJ, field_inboundFileDescriptors, fdArray);
    166 
    167             if (env->ExceptionCheck()) {
    168                 return -1;
    169             }
    170         }
    171     }
    172 
    173     return 0;
    174 }
    175 
    176 /**
    177  * Reads data from a socket into buf, processing any ancillary data
    178  * and adding it to thisJ.
    179  *
    180  * Returns the length of normal data read, or -1 if an exception has
    181  * been thrown in this function.
    182  */
    183 static ssize_t socket_read_all(JNIEnv *env, jobject thisJ, int fd,
    184         void *buffer, size_t len)
    185 {
    186     ssize_t ret;
    187     struct msghdr msg;
    188     struct iovec iv;
    189     unsigned char *buf = (unsigned char *)buffer;
    190     // Enough buffer for a pile of fd's. We throw an exception if
    191     // this buffer is too small.
    192     struct cmsghdr cmsgbuf[2*sizeof(cmsghdr) + 0x100];
    193 
    194     memset(&msg, 0, sizeof(msg));
    195     memset(&iv, 0, sizeof(iv));
    196 
    197     iv.iov_base = buf;
    198     iv.iov_len = len;
    199 
    200     msg.msg_iov = &iv;
    201     msg.msg_iovlen = 1;
    202     msg.msg_control = cmsgbuf;
    203     msg.msg_controllen = sizeof(cmsgbuf);
    204 
    205     ret = TEMP_FAILURE_RETRY(recvmsg(fd, &msg, MSG_NOSIGNAL | MSG_CMSG_CLOEXEC));
    206 
    207     if (ret < 0 && errno == EPIPE) {
    208         // Treat this as an end of stream
    209         return 0;
    210     }
    211 
    212     if (ret < 0) {
    213         jniThrowIOException(env, errno);
    214         return -1;
    215     }
    216 
    217     if ((msg.msg_flags & (MSG_CTRUNC | MSG_OOB | MSG_ERRQUEUE)) != 0) {
    218         // To us, any of the above flags are a fatal error
    219 
    220         jniThrowException(env, "java/io/IOException",
    221                 "Unexpected error or truncation during recvmsg()");
    222 
    223         return -1;
    224     }
    225 
    226     if (ret >= 0) {
    227         socket_process_cmsg(env, thisJ, &msg);
    228     }
    229 
    230     return ret;
    231 }
    232 
    233 /**
    234  * Writes all the data in the specified buffer to the specified socket.
    235  *
    236  * Returns 0 on success or -1 if an exception was thrown.
    237  */
    238 static int socket_write_all(JNIEnv *env, jobject object, int fd,
    239         void *buf, size_t len)
    240 {
    241     ssize_t ret;
    242     struct msghdr msg;
    243     unsigned char *buffer = (unsigned char *)buf;
    244     memset(&msg, 0, sizeof(msg));
    245 
    246     jobjectArray outboundFds
    247             = (jobjectArray)env->GetObjectField(
    248                 object, field_outboundFileDescriptors);
    249 
    250     if (env->ExceptionCheck()) {
    251         return -1;
    252     }
    253 
    254     struct cmsghdr *cmsg;
    255     int countFds = outboundFds == NULL ? 0 : env->GetArrayLength(outboundFds);
    256     int fds[countFds];
    257     char msgbuf[CMSG_SPACE(countFds)];
    258 
    259     // Add any pending outbound file descriptors to the message
    260     if (outboundFds != NULL) {
    261 
    262         if (env->ExceptionCheck()) {
    263             return -1;
    264         }
    265 
    266         for (int i = 0; i < countFds; i++) {
    267             jobject fdObject = env->GetObjectArrayElement(outboundFds, i);
    268             if (env->ExceptionCheck()) {
    269                 return -1;
    270             }
    271 
    272             fds[i] = jniGetFDFromFileDescriptor(env, fdObject);
    273             if (env->ExceptionCheck()) {
    274                 return -1;
    275             }
    276         }
    277 
    278         // See "man cmsg" really
    279         msg.msg_control = msgbuf;
    280         msg.msg_controllen = sizeof msgbuf;
    281         cmsg = CMSG_FIRSTHDR(&msg);
    282         cmsg->cmsg_level = SOL_SOCKET;
    283         cmsg->cmsg_type = SCM_RIGHTS;
    284         cmsg->cmsg_len = CMSG_LEN(sizeof fds);
    285         memcpy(CMSG_DATA(cmsg), fds, sizeof fds);
    286     }
    287 
    288     // We only write our msg_control during the first write
    289     while (len > 0) {
    290         struct iovec iv;
    291         memset(&iv, 0, sizeof(iv));
    292 
    293         iv.iov_base = buffer;
    294         iv.iov_len = len;
    295 
    296         msg.msg_iov = &iv;
    297         msg.msg_iovlen = 1;
    298 
    299         do {
    300             ret = sendmsg(fd, &msg, MSG_NOSIGNAL);
    301         } while (ret < 0 && errno == EINTR);
    302 
    303         if (ret < 0) {
    304             jniThrowIOException(env, errno);
    305             return -1;
    306         }
    307 
    308         buffer += ret;
    309         len -= ret;
    310 
    311         // Wipes out any msg_control too
    312         memset(&msg, 0, sizeof(msg));
    313     }
    314 
    315     return 0;
    316 }
    317 
    318 static jint socket_read (JNIEnv *env, jobject object, jobject fileDescriptor)
    319 {
    320     int fd;
    321     int err;
    322 
    323     if (fileDescriptor == NULL) {
    324         jniThrowNullPointerException(env, NULL);
    325         return (jint)-1;
    326     }
    327 
    328     fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
    329 
    330     if (env->ExceptionCheck()) {
    331         return (jint)0;
    332     }
    333 
    334     unsigned char buf;
    335 
    336     err = socket_read_all(env, object, fd, &buf, 1);
    337 
    338     if (err < 0) {
    339         jniThrowIOException(env, errno);
    340         return (jint)0;
    341     }
    342 
    343     if (err == 0) {
    344         // end of file
    345         return (jint)-1;
    346     }
    347 
    348     return (jint)buf;
    349 }
    350 
    351 static jint socket_readba (JNIEnv *env, jobject object,
    352         jbyteArray buffer, jint off, jint len, jobject fileDescriptor)
    353 {
    354     int fd;
    355     jbyte* byteBuffer;
    356     int ret;
    357 
    358     if (fileDescriptor == NULL || buffer == NULL) {
    359         jniThrowNullPointerException(env, NULL);
    360         return (jint)-1;
    361     }
    362 
    363     if (off < 0 || len < 0 || (off + len) > env->GetArrayLength(buffer)) {
    364         jniThrowException(env, "java/lang/ArrayIndexOutOfBoundsException", NULL);
    365         return (jint)-1;
    366     }
    367 
    368     if (len == 0) {
    369         // because socket_read_all returns 0 on EOF
    370         return 0;
    371     }
    372 
    373     fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
    374 
    375     if (env->ExceptionCheck()) {
    376         return (jint)-1;
    377     }
    378 
    379     byteBuffer = env->GetByteArrayElements(buffer, NULL);
    380 
    381     if (NULL == byteBuffer) {
    382         // an exception will have been thrown
    383         return (jint)-1;
    384     }
    385 
    386     ret = socket_read_all(env, object,
    387             fd, byteBuffer + off, len);
    388 
    389     // A return of -1 above means an exception is pending
    390 
    391     env->ReleaseByteArrayElements(buffer, byteBuffer, 0);
    392 
    393     return (jint) ((ret == 0) ? -1 : ret);
    394 }
    395 
    396 static void socket_write (JNIEnv *env, jobject object,
    397         jint b, jobject fileDescriptor)
    398 {
    399     int fd;
    400     int err;
    401 
    402     if (fileDescriptor == NULL) {
    403         jniThrowNullPointerException(env, NULL);
    404         return;
    405     }
    406 
    407     fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
    408 
    409     if (env->ExceptionCheck()) {
    410         return;
    411     }
    412 
    413     err = socket_write_all(env, object, fd, &b, 1);
    414     UNUSED(err);
    415     // A return of -1 above means an exception is pending
    416 }
    417 
    418 static void socket_writeba (JNIEnv *env, jobject object,
    419         jbyteArray buffer, jint off, jint len, jobject fileDescriptor)
    420 {
    421     int fd;
    422     int err;
    423     jbyte* byteBuffer;
    424 
    425     if (fileDescriptor == NULL || buffer == NULL) {
    426         jniThrowNullPointerException(env, NULL);
    427         return;
    428     }
    429 
    430     if (off < 0 || len < 0 || (off + len) > env->GetArrayLength(buffer)) {
    431         jniThrowException(env, "java/lang/ArrayIndexOutOfBoundsException", NULL);
    432         return;
    433     }
    434 
    435     fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
    436 
    437     if (env->ExceptionCheck()) {
    438         return;
    439     }
    440 
    441     byteBuffer = env->GetByteArrayElements(buffer,NULL);
    442 
    443     if (NULL == byteBuffer) {
    444         // an exception will have been thrown
    445         return;
    446     }
    447 
    448     err = socket_write_all(env, object, fd,
    449             byteBuffer + off, len);
    450     UNUSED(err);
    451     // A return of -1 above means an exception is pending
    452 
    453     env->ReleaseByteArrayElements(buffer, byteBuffer, JNI_ABORT);
    454 }
    455 
    456 static jobject socket_get_peer_credentials(JNIEnv *env,
    457         jobject object, jobject fileDescriptor)
    458 {
    459     int err;
    460     int fd;
    461 
    462     if (fileDescriptor == NULL) {
    463         jniThrowNullPointerException(env, NULL);
    464         return NULL;
    465     }
    466 
    467     fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
    468 
    469     if (env->ExceptionCheck()) {
    470         return NULL;
    471     }
    472 
    473     struct ucred creds;
    474 
    475     memset(&creds, 0, sizeof(creds));
    476     socklen_t szCreds = sizeof(creds);
    477 
    478     err = getsockopt(fd, SOL_SOCKET, SO_PEERCRED, &creds, &szCreds);
    479 
    480     if (err < 0) {
    481         jniThrowIOException(env, errno);
    482         return NULL;
    483     }
    484 
    485     if (szCreds == 0) {
    486         return NULL;
    487     }
    488 
    489     return env->NewObject(class_Credentials, method_CredentialsInit,
    490             creds.pid, creds.uid, creds.gid);
    491 }
    492 
    493 /*
    494  * JNI registration.
    495  */
    496 static const JNINativeMethod gMethods[] = {
    497      /* name, signature, funcPtr */
    498     {"connectLocal", "(Ljava/io/FileDescriptor;Ljava/lang/String;I)V",
    499                                                 (void*)socket_connect_local},
    500     {"bindLocal", "(Ljava/io/FileDescriptor;Ljava/lang/String;I)V", (void*)socket_bind_local},
    501     {"read_native", "(Ljava/io/FileDescriptor;)I", (void*) socket_read},
    502     {"readba_native", "([BIILjava/io/FileDescriptor;)I", (void*) socket_readba},
    503     {"writeba_native", "([BIILjava/io/FileDescriptor;)V", (void*) socket_writeba},
    504     {"write_native", "(ILjava/io/FileDescriptor;)V", (void*) socket_write},
    505     {"getPeerCredentials_native",
    506             "(Ljava/io/FileDescriptor;)Landroid/net/Credentials;",
    507             (void*) socket_get_peer_credentials}
    508 };
    509 
    510 int register_android_net_LocalSocketImpl(JNIEnv *env)
    511 {
    512     jclass clazz;
    513 
    514     clazz = env->FindClass("android/net/LocalSocketImpl");
    515 
    516     if (clazz == NULL) {
    517         goto error;
    518     }
    519 
    520     field_inboundFileDescriptors = env->GetFieldID(clazz,
    521             "inboundFileDescriptors", "[Ljava/io/FileDescriptor;");
    522 
    523     if (field_inboundFileDescriptors == NULL) {
    524         goto error;
    525     }
    526 
    527     field_outboundFileDescriptors = env->GetFieldID(clazz,
    528             "outboundFileDescriptors", "[Ljava/io/FileDescriptor;");
    529 
    530     if (field_outboundFileDescriptors == NULL) {
    531         goto error;
    532     }
    533 
    534     class_Credentials = env->FindClass("android/net/Credentials");
    535 
    536     if (class_Credentials == NULL) {
    537         goto error;
    538     }
    539 
    540     class_Credentials = (jclass)env->NewGlobalRef(class_Credentials);
    541 
    542     class_FileDescriptor = env->FindClass("java/io/FileDescriptor");
    543 
    544     if (class_FileDescriptor == NULL) {
    545         goto error;
    546     }
    547 
    548     class_FileDescriptor = (jclass)env->NewGlobalRef(class_FileDescriptor);
    549 
    550     method_CredentialsInit
    551             = env->GetMethodID(class_Credentials, "<init>", "(III)V");
    552 
    553     if (method_CredentialsInit == NULL) {
    554         goto error;
    555     }
    556 
    557     return jniRegisterNativeMethods(env,
    558         "android/net/LocalSocketImpl", gMethods, NELEM(gMethods));
    559 
    560 error:
    561     ALOGE("Error registering android.net.LocalSocketImpl");
    562     return -1;
    563 }
    564 
    565 };
    566