Home | History | Annotate | Download | only in core
      1 // Copyright 2016 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include <windows.h>
      6 
      7 #include <limits>
      8 #include <utility>
      9 
     10 #include "base/debug/alias.h"
     11 #include "base/memory/platform_shared_memory_region.h"
     12 #include "base/numerics/safe_conversions.h"
     13 #include "base/strings/string_piece.h"
     14 #include "mojo/core/broker.h"
     15 #include "mojo/core/broker_messages.h"
     16 #include "mojo/core/channel.h"
     17 #include "mojo/core/platform_handle_utils.h"
     18 #include "mojo/public/cpp/platform/named_platform_channel.h"
     19 
     20 namespace mojo {
     21 namespace core {
     22 
     23 namespace {
     24 
     25 // 256 bytes should be enough for anyone!
     26 const size_t kMaxBrokerMessageSize = 256;
     27 
     28 bool TakeHandlesFromBrokerMessage(Channel::Message* message,
     29                                   size_t num_handles,
     30                                   PlatformHandle* out_handles) {
     31   if (message->num_handles() != num_handles) {
     32     DLOG(ERROR) << "Received unexpected number of handles in broker message";
     33     return false;
     34   }
     35 
     36   std::vector<PlatformHandleInTransit> handles = message->TakeHandles();
     37   DCHECK_EQ(handles.size(), num_handles);
     38   DCHECK(out_handles);
     39 
     40   for (size_t i = 0; i < num_handles; ++i)
     41     out_handles[i] = handles[i].TakeHandle();
     42   return true;
     43 }
     44 
     45 Channel::MessagePtr WaitForBrokerMessage(HANDLE pipe_handle,
     46                                          BrokerMessageType expected_type) {
     47   char buffer[kMaxBrokerMessageSize];
     48   DWORD bytes_read = 0;
     49   BOOL result = ::ReadFile(pipe_handle, buffer, kMaxBrokerMessageSize,
     50                            &bytes_read, nullptr);
     51   if (!result) {
     52     // The pipe may be broken if the browser side has been closed, e.g. during
     53     // browser shutdown. In that case the ReadFile call will fail and we
     54     // shouldn't continue waiting.
     55     PLOG(ERROR) << "Error reading broker pipe";
     56     return nullptr;
     57   }
     58 
     59   Channel::MessagePtr message =
     60       Channel::Message::Deserialize(buffer, static_cast<size_t>(bytes_read));
     61   if (!message || message->payload_size() < sizeof(BrokerMessageHeader)) {
     62     LOG(ERROR) << "Invalid broker message";
     63 
     64     base::debug::Alias(&buffer[0]);
     65     base::debug::Alias(&bytes_read);
     66     CHECK(false);
     67     return nullptr;
     68   }
     69 
     70   const BrokerMessageHeader* header =
     71       reinterpret_cast<const BrokerMessageHeader*>(message->payload());
     72   if (header->type != expected_type) {
     73     LOG(ERROR) << "Unexpected broker message type";
     74 
     75     base::debug::Alias(&buffer[0]);
     76     base::debug::Alias(&bytes_read);
     77     CHECK(false);
     78     return nullptr;
     79   }
     80 
     81   return message;
     82 }
     83 
     84 }  // namespace
     85 
     86 Broker::Broker(PlatformHandle handle) : sync_channel_(std::move(handle)) {
     87   CHECK(sync_channel_.is_valid());
     88   Channel::MessagePtr message = WaitForBrokerMessage(
     89       sync_channel_.GetHandle().Get(), BrokerMessageType::INIT);
     90 
     91   // If we fail to read a message (broken pipe), just return early. The inviter
     92   // handle will be null and callers must handle this gracefully.
     93   if (!message)
     94     return;
     95 
     96   PlatformHandle endpoint_handle;
     97   if (TakeHandlesFromBrokerMessage(message.get(), 1, &endpoint_handle)) {
     98     inviter_endpoint_ = PlatformChannelEndpoint(std::move(endpoint_handle));
     99   } else {
    100     // If the message has no handles, we expect it to carry pipe name instead.
    101     const BrokerMessageHeader* header =
    102         static_cast<const BrokerMessageHeader*>(message->payload());
    103     CHECK_GE(message->payload_size(),
    104              sizeof(BrokerMessageHeader) + sizeof(InitData));
    105     const InitData* data = reinterpret_cast<const InitData*>(header + 1);
    106     CHECK_EQ(message->payload_size(),
    107              sizeof(BrokerMessageHeader) + sizeof(InitData) +
    108                  data->pipe_name_length * sizeof(base::char16));
    109     const base::char16* name_data =
    110         reinterpret_cast<const base::char16*>(data + 1);
    111     CHECK(data->pipe_name_length);
    112     inviter_endpoint_ = NamedPlatformChannel::ConnectToServer(
    113         base::StringPiece16(name_data, data->pipe_name_length).as_string());
    114   }
    115 }
    116 
    117 Broker::~Broker() {}
    118 
    119 PlatformChannelEndpoint Broker::GetInviterEndpoint() {
    120   return std::move(inviter_endpoint_);
    121 }
    122 
    123 base::WritableSharedMemoryRegion Broker::GetWritableSharedMemoryRegion(
    124     size_t num_bytes) {
    125   base::AutoLock lock(lock_);
    126   BufferRequestData* buffer_request;
    127   Channel::MessagePtr out_message = CreateBrokerMessage(
    128       BrokerMessageType::BUFFER_REQUEST, 0, 0, &buffer_request);
    129   buffer_request->size = base::checked_cast<uint32_t>(num_bytes);
    130   DWORD bytes_written = 0;
    131   BOOL result =
    132       ::WriteFile(sync_channel_.GetHandle().Get(), out_message->data(),
    133                   static_cast<DWORD>(out_message->data_num_bytes()),
    134                   &bytes_written, nullptr);
    135   if (!result ||
    136       static_cast<size_t>(bytes_written) != out_message->data_num_bytes()) {
    137     PLOG(ERROR) << "Error sending sync broker message";
    138     return base::WritableSharedMemoryRegion();
    139   }
    140 
    141   PlatformHandle handle;
    142   Channel::MessagePtr response = WaitForBrokerMessage(
    143       sync_channel_.GetHandle().Get(), BrokerMessageType::BUFFER_RESPONSE);
    144   if (response && TakeHandlesFromBrokerMessage(response.get(), 1, &handle)) {
    145     BufferResponseData* data;
    146     if (!GetBrokerMessageData(response.get(), &data))
    147       return base::WritableSharedMemoryRegion();
    148     return base::WritableSharedMemoryRegion::Deserialize(
    149         base::subtle::PlatformSharedMemoryRegion::Take(
    150             CreateSharedMemoryRegionHandleFromPlatformHandles(std::move(handle),
    151                                                               PlatformHandle()),
    152             base::subtle::PlatformSharedMemoryRegion::Mode::kWritable,
    153             num_bytes,
    154             base::UnguessableToken::Deserialize(data->guid_high,
    155                                                 data->guid_low)));
    156   }
    157 
    158   return base::WritableSharedMemoryRegion();
    159 }
    160 
    161 }  // namespace core
    162 }  // namespace mojo
    163