Home | History | Annotate | Download | only in system
      1 // Copyright 2014 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "mojo/system/raw_channel.h"
      6 
      7 #include <stdint.h>
      8 
      9 #include <vector>
     10 
     11 #include "base/bind.h"
     12 #include "base/location.h"
     13 #include "base/logging.h"
     14 #include "base/macros.h"
     15 #include "base/memory/scoped_ptr.h"
     16 #include "base/memory/scoped_vector.h"
     17 #include "base/rand_util.h"
     18 #include "base/synchronization/lock.h"
     19 #include "base/synchronization/waitable_event.h"
     20 #include "base/test/test_io_thread.h"
     21 #include "base/threading/platform_thread.h"  // For |Sleep()|.
     22 #include "base/threading/simple_thread.h"
     23 #include "base/time/time.h"
     24 #include "build/build_config.h"
     25 #include "mojo/common/test/test_utils.h"
     26 #include "mojo/embedder/platform_channel_pair.h"
     27 #include "mojo/embedder/platform_handle.h"
     28 #include "mojo/embedder/scoped_platform_handle.h"
     29 #include "mojo/system/message_in_transit.h"
     30 #include "mojo/system/test_utils.h"
     31 #include "testing/gtest/include/gtest/gtest.h"
     32 
     33 namespace mojo {
     34 namespace system {
     35 namespace {
     36 
     37 scoped_ptr<MessageInTransit> MakeTestMessage(uint32_t num_bytes) {
     38   std::vector<unsigned char> bytes(num_bytes, 0);
     39   for (size_t i = 0; i < num_bytes; i++)
     40     bytes[i] = static_cast<unsigned char>(i + num_bytes);
     41   return make_scoped_ptr(
     42       new MessageInTransit(MessageInTransit::kTypeMessagePipeEndpoint,
     43                            MessageInTransit::kSubtypeMessagePipeEndpointData,
     44                            num_bytes,
     45                            bytes.empty() ? nullptr : &bytes[0]));
     46 }
     47 
     48 bool CheckMessageData(const void* bytes, uint32_t num_bytes) {
     49   const unsigned char* b = static_cast<const unsigned char*>(bytes);
     50   for (uint32_t i = 0; i < num_bytes; i++) {
     51     if (b[i] != static_cast<unsigned char>(i + num_bytes))
     52       return false;
     53   }
     54   return true;
     55 }
     56 
     57 void InitOnIOThread(RawChannel* raw_channel, RawChannel::Delegate* delegate) {
     58   CHECK(raw_channel->Init(delegate));
     59 }
     60 
     61 bool WriteTestMessageToHandle(const embedder::PlatformHandle& handle,
     62                               uint32_t num_bytes) {
     63   scoped_ptr<MessageInTransit> message(MakeTestMessage(num_bytes));
     64 
     65   size_t write_size = 0;
     66   mojo::test::BlockingWrite(
     67       handle, message->main_buffer(), message->main_buffer_size(), &write_size);
     68   return write_size == message->main_buffer_size();
     69 }
     70 
     71 // -----------------------------------------------------------------------------
     72 
     73 class RawChannelTest : public testing::Test {
     74  public:
     75   RawChannelTest() : io_thread_(base::TestIOThread::kManualStart) {}
     76   virtual ~RawChannelTest() {}
     77 
     78   virtual void SetUp() OVERRIDE {
     79     embedder::PlatformChannelPair channel_pair;
     80     handles[0] = channel_pair.PassServerHandle();
     81     handles[1] = channel_pair.PassClientHandle();
     82     io_thread_.Start();
     83   }
     84 
     85   virtual void TearDown() OVERRIDE {
     86     io_thread_.Stop();
     87     handles[0].reset();
     88     handles[1].reset();
     89   }
     90 
     91  protected:
     92   base::TestIOThread* io_thread() { return &io_thread_; }
     93 
     94   embedder::ScopedPlatformHandle handles[2];
     95 
     96  private:
     97   base::TestIOThread io_thread_;
     98 
     99   DISALLOW_COPY_AND_ASSIGN(RawChannelTest);
    100 };
    101 
    102 // RawChannelTest.WriteMessage -------------------------------------------------
    103 
    104 class WriteOnlyRawChannelDelegate : public RawChannel::Delegate {
    105  public:
    106   WriteOnlyRawChannelDelegate() {}
    107   virtual ~WriteOnlyRawChannelDelegate() {}
    108 
    109   // |RawChannel::Delegate| implementation:
    110   virtual void OnReadMessage(
    111       const MessageInTransit::View& /*message_view*/,
    112       embedder::ScopedPlatformHandleVectorPtr /*platform_handles*/) OVERRIDE {
    113     CHECK(false);  // Should not get called.
    114   }
    115   virtual void OnError(Error error) OVERRIDE {
    116     // We'll get a read (shutdown) error when the connection is closed.
    117     CHECK_EQ(error, ERROR_READ_SHUTDOWN);
    118   }
    119 
    120  private:
    121   DISALLOW_COPY_AND_ASSIGN(WriteOnlyRawChannelDelegate);
    122 };
    123 
    124 static const int64_t kMessageReaderSleepMs = 1;
    125 static const size_t kMessageReaderMaxPollIterations = 3000;
    126 
    127 class TestMessageReaderAndChecker {
    128  public:
    129   explicit TestMessageReaderAndChecker(embedder::PlatformHandle handle)
    130       : handle_(handle) {}
    131   ~TestMessageReaderAndChecker() { CHECK(bytes_.empty()); }
    132 
    133   bool ReadAndCheckNextMessage(uint32_t expected_size) {
    134     unsigned char buffer[4096];
    135 
    136     for (size_t i = 0; i < kMessageReaderMaxPollIterations;) {
    137       size_t read_size = 0;
    138       CHECK(mojo::test::NonBlockingRead(
    139           handle_, buffer, sizeof(buffer), &read_size));
    140 
    141       // Append newly-read data to |bytes_|.
    142       bytes_.insert(bytes_.end(), buffer, buffer + read_size);
    143 
    144       // If we have the header....
    145       size_t message_size;
    146       if (MessageInTransit::GetNextMessageSize(
    147               bytes_.empty() ? nullptr : &bytes_[0],
    148               bytes_.size(),
    149               &message_size)) {
    150         // If we've read the whole message....
    151         if (bytes_.size() >= message_size) {
    152           bool rv = true;
    153           MessageInTransit::View message_view(message_size, &bytes_[0]);
    154           CHECK_EQ(message_view.main_buffer_size(), message_size);
    155 
    156           if (message_view.num_bytes() != expected_size) {
    157             LOG(ERROR) << "Wrong size: " << message_size << " instead of "
    158                        << expected_size << " bytes.";
    159             rv = false;
    160           } else if (!CheckMessageData(message_view.bytes(),
    161                                        message_view.num_bytes())) {
    162             LOG(ERROR) << "Incorrect message bytes.";
    163             rv = false;
    164           }
    165 
    166           // Erase message data.
    167           bytes_.erase(bytes_.begin(),
    168                        bytes_.begin() + message_view.main_buffer_size());
    169           return rv;
    170         }
    171       }
    172 
    173       if (static_cast<size_t>(read_size) < sizeof(buffer)) {
    174         i++;
    175         base::PlatformThread::Sleep(
    176             base::TimeDelta::FromMilliseconds(kMessageReaderSleepMs));
    177       }
    178     }
    179 
    180     LOG(ERROR) << "Too many iterations.";
    181     return false;
    182   }
    183 
    184  private:
    185   const embedder::PlatformHandle handle_;
    186 
    187   // The start of the received data should always be on a message boundary.
    188   std::vector<unsigned char> bytes_;
    189 
    190   DISALLOW_COPY_AND_ASSIGN(TestMessageReaderAndChecker);
    191 };
    192 
    193 // Tests writing (and verifies reading using our own custom reader).
    194 TEST_F(RawChannelTest, WriteMessage) {
    195   WriteOnlyRawChannelDelegate delegate;
    196   scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
    197   TestMessageReaderAndChecker checker(handles[1].get());
    198   io_thread()->PostTaskAndWait(
    199       FROM_HERE,
    200       base::Bind(&InitOnIOThread, rc.get(), base::Unretained(&delegate)));
    201 
    202   // Write and read, for a variety of sizes.
    203   for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1) {
    204     EXPECT_TRUE(rc->WriteMessage(MakeTestMessage(size)));
    205     EXPECT_TRUE(checker.ReadAndCheckNextMessage(size)) << size;
    206   }
    207 
    208   // Write/queue and read afterwards, for a variety of sizes.
    209   for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1)
    210     EXPECT_TRUE(rc->WriteMessage(MakeTestMessage(size)));
    211   for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1)
    212     EXPECT_TRUE(checker.ReadAndCheckNextMessage(size)) << size;
    213 
    214   io_thread()->PostTaskAndWait(
    215       FROM_HERE, base::Bind(&RawChannel::Shutdown, base::Unretained(rc.get())));
    216 }
    217 
    218 // RawChannelTest.OnReadMessage ------------------------------------------------
    219 
    220 class ReadCheckerRawChannelDelegate : public RawChannel::Delegate {
    221  public:
    222   ReadCheckerRawChannelDelegate() : done_event_(false, false), position_(0) {}
    223   virtual ~ReadCheckerRawChannelDelegate() {}
    224 
    225   // |RawChannel::Delegate| implementation (called on the I/O thread):
    226   virtual void OnReadMessage(
    227       const MessageInTransit::View& message_view,
    228       embedder::ScopedPlatformHandleVectorPtr platform_handles) OVERRIDE {
    229     EXPECT_FALSE(platform_handles);
    230 
    231     size_t position;
    232     size_t expected_size;
    233     bool should_signal = false;
    234     {
    235       base::AutoLock locker(lock_);
    236       CHECK_LT(position_, expected_sizes_.size());
    237       position = position_;
    238       expected_size = expected_sizes_[position];
    239       position_++;
    240       if (position_ >= expected_sizes_.size())
    241         should_signal = true;
    242     }
    243 
    244     EXPECT_EQ(expected_size, message_view.num_bytes()) << position;
    245     if (message_view.num_bytes() == expected_size) {
    246       EXPECT_TRUE(
    247           CheckMessageData(message_view.bytes(), message_view.num_bytes()))
    248           << position;
    249     }
    250 
    251     if (should_signal)
    252       done_event_.Signal();
    253   }
    254   virtual void OnError(Error error) OVERRIDE {
    255     // We'll get a read (shutdown) error when the connection is closed.
    256     CHECK_EQ(error, ERROR_READ_SHUTDOWN);
    257   }
    258 
    259   // Waits for all the messages (of sizes |expected_sizes_|) to be seen.
    260   void Wait() { done_event_.Wait(); }
    261 
    262   void SetExpectedSizes(const std::vector<uint32_t>& expected_sizes) {
    263     base::AutoLock locker(lock_);
    264     CHECK_EQ(position_, expected_sizes_.size());
    265     expected_sizes_ = expected_sizes;
    266     position_ = 0;
    267   }
    268 
    269  private:
    270   base::WaitableEvent done_event_;
    271 
    272   base::Lock lock_;  // Protects the following members.
    273   std::vector<uint32_t> expected_sizes_;
    274   size_t position_;
    275 
    276   DISALLOW_COPY_AND_ASSIGN(ReadCheckerRawChannelDelegate);
    277 };
    278 
    279 // Tests reading (writing using our own custom writer).
    280 TEST_F(RawChannelTest, OnReadMessage) {
    281   ReadCheckerRawChannelDelegate delegate;
    282   scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
    283   io_thread()->PostTaskAndWait(
    284       FROM_HERE,
    285       base::Bind(&InitOnIOThread, rc.get(), base::Unretained(&delegate)));
    286 
    287   // Write and read, for a variety of sizes.
    288   for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1) {
    289     delegate.SetExpectedSizes(std::vector<uint32_t>(1, size));
    290 
    291     EXPECT_TRUE(WriteTestMessageToHandle(handles[1].get(), size));
    292 
    293     delegate.Wait();
    294   }
    295 
    296   // Set up reader and write as fast as we can.
    297   // Write/queue and read afterwards, for a variety of sizes.
    298   std::vector<uint32_t> expected_sizes;
    299   for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1)
    300     expected_sizes.push_back(size);
    301   delegate.SetExpectedSizes(expected_sizes);
    302   for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1)
    303     EXPECT_TRUE(WriteTestMessageToHandle(handles[1].get(), size));
    304   delegate.Wait();
    305 
    306   io_thread()->PostTaskAndWait(
    307       FROM_HERE, base::Bind(&RawChannel::Shutdown, base::Unretained(rc.get())));
    308 }
    309 
    310 // RawChannelTest.WriteMessageAndOnReadMessage ---------------------------------
    311 
    312 class RawChannelWriterThread : public base::SimpleThread {
    313  public:
    314   RawChannelWriterThread(RawChannel* raw_channel, size_t write_count)
    315       : base::SimpleThread("raw_channel_writer_thread"),
    316         raw_channel_(raw_channel),
    317         left_to_write_(write_count) {}
    318 
    319   virtual ~RawChannelWriterThread() { Join(); }
    320 
    321  private:
    322   virtual void Run() OVERRIDE {
    323     static const int kMaxRandomMessageSize = 25000;
    324 
    325     while (left_to_write_-- > 0) {
    326       EXPECT_TRUE(raw_channel_->WriteMessage(MakeTestMessage(
    327           static_cast<uint32_t>(base::RandInt(1, kMaxRandomMessageSize)))));
    328     }
    329   }
    330 
    331   RawChannel* const raw_channel_;
    332   size_t left_to_write_;
    333 
    334   DISALLOW_COPY_AND_ASSIGN(RawChannelWriterThread);
    335 };
    336 
    337 class ReadCountdownRawChannelDelegate : public RawChannel::Delegate {
    338  public:
    339   explicit ReadCountdownRawChannelDelegate(size_t expected_count)
    340       : done_event_(false, false), expected_count_(expected_count), count_(0) {}
    341   virtual ~ReadCountdownRawChannelDelegate() {}
    342 
    343   // |RawChannel::Delegate| implementation (called on the I/O thread):
    344   virtual void OnReadMessage(
    345       const MessageInTransit::View& message_view,
    346       embedder::ScopedPlatformHandleVectorPtr platform_handles) OVERRIDE {
    347     EXPECT_FALSE(platform_handles);
    348 
    349     EXPECT_LT(count_, expected_count_);
    350     count_++;
    351 
    352     EXPECT_TRUE(
    353         CheckMessageData(message_view.bytes(), message_view.num_bytes()));
    354 
    355     if (count_ >= expected_count_)
    356       done_event_.Signal();
    357   }
    358   virtual void OnError(Error error) OVERRIDE {
    359     // We'll get a read (shutdown) error when the connection is closed.
    360     CHECK_EQ(error, ERROR_READ_SHUTDOWN);
    361   }
    362 
    363   // Waits for all the messages to have been seen.
    364   void Wait() { done_event_.Wait(); }
    365 
    366  private:
    367   base::WaitableEvent done_event_;
    368   size_t expected_count_;
    369   size_t count_;
    370 
    371   DISALLOW_COPY_AND_ASSIGN(ReadCountdownRawChannelDelegate);
    372 };
    373 
    374 TEST_F(RawChannelTest, WriteMessageAndOnReadMessage) {
    375   static const size_t kNumWriterThreads = 10;
    376   static const size_t kNumWriteMessagesPerThread = 4000;
    377 
    378   WriteOnlyRawChannelDelegate writer_delegate;
    379   scoped_ptr<RawChannel> writer_rc(RawChannel::Create(handles[0].Pass()));
    380   io_thread()->PostTaskAndWait(FROM_HERE,
    381                                base::Bind(&InitOnIOThread,
    382                                           writer_rc.get(),
    383                                           base::Unretained(&writer_delegate)));
    384 
    385   ReadCountdownRawChannelDelegate reader_delegate(kNumWriterThreads *
    386                                                   kNumWriteMessagesPerThread);
    387   scoped_ptr<RawChannel> reader_rc(RawChannel::Create(handles[1].Pass()));
    388   io_thread()->PostTaskAndWait(FROM_HERE,
    389                                base::Bind(&InitOnIOThread,
    390                                           reader_rc.get(),
    391                                           base::Unretained(&reader_delegate)));
    392 
    393   {
    394     ScopedVector<RawChannelWriterThread> writer_threads;
    395     for (size_t i = 0; i < kNumWriterThreads; i++) {
    396       writer_threads.push_back(new RawChannelWriterThread(
    397           writer_rc.get(), kNumWriteMessagesPerThread));
    398     }
    399     for (size_t i = 0; i < writer_threads.size(); i++)
    400       writer_threads[i]->Start();
    401   }  // Joins all the writer threads.
    402 
    403   // Sleep a bit, to let any extraneous reads be processed. (There shouldn't be
    404   // any, but we want to know about them.)
    405   base::PlatformThread::Sleep(base::TimeDelta::FromMilliseconds(100));
    406 
    407   // Wait for reading to finish.
    408   reader_delegate.Wait();
    409 
    410   io_thread()->PostTaskAndWait(
    411       FROM_HERE,
    412       base::Bind(&RawChannel::Shutdown, base::Unretained(reader_rc.get())));
    413 
    414   io_thread()->PostTaskAndWait(
    415       FROM_HERE,
    416       base::Bind(&RawChannel::Shutdown, base::Unretained(writer_rc.get())));
    417 }
    418 
    419 // RawChannelTest.OnError ------------------------------------------------------
    420 
    421 class ErrorRecordingRawChannelDelegate
    422     : public ReadCountdownRawChannelDelegate {
    423  public:
    424   ErrorRecordingRawChannelDelegate(size_t expected_read_count,
    425                                    bool expect_read_error,
    426                                    bool expect_write_error)
    427       : ReadCountdownRawChannelDelegate(expected_read_count),
    428         got_read_error_event_(false, false),
    429         got_write_error_event_(false, false),
    430         expecting_read_error_(expect_read_error),
    431         expecting_write_error_(expect_write_error) {}
    432 
    433   virtual ~ErrorRecordingRawChannelDelegate() {}
    434 
    435   virtual void OnError(Error error) OVERRIDE {
    436     switch (error) {
    437       case ERROR_READ_SHUTDOWN:
    438         ASSERT_TRUE(expecting_read_error_);
    439         expecting_read_error_ = false;
    440         got_read_error_event_.Signal();
    441         break;
    442       case ERROR_READ_BROKEN:
    443         // TODO(vtl): Test broken connections.
    444         CHECK(false);
    445         break;
    446       case ERROR_READ_BAD_MESSAGE:
    447         // TODO(vtl): Test reception/detection of bad messages.
    448         CHECK(false);
    449         break;
    450       case ERROR_READ_UNKNOWN:
    451         // TODO(vtl): Test however it is we might get here.
    452         CHECK(false);
    453         break;
    454       case ERROR_WRITE:
    455         ASSERT_TRUE(expecting_write_error_);
    456         expecting_write_error_ = false;
    457         got_write_error_event_.Signal();
    458         break;
    459     }
    460   }
    461 
    462   void WaitForReadError() { got_read_error_event_.Wait(); }
    463   void WaitForWriteError() { got_write_error_event_.Wait(); }
    464 
    465  private:
    466   base::WaitableEvent got_read_error_event_;
    467   base::WaitableEvent got_write_error_event_;
    468 
    469   bool expecting_read_error_;
    470   bool expecting_write_error_;
    471 
    472   DISALLOW_COPY_AND_ASSIGN(ErrorRecordingRawChannelDelegate);
    473 };
    474 
    475 // Tests (fatal) errors.
    476 TEST_F(RawChannelTest, OnError) {
    477   ErrorRecordingRawChannelDelegate delegate(0, true, true);
    478   scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
    479   io_thread()->PostTaskAndWait(
    480       FROM_HERE,
    481       base::Bind(&InitOnIOThread, rc.get(), base::Unretained(&delegate)));
    482 
    483   // Close the handle of the other end, which should make writing fail.
    484   handles[1].reset();
    485 
    486   EXPECT_FALSE(rc->WriteMessage(MakeTestMessage(1)));
    487 
    488   // We should get a write error.
    489   delegate.WaitForWriteError();
    490 
    491   // We should also get a read error.
    492   delegate.WaitForReadError();
    493 
    494   EXPECT_FALSE(rc->WriteMessage(MakeTestMessage(2)));
    495 
    496   // Sleep a bit, to make sure we don't get another |OnError()|
    497   // notification. (If we actually get another one, |OnError()| crashes.)
    498   base::PlatformThread::Sleep(base::TimeDelta::FromMilliseconds(20));
    499 
    500   io_thread()->PostTaskAndWait(
    501       FROM_HERE, base::Bind(&RawChannel::Shutdown, base::Unretained(rc.get())));
    502 }
    503 
    504 // RawChannelTest.ReadUnaffectedByWriteError -----------------------------------
    505 
    506 TEST_F(RawChannelTest, ReadUnaffectedByWriteError) {
    507   const size_t kMessageCount = 5;
    508 
    509   // Write a few messages into the other end.
    510   uint32_t message_size = 1;
    511   for (size_t i = 0; i < kMessageCount;
    512        i++, message_size += message_size / 2 + 1)
    513     EXPECT_TRUE(WriteTestMessageToHandle(handles[1].get(), message_size));
    514 
    515   // Close the other end, which should make writing fail.
    516   handles[1].reset();
    517 
    518   // Only start up reading here. The system buffer should still contain the
    519   // messages that were written.
    520   ErrorRecordingRawChannelDelegate delegate(kMessageCount, true, true);
    521   scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
    522   io_thread()->PostTaskAndWait(
    523       FROM_HERE,
    524       base::Bind(&InitOnIOThread, rc.get(), base::Unretained(&delegate)));
    525 
    526   EXPECT_FALSE(rc->WriteMessage(MakeTestMessage(1)));
    527 
    528   // We should definitely get a write error.
    529   delegate.WaitForWriteError();
    530 
    531   // Wait for reading to finish. A writing failure shouldn't affect reading.
    532   delegate.Wait();
    533 
    534   // And then we should get a read error.
    535   delegate.WaitForReadError();
    536 
    537   io_thread()->PostTaskAndWait(
    538       FROM_HERE, base::Bind(&RawChannel::Shutdown, base::Unretained(rc.get())));
    539 }
    540 
    541 // RawChannelTest.WriteMessageAfterShutdown ------------------------------------
    542 
    543 // Makes sure that calling |WriteMessage()| after |Shutdown()| behaves
    544 // correctly.
    545 TEST_F(RawChannelTest, WriteMessageAfterShutdown) {
    546   WriteOnlyRawChannelDelegate delegate;
    547   scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
    548   io_thread()->PostTaskAndWait(
    549       FROM_HERE,
    550       base::Bind(&InitOnIOThread, rc.get(), base::Unretained(&delegate)));
    551   io_thread()->PostTaskAndWait(
    552       FROM_HERE, base::Bind(&RawChannel::Shutdown, base::Unretained(rc.get())));
    553 
    554   EXPECT_FALSE(rc->WriteMessage(MakeTestMessage(1)));
    555 }
    556 
    557 // RawChannelTest.ShutdownOnReadMessage ----------------------------------------
    558 
    559 class ShutdownOnReadMessageRawChannelDelegate : public RawChannel::Delegate {
    560  public:
    561   explicit ShutdownOnReadMessageRawChannelDelegate(RawChannel* raw_channel)
    562       : raw_channel_(raw_channel),
    563         done_event_(false, false),
    564         did_shutdown_(false) {}
    565   virtual ~ShutdownOnReadMessageRawChannelDelegate() {}
    566 
    567   // |RawChannel::Delegate| implementation (called on the I/O thread):
    568   virtual void OnReadMessage(
    569       const MessageInTransit::View& message_view,
    570       embedder::ScopedPlatformHandleVectorPtr platform_handles) OVERRIDE {
    571     EXPECT_FALSE(platform_handles);
    572     EXPECT_FALSE(did_shutdown_);
    573     EXPECT_TRUE(
    574         CheckMessageData(message_view.bytes(), message_view.num_bytes()));
    575     raw_channel_->Shutdown();
    576     did_shutdown_ = true;
    577     done_event_.Signal();
    578   }
    579   virtual void OnError(Error /*error*/) OVERRIDE {
    580     CHECK(false);  // Should not get called.
    581   }
    582 
    583   // Waits for shutdown.
    584   void Wait() {
    585     done_event_.Wait();
    586     EXPECT_TRUE(did_shutdown_);
    587   }
    588 
    589  private:
    590   RawChannel* const raw_channel_;
    591   base::WaitableEvent done_event_;
    592   bool did_shutdown_;
    593 
    594   DISALLOW_COPY_AND_ASSIGN(ShutdownOnReadMessageRawChannelDelegate);
    595 };
    596 
    597 TEST_F(RawChannelTest, ShutdownOnReadMessage) {
    598   // Write a few messages into the other end.
    599   for (size_t count = 0; count < 5; count++)
    600     EXPECT_TRUE(WriteTestMessageToHandle(handles[1].get(), 10));
    601 
    602   scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
    603   ShutdownOnReadMessageRawChannelDelegate delegate(rc.get());
    604   io_thread()->PostTaskAndWait(
    605       FROM_HERE,
    606       base::Bind(&InitOnIOThread, rc.get(), base::Unretained(&delegate)));
    607 
    608   // Wait for the delegate, which will shut the |RawChannel| down.
    609   delegate.Wait();
    610 }
    611 
    612 // RawChannelTest.ShutdownOnError{Read, Write} ---------------------------------
    613 
    614 class ShutdownOnErrorRawChannelDelegate : public RawChannel::Delegate {
    615  public:
    616   ShutdownOnErrorRawChannelDelegate(RawChannel* raw_channel,
    617                                     Error shutdown_on_error_type)
    618       : raw_channel_(raw_channel),
    619         shutdown_on_error_type_(shutdown_on_error_type),
    620         done_event_(false, false),
    621         did_shutdown_(false) {}
    622   virtual ~ShutdownOnErrorRawChannelDelegate() {}
    623 
    624   // |RawChannel::Delegate| implementation (called on the I/O thread):
    625   virtual void OnReadMessage(
    626       const MessageInTransit::View& /*message_view*/,
    627       embedder::ScopedPlatformHandleVectorPtr /*platform_handles*/) OVERRIDE {
    628     CHECK(false);  // Should not get called.
    629   }
    630   virtual void OnError(Error error) OVERRIDE {
    631     EXPECT_FALSE(did_shutdown_);
    632     if (error != shutdown_on_error_type_)
    633       return;
    634     raw_channel_->Shutdown();
    635     did_shutdown_ = true;
    636     done_event_.Signal();
    637   }
    638 
    639   // Waits for shutdown.
    640   void Wait() {
    641     done_event_.Wait();
    642     EXPECT_TRUE(did_shutdown_);
    643   }
    644 
    645  private:
    646   RawChannel* const raw_channel_;
    647   const Error shutdown_on_error_type_;
    648   base::WaitableEvent done_event_;
    649   bool did_shutdown_;
    650 
    651   DISALLOW_COPY_AND_ASSIGN(ShutdownOnErrorRawChannelDelegate);
    652 };
    653 
    654 TEST_F(RawChannelTest, ShutdownOnErrorRead) {
    655   scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
    656   ShutdownOnErrorRawChannelDelegate delegate(
    657       rc.get(), RawChannel::Delegate::ERROR_READ_SHUTDOWN);
    658   io_thread()->PostTaskAndWait(
    659       FROM_HERE,
    660       base::Bind(&InitOnIOThread, rc.get(), base::Unretained(&delegate)));
    661 
    662   // Close the handle of the other end, which should stuff fail.
    663   handles[1].reset();
    664 
    665   // Wait for the delegate, which will shut the |RawChannel| down.
    666   delegate.Wait();
    667 }
    668 
    669 TEST_F(RawChannelTest, ShutdownOnErrorWrite) {
    670   scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
    671   ShutdownOnErrorRawChannelDelegate delegate(rc.get(),
    672                                              RawChannel::Delegate::ERROR_WRITE);
    673   io_thread()->PostTaskAndWait(
    674       FROM_HERE,
    675       base::Bind(&InitOnIOThread, rc.get(), base::Unretained(&delegate)));
    676 
    677   // Close the handle of the other end, which should stuff fail.
    678   handles[1].reset();
    679 
    680   EXPECT_FALSE(rc->WriteMessage(MakeTestMessage(1)));
    681 
    682   // Wait for the delegate, which will shut the |RawChannel| down.
    683   delegate.Wait();
    684 }
    685 
    686 }  // namespace
    687 }  // namespace system
    688 }  // namespace mojo
    689