Home | History | Annotate | Download | only in core
      1 // Copyright 2016 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "mojo/core/channel.h"
      6 
      7 #include <stdint.h>
      8 #include <windows.h>
      9 
     10 #include <algorithm>
     11 #include <limits>
     12 #include <memory>
     13 
     14 #include "base/bind.h"
     15 #include "base/containers/queue.h"
     16 #include "base/location.h"
     17 #include "base/macros.h"
     18 #include "base/memory/ref_counted.h"
     19 #include "base/message_loop/message_loop_current.h"
     20 #include "base/message_loop/message_pump_for_io.h"
     21 #include "base/process/process_handle.h"
     22 #include "base/synchronization/lock.h"
     23 #include "base/task_runner.h"
     24 #include "base/win/scoped_handle.h"
     25 #include "base/win/win_util.h"
     26 
     27 namespace mojo {
     28 namespace core {
     29 
     30 namespace {
     31 
     32 class ChannelWin : public Channel,
     33                    public base::MessageLoopCurrent::DestructionObserver,
     34                    public base::MessagePumpForIO::IOHandler {
     35  public:
     36   ChannelWin(Delegate* delegate,
     37              ConnectionParams connection_params,
     38              scoped_refptr<base::TaskRunner> io_task_runner)
     39       : Channel(delegate), self_(this), io_task_runner_(io_task_runner) {
     40     if (connection_params.server_endpoint().is_valid()) {
     41       handle_ = connection_params.TakeServerEndpoint()
     42                     .TakePlatformHandle()
     43                     .TakeHandle();
     44       needs_connection_ = true;
     45     } else {
     46       handle_ =
     47           connection_params.TakeEndpoint().TakePlatformHandle().TakeHandle();
     48     }
     49 
     50     CHECK(handle_.IsValid());
     51   }
     52 
     53   void Start() override {
     54     io_task_runner_->PostTask(
     55         FROM_HERE, base::BindOnce(&ChannelWin::StartOnIOThread, this));
     56   }
     57 
     58   void ShutDownImpl() override {
     59     // Always shut down asynchronously when called through the public interface.
     60     io_task_runner_->PostTask(
     61         FROM_HERE, base::BindOnce(&ChannelWin::ShutDownOnIOThread, this));
     62   }
     63 
     64   void Write(MessagePtr message) override {
     65     if (remote_process().is_valid()) {
     66       // If we know the remote process handle, we transfer all outgoing handles
     67       // to the process now rewriting them in the message.
     68       std::vector<PlatformHandleInTransit> handles = message->TakeHandles();
     69       for (auto& handle : handles) {
     70         if (handle.handle().is_valid())
     71           handle.TransferToProcess(remote_process().Clone());
     72       }
     73       message->SetHandles(std::move(handles));
     74     }
     75 
     76     bool write_error = false;
     77     {
     78       base::AutoLock lock(write_lock_);
     79       if (reject_writes_)
     80         return;
     81 
     82       bool write_now = !delay_writes_ && outgoing_messages_.empty();
     83       outgoing_messages_.emplace_back(std::move(message));
     84       if (write_now && !WriteNoLock(outgoing_messages_.front()))
     85         reject_writes_ = write_error = true;
     86     }
     87     if (write_error) {
     88       // Do not synchronously invoke OnWriteError(). Write() may have been
     89       // called by the delegate and we don't want to re-enter it.
     90       io_task_runner_->PostTask(FROM_HERE,
     91                                 base::BindOnce(&ChannelWin::OnWriteError, this,
     92                                                Error::kDisconnected));
     93     }
     94   }
     95 
     96   void LeakHandle() override {
     97     DCHECK(io_task_runner_->RunsTasksInCurrentSequence());
     98     leak_handle_ = true;
     99   }
    100 
    101   bool GetReadPlatformHandles(const void* payload,
    102                               size_t payload_size,
    103                               size_t num_handles,
    104                               const void* extra_header,
    105                               size_t extra_header_size,
    106                               std::vector<PlatformHandle>* handles,
    107                               bool* deferred) override {
    108     DCHECK(extra_header);
    109     if (num_handles > std::numeric_limits<uint16_t>::max())
    110       return false;
    111     using HandleEntry = Channel::Message::HandleEntry;
    112     size_t handles_size = sizeof(HandleEntry) * num_handles;
    113     if (handles_size > extra_header_size)
    114       return false;
    115     handles->reserve(num_handles);
    116     const HandleEntry* extra_header_handles =
    117         reinterpret_cast<const HandleEntry*>(extra_header);
    118     for (size_t i = 0; i < num_handles; i++) {
    119       HANDLE handle_value =
    120           base::win::Uint32ToHandle(extra_header_handles[i].handle);
    121       if (remote_process().is_valid()) {
    122         // If we know the remote process's handle, we assume it doesn't know
    123         // ours; that means any handle values still belong to that process, and
    124         // we need to transfer them to this process.
    125         handle_value = PlatformHandleInTransit::TakeIncomingRemoteHandle(
    126                            handle_value, remote_process().get())
    127                            .ReleaseHandle();
    128       }
    129       handles->emplace_back(base::win::ScopedHandle(std::move(handle_value)));
    130     }
    131     return true;
    132   }
    133 
    134  private:
    135   // May run on any thread.
    136   ~ChannelWin() override {}
    137 
    138   void StartOnIOThread() {
    139     base::MessageLoopCurrent::Get()->AddDestructionObserver(this);
    140     base::MessageLoopCurrentForIO::Get()->RegisterIOHandler(handle_.Get(),
    141                                                             this);
    142 
    143     if (needs_connection_) {
    144       BOOL ok = ::ConnectNamedPipe(handle_.Get(), &connect_context_.overlapped);
    145       if (ok) {
    146         PLOG(ERROR) << "Unexpected success while waiting for pipe connection";
    147         OnError(Error::kConnectionFailed);
    148         return;
    149       }
    150 
    151       const DWORD err = GetLastError();
    152       switch (err) {
    153         case ERROR_PIPE_CONNECTED:
    154           break;
    155         case ERROR_IO_PENDING:
    156           is_connect_pending_ = true;
    157           AddRef();
    158           return;
    159         case ERROR_NO_DATA:
    160         default:
    161           OnError(Error::kConnectionFailed);
    162           return;
    163       }
    164     }
    165 
    166     // Now that we have registered our IOHandler, we can start writing.
    167     {
    168       base::AutoLock lock(write_lock_);
    169       if (delay_writes_) {
    170         delay_writes_ = false;
    171         WriteNextNoLock();
    172       }
    173     }
    174 
    175     // Keep this alive in case we synchronously run shutdown, via OnError(),
    176     // as a result of a ReadFile() failure on the channel.
    177     scoped_refptr<ChannelWin> keep_alive(this);
    178     ReadMore(0);
    179   }
    180 
    181   void ShutDownOnIOThread() {
    182     base::MessageLoopCurrent::Get()->RemoveDestructionObserver(this);
    183 
    184     // TODO(https://crbug.com/583525): This function is expected to be called
    185     // once, and |handle_| should be valid at this point.
    186     CHECK(handle_.IsValid());
    187     CancelIo(handle_.Get());
    188     if (leak_handle_)
    189       ignore_result(handle_.Take());
    190     else
    191       handle_.Close();
    192 
    193     // Allow |this| to be destroyed as soon as no IO is pending.
    194     self_ = nullptr;
    195   }
    196 
    197   // base::MessageLoopCurrent::DestructionObserver:
    198   void WillDestroyCurrentMessageLoop() override {
    199     DCHECK(io_task_runner_->RunsTasksInCurrentSequence());
    200     if (self_)
    201       ShutDownOnIOThread();
    202   }
    203 
    204   // base::MessageLoop::IOHandler:
    205   void OnIOCompleted(base::MessagePumpForIO::IOContext* context,
    206                      DWORD bytes_transfered,
    207                      DWORD error) override {
    208     if (error != ERROR_SUCCESS) {
    209       if (context == &write_context_) {
    210         {
    211           base::AutoLock lock(write_lock_);
    212           reject_writes_ = true;
    213         }
    214         OnWriteError(Error::kDisconnected);
    215       } else {
    216         OnError(Error::kDisconnected);
    217       }
    218     } else if (context == &connect_context_) {
    219       DCHECK(is_connect_pending_);
    220       is_connect_pending_ = false;
    221       ReadMore(0);
    222 
    223       base::AutoLock lock(write_lock_);
    224       if (delay_writes_) {
    225         delay_writes_ = false;
    226         WriteNextNoLock();
    227       }
    228     } else if (context == &read_context_) {
    229       OnReadDone(static_cast<size_t>(bytes_transfered));
    230     } else {
    231       CHECK(context == &write_context_);
    232       OnWriteDone(static_cast<size_t>(bytes_transfered));
    233     }
    234     Release();
    235   }
    236 
    237   void OnReadDone(size_t bytes_read) {
    238     DCHECK(is_read_pending_);
    239     is_read_pending_ = false;
    240 
    241     if (bytes_read > 0) {
    242       size_t next_read_size = 0;
    243       if (OnReadComplete(bytes_read, &next_read_size)) {
    244         ReadMore(next_read_size);
    245       } else {
    246         OnError(Error::kReceivedMalformedData);
    247       }
    248     } else if (bytes_read == 0) {
    249       OnError(Error::kDisconnected);
    250     }
    251   }
    252 
    253   void OnWriteDone(size_t bytes_written) {
    254     if (bytes_written == 0)
    255       return;
    256 
    257     bool write_error = false;
    258     {
    259       base::AutoLock lock(write_lock_);
    260 
    261       DCHECK(is_write_pending_);
    262       is_write_pending_ = false;
    263       DCHECK(!outgoing_messages_.empty());
    264 
    265       Channel::MessagePtr message = std::move(outgoing_messages_.front());
    266       outgoing_messages_.pop_front();
    267 
    268       // Invalidate all the scoped handles so we don't attempt to close them.
    269       std::vector<PlatformHandleInTransit> handles = message->TakeHandles();
    270       for (auto& handle : handles)
    271         handle.CompleteTransit();
    272 
    273       // Overlapped WriteFile() to a pipe should always fully complete.
    274       if (message->data_num_bytes() != bytes_written)
    275         reject_writes_ = write_error = true;
    276       else if (!WriteNextNoLock())
    277         reject_writes_ = write_error = true;
    278     }
    279     if (write_error)
    280       OnWriteError(Error::kDisconnected);
    281   }
    282 
    283   void ReadMore(size_t next_read_size_hint) {
    284     DCHECK(!is_read_pending_);
    285 
    286     size_t buffer_capacity = next_read_size_hint;
    287     char* buffer = GetReadBuffer(&buffer_capacity);
    288     DCHECK_GT(buffer_capacity, 0u);
    289 
    290     BOOL ok =
    291         ::ReadFile(handle_.Get(), buffer, static_cast<DWORD>(buffer_capacity),
    292                    NULL, &read_context_.overlapped);
    293     if (ok || GetLastError() == ERROR_IO_PENDING) {
    294       is_read_pending_ = true;
    295       AddRef();
    296     } else {
    297       OnError(Error::kDisconnected);
    298     }
    299   }
    300 
    301   // Attempts to write a message directly to the channel. If the full message
    302   // cannot be written, it's queued and a wait is initiated to write the message
    303   // ASAP on the I/O thread.
    304   bool WriteNoLock(const Channel::MessagePtr& message) {
    305     BOOL ok = WriteFile(handle_.Get(), message->data(),
    306                         static_cast<DWORD>(message->data_num_bytes()), NULL,
    307                         &write_context_.overlapped);
    308     if (ok || GetLastError() == ERROR_IO_PENDING) {
    309       is_write_pending_ = true;
    310       AddRef();
    311       return true;
    312     }
    313     return false;
    314   }
    315 
    316   bool WriteNextNoLock() {
    317     if (outgoing_messages_.empty())
    318       return true;
    319     return WriteNoLock(outgoing_messages_.front());
    320   }
    321 
    322   void OnWriteError(Error error) {
    323     DCHECK(io_task_runner_->RunsTasksInCurrentSequence());
    324     DCHECK(reject_writes_);
    325 
    326     if (error == Error::kDisconnected) {
    327       // If we can't write because the pipe is disconnected then continue
    328       // reading to fetch any in-flight messages, relying on end-of-stream to
    329       // signal the actual disconnection.
    330       if (is_read_pending_ || is_connect_pending_)
    331         return;
    332     }
    333 
    334     OnError(error);
    335   }
    336 
    337   // Keeps the Channel alive at least until explicit shutdown on the IO thread.
    338   scoped_refptr<Channel> self_;
    339 
    340   // The pipe handle this Channel uses for communication.
    341   base::win::ScopedHandle handle_;
    342 
    343   // Indicates whether |handle_| must wait for a connection.
    344   bool needs_connection_ = false;
    345 
    346   const scoped_refptr<base::TaskRunner> io_task_runner_;
    347 
    348   base::MessagePumpForIO::IOContext connect_context_;
    349   base::MessagePumpForIO::IOContext read_context_;
    350   bool is_connect_pending_ = false;
    351   bool is_read_pending_ = false;
    352 
    353   // Protects all fields potentially accessed on multiple threads via Write().
    354   base::Lock write_lock_;
    355   base::MessagePumpForIO::IOContext write_context_;
    356   base::circular_deque<Channel::MessagePtr> outgoing_messages_;
    357   bool delay_writes_ = true;
    358   bool reject_writes_ = false;
    359   bool is_write_pending_ = false;
    360 
    361   bool leak_handle_ = false;
    362 
    363   DISALLOW_COPY_AND_ASSIGN(ChannelWin);
    364 };
    365 
    366 }  // namespace
    367 
    368 // static
    369 scoped_refptr<Channel> Channel::Create(
    370     Delegate* delegate,
    371     ConnectionParams connection_params,
    372     scoped_refptr<base::TaskRunner> io_task_runner) {
    373   return new ChannelWin(delegate, std::move(connection_params), io_task_runner);
    374 }
    375 
    376 }  // namespace core
    377 }  // namespace mojo
    378