1 // Copyright 2014 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "mojo/public/cpp/bindings/lib/router.h" 6 7 #include <stdint.h> 8 9 #include <utility> 10 11 #include "base/bind.h" 12 #include "base/location.h" 13 #include "base/logging.h" 14 #include "base/memory/ptr_util.h" 15 #include "base/stl_util.h" 16 #include "mojo/public/cpp/bindings/sync_call_restrictions.h" 17 18 namespace mojo { 19 namespace internal { 20 21 // ---------------------------------------------------------------------------- 22 23 namespace { 24 25 void DCheckIfInvalid(const base::WeakPtr<Router>& router, 26 const std::string& message) { 27 bool is_valid = router && !router->encountered_error() && router->is_valid(); 28 DCHECK(!is_valid) << message; 29 } 30 31 class ResponderThunk : public MessageReceiverWithStatus { 32 public: 33 explicit ResponderThunk(const base::WeakPtr<Router>& router, 34 scoped_refptr<base::SingleThreadTaskRunner> runner) 35 : router_(router), 36 accept_was_invoked_(false), 37 task_runner_(std::move(runner)) {} 38 ~ResponderThunk() override { 39 if (!accept_was_invoked_) { 40 // The Mojo application handled a message that was expecting a response 41 // but did not send a response. 42 // We raise an error to signal the calling application that an error 43 // condition occurred. Without this the calling application would have no 44 // way of knowing it should stop waiting for a response. 45 if (task_runner_->RunsTasksOnCurrentThread()) { 46 // Please note that even if this code is run from a different task 47 // runner on the same thread as |task_runner_|, it is okay to directly 48 // call Router::RaiseError(), because it will raise error from the 49 // correct task runner asynchronously. 50 if (router_) 51 router_->RaiseError(); 52 } else { 53 task_runner_->PostTask(FROM_HERE, 54 base::Bind(&Router::RaiseError, router_)); 55 } 56 } 57 } 58 59 // MessageReceiver implementation: 60 bool Accept(Message* message) override { 61 DCHECK(task_runner_->RunsTasksOnCurrentThread()); 62 accept_was_invoked_ = true; 63 DCHECK(message->has_flag(Message::kFlagIsResponse)); 64 65 bool result = false; 66 67 if (router_) 68 result = router_->Accept(message); 69 70 return result; 71 } 72 73 // MessageReceiverWithStatus implementation: 74 bool IsValid() override { 75 DCHECK(task_runner_->RunsTasksOnCurrentThread()); 76 return router_ && !router_->encountered_error() && router_->is_valid(); 77 } 78 79 void DCheckInvalid(const std::string& message) override { 80 if (task_runner_->RunsTasksOnCurrentThread()) { 81 DCheckIfInvalid(router_, message); 82 } else { 83 task_runner_->PostTask(FROM_HERE, 84 base::Bind(&DCheckIfInvalid, router_, message)); 85 } 86 } 87 88 private: 89 base::WeakPtr<Router> router_; 90 bool accept_was_invoked_; 91 scoped_refptr<base::SingleThreadTaskRunner> task_runner_; 92 }; 93 94 } // namespace 95 96 // ---------------------------------------------------------------------------- 97 98 Router::SyncResponseInfo::SyncResponseInfo(bool* in_response_received) 99 : response_received(in_response_received) {} 100 101 Router::SyncResponseInfo::~SyncResponseInfo() {} 102 103 // ---------------------------------------------------------------------------- 104 105 Router::HandleIncomingMessageThunk::HandleIncomingMessageThunk(Router* router) 106 : router_(router) { 107 } 108 109 Router::HandleIncomingMessageThunk::~HandleIncomingMessageThunk() { 110 } 111 112 bool Router::HandleIncomingMessageThunk::Accept(Message* message) { 113 return router_->HandleIncomingMessage(message); 114 } 115 116 // ---------------------------------------------------------------------------- 117 118 Router::Router(ScopedMessagePipeHandle message_pipe, 119 FilterChain filters, 120 bool expects_sync_requests, 121 scoped_refptr<base::SingleThreadTaskRunner> runner) 122 : thunk_(this), 123 filters_(std::move(filters)), 124 connector_(std::move(message_pipe), 125 Connector::SINGLE_THREADED_SEND, 126 std::move(runner)), 127 incoming_receiver_(nullptr), 128 next_request_id_(0), 129 testing_mode_(false), 130 pending_task_for_messages_(false), 131 encountered_error_(false), 132 weak_factory_(this) { 133 filters_.SetSink(&thunk_); 134 if (expects_sync_requests) 135 connector_.AllowWokenUpBySyncWatchOnSameThread(); 136 connector_.set_incoming_receiver(filters_.GetHead()); 137 connector_.set_connection_error_handler( 138 base::Bind(&Router::OnConnectionError, base::Unretained(this))); 139 } 140 141 Router::~Router() {} 142 143 bool Router::Accept(Message* message) { 144 DCHECK(thread_checker_.CalledOnValidThread()); 145 DCHECK(!message->has_flag(Message::kFlagExpectsResponse)); 146 return connector_.Accept(message); 147 } 148 149 bool Router::AcceptWithResponder(Message* message, MessageReceiver* responder) { 150 DCHECK(thread_checker_.CalledOnValidThread()); 151 DCHECK(message->has_flag(Message::kFlagExpectsResponse)); 152 153 // Reserve 0 in case we want it to convey special meaning in the future. 154 uint64_t request_id = next_request_id_++; 155 if (request_id == 0) 156 request_id = next_request_id_++; 157 158 bool is_sync = message->has_flag(Message::kFlagIsSync); 159 message->set_request_id(request_id); 160 if (!connector_.Accept(message)) 161 return false; 162 163 if (!is_sync) { 164 // We assume ownership of |responder|. 165 async_responders_[request_id] = base::WrapUnique(responder); 166 return true; 167 } 168 169 SyncCallRestrictions::AssertSyncCallAllowed(); 170 171 bool response_received = false; 172 std::unique_ptr<MessageReceiver> sync_responder(responder); 173 sync_responses_.insert(std::make_pair( 174 request_id, base::WrapUnique(new SyncResponseInfo(&response_received)))); 175 176 base::WeakPtr<Router> weak_self = weak_factory_.GetWeakPtr(); 177 connector_.SyncWatch(&response_received); 178 // Make sure that this instance hasn't been destroyed. 179 if (weak_self) { 180 DCHECK(ContainsKey(sync_responses_, request_id)); 181 auto iter = sync_responses_.find(request_id); 182 DCHECK_EQ(&response_received, iter->second->response_received); 183 if (response_received) { 184 std::unique_ptr<Message> response = std::move(iter->second->response); 185 ignore_result(sync_responder->Accept(response.get())); 186 } 187 sync_responses_.erase(iter); 188 } 189 190 // Return true means that we take ownership of |responder|. 191 return true; 192 } 193 194 void Router::EnableTestingMode() { 195 DCHECK(thread_checker_.CalledOnValidThread()); 196 testing_mode_ = true; 197 connector_.set_enforce_errors_from_incoming_receiver(false); 198 } 199 200 bool Router::HandleIncomingMessage(Message* message) { 201 DCHECK(thread_checker_.CalledOnValidThread()); 202 203 const bool during_sync_call = 204 connector_.during_sync_handle_watcher_callback(); 205 if (!message->has_flag(Message::kFlagIsSync) && 206 (during_sync_call || !pending_messages_.empty())) { 207 std::unique_ptr<Message> pending_message(new Message); 208 message->MoveTo(pending_message.get()); 209 pending_messages_.push(std::move(pending_message)); 210 211 if (!pending_task_for_messages_) { 212 pending_task_for_messages_ = true; 213 connector_.task_runner()->PostTask( 214 FROM_HERE, base::Bind(&Router::HandleQueuedMessages, 215 weak_factory_.GetWeakPtr())); 216 } 217 218 return true; 219 } 220 221 return HandleMessageInternal(message); 222 } 223 224 void Router::HandleQueuedMessages() { 225 DCHECK(thread_checker_.CalledOnValidThread()); 226 DCHECK(pending_task_for_messages_); 227 228 base::WeakPtr<Router> weak_self = weak_factory_.GetWeakPtr(); 229 while (!pending_messages_.empty()) { 230 std::unique_ptr<Message> message(std::move(pending_messages_.front())); 231 pending_messages_.pop(); 232 233 bool result = HandleMessageInternal(message.get()); 234 if (!weak_self) 235 return; 236 237 if (!result && !testing_mode_) { 238 connector_.RaiseError(); 239 break; 240 } 241 } 242 243 pending_task_for_messages_ = false; 244 245 // We may have already seen a connection error from the connector, but 246 // haven't notified the user because we want to process all the queued 247 // messages first. We should do it now. 248 if (connector_.encountered_error() && !encountered_error_) 249 OnConnectionError(); 250 } 251 252 bool Router::HandleMessageInternal(Message* message) { 253 if (message->has_flag(Message::kFlagExpectsResponse)) { 254 if (!incoming_receiver_) 255 return false; 256 257 MessageReceiverWithStatus* responder = new ResponderThunk( 258 weak_factory_.GetWeakPtr(), connector_.task_runner()); 259 bool ok = incoming_receiver_->AcceptWithResponder(message, responder); 260 if (!ok) 261 delete responder; 262 return ok; 263 264 } else if (message->has_flag(Message::kFlagIsResponse)) { 265 uint64_t request_id = message->request_id(); 266 267 if (message->has_flag(Message::kFlagIsSync)) { 268 auto it = sync_responses_.find(request_id); 269 if (it == sync_responses_.end()) { 270 DCHECK(testing_mode_); 271 return false; 272 } 273 it->second->response.reset(new Message()); 274 message->MoveTo(it->second->response.get()); 275 *it->second->response_received = true; 276 return true; 277 } 278 279 auto it = async_responders_.find(request_id); 280 if (it == async_responders_.end()) { 281 DCHECK(testing_mode_); 282 return false; 283 } 284 std::unique_ptr<MessageReceiver> responder = std::move(it->second); 285 async_responders_.erase(it); 286 return responder->Accept(message); 287 } else { 288 if (!incoming_receiver_) 289 return false; 290 291 return incoming_receiver_->Accept(message); 292 } 293 } 294 295 void Router::OnConnectionError() { 296 if (encountered_error_) 297 return; 298 299 if (!pending_messages_.empty()) { 300 // After all the pending messages are processed, we will check whether an 301 // error has been encountered and run the user's connection error handler 302 // if necessary. 303 DCHECK(pending_task_for_messages_); 304 return; 305 } 306 307 if (connector_.during_sync_handle_watcher_callback()) { 308 // We don't want the error handler to reenter an ongoing sync call. 309 connector_.task_runner()->PostTask( 310 FROM_HERE, 311 base::Bind(&Router::OnConnectionError, weak_factory_.GetWeakPtr())); 312 return; 313 } 314 315 encountered_error_ = true; 316 if (!error_handler_.is_null()) 317 error_handler_.Run(); 318 } 319 320 // ---------------------------------------------------------------------------- 321 322 } // namespace internal 323 } // namespace mojo 324