1 // Copyright 2013 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "mojo/public/cpp/bindings/message.h" 6 7 #include <stddef.h> 8 #include <stdint.h> 9 #include <stdlib.h> 10 11 #include <algorithm> 12 #include <utility> 13 14 #include "base/bind.h" 15 #include "base/lazy_instance.h" 16 #include "base/logging.h" 17 #include "base/strings/stringprintf.h" 18 #include "base/threading/thread_local.h" 19 #include "mojo/public/cpp/bindings/associated_group_controller.h" 20 #include "mojo/public/cpp/bindings/lib/array_internal.h" 21 22 namespace mojo { 23 24 namespace { 25 26 base::LazyInstance<base::ThreadLocalPointer<internal::MessageDispatchContext>>:: 27 DestructorAtExit g_tls_message_dispatch_context = LAZY_INSTANCE_INITIALIZER; 28 29 base::LazyInstance<base::ThreadLocalPointer<SyncMessageResponseContext>>:: 30 DestructorAtExit g_tls_sync_response_context = LAZY_INSTANCE_INITIALIZER; 31 32 void DoNotifyBadMessage(Message message, const std::string& error) { 33 message.NotifyBadMessage(error); 34 } 35 36 } // namespace 37 38 Message::Message() { 39 } 40 41 Message::Message(Message&& other) 42 : buffer_(std::move(other.buffer_)), 43 handles_(std::move(other.handles_)), 44 associated_endpoint_handles_( 45 std::move(other.associated_endpoint_handles_)) {} 46 47 Message::~Message() { 48 CloseHandles(); 49 } 50 51 Message& Message::operator=(Message&& other) { 52 Reset(); 53 std::swap(other.buffer_, buffer_); 54 std::swap(other.handles_, handles_); 55 std::swap(other.associated_endpoint_handles_, associated_endpoint_handles_); 56 return *this; 57 } 58 59 void Message::Reset() { 60 CloseHandles(); 61 handles_.clear(); 62 associated_endpoint_handles_.clear(); 63 buffer_.reset(); 64 } 65 66 void Message::Initialize(size_t capacity, bool zero_initialized) { 67 DCHECK(!buffer_); 68 buffer_.reset(new internal::MessageBuffer(capacity, zero_initialized)); 69 } 70 71 void Message::InitializeFromMojoMessage(ScopedMessageHandle message, 72 uint32_t num_bytes, 73 std::vector<Handle>* handles) { 74 DCHECK(!buffer_); 75 buffer_.reset(new internal::MessageBuffer(std::move(message), num_bytes)); 76 handles_.swap(*handles); 77 } 78 79 const uint8_t* Message::payload() const { 80 if (version() < 2) 81 return data() + header()->num_bytes; 82 83 return static_cast<const uint8_t*>(header_v2()->payload.Get()); 84 } 85 86 uint32_t Message::payload_num_bytes() const { 87 DCHECK_GE(data_num_bytes(), header()->num_bytes); 88 size_t num_bytes; 89 if (version() < 2) { 90 num_bytes = data_num_bytes() - header()->num_bytes; 91 } else { 92 auto payload = reinterpret_cast<uintptr_t>(header_v2()->payload.Get()); 93 if (!payload) { 94 num_bytes = 0; 95 } else { 96 auto payload_end = 97 reinterpret_cast<uintptr_t>(header_v2()->payload_interface_ids.Get()); 98 if (!payload_end) 99 payload_end = reinterpret_cast<uintptr_t>(data() + data_num_bytes()); 100 DCHECK_GE(payload_end, payload); 101 num_bytes = payload_end - payload; 102 } 103 } 104 DCHECK_LE(num_bytes, std::numeric_limits<uint32_t>::max()); 105 return static_cast<uint32_t>(num_bytes); 106 } 107 108 uint32_t Message::payload_num_interface_ids() const { 109 auto* array_pointer = 110 version() < 2 ? nullptr : header_v2()->payload_interface_ids.Get(); 111 return array_pointer ? static_cast<uint32_t>(array_pointer->size()) : 0; 112 } 113 114 const uint32_t* Message::payload_interface_ids() const { 115 auto* array_pointer = 116 version() < 2 ? nullptr : header_v2()->payload_interface_ids.Get(); 117 return array_pointer ? array_pointer->storage() : nullptr; 118 } 119 120 ScopedMessageHandle Message::TakeMojoMessage() { 121 // If there are associated endpoints transferred, 122 // SerializeAssociatedEndpointHandles() must be called before this method. 123 DCHECK(associated_endpoint_handles_.empty()); 124 125 if (handles_.empty()) // Fast path for the common case: No handles. 126 return buffer_->TakeMessage(); 127 128 // Allocate a new message with space for the handles, then copy the buffer 129 // contents into it. 130 // 131 // TODO(rockot): We could avoid this copy by extending GetSerializedSize() 132 // behavior to collect handles. It's unoptimized for now because it's much 133 // more common to have messages with no handles. 134 ScopedMessageHandle new_message; 135 MojoResult rv = AllocMessage( 136 data_num_bytes(), 137 handles_.empty() ? nullptr 138 : reinterpret_cast<const MojoHandle*>(handles_.data()), 139 handles_.size(), 140 MOJO_ALLOC_MESSAGE_FLAG_NONE, 141 &new_message); 142 CHECK_EQ(rv, MOJO_RESULT_OK); 143 handles_.clear(); 144 145 void* new_buffer = nullptr; 146 rv = GetMessageBuffer(new_message.get(), &new_buffer); 147 CHECK_EQ(rv, MOJO_RESULT_OK); 148 149 memcpy(new_buffer, data(), data_num_bytes()); 150 buffer_.reset(); 151 152 return new_message; 153 } 154 155 void Message::NotifyBadMessage(const std::string& error) { 156 DCHECK(buffer_); 157 buffer_->NotifyBadMessage(error); 158 } 159 160 void Message::CloseHandles() { 161 for (std::vector<Handle>::iterator it = handles_.begin(); 162 it != handles_.end(); ++it) { 163 if (it->is_valid()) 164 CloseRaw(*it); 165 } 166 } 167 168 void Message::SerializeAssociatedEndpointHandles( 169 AssociatedGroupController* group_controller) { 170 if (associated_endpoint_handles_.empty()) 171 return; 172 173 DCHECK_GE(version(), 2u); 174 DCHECK(header_v2()->payload_interface_ids.is_null()); 175 176 size_t size = associated_endpoint_handles_.size(); 177 auto* data = internal::Array_Data<uint32_t>::New(size, buffer()); 178 header_v2()->payload_interface_ids.Set(data); 179 180 for (size_t i = 0; i < size; ++i) { 181 ScopedInterfaceEndpointHandle& handle = associated_endpoint_handles_[i]; 182 183 DCHECK(handle.pending_association()); 184 data->storage()[i] = 185 group_controller->AssociateInterface(std::move(handle)); 186 } 187 associated_endpoint_handles_.clear(); 188 } 189 190 bool Message::DeserializeAssociatedEndpointHandles( 191 AssociatedGroupController* group_controller) { 192 associated_endpoint_handles_.clear(); 193 194 uint32_t num_ids = payload_num_interface_ids(); 195 if (num_ids == 0) 196 return true; 197 198 associated_endpoint_handles_.reserve(num_ids); 199 uint32_t* ids = header_v2()->payload_interface_ids.Get()->storage(); 200 bool result = true; 201 for (uint32_t i = 0; i < num_ids; ++i) { 202 auto handle = group_controller->CreateLocalEndpointHandle(ids[i]); 203 if (IsValidInterfaceId(ids[i]) && !handle.is_valid()) { 204 // |ids[i]| itself is valid but handle creation failed. In that case, mark 205 // deserialization as failed but continue to deserialize the rest of 206 // handles. 207 result = false; 208 } 209 210 associated_endpoint_handles_.push_back(std::move(handle)); 211 ids[i] = kInvalidInterfaceId; 212 } 213 return result; 214 } 215 216 PassThroughFilter::PassThroughFilter() {} 217 218 PassThroughFilter::~PassThroughFilter() {} 219 220 bool PassThroughFilter::Accept(Message* message) { return true; } 221 222 SyncMessageResponseContext::SyncMessageResponseContext() 223 : outer_context_(current()) { 224 g_tls_sync_response_context.Get().Set(this); 225 } 226 227 SyncMessageResponseContext::~SyncMessageResponseContext() { 228 DCHECK_EQ(current(), this); 229 g_tls_sync_response_context.Get().Set(outer_context_); 230 } 231 232 // static 233 SyncMessageResponseContext* SyncMessageResponseContext::current() { 234 return g_tls_sync_response_context.Get().Get(); 235 } 236 237 void SyncMessageResponseContext::ReportBadMessage(const std::string& error) { 238 GetBadMessageCallback().Run(error); 239 } 240 241 const ReportBadMessageCallback& 242 SyncMessageResponseContext::GetBadMessageCallback() { 243 if (bad_message_callback_.is_null()) { 244 bad_message_callback_ = 245 base::Bind(&DoNotifyBadMessage, base::Passed(&response_)); 246 } 247 return bad_message_callback_; 248 } 249 250 MojoResult ReadMessage(MessagePipeHandle handle, Message* message) { 251 MojoResult rv; 252 253 std::vector<Handle> handles; 254 ScopedMessageHandle mojo_message; 255 uint32_t num_bytes = 0, num_handles = 0; 256 rv = ReadMessageNew(handle, 257 &mojo_message, 258 &num_bytes, 259 nullptr, 260 &num_handles, 261 MOJO_READ_MESSAGE_FLAG_NONE); 262 if (rv == MOJO_RESULT_RESOURCE_EXHAUSTED) { 263 DCHECK_GT(num_handles, 0u); 264 handles.resize(num_handles); 265 rv = ReadMessageNew(handle, 266 &mojo_message, 267 &num_bytes, 268 reinterpret_cast<MojoHandle*>(handles.data()), 269 &num_handles, 270 MOJO_READ_MESSAGE_FLAG_NONE); 271 } 272 273 if (rv != MOJO_RESULT_OK) 274 return rv; 275 276 message->InitializeFromMojoMessage( 277 std::move(mojo_message), num_bytes, &handles); 278 return MOJO_RESULT_OK; 279 } 280 281 void ReportBadMessage(const std::string& error) { 282 internal::MessageDispatchContext* context = 283 internal::MessageDispatchContext::current(); 284 DCHECK(context); 285 context->GetBadMessageCallback().Run(error); 286 } 287 288 ReportBadMessageCallback GetBadMessageCallback() { 289 internal::MessageDispatchContext* context = 290 internal::MessageDispatchContext::current(); 291 DCHECK(context); 292 return context->GetBadMessageCallback(); 293 } 294 295 namespace internal { 296 297 MessageHeaderV2::MessageHeaderV2() = default; 298 299 MessageDispatchContext::MessageDispatchContext(Message* message) 300 : outer_context_(current()), message_(message) { 301 g_tls_message_dispatch_context.Get().Set(this); 302 } 303 304 MessageDispatchContext::~MessageDispatchContext() { 305 DCHECK_EQ(current(), this); 306 g_tls_message_dispatch_context.Get().Set(outer_context_); 307 } 308 309 // static 310 MessageDispatchContext* MessageDispatchContext::current() { 311 return g_tls_message_dispatch_context.Get().Get(); 312 } 313 314 const ReportBadMessageCallback& 315 MessageDispatchContext::GetBadMessageCallback() { 316 if (bad_message_callback_.is_null()) { 317 bad_message_callback_ = 318 base::Bind(&DoNotifyBadMessage, base::Passed(message_)); 319 } 320 return bad_message_callback_; 321 } 322 323 // static 324 void SyncMessageResponseSetup::SetCurrentSyncResponseMessage(Message* message) { 325 SyncMessageResponseContext* context = SyncMessageResponseContext::current(); 326 if (context) 327 context->response_ = std::move(*message); 328 } 329 330 } // namespace internal 331 332 } // namespace mojo 333