1 // Copyright 2014 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "mojo/system/raw_channel.h" 6 7 #include <stdint.h> 8 9 #include <vector> 10 11 #include "base/bind.h" 12 #include "base/location.h" 13 #include "base/logging.h" 14 #include "base/macros.h" 15 #include "base/memory/scoped_ptr.h" 16 #include "base/memory/scoped_vector.h" 17 #include "base/rand_util.h" 18 #include "base/synchronization/lock.h" 19 #include "base/synchronization/waitable_event.h" 20 #include "base/threading/platform_thread.h" // For |Sleep()|. 21 #include "base/threading/simple_thread.h" 22 #include "base/time/time.h" 23 #include "build/build_config.h" 24 #include "mojo/common/test/test_utils.h" 25 #include "mojo/embedder/platform_channel_pair.h" 26 #include "mojo/embedder/platform_handle.h" 27 #include "mojo/embedder/scoped_platform_handle.h" 28 #include "mojo/system/message_in_transit.h" 29 #include "mojo/system/test_utils.h" 30 #include "testing/gtest/include/gtest/gtest.h" 31 32 namespace mojo { 33 namespace system { 34 namespace { 35 36 scoped_ptr<MessageInTransit> MakeTestMessage(uint32_t num_bytes) { 37 std::vector<unsigned char> bytes(num_bytes, 0); 38 for (size_t i = 0; i < num_bytes; i++) 39 bytes[i] = static_cast<unsigned char>(i + num_bytes); 40 return make_scoped_ptr( 41 new MessageInTransit(MessageInTransit::kTypeMessagePipeEndpoint, 42 MessageInTransit::kSubtypeMessagePipeEndpointData, 43 num_bytes, bytes.empty() ? NULL : &bytes[0])); 44 } 45 46 bool CheckMessageData(const void* bytes, uint32_t num_bytes) { 47 const unsigned char* b = static_cast<const unsigned char*>(bytes); 48 for (uint32_t i = 0; i < num_bytes; i++) { 49 if (b[i] != static_cast<unsigned char>(i + num_bytes)) 50 return false; 51 } 52 return true; 53 } 54 55 void InitOnIOThread(RawChannel* raw_channel, RawChannel::Delegate* delegate) { 56 CHECK(raw_channel->Init(delegate)); 57 } 58 59 bool WriteTestMessageToHandle(const embedder::PlatformHandle& handle, 60 uint32_t num_bytes) { 61 scoped_ptr<MessageInTransit> message(MakeTestMessage(num_bytes)); 62 63 size_t write_size = 0; 64 mojo::test::BlockingWrite( 65 handle, message->main_buffer(), message->main_buffer_size(), &write_size); 66 return write_size == message->main_buffer_size(); 67 } 68 69 // ----------------------------------------------------------------------------- 70 71 class RawChannelTest : public testing::Test { 72 public: 73 RawChannelTest() : io_thread_(test::TestIOThread::kManualStart) {} 74 virtual ~RawChannelTest() {} 75 76 virtual void SetUp() OVERRIDE { 77 embedder::PlatformChannelPair channel_pair; 78 handles[0] = channel_pair.PassServerHandle(); 79 handles[1] = channel_pair.PassClientHandle(); 80 io_thread_.Start(); 81 } 82 83 virtual void TearDown() OVERRIDE { 84 io_thread_.Stop(); 85 handles[0].reset(); 86 handles[1].reset(); 87 } 88 89 protected: 90 test::TestIOThread* io_thread() { return &io_thread_; } 91 92 embedder::ScopedPlatformHandle handles[2]; 93 94 private: 95 test::TestIOThread io_thread_; 96 97 DISALLOW_COPY_AND_ASSIGN(RawChannelTest); 98 }; 99 100 // RawChannelTest.WriteMessage ------------------------------------------------- 101 102 class WriteOnlyRawChannelDelegate : public RawChannel::Delegate { 103 public: 104 WriteOnlyRawChannelDelegate() {} 105 virtual ~WriteOnlyRawChannelDelegate() {} 106 107 // |RawChannel::Delegate| implementation: 108 virtual void OnReadMessage( 109 const MessageInTransit::View& /*message_view*/, 110 embedder::ScopedPlatformHandleVectorPtr /*platform_handles*/) OVERRIDE { 111 CHECK(false); // Should not get called. 112 } 113 virtual void OnFatalError(FatalError fatal_error) OVERRIDE { 114 // We'll get a read error when the connection is closed. 115 CHECK_EQ(fatal_error, FATAL_ERROR_READ); 116 } 117 118 private: 119 DISALLOW_COPY_AND_ASSIGN(WriteOnlyRawChannelDelegate); 120 }; 121 122 static const int64_t kMessageReaderSleepMs = 1; 123 static const size_t kMessageReaderMaxPollIterations = 3000; 124 125 class TestMessageReaderAndChecker { 126 public: 127 explicit TestMessageReaderAndChecker(embedder::PlatformHandle handle) 128 : handle_(handle) {} 129 ~TestMessageReaderAndChecker() { CHECK(bytes_.empty()); } 130 131 bool ReadAndCheckNextMessage(uint32_t expected_size) { 132 unsigned char buffer[4096]; 133 134 for (size_t i = 0; i < kMessageReaderMaxPollIterations;) { 135 size_t read_size = 0; 136 CHECK(mojo::test::NonBlockingRead(handle_, buffer, sizeof(buffer), 137 &read_size)); 138 139 // Append newly-read data to |bytes_|. 140 bytes_.insert(bytes_.end(), buffer, buffer + read_size); 141 142 // If we have the header.... 143 size_t message_size; 144 if (MessageInTransit::GetNextMessageSize( 145 bytes_.empty() ? NULL : &bytes_[0], 146 bytes_.size(), 147 &message_size)) { 148 // If we've read the whole message.... 149 if (bytes_.size() >= message_size) { 150 bool rv = true; 151 MessageInTransit::View message_view(message_size, &bytes_[0]); 152 CHECK_EQ(message_view.main_buffer_size(), message_size); 153 154 if (message_view.num_bytes() != expected_size) { 155 LOG(ERROR) << "Wrong size: " << message_size << " instead of " 156 << expected_size << " bytes."; 157 rv = false; 158 } else if (!CheckMessageData(message_view.bytes(), 159 message_view.num_bytes())) { 160 LOG(ERROR) << "Incorrect message bytes."; 161 rv = false; 162 } 163 164 // Erase message data. 165 bytes_.erase(bytes_.begin(), 166 bytes_.begin() + 167 message_view.main_buffer_size()); 168 return rv; 169 } 170 } 171 172 if (static_cast<size_t>(read_size) < sizeof(buffer)) { 173 i++; 174 base::PlatformThread::Sleep( 175 base::TimeDelta::FromMilliseconds(kMessageReaderSleepMs)); 176 } 177 } 178 179 LOG(ERROR) << "Too many iterations."; 180 return false; 181 } 182 183 private: 184 const embedder::PlatformHandle handle_; 185 186 // The start of the received data should always be on a message boundary. 187 std::vector<unsigned char> bytes_; 188 189 DISALLOW_COPY_AND_ASSIGN(TestMessageReaderAndChecker); 190 }; 191 192 // Tests writing (and verifies reading using our own custom reader). 193 TEST_F(RawChannelTest, WriteMessage) { 194 WriteOnlyRawChannelDelegate delegate; 195 scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass())); 196 TestMessageReaderAndChecker checker(handles[1].get()); 197 io_thread()->PostTaskAndWait(FROM_HERE, 198 base::Bind(&InitOnIOThread, rc.get(), 199 base::Unretained(&delegate))); 200 201 // Write and read, for a variety of sizes. 202 for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1) { 203 EXPECT_TRUE(rc->WriteMessage(MakeTestMessage(size))); 204 EXPECT_TRUE(checker.ReadAndCheckNextMessage(size)) << size; 205 } 206 207 // Write/queue and read afterwards, for a variety of sizes. 208 for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1) 209 EXPECT_TRUE(rc->WriteMessage(MakeTestMessage(size))); 210 for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1) 211 EXPECT_TRUE(checker.ReadAndCheckNextMessage(size)) << size; 212 213 io_thread()->PostTaskAndWait(FROM_HERE, 214 base::Bind(&RawChannel::Shutdown, 215 base::Unretained(rc.get()))); 216 } 217 218 // RawChannelTest.OnReadMessage ------------------------------------------------ 219 220 class ReadCheckerRawChannelDelegate : public RawChannel::Delegate { 221 public: 222 ReadCheckerRawChannelDelegate() 223 : done_event_(false, false), 224 position_(0) {} 225 virtual ~ReadCheckerRawChannelDelegate() {} 226 227 // |RawChannel::Delegate| implementation (called on the I/O thread): 228 virtual void OnReadMessage( 229 const MessageInTransit::View& message_view, 230 embedder::ScopedPlatformHandleVectorPtr platform_handles) OVERRIDE { 231 EXPECT_FALSE(platform_handles); 232 233 size_t position; 234 size_t expected_size; 235 bool should_signal = false; 236 { 237 base::AutoLock locker(lock_); 238 CHECK_LT(position_, expected_sizes_.size()); 239 position = position_; 240 expected_size = expected_sizes_[position]; 241 position_++; 242 if (position_ >= expected_sizes_.size()) 243 should_signal = true; 244 } 245 246 EXPECT_EQ(expected_size, message_view.num_bytes()) << position; 247 if (message_view.num_bytes() == expected_size) { 248 EXPECT_TRUE(CheckMessageData(message_view.bytes(), 249 message_view.num_bytes())) << position; 250 } 251 252 if (should_signal) 253 done_event_.Signal(); 254 } 255 virtual void OnFatalError(FatalError fatal_error) OVERRIDE { 256 // We'll get a read error when the connection is closed. 257 CHECK_EQ(fatal_error, FATAL_ERROR_READ); 258 } 259 260 // Waits for all the messages (of sizes |expected_sizes_|) to be seen. 261 void Wait() { 262 done_event_.Wait(); 263 } 264 265 void SetExpectedSizes(const std::vector<uint32_t>& expected_sizes) { 266 base::AutoLock locker(lock_); 267 CHECK_EQ(position_, expected_sizes_.size()); 268 expected_sizes_ = expected_sizes; 269 position_ = 0; 270 } 271 272 private: 273 base::WaitableEvent done_event_; 274 275 base::Lock lock_; // Protects the following members. 276 std::vector<uint32_t> expected_sizes_; 277 size_t position_; 278 279 DISALLOW_COPY_AND_ASSIGN(ReadCheckerRawChannelDelegate); 280 }; 281 282 // Tests reading (writing using our own custom writer). 283 TEST_F(RawChannelTest, OnReadMessage) { 284 ReadCheckerRawChannelDelegate delegate; 285 scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass())); 286 io_thread()->PostTaskAndWait(FROM_HERE, 287 base::Bind(&InitOnIOThread, rc.get(), 288 base::Unretained(&delegate))); 289 290 // Write and read, for a variety of sizes. 291 for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1) { 292 delegate.SetExpectedSizes(std::vector<uint32_t>(1, size)); 293 294 EXPECT_TRUE(WriteTestMessageToHandle(handles[1].get(), size)); 295 296 delegate.Wait(); 297 } 298 299 // Set up reader and write as fast as we can. 300 // Write/queue and read afterwards, for a variety of sizes. 301 std::vector<uint32_t> expected_sizes; 302 for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1) 303 expected_sizes.push_back(size); 304 delegate.SetExpectedSizes(expected_sizes); 305 for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1) 306 EXPECT_TRUE(WriteTestMessageToHandle(handles[1].get(), size)); 307 delegate.Wait(); 308 309 io_thread()->PostTaskAndWait(FROM_HERE, 310 base::Bind(&RawChannel::Shutdown, 311 base::Unretained(rc.get()))); 312 } 313 314 // RawChannelTest.WriteMessageAndOnReadMessage --------------------------------- 315 316 class RawChannelWriterThread : public base::SimpleThread { 317 public: 318 RawChannelWriterThread(RawChannel* raw_channel, size_t write_count) 319 : base::SimpleThread("raw_channel_writer_thread"), 320 raw_channel_(raw_channel), 321 left_to_write_(write_count) { 322 } 323 324 virtual ~RawChannelWriterThread() { 325 Join(); 326 } 327 328 private: 329 virtual void Run() OVERRIDE { 330 static const int kMaxRandomMessageSize = 25000; 331 332 while (left_to_write_-- > 0) { 333 EXPECT_TRUE(raw_channel_->WriteMessage(MakeTestMessage( 334 static_cast<uint32_t>(base::RandInt(1, kMaxRandomMessageSize))))); 335 } 336 } 337 338 RawChannel* const raw_channel_; 339 size_t left_to_write_; 340 341 DISALLOW_COPY_AND_ASSIGN(RawChannelWriterThread); 342 }; 343 344 class ReadCountdownRawChannelDelegate : public RawChannel::Delegate { 345 public: 346 explicit ReadCountdownRawChannelDelegate(size_t expected_count) 347 : done_event_(false, false), 348 expected_count_(expected_count), 349 count_(0) {} 350 virtual ~ReadCountdownRawChannelDelegate() {} 351 352 // |RawChannel::Delegate| implementation (called on the I/O thread): 353 virtual void OnReadMessage( 354 const MessageInTransit::View& message_view, 355 embedder::ScopedPlatformHandleVectorPtr platform_handles) OVERRIDE { 356 EXPECT_FALSE(platform_handles); 357 358 EXPECT_LT(count_, expected_count_); 359 count_++; 360 361 EXPECT_TRUE(CheckMessageData(message_view.bytes(), 362 message_view.num_bytes())); 363 364 if (count_ >= expected_count_) 365 done_event_.Signal(); 366 } 367 virtual void OnFatalError(FatalError fatal_error) OVERRIDE { 368 // We'll get a read error when the connection is closed. 369 CHECK_EQ(fatal_error, FATAL_ERROR_READ); 370 } 371 372 // Waits for all the messages to have been seen. 373 void Wait() { 374 done_event_.Wait(); 375 } 376 377 private: 378 base::WaitableEvent done_event_; 379 size_t expected_count_; 380 size_t count_; 381 382 DISALLOW_COPY_AND_ASSIGN(ReadCountdownRawChannelDelegate); 383 }; 384 385 TEST_F(RawChannelTest, WriteMessageAndOnReadMessage) { 386 static const size_t kNumWriterThreads = 10; 387 static const size_t kNumWriteMessagesPerThread = 4000; 388 389 WriteOnlyRawChannelDelegate writer_delegate; 390 scoped_ptr<RawChannel> writer_rc(RawChannel::Create(handles[0].Pass())); 391 io_thread()->PostTaskAndWait(FROM_HERE, 392 base::Bind(&InitOnIOThread, writer_rc.get(), 393 base::Unretained(&writer_delegate))); 394 395 ReadCountdownRawChannelDelegate reader_delegate( 396 kNumWriterThreads * kNumWriteMessagesPerThread); 397 scoped_ptr<RawChannel> reader_rc(RawChannel::Create(handles[1].Pass())); 398 io_thread()->PostTaskAndWait(FROM_HERE, 399 base::Bind(&InitOnIOThread, reader_rc.get(), 400 base::Unretained(&reader_delegate))); 401 402 { 403 ScopedVector<RawChannelWriterThread> writer_threads; 404 for (size_t i = 0; i < kNumWriterThreads; i++) { 405 writer_threads.push_back(new RawChannelWriterThread( 406 writer_rc.get(), kNumWriteMessagesPerThread)); 407 } 408 for (size_t i = 0; i < writer_threads.size(); i++) 409 writer_threads[i]->Start(); 410 } // Joins all the writer threads. 411 412 // Sleep a bit, to let any extraneous reads be processed. (There shouldn't be 413 // any, but we want to know about them.) 414 base::PlatformThread::Sleep(base::TimeDelta::FromMilliseconds(100)); 415 416 // Wait for reading to finish. 417 reader_delegate.Wait(); 418 419 io_thread()->PostTaskAndWait(FROM_HERE, 420 base::Bind(&RawChannel::Shutdown, 421 base::Unretained(reader_rc.get()))); 422 423 io_thread()->PostTaskAndWait(FROM_HERE, 424 base::Bind(&RawChannel::Shutdown, 425 base::Unretained(writer_rc.get()))); 426 } 427 428 // RawChannelTest.OnFatalError ------------------------------------------------- 429 430 class FatalErrorRecordingRawChannelDelegate 431 : public ReadCountdownRawChannelDelegate { 432 public: 433 FatalErrorRecordingRawChannelDelegate(size_t expected_read_count, 434 bool expect_read_error, 435 bool expect_write_error) 436 : ReadCountdownRawChannelDelegate(expected_read_count), 437 got_read_fatal_error_event_(false, false), 438 got_write_fatal_error_event_(false, false), 439 expecting_read_error_(expect_read_error), 440 expecting_write_error_(expect_write_error) { 441 } 442 443 virtual ~FatalErrorRecordingRawChannelDelegate() {} 444 445 virtual void OnFatalError(FatalError fatal_error) OVERRIDE { 446 switch (fatal_error) { 447 case FATAL_ERROR_READ: 448 ASSERT_TRUE(expecting_read_error_); 449 expecting_read_error_ = false; 450 got_read_fatal_error_event_.Signal(); 451 break; 452 case FATAL_ERROR_WRITE: 453 ASSERT_TRUE(expecting_write_error_); 454 expecting_write_error_ = false; 455 got_write_fatal_error_event_.Signal(); 456 break; 457 } 458 } 459 460 void WaitForReadFatalError() { got_read_fatal_error_event_.Wait(); } 461 void WaitForWriteFatalError() { got_write_fatal_error_event_.Wait(); } 462 463 private: 464 base::WaitableEvent got_read_fatal_error_event_; 465 base::WaitableEvent got_write_fatal_error_event_; 466 467 bool expecting_read_error_; 468 bool expecting_write_error_; 469 470 DISALLOW_COPY_AND_ASSIGN(FatalErrorRecordingRawChannelDelegate); 471 }; 472 473 // Tests fatal errors. 474 TEST_F(RawChannelTest, OnFatalError) { 475 FatalErrorRecordingRawChannelDelegate delegate(0, true, true); 476 scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass())); 477 io_thread()->PostTaskAndWait(FROM_HERE, 478 base::Bind(&InitOnIOThread, rc.get(), 479 base::Unretained(&delegate))); 480 481 // Close the handle of the other end, which should make writing fail. 482 handles[1].reset(); 483 484 EXPECT_FALSE(rc->WriteMessage(MakeTestMessage(1))); 485 486 // We should get a write fatal error. 487 delegate.WaitForWriteFatalError(); 488 489 // We should also get a read fatal error. 490 delegate.WaitForReadFatalError(); 491 492 EXPECT_FALSE(rc->WriteMessage(MakeTestMessage(2))); 493 494 // Sleep a bit, to make sure we don't get another |OnFatalError()| 495 // notification. (If we actually get another one, |OnFatalError()| crashes.) 496 base::PlatformThread::Sleep(base::TimeDelta::FromMilliseconds(20)); 497 498 io_thread()->PostTaskAndWait(FROM_HERE, 499 base::Bind(&RawChannel::Shutdown, 500 base::Unretained(rc.get()))); 501 } 502 503 // RawChannelTest.ReadUnaffectedByWriteFatalError ------------------------------ 504 505 TEST_F(RawChannelTest, ReadUnaffectedByWriteFatalError) { 506 const size_t kMessageCount = 5; 507 508 // Write a few messages into the other end. 509 uint32_t message_size = 1; 510 for (size_t i = 0; i < kMessageCount; 511 i++, message_size += message_size / 2 + 1) 512 EXPECT_TRUE(WriteTestMessageToHandle(handles[1].get(), message_size)); 513 514 // Close the other end, which should make writing fail. 515 handles[1].reset(); 516 517 // Only start up reading here. The system buffer should still contain the 518 // messages that were written. 519 FatalErrorRecordingRawChannelDelegate delegate(kMessageCount, true, true); 520 scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass())); 521 io_thread()->PostTaskAndWait(FROM_HERE, 522 base::Bind(&InitOnIOThread, rc.get(), 523 base::Unretained(&delegate))); 524 525 EXPECT_FALSE(rc->WriteMessage(MakeTestMessage(1))); 526 527 // We should definitely get a write fatal error. 528 delegate.WaitForWriteFatalError(); 529 530 // Wait for reading to finish. A writing failure shouldn't affect reading. 531 delegate.Wait(); 532 533 // And then we should get a read fatal error. 534 delegate.WaitForReadFatalError(); 535 536 io_thread()->PostTaskAndWait(FROM_HERE, 537 base::Bind(&RawChannel::Shutdown, 538 base::Unretained(rc.get()))); 539 } 540 541 // RawChannelTest.WriteMessageAfterShutdown ------------------------------------ 542 543 // Makes sure that calling |WriteMessage()| after |Shutdown()| behaves 544 // correctly. 545 TEST_F(RawChannelTest, WriteMessageAfterShutdown) { 546 WriteOnlyRawChannelDelegate delegate; 547 scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass())); 548 io_thread()->PostTaskAndWait(FROM_HERE, 549 base::Bind(&InitOnIOThread, rc.get(), 550 base::Unretained(&delegate))); 551 io_thread()->PostTaskAndWait(FROM_HERE, 552 base::Bind(&RawChannel::Shutdown, 553 base::Unretained(rc.get()))); 554 555 EXPECT_FALSE(rc->WriteMessage(MakeTestMessage(1))); 556 } 557 558 // RawChannelTest.ShutdownOnReadMessage ---------------------------------------- 559 560 class ShutdownOnReadMessageRawChannelDelegate : public RawChannel::Delegate { 561 public: 562 explicit ShutdownOnReadMessageRawChannelDelegate(RawChannel* raw_channel) 563 : raw_channel_(raw_channel), 564 done_event_(false, false), 565 did_shutdown_(false) {} 566 virtual ~ShutdownOnReadMessageRawChannelDelegate() {} 567 568 // |RawChannel::Delegate| implementation (called on the I/O thread): 569 virtual void OnReadMessage( 570 const MessageInTransit::View& message_view, 571 embedder::ScopedPlatformHandleVectorPtr platform_handles) OVERRIDE { 572 EXPECT_FALSE(platform_handles); 573 EXPECT_FALSE(did_shutdown_); 574 EXPECT_TRUE(CheckMessageData(message_view.bytes(), 575 message_view.num_bytes())); 576 raw_channel_->Shutdown(); 577 did_shutdown_ = true; 578 done_event_.Signal(); 579 } 580 virtual void OnFatalError(FatalError /*fatal_error*/) OVERRIDE { 581 CHECK(false); // Should not get called. 582 } 583 584 // Waits for shutdown. 585 void Wait() { 586 done_event_.Wait(); 587 EXPECT_TRUE(did_shutdown_); 588 } 589 590 private: 591 RawChannel* const raw_channel_; 592 base::WaitableEvent done_event_; 593 bool did_shutdown_; 594 595 DISALLOW_COPY_AND_ASSIGN(ShutdownOnReadMessageRawChannelDelegate); 596 }; 597 598 TEST_F(RawChannelTest, ShutdownOnReadMessage) { 599 // Write a few messages into the other end. 600 for (size_t count = 0; count < 5; count++) 601 EXPECT_TRUE(WriteTestMessageToHandle(handles[1].get(), 10)); 602 603 scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass())); 604 ShutdownOnReadMessageRawChannelDelegate delegate(rc.get()); 605 io_thread()->PostTaskAndWait(FROM_HERE, 606 base::Bind(&InitOnIOThread, rc.get(), 607 base::Unretained(&delegate))); 608 609 // Wait for the delegate, which will shut the |RawChannel| down. 610 delegate.Wait(); 611 } 612 613 // RawChannelTest.ShutdownOnFatalError{Read, Write} ---------------------------- 614 615 class ShutdownOnFatalErrorRawChannelDelegate : public RawChannel::Delegate { 616 public: 617 ShutdownOnFatalErrorRawChannelDelegate(RawChannel* raw_channel, 618 FatalError shutdown_on_error_type) 619 : raw_channel_(raw_channel), 620 shutdown_on_error_type_(shutdown_on_error_type), 621 done_event_(false, false), 622 did_shutdown_(false) {} 623 virtual ~ShutdownOnFatalErrorRawChannelDelegate() {} 624 625 // |RawChannel::Delegate| implementation (called on the I/O thread): 626 virtual void OnReadMessage( 627 const MessageInTransit::View& /*message_view*/, 628 embedder::ScopedPlatformHandleVectorPtr /*platform_handles*/) OVERRIDE { 629 CHECK(false); // Should not get called. 630 } 631 virtual void OnFatalError(FatalError fatal_error) OVERRIDE { 632 EXPECT_FALSE(did_shutdown_); 633 if (fatal_error != shutdown_on_error_type_) 634 return; 635 raw_channel_->Shutdown(); 636 did_shutdown_ = true; 637 done_event_.Signal(); 638 } 639 640 // Waits for shutdown. 641 void Wait() { 642 done_event_.Wait(); 643 EXPECT_TRUE(did_shutdown_); 644 } 645 646 private: 647 RawChannel* const raw_channel_; 648 const FatalError shutdown_on_error_type_; 649 base::WaitableEvent done_event_; 650 bool did_shutdown_; 651 652 DISALLOW_COPY_AND_ASSIGN(ShutdownOnFatalErrorRawChannelDelegate); 653 }; 654 655 TEST_F(RawChannelTest, ShutdownOnFatalErrorRead) { 656 scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass())); 657 ShutdownOnFatalErrorRawChannelDelegate delegate( 658 rc.get(), RawChannel::Delegate::FATAL_ERROR_READ); 659 io_thread()->PostTaskAndWait(FROM_HERE, 660 base::Bind(&InitOnIOThread, rc.get(), 661 base::Unretained(&delegate))); 662 663 // Close the handle of the other end, which should stuff fail. 664 handles[1].reset(); 665 666 // Wait for the delegate, which will shut the |RawChannel| down. 667 delegate.Wait(); 668 } 669 670 TEST_F(RawChannelTest, ShutdownOnFatalErrorWrite) { 671 scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass())); 672 ShutdownOnFatalErrorRawChannelDelegate delegate( 673 rc.get(), RawChannel::Delegate::FATAL_ERROR_WRITE); 674 io_thread()->PostTaskAndWait(FROM_HERE, 675 base::Bind(&InitOnIOThread, rc.get(), 676 base::Unretained(&delegate))); 677 678 // Close the handle of the other end, which should stuff fail. 679 handles[1].reset(); 680 681 EXPECT_FALSE(rc->WriteMessage(MakeTestMessage(1))); 682 683 // Wait for the delegate, which will shut the |RawChannel| down. 684 delegate.Wait(); 685 } 686 687 } // namespace 688 } // namespace system 689 } // namespace mojo 690