1 // Copyright 2015 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "mojo/public/cpp/bindings/interface_endpoint_client.h" 6 7 #include <stdint.h> 8 9 #include "base/bind.h" 10 #include "base/location.h" 11 #include "base/logging.h" 12 #include "base/macros.h" 13 #include "base/memory/ptr_util.h" 14 #include "base/sequenced_task_runner.h" 15 #include "base/stl_util.h" 16 #include "mojo/public/cpp/bindings/associated_group.h" 17 #include "mojo/public/cpp/bindings/associated_group_controller.h" 18 #include "mojo/public/cpp/bindings/interface_endpoint_controller.h" 19 #include "mojo/public/cpp/bindings/lib/task_runner_helper.h" 20 #include "mojo/public/cpp/bindings/lib/validation_util.h" 21 #include "mojo/public/cpp/bindings/sync_call_restrictions.h" 22 23 namespace mojo { 24 25 // ---------------------------------------------------------------------------- 26 27 namespace { 28 29 void DetermineIfEndpointIsConnected( 30 const base::WeakPtr<InterfaceEndpointClient>& client, 31 base::OnceCallback<void(bool)> callback) { 32 std::move(callback).Run(client && !client->encountered_error()); 33 } 34 35 // When receiving an incoming message which expects a repsonse, 36 // InterfaceEndpointClient creates a ResponderThunk object and passes it to the 37 // incoming message receiver. When the receiver finishes processing the message, 38 // it can provide a response using this object. 39 class ResponderThunk : public MessageReceiverWithStatus { 40 public: 41 explicit ResponderThunk( 42 const base::WeakPtr<InterfaceEndpointClient>& endpoint_client, 43 scoped_refptr<base::SequencedTaskRunner> runner) 44 : endpoint_client_(endpoint_client), 45 accept_was_invoked_(false), 46 task_runner_(std::move(runner)) {} 47 ~ResponderThunk() override { 48 if (!accept_was_invoked_) { 49 // The Service handled a message that was expecting a response 50 // but did not send a response. 51 // We raise an error to signal the calling application that an error 52 // condition occurred. Without this the calling application would have no 53 // way of knowing it should stop waiting for a response. 54 if (task_runner_->RunsTasksInCurrentSequence()) { 55 // Please note that even if this code is run from a different task 56 // runner on the same thread as |task_runner_|, it is okay to directly 57 // call InterfaceEndpointClient::RaiseError(), because it will raise 58 // error from the correct task runner asynchronously. 59 if (endpoint_client_) { 60 endpoint_client_->RaiseError(); 61 } 62 } else { 63 task_runner_->PostTask( 64 FROM_HERE, 65 base::Bind(&InterfaceEndpointClient::RaiseError, endpoint_client_)); 66 } 67 } 68 } 69 70 // MessageReceiver implementation: 71 bool PrefersSerializedMessages() override { 72 return endpoint_client_ && endpoint_client_->PrefersSerializedMessages(); 73 } 74 75 bool Accept(Message* message) override { 76 DCHECK(task_runner_->RunsTasksInCurrentSequence()); 77 accept_was_invoked_ = true; 78 DCHECK(message->has_flag(Message::kFlagIsResponse)); 79 80 bool result = false; 81 82 if (endpoint_client_) 83 result = endpoint_client_->Accept(message); 84 85 return result; 86 } 87 88 // MessageReceiverWithStatus implementation: 89 bool IsConnected() override { 90 DCHECK(task_runner_->RunsTasksInCurrentSequence()); 91 return endpoint_client_ && !endpoint_client_->encountered_error(); 92 } 93 94 void IsConnectedAsync(base::OnceCallback<void(bool)> callback) override { 95 if (task_runner_->RunsTasksInCurrentSequence()) { 96 DetermineIfEndpointIsConnected(endpoint_client_, std::move(callback)); 97 } else { 98 task_runner_->PostTask( 99 FROM_HERE, base::BindOnce(&DetermineIfEndpointIsConnected, 100 endpoint_client_, std::move(callback))); 101 } 102 } 103 104 private: 105 base::WeakPtr<InterfaceEndpointClient> endpoint_client_; 106 bool accept_was_invoked_; 107 scoped_refptr<base::SequencedTaskRunner> task_runner_; 108 109 DISALLOW_COPY_AND_ASSIGN(ResponderThunk); 110 }; 111 112 } // namespace 113 114 // ---------------------------------------------------------------------------- 115 116 InterfaceEndpointClient::SyncResponseInfo::SyncResponseInfo( 117 bool* in_response_received) 118 : response_received(in_response_received) {} 119 120 InterfaceEndpointClient::SyncResponseInfo::~SyncResponseInfo() {} 121 122 // ---------------------------------------------------------------------------- 123 124 InterfaceEndpointClient::HandleIncomingMessageThunk::HandleIncomingMessageThunk( 125 InterfaceEndpointClient* owner) 126 : owner_(owner) {} 127 128 InterfaceEndpointClient::HandleIncomingMessageThunk:: 129 ~HandleIncomingMessageThunk() {} 130 131 bool InterfaceEndpointClient::HandleIncomingMessageThunk::Accept( 132 Message* message) { 133 return owner_->HandleValidatedMessage(message); 134 } 135 136 // ---------------------------------------------------------------------------- 137 138 InterfaceEndpointClient::InterfaceEndpointClient( 139 ScopedInterfaceEndpointHandle handle, 140 MessageReceiverWithResponderStatus* receiver, 141 std::unique_ptr<MessageReceiver> payload_validator, 142 bool expect_sync_requests, 143 scoped_refptr<base::SequencedTaskRunner> runner, 144 uint32_t interface_version) 145 : expect_sync_requests_(expect_sync_requests), 146 handle_(std::move(handle)), 147 incoming_receiver_(receiver), 148 thunk_(this), 149 filters_(&thunk_), 150 task_runner_(std::move(runner)), 151 control_message_proxy_(this), 152 control_message_handler_(interface_version), 153 weak_ptr_factory_(this) { 154 DCHECK(handle_.is_valid()); 155 156 // TODO(yzshen): the way to use validator (or message filter in general) 157 // directly is a little awkward. 158 if (payload_validator) 159 filters_.Append(std::move(payload_validator)); 160 161 if (handle_.pending_association()) { 162 handle_.SetAssociationEventHandler(base::Bind( 163 &InterfaceEndpointClient::OnAssociationEvent, base::Unretained(this))); 164 } else { 165 InitControllerIfNecessary(); 166 } 167 } 168 169 InterfaceEndpointClient::~InterfaceEndpointClient() { 170 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); 171 if (controller_) 172 handle_.group_controller()->DetachEndpointClient(handle_); 173 } 174 175 AssociatedGroup* InterfaceEndpointClient::associated_group() { 176 if (!associated_group_) 177 associated_group_ = std::make_unique<AssociatedGroup>(handle_); 178 return associated_group_.get(); 179 } 180 181 ScopedInterfaceEndpointHandle InterfaceEndpointClient::PassHandle() { 182 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); 183 DCHECK(!has_pending_responders()); 184 185 if (!handle_.is_valid()) 186 return ScopedInterfaceEndpointHandle(); 187 188 handle_.SetAssociationEventHandler( 189 ScopedInterfaceEndpointHandle::AssociationEventCallback()); 190 191 if (controller_) { 192 controller_ = nullptr; 193 handle_.group_controller()->DetachEndpointClient(handle_); 194 } 195 196 return std::move(handle_); 197 } 198 199 void InterfaceEndpointClient::AddFilter( 200 std::unique_ptr<MessageReceiver> filter) { 201 filters_.Append(std::move(filter)); 202 } 203 204 void InterfaceEndpointClient::RaiseError() { 205 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); 206 207 if (!handle_.pending_association()) 208 handle_.group_controller()->RaiseError(); 209 } 210 211 void InterfaceEndpointClient::CloseWithReason(uint32_t custom_reason, 212 const std::string& description) { 213 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); 214 215 auto handle = PassHandle(); 216 handle.ResetWithReason(custom_reason, description); 217 } 218 219 bool InterfaceEndpointClient::PrefersSerializedMessages() { 220 auto* controller = handle_.group_controller(); 221 return controller && controller->PrefersSerializedMessages(); 222 } 223 224 bool InterfaceEndpointClient::Accept(Message* message) { 225 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); 226 DCHECK(!message->has_flag(Message::kFlagExpectsResponse)); 227 DCHECK(!handle_.pending_association()); 228 229 // This has to been done even if connection error has occurred. For example, 230 // the message contains a pending associated request. The user may try to use 231 // the corresponding associated interface pointer after sending this message. 232 // That associated interface pointer has to join an associated group in order 233 // to work properly. 234 if (!message->associated_endpoint_handles()->empty()) 235 message->SerializeAssociatedEndpointHandles(handle_.group_controller()); 236 237 if (encountered_error_) 238 return false; 239 240 InitControllerIfNecessary(); 241 242 return controller_->SendMessage(message); 243 } 244 245 bool InterfaceEndpointClient::AcceptWithResponder( 246 Message* message, 247 std::unique_ptr<MessageReceiver> responder) { 248 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); 249 DCHECK(message->has_flag(Message::kFlagExpectsResponse)); 250 DCHECK(!handle_.pending_association()); 251 252 // Please see comments in Accept(). 253 if (!message->associated_endpoint_handles()->empty()) 254 message->SerializeAssociatedEndpointHandles(handle_.group_controller()); 255 256 if (encountered_error_) 257 return false; 258 259 InitControllerIfNecessary(); 260 261 // Reserve 0 in case we want it to convey special meaning in the future. 262 uint64_t request_id = next_request_id_++; 263 if (request_id == 0) 264 request_id = next_request_id_++; 265 266 message->set_request_id(request_id); 267 268 bool is_sync = message->has_flag(Message::kFlagIsSync); 269 if (!controller_->SendMessage(message)) 270 return false; 271 272 if (!is_sync) { 273 async_responders_[request_id] = std::move(responder); 274 return true; 275 } 276 277 SyncCallRestrictions::AssertSyncCallAllowed(); 278 279 bool response_received = false; 280 sync_responses_.insert(std::make_pair( 281 request_id, std::make_unique<SyncResponseInfo>(&response_received))); 282 283 base::WeakPtr<InterfaceEndpointClient> weak_self = 284 weak_ptr_factory_.GetWeakPtr(); 285 controller_->SyncWatch(&response_received); 286 // Make sure that this instance hasn't been destroyed. 287 if (weak_self) { 288 DCHECK(base::ContainsKey(sync_responses_, request_id)); 289 auto iter = sync_responses_.find(request_id); 290 DCHECK_EQ(&response_received, iter->second->response_received); 291 if (response_received) { 292 ignore_result(responder->Accept(&iter->second->response)); 293 } else { 294 DVLOG(1) << "Mojo sync call returns without receiving a response. " 295 << "Typcially it is because the interface has been " 296 << "disconnected."; 297 } 298 sync_responses_.erase(iter); 299 } 300 301 return true; 302 } 303 304 bool InterfaceEndpointClient::HandleIncomingMessage(Message* message) { 305 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); 306 return filters_.Accept(message); 307 } 308 309 void InterfaceEndpointClient::NotifyError( 310 const base::Optional<DisconnectReason>& reason) { 311 DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); 312 313 if (encountered_error_) 314 return; 315 encountered_error_ = true; 316 317 // Response callbacks may hold on to resource, and there's no need to keep 318 // them alive any longer. Note that it's allowed that a pending response 319 // callback may own this endpoint, so we simply move the responders onto the 320 // stack here and let them be destroyed when the stack unwinds. 321 AsyncResponderMap responders = std::move(async_responders_); 322 323 control_message_proxy_.OnConnectionError(); 324 325 if (error_handler_) { 326 std::move(error_handler_).Run(); 327 } else if (error_with_reason_handler_) { 328 if (reason) { 329 std::move(error_with_reason_handler_) 330 .Run(reason->custom_reason, reason->description); 331 } else { 332 std::move(error_with_reason_handler_).Run(0, std::string()); 333 } 334 } 335 } 336 337 void InterfaceEndpointClient::QueryVersion( 338 const base::Callback<void(uint32_t)>& callback) { 339 control_message_proxy_.QueryVersion(callback); 340 } 341 342 void InterfaceEndpointClient::RequireVersion(uint32_t version) { 343 control_message_proxy_.RequireVersion(version); 344 } 345 346 void InterfaceEndpointClient::FlushForTesting() { 347 control_message_proxy_.FlushForTesting(); 348 } 349 350 void InterfaceEndpointClient::InitControllerIfNecessary() { 351 if (controller_ || handle_.pending_association()) 352 return; 353 354 controller_ = handle_.group_controller()->AttachEndpointClient(handle_, this, 355 task_runner_); 356 if (expect_sync_requests_) 357 controller_->AllowWokenUpBySyncWatchOnSameThread(); 358 } 359 360 void InterfaceEndpointClient::OnAssociationEvent( 361 ScopedInterfaceEndpointHandle::AssociationEvent event) { 362 if (event == ScopedInterfaceEndpointHandle::ASSOCIATED) { 363 InitControllerIfNecessary(); 364 } else if (event == 365 ScopedInterfaceEndpointHandle::PEER_CLOSED_BEFORE_ASSOCIATION) { 366 task_runner_->PostTask(FROM_HERE, 367 base::Bind(&InterfaceEndpointClient::NotifyError, 368 weak_ptr_factory_.GetWeakPtr(), 369 handle_.disconnect_reason())); 370 } 371 } 372 373 bool InterfaceEndpointClient::HandleValidatedMessage(Message* message) { 374 DCHECK_EQ(handle_.id(), message->interface_id()); 375 376 if (encountered_error_) { 377 // This message is received after error has been encountered. For associated 378 // interfaces, this means the remote side sends a 379 // PeerAssociatedEndpointClosed event but continues to send more messages 380 // for the same interface. Close the pipe because this shouldn't happen. 381 DVLOG(1) << "A message is received for an interface after it has been " 382 << "disconnected. Closing the pipe."; 383 return false; 384 } 385 386 if (message->has_flag(Message::kFlagExpectsResponse)) { 387 std::unique_ptr<MessageReceiverWithStatus> responder = 388 std::make_unique<ResponderThunk>(weak_ptr_factory_.GetWeakPtr(), 389 task_runner_); 390 if (mojo::internal::ControlMessageHandler::IsControlMessage(message)) { 391 return control_message_handler_.AcceptWithResponder(message, 392 std::move(responder)); 393 } else { 394 return incoming_receiver_->AcceptWithResponder(message, 395 std::move(responder)); 396 } 397 } else if (message->has_flag(Message::kFlagIsResponse)) { 398 uint64_t request_id = message->request_id(); 399 400 if (message->has_flag(Message::kFlagIsSync)) { 401 auto it = sync_responses_.find(request_id); 402 if (it == sync_responses_.end()) 403 return false; 404 it->second->response = std::move(*message); 405 *it->second->response_received = true; 406 return true; 407 } 408 409 auto it = async_responders_.find(request_id); 410 if (it == async_responders_.end()) 411 return false; 412 std::unique_ptr<MessageReceiver> responder = std::move(it->second); 413 async_responders_.erase(it); 414 return responder->Accept(message); 415 } else { 416 if (mojo::internal::ControlMessageHandler::IsControlMessage(message)) 417 return control_message_handler_.Accept(message); 418 419 return incoming_receiver_->Accept(message); 420 } 421 } 422 423 } // namespace mojo 424