1 // Copyright 2014 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "mojo/system/raw_channel.h" 6 7 #include <stdint.h> 8 9 #include <vector> 10 11 #include "base/bind.h" 12 #include "base/location.h" 13 #include "base/logging.h" 14 #include "base/macros.h" 15 #include "base/memory/scoped_ptr.h" 16 #include "base/memory/scoped_vector.h" 17 #include "base/rand_util.h" 18 #include "base/synchronization/lock.h" 19 #include "base/synchronization/waitable_event.h" 20 #include "base/test/test_io_thread.h" 21 #include "base/threading/platform_thread.h" // For |Sleep()|. 22 #include "base/threading/simple_thread.h" 23 #include "base/time/time.h" 24 #include "build/build_config.h" 25 #include "mojo/common/test/test_utils.h" 26 #include "mojo/embedder/platform_channel_pair.h" 27 #include "mojo/embedder/platform_handle.h" 28 #include "mojo/embedder/scoped_platform_handle.h" 29 #include "mojo/system/message_in_transit.h" 30 #include "mojo/system/test_utils.h" 31 #include "testing/gtest/include/gtest/gtest.h" 32 33 namespace mojo { 34 namespace system { 35 namespace { 36 37 scoped_ptr<MessageInTransit> MakeTestMessage(uint32_t num_bytes) { 38 std::vector<unsigned char> bytes(num_bytes, 0); 39 for (size_t i = 0; i < num_bytes; i++) 40 bytes[i] = static_cast<unsigned char>(i + num_bytes); 41 return make_scoped_ptr( 42 new MessageInTransit(MessageInTransit::kTypeMessagePipeEndpoint, 43 MessageInTransit::kSubtypeMessagePipeEndpointData, 44 num_bytes, 45 bytes.empty() ? nullptr : &bytes[0])); 46 } 47 48 bool CheckMessageData(const void* bytes, uint32_t num_bytes) { 49 const unsigned char* b = static_cast<const unsigned char*>(bytes); 50 for (uint32_t i = 0; i < num_bytes; i++) { 51 if (b[i] != static_cast<unsigned char>(i + num_bytes)) 52 return false; 53 } 54 return true; 55 } 56 57 void InitOnIOThread(RawChannel* raw_channel, RawChannel::Delegate* delegate) { 58 CHECK(raw_channel->Init(delegate)); 59 } 60 61 bool WriteTestMessageToHandle(const embedder::PlatformHandle& handle, 62 uint32_t num_bytes) { 63 scoped_ptr<MessageInTransit> message(MakeTestMessage(num_bytes)); 64 65 size_t write_size = 0; 66 mojo::test::BlockingWrite( 67 handle, message->main_buffer(), message->main_buffer_size(), &write_size); 68 return write_size == message->main_buffer_size(); 69 } 70 71 // ----------------------------------------------------------------------------- 72 73 class RawChannelTest : public testing::Test { 74 public: 75 RawChannelTest() : io_thread_(base::TestIOThread::kManualStart) {} 76 virtual ~RawChannelTest() {} 77 78 virtual void SetUp() OVERRIDE { 79 embedder::PlatformChannelPair channel_pair; 80 handles[0] = channel_pair.PassServerHandle(); 81 handles[1] = channel_pair.PassClientHandle(); 82 io_thread_.Start(); 83 } 84 85 virtual void TearDown() OVERRIDE { 86 io_thread_.Stop(); 87 handles[0].reset(); 88 handles[1].reset(); 89 } 90 91 protected: 92 base::TestIOThread* io_thread() { return &io_thread_; } 93 94 embedder::ScopedPlatformHandle handles[2]; 95 96 private: 97 base::TestIOThread io_thread_; 98 99 DISALLOW_COPY_AND_ASSIGN(RawChannelTest); 100 }; 101 102 // RawChannelTest.WriteMessage ------------------------------------------------- 103 104 class WriteOnlyRawChannelDelegate : public RawChannel::Delegate { 105 public: 106 WriteOnlyRawChannelDelegate() {} 107 virtual ~WriteOnlyRawChannelDelegate() {} 108 109 // |RawChannel::Delegate| implementation: 110 virtual void OnReadMessage( 111 const MessageInTransit::View& /*message_view*/, 112 embedder::ScopedPlatformHandleVectorPtr /*platform_handles*/) OVERRIDE { 113 CHECK(false); // Should not get called. 114 } 115 virtual void OnError(Error error) OVERRIDE { 116 // We'll get a read (shutdown) error when the connection is closed. 117 CHECK_EQ(error, ERROR_READ_SHUTDOWN); 118 } 119 120 private: 121 DISALLOW_COPY_AND_ASSIGN(WriteOnlyRawChannelDelegate); 122 }; 123 124 static const int64_t kMessageReaderSleepMs = 1; 125 static const size_t kMessageReaderMaxPollIterations = 3000; 126 127 class TestMessageReaderAndChecker { 128 public: 129 explicit TestMessageReaderAndChecker(embedder::PlatformHandle handle) 130 : handle_(handle) {} 131 ~TestMessageReaderAndChecker() { CHECK(bytes_.empty()); } 132 133 bool ReadAndCheckNextMessage(uint32_t expected_size) { 134 unsigned char buffer[4096]; 135 136 for (size_t i = 0; i < kMessageReaderMaxPollIterations;) { 137 size_t read_size = 0; 138 CHECK(mojo::test::NonBlockingRead( 139 handle_, buffer, sizeof(buffer), &read_size)); 140 141 // Append newly-read data to |bytes_|. 142 bytes_.insert(bytes_.end(), buffer, buffer + read_size); 143 144 // If we have the header.... 145 size_t message_size; 146 if (MessageInTransit::GetNextMessageSize( 147 bytes_.empty() ? nullptr : &bytes_[0], 148 bytes_.size(), 149 &message_size)) { 150 // If we've read the whole message.... 151 if (bytes_.size() >= message_size) { 152 bool rv = true; 153 MessageInTransit::View message_view(message_size, &bytes_[0]); 154 CHECK_EQ(message_view.main_buffer_size(), message_size); 155 156 if (message_view.num_bytes() != expected_size) { 157 LOG(ERROR) << "Wrong size: " << message_size << " instead of " 158 << expected_size << " bytes."; 159 rv = false; 160 } else if (!CheckMessageData(message_view.bytes(), 161 message_view.num_bytes())) { 162 LOG(ERROR) << "Incorrect message bytes."; 163 rv = false; 164 } 165 166 // Erase message data. 167 bytes_.erase(bytes_.begin(), 168 bytes_.begin() + message_view.main_buffer_size()); 169 return rv; 170 } 171 } 172 173 if (static_cast<size_t>(read_size) < sizeof(buffer)) { 174 i++; 175 base::PlatformThread::Sleep( 176 base::TimeDelta::FromMilliseconds(kMessageReaderSleepMs)); 177 } 178 } 179 180 LOG(ERROR) << "Too many iterations."; 181 return false; 182 } 183 184 private: 185 const embedder::PlatformHandle handle_; 186 187 // The start of the received data should always be on a message boundary. 188 std::vector<unsigned char> bytes_; 189 190 DISALLOW_COPY_AND_ASSIGN(TestMessageReaderAndChecker); 191 }; 192 193 // Tests writing (and verifies reading using our own custom reader). 194 TEST_F(RawChannelTest, WriteMessage) { 195 WriteOnlyRawChannelDelegate delegate; 196 scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass())); 197 TestMessageReaderAndChecker checker(handles[1].get()); 198 io_thread()->PostTaskAndWait( 199 FROM_HERE, 200 base::Bind(&InitOnIOThread, rc.get(), base::Unretained(&delegate))); 201 202 // Write and read, for a variety of sizes. 203 for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1) { 204 EXPECT_TRUE(rc->WriteMessage(MakeTestMessage(size))); 205 EXPECT_TRUE(checker.ReadAndCheckNextMessage(size)) << size; 206 } 207 208 // Write/queue and read afterwards, for a variety of sizes. 209 for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1) 210 EXPECT_TRUE(rc->WriteMessage(MakeTestMessage(size))); 211 for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1) 212 EXPECT_TRUE(checker.ReadAndCheckNextMessage(size)) << size; 213 214 io_thread()->PostTaskAndWait( 215 FROM_HERE, base::Bind(&RawChannel::Shutdown, base::Unretained(rc.get()))); 216 } 217 218 // RawChannelTest.OnReadMessage ------------------------------------------------ 219 220 class ReadCheckerRawChannelDelegate : public RawChannel::Delegate { 221 public: 222 ReadCheckerRawChannelDelegate() : done_event_(false, false), position_(0) {} 223 virtual ~ReadCheckerRawChannelDelegate() {} 224 225 // |RawChannel::Delegate| implementation (called on the I/O thread): 226 virtual void OnReadMessage( 227 const MessageInTransit::View& message_view, 228 embedder::ScopedPlatformHandleVectorPtr platform_handles) OVERRIDE { 229 EXPECT_FALSE(platform_handles); 230 231 size_t position; 232 size_t expected_size; 233 bool should_signal = false; 234 { 235 base::AutoLock locker(lock_); 236 CHECK_LT(position_, expected_sizes_.size()); 237 position = position_; 238 expected_size = expected_sizes_[position]; 239 position_++; 240 if (position_ >= expected_sizes_.size()) 241 should_signal = true; 242 } 243 244 EXPECT_EQ(expected_size, message_view.num_bytes()) << position; 245 if (message_view.num_bytes() == expected_size) { 246 EXPECT_TRUE( 247 CheckMessageData(message_view.bytes(), message_view.num_bytes())) 248 << position; 249 } 250 251 if (should_signal) 252 done_event_.Signal(); 253 } 254 virtual void OnError(Error error) OVERRIDE { 255 // We'll get a read (shutdown) error when the connection is closed. 256 CHECK_EQ(error, ERROR_READ_SHUTDOWN); 257 } 258 259 // Waits for all the messages (of sizes |expected_sizes_|) to be seen. 260 void Wait() { done_event_.Wait(); } 261 262 void SetExpectedSizes(const std::vector<uint32_t>& expected_sizes) { 263 base::AutoLock locker(lock_); 264 CHECK_EQ(position_, expected_sizes_.size()); 265 expected_sizes_ = expected_sizes; 266 position_ = 0; 267 } 268 269 private: 270 base::WaitableEvent done_event_; 271 272 base::Lock lock_; // Protects the following members. 273 std::vector<uint32_t> expected_sizes_; 274 size_t position_; 275 276 DISALLOW_COPY_AND_ASSIGN(ReadCheckerRawChannelDelegate); 277 }; 278 279 // Tests reading (writing using our own custom writer). 280 TEST_F(RawChannelTest, OnReadMessage) { 281 ReadCheckerRawChannelDelegate delegate; 282 scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass())); 283 io_thread()->PostTaskAndWait( 284 FROM_HERE, 285 base::Bind(&InitOnIOThread, rc.get(), base::Unretained(&delegate))); 286 287 // Write and read, for a variety of sizes. 288 for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1) { 289 delegate.SetExpectedSizes(std::vector<uint32_t>(1, size)); 290 291 EXPECT_TRUE(WriteTestMessageToHandle(handles[1].get(), size)); 292 293 delegate.Wait(); 294 } 295 296 // Set up reader and write as fast as we can. 297 // Write/queue and read afterwards, for a variety of sizes. 298 std::vector<uint32_t> expected_sizes; 299 for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1) 300 expected_sizes.push_back(size); 301 delegate.SetExpectedSizes(expected_sizes); 302 for (uint32_t size = 1; size < 5 * 1000 * 1000; size += size / 2 + 1) 303 EXPECT_TRUE(WriteTestMessageToHandle(handles[1].get(), size)); 304 delegate.Wait(); 305 306 io_thread()->PostTaskAndWait( 307 FROM_HERE, base::Bind(&RawChannel::Shutdown, base::Unretained(rc.get()))); 308 } 309 310 // RawChannelTest.WriteMessageAndOnReadMessage --------------------------------- 311 312 class RawChannelWriterThread : public base::SimpleThread { 313 public: 314 RawChannelWriterThread(RawChannel* raw_channel, size_t write_count) 315 : base::SimpleThread("raw_channel_writer_thread"), 316 raw_channel_(raw_channel), 317 left_to_write_(write_count) {} 318 319 virtual ~RawChannelWriterThread() { Join(); } 320 321 private: 322 virtual void Run() OVERRIDE { 323 static const int kMaxRandomMessageSize = 25000; 324 325 while (left_to_write_-- > 0) { 326 EXPECT_TRUE(raw_channel_->WriteMessage(MakeTestMessage( 327 static_cast<uint32_t>(base::RandInt(1, kMaxRandomMessageSize))))); 328 } 329 } 330 331 RawChannel* const raw_channel_; 332 size_t left_to_write_; 333 334 DISALLOW_COPY_AND_ASSIGN(RawChannelWriterThread); 335 }; 336 337 class ReadCountdownRawChannelDelegate : public RawChannel::Delegate { 338 public: 339 explicit ReadCountdownRawChannelDelegate(size_t expected_count) 340 : done_event_(false, false), expected_count_(expected_count), count_(0) {} 341 virtual ~ReadCountdownRawChannelDelegate() {} 342 343 // |RawChannel::Delegate| implementation (called on the I/O thread): 344 virtual void OnReadMessage( 345 const MessageInTransit::View& message_view, 346 embedder::ScopedPlatformHandleVectorPtr platform_handles) OVERRIDE { 347 EXPECT_FALSE(platform_handles); 348 349 EXPECT_LT(count_, expected_count_); 350 count_++; 351 352 EXPECT_TRUE( 353 CheckMessageData(message_view.bytes(), message_view.num_bytes())); 354 355 if (count_ >= expected_count_) 356 done_event_.Signal(); 357 } 358 virtual void OnError(Error error) OVERRIDE { 359 // We'll get a read (shutdown) error when the connection is closed. 360 CHECK_EQ(error, ERROR_READ_SHUTDOWN); 361 } 362 363 // Waits for all the messages to have been seen. 364 void Wait() { done_event_.Wait(); } 365 366 private: 367 base::WaitableEvent done_event_; 368 size_t expected_count_; 369 size_t count_; 370 371 DISALLOW_COPY_AND_ASSIGN(ReadCountdownRawChannelDelegate); 372 }; 373 374 TEST_F(RawChannelTest, WriteMessageAndOnReadMessage) { 375 static const size_t kNumWriterThreads = 10; 376 static const size_t kNumWriteMessagesPerThread = 4000; 377 378 WriteOnlyRawChannelDelegate writer_delegate; 379 scoped_ptr<RawChannel> writer_rc(RawChannel::Create(handles[0].Pass())); 380 io_thread()->PostTaskAndWait(FROM_HERE, 381 base::Bind(&InitOnIOThread, 382 writer_rc.get(), 383 base::Unretained(&writer_delegate))); 384 385 ReadCountdownRawChannelDelegate reader_delegate(kNumWriterThreads * 386 kNumWriteMessagesPerThread); 387 scoped_ptr<RawChannel> reader_rc(RawChannel::Create(handles[1].Pass())); 388 io_thread()->PostTaskAndWait(FROM_HERE, 389 base::Bind(&InitOnIOThread, 390 reader_rc.get(), 391 base::Unretained(&reader_delegate))); 392 393 { 394 ScopedVector<RawChannelWriterThread> writer_threads; 395 for (size_t i = 0; i < kNumWriterThreads; i++) { 396 writer_threads.push_back(new RawChannelWriterThread( 397 writer_rc.get(), kNumWriteMessagesPerThread)); 398 } 399 for (size_t i = 0; i < writer_threads.size(); i++) 400 writer_threads[i]->Start(); 401 } // Joins all the writer threads. 402 403 // Sleep a bit, to let any extraneous reads be processed. (There shouldn't be 404 // any, but we want to know about them.) 405 base::PlatformThread::Sleep(base::TimeDelta::FromMilliseconds(100)); 406 407 // Wait for reading to finish. 408 reader_delegate.Wait(); 409 410 io_thread()->PostTaskAndWait( 411 FROM_HERE, 412 base::Bind(&RawChannel::Shutdown, base::Unretained(reader_rc.get()))); 413 414 io_thread()->PostTaskAndWait( 415 FROM_HERE, 416 base::Bind(&RawChannel::Shutdown, base::Unretained(writer_rc.get()))); 417 } 418 419 // RawChannelTest.OnError ------------------------------------------------------ 420 421 class ErrorRecordingRawChannelDelegate 422 : public ReadCountdownRawChannelDelegate { 423 public: 424 ErrorRecordingRawChannelDelegate(size_t expected_read_count, 425 bool expect_read_error, 426 bool expect_write_error) 427 : ReadCountdownRawChannelDelegate(expected_read_count), 428 got_read_error_event_(false, false), 429 got_write_error_event_(false, false), 430 expecting_read_error_(expect_read_error), 431 expecting_write_error_(expect_write_error) {} 432 433 virtual ~ErrorRecordingRawChannelDelegate() {} 434 435 virtual void OnError(Error error) OVERRIDE { 436 switch (error) { 437 case ERROR_READ_SHUTDOWN: 438 ASSERT_TRUE(expecting_read_error_); 439 expecting_read_error_ = false; 440 got_read_error_event_.Signal(); 441 break; 442 case ERROR_READ_BROKEN: 443 // TODO(vtl): Test broken connections. 444 CHECK(false); 445 break; 446 case ERROR_READ_BAD_MESSAGE: 447 // TODO(vtl): Test reception/detection of bad messages. 448 CHECK(false); 449 break; 450 case ERROR_READ_UNKNOWN: 451 // TODO(vtl): Test however it is we might get here. 452 CHECK(false); 453 break; 454 case ERROR_WRITE: 455 ASSERT_TRUE(expecting_write_error_); 456 expecting_write_error_ = false; 457 got_write_error_event_.Signal(); 458 break; 459 } 460 } 461 462 void WaitForReadError() { got_read_error_event_.Wait(); } 463 void WaitForWriteError() { got_write_error_event_.Wait(); } 464 465 private: 466 base::WaitableEvent got_read_error_event_; 467 base::WaitableEvent got_write_error_event_; 468 469 bool expecting_read_error_; 470 bool expecting_write_error_; 471 472 DISALLOW_COPY_AND_ASSIGN(ErrorRecordingRawChannelDelegate); 473 }; 474 475 // Tests (fatal) errors. 476 TEST_F(RawChannelTest, OnError) { 477 ErrorRecordingRawChannelDelegate delegate(0, true, true); 478 scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass())); 479 io_thread()->PostTaskAndWait( 480 FROM_HERE, 481 base::Bind(&InitOnIOThread, rc.get(), base::Unretained(&delegate))); 482 483 // Close the handle of the other end, which should make writing fail. 484 handles[1].reset(); 485 486 EXPECT_FALSE(rc->WriteMessage(MakeTestMessage(1))); 487 488 // We should get a write error. 489 delegate.WaitForWriteError(); 490 491 // We should also get a read error. 492 delegate.WaitForReadError(); 493 494 EXPECT_FALSE(rc->WriteMessage(MakeTestMessage(2))); 495 496 // Sleep a bit, to make sure we don't get another |OnError()| 497 // notification. (If we actually get another one, |OnError()| crashes.) 498 base::PlatformThread::Sleep(base::TimeDelta::FromMilliseconds(20)); 499 500 io_thread()->PostTaskAndWait( 501 FROM_HERE, base::Bind(&RawChannel::Shutdown, base::Unretained(rc.get()))); 502 } 503 504 // RawChannelTest.ReadUnaffectedByWriteError ----------------------------------- 505 506 TEST_F(RawChannelTest, ReadUnaffectedByWriteError) { 507 const size_t kMessageCount = 5; 508 509 // Write a few messages into the other end. 510 uint32_t message_size = 1; 511 for (size_t i = 0; i < kMessageCount; 512 i++, message_size += message_size / 2 + 1) 513 EXPECT_TRUE(WriteTestMessageToHandle(handles[1].get(), message_size)); 514 515 // Close the other end, which should make writing fail. 516 handles[1].reset(); 517 518 // Only start up reading here. The system buffer should still contain the 519 // messages that were written. 520 ErrorRecordingRawChannelDelegate delegate(kMessageCount, true, true); 521 scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass())); 522 io_thread()->PostTaskAndWait( 523 FROM_HERE, 524 base::Bind(&InitOnIOThread, rc.get(), base::Unretained(&delegate))); 525 526 EXPECT_FALSE(rc->WriteMessage(MakeTestMessage(1))); 527 528 // We should definitely get a write error. 529 delegate.WaitForWriteError(); 530 531 // Wait for reading to finish. A writing failure shouldn't affect reading. 532 delegate.Wait(); 533 534 // And then we should get a read error. 535 delegate.WaitForReadError(); 536 537 io_thread()->PostTaskAndWait( 538 FROM_HERE, base::Bind(&RawChannel::Shutdown, base::Unretained(rc.get()))); 539 } 540 541 // RawChannelTest.WriteMessageAfterShutdown ------------------------------------ 542 543 // Makes sure that calling |WriteMessage()| after |Shutdown()| behaves 544 // correctly. 545 TEST_F(RawChannelTest, WriteMessageAfterShutdown) { 546 WriteOnlyRawChannelDelegate delegate; 547 scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass())); 548 io_thread()->PostTaskAndWait( 549 FROM_HERE, 550 base::Bind(&InitOnIOThread, rc.get(), base::Unretained(&delegate))); 551 io_thread()->PostTaskAndWait( 552 FROM_HERE, base::Bind(&RawChannel::Shutdown, base::Unretained(rc.get()))); 553 554 EXPECT_FALSE(rc->WriteMessage(MakeTestMessage(1))); 555 } 556 557 // RawChannelTest.ShutdownOnReadMessage ---------------------------------------- 558 559 class ShutdownOnReadMessageRawChannelDelegate : public RawChannel::Delegate { 560 public: 561 explicit ShutdownOnReadMessageRawChannelDelegate(RawChannel* raw_channel) 562 : raw_channel_(raw_channel), 563 done_event_(false, false), 564 did_shutdown_(false) {} 565 virtual ~ShutdownOnReadMessageRawChannelDelegate() {} 566 567 // |RawChannel::Delegate| implementation (called on the I/O thread): 568 virtual void OnReadMessage( 569 const MessageInTransit::View& message_view, 570 embedder::ScopedPlatformHandleVectorPtr platform_handles) OVERRIDE { 571 EXPECT_FALSE(platform_handles); 572 EXPECT_FALSE(did_shutdown_); 573 EXPECT_TRUE( 574 CheckMessageData(message_view.bytes(), message_view.num_bytes())); 575 raw_channel_->Shutdown(); 576 did_shutdown_ = true; 577 done_event_.Signal(); 578 } 579 virtual void OnError(Error /*error*/) OVERRIDE { 580 CHECK(false); // Should not get called. 581 } 582 583 // Waits for shutdown. 584 void Wait() { 585 done_event_.Wait(); 586 EXPECT_TRUE(did_shutdown_); 587 } 588 589 private: 590 RawChannel* const raw_channel_; 591 base::WaitableEvent done_event_; 592 bool did_shutdown_; 593 594 DISALLOW_COPY_AND_ASSIGN(ShutdownOnReadMessageRawChannelDelegate); 595 }; 596 597 TEST_F(RawChannelTest, ShutdownOnReadMessage) { 598 // Write a few messages into the other end. 599 for (size_t count = 0; count < 5; count++) 600 EXPECT_TRUE(WriteTestMessageToHandle(handles[1].get(), 10)); 601 602 scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass())); 603 ShutdownOnReadMessageRawChannelDelegate delegate(rc.get()); 604 io_thread()->PostTaskAndWait( 605 FROM_HERE, 606 base::Bind(&InitOnIOThread, rc.get(), base::Unretained(&delegate))); 607 608 // Wait for the delegate, which will shut the |RawChannel| down. 609 delegate.Wait(); 610 } 611 612 // RawChannelTest.ShutdownOnError{Read, Write} --------------------------------- 613 614 class ShutdownOnErrorRawChannelDelegate : public RawChannel::Delegate { 615 public: 616 ShutdownOnErrorRawChannelDelegate(RawChannel* raw_channel, 617 Error shutdown_on_error_type) 618 : raw_channel_(raw_channel), 619 shutdown_on_error_type_(shutdown_on_error_type), 620 done_event_(false, false), 621 did_shutdown_(false) {} 622 virtual ~ShutdownOnErrorRawChannelDelegate() {} 623 624 // |RawChannel::Delegate| implementation (called on the I/O thread): 625 virtual void OnReadMessage( 626 const MessageInTransit::View& /*message_view*/, 627 embedder::ScopedPlatformHandleVectorPtr /*platform_handles*/) OVERRIDE { 628 CHECK(false); // Should not get called. 629 } 630 virtual void OnError(Error error) OVERRIDE { 631 EXPECT_FALSE(did_shutdown_); 632 if (error != shutdown_on_error_type_) 633 return; 634 raw_channel_->Shutdown(); 635 did_shutdown_ = true; 636 done_event_.Signal(); 637 } 638 639 // Waits for shutdown. 640 void Wait() { 641 done_event_.Wait(); 642 EXPECT_TRUE(did_shutdown_); 643 } 644 645 private: 646 RawChannel* const raw_channel_; 647 const Error shutdown_on_error_type_; 648 base::WaitableEvent done_event_; 649 bool did_shutdown_; 650 651 DISALLOW_COPY_AND_ASSIGN(ShutdownOnErrorRawChannelDelegate); 652 }; 653 654 TEST_F(RawChannelTest, ShutdownOnErrorRead) { 655 scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass())); 656 ShutdownOnErrorRawChannelDelegate delegate( 657 rc.get(), RawChannel::Delegate::ERROR_READ_SHUTDOWN); 658 io_thread()->PostTaskAndWait( 659 FROM_HERE, 660 base::Bind(&InitOnIOThread, rc.get(), base::Unretained(&delegate))); 661 662 // Close the handle of the other end, which should stuff fail. 663 handles[1].reset(); 664 665 // Wait for the delegate, which will shut the |RawChannel| down. 666 delegate.Wait(); 667 } 668 669 TEST_F(RawChannelTest, ShutdownOnErrorWrite) { 670 scoped_ptr<RawChannel> rc(RawChannel::Create(handles[0].Pass())); 671 ShutdownOnErrorRawChannelDelegate delegate(rc.get(), 672 RawChannel::Delegate::ERROR_WRITE); 673 io_thread()->PostTaskAndWait( 674 FROM_HERE, 675 base::Bind(&InitOnIOThread, rc.get(), base::Unretained(&delegate))); 676 677 // Close the handle of the other end, which should stuff fail. 678 handles[1].reset(); 679 680 EXPECT_FALSE(rc->WriteMessage(MakeTestMessage(1))); 681 682 // Wait for the delegate, which will shut the |RawChannel| down. 683 delegate.Wait(); 684 } 685 686 } // namespace 687 } // namespace system 688 } // namespace mojo 689