Home | History | Annotate | Download | only in system
      1 // Copyright 2016 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "mojo/edk/system/channel.h"
      6 
      7 #include <stdint.h>
      8 #include <windows.h>
      9 
     10 #include <algorithm>
     11 #include <deque>
     12 #include <limits>
     13 #include <memory>
     14 
     15 #include "base/bind.h"
     16 #include "base/location.h"
     17 #include "base/macros.h"
     18 #include "base/memory/ref_counted.h"
     19 #include "base/message_loop/message_loop.h"
     20 #include "base/synchronization/lock.h"
     21 #include "base/task_runner.h"
     22 #include "base/win/win_util.h"
     23 #include "mojo/edk/embedder/platform_handle_vector.h"
     24 
     25 namespace mojo {
     26 namespace edk {
     27 
     28 namespace {
     29 
     30 // A view over a Channel::Message object. The write queue uses these since
     31 // large messages may need to be sent in chunks.
     32 class MessageView {
     33  public:
     34   // Owns |message|. |offset| indexes the first unsent byte in the message.
     35   MessageView(Channel::MessagePtr message, size_t offset)
     36       : message_(std::move(message)),
     37         offset_(offset) {
     38     DCHECK_GT(message_->data_num_bytes(), offset_);
     39   }
     40 
     41   MessageView(MessageView&& other) { *this = std::move(other); }
     42 
     43   MessageView& operator=(MessageView&& other) {
     44     message_ = std::move(other.message_);
     45     offset_ = other.offset_;
     46     return *this;
     47   }
     48 
     49   ~MessageView() {}
     50 
     51   const void* data() const {
     52     return static_cast<const char*>(message_->data()) + offset_;
     53   }
     54 
     55   size_t data_num_bytes() const { return message_->data_num_bytes() - offset_; }
     56 
     57   size_t data_offset() const { return offset_; }
     58   void advance_data_offset(size_t num_bytes) {
     59     DCHECK_GE(message_->data_num_bytes(), offset_ + num_bytes);
     60     offset_ += num_bytes;
     61   }
     62 
     63   Channel::MessagePtr TakeChannelMessage() { return std::move(message_); }
     64 
     65  private:
     66   Channel::MessagePtr message_;
     67   size_t offset_;
     68 
     69   DISALLOW_COPY_AND_ASSIGN(MessageView);
     70 };
     71 
     72 class ChannelWin : public Channel,
     73                    public base::MessageLoop::DestructionObserver,
     74                    public base::MessageLoopForIO::IOHandler {
     75  public:
     76   ChannelWin(Delegate* delegate,
     77              ScopedPlatformHandle handle,
     78              scoped_refptr<base::TaskRunner> io_task_runner)
     79       : Channel(delegate),
     80         self_(this),
     81         handle_(std::move(handle)),
     82         io_task_runner_(io_task_runner) {
     83     CHECK(handle_.is_valid());
     84 
     85     wait_for_connect_ = handle_.get().needs_connection;
     86   }
     87 
     88   void Start() override {
     89     io_task_runner_->PostTask(
     90         FROM_HERE, base::Bind(&ChannelWin::StartOnIOThread, this));
     91   }
     92 
     93   void ShutDownImpl() override {
     94     // Always shut down asynchronously when called through the public interface.
     95     io_task_runner_->PostTask(
     96         FROM_HERE, base::Bind(&ChannelWin::ShutDownOnIOThread, this));
     97   }
     98 
     99   void Write(MessagePtr message) override {
    100     bool write_error = false;
    101     {
    102       base::AutoLock lock(write_lock_);
    103       if (reject_writes_)
    104         return;
    105 
    106       bool write_now = !delay_writes_ && outgoing_messages_.empty();
    107       outgoing_messages_.emplace_back(std::move(message), 0);
    108 
    109       if (write_now && !WriteNoLock(outgoing_messages_.front()))
    110         reject_writes_ = write_error = true;
    111     }
    112     if (write_error) {
    113       // Do not synchronously invoke OnError(). Write() may have been called by
    114       // the delegate and we don't want to re-enter it.
    115       io_task_runner_->PostTask(FROM_HERE,
    116                                 base::Bind(&ChannelWin::OnError, this));
    117     }
    118   }
    119 
    120   void LeakHandle() override {
    121     DCHECK(io_task_runner_->RunsTasksOnCurrentThread());
    122     leak_handle_ = true;
    123   }
    124 
    125   bool GetReadPlatformHandles(
    126       size_t num_handles,
    127       const void* extra_header,
    128       size_t extra_header_size,
    129       ScopedPlatformHandleVectorPtr* handles) override {
    130     if (num_handles > std::numeric_limits<uint16_t>::max())
    131       return false;
    132     using HandleEntry = Channel::Message::HandleEntry;
    133     size_t handles_size = sizeof(HandleEntry) * num_handles;
    134     if (handles_size > extra_header_size)
    135       return false;
    136     DCHECK(extra_header);
    137     handles->reset(new PlatformHandleVector(num_handles));
    138     const HandleEntry* extra_header_handles =
    139         reinterpret_cast<const HandleEntry*>(extra_header);
    140     for (size_t i = 0; i < num_handles; i++) {
    141       (*handles)->at(i).handle =
    142           base::win::Uint32ToHandle(extra_header_handles[i].handle);
    143     }
    144     return true;
    145   }
    146 
    147  private:
    148   // May run on any thread.
    149   ~ChannelWin() override {}
    150 
    151   void StartOnIOThread() {
    152     base::MessageLoop::current()->AddDestructionObserver(this);
    153     base::MessageLoopForIO::current()->RegisterIOHandler(
    154         handle_.get().handle, this);
    155 
    156     if (wait_for_connect_) {
    157       BOOL ok = ConnectNamedPipe(handle_.get().handle,
    158                                  &connect_context_.overlapped);
    159       if (ok) {
    160         PLOG(ERROR) << "Unexpected success while waiting for pipe connection";
    161         OnError();
    162         return;
    163       }
    164 
    165       const DWORD err = GetLastError();
    166       switch (err) {
    167         case ERROR_PIPE_CONNECTED:
    168           wait_for_connect_ = false;
    169           break;
    170         case ERROR_IO_PENDING:
    171           AddRef();
    172           return;
    173         case ERROR_NO_DATA:
    174           OnError();
    175           return;
    176       }
    177     }
    178 
    179     // Now that we have registered our IOHandler, we can start writing.
    180     {
    181       base::AutoLock lock(write_lock_);
    182       if (delay_writes_) {
    183         delay_writes_ = false;
    184         WriteNextNoLock();
    185       }
    186     }
    187 
    188     // Keep this alive in case we synchronously run shutdown.
    189     scoped_refptr<ChannelWin> keep_alive(this);
    190     ReadMore(0);
    191   }
    192 
    193   void ShutDownOnIOThread() {
    194     base::MessageLoop::current()->RemoveDestructionObserver(this);
    195 
    196     // BUG(crbug.com/583525): This function is expected to be called once, and
    197     // |handle_| should be valid at this point.
    198     CHECK(handle_.is_valid());
    199     CancelIo(handle_.get().handle);
    200     if (leak_handle_)
    201       ignore_result(handle_.release());
    202     handle_.reset();
    203 
    204     // May destroy the |this| if it was the last reference.
    205     self_ = nullptr;
    206   }
    207 
    208   // base::MessageLoop::DestructionObserver:
    209   void WillDestroyCurrentMessageLoop() override {
    210     DCHECK(io_task_runner_->RunsTasksOnCurrentThread());
    211     if (self_)
    212       ShutDownOnIOThread();
    213   }
    214 
    215   // base::MessageLoop::IOHandler:
    216   void OnIOCompleted(base::MessageLoopForIO::IOContext* context,
    217                      DWORD bytes_transfered,
    218                      DWORD error) override {
    219     if (error != ERROR_SUCCESS) {
    220       OnError();
    221     } else if (context == &connect_context_) {
    222       DCHECK(wait_for_connect_);
    223       wait_for_connect_ = false;
    224       ReadMore(0);
    225 
    226       base::AutoLock lock(write_lock_);
    227       if (delay_writes_) {
    228         delay_writes_ = false;
    229         WriteNextNoLock();
    230       }
    231     } else if (context == &read_context_) {
    232       OnReadDone(static_cast<size_t>(bytes_transfered));
    233     } else {
    234       CHECK(context == &write_context_);
    235       OnWriteDone(static_cast<size_t>(bytes_transfered));
    236     }
    237     Release();  // Balancing reference taken after ReadFile / WriteFile.
    238   }
    239 
    240   void OnReadDone(size_t bytes_read) {
    241     if (bytes_read > 0) {
    242       size_t next_read_size = 0;
    243       if (OnReadComplete(bytes_read, &next_read_size)) {
    244         ReadMore(next_read_size);
    245       } else {
    246         OnError();
    247       }
    248     } else if (bytes_read == 0) {
    249       OnError();
    250     }
    251   }
    252 
    253   void OnWriteDone(size_t bytes_written) {
    254     if (bytes_written == 0)
    255       return;
    256 
    257     bool write_error = false;
    258     {
    259       base::AutoLock lock(write_lock_);
    260 
    261       DCHECK(!outgoing_messages_.empty());
    262 
    263       MessageView& message_view = outgoing_messages_.front();
    264       message_view.advance_data_offset(bytes_written);
    265       if (message_view.data_num_bytes() == 0) {
    266         Channel::MessagePtr message = message_view.TakeChannelMessage();
    267         outgoing_messages_.pop_front();
    268 
    269         // Clear any handles so they don't get closed on destruction.
    270         ScopedPlatformHandleVectorPtr handles = message->TakeHandles();
    271         if (handles)
    272           handles->clear();
    273       }
    274 
    275       if (!WriteNextNoLock())
    276         reject_writes_ = write_error = true;
    277     }
    278     if (write_error)
    279       OnError();
    280   }
    281 
    282   void ReadMore(size_t next_read_size_hint) {
    283     size_t buffer_capacity = next_read_size_hint;
    284     char* buffer = GetReadBuffer(&buffer_capacity);
    285     DCHECK_GT(buffer_capacity, 0u);
    286 
    287     BOOL ok = ReadFile(handle_.get().handle,
    288                        buffer,
    289                        static_cast<DWORD>(buffer_capacity),
    290                        NULL,
    291                        &read_context_.overlapped);
    292 
    293     if (ok || GetLastError() == ERROR_IO_PENDING) {
    294       AddRef();  // Will be balanced in OnIOCompleted
    295     } else {
    296       OnError();
    297     }
    298   }
    299 
    300   // Attempts to write a message directly to the channel. If the full message
    301   // cannot be written, it's queued and a wait is initiated to write the message
    302   // ASAP on the I/O thread.
    303   bool WriteNoLock(const MessageView& message_view) {
    304     BOOL ok = WriteFile(handle_.get().handle,
    305                         message_view.data(),
    306                         static_cast<DWORD>(message_view.data_num_bytes()),
    307                         NULL,
    308                         &write_context_.overlapped);
    309 
    310     if (ok || GetLastError() == ERROR_IO_PENDING) {
    311       AddRef();  // Will be balanced in OnIOCompleted.
    312       return true;
    313     }
    314     return false;
    315   }
    316 
    317   bool WriteNextNoLock() {
    318     if (outgoing_messages_.empty())
    319       return true;
    320     return WriteNoLock(outgoing_messages_.front());
    321   }
    322 
    323   // Keeps the Channel alive at least until explicit shutdown on the IO thread.
    324   scoped_refptr<Channel> self_;
    325 
    326   ScopedPlatformHandle handle_;
    327   scoped_refptr<base::TaskRunner> io_task_runner_;
    328 
    329   base::MessageLoopForIO::IOContext connect_context_;
    330   base::MessageLoopForIO::IOContext read_context_;
    331   base::MessageLoopForIO::IOContext write_context_;
    332 
    333   // Protects |reject_writes_| and |outgoing_messages_|.
    334   base::Lock write_lock_;
    335 
    336   bool delay_writes_ = true;
    337 
    338   bool reject_writes_ = false;
    339   std::deque<MessageView> outgoing_messages_;
    340 
    341   bool wait_for_connect_;
    342 
    343   bool leak_handle_ = false;
    344 
    345   DISALLOW_COPY_AND_ASSIGN(ChannelWin);
    346 };
    347 
    348 }  // namespace
    349 
    350 // static
    351 scoped_refptr<Channel> Channel::Create(
    352     Delegate* delegate,
    353     ConnectionParams connection_params,
    354     scoped_refptr<base::TaskRunner> io_task_runner) {
    355   return new ChannelWin(delegate, connection_params.TakeChannelHandle(),
    356                         io_task_runner);
    357 }
    358 
    359 }  // namespace edk
    360 }  // namespace mojo
    361