1 // Copyright 2013 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "mojo/public/cpp/bindings/message.h" 6 7 #include <stddef.h> 8 #include <stdint.h> 9 #include <stdlib.h> 10 11 #include <algorithm> 12 #include <utility> 13 14 #include "base/bind.h" 15 #include "base/lazy_instance.h" 16 #include "base/logging.h" 17 #include "base/numerics/safe_math.h" 18 #include "base/strings/stringprintf.h" 19 #include "base/threading/thread_local.h" 20 #include "mojo/public/cpp/bindings/associated_group_controller.h" 21 #include "mojo/public/cpp/bindings/lib/array_internal.h" 22 #include "mojo/public/cpp/bindings/lib/unserialized_message_context.h" 23 24 namespace mojo { 25 26 namespace { 27 28 base::LazyInstance<base::ThreadLocalPointer<internal::MessageDispatchContext>>:: 29 Leaky g_tls_message_dispatch_context = LAZY_INSTANCE_INITIALIZER; 30 31 base::LazyInstance<base::ThreadLocalPointer<SyncMessageResponseContext>>::Leaky 32 g_tls_sync_response_context = LAZY_INSTANCE_INITIALIZER; 33 34 void DoNotifyBadMessage(Message message, const std::string& error) { 35 message.NotifyBadMessage(error); 36 } 37 38 template <typename HeaderType> 39 void AllocateHeaderFromBuffer(internal::Buffer* buffer, HeaderType** header) { 40 *header = buffer->AllocateAndGet<HeaderType>(); 41 (*header)->num_bytes = sizeof(HeaderType); 42 } 43 44 void WriteMessageHeader(uint32_t name, 45 uint32_t flags, 46 size_t payload_interface_id_count, 47 internal::Buffer* payload_buffer) { 48 if (payload_interface_id_count > 0) { 49 // Version 2 50 internal::MessageHeaderV2* header; 51 AllocateHeaderFromBuffer(payload_buffer, &header); 52 header->version = 2; 53 header->name = name; 54 header->flags = flags; 55 // The payload immediately follows the header. 56 header->payload.Set(header + 1); 57 } else if (flags & 58 (Message::kFlagExpectsResponse | Message::kFlagIsResponse)) { 59 // Version 1 60 internal::MessageHeaderV1* header; 61 AllocateHeaderFromBuffer(payload_buffer, &header); 62 header->version = 1; 63 header->name = name; 64 header->flags = flags; 65 } else { 66 internal::MessageHeader* header; 67 AllocateHeaderFromBuffer(payload_buffer, &header); 68 header->version = 0; 69 header->name = name; 70 header->flags = flags; 71 } 72 } 73 74 void CreateSerializedMessageObject(uint32_t name, 75 uint32_t flags, 76 size_t payload_size, 77 size_t payload_interface_id_count, 78 std::vector<ScopedHandle>* handles, 79 ScopedMessageHandle* out_handle, 80 internal::Buffer* out_buffer) { 81 ScopedMessageHandle handle; 82 MojoResult rv = mojo::CreateMessage(&handle); 83 DCHECK_EQ(MOJO_RESULT_OK, rv); 84 DCHECK(handle.is_valid()); 85 86 void* buffer; 87 uint32_t buffer_size; 88 size_t total_size = internal::ComputeSerializedMessageSize( 89 flags, payload_size, payload_interface_id_count); 90 DCHECK(base::IsValueInRangeForNumericType<uint32_t>(total_size)); 91 DCHECK(!handles || 92 base::IsValueInRangeForNumericType<uint32_t>(handles->size())); 93 rv = MojoAppendMessageData( 94 handle->value(), static_cast<uint32_t>(total_size), 95 handles ? reinterpret_cast<MojoHandle*>(handles->data()) : nullptr, 96 handles ? static_cast<uint32_t>(handles->size()) : 0, nullptr, &buffer, 97 &buffer_size); 98 DCHECK_EQ(MOJO_RESULT_OK, rv); 99 if (handles) { 100 // Handle ownership has been taken by MojoAppendMessageData. 101 for (size_t i = 0; i < handles->size(); ++i) 102 ignore_result(handles->at(i).release()); 103 } 104 105 internal::Buffer payload_buffer(handle.get(), total_size, buffer, 106 buffer_size); 107 108 // Make sure we zero the memory first! 109 memset(payload_buffer.data(), 0, total_size); 110 WriteMessageHeader(name, flags, payload_interface_id_count, &payload_buffer); 111 112 *out_handle = std::move(handle); 113 *out_buffer = std::move(payload_buffer); 114 } 115 116 void SerializeUnserializedContext(MojoMessageHandle message, 117 uintptr_t context_value) { 118 auto* context = 119 reinterpret_cast<internal::UnserializedMessageContext*>(context_value); 120 void* buffer; 121 uint32_t buffer_size; 122 MojoResult attach_result = MojoAppendMessageData( 123 message, 0, nullptr, 0, nullptr, &buffer, &buffer_size); 124 if (attach_result != MOJO_RESULT_OK) 125 return; 126 127 internal::Buffer payload_buffer(MessageHandle(message), 0, buffer, 128 buffer_size); 129 WriteMessageHeader(context->message_name(), context->message_flags(), 130 0 /* payload_interface_id_count */, &payload_buffer); 131 132 // We need to copy additional header data which may have been set after 133 // message construction, as this codepath may be reached at some arbitrary 134 // time between message send and message dispatch. 135 static_cast<internal::MessageHeader*>(buffer)->interface_id = 136 context->header()->interface_id; 137 if (context->header()->flags & 138 (Message::kFlagExpectsResponse | Message::kFlagIsResponse)) { 139 DCHECK_GE(context->header()->version, 1u); 140 static_cast<internal::MessageHeaderV1*>(buffer)->request_id = 141 context->header()->request_id; 142 } 143 144 internal::SerializationContext serialization_context; 145 context->Serialize(&serialization_context, &payload_buffer); 146 147 // TODO(crbug.com/753433): Support lazy serialization of associated endpoint 148 // handles. See corresponding TODO in the bindings generator for proof that 149 // this DCHECK is indeed valid. 150 DCHECK(serialization_context.associated_endpoint_handles()->empty()); 151 if (!serialization_context.handles()->empty()) 152 payload_buffer.AttachHandles(serialization_context.mutable_handles()); 153 payload_buffer.Seal(); 154 } 155 156 void DestroyUnserializedContext(uintptr_t context) { 157 delete reinterpret_cast<internal::UnserializedMessageContext*>(context); 158 } 159 160 ScopedMessageHandle CreateUnserializedMessageObject( 161 std::unique_ptr<internal::UnserializedMessageContext> context) { 162 ScopedMessageHandle handle; 163 MojoResult rv = mojo::CreateMessage(&handle); 164 DCHECK_EQ(MOJO_RESULT_OK, rv); 165 DCHECK(handle.is_valid()); 166 167 rv = MojoSetMessageContext( 168 handle->value(), reinterpret_cast<uintptr_t>(context.release()), 169 &SerializeUnserializedContext, &DestroyUnserializedContext, nullptr); 170 DCHECK_EQ(MOJO_RESULT_OK, rv); 171 return handle; 172 } 173 174 } // namespace 175 176 Message::Message() = default; 177 178 Message::Message(Message&& other) 179 : handle_(std::move(other.handle_)), 180 payload_buffer_(std::move(other.payload_buffer_)), 181 handles_(std::move(other.handles_)), 182 associated_endpoint_handles_( 183 std::move(other.associated_endpoint_handles_)), 184 transferable_(other.transferable_), 185 serialized_(other.serialized_) { 186 other.transferable_ = false; 187 other.serialized_ = false; 188 #if defined(ENABLE_IPC_FUZZER) 189 interface_name_ = other.interface_name_; 190 method_name_ = other.method_name_; 191 #endif 192 } 193 194 Message::Message(std::unique_ptr<internal::UnserializedMessageContext> context) 195 : Message(CreateUnserializedMessageObject(std::move(context))) {} 196 197 Message::Message(uint32_t name, 198 uint32_t flags, 199 size_t payload_size, 200 size_t payload_interface_id_count, 201 std::vector<ScopedHandle>* handles) { 202 CreateSerializedMessageObject(name, flags, payload_size, 203 payload_interface_id_count, handles, &handle_, 204 &payload_buffer_); 205 transferable_ = true; 206 serialized_ = true; 207 } 208 209 Message::Message(ScopedMessageHandle handle) { 210 DCHECK(handle.is_valid()); 211 212 uintptr_t context_value = 0; 213 MojoResult get_context_result = 214 MojoGetMessageContext(handle->value(), nullptr, &context_value); 215 if (get_context_result == MOJO_RESULT_NOT_FOUND) { 216 // It's a serialized message. Extract handles if possible. 217 uint32_t num_bytes; 218 void* buffer; 219 uint32_t num_handles = 0; 220 MojoResult rv = MojoGetMessageData(handle->value(), nullptr, &buffer, 221 &num_bytes, nullptr, &num_handles); 222 if (rv == MOJO_RESULT_RESOURCE_EXHAUSTED) { 223 handles_.resize(num_handles); 224 rv = MojoGetMessageData(handle->value(), nullptr, &buffer, &num_bytes, 225 reinterpret_cast<MojoHandle*>(handles_.data()), 226 &num_handles); 227 } else { 228 // No handles, so it's safe to retransmit this message if the caller 229 // really wants to. 230 transferable_ = true; 231 } 232 233 if (rv != MOJO_RESULT_OK) { 234 // Failed to deserialize handles. Leave the Message uninitialized. 235 return; 236 } 237 238 payload_buffer_ = internal::Buffer(buffer, num_bytes, num_bytes); 239 serialized_ = true; 240 } else { 241 DCHECK_EQ(MOJO_RESULT_OK, get_context_result); 242 auto* context = 243 reinterpret_cast<internal::UnserializedMessageContext*>(context_value); 244 // Dummy data address so common header accessors still behave properly. The 245 // choice is V1 reflects unserialized message capabilities: we may or may 246 // not need to support request IDs (which require at least V1), but we never 247 // (for now, anyway) need to support associated interface handles (V2). 248 payload_buffer_ = 249 internal::Buffer(context->header(), sizeof(internal::MessageHeaderV1), 250 sizeof(internal::MessageHeaderV1)); 251 transferable_ = true; 252 serialized_ = false; 253 } 254 255 handle_ = std::move(handle); 256 } 257 258 Message::~Message() = default; 259 260 Message& Message::operator=(Message&& other) { 261 handle_ = std::move(other.handle_); 262 payload_buffer_ = std::move(other.payload_buffer_); 263 handles_ = std::move(other.handles_); 264 associated_endpoint_handles_ = std::move(other.associated_endpoint_handles_); 265 transferable_ = other.transferable_; 266 other.transferable_ = false; 267 serialized_ = other.serialized_; 268 other.serialized_ = false; 269 #if defined(ENABLE_IPC_FUZZER) 270 interface_name_ = other.interface_name_; 271 method_name_ = other.method_name_; 272 #endif 273 return *this; 274 } 275 276 void Message::Reset() { 277 handle_.reset(); 278 payload_buffer_.Reset(); 279 handles_.clear(); 280 associated_endpoint_handles_.clear(); 281 transferable_ = false; 282 serialized_ = false; 283 } 284 285 const uint8_t* Message::payload() const { 286 if (version() < 2) 287 return data() + header()->num_bytes; 288 289 DCHECK(!header_v2()->payload.is_null()); 290 return static_cast<const uint8_t*>(header_v2()->payload.Get()); 291 } 292 293 uint32_t Message::payload_num_bytes() const { 294 DCHECK_GE(data_num_bytes(), header()->num_bytes); 295 size_t num_bytes; 296 if (version() < 2) { 297 num_bytes = data_num_bytes() - header()->num_bytes; 298 } else { 299 auto payload_begin = 300 reinterpret_cast<uintptr_t>(header_v2()->payload.Get()); 301 auto payload_end = 302 reinterpret_cast<uintptr_t>(header_v2()->payload_interface_ids.Get()); 303 if (!payload_end) 304 payload_end = reinterpret_cast<uintptr_t>(data() + data_num_bytes()); 305 DCHECK_GE(payload_end, payload_begin); 306 num_bytes = payload_end - payload_begin; 307 } 308 DCHECK(base::IsValueInRangeForNumericType<uint32_t>(num_bytes)); 309 return static_cast<uint32_t>(num_bytes); 310 } 311 312 uint32_t Message::payload_num_interface_ids() const { 313 auto* array_pointer = 314 version() < 2 ? nullptr : header_v2()->payload_interface_ids.Get(); 315 return array_pointer ? static_cast<uint32_t>(array_pointer->size()) : 0; 316 } 317 318 const uint32_t* Message::payload_interface_ids() const { 319 auto* array_pointer = 320 version() < 2 ? nullptr : header_v2()->payload_interface_ids.Get(); 321 return array_pointer ? array_pointer->storage() : nullptr; 322 } 323 324 void Message::AttachHandlesFromSerializationContext( 325 internal::SerializationContext* context) { 326 if (context->handles()->empty() && 327 context->associated_endpoint_handles()->empty()) { 328 // No handles attached, so no extra serialization work. 329 return; 330 } 331 332 if (context->associated_endpoint_handles()->empty()) { 333 // Attaching only non-associated handles is easier since we don't have to 334 // modify the message header. Faster path for that. 335 payload_buffer_.AttachHandles(context->mutable_handles()); 336 return; 337 } 338 339 // Allocate a new message with enough space to hold all attached handles. Copy 340 // this message's contents into the new one and use it to replace ourself. 341 // 342 // TODO(rockot): We could avoid the extra full message allocation by instead 343 // growing the buffer and carefully moving its contents around. This errs on 344 // the side of less complexity with probably only marginal performance cost. 345 uint32_t payload_size = payload_num_bytes(); 346 mojo::Message new_message(name(), header()->flags, payload_size, 347 context->associated_endpoint_handles()->size(), 348 context->mutable_handles()); 349 std::swap(*context->mutable_associated_endpoint_handles(), 350 new_message.associated_endpoint_handles_); 351 memcpy(new_message.payload_buffer()->AllocateAndGet(payload_size), payload(), 352 payload_size); 353 *this = std::move(new_message); 354 } 355 356 ScopedMessageHandle Message::TakeMojoMessage() { 357 // If there are associated endpoints transferred, 358 // SerializeAssociatedEndpointHandles() must be called before this method. 359 DCHECK(associated_endpoint_handles_.empty()); 360 DCHECK(transferable_); 361 payload_buffer_.Seal(); 362 auto handle = std::move(handle_); 363 Reset(); 364 return handle; 365 } 366 367 void Message::NotifyBadMessage(const std::string& error) { 368 DCHECK(handle_.is_valid()); 369 mojo::NotifyBadMessage(handle_.get(), error); 370 } 371 372 void Message::SerializeAssociatedEndpointHandles( 373 AssociatedGroupController* group_controller) { 374 if (associated_endpoint_handles_.empty()) 375 return; 376 377 DCHECK_GE(version(), 2u); 378 DCHECK(header_v2()->payload_interface_ids.is_null()); 379 DCHECK(payload_buffer_.is_valid()); 380 DCHECK(handle_.is_valid()); 381 382 size_t size = associated_endpoint_handles_.size(); 383 384 internal::Array_Data<uint32_t>::BufferWriter handle_writer; 385 handle_writer.Allocate(size, &payload_buffer_); 386 header_v2()->payload_interface_ids.Set(handle_writer.data()); 387 388 for (size_t i = 0; i < size; ++i) { 389 ScopedInterfaceEndpointHandle& handle = associated_endpoint_handles_[i]; 390 391 DCHECK(handle.pending_association()); 392 handle_writer->storage()[i] = 393 group_controller->AssociateInterface(std::move(handle)); 394 } 395 associated_endpoint_handles_.clear(); 396 } 397 398 bool Message::DeserializeAssociatedEndpointHandles( 399 AssociatedGroupController* group_controller) { 400 if (!serialized_) 401 return true; 402 403 associated_endpoint_handles_.clear(); 404 405 uint32_t num_ids = payload_num_interface_ids(); 406 if (num_ids == 0) 407 return true; 408 409 associated_endpoint_handles_.reserve(num_ids); 410 uint32_t* ids = header_v2()->payload_interface_ids.Get()->storage(); 411 bool result = true; 412 for (uint32_t i = 0; i < num_ids; ++i) { 413 auto handle = group_controller->CreateLocalEndpointHandle(ids[i]); 414 if (IsValidInterfaceId(ids[i]) && !handle.is_valid()) { 415 // |ids[i]| itself is valid but handle creation failed. In that case, mark 416 // deserialization as failed but continue to deserialize the rest of 417 // handles. 418 result = false; 419 } 420 421 associated_endpoint_handles_.push_back(std::move(handle)); 422 ids[i] = kInvalidInterfaceId; 423 } 424 return result; 425 } 426 427 void Message::SerializeIfNecessary() { 428 MojoResult rv = MojoSerializeMessage(handle_->value(), nullptr); 429 if (rv == MOJO_RESULT_FAILED_PRECONDITION) 430 return; 431 432 // Reconstruct this Message instance from the serialized message's handle. 433 *this = Message(std::move(handle_)); 434 } 435 436 std::unique_ptr<internal::UnserializedMessageContext> 437 Message::TakeUnserializedContext( 438 const internal::UnserializedMessageContext::Tag* tag) { 439 DCHECK(handle_.is_valid()); 440 uintptr_t context_value = 0; 441 MojoResult rv = 442 MojoGetMessageContext(handle_->value(), nullptr, &context_value); 443 if (rv == MOJO_RESULT_NOT_FOUND) 444 return nullptr; 445 DCHECK_EQ(MOJO_RESULT_OK, rv); 446 447 auto* context = 448 reinterpret_cast<internal::UnserializedMessageContext*>(context_value); 449 if (context->tag() != tag) 450 return nullptr; 451 452 // Detach the context from the message. 453 rv = MojoSetMessageContext(handle_->value(), 0, nullptr, nullptr, nullptr); 454 DCHECK_EQ(MOJO_RESULT_OK, rv); 455 return base::WrapUnique(context); 456 } 457 458 bool MessageReceiver::PrefersSerializedMessages() { 459 return false; 460 } 461 462 PassThroughFilter::PassThroughFilter() {} 463 464 PassThroughFilter::~PassThroughFilter() {} 465 466 bool PassThroughFilter::Accept(Message* message) { 467 return true; 468 } 469 470 SyncMessageResponseContext::SyncMessageResponseContext() 471 : outer_context_(current()) { 472 g_tls_sync_response_context.Get().Set(this); 473 } 474 475 SyncMessageResponseContext::~SyncMessageResponseContext() { 476 DCHECK_EQ(current(), this); 477 g_tls_sync_response_context.Get().Set(outer_context_); 478 } 479 480 // static 481 SyncMessageResponseContext* SyncMessageResponseContext::current() { 482 return g_tls_sync_response_context.Get().Get(); 483 } 484 485 void SyncMessageResponseContext::ReportBadMessage(const std::string& error) { 486 GetBadMessageCallback().Run(error); 487 } 488 489 ReportBadMessageCallback SyncMessageResponseContext::GetBadMessageCallback() { 490 DCHECK(!response_.IsNull()); 491 return base::BindOnce(&DoNotifyBadMessage, std::move(response_)); 492 } 493 494 MojoResult ReadMessage(MessagePipeHandle handle, Message* message) { 495 ScopedMessageHandle message_handle; 496 MojoResult rv = 497 ReadMessageNew(handle, &message_handle, MOJO_READ_MESSAGE_FLAG_NONE); 498 if (rv != MOJO_RESULT_OK) 499 return rv; 500 501 *message = Message(std::move(message_handle)); 502 return MOJO_RESULT_OK; 503 } 504 505 void ReportBadMessage(const std::string& error) { 506 internal::MessageDispatchContext* context = 507 internal::MessageDispatchContext::current(); 508 DCHECK(context); 509 context->GetBadMessageCallback().Run(error); 510 } 511 512 ReportBadMessageCallback GetBadMessageCallback() { 513 internal::MessageDispatchContext* context = 514 internal::MessageDispatchContext::current(); 515 DCHECK(context); 516 return context->GetBadMessageCallback(); 517 } 518 519 namespace internal { 520 521 MessageHeaderV2::MessageHeaderV2() = default; 522 523 MessageDispatchContext::MessageDispatchContext(Message* message) 524 : outer_context_(current()), message_(message) { 525 g_tls_message_dispatch_context.Get().Set(this); 526 } 527 528 MessageDispatchContext::~MessageDispatchContext() { 529 DCHECK_EQ(current(), this); 530 g_tls_message_dispatch_context.Get().Set(outer_context_); 531 } 532 533 // static 534 MessageDispatchContext* MessageDispatchContext::current() { 535 return g_tls_message_dispatch_context.Get().Get(); 536 } 537 538 ReportBadMessageCallback MessageDispatchContext::GetBadMessageCallback() { 539 DCHECK(!message_->IsNull()); 540 return base::BindOnce(&DoNotifyBadMessage, std::move(*message_)); 541 } 542 543 // static 544 void SyncMessageResponseSetup::SetCurrentSyncResponseMessage(Message* message) { 545 SyncMessageResponseContext* context = SyncMessageResponseContext::current(); 546 if (context) 547 context->response_ = std::move(*message); 548 } 549 550 } // namespace internal 551 552 } // namespace mojo 553