Home | History | Annotate | Download | only in lib
      1 // Copyright 2015 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "mojo/public/cpp/bindings/interface_endpoint_client.h"
      6 
      7 #include <stdint.h>
      8 
      9 #include "base/bind.h"
     10 #include "base/location.h"
     11 #include "base/logging.h"
     12 #include "base/macros.h"
     13 #include "base/memory/ptr_util.h"
     14 #include "base/sequenced_task_runner.h"
     15 #include "base/stl_util.h"
     16 #include "mojo/public/cpp/bindings/associated_group.h"
     17 #include "mojo/public/cpp/bindings/associated_group_controller.h"
     18 #include "mojo/public/cpp/bindings/interface_endpoint_controller.h"
     19 #include "mojo/public/cpp/bindings/lib/task_runner_helper.h"
     20 #include "mojo/public/cpp/bindings/lib/validation_util.h"
     21 #include "mojo/public/cpp/bindings/sync_call_restrictions.h"
     22 
     23 namespace mojo {
     24 
     25 // ----------------------------------------------------------------------------
     26 
     27 namespace {
     28 
     29 void DetermineIfEndpointIsConnected(
     30     const base::WeakPtr<InterfaceEndpointClient>& client,
     31     base::OnceCallback<void(bool)> callback) {
     32   std::move(callback).Run(client && !client->encountered_error());
     33 }
     34 
     35 // When receiving an incoming message which expects a repsonse,
     36 // InterfaceEndpointClient creates a ResponderThunk object and passes it to the
     37 // incoming message receiver. When the receiver finishes processing the message,
     38 // it can provide a response using this object.
     39 class ResponderThunk : public MessageReceiverWithStatus {
     40  public:
     41   explicit ResponderThunk(
     42       const base::WeakPtr<InterfaceEndpointClient>& endpoint_client,
     43       scoped_refptr<base::SequencedTaskRunner> runner)
     44       : endpoint_client_(endpoint_client),
     45         accept_was_invoked_(false),
     46         task_runner_(std::move(runner)) {}
     47   ~ResponderThunk() override {
     48     if (!accept_was_invoked_) {
     49       // The Service handled a message that was expecting a response
     50       // but did not send a response.
     51       // We raise an error to signal the calling application that an error
     52       // condition occurred. Without this the calling application would have no
     53       // way of knowing it should stop waiting for a response.
     54       if (task_runner_->RunsTasksInCurrentSequence()) {
     55         // Please note that even if this code is run from a different task
     56         // runner on the same thread as |task_runner_|, it is okay to directly
     57         // call InterfaceEndpointClient::RaiseError(), because it will raise
     58         // error from the correct task runner asynchronously.
     59         if (endpoint_client_) {
     60           endpoint_client_->RaiseError();
     61         }
     62       } else {
     63         task_runner_->PostTask(
     64             FROM_HERE,
     65             base::Bind(&InterfaceEndpointClient::RaiseError, endpoint_client_));
     66       }
     67     }
     68   }
     69 
     70   // MessageReceiver implementation:
     71   bool PrefersSerializedMessages() override {
     72     return endpoint_client_ && endpoint_client_->PrefersSerializedMessages();
     73   }
     74 
     75   bool Accept(Message* message) override {
     76     DCHECK(task_runner_->RunsTasksInCurrentSequence());
     77     accept_was_invoked_ = true;
     78     DCHECK(message->has_flag(Message::kFlagIsResponse));
     79 
     80     bool result = false;
     81 
     82     if (endpoint_client_)
     83       result = endpoint_client_->Accept(message);
     84 
     85     return result;
     86   }
     87 
     88   // MessageReceiverWithStatus implementation:
     89   bool IsConnected() override {
     90     DCHECK(task_runner_->RunsTasksInCurrentSequence());
     91     return endpoint_client_ && !endpoint_client_->encountered_error();
     92   }
     93 
     94   void IsConnectedAsync(base::OnceCallback<void(bool)> callback) override {
     95     if (task_runner_->RunsTasksInCurrentSequence()) {
     96       DetermineIfEndpointIsConnected(endpoint_client_, std::move(callback));
     97     } else {
     98       task_runner_->PostTask(
     99           FROM_HERE, base::BindOnce(&DetermineIfEndpointIsConnected,
    100                                     endpoint_client_, std::move(callback)));
    101     }
    102   }
    103 
    104  private:
    105   base::WeakPtr<InterfaceEndpointClient> endpoint_client_;
    106   bool accept_was_invoked_;
    107   scoped_refptr<base::SequencedTaskRunner> task_runner_;
    108 
    109   DISALLOW_COPY_AND_ASSIGN(ResponderThunk);
    110 };
    111 
    112 }  // namespace
    113 
    114 // ----------------------------------------------------------------------------
    115 
    116 InterfaceEndpointClient::SyncResponseInfo::SyncResponseInfo(
    117     bool* in_response_received)
    118     : response_received(in_response_received) {}
    119 
    120 InterfaceEndpointClient::SyncResponseInfo::~SyncResponseInfo() {}
    121 
    122 // ----------------------------------------------------------------------------
    123 
    124 InterfaceEndpointClient::HandleIncomingMessageThunk::HandleIncomingMessageThunk(
    125     InterfaceEndpointClient* owner)
    126     : owner_(owner) {}
    127 
    128 InterfaceEndpointClient::HandleIncomingMessageThunk::
    129     ~HandleIncomingMessageThunk() {}
    130 
    131 bool InterfaceEndpointClient::HandleIncomingMessageThunk::Accept(
    132     Message* message) {
    133   return owner_->HandleValidatedMessage(message);
    134 }
    135 
    136 // ----------------------------------------------------------------------------
    137 
    138 InterfaceEndpointClient::InterfaceEndpointClient(
    139     ScopedInterfaceEndpointHandle handle,
    140     MessageReceiverWithResponderStatus* receiver,
    141     std::unique_ptr<MessageReceiver> payload_validator,
    142     bool expect_sync_requests,
    143     scoped_refptr<base::SequencedTaskRunner> runner,
    144     uint32_t interface_version)
    145     : expect_sync_requests_(expect_sync_requests),
    146       handle_(std::move(handle)),
    147       incoming_receiver_(receiver),
    148       thunk_(this),
    149       filters_(&thunk_),
    150       task_runner_(std::move(runner)),
    151       control_message_proxy_(this),
    152       control_message_handler_(interface_version),
    153       weak_ptr_factory_(this) {
    154   DCHECK(handle_.is_valid());
    155 
    156   // TODO(yzshen): the way to use validator (or message filter in general)
    157   // directly is a little awkward.
    158   if (payload_validator)
    159     filters_.Append(std::move(payload_validator));
    160 
    161   if (handle_.pending_association()) {
    162     handle_.SetAssociationEventHandler(base::Bind(
    163         &InterfaceEndpointClient::OnAssociationEvent, base::Unretained(this)));
    164   } else {
    165     InitControllerIfNecessary();
    166   }
    167 }
    168 
    169 InterfaceEndpointClient::~InterfaceEndpointClient() {
    170   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
    171   if (controller_)
    172     handle_.group_controller()->DetachEndpointClient(handle_);
    173 }
    174 
    175 AssociatedGroup* InterfaceEndpointClient::associated_group() {
    176   if (!associated_group_)
    177     associated_group_ = std::make_unique<AssociatedGroup>(handle_);
    178   return associated_group_.get();
    179 }
    180 
    181 ScopedInterfaceEndpointHandle InterfaceEndpointClient::PassHandle() {
    182   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
    183   DCHECK(!has_pending_responders());
    184 
    185   if (!handle_.is_valid())
    186     return ScopedInterfaceEndpointHandle();
    187 
    188   handle_.SetAssociationEventHandler(
    189       ScopedInterfaceEndpointHandle::AssociationEventCallback());
    190 
    191   if (controller_) {
    192     controller_ = nullptr;
    193     handle_.group_controller()->DetachEndpointClient(handle_);
    194   }
    195 
    196   return std::move(handle_);
    197 }
    198 
    199 void InterfaceEndpointClient::AddFilter(
    200     std::unique_ptr<MessageReceiver> filter) {
    201   filters_.Append(std::move(filter));
    202 }
    203 
    204 void InterfaceEndpointClient::RaiseError() {
    205   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
    206 
    207   if (!handle_.pending_association())
    208     handle_.group_controller()->RaiseError();
    209 }
    210 
    211 void InterfaceEndpointClient::CloseWithReason(uint32_t custom_reason,
    212                                               const std::string& description) {
    213   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
    214 
    215   auto handle = PassHandle();
    216   handle.ResetWithReason(custom_reason, description);
    217 }
    218 
    219 bool InterfaceEndpointClient::PrefersSerializedMessages() {
    220   auto* controller = handle_.group_controller();
    221   return controller && controller->PrefersSerializedMessages();
    222 }
    223 
    224 bool InterfaceEndpointClient::Accept(Message* message) {
    225   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
    226   DCHECK(!message->has_flag(Message::kFlagExpectsResponse));
    227   DCHECK(!handle_.pending_association());
    228 
    229   // This has to been done even if connection error has occurred. For example,
    230   // the message contains a pending associated request. The user may try to use
    231   // the corresponding associated interface pointer after sending this message.
    232   // That associated interface pointer has to join an associated group in order
    233   // to work properly.
    234   if (!message->associated_endpoint_handles()->empty())
    235     message->SerializeAssociatedEndpointHandles(handle_.group_controller());
    236 
    237   if (encountered_error_)
    238     return false;
    239 
    240   InitControllerIfNecessary();
    241 
    242   return controller_->SendMessage(message);
    243 }
    244 
    245 bool InterfaceEndpointClient::AcceptWithResponder(
    246     Message* message,
    247     std::unique_ptr<MessageReceiver> responder) {
    248   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
    249   DCHECK(message->has_flag(Message::kFlagExpectsResponse));
    250   DCHECK(!handle_.pending_association());
    251 
    252   // Please see comments in Accept().
    253   if (!message->associated_endpoint_handles()->empty())
    254     message->SerializeAssociatedEndpointHandles(handle_.group_controller());
    255 
    256   if (encountered_error_)
    257     return false;
    258 
    259   InitControllerIfNecessary();
    260 
    261   // Reserve 0 in case we want it to convey special meaning in the future.
    262   uint64_t request_id = next_request_id_++;
    263   if (request_id == 0)
    264     request_id = next_request_id_++;
    265 
    266   message->set_request_id(request_id);
    267 
    268   bool is_sync = message->has_flag(Message::kFlagIsSync);
    269   if (!controller_->SendMessage(message))
    270     return false;
    271 
    272   if (!is_sync) {
    273     async_responders_[request_id] = std::move(responder);
    274     return true;
    275   }
    276 
    277   SyncCallRestrictions::AssertSyncCallAllowed();
    278 
    279   bool response_received = false;
    280   sync_responses_.insert(std::make_pair(
    281       request_id, std::make_unique<SyncResponseInfo>(&response_received)));
    282 
    283   base::WeakPtr<InterfaceEndpointClient> weak_self =
    284       weak_ptr_factory_.GetWeakPtr();
    285   controller_->SyncWatch(&response_received);
    286   // Make sure that this instance hasn't been destroyed.
    287   if (weak_self) {
    288     DCHECK(base::ContainsKey(sync_responses_, request_id));
    289     auto iter = sync_responses_.find(request_id);
    290     DCHECK_EQ(&response_received, iter->second->response_received);
    291     if (response_received) {
    292       ignore_result(responder->Accept(&iter->second->response));
    293     } else {
    294       DVLOG(1) << "Mojo sync call returns without receiving a response. "
    295                << "Typcially it is because the interface has been "
    296                << "disconnected.";
    297     }
    298     sync_responses_.erase(iter);
    299   }
    300 
    301   return true;
    302 }
    303 
    304 bool InterfaceEndpointClient::HandleIncomingMessage(Message* message) {
    305   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
    306   return filters_.Accept(message);
    307 }
    308 
    309 void InterfaceEndpointClient::NotifyError(
    310     const base::Optional<DisconnectReason>& reason) {
    311   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
    312 
    313   if (encountered_error_)
    314     return;
    315   encountered_error_ = true;
    316 
    317   // Response callbacks may hold on to resource, and there's no need to keep
    318   // them alive any longer. Note that it's allowed that a pending response
    319   // callback may own this endpoint, so we simply move the responders onto the
    320   // stack here and let them be destroyed when the stack unwinds.
    321   AsyncResponderMap responders = std::move(async_responders_);
    322 
    323   control_message_proxy_.OnConnectionError();
    324 
    325   if (error_handler_) {
    326     std::move(error_handler_).Run();
    327   } else if (error_with_reason_handler_) {
    328     if (reason) {
    329       std::move(error_with_reason_handler_)
    330           .Run(reason->custom_reason, reason->description);
    331     } else {
    332       std::move(error_with_reason_handler_).Run(0, std::string());
    333     }
    334   }
    335 }
    336 
    337 void InterfaceEndpointClient::QueryVersion(
    338     const base::Callback<void(uint32_t)>& callback) {
    339   control_message_proxy_.QueryVersion(callback);
    340 }
    341 
    342 void InterfaceEndpointClient::RequireVersion(uint32_t version) {
    343   control_message_proxy_.RequireVersion(version);
    344 }
    345 
    346 void InterfaceEndpointClient::FlushForTesting() {
    347   control_message_proxy_.FlushForTesting();
    348 }
    349 
    350 void InterfaceEndpointClient::InitControllerIfNecessary() {
    351   if (controller_ || handle_.pending_association())
    352     return;
    353 
    354   controller_ = handle_.group_controller()->AttachEndpointClient(handle_, this,
    355                                                                  task_runner_);
    356   if (expect_sync_requests_)
    357     controller_->AllowWokenUpBySyncWatchOnSameThread();
    358 }
    359 
    360 void InterfaceEndpointClient::OnAssociationEvent(
    361     ScopedInterfaceEndpointHandle::AssociationEvent event) {
    362   if (event == ScopedInterfaceEndpointHandle::ASSOCIATED) {
    363     InitControllerIfNecessary();
    364   } else if (event ==
    365              ScopedInterfaceEndpointHandle::PEER_CLOSED_BEFORE_ASSOCIATION) {
    366     task_runner_->PostTask(FROM_HERE,
    367                            base::Bind(&InterfaceEndpointClient::NotifyError,
    368                                       weak_ptr_factory_.GetWeakPtr(),
    369                                       handle_.disconnect_reason()));
    370   }
    371 }
    372 
    373 bool InterfaceEndpointClient::HandleValidatedMessage(Message* message) {
    374   DCHECK_EQ(handle_.id(), message->interface_id());
    375 
    376   if (encountered_error_) {
    377     // This message is received after error has been encountered. For associated
    378     // interfaces, this means the remote side sends a
    379     // PeerAssociatedEndpointClosed event but continues to send more messages
    380     // for the same interface. Close the pipe because this shouldn't happen.
    381     DVLOG(1) << "A message is received for an interface after it has been "
    382              << "disconnected. Closing the pipe.";
    383     return false;
    384   }
    385 
    386   if (message->has_flag(Message::kFlagExpectsResponse)) {
    387     std::unique_ptr<MessageReceiverWithStatus> responder =
    388         std::make_unique<ResponderThunk>(weak_ptr_factory_.GetWeakPtr(),
    389                                          task_runner_);
    390     if (mojo::internal::ControlMessageHandler::IsControlMessage(message)) {
    391       return control_message_handler_.AcceptWithResponder(message,
    392                                                           std::move(responder));
    393     } else {
    394       return incoming_receiver_->AcceptWithResponder(message,
    395                                                      std::move(responder));
    396     }
    397   } else if (message->has_flag(Message::kFlagIsResponse)) {
    398     uint64_t request_id = message->request_id();
    399 
    400     if (message->has_flag(Message::kFlagIsSync)) {
    401       auto it = sync_responses_.find(request_id);
    402       if (it == sync_responses_.end())
    403         return false;
    404       it->second->response = std::move(*message);
    405       *it->second->response_received = true;
    406       return true;
    407     }
    408 
    409     auto it = async_responders_.find(request_id);
    410     if (it == async_responders_.end())
    411       return false;
    412     std::unique_ptr<MessageReceiver> responder = std::move(it->second);
    413     async_responders_.erase(it);
    414     return responder->Accept(message);
    415   } else {
    416     if (mojo::internal::ControlMessageHandler::IsControlMessage(message))
    417       return control_message_handler_.Accept(message);
    418 
    419     return incoming_receiver_->Accept(message);
    420   }
    421 }
    422 
    423 }  // namespace mojo
    424