1 #include "uds/service_endpoint.h" 2 3 #include <poll.h> 4 #include <sys/epoll.h> 5 #include <sys/eventfd.h> 6 #include <sys/socket.h> 7 #include <sys/un.h> 8 #include <algorithm> // std::min 9 10 #include <android-base/logging.h> 11 #include <android-base/strings.h> 12 #include <cutils/sockets.h> 13 #include <pdx/service.h> 14 #include <selinux/selinux.h> 15 #include <uds/channel_manager.h> 16 #include <uds/client_channel_factory.h> 17 #include <uds/ipc_helper.h> 18 19 namespace { 20 21 constexpr int kMaxBackLogForSocketListen = 1; 22 23 using android::pdx::BorrowedChannelHandle; 24 using android::pdx::BorrowedHandle; 25 using android::pdx::ChannelReference; 26 using android::pdx::ErrorStatus; 27 using android::pdx::FileReference; 28 using android::pdx::LocalChannelHandle; 29 using android::pdx::LocalHandle; 30 using android::pdx::Status; 31 using android::pdx::uds::ChannelInfo; 32 using android::pdx::uds::ChannelManager; 33 34 struct MessageState { 35 bool GetLocalFileHandle(int index, LocalHandle* handle) { 36 if (index < 0) { 37 handle->Reset(index); 38 } else if (static_cast<size_t>(index) < request.file_descriptors.size()) { 39 *handle = std::move(request.file_descriptors[index]); 40 } else { 41 return false; 42 } 43 return true; 44 } 45 46 bool GetLocalChannelHandle(int index, LocalChannelHandle* handle) { 47 if (index < 0) { 48 *handle = LocalChannelHandle{nullptr, index}; 49 } else if (static_cast<size_t>(index) < request.channels.size()) { 50 auto& channel_info = request.channels[index]; 51 *handle = ChannelManager::Get().CreateHandle( 52 std::move(channel_info.data_fd), 53 std::move(channel_info.pollin_event_fd), 54 std::move(channel_info.pollhup_event_fd)); 55 } else { 56 return false; 57 } 58 return true; 59 } 60 61 Status<FileReference> PushFileHandle(BorrowedHandle handle) { 62 if (!handle) 63 return handle.Get(); 64 response.file_descriptors.push_back(std::move(handle)); 65 return response.file_descriptors.size() - 1; 66 } 67 68 Status<ChannelReference> PushChannelHandle(BorrowedChannelHandle handle) { 69 if (!handle) 70 return handle.value(); 71 72 if (auto* channel_data = 73 ChannelManager::Get().GetChannelData(handle.value())) { 74 ChannelInfo<BorrowedHandle> channel_info{ 75 channel_data->data_fd(), channel_data->pollin_event_fd(), 76 channel_data->pollhup_event_fd()}; 77 response.channels.push_back(std::move(channel_info)); 78 return response.channels.size() - 1; 79 } else { 80 return ErrorStatus{EINVAL}; 81 } 82 } 83 84 Status<ChannelReference> PushChannelHandle(BorrowedHandle data_fd, 85 BorrowedHandle pollin_event_fd, 86 BorrowedHandle pollhup_event_fd) { 87 if (!data_fd || !pollin_event_fd || !pollhup_event_fd) 88 return ErrorStatus{EINVAL}; 89 ChannelInfo<BorrowedHandle> channel_info{std::move(data_fd), 90 std::move(pollin_event_fd), 91 std::move(pollhup_event_fd)}; 92 response.channels.push_back(std::move(channel_info)); 93 return response.channels.size() - 1; 94 } 95 96 Status<size_t> WriteData(const iovec* vector, size_t vector_length) { 97 size_t size = 0; 98 for (size_t i = 0; i < vector_length; i++) { 99 const auto* data = reinterpret_cast<const uint8_t*>(vector[i].iov_base); 100 response_data.insert(response_data.end(), data, data + vector[i].iov_len); 101 size += vector[i].iov_len; 102 } 103 return size; 104 } 105 106 Status<size_t> ReadData(const iovec* vector, size_t vector_length) { 107 size_t size_remaining = request_data.size() - request_data_read_pos; 108 size_t size = 0; 109 for (size_t i = 0; i < vector_length && size_remaining > 0; i++) { 110 size_t size_to_copy = std::min(size_remaining, vector[i].iov_len); 111 memcpy(vector[i].iov_base, request_data.data() + request_data_read_pos, 112 size_to_copy); 113 size += size_to_copy; 114 request_data_read_pos += size_to_copy; 115 size_remaining -= size_to_copy; 116 } 117 return size; 118 } 119 120 android::pdx::uds::RequestHeader<LocalHandle> request; 121 android::pdx::uds::ResponseHeader<BorrowedHandle> response; 122 std::vector<LocalHandle> sockets_to_close; 123 std::vector<uint8_t> request_data; 124 size_t request_data_read_pos{0}; 125 std::vector<uint8_t> response_data; 126 }; 127 128 } // anonymous namespace 129 130 namespace android { 131 namespace pdx { 132 namespace uds { 133 134 Endpoint::Endpoint(const std::string& endpoint_path, bool blocking, 135 bool use_init_socket_fd) 136 : endpoint_path_{ClientChannelFactory::GetEndpointPath(endpoint_path)}, 137 is_blocking_{blocking} { 138 LocalHandle fd; 139 if (use_init_socket_fd) { 140 // Cut off the /dev/socket/ prefix from the full socket path and use the 141 // resulting "name" to retrieve the file descriptor for the socket created 142 // by the init process. 143 constexpr char prefix[] = "/dev/socket/"; 144 CHECK(android::base::StartsWith(endpoint_path_, prefix)) 145 << "Endpoint::Endpoint: Socket name '" << endpoint_path_ 146 << "' must begin with '" << prefix << "'"; 147 std::string socket_name = endpoint_path_.substr(sizeof(prefix) - 1); 148 fd.Reset(android_get_control_socket(socket_name.c_str())); 149 CHECK(fd.IsValid()) 150 << "Endpoint::Endpoint: Unable to obtain the control socket fd for '" 151 << socket_name << "'"; 152 fcntl(fd.Get(), F_SETFD, FD_CLOEXEC); 153 } else { 154 fd.Reset(socket(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0)); 155 CHECK(fd.IsValid()) << "Endpoint::Endpoint: Failed to create socket: " 156 << strerror(errno); 157 158 sockaddr_un local; 159 local.sun_family = AF_UNIX; 160 strncpy(local.sun_path, endpoint_path_.c_str(), sizeof(local.sun_path)); 161 local.sun_path[sizeof(local.sun_path) - 1] = '\0'; 162 163 unlink(local.sun_path); 164 int ret = 165 bind(fd.Get(), reinterpret_cast<sockaddr*>(&local), sizeof(local)); 166 CHECK_EQ(ret, 0) << "Endpoint::Endpoint: bind error: " << strerror(errno); 167 } 168 Init(std::move(fd)); 169 } 170 171 Endpoint::Endpoint(LocalHandle socket_fd) { Init(std::move(socket_fd)); } 172 173 void Endpoint::Init(LocalHandle socket_fd) { 174 if (socket_fd) { 175 CHECK_EQ(listen(socket_fd.Get(), kMaxBackLogForSocketListen), 0) 176 << "Endpoint::Endpoint: listen error: " << strerror(errno); 177 } 178 cancel_event_fd_.Reset(eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK)); 179 CHECK(cancel_event_fd_.IsValid()) 180 << "Endpoint::Endpoint: Failed to create event fd: " << strerror(errno); 181 182 epoll_fd_.Reset(epoll_create1(EPOLL_CLOEXEC)); 183 CHECK(epoll_fd_.IsValid()) 184 << "Endpoint::Endpoint: Failed to create epoll fd: " << strerror(errno); 185 186 if (socket_fd) { 187 epoll_event socket_event; 188 socket_event.events = EPOLLIN | EPOLLRDHUP | EPOLLONESHOT; 189 socket_event.data.fd = socket_fd.Get(); 190 int ret = epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_ADD, socket_fd.Get(), 191 &socket_event); 192 CHECK_EQ(ret, 0) 193 << "Endpoint::Endpoint: Failed to add socket fd to epoll fd: " 194 << strerror(errno); 195 } 196 197 epoll_event cancel_event; 198 cancel_event.events = EPOLLIN; 199 cancel_event.data.fd = cancel_event_fd_.Get(); 200 201 int ret = epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_ADD, cancel_event_fd_.Get(), 202 &cancel_event); 203 CHECK_EQ(ret, 0) 204 << "Endpoint::Endpoint: Failed to add cancel event fd to epoll fd: " 205 << strerror(errno); 206 socket_fd_ = std::move(socket_fd); 207 } 208 209 void* Endpoint::AllocateMessageState() { return new MessageState; } 210 211 void Endpoint::FreeMessageState(void* state) { 212 delete static_cast<MessageState*>(state); 213 } 214 215 Status<void> Endpoint::AcceptConnection(Message* message) { 216 if (!socket_fd_) 217 return ErrorStatus(EBADF); 218 219 sockaddr_un remote; 220 socklen_t addrlen = sizeof(remote); 221 LocalHandle connection_fd{accept4(socket_fd_.Get(), 222 reinterpret_cast<sockaddr*>(&remote), 223 &addrlen, SOCK_CLOEXEC)}; 224 if (!connection_fd) { 225 ALOGE("Endpoint::AcceptConnection: failed to accept connection: %s", 226 strerror(errno)); 227 return ErrorStatus(errno); 228 } 229 230 LocalHandle local_socket; 231 LocalHandle remote_socket; 232 auto status = CreateChannelSocketPair(&local_socket, &remote_socket); 233 if (!status) 234 return status; 235 236 // Borrow the local channel handle before we move it into OnNewChannel(). 237 BorrowedHandle channel_handle = local_socket.Borrow(); 238 status = OnNewChannel(std::move(local_socket)); 239 if (!status) 240 return status; 241 242 // Send the channel socket fd to the client. 243 ChannelConnectionInfo<LocalHandle> connection_info; 244 connection_info.channel_fd = std::move(remote_socket); 245 status = SendData(connection_fd.Borrow(), connection_info); 246 247 if (status) { 248 // Get the CHANNEL_OPEN message from client over the channel socket. 249 status = ReceiveMessageForChannel(channel_handle, message); 250 } else { 251 CloseChannel(GetChannelId(channel_handle)); 252 } 253 254 // Don't need the connection socket anymore. Further communication should 255 // happen over the channel socket. 256 shutdown(connection_fd.Get(), SHUT_WR); 257 return status; 258 } 259 260 Status<void> Endpoint::SetService(Service* service) { 261 service_ = service; 262 return {}; 263 } 264 265 Status<void> Endpoint::SetChannel(int channel_id, Channel* channel) { 266 std::lock_guard<std::mutex> autolock(channel_mutex_); 267 auto channel_data = channels_.find(channel_id); 268 if (channel_data == channels_.end()) 269 return ErrorStatus{EINVAL}; 270 channel_data->second.channel_state = channel; 271 return {}; 272 } 273 274 Status<void> Endpoint::OnNewChannel(LocalHandle channel_fd) { 275 std::lock_guard<std::mutex> autolock(channel_mutex_); 276 Status<void> status; 277 status.PropagateError(OnNewChannelLocked(std::move(channel_fd), nullptr)); 278 return status; 279 } 280 281 Status<std::pair<int32_t, Endpoint::ChannelData*>> Endpoint::OnNewChannelLocked( 282 LocalHandle channel_fd, Channel* channel_state) { 283 epoll_event event; 284 event.events = EPOLLIN | EPOLLRDHUP | EPOLLONESHOT; 285 event.data.fd = channel_fd.Get(); 286 if (epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_ADD, channel_fd.Get(), &event) < 0) { 287 ALOGE( 288 "Endpoint::OnNewChannelLocked: Failed to add channel to endpoint: %s\n", 289 strerror(errno)); 290 return ErrorStatus(errno); 291 } 292 ChannelData channel_data; 293 channel_data.data_fd = std::move(channel_fd); 294 channel_data.channel_state = channel_state; 295 for (;;) { 296 // Try new channel IDs until we find one which is not already in the map. 297 if (last_channel_id_++ == std::numeric_limits<int32_t>::max()) 298 last_channel_id_ = 1; 299 auto iter = channels_.lower_bound(last_channel_id_); 300 if (iter == channels_.end() || iter->first != last_channel_id_) { 301 channel_fd_to_id_.emplace(channel_data.data_fd.Get(), last_channel_id_); 302 iter = channels_.emplace_hint(iter, last_channel_id_, 303 std::move(channel_data)); 304 return std::make_pair(last_channel_id_, &iter->second); 305 } 306 } 307 } 308 309 Status<void> Endpoint::ReenableEpollEvent(const BorrowedHandle& fd) { 310 epoll_event event; 311 event.events = EPOLLIN | EPOLLRDHUP | EPOLLONESHOT; 312 event.data.fd = fd.Get(); 313 if (epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_MOD, fd.Get(), &event) < 0) { 314 ALOGE( 315 "Endpoint::ReenableEpollEvent: Failed to re-enable channel to " 316 "endpoint: %s\n", 317 strerror(errno)); 318 return ErrorStatus(errno); 319 } 320 return {}; 321 } 322 323 Status<void> Endpoint::CloseChannel(int channel_id) { 324 std::lock_guard<std::mutex> autolock(channel_mutex_); 325 return CloseChannelLocked(channel_id); 326 } 327 328 Status<void> Endpoint::CloseChannelLocked(int32_t channel_id) { 329 ALOGD_IF(TRACE, "Endpoint::CloseChannelLocked: channel_id=%d", channel_id); 330 331 auto iter = channels_.find(channel_id); 332 if (iter == channels_.end()) 333 return ErrorStatus{EINVAL}; 334 335 int channel_fd = iter->second.data_fd.Get(); 336 Status<void> status; 337 epoll_event dummy; // See BUGS in man 2 epoll_ctl. 338 if (epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_DEL, channel_fd, &dummy) < 0) { 339 status.SetError(errno); 340 ALOGE( 341 "Endpoint::CloseChannelLocked: Failed to remove channel from endpoint: " 342 "%s\n", 343 strerror(errno)); 344 } else { 345 status.SetValue(); 346 } 347 348 channel_fd_to_id_.erase(channel_fd); 349 channels_.erase(iter); 350 return status; 351 } 352 353 Status<void> Endpoint::ModifyChannelEvents(int channel_id, int clear_mask, 354 int set_mask) { 355 std::lock_guard<std::mutex> autolock(channel_mutex_); 356 357 auto search = channels_.find(channel_id); 358 if (search != channels_.end()) { 359 auto& channel_data = search->second; 360 channel_data.event_set.ModifyEvents(clear_mask, set_mask); 361 return {}; 362 } 363 364 return ErrorStatus{EINVAL}; 365 } 366 367 Status<void> Endpoint::CreateChannelSocketPair(LocalHandle* local_socket, 368 LocalHandle* remote_socket) { 369 Status<void> status; 370 char* endpoint_context = nullptr; 371 // Make sure the channel socket has the correct SELinux label applied. 372 // Here we get the label from the endpoint file descriptor, which should be 373 // something like "u:object_r:pdx_service_endpoint_socket:s0" and replace 374 // "endpoint" with "channel" to produce the channel label such as this: 375 // "u:object_r:pdx_service_channel_socket:s0". 376 if (fgetfilecon_raw(socket_fd_.Get(), &endpoint_context) > 0) { 377 std::string channel_context = endpoint_context; 378 freecon(endpoint_context); 379 const std::string suffix = "_endpoint_socket"; 380 auto pos = channel_context.find(suffix); 381 if (pos != std::string::npos) { 382 channel_context.replace(pos, suffix.size(), "_channel_socket"); 383 } else { 384 ALOGW( 385 "Endpoint::CreateChannelSocketPair: Endpoint security context '%s' " 386 "does not contain expected substring '%s'", 387 channel_context.c_str(), suffix.c_str()); 388 } 389 ALOGE_IF(setsockcreatecon_raw(channel_context.c_str()) == -1, 390 "Endpoint::CreateChannelSocketPair: Failed to set channel socket " 391 "security context: %s", 392 strerror(errno)); 393 } else { 394 ALOGE( 395 "Endpoint::CreateChannelSocketPair: Failed to obtain the endpoint " 396 "socket's security context: %s", 397 strerror(errno)); 398 } 399 400 int channel_pair[2] = {}; 401 if (socketpair(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0, channel_pair) == -1) { 402 ALOGE("Endpoint::CreateChannelSocketPair: Failed to create socket pair: %s", 403 strerror(errno)); 404 status.SetError(errno); 405 return status; 406 } 407 408 setsockcreatecon_raw(nullptr); 409 410 local_socket->Reset(channel_pair[0]); 411 remote_socket->Reset(channel_pair[1]); 412 413 int optval = 1; 414 if (setsockopt(local_socket->Get(), SOL_SOCKET, SO_PASSCRED, &optval, 415 sizeof(optval)) == -1) { 416 ALOGE( 417 "Endpoint::CreateChannelSocketPair: Failed to enable the receiving of " 418 "the credentials for channel %d: %s", 419 local_socket->Get(), strerror(errno)); 420 status.SetError(errno); 421 } 422 return status; 423 } 424 425 Status<RemoteChannelHandle> Endpoint::PushChannel(Message* message, 426 int /*flags*/, 427 Channel* channel, 428 int* channel_id) { 429 LocalHandle local_socket; 430 LocalHandle remote_socket; 431 auto status = CreateChannelSocketPair(&local_socket, &remote_socket); 432 if (!status) 433 return status.error_status(); 434 435 std::lock_guard<std::mutex> autolock(channel_mutex_); 436 auto channel_data_status = 437 OnNewChannelLocked(std::move(local_socket), channel); 438 if (!channel_data_status) 439 return channel_data_status.error_status(); 440 441 ChannelData* channel_data; 442 std::tie(*channel_id, channel_data) = channel_data_status.take(); 443 444 // Flags are ignored for now. 445 // TODO(xiaohuit): Implement those. 446 447 auto* state = static_cast<MessageState*>(message->GetState()); 448 Status<ChannelReference> ref = state->PushChannelHandle( 449 remote_socket.Borrow(), channel_data->event_set.pollin_event_fd(), 450 channel_data->event_set.pollhup_event_fd()); 451 if (!ref) 452 return ref.error_status(); 453 state->sockets_to_close.push_back(std::move(remote_socket)); 454 return RemoteChannelHandle{ref.get()}; 455 } 456 457 Status<int> Endpoint::CheckChannel(const Message* /*message*/, 458 ChannelReference /*ref*/, 459 Channel** /*channel*/) { 460 // TODO(xiaohuit): Implement this. 461 return ErrorStatus(EFAULT); 462 } 463 464 Channel* Endpoint::GetChannelState(int32_t channel_id) { 465 std::lock_guard<std::mutex> autolock(channel_mutex_); 466 auto channel_data = channels_.find(channel_id); 467 return (channel_data != channels_.end()) ? channel_data->second.channel_state 468 : nullptr; 469 } 470 471 BorrowedHandle Endpoint::GetChannelSocketFd(int32_t channel_id) { 472 std::lock_guard<std::mutex> autolock(channel_mutex_); 473 BorrowedHandle handle; 474 auto channel_data = channels_.find(channel_id); 475 if (channel_data != channels_.end()) 476 handle = channel_data->second.data_fd.Borrow(); 477 return handle; 478 } 479 480 Status<std::pair<BorrowedHandle, BorrowedHandle>> Endpoint::GetChannelEventFd( 481 int32_t channel_id) { 482 std::lock_guard<std::mutex> autolock(channel_mutex_); 483 auto channel_data = channels_.find(channel_id); 484 if (channel_data != channels_.end()) { 485 return {{channel_data->second.event_set.pollin_event_fd(), 486 channel_data->second.event_set.pollhup_event_fd()}}; 487 } 488 return ErrorStatus(ENOENT); 489 } 490 491 int32_t Endpoint::GetChannelId(const BorrowedHandle& channel_fd) { 492 std::lock_guard<std::mutex> autolock(channel_mutex_); 493 auto iter = channel_fd_to_id_.find(channel_fd.Get()); 494 return (iter != channel_fd_to_id_.end()) ? iter->second : -1; 495 } 496 497 Status<void> Endpoint::ReceiveMessageForChannel( 498 const BorrowedHandle& channel_fd, Message* message) { 499 RequestHeader<LocalHandle> request; 500 int32_t channel_id = GetChannelId(channel_fd); 501 auto status = ReceiveData(channel_fd.Borrow(), &request); 502 if (!status) { 503 if (status.error() == ESHUTDOWN) { 504 BuildCloseMessage(channel_id, message); 505 return {}; 506 } else { 507 CloseChannel(channel_id); 508 return status; 509 } 510 } 511 512 MessageInfo info; 513 info.pid = request.cred.pid; 514 info.tid = -1; 515 info.cid = channel_id; 516 info.mid = request.is_impulse ? Message::IMPULSE_MESSAGE_ID 517 : GetNextAvailableMessageId(); 518 info.euid = request.cred.uid; 519 info.egid = request.cred.gid; 520 info.op = request.op; 521 info.flags = 0; 522 info.service = service_; 523 info.channel = GetChannelState(channel_id); 524 info.send_len = request.send_len; 525 info.recv_len = request.max_recv_len; 526 info.fd_count = request.file_descriptors.size(); 527 static_assert(sizeof(info.impulse) == request.impulse_payload.size(), 528 "Impulse payload sizes must be the same in RequestHeader and " 529 "MessageInfo"); 530 memcpy(info.impulse, request.impulse_payload.data(), 531 request.impulse_payload.size()); 532 *message = Message{info}; 533 auto* state = static_cast<MessageState*>(message->GetState()); 534 state->request = std::move(request); 535 if (request.send_len > 0 && !request.is_impulse) { 536 state->request_data.resize(request.send_len); 537 status = ReceiveData(channel_fd, state->request_data.data(), 538 state->request_data.size()); 539 } 540 541 if (status && request.is_impulse) 542 status = ReenableEpollEvent(channel_fd); 543 544 if (!status) { 545 if (status.error() == ESHUTDOWN) { 546 BuildCloseMessage(channel_id, message); 547 return {}; 548 } else { 549 CloseChannel(channel_id); 550 return status; 551 } 552 } 553 554 return status; 555 } 556 557 void Endpoint::BuildCloseMessage(int32_t channel_id, Message* message) { 558 ALOGD_IF(TRACE, "Endpoint::BuildCloseMessage: channel_id=%d", channel_id); 559 MessageInfo info; 560 info.pid = -1; 561 info.tid = -1; 562 info.cid = channel_id; 563 info.mid = GetNextAvailableMessageId(); 564 info.euid = -1; 565 info.egid = -1; 566 info.op = opcodes::CHANNEL_CLOSE; 567 info.flags = 0; 568 info.service = service_; 569 info.channel = GetChannelState(channel_id); 570 info.send_len = 0; 571 info.recv_len = 0; 572 info.fd_count = 0; 573 *message = Message{info}; 574 } 575 576 Status<void> Endpoint::MessageReceive(Message* message) { 577 // Receive at most one event from the epoll set. This should prevent multiple 578 // dispatch threads from attempting to handle messages on the same socket at 579 // the same time. 580 epoll_event event; 581 int count = RETRY_EINTR( 582 epoll_wait(epoll_fd_.Get(), &event, 1, is_blocking_ ? -1 : 0)); 583 if (count < 0) { 584 ALOGE("Endpoint::MessageReceive: Failed to wait for epoll events: %s\n", 585 strerror(errno)); 586 return ErrorStatus{errno}; 587 } else if (count == 0) { 588 return ErrorStatus{ETIMEDOUT}; 589 } 590 591 if (event.data.fd == cancel_event_fd_.Get()) { 592 return ErrorStatus{ESHUTDOWN}; 593 } 594 595 if (socket_fd_ && event.data.fd == socket_fd_.Get()) { 596 auto status = AcceptConnection(message); 597 if (!status) 598 return status; 599 return ReenableEpollEvent(socket_fd_.Borrow()); 600 } 601 602 BorrowedHandle channel_fd{event.data.fd}; 603 return ReceiveMessageForChannel(channel_fd, message); 604 } 605 606 Status<void> Endpoint::MessageReply(Message* message, int return_code) { 607 const int32_t channel_id = message->GetChannelId(); 608 auto channel_socket = GetChannelSocketFd(channel_id); 609 if (!channel_socket) 610 return ErrorStatus{EBADF}; 611 612 auto* state = static_cast<MessageState*>(message->GetState()); 613 switch (message->GetOp()) { 614 case opcodes::CHANNEL_CLOSE: 615 return CloseChannel(channel_id); 616 617 case opcodes::CHANNEL_OPEN: 618 if (return_code < 0) { 619 return CloseChannel(channel_id); 620 } else { 621 // Open messages do not have a payload and may not transfer any channels 622 // or file descriptors on behalf of the service. 623 state->response_data.clear(); 624 state->response.file_descriptors.clear(); 625 state->response.channels.clear(); 626 627 // Return the channel event-related fds in a single ChannelInfo entry 628 // with an empty data_fd member. 629 auto status = GetChannelEventFd(channel_id); 630 if (!status) 631 return status.error_status(); 632 633 auto handles = status.take(); 634 state->response.channels.push_back({BorrowedHandle(), 635 std::move(handles.first), 636 std::move(handles.second)}); 637 return_code = 0; 638 } 639 break; 640 } 641 642 state->response.ret_code = return_code; 643 state->response.recv_len = state->response_data.size(); 644 auto status = SendData(channel_socket, state->response); 645 if (status && !state->response_data.empty()) { 646 status = SendData(channel_socket, state->response_data.data(), 647 state->response_data.size()); 648 } 649 650 if (status) 651 status = ReenableEpollEvent(channel_socket); 652 653 return status; 654 } 655 656 Status<void> Endpoint::MessageReplyFd(Message* message, unsigned int push_fd) { 657 auto* state = static_cast<MessageState*>(message->GetState()); 658 auto ref = state->PushFileHandle(BorrowedHandle{static_cast<int>(push_fd)}); 659 if (!ref) 660 return ref.error_status(); 661 return MessageReply(message, ref.get()); 662 } 663 664 Status<void> Endpoint::MessageReplyChannelHandle( 665 Message* message, const LocalChannelHandle& handle) { 666 auto* state = static_cast<MessageState*>(message->GetState()); 667 auto ref = state->PushChannelHandle(handle.Borrow()); 668 if (!ref) 669 return ref.error_status(); 670 return MessageReply(message, ref.get()); 671 } 672 673 Status<void> Endpoint::MessageReplyChannelHandle( 674 Message* message, const BorrowedChannelHandle& handle) { 675 auto* state = static_cast<MessageState*>(message->GetState()); 676 auto ref = state->PushChannelHandle(handle.Duplicate()); 677 if (!ref) 678 return ref.error_status(); 679 return MessageReply(message, ref.get()); 680 } 681 682 Status<void> Endpoint::MessageReplyChannelHandle( 683 Message* message, const RemoteChannelHandle& handle) { 684 return MessageReply(message, handle.value()); 685 } 686 687 Status<size_t> Endpoint::ReadMessageData(Message* message, const iovec* vector, 688 size_t vector_length) { 689 auto* state = static_cast<MessageState*>(message->GetState()); 690 return state->ReadData(vector, vector_length); 691 } 692 693 Status<size_t> Endpoint::WriteMessageData(Message* message, const iovec* vector, 694 size_t vector_length) { 695 auto* state = static_cast<MessageState*>(message->GetState()); 696 return state->WriteData(vector, vector_length); 697 } 698 699 Status<FileReference> Endpoint::PushFileHandle(Message* message, 700 const LocalHandle& handle) { 701 auto* state = static_cast<MessageState*>(message->GetState()); 702 return state->PushFileHandle(handle.Borrow()); 703 } 704 705 Status<FileReference> Endpoint::PushFileHandle(Message* message, 706 const BorrowedHandle& handle) { 707 auto* state = static_cast<MessageState*>(message->GetState()); 708 return state->PushFileHandle(handle.Duplicate()); 709 } 710 711 Status<FileReference> Endpoint::PushFileHandle(Message* /*message*/, 712 const RemoteHandle& handle) { 713 return handle.Get(); 714 } 715 716 Status<ChannelReference> Endpoint::PushChannelHandle( 717 Message* message, const LocalChannelHandle& handle) { 718 auto* state = static_cast<MessageState*>(message->GetState()); 719 return state->PushChannelHandle(handle.Borrow()); 720 } 721 722 Status<ChannelReference> Endpoint::PushChannelHandle( 723 Message* message, const BorrowedChannelHandle& handle) { 724 auto* state = static_cast<MessageState*>(message->GetState()); 725 return state->PushChannelHandle(handle.Duplicate()); 726 } 727 728 Status<ChannelReference> Endpoint::PushChannelHandle( 729 Message* /*message*/, const RemoteChannelHandle& handle) { 730 return handle.value(); 731 } 732 733 LocalHandle Endpoint::GetFileHandle(Message* message, FileReference ref) const { 734 LocalHandle handle; 735 auto* state = static_cast<MessageState*>(message->GetState()); 736 state->GetLocalFileHandle(ref, &handle); 737 return handle; 738 } 739 740 LocalChannelHandle Endpoint::GetChannelHandle(Message* message, 741 ChannelReference ref) const { 742 LocalChannelHandle handle; 743 auto* state = static_cast<MessageState*>(message->GetState()); 744 state->GetLocalChannelHandle(ref, &handle); 745 return handle; 746 } 747 748 Status<void> Endpoint::Cancel() { 749 if (eventfd_write(cancel_event_fd_.Get(), 1) < 0) 750 return ErrorStatus{errno}; 751 return {}; 752 } 753 754 std::unique_ptr<Endpoint> Endpoint::Create(const std::string& endpoint_path, 755 mode_t /*unused_mode*/, 756 bool blocking) { 757 return std::unique_ptr<Endpoint>(new Endpoint(endpoint_path, blocking)); 758 } 759 760 std::unique_ptr<Endpoint> Endpoint::CreateAndBindSocket( 761 const std::string& endpoint_path, bool blocking) { 762 return std::unique_ptr<Endpoint>( 763 new Endpoint(endpoint_path, blocking, false)); 764 } 765 766 std::unique_ptr<Endpoint> Endpoint::CreateFromSocketFd(LocalHandle socket_fd) { 767 return std::unique_ptr<Endpoint>(new Endpoint(std::move(socket_fd))); 768 } 769 770 Status<void> Endpoint::RegisterNewChannelForTests(LocalHandle channel_fd) { 771 int optval = 1; 772 if (setsockopt(channel_fd.Get(), SOL_SOCKET, SO_PASSCRED, &optval, 773 sizeof(optval)) == -1) { 774 ALOGE( 775 "Endpoint::RegisterNewChannelForTests: Failed to enable the receiving" 776 "of the credentials for channel %d: %s", 777 channel_fd.Get(), strerror(errno)); 778 return ErrorStatus(errno); 779 } 780 return OnNewChannel(std::move(channel_fd)); 781 } 782 783 } // namespace uds 784 } // namespace pdx 785 } // namespace android 786