Home | History | Annotate | Download | only in system
      1 // Copyright 2014 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 
      5 #include "mojo/system/raw_channel.h"
      6 
      7 #include <stdint.h>
      8 
      9 #include <vector>
     10 
     11 #include "base/bind.h"
     12 #include "base/location.h"
     13 #include "base/logging.h"
     14 #include "base/macros.h"
     15 #include "base/memory/scoped_ptr.h"
     16 #include "base/memory/scoped_vector.h"
     17 #include "base/rand_util.h"
     18 #include "base/synchronization/lock.h"
     19 #include "base/synchronization/waitable_event.h"
     20 #include "base/threading/platform_thread.h"  // For |Sleep()|.
     21 #include "base/threading/simple_thread.h"
     22 #include "base/time/time.h"
     23 #include "build/build_config.h"
     24 #include "mojo/common/test/test_utils.h"
     25 #include "mojo/embedder/platform_channel_pair.h"
     26 #include "mojo/embedder/platform_handle.h"
     27 #include "mojo/embedder/scoped_platform_handle.h"
     28 #include "mojo/system/message_in_transit.h"
     29 #include "mojo/system/test_utils.h"
     30 #include "testing/gtest/include/gtest/gtest.h"
     31 
     32 namespace mojo {
     33 namespace system {
     34 namespace {
     35 
     36 scoped_ptr<MessageInTransit> MakeTestMessage(uint32_t num_bytes) {
     37   std::vector<unsigned char> bytes(num_bytes, 0);
     38   for (size_t i = 0; i < num_bytes; i++)
     39     bytes[i] = static_cast<unsigned char>(i + num_bytes);
     40   return make_scoped_ptr(
     41       new MessageInTransit(MessageInTransit::kTypeMessagePipeEndpoint,
     42                            MessageInTransit::kSubtypeMessagePipeEndpointData,
     43                            num_bytes, bytes.empty() ? NULL : &bytes[0]));
     44 }
     45 
     46 bool CheckMessageData(const void* bytes, uint32_t num_bytes) {
     47   const unsigned char* b = static_cast<const unsigned char*>(bytes);
     48   for (uint32_t i = 0; i < num_bytes; i++) {
     49     if (b[i] != static_cast<unsigned char>(i + num_bytes))
     50       return false;
     51   }
     52   return true;
     53 }
     54 
     55 void InitOnIOThread(RawChannel* raw_channel, RawChannel::Delegate* delegate) {
     56   CHECK(raw_channel->Init(delegate));
     57 }
     58 
     59 bool WriteTestMessageToHandle(const embedder::PlatformHandle& handle,
     60                               uint32_t num_bytes) {
     61   scoped_ptr<MessageInTransit> message(MakeTestMessage(num_bytes));
     62 
     63   size_t write_size = 0;
     64   mojo::test::BlockingWrite(
     65       handle, message->main_buffer(), message->main_buffer_size(), &write_size);
     66   return write_size == message->main_buffer_size();
     67 }
     68 
     69 // -----------------------------------------------------------------------------
     70 
     71 class RawChannelTest : public testing::Test {
     72  public:
     73   RawChannelTest() : io_thread_(test::TestIOThread::kManualStart) {}
     74   virtual ~RawChannelTest() {}
     75 
     76   virtual void SetUp() OVERRIDE {
     77     embedder::PlatformChannelPair channel_pair;
     78     handles[0] = channel_pair.PassServerHandle();
     79     handles[1] = channel_pair.PassClientHandle();
     80     io_thread_.Start();
     81   }
     82 
     83   virtual void TearDown() OVERRIDE {
     84     io_thread_.Stop();
     85     handles[0].reset();
     86     handles[1].reset();
     87   }
     88 
     89  protected:
     90   test::TestIOThread* io_thread() { return &io_thread_; }
     91 
     92   embedder::ScopedPlatformHandle handles[2];
     93 
     94  private:
     95   test::TestIOThread io_thread_;
     96 
     97   DISALLOW_COPY_AND_ASSIGN(RawChannelTest);
     98 };
     99 
    100 // RawChannelTest.WriteMessage -------------------------------------------------
    101 
    102 class WriteOnlyRawChannelDelegate : public RawChannel::Delegate {
    103  public:
    104   WriteOnlyRawChannelDelegate() {}
    105   virtual ~WriteOnlyRawChannelDelegate() {}
    106 
    107   // |RawChannel::Delegate| implementation:
    108   virtual void OnReadMessage(
    109       const MessageInTransit::View& /*message_view*/,
    110       embedder::ScopedPlatformHandleVectorPtr /*platform_handles*/) OVERRIDE {
    111     CHECK(false);  // Should not get called.
    112   }
    113   virtual void OnFatalError(FatalError fatal_error) OVERRIDE {
    114     // We'll get a read error when the connection is closed.
    115     CHECK_EQ(fatal_error, FATAL_ERROR_READ);
    116   }
    117 
    118  private:
    119   DISALLOW_COPY_AND_ASSIGN(WriteOnlyRawChannelDelegate);
    120 };
    121 
    122 static const int64_t kMessageReaderSleepMs = 1;
    123 static const size_t kMessageReaderMaxPollIterations = 3000;
    124 
    125 class TestMessageReaderAndChecker {
    126  public:
    127   explicit TestMessageReaderAndChecker(embedder::PlatformHandle handle)
    128       : handle_(handle) {}
    129   ~TestMessageReaderAndChecker() { CHECK(bytes_.empty()); }
    130 
    131   bool ReadAndCheckNextMessage(uint32_t expected_size) {
    132     unsigned char buffer[4096];
    133 
    134     for (size_t i = 0; i < kMessageReaderMaxPollIterations;) {
    135       size_t read_size = 0;
    136       CHECK(mojo::test::NonBlockingRead(handle_, buffer, sizeof(buffer),
    137                                         &read_size));
    138 
    139       // Append newly-read data to |bytes_|.
    140       bytes_.insert(bytes_.end(), buffer, buffer + read_size);
    141 
    142       // If we have the header....
    143       size_t message_size;
    144       if (MessageInTransit::GetNextMessageSize(
    145               bytes_.empty() ? NULL : &bytes_[0],
    146               bytes_.size(),
    147               &message_size)) {
    148         // If we've read the whole message....
    149         if (bytes_.size() >= message_size) {
    150           bool rv = true;
    151           MessageInTransit::View message_view(message_size, &bytes_[0]);
    152           CHECK_EQ(message_view.main_buffer_size(), message_size);
    153 
    154           if (message_view.num_bytes() != expected_size) {
    155             LOG(ERROR) << "Wrong size: " << message_size << " instead of "
    156                        << expected_size << " bytes.";
    157             rv = false;
    158           } else if (!CheckMessageData(message_view.bytes(),
    159                                        message_view.num_bytes())) {
    160             LOG(ERROR) << "Incorrect message bytes.";
    161             rv = false;
    162           }
    163 
    164           // Erase message data.
    165           bytes_.erase(bytes_.begin(),
    166                        bytes_.begin() +
    167                            message_view.main_buffer_size());
    168           return rv;
    169         }
    170       }
    171 
    172       if (static_cast<size_t>(read_size) < sizeof(buffer)) {
    173         i++;
    174         base::PlatformThread::Sleep(
    175             base::TimeDelta::FromMilliseconds(kMessageReaderSleepMs));
    176       }
    177     }
    178 
    179     LOG(ERROR) << "Too many iterations.";
    180     return false;
    181   }
    182 
    183  private:
    184   const embedder::PlatformHandle handle_;
    185 
    186   // The start of the received data should always be on a message boundary.
    187   std::vector<unsigned char> bytes_;
    188 
    189   DISALLOW_COPY_AND_ASSIGN(TestMessageReaderAndChecker);
    190 };
    191 
    192 // Tests writing (and verifies reading using our own custom reader).
    193 TEST_F(RawChannelTest, WriteMessage) {
    194   WriteOnlyRawChannelDelegate delegate;
    195   scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
    196   TestMessageReaderAndChecker checker(handles[1].get());
    197   io_thread()->PostTaskAndWait(FROM_HERE,
    198                                base::Bind(&InitOnIOThread, rc.get(),
    199                                           base::Unretained(&delegate)));
    200 
    201   // Write and read, for a variety of sizes.
    202   for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1) {
    203     EXPECT_TRUE(rc->WriteMessage(MakeTestMessage(size)));
    204     EXPECT_TRUE(checker.ReadAndCheckNextMessage(size)) << size;
    205   }
    206 
    207   // Write/queue and read afterwards, for a variety of sizes.
    208   for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1)
    209     EXPECT_TRUE(rc->WriteMessage(MakeTestMessage(size)));
    210   for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1)
    211     EXPECT_TRUE(checker.ReadAndCheckNextMessage(size)) << size;
    212 
    213   io_thread()->PostTaskAndWait(FROM_HERE,
    214                                base::Bind(&RawChannel::Shutdown,
    215                                           base::Unretained(rc.get())));
    216 }
    217 
    218 // RawChannelTest.OnReadMessage ------------------------------------------------
    219 
    220 class ReadCheckerRawChannelDelegate : public RawChannel::Delegate {
    221  public:
    222   ReadCheckerRawChannelDelegate()
    223       : done_event_(false, false),
    224         position_(0) {}
    225   virtual ~ReadCheckerRawChannelDelegate() {}
    226 
    227   // |RawChannel::Delegate| implementation (called on the I/O thread):
    228   virtual void OnReadMessage(
    229       const MessageInTransit::View& message_view,
    230       embedder::ScopedPlatformHandleVectorPtr platform_handles) OVERRIDE {
    231     EXPECT_FALSE(platform_handles);
    232 
    233     size_t position;
    234     size_t expected_size;
    235     bool should_signal = false;
    236     {
    237       base::AutoLock locker(lock_);
    238       CHECK_LT(position_, expected_sizes_.size());
    239       position = position_;
    240       expected_size = expected_sizes_[position];
    241       position_++;
    242       if (position_ >= expected_sizes_.size())
    243         should_signal = true;
    244     }
    245 
    246     EXPECT_EQ(expected_size, message_view.num_bytes()) << position;
    247     if (message_view.num_bytes() == expected_size) {
    248       EXPECT_TRUE(CheckMessageData(message_view.bytes(),
    249                   message_view.num_bytes())) << position;
    250     }
    251 
    252     if (should_signal)
    253       done_event_.Signal();
    254   }
    255   virtual void OnFatalError(FatalError fatal_error) OVERRIDE {
    256     // We'll get a read error when the connection is closed.
    257     CHECK_EQ(fatal_error, FATAL_ERROR_READ);
    258   }
    259 
    260   // Waits for all the messages (of sizes |expected_sizes_|) to be seen.
    261   void Wait() {
    262     done_event_.Wait();
    263   }
    264 
    265   void SetExpectedSizes(const std::vector<uint32_t>& expected_sizes) {
    266     base::AutoLock locker(lock_);
    267     CHECK_EQ(position_, expected_sizes_.size());
    268     expected_sizes_ = expected_sizes;
    269     position_ = 0;
    270   }
    271 
    272  private:
    273   base::WaitableEvent done_event_;
    274 
    275   base::Lock lock_;  // Protects the following members.
    276   std::vector<uint32_t> expected_sizes_;
    277   size_t position_;
    278 
    279   DISALLOW_COPY_AND_ASSIGN(ReadCheckerRawChannelDelegate);
    280 };
    281 
    282 // Tests reading (writing using our own custom writer).
    283 TEST_F(RawChannelTest, OnReadMessage) {
    284   ReadCheckerRawChannelDelegate delegate;
    285   scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
    286   io_thread()->PostTaskAndWait(FROM_HERE,
    287                                base::Bind(&InitOnIOThread, rc.get(),
    288                                           base::Unretained(&delegate)));
    289 
    290   // Write and read, for a variety of sizes.
    291   for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1) {
    292     delegate.SetExpectedSizes(std::vector<uint32_t>(1, size));
    293 
    294     EXPECT_TRUE(WriteTestMessageToHandle(handles[1].get(), size));
    295 
    296     delegate.Wait();
    297   }
    298 
    299   // Set up reader and write as fast as we can.
    300   // Write/queue and read afterwards, for a variety of sizes.
    301   std::vector<uint32_t> expected_sizes;
    302   for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1)
    303     expected_sizes.push_back(size);
    304   delegate.SetExpectedSizes(expected_sizes);
    305   for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1)
    306     EXPECT_TRUE(WriteTestMessageToHandle(handles[1].get(), size));
    307   delegate.Wait();
    308 
    309   io_thread()->PostTaskAndWait(FROM_HERE,
    310                                base::Bind(&RawChannel::Shutdown,
    311                                           base::Unretained(rc.get())));
    312 }
    313 
    314 // RawChannelTest.WriteMessageAndOnReadMessage ---------------------------------
    315 
    316 class RawChannelWriterThread : public base::SimpleThread {
    317  public:
    318   RawChannelWriterThread(RawChannel* raw_channel, size_t write_count)
    319       : base::SimpleThread("raw_channel_writer_thread"),
    320         raw_channel_(raw_channel),
    321         left_to_write_(write_count) {
    322   }
    323 
    324   virtual ~RawChannelWriterThread() {
    325     Join();
    326   }
    327 
    328  private:
    329   virtual void Run() OVERRIDE {
    330     static const int kMaxRandomMessageSize = 25000;
    331 
    332     while (left_to_write_-- > 0) {
    333       EXPECT_TRUE(raw_channel_->WriteMessage(MakeTestMessage(
    334           static_cast<uint32_t>(base::RandInt(1, kMaxRandomMessageSize)))));
    335     }
    336   }
    337 
    338   RawChannel* const raw_channel_;
    339   size_t left_to_write_;
    340 
    341   DISALLOW_COPY_AND_ASSIGN(RawChannelWriterThread);
    342 };
    343 
    344 class ReadCountdownRawChannelDelegate : public RawChannel::Delegate {
    345  public:
    346   explicit ReadCountdownRawChannelDelegate(size_t expected_count)
    347       : done_event_(false, false),
    348         expected_count_(expected_count),
    349         count_(0) {}
    350   virtual ~ReadCountdownRawChannelDelegate() {}
    351 
    352   // |RawChannel::Delegate| implementation (called on the I/O thread):
    353   virtual void OnReadMessage(
    354       const MessageInTransit::View& message_view,
    355       embedder::ScopedPlatformHandleVectorPtr platform_handles) OVERRIDE {
    356     EXPECT_FALSE(platform_handles);
    357 
    358     EXPECT_LT(count_, expected_count_);
    359     count_++;
    360 
    361     EXPECT_TRUE(CheckMessageData(message_view.bytes(),
    362                 message_view.num_bytes()));
    363 
    364     if (count_ >= expected_count_)
    365       done_event_.Signal();
    366   }
    367   virtual void OnFatalError(FatalError fatal_error) OVERRIDE {
    368     // We'll get a read error when the connection is closed.
    369     CHECK_EQ(fatal_error, FATAL_ERROR_READ);
    370   }
    371 
    372   // Waits for all the messages to have been seen.
    373   void Wait() {
    374     done_event_.Wait();
    375   }
    376 
    377  private:
    378   base::WaitableEvent done_event_;
    379   size_t expected_count_;
    380   size_t count_;
    381 
    382   DISALLOW_COPY_AND_ASSIGN(ReadCountdownRawChannelDelegate);
    383 };
    384 
    385 TEST_F(RawChannelTest, WriteMessageAndOnReadMessage) {
    386   static const size_t kNumWriterThreads = 10;
    387   static const size_t kNumWriteMessagesPerThread = 4000;
    388 
    389   WriteOnlyRawChannelDelegate writer_delegate;
    390   scoped_ptr<RawChannel> writer_rc(RawChannel::Create(handles[0].Pass()));
    391   io_thread()->PostTaskAndWait(FROM_HERE,
    392                                base::Bind(&InitOnIOThread, writer_rc.get(),
    393                                           base::Unretained(&writer_delegate)));
    394 
    395   ReadCountdownRawChannelDelegate reader_delegate(
    396       kNumWriterThreads * kNumWriteMessagesPerThread);
    397   scoped_ptr<RawChannel> reader_rc(RawChannel::Create(handles[1].Pass()));
    398   io_thread()->PostTaskAndWait(FROM_HERE,
    399                                base::Bind(&InitOnIOThread, reader_rc.get(),
    400                                           base::Unretained(&reader_delegate)));
    401 
    402   {
    403     ScopedVector<RawChannelWriterThread> writer_threads;
    404     for (size_t i = 0; i < kNumWriterThreads; i++) {
    405       writer_threads.push_back(new RawChannelWriterThread(
    406           writer_rc.get(), kNumWriteMessagesPerThread));
    407     }
    408     for (size_t i = 0; i < writer_threads.size(); i++)
    409       writer_threads[i]->Start();
    410   }  // Joins all the writer threads.
    411 
    412   // Sleep a bit, to let any extraneous reads be processed. (There shouldn't be
    413   // any, but we want to know about them.)
    414   base::PlatformThread::Sleep(base::TimeDelta::FromMilliseconds(100));
    415 
    416   // Wait for reading to finish.
    417   reader_delegate.Wait();
    418 
    419   io_thread()->PostTaskAndWait(FROM_HERE,
    420                                base::Bind(&RawChannel::Shutdown,
    421                                           base::Unretained(reader_rc.get())));
    422 
    423   io_thread()->PostTaskAndWait(FROM_HERE,
    424                                base::Bind(&RawChannel::Shutdown,
    425                                           base::Unretained(writer_rc.get())));
    426 }
    427 
    428 // RawChannelTest.OnFatalError -------------------------------------------------
    429 
    430 class FatalErrorRecordingRawChannelDelegate
    431     : public ReadCountdownRawChannelDelegate {
    432  public:
    433   FatalErrorRecordingRawChannelDelegate(size_t expected_read_count,
    434                                         bool expect_read_error,
    435                                         bool expect_write_error)
    436       : ReadCountdownRawChannelDelegate(expected_read_count),
    437         got_read_fatal_error_event_(false, false),
    438         got_write_fatal_error_event_(false, false),
    439         expecting_read_error_(expect_read_error),
    440         expecting_write_error_(expect_write_error) {
    441   }
    442 
    443   virtual ~FatalErrorRecordingRawChannelDelegate() {}
    444 
    445   virtual void OnFatalError(FatalError fatal_error) OVERRIDE {
    446     switch (fatal_error) {
    447       case FATAL_ERROR_READ:
    448         ASSERT_TRUE(expecting_read_error_);
    449         expecting_read_error_ = false;
    450         got_read_fatal_error_event_.Signal();
    451         break;
    452       case FATAL_ERROR_WRITE:
    453         ASSERT_TRUE(expecting_write_error_);
    454         expecting_write_error_ = false;
    455         got_write_fatal_error_event_.Signal();
    456         break;
    457     }
    458   }
    459 
    460   void WaitForReadFatalError() { got_read_fatal_error_event_.Wait(); }
    461   void WaitForWriteFatalError() { got_write_fatal_error_event_.Wait(); }
    462 
    463  private:
    464   base::WaitableEvent got_read_fatal_error_event_;
    465   base::WaitableEvent got_write_fatal_error_event_;
    466 
    467   bool expecting_read_error_;
    468   bool expecting_write_error_;
    469 
    470   DISALLOW_COPY_AND_ASSIGN(FatalErrorRecordingRawChannelDelegate);
    471 };
    472 
    473 // Tests fatal errors.
    474 TEST_F(RawChannelTest, OnFatalError) {
    475   FatalErrorRecordingRawChannelDelegate delegate(0, true, true);
    476   scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
    477   io_thread()->PostTaskAndWait(FROM_HERE,
    478                                base::Bind(&InitOnIOThread, rc.get(),
    479                                           base::Unretained(&delegate)));
    480 
    481   // Close the handle of the other end, which should make writing fail.
    482   handles[1].reset();
    483 
    484   EXPECT_FALSE(rc->WriteMessage(MakeTestMessage(1)));
    485 
    486   // We should get a write fatal error.
    487   delegate.WaitForWriteFatalError();
    488 
    489   // We should also get a read fatal error.
    490   delegate.WaitForReadFatalError();
    491 
    492   EXPECT_FALSE(rc->WriteMessage(MakeTestMessage(2)));
    493 
    494   // Sleep a bit, to make sure we don't get another |OnFatalError()|
    495   // notification. (If we actually get another one, |OnFatalError()| crashes.)
    496   base::PlatformThread::Sleep(base::TimeDelta::FromMilliseconds(20));
    497 
    498   io_thread()->PostTaskAndWait(FROM_HERE,
    499                                base::Bind(&RawChannel::Shutdown,
    500                                           base::Unretained(rc.get())));
    501 }
    502 
    503 // RawChannelTest.ReadUnaffectedByWriteFatalError ------------------------------
    504 
    505 TEST_F(RawChannelTest, ReadUnaffectedByWriteFatalError) {
    506   const size_t kMessageCount = 5;
    507 
    508   // Write a few messages into the other end.
    509   uint32_t message_size = 1;
    510   for (size_t i = 0; i < kMessageCount;
    511        i++, message_size += message_size / 2 + 1)
    512     EXPECT_TRUE(WriteTestMessageToHandle(handles[1].get(), message_size));
    513 
    514   // Close the other end, which should make writing fail.
    515   handles[1].reset();
    516 
    517   // Only start up reading here. The system buffer should still contain the
    518   // messages that were written.
    519   FatalErrorRecordingRawChannelDelegate delegate(kMessageCount, true, true);
    520   scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
    521   io_thread()->PostTaskAndWait(FROM_HERE,
    522                                base::Bind(&InitOnIOThread, rc.get(),
    523                                           base::Unretained(&delegate)));
    524 
    525   EXPECT_FALSE(rc->WriteMessage(MakeTestMessage(1)));
    526 
    527   // We should definitely get a write fatal error.
    528   delegate.WaitForWriteFatalError();
    529 
    530   // Wait for reading to finish. A writing failure shouldn't affect reading.
    531   delegate.Wait();
    532 
    533   // And then we should get a read fatal error.
    534   delegate.WaitForReadFatalError();
    535 
    536   io_thread()->PostTaskAndWait(FROM_HERE,
    537                                base::Bind(&RawChannel::Shutdown,
    538                                           base::Unretained(rc.get())));
    539 }
    540 
    541 // RawChannelTest.WriteMessageAfterShutdown ------------------------------------
    542 
    543 // Makes sure that calling |WriteMessage()| after |Shutdown()| behaves
    544 // correctly.
    545 TEST_F(RawChannelTest, WriteMessageAfterShutdown) {
    546   WriteOnlyRawChannelDelegate delegate;
    547   scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
    548   io_thread()->PostTaskAndWait(FROM_HERE,
    549                                base::Bind(&InitOnIOThread, rc.get(),
    550                                           base::Unretained(&delegate)));
    551   io_thread()->PostTaskAndWait(FROM_HERE,
    552                                base::Bind(&RawChannel::Shutdown,
    553                                           base::Unretained(rc.get())));
    554 
    555   EXPECT_FALSE(rc->WriteMessage(MakeTestMessage(1)));
    556 }
    557 
    558 // RawChannelTest.ShutdownOnReadMessage ----------------------------------------
    559 
    560 class ShutdownOnReadMessageRawChannelDelegate : public RawChannel::Delegate {
    561  public:
    562   explicit ShutdownOnReadMessageRawChannelDelegate(RawChannel* raw_channel)
    563       : raw_channel_(raw_channel),
    564         done_event_(false, false),
    565         did_shutdown_(false) {}
    566   virtual ~ShutdownOnReadMessageRawChannelDelegate() {}
    567 
    568   // |RawChannel::Delegate| implementation (called on the I/O thread):
    569   virtual void OnReadMessage(
    570       const MessageInTransit::View& message_view,
    571       embedder::ScopedPlatformHandleVectorPtr platform_handles) OVERRIDE {
    572     EXPECT_FALSE(platform_handles);
    573     EXPECT_FALSE(did_shutdown_);
    574     EXPECT_TRUE(CheckMessageData(message_view.bytes(),
    575                 message_view.num_bytes()));
    576     raw_channel_->Shutdown();
    577     did_shutdown_ = true;
    578     done_event_.Signal();
    579   }
    580   virtual void OnFatalError(FatalError /*fatal_error*/) OVERRIDE {
    581     CHECK(false);  // Should not get called.
    582   }
    583 
    584   // Waits for shutdown.
    585   void Wait() {
    586     done_event_.Wait();
    587     EXPECT_TRUE(did_shutdown_);
    588   }
    589 
    590  private:
    591   RawChannel* const raw_channel_;
    592   base::WaitableEvent done_event_;
    593   bool did_shutdown_;
    594 
    595   DISALLOW_COPY_AND_ASSIGN(ShutdownOnReadMessageRawChannelDelegate);
    596 };
    597 
    598 TEST_F(RawChannelTest, ShutdownOnReadMessage) {
    599   // Write a few messages into the other end.
    600   for (size_t count = 0; count < 5; count++)
    601     EXPECT_TRUE(WriteTestMessageToHandle(handles[1].get(), 10));
    602 
    603   scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
    604   ShutdownOnReadMessageRawChannelDelegate delegate(rc.get());
    605   io_thread()->PostTaskAndWait(FROM_HERE,
    606                                base::Bind(&InitOnIOThread, rc.get(),
    607                                           base::Unretained(&delegate)));
    608 
    609   // Wait for the delegate, which will shut the |RawChannel| down.
    610   delegate.Wait();
    611 }
    612 
    613 // RawChannelTest.ShutdownOnFatalError{Read, Write} ----------------------------
    614 
    615 class ShutdownOnFatalErrorRawChannelDelegate : public RawChannel::Delegate {
    616  public:
    617   ShutdownOnFatalErrorRawChannelDelegate(RawChannel* raw_channel,
    618                                          FatalError shutdown_on_error_type)
    619       : raw_channel_(raw_channel),
    620         shutdown_on_error_type_(shutdown_on_error_type),
    621         done_event_(false, false),
    622         did_shutdown_(false) {}
    623   virtual ~ShutdownOnFatalErrorRawChannelDelegate() {}
    624 
    625   // |RawChannel::Delegate| implementation (called on the I/O thread):
    626   virtual void OnReadMessage(
    627       const MessageInTransit::View& /*message_view*/,
    628       embedder::ScopedPlatformHandleVectorPtr /*platform_handles*/) OVERRIDE {
    629     CHECK(false);  // Should not get called.
    630   }
    631   virtual void OnFatalError(FatalError fatal_error) OVERRIDE {
    632     EXPECT_FALSE(did_shutdown_);
    633     if (fatal_error != shutdown_on_error_type_)
    634       return;
    635     raw_channel_->Shutdown();
    636     did_shutdown_ = true;
    637     done_event_.Signal();
    638   }
    639 
    640   // Waits for shutdown.
    641   void Wait() {
    642     done_event_.Wait();
    643     EXPECT_TRUE(did_shutdown_);
    644   }
    645 
    646  private:
    647   RawChannel* const raw_channel_;
    648   const FatalError shutdown_on_error_type_;
    649   base::WaitableEvent done_event_;
    650   bool did_shutdown_;
    651 
    652   DISALLOW_COPY_AND_ASSIGN(ShutdownOnFatalErrorRawChannelDelegate);
    653 };
    654 
    655 TEST_F(RawChannelTest, ShutdownOnFatalErrorRead) {
    656   scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
    657   ShutdownOnFatalErrorRawChannelDelegate delegate(
    658       rc.get(), RawChannel::Delegate::FATAL_ERROR_READ);
    659   io_thread()->PostTaskAndWait(FROM_HERE,
    660                                base::Bind(&InitOnIOThread, rc.get(),
    661                                           base::Unretained(&delegate)));
    662 
    663   // Close the handle of the other end, which should stuff fail.
    664   handles[1].reset();
    665 
    666   // Wait for the delegate, which will shut the |RawChannel| down.
    667   delegate.Wait();
    668 }
    669 
    670 TEST_F(RawChannelTest, ShutdownOnFatalErrorWrite) {
    671   scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass()));
    672   ShutdownOnFatalErrorRawChannelDelegate delegate(
    673       rc.get(), RawChannel::Delegate::FATAL_ERROR_WRITE);
    674   io_thread()->PostTaskAndWait(FROM_HERE,
    675                                base::Bind(&InitOnIOThread, rc.get(),
    676                                           base::Unretained(&delegate)));
    677 
    678   // Close the handle of the other end, which should stuff fail.
    679   handles[1].reset();
    680 
    681   EXPECT_FALSE(rc->WriteMessage(MakeTestMessage(1)));
    682 
    683   // Wait for the delegate, which will shut the |RawChannel| down.
    684   delegate.Wait();
    685 }
    686 
    687 }  // namespace
    688 }  // namespace system
    689 }  // namespace mojo
    690