Home | History | Annotate | Download | only in libpdx_uds
      1 #include "uds/ipc_helper.h"
      2 
      3 #include <alloca.h>
      4 #include <errno.h>
      5 #include <log/log.h>
      6 #include <poll.h>
      7 #include <string.h>
      8 #include <sys/inotify.h>
      9 #include <sys/param.h>
     10 #include <sys/socket.h>
     11 
     12 #include <algorithm>
     13 
     14 #include <pdx/service.h>
     15 #include <pdx/utility.h>
     16 
     17 namespace android {
     18 namespace pdx {
     19 namespace uds {
     20 
     21 namespace {
     22 
     23 constexpr size_t kMaxFdCount =
     24     256;  // Total of 1KiB of data to transfer these FDs.
     25 
     26 // Default implementations of Send/Receive interfaces to use standard socket
     27 // send/sendmsg/recv/recvmsg functions.
     28 class SocketSender : public SendInterface {
     29  public:
     30   ssize_t Send(int socket_fd, const void* data, size_t size,
     31                int flags) override {
     32     return send(socket_fd, data, size, flags);
     33   }
     34   ssize_t SendMessage(int socket_fd, const msghdr* msg, int flags) override {
     35     return sendmsg(socket_fd, msg, flags);
     36   }
     37 } g_socket_sender;
     38 
     39 class SocketReceiver : public RecvInterface {
     40  public:
     41   ssize_t Receive(int socket_fd, void* data, size_t size, int flags) override {
     42     return recv(socket_fd, data, size, flags);
     43   }
     44   ssize_t ReceiveMessage(int socket_fd, msghdr* msg, int flags) override {
     45     return recvmsg(socket_fd, msg, flags);
     46   }
     47 } g_socket_receiver;
     48 
     49 }  // anonymous namespace
     50 
     51 // Helper wrappers around send()/sendmsg() which repeat send() calls on data
     52 // that was not sent with the initial call to send/sendmsg. This is important to
     53 // handle transmissions interrupted by signals.
     54 Status<void> SendAll(SendInterface* sender, const BorrowedHandle& socket_fd,
     55                      const void* data, size_t size) {
     56   Status<void> ret;
     57   const uint8_t* ptr = static_cast<const uint8_t*>(data);
     58   while (size > 0) {
     59     ssize_t size_written =
     60         RETRY_EINTR(sender->Send(socket_fd.Get(), ptr, size, MSG_NOSIGNAL));
     61     if (size_written < 0) {
     62       ret.SetError(errno);
     63       ALOGE("SendAll: Failed to send data over socket: %s",
     64             ret.GetErrorMessage().c_str());
     65       break;
     66     }
     67     size -= size_written;
     68     ptr += size_written;
     69   }
     70   return ret;
     71 }
     72 
     73 Status<void> SendMsgAll(SendInterface* sender, const BorrowedHandle& socket_fd,
     74                         const msghdr* msg) {
     75   Status<void> ret;
     76   ssize_t sent_size =
     77       RETRY_EINTR(sender->SendMessage(socket_fd.Get(), msg, MSG_NOSIGNAL));
     78   if (sent_size < 0) {
     79     ret.SetError(errno);
     80     ALOGE("SendMsgAll: Failed to send data over socket: %s",
     81           ret.GetErrorMessage().c_str());
     82     return ret;
     83   }
     84 
     85   ssize_t chunk_start_offset = 0;
     86   for (size_t i = 0; i < msg->msg_iovlen; i++) {
     87     ssize_t chunk_end_offset = chunk_start_offset + msg->msg_iov[i].iov_len;
     88     if (sent_size < chunk_end_offset) {
     89       size_t offset_within_chunk = sent_size - chunk_start_offset;
     90       size_t data_size = msg->msg_iov[i].iov_len - offset_within_chunk;
     91       const uint8_t* chunk_base =
     92           static_cast<const uint8_t*>(msg->msg_iov[i].iov_base);
     93       ret = SendAll(sender, socket_fd, chunk_base + offset_within_chunk,
     94                     data_size);
     95       if (!ret)
     96         break;
     97       sent_size += data_size;
     98     }
     99     chunk_start_offset = chunk_end_offset;
    100   }
    101   return ret;
    102 }
    103 
    104 // Helper wrappers around recv()/recvmsg() which repeat recv() calls on data
    105 // that was not received with the initial call to recvmsg(). This is important
    106 // to handle transmissions interrupted by signals as well as the case when
    107 // initial data did not arrive in a single chunk over the socket (e.g. socket
    108 // buffer was full at the time of transmission, and only portion of initial
    109 // message was sent and the rest was blocked until the buffer was cleared by the
    110 // receiving side).
    111 Status<void> RecvMsgAll(RecvInterface* receiver,
    112                         const BorrowedHandle& socket_fd, msghdr* msg) {
    113   Status<void> ret;
    114   ssize_t size_read = RETRY_EINTR(receiver->ReceiveMessage(
    115       socket_fd.Get(), msg, MSG_WAITALL | MSG_CMSG_CLOEXEC));
    116   if (size_read < 0) {
    117     ret.SetError(errno);
    118     ALOGE("RecvMsgAll: Failed to receive data from socket: %s",
    119           ret.GetErrorMessage().c_str());
    120     return ret;
    121   } else if (size_read == 0) {
    122     ret.SetError(ESHUTDOWN);
    123     ALOGW("RecvMsgAll: Socket has been shut down");
    124     return ret;
    125   }
    126 
    127   ssize_t chunk_start_offset = 0;
    128   for (size_t i = 0; i < msg->msg_iovlen; i++) {
    129     ssize_t chunk_end_offset = chunk_start_offset + msg->msg_iov[i].iov_len;
    130     if (size_read < chunk_end_offset) {
    131       size_t offset_within_chunk = size_read - chunk_start_offset;
    132       size_t data_size = msg->msg_iov[i].iov_len - offset_within_chunk;
    133       uint8_t* chunk_base = static_cast<uint8_t*>(msg->msg_iov[i].iov_base);
    134       ret = RecvAll(receiver, socket_fd, chunk_base + offset_within_chunk,
    135                     data_size);
    136       if (!ret)
    137         break;
    138       size_read += data_size;
    139     }
    140     chunk_start_offset = chunk_end_offset;
    141   }
    142   return ret;
    143 }
    144 
    145 Status<void> RecvAll(RecvInterface* receiver, const BorrowedHandle& socket_fd,
    146                      void* data, size_t size) {
    147   Status<void> ret;
    148   uint8_t* ptr = static_cast<uint8_t*>(data);
    149   while (size > 0) {
    150     ssize_t size_read = RETRY_EINTR(receiver->Receive(
    151         socket_fd.Get(), ptr, size, MSG_WAITALL | MSG_CMSG_CLOEXEC));
    152     if (size_read < 0) {
    153       ret.SetError(errno);
    154       ALOGE("RecvAll: Failed to receive data from socket: %s",
    155             ret.GetErrorMessage().c_str());
    156       break;
    157     } else if (size_read == 0) {
    158       ret.SetError(ESHUTDOWN);
    159       ALOGW("RecvAll: Socket has been shut down");
    160       break;
    161     }
    162     size -= size_read;
    163     ptr += size_read;
    164   }
    165   return ret;
    166 }
    167 
    168 uint32_t kMagicPreamble = 0x7564736d;  // 'udsm'.
    169 
    170 struct MessagePreamble {
    171   uint32_t magic{0};
    172   uint32_t data_size{0};
    173   uint32_t fd_count{0};
    174 };
    175 
    176 Status<void> SendPayload::Send(const BorrowedHandle& socket_fd) {
    177   return Send(socket_fd, nullptr);
    178 }
    179 
    180 Status<void> SendPayload::Send(const BorrowedHandle& socket_fd,
    181                                const ucred* cred, const iovec* data_vec,
    182                                size_t vec_count) {
    183   if (file_handles_.size() > kMaxFdCount) {
    184     ALOGE(
    185         "SendPayload::Send: Trying to send too many file descriptors (%zu), "
    186         "max allowed = %zu",
    187         file_handles_.size(), kMaxFdCount);
    188     return ErrorStatus{EINVAL};
    189   }
    190 
    191   SendInterface* sender = sender_ ? sender_ : &g_socket_sender;
    192   MessagePreamble preamble;
    193   preamble.magic = kMagicPreamble;
    194   preamble.data_size = buffer_.size();
    195   preamble.fd_count = file_handles_.size();
    196 
    197   msghdr msg = {};
    198   msg.msg_iovlen = 2 + vec_count;
    199   msg.msg_iov = static_cast<iovec*>(alloca(sizeof(iovec) * msg.msg_iovlen));
    200   msg.msg_iov[0].iov_base = &preamble;
    201   msg.msg_iov[0].iov_len = sizeof(preamble);
    202   msg.msg_iov[1].iov_base = buffer_.data();
    203   msg.msg_iov[1].iov_len = buffer_.size();
    204   for (size_t i = 0; i < vec_count; i++)
    205     msg.msg_iov[i + 2] = data_vec[i];
    206 
    207   if (cred || !file_handles_.empty()) {
    208     const size_t fd_bytes = file_handles_.size() * sizeof(int);
    209     msg.msg_controllen = (cred ? CMSG_SPACE(sizeof(ucred)) : 0) +
    210                          (fd_bytes == 0 ? 0 : CMSG_SPACE(fd_bytes));
    211     msg.msg_control = alloca(msg.msg_controllen);
    212 
    213     cmsghdr* control = CMSG_FIRSTHDR(&msg);
    214     if (cred) {
    215       control->cmsg_level = SOL_SOCKET;
    216       control->cmsg_type = SCM_CREDENTIALS;
    217       control->cmsg_len = CMSG_LEN(sizeof(ucred));
    218       memcpy(CMSG_DATA(control), cred, sizeof(ucred));
    219       control = CMSG_NXTHDR(&msg, control);
    220     }
    221 
    222     if (fd_bytes) {
    223       control->cmsg_level = SOL_SOCKET;
    224       control->cmsg_type = SCM_RIGHTS;
    225       control->cmsg_len = CMSG_LEN(fd_bytes);
    226       memcpy(CMSG_DATA(control), file_handles_.data(), fd_bytes);
    227     }
    228   }
    229 
    230   return SendMsgAll(sender, socket_fd, &msg);
    231 }
    232 
    233 // MessageWriter
    234 void* SendPayload::GetNextWriteBufferSection(size_t size) {
    235   return buffer_.grow_by(size);
    236 }
    237 
    238 OutputResourceMapper* SendPayload::GetOutputResourceMapper() { return this; }
    239 
    240 // OutputResourceMapper
    241 Status<FileReference> SendPayload::PushFileHandle(const LocalHandle& handle) {
    242   if (handle) {
    243     const int ref = file_handles_.size();
    244     file_handles_.push_back(handle.Get());
    245     return ref;
    246   } else {
    247     return handle.Get();
    248   }
    249 }
    250 
    251 Status<FileReference> SendPayload::PushFileHandle(
    252     const BorrowedHandle& handle) {
    253   if (handle) {
    254     const int ref = file_handles_.size();
    255     file_handles_.push_back(handle.Get());
    256     return ref;
    257   } else {
    258     return handle.Get();
    259   }
    260 }
    261 
    262 Status<FileReference> SendPayload::PushFileHandle(const RemoteHandle& handle) {
    263   return handle.Get();
    264 }
    265 
    266 Status<ChannelReference> SendPayload::PushChannelHandle(
    267     const LocalChannelHandle& /*handle*/) {
    268   return ErrorStatus{EOPNOTSUPP};
    269 }
    270 Status<ChannelReference> SendPayload::PushChannelHandle(
    271     const BorrowedChannelHandle& /*handle*/) {
    272   return ErrorStatus{EOPNOTSUPP};
    273 }
    274 Status<ChannelReference> SendPayload::PushChannelHandle(
    275     const RemoteChannelHandle& /*handle*/) {
    276   return ErrorStatus{EOPNOTSUPP};
    277 }
    278 
    279 Status<void> ReceivePayload::Receive(const BorrowedHandle& socket_fd) {
    280   return Receive(socket_fd, nullptr);
    281 }
    282 
    283 Status<void> ReceivePayload::Receive(const BorrowedHandle& socket_fd,
    284                                      ucred* cred) {
    285   RecvInterface* receiver = receiver_ ? receiver_ : &g_socket_receiver;
    286   MessagePreamble preamble;
    287   msghdr msg = {};
    288   iovec recv_vect = {&preamble, sizeof(preamble)};
    289   msg.msg_iov = &recv_vect;
    290   msg.msg_iovlen = 1;
    291   const size_t receive_fd_bytes = kMaxFdCount * sizeof(int);
    292   msg.msg_controllen = CMSG_SPACE(sizeof(ucred)) + CMSG_SPACE(receive_fd_bytes);
    293   msg.msg_control = alloca(msg.msg_controllen);
    294 
    295   Status<void> ret = RecvMsgAll(receiver, socket_fd, &msg);
    296   if (!ret)
    297     return ret;
    298 
    299   if (preamble.magic != kMagicPreamble) {
    300     ALOGE("ReceivePayload::Receive: Message header is invalid");
    301     ret.SetError(EIO);
    302     return ret;
    303   }
    304 
    305   buffer_.resize(preamble.data_size);
    306   file_handles_.clear();
    307   read_pos_ = 0;
    308 
    309   bool cred_available = false;
    310   file_handles_.reserve(preamble.fd_count);
    311   cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
    312   while (cmsg) {
    313     if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_CREDENTIALS &&
    314         cred && cmsg->cmsg_len == CMSG_LEN(sizeof(ucred))) {
    315       cred_available = true;
    316       memcpy(cred, CMSG_DATA(cmsg), sizeof(ucred));
    317     } else if (cmsg->cmsg_level == SOL_SOCKET &&
    318                cmsg->cmsg_type == SCM_RIGHTS) {
    319       socklen_t payload_len = cmsg->cmsg_len - CMSG_LEN(0);
    320       const int* fds = reinterpret_cast<const int*>(CMSG_DATA(cmsg));
    321       size_t fd_count = payload_len / sizeof(int);
    322       std::transform(fds, fds + fd_count, std::back_inserter(file_handles_),
    323                      [](int fd) { return LocalHandle{fd}; });
    324     }
    325     cmsg = CMSG_NXTHDR(&msg, cmsg);
    326   }
    327 
    328   ret = RecvAll(receiver, socket_fd, buffer_.data(), buffer_.size());
    329   if (!ret)
    330     return ret;
    331 
    332   if (cred && !cred_available) {
    333     ALOGE("ReceivePayload::Receive: Failed to obtain message credentials");
    334     ret.SetError(EIO);
    335   }
    336 
    337   return ret;
    338 }
    339 
    340 // MessageReader
    341 MessageReader::BufferSection ReceivePayload::GetNextReadBufferSection() {
    342   return {buffer_.data() + read_pos_, &*buffer_.end()};
    343 }
    344 
    345 void ReceivePayload::ConsumeReadBufferSectionData(const void* new_start) {
    346   read_pos_ = PointerDistance(new_start, buffer_.data());
    347 }
    348 
    349 InputResourceMapper* ReceivePayload::GetInputResourceMapper() { return this; }
    350 
    351 // InputResourceMapper
    352 bool ReceivePayload::GetFileHandle(FileReference ref, LocalHandle* handle) {
    353   if (ref < 0) {
    354     *handle = LocalHandle{ref};
    355     return true;
    356   }
    357   if (static_cast<size_t>(ref) > file_handles_.size())
    358     return false;
    359   *handle = std::move(file_handles_[ref]);
    360   return true;
    361 }
    362 
    363 bool ReceivePayload::GetChannelHandle(ChannelReference /*ref*/,
    364                                       LocalChannelHandle* /*handle*/) {
    365   return false;
    366 }
    367 
    368 Status<void> SendData(const BorrowedHandle& socket_fd, const void* data,
    369                       size_t size) {
    370   return SendAll(&g_socket_sender, socket_fd, data, size);
    371 }
    372 
    373 Status<void> SendDataVector(const BorrowedHandle& socket_fd, const iovec* data,
    374                             size_t count) {
    375   msghdr msg = {};
    376   msg.msg_iov = const_cast<iovec*>(data);
    377   msg.msg_iovlen = count;
    378   return SendMsgAll(&g_socket_sender, socket_fd, &msg);
    379 }
    380 
    381 Status<void> ReceiveData(const BorrowedHandle& socket_fd, void* data,
    382                          size_t size) {
    383   return RecvAll(&g_socket_receiver, socket_fd, data, size);
    384 }
    385 
    386 Status<void> ReceiveDataVector(const BorrowedHandle& socket_fd,
    387                                const iovec* data, size_t count) {
    388   msghdr msg = {};
    389   msg.msg_iov = const_cast<iovec*>(data);
    390   msg.msg_iovlen = count;
    391   return RecvMsgAll(&g_socket_receiver, socket_fd, &msg);
    392 }
    393 
    394 size_t CountVectorSize(const iovec* vector, size_t count) {
    395   return std::accumulate(
    396       vector, vector + count, size_t{0},
    397       [](size_t size, const iovec& vec) { return size + vec.iov_len; });
    398 }
    399 
    400 void InitRequest(android::pdx::uds::RequestHeader<BorrowedHandle>* request,
    401                  int opcode, uint32_t send_len, uint32_t max_recv_len,
    402                  bool is_impulse) {
    403   request->op = opcode;
    404   request->cred.pid = getpid();
    405   request->cred.uid = geteuid();
    406   request->cred.gid = getegid();
    407   request->send_len = send_len;
    408   request->max_recv_len = max_recv_len;
    409   request->is_impulse = is_impulse;
    410 }
    411 
    412 Status<void> WaitForEndpoint(const std::string& endpoint_path,
    413                              int64_t timeout_ms) {
    414   // Endpoint path must be absolute.
    415   if (endpoint_path.empty() || endpoint_path.front() != '/')
    416     return ErrorStatus(EINVAL);
    417 
    418   // Create inotify fd.
    419   LocalHandle fd{inotify_init()};
    420   if (!fd)
    421     return ErrorStatus(errno);
    422 
    423   // Set the inotify fd to non-blocking.
    424   int ret = fcntl(fd.Get(), F_GETFL);
    425   fcntl(fd.Get(), F_SETFL, ret | O_NONBLOCK);
    426 
    427   // Setup the pollfd.
    428   pollfd pfd = {fd.Get(), POLLIN, 0};
    429 
    430   // Find locations of each path separator.
    431   std::vector<size_t> separators{0};  // The path is absolute, so '/' is at #0.
    432   size_t pos = endpoint_path.find('/', 1);
    433   while (pos != std::string::npos) {
    434     separators.push_back(pos);
    435     pos = endpoint_path.find('/', pos + 1);
    436   }
    437   separators.push_back(endpoint_path.size());
    438 
    439   // Walk down the path, checking for existence and waiting if needed.
    440   pos = 1;
    441   size_t links = 0;
    442   std::string current;
    443   while (pos < separators.size() && links <= MAXSYMLINKS) {
    444     std::string previous = current;
    445     current = endpoint_path.substr(0, separators[pos]);
    446 
    447     // Check for existence; proceed to setup a watch if not.
    448     if (access(current.c_str(), F_OK) < 0) {
    449       if (errno != ENOENT)
    450         return ErrorStatus(errno);
    451 
    452       // Extract the name of the path component to wait for.
    453       std::string next = current.substr(
    454           separators[pos - 1] + 1, separators[pos] - separators[pos - 1] - 1);
    455 
    456       // Add a watch on the last existing directory we reach.
    457       int wd = inotify_add_watch(
    458           fd.Get(), previous.c_str(),
    459           IN_CREATE | IN_DELETE_SELF | IN_MOVE_SELF | IN_MOVED_TO);
    460       if (wd < 0) {
    461         if (errno != ENOENT)
    462           return ErrorStatus(errno);
    463         // Restart at the beginning if previous was deleted.
    464         links = 0;
    465         current.clear();
    466         pos = 1;
    467         continue;
    468       }
    469 
    470       // Make sure current didn't get created before the watch was added.
    471       ret = access(current.c_str(), F_OK);
    472       if (ret < 0) {
    473         if (errno != ENOENT)
    474           return ErrorStatus(errno);
    475 
    476         bool exit_poll = false;
    477         while (!exit_poll) {
    478           // Wait for an event or timeout.
    479           ret = poll(&pfd, 1, timeout_ms);
    480           if (ret <= 0)
    481             return ErrorStatus(ret == 0 ? ETIMEDOUT : errno);
    482 
    483           // Read events.
    484           char buffer[sizeof(inotify_event) + NAME_MAX + 1];
    485 
    486           ret = read(fd.Get(), buffer, sizeof(buffer));
    487           if (ret < 0) {
    488             if (errno == EAGAIN || errno == EWOULDBLOCK)
    489               continue;
    490             else
    491               return ErrorStatus(errno);
    492           } else if (static_cast<size_t>(ret) < sizeof(struct inotify_event)) {
    493             return ErrorStatus(EIO);
    494           }
    495 
    496           auto* event = reinterpret_cast<const inotify_event*>(buffer);
    497           auto* end = reinterpret_cast<const inotify_event*>(buffer + ret);
    498           while (event < end) {
    499             std::string event_for;
    500             if (event->len > 0)
    501               event_for = event->name;
    502 
    503             if (event->mask & (IN_CREATE | IN_MOVED_TO)) {
    504               // See if this is the droid we're looking for.
    505               if (next == event_for) {
    506                 exit_poll = true;
    507                 break;
    508               }
    509             } else if (event->mask & (IN_DELETE_SELF | IN_MOVE_SELF)) {
    510               // Restart at the beginning if our watch dir is deleted.
    511               links = 0;
    512               current.clear();
    513               pos = 0;
    514               exit_poll = true;
    515               break;
    516             }
    517 
    518             event = reinterpret_cast<const inotify_event*>(AdvancePointer(
    519                 event, sizeof(struct inotify_event) + event->len));
    520           }  // while (event < end)
    521         }    // while (!exit_poll)
    522       }      // Current dir doesn't exist.
    523       ret = inotify_rm_watch(fd.Get(), wd);
    524       if (ret < 0 && errno != EINVAL)
    525         return ErrorStatus(errno);
    526     }  // if (access(current.c_str(), F_OK) < 0)
    527 
    528     // Check for symbolic link and update link count.
    529     struct stat stat_buf;
    530     ret = lstat(current.c_str(), &stat_buf);
    531     if (ret < 0 && errno != ENOENT)
    532       return ErrorStatus(errno);
    533     else if (ret == 0 && S_ISLNK(stat_buf.st_mode))
    534       links++;
    535     pos++;
    536   }  // while (pos < separators.size() && links <= MAXSYMLINKS)
    537 
    538   return {};
    539 }
    540 
    541 }  // namespace uds
    542 }  // namespace pdx
    543 }  // namespace android
    544