1 // Copyright 2016 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "mojo/edk/system/channel.h" 6 7 #include <stdint.h> 8 #include <windows.h> 9 10 #include <algorithm> 11 #include <deque> 12 #include <limits> 13 #include <memory> 14 15 #include "base/bind.h" 16 #include "base/location.h" 17 #include "base/macros.h" 18 #include "base/memory/ref_counted.h" 19 #include "base/message_loop/message_loop.h" 20 #include "base/synchronization/lock.h" 21 #include "base/task_runner.h" 22 #include "base/win/win_util.h" 23 #include "mojo/edk/embedder/platform_handle_vector.h" 24 25 namespace mojo { 26 namespace edk { 27 28 namespace { 29 30 // A view over a Channel::Message object. The write queue uses these since 31 // large messages may need to be sent in chunks. 32 class MessageView { 33 public: 34 // Owns |message|. |offset| indexes the first unsent byte in the message. 35 MessageView(Channel::MessagePtr message, size_t offset) 36 : message_(std::move(message)), 37 offset_(offset) { 38 DCHECK_GT(message_->data_num_bytes(), offset_); 39 } 40 41 MessageView(MessageView&& other) { *this = std::move(other); } 42 43 MessageView& operator=(MessageView&& other) { 44 message_ = std::move(other.message_); 45 offset_ = other.offset_; 46 return *this; 47 } 48 49 ~MessageView() {} 50 51 const void* data() const { 52 return static_cast<const char*>(message_->data()) + offset_; 53 } 54 55 size_t data_num_bytes() const { return message_->data_num_bytes() - offset_; } 56 57 size_t data_offset() const { return offset_; } 58 void advance_data_offset(size_t num_bytes) { 59 DCHECK_GE(message_->data_num_bytes(), offset_ + num_bytes); 60 offset_ += num_bytes; 61 } 62 63 Channel::MessagePtr TakeChannelMessage() { return std::move(message_); } 64 65 private: 66 Channel::MessagePtr message_; 67 size_t offset_; 68 69 DISALLOW_COPY_AND_ASSIGN(MessageView); 70 }; 71 72 class ChannelWin : public Channel, 73 public base::MessageLoop::DestructionObserver, 74 public base::MessageLoopForIO::IOHandler { 75 public: 76 ChannelWin(Delegate* delegate, 77 ScopedPlatformHandle handle, 78 scoped_refptr<base::TaskRunner> io_task_runner) 79 : Channel(delegate), 80 self_(this), 81 handle_(std::move(handle)), 82 io_task_runner_(io_task_runner) { 83 CHECK(handle_.is_valid()); 84 85 wait_for_connect_ = handle_.get().needs_connection; 86 } 87 88 void Start() override { 89 io_task_runner_->PostTask( 90 FROM_HERE, base::Bind(&ChannelWin::StartOnIOThread, this)); 91 } 92 93 void ShutDownImpl() override { 94 // Always shut down asynchronously when called through the public interface. 95 io_task_runner_->PostTask( 96 FROM_HERE, base::Bind(&ChannelWin::ShutDownOnIOThread, this)); 97 } 98 99 void Write(MessagePtr message) override { 100 bool write_error = false; 101 { 102 base::AutoLock lock(write_lock_); 103 if (reject_writes_) 104 return; 105 106 bool write_now = !delay_writes_ && outgoing_messages_.empty(); 107 outgoing_messages_.emplace_back(std::move(message), 0); 108 109 if (write_now && !WriteNoLock(outgoing_messages_.front())) 110 reject_writes_ = write_error = true; 111 } 112 if (write_error) { 113 // Do not synchronously invoke OnError(). Write() may have been called by 114 // the delegate and we don't want to re-enter it. 115 io_task_runner_->PostTask(FROM_HERE, 116 base::Bind(&ChannelWin::OnError, this)); 117 } 118 } 119 120 void LeakHandle() override { 121 DCHECK(io_task_runner_->RunsTasksOnCurrentThread()); 122 leak_handle_ = true; 123 } 124 125 bool GetReadPlatformHandles( 126 size_t num_handles, 127 const void* extra_header, 128 size_t extra_header_size, 129 ScopedPlatformHandleVectorPtr* handles) override { 130 if (num_handles > std::numeric_limits<uint16_t>::max()) 131 return false; 132 using HandleEntry = Channel::Message::HandleEntry; 133 size_t handles_size = sizeof(HandleEntry) * num_handles; 134 if (handles_size > extra_header_size) 135 return false; 136 DCHECK(extra_header); 137 handles->reset(new PlatformHandleVector(num_handles)); 138 const HandleEntry* extra_header_handles = 139 reinterpret_cast<const HandleEntry*>(extra_header); 140 for (size_t i = 0; i < num_handles; i++) { 141 (*handles)->at(i).handle = 142 base::win::Uint32ToHandle(extra_header_handles[i].handle); 143 } 144 return true; 145 } 146 147 private: 148 // May run on any thread. 149 ~ChannelWin() override {} 150 151 void StartOnIOThread() { 152 base::MessageLoop::current()->AddDestructionObserver(this); 153 base::MessageLoopForIO::current()->RegisterIOHandler( 154 handle_.get().handle, this); 155 156 if (wait_for_connect_) { 157 BOOL ok = ConnectNamedPipe(handle_.get().handle, 158 &connect_context_.overlapped); 159 if (ok) { 160 PLOG(ERROR) << "Unexpected success while waiting for pipe connection"; 161 OnError(); 162 return; 163 } 164 165 const DWORD err = GetLastError(); 166 switch (err) { 167 case ERROR_PIPE_CONNECTED: 168 wait_for_connect_ = false; 169 break; 170 case ERROR_IO_PENDING: 171 AddRef(); 172 return; 173 case ERROR_NO_DATA: 174 OnError(); 175 return; 176 } 177 } 178 179 // Now that we have registered our IOHandler, we can start writing. 180 { 181 base::AutoLock lock(write_lock_); 182 if (delay_writes_) { 183 delay_writes_ = false; 184 WriteNextNoLock(); 185 } 186 } 187 188 // Keep this alive in case we synchronously run shutdown. 189 scoped_refptr<ChannelWin> keep_alive(this); 190 ReadMore(0); 191 } 192 193 void ShutDownOnIOThread() { 194 base::MessageLoop::current()->RemoveDestructionObserver(this); 195 196 // BUG(crbug.com/583525): This function is expected to be called once, and 197 // |handle_| should be valid at this point. 198 CHECK(handle_.is_valid()); 199 CancelIo(handle_.get().handle); 200 if (leak_handle_) 201 ignore_result(handle_.release()); 202 handle_.reset(); 203 204 // May destroy the |this| if it was the last reference. 205 self_ = nullptr; 206 } 207 208 // base::MessageLoop::DestructionObserver: 209 void WillDestroyCurrentMessageLoop() override { 210 DCHECK(io_task_runner_->RunsTasksOnCurrentThread()); 211 if (self_) 212 ShutDownOnIOThread(); 213 } 214 215 // base::MessageLoop::IOHandler: 216 void OnIOCompleted(base::MessageLoopForIO::IOContext* context, 217 DWORD bytes_transfered, 218 DWORD error) override { 219 if (error != ERROR_SUCCESS) { 220 OnError(); 221 } else if (context == &connect_context_) { 222 DCHECK(wait_for_connect_); 223 wait_for_connect_ = false; 224 ReadMore(0); 225 226 base::AutoLock lock(write_lock_); 227 if (delay_writes_) { 228 delay_writes_ = false; 229 WriteNextNoLock(); 230 } 231 } else if (context == &read_context_) { 232 OnReadDone(static_cast<size_t>(bytes_transfered)); 233 } else { 234 CHECK(context == &write_context_); 235 OnWriteDone(static_cast<size_t>(bytes_transfered)); 236 } 237 Release(); // Balancing reference taken after ReadFile / WriteFile. 238 } 239 240 void OnReadDone(size_t bytes_read) { 241 if (bytes_read > 0) { 242 size_t next_read_size = 0; 243 if (OnReadComplete(bytes_read, &next_read_size)) { 244 ReadMore(next_read_size); 245 } else { 246 OnError(); 247 } 248 } else if (bytes_read == 0) { 249 OnError(); 250 } 251 } 252 253 void OnWriteDone(size_t bytes_written) { 254 if (bytes_written == 0) 255 return; 256 257 bool write_error = false; 258 { 259 base::AutoLock lock(write_lock_); 260 261 DCHECK(!outgoing_messages_.empty()); 262 263 MessageView& message_view = outgoing_messages_.front(); 264 message_view.advance_data_offset(bytes_written); 265 if (message_view.data_num_bytes() == 0) { 266 Channel::MessagePtr message = message_view.TakeChannelMessage(); 267 outgoing_messages_.pop_front(); 268 269 // Clear any handles so they don't get closed on destruction. 270 ScopedPlatformHandleVectorPtr handles = message->TakeHandles(); 271 if (handles) 272 handles->clear(); 273 } 274 275 if (!WriteNextNoLock()) 276 reject_writes_ = write_error = true; 277 } 278 if (write_error) 279 OnError(); 280 } 281 282 void ReadMore(size_t next_read_size_hint) { 283 size_t buffer_capacity = next_read_size_hint; 284 char* buffer = GetReadBuffer(&buffer_capacity); 285 DCHECK_GT(buffer_capacity, 0u); 286 287 BOOL ok = ReadFile(handle_.get().handle, 288 buffer, 289 static_cast<DWORD>(buffer_capacity), 290 NULL, 291 &read_context_.overlapped); 292 293 if (ok || GetLastError() == ERROR_IO_PENDING) { 294 AddRef(); // Will be balanced in OnIOCompleted 295 } else { 296 OnError(); 297 } 298 } 299 300 // Attempts to write a message directly to the channel. If the full message 301 // cannot be written, it's queued and a wait is initiated to write the message 302 // ASAP on the I/O thread. 303 bool WriteNoLock(const MessageView& message_view) { 304 BOOL ok = WriteFile(handle_.get().handle, 305 message_view.data(), 306 static_cast<DWORD>(message_view.data_num_bytes()), 307 NULL, 308 &write_context_.overlapped); 309 310 if (ok || GetLastError() == ERROR_IO_PENDING) { 311 AddRef(); // Will be balanced in OnIOCompleted. 312 return true; 313 } 314 return false; 315 } 316 317 bool WriteNextNoLock() { 318 if (outgoing_messages_.empty()) 319 return true; 320 return WriteNoLock(outgoing_messages_.front()); 321 } 322 323 // Keeps the Channel alive at least until explicit shutdown on the IO thread. 324 scoped_refptr<Channel> self_; 325 326 ScopedPlatformHandle handle_; 327 scoped_refptr<base::TaskRunner> io_task_runner_; 328 329 base::MessageLoopForIO::IOContext connect_context_; 330 base::MessageLoopForIO::IOContext read_context_; 331 base::MessageLoopForIO::IOContext write_context_; 332 333 // Protects |reject_writes_| and |outgoing_messages_|. 334 base::Lock write_lock_; 335 336 bool delay_writes_ = true; 337 338 bool reject_writes_ = false; 339 std::deque<MessageView> outgoing_messages_; 340 341 bool wait_for_connect_; 342 343 bool leak_handle_ = false; 344 345 DISALLOW_COPY_AND_ASSIGN(ChannelWin); 346 }; 347 348 } // namespace 349 350 // static 351 scoped_refptr<Channel> Channel::Create( 352 Delegate* delegate, 353 ConnectionParams connection_params, 354 scoped_refptr<base::TaskRunner> io_task_runner) { 355 return new ChannelWin(delegate, connection_params.TakeChannelHandle(), 356 io_task_runner); 357 } 358 359 } // namespace edk 360 } // namespace mojo 361