1 // Copyright 2016 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "mojo/core/broker.h" 6 7 #include <fcntl.h> 8 #include <unistd.h> 9 10 #include <utility> 11 #include <vector> 12 13 #include "base/logging.h" 14 #include "base/memory/platform_shared_memory_region.h" 15 #include "build/build_config.h" 16 #include "mojo/core/broker_messages.h" 17 #include "mojo/core/channel.h" 18 #include "mojo/core/platform_handle_utils.h" 19 #include "mojo/public/cpp/platform/socket_utils_posix.h" 20 21 namespace mojo { 22 namespace core { 23 24 namespace { 25 26 Channel::MessagePtr WaitForBrokerMessage( 27 int socket_fd, 28 BrokerMessageType expected_type, 29 size_t expected_num_handles, 30 size_t expected_data_size, 31 std::vector<PlatformHandle>* incoming_handles) { 32 Channel::MessagePtr message(new Channel::Message( 33 sizeof(BrokerMessageHeader) + expected_data_size, expected_num_handles)); 34 std::vector<base::ScopedFD> incoming_fds; 35 ssize_t read_result = 36 SocketRecvmsg(socket_fd, const_cast<void*>(message->data()), 37 message->data_num_bytes(), &incoming_fds, true /* block */); 38 bool error = false; 39 if (read_result < 0) { 40 PLOG(ERROR) << "Recvmsg error"; 41 error = true; 42 } else if (static_cast<size_t>(read_result) != message->data_num_bytes()) { 43 LOG(ERROR) << "Invalid node channel message"; 44 error = true; 45 } else if (incoming_fds.size() != expected_num_handles) { 46 LOG(ERROR) << "Received unexpected number of handles"; 47 error = true; 48 } 49 50 if (error) 51 return nullptr; 52 53 const BrokerMessageHeader* header = 54 reinterpret_cast<const BrokerMessageHeader*>(message->payload()); 55 if (header->type != expected_type) { 56 LOG(ERROR) << "Unexpected message"; 57 return nullptr; 58 } 59 60 incoming_handles->reserve(incoming_fds.size()); 61 for (size_t i = 0; i < incoming_fds.size(); ++i) 62 incoming_handles->emplace_back(std::move(incoming_fds[i])); 63 64 return message; 65 } 66 67 } // namespace 68 69 Broker::Broker(PlatformHandle handle) : sync_channel_(std::move(handle)) { 70 CHECK(sync_channel_.is_valid()); 71 72 int fd = sync_channel_.GetFD().get(); 73 // Mark the channel as blocking. 74 int flags = fcntl(fd, F_GETFL); 75 PCHECK(flags != -1); 76 flags = fcntl(fd, F_SETFL, flags & ~O_NONBLOCK); 77 PCHECK(flags != -1); 78 79 // Wait for the first message, which should contain a handle. 80 std::vector<PlatformHandle> incoming_platform_handles; 81 if (WaitForBrokerMessage(fd, BrokerMessageType::INIT, 1, 0, 82 &incoming_platform_handles)) { 83 inviter_endpoint_ = 84 PlatformChannelEndpoint(std::move(incoming_platform_handles[0])); 85 } 86 } 87 88 Broker::~Broker() = default; 89 90 PlatformChannelEndpoint Broker::GetInviterEndpoint() { 91 return std::move(inviter_endpoint_); 92 } 93 94 base::WritableSharedMemoryRegion Broker::GetWritableSharedMemoryRegion( 95 size_t num_bytes) { 96 base::AutoLock lock(lock_); 97 98 BufferRequestData* buffer_request; 99 Channel::MessagePtr out_message = CreateBrokerMessage( 100 BrokerMessageType::BUFFER_REQUEST, 0, 0, &buffer_request); 101 buffer_request->size = num_bytes; 102 ssize_t write_result = 103 SocketWrite(sync_channel_.GetFD().get(), out_message->data(), 104 out_message->data_num_bytes()); 105 if (write_result < 0) { 106 PLOG(ERROR) << "Error sending sync broker message"; 107 return base::WritableSharedMemoryRegion(); 108 } else if (static_cast<size_t>(write_result) != 109 out_message->data_num_bytes()) { 110 LOG(ERROR) << "Error sending complete broker message"; 111 return base::WritableSharedMemoryRegion(); 112 } 113 114 #if !defined(OS_POSIX) || defined(OS_ANDROID) || defined(OS_FUCHSIA) || \ 115 (defined(OS_MACOSX) && !defined(OS_IOS)) 116 // Non-POSIX systems, as well as Android, Fuchsia, and non-iOS Mac, only use 117 // a single handle to represent a writable region. 118 constexpr size_t kNumExpectedHandles = 1; 119 #else 120 constexpr size_t kNumExpectedHandles = 2; 121 #endif 122 123 std::vector<PlatformHandle> handles; 124 Channel::MessagePtr message = WaitForBrokerMessage( 125 sync_channel_.GetFD().get(), BrokerMessageType::BUFFER_RESPONSE, 126 kNumExpectedHandles, sizeof(BufferResponseData), &handles); 127 if (message) { 128 const BufferResponseData* data; 129 if (!GetBrokerMessageData(message.get(), &data)) 130 return base::WritableSharedMemoryRegion(); 131 132 if (handles.size() == 1) 133 handles.emplace_back(); 134 return base::WritableSharedMemoryRegion::Deserialize( 135 base::subtle::PlatformSharedMemoryRegion::Take( 136 CreateSharedMemoryRegionHandleFromPlatformHandles( 137 std::move(handles[0]), std::move(handles[1])), 138 base::subtle::PlatformSharedMemoryRegion::Mode::kWritable, 139 num_bytes, 140 base::UnguessableToken::Deserialize(data->guid_high, 141 data->guid_low))); 142 } 143 144 return base::WritableSharedMemoryRegion(); 145 } 146 147 } // namespace core 148 } // namespace mojo 149