1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style license that can be 3 // found in the LICENSE file. 4 5 #include "remoting/protocol/authenticator_test_base.h" 6 7 #include "base/base64.h" 8 #include "base/files/file_path.h" 9 #include "base/files/file_util.h" 10 #include "base/path_service.h" 11 #include "base/test/test_timeouts.h" 12 #include "base/timer/timer.h" 13 #include "net/base/net_errors.h" 14 #include "net/base/test_data_directory.h" 15 #include "remoting/base/rsa_key_pair.h" 16 #include "remoting/protocol/authenticator.h" 17 #include "remoting/protocol/channel_authenticator.h" 18 #include "remoting/protocol/fake_stream_socket.h" 19 #include "testing/gtest/include/gtest/gtest.h" 20 #include "third_party/webrtc/libjingle/xmllite/xmlelement.h" 21 22 using testing::_; 23 using testing::SaveArg; 24 25 namespace remoting { 26 namespace protocol { 27 28 namespace { 29 30 ACTION_P(QuitThreadOnCounter, counter) { 31 --(*counter); 32 EXPECT_GE(*counter, 0); 33 if (*counter == 0) 34 base::MessageLoop::current()->Quit(); 35 } 36 37 } // namespace 38 39 AuthenticatorTestBase::MockChannelDoneCallback::MockChannelDoneCallback() {} 40 41 AuthenticatorTestBase::MockChannelDoneCallback::~MockChannelDoneCallback() {} 42 43 AuthenticatorTestBase::AuthenticatorTestBase() {} 44 45 AuthenticatorTestBase::~AuthenticatorTestBase() {} 46 47 void AuthenticatorTestBase::SetUp() { 48 base::FilePath certs_dir(net::GetTestCertsDirectory()); 49 50 base::FilePath cert_path = certs_dir.AppendASCII("unittest.selfsigned.der"); 51 ASSERT_TRUE(base::ReadFileToString(cert_path, &host_cert_)); 52 53 base::FilePath key_path = certs_dir.AppendASCII("unittest.key.bin"); 54 std::string key_string; 55 ASSERT_TRUE(base::ReadFileToString(key_path, &key_string)); 56 std::string key_base64; 57 base::Base64Encode(key_string, &key_base64); 58 key_pair_ = RsaKeyPair::FromString(key_base64); 59 ASSERT_TRUE(key_pair_.get()); 60 host_public_key_ = key_pair_->GetPublicKey(); 61 } 62 63 void AuthenticatorTestBase::RunAuthExchange() { 64 ContinueAuthExchangeWith(client_.get(), 65 host_.get(), 66 client_->started(), 67 host_->started()); 68 } 69 70 void AuthenticatorTestBase::RunHostInitiatedAuthExchange() { 71 ContinueAuthExchangeWith(host_.get(), 72 client_.get(), 73 host_->started(), 74 client_->started()); 75 } 76 77 // static 78 // This function sends a message from the sender and receiver and recursively 79 // calls itself to the send the next message from the receiver to the sender 80 // untils the authentication completes. 81 void AuthenticatorTestBase::ContinueAuthExchangeWith(Authenticator* sender, 82 Authenticator* receiver, 83 bool sender_started, 84 bool receiver_started) { 85 scoped_ptr<buzz::XmlElement> message; 86 ASSERT_NE(Authenticator::WAITING_MESSAGE, sender->state()); 87 if (sender->state() == Authenticator::ACCEPTED || 88 sender->state() == Authenticator::REJECTED) 89 return; 90 91 // Verify that once the started flag for either party is set to true, 92 // it should always stay true. 93 if (receiver_started) { 94 ASSERT_TRUE(receiver->started()); 95 } 96 97 if (sender_started) { 98 ASSERT_TRUE(sender->started()); 99 } 100 101 ASSERT_EQ(Authenticator::MESSAGE_READY, sender->state()); 102 message = sender->GetNextMessage(); 103 ASSERT_TRUE(message.get()); 104 ASSERT_NE(Authenticator::MESSAGE_READY, sender->state()); 105 106 ASSERT_EQ(Authenticator::WAITING_MESSAGE, receiver->state()); 107 receiver->ProcessMessage(message.get(), base::Bind( 108 &AuthenticatorTestBase::ContinueAuthExchangeWith, 109 base::Unretained(receiver), base::Unretained(sender), 110 receiver->started(), sender->started())); 111 } 112 113 void AuthenticatorTestBase::RunChannelAuth(bool expected_fail) { 114 client_fake_socket_.reset(new FakeStreamSocket()); 115 host_fake_socket_.reset(new FakeStreamSocket()); 116 client_fake_socket_->PairWith(host_fake_socket_.get()); 117 118 client_auth_->SecureAndAuthenticate( 119 client_fake_socket_.PassAs<net::StreamSocket>(), 120 base::Bind(&AuthenticatorTestBase::OnClientConnected, 121 base::Unretained(this))); 122 123 host_auth_->SecureAndAuthenticate( 124 host_fake_socket_.PassAs<net::StreamSocket>(), 125 base::Bind(&AuthenticatorTestBase::OnHostConnected, 126 base::Unretained(this))); 127 128 // Expect two callbacks to be called - the client callback and the host 129 // callback. 130 int callback_counter = 2; 131 132 EXPECT_CALL(client_callback_, OnDone(net::OK)) 133 .WillOnce(QuitThreadOnCounter(&callback_counter)); 134 if (expected_fail) { 135 EXPECT_CALL(host_callback_, OnDone(net::ERR_FAILED)) 136 .WillOnce(QuitThreadOnCounter(&callback_counter)); 137 } else { 138 EXPECT_CALL(host_callback_, OnDone(net::OK)) 139 .WillOnce(QuitThreadOnCounter(&callback_counter)); 140 } 141 142 // Ensure that .Run() does not run unbounded if the callbacks are never 143 // called. 144 base::Timer shutdown_timer(false, false); 145 shutdown_timer.Start(FROM_HERE, 146 TestTimeouts::action_timeout(), 147 base::MessageLoop::QuitClosure()); 148 message_loop_.Run(); 149 shutdown_timer.Stop(); 150 151 testing::Mock::VerifyAndClearExpectations(&client_callback_); 152 testing::Mock::VerifyAndClearExpectations(&host_callback_); 153 154 if (!expected_fail) { 155 ASSERT_TRUE(client_socket_.get() != NULL); 156 ASSERT_TRUE(host_socket_.get() != NULL); 157 } 158 } 159 160 void AuthenticatorTestBase::OnHostConnected( 161 int error, 162 scoped_ptr<net::StreamSocket> socket) { 163 host_callback_.OnDone(error); 164 host_socket_ = socket.Pass(); 165 } 166 167 void AuthenticatorTestBase::OnClientConnected( 168 int error, 169 scoped_ptr<net::StreamSocket> socket) { 170 client_callback_.OnDone(error); 171 client_socket_ = socket.Pass(); 172 } 173 174 } // namespace protocol 175 } // namespace remoting 176