Home | History | Annotate | Download | only in system
      1 // Copyright 2014 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "mojo/system/raw_channel.h"
      6 
      7 #include <string.h>
      8 
      9 #include <algorithm>
     10 
     11 #include "base/bind.h"
     12 #include "base/location.h"
     13 #include "base/logging.h"
     14 #include "base/message_loop/message_loop.h"
     15 #include "base/stl_util.h"
     16 #include "mojo/system/message_in_transit.h"
     17 #include "mojo/system/transport_data.h"
     18 
     19 namespace mojo {
     20 namespace system {
     21 
     22 const size_t kReadSize = 4096;
     23 
     24 // RawChannel::ReadBuffer ------------------------------------------------------
     25 
     26 RawChannel::ReadBuffer::ReadBuffer()
     27     : buffer_(kReadSize),
     28       num_valid_bytes_(0) {
     29 }
     30 
     31 RawChannel::ReadBuffer::~ReadBuffer() {
     32 }
     33 
     34 void RawChannel::ReadBuffer::GetBuffer(char** addr, size_t* size) {
     35   DCHECK_GE(buffer_.size(), num_valid_bytes_ + kReadSize);
     36   *addr = &buffer_[0] + num_valid_bytes_;
     37   *size = kReadSize;
     38 }
     39 
     40 // RawChannel::WriteBuffer -----------------------------------------------------
     41 
     42 RawChannel::WriteBuffer::WriteBuffer(size_t serialized_platform_handle_size)
     43     : serialized_platform_handle_size_(serialized_platform_handle_size),
     44       platform_handles_offset_(0),
     45       data_offset_(0) {
     46 }
     47 
     48 RawChannel::WriteBuffer::~WriteBuffer() {
     49   STLDeleteElements(&message_queue_);
     50 }
     51 
     52 bool RawChannel::WriteBuffer::HavePlatformHandlesToSend() const {
     53   if (message_queue_.empty())
     54     return false;
     55 
     56   const TransportData* transport_data =
     57       message_queue_.front()->transport_data();
     58   if (!transport_data)
     59     return false;
     60 
     61   const embedder::PlatformHandleVector* all_platform_handles =
     62       transport_data->platform_handles();
     63   if (!all_platform_handles) {
     64     DCHECK_EQ(platform_handles_offset_, 0u);
     65     return false;
     66   }
     67   if (platform_handles_offset_ >= all_platform_handles->size()) {
     68     DCHECK_EQ(platform_handles_offset_, all_platform_handles->size());
     69     return false;
     70   }
     71 
     72   return true;
     73 }
     74 
     75 void RawChannel::WriteBuffer::GetPlatformHandlesToSend(
     76     size_t* num_platform_handles,
     77     embedder::PlatformHandle** platform_handles,
     78     void** serialization_data) {
     79   DCHECK(HavePlatformHandlesToSend());
     80 
     81   TransportData* transport_data = message_queue_.front()->transport_data();
     82   embedder::PlatformHandleVector* all_platform_handles =
     83       transport_data->platform_handles();
     84   *num_platform_handles =
     85       all_platform_handles->size() - platform_handles_offset_;
     86   *platform_handles = &(*all_platform_handles)[platform_handles_offset_];
     87   size_t serialization_data_offset =
     88       transport_data->platform_handle_table_offset();
     89   DCHECK_GT(serialization_data_offset, 0u);
     90   serialization_data_offset +=
     91       platform_handles_offset_ * serialized_platform_handle_size_;
     92   *serialization_data =
     93       static_cast<char*>(transport_data->buffer()) + serialization_data_offset;
     94 }
     95 
     96 void RawChannel::WriteBuffer::GetBuffers(std::vector<Buffer>* buffers) const {
     97   buffers->clear();
     98 
     99   if (message_queue_.empty())
    100     return;
    101 
    102   MessageInTransit* message = message_queue_.front();
    103   DCHECK_LT(data_offset_, message->total_size());
    104   size_t bytes_to_write = message->total_size() - data_offset_;
    105 
    106   size_t transport_data_buffer_size = message->transport_data() ?
    107       message->transport_data()->buffer_size() : 0;
    108 
    109   if (!transport_data_buffer_size) {
    110     // Only write from the main buffer.
    111     DCHECK_LT(data_offset_, message->main_buffer_size());
    112     DCHECK_LE(bytes_to_write, message->main_buffer_size());
    113     Buffer buffer = {
    114         static_cast<const char*>(message->main_buffer()) + data_offset_,
    115         bytes_to_write};
    116     buffers->push_back(buffer);
    117     return;
    118   }
    119 
    120   if (data_offset_ >= message->main_buffer_size()) {
    121     // Only write from the transport data buffer.
    122     DCHECK_LT(data_offset_ - message->main_buffer_size(),
    123               transport_data_buffer_size);
    124     DCHECK_LE(bytes_to_write, transport_data_buffer_size);
    125     Buffer buffer = {
    126         static_cast<const char*>(message->transport_data()->buffer()) +
    127             (data_offset_ - message->main_buffer_size()),
    128         bytes_to_write};
    129     buffers->push_back(buffer);
    130     return;
    131   }
    132 
    133   // TODO(vtl): We could actually send out buffers from multiple messages, with
    134   // the "stopping" condition being reaching a message with platform handles
    135   // attached.
    136 
    137   // Write from both buffers.
    138   DCHECK_EQ(bytes_to_write, message->main_buffer_size() - data_offset_ +
    139                                 transport_data_buffer_size);
    140   Buffer buffer1 = {
    141     static_cast<const char*>(message->main_buffer()) + data_offset_,
    142     message->main_buffer_size() - data_offset_
    143   };
    144   buffers->push_back(buffer1);
    145   Buffer buffer2 = {
    146     static_cast<const char*>(message->transport_data()->buffer()),
    147     transport_data_buffer_size
    148   };
    149   buffers->push_back(buffer2);
    150 }
    151 
    152 // RawChannel ------------------------------------------------------------------
    153 
    154 RawChannel::RawChannel()
    155     : message_loop_for_io_(NULL),
    156       delegate_(NULL),
    157       read_stopped_(false),
    158       write_stopped_(false),
    159       weak_ptr_factory_(this) {
    160 }
    161 
    162 RawChannel::~RawChannel() {
    163   DCHECK(!read_buffer_);
    164   DCHECK(!write_buffer_);
    165 
    166   // No need to take the |write_lock_| here -- if there are still weak pointers
    167   // outstanding, then we're hosed anyway (since we wouldn't be able to
    168   // invalidate them cleanly, since we might not be on the I/O thread).
    169   DCHECK(!weak_ptr_factory_.HasWeakPtrs());
    170 }
    171 
    172 bool RawChannel::Init(Delegate* delegate) {
    173   DCHECK(delegate);
    174 
    175   DCHECK(!delegate_);
    176   delegate_ = delegate;
    177 
    178   CHECK_EQ(base::MessageLoop::current()->type(), base::MessageLoop::TYPE_IO);
    179   DCHECK(!message_loop_for_io_);
    180   message_loop_for_io_ =
    181       static_cast<base::MessageLoopForIO*>(base::MessageLoop::current());
    182 
    183   // No need to take the lock. No one should be using us yet.
    184   DCHECK(!read_buffer_);
    185   read_buffer_.reset(new ReadBuffer);
    186   DCHECK(!write_buffer_);
    187   write_buffer_.reset(new WriteBuffer(GetSerializedPlatformHandleSize()));
    188 
    189   if (!OnInit()) {
    190     delegate_ = NULL;
    191     message_loop_for_io_ = NULL;
    192     read_buffer_.reset();
    193     write_buffer_.reset();
    194     return false;
    195   }
    196 
    197   if (ScheduleRead() != IO_PENDING) {
    198     // This will notify the delegate about the read failure. Although we're on
    199     // the I/O thread, don't call it in the nested context.
    200     message_loop_for_io_->PostTask(
    201         FROM_HERE,
    202         base::Bind(&RawChannel::OnReadCompleted, weak_ptr_factory_.GetWeakPtr(),
    203                    false, 0));
    204   }
    205 
    206   // ScheduleRead() failure is treated as a read failure (by notifying the
    207   // delegate), not as an init failure.
    208   return true;
    209 }
    210 
    211 void RawChannel::Shutdown() {
    212   DCHECK_EQ(base::MessageLoop::current(), message_loop_for_io_);
    213 
    214   base::AutoLock locker(write_lock_);
    215 
    216   LOG_IF(WARNING, !write_buffer_->message_queue_.empty())
    217       << "Shutting down RawChannel with write buffer nonempty";
    218 
    219   // Reset the delegate so that it won't receive further calls.
    220   delegate_ = NULL;
    221   read_stopped_ = true;
    222   write_stopped_ = true;
    223   weak_ptr_factory_.InvalidateWeakPtrs();
    224 
    225   OnShutdownNoLock(read_buffer_.Pass(), write_buffer_.Pass());
    226 }
    227 
    228 // Reminder: This must be thread-safe.
    229 bool RawChannel::WriteMessage(scoped_ptr<MessageInTransit> message) {
    230   DCHECK(message);
    231 
    232   base::AutoLock locker(write_lock_);
    233   if (write_stopped_)
    234     return false;
    235 
    236   if (!write_buffer_->message_queue_.empty()) {
    237     EnqueueMessageNoLock(message.Pass());
    238     return true;
    239   }
    240 
    241   EnqueueMessageNoLock(message.Pass());
    242   DCHECK_EQ(write_buffer_->data_offset_, 0u);
    243 
    244   size_t platform_handles_written = 0;
    245   size_t bytes_written = 0;
    246   IOResult io_result = WriteNoLock(&platform_handles_written, &bytes_written);
    247   if (io_result == IO_PENDING)
    248     return true;
    249 
    250   bool result = OnWriteCompletedNoLock(io_result == IO_SUCCEEDED,
    251                                        platform_handles_written,
    252                                        bytes_written);
    253   if (!result) {
    254     // Even if we're on the I/O thread, don't call |OnFatalError()| in the
    255     // nested context.
    256     message_loop_for_io_->PostTask(
    257         FROM_HERE,
    258         base::Bind(&RawChannel::CallOnFatalError,
    259                    weak_ptr_factory_.GetWeakPtr(),
    260                    Delegate::FATAL_ERROR_WRITE));
    261   }
    262 
    263   return result;
    264 }
    265 
    266 // Reminder: This must be thread-safe.
    267 bool RawChannel::IsWriteBufferEmpty() {
    268   base::AutoLock locker(write_lock_);
    269   return write_buffer_->message_queue_.empty();
    270 }
    271 
    272 void RawChannel::OnReadCompleted(bool result, size_t bytes_read) {
    273   DCHECK_EQ(base::MessageLoop::current(), message_loop_for_io_);
    274 
    275   if (read_stopped_) {
    276     NOTREACHED();
    277     return;
    278   }
    279 
    280   IOResult io_result = result ? IO_SUCCEEDED : IO_FAILED;
    281 
    282   // Keep reading data in a loop, and dispatch messages if enough data is
    283   // received. Exit the loop if any of the following happens:
    284   //   - one or more messages were dispatched;
    285   //   - the last read failed, was a partial read or would block;
    286   //   - |Shutdown()| was called.
    287   do {
    288     if (io_result != IO_SUCCEEDED) {
    289       read_stopped_ = true;
    290       CallOnFatalError(Delegate::FATAL_ERROR_READ);
    291       return;
    292     }
    293 
    294     read_buffer_->num_valid_bytes_ += bytes_read;
    295 
    296     // Dispatch all the messages that we can.
    297     bool did_dispatch_message = false;
    298     // Tracks the offset of the first undispatched message in |read_buffer_|.
    299     // Currently, we copy data to ensure that this is zero at the beginning.
    300     size_t read_buffer_start = 0;
    301     size_t remaining_bytes = read_buffer_->num_valid_bytes_;
    302     size_t message_size;
    303     // Note that we rely on short-circuit evaluation here:
    304     //   - |read_buffer_start| may be an invalid index into
    305     //     |read_buffer_->buffer_| if |remaining_bytes| is zero.
    306     //   - |message_size| is only valid if |GetNextMessageSize()| returns true.
    307     // TODO(vtl): Use |message_size| more intelligently (e.g., to request the
    308     // next read).
    309     // TODO(vtl): Validate that |message_size| is sane.
    310     while (remaining_bytes > 0 &&
    311            MessageInTransit::GetNextMessageSize(
    312                &read_buffer_->buffer_[read_buffer_start], remaining_bytes,
    313                &message_size) &&
    314            remaining_bytes >= message_size) {
    315       MessageInTransit::View
    316           message_view(message_size, &read_buffer_->buffer_[read_buffer_start]);
    317       DCHECK_EQ(message_view.total_size(), message_size);
    318 
    319       const char* error_message = NULL;
    320       if (!message_view.IsValid(GetSerializedPlatformHandleSize(),
    321                                 &error_message)) {
    322         DCHECK(error_message);
    323         LOG(WARNING) << "Received invalid message: " << error_message;
    324         read_stopped_ = true;
    325         CallOnFatalError(Delegate::FATAL_ERROR_READ);
    326         return;
    327       }
    328 
    329       if (message_view.type() == MessageInTransit::kTypeRawChannel) {
    330         if (!OnReadMessageForRawChannel(message_view)) {
    331           read_stopped_ = true;
    332           CallOnFatalError(Delegate::FATAL_ERROR_READ);
    333           return;
    334         }
    335       } else {
    336         embedder::ScopedPlatformHandleVectorPtr platform_handles;
    337         if (message_view.transport_data_buffer()) {
    338           size_t num_platform_handles;
    339           const void* platform_handle_table;
    340           TransportData::GetPlatformHandleTable(
    341               message_view.transport_data_buffer(),
    342               &num_platform_handles,
    343               &platform_handle_table);
    344 
    345           if (num_platform_handles > 0) {
    346             platform_handles =
    347                 GetReadPlatformHandles(num_platform_handles,
    348                                        platform_handle_table).Pass();
    349             if (!platform_handles) {
    350               LOG(WARNING) << "Invalid number of platform handles received";
    351               read_stopped_ = true;
    352               CallOnFatalError(Delegate::FATAL_ERROR_READ);
    353               return;
    354             }
    355           }
    356         }
    357 
    358         // TODO(vtl): In the case that we aren't expecting any platform handles,
    359         // for the POSIX implementation, we should confirm that none are stored.
    360 
    361         // Dispatch the message.
    362         DCHECK(delegate_);
    363         delegate_->OnReadMessage(message_view, platform_handles.Pass());
    364         if (read_stopped_) {
    365           // |Shutdown()| was called in |OnReadMessage()|.
    366           // TODO(vtl): Add test for this case.
    367           return;
    368         }
    369       }
    370 
    371       did_dispatch_message = true;
    372 
    373       // Update our state.
    374       read_buffer_start += message_size;
    375       remaining_bytes -= message_size;
    376     }
    377 
    378     if (read_buffer_start > 0) {
    379       // Move data back to start.
    380       read_buffer_->num_valid_bytes_ = remaining_bytes;
    381       if (read_buffer_->num_valid_bytes_ > 0) {
    382         memmove(&read_buffer_->buffer_[0],
    383                 &read_buffer_->buffer_[read_buffer_start], remaining_bytes);
    384       }
    385       read_buffer_start = 0;
    386     }
    387 
    388     if (read_buffer_->buffer_.size() - read_buffer_->num_valid_bytes_ <
    389             kReadSize) {
    390       // Use power-of-2 buffer sizes.
    391       // TODO(vtl): Make sure the buffer doesn't get too large (and enforce the
    392       // maximum message size to whatever extent necessary).
    393       // TODO(vtl): We may often be able to peek at the header and get the real
    394       // required extra space (which may be much bigger than |kReadSize|).
    395       size_t new_size = std::max(read_buffer_->buffer_.size(), kReadSize);
    396       while (new_size < read_buffer_->num_valid_bytes_ + kReadSize)
    397         new_size *= 2;
    398 
    399       // TODO(vtl): It's suboptimal to zero out the fresh memory.
    400       read_buffer_->buffer_.resize(new_size, 0);
    401     }
    402 
    403     // (1) If we dispatched any messages, stop reading for now (and let the
    404     // message loop do its thing for another round).
    405     // TODO(vtl): Is this the behavior we want? (Alternatives: i. Dispatch only
    406     // a single message. Risks: slower, more complex if we want to avoid lots of
    407     // copying. ii. Keep reading until there's no more data and dispatch all the
    408     // messages we can. Risks: starvation of other users of the message loop.)
    409     // (2) If we didn't max out |kReadSize|, stop reading for now.
    410     bool schedule_for_later = did_dispatch_message || bytes_read < kReadSize;
    411     bytes_read = 0;
    412     io_result = schedule_for_later ? ScheduleRead() : Read(&bytes_read);
    413   } while (io_result != IO_PENDING);
    414 }
    415 
    416 void RawChannel::OnWriteCompleted(bool result,
    417                                   size_t platform_handles_written,
    418                                   size_t bytes_written) {
    419   DCHECK_EQ(base::MessageLoop::current(), message_loop_for_io_);
    420 
    421   bool did_fail = false;
    422   {
    423     base::AutoLock locker(write_lock_);
    424     DCHECK_EQ(write_stopped_, write_buffer_->message_queue_.empty());
    425 
    426     if (write_stopped_) {
    427       NOTREACHED();
    428       return;
    429     }
    430 
    431     did_fail = !OnWriteCompletedNoLock(result,
    432                                        platform_handles_written,
    433                                        bytes_written);
    434   }
    435 
    436   if (did_fail)
    437     CallOnFatalError(Delegate::FATAL_ERROR_WRITE);
    438 }
    439 
    440 void RawChannel::EnqueueMessageNoLock(scoped_ptr<MessageInTransit> message) {
    441   write_lock_.AssertAcquired();
    442   write_buffer_->message_queue_.push_back(message.release());
    443 }
    444 
    445 bool RawChannel::OnReadMessageForRawChannel(
    446     const MessageInTransit::View& message_view) {
    447   // No non-implementation specific |RawChannel| control messages.
    448   LOG(ERROR) << "Invalid control message (subtype " << message_view.subtype()
    449              << ")";
    450   return false;
    451 }
    452 
    453 void RawChannel::CallOnFatalError(Delegate::FatalError fatal_error) {
    454   DCHECK_EQ(base::MessageLoop::current(), message_loop_for_io_);
    455   // TODO(vtl): Add a "write_lock_.AssertNotAcquired()"?
    456   if (delegate_)
    457     delegate_->OnFatalError(fatal_error);
    458 }
    459 
    460 bool RawChannel::OnWriteCompletedNoLock(bool result,
    461                                         size_t platform_handles_written,
    462                                         size_t bytes_written) {
    463   write_lock_.AssertAcquired();
    464 
    465   DCHECK(!write_stopped_);
    466   DCHECK(!write_buffer_->message_queue_.empty());
    467 
    468   if (result) {
    469     write_buffer_->platform_handles_offset_ += platform_handles_written;
    470     write_buffer_->data_offset_ += bytes_written;
    471 
    472     MessageInTransit* message = write_buffer_->message_queue_.front();
    473     if (write_buffer_->data_offset_ >= message->total_size()) {
    474       // Complete write.
    475       DCHECK_EQ(write_buffer_->data_offset_, message->total_size());
    476       write_buffer_->message_queue_.pop_front();
    477       delete message;
    478       write_buffer_->platform_handles_offset_ = 0;
    479       write_buffer_->data_offset_ = 0;
    480 
    481       if (write_buffer_->message_queue_.empty())
    482         return true;
    483     }
    484 
    485     // Schedule the next write.
    486     IOResult io_result = ScheduleWriteNoLock();
    487     if (io_result == IO_PENDING)
    488       return true;
    489     DCHECK_EQ(io_result, IO_FAILED);
    490   }
    491 
    492   write_stopped_ = true;
    493   STLDeleteElements(&write_buffer_->message_queue_);
    494   write_buffer_->platform_handles_offset_ = 0;
    495   write_buffer_->data_offset_ = 0;
    496   return false;
    497 }
    498 
    499 }  // namespace system
    500 }  // namespace mojo
    501