Home | History | Annotate | Download | only in system
      1 // Copyright 2014 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "mojo/system/raw_channel.h"
      6 
      7 #include <string.h>
      8 
      9 #include <algorithm>
     10 
     11 #include "base/bind.h"
     12 #include "base/location.h"
     13 #include "base/logging.h"
     14 #include "base/message_loop/message_loop.h"
     15 #include "base/stl_util.h"
     16 #include "mojo/system/message_in_transit.h"
     17 #include "mojo/system/transport_data.h"
     18 
     19 namespace mojo {
     20 namespace system {
     21 
     22 const size_t kReadSize = 4096;
     23 
     24 // RawChannel::ReadBuffer ------------------------------------------------------
     25 
     26 RawChannel::ReadBuffer::ReadBuffer() : buffer_(kReadSize), num_valid_bytes_(0) {
     27 }
     28 
     29 RawChannel::ReadBuffer::~ReadBuffer() {
     30 }
     31 
     32 void RawChannel::ReadBuffer::GetBuffer(char** addr, size_t* size) {
     33   DCHECK_GE(buffer_.size(), num_valid_bytes_ + kReadSize);
     34   *addr = &buffer_[0] + num_valid_bytes_;
     35   *size = kReadSize;
     36 }
     37 
     38 // RawChannel::WriteBuffer -----------------------------------------------------
     39 
     40 RawChannel::WriteBuffer::WriteBuffer(size_t serialized_platform_handle_size)
     41     : serialized_platform_handle_size_(serialized_platform_handle_size),
     42       platform_handles_offset_(0),
     43       data_offset_(0) {
     44 }
     45 
     46 RawChannel::WriteBuffer::~WriteBuffer() {
     47   STLDeleteElements(&message_queue_);
     48 }
     49 
     50 bool RawChannel::WriteBuffer::HavePlatformHandlesToSend() const {
     51   if (message_queue_.empty())
     52     return false;
     53 
     54   const TransportData* transport_data =
     55       message_queue_.front()->transport_data();
     56   if (!transport_data)
     57     return false;
     58 
     59   const embedder::PlatformHandleVector* all_platform_handles =
     60       transport_data->platform_handles();
     61   if (!all_platform_handles) {
     62     DCHECK_EQ(platform_handles_offset_, 0u);
     63     return false;
     64   }
     65   if (platform_handles_offset_ >= all_platform_handles->size()) {
     66     DCHECK_EQ(platform_handles_offset_, all_platform_handles->size());
     67     return false;
     68   }
     69 
     70   return true;
     71 }
     72 
     73 void RawChannel::WriteBuffer::GetPlatformHandlesToSend(
     74     size_t* num_platform_handles,
     75     embedder::PlatformHandle** platform_handles,
     76     void** serialization_data) {
     77   DCHECK(HavePlatformHandlesToSend());
     78 
     79   TransportData* transport_data = message_queue_.front()->transport_data();
     80   embedder::PlatformHandleVector* all_platform_handles =
     81       transport_data->platform_handles();
     82   *num_platform_handles =
     83       all_platform_handles->size() - platform_handles_offset_;
     84   *platform_handles = &(*all_platform_handles)[platform_handles_offset_];
     85   size_t serialization_data_offset =
     86       transport_data->platform_handle_table_offset();
     87   DCHECK_GT(serialization_data_offset, 0u);
     88   serialization_data_offset +=
     89       platform_handles_offset_ * serialized_platform_handle_size_;
     90   *serialization_data =
     91       static_cast<char*>(transport_data->buffer()) + serialization_data_offset;
     92 }
     93 
     94 void RawChannel::WriteBuffer::GetBuffers(std::vector<Buffer>* buffers) const {
     95   buffers->clear();
     96 
     97   if (message_queue_.empty())
     98     return;
     99 
    100   MessageInTransit* message = message_queue_.front();
    101   DCHECK_LT(data_offset_, message->total_size());
    102   size_t bytes_to_write = message->total_size() - data_offset_;
    103 
    104   size_t transport_data_buffer_size =
    105       message->transport_data() ? message->transport_data()->buffer_size() : 0;
    106 
    107   if (!transport_data_buffer_size) {
    108     // Only write from the main buffer.
    109     DCHECK_LT(data_offset_, message->main_buffer_size());
    110     DCHECK_LE(bytes_to_write, message->main_buffer_size());
    111     Buffer buffer = {
    112         static_cast<const char*>(message->main_buffer()) + data_offset_,
    113         bytes_to_write};
    114     buffers->push_back(buffer);
    115     return;
    116   }
    117 
    118   if (data_offset_ >= message->main_buffer_size()) {
    119     // Only write from the transport data buffer.
    120     DCHECK_LT(data_offset_ - message->main_buffer_size(),
    121               transport_data_buffer_size);
    122     DCHECK_LE(bytes_to_write, transport_data_buffer_size);
    123     Buffer buffer = {
    124         static_cast<const char*>(message->transport_data()->buffer()) +
    125             (data_offset_ - message->main_buffer_size()),
    126         bytes_to_write};
    127     buffers->push_back(buffer);
    128     return;
    129   }
    130 
    131   // TODO(vtl): We could actually send out buffers from multiple messages, with
    132   // the "stopping" condition being reaching a message with platform handles
    133   // attached.
    134 
    135   // Write from both buffers.
    136   DCHECK_EQ(
    137       bytes_to_write,
    138       message->main_buffer_size() - data_offset_ + transport_data_buffer_size);
    139   Buffer buffer1 = {
    140       static_cast<const char*>(message->main_buffer()) + data_offset_,
    141       message->main_buffer_size() - data_offset_};
    142   buffers->push_back(buffer1);
    143   Buffer buffer2 = {
    144       static_cast<const char*>(message->transport_data()->buffer()),
    145       transport_data_buffer_size};
    146   buffers->push_back(buffer2);
    147 }
    148 
    149 // RawChannel ------------------------------------------------------------------
    150 
    151 RawChannel::RawChannel()
    152     : message_loop_for_io_(nullptr),
    153       delegate_(nullptr),
    154       read_stopped_(false),
    155       write_stopped_(false),
    156       weak_ptr_factory_(this) {
    157 }
    158 
    159 RawChannel::~RawChannel() {
    160   DCHECK(!read_buffer_);
    161   DCHECK(!write_buffer_);
    162 
    163   // No need to take the |write_lock_| here -- if there are still weak pointers
    164   // outstanding, then we're hosed anyway (since we wouldn't be able to
    165   // invalidate them cleanly, since we might not be on the I/O thread).
    166   DCHECK(!weak_ptr_factory_.HasWeakPtrs());
    167 }
    168 
    169 bool RawChannel::Init(Delegate* delegate) {
    170   DCHECK(delegate);
    171 
    172   DCHECK(!delegate_);
    173   delegate_ = delegate;
    174 
    175   CHECK_EQ(base::MessageLoop::current()->type(), base::MessageLoop::TYPE_IO);
    176   DCHECK(!message_loop_for_io_);
    177   message_loop_for_io_ =
    178       static_cast<base::MessageLoopForIO*>(base::MessageLoop::current());
    179 
    180   // No need to take the lock. No one should be using us yet.
    181   DCHECK(!read_buffer_);
    182   read_buffer_.reset(new ReadBuffer);
    183   DCHECK(!write_buffer_);
    184   write_buffer_.reset(new WriteBuffer(GetSerializedPlatformHandleSize()));
    185 
    186   if (!OnInit()) {
    187     delegate_ = nullptr;
    188     message_loop_for_io_ = nullptr;
    189     read_buffer_.reset();
    190     write_buffer_.reset();
    191     return false;
    192   }
    193 
    194   IOResult io_result = ScheduleRead();
    195   if (io_result != IO_PENDING) {
    196     // This will notify the delegate about the read failure. Although we're on
    197     // the I/O thread, don't call it in the nested context.
    198     message_loop_for_io_->PostTask(FROM_HERE,
    199                                    base::Bind(&RawChannel::OnReadCompleted,
    200                                               weak_ptr_factory_.GetWeakPtr(),
    201                                               io_result,
    202                                               0));
    203   }
    204 
    205   // ScheduleRead() failure is treated as a read failure (by notifying the
    206   // delegate), not as an init failure.
    207   return true;
    208 }
    209 
    210 void RawChannel::Shutdown() {
    211   DCHECK_EQ(base::MessageLoop::current(), message_loop_for_io_);
    212 
    213   base::AutoLock locker(write_lock_);
    214 
    215   LOG_IF(WARNING, !write_buffer_->message_queue_.empty())
    216       << "Shutting down RawChannel with write buffer nonempty";
    217 
    218   // Reset the delegate so that it won't receive further calls.
    219   delegate_ = nullptr;
    220   read_stopped_ = true;
    221   write_stopped_ = true;
    222   weak_ptr_factory_.InvalidateWeakPtrs();
    223 
    224   OnShutdownNoLock(read_buffer_.Pass(), write_buffer_.Pass());
    225 }
    226 
    227 // Reminder: This must be thread-safe.
    228 bool RawChannel::WriteMessage(scoped_ptr<MessageInTransit> message) {
    229   DCHECK(message);
    230 
    231   base::AutoLock locker(write_lock_);
    232   if (write_stopped_)
    233     return false;
    234 
    235   if (!write_buffer_->message_queue_.empty()) {
    236     EnqueueMessageNoLock(message.Pass());
    237     return true;
    238   }
    239 
    240   EnqueueMessageNoLock(message.Pass());
    241   DCHECK_EQ(write_buffer_->data_offset_, 0u);
    242 
    243   size_t platform_handles_written = 0;
    244   size_t bytes_written = 0;
    245   IOResult io_result = WriteNoLock(&platform_handles_written, &bytes_written);
    246   if (io_result == IO_PENDING)
    247     return true;
    248 
    249   bool result = OnWriteCompletedNoLock(
    250       io_result, platform_handles_written, bytes_written);
    251   if (!result) {
    252     // Even if we're on the I/O thread, don't call |OnError()| in the nested
    253     // context.
    254     message_loop_for_io_->PostTask(FROM_HERE,
    255                                    base::Bind(&RawChannel::CallOnError,
    256                                               weak_ptr_factory_.GetWeakPtr(),
    257                                               Delegate::ERROR_WRITE));
    258   }
    259 
    260   return result;
    261 }
    262 
    263 // Reminder: This must be thread-safe.
    264 bool RawChannel::IsWriteBufferEmpty() {
    265   base::AutoLock locker(write_lock_);
    266   return write_buffer_->message_queue_.empty();
    267 }
    268 
    269 void RawChannel::OnReadCompleted(IOResult io_result, size_t bytes_read) {
    270   DCHECK_EQ(base::MessageLoop::current(), message_loop_for_io_);
    271 
    272   if (read_stopped_) {
    273     NOTREACHED();
    274     return;
    275   }
    276 
    277   // Keep reading data in a loop, and dispatch messages if enough data is
    278   // received. Exit the loop if any of the following happens:
    279   //   - one or more messages were dispatched;
    280   //   - the last read failed, was a partial read or would block;
    281   //   - |Shutdown()| was called.
    282   do {
    283     switch (io_result) {
    284       case IO_SUCCEEDED:
    285         break;
    286       case IO_FAILED_SHUTDOWN:
    287       case IO_FAILED_BROKEN:
    288       case IO_FAILED_UNKNOWN:
    289         read_stopped_ = true;
    290         CallOnError(ReadIOResultToError(io_result));
    291         return;
    292       case IO_PENDING:
    293         NOTREACHED();
    294         return;
    295     }
    296 
    297     read_buffer_->num_valid_bytes_ += bytes_read;
    298 
    299     // Dispatch all the messages that we can.
    300     bool did_dispatch_message = false;
    301     // Tracks the offset of the first undispatched message in |read_buffer_|.
    302     // Currently, we copy data to ensure that this is zero at the beginning.
    303     size_t read_buffer_start = 0;
    304     size_t remaining_bytes = read_buffer_->num_valid_bytes_;
    305     size_t message_size;
    306     // Note that we rely on short-circuit evaluation here:
    307     //   - |read_buffer_start| may be an invalid index into
    308     //     |read_buffer_->buffer_| if |remaining_bytes| is zero.
    309     //   - |message_size| is only valid if |GetNextMessageSize()| returns true.
    310     // TODO(vtl): Use |message_size| more intelligently (e.g., to request the
    311     // next read).
    312     // TODO(vtl): Validate that |message_size| is sane.
    313     while (remaining_bytes > 0 && MessageInTransit::GetNextMessageSize(
    314                                       &read_buffer_->buffer_[read_buffer_start],
    315                                       remaining_bytes,
    316                                       &message_size) &&
    317            remaining_bytes >= message_size) {
    318       MessageInTransit::View message_view(
    319           message_size, &read_buffer_->buffer_[read_buffer_start]);
    320       DCHECK_EQ(message_view.total_size(), message_size);
    321 
    322       const char* error_message = nullptr;
    323       if (!message_view.IsValid(GetSerializedPlatformHandleSize(),
    324                                 &error_message)) {
    325         DCHECK(error_message);
    326         LOG(ERROR) << "Received invalid message: " << error_message;
    327         read_stopped_ = true;
    328         CallOnError(Delegate::ERROR_READ_BAD_MESSAGE);
    329         return;
    330       }
    331 
    332       if (message_view.type() == MessageInTransit::kTypeRawChannel) {
    333         if (!OnReadMessageForRawChannel(message_view)) {
    334           read_stopped_ = true;
    335           CallOnError(Delegate::ERROR_READ_BAD_MESSAGE);
    336           return;
    337         }
    338       } else {
    339         embedder::ScopedPlatformHandleVectorPtr platform_handles;
    340         if (message_view.transport_data_buffer()) {
    341           size_t num_platform_handles;
    342           const void* platform_handle_table;
    343           TransportData::GetPlatformHandleTable(
    344               message_view.transport_data_buffer(),
    345               &num_platform_handles,
    346               &platform_handle_table);
    347 
    348           if (num_platform_handles > 0) {
    349             platform_handles =
    350                 GetReadPlatformHandles(num_platform_handles,
    351                                        platform_handle_table).Pass();
    352             if (!platform_handles) {
    353               LOG(ERROR) << "Invalid number of platform handles received";
    354               read_stopped_ = true;
    355               CallOnError(Delegate::ERROR_READ_BAD_MESSAGE);
    356               return;
    357             }
    358           }
    359         }
    360 
    361         // TODO(vtl): In the case that we aren't expecting any platform handles,
    362         // for the POSIX implementation, we should confirm that none are stored.
    363 
    364         // Dispatch the message.
    365         DCHECK(delegate_);
    366         delegate_->OnReadMessage(message_view, platform_handles.Pass());
    367         if (read_stopped_) {
    368           // |Shutdown()| was called in |OnReadMessage()|.
    369           // TODO(vtl): Add test for this case.
    370           return;
    371         }
    372       }
    373 
    374       did_dispatch_message = true;
    375 
    376       // Update our state.
    377       read_buffer_start += message_size;
    378       remaining_bytes -= message_size;
    379     }
    380 
    381     if (read_buffer_start > 0) {
    382       // Move data back to start.
    383       read_buffer_->num_valid_bytes_ = remaining_bytes;
    384       if (read_buffer_->num_valid_bytes_ > 0) {
    385         memmove(&read_buffer_->buffer_[0],
    386                 &read_buffer_->buffer_[read_buffer_start],
    387                 remaining_bytes);
    388       }
    389       read_buffer_start = 0;
    390     }
    391 
    392     if (read_buffer_->buffer_.size() - read_buffer_->num_valid_bytes_ <
    393         kReadSize) {
    394       // Use power-of-2 buffer sizes.
    395       // TODO(vtl): Make sure the buffer doesn't get too large (and enforce the
    396       // maximum message size to whatever extent necessary).
    397       // TODO(vtl): We may often be able to peek at the header and get the real
    398       // required extra space (which may be much bigger than |kReadSize|).
    399       size_t new_size = std::max(read_buffer_->buffer_.size(), kReadSize);
    400       while (new_size < read_buffer_->num_valid_bytes_ + kReadSize)
    401         new_size *= 2;
    402 
    403       // TODO(vtl): It's suboptimal to zero out the fresh memory.
    404       read_buffer_->buffer_.resize(new_size, 0);
    405     }
    406 
    407     // (1) If we dispatched any messages, stop reading for now (and let the
    408     // message loop do its thing for another round).
    409     // TODO(vtl): Is this the behavior we want? (Alternatives: i. Dispatch only
    410     // a single message. Risks: slower, more complex if we want to avoid lots of
    411     // copying. ii. Keep reading until there's no more data and dispatch all the
    412     // messages we can. Risks: starvation of other users of the message loop.)
    413     // (2) If we didn't max out |kReadSize|, stop reading for now.
    414     bool schedule_for_later = did_dispatch_message || bytes_read < kReadSize;
    415     bytes_read = 0;
    416     io_result = schedule_for_later ? ScheduleRead() : Read(&bytes_read);
    417   } while (io_result != IO_PENDING);
    418 }
    419 
    420 void RawChannel::OnWriteCompleted(IOResult io_result,
    421                                   size_t platform_handles_written,
    422                                   size_t bytes_written) {
    423   DCHECK_EQ(base::MessageLoop::current(), message_loop_for_io_);
    424   DCHECK_NE(io_result, IO_PENDING);
    425 
    426   bool did_fail = false;
    427   {
    428     base::AutoLock locker(write_lock_);
    429     DCHECK_EQ(write_stopped_, write_buffer_->message_queue_.empty());
    430 
    431     if (write_stopped_) {
    432       NOTREACHED();
    433       return;
    434     }
    435 
    436     did_fail = !OnWriteCompletedNoLock(
    437                    io_result, platform_handles_written, bytes_written);
    438   }
    439 
    440   if (did_fail)
    441     CallOnError(Delegate::ERROR_WRITE);
    442 }
    443 
    444 void RawChannel::EnqueueMessageNoLock(scoped_ptr<MessageInTransit> message) {
    445   write_lock_.AssertAcquired();
    446   write_buffer_->message_queue_.push_back(message.release());
    447 }
    448 
    449 bool RawChannel::OnReadMessageForRawChannel(
    450     const MessageInTransit::View& message_view) {
    451   // No non-implementation specific |RawChannel| control messages.
    452   LOG(ERROR) << "Invalid control message (subtype " << message_view.subtype()
    453              << ")";
    454   return false;
    455 }
    456 
    457 // static
    458 RawChannel::Delegate::Error RawChannel::ReadIOResultToError(
    459     IOResult io_result) {
    460   switch (io_result) {
    461     case IO_FAILED_SHUTDOWN:
    462       return Delegate::ERROR_READ_SHUTDOWN;
    463     case IO_FAILED_BROKEN:
    464       return Delegate::ERROR_READ_BROKEN;
    465     case IO_FAILED_UNKNOWN:
    466       return Delegate::ERROR_READ_UNKNOWN;
    467     case IO_SUCCEEDED:
    468     case IO_PENDING:
    469       NOTREACHED();
    470       break;
    471   }
    472   return Delegate::ERROR_READ_UNKNOWN;
    473 }
    474 
    475 void RawChannel::CallOnError(Delegate::Error error) {
    476   DCHECK_EQ(base::MessageLoop::current(), message_loop_for_io_);
    477   // TODO(vtl): Add a "write_lock_.AssertNotAcquired()"?
    478   if (delegate_)
    479     delegate_->OnError(error);
    480 }
    481 
    482 bool RawChannel::OnWriteCompletedNoLock(IOResult io_result,
    483                                         size_t platform_handles_written,
    484                                         size_t bytes_written) {
    485   write_lock_.AssertAcquired();
    486 
    487   DCHECK(!write_stopped_);
    488   DCHECK(!write_buffer_->message_queue_.empty());
    489 
    490   if (io_result == IO_SUCCEEDED) {
    491     write_buffer_->platform_handles_offset_ += platform_handles_written;
    492     write_buffer_->data_offset_ += bytes_written;
    493 
    494     MessageInTransit* message = write_buffer_->message_queue_.front();
    495     if (write_buffer_->data_offset_ >= message->total_size()) {
    496       // Complete write.
    497       CHECK_EQ(write_buffer_->data_offset_, message->total_size());
    498       write_buffer_->message_queue_.pop_front();
    499       delete message;
    500       write_buffer_->platform_handles_offset_ = 0;
    501       write_buffer_->data_offset_ = 0;
    502 
    503       if (write_buffer_->message_queue_.empty())
    504         return true;
    505     }
    506 
    507     // Schedule the next write.
    508     io_result = ScheduleWriteNoLock();
    509     if (io_result == IO_PENDING)
    510       return true;
    511     DCHECK_NE(io_result, IO_SUCCEEDED);
    512   }
    513 
    514   write_stopped_ = true;
    515   STLDeleteElements(&write_buffer_->message_queue_);
    516   write_buffer_->platform_handles_offset_ = 0;
    517   write_buffer_->data_offset_ = 0;
    518   return false;
    519 }
    520 
    521 }  // namespace system
    522 }  // namespace mojo
    523