1 // Copyright 2015 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "mojo/public/cpp/bindings/interface_endpoint_client.h" 6 7 #include <stdint.h> 8 9 #include <utility> 10 11 #include "base/bind.h" 12 #include "base/location.h" 13 #include "base/macros.h" 14 #include "base/memory/ptr_util.h" 15 #include "base/single_thread_task_runner.h" 16 #include "base/stl_util.h" 17 #include "mojo/public/cpp/bindings/associated_group.h" 18 #include "mojo/public/cpp/bindings/associated_group_controller.h" 19 #include "mojo/public/cpp/bindings/interface_endpoint_controller.h" 20 #include "mojo/public/cpp/bindings/sync_call_restrictions.h" 21 22 namespace mojo { 23 24 // ---------------------------------------------------------------------------- 25 26 namespace { 27 28 void DCheckIfInvalid(const base::WeakPtr<InterfaceEndpointClient>& client, 29 const std::string& message) { 30 bool is_valid = client && !client->encountered_error(); 31 DCHECK(!is_valid) << message; 32 } 33 34 // When receiving an incoming message which expects a repsonse, 35 // InterfaceEndpointClient creates a ResponderThunk object and passes it to the 36 // incoming message receiver. When the receiver finishes processing the message, 37 // it can provide a response using this object. 38 class ResponderThunk : public MessageReceiverWithStatus { 39 public: 40 explicit ResponderThunk( 41 const base::WeakPtr<InterfaceEndpointClient>& endpoint_client, 42 scoped_refptr<base::SingleThreadTaskRunner> runner) 43 : endpoint_client_(endpoint_client), 44 accept_was_invoked_(false), 45 task_runner_(std::move(runner)) {} 46 ~ResponderThunk() override { 47 if (!accept_was_invoked_) { 48 // The Mojo application handled a message that was expecting a response 49 // but did not send a response. 50 // We raise an error to signal the calling application that an error 51 // condition occurred. Without this the calling application would have no 52 // way of knowing it should stop waiting for a response. 53 if (task_runner_->RunsTasksOnCurrentThread()) { 54 // Please note that even if this code is run from a different task 55 // runner on the same thread as |task_runner_|, it is okay to directly 56 // call InterfaceEndpointClient::RaiseError(), because it will raise 57 // error from the correct task runner asynchronously. 58 if (endpoint_client_) { 59 endpoint_client_->RaiseError(); 60 } 61 } else { 62 task_runner_->PostTask( 63 FROM_HERE, 64 base::Bind(&InterfaceEndpointClient::RaiseError, endpoint_client_)); 65 } 66 } 67 } 68 69 // MessageReceiver implementation: 70 bool Accept(Message* message) override { 71 DCHECK(task_runner_->RunsTasksOnCurrentThread()); 72 accept_was_invoked_ = true; 73 DCHECK(message->has_flag(Message::kFlagIsResponse)); 74 75 bool result = false; 76 77 if (endpoint_client_) 78 result = endpoint_client_->Accept(message); 79 80 return result; 81 } 82 83 // MessageReceiverWithStatus implementation: 84 bool IsValid() override { 85 DCHECK(task_runner_->RunsTasksOnCurrentThread()); 86 return endpoint_client_ && !endpoint_client_->encountered_error(); 87 } 88 89 void DCheckInvalid(const std::string& message) override { 90 if (task_runner_->RunsTasksOnCurrentThread()) { 91 DCheckIfInvalid(endpoint_client_, message); 92 } else { 93 task_runner_->PostTask( 94 FROM_HERE, base::Bind(&DCheckIfInvalid, endpoint_client_, message)); 95 } 96 } 97 98 private: 99 base::WeakPtr<InterfaceEndpointClient> endpoint_client_; 100 bool accept_was_invoked_; 101 scoped_refptr<base::SingleThreadTaskRunner> task_runner_; 102 103 DISALLOW_COPY_AND_ASSIGN(ResponderThunk); 104 }; 105 106 } // namespace 107 108 // ---------------------------------------------------------------------------- 109 110 InterfaceEndpointClient::SyncResponseInfo::SyncResponseInfo( 111 bool* in_response_received) 112 : response_received(in_response_received) {} 113 114 InterfaceEndpointClient::SyncResponseInfo::~SyncResponseInfo() {} 115 116 // ---------------------------------------------------------------------------- 117 118 InterfaceEndpointClient::HandleIncomingMessageThunk::HandleIncomingMessageThunk( 119 InterfaceEndpointClient* owner) 120 : owner_(owner) {} 121 122 InterfaceEndpointClient::HandleIncomingMessageThunk:: 123 ~HandleIncomingMessageThunk() {} 124 125 bool InterfaceEndpointClient::HandleIncomingMessageThunk::Accept( 126 Message* message) { 127 return owner_->HandleValidatedMessage(message); 128 } 129 130 // ---------------------------------------------------------------------------- 131 132 InterfaceEndpointClient::InterfaceEndpointClient( 133 ScopedInterfaceEndpointHandle handle, 134 MessageReceiverWithResponderStatus* receiver, 135 std::unique_ptr<MessageFilter> payload_validator, 136 bool expect_sync_requests, 137 scoped_refptr<base::SingleThreadTaskRunner> runner) 138 : handle_(std::move(handle)), 139 incoming_receiver_(receiver), 140 payload_validator_(std::move(payload_validator)), 141 thunk_(this), 142 next_request_id_(1), 143 encountered_error_(false), 144 task_runner_(std::move(runner)), 145 weak_ptr_factory_(this) { 146 DCHECK(handle_.is_valid()); 147 DCHECK(handle_.is_local()); 148 149 // TODO(yzshen): the way to use validator (or message filter in general) 150 // directly is a little awkward. 151 payload_validator_->set_sink(&thunk_); 152 153 controller_ = handle_.group_controller()->AttachEndpointClient( 154 handle_, this, task_runner_); 155 if (expect_sync_requests) 156 controller_->AllowWokenUpBySyncWatchOnSameThread(); 157 } 158 159 InterfaceEndpointClient::~InterfaceEndpointClient() { 160 DCHECK(thread_checker_.CalledOnValidThread()); 161 162 handle_.group_controller()->DetachEndpointClient(handle_); 163 } 164 165 AssociatedGroup* InterfaceEndpointClient::associated_group() { 166 if (!associated_group_) 167 associated_group_ = handle_.group_controller()->CreateAssociatedGroup(); 168 return associated_group_.get(); 169 } 170 171 uint32_t InterfaceEndpointClient::interface_id() const { 172 DCHECK(thread_checker_.CalledOnValidThread()); 173 return handle_.id(); 174 } 175 176 ScopedInterfaceEndpointHandle InterfaceEndpointClient::PassHandle() { 177 DCHECK(thread_checker_.CalledOnValidThread()); 178 DCHECK(!has_pending_responders()); 179 180 if (!handle_.is_valid()) 181 return ScopedInterfaceEndpointHandle(); 182 183 controller_ = nullptr; 184 handle_.group_controller()->DetachEndpointClient(handle_); 185 186 return std::move(handle_); 187 } 188 189 void InterfaceEndpointClient::RaiseError() { 190 DCHECK(thread_checker_.CalledOnValidThread()); 191 192 handle_.group_controller()->RaiseError(); 193 } 194 195 bool InterfaceEndpointClient::Accept(Message* message) { 196 DCHECK(thread_checker_.CalledOnValidThread()); 197 DCHECK(controller_); 198 DCHECK(!message->has_flag(Message::kFlagExpectsResponse)); 199 200 if (encountered_error_) 201 return false; 202 203 return controller_->SendMessage(message); 204 } 205 206 bool InterfaceEndpointClient::AcceptWithResponder(Message* message, 207 MessageReceiver* responder) { 208 DCHECK(thread_checker_.CalledOnValidThread()); 209 DCHECK(controller_); 210 DCHECK(message->has_flag(Message::kFlagExpectsResponse)); 211 212 if (encountered_error_) 213 return false; 214 215 // Reserve 0 in case we want it to convey special meaning in the future. 216 uint64_t request_id = next_request_id_++; 217 if (request_id == 0) 218 request_id = next_request_id_++; 219 220 message->set_request_id(request_id); 221 222 bool is_sync = message->has_flag(Message::kFlagIsSync); 223 if (!controller_->SendMessage(message)) 224 return false; 225 226 if (!is_sync) { 227 // We assume ownership of |responder|. 228 async_responders_[request_id] = base::WrapUnique(responder); 229 return true; 230 } 231 232 SyncCallRestrictions::AssertSyncCallAllowed(); 233 234 bool response_received = false; 235 std::unique_ptr<MessageReceiver> sync_responder(responder); 236 sync_responses_.insert(std::make_pair( 237 request_id, base::WrapUnique(new SyncResponseInfo(&response_received)))); 238 239 base::WeakPtr<InterfaceEndpointClient> weak_self = 240 weak_ptr_factory_.GetWeakPtr(); 241 controller_->SyncWatch(&response_received); 242 // Make sure that this instance hasn't been destroyed. 243 if (weak_self) { 244 DCHECK(ContainsKey(sync_responses_, request_id)); 245 auto iter = sync_responses_.find(request_id); 246 DCHECK_EQ(&response_received, iter->second->response_received); 247 if (response_received) { 248 std::unique_ptr<Message> response = std::move(iter->second->response); 249 ignore_result(sync_responder->Accept(response.get())); 250 } 251 sync_responses_.erase(iter); 252 } 253 254 // Return true means that we take ownership of |responder|. 255 return true; 256 } 257 258 bool InterfaceEndpointClient::HandleIncomingMessage(Message* message) { 259 DCHECK(thread_checker_.CalledOnValidThread()); 260 261 return payload_validator_->Accept(message); 262 } 263 264 void InterfaceEndpointClient::NotifyError() { 265 DCHECK(thread_checker_.CalledOnValidThread()); 266 267 if (encountered_error_) 268 return; 269 encountered_error_ = true; 270 if (!error_handler_.is_null()) 271 error_handler_.Run(); 272 } 273 274 bool InterfaceEndpointClient::HandleValidatedMessage(Message* message) { 275 DCHECK_EQ(handle_.id(), message->interface_id()); 276 277 if (message->has_flag(Message::kFlagExpectsResponse)) { 278 if (!incoming_receiver_) 279 return false; 280 281 MessageReceiverWithStatus* responder = 282 new ResponderThunk(weak_ptr_factory_.GetWeakPtr(), task_runner_); 283 bool ok = incoming_receiver_->AcceptWithResponder(message, responder); 284 if (!ok) 285 delete responder; 286 return ok; 287 } else if (message->has_flag(Message::kFlagIsResponse)) { 288 uint64_t request_id = message->request_id(); 289 290 if (message->has_flag(Message::kFlagIsSync)) { 291 auto it = sync_responses_.find(request_id); 292 if (it == sync_responses_.end()) 293 return false; 294 it->second->response.reset(new Message()); 295 message->MoveTo(it->second->response.get()); 296 *it->second->response_received = true; 297 return true; 298 } 299 300 auto it = async_responders_.find(request_id); 301 if (it == async_responders_.end()) 302 return false; 303 std::unique_ptr<MessageReceiver> responder = std::move(it->second); 304 async_responders_.erase(it); 305 return responder->Accept(message); 306 } else { 307 if (!incoming_receiver_) 308 return false; 309 310 return incoming_receiver_->Accept(message); 311 } 312 } 313 314 } // namespace mojo 315