1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "remoting/protocol/authenticator_test_base.h" 6 7 #include "base/base64.h" 8 #include "base/file_util.h" 9 #include "base/files/file_path.h" 10 #include "base/path_service.h" 11 #include "base/test/test_timeouts.h" 12 #include "base/timer/timer.h" 13 #include "net/base/test_data_directory.h" 14 #include "remoting/base/rsa_key_pair.h" 15 #include "remoting/protocol/authenticator.h" 16 #include "remoting/protocol/channel_authenticator.h" 17 #include "remoting/protocol/fake_session.h" 18 #include "testing/gtest/include/gtest/gtest.h" 19 #include "third_party/libjingle/source/talk/xmllite/xmlelement.h" 20 21 using testing::_; 22 using testing::SaveArg; 23 24 namespace remoting { 25 namespace protocol { 26 27 namespace { 28 29 ACTION_P(QuitThreadOnCounter, counter) { 30 --(*counter); 31 EXPECT_GE(*counter, 0); 32 if (*counter == 0) 33 base::MessageLoop::current()->Quit(); 34 } 35 36 } // namespace 37 38 AuthenticatorTestBase::MockChannelDoneCallback::MockChannelDoneCallback() {} 39 40 AuthenticatorTestBase::MockChannelDoneCallback::~MockChannelDoneCallback() {} 41 42 AuthenticatorTestBase::AuthenticatorTestBase() {} 43 44 AuthenticatorTestBase::~AuthenticatorTestBase() {} 45 46 void AuthenticatorTestBase::SetUp() { 47 base::FilePath certs_dir(net::GetTestCertsDirectory()); 48 49 base::FilePath cert_path = certs_dir.AppendASCII("unittest.selfsigned.der"); 50 ASSERT_TRUE(base::ReadFileToString(cert_path, &host_cert_)); 51 52 base::FilePath key_path = certs_dir.AppendASCII("unittest.key.bin"); 53 std::string key_string; 54 ASSERT_TRUE(base::ReadFileToString(key_path, &key_string)); 55 std::string key_base64; 56 base::Base64Encode(key_string, &key_base64); 57 key_pair_ = RsaKeyPair::FromString(key_base64); 58 ASSERT_TRUE(key_pair_.get()); 59 host_public_key_ = key_pair_->GetPublicKey(); 60 } 61 62 void AuthenticatorTestBase::RunAuthExchange() { 63 ContinueAuthExchangeWith(client_.get(), host_.get()); 64 } 65 66 void AuthenticatorTestBase::RunHostInitiatedAuthExchange() { 67 ContinueAuthExchangeWith(host_.get(), client_.get()); 68 } 69 70 // static 71 void AuthenticatorTestBase::ContinueAuthExchangeWith(Authenticator* sender, 72 Authenticator* receiver) { 73 scoped_ptr<buzz::XmlElement> message; 74 ASSERT_NE(Authenticator::WAITING_MESSAGE, sender->state()); 75 if (sender->state() == Authenticator::ACCEPTED || 76 sender->state() == Authenticator::REJECTED) 77 return; 78 // Pass message from client to host. 79 ASSERT_EQ(Authenticator::MESSAGE_READY, sender->state()); 80 message = sender->GetNextMessage(); 81 ASSERT_TRUE(message.get()); 82 ASSERT_NE(Authenticator::MESSAGE_READY, sender->state()); 83 84 ASSERT_EQ(Authenticator::WAITING_MESSAGE, receiver->state()); 85 receiver->ProcessMessage(message.get(), base::Bind( 86 &AuthenticatorTestBase::ContinueAuthExchangeWith, 87 base::Unretained(receiver), base::Unretained(sender))); 88 } 89 90 void AuthenticatorTestBase::RunChannelAuth(bool expected_fail) { 91 client_fake_socket_.reset(new FakeSocket()); 92 host_fake_socket_.reset(new FakeSocket()); 93 client_fake_socket_->PairWith(host_fake_socket_.get()); 94 95 client_auth_->SecureAndAuthenticate( 96 client_fake_socket_.PassAs<net::StreamSocket>(), 97 base::Bind(&AuthenticatorTestBase::OnClientConnected, 98 base::Unretained(this))); 99 100 host_auth_->SecureAndAuthenticate( 101 host_fake_socket_.PassAs<net::StreamSocket>(), 102 base::Bind(&AuthenticatorTestBase::OnHostConnected, 103 base::Unretained(this))); 104 105 // Expect two callbacks to be called - the client callback and the host 106 // callback. 107 int callback_counter = 2; 108 109 EXPECT_CALL(client_callback_, OnDone(net::OK)) 110 .WillOnce(QuitThreadOnCounter(&callback_counter)); 111 if (expected_fail) { 112 EXPECT_CALL(host_callback_, OnDone(net::ERR_FAILED)) 113 .WillOnce(QuitThreadOnCounter(&callback_counter)); 114 } else { 115 EXPECT_CALL(host_callback_, OnDone(net::OK)) 116 .WillOnce(QuitThreadOnCounter(&callback_counter)); 117 } 118 119 // Ensure that .Run() does not run unbounded if the callbacks are never 120 // called. 121 base::Timer shutdown_timer(false, false); 122 shutdown_timer.Start(FROM_HERE, 123 TestTimeouts::action_timeout(), 124 base::MessageLoop::QuitClosure()); 125 message_loop_.Run(); 126 shutdown_timer.Stop(); 127 128 testing::Mock::VerifyAndClearExpectations(&client_callback_); 129 testing::Mock::VerifyAndClearExpectations(&host_callback_); 130 131 if (!expected_fail) { 132 ASSERT_TRUE(client_socket_.get() != NULL); 133 ASSERT_TRUE(host_socket_.get() != NULL); 134 } 135 } 136 137 void AuthenticatorTestBase::OnHostConnected( 138 net::Error error, 139 scoped_ptr<net::StreamSocket> socket) { 140 host_callback_.OnDone(error); 141 host_socket_ = socket.Pass(); 142 } 143 144 void AuthenticatorTestBase::OnClientConnected( 145 net::Error error, 146 scoped_ptr<net::StreamSocket> socket) { 147 client_callback_.OnDone(error); 148 client_socket_ = socket.Pass(); 149 } 150 151 } // namespace protocol 152 } // namespace remoting 153