Home | History | Annotate | Download | only in lib
      1 // Copyright 2013 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "mojo/public/cpp/bindings/message.h"
      6 
      7 #include <stddef.h>
      8 #include <stdint.h>
      9 #include <stdlib.h>
     10 
     11 #include <algorithm>
     12 #include <utility>
     13 
     14 #include "base/bind.h"
     15 #include "base/lazy_instance.h"
     16 #include "base/logging.h"
     17 #include "base/strings/stringprintf.h"
     18 #include "base/threading/thread_local.h"
     19 #include "mojo/public/cpp/bindings/associated_group_controller.h"
     20 #include "mojo/public/cpp/bindings/lib/array_internal.h"
     21 
     22 namespace mojo {
     23 
     24 namespace {
     25 
     26 base::LazyInstance<base::ThreadLocalPointer<internal::MessageDispatchContext>>::
     27     DestructorAtExit g_tls_message_dispatch_context = LAZY_INSTANCE_INITIALIZER;
     28 
     29 base::LazyInstance<base::ThreadLocalPointer<SyncMessageResponseContext>>::
     30     DestructorAtExit g_tls_sync_response_context = LAZY_INSTANCE_INITIALIZER;
     31 
     32 void DoNotifyBadMessage(Message message, const std::string& error) {
     33   message.NotifyBadMessage(error);
     34 }
     35 
     36 }  // namespace
     37 
     38 Message::Message() {
     39 }
     40 
     41 Message::Message(Message&& other)
     42     : buffer_(std::move(other.buffer_)),
     43       handles_(std::move(other.handles_)),
     44       associated_endpoint_handles_(
     45           std::move(other.associated_endpoint_handles_)) {}
     46 
     47 Message::~Message() {
     48   CloseHandles();
     49 }
     50 
     51 Message& Message::operator=(Message&& other) {
     52   Reset();
     53   std::swap(other.buffer_, buffer_);
     54   std::swap(other.handles_, handles_);
     55   std::swap(other.associated_endpoint_handles_, associated_endpoint_handles_);
     56   return *this;
     57 }
     58 
     59 void Message::Reset() {
     60   CloseHandles();
     61   handles_.clear();
     62   associated_endpoint_handles_.clear();
     63   buffer_.reset();
     64 }
     65 
     66 void Message::Initialize(size_t capacity, bool zero_initialized) {
     67   DCHECK(!buffer_);
     68   buffer_.reset(new internal::MessageBuffer(capacity, zero_initialized));
     69 }
     70 
     71 void Message::InitializeFromMojoMessage(ScopedMessageHandle message,
     72                                         uint32_t num_bytes,
     73                                         std::vector<Handle>* handles) {
     74   DCHECK(!buffer_);
     75   buffer_.reset(new internal::MessageBuffer(std::move(message), num_bytes));
     76   handles_.swap(*handles);
     77 }
     78 
     79 const uint8_t* Message::payload() const {
     80   if (version() < 2)
     81     return data() + header()->num_bytes;
     82 
     83   return static_cast<const uint8_t*>(header_v2()->payload.Get());
     84 }
     85 
     86 uint32_t Message::payload_num_bytes() const {
     87   DCHECK_GE(data_num_bytes(), header()->num_bytes);
     88   size_t num_bytes;
     89   if (version() < 2) {
     90     num_bytes = data_num_bytes() - header()->num_bytes;
     91   } else {
     92     auto payload = reinterpret_cast<uintptr_t>(header_v2()->payload.Get());
     93     if (!payload) {
     94       num_bytes = 0;
     95     } else {
     96       auto payload_end =
     97           reinterpret_cast<uintptr_t>(header_v2()->payload_interface_ids.Get());
     98       if (!payload_end)
     99         payload_end = reinterpret_cast<uintptr_t>(data() + data_num_bytes());
    100       DCHECK_GE(payload_end, payload);
    101       num_bytes = payload_end - payload;
    102     }
    103   }
    104   DCHECK_LE(num_bytes, std::numeric_limits<uint32_t>::max());
    105   return static_cast<uint32_t>(num_bytes);
    106 }
    107 
    108 uint32_t Message::payload_num_interface_ids() const {
    109   auto* array_pointer =
    110       version() < 2 ? nullptr : header_v2()->payload_interface_ids.Get();
    111   return array_pointer ? static_cast<uint32_t>(array_pointer->size()) : 0;
    112 }
    113 
    114 const uint32_t* Message::payload_interface_ids() const {
    115   auto* array_pointer =
    116       version() < 2 ? nullptr : header_v2()->payload_interface_ids.Get();
    117   return array_pointer ? array_pointer->storage() : nullptr;
    118 }
    119 
    120 ScopedMessageHandle Message::TakeMojoMessage() {
    121   // If there are associated endpoints transferred,
    122   // SerializeAssociatedEndpointHandles() must be called before this method.
    123   DCHECK(associated_endpoint_handles_.empty());
    124 
    125   if (handles_.empty())  // Fast path for the common case: No handles.
    126     return buffer_->TakeMessage();
    127 
    128   // Allocate a new message with space for the handles, then copy the buffer
    129   // contents into it.
    130   //
    131   // TODO(rockot): We could avoid this copy by extending GetSerializedSize()
    132   // behavior to collect handles. It's unoptimized for now because it's much
    133   // more common to have messages with no handles.
    134   ScopedMessageHandle new_message;
    135   MojoResult rv = AllocMessage(
    136       data_num_bytes(),
    137       handles_.empty() ? nullptr
    138                        : reinterpret_cast<const MojoHandle*>(handles_.data()),
    139       handles_.size(),
    140       MOJO_ALLOC_MESSAGE_FLAG_NONE,
    141       &new_message);
    142   CHECK_EQ(rv, MOJO_RESULT_OK);
    143   handles_.clear();
    144 
    145   void* new_buffer = nullptr;
    146   rv = GetMessageBuffer(new_message.get(), &new_buffer);
    147   CHECK_EQ(rv, MOJO_RESULT_OK);
    148 
    149   memcpy(new_buffer, data(), data_num_bytes());
    150   buffer_.reset();
    151 
    152   return new_message;
    153 }
    154 
    155 void Message::NotifyBadMessage(const std::string& error) {
    156   DCHECK(buffer_);
    157   buffer_->NotifyBadMessage(error);
    158 }
    159 
    160 void Message::CloseHandles() {
    161   for (std::vector<Handle>::iterator it = handles_.begin();
    162        it != handles_.end(); ++it) {
    163     if (it->is_valid())
    164       CloseRaw(*it);
    165   }
    166 }
    167 
    168 void Message::SerializeAssociatedEndpointHandles(
    169     AssociatedGroupController* group_controller) {
    170   if (associated_endpoint_handles_.empty())
    171     return;
    172 
    173   DCHECK_GE(version(), 2u);
    174   DCHECK(header_v2()->payload_interface_ids.is_null());
    175 
    176   size_t size = associated_endpoint_handles_.size();
    177   auto* data = internal::Array_Data<uint32_t>::New(size, buffer());
    178   header_v2()->payload_interface_ids.Set(data);
    179 
    180   for (size_t i = 0; i < size; ++i) {
    181     ScopedInterfaceEndpointHandle& handle = associated_endpoint_handles_[i];
    182 
    183     DCHECK(handle.pending_association());
    184     data->storage()[i] =
    185         group_controller->AssociateInterface(std::move(handle));
    186   }
    187   associated_endpoint_handles_.clear();
    188 }
    189 
    190 bool Message::DeserializeAssociatedEndpointHandles(
    191     AssociatedGroupController* group_controller) {
    192   associated_endpoint_handles_.clear();
    193 
    194   uint32_t num_ids = payload_num_interface_ids();
    195   if (num_ids == 0)
    196     return true;
    197 
    198   associated_endpoint_handles_.reserve(num_ids);
    199   uint32_t* ids = header_v2()->payload_interface_ids.Get()->storage();
    200   bool result = true;
    201   for (uint32_t i = 0; i < num_ids; ++i) {
    202     auto handle = group_controller->CreateLocalEndpointHandle(ids[i]);
    203     if (IsValidInterfaceId(ids[i]) && !handle.is_valid()) {
    204       // |ids[i]| itself is valid but handle creation failed. In that case, mark
    205       // deserialization as failed but continue to deserialize the rest of
    206       // handles.
    207       result = false;
    208     }
    209 
    210     associated_endpoint_handles_.push_back(std::move(handle));
    211     ids[i] = kInvalidInterfaceId;
    212   }
    213   return result;
    214 }
    215 
    216 PassThroughFilter::PassThroughFilter() {}
    217 
    218 PassThroughFilter::~PassThroughFilter() {}
    219 
    220 bool PassThroughFilter::Accept(Message* message) { return true; }
    221 
    222 SyncMessageResponseContext::SyncMessageResponseContext()
    223     : outer_context_(current()) {
    224   g_tls_sync_response_context.Get().Set(this);
    225 }
    226 
    227 SyncMessageResponseContext::~SyncMessageResponseContext() {
    228   DCHECK_EQ(current(), this);
    229   g_tls_sync_response_context.Get().Set(outer_context_);
    230 }
    231 
    232 // static
    233 SyncMessageResponseContext* SyncMessageResponseContext::current() {
    234   return g_tls_sync_response_context.Get().Get();
    235 }
    236 
    237 void SyncMessageResponseContext::ReportBadMessage(const std::string& error) {
    238   GetBadMessageCallback().Run(error);
    239 }
    240 
    241 const ReportBadMessageCallback&
    242 SyncMessageResponseContext::GetBadMessageCallback() {
    243   if (bad_message_callback_.is_null()) {
    244     bad_message_callback_ =
    245         base::Bind(&DoNotifyBadMessage, base::Passed(&response_));
    246   }
    247   return bad_message_callback_;
    248 }
    249 
    250 MojoResult ReadMessage(MessagePipeHandle handle, Message* message) {
    251   MojoResult rv;
    252 
    253   std::vector<Handle> handles;
    254   ScopedMessageHandle mojo_message;
    255   uint32_t num_bytes = 0, num_handles = 0;
    256   rv = ReadMessageNew(handle,
    257                       &mojo_message,
    258                       &num_bytes,
    259                       nullptr,
    260                       &num_handles,
    261                       MOJO_READ_MESSAGE_FLAG_NONE);
    262   if (rv == MOJO_RESULT_RESOURCE_EXHAUSTED) {
    263     DCHECK_GT(num_handles, 0u);
    264     handles.resize(num_handles);
    265     rv = ReadMessageNew(handle,
    266                         &mojo_message,
    267                         &num_bytes,
    268                         reinterpret_cast<MojoHandle*>(handles.data()),
    269                         &num_handles,
    270                         MOJO_READ_MESSAGE_FLAG_NONE);
    271   }
    272 
    273   if (rv != MOJO_RESULT_OK)
    274     return rv;
    275 
    276   message->InitializeFromMojoMessage(
    277       std::move(mojo_message), num_bytes, &handles);
    278   return MOJO_RESULT_OK;
    279 }
    280 
    281 void ReportBadMessage(const std::string& error) {
    282   internal::MessageDispatchContext* context =
    283       internal::MessageDispatchContext::current();
    284   DCHECK(context);
    285   context->GetBadMessageCallback().Run(error);
    286 }
    287 
    288 ReportBadMessageCallback GetBadMessageCallback() {
    289   internal::MessageDispatchContext* context =
    290       internal::MessageDispatchContext::current();
    291   DCHECK(context);
    292   return context->GetBadMessageCallback();
    293 }
    294 
    295 namespace internal {
    296 
    297 MessageHeaderV2::MessageHeaderV2() = default;
    298 
    299 MessageDispatchContext::MessageDispatchContext(Message* message)
    300     : outer_context_(current()), message_(message) {
    301   g_tls_message_dispatch_context.Get().Set(this);
    302 }
    303 
    304 MessageDispatchContext::~MessageDispatchContext() {
    305   DCHECK_EQ(current(), this);
    306   g_tls_message_dispatch_context.Get().Set(outer_context_);
    307 }
    308 
    309 // static
    310 MessageDispatchContext* MessageDispatchContext::current() {
    311   return g_tls_message_dispatch_context.Get().Get();
    312 }
    313 
    314 const ReportBadMessageCallback&
    315 MessageDispatchContext::GetBadMessageCallback() {
    316   if (bad_message_callback_.is_null()) {
    317     bad_message_callback_ =
    318         base::Bind(&DoNotifyBadMessage, base::Passed(message_));
    319   }
    320   return bad_message_callback_;
    321 }
    322 
    323 // static
    324 void SyncMessageResponseSetup::SetCurrentSyncResponseMessage(Message* message) {
    325   SyncMessageResponseContext* context = SyncMessageResponseContext::current();
    326   if (context)
    327     context->response_ = std::move(*message);
    328 }
    329 
    330 }  // namespace internal
    331 
    332 }  // namespace mojo
    333