Home | History | Annotate | Download | only in core
      1 // Copyright 2016 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "mojo/core/broker_host.h"
      6 
      7 #include <utility>
      8 
      9 #include "base/logging.h"
     10 #include "base/memory/platform_shared_memory_region.h"
     11 #include "base/memory/ref_counted.h"
     12 #include "base/threading/thread_task_runner_handle.h"
     13 #include "build/build_config.h"
     14 #include "mojo/core/broker_messages.h"
     15 #include "mojo/core/platform_handle_utils.h"
     16 
     17 #if defined(OS_WIN)
     18 #include <windows.h>
     19 #endif
     20 
     21 namespace mojo {
     22 namespace core {
     23 
     24 BrokerHost::BrokerHost(base::ProcessHandle client_process,
     25                        ConnectionParams connection_params,
     26                        const ProcessErrorCallback& process_error_callback)
     27     : process_error_callback_(process_error_callback)
     28 #if defined(OS_WIN)
     29       ,
     30       client_process_(ScopedProcessHandle::CloneFrom(client_process))
     31 #endif
     32 {
     33   CHECK(connection_params.endpoint().is_valid() ||
     34         connection_params.server_endpoint().is_valid());
     35 
     36   base::MessageLoopCurrent::Get()->AddDestructionObserver(this);
     37 
     38   channel_ = Channel::Create(this, std::move(connection_params),
     39                              base::ThreadTaskRunnerHandle::Get());
     40   channel_->Start();
     41 }
     42 
     43 BrokerHost::~BrokerHost() {
     44   // We're always destroyed on the creation thread, which is the IO thread.
     45   base::MessageLoopCurrent::Get()->RemoveDestructionObserver(this);
     46 
     47   if (channel_)
     48     channel_->ShutDown();
     49 }
     50 
     51 bool BrokerHost::PrepareHandlesForClient(
     52     std::vector<PlatformHandleInTransit>* handles) {
     53 #if defined(OS_WIN)
     54   bool handles_ok = true;
     55   for (auto& handle : *handles) {
     56     if (!handle.TransferToProcess(client_process_.Clone()))
     57       handles_ok = false;
     58   }
     59   return handles_ok;
     60 #else
     61   return true;
     62 #endif
     63 }
     64 
     65 bool BrokerHost::SendChannel(PlatformHandle handle) {
     66   CHECK(handle.is_valid());
     67   CHECK(channel_);
     68 
     69 #if defined(OS_WIN)
     70   InitData* data;
     71   Channel::MessagePtr message =
     72       CreateBrokerMessage(BrokerMessageType::INIT, 1, 0, &data);
     73   data->pipe_name_length = 0;
     74 #else
     75   Channel::MessagePtr message =
     76       CreateBrokerMessage(BrokerMessageType::INIT, 1, nullptr);
     77 #endif
     78   std::vector<PlatformHandleInTransit> handles(1);
     79   handles[0] = PlatformHandleInTransit(std::move(handle));
     80 
     81   // This may legitimately fail on Windows if the client process is in another
     82   // session, e.g., is an elevated process.
     83   if (!PrepareHandlesForClient(&handles))
     84     return false;
     85 
     86   message->SetHandles(std::move(handles));
     87   channel_->Write(std::move(message));
     88   return true;
     89 }
     90 
     91 #if defined(OS_WIN)
     92 
     93 void BrokerHost::SendNamedChannel(const base::StringPiece16& pipe_name) {
     94   InitData* data;
     95   base::char16* name_data;
     96   Channel::MessagePtr message = CreateBrokerMessage(
     97       BrokerMessageType::INIT, 0, sizeof(*name_data) * pipe_name.length(),
     98       &data, reinterpret_cast<void**>(&name_data));
     99   data->pipe_name_length = static_cast<uint32_t>(pipe_name.length());
    100   std::copy(pipe_name.begin(), pipe_name.end(), name_data);
    101   channel_->Write(std::move(message));
    102 }
    103 
    104 #endif  // defined(OS_WIN)
    105 
    106 void BrokerHost::OnBufferRequest(uint32_t num_bytes) {
    107   base::subtle::PlatformSharedMemoryRegion region =
    108       base::subtle::PlatformSharedMemoryRegion::CreateWritable(num_bytes);
    109 
    110   std::vector<PlatformHandleInTransit> handles(2);
    111   if (region.IsValid()) {
    112     PlatformHandle h[2];
    113     ExtractPlatformHandlesFromSharedMemoryRegionHandle(
    114         region.PassPlatformHandle(), &h[0], &h[1]);
    115     handles[0] = PlatformHandleInTransit(std::move(h[0]));
    116     handles[1] = PlatformHandleInTransit(std::move(h[1]));
    117 #if !defined(OS_POSIX) || defined(OS_ANDROID) || defined(OS_FUCHSIA) || \
    118     (defined(OS_MACOSX) && !defined(OS_IOS))
    119     // Non-POSIX systems, as well as Android, Fuchsia, and non-iOS Mac, only use
    120     // a single handle to represent a writable region.
    121     DCHECK(!handles[1].handle().is_valid());
    122     handles.resize(1);
    123 #else
    124     DCHECK(handles[1].handle().is_valid());
    125 #endif
    126   }
    127 
    128   BufferResponseData* response;
    129   Channel::MessagePtr message = CreateBrokerMessage(
    130       BrokerMessageType::BUFFER_RESPONSE, handles.size(), 0, &response);
    131   if (!handles.empty()) {
    132     base::UnguessableToken guid = region.GetGUID();
    133     response->guid_high = guid.GetHighForSerialization();
    134     response->guid_low = guid.GetLowForSerialization();
    135     PrepareHandlesForClient(&handles);
    136     message->SetHandles(std::move(handles));
    137   }
    138 
    139   channel_->Write(std::move(message));
    140 }
    141 
    142 void BrokerHost::OnChannelMessage(const void* payload,
    143                                   size_t payload_size,
    144                                   std::vector<PlatformHandle> handles) {
    145   if (payload_size < sizeof(BrokerMessageHeader))
    146     return;
    147 
    148   const BrokerMessageHeader* header =
    149       static_cast<const BrokerMessageHeader*>(payload);
    150   switch (header->type) {
    151     case BrokerMessageType::BUFFER_REQUEST:
    152       if (payload_size ==
    153           sizeof(BrokerMessageHeader) + sizeof(BufferRequestData)) {
    154         const BufferRequestData* request =
    155             reinterpret_cast<const BufferRequestData*>(header + 1);
    156         OnBufferRequest(request->size);
    157       }
    158       break;
    159 
    160     default:
    161       DLOG(ERROR) << "Unexpected broker message type: " << header->type;
    162       break;
    163   }
    164 }
    165 
    166 void BrokerHost::OnChannelError(Channel::Error error) {
    167   if (process_error_callback_ &&
    168       error == Channel::Error::kReceivedMalformedData) {
    169     process_error_callback_.Run("Broker host received malformed message");
    170   }
    171 
    172   delete this;
    173 }
    174 
    175 void BrokerHost::WillDestroyCurrentMessageLoop() {
    176   delete this;
    177 }
    178 
    179 }  // namespace core
    180 }  // namespace mojo
    181