Home | History | Annotate | Download | only in libpdx_uds
      1 #include "uds/ipc_helper.h"
      2 
      3 #include <alloca.h>
      4 #include <errno.h>
      5 #include <log/log.h>
      6 #include <poll.h>
      7 #include <string.h>
      8 #include <sys/inotify.h>
      9 #include <sys/param.h>
     10 #include <sys/socket.h>
     11 
     12 #include <algorithm>
     13 
     14 #include <pdx/service.h>
     15 #include <pdx/utility.h>
     16 
     17 namespace android {
     18 namespace pdx {
     19 namespace uds {
     20 
     21 namespace {
     22 
     23 // Default implementations of Send/Receive interfaces to use standard socket
     24 // send/sendmsg/recv/recvmsg functions.
     25 class SocketSender : public SendInterface {
     26  public:
     27   ssize_t Send(int socket_fd, const void* data, size_t size,
     28                int flags) override {
     29     return send(socket_fd, data, size, flags);
     30   }
     31   ssize_t SendMessage(int socket_fd, const msghdr* msg, int flags) override {
     32     return sendmsg(socket_fd, msg, flags);
     33   }
     34 } g_socket_sender;
     35 
     36 class SocketReceiver : public RecvInterface {
     37  public:
     38   ssize_t Receive(int socket_fd, void* data, size_t size, int flags) override {
     39     return recv(socket_fd, data, size, flags);
     40   }
     41   ssize_t ReceiveMessage(int socket_fd, msghdr* msg, int flags) override {
     42     return recvmsg(socket_fd, msg, flags);
     43   }
     44 } g_socket_receiver;
     45 
     46 }  // anonymous namespace
     47 
     48 // Helper wrappers around send()/sendmsg() which repeat send() calls on data
     49 // that was not sent with the initial call to send/sendmsg. This is important to
     50 // handle transmissions interrupted by signals.
     51 Status<void> SendAll(SendInterface* sender, const BorrowedHandle& socket_fd,
     52                      const void* data, size_t size) {
     53   Status<void> ret;
     54   const uint8_t* ptr = static_cast<const uint8_t*>(data);
     55   while (size > 0) {
     56     ssize_t size_written =
     57         RETRY_EINTR(sender->Send(socket_fd.Get(), ptr, size, MSG_NOSIGNAL));
     58     if (size_written < 0) {
     59       ret.SetError(errno);
     60       ALOGE("SendAll: Failed to send data over socket: %s",
     61             ret.GetErrorMessage().c_str());
     62       break;
     63     }
     64     size -= size_written;
     65     ptr += size_written;
     66   }
     67   return ret;
     68 }
     69 
     70 Status<void> SendMsgAll(SendInterface* sender, const BorrowedHandle& socket_fd,
     71                         const msghdr* msg) {
     72   Status<void> ret;
     73   ssize_t sent_size =
     74       RETRY_EINTR(sender->SendMessage(socket_fd.Get(), msg, MSG_NOSIGNAL));
     75   if (sent_size < 0) {
     76     ret.SetError(errno);
     77     ALOGE("SendMsgAll: Failed to send data over socket: %s",
     78           ret.GetErrorMessage().c_str());
     79     return ret;
     80   }
     81 
     82   ssize_t chunk_start_offset = 0;
     83   for (size_t i = 0; i < msg->msg_iovlen; i++) {
     84     ssize_t chunk_end_offset = chunk_start_offset + msg->msg_iov[i].iov_len;
     85     if (sent_size < chunk_end_offset) {
     86       size_t offset_within_chunk = sent_size - chunk_start_offset;
     87       size_t data_size = msg->msg_iov[i].iov_len - offset_within_chunk;
     88       const uint8_t* chunk_base =
     89           static_cast<const uint8_t*>(msg->msg_iov[i].iov_base);
     90       ret = SendAll(sender, socket_fd, chunk_base + offset_within_chunk,
     91                     data_size);
     92       if (!ret)
     93         break;
     94       sent_size += data_size;
     95     }
     96     chunk_start_offset = chunk_end_offset;
     97   }
     98   return ret;
     99 }
    100 
    101 // Helper wrappers around recv()/recvmsg() which repeat recv() calls on data
    102 // that was not received with the initial call to recvmsg(). This is important
    103 // to handle transmissions interrupted by signals as well as the case when
    104 // initial data did not arrive in a single chunk over the socket (e.g. socket
    105 // buffer was full at the time of transmission, and only portion of initial
    106 // message was sent and the rest was blocked until the buffer was cleared by the
    107 // receiving side).
    108 Status<void> RecvMsgAll(RecvInterface* receiver,
    109                         const BorrowedHandle& socket_fd, msghdr* msg) {
    110   Status<void> ret;
    111   ssize_t size_read = RETRY_EINTR(receiver->ReceiveMessage(
    112       socket_fd.Get(), msg, MSG_WAITALL | MSG_CMSG_CLOEXEC));
    113   if (size_read < 0) {
    114     ret.SetError(errno);
    115     ALOGE("RecvMsgAll: Failed to receive data from socket: %s",
    116           ret.GetErrorMessage().c_str());
    117     return ret;
    118   } else if (size_read == 0) {
    119     ret.SetError(ESHUTDOWN);
    120     ALOGW("RecvMsgAll: Socket has been shut down");
    121     return ret;
    122   }
    123 
    124   ssize_t chunk_start_offset = 0;
    125   for (size_t i = 0; i < msg->msg_iovlen; i++) {
    126     ssize_t chunk_end_offset = chunk_start_offset + msg->msg_iov[i].iov_len;
    127     if (size_read < chunk_end_offset) {
    128       size_t offset_within_chunk = size_read - chunk_start_offset;
    129       size_t data_size = msg->msg_iov[i].iov_len - offset_within_chunk;
    130       uint8_t* chunk_base = static_cast<uint8_t*>(msg->msg_iov[i].iov_base);
    131       ret = RecvAll(receiver, socket_fd, chunk_base + offset_within_chunk,
    132                     data_size);
    133       if (!ret)
    134         break;
    135       size_read += data_size;
    136     }
    137     chunk_start_offset = chunk_end_offset;
    138   }
    139   return ret;
    140 }
    141 
    142 Status<void> RecvAll(RecvInterface* receiver, const BorrowedHandle& socket_fd,
    143                      void* data, size_t size) {
    144   Status<void> ret;
    145   uint8_t* ptr = static_cast<uint8_t*>(data);
    146   while (size > 0) {
    147     ssize_t size_read = RETRY_EINTR(receiver->Receive(
    148         socket_fd.Get(), ptr, size, MSG_WAITALL | MSG_CMSG_CLOEXEC));
    149     if (size_read < 0) {
    150       ret.SetError(errno);
    151       ALOGE("RecvAll: Failed to receive data from socket: %s",
    152             ret.GetErrorMessage().c_str());
    153       break;
    154     } else if (size_read == 0) {
    155       ret.SetError(ESHUTDOWN);
    156       ALOGW("RecvAll: Socket has been shut down");
    157       break;
    158     }
    159     size -= size_read;
    160     ptr += size_read;
    161   }
    162   return ret;
    163 }
    164 
    165 uint32_t kMagicPreamble = 0x7564736d;  // 'udsm'.
    166 
    167 struct MessagePreamble {
    168   uint32_t magic{0};
    169   uint32_t data_size{0};
    170   uint32_t fd_count{0};
    171 };
    172 
    173 Status<void> SendPayload::Send(const BorrowedHandle& socket_fd) {
    174   return Send(socket_fd, nullptr);
    175 }
    176 
    177 Status<void> SendPayload::Send(const BorrowedHandle& socket_fd,
    178                                const ucred* cred) {
    179   SendInterface* sender = sender_ ? sender_ : &g_socket_sender;
    180   MessagePreamble preamble;
    181   preamble.magic = kMagicPreamble;
    182   preamble.data_size = buffer_.size();
    183   preamble.fd_count = file_handles_.size();
    184   Status<void> ret = SendAll(sender, socket_fd, &preamble, sizeof(preamble));
    185   if (!ret)
    186     return ret;
    187 
    188   msghdr msg = {};
    189   iovec recv_vect = {buffer_.data(), buffer_.size()};
    190   msg.msg_iov = &recv_vect;
    191   msg.msg_iovlen = 1;
    192 
    193   if (cred || !file_handles_.empty()) {
    194     const size_t fd_bytes = file_handles_.size() * sizeof(int);
    195     msg.msg_controllen = (cred ? CMSG_SPACE(sizeof(ucred)) : 0) +
    196                          (fd_bytes == 0 ? 0 : CMSG_SPACE(fd_bytes));
    197     msg.msg_control = alloca(msg.msg_controllen);
    198 
    199     cmsghdr* control = CMSG_FIRSTHDR(&msg);
    200     if (cred) {
    201       control->cmsg_level = SOL_SOCKET;
    202       control->cmsg_type = SCM_CREDENTIALS;
    203       control->cmsg_len = CMSG_LEN(sizeof(ucred));
    204       memcpy(CMSG_DATA(control), cred, sizeof(ucred));
    205       control = CMSG_NXTHDR(&msg, control);
    206     }
    207 
    208     if (fd_bytes) {
    209       control->cmsg_level = SOL_SOCKET;
    210       control->cmsg_type = SCM_RIGHTS;
    211       control->cmsg_len = CMSG_LEN(fd_bytes);
    212       memcpy(CMSG_DATA(control), file_handles_.data(), fd_bytes);
    213     }
    214   }
    215 
    216   return SendMsgAll(sender, socket_fd, &msg);
    217 }
    218 
    219 // MessageWriter
    220 void* SendPayload::GetNextWriteBufferSection(size_t size) {
    221   return buffer_.grow_by(size);
    222 }
    223 
    224 OutputResourceMapper* SendPayload::GetOutputResourceMapper() { return this; }
    225 
    226 // OutputResourceMapper
    227 Status<FileReference> SendPayload::PushFileHandle(const LocalHandle& handle) {
    228   if (handle) {
    229     const int ref = file_handles_.size();
    230     file_handles_.push_back(handle.Get());
    231     return ref;
    232   } else {
    233     return handle.Get();
    234   }
    235 }
    236 
    237 Status<FileReference> SendPayload::PushFileHandle(
    238     const BorrowedHandle& handle) {
    239   if (handle) {
    240     const int ref = file_handles_.size();
    241     file_handles_.push_back(handle.Get());
    242     return ref;
    243   } else {
    244     return handle.Get();
    245   }
    246 }
    247 
    248 Status<FileReference> SendPayload::PushFileHandle(const RemoteHandle& handle) {
    249   return handle.Get();
    250 }
    251 
    252 Status<ChannelReference> SendPayload::PushChannelHandle(
    253     const LocalChannelHandle& /*handle*/) {
    254   return ErrorStatus{EOPNOTSUPP};
    255 }
    256 Status<ChannelReference> SendPayload::PushChannelHandle(
    257     const BorrowedChannelHandle& /*handle*/) {
    258   return ErrorStatus{EOPNOTSUPP};
    259 }
    260 Status<ChannelReference> SendPayload::PushChannelHandle(
    261     const RemoteChannelHandle& /*handle*/) {
    262   return ErrorStatus{EOPNOTSUPP};
    263 }
    264 
    265 Status<void> ReceivePayload::Receive(const BorrowedHandle& socket_fd) {
    266   return Receive(socket_fd, nullptr);
    267 }
    268 
    269 Status<void> ReceivePayload::Receive(const BorrowedHandle& socket_fd,
    270                                      ucred* cred) {
    271   RecvInterface* receiver = receiver_ ? receiver_ : &g_socket_receiver;
    272   MessagePreamble preamble;
    273   Status<void> ret = RecvAll(receiver, socket_fd, &preamble, sizeof(preamble));
    274   if (!ret)
    275     return ret;
    276 
    277   if (preamble.magic != kMagicPreamble) {
    278     ALOGE("ReceivePayload::Receive: Message header is invalid");
    279     ret.SetError(EIO);
    280     return ret;
    281   }
    282 
    283   buffer_.resize(preamble.data_size);
    284   file_handles_.clear();
    285   read_pos_ = 0;
    286 
    287   msghdr msg = {};
    288   iovec recv_vect = {buffer_.data(), buffer_.size()};
    289   msg.msg_iov = &recv_vect;
    290   msg.msg_iovlen = 1;
    291 
    292   if (cred || preamble.fd_count) {
    293     const size_t receive_fd_bytes = preamble.fd_count * sizeof(int);
    294     msg.msg_controllen =
    295         (cred ? CMSG_SPACE(sizeof(ucred)) : 0) +
    296         (receive_fd_bytes == 0 ? 0 : CMSG_SPACE(receive_fd_bytes));
    297     msg.msg_control = alloca(msg.msg_controllen);
    298   }
    299 
    300   ret = RecvMsgAll(receiver, socket_fd, &msg);
    301   if (!ret)
    302     return ret;
    303 
    304   bool cred_available = false;
    305   file_handles_.reserve(preamble.fd_count);
    306   cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
    307   while (cmsg) {
    308     if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_CREDENTIALS &&
    309         cred && cmsg->cmsg_len == CMSG_LEN(sizeof(ucred))) {
    310       cred_available = true;
    311       memcpy(cred, CMSG_DATA(cmsg), sizeof(ucred));
    312     } else if (cmsg->cmsg_level == SOL_SOCKET &&
    313                cmsg->cmsg_type == SCM_RIGHTS) {
    314       socklen_t payload_len = cmsg->cmsg_len - CMSG_LEN(0);
    315       const int* fds = reinterpret_cast<const int*>(CMSG_DATA(cmsg));
    316       size_t fd_count = payload_len / sizeof(int);
    317       std::transform(fds, fds + fd_count, std::back_inserter(file_handles_),
    318                      [](int fd) { return LocalHandle{fd}; });
    319     }
    320     cmsg = CMSG_NXTHDR(&msg, cmsg);
    321   }
    322 
    323   if (cred && !cred_available) {
    324     ALOGE("ReceivePayload::Receive: Failed to obtain message credentials");
    325     ret.SetError(EIO);
    326   }
    327 
    328   return ret;
    329 }
    330 
    331 // MessageReader
    332 MessageReader::BufferSection ReceivePayload::GetNextReadBufferSection() {
    333   return {buffer_.data() + read_pos_, &*buffer_.end()};
    334 }
    335 
    336 void ReceivePayload::ConsumeReadBufferSectionData(const void* new_start) {
    337   read_pos_ = PointerDistance(new_start, buffer_.data());
    338 }
    339 
    340 InputResourceMapper* ReceivePayload::GetInputResourceMapper() { return this; }
    341 
    342 // InputResourceMapper
    343 bool ReceivePayload::GetFileHandle(FileReference ref, LocalHandle* handle) {
    344   if (ref < 0) {
    345     *handle = LocalHandle{ref};
    346     return true;
    347   }
    348   if (static_cast<size_t>(ref) > file_handles_.size())
    349     return false;
    350   *handle = std::move(file_handles_[ref]);
    351   return true;
    352 }
    353 
    354 bool ReceivePayload::GetChannelHandle(ChannelReference /*ref*/,
    355                                       LocalChannelHandle* /*handle*/) {
    356   return false;
    357 }
    358 
    359 Status<void> SendData(const BorrowedHandle& socket_fd, const void* data,
    360                       size_t size) {
    361   return SendAll(&g_socket_sender, socket_fd, data, size);
    362 }
    363 
    364 Status<void> SendDataVector(const BorrowedHandle& socket_fd, const iovec* data,
    365                             size_t count) {
    366   msghdr msg = {};
    367   msg.msg_iov = const_cast<iovec*>(data);
    368   msg.msg_iovlen = count;
    369   return SendMsgAll(&g_socket_sender, socket_fd, &msg);
    370 }
    371 
    372 Status<void> ReceiveData(const BorrowedHandle& socket_fd, void* data,
    373                          size_t size) {
    374   return RecvAll(&g_socket_receiver, socket_fd, data, size);
    375 }
    376 
    377 Status<void> ReceiveDataVector(const BorrowedHandle& socket_fd,
    378                                const iovec* data, size_t count) {
    379   msghdr msg = {};
    380   msg.msg_iov = const_cast<iovec*>(data);
    381   msg.msg_iovlen = count;
    382   return RecvMsgAll(&g_socket_receiver, socket_fd, &msg);
    383 }
    384 
    385 size_t CountVectorSize(const iovec* vector, size_t count) {
    386   return std::accumulate(
    387       vector, vector + count, size_t{0},
    388       [](size_t size, const iovec& vec) { return size + vec.iov_len; });
    389 }
    390 
    391 void InitRequest(android::pdx::uds::RequestHeader<BorrowedHandle>* request,
    392                  int opcode, uint32_t send_len, uint32_t max_recv_len,
    393                  bool is_impulse) {
    394   request->op = opcode;
    395   request->cred.pid = getpid();
    396   request->cred.uid = geteuid();
    397   request->cred.gid = getegid();
    398   request->send_len = send_len;
    399   request->max_recv_len = max_recv_len;
    400   request->is_impulse = is_impulse;
    401 }
    402 
    403 Status<void> WaitForEndpoint(const std::string& endpoint_path,
    404                              int64_t timeout_ms) {
    405   // Endpoint path must be absolute.
    406   if (endpoint_path.empty() || endpoint_path.front() != '/')
    407     return ErrorStatus(EINVAL);
    408 
    409   // Create inotify fd.
    410   LocalHandle fd{inotify_init()};
    411   if (!fd)
    412     return ErrorStatus(errno);
    413 
    414   // Set the inotify fd to non-blocking.
    415   int ret = fcntl(fd.Get(), F_GETFL);
    416   fcntl(fd.Get(), F_SETFL, ret | O_NONBLOCK);
    417 
    418   // Setup the pollfd.
    419   pollfd pfd = {fd.Get(), POLLIN, 0};
    420 
    421   // Find locations of each path separator.
    422   std::vector<size_t> separators{0};  // The path is absolute, so '/' is at #0.
    423   size_t pos = endpoint_path.find('/', 1);
    424   while (pos != std::string::npos) {
    425     separators.push_back(pos);
    426     pos = endpoint_path.find('/', pos + 1);
    427   }
    428   separators.push_back(endpoint_path.size());
    429 
    430   // Walk down the path, checking for existence and waiting if needed.
    431   pos = 1;
    432   size_t links = 0;
    433   std::string current;
    434   while (pos < separators.size() && links <= MAXSYMLINKS) {
    435     std::string previous = current;
    436     current = endpoint_path.substr(0, separators[pos]);
    437 
    438     // Check for existence; proceed to setup a watch if not.
    439     if (access(current.c_str(), F_OK) < 0) {
    440       if (errno != ENOENT)
    441         return ErrorStatus(errno);
    442 
    443       // Extract the name of the path component to wait for.
    444       std::string next = current.substr(
    445           separators[pos - 1] + 1, separators[pos] - separators[pos - 1] - 1);
    446 
    447       // Add a watch on the last existing directory we reach.
    448       int wd = inotify_add_watch(
    449           fd.Get(), previous.c_str(),
    450           IN_CREATE | IN_DELETE_SELF | IN_MOVE_SELF | IN_MOVED_TO);
    451       if (wd < 0) {
    452         if (errno != ENOENT)
    453           return ErrorStatus(errno);
    454         // Restart at the beginning if previous was deleted.
    455         links = 0;
    456         current.clear();
    457         pos = 1;
    458         continue;
    459       }
    460 
    461       // Make sure current didn't get created before the watch was added.
    462       ret = access(current.c_str(), F_OK);
    463       if (ret < 0) {
    464         if (errno != ENOENT)
    465           return ErrorStatus(errno);
    466 
    467         bool exit_poll = false;
    468         while (!exit_poll) {
    469           // Wait for an event or timeout.
    470           ret = poll(&pfd, 1, timeout_ms);
    471           if (ret <= 0)
    472             return ErrorStatus(ret == 0 ? ETIMEDOUT : errno);
    473 
    474           // Read events.
    475           char buffer[sizeof(inotify_event) + NAME_MAX + 1];
    476 
    477           ret = read(fd.Get(), buffer, sizeof(buffer));
    478           if (ret < 0) {
    479             if (errno == EAGAIN || errno == EWOULDBLOCK)
    480               continue;
    481             else
    482               return ErrorStatus(errno);
    483           } else if (static_cast<size_t>(ret) < sizeof(struct inotify_event)) {
    484             return ErrorStatus(EIO);
    485           }
    486 
    487           auto* event = reinterpret_cast<const inotify_event*>(buffer);
    488           auto* end = reinterpret_cast<const inotify_event*>(buffer + ret);
    489           while (event < end) {
    490             std::string event_for;
    491             if (event->len > 0)
    492               event_for = event->name;
    493 
    494             if (event->mask & (IN_CREATE | IN_MOVED_TO)) {
    495               // See if this is the droid we're looking for.
    496               if (next == event_for) {
    497                 exit_poll = true;
    498                 break;
    499               }
    500             } else if (event->mask & (IN_DELETE_SELF | IN_MOVE_SELF)) {
    501               // Restart at the beginning if our watch dir is deleted.
    502               links = 0;
    503               current.clear();
    504               pos = 0;
    505               exit_poll = true;
    506               break;
    507             }
    508 
    509             event = reinterpret_cast<const inotify_event*>(AdvancePointer(
    510                 event, sizeof(struct inotify_event) + event->len));
    511           }  // while (event < end)
    512         }    // while (!exit_poll)
    513       }      // Current dir doesn't exist.
    514       ret = inotify_rm_watch(fd.Get(), wd);
    515       if (ret < 0 && errno != EINVAL)
    516         return ErrorStatus(errno);
    517     }  // if (access(current.c_str(), F_OK) < 0)
    518 
    519     // Check for symbolic link and update link count.
    520     struct stat stat_buf;
    521     ret = lstat(current.c_str(), &stat_buf);
    522     if (ret < 0 && errno != ENOENT)
    523       return ErrorStatus(errno);
    524     else if (ret == 0 && S_ISLNK(stat_buf.st_mode))
    525       links++;
    526     pos++;
    527   }  // while (pos < separators.size() && links <= MAXSYMLINKS)
    528 
    529   return {};
    530 }
    531 
    532 }  // namespace uds
    533 }  // namespace pdx
    534 }  // namespace android
    535