Home | History | Annotate | Download | only in lib
      1 // Copyright 2013 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "mojo/public/cpp/bindings/message.h"
      6 
      7 #include <stddef.h>
      8 #include <stdint.h>
      9 #include <stdlib.h>
     10 
     11 #include <algorithm>
     12 #include <utility>
     13 
     14 #include "base/bind.h"
     15 #include "base/lazy_instance.h"
     16 #include "base/logging.h"
     17 #include "base/numerics/safe_math.h"
     18 #include "base/strings/stringprintf.h"
     19 #include "base/threading/thread_local.h"
     20 #include "mojo/public/cpp/bindings/associated_group_controller.h"
     21 #include "mojo/public/cpp/bindings/lib/array_internal.h"
     22 #include "mojo/public/cpp/bindings/lib/unserialized_message_context.h"
     23 
     24 namespace mojo {
     25 
     26 namespace {
     27 
     28 base::LazyInstance<base::ThreadLocalPointer<internal::MessageDispatchContext>>::
     29     Leaky g_tls_message_dispatch_context = LAZY_INSTANCE_INITIALIZER;
     30 
     31 base::LazyInstance<base::ThreadLocalPointer<SyncMessageResponseContext>>::Leaky
     32     g_tls_sync_response_context = LAZY_INSTANCE_INITIALIZER;
     33 
     34 void DoNotifyBadMessage(Message message, const std::string& error) {
     35   message.NotifyBadMessage(error);
     36 }
     37 
     38 template <typename HeaderType>
     39 void AllocateHeaderFromBuffer(internal::Buffer* buffer, HeaderType** header) {
     40   *header = buffer->AllocateAndGet<HeaderType>();
     41   (*header)->num_bytes = sizeof(HeaderType);
     42 }
     43 
     44 void WriteMessageHeader(uint32_t name,
     45                         uint32_t flags,
     46                         size_t payload_interface_id_count,
     47                         internal::Buffer* payload_buffer) {
     48   if (payload_interface_id_count > 0) {
     49     // Version 2
     50     internal::MessageHeaderV2* header;
     51     AllocateHeaderFromBuffer(payload_buffer, &header);
     52     header->version = 2;
     53     header->name = name;
     54     header->flags = flags;
     55     // The payload immediately follows the header.
     56     header->payload.Set(header + 1);
     57   } else if (flags &
     58              (Message::kFlagExpectsResponse | Message::kFlagIsResponse)) {
     59     // Version 1
     60     internal::MessageHeaderV1* header;
     61     AllocateHeaderFromBuffer(payload_buffer, &header);
     62     header->version = 1;
     63     header->name = name;
     64     header->flags = flags;
     65   } else {
     66     internal::MessageHeader* header;
     67     AllocateHeaderFromBuffer(payload_buffer, &header);
     68     header->version = 0;
     69     header->name = name;
     70     header->flags = flags;
     71   }
     72 }
     73 
     74 void CreateSerializedMessageObject(uint32_t name,
     75                                    uint32_t flags,
     76                                    size_t payload_size,
     77                                    size_t payload_interface_id_count,
     78                                    std::vector<ScopedHandle>* handles,
     79                                    ScopedMessageHandle* out_handle,
     80                                    internal::Buffer* out_buffer) {
     81   ScopedMessageHandle handle;
     82   MojoResult rv = mojo::CreateMessage(&handle);
     83   DCHECK_EQ(MOJO_RESULT_OK, rv);
     84   DCHECK(handle.is_valid());
     85 
     86   void* buffer;
     87   uint32_t buffer_size;
     88   size_t total_size = internal::ComputeSerializedMessageSize(
     89       flags, payload_size, payload_interface_id_count);
     90   DCHECK(base::IsValueInRangeForNumericType<uint32_t>(total_size));
     91   DCHECK(!handles ||
     92          base::IsValueInRangeForNumericType<uint32_t>(handles->size()));
     93   rv = MojoAppendMessageData(
     94       handle->value(), static_cast<uint32_t>(total_size),
     95       handles ? reinterpret_cast<MojoHandle*>(handles->data()) : nullptr,
     96       handles ? static_cast<uint32_t>(handles->size()) : 0, nullptr, &buffer,
     97       &buffer_size);
     98   DCHECK_EQ(MOJO_RESULT_OK, rv);
     99   if (handles) {
    100     // Handle ownership has been taken by MojoAppendMessageData.
    101     for (size_t i = 0; i < handles->size(); ++i)
    102       ignore_result(handles->at(i).release());
    103   }
    104 
    105   internal::Buffer payload_buffer(handle.get(), total_size, buffer,
    106                                   buffer_size);
    107 
    108   // Make sure we zero the memory first!
    109   memset(payload_buffer.data(), 0, total_size);
    110   WriteMessageHeader(name, flags, payload_interface_id_count, &payload_buffer);
    111 
    112   *out_handle = std::move(handle);
    113   *out_buffer = std::move(payload_buffer);
    114 }
    115 
    116 void SerializeUnserializedContext(MojoMessageHandle message,
    117                                   uintptr_t context_value) {
    118   auto* context =
    119       reinterpret_cast<internal::UnserializedMessageContext*>(context_value);
    120   void* buffer;
    121   uint32_t buffer_size;
    122   MojoResult attach_result = MojoAppendMessageData(
    123       message, 0, nullptr, 0, nullptr, &buffer, &buffer_size);
    124   if (attach_result != MOJO_RESULT_OK)
    125     return;
    126 
    127   internal::Buffer payload_buffer(MessageHandle(message), 0, buffer,
    128                                   buffer_size);
    129   WriteMessageHeader(context->message_name(), context->message_flags(),
    130                      0 /* payload_interface_id_count */, &payload_buffer);
    131 
    132   // We need to copy additional header data which may have been set after
    133   // message construction, as this codepath may be reached at some arbitrary
    134   // time between message send and message dispatch.
    135   static_cast<internal::MessageHeader*>(buffer)->interface_id =
    136       context->header()->interface_id;
    137   if (context->header()->flags &
    138       (Message::kFlagExpectsResponse | Message::kFlagIsResponse)) {
    139     DCHECK_GE(context->header()->version, 1u);
    140     static_cast<internal::MessageHeaderV1*>(buffer)->request_id =
    141         context->header()->request_id;
    142   }
    143 
    144   internal::SerializationContext serialization_context;
    145   context->Serialize(&serialization_context, &payload_buffer);
    146 
    147   // TODO(crbug.com/753433): Support lazy serialization of associated endpoint
    148   // handles. See corresponding TODO in the bindings generator for proof that
    149   // this DCHECK is indeed valid.
    150   DCHECK(serialization_context.associated_endpoint_handles()->empty());
    151   if (!serialization_context.handles()->empty())
    152     payload_buffer.AttachHandles(serialization_context.mutable_handles());
    153   payload_buffer.Seal();
    154 }
    155 
    156 void DestroyUnserializedContext(uintptr_t context) {
    157   delete reinterpret_cast<internal::UnserializedMessageContext*>(context);
    158 }
    159 
    160 ScopedMessageHandle CreateUnserializedMessageObject(
    161     std::unique_ptr<internal::UnserializedMessageContext> context) {
    162   ScopedMessageHandle handle;
    163   MojoResult rv = mojo::CreateMessage(&handle);
    164   DCHECK_EQ(MOJO_RESULT_OK, rv);
    165   DCHECK(handle.is_valid());
    166 
    167   rv = MojoSetMessageContext(
    168       handle->value(), reinterpret_cast<uintptr_t>(context.release()),
    169       &SerializeUnserializedContext, &DestroyUnserializedContext, nullptr);
    170   DCHECK_EQ(MOJO_RESULT_OK, rv);
    171   return handle;
    172 }
    173 
    174 }  // namespace
    175 
    176 Message::Message() = default;
    177 
    178 Message::Message(Message&& other)
    179     : handle_(std::move(other.handle_)),
    180       payload_buffer_(std::move(other.payload_buffer_)),
    181       handles_(std::move(other.handles_)),
    182       associated_endpoint_handles_(
    183           std::move(other.associated_endpoint_handles_)),
    184       transferable_(other.transferable_),
    185       serialized_(other.serialized_) {
    186   other.transferable_ = false;
    187   other.serialized_ = false;
    188 #if defined(ENABLE_IPC_FUZZER)
    189   interface_name_ = other.interface_name_;
    190   method_name_ = other.method_name_;
    191 #endif
    192 }
    193 
    194 Message::Message(std::unique_ptr<internal::UnserializedMessageContext> context)
    195     : Message(CreateUnserializedMessageObject(std::move(context))) {}
    196 
    197 Message::Message(uint32_t name,
    198                  uint32_t flags,
    199                  size_t payload_size,
    200                  size_t payload_interface_id_count,
    201                  std::vector<ScopedHandle>* handles) {
    202   CreateSerializedMessageObject(name, flags, payload_size,
    203                                 payload_interface_id_count, handles, &handle_,
    204                                 &payload_buffer_);
    205   transferable_ = true;
    206   serialized_ = true;
    207 }
    208 
    209 Message::Message(ScopedMessageHandle handle) {
    210   DCHECK(handle.is_valid());
    211 
    212   uintptr_t context_value = 0;
    213   MojoResult get_context_result =
    214       MojoGetMessageContext(handle->value(), nullptr, &context_value);
    215   if (get_context_result == MOJO_RESULT_NOT_FOUND) {
    216     // It's a serialized message. Extract handles if possible.
    217     uint32_t num_bytes;
    218     void* buffer;
    219     uint32_t num_handles = 0;
    220     MojoResult rv = MojoGetMessageData(handle->value(), nullptr, &buffer,
    221                                        &num_bytes, nullptr, &num_handles);
    222     if (rv == MOJO_RESULT_RESOURCE_EXHAUSTED) {
    223       handles_.resize(num_handles);
    224       rv = MojoGetMessageData(handle->value(), nullptr, &buffer, &num_bytes,
    225                               reinterpret_cast<MojoHandle*>(handles_.data()),
    226                               &num_handles);
    227     } else {
    228       // No handles, so it's safe to retransmit this message if the caller
    229       // really wants to.
    230       transferable_ = true;
    231     }
    232 
    233     if (rv != MOJO_RESULT_OK) {
    234       // Failed to deserialize handles. Leave the Message uninitialized.
    235       return;
    236     }
    237 
    238     payload_buffer_ = internal::Buffer(buffer, num_bytes, num_bytes);
    239     serialized_ = true;
    240   } else {
    241     DCHECK_EQ(MOJO_RESULT_OK, get_context_result);
    242     auto* context =
    243         reinterpret_cast<internal::UnserializedMessageContext*>(context_value);
    244     // Dummy data address so common header accessors still behave properly. The
    245     // choice is V1 reflects unserialized message capabilities: we may or may
    246     // not need to support request IDs (which require at least V1), but we never
    247     // (for now, anyway) need to support associated interface handles (V2).
    248     payload_buffer_ =
    249         internal::Buffer(context->header(), sizeof(internal::MessageHeaderV1),
    250                          sizeof(internal::MessageHeaderV1));
    251     transferable_ = true;
    252     serialized_ = false;
    253   }
    254 
    255   handle_ = std::move(handle);
    256 }
    257 
    258 Message::~Message() = default;
    259 
    260 Message& Message::operator=(Message&& other) {
    261   handle_ = std::move(other.handle_);
    262   payload_buffer_ = std::move(other.payload_buffer_);
    263   handles_ = std::move(other.handles_);
    264   associated_endpoint_handles_ = std::move(other.associated_endpoint_handles_);
    265   transferable_ = other.transferable_;
    266   other.transferable_ = false;
    267   serialized_ = other.serialized_;
    268   other.serialized_ = false;
    269 #if defined(ENABLE_IPC_FUZZER)
    270   interface_name_ = other.interface_name_;
    271   method_name_ = other.method_name_;
    272 #endif
    273   return *this;
    274 }
    275 
    276 void Message::Reset() {
    277   handle_.reset();
    278   payload_buffer_.Reset();
    279   handles_.clear();
    280   associated_endpoint_handles_.clear();
    281   transferable_ = false;
    282   serialized_ = false;
    283 }
    284 
    285 const uint8_t* Message::payload() const {
    286   if (version() < 2)
    287     return data() + header()->num_bytes;
    288 
    289   DCHECK(!header_v2()->payload.is_null());
    290   return static_cast<const uint8_t*>(header_v2()->payload.Get());
    291 }
    292 
    293 uint32_t Message::payload_num_bytes() const {
    294   DCHECK_GE(data_num_bytes(), header()->num_bytes);
    295   size_t num_bytes;
    296   if (version() < 2) {
    297     num_bytes = data_num_bytes() - header()->num_bytes;
    298   } else {
    299     auto payload_begin =
    300         reinterpret_cast<uintptr_t>(header_v2()->payload.Get());
    301     auto payload_end =
    302         reinterpret_cast<uintptr_t>(header_v2()->payload_interface_ids.Get());
    303     if (!payload_end)
    304       payload_end = reinterpret_cast<uintptr_t>(data() + data_num_bytes());
    305     DCHECK_GE(payload_end, payload_begin);
    306     num_bytes = payload_end - payload_begin;
    307   }
    308   DCHECK(base::IsValueInRangeForNumericType<uint32_t>(num_bytes));
    309   return static_cast<uint32_t>(num_bytes);
    310 }
    311 
    312 uint32_t Message::payload_num_interface_ids() const {
    313   auto* array_pointer =
    314       version() < 2 ? nullptr : header_v2()->payload_interface_ids.Get();
    315   return array_pointer ? static_cast<uint32_t>(array_pointer->size()) : 0;
    316 }
    317 
    318 const uint32_t* Message::payload_interface_ids() const {
    319   auto* array_pointer =
    320       version() < 2 ? nullptr : header_v2()->payload_interface_ids.Get();
    321   return array_pointer ? array_pointer->storage() : nullptr;
    322 }
    323 
    324 void Message::AttachHandlesFromSerializationContext(
    325     internal::SerializationContext* context) {
    326   if (context->handles()->empty() &&
    327       context->associated_endpoint_handles()->empty()) {
    328     // No handles attached, so no extra serialization work.
    329     return;
    330   }
    331 
    332   if (context->associated_endpoint_handles()->empty()) {
    333     // Attaching only non-associated handles is easier since we don't have to
    334     // modify the message header. Faster path for that.
    335     payload_buffer_.AttachHandles(context->mutable_handles());
    336     return;
    337   }
    338 
    339   // Allocate a new message with enough space to hold all attached handles. Copy
    340   // this message's contents into the new one and use it to replace ourself.
    341   //
    342   // TODO(rockot): We could avoid the extra full message allocation by instead
    343   // growing the buffer and carefully moving its contents around. This errs on
    344   // the side of less complexity with probably only marginal performance cost.
    345   uint32_t payload_size = payload_num_bytes();
    346   mojo::Message new_message(name(), header()->flags, payload_size,
    347                             context->associated_endpoint_handles()->size(),
    348                             context->mutable_handles());
    349   std::swap(*context->mutable_associated_endpoint_handles(),
    350             new_message.associated_endpoint_handles_);
    351   memcpy(new_message.payload_buffer()->AllocateAndGet(payload_size), payload(),
    352          payload_size);
    353   *this = std::move(new_message);
    354 }
    355 
    356 ScopedMessageHandle Message::TakeMojoMessage() {
    357   // If there are associated endpoints transferred,
    358   // SerializeAssociatedEndpointHandles() must be called before this method.
    359   DCHECK(associated_endpoint_handles_.empty());
    360   DCHECK(transferable_);
    361   payload_buffer_.Seal();
    362   auto handle = std::move(handle_);
    363   Reset();
    364   return handle;
    365 }
    366 
    367 void Message::NotifyBadMessage(const std::string& error) {
    368   DCHECK(handle_.is_valid());
    369   mojo::NotifyBadMessage(handle_.get(), error);
    370 }
    371 
    372 void Message::SerializeAssociatedEndpointHandles(
    373     AssociatedGroupController* group_controller) {
    374   if (associated_endpoint_handles_.empty())
    375     return;
    376 
    377   DCHECK_GE(version(), 2u);
    378   DCHECK(header_v2()->payload_interface_ids.is_null());
    379   DCHECK(payload_buffer_.is_valid());
    380   DCHECK(handle_.is_valid());
    381 
    382   size_t size = associated_endpoint_handles_.size();
    383 
    384   internal::Array_Data<uint32_t>::BufferWriter handle_writer;
    385   handle_writer.Allocate(size, &payload_buffer_);
    386   header_v2()->payload_interface_ids.Set(handle_writer.data());
    387 
    388   for (size_t i = 0; i < size; ++i) {
    389     ScopedInterfaceEndpointHandle& handle = associated_endpoint_handles_[i];
    390 
    391     DCHECK(handle.pending_association());
    392     handle_writer->storage()[i] =
    393         group_controller->AssociateInterface(std::move(handle));
    394   }
    395   associated_endpoint_handles_.clear();
    396 }
    397 
    398 bool Message::DeserializeAssociatedEndpointHandles(
    399     AssociatedGroupController* group_controller) {
    400   if (!serialized_)
    401     return true;
    402 
    403   associated_endpoint_handles_.clear();
    404 
    405   uint32_t num_ids = payload_num_interface_ids();
    406   if (num_ids == 0)
    407     return true;
    408 
    409   associated_endpoint_handles_.reserve(num_ids);
    410   uint32_t* ids = header_v2()->payload_interface_ids.Get()->storage();
    411   bool result = true;
    412   for (uint32_t i = 0; i < num_ids; ++i) {
    413     auto handle = group_controller->CreateLocalEndpointHandle(ids[i]);
    414     if (IsValidInterfaceId(ids[i]) && !handle.is_valid()) {
    415       // |ids[i]| itself is valid but handle creation failed. In that case, mark
    416       // deserialization as failed but continue to deserialize the rest of
    417       // handles.
    418       result = false;
    419     }
    420 
    421     associated_endpoint_handles_.push_back(std::move(handle));
    422     ids[i] = kInvalidInterfaceId;
    423   }
    424   return result;
    425 }
    426 
    427 void Message::SerializeIfNecessary() {
    428   MojoResult rv = MojoSerializeMessage(handle_->value(), nullptr);
    429   if (rv == MOJO_RESULT_FAILED_PRECONDITION)
    430     return;
    431 
    432   // Reconstruct this Message instance from the serialized message's handle.
    433   *this = Message(std::move(handle_));
    434 }
    435 
    436 std::unique_ptr<internal::UnserializedMessageContext>
    437 Message::TakeUnserializedContext(
    438     const internal::UnserializedMessageContext::Tag* tag) {
    439   DCHECK(handle_.is_valid());
    440   uintptr_t context_value = 0;
    441   MojoResult rv =
    442       MojoGetMessageContext(handle_->value(), nullptr, &context_value);
    443   if (rv == MOJO_RESULT_NOT_FOUND)
    444     return nullptr;
    445   DCHECK_EQ(MOJO_RESULT_OK, rv);
    446 
    447   auto* context =
    448       reinterpret_cast<internal::UnserializedMessageContext*>(context_value);
    449   if (context->tag() != tag)
    450     return nullptr;
    451 
    452   // Detach the context from the message.
    453   rv = MojoSetMessageContext(handle_->value(), 0, nullptr, nullptr, nullptr);
    454   DCHECK_EQ(MOJO_RESULT_OK, rv);
    455   return base::WrapUnique(context);
    456 }
    457 
    458 bool MessageReceiver::PrefersSerializedMessages() {
    459   return false;
    460 }
    461 
    462 PassThroughFilter::PassThroughFilter() {}
    463 
    464 PassThroughFilter::~PassThroughFilter() {}
    465 
    466 bool PassThroughFilter::Accept(Message* message) {
    467   return true;
    468 }
    469 
    470 SyncMessageResponseContext::SyncMessageResponseContext()
    471     : outer_context_(current()) {
    472   g_tls_sync_response_context.Get().Set(this);
    473 }
    474 
    475 SyncMessageResponseContext::~SyncMessageResponseContext() {
    476   DCHECK_EQ(current(), this);
    477   g_tls_sync_response_context.Get().Set(outer_context_);
    478 }
    479 
    480 // static
    481 SyncMessageResponseContext* SyncMessageResponseContext::current() {
    482   return g_tls_sync_response_context.Get().Get();
    483 }
    484 
    485 void SyncMessageResponseContext::ReportBadMessage(const std::string& error) {
    486   GetBadMessageCallback().Run(error);
    487 }
    488 
    489 ReportBadMessageCallback SyncMessageResponseContext::GetBadMessageCallback() {
    490   DCHECK(!response_.IsNull());
    491   return base::BindOnce(&DoNotifyBadMessage, std::move(response_));
    492 }
    493 
    494 MojoResult ReadMessage(MessagePipeHandle handle, Message* message) {
    495   ScopedMessageHandle message_handle;
    496   MojoResult rv =
    497       ReadMessageNew(handle, &message_handle, MOJO_READ_MESSAGE_FLAG_NONE);
    498   if (rv != MOJO_RESULT_OK)
    499     return rv;
    500 
    501   *message = Message(std::move(message_handle));
    502   return MOJO_RESULT_OK;
    503 }
    504 
    505 void ReportBadMessage(const std::string& error) {
    506   internal::MessageDispatchContext* context =
    507       internal::MessageDispatchContext::current();
    508   DCHECK(context);
    509   context->GetBadMessageCallback().Run(error);
    510 }
    511 
    512 ReportBadMessageCallback GetBadMessageCallback() {
    513   internal::MessageDispatchContext* context =
    514       internal::MessageDispatchContext::current();
    515   DCHECK(context);
    516   return context->GetBadMessageCallback();
    517 }
    518 
    519 namespace internal {
    520 
    521 MessageHeaderV2::MessageHeaderV2() = default;
    522 
    523 MessageDispatchContext::MessageDispatchContext(Message* message)
    524     : outer_context_(current()), message_(message) {
    525   g_tls_message_dispatch_context.Get().Set(this);
    526 }
    527 
    528 MessageDispatchContext::~MessageDispatchContext() {
    529   DCHECK_EQ(current(), this);
    530   g_tls_message_dispatch_context.Get().Set(outer_context_);
    531 }
    532 
    533 // static
    534 MessageDispatchContext* MessageDispatchContext::current() {
    535   return g_tls_message_dispatch_context.Get().Get();
    536 }
    537 
    538 ReportBadMessageCallback MessageDispatchContext::GetBadMessageCallback() {
    539   DCHECK(!message_->IsNull());
    540   return base::BindOnce(&DoNotifyBadMessage, std::move(*message_));
    541 }
    542 
    543 // static
    544 void SyncMessageResponseSetup::SetCurrentSyncResponseMessage(Message* message) {
    545   SyncMessageResponseContext* context = SyncMessageResponseContext::current();
    546   if (context)
    547     context->response_ = std::move(*message);
    548 }
    549 
    550 }  // namespace internal
    551 
    552 }  // namespace mojo
    553