Home | History | Annotate | Download | only in lib
      1 // Copyright 2015 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "mojo/public/cpp/bindings/interface_endpoint_client.h"
      6 
      7 #include <stdint.h>
      8 
      9 #include <utility>
     10 
     11 #include "base/bind.h"
     12 #include "base/location.h"
     13 #include "base/macros.h"
     14 #include "base/memory/ptr_util.h"
     15 #include "base/single_thread_task_runner.h"
     16 #include "base/stl_util.h"
     17 #include "mojo/public/cpp/bindings/associated_group.h"
     18 #include "mojo/public/cpp/bindings/associated_group_controller.h"
     19 #include "mojo/public/cpp/bindings/interface_endpoint_controller.h"
     20 #include "mojo/public/cpp/bindings/sync_call_restrictions.h"
     21 
     22 namespace mojo {
     23 
     24 // ----------------------------------------------------------------------------
     25 
     26 namespace {
     27 
     28 void DCheckIfInvalid(const base::WeakPtr<InterfaceEndpointClient>& client,
     29                    const std::string& message) {
     30   bool is_valid = client && !client->encountered_error();
     31   DCHECK(!is_valid) << message;
     32 }
     33 
     34 // When receiving an incoming message which expects a repsonse,
     35 // InterfaceEndpointClient creates a ResponderThunk object and passes it to the
     36 // incoming message receiver. When the receiver finishes processing the message,
     37 // it can provide a response using this object.
     38 class ResponderThunk : public MessageReceiverWithStatus {
     39  public:
     40   explicit ResponderThunk(
     41       const base::WeakPtr<InterfaceEndpointClient>& endpoint_client,
     42       scoped_refptr<base::SingleThreadTaskRunner> runner)
     43       : endpoint_client_(endpoint_client),
     44         accept_was_invoked_(false),
     45         task_runner_(std::move(runner)) {}
     46   ~ResponderThunk() override {
     47     if (!accept_was_invoked_) {
     48       // The Mojo application handled a message that was expecting a response
     49       // but did not send a response.
     50       // We raise an error to signal the calling application that an error
     51       // condition occurred. Without this the calling application would have no
     52       // way of knowing it should stop waiting for a response.
     53       if (task_runner_->RunsTasksOnCurrentThread()) {
     54         // Please note that even if this code is run from a different task
     55         // runner on the same thread as |task_runner_|, it is okay to directly
     56         // call InterfaceEndpointClient::RaiseError(), because it will raise
     57         // error from the correct task runner asynchronously.
     58         if (endpoint_client_) {
     59           endpoint_client_->RaiseError();
     60         }
     61       } else {
     62         task_runner_->PostTask(
     63             FROM_HERE,
     64             base::Bind(&InterfaceEndpointClient::RaiseError, endpoint_client_));
     65       }
     66     }
     67   }
     68 
     69   // MessageReceiver implementation:
     70   bool Accept(Message* message) override {
     71     DCHECK(task_runner_->RunsTasksOnCurrentThread());
     72     accept_was_invoked_ = true;
     73     DCHECK(message->has_flag(Message::kFlagIsResponse));
     74 
     75     bool result = false;
     76 
     77     if (endpoint_client_)
     78       result = endpoint_client_->Accept(message);
     79 
     80     return result;
     81   }
     82 
     83   // MessageReceiverWithStatus implementation:
     84   bool IsValid() override {
     85     DCHECK(task_runner_->RunsTasksOnCurrentThread());
     86     return endpoint_client_ && !endpoint_client_->encountered_error();
     87   }
     88 
     89   void DCheckInvalid(const std::string& message) override {
     90     if (task_runner_->RunsTasksOnCurrentThread()) {
     91       DCheckIfInvalid(endpoint_client_, message);
     92     } else {
     93       task_runner_->PostTask(
     94           FROM_HERE, base::Bind(&DCheckIfInvalid, endpoint_client_, message));
     95     }
     96  }
     97 
     98  private:
     99   base::WeakPtr<InterfaceEndpointClient> endpoint_client_;
    100   bool accept_was_invoked_;
    101   scoped_refptr<base::SingleThreadTaskRunner> task_runner_;
    102 
    103   DISALLOW_COPY_AND_ASSIGN(ResponderThunk);
    104 };
    105 
    106 }  // namespace
    107 
    108 // ----------------------------------------------------------------------------
    109 
    110 InterfaceEndpointClient::SyncResponseInfo::SyncResponseInfo(
    111     bool* in_response_received)
    112     : response_received(in_response_received) {}
    113 
    114 InterfaceEndpointClient::SyncResponseInfo::~SyncResponseInfo() {}
    115 
    116 // ----------------------------------------------------------------------------
    117 
    118 InterfaceEndpointClient::HandleIncomingMessageThunk::HandleIncomingMessageThunk(
    119     InterfaceEndpointClient* owner)
    120     : owner_(owner) {}
    121 
    122 InterfaceEndpointClient::HandleIncomingMessageThunk::
    123     ~HandleIncomingMessageThunk() {}
    124 
    125 bool InterfaceEndpointClient::HandleIncomingMessageThunk::Accept(
    126     Message* message) {
    127   return owner_->HandleValidatedMessage(message);
    128 }
    129 
    130 // ----------------------------------------------------------------------------
    131 
    132 InterfaceEndpointClient::InterfaceEndpointClient(
    133     ScopedInterfaceEndpointHandle handle,
    134     MessageReceiverWithResponderStatus* receiver,
    135     std::unique_ptr<MessageFilter> payload_validator,
    136     bool expect_sync_requests,
    137     scoped_refptr<base::SingleThreadTaskRunner> runner)
    138     : handle_(std::move(handle)),
    139       incoming_receiver_(receiver),
    140       payload_validator_(std::move(payload_validator)),
    141       thunk_(this),
    142       next_request_id_(1),
    143       encountered_error_(false),
    144       task_runner_(std::move(runner)),
    145       weak_ptr_factory_(this) {
    146   DCHECK(handle_.is_valid());
    147   DCHECK(handle_.is_local());
    148 
    149   // TODO(yzshen): the way to use validator (or message filter in general)
    150   // directly is a little awkward.
    151   payload_validator_->set_sink(&thunk_);
    152 
    153   controller_ = handle_.group_controller()->AttachEndpointClient(
    154       handle_, this, task_runner_);
    155   if (expect_sync_requests)
    156     controller_->AllowWokenUpBySyncWatchOnSameThread();
    157 }
    158 
    159 InterfaceEndpointClient::~InterfaceEndpointClient() {
    160   DCHECK(thread_checker_.CalledOnValidThread());
    161 
    162   handle_.group_controller()->DetachEndpointClient(handle_);
    163 }
    164 
    165 AssociatedGroup* InterfaceEndpointClient::associated_group() {
    166   if (!associated_group_)
    167     associated_group_ = handle_.group_controller()->CreateAssociatedGroup();
    168   return associated_group_.get();
    169 }
    170 
    171 uint32_t InterfaceEndpointClient::interface_id() const {
    172   DCHECK(thread_checker_.CalledOnValidThread());
    173   return handle_.id();
    174 }
    175 
    176 ScopedInterfaceEndpointHandle InterfaceEndpointClient::PassHandle() {
    177   DCHECK(thread_checker_.CalledOnValidThread());
    178   DCHECK(!has_pending_responders());
    179 
    180   if (!handle_.is_valid())
    181     return ScopedInterfaceEndpointHandle();
    182 
    183   controller_ = nullptr;
    184   handle_.group_controller()->DetachEndpointClient(handle_);
    185 
    186   return std::move(handle_);
    187 }
    188 
    189 void InterfaceEndpointClient::RaiseError() {
    190   DCHECK(thread_checker_.CalledOnValidThread());
    191 
    192   handle_.group_controller()->RaiseError();
    193 }
    194 
    195 bool InterfaceEndpointClient::Accept(Message* message) {
    196   DCHECK(thread_checker_.CalledOnValidThread());
    197   DCHECK(controller_);
    198   DCHECK(!message->has_flag(Message::kFlagExpectsResponse));
    199 
    200   if (encountered_error_)
    201     return false;
    202 
    203   return controller_->SendMessage(message);
    204 }
    205 
    206 bool InterfaceEndpointClient::AcceptWithResponder(Message* message,
    207                                                   MessageReceiver* responder) {
    208   DCHECK(thread_checker_.CalledOnValidThread());
    209   DCHECK(controller_);
    210   DCHECK(message->has_flag(Message::kFlagExpectsResponse));
    211 
    212   if (encountered_error_)
    213     return false;
    214 
    215   // Reserve 0 in case we want it to convey special meaning in the future.
    216   uint64_t request_id = next_request_id_++;
    217   if (request_id == 0)
    218     request_id = next_request_id_++;
    219 
    220   message->set_request_id(request_id);
    221 
    222   bool is_sync = message->has_flag(Message::kFlagIsSync);
    223   if (!controller_->SendMessage(message))
    224     return false;
    225 
    226   if (!is_sync) {
    227     // We assume ownership of |responder|.
    228     async_responders_[request_id] = base::WrapUnique(responder);
    229     return true;
    230   }
    231 
    232   SyncCallRestrictions::AssertSyncCallAllowed();
    233 
    234   bool response_received = false;
    235   std::unique_ptr<MessageReceiver> sync_responder(responder);
    236   sync_responses_.insert(std::make_pair(
    237       request_id, base::WrapUnique(new SyncResponseInfo(&response_received))));
    238 
    239   base::WeakPtr<InterfaceEndpointClient> weak_self =
    240       weak_ptr_factory_.GetWeakPtr();
    241   controller_->SyncWatch(&response_received);
    242   // Make sure that this instance hasn't been destroyed.
    243   if (weak_self) {
    244     DCHECK(ContainsKey(sync_responses_, request_id));
    245     auto iter = sync_responses_.find(request_id);
    246     DCHECK_EQ(&response_received, iter->second->response_received);
    247     if (response_received) {
    248       std::unique_ptr<Message> response = std::move(iter->second->response);
    249       ignore_result(sync_responder->Accept(response.get()));
    250     }
    251     sync_responses_.erase(iter);
    252   }
    253 
    254   // Return true means that we take ownership of |responder|.
    255   return true;
    256 }
    257 
    258 bool InterfaceEndpointClient::HandleIncomingMessage(Message* message) {
    259   DCHECK(thread_checker_.CalledOnValidThread());
    260 
    261   return payload_validator_->Accept(message);
    262 }
    263 
    264 void InterfaceEndpointClient::NotifyError() {
    265   DCHECK(thread_checker_.CalledOnValidThread());
    266 
    267   if (encountered_error_)
    268     return;
    269   encountered_error_ = true;
    270   if (!error_handler_.is_null())
    271     error_handler_.Run();
    272 }
    273 
    274 bool InterfaceEndpointClient::HandleValidatedMessage(Message* message) {
    275   DCHECK_EQ(handle_.id(), message->interface_id());
    276 
    277   if (message->has_flag(Message::kFlagExpectsResponse)) {
    278     if (!incoming_receiver_)
    279       return false;
    280 
    281     MessageReceiverWithStatus* responder =
    282         new ResponderThunk(weak_ptr_factory_.GetWeakPtr(), task_runner_);
    283     bool ok = incoming_receiver_->AcceptWithResponder(message, responder);
    284     if (!ok)
    285       delete responder;
    286     return ok;
    287   } else if (message->has_flag(Message::kFlagIsResponse)) {
    288     uint64_t request_id = message->request_id();
    289 
    290     if (message->has_flag(Message::kFlagIsSync)) {
    291       auto it = sync_responses_.find(request_id);
    292       if (it == sync_responses_.end())
    293         return false;
    294       it->second->response.reset(new Message());
    295       message->MoveTo(it->second->response.get());
    296       *it->second->response_received = true;
    297       return true;
    298     }
    299 
    300     auto it = async_responders_.find(request_id);
    301     if (it == async_responders_.end())
    302       return false;
    303     std::unique_ptr<MessageReceiver> responder = std::move(it->second);
    304     async_responders_.erase(it);
    305     return responder->Accept(message);
    306   } else {
    307     if (!incoming_receiver_)
    308       return false;
    309 
    310     return incoming_receiver_->Accept(message);
    311   }
    312 }
    313 
    314 }  // namespace mojo
    315